Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement plotter for history state in mlr3viz #227

Open
sebffischer opened this issue Jun 13, 2024 · 2 comments
Open

implement plotter for history state in mlr3viz #227

sebffischer opened this issue Jun 13, 2024 · 2 comments
Labels
enhancement New feature or request good first issue Good for newcomers
Milestone

Comments

@sebffischer
Copy link
Member

No description provided.

@sebffischer
Copy link
Member Author

old code:

   #' @description Plots the history.
    #' @param measures (`character()`)\cr
    #'   Which measures to plot. No default.
    #' @param set (`character(1)`)\cr
    #'   Which set to plot. Either `"train"` or `"valid"`. Default is `"valid"`.
    #' @param epochs (`integer()`)\cr
    #'   An integer vector restricting which epochs to plot. Default is `NULL`, which plots all epochs.
    #' @param theme ([ggplot2::theme()])\cr
    #'   The theme, [ggplot2::theme_minimal()] is the default.
    #' @param ... (any)\cr
    #'   Currently unused.
    plot = function(measures, set = "valid", epochs = NULL, theme = ggplot2::theme_minimal(), ...) {
      assert_choice(set, c("valid", "train"))
      data = self[[set]]
      assert_subset(measures, colnames(data))

      if (is.null(epochs)) {
        data = data[, c("epoch", measures), with = FALSE]
      } else {
        assert_integerish(epochs, unique = TRUE)
        data = data[get("epoch") %in% epochs, c("epoch", measures), with = FALSE]
      }

      if ((!nrow(data)) || (ncol(data) < 2)) {
        stopf("No eligible measures to plot for set '%s'.", set)
      }

      epoch = score = measure = .data = NULL
      if (ncol(data) == 2L) {
        ggplot2::ggplot(data = data, ggplot2::aes(x = epoch, y = .data[[measures]])) +
          ggplot2::geom_line() +
          ggplot2::geom_point() +
          ggplot2::labs(
            x = "Epoch",
            y = measures,
            title = sprintf("%s Loss", switch(set, valid = "Validation", train = "Training"))
          ) +
          theme
      } else {
        data = melt(data, id.vars = "epoch", variable.name = "measure", value.name = "score")
        ggplot2::ggplot(data = data, ggplot2::aes(x = epoch, y = score, color = measure)) +
          viridis::scale_color_viridis(discrete = TRUE) +
          ggplot2::geom_line() +
          ggplot2::geom_point() +
          ggplot2::labs(
            x = "Epoch",
            y = "Score",
            title = sprintf("%s Loss", switch(set, valid = "Validation", train = "Training"))
          ) +
          theme
      }

@sebffischer sebffischer added good first issue Good for newcomers labels Jun 13, 2024
@sebffischer sebffischer modified the milestone: 0.2 Jun 13, 2024
@sebffischer
Copy link
Member Author

this should dispatch on LearnerTorch

@sebffischer sebffischer added the enhancement New feature or request label Jun 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

1 participant