Skip to content

Commit

Permalink
Merge pull request #27 from topepo/summarize
Browse files Browse the repository at this point in the history
estimate -> summarize = summarise
  • Loading branch information
topepo authored Sep 4, 2019
2 parents c74e608 + 73216ce commit 839ad82
Show file tree
Hide file tree
Showing 31 changed files with 290 additions and 85 deletions.
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
^docs$
^\.travis\.yml$
^vignettes
^inst/examples/finefoods.txt$
5 changes: 3 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Generated by roxygen2: do not edit by hand

S3method(estimate,grid_results)
S3method(merge,model_spec)
S3method(merge,recipe)
S3method(min_grid,boost_tree)
Expand Down Expand Up @@ -34,7 +33,6 @@ S3method(tune_args,step)
S3method(tune_args,workflow)
export(Bayes_control)
export(conf_bound)
export(estimate)
export(exp_improve)
export(expo_decay)
export(grid_control)
Expand All @@ -49,6 +47,8 @@ export(no_param)
export(param_set)
export(plot_perf_vs_iter)
export(prob_improve)
export(summarise.grid_results)
export(summarize)
export(tunable)
export(tune)
export(tune_Bayes)
Expand Down Expand Up @@ -77,6 +77,7 @@ importFrom(dplyr,rename)
importFrom(dplyr,sample_n)
importFrom(dplyr,select)
importFrom(dplyr,slice)
importFrom(dplyr,summarize)
importFrom(dplyr,ungroup)
importFrom(ggplot2,aes)
importFrom(ggplot2,facet_wrap)
Expand Down
4 changes: 4 additions & 0 deletions R/0_imports.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
#' @export
dials::param_set

#' @importFrom dplyr summarize
#' @export
dplyr::summarize

# ------------------------------------------------------------------------------

utils::globalVariables(
Expand Down
7 changes: 5 additions & 2 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,22 @@ check_initial <- function(x, pset, wflow, rs, perf, ctrl) {
message(msg)
}
x <- tune_grid(wflow, rs = rs, grid = x, perf = perf)
x <- estimate(x)
x <- summarize(x)
if (ctrl$verbose) {
msg <- paste(crayon::green(cli::symbol$tick), "Initialization complete")
message(msg)
message()
}
} else {
if (inherits(x, "grid_results")) {
x <- estimate(x)
x <- summarize(x)
} else {
x <- x
}
}
if (!any(names(x) == ".iter")) {
x <- x %>% dplyr::mutate(.iter = 0)
}
x
}

Expand Down
11 changes: 3 additions & 8 deletions R/estimate.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,8 @@
#' and the number of non-missing values (`n`). These are computed for each
#' metric and estimator type.
#' @export
estimate <- function(x, ...) {
UseMethod("estimate")
}


#' @export
#' @rdname estimate
estimate.grid_results <- function(x, ...) {
#' @rdname summarize.grid_results
summarise.grid_results <- function(x, ...) {
all_bad <- is_cataclysmic(x)
if (all_bad) {
stop("All of the models failed.", call. = FALSE)
Expand All @@ -37,3 +31,4 @@ estimate.grid_results <- function(x, ...) {
) %>%
ungroup()
}

15 changes: 2 additions & 13 deletions R/tune_Bayes.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,10 @@ tune_Bayes <-
param_info <- param_set(object)
}

on.exit({
warning("Optimization stopped prematurely; returning current results.", call. = FALSE)
return(initial)
})

initial_grid <- check_initial(initial, param_info, object, rs, perf, control)

check_time(start_time, control$time_limit)

if (!any(names(initial_grid) == ".iter")) {
res <- initial_grid %>% dplyr::mutate(.iter = 0)
} else {
res <- initial_grid
}

on.exit({
warning("Optimization stopped prematurely; returning current results.", call. = FALSE)
return(res)
Expand Down Expand Up @@ -144,11 +133,11 @@ tune_Bayes <-
all_bad <- is_cataclysmic(tmp_res)

if (!inherits(tmp_res, "try-error") & !all_bad) {
rs_estimate <- estimate(tmp_res)
rs_estimate <- summarize(tmp_res)
res <- dplyr::bind_rows(res, rs_estimate %>% dplyr::mutate(.iter = i))
current_val <-
tmp_res %>%
estimate() %>%
summarize() %>%
dplyr::filter(.metric == perf_name) %>%
dplyr::pull(mean)

Expand Down
2 changes: 1 addition & 1 deletion _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ reference:
- conf_bound
- Bayes_control
- grid_control
- estimate
- summarise.grid_results
- plot_perf_vs_iter
- title: Low-Level Functions
contents:
Expand Down
2 changes: 1 addition & 1 deletion docs/reference/index.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion docs/reference/reexports.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

187 changes: 187 additions & 0 deletions docs/reference/summarize.grid_results.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions inst/examples/Chicago_corr_knn.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ chi_grid <-

res <- tune_grid(chi_wflow, data_folds, chi_grid, control = grid_control(verbose = TRUE))

estimate(res) %>%
summarize(res) %>%
dplyr::filter(.metric == "rmse") %>%
select(-n, -std_err, -.estimator, -.metric) %>%
ggplot(aes(x = neighbors, y = mean, col = weight_func)) +
geom_point() + geom_line() +
facet_wrap(~threshold, scales = "free_x")

estimate(res) %>%
summarize(res) %>%
dplyr::filter(.metric == "rmse") %>%
arrange(mean) %>%
slice(1)
Expand Down
4 changes: 2 additions & 2 deletions inst/examples/Chicago_corr_lm.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ chi_grid <-

res <- tune_grid(chi_wflow, data_folds, chi_grid, control = grid_control(verbose = TRUE))

estimate(res) %>%
summarize(res) %>%
dplyr::filter(.metric == "rmse") %>%
select(-n, -std_err, -.estimator, -.metric) %>%
ggplot(aes(x = threshold, y = mean)) +
geom_point() +
geom_line()

estimate(res) %>%
summarize(res) %>%
dplyr::filter(.metric == "rmse") %>%
arrange(mean) %>%
slice(1)
Expand Down
Loading

0 comments on commit 839ad82

Please sign in to comment.