From 74963c317814e2e76bd050e03db1face0937d67c Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Fri, 25 Apr 2025 14:04:36 -0500 Subject: [PATCH 01/11] `usethis::use_tidy_upkeep_issue(2024)` --- DESCRIPTION | 1 + 1 file changed, 1 insertion(+) diff --git a/DESCRIPTION b/DESCRIPTION index ece826c..a95b717 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -69,3 +69,4 @@ Language: en-US LazyData: true Roxygen: list(markdown = TRUE) RoxygenNote: 7.3.2 +Config/usethis/last-upkeep: 2025-04-25 From e87b77a632b15587e5f0c85ce6ea655e4c12c752 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Fri, 25 Apr 2025 14:04:50 -0500 Subject: [PATCH 02/11] `usethis::use_air()` --- .Rbuildignore | 2 ++ .vscode/extensions.json | 5 +++++ .vscode/settings.json | 6 ++++++ air.toml | 0 4 files changed, 13 insertions(+) create mode 100644 .vscode/extensions.json create mode 100644 .vscode/settings.json create mode 100644 air.toml diff --git a/.Rbuildignore b/.Rbuildignore index b52b5c1..39a8253 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -11,3 +11,5 @@ ^revdep$ ^cran-comments\.md$ ^man-roxygen$ +^[\.]?air\.toml$ +^\.vscode$ diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000..344f76e --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,5 @@ +{ + "recommendations": [ + "Posit.air-vscode" + ] +} diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..f2d0b79 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,6 @@ +{ + "[r]": { + "editor.formatOnSave": true, + "editor.defaultFormatter": "Posit.air-vscode" + } +} diff --git a/air.toml b/air.toml new file mode 100644 index 0000000..e69de29 From bf1df8c7e334007ae3fc4b8c2cb55871d31339db Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Fri, 25 Apr 2025 14:05:37 -0500 Subject: [PATCH 03/11] `usethis::use_package("R", "Depends", "4.1")` --- DESCRIPTION | 2 +- NEWS.md | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index a95b717..abc8727 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -19,7 +19,7 @@ URL: https://github.com/tidymodels/workflowsets, https://workflowsets.tidymodels.org BugReports: https://github.com/tidymodels/workflowsets/issues Depends: - R (>= 3.6) + R (>= 4.1) Imports: cli, dplyr (>= 1.0.0), diff --git a/NEWS.md b/NEWS.md index ee7c4e1..f97597d 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,8 @@ * Added a `collect_extracts()` method for workflow sets (#156). +* Increased the minimum required R version to R 4.1. + # workflowsets 1.1.0 * Ellipses (...) are now used consistently in the package to require optional arguments to be named; `collect_metrics()` and `collect_predictions()` are the only functions that received changes (#151, tidymodels/tune#863). From cfc76cb3a7ec8ef5be62ecd92b1c31821a9f3124 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Fri, 25 Apr 2025 14:06:27 -0500 Subject: [PATCH 04/11] `air format .` --- R/0_imports.R | 35 +++++++++-- R/as_workflow_set.R | 5 +- R/autoplot.R | 61 +++++++++++++------ R/checks.R | 40 +++++++++--- R/collect.R | 16 +++-- R/comments.R | 4 +- R/compat-vctrs.R | 56 ++++++++++++++--- R/extract.R | 4 +- R/fit.R | 6 +- R/leave_var_out_formulas.R | 5 +- R/misc.R | 11 +++- R/options.R | 8 ++- R/rank_results.R | 31 ++++++++-- R/update.R | 6 +- R/workflow_map.R | 51 ++++++++++++---- R/workflow_set.R | 10 ++- tests/spelling.R | 3 +- tests/testthat/helper-extract_parameter_set.R | 10 ++- tests/testthat/test-autoplot.R | 33 ++++++++-- tests/testthat/test-collect-extracts.R | 10 +-- tests/testthat/test-collect-metrics.R | 8 ++- tests/testthat/test-collect-notes.R | 3 +- tests/testthat/test-collect-predictions.R | 36 ++++++++--- tests/testthat/test-compat-dplyr.R | 14 ++++- tests/testthat/test-compat-vctrs.R | 10 ++- tests/testthat/test-extract.R | 39 +++++++++--- tests/testthat/test-fit.R | 3 +- tests/testthat/test-fit_best.R | 10 ++- tests/testthat/test-options.R | 1 - tests/testthat/test-predict.R | 3 +- tests/testthat/test-pull.R | 3 +- tests/testthat/test-updates.R | 10 ++- tests/testthat/test-workflow-map.R | 1 - tests/testthat/test-workflow_set.R | 44 +++++++++---- 34 files changed, 453 insertions(+), 137 deletions(-) diff --git a/R/0_imports.R b/R/0_imports.R index ad73a7e..fa780e6 100644 --- a/R/0_imports.R +++ b/R/0_imports.R @@ -15,11 +15,36 @@ NULL utils::globalVariables( c( - ".config", ".estimator", ".metric", "info", "metric", "mod_nm", - "model", "n", "pp_nm", "preprocessor", "preproc", "object", "engine", - "result", "std_err", "wflow_id", "func", "is_race", "num_rs", "option", - "metrics", "predictions", "hash", "id", "workflow", "comment", "get_from_env", - ".get_tune_metric_names", "select_best", "notes" + ".config", + ".estimator", + ".metric", + "info", + "metric", + "mod_nm", + "model", + "n", + "pp_nm", + "preprocessor", + "preproc", + "object", + "engine", + "result", + "std_err", + "wflow_id", + "func", + "is_race", + "num_rs", + "option", + "metrics", + "predictions", + "hash", + "id", + "workflow", + "comment", + "get_from_env", + ".get_tune_metric_names", + "select_best", + "notes" ) ) diff --git a/R/as_workflow_set.R b/R/as_workflow_set.R index 5f2ee8e..121e76b 100644 --- a/R/as_workflow_set.R +++ b/R/as_workflow_set.R @@ -58,7 +58,10 @@ as_workflow_set <- function(...) { is_workflow <- purrr::map_lgl(object, ~ inherits(.x, "workflow")) wflows <- vector("list", length(is_workflow)) wflows[is_workflow] <- object[is_workflow] - wflows[!is_workflow] <- purrr::map(object[!is_workflow], tune::.get_tune_workflow) + wflows[!is_workflow] <- purrr::map( + object[!is_workflow], + tune::.get_tune_workflow + ) names(wflows) <- names(object) check_names(wflows) diff --git a/R/autoplot.R b/R/autoplot.R index d8d45f6..64039ef 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -44,12 +44,16 @@ #' autoplot(two_class_res, select_best = TRUE) #' autoplot(two_class_res, id = "yj_trans_cart", metric = "roc_auc") #' @export -autoplot.workflow_set <- function(object, rank_metric = NULL, metric = NULL, - id = "workflow_set", - select_best = FALSE, - std_errs = qnorm(0.95), - type = "class", - ...) { +autoplot.workflow_set <- function( + object, + rank_metric = NULL, + metric = NULL, + id = "workflow_set", + select_best = FALSE, + std_errs = qnorm(0.95), + type = "class", + ... +) { rlang::arg_match(type, c("class", "wflow_id")) check_string(rank_metric, allow_null = TRUE) check_character(metric, allow_null = TRUE) @@ -57,21 +61,39 @@ autoplot.workflow_set <- function(object, rank_metric = NULL, metric = NULL, check_bool(select_best) if (id == "workflow_set") { - p <- rank_plot(object, - rank_metric = rank_metric, metric = metric, - select_best = select_best, std_errs = std_errs, type = type + p <- rank_plot( + object, + rank_metric = rank_metric, + metric = metric, + select_best = select_best, + std_errs = std_errs, + type = type ) } else { - p <- autoplot(object$result[[which(object$wflow_id == id)]], metric = metric, ...) + p <- autoplot( + object$result[[which(object$wflow_id == id)]], + metric = metric, + ... + ) } p } -rank_plot <- function(object, rank_metric = NULL, metric = NULL, - select_best = FALSE, std_errs = 1, type = "class") { +rank_plot <- function( + object, + rank_metric = NULL, + metric = NULL, + select_best = FALSE, + std_errs = 1, + type = "class" +) { metric_info <- pick_metric(object, rank_metric, metric) metrics <- collate_metrics(object) - res <- rank_results(object, rank_metric = metric_info$metric, select_best = select_best) + res <- rank_results( + object, + rank_metric = metric_info$metric, + select_best = select_best + ) if (!is.null(metric)) { keep_metrics <- unique(c(rank_metric, metric)) @@ -82,13 +104,12 @@ rank_plot <- function(object, rank_metric = NULL, metric = NULL, has_std_error <- !all(is.na(res$std_err)) p <- - switch(type, - class = - ggplot(res, aes(x = rank, y = mean, col = model)) + - geom_point(aes(shape = preprocessor)), - wflow_id = - ggplot(res, aes(x = rank, y = mean, col = wflow_id)) + - geom_point() + switch( + type, + class = ggplot(res, aes(x = rank, y = mean, col = model)) + + geom_point(aes(shape = preprocessor)), + wflow_id = ggplot(res, aes(x = rank, y = mean, col = wflow_id)) + + geom_point() ) if (num_metrics > 1) { diff --git a/R/checks.R b/R/checks.R index ab1ffeb..4a006c6 100644 --- a/R/checks.R +++ b/R/checks.R @@ -13,7 +13,11 @@ check_consistent_metrics <- function(x, fail = TRUE, call = caller_env()) { metric_info <- dplyr::distinct(x, .metric, wflow_id) %>% dplyr::mutate(has = TRUE) %>% - tidyr::pivot_wider(names_from = ".metric", values_from = "has", values_fill = FALSE) %>% + tidyr::pivot_wider( + names_from = ".metric", + values_from = "has", + values_fill = FALSE + ) %>% dplyr::select(-wflow_id) %>% purrr::map_dbl(~ sum(!.x)) @@ -58,7 +62,6 @@ check_incompete <- function(x, fail = TRUE, call = caller_env()) { # TODO check for consistent resamples - # if global in local, overwrite or fail? common_options <- function(model, global) { @@ -73,7 +76,13 @@ common_options <- function(model, global) { res } -check_options <- function(model, id, global, action = "fail", call = caller_env()) { +check_options <- function( + model, + id, + global, + action = "fail", + call = caller_env() +) { res <- purrr::map_chr(model, common_options, global) flag <- nchar(res) > 0 if (any(flag)) { @@ -91,8 +100,15 @@ check_options <- function(model, id, global, action = "fail", call = caller_env( check_tune_args <- function(x, call = caller_env()) { arg_names <- c( - "resamples", "param_info", "grid", "metrics", "control", - "iter", "objective", "initial", "eval_time" + "resamples", + "param_info", + "grid", + "metrics", + "control", + "iter", + "objective", + "initial", + "eval_time" ) bad_args <- setdiff(x, arg_names) if (length(bad_args) > 0) { @@ -261,7 +277,11 @@ has_valid_column_option_structure <- function(x) { } has_valid_column_option_inner_types <- function(x) { option <- x$option - valid_options_indicator <- purrr::map_lgl(option, inherits, "workflow_set_options") + valid_options_indicator <- purrr::map_lgl( + option, + inherits, + "workflow_set_options" + ) all(valid_options_indicator) } @@ -290,7 +310,10 @@ has_valid_column_wflow_id_strings <- function(x) { has_all_pkgs <- function(w) { pkgs <- generics::required_pkgs(w, infra = FALSE) if (length(pkgs) > 0) { - is_inst <- purrr::map_lgl(pkgs, ~ rlang::is_true(requireNamespace(.x, quietly = TRUE))) + is_inst <- purrr::map_lgl( + pkgs, + ~ rlang::is_true(requireNamespace(.x, quietly = TRUE)) + ) if (!all(is_inst)) { cols <- tune::get_tune_colors() msg <- paste0( @@ -299,7 +322,8 @@ has_all_pkgs <- function(w) { ". Skipping this workflow." ) message( - cols$symbol$danger(cli::symbol$cross), " ", + cols$symbol$danger(cli::symbol$cross), + " ", cols$message$warning(msg) ) res <- FALSE diff --git a/R/collect.R b/R/collect.R index 28d4c70..38e43fd 100644 --- a/R/collect.R +++ b/R/collect.R @@ -99,8 +99,14 @@ reorder_cols <- function(x) { #' @export #' @rdname collect_metrics.workflow_set collect_predictions.workflow_set <- - function(x, ..., summarize = TRUE, parameters = NULL, select_best = FALSE, - metric = NULL) { + function( + x, + ..., + summarize = TRUE, + parameters = NULL, + select_best = FALSE, + metric = NULL + ) { rlang::check_dots_empty() check_incompete(x, fail = TRUE) check_bool(summarize) @@ -108,7 +114,8 @@ collect_predictions.workflow_set <- check_string(metric, allow_null = TRUE) if (select_best) { x <- - dplyr::mutate(x, + dplyr::mutate( + x, predictions = purrr::map( result, ~ select_bare_predictions( @@ -141,7 +148,8 @@ collect_predictions.workflow_set <- select_bare_predictions <- function(x, metric, summarize) { res <- - tune::collect_predictions(x, + tune::collect_predictions( + x, summarize = summarize, parameters = tune::select_best(x, metric = metric) ) diff --git a/R/comments.R b/R/comments.R index 832303c..5eaac4b 100644 --- a/R/comments.R +++ b/R/comments.R @@ -53,7 +53,9 @@ comment_add <- function(x, id, ..., append = TRUE, collapse = "\n") { id_index <- which(has_id) current_val <- x$info[[id_index]]$comment if (!is.na(current_val) && !append) { - cli::cli_abort("There is already a comment for this id and {.code append = FALSE}.") + cli::cli_abort( + "There is already a comment for this id and {.code append = FALSE}." + ) } new_value <- c(x$info[[id_index]]$comment, unlist(dots)) new_value <- new_value[!is.na(new_value) & nchar(new_value) > 0] diff --git a/R/compat-vctrs.R b/R/compat-vctrs.R index 7e36fe6..cdaadd7 100644 --- a/R/compat-vctrs.R +++ b/R/compat-vctrs.R @@ -30,7 +30,13 @@ vec_restore.workflow_set <- function(x, to, ...) { # workflow_set, so instead we always return a tibble. #' @export -vec_ptype2.workflow_set.workflow_set <- function(x, y, ..., x_arg = "", y_arg = "") { +vec_ptype2.workflow_set.workflow_set <- function( + x, + y, + ..., + x_arg = "", + y_arg = "" +) { out <- vctrs::df_ptype2(x, y, ..., x_arg = x_arg, y_arg = y_arg) workflow_set_maybe_reconstruct(out) } @@ -43,11 +49,23 @@ vec_ptype2.tbl_df.workflow_set <- function(x, y, ..., x_arg = "", y_arg = "") { vctrs::tib_ptype2(x, y, ..., x_arg = x_arg, y_arg = y_arg) } #' @export -vec_ptype2.workflow_set.data.frame <- function(x, y, ..., x_arg = "", y_arg = "") { +vec_ptype2.workflow_set.data.frame <- function( + x, + y, + ..., + x_arg = "", + y_arg = "" +) { vctrs::tib_ptype2(x, y, ..., x_arg = x_arg, y_arg = y_arg) } #' @export -vec_ptype2.data.frame.workflow_set <- function(x, y, ..., x_arg = "", y_arg = "") { +vec_ptype2.data.frame.workflow_set <- function( + x, + y, + ..., + x_arg = "", + y_arg = "" +) { vctrs::tib_ptype2(x, y, ..., x_arg = x_arg, y_arg = y_arg) } @@ -74,7 +92,13 @@ vec_ptype2.data.frame.workflow_set <- function(x, y, ..., x_arg = "", y_arg = "" # a common type of tibble, and then each input will be cast to tibble. #' @export -vec_cast.workflow_set.workflow_set <- function(x, to, ..., x_arg = "", to_arg = "") { +vec_cast.workflow_set.workflow_set <- function( + x, + to, + ..., + x_arg = "", + to_arg = "" +) { out <- vctrs::df_cast(x, to, ..., x_arg = x_arg, to_arg = to_arg) workflow_set_maybe_reconstruct(out) } @@ -87,11 +111,23 @@ vec_cast.tbl_df.workflow_set <- function(x, to, ..., x_arg = "", to_arg = "") { vctrs::tib_cast(x, to, ..., x_arg = x_arg, to_arg = to_arg) } #' @export -vec_cast.workflow_set.data.frame <- function(x, to, ..., x_arg = "", to_arg = "") { +vec_cast.workflow_set.data.frame <- function( + x, + to, + ..., + x_arg = "", + to_arg = "" +) { stop_incompatible_cast_workflow_set(x, to, x_arg = x_arg, to_arg = to_arg) } #' @export -vec_cast.data.frame.workflow_set <- function(x, to, ..., x_arg = "", to_arg = "") { +vec_cast.data.frame.workflow_set <- function( + x, + to, + ..., + x_arg = "", + to_arg = "" +) { vctrs::df_cast(x, to, ..., x_arg = x_arg, to_arg = to_arg) } @@ -99,5 +135,11 @@ vec_cast.data.frame.workflow_set <- function(x, to, ..., x_arg = "", to_arg = "" stop_incompatible_cast_workflow_set <- function(x, to, ..., x_arg, to_arg) { details <- "Can't cast to a because the resulting structure is likely invalid." - vctrs::stop_incompatible_cast(x, to, x_arg = x_arg, to_arg = to_arg, details = details) + vctrs::stop_incompatible_cast( + x, + to, + x_arg = x_arg, + to_arg = to_arg, + details = details + ) } diff --git a/R/extract.R b/R/extract.R index 14be1d7..4e03de7 100644 --- a/R/extract.R +++ b/R/extract.R @@ -94,7 +94,9 @@ extract_spec_parsnip.workflow_set <- function(x, id, ...) { extract_recipe.workflow_set <- function(x, id, ..., estimated = TRUE) { check_empty_dots(...) if (!rlang::is_bool(estimated)) { - cli::cli_abort("{.arg estimated} must be a single {.code TRUE} or {.code FALSE}.") + cli::cli_abort( + "{.arg estimated} must be a single {.code TRUE} or {.code FALSE}." + ) } y <- filter_id(x, id) extract_recipe(y$info[[1]]$workflow[[1]], estimated = estimated) diff --git a/R/fit.R b/R/fit.R index 6093d11..1aac8d3 100644 --- a/R/fit.R +++ b/R/fit.R @@ -11,13 +11,15 @@ fit.workflow_set <- function(object, ...) { if (!all(purrr::map_lgl(object$result, ~ identical(.x, list())))) { # if fitted: msg <- - c(msg, + c( + msg, "i" = "Please see {.help [{.fun fit_best}](workflowsets::fit_best.workflow_set)}." ) } else { # if not fitted: msg <- - c(msg, + c( + msg, "i" = "Please see {.help [{.fun workflow_map}](workflowsets::workflow_map)}." ) } diff --git a/R/leave_var_out_formulas.R b/R/leave_var_out_formulas.R index cf43e2e..0a59bb6 100644 --- a/R/leave_var_out_formulas.R +++ b/R/leave_var_out_formulas.R @@ -45,7 +45,10 @@ leave_var_out_formulas <- function(formula, data, full_model = TRUE, ...) { y_vars <- as.character(formula[[2]]) form_terms <- purrr::map(x_vars, rm_vars, lst = x_vars) - form <- purrr::map_chr(form_terms, ~ paste(y_vars, "~", paste(.x, collapse = " + "))) + form <- purrr::map_chr( + form_terms, + ~ paste(y_vars, "~", paste(.x, collapse = " + ")) + ) form <- purrr::map(form, as.formula) form <- purrr::map(form, rm_formula_env) names(form) <- x_vars diff --git a/R/misc.R b/R/misc.R index 04dff60..c06b13f 100644 --- a/R/misc.R +++ b/R/misc.R @@ -21,7 +21,6 @@ make_workflow <- function(x, y, call = caller_env()) { # ------------------------------------------------------------------------------ - metric_to_df <- function(x, ...) { metrics <- attributes(x)$metrics names <- names(metrics) @@ -44,7 +43,8 @@ collate_metrics <- function(x) { metrics %>% dplyr::group_by(metric) %>% dplyr::summarize( - order = mean(order, na.rm = TRUE), n = dplyr::n(), + order = mean(order, na.rm = TRUE), + n = dplyr::n(), .groups = "drop" ) @@ -56,7 +56,12 @@ collate_metrics <- function(x) { dplyr::arrange(order) } -pick_metric <- function(x, rank_metric, select_metrics = NULL, call = caller_env()) { +pick_metric <- function( + x, + rank_metric, + select_metrics = NULL, + call = caller_env() +) { # mostly to check for completeness and consistency: tmp <- collect_metrics(x) metrics <- collate_metrics(x) diff --git a/R/options.R b/R/options.R index 6cb2d01..a59e735 100644 --- a/R/options.R +++ b/R/options.R @@ -101,7 +101,6 @@ option_remove <- function(x, ...) { } - maybe_param <- function(x) { prm <- hardhat::extract_parameter_set_dials(x) if (nrow(prm) == 0) { @@ -132,7 +131,12 @@ option_add_parameters <- function(x, id = NULL, strict = FALSE) { if (length(ind) == 0) { cli::cli_warn("Don't have an {.arg id} value {i}") } else { - check_options(x$option[[ind]], x$wflow_id[[ind]], prm[[ind]], action = act) + check_options( + x$option[[ind]], + x$wflow_id[[ind]], + prm[[ind]], + action = act + ) x$option[[ind]] <- append_options(x$option[[ind]], prm[[ind]]) } } diff --git a/R/rank_results.R b/R/rank_results.R index bef8e9e..db9967c 100644 --- a/R/rank_results.R +++ b/R/rank_results.R @@ -27,7 +27,12 @@ #' rank_results(chi_features_res, select_best = TRUE) #' rank_results(chi_features_res, rank_metric = "rsq") #' @export -rank_results <- function(x, rank_metric = NULL, eval_time = NULL, select_best = FALSE) { +rank_results <- function( + x, + rank_metric = NULL, + eval_time = NULL, + select_best = FALSE +) { check_wf_set(x) check_string(rank_metric, allow_null = TRUE) check_bool(select_best) @@ -40,13 +45,21 @@ rank_results <- function(x, rank_metric = NULL, eval_time = NULL, select_best = metric_info <- pick_metric(x, rank_metric) metric <- metric_info$metric direction <- metric_info$direction - wflow_info <- dplyr::bind_cols(purrr::map_dfr(x$info, I), dplyr::select(x, wflow_id)) + wflow_info <- dplyr::bind_cols( + purrr::map_dfr(x$info, I), + dplyr::select(x, wflow_id) + ) eval_time <- tune::choose_eval_time(result_1, metric, eval_time = eval_time) results <- collect_metrics(x) %>% dplyr::select( - wflow_id, .config, .metric, mean, std_err, n, + wflow_id, + .config, + .metric, + mean, + std_err, + n, dplyr::any_of(".eval_time") ) %>% dplyr::full_join(wflow_info, by = "wflow_id") %>% @@ -77,7 +90,11 @@ rank_results <- function(x, rank_metric = NULL, eval_time = NULL, select_best = dplyr::distinct() if (nrow(rm_rows) > 0) { ranked <- dplyr::anti_join(ranked, rm_rows, by = c("wflow_id", ".config")) - results <- dplyr::anti_join(results, rm_rows, by = c("wflow_id", ".config")) + results <- dplyr::anti_join( + results, + rm_rows, + by = c("wflow_id", ".config") + ) } } @@ -91,7 +108,11 @@ rank_results <- function(x, rank_metric = NULL, eval_time = NULL, select_best = dplyr::slice_min(mean, with_ties = FALSE) %>% dplyr::ungroup() %>% dplyr::select(wflow_id, .config) - ranked <- dplyr::inner_join(ranked, best_by_wflow, by = c("wflow_id", ".config")) + ranked <- dplyr::inner_join( + ranked, + best_by_wflow, + by = c("wflow_id", ".config") + ) } # ensure reproducible rankings when there are ties diff --git a/R/update.R b/R/update.R index a0532fc..9e96550 100644 --- a/R/update.R +++ b/R/update.R @@ -55,7 +55,11 @@ update_workflow_recipe <- function(x, id, recipe, blueprint = NULL) { check_string(id) wflow <- extract_workflow(x, id = id) - wflow <- workflows::update_recipe(wflow, recipe = recipe, blueprint = blueprint) + wflow <- workflows::update_recipe( + wflow, + recipe = recipe, + blueprint = blueprint + ) id_ind <- which(x$wflow_id == id) x$info[[id_ind]]$workflow[[1]] <- wflow # Remove any existing results since they are now inconsistent diff --git a/R/workflow_map.R b/R/workflow_map.R index 7b9a71a..6437931 100644 --- a/R/workflow_map.R +++ b/R/workflow_map.R @@ -152,8 +152,13 @@ #' #' chi_features_res_new #' @export -workflow_map <- function(object, fn = "tune_grid", verbose = FALSE, - seed = sample.int(10^4, 1), ...) { +workflow_map <- function( + object, + fn = "tune_grid", + verbose = FALSE, + seed = sample.int(10^4, 1), + ... +) { check_wf_set(object) rlang::arg_match(fn, allowed_fn$func) @@ -187,8 +192,13 @@ workflow_map <- function(object, fn = "tune_grid", verbose = FALSE, .fn_info <- dplyr::filter(allowed_fn, func == .fn) log_progress( - verbose, object$wflow_id[[iter]], NULL, iter_chr[iter], - n, .fn, NULL + verbose, + object$wflow_id[[iter]], + NULL, + iter_chr[iter], + n, + .fn, + NULL ) if (has_all_pkgs(wflow)) { @@ -202,8 +212,13 @@ workflow_map <- function(object, fn = "tune_grid", verbose = FALSE, }) object <- new_workflow_set(object) log_progress( - verbose, object$wflow_id[[iter]], object$result[[iter]], - iter_chr[iter], n, .fn, run_time + verbose, + object$wflow_id[[iter]], + object$result[[iter]], + iter_chr[iter], + n, + .fn, + run_time ) } } @@ -214,8 +229,13 @@ workflow_map <- function(object, fn = "tune_grid", verbose = FALSE, allowed_fn <- tibble::tibble( func = c( - "tune_grid", "tune_bayes", "fit_resamples", "tune_race_anova", - "tune_race_win_loss", "tune_sim_anneal", "tune_cluster" + "tune_grid", + "tune_bayes", + "fit_resamples", + "tune_race_anova", + "tune_race_win_loss", + "tune_sim_anneal", + "tune_cluster" ), pkg = c(rep("tune", 3), rep("finetune", 3), "tidyclust") ) @@ -225,7 +245,8 @@ allowed_fn_list <- paste0("'", allowed_fn$func, "'", collapse = ", ") # --------------------------------------------- check_object_fn <- function(object, fn, call = rlang::caller_env()) { wf_specs <- purrr::map( - object$wflow_id, ~ extract_spec_parsnip(object, id = .x) + object$wflow_id, + ~ extract_spec_parsnip(object, id = .x) ) is_cluster_spec <- purrr::map_lgl(wf_specs, inherits, "cluster_spec") @@ -279,7 +300,8 @@ log_progress <- function(verbose, id, res, iter, n, .fn, elapsed) { errors_msg <- gsub("\n", "", as.character(res)) errors_msg <- gsub("Error : ", "", errors_msg, fixed = TRUE) message( - cols$symbol$danger(cli::symbol$cross), " ", + cols$symbol$danger(cli::symbol$cross), + " ", cols$message$info(msg), cols$message$info(" failed with: "), cols$message$danger(errors_msg) @@ -289,7 +311,8 @@ log_progress <- function(verbose, id, res, iter, n, .fn, elapsed) { if (is.null(res)) { message( - cols$symbol$info("i"), " ", + cols$symbol$info("i"), + " ", cols$message$info(msg) ) } else { @@ -301,7 +324,8 @@ log_progress <- function(verbose, id, res, iter, n, .fn, elapsed) { errors_msg <- gsub("\n", "", as.character(res)) errors_msg <- gsub("Error : ", "", errors_msg, fixed = TRUE) message( - cols$symbol$danger(cli::symbol$cross), " ", + cols$symbol$danger(cli::symbol$cross), + " ", cols$message$info(msg), cols$message$info(" failed with "), cols$message$danger(errors_msg) @@ -309,7 +333,8 @@ log_progress <- function(verbose, id, res, iter, n, .fn, elapsed) { } else { time_msg <- paste0(" (", prettyunits::pretty_sec(elapsed[3]), ")") message( - cols$symbol$success(cli::symbol$tick), " ", + cols$symbol$success(cli::symbol$tick), + " ", cols$message$info(msg), cols$message$info(time_msg) ) diff --git a/R/workflow_set.R b/R/workflow_set.R index 2d26c8c..32285a6 100644 --- a/R/workflow_set.R +++ b/R/workflow_set.R @@ -128,8 +128,9 @@ workflow_set <- function(preproc, models, cross = TRUE, case_weights = NULL) { check_bool(cross) - if (length(preproc) != length(models) & - (length(preproc) != 1 & length(models) != 1 & !cross) + if ( + length(preproc) != length(models) & + (length(preproc) != 1 & length(models) != 1 & !cross) ) { cli::cli_abort( "The lengths of {.arg preproc} and {.arg models} are different @@ -337,7 +338,10 @@ new_workflow_set <- function(x, call = caller_env()) { cli::cli_abort("The {.field info} column should be a list.", call = call) } if (!has_valid_column_info_inner_types(x)) { - cli::cli_abort("All elements of {.field info} must be tibbles.", call = call) + cli::cli_abort( + "All elements of {.field info} must be tibbles.", + call = call + ) } if (!has_valid_column_info_inner_names(x)) { columns <- required_info_inner_names() diff --git a/tests/spelling.R b/tests/spelling.R index 13f77d9..d60e024 100644 --- a/tests/spelling.R +++ b/tests/spelling.R @@ -1,6 +1,7 @@ if (requireNamespace("spelling", quietly = TRUE)) { spelling::spell_check_test( - vignettes = TRUE, error = FALSE, + vignettes = TRUE, + error = FALSE, skip_on_cran = TRUE ) } diff --git a/tests/testthat/helper-extract_parameter_set.R b/tests/testthat/helper-extract_parameter_set.R index 5112cae..ec3d393 100644 --- a/tests/testthat/helper-extract_parameter_set.R +++ b/tests/testthat/helper-extract_parameter_set.R @@ -1,5 +1,8 @@ check_parameter_set_tibble <- function(x) { - expect_equal(names(x), c("name", "id", "source", "component", "component_id", "object")) + expect_equal( + names(x), + c("name", "id", "source", "component", "component_id", "object") + ) expect_equal(class(x$name), "character") expect_equal(class(x$id), "character") expect_equal(class(x$source), "character") @@ -8,7 +11,10 @@ check_parameter_set_tibble <- function(x) { expect_true(!any(duplicated(x$id))) expect_equal(class(x$object), "list") - obj_check <- purrr::map_lgl(x$object, ~ inherits(.x, "param") | all(is.na(.x))) + obj_check <- purrr::map_lgl( + x$object, + ~ inherits(.x, "param") | all(is.na(.x)) + ) expect_true(all(obj_check)) invisible(TRUE) diff --git a/tests/testthat/test-autoplot.R b/tests/testthat/test-autoplot.R index 65e1b9d..6846c86 100644 --- a/tests/testthat/test-autoplot.R +++ b/tests/testthat/test-autoplot.R @@ -4,8 +4,15 @@ test_that("autoplot with error bars (class)", { expect_equal( names(p_1$data), c( - "wflow_id", ".config", ".metric", "mean", "std_err", "n", - "preprocessor", "model", "rank" + "wflow_id", + ".config", + ".metric", + "mean", + "std_err", + "n", + "preprocessor", + "model", + "rank" ) ) expect_equal(rlang::get_expr(p_1$mapping$x), expr(rank)) @@ -30,8 +37,15 @@ test_that("autoplot with error bars (wflow_id)", { expect_equal( names(p_1$data), c( - "wflow_id", ".config", ".metric", "mean", "std_err", "n", - "preprocessor", "model", "rank" + "wflow_id", + ".config", + ".metric", + "mean", + "std_err", + "n", + "preprocessor", + "model", + "rank" ) ) expect_equal(rlang::get_expr(p_1$mapping$x), expr(rank)) @@ -63,8 +77,15 @@ test_that("autoplot with without error bars", { expect_equal( names(p_2$data), c( - "wflow_id", ".config", ".metric", "mean", "std_err", "n", - "preprocessor", "model", "rank" + "wflow_id", + ".config", + ".metric", + "mean", + "std_err", + "n", + "preprocessor", + "model", + "rank" ) ) expect_equal(rlang::get_expr(p_2$mapping$x), expr(rank)) diff --git a/tests/testthat/test-collect-extracts.R b/tests/testthat/test-collect-extracts.R index f0c5eb6..758faa9 100644 --- a/tests/testthat/test-collect-extracts.R +++ b/tests/testthat/test-collect-extracts.R @@ -12,7 +12,8 @@ test_that("collect_extracts works", { wflow_set_trained <- wflow_set %>% - workflow_map("fit_resamples", + workflow_map( + "fit_resamples", resamples = folds, control = tune::control_resamples(extract = function(x) { x @@ -23,7 +24,8 @@ test_that("collect_extracts works", { expect_equal(nrow(extracts), 6) expect_contains( - class(extracts$.extracts[[1]]), "workflow" + class(extracts$.extracts[[1]]), + "workflow" ) expect_named(extracts, c("wflow_id", "id", ".extracts", ".config")) }) @@ -41,9 +43,7 @@ test_that("collect_extracts fails gracefully without .extracts column", { wflow_set_trained <- wflow_set %>% - workflow_map("fit_resamples", - resamples = folds - ) + workflow_map("fit_resamples", resamples = folds) expect_snapshot( error = TRUE, diff --git a/tests/testthat/test-collect-metrics.R b/tests/testthat/test-collect-metrics.R index 3dc4958..aafdf0e 100644 --- a/tests/testthat/test-collect-metrics.R +++ b/tests/testthat/test-collect-metrics.R @@ -40,8 +40,12 @@ test_that("ranking models", { # expected number of rows per metric per model param_lines <- c( - none_cart = 10, none_glm = 1, none_mars = 2, - yj_trans_cart = 10, yj_trans_glm = 1, yj_trans_mars = 2 + none_cart = 10, + none_glm = 1, + none_mars = 2, + yj_trans_cart = 10, + yj_trans_glm = 1, + yj_trans_mars = 2 ) expect_no_error(ranking_1 <- rank_results(two_class_res)) diff --git a/tests/testthat/test-collect-notes.R b/tests/testthat/test-collect-notes.R index 7b2802a..9eb9e67 100644 --- a/tests/testthat/test-collect-notes.R +++ b/tests/testthat/test-collect-notes.R @@ -12,7 +12,8 @@ test_that("collect_notes works", { wflow_set_trained <- wflow_set %>% - workflow_map("fit_resamples", + workflow_map( + "fit_resamples", resamples = folds, control = tune::control_resamples(extract = function(x) { warn("hey!") diff --git a/tests/testthat/test-collect-predictions.R b/tests/testthat/test-collect-predictions.R index 1c06d58..3ea8fe4 100644 --- a/tests/testthat/test-collect-predictions.R +++ b/tests/testthat/test-collect-predictions.R @@ -22,7 +22,8 @@ car_set_1 <- list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), list(lm = lr_spec) ) %>% - workflow_map("fit_resamples", + workflow_map( + "fit_resamples", resamples = vfold_cv(mtcars, v = 3), control = tune::control_resamples(save_pred = TRUE) ) @@ -36,7 +37,8 @@ car_set_2 <- list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), list(lm = lr_spec) ) %>% - workflow_map("fit_resamples", + workflow_map( + "fit_resamples", resamples = resamples, control = tune::control_resamples(save_pred = TRUE) ) @@ -47,17 +49,19 @@ car_set_3 <- list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), list(knn = knn_spec) ) %>% - workflow_map("tune_bayes", + workflow_map( + "tune_bayes", resamples = resamples, control = tune::control_bayes(save_pred = TRUE), - seed = 1, iter = 2, initial = 3 + seed = 1, + iter = 2, + initial = 3 ) car_set_23 <- dplyr::bind_rows(car_set_2, car_set_3) # ------------------------------------------------------------------------------ - check_prediction_results <- function(ind, x, summarize = FALSE, ...) { id_val <- x$wflow_id[ind] @@ -104,7 +108,9 @@ test_that("collect predictions", { expect_no_error( res_car_set_3_reps <- collect_predictions(car_set_3, summarize = FALSE) ) - expect_true(nrow(mtcars) * nrow(car_set_2) * 5 * 2 == nrow(res_car_set_3_reps)) + expect_true( + nrow(mtcars) * nrow(car_set_2) * 5 * 2 == nrow(res_car_set_3_reps) + ) # --------------------------------------------------------------------------- # These don't seem to get captured by covr @@ -144,12 +150,26 @@ test_that("dropping tuning parameter columns", { ) expect_named( collect_predictions(car_set_2, summarize = FALSE), - c("wflow_id", ".config", "preproc", "model", "id", "id2", ".pred", ".row", "mpg"), + c( + "wflow_id", + ".config", + "preproc", + "model", + "id", + "id2", + ".pred", + ".row", + "mpg" + ), ignore.order = TRUE ) expect_no_error( - best_iter <- collect_predictions(car_set_3, select_best = TRUE, metric = "rmse") + best_iter <- collect_predictions( + car_set_3, + select_best = TRUE, + metric = "rmse" + ) ) expect_true( nrow(dplyr::distinct(best_iter[, c(".config", "wflow_id")])) == 2 diff --git a/tests/testthat/test-compat-dplyr.R b/tests/testthat/test-compat-dplyr.R index 06e15ad..428591f 100644 --- a/tests/testthat/test-compat-dplyr.R +++ b/tests/testthat/test-compat-dplyr.R @@ -100,7 +100,13 @@ test_that("workflow_set subclass is kept if row order is changed", { test_that("summarise() always drops the workflow_set class", { for (x in workflow_set_objects) { expect_s3_class_bare_tibble(summarise(x, y = 1)) - expect_s3_class_bare_tibble(summarise(x, wflow_id = wflow_id[1], info = info[1], option = option[1], result = result[1])) + expect_s3_class_bare_tibble(summarise( + x, + wflow_id = wflow_id[1], + info = info[1], + option = option[1], + result = result[1] + )) } }) @@ -110,7 +116,11 @@ test_that("summarise() always drops the workflow_set class", { test_that("group_by() always returns a bare grouped-df or bare tibble", { for (x in workflow_set_objects) { expect_s3_class_bare_tibble(group_by(x)) - expect_s3_class(group_by(x, wflow_id), c("grouped_df", "tbl_df", "tbl", "data.frame"), exact = TRUE) + expect_s3_class( + group_by(x, wflow_id), + c("grouped_df", "tbl_df", "tbl", "data.frame"), + exact = TRUE + ) } }) diff --git a/tests/testthat/test-compat-vctrs.R b/tests/testthat/test-compat-vctrs.R index aa77b2a..8b42d28 100644 --- a/tests/testthat/test-compat-vctrs.R +++ b/tests/testthat/test-compat-vctrs.R @@ -133,7 +133,13 @@ test_that("vec_cbind() returns a bare tibble", { # Unlike vec_c() and vec_rbind(), the prototype of the output comes # from doing `x[0]`, which will drop the workflow_set class expect_identical(vec_cbind(x), vec_cbind(tbl)) - expect_identical(vec_cbind(x, x, .name_repair = "minimal"), vec_cbind(tbl, tbl, .name_repair = "minimal")) - expect_identical(vec_cbind(x, tbl, .name_repair = "minimal"), vec_cbind(tbl, tbl, .name_repair = "minimal")) + expect_identical( + vec_cbind(x, x, .name_repair = "minimal"), + vec_cbind(tbl, tbl, .name_repair = "minimal") + ) + expect_identical( + vec_cbind(x, tbl, .name_repair = "minimal"), + vec_cbind(tbl, tbl, .name_repair = "minimal") + ) } }) diff --git a/tests/testthat/test-extract.R b/tests/testthat/test-extract.R index 98b9a6b..d877ea5 100644 --- a/tests/testthat/test-extract.R +++ b/tests/testthat/test-extract.R @@ -16,7 +16,8 @@ car_set_1 <- ), list(lm = lr_spec) ) %>% - workflow_map("fit_resamples", + workflow_map( + "fit_resamples", resamples = vfold_cv(mtcars, v = 3), control = tune::control_resamples(save_pred = TRUE) ) @@ -85,7 +86,10 @@ test_that("extract parameter set from workflow set with untunable workflow", { lm_model <- parsnip::linear_reg() %>% parsnip::set_engine("lm") bst_model <- - parsnip::boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) %>% + parsnip::boost_tree( + mode = "classification", + trees = hardhat::tune("funky name \n") + ) %>% parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE) wf_set <- workflow_set( list(reg = rm_rec), @@ -103,7 +107,10 @@ test_that("extract parameter set from workflow set with tunable workflow", { lm_model <- parsnip::linear_reg() %>% parsnip::set_engine("lm") bst_model <- - parsnip::boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) %>% + parsnip::boost_tree( + mode = "classification", + trees = hardhat::tune("funky name \n") + ) %>% parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE) wf_set <- workflow_set( list(reg = rm_rec), @@ -131,7 +138,8 @@ test_that("extract parameter set from workflow set with tunable workflow", { c5_new_info <- c5_info %>% update( - rules = dials::new_qual_param("logical", + rules = dials::new_qual_param( + "logical", values = c(TRUE, FALSE), label = c(rules = "Rules") ) @@ -150,14 +158,16 @@ test_that("extract parameter set from workflow set with tunable workflow", { }) - test_that("extract single parameter from workflow set with untunable workflow", { rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>% recipes::step_rm(date, ends_with("away")) lm_model <- parsnip::linear_reg() %>% parsnip::set_engine("lm") bst_model <- - parsnip::boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) %>% + parsnip::boost_tree( + mode = "classification", + trees = hardhat::tune("funky name \n") + ) %>% parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE) wf_set <- workflow_set( list(reg = rm_rec), @@ -166,7 +176,11 @@ test_that("extract single parameter from workflow set with untunable workflow", expect_snapshot( error = TRUE, - hardhat::extract_parameter_dials(wf_set, id = "reg_lm", parameter = "non there") + hardhat::extract_parameter_dials( + wf_set, + id = "reg_lm", + parameter = "non there" + ) ) }) @@ -176,7 +190,10 @@ test_that("extract single parameter from workflow set with tunable workflow", { lm_model <- parsnip::linear_reg() %>% parsnip::set_engine("lm") bst_model <- - parsnip::boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) %>% + parsnip::boost_tree( + mode = "classification", + trees = hardhat::tune("funky name \n") + ) %>% parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE) wf_set <- workflow_set( list(reg = rm_rec), @@ -184,7 +201,11 @@ test_that("extract single parameter from workflow set with tunable workflow", { ) expect_equal( - hardhat::extract_parameter_dials(wf_set, id = "reg_bst", parameter = "funky name \n"), + hardhat::extract_parameter_dials( + wf_set, + id = "reg_bst", + parameter = "funky name \n" + ), dials::trees(c(1, 100)) ) expect_equal( diff --git a/tests/testthat/test-fit.R b/tests/testthat/test-fit.R index 5bcd649..f7cfc36 100644 --- a/tests/testthat/test-fit.R +++ b/tests/testthat/test-fit.R @@ -23,7 +23,8 @@ car_set_1 <- car_set_2 <- car_set_1 %>% - workflow_map("fit_resamples", + workflow_map( + "fit_resamples", resamples = vfold_cv(mtcars, v = 3), control = tune::control_resamples(save_pred = TRUE) ) diff --git a/tests/testthat/test-fit_best.R b/tests/testthat/test-fit_best.R index c944018..b994da9 100644 --- a/tests/testthat/test-fit_best.R +++ b/tests/testthat/test-fit_best.R @@ -37,7 +37,10 @@ test_that("fit_best fits with correct hyperparameters", { expect_s3_class(fit_best_wf, "workflow") rankings <- rank_results(chi_features_map, "rmse") - tune_res <- extract_workflow_set_result(chi_features_map, rankings$wflow_id[1]) + tune_res <- extract_workflow_set_result( + chi_features_map, + rankings$wflow_id[1] + ) tune_params <- select_best(tune_res, metric = "rmse") manual_wf <- fit_best(tune_res, parameters = tune_params) @@ -52,7 +55,10 @@ test_that("fit_best fits with correct hyperparameters", { expect_s3_class(fit_best_wf_2, "workflow") rankings_2 <- rank_results(chi_features_map, "iic") - tune_res_2 <- extract_workflow_set_result(chi_features_map, rankings_2$wflow_id[1]) + tune_res_2 <- extract_workflow_set_result( + chi_features_map, + rankings_2$wflow_id[1] + ) tune_params_2 <- select_best(tune_res_2, metric = "iic") manual_wf_2 <- fit_best(tune_res_2, parameters = tune_params_2) diff --git a/tests/testthat/test-options.R b/tests/testthat/test-options.R index ba31ae9..193efa4 100644 --- a/tests/testthat/test-options.R +++ b/tests/testthat/test-options.R @@ -41,7 +41,6 @@ test_that("option management", { }) - test_that("option printing", { expect_output( print(two_class_res$option[[1]]), diff --git a/tests/testthat/test-predict.R b/tests/testthat/test-predict.R index 962d4c6..1838530 100644 --- a/tests/testthat/test-predict.R +++ b/tests/testthat/test-predict.R @@ -23,7 +23,8 @@ car_set_1 <- car_set_2 <- car_set_1 %>% - workflow_map("fit_resamples", + workflow_map( + "fit_resamples", resamples = vfold_cv(mtcars, v = 3), control = tune::control_resamples(save_pred = TRUE) ) diff --git a/tests/testthat/test-pull.R b/tests/testthat/test-pull.R index cd7eb3d..a1f7cd9 100644 --- a/tests/testthat/test-pull.R +++ b/tests/testthat/test-pull.R @@ -9,7 +9,8 @@ car_set_1 <- list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), list(lm = lr_spec) ) %>% - workflow_map("fit_resamples", + workflow_map( + "fit_resamples", resamples = vfold_cv(mtcars, v = 3), control = tune::control_resamples(save_pred = TRUE) ) diff --git a/tests/testthat/test-updates.R b/tests/testthat/test-updates.R index 5116cae..a91ad72 100644 --- a/tests/testthat/test-updates.R +++ b/tests/testthat/test-updates.R @@ -30,7 +30,8 @@ test_that("update model", { expect_no_error( new_new_set <- - update_workflow_model(new_set, + update_workflow_model( + new_set, "none_glm", spec = xgb, formula = Class ~ log(A) + B @@ -45,11 +46,14 @@ test_that("update model", { test_that("update recipe", { expect_no_error( - new_set <- update_workflow_recipe(two_class_res, "yj_trans_cart", recipe = rec) + new_set <- update_workflow_recipe( + two_class_res, + "yj_trans_cart", + recipe = rec + ) ) new_rec <- extract_recipe(new_set, id = "yj_trans_cart", estimated = FALSE) - expect_true(all(tidy(new_rec)$type == "normalize")) expect_equal(new_set$result[[4]], list()) diff --git a/tests/testthat/test-workflow-map.R b/tests/testthat/test-workflow-map.R index e685651..13153fd 100644 --- a/tests/testthat/test-workflow-map.R +++ b/tests/testthat/test-workflow-map.R @@ -107,7 +107,6 @@ test_that("missing packages", { }) - test_that("failers", { skip_on_cran() car_set_3 <- diff --git a/tests/testthat/test-workflow_set.R b/tests/testthat/test-workflow_set.R index 2524ed5..12fbe36 100644 --- a/tests/testthat/test-workflow_set.R +++ b/tests/testthat/test-workflow_set.R @@ -17,9 +17,13 @@ test_that("creating workflow sets", { list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), list(lm = lr_spec) ) %>% - workflow_map("fit_resamples", + workflow_map( + "fit_resamples", resamples = vfold_cv(mtcars, v = 3), - control = tune::control_resamples(save_pred = TRUE, save_workflow = TRUE) + control = tune::control_resamples( + save_pred = TRUE, + save_workflow = TRUE + ) ) }) @@ -44,7 +48,10 @@ test_that("creating workflow sets", { ) expect_true( - all(purrr::map_lgl(car_set_1$info, ~ inherits(.x$workflow[[1]], "workflow"))) + all(purrr::map_lgl( + car_set_1$info, + ~ inherits(.x$workflow[[1]], "workflow") + )) ) expect_true( all(purrr::map_lgl(car_set_1$option, ~ inherits(.x, "list"))) @@ -84,7 +91,10 @@ test_that("creating workflow sets", { ) expect_true( - all(purrr::map_lgl(car_set_2$info, ~ inherits(.x$workflow[[1]], "workflow"))) + all(purrr::map_lgl( + car_set_2$info, + ~ inherits(.x$workflow[[1]], "workflow") + )) ) expect_true( all(purrr::map_lgl(car_set_2$option, ~ inherits(.x, "list"))) @@ -108,7 +118,6 @@ test_that("creating workflow sets", { all(purrr::map_lgl(car_set_3$info, tibble::is_tibble)) ) - # ------------------------------------------------------------------------------ # mixed inputs @@ -247,7 +256,11 @@ test_that("correct object type and resamples", { # same resamples since the seed is set expect_no_error( - res_1 <- workflow_map(set_1, "fit_resamples", resamples = bootstraps(mtcars, 3)) + res_1 <- workflow_map( + set_1, + "fit_resamples", + resamples = bootstraps(mtcars, 3) + ) ) res_1$result[[1]] <- lm(pp[[1]], data = mtcars) expect_identical( @@ -270,7 +283,6 @@ test_that("correct object type and resamples", { }) - # ------------------------------------------------------------------------------ test_that("crossing", { @@ -316,7 +328,11 @@ test_that("crossing", { error = TRUE, nrow( workflow_set( - list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp), two = mpg ~ wt + disp), + list( + reg = mpg ~ ., + nonlin = mpg ~ wt + 1 / sqrt(disp), + two = mpg ~ wt + disp + ), list(lm = lr_spec, knn = knn_spec), cross = FALSE ) @@ -332,10 +348,12 @@ test_that("checking resamples", { ctrl <- tune::control_resamples(save_workflow = TRUE) set.seed(1) cv_1 <- vfold_cv(mtcars, v = 5) - f_1 <- lr_spec %>% tune::fit_resamples(mpg ~ wt, resamples = cv_1, control = ctrl) + f_1 <- lr_spec %>% + tune::fit_resamples(mpg ~ wt, resamples = cv_1, control = ctrl) set.seed(2) cv_2 <- vfold_cv(mtcars, v = 5) - f_2 <- lr_spec %>% tune::fit_resamples(mpg ~ disp, resamples = cv_2, control = ctrl) + f_2 <- lr_spec %>% + tune::fit_resamples(mpg ~ disp, resamples = cv_2, control = ctrl) expect_snapshot( error = TRUE, as_workflow_set(wt = f_1, disp = f_2) @@ -343,7 +361,8 @@ test_that("checking resamples", { # Emulate old rset objects attr(cv_2, "fingerprint") <- NULL - f_3 <- lr_spec %>% tune::fit_resamples(mpg ~ disp, resamples = cv_2, control = ctrl) + f_3 <- lr_spec %>% + tune::fit_resamples(mpg ~ disp, resamples = cv_2, control = ctrl) expect_no_error(as_workflow_set(wt = f_1, disp = f_3)) }) @@ -356,7 +375,8 @@ test_that("constructor", { list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), list(lm = lr_spec) ) %>% - workflow_map("fit_resamples", + workflow_map( + "fit_resamples", resamples = vfold_cv(mtcars, v = 3), control = tune::control_resamples(save_pred = TRUE, save_workflow = TRUE) ) From a9638507b99efb86767ef8e0d88f66672d634da6 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Fri, 25 Apr 2025 14:10:38 -0500 Subject: [PATCH 05/11] transition to the base pipe --- R/as_workflow_set.R | 16 ++-- R/checks.R | 8 +- R/collect.R | 24 +++--- R/comments.R | 12 +-- R/fit_best.R | 4 +- R/misc.R | 16 ++-- R/options.R | 8 +- R/pull.R | 4 +- R/rank_results.R | 38 ++++----- R/update.R | 4 +- R/workflow_map.R | 50 ++++++------ R/workflow_set.R | 46 +++++------ README.Rmd | 37 +++++---- README.md | 36 ++++----- man-roxygen/chi_features_set.Rmd | 50 ++++++------ man-roxygen/two_class_set.Rmd | 14 ++-- man/as_workflow_set.Rd | 10 +-- man/chi_features_set.Rd | 50 ++++++------ man/collect_metrics.workflow_set.Rd | 8 +- man/comment_add.Rd | 12 +-- man/fit_best.workflow_set.Rd | 4 +- man/option_add.Rd | 8 +- man/two_class_set.Rd | 14 ++-- man/update_workflow_model.Rd | 4 +- man/workflow_map.Rd | 50 ++++++------ man/workflow_set.Rd | 16 ++-- tests/testthat/_snaps/comments.md | 14 ++-- tests/testthat/_snaps/extract.md | 4 +- tests/testthat/_snaps/pull.md | 8 +- tests/testthat/_snaps/workflow-map.md | 8 +- tests/testthat/_snaps/workflow_set.md | 38 +++++---- tests/testthat/test-collect-extracts.R | 4 +- tests/testthat/test-collect-metrics.R | 6 +- tests/testthat/test-collect-notes.R | 2 +- tests/testthat/test-collect-predictions.R | 18 ++--- tests/testthat/test-comments.R | 32 ++++---- tests/testthat/test-extract.R | 42 +++++----- tests/testthat/test-fit.R | 8 +- tests/testthat/test-fit_best.R | 8 +- tests/testthat/test-options.R | 14 ++-- tests/testthat/test-predict.R | 8 +- tests/testthat/test-pull.R | 12 +-- tests/testthat/test-updates.R | 6 +- tests/testthat/test-workflow-map.R | 32 ++++---- tests/testthat/test-workflow_set.R | 80 +++++++++---------- .../articles/tuning-and-comparing-models.Rmd | 46 +++++------ .../evaluating-different-predictor-sets.Rmd | 29 ++++--- 47 files changed, 484 insertions(+), 478 deletions(-) diff --git a/R/as_workflow_set.R b/R/as_workflow_set.R index 121e76b..83dd237 100644 --- a/R/as_workflow_set.R +++ b/R/as_workflow_set.R @@ -22,7 +22,7 @@ #' # objects to a workflow set #' two_class_res #' -#' results <- two_class_res %>% purrr::pluck("result") +#' results <- two_class_res |> purrr::pluck("result") #' names(results) <- two_class_res$wflow_id #' #' # These are all objects that have been resampled or tuned: @@ -40,13 +40,13 @@ #' lr_spec <- logistic_reg() #' #' main_effects <- -#' workflow() %>% -#' add_model(lr_spec) %>% +#' workflow() |> +#' add_model(lr_spec) |> #' add_formula(Class ~ .) #' #' interactions <- -#' workflow() %>% -#' add_model(lr_spec) %>% +#' workflow() |> +#' add_model(lr_spec) |> #' add_formula(Class ~ (.)^2) #' #' as_workflow_set(main = main_effects, int = interactions) @@ -69,7 +69,7 @@ as_workflow_set <- function(...) { res <- tibble::tibble(wflow_id = names(wflows)) res <- - res %>% + res |> dplyr::mutate( workflow = unname(wflows), info = purrr::map(workflow, get_info), @@ -78,7 +78,7 @@ as_workflow_set <- function(...) { res$result <- vector(mode = "list", length = nrow(res)) res$result[!is_workflow] <- object[!is_workflow] - res %>% - dplyr::select(wflow_id, info, option, result) %>% + res |> + dplyr::select(wflow_id, info, option, result) |> new_workflow_set() } diff --git a/R/checks.R b/R/checks.R index 4a006c6..dd4d6b3 100644 --- a/R/checks.R +++ b/R/checks.R @@ -11,14 +11,14 @@ check_wf_set <- function(x, arg = caller_arg(x), call = caller_env()) { check_consistent_metrics <- function(x, fail = TRUE, call = caller_env()) { metric_info <- - dplyr::distinct(x, .metric, wflow_id) %>% - dplyr::mutate(has = TRUE) %>% + dplyr::distinct(x, .metric, wflow_id) |> + dplyr::mutate(has = TRUE) |> tidyr::pivot_wider( names_from = ".metric", values_from = "has", values_fill = FALSE - ) %>% - dplyr::select(-wflow_id) %>% + ) |> + dplyr::select(-wflow_id) |> purrr::map_dbl(~ sum(!.x)) if (any(metric_info > 0)) { diff --git a/R/collect.R b/R/collect.R index 38e43fd..968e554 100644 --- a/R/collect.R +++ b/R/collect.R @@ -48,10 +48,10 @@ #' collect_metrics(two_class_res) #' #' # Alternatively, if the tuning parameter values are needed: -#' two_class_res %>% -#' dplyr::filter(grepl("cart", wflow_id)) %>% -#' mutate(metrics = map(result, collect_metrics)) %>% -#' dplyr::select(wflow_id, metrics) %>% +#' two_class_res |> +#' dplyr::filter(grepl("cart", wflow_id)) |> +#' mutate(metrics = map(result, collect_metrics)) |> +#' dplyr::select(wflow_id, metrics) |> #' tidyr::unnest(cols = metrics) #' } #' @@ -71,11 +71,11 @@ collect_metrics.workflow_set <- function(x, ..., summarize = TRUE) { ), metrics = purrr::map2(metrics, result, remove_parameters) ) - info <- dplyr::bind_rows(x$info) %>% dplyr::select(-workflow, -comment) + info <- dplyr::bind_rows(x$info) |> dplyr::select(-workflow, -comment) x <- - dplyr::select(x, wflow_id, metrics) %>% - dplyr::bind_cols(info) %>% - tidyr::unnest(cols = c(metrics)) %>% + dplyr::select(x, wflow_id, metrics) |> + dplyr::bind_cols(info) |> + tidyr::unnest(cols = c(metrics)) |> reorder_cols() check_consistent_metrics(x, fail = FALSE) x @@ -137,11 +137,11 @@ collect_predictions.workflow_set <- ) ) } - info <- dplyr::bind_rows(x$info) %>% dplyr::select(-workflow, -comment) + info <- dplyr::bind_rows(x$info) |> dplyr::select(-workflow, -comment) x <- - dplyr::select(x, wflow_id, predictions) %>% - dplyr::bind_cols(info) %>% - tidyr::unnest(cols = c(predictions)) %>% + dplyr::select(x, wflow_id, predictions) |> + dplyr::bind_cols(info) |> + tidyr::unnest(cols = c(predictions)) |> reorder_cols() x } diff --git a/R/comments.R b/R/comments.R index 5eaac4b..66d5128 100644 --- a/R/comments.R +++ b/R/comments.R @@ -17,19 +17,19 @@ #' @examples #' two_class_set #' -#' two_class_set %>% comment_get("none_cart") +#' two_class_set |> comment_get("none_cart") #' #' new_set <- -#' two_class_set %>% -#' comment_add("none_cart", "What does 'cart' stand for\u2753") %>% +#' two_class_set |> +#' comment_add("none_cart", "What does 'cart' stand for\u2753") |> #' comment_add("none_cart", "Classification And Regression Trees.") #' #' comment_print(new_set) #' -#' new_set %>% comment_get("none_cart") +#' new_set |> comment_get("none_cart") #' -#' new_set %>% -#' comment_reset("none_cart") %>% +#' new_set |> +#' comment_reset("none_cart") |> #' comment_get("none_cart") comment_add <- function(x, id, ..., append = TRUE, collapse = "\n") { check_wf_set(x) diff --git a/R/fit_best.R b/R/fit_best.R index 40f4901..636b661 100644 --- a/R/fit_best.R +++ b/R/fit_best.R @@ -52,9 +52,9 @@ tune::fit_best #' chi_features_set #' #' chi_features_res_new <- -#' chi_features_set %>% +#' chi_features_set |> #' # note: must set `save_workflow = TRUE` to use `fit_best()` -#' option_add(control = control_grid(save_workflow = TRUE)) %>% +#' option_add(control = control_grid(save_workflow = TRUE)) |> #' # evaluate with resamples #' workflow_map(resamples = time_val_split, grid = 21, seed = 1, verbose = TRUE) #' diff --git a/R/misc.R b/R/misc.R index c06b13f..cf2edbd 100644 --- a/R/misc.R +++ b/R/misc.R @@ -1,7 +1,7 @@ make_workflow <- function(x, y, call = caller_env()) { exp_classes <- c("formula", "recipe", "workflow_variables") w <- - workflows::workflow() %>% + workflows::workflow() |> workflows::add_model(y) if (inherits(x, "formula")) { w <- workflows::add_formula(w, x) @@ -34,14 +34,14 @@ metric_to_df <- function(x, ...) { collate_metrics <- function(x) { metrics <- - x$result %>% - purrr::map(tune::.get_tune_metrics) %>% - purrr::map(metric_to_df) %>% + x$result |> + purrr::map(tune::.get_tune_metrics) |> + purrr::map(metric_to_df) |> purrr::map_dfr(~ dplyr::mutate(.x, order = 1:nrow(.x))) mean_order <- - metrics %>% - dplyr::group_by(metric) %>% + metrics |> + dplyr::group_by(metric) |> dplyr::summarize( order = mean(order, na.rm = TRUE), n = dplyr::n(), @@ -49,10 +49,10 @@ collate_metrics <- function(x) { ) dplyr::full_join( - dplyr::distinct(metrics) %>% dplyr::select(-order), + dplyr::distinct(metrics) |> dplyr::select(-order), mean_order, by = "metric" - ) %>% + ) |> dplyr::arrange(order) } diff --git a/R/options.R b/R/options.R index a59e735..61f3d72 100644 --- a/R/options.R +++ b/R/options.R @@ -42,14 +42,14 @@ #' #' two_class_set #' -#' two_class_set %>% +#' two_class_set |> #' option_add(grid = 10) #' -#' two_class_set %>% -#' option_add(grid = 10) %>% +#' two_class_set |> +#' option_add(grid = 10) |> #' option_add(grid = 50, id = "none_cart") #' -#' two_class_set %>% +#' two_class_set |> #' option_add_parameters() option_add <- function(x, ..., id = NULL, strict = FALSE) { check_wf_set(x) diff --git a/R/pull.R b/R/pull.R index 3267ff5..2aa3798 100644 --- a/R/pull.R +++ b/R/pull.R @@ -33,7 +33,7 @@ pull_workflow_set_result <- function(x, id) { if (length(id) != 1) { cli::cli_abort("{.arg id} should have a single value.") } - y <- x %>% dplyr::filter(wflow_id == id[1]) + y <- x |> dplyr::filter(wflow_id == id[1]) if (nrow(y) != 1) { cli::cli_abort("No workflow ID found for {.val {id[1]}}.") } @@ -47,7 +47,7 @@ pull_workflow <- function(x, id) { if (length(id) != 1) { cli::cli_abort("{.arg id} should have a single value.") } - y <- x %>% dplyr::filter(wflow_id == id[1]) + y <- x |> dplyr::filter(wflow_id == id[1]) if (nrow(y) != 1) { cli::cli_abort("No workflow ID found for {.val {id[1]}}.") } diff --git a/R/rank_results.R b/R/rank_results.R index db9967c..054cd9d 100644 --- a/R/rank_results.R +++ b/R/rank_results.R @@ -52,7 +52,7 @@ rank_results <- function( eval_time <- tune::choose_eval_time(result_1, metric, eval_time = eval_time) - results <- collect_metrics(x) %>% + results <- collect_metrics(x) |> dplyr::select( wflow_id, .config, @@ -61,32 +61,32 @@ rank_results <- function( std_err, n, dplyr::any_of(".eval_time") - ) %>% - dplyr::full_join(wflow_info, by = "wflow_id") %>% + ) |> + dplyr::full_join(wflow_info, by = "wflow_id") |> dplyr::select(-comment, -workflow) if (!is.null(eval_time) && ".eval_time" %in% names(results)) { results <- results[results$.eval_time == eval_time, ] } - types <- x %>% - dplyr::full_join(wflow_info, by = "wflow_id") %>% + types <- x |> + dplyr::full_join(wflow_info, by = "wflow_id") |> dplyr::mutate( is_race = purrr::map_lgl(result, ~ inherits(.x, "tune_race")), num_rs = purrr::map_int(result, get_num_resamples) - ) %>% + ) |> dplyr::select(wflow_id, is_race, num_rs) ranked <- - dplyr::full_join(results, types, by = "wflow_id") %>% + dplyr::full_join(results, types, by = "wflow_id") |> dplyr::filter(.metric == metric) if (any(ranked$is_race)) { # remove any racing results with less resamples than the total number rm_rows <- - ranked %>% - dplyr::filter(is_race & (num_rs > n)) %>% - dplyr::select(wflow_id, .config) %>% + ranked |> + dplyr::filter(is_race & (num_rs > n)) |> + dplyr::select(wflow_id, .config) |> dplyr::distinct() if (nrow(rm_rows) > 0) { ranked <- dplyr::anti_join(ranked, rm_rows, by = c("wflow_id", ".config")) @@ -104,9 +104,9 @@ rank_results <- function( if (select_best) { best_by_wflow <- - dplyr::group_by(ranked, wflow_id) %>% - dplyr::slice_min(mean, with_ties = FALSE) %>% - dplyr::ungroup() %>% + dplyr::group_by(ranked, wflow_id) |> + dplyr::slice_min(mean, with_ties = FALSE) |> + dplyr::ungroup() |> dplyr::select(wflow_id, .config) ranked <- dplyr::inner_join( ranked, @@ -120,19 +120,19 @@ rank_results <- function( 1, { ranked <- - ranked %>% - dplyr::mutate(rank = rank(mean, ties.method = "random")) %>% + ranked |> + dplyr::mutate(rank = rank(mean, ties.method = "random")) |> dplyr::select(wflow_id, .config, rank) } ) - dplyr::inner_join(results, ranked, by = c("wflow_id", ".config")) %>% - dplyr::arrange(rank) %>% + dplyr::inner_join(results, ranked, by = c("wflow_id", ".config")) |> + dplyr::arrange(rank) |> dplyr::rename(preprocessor = preproc) } get_num_resamples <- function(x) { - purrr::map_dfr(x$splits, ~ .x$id) %>% - dplyr::distinct() %>% + purrr::map_dfr(x$splits, ~ .x$id) |> + dplyr::distinct() |> nrow() } diff --git a/R/update.R b/R/update.R index 9e96550..677f4e0 100644 --- a/R/update.R +++ b/R/update.R @@ -21,8 +21,8 @@ #' library(parsnip) #' #' new_mod <- -#' decision_tree() %>% -#' set_engine("rpart", method = "anova") %>% +#' decision_tree() |> +#' set_engine("rpart", method = "anova") |> #' set_mode("classification") #' #' new_set <- update_workflow_model(two_class_res, "none_cart", spec = new_mod) diff --git a/R/workflow_map.R b/R/workflow_map.R index 6437931..f0036b4 100644 --- a/R/workflow_map.R +++ b/R/workflow_map.R @@ -73,61 +73,61 @@ #' # --------------------------------------------------------------------------- #' #' base_recipe <- -#' recipe(ridership ~ ., data = Chicago) %>% +#' recipe(ridership ~ ., data = Chicago) |> #' # create date features -#' step_date(date) %>% -#' step_holiday(date) %>% +#' step_date(date) |> +#' step_holiday(date) |> #' # remove date from the list of predictors -#' update_role(date, new_role = "id") %>% +#' update_role(date, new_role = "id") |> #' # create dummy variables from factor columns -#' step_dummy(all_nominal()) %>% +#' step_dummy(all_nominal()) |> #' # remove any columns with a single unique value -#' step_zv(all_predictors()) %>% +#' step_zv(all_predictors()) |> #' step_normalize(all_predictors()) #' #' date_only <- -#' recipe(ridership ~ ., data = Chicago) %>% +#' recipe(ridership ~ ., data = Chicago) |> #' # create date features -#' step_date(date) %>% -#' update_role(date, new_role = "id") %>% +#' step_date(date) |> +#' update_role(date, new_role = "id") |> #' # create dummy variables from factor columns -#' step_dummy(all_nominal()) %>% +#' step_dummy(all_nominal()) |> #' # remove any columns with a single unique value #' step_zv(all_predictors()) #' #' date_and_holidays <- -#' recipe(ridership ~ ., data = Chicago) %>% +#' recipe(ridership ~ ., data = Chicago) |> #' # create date features -#' step_date(date) %>% -#' step_holiday(date) %>% +#' step_date(date) |> +#' step_holiday(date) |> #' # remove date from the list of predictors -#' update_role(date, new_role = "id") %>% +#' update_role(date, new_role = "id") |> #' # create dummy variables from factor columns -#' step_dummy(all_nominal()) %>% +#' step_dummy(all_nominal()) |> #' # remove any columns with a single unique value #' step_zv(all_predictors()) #' #' date_and_holidays_and_pca <- -#' recipe(ridership ~ ., data = Chicago) %>% +#' recipe(ridership ~ ., data = Chicago) |> #' # create date features -#' step_date(date) %>% -#' step_holiday(date) %>% +#' step_date(date) |> +#' step_holiday(date) |> #' # remove date from the list of predictors -#' update_role(date, new_role = "id") %>% +#' update_role(date, new_role = "id") |> #' # create dummy variables from factor columns -#' step_dummy(all_nominal()) %>% +#' step_dummy(all_nominal()) |> #' # remove any columns with a single unique value -#' step_zv(all_predictors()) %>% +#' step_zv(all_predictors()) |> #' step_pca(!!stations, num_comp = tune()) #' #' # --------------------------------------------------------------------------- #' -#' lm_spec <- linear_reg() %>% set_engine("lm") +#' lm_spec <- linear_reg() |> set_engine("lm") #' #' # --------------------------------------------------------------------------- #' #' pca_param <- -#' parameters(num_comp()) %>% +#' parameters(num_comp()) |> #' update(num_comp = num_comp(c(0, 20))) #' #' # --------------------------------------------------------------------------- @@ -146,8 +146,8 @@ #' # --------------------------------------------------------------------------- #' #' chi_features_res_new <- -#' chi_features_set %>% -#' option_add(param_info = pca_param, id = "plus_pca_lm") %>% +#' chi_features_set |> +#' option_add(param_info = pca_param, id = "plus_pca_lm") |> #' workflow_map(resamples = time_val_split, grid = 21, seed = 1, verbose = TRUE) #' #' chi_features_res_new diff --git a/R/workflow_set.R b/R/workflow_set.R index 32285a6..124a47a 100644 --- a/R/workflow_set.R +++ b/R/workflow_set.R @@ -73,7 +73,7 @@ #' # ------------------------------------------------------------------------------ #' #' data(cells) -#' cells <- cells %>% dplyr::select(-case) +#' cells <- cells |> dplyr::select(-case) #' #' set.seed(1) #' val_set <- validation_split(cells) @@ -81,27 +81,27 @@ #' # ------------------------------------------------------------------------------ #' #' basic_recipe <- -#' recipe(class ~ ., data = cells) %>% -#' step_YeoJohnson(all_predictors()) %>% +#' recipe(class ~ ., data = cells) |> +#' step_YeoJohnson(all_predictors()) |> #' step_normalize(all_predictors()) #' #' pca_recipe <- -#' basic_recipe %>% +#' basic_recipe |> #' step_pca(all_predictors(), num_comp = tune()) #' #' ss_recipe <- -#' basic_recipe %>% +#' basic_recipe |> #' step_spatialsign(all_predictors()) #' #' # ------------------------------------------------------------------------------ #' #' knn_mod <- -#' nearest_neighbor(neighbors = tune(), weight_func = tune()) %>% -#' set_engine("kknn") %>% +#' nearest_neighbor(neighbors = tune(), weight_func = tune()) |> +#' set_engine("kknn") |> #' set_mode("classification") #' #' lr_mod <- -#' logistic_reg() %>% +#' logistic_reg() |> #' set_engine("glm") #' #' # ------------------------------------------------------------------------------ @@ -151,18 +151,18 @@ workflow_set <- function(preproc, models, cross = TRUE, case_weights = NULL) { # call set_weights outside of mutate call so that dplyr # doesn't prepend possible warnings with "Problem while computing..." wfs <- - purrr::map2(res$preproc, res$model, make_workflow) %>% - set_weights(case_weights) %>% + purrr::map2(res$preproc, res$model, make_workflow) |> + set_weights(case_weights) |> unname() res <- - res %>% + res |> dplyr::mutate( workflow = wfs, info = purrr::map(workflow, get_info), option = purrr::map(1:nrow(res), ~ new_workflow_set_options()), result = purrr::map(1:nrow(res), ~ list()) - ) %>% + ) |> dplyr::select(wflow_id, info, option, result) new_workflow_set(res) @@ -202,9 +202,9 @@ fix_list_names <- function(x) { } cross_objects <- function(preproc, models) { - tidyr::crossing(preproc, models) %>% - dplyr::mutate(pp_nm = names(preproc), mod_nm = names(models)) %>% - dplyr::mutate(wflow_id = paste(pp_nm, mod_nm, sep = "_")) %>% + tidyr::crossing(preproc, models) |> + dplyr::mutate(pp_nm = names(preproc), mod_nm = names(models)) |> + dplyr::mutate(wflow_id = paste(pp_nm, mod_nm, sep = "_")) |> dplyr::select(wflow_id, preproc, model = models) } @@ -215,7 +215,7 @@ fuse_objects <- function(preproc, models) { nms <- tibble::tibble(wflow_id = paste(names(preproc), names(models), sep = "_")) - tibble::tibble(preproc = preproc, model = models) %>% + tibble::tibble(preproc = preproc, model = models) |> dplyr::bind_cols(nms) } @@ -227,16 +227,16 @@ set_weights <- function(workflows, case_weights) { } allowed <- - workflows %>% - purrr::map(extract_spec_parsnip) %>% + workflows |> + purrr::map(extract_spec_parsnip) |> purrr::map_lgl(case_weights_allowed) if (any(!allowed)) { disallowed <- - workflows[!allowed] %>% - purrr::map(extract_spec_parsnip) %>% - purrr::map(purrr::pluck, "engine") %>% - unlist() %>% + workflows[!allowed] |> + purrr::map(extract_spec_parsnip) |> + purrr::map(purrr::pluck, "engine") |> + unlist() |> unique() cli::cli_warn( @@ -268,7 +268,7 @@ case_weights_allowed <- function(spec) { mod_mode <- spec$mode model_info <- - parsnip::get_from_env(paste0(mod_type, "_fit")) %>% + parsnip::get_from_env(paste0(mod_type, "_fit")) |> dplyr::filter(engine == mod_eng & mode == mod_mode) # If weights are used, they are protected data arguments with the canonical diff --git a/README.Rmd b/README.Rmd index af5e3c7..cef05e3 100644 --- a/README.Rmd +++ b/README.Rmd @@ -69,19 +69,19 @@ theme_set(theme_bw()) library(tidymodels) data(Chicago) # Use a small sample to keep file sizes down: -Chicago <- Chicago %>% slice(1:365) +Chicago <- Chicago |> slice(1:365) base_recipe <- - recipe(ridership ~ ., data = Chicago) %>% + recipe(ridership ~ ., data = Chicago) |> # create date features - step_date(date) %>% - step_holiday(date) %>% + step_date(date) |> + step_holiday(date) |> # remove date from the list of predictors - update_role(date, new_role = "id") %>% + update_role(date, new_role = "id") |> # create dummy variables from factor columns - step_dummy(all_nominal()) %>% + step_dummy(all_nominal()) |> # remove any columns with a single unique value - step_zv(all_predictors()) %>% + step_zv(all_predictors()) |> step_normalize(all_predictors()) ``` @@ -89,7 +89,7 @@ To enact a correlation filter, an additional step is used: ```{r filter} filter_rec <- - base_recipe %>% + base_recipe |> step_corr(all_of(stations), threshold = tune()) ``` @@ -98,8 +98,8 @@ Similarly, for PCA: ```{r pca} pca_rec <- - base_recipe %>% - step_pca(all_of(stations), num_comp = tune()) %>% + base_recipe |> + step_pca(all_of(stations), num_comp = tune()) |> step_normalize(all_predictors()) ``` @@ -107,17 +107,17 @@ We might want to assess a few different models, including a regularized method ( ```{r models} regularized_spec <- - linear_reg(penalty = tune(), mixture = tune()) %>% + linear_reg(penalty = tune(), mixture = tune()) |> set_engine("glmnet") cart_spec <- - decision_tree(cost_complexity = tune(), min_n = tune()) %>% - set_engine("rpart") %>% + decision_tree(cost_complexity = tune(), min_n = tune()) |> + set_engine("rpart") |> set_mode("regression") knn_spec <- - nearest_neighbor(neighbors = tune(), weight_func = tune()) %>% - set_engine("kknn") %>% + nearest_neighbor(neighbors = tune(), weight_func = tune()) |> + set_engine("kknn") |> set_mode("regression") ``` @@ -142,7 +142,7 @@ It doesn't make sense to use PCA or a filter with a `glmnet` model. We can remov ```{r rm} chi_models <- - chi_models %>% + chi_models |> anti_join(tibble(wflow_id = c("pca_glmnet", "filter_glmnet")), by = "wflow_id" ) @@ -169,7 +169,7 @@ We'll use simple grid search for these models by running `workflow_map()`. This ```{r tune} set.seed(123) chi_models <- - chi_models %>% + chi_models |> # The first argument is a function name from the {{tune}} package # such as `tune_grid()`, `fit_resamples()`, etc. workflow_map("tune_grid", @@ -197,7 +197,7 @@ autoplot(chi_models, select_best = TRUE) We can determine how well each combination did by looking at the best results per workflow: ```{r best} -rank_results(chi_models, rank_metric = "mae", select_best = TRUE) %>% +rank_results(chi_models, rank_metric = "mae", select_best = TRUE) |> select(rank, mean, model, wflow_id, .config) ``` @@ -217,4 +217,3 @@ This project is released with a [Contributor Code of Conduct](https://contributo - Either way, learn how to create and share a [reprex](https://reprex.tidyverse.org/articles/articles/learn-reprex.html) (a minimal, reproducible example), to clearly communicate about your code. - Check out further details on [contributing guidelines for tidymodels packages](https://www.tidymodels.org/contribute/) and [how to get help](https://www.tidymodels.org/help/). - diff --git a/README.md b/README.md index 4676c2b..d8710a5 100644 --- a/README.md +++ b/README.md @@ -64,19 +64,19 @@ we will build on: library(tidymodels) data(Chicago) # Use a small sample to keep file sizes down: -Chicago <- Chicago %>% slice(1:365) +Chicago <- Chicago |> slice(1:365) base_recipe <- - recipe(ridership ~ ., data = Chicago) %>% + recipe(ridership ~ ., data = Chicago) |> # create date features - step_date(date) %>% - step_holiday(date) %>% + step_date(date) |> + step_holiday(date) |> # remove date from the list of predictors - update_role(date, new_role = "id") %>% + update_role(date, new_role = "id") |> # create dummy variables from factor columns - step_dummy(all_nominal()) %>% + step_dummy(all_nominal()) |> # remove any columns with a single unique value - step_zv(all_predictors()) %>% + step_zv(all_predictors()) |> step_normalize(all_predictors()) ``` @@ -84,7 +84,7 @@ To enact a correlation filter, an additional step is used: ``` r filter_rec <- - base_recipe %>% + base_recipe |> step_corr(all_of(stations), threshold = tune()) ``` @@ -92,8 +92,8 @@ Similarly, for PCA: ``` r pca_rec <- - base_recipe %>% - step_pca(all_of(stations), num_comp = tune()) %>% + base_recipe |> + step_pca(all_of(stations), num_comp = tune()) |> step_normalize(all_predictors()) ``` @@ -102,17 +102,17 @@ method (`glmnet`): ``` r regularized_spec <- - linear_reg(penalty = tune(), mixture = tune()) %>% + linear_reg(penalty = tune(), mixture = tune()) |> set_engine("glmnet") cart_spec <- - decision_tree(cost_complexity = tune(), min_n = tune()) %>% - set_engine("rpart") %>% + decision_tree(cost_complexity = tune(), min_n = tune()) |> + set_engine("rpart") |> set_mode("regression") knn_spec <- - nearest_neighbor(neighbors = tune(), weight_func = tune()) %>% - set_engine("kknn") %>% + nearest_neighbor(neighbors = tune(), weight_func = tune()) |> + set_engine("kknn") |> set_mode("regression") ``` @@ -152,7 +152,7 @@ can remove these easily: ``` r chi_models <- - chi_models %>% + chi_models |> anti_join(tibble(wflow_id = c("pca_glmnet", "filter_glmnet")), by = "wflow_id" ) @@ -194,7 +194,7 @@ the workflows in the `workflow` column: ``` r set.seed(123) chi_models <- - chi_models %>% + chi_models |> # The first argument is a function name from the {{tune}} package # such as `tune_grid()`, `fit_resamples()`, etc. workflow_map("tune_grid", @@ -251,7 +251,7 @@ We can determine how well each combination did by looking at the best results per workflow: ``` r -rank_results(chi_models, rank_metric = "mae", select_best = TRUE) %>% +rank_results(chi_models, rank_metric = "mae", select_best = TRUE) |> select(rank, mean, model, wflow_id, .config) #> # A tibble: 7 × 5 #> rank mean model wflow_id .config diff --git a/man-roxygen/chi_features_set.Rmd b/man-roxygen/chi_features_set.Rmd index 9b43477..014a286 100644 --- a/man-roxygen/chi_features_set.Rmd +++ b/man-roxygen/chi_features_set.Rmd @@ -29,61 +29,61 @@ time_val_split <- # ------------------------------------------------------------------------------ base_recipe <- - recipe(ridership ~ ., data = Chicago) %>% + recipe(ridership ~ ., data = Chicago) |> # create date features - step_date(date) %>% - step_holiday(date) %>% + step_date(date) |> + step_holiday(date) |> # remove date from the list of predictors - update_role(date, new_role = "id") %>% + update_role(date, new_role = "id") |> # create dummy variables from factor columns - step_dummy(all_nominal()) %>% + step_dummy(all_nominal()) |> # remove any columns with a single unique value - step_zv(all_predictors()) %>% + step_zv(all_predictors()) |> step_normalize(all_predictors()) date_only <- - recipe(ridership ~ ., data = Chicago) %>% + recipe(ridership ~ ., data = Chicago) |> # create date features - step_date(date) %>% - update_role(date, new_role = "id") %>% + step_date(date) |> + update_role(date, new_role = "id") |> # create dummy variables from factor columns - step_dummy(all_nominal()) %>% + step_dummy(all_nominal()) |> # remove any columns with a single unique value step_zv(all_predictors()) date_and_holidays <- - recipe(ridership ~ ., data = Chicago) %>% + recipe(ridership ~ ., data = Chicago) |> # create date features - step_date(date) %>% - step_holiday(date) %>% + step_date(date) |> + step_holiday(date) |> # remove date from the list of predictors - update_role(date, new_role = "id") %>% + update_role(date, new_role = "id") |> # create dummy variables from factor columns - step_dummy(all_nominal()) %>% + step_dummy(all_nominal()) |> # remove any columns with a single unique value step_zv(all_predictors()) date_and_holidays_and_pca <- - recipe(ridership ~ ., data = Chicago) %>% + recipe(ridership ~ ., data = Chicago) |> # create date features - step_date(date) %>% - step_holiday(date) %>% + step_date(date) |> + step_holiday(date) |> # remove date from the list of predictors - update_role(date, new_role = "id") %>% + update_role(date, new_role = "id") |> # create dummy variables from factor columns - step_dummy(all_nominal()) %>% + step_dummy(all_nominal()) |> # remove any columns with a single unique value - step_zv(all_predictors()) %>% + step_zv(all_predictors()) |> step_pca(!!stations, num_comp = tune()) # ------------------------------------------------------------------------------ -lm_spec <- linear_reg() %>% set_engine("lm") +lm_spec <- linear_reg() |> set_engine("lm") # ------------------------------------------------------------------------------ pca_param <- - parameters(num_comp()) %>% + parameters(num_comp()) |> update(num_comp = num_comp(c(0, 20))) # ------------------------------------------------------------------------------ @@ -100,8 +100,8 @@ chi_features_set <- # ------------------------------------------------------------------------------ chi_features_res <- - chi_features_set %>% - option_add(param_info = pca_param, id = "plus_pca_lm") %>% + chi_features_set |> + option_add(param_info = pca_param, id = "plus_pca_lm") |> workflow_map(resamples = time_val_split, grid = 21, seed = 1, verbose = TRUE) ``` diff --git a/man-roxygen/two_class_set.Rmd b/man-roxygen/two_class_set.Rmd index 6504a6b..88452b0 100644 --- a/man-roxygen/two_class_set.Rmd +++ b/man-roxygen/two_class_set.Rmd @@ -21,23 +21,23 @@ folds <- vfold_cv(two_class_dat, v = 5) # ------------------------------------------------------------------------------ decision_tree_rpart_spec <- - decision_tree(min_n = tune(), cost_complexity = tune()) %>% - set_engine('rpart') %>% + decision_tree(min_n = tune(), cost_complexity = tune()) |> + set_engine('rpart') |> set_mode('classification') logistic_reg_glm_spec <- - logistic_reg() %>% + logistic_reg() |> set_engine('glm') mars_earth_spec <- - mars(prod_degree = tune()) %>% - set_engine('earth') %>% + mars(prod_degree = tune()) |> + set_engine('earth') |> set_mode('classification') # ------------------------------------------------------------------------------ yj_recipe <- - recipe(Class ~ ., data = two_class_dat) %>% + recipe(Class ~ ., data = two_class_dat) |> step_YeoJohnson(A, B) # ------------------------------------------------------------------------------ @@ -52,7 +52,7 @@ two_class_set <- # ------------------------------------------------------------------------------ two_class_res <- - two_class_set %>% + two_class_set |> workflow_map( resamples = folds, grid = 10, diff --git a/man/as_workflow_set.Rd b/man/as_workflow_set.Rd index bc4cb0d..8a6bd00 100644 --- a/man/as_workflow_set.Rd +++ b/man/as_workflow_set.Rd @@ -50,7 +50,7 @@ sequence of models built in Section 1.3 of Kuhn and Johnson (2019). See # objects to a workflow set two_class_res -results <- two_class_res \%>\% purrr::pluck("result") +results <- two_class_res |> purrr::pluck("result") names(results) <- two_class_res$wflow_id # These are all objects that have been resampled or tuned: @@ -68,13 +68,13 @@ library(workflows) lr_spec <- logistic_reg() main_effects <- - workflow() \%>\% - add_model(lr_spec) \%>\% + workflow() |> + add_model(lr_spec) |> add_formula(Class ~ .) interactions <- - workflow() \%>\% - add_model(lr_spec) \%>\% + workflow() |> + add_model(lr_spec) |> add_formula(Class ~ (.)^2) as_workflow_set(main = main_effects, int = interactions) diff --git a/man/chi_features_set.Rd b/man/chi_features_set.Rd index ca26883..8c10a27 100644 --- a/man/chi_features_set.Rd +++ b/man/chi_features_set.Rd @@ -56,61 +56,61 @@ time_val_split <- # ------------------------------------------------------------------------------ base_recipe <- - recipe(ridership ~ ., data = Chicago) \%>\% + recipe(ridership ~ ., data = Chicago) |> # create date features - step_date(date) \%>\% - step_holiday(date) \%>\% + step_date(date) |> + step_holiday(date) |> # remove date from the list of predictors - update_role(date, new_role = "id") \%>\% + update_role(date, new_role = "id") |> # create dummy variables from factor columns - step_dummy(all_nominal()) \%>\% + step_dummy(all_nominal()) |> # remove any columns with a single unique value - step_zv(all_predictors()) \%>\% + step_zv(all_predictors()) |> step_normalize(all_predictors()) date_only <- - recipe(ridership ~ ., data = Chicago) \%>\% + recipe(ridership ~ ., data = Chicago) |> # create date features - step_date(date) \%>\% - update_role(date, new_role = "id") \%>\% + step_date(date) |> + update_role(date, new_role = "id") |> # create dummy variables from factor columns - step_dummy(all_nominal()) \%>\% + step_dummy(all_nominal()) |> # remove any columns with a single unique value step_zv(all_predictors()) date_and_holidays <- - recipe(ridership ~ ., data = Chicago) \%>\% + recipe(ridership ~ ., data = Chicago) |> # create date features - step_date(date) \%>\% - step_holiday(date) \%>\% + step_date(date) |> + step_holiday(date) |> # remove date from the list of predictors - update_role(date, new_role = "id") \%>\% + update_role(date, new_role = "id") |> # create dummy variables from factor columns - step_dummy(all_nominal()) \%>\% + step_dummy(all_nominal()) |> # remove any columns with a single unique value step_zv(all_predictors()) date_and_holidays_and_pca <- - recipe(ridership ~ ., data = Chicago) \%>\% + recipe(ridership ~ ., data = Chicago) |> # create date features - step_date(date) \%>\% - step_holiday(date) \%>\% + step_date(date) |> + step_holiday(date) |> # remove date from the list of predictors - update_role(date, new_role = "id") \%>\% + update_role(date, new_role = "id") |> # create dummy variables from factor columns - step_dummy(all_nominal()) \%>\% + step_dummy(all_nominal()) |> # remove any columns with a single unique value - step_zv(all_predictors()) \%>\% + step_zv(all_predictors()) |> step_pca(!!stations, num_comp = tune()) # ------------------------------------------------------------------------------ -lm_spec <- linear_reg() \%>\% set_engine("lm") +lm_spec <- linear_reg() |> set_engine("lm") # ------------------------------------------------------------------------------ pca_param <- - parameters(num_comp()) \%>\% + parameters(num_comp()) |> update(num_comp = num_comp(c(0, 20))) # ------------------------------------------------------------------------------ @@ -127,8 +127,8 @@ chi_features_set <- # ------------------------------------------------------------------------------ chi_features_res <- - chi_features_set \%>\% - option_add(param_info = pca_param, id = "plus_pca_lm") \%>\% + chi_features_set |> + option_add(param_info = pca_param, id = "plus_pca_lm") |> workflow_map(resamples = time_val_split, grid = 21, seed = 1, verbose = TRUE) }\if{html}{\out{}} } diff --git a/man/collect_metrics.workflow_set.Rd b/man/collect_metrics.workflow_set.Rd index d22e19a..188847c 100644 --- a/man/collect_metrics.workflow_set.Rd +++ b/man/collect_metrics.workflow_set.Rd @@ -94,10 +94,10 @@ two_class_res collect_metrics(two_class_res) # Alternatively, if the tuning parameter values are needed: -two_class_res \%>\% - dplyr::filter(grepl("cart", wflow_id)) \%>\% - mutate(metrics = map(result, collect_metrics)) \%>\% - dplyr::select(wflow_id, metrics) \%>\% +two_class_res |> + dplyr::filter(grepl("cart", wflow_id)) |> + mutate(metrics = map(result, collect_metrics)) |> + dplyr::select(wflow_id, metrics) |> tidyr::unnest(cols = metrics) } diff --git a/man/comment_add.Rd b/man/comment_add.Rd index f398de4..728df62 100644 --- a/man/comment_add.Rd +++ b/man/comment_add.Rd @@ -41,18 +41,18 @@ its results as you work. Comments can be appended or removed. \examples{ two_class_set -two_class_set \%>\% comment_get("none_cart") +two_class_set |> comment_get("none_cart") new_set <- - two_class_set \%>\% - comment_add("none_cart", "What does 'cart' stand for\u2753") \%>\% + two_class_set |> + comment_add("none_cart", "What does 'cart' stand for\u2753") |> comment_add("none_cart", "Classification And Regression Trees.") comment_print(new_set) -new_set \%>\% comment_get("none_cart") +new_set |> comment_get("none_cart") -new_set \%>\% - comment_reset("none_cart") \%>\% +new_set |> + comment_reset("none_cart") |> comment_get("none_cart") } diff --git a/man/fit_best.workflow_set.Rd b/man/fit_best.workflow_set.Rd index 15bcb9c..49c920f 100644 --- a/man/fit_best.workflow_set.Rd +++ b/man/fit_best.workflow_set.Rd @@ -79,9 +79,9 @@ time_val_split <- chi_features_set chi_features_res_new <- - chi_features_set \%>\% + chi_features_set |> # note: must set `save_workflow = TRUE` to use `fit_best()` - option_add(control = control_grid(save_workflow = TRUE)) \%>\% + option_add(control = control_grid(save_workflow = TRUE)) |> # evaluate with resamples workflow_map(resamples = time_val_split, grid = 21, seed = 1, verbose = TRUE) diff --git a/man/option_add.Rd b/man/option_add.Rd index 45097de..5f55957 100644 --- a/man/option_add.Rd +++ b/man/option_add.Rd @@ -60,13 +60,13 @@ library(tune) two_class_set -two_class_set \%>\% +two_class_set |> option_add(grid = 10) -two_class_set \%>\% - option_add(grid = 10) \%>\% +two_class_set |> + option_add(grid = 10) |> option_add(grid = 50, id = "none_cart") -two_class_set \%>\% +two_class_set |> option_add_parameters() } diff --git a/man/two_class_set.Rd b/man/two_class_set.Rd index 7dd11c6..3087df3 100644 --- a/man/two_class_set.Rd +++ b/man/two_class_set.Rd @@ -48,23 +48,23 @@ folds <- vfold_cv(two_class_dat, v = 5) # ------------------------------------------------------------------------------ decision_tree_rpart_spec <- - decision_tree(min_n = tune(), cost_complexity = tune()) \%>\% - set_engine('rpart') \%>\% + decision_tree(min_n = tune(), cost_complexity = tune()) |> + set_engine('rpart') |> set_mode('classification') logistic_reg_glm_spec <- - logistic_reg() \%>\% + logistic_reg() |> set_engine('glm') mars_earth_spec <- - mars(prod_degree = tune()) \%>\% - set_engine('earth') \%>\% + mars(prod_degree = tune()) |> + set_engine('earth') |> set_mode('classification') # ------------------------------------------------------------------------------ yj_recipe <- - recipe(Class ~ ., data = two_class_dat) \%>\% + recipe(Class ~ ., data = two_class_dat) |> step_YeoJohnson(A, B) # ------------------------------------------------------------------------------ @@ -79,7 +79,7 @@ two_class_set <- # ------------------------------------------------------------------------------ two_class_res <- - two_class_set \%>\% + two_class_set |> workflow_map( resamples = folds, grid = 10, diff --git a/man/update_workflow_model.Rd b/man/update_workflow_model.Rd index 03b7a2a..ab01362 100644 --- a/man/update_workflow_model.Rd +++ b/man/update_workflow_model.Rd @@ -68,8 +68,8 @@ sequence of models built in Section 1.3 of Kuhn and Johnson (2019). See library(parsnip) new_mod <- - decision_tree() \%>\% - set_engine("rpart", method = "anova") \%>\% + decision_tree() |> + set_engine("rpart", method = "anova") |> set_mode("classification") new_set <- update_workflow_model(two_class_res, "none_cart", spec = new_mod) diff --git a/man/workflow_map.Rd b/man/workflow_map.Rd index aec17f5..1b6338e 100644 --- a/man/workflow_map.Rd +++ b/man/workflow_map.Rd @@ -111,61 +111,61 @@ time_val_split <- # --------------------------------------------------------------------------- base_recipe <- - recipe(ridership ~ ., data = Chicago) \%>\% + recipe(ridership ~ ., data = Chicago) |> # create date features - step_date(date) \%>\% - step_holiday(date) \%>\% + step_date(date) |> + step_holiday(date) |> # remove date from the list of predictors - update_role(date, new_role = "id") \%>\% + update_role(date, new_role = "id") |> # create dummy variables from factor columns - step_dummy(all_nominal()) \%>\% + step_dummy(all_nominal()) |> # remove any columns with a single unique value - step_zv(all_predictors()) \%>\% + step_zv(all_predictors()) |> step_normalize(all_predictors()) date_only <- - recipe(ridership ~ ., data = Chicago) \%>\% + recipe(ridership ~ ., data = Chicago) |> # create date features - step_date(date) \%>\% - update_role(date, new_role = "id") \%>\% + step_date(date) |> + update_role(date, new_role = "id") |> # create dummy variables from factor columns - step_dummy(all_nominal()) \%>\% + step_dummy(all_nominal()) |> # remove any columns with a single unique value step_zv(all_predictors()) date_and_holidays <- - recipe(ridership ~ ., data = Chicago) \%>\% + recipe(ridership ~ ., data = Chicago) |> # create date features - step_date(date) \%>\% - step_holiday(date) \%>\% + step_date(date) |> + step_holiday(date) |> # remove date from the list of predictors - update_role(date, new_role = "id") \%>\% + update_role(date, new_role = "id") |> # create dummy variables from factor columns - step_dummy(all_nominal()) \%>\% + step_dummy(all_nominal()) |> # remove any columns with a single unique value step_zv(all_predictors()) date_and_holidays_and_pca <- - recipe(ridership ~ ., data = Chicago) \%>\% + recipe(ridership ~ ., data = Chicago) |> # create date features - step_date(date) \%>\% - step_holiday(date) \%>\% + step_date(date) |> + step_holiday(date) |> # remove date from the list of predictors - update_role(date, new_role = "id") \%>\% + update_role(date, new_role = "id") |> # create dummy variables from factor columns - step_dummy(all_nominal()) \%>\% + step_dummy(all_nominal()) |> # remove any columns with a single unique value - step_zv(all_predictors()) \%>\% + step_zv(all_predictors()) |> step_pca(!!stations, num_comp = tune()) # --------------------------------------------------------------------------- -lm_spec <- linear_reg() \%>\% set_engine("lm") +lm_spec <- linear_reg() |> set_engine("lm") # --------------------------------------------------------------------------- pca_param <- - parameters(num_comp()) \%>\% + parameters(num_comp()) |> update(num_comp = num_comp(c(0, 20))) # --------------------------------------------------------------------------- @@ -184,8 +184,8 @@ chi_features_set <- # --------------------------------------------------------------------------- chi_features_res_new <- - chi_features_set \%>\% - option_add(param_info = pca_param, id = "plus_pca_lm") \%>\% + chi_features_set |> + option_add(param_info = pca_param, id = "plus_pca_lm") |> workflow_map(resamples = time_val_split, grid = 21, seed = 1, verbose = TRUE) chi_features_res_new diff --git a/man/workflow_set.Rd b/man/workflow_set.Rd index 57394b5..80fcb99 100644 --- a/man/workflow_set.Rd +++ b/man/workflow_set.Rd @@ -106,7 +106,7 @@ library(yardstick) # ------------------------------------------------------------------------------ data(cells) -cells <- cells \%>\% dplyr::select(-case) +cells <- cells |> dplyr::select(-case) set.seed(1) val_set <- validation_split(cells) @@ -114,27 +114,27 @@ val_set <- validation_split(cells) # ------------------------------------------------------------------------------ basic_recipe <- - recipe(class ~ ., data = cells) \%>\% - step_YeoJohnson(all_predictors()) \%>\% + recipe(class ~ ., data = cells) |> + step_YeoJohnson(all_predictors()) |> step_normalize(all_predictors()) pca_recipe <- - basic_recipe \%>\% + basic_recipe |> step_pca(all_predictors(), num_comp = tune()) ss_recipe <- - basic_recipe \%>\% + basic_recipe |> step_spatialsign(all_predictors()) # ------------------------------------------------------------------------------ knn_mod <- - nearest_neighbor(neighbors = tune(), weight_func = tune()) \%>\% - set_engine("kknn") \%>\% + nearest_neighbor(neighbors = tune(), weight_func = tune()) |> + set_engine("kknn") |> set_mode("classification") lr_mod <- - logistic_reg() \%>\% + logistic_reg() |> set_engine("glm") # ------------------------------------------------------------------------------ diff --git a/tests/testthat/_snaps/comments.md b/tests/testthat/_snaps/comments.md index 6a87ad0..d504a3a 100644 --- a/tests/testthat/_snaps/comments.md +++ b/tests/testthat/_snaps/comments.md @@ -1,7 +1,7 @@ # test comments Code - two_class_set %>% comment_add("toe", "foot") + comment_add(two_class_set, "toe", "foot") Condition Error in `comment_add()`: ! The `id` value is not in `wflow_id`. @@ -9,7 +9,7 @@ --- Code - two_class_set %>% comment_add(letters, "foot") + comment_add(two_class_set, letters, "foot") Condition Error in `comment_add()`: ! `id` must be a single string, not a character vector. @@ -17,7 +17,7 @@ --- Code - two_class_set %>% comment_add(1:2, "foot") + comment_add(two_class_set, 1:2, "foot") Condition Error in `comment_add()`: ! `id` must be a single string, not an integer vector. @@ -25,7 +25,7 @@ --- Code - two_class_set %>% comment_add("none_cart", 1:2) + comment_add(two_class_set, "none_cart", 1:2) Condition Error in `comment_add()`: ! The comments should be character strings. @@ -33,7 +33,7 @@ --- Code - comments_1 %>% comment_add("none_cart", "Stuff.", append = FALSE) + comment_add(comments_1, "none_cart", "Stuff.", append = FALSE) Condition Error in `comment_add()`: ! There is already a comment for this id and `append = FALSE`. @@ -57,7 +57,7 @@ --- Code - comments_1 %>% comment_reset(letters) + comment_reset(comments_1, letters) Condition Error in `comment_reset()`: ! `id` should be a single character value. @@ -65,7 +65,7 @@ --- Code - comments_1 %>% comment_reset("none_carts") + comment_reset(comments_1, "none_carts") Condition Error in `comment_reset()`: ! The `id` value is not in `wflow_id`. diff --git a/tests/testthat/_snaps/extract.md b/tests/testthat/_snaps/extract.md index ee68802..14c0802 100644 --- a/tests/testthat/_snaps/extract.md +++ b/tests/testthat/_snaps/extract.md @@ -37,7 +37,7 @@ --- Code - car_set_1 %>% extract_workflow_set_result("Gideon Nav") + extract_workflow_set_result(car_set_1, "Gideon Nav") Condition Error in `extract_workflow_set_result()`: ! `id` must correspond to a single row in `x`. @@ -45,7 +45,7 @@ --- Code - car_set_1 %>% extract_workflow("Coronabeth Tridentarius") + extract_workflow(car_set_1, "Coronabeth Tridentarius") Condition Error in `extract_workflow()`: ! `id` must correspond to a single row in `x`. diff --git a/tests/testthat/_snaps/pull.md b/tests/testthat/_snaps/pull.md index 0c2abaf..0279b56 100644 --- a/tests/testthat/_snaps/pull.md +++ b/tests/testthat/_snaps/pull.md @@ -1,7 +1,7 @@ # pulling objects Code - res <- car_set_1 %>% pull_workflow("reg_lm") + res <- pull_workflow(car_set_1, "reg_lm") Condition Warning: `pull_workflow()` was deprecated in workflowsets 0.1.0. @@ -10,7 +10,7 @@ --- Code - res <- car_set_1 %>% pull_workflow_set_result("reg_lm") + res <- pull_workflow_set_result(car_set_1, "reg_lm") Condition Warning: `pull_workflow_set_result()` was deprecated in workflowsets 0.1.0. @@ -19,7 +19,7 @@ --- Code - car_set_1 %>% pull_workflow_set_result("Gideon Nav") + pull_workflow_set_result(car_set_1, "Gideon Nav") Condition Warning: `pull_workflow_set_result()` was deprecated in workflowsets 0.1.0. @@ -30,7 +30,7 @@ --- Code - car_set_1 %>% pull_workflow("Coronabeth Tridentarius") + pull_workflow(car_set_1, "Coronabeth Tridentarius") Condition Warning: `pull_workflow()` was deprecated in workflowsets 0.1.0. diff --git a/tests/testthat/_snaps/workflow-map.md b/tests/testthat/_snaps/workflow-map.md index 4dbe45c..28f624d 100644 --- a/tests/testthat/_snaps/workflow-map.md +++ b/tests/testthat/_snaps/workflow-map.md @@ -1,7 +1,7 @@ # basic mapping Code - two_class_set %>% workflow_map("foo", seed = 1, resamples = folds, grid = 2) + workflow_map(two_class_set, "foo", seed = 1, resamples = folds, grid = 2) Condition Error in `workflow_map()`: ! `fn` must be one of "tune_grid", "tune_bayes", "fit_resamples", "tune_race_anova", "tune_race_win_loss", "tune_sim_anneal", or "tune_cluster", not "foo". @@ -9,7 +9,7 @@ --- Code - two_class_set %>% workflow_map(fn = 1L, seed = 1, resamples = folds, grid = 2) + workflow_map(two_class_set, fn = 1L, seed = 1, resamples = folds, grid = 2) Condition Error in `workflow_map()`: ! `fn` must be a character vector, not the number 1. @@ -17,7 +17,7 @@ --- Code - two_class_set %>% workflow_map(fn = tune::tune_grid, seed = 1, resamples = folds, + workflow_map(two_class_set, fn = tune::tune_grid, seed = 1, resamples = folds, grid = 2) Condition Error in `workflow_map()`: @@ -37,7 +37,7 @@ # failers Code - res_loud <- car_set_3 %>% workflow_map(resamples = folds, seed = 2, verbose = TRUE, + res_loud <- workflow_map(car_set_3, resamples = folds, seed = 2, verbose = TRUE, grid = "a") Message i 1 of 2 tuning: reg_knn diff --git a/tests/testthat/_snaps/workflow_set.md b/tests/testthat/_snaps/workflow_set.md index 128da6a..2952b79 100644 --- a/tests/testthat/_snaps/workflow_set.md +++ b/tests/testthat/_snaps/workflow_set.md @@ -1,25 +1,30 @@ # specifying a column that is not case weights Code - car_set_2 <- workflow_set(list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), - list(lm = lr_spec), case_weights = non_wts) %>% workflow_map("fit_resamples", - resamples = vfold_cv(cars, v = 5)) + car_set_2 <- workflow_map(workflow_set(list(reg = mpg ~ ., nonlin = mpg ~ wt + + 1 / sqrt(disp)), list(lm = lr_spec), case_weights = non_wts), "fit_resamples", + resamples = vfold_cv(cars, v = 5)) Message x Fold1: preprocessor 1/1: Error in `fit()`: ! `col` must select a classed case weights column, as determined by `h... + i For example, it could be a column created by `hardhat::frequency_wei... x Fold2: preprocessor 1/1: Error in `fit()`: ! `col` must select a classed case weights column, as determined by `h... + i For example, it could be a column created by `hardhat::frequency_wei... x Fold3: preprocessor 1/1: Error in `fit()`: ! `col` must select a classed case weights column, as determined by `h... + i For example, it could be a column created by `hardhat::frequency_wei... x Fold4: preprocessor 1/1: Error in `fit()`: ! `col` must select a classed case weights column, as determined by `h... + i For example, it could be a column created by `hardhat::frequency_wei... x Fold5: preprocessor 1/1: Error in `fit()`: ! `col` must select a classed case weights column, as determined by `h... + i For example, it could be a column created by `hardhat::frequency_wei... Condition Warning: All models failed. Run `show_notes(.Last.tune.result)` for more information. @@ -27,18 +32,23 @@ x Fold1: preprocessor 1/1: Error in `fit()`: ! `col` must select a classed case weights column, as determined by `h... + i For example, it could be a column created by `hardhat::frequency_wei... x Fold2: preprocessor 1/1: Error in `fit()`: ! `col` must select a classed case weights column, as determined by `h... + i For example, it could be a column created by `hardhat::frequency_wei... x Fold3: preprocessor 1/1: Error in `fit()`: ! `col` must select a classed case weights column, as determined by `h... + i For example, it could be a column created by `hardhat::frequency_wei... x Fold4: preprocessor 1/1: Error in `fit()`: ! `col` must select a classed case weights column, as determined by `h... + i For example, it could be a column created by `hardhat::frequency_wei... x Fold5: preprocessor 1/1: Error in `fit()`: ! `col` must select a classed case weights column, as determined by `h... + i For example, it could be a column created by `hardhat::frequency_wei... Condition Warning: All models failed. Run `show_notes(.Last.tune.result)` for more information. @@ -48,7 +58,7 @@ Code class_note$note[1] Output - [1] "Error in `fit()`:\n! `col` must select a classed case weights column, as determined by `hardhat::is_case_weights()`. For example, it could be a column created by `hardhat::frequency_weights()` or `hardhat::importance_weights()`." + [1] "Error in `fit()`:\n! `col` must select a classed case weights column, as determined by `hardhat::is_case_weights()`.\ni For example, it could be a column created by `hardhat::frequency_weights()` or `hardhat::importance_weights()`." # specifying an engine that does not allow case weights @@ -63,9 +73,9 @@ # specifying a case weight column that isn't in the resamples Code - car_set_4 <- workflow_set(list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), - list(lm = lr_spec), case_weights = boop) %>% workflow_map("fit_resamples", - resamples = vfold_cv(cars, v = 5)) + car_set_4 <- workflow_map(workflow_set(list(reg = mpg ~ ., nonlin = mpg ~ wt + + 1 / sqrt(disp)), list(lm = lr_spec), case_weights = boop), "fit_resamples", + resamples = vfold_cv(cars, v = 5)) Message x Fold1: preprocessor 1/1: Error in `fit()`: @@ -136,14 +146,14 @@ Code as_workflow_set(wt = f_1, disp = f_2) Condition - Error: + Error in `as_workflow_set()`: ! Different resamples were used in the workflow results. i All elements of result must use the same resamples. # constructor Code - new_workflow_set(car_set_1 %>% dplyr::select(-info)) + new_workflow_set(dplyr::select(car_set_1, -info)) Condition Error: ! The object should have columns wflow_id, info, option, and result. @@ -151,7 +161,7 @@ --- Code - new_workflow_set(car_set_1 %>% dplyr::mutate(info = "a")) + new_workflow_set(dplyr::mutate(car_set_1, info = "a")) Condition Error: ! The info column should be a list. @@ -159,7 +169,7 @@ --- Code - new_workflow_set(car_set_1 %>% dplyr::mutate(result = "a")) + new_workflow_set(dplyr::mutate(car_set_1, result = "a")) Condition Error: ! The result column should be a list. @@ -167,7 +177,7 @@ --- Code - new_workflow_set(car_set_1 %>% dplyr::mutate(option = "a")) + new_workflow_set(dplyr::mutate(car_set_1, option = "a")) Condition Error: ! The option column should be a list. @@ -175,7 +185,7 @@ --- Code - new_workflow_set(car_set_1 %>% dplyr::mutate(wflow_id = 1)) + new_workflow_set(dplyr::mutate(car_set_1, wflow_id = 1)) Condition Error: ! The wflow_id column should be character. @@ -183,7 +193,7 @@ --- Code - new_workflow_set(car_set_1 %>% dplyr::mutate(wflow_id = "a")) + new_workflow_set(dplyr::mutate(car_set_1, wflow_id = "a")) Condition Error: ! The wflow_id column should contain unique, non-missing character strings. diff --git a/tests/testthat/test-collect-extracts.R b/tests/testthat/test-collect-extracts.R index 758faa9..7ac19f6 100644 --- a/tests/testthat/test-collect-extracts.R +++ b/tests/testthat/test-collect-extracts.R @@ -11,7 +11,7 @@ test_that("collect_extracts works", { ) wflow_set_trained <- - wflow_set %>% + wflow_set |> workflow_map( "fit_resamples", resamples = folds, @@ -42,7 +42,7 @@ test_that("collect_extracts fails gracefully without .extracts column", { ) wflow_set_trained <- - wflow_set %>% + wflow_set |> workflow_map("fit_resamples", resamples = folds) expect_snapshot( diff --git a/tests/testthat/test-collect-metrics.R b/tests/testthat/test-collect-metrics.R index aafdf0e..66f78d9 100644 --- a/tests/testthat/test-collect-metrics.R +++ b/tests/testthat/test-collect-metrics.R @@ -8,12 +8,12 @@ check_metric_results <- function(ind, x, ...) { } orig <- - collect_metrics(x$result[[ind]], ...) %>% + collect_metrics(x$result[[ind]], ...) |> dplyr::select(dplyr::all_of(cols)) everythng <- - collect_metrics(x, ...) %>% - dplyr::filter(wflow_id == id_val) %>% + collect_metrics(x, ...) |> + dplyr::filter(wflow_id == id_val) |> dplyr::select(dplyr::all_of(cols)) all.equal(orig, everythng) } diff --git a/tests/testthat/test-collect-notes.R b/tests/testthat/test-collect-notes.R index 9eb9e67..c8f4d05 100644 --- a/tests/testthat/test-collect-notes.R +++ b/tests/testthat/test-collect-notes.R @@ -11,7 +11,7 @@ test_that("collect_notes works", { ) wflow_set_trained <- - wflow_set %>% + wflow_set |> workflow_map( "fit_resamples", resamples = folds, diff --git a/tests/testthat/test-collect-predictions.R b/tests/testthat/test-collect-predictions.R index 3ea8fe4..f218dfb 100644 --- a/tests/testthat/test-collect-predictions.R +++ b/tests/testthat/test-collect-predictions.R @@ -10,10 +10,10 @@ suppressPackageStartupMessages(library(tune)) # ------------------------------------------------------------------------------ -lr_spec <- linear_reg() %>% set_engine("lm") +lr_spec <- linear_reg() |> set_engine("lm") knn_spec <- - nearest_neighbor(neighbors = tune()) %>% - set_engine("kknn") %>% + nearest_neighbor(neighbors = tune()) |> + set_engine("kknn") |> set_mode("regression") set.seed(1) @@ -21,7 +21,7 @@ car_set_1 <- workflow_set( list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), list(lm = lr_spec) - ) %>% + ) |> workflow_map( "fit_resamples", resamples = vfold_cv(mtcars, v = 3), @@ -36,7 +36,7 @@ car_set_2 <- workflow_set( list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), list(lm = lr_spec) - ) %>% + ) |> workflow_map( "fit_resamples", resamples = resamples, @@ -48,7 +48,7 @@ car_set_3 <- workflow_set( list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), list(knn = knn_spec) - ) %>% + ) |> workflow_map( "tune_bayes", resamples = resamples, @@ -68,7 +68,7 @@ check_prediction_results <- function(ind, x, summarize = FALSE, ...) { cols <- c(".row", "mpg", ".config", ".pred") orig <- - collect_predictions(x$result[[ind]], summarize = summarize, ...) %>% + collect_predictions(x$result[[ind]], summarize = summarize, ...) |> dplyr::select(dplyr::all_of(cols)) if (any(names(list(...)) == "summarize")) { @@ -76,8 +76,8 @@ check_prediction_results <- function(ind, x, summarize = FALSE, ...) { } everythng <- - collect_predictions(x, summarize = summarize, ...) %>% - dplyr::filter(wflow_id == id_val) %>% + collect_predictions(x, summarize = summarize, ...) |> + dplyr::filter(wflow_id == id_val) |> dplyr::select(dplyr::all_of(cols)) all.equal(orig, everythng) } diff --git a/tests/testthat/test-comments.R b/tests/testthat/test-comments.R index 94b0018..b6f9f3a 100644 --- a/tests/testthat/test-comments.R +++ b/tests/testthat/test-comments.R @@ -2,7 +2,7 @@ test_that("test comments", { comments_1 <- - two_class_set %>% + two_class_set |> comment_add("none_cart", "What does 'cart' stand for\u2753") expect_equal( @@ -14,46 +14,46 @@ test_that("test comments", { expect_equal(comments_1$info[[i]]$comment, character(1)) } comments_2 <- - comments_1 %>% + comments_1 |> comment_add("none_cart", "Stuff.") expect_equal( - comment_get(comments_2, id = "none_cart") %>% paste0(collapse = "\n"), + comment_get(comments_2, id = "none_cart") |> paste0(collapse = "\n"), "What does 'cart' stand for\u2753\nStuff." ) comments_3 <- - comments_2 %>% + comments_2 |> comment_reset("none_cart") expect_equal( comments_3$info[[1]]$comment, character(1) ) expect_equal( - two_class_set %>% comment_add(), + two_class_set |> comment_add(), two_class_set ) expect_equal( - two_class_set %>% comment_add("none_cart"), + two_class_set |> comment_add("none_cart"), two_class_set ) expect_snapshot( error = TRUE, - two_class_set %>% comment_add("toe", "foot") + two_class_set |> comment_add("toe", "foot") ) expect_snapshot( error = TRUE, - two_class_set %>% comment_add(letters, "foot") + two_class_set |> comment_add(letters, "foot") ) expect_snapshot( error = TRUE, - two_class_set %>% comment_add(1:2, "foot") + two_class_set |> comment_add(1:2, "foot") ) expect_snapshot( error = TRUE, - two_class_set %>% comment_add("none_cart", 1:2) + two_class_set |> comment_add("none_cart", 1:2) ) expect_snapshot( error = TRUE, - comments_1 %>% comment_add("none_cart", "Stuff.", append = FALSE) + comments_1 |> comment_add("none_cart", "Stuff.", append = FALSE) ) expect_snapshot( error = TRUE, @@ -65,11 +65,11 @@ test_that("test comments", { ) expect_snapshot( error = TRUE, - comments_1 %>% comment_reset(letters) + comments_1 |> comment_reset(letters) ) expect_snapshot( error = TRUE, - comments_1 %>% comment_reset("none_carts") + comments_1 |> comment_reset("none_carts") ) }) @@ -79,9 +79,9 @@ test_that("print comments", { gatsby_3 <- "Across the courtesy bay the white palaces of fashionable East Egg glittered along the water, and the history of the summer really begins on the evening I drove over there to have dinner with the Tom Buchanans. Daisy was my second cousin once removed and I'd known Tom in college. And just after the war I spent two days with them in Chicago." test <- - two_class_res %>% - comment_add("none_cart", gatsby_1) %>% - comment_add("none_cart", gatsby_2) %>% + two_class_res |> + comment_add("none_cart", gatsby_1) |> + comment_add("none_cart", gatsby_2) |> comment_add("none_glm", gatsby_3) expect_snapshot(comment_print(test)) diff --git a/tests/testthat/test-extract.R b/tests/testthat/test-extract.R index d877ea5..e9d9d02 100644 --- a/tests/testthat/test-extract.R +++ b/tests/testthat/test-extract.R @@ -5,17 +5,17 @@ library(rsample) library(recipes) data(Chicago, package = "modeldata") -lr_spec <- linear_reg() %>% set_engine("lm") +lr_spec <- linear_reg() |> set_engine("lm") set.seed(1) car_set_1 <- workflow_set( list( - reg = recipe(mpg ~ ., data = mtcars) %>% step_log(disp), + reg = recipe(mpg ~ ., data = mtcars) |> step_log(disp), nonlin = mpg ~ wt + 1 / sqrt(disp) ), list(lm = lr_spec) - ) %>% + ) |> workflow_map( "fit_resamples", resamples = vfold_cv(mtcars, v = 3), @@ -61,35 +61,35 @@ test_that("extracts", { ) expect_equal( - car_set_1 %>% extract_workflow("reg_lm"), + car_set_1 |> extract_workflow("reg_lm"), car_set_1$info[[1]]$workflow[[1]] ) expect_equal( - car_set_1 %>% extract_workflow_set_result("reg_lm"), + car_set_1 |> extract_workflow_set_result("reg_lm"), car_set_1$result[[1]] ) expect_snapshot(error = TRUE, { - car_set_1 %>% extract_workflow_set_result("Gideon Nav") + car_set_1 |> extract_workflow_set_result("Gideon Nav") }) expect_snapshot(error = TRUE, { - car_set_1 %>% extract_workflow("Coronabeth Tridentarius") + car_set_1 |> extract_workflow("Coronabeth Tridentarius") }) }) test_that("extract parameter set from workflow set with untunable workflow", { - rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>% + rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) |> recipes::step_rm(date, ends_with("away")) - lm_model <- parsnip::linear_reg() %>% + lm_model <- parsnip::linear_reg() |> parsnip::set_engine("lm") bst_model <- parsnip::boost_tree( mode = "classification", trees = hardhat::tune("funky name \n") - ) %>% + ) |> parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE) wf_set <- workflow_set( list(reg = rm_rec), @@ -102,15 +102,15 @@ test_that("extract parameter set from workflow set with untunable workflow", { }) test_that("extract parameter set from workflow set with tunable workflow", { - rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>% + rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) |> recipes::step_rm(date, ends_with("away")) - lm_model <- parsnip::linear_reg() %>% + lm_model <- parsnip::linear_reg() |> parsnip::set_engine("lm") bst_model <- parsnip::boost_tree( mode = "classification", trees = hardhat::tune("funky name \n") - ) %>% + ) |> parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE) wf_set <- workflow_set( list(reg = rm_rec), @@ -136,7 +136,7 @@ test_that("extract parameter set from workflow set with tunable workflow", { expect_equal(c5_info$object[[2]], NA) c5_new_info <- - c5_info %>% + c5_info |> update( rules = dials::new_qual_param( "logical", @@ -146,7 +146,7 @@ test_that("extract parameter set from workflow set with tunable workflow", { ) wf_set_2 <- - wf_set %>% + wf_set |> option_add(id = "reg_bst", param_info = c5_new_info) check_parameter_set_tibble(c5_new_info) @@ -159,15 +159,15 @@ test_that("extract parameter set from workflow set with tunable workflow", { test_that("extract single parameter from workflow set with untunable workflow", { - rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>% + rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) |> recipes::step_rm(date, ends_with("away")) - lm_model <- parsnip::linear_reg() %>% + lm_model <- parsnip::linear_reg() |> parsnip::set_engine("lm") bst_model <- parsnip::boost_tree( mode = "classification", trees = hardhat::tune("funky name \n") - ) %>% + ) |> parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE) wf_set <- workflow_set( list(reg = rm_rec), @@ -185,15 +185,15 @@ test_that("extract single parameter from workflow set with untunable workflow", }) test_that("extract single parameter from workflow set with tunable workflow", { - rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>% + rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) |> recipes::step_rm(date, ends_with("away")) - lm_model <- parsnip::linear_reg() %>% + lm_model <- parsnip::linear_reg() |> parsnip::set_engine("lm") bst_model <- parsnip::boost_tree( mode = "classification", trees = hardhat::tune("funky name \n") - ) %>% + ) |> parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE) wf_set <- workflow_set( list(reg = rm_rec), diff --git a/tests/testthat/test-fit.R b/tests/testthat/test-fit.R index f7cfc36..8eca763 100644 --- a/tests/testthat/test-fit.R +++ b/tests/testthat/test-fit.R @@ -8,10 +8,10 @@ suppressPackageStartupMessages(library(tune)) # ------------------------------------------------------------------------------ -lr_spec <- linear_reg() %>% set_engine("lm") +lr_spec <- linear_reg() |> set_engine("lm") knn_spec <- - nearest_neighbor(neighbors = tune()) %>% - set_engine("kknn") %>% + nearest_neighbor(neighbors = tune()) |> + set_engine("kknn") |> set_mode("regression") set.seed(1) @@ -22,7 +22,7 @@ car_set_1 <- ) car_set_2 <- - car_set_1 %>% + car_set_1 |> workflow_map( "fit_resamples", resamples = vfold_cv(mtcars, v = 3), diff --git a/tests/testthat/test-fit_best.R b/tests/testthat/test-fit_best.R index b994da9..8b1400f 100644 --- a/tests/testthat/test-fit_best.R +++ b/tests/testthat/test-fit_best.R @@ -22,12 +22,12 @@ test_that("fit_best fits with correct hyperparameters", { ) chi_features_map <- - chi_features_set %>% + chi_features_set |> option_add( control = control_grid(save_workflow = TRUE), # choose metrics resulting in different rankings metrics = metric_set(rmse, iic) - ) %>% + ) |> workflow_map(resamples = time_val_split, grid = 21, seed = 1) chi_features_map @@ -90,11 +90,11 @@ test_that("fit_best errors informatively with bad inputs", { ) chi_features_map <- - chi_features_set %>% + chi_features_set |> option_add( # set needed `save_workflow` option control = control_grid(save_workflow = TRUE) - ) %>% + ) |> workflow_map(resamples = time_val_split, grid = 21, seed = 1) expect_snapshot( diff --git a/tests/testthat/test-options.R b/tests/testthat/test-options.R index 193efa4..3d1061b 100644 --- a/tests/testthat/test-options.R +++ b/tests/testthat/test-options.R @@ -2,26 +2,26 @@ test_that("option management", { expect_no_error( - set_1 <- two_class_set %>% option_add(grid = 1) + set_1 <- two_class_set |> option_add(grid = 1) ) for (i in 1:nrow(set_1)) { expect_equal(unclass(set_1$option[[i]]), list(grid = 1)) } expect_no_error( - set_2 <- two_class_set %>% option_remove(grid) + set_2 <- two_class_set |> option_remove(grid) ) for (i in 1:nrow(set_2)) { expect_equal(unclass(set_2$option[[i]]), list()) } expect_no_error( - set_3 <- two_class_set %>% option_add(grid = 1, id = "none_cart") + set_3 <- two_class_set |> option_add(grid = 1, id = "none_cart") ) expect_equal(unclass(set_3$option[[1]]), list(grid = 1)) for (i in 2:nrow(set_3)) { expect_equal(unclass(set_3$option[[i]]), list()) } expect_no_error( - set_4 <- two_class_set %>% option_add_parameters() + set_4 <- two_class_set |> option_add_parameters() ) for (i in which(!grepl("glm", set_4$wflow_id))) { expect_true(all(names(set_4$option[[i]]) == "param_info")) @@ -31,7 +31,7 @@ test_that("option management", { expect_equal(unclass(set_4$option[[i]]), list()) } expect_no_error( - set_5 <- two_class_set %>% option_add_parameters(id = "none_cart") + set_5 <- two_class_set |> option_add_parameters(id = "none_cart") ) expect_true(all(names(set_5$option[[1]]) == "param_info")) expect_true(inherits(set_5$option[[1]]$param_info, "parameters")) @@ -55,9 +55,9 @@ test_that("option printing", { test_that("check for bad options", { expect_snapshot_error( - two_class_set %>% option_add(grid2 = 1) + two_class_set |> option_add(grid2 = 1) ) expect_snapshot_error( - two_class_set %>% option_add(grid = 1, blueprint = 2) + two_class_set |> option_add(grid = 1, blueprint = 2) ) }) diff --git a/tests/testthat/test-predict.R b/tests/testthat/test-predict.R index 1838530..e31fe69 100644 --- a/tests/testthat/test-predict.R +++ b/tests/testthat/test-predict.R @@ -8,10 +8,10 @@ suppressPackageStartupMessages(library(tune)) # ------------------------------------------------------------------------------ -lr_spec <- linear_reg() %>% set_engine("lm") +lr_spec <- linear_reg() |> set_engine("lm") knn_spec <- - nearest_neighbor(neighbors = tune()) %>% - set_engine("kknn") %>% + nearest_neighbor(neighbors = tune()) |> + set_engine("kknn") |> set_mode("regression") set.seed(1) @@ -22,7 +22,7 @@ car_set_1 <- ) car_set_2 <- - car_set_1 %>% + car_set_1 |> workflow_map( "fit_resamples", resamples = vfold_cv(mtcars, v = 3), diff --git a/tests/testthat/test-pull.R b/tests/testthat/test-pull.R index a1f7cd9..66db751 100644 --- a/tests/testthat/test-pull.R +++ b/tests/testthat/test-pull.R @@ -1,14 +1,14 @@ library(parsnip) library(rsample) -lr_spec <- linear_reg() %>% set_engine("lm") +lr_spec <- linear_reg() |> set_engine("lm") set.seed(1) car_set_1 <- workflow_set( list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), list(lm = lr_spec) - ) %>% + ) |> workflow_map( "fit_resamples", resamples = vfold_cv(mtcars, v = 3), @@ -18,18 +18,18 @@ car_set_1 <- # ------------------------------------------------------------------------------ test_that("pulling objects", { - expect_snapshot(res <- car_set_1 %>% pull_workflow("reg_lm")) + expect_snapshot(res <- car_set_1 |> pull_workflow("reg_lm")) expect_equal(res, car_set_1$info[[1]]$workflow[[1]]) - expect_snapshot(res <- car_set_1 %>% pull_workflow_set_result("reg_lm")) + expect_snapshot(res <- car_set_1 |> pull_workflow_set_result("reg_lm")) expect_equal(res, car_set_1$result[[1]]) expect_snapshot( error = TRUE, - car_set_1 %>% pull_workflow_set_result("Gideon Nav") + car_set_1 |> pull_workflow_set_result("Gideon Nav") ) expect_snapshot( error = TRUE, - car_set_1 %>% pull_workflow("Coronabeth Tridentarius") + car_set_1 |> pull_workflow("Coronabeth Tridentarius") ) }) diff --git a/tests/testthat/test-updates.R b/tests/testthat/test-updates.R index a91ad72..72e35ee 100644 --- a/tests/testthat/test-updates.R +++ b/tests/testthat/test-updates.R @@ -7,10 +7,10 @@ library(hardhat) data(two_class_dat, package = "modeldata") -xgb <- boost_tree(trees = 3) %>% set_mode("classification") +xgb <- boost_tree(trees = 3) |> set_mode("classification") rec <- - recipe(Class ~ A + B, two_class_dat) %>% - step_normalize(A) %>% + recipe(Class ~ A + B, two_class_dat) |> + step_normalize(A) |> step_normalize(B) sparse_bp <- default_recipe_blueprint(composition = "dgCMatrix") diff --git a/tests/testthat/test-workflow-map.R b/tests/testthat/test-workflow-map.R index 13153fd..51e6213 100644 --- a/tests/testthat/test-workflow-map.R +++ b/tests/testthat/test-workflow-map.R @@ -8,13 +8,13 @@ library(kknn) # ------------------------------------------------------------------------------ -lr_spec <- linear_reg() %>% set_engine("lm") +lr_spec <- linear_reg() |> set_engine("lm") knn_spec <- - nearest_neighbor(neighbors = tune()) %>% - set_engine("kknn") %>% + nearest_neighbor(neighbors = tune()) |> + set_engine("kknn") |> set_mode("regression") glmn_spec <- - linear_reg(penalty = tune()) %>% + linear_reg(penalty = tune()) |> set_engine("glmnet") set.seed(1) @@ -24,7 +24,7 @@ car_set_1 <- workflow_set( list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), list(lm = lr_spec, knn = knn_spec) - ) %>% + ) |> dplyr::slice(-4) # ------------------------------------------------------------------------------ @@ -32,14 +32,14 @@ car_set_1 <- test_that("basic mapping", { expect_no_error({ res_1 <- - car_set_1 %>% + car_set_1 |> workflow_map(resamples = folds, seed = 2, grid = 2) }) # check reproducibility expect_no_error({ res_2 <- - car_set_1 %>% + car_set_1 |> workflow_map(resamples = folds, seed = 2, grid = 2) }) expect_equal(collect_metrics(res_1), collect_metrics(res_2)) @@ -48,19 +48,19 @@ test_that("basic mapping", { expect_snapshot( error = TRUE, - two_class_set %>% + two_class_set |> workflow_map("foo", seed = 1, resamples = folds, grid = 2) ) expect_snapshot( error = TRUE, - two_class_set %>% + two_class_set |> workflow_map(fn = 1L, seed = 1, resamples = folds, grid = 2) ) expect_snapshot( error = TRUE, - two_class_set %>% + two_class_set |> workflow_map(fn = tune::tune_grid, seed = 1, resamples = folds, grid = 2) ) }) @@ -73,7 +73,7 @@ test_that("map logging", { logging_res <- capture.output( res <- - car_set_1 %>% + car_set_1 |> workflow_map(resamples = folds, seed = 2, verbose = TRUE), type = "message" ) @@ -95,7 +95,7 @@ test_that("missing packages", { expect_snapshot( { res <- - car_set_2 %>% + car_set_2 |> workflow_map(resamples = folds, seed = 2, verbose = FALSE) }, transform = function(lines) { @@ -117,7 +117,7 @@ test_that("failers", { expect_no_error({ res_quiet <- - car_set_3 %>% + car_set_3 |> workflow_map(resamples = folds, seed = 2, verbose = FALSE, grid = "a") }) expect_true(inherits(res_quiet, "workflow_set")) @@ -126,7 +126,7 @@ test_that("failers", { expect_snapshot( { res_loud <- - car_set_3 %>% + car_set_3 |> workflow_map(resamples = folds, seed = 2, verbose = TRUE, grid = "a") }, transform = function(lines) { @@ -144,7 +144,7 @@ test_that("workflow_map can handle cluster specifications", { library(recipes) set.seed(1) - mtcars_tbl <- mtcars %>% dplyr::select(where(is.numeric)) + mtcars_tbl <- mtcars |> dplyr::select(where(is.numeric)) folds <- vfold_cv(mtcars_tbl, v = 3) wf_set_spec <- @@ -165,7 +165,7 @@ test_that("fail informatively on mismatched spec/tuning function", { library(tidyclust) set.seed(1) - mtcars_tbl <- mtcars %>% dplyr::select(where(is.numeric)) + mtcars_tbl <- mtcars |> dplyr::select(where(is.numeric)) folds <- vfold_cv(mtcars_tbl, v = 3) wf_set_1 <- diff --git a/tests/testthat/test-workflow_set.R b/tests/testthat/test-workflow_set.R index 12fbe36..2c8a246 100644 --- a/tests/testthat/test-workflow_set.R +++ b/tests/testthat/test-workflow_set.R @@ -2,9 +2,9 @@ library(parsnip) library(rsample) library(rlang) -lr_spec <- linear_reg() %>% set_engine("lm") -knn_spec <- nearest_neighbor() %>% - set_engine("kknn") %>% +lr_spec <- linear_reg() |> set_engine("lm") +knn_spec <- nearest_neighbor() |> + set_engine("kknn") |> set_mode("regression") # ------------------------------------------------------------------------------ @@ -16,7 +16,7 @@ test_that("creating workflow sets", { workflow_set( list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), list(lm = lr_spec) - ) %>% + ) |> workflow_map( "fit_resamples", resamples = vfold_cv(mtcars, v = 3), @@ -127,13 +127,13 @@ test_that("creating workflow sets", { }) test_that("workflow_set can handle correctly passed case weights", { - lr_spec <- linear_reg() %>% set_engine("lm") + lr_spec <- linear_reg() |> set_engine("lm") cars <- - mtcars %>% + mtcars |> dplyr::mutate( - wts = hardhat::importance_weights(1:nrow(.)), - non_wts = 1:nrow(.) + wts = hardhat::importance_weights(1:nrow(mtcars)), + non_wts = 1:nrow(mtcars) ) expect_silent({ @@ -142,7 +142,7 @@ test_that("workflow_set can handle correctly passed case weights", { list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), list(lm = lr_spec), case_weights = wts - ) %>% + ) |> workflow_map( "fit_resamples", resamples = vfold_cv(cars, v = 5) @@ -153,13 +153,13 @@ test_that("workflow_set can handle correctly passed case weights", { }) test_that("specifying a column that is not case weights", { - lr_spec <- linear_reg() %>% set_engine("lm") + lr_spec <- linear_reg() |> set_engine("lm") cars <- - mtcars %>% + mtcars |> dplyr::mutate( - wts = hardhat::importance_weights(1:nrow(.)), - non_wts = 1:nrow(.) + wts = hardhat::importance_weights(1:nrow(mtcars)), + non_wts = 1:nrow(mtcars) ) expect_snapshot({ @@ -168,31 +168,31 @@ test_that("specifying a column that is not case weights", { list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), list(lm = lr_spec), case_weights = non_wts - ) %>% + ) |> workflow_map( "fit_resamples", resamples = vfold_cv(cars, v = 5) ) }) - class_note <- extract_workflow_set_result(car_set_2, "reg_lm") %>% - tune::collect_notes() %>% + class_note <- extract_workflow_set_result(car_set_2, "reg_lm") |> + tune::collect_notes() |> dplyr::select(note) expect_snapshot(class_note$note[1]) }) test_that("specifying an engine that does not allow case weights", { - lr_spec <- linear_reg() %>% set_engine("lm") - knn_spec <- nearest_neighbor() %>% - set_engine("kknn") %>% + lr_spec <- linear_reg() |> set_engine("lm") + knn_spec <- nearest_neighbor() |> + set_engine("kknn") |> set_mode("regression") cars <- - mtcars %>% + mtcars |> dplyr::mutate( - wts = hardhat::importance_weights(1:nrow(.)), - non_wts = 1:nrow(.) + wts = hardhat::importance_weights(1:nrow(mtcars)), + non_wts = 1:nrow(mtcars) ) expect_snapshot({ @@ -209,13 +209,13 @@ test_that("specifying an engine that does not allow case weights", { }) test_that("specifying a case weight column that isn't in the resamples", { - lr_spec <- linear_reg() %>% set_engine("lm") + lr_spec <- linear_reg() |> set_engine("lm") cars <- - mtcars %>% + mtcars |> dplyr::mutate( - wts = hardhat::importance_weights(1:nrow(.)), - non_wts = 1:nrow(.) + wts = hardhat::importance_weights(1:nrow(mtcars)), + non_wts = 1:nrow(mtcars) ) expect_snapshot({ @@ -224,15 +224,15 @@ test_that("specifying a case weight column that isn't in the resamples", { list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), list(lm = lr_spec), case_weights = boop - ) %>% + ) |> workflow_map( "fit_resamples", resamples = vfold_cv(cars, v = 5) ) }) - class_note <- extract_workflow_set_result(car_set_4, "reg_lm") %>% - tune::collect_notes() %>% + class_note <- extract_workflow_set_result(car_set_4, "reg_lm") |> + tune::collect_notes() |> dplyr::select(note) expect_snapshot(class_note$note[1]) @@ -270,7 +270,7 @@ test_that("correct object type and resamples", { res_2 <- set_1 res_2$result <- - purrr::map(res_2$wflow_id, ~ extract_workflow(res_2, id = .x)) %>% + purrr::map(res_2$wflow_id, ~ extract_workflow(res_2, id = .x)) |> purrr::map(~ tune::fit_resamples(.x, resamples = bootstraps(mtcars, 3))) expect_identical( has_valid_column_result_inner_types(res_2), @@ -348,11 +348,11 @@ test_that("checking resamples", { ctrl <- tune::control_resamples(save_workflow = TRUE) set.seed(1) cv_1 <- vfold_cv(mtcars, v = 5) - f_1 <- lr_spec %>% + f_1 <- lr_spec |> tune::fit_resamples(mpg ~ wt, resamples = cv_1, control = ctrl) set.seed(2) cv_2 <- vfold_cv(mtcars, v = 5) - f_2 <- lr_spec %>% + f_2 <- lr_spec |> tune::fit_resamples(mpg ~ disp, resamples = cv_2, control = ctrl) expect_snapshot( error = TRUE, @@ -361,7 +361,7 @@ test_that("checking resamples", { # Emulate old rset objects attr(cv_2, "fingerprint") <- NULL - f_3 <- lr_spec %>% + f_3 <- lr_spec |> tune::fit_resamples(mpg ~ disp, resamples = cv_2, control = ctrl) expect_no_error(as_workflow_set(wt = f_1, disp = f_3)) }) @@ -374,7 +374,7 @@ test_that("constructor", { workflow_set( list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)), list(lm = lr_spec) - ) %>% + ) |> workflow_map( "fit_resamples", resamples = vfold_cv(mtcars, v = 3), @@ -383,28 +383,28 @@ test_that("constructor", { expect_snapshot( error = TRUE, - new_workflow_set(car_set_1 %>% dplyr::select(-info)) + new_workflow_set(car_set_1 |> dplyr::select(-info)) ) expect_snapshot( error = TRUE, - new_workflow_set(car_set_1 %>% dplyr::mutate(info = "a")) + new_workflow_set(car_set_1 |> dplyr::mutate(info = "a")) ) expect_snapshot( error = TRUE, - new_workflow_set(car_set_1 %>% dplyr::mutate(result = "a")) + new_workflow_set(car_set_1 |> dplyr::mutate(result = "a")) ) expect_snapshot( error = TRUE, - new_workflow_set(car_set_1 %>% dplyr::mutate(option = "a")) + new_workflow_set(car_set_1 |> dplyr::mutate(option = "a")) ) expect_snapshot( error = TRUE, - new_workflow_set(car_set_1 %>% dplyr::mutate(wflow_id = 1)) + new_workflow_set(car_set_1 |> dplyr::mutate(wflow_id = 1)) ) expect_snapshot( error = TRUE, - new_workflow_set(car_set_1 %>% dplyr::mutate(wflow_id = "a")) + new_workflow_set(car_set_1 |> dplyr::mutate(wflow_id = "a")) ) }) diff --git a/vignettes/articles/tuning-and-comparing-models.Rmd b/vignettes/articles/tuning-and-comparing-models.Rmd index 5b0d65f..bc0cec3 100644 --- a/vignettes/articles/tuning-and-comparing-models.Rmd +++ b/vignettes/articles/tuning-and-comparing-models.Rmd @@ -62,16 +62,16 @@ We'll fit two types of discriminant analysis (DA) models (regularized DA and fle library(discrim) mars_disc_spec <- - discrim_flexible(prod_degree = tune()) %>% + discrim_flexible(prod_degree = tune()) |> set_engine("earth") reg_disc_sepc <- - discrim_regularized(frac_common_cov = tune(), frac_identity = tune()) %>% + discrim_regularized(frac_common_cov = tune(), frac_identity = tune()) |> set_engine("klaR") cart_spec <- - decision_tree(cost_complexity = tune(), min_n = tune()) %>% - set_engine("rpart") %>% + decision_tree(cost_complexity = tune(), min_n = tune()) |> + set_engine("rpart") |> set_mode("classification") ``` @@ -103,7 +103,7 @@ For illustration, let's use the `extract` argument of the [control function](htt ```{r option} all_workflows <- - all_workflows %>% + all_workflows |> option_add( id = "formula_cart", control = control_grid(extract = function(x) x) @@ -123,7 +123,7 @@ The `verbose` option provides a concise listing for which workflow is being proc ```{r tuning} all_workflows <- - all_workflows %>% + all_workflows |> # Specifying arguments here adds to any previously set with `option_add()`: workflow_map(resamples = train_resamples, grid = 20, verbose = TRUE) all_workflows @@ -154,7 +154,7 @@ We can also pull out the results of `tune_grid()` for this model: ```{r mars-results-print} mars_results <- - all_workflows %>% + all_workflows |> extract_workflow_set_result("formula_mars") mars_results ``` @@ -163,13 +163,13 @@ Let's get that workflow object and finalize the model: ```{r final-mars} mars_workflow <- - all_workflows %>% + all_workflows |> extract_workflow("formula_mars") mars_workflow mars_workflow_fit <- - mars_workflow %>% - finalize_workflow(tibble(prod_degree = 1)) %>% + mars_workflow |> + finalize_workflow(tibble(prod_degree = 1)) |> fit(data = train_set) mars_workflow_fit ``` @@ -185,7 +185,7 @@ grid <- ) grid <- - grid %>% + grid |> bind_cols(predict(mars_workflow_fit, grid, type = "prob")) ``` @@ -207,7 +207,7 @@ Recall that we added an option to the CART model to extract the model results. L ```{r extraction-res} cart_res <- - all_workflows %>% + all_workflows |> extract_workflow_set_result("formula_cart") cart_res ``` @@ -221,9 +221,9 @@ Let's slim that down by keeping the ones that correspond to the best tuning para best_cart <- select_best(cart_res, metric = "roc_auc") cart_wflows <- - cart_res %>% - select(id, .extracts) %>% - unnest(cols = .extracts) %>% + cart_res |> + select(id, .extracts) |> + unnest(cols = .extracts) |> inner_join(best_cart) cart_wflows @@ -234,32 +234,30 @@ What can we do with these? Let's write a function to return the number of termin ```{r cart-nodes} num_nodes <- function(wflow) { var_imps <- - wflow %>% + wflow |> # Pull out the rpart model - extract_fit_engine() %>% + extract_fit_engine() |> # The 'frame' element is a matrix with a column that # indicates which leaves are terminal - pluck("frame") %>% + pluck("frame") |> # Convert to a data frame - as_tibble() %>% + as_tibble() |> # Save only the rows that are terminal nodes - filter(var == "") %>% + filter(var == "") |> # Count them nrow() } -cart_wflows$.extracts[[1]] %>% num_nodes() +cart_wflows$.extracts[[1]] |> num_nodes() ``` Now let's create a column with the results for each resample: ```{r num-nodes-counts} cart_wflows <- - cart_wflows %>% + cart_wflows |> mutate(num_nodes = map_int(.extracts, num_nodes)) cart_wflows ``` The average number of terminal nodes for this model is `r round(mean(cart_wflows$num_nodes), 1)` nodes. - - diff --git a/vignettes/evaluating-different-predictor-sets.Rmd b/vignettes/evaluating-different-predictor-sets.Rmd index 6681f46..6c1f6c6 100644 --- a/vignettes/evaluating-different-predictor-sets.Rmd +++ b/vignettes/evaluating-different-predictor-sets.Rmd @@ -45,7 +45,7 @@ library(rsample) library(dplyr) library(ggplot2) -lr_model <- logistic_reg() %>% set_engine("glm") +lr_model <- logistic_reg() |> set_engine("glm") set.seed(1) trn_tst_split <- initial_split(mlc_churn, strata = churn) @@ -79,7 +79,7 @@ Since we are using basic logistic regression, there is nothing to tune for these ```{r churn-wflow-set-fits} churn_workflows <- - churn_workflows %>% + churn_workflows |> workflow_map("fit_resamples", resamples = folds) churn_workflows ``` @@ -88,38 +88,38 @@ To assess how to measure the effect of each predictor, let's subtract the area u ```{r churn-metrics, fig.width=6, fig.height=5} roc_values <- - churn_workflows %>% - collect_metrics(summarize = FALSE) %>% - filter(.metric == "roc_auc") %>% + churn_workflows |> + collect_metrics(summarize = FALSE) |> + filter(.metric == "roc_auc") |> mutate(wflow_id = gsub("_logistic", "", wflow_id)) full_model <- - roc_values %>% - filter(wflow_id == "everything") %>% + roc_values |> + filter(wflow_id == "everything") |> select(full_model = .estimate, id) differences <- - roc_values %>% - filter(wflow_id != "everything") %>% - full_join(full_model, by = "id") %>% + roc_values |> + filter(wflow_id != "everything") |> + full_join(full_model, by = "id") |> mutate(performance_drop = full_model - .estimate) summary_stats <- - differences %>% - group_by(wflow_id) %>% + differences |> + group_by(wflow_id) |> summarize( std_err = sd(performance_drop) / sum(!is.na(performance_drop)), performance_drop = mean(performance_drop), lower = performance_drop - qnorm(0.975) * std_err, upper = performance_drop + qnorm(0.975) * std_err, .groups = "drop" - ) %>% + ) |> mutate( wflow_id = factor(wflow_id), wflow_id = reorder(wflow_id, performance_drop) ) -summary_stats %>% filter(lower > 0) +summary_stats |> filter(lower > 0) ggplot(summary_stats, aes(x = performance_drop, y = wflow_id)) + geom_point() + @@ -128,4 +128,3 @@ ggplot(summary_stats, aes(x = performance_drop, y = wflow_id)) + ``` From this, there are a predictors that, when not included in the model, have a significant effect on the performance metric. - From 586c070fc2cf182b72dcadc1f3071817b58b3864 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Fri, 25 Apr 2025 14:20:46 -0500 Subject: [PATCH 06/11] transition to the base anonymous function syntax --- R/as_workflow_set.R | 6 +++--- R/checks.R | 10 +++++----- R/collect.R | 11 ++++++----- R/comments.R | 8 ++++---- R/leave_var_out_formulas.R | 4 ++-- R/misc.R | 6 +++--- R/options.R | 2 +- R/rank_results.R | 4 ++-- R/workflow_map.R | 2 +- R/workflow_set.R | 8 ++++---- man/as_workflow_set.Rd | 2 +- man/workflow_set.Rd | 2 +- tests/testthat/helper-extract_parameter_set.R | 2 +- 13 files changed, 34 insertions(+), 33 deletions(-) diff --git a/R/as_workflow_set.R b/R/as_workflow_set.R index 83dd237..8eef16d 100644 --- a/R/as_workflow_set.R +++ b/R/as_workflow_set.R @@ -26,7 +26,7 @@ #' names(results) <- two_class_res$wflow_id #' #' # These are all objects that have been resampled or tuned: -#' purrr::map_chr(results, ~ class(.x)[1]) +#' purrr::map_chr(results, \(x) class(x)[1]) #' #' # Use rlang's !!! operator to splice in the elements of the list #' new_set <- as_workflow_set(!!!results) @@ -55,7 +55,7 @@ as_workflow_set <- function(...) { object <- rlang::list2(...) # These could be workflows or objects of class `tune_result` - is_workflow <- purrr::map_lgl(object, ~ inherits(.x, "workflow")) + is_workflow <- purrr::map_lgl(object, \(x) inherits(x, "workflow")) wflows <- vector("list", length(is_workflow)) wflows[is_workflow] <- object[is_workflow] wflows[!is_workflow] <- purrr::map( @@ -73,7 +73,7 @@ as_workflow_set <- function(...) { dplyr::mutate( workflow = unname(wflows), info = purrr::map(workflow, get_info), - option = purrr::map(1:nrow(res), ~ new_workflow_set_options()) + option = purrr::map(1:nrow(res), \(i) new_workflow_set_options()) ) res$result <- vector(mode = "list", length = nrow(res)) res$result[!is_workflow] <- object[!is_workflow] diff --git a/R/checks.R b/R/checks.R index dd4d6b3..c8c5062 100644 --- a/R/checks.R +++ b/R/checks.R @@ -19,7 +19,7 @@ check_consistent_metrics <- function(x, fail = TRUE, call = caller_env()) { values_fill = FALSE ) |> dplyr::select(-wflow_id) |> - purrr::map_dbl(~ sum(!.x)) + purrr::map_dbl(\(.x) sum(!.x)) if (any(metric_info > 0)) { incp_metrics <- names(metric_info)[metric_info > 0] @@ -40,8 +40,8 @@ check_consistent_metrics <- function(x, fail = TRUE, call = caller_env()) { } check_incompete <- function(x, fail = TRUE, call = caller_env()) { - empty_res <- purrr::map_lgl(x$result, ~ identical(.x, list())) - failed_res <- purrr::map_lgl(x$result, ~ inherits(.x, "try-error")) + empty_res <- purrr::map_lgl(x$result, \(.x) identical(.x, list())) + failed_res <- purrr::map_lgl(x$result, \(.x) inherits(.x, "try-error")) n_empty <- sum(empty_res | failed_res) if (n_empty > 0) { @@ -168,7 +168,7 @@ check_names <- function(x, call = caller_env()) { } check_for_workflow <- function(x, call = caller_env()) { - no_wflow <- purrr::map_lgl(x, ~ !inherits(.x, "workflow")) + no_wflow <- purrr::map_lgl(x, \(.x) !inherits(.x, "workflow")) if (any(no_wflow)) { bad <- names(no_wflow)[no_wflow] cli::cli_abort( @@ -312,7 +312,7 @@ has_all_pkgs <- function(w) { if (length(pkgs) > 0) { is_inst <- purrr::map_lgl( pkgs, - ~ rlang::is_true(requireNamespace(.x, quietly = TRUE)) + \(.x) rlang::is_true(requireNamespace(.x, quietly = TRUE)) ) if (!all(is_inst)) { cols <- tune::get_tune_colors() diff --git a/R/collect.R b/R/collect.R index 968e554..901ddc1 100644 --- a/R/collect.R +++ b/R/collect.R @@ -118,11 +118,12 @@ collect_predictions.workflow_set <- x, predictions = purrr::map( result, - ~ select_bare_predictions( - .x, - summarize = summarize, - metric = metric - ) + \(.x) + select_bare_predictions( + .x, + summarize = summarize, + metric = metric + ) ) ) } else { diff --git a/R/comments.R b/R/comments.R index 66d5128..5706f27 100644 --- a/R/comments.R +++ b/R/comments.R @@ -105,8 +105,8 @@ comment_print <- function(x, id = NULL, ...) { } x <- dplyr::filter(x, wflow_id %in% id) - chr_x <- purrr::map(x$wflow_id, ~ comment_get(x, id = .x)) - has_comment <- purrr::map_lgl(chr_x, ~ nchar(.x) > 0) + chr_x <- purrr::map(x$wflow_id, \(.x) comment_get(x, id = .x)) + has_comment <- purrr::map_lgl(chr_x, \(.x) nchar(.x) > 0) chr_x <- chr_x[which(has_comment)] id <- x$wflow_id[which(has_comment)] @@ -124,8 +124,8 @@ comment_print <- function(x, id = NULL, ...) { comment_format <- function(x, id, ...) { x <- strsplit(x, "\n")[[1]] - x <- purrr::map(x, ~ strwrap(.x)) - x <- purrr::map(x, ~ add_returns(.x)) + x <- purrr::map(x, \(.x) strwrap(.x)) + x <- purrr::map(x, \(.x) add_returns(.x)) paste0(x, collapse = "\n\n") } diff --git a/R/leave_var_out_formulas.R b/R/leave_var_out_formulas.R index 0a59bb6..00ee469 100644 --- a/R/leave_var_out_formulas.R +++ b/R/leave_var_out_formulas.R @@ -47,7 +47,7 @@ leave_var_out_formulas <- function(formula, data, full_model = TRUE, ...) { form_terms <- purrr::map(x_vars, rm_vars, lst = x_vars) form <- purrr::map_chr( form_terms, - ~ paste(y_vars, "~", paste(.x, collapse = " + ")) + \(.x) paste(y_vars, "~", paste(.x, collapse = " + ")) ) form <- purrr::map(form, as.formula) form <- purrr::map(form, rm_formula_env) @@ -63,7 +63,7 @@ rm_vars <- function(x, lst) { } remaining_terms <- function(x, lst) { - has_x <- purrr::map_lgl(lst, ~ x %in% all_terms(.x)) + has_x <- purrr::map_lgl(lst, \(.x) x %in% all_terms(.x)) is_x <- lst == x lst[!has_x & !is_x] } diff --git a/R/misc.R b/R/misc.R index cf2edbd..2520ad5 100644 --- a/R/misc.R +++ b/R/misc.R @@ -25,8 +25,8 @@ metric_to_df <- function(x, ...) { metrics <- attributes(x)$metrics names <- names(metrics) metrics <- unname(metrics) - classes <- purrr::map_chr(metrics, ~ class(.x)[[1]]) - directions <- purrr::map_chr(metrics, ~ attr(.x, "direction")) + classes <- purrr::map_chr(metrics, \(.x) class(.x)[[1]]) + directions <- purrr::map_chr(metrics, \(.x) attr(.x, "direction")) info <- data.frame(metric = names, class = classes, direction = directions) info } @@ -37,7 +37,7 @@ collate_metrics <- function(x) { x$result |> purrr::map(tune::.get_tune_metrics) |> purrr::map(metric_to_df) |> - purrr::map_dfr(~ dplyr::mutate(.x, order = 1:nrow(.x))) + purrr::map_dfr(\(.x) dplyr::mutate(.x, order = 1:nrow(.x))) mean_order <- metrics |> diff --git a/R/options.R b/R/options.R index 61f3d72..56d8d64 100644 --- a/R/options.R +++ b/R/options.R @@ -113,7 +113,7 @@ maybe_param <- function(x) { #' @export #' @rdname option_add option_add_parameters <- function(x, id = NULL, strict = FALSE) { - prm <- purrr::map(x$info, ~ maybe_param(.x$workflow[[1]])) + prm <- purrr::map(x$info, \(.x) maybe_param(.x$workflow[[1]])) num <- purrr::map_int(prm, length) if (all(num == 0)) { return(x) diff --git a/R/rank_results.R b/R/rank_results.R index 054cd9d..27beb42 100644 --- a/R/rank_results.R +++ b/R/rank_results.R @@ -72,7 +72,7 @@ rank_results <- function( types <- x |> dplyr::full_join(wflow_info, by = "wflow_id") |> dplyr::mutate( - is_race = purrr::map_lgl(result, ~ inherits(.x, "tune_race")), + is_race = purrr::map_lgl(result, \(.x) inherits(.x, "tune_race")), num_rs = purrr::map_int(result, get_num_resamples) ) |> dplyr::select(wflow_id, is_race, num_rs) @@ -132,7 +132,7 @@ rank_results <- function( } get_num_resamples <- function(x) { - purrr::map_dfr(x$splits, ~ .x$id) |> + purrr::map_dfr(x$splits, \(.x) .x$id) |> dplyr::distinct() |> nrow() } diff --git a/R/workflow_map.R b/R/workflow_map.R index f0036b4..0fd0702 100644 --- a/R/workflow_map.R +++ b/R/workflow_map.R @@ -246,7 +246,7 @@ allowed_fn_list <- paste0("'", allowed_fn$func, "'", collapse = ", ") check_object_fn <- function(object, fn, call = rlang::caller_env()) { wf_specs <- purrr::map( object$wflow_id, - ~ extract_spec_parsnip(object, id = .x) + \(.x) extract_spec_parsnip(object, id = .x) ) is_cluster_spec <- purrr::map_lgl(wf_specs, inherits, "cluster_spec") diff --git a/R/workflow_set.R b/R/workflow_set.R index 124a47a..d427a42 100644 --- a/R/workflow_set.R +++ b/R/workflow_set.R @@ -117,7 +117,7 @@ #' #' # Select predictors by their names #' channels <- paste0("ch_", 1:4) -#' preproc <- purrr::map(channels, ~ workflow_variables(class, c(contains(!!.x)))) +#' preproc <- purrr::map(channels, \(.x) workflow_variables(class, c(contains(!!.x)))) #' names(preproc) <- channels #' preproc$everything <- class ~ . #' preproc @@ -160,8 +160,8 @@ workflow_set <- function(preproc, models, cross = TRUE, case_weights = NULL) { dplyr::mutate( workflow = wfs, info = purrr::map(workflow, get_info), - option = purrr::map(1:nrow(res), ~ new_workflow_set_options()), - result = purrr::map(1:nrow(res), ~ list()) + option = purrr::map(1:nrow(res), \(i) new_workflow_set_options()), + result = purrr::map(1:nrow(res), \(i) list()) ) |> dplyr::select(wflow_id, info, option, result) @@ -188,7 +188,7 @@ model_type <- function(x) { } fix_list_names <- function(x) { - prefix <- purrr::map_chr(x, ~ class(.x)[1]) + prefix <- purrr::map_chr(x, \(.x) class(.x)[1]) prefix <- vctrs::vec_as_names(prefix, repair = "unique", quiet = TRUE) prefix <- gsub("\\.\\.\\.", "_", prefix) nms <- names(x) diff --git a/man/as_workflow_set.Rd b/man/as_workflow_set.Rd index 8a6bd00..81a8601 100644 --- a/man/as_workflow_set.Rd +++ b/man/as_workflow_set.Rd @@ -54,7 +54,7 @@ results <- two_class_res |> purrr::pluck("result") names(results) <- two_class_res$wflow_id # These are all objects that have been resampled or tuned: -purrr::map_chr(results, ~ class(.x)[1]) +purrr::map_chr(results, \(x) class(x)[1]) # Use rlang's !!! operator to splice in the elements of the list new_set <- as_workflow_set(!!!results) diff --git a/man/workflow_set.Rd b/man/workflow_set.Rd index 80fcb99..6385c58 100644 --- a/man/workflow_set.Rd +++ b/man/workflow_set.Rd @@ -150,7 +150,7 @@ cell_set # Select predictors by their names channels <- paste0("ch_", 1:4) -preproc <- purrr::map(channels, ~ workflow_variables(class, c(contains(!!.x)))) +preproc <- purrr::map(channels, \(.x) workflow_variables(class, c(contains(!!.x)))) names(preproc) <- channels preproc$everything <- class ~ . preproc diff --git a/tests/testthat/helper-extract_parameter_set.R b/tests/testthat/helper-extract_parameter_set.R index ec3d393..66ca349 100644 --- a/tests/testthat/helper-extract_parameter_set.R +++ b/tests/testthat/helper-extract_parameter_set.R @@ -13,7 +13,7 @@ check_parameter_set_tibble <- function(x) { expect_equal(class(x$object), "list") obj_check <- purrr::map_lgl( x$object, - ~ inherits(.x, "param") | all(is.na(.x)) + \(.x) inherits(.x, "param") | all(is.na(.x)) ) expect_true(all(obj_check)) From caab4caa08123f890ed82459a40df82e21c5a534 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Fri, 25 Apr 2025 14:21:32 -0500 Subject: [PATCH 07/11] add ROR for Posit --- DESCRIPTION | 4 +++- man/workflowsets-package.Rd | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index abc8727..a11f276 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -6,7 +6,9 @@ Authors@R: c( comment = c(ORCID = "0000-0003-2402-136X")), person("Simon", "Couch", , "simon.couch@posit.co", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-5676-5107")), - person(given = "Posit Software, PBC", role = c("cph", "fnd")) + person(given = "Posit Software, PBC", + role = c("cph", "fnd"), + comment = c(ROR = "03wc8by49")) ) Description: A workflow is a combination of a model and preprocessors (e.g, a formula, recipe, etc.) (Kuhn and Silge (2021) diff --git a/man/workflowsets-package.Rd b/man/workflowsets-package.Rd index 6a640f0..647c025 100644 --- a/man/workflowsets-package.Rd +++ b/man/workflowsets-package.Rd @@ -27,7 +27,7 @@ Authors: Other contributors: \itemize{ - \item Posit Software, PBC [copyright holder, funder] + \item Posit Software, PBC (03wc8by49) [copyright holder, funder] } } From f2dfdb7f2c523f805cfc099b513a39523f37535c Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Fri, 25 Apr 2025 14:21:58 -0500 Subject: [PATCH 08/11] `knitr::convert_chunk_header(type = "yaml")` --- .../articles/tuning-and-comparing-models.Rmd | 67 +++++++++++++------ .../evaluating-different-predictor-sets.Rmd | 23 +++++-- 2 files changed, 64 insertions(+), 26 deletions(-) diff --git a/vignettes/articles/tuning-and-comparing-models.Rmd b/vignettes/articles/tuning-and-comparing-models.Rmd index bc0cec3..a577931 100644 --- a/vignettes/articles/tuning-and-comparing-models.Rmd +++ b/vignettes/articles/tuning-and-comparing-models.Rmd @@ -7,7 +7,8 @@ vignette: > %\VignetteEncoding{UTF-8} --- -```{r, include = FALSE} +```{r} +#| include: false knitr::opts_chunk$set( collapse = TRUE, comment = "#>", @@ -28,7 +29,8 @@ For some problems, users might want to try different combinations of preprocessi In this example we'll use a small, two-dimensional data set for illustrating classification models. The data are in the [modeldata](https://modeldata.tidymodels.org/) package: -```{r parabolic} +```{r} +#| label: parabolic library(tidymodels) data(parabolic) @@ -37,7 +39,8 @@ str(parabolic) Let's hold back 25% of the data for a test set: -```{r 2d-splits} +```{r} +#| label: 2d-splits set.seed(1) split <- initial_split(parabolic) @@ -47,7 +50,10 @@ test_set <- testing(split) Visually, we can see that the predictors are mildly correlated and some type of nonlinear class boundary is probably needed. -```{r 2d-plot, fig.width=5, fig.height=5.1} +```{r} +#| label: 2d-plot +#| fig-width: 5 +#| fig-height: 5.1 ggplot(train_set, aes(x = X1, y = X2, col = class)) + geom_point(alpha = 0.5) + coord_fixed(ratio = 1) + @@ -58,7 +64,8 @@ ggplot(train_set, aes(x = X1, y = X2, col = class)) + We'll fit two types of discriminant analysis (DA) models (regularized DA and flexible DA using MARS, multivariate adaptive regression splines) as well as a simple classification tree. Let's create those parsnip model objects: -```{r models} +```{r} +#| label: models library(discrim) mars_disc_spec <- @@ -77,7 +84,8 @@ cart_spec <- Next, we'll need a resampling method. Let's use the bootstrap: -```{r resamples} +```{r} +#| label: resamples set.seed(2) train_resamples <- bootstraps(train_set) ``` @@ -86,7 +94,8 @@ We have a simple data set so a basic formula will suffice for our preprocessing. The workflow set takes a named list of preprocessors and a named list of parsnip model specifications, and can cross them to find all combinations. For our case, it will just make a set of workflows for our models: -```{r wflow-set} +```{r} +#| label: wflow-set all_workflows <- workflow_set( preproc = list("formula" = class ~ .), @@ -101,7 +110,8 @@ We can add any specific options that we think are important for tuning or resamp For illustration, let's use the `extract` argument of the [control function](https://tune.tidymodels.org/reference/control_grid.html) to save the fitted workflow. We can then pick which workflow should use this option with the `id` argument: -```{r option} +```{r} +#| label: option all_workflows <- all_workflows |> option_add( @@ -121,7 +131,8 @@ Let's use the same grid size for each model. For the MARS model, there are only The `verbose` option provides a concise listing for which workflow is being processed: -```{r tuning} +```{r} +#| label: tuning all_workflows <- all_workflows |> # Specifying arguments here adds to any previously set with `option_add()`: @@ -133,7 +144,11 @@ The `result` column now has the results of each `tune_grid()` call. From these results, we can get quick assessments of how well these models classified the data: -```{r rank_res, fig.width=8, fig.height=5.5, out.width="100%"} +```{r} +#| label: rank_res +#| fig-width: 8 +#| fig-height: 5.5 +#| out-width: 100% rank_results(all_workflows, rank_metric = "roc_auc") # or a handy plot: @@ -144,7 +159,10 @@ autoplot(all_workflows, metric = "roc_auc") It looks like the MARS model did well. We can plot its results and also pull out the tuning object too: -```{r mars, fig.width=6, fig.height=4.25} +```{r} +#| label: mars +#| fig-width: 6 +#| fig-height: 4.25 autoplot(all_workflows, metric = "roc_auc", id = "formula_mars") ``` @@ -152,7 +170,8 @@ Not much of a difference in performance; it may be prudent to use the additive m We can also pull out the results of `tune_grid()` for this model: -```{r mars-results-print} +```{r} +#| label: mars-results-print mars_results <- all_workflows |> extract_workflow_set_result("formula_mars") @@ -161,7 +180,8 @@ mars_results Let's get that workflow object and finalize the model: -```{r final-mars} +```{r} +#| label: final-mars mars_workflow <- all_workflows |> extract_workflow("formula_mars") @@ -176,7 +196,8 @@ mars_workflow_fit Let's see how well these data work on the test set: -```{r grid-pred} +```{r} +#| label: grid-pred # Make a grid to predict the whole space: grid <- crossing( @@ -191,7 +212,11 @@ grid <- We can produce a contour plot for the class boundary, then overlay the data: -```{r 2d-boundary, fig.width=5, fig.height=5.1, warning=FALSE} +```{r} +#| label: 2d-boundary +#| warning: false +#| fig-width: 5 +#| fig-height: 5.1 ggplot(grid, aes(x = X1, y = X2)) + geom_contour(aes(z = .pred_Class2), breaks = 0.5, col = "black") + geom_point(data = test_set, aes(col = class), alpha = 0.5) + @@ -205,7 +230,8 @@ The workflow set allows us to screen many models to find one that does very well Recall that we added an option to the CART model to extract the model results. Let's pull out the CART tuning results and see what we have: -```{r extraction-res} +```{r} +#| label: extraction-res cart_res <- all_workflows |> extract_workflow_set_result("formula_cart") @@ -216,7 +242,8 @@ The `.extracts` has 20 rows for each resample (since there were 20 tuning parame Let's slim that down by keeping the ones that correspond to the best tuning parameters: -```{r extract-subset} +```{r} +#| label: extract-subset # Get the best results best_cart <- select_best(cart_res, metric = "roc_auc") @@ -231,7 +258,8 @@ cart_wflows What can we do with these? Let's write a function to return the number of terminal nodes in the tree. -```{r cart-nodes} +```{r} +#| label: cart-nodes num_nodes <- function(wflow) { var_imps <- wflow |> @@ -253,7 +281,8 @@ cart_wflows$.extracts[[1]] |> num_nodes() Now let's create a column with the results for each resample: -```{r num-nodes-counts} +```{r} +#| label: num-nodes-counts cart_wflows <- cart_wflows |> mutate(num_nodes = map_int(.extracts, num_nodes)) diff --git a/vignettes/evaluating-different-predictor-sets.Rmd b/vignettes/evaluating-different-predictor-sets.Rmd index 6c1f6c6..ab1cef8 100644 --- a/vignettes/evaluating-different-predictor-sets.Rmd +++ b/vignettes/evaluating-different-predictor-sets.Rmd @@ -7,7 +7,8 @@ vignette: > %\VignetteEncoding{UTF-8} --- -```{r, include = FALSE} +```{r} +#| include: false knitr::opts_chunk$set( collapse = TRUE, eval = rlang::is_installed(c("modeldata", "recipes")), @@ -29,7 +30,8 @@ In this example, we'll fit the same model but specify different predictor sets i Let's take a look at the customer churn data from the `modeldata` package: -```{r tidymodels} +```{r} +#| label: tidymodels data(mlc_churn, package = "modeldata") ncol(mlc_churn) ``` @@ -38,7 +40,8 @@ There are 19 predictors, mostly numeric. This include aspects of their account, We'll use a logistic regression to model the data. Since the data set is not small, we'll use basic 10-fold cross-validation to get resampled performance estimates. -```{r churn-objects} +```{r} +#| label: churn-objects library(workflowsets) library(parsnip) library(rsample) @@ -57,7 +60,8 @@ folds <- vfold_cv(training(trn_tst_split), strata = churn) We would make a basic workflow that uses this model specification and a basic formula. However, in this application, we'd like to know which predictors are associated with the best area under the ROC curve. -```{r churn-formulas} +```{r} +#| label: churn-formulas formulas <- leave_var_out_formulas(churn ~ ., data = mlc_churn) length(formulas) @@ -66,7 +70,8 @@ formulas[["area_code"]] We create our workflow set: -```{r churn-wflow-sets} +```{r} +#| label: churn-wflow-sets churn_workflows <- workflow_set( preproc = formulas, @@ -77,7 +82,8 @@ churn_workflows Since we are using basic logistic regression, there is nothing to tune for these models. Instead of `tune_grid()`, we'll use `tune::fit_resamples()` instead by giving that function name as the first argument: -```{r churn-wflow-set-fits} +```{r} +#| label: churn-wflow-set-fits churn_workflows <- churn_workflows |> workflow_map("fit_resamples", resamples = folds) @@ -86,7 +92,10 @@ churn_workflows To assess how to measure the effect of each predictor, let's subtract the area under the ROC curve for each predictor from the same metric from the full model. We'll match first by resampling ID, the compute the mean difference. -```{r churn-metrics, fig.width=6, fig.height=5} +```{r} +#| label: churn-metrics +#| fig-width: 6 +#| fig-height: 5 roc_values <- churn_workflows |> collect_metrics(summarize = FALSE) |> From 02dd00503326533ef51f49f9403e38e9ef04b31a Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Fri, 25 Apr 2025 14:22:17 -0500 Subject: [PATCH 09/11] update copyright year --- LICENSE | 2 +- LICENSE.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/LICENSE b/LICENSE index 2449626..4738cf0 100644 --- a/LICENSE +++ b/LICENSE @@ -1,2 +1,2 @@ -YEAR: 2023 +YEAR: 2025 COPYRIGHT HOLDER: workflowsets authors diff --git a/LICENSE.md b/LICENSE.md index 921d934..75eb8f6 100644 --- a/LICENSE.md +++ b/LICENSE.md @@ -1,6 +1,6 @@ # MIT License -Copyright (c) 2023 workflowsets authors +Copyright (c) 2025 workflowsets authors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal From 1b0f9b40628434beb66afb325eaaa788b9f465f9 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Fri, 25 Apr 2025 14:22:36 -0500 Subject: [PATCH 10/11] `usethis::use_tidy_description()` --- DESCRIPTION | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index a11f276..492ca5b 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -2,12 +2,11 @@ Package: workflowsets Title: Create a Collection of 'tidymodels' Workflows Version: 1.1.0.9000 Authors@R: c( - person("Max", "Kuhn", , "max@posit.co", role = c("aut"), + person("Max", "Kuhn", , "max@posit.co", role = "aut", comment = c(ORCID = "0000-0003-2402-136X")), person("Simon", "Couch", , "simon.couch@posit.co", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-5676-5107")), - person(given = "Posit Software, PBC", - role = c("cph", "fnd"), + person("Posit Software, PBC", role = c("cph", "fnd"), comment = c(ROR = "03wc8by49")) ) Description: A workflow is a combination of a model and preprocessors @@ -57,18 +56,12 @@ Suggests: yardstick (>= 1.3.0) VignetteBuilder: knitr -Config/Needs/website: - discrim, - rpart, - mda, - klaR, - earth, - tidymodels, +Config/Needs/website: discrim, rpart, mda, klaR, earth, tidymodels, tidyverse/tidytemplate Config/testthat/edition: 3 +Config/usethis/last-upkeep: 2025-04-25 Encoding: UTF-8 Language: en-US LazyData: true Roxygen: list(markdown = TRUE) RoxygenNote: 7.3.2 -Config/usethis/last-upkeep: 2025-04-25 From f5484a663d5696e3433e2c25c205bf789eae96ee Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Fri, 25 Apr 2025 14:23:10 -0500 Subject: [PATCH 11/11] `usethis::use_tidy_github_actions()` --- .github/workflows/R-CMD-check.yaml | 1 - .github/workflows/pkgdown.yaml | 1 - .github/workflows/test-coverage.yaml | 11 ++++++----- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index 064677b..69cfc6a 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -8,7 +8,6 @@ on: push: branches: [main, master] pull_request: - branches: [main, master] name: R-CMD-check.yaml diff --git a/.github/workflows/pkgdown.yaml b/.github/workflows/pkgdown.yaml index 4bbce75..bfc9f4d 100644 --- a/.github/workflows/pkgdown.yaml +++ b/.github/workflows/pkgdown.yaml @@ -4,7 +4,6 @@ on: push: branches: [main, master] pull_request: - branches: [main, master] release: types: [published] workflow_dispatch: diff --git a/.github/workflows/test-coverage.yaml b/.github/workflows/test-coverage.yaml index 9882260..0ab748d 100644 --- a/.github/workflows/test-coverage.yaml +++ b/.github/workflows/test-coverage.yaml @@ -4,7 +4,6 @@ on: push: branches: [main, master] pull_request: - branches: [main, master] name: test-coverage.yaml @@ -35,14 +34,16 @@ jobs: clean = FALSE, install_path = file.path(normalizePath(Sys.getenv("RUNNER_TEMP"), winslash = "/"), "package") ) + print(cov) covr::to_cobertura(cov) shell: Rscript {0} - - uses: codecov/codecov-action@v4 + - uses: codecov/codecov-action@v5 with: - fail_ci_if_error: ${{ github.event_name != 'pull_request' && true || false }} - file: ./cobertura.xml - plugin: noop + # Fail if error if not on PR, or if on PR and token is given + fail_ci_if_error: ${{ github.event_name != 'pull_request' || secrets.CODECOV_TOKEN }} + files: ./cobertura.xml + plugins: noop disable_search: true token: ${{ secrets.CODECOV_TOKEN }}