Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@
^revdep$
^cran-comments\.md$
^man-roxygen$
^[\.]?air\.toml$
^\.vscode$
1 change: 0 additions & 1 deletion .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ on:
push:
branches: [main, master]
pull_request:
branches: [main, master]

name: R-CMD-check.yaml

Expand Down
1 change: 0 additions & 1 deletion .github/workflows/pkgdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ on:
push:
branches: [main, master]
pull_request:
branches: [main, master]
release:
types: [published]
workflow_dispatch:
Expand Down
11 changes: 6 additions & 5 deletions .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ on:
push:
branches: [main, master]
pull_request:
branches: [main, master]

name: test-coverage.yaml

Expand Down Expand Up @@ -35,14 +34,16 @@ jobs:
clean = FALSE,
install_path = file.path(normalizePath(Sys.getenv("RUNNER_TEMP"), winslash = "/"), "package")
)
print(cov)
covr::to_cobertura(cov)
shell: Rscript {0}

- uses: codecov/codecov-action@v4
- uses: codecov/codecov-action@v5
with:
fail_ci_if_error: ${{ github.event_name != 'pull_request' && true || false }}
file: ./cobertura.xml
plugin: noop
# Fail if error if not on PR, or if on PR and token is given
fail_ci_if_error: ${{ github.event_name != 'pull_request' || secrets.CODECOV_TOKEN }}
files: ./cobertura.xml
plugins: noop
disable_search: true
token: ${{ secrets.CODECOV_TOKEN }}

Expand Down
5 changes: 5 additions & 0 deletions .vscode/extensions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"recommendations": [
"Posit.air-vscode"
]
}
6 changes: 6 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"[r]": {
"editor.formatOnSave": true,
"editor.defaultFormatter": "Posit.air-vscode"
}
}
16 changes: 6 additions & 10 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ Package: workflowsets
Title: Create a Collection of 'tidymodels' Workflows
Version: 1.1.0.9000
Authors@R: c(
person("Max", "Kuhn", , "max@posit.co", role = c("aut"),
person("Max", "Kuhn", , "max@posit.co", role = "aut",
comment = c(ORCID = "0000-0003-2402-136X")),
person("Simon", "Couch", , "simon.couch@posit.co", role = c("aut", "cre"),
comment = c(ORCID = "0000-0001-5676-5107")),
person(given = "Posit Software, PBC", role = c("cph", "fnd"))
person("Posit Software, PBC", role = c("cph", "fnd"),
comment = c(ROR = "03wc8by49"))
)
Description: A workflow is a combination of a model and preprocessors
(e.g, a formula, recipe, etc.) (Kuhn and Silge (2021)
Expand All @@ -19,7 +20,7 @@ URL: https://github.com/tidymodels/workflowsets,
https://workflowsets.tidymodels.org
BugReports: https://github.com/tidymodels/workflowsets/issues
Depends:
R (>= 3.6)
R (>= 4.1)
Imports:
cli,
dplyr (>= 1.0.0),
Expand Down Expand Up @@ -55,15 +56,10 @@ Suggests:
yardstick (>= 1.3.0)
VignetteBuilder:
knitr
Config/Needs/website:
discrim,
rpart,
mda,
klaR,
earth,
tidymodels,
Config/Needs/website: discrim, rpart, mda, klaR, earth, tidymodels,
tidyverse/tidytemplate
Config/testthat/edition: 3
Config/usethis/last-upkeep: 2025-04-25
Encoding: UTF-8
Language: en-US
LazyData: true
Expand Down
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
YEAR: 2023
YEAR: 2025
COPYRIGHT HOLDER: workflowsets authors
2 changes: 1 addition & 1 deletion LICENSE.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MIT License

Copyright (c) 2023 workflowsets authors
Copyright (c) 2025 workflowsets authors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

* Added a `collect_extracts()` method for workflow sets (#156).

* Increased the minimum required R version to R 4.1.

# workflowsets 1.1.0

* Ellipses (...) are now used consistently in the package to require optional arguments to be named; `collect_metrics()` and `collect_predictions()` are the only functions that received changes (#151, tidymodels/tune#863).
Expand Down
35 changes: 30 additions & 5 deletions R/0_imports.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,36 @@ NULL

utils::globalVariables(
c(
".config", ".estimator", ".metric", "info", "metric", "mod_nm",
"model", "n", "pp_nm", "preprocessor", "preproc", "object", "engine",
"result", "std_err", "wflow_id", "func", "is_race", "num_rs", "option",
"metrics", "predictions", "hash", "id", "workflow", "comment", "get_from_env",
".get_tune_metric_names", "select_best", "notes"
".config",
".estimator",
".metric",
"info",
"metric",
"mod_nm",
"model",
"n",
"pp_nm",
"preprocessor",
"preproc",
"object",
"engine",
"result",
"std_err",
"wflow_id",
"func",
"is_race",
"num_rs",
"option",
"metrics",
"predictions",
"hash",
"id",
"workflow",
"comment",
"get_from_env",
".get_tune_metric_names",
"select_best",
"notes"
)
)

Expand Down
27 changes: 15 additions & 12 deletions R/as_workflow_set.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
#' # objects to a workflow set
#' two_class_res
#'
#' results <- two_class_res %>% purrr::pluck("result")
#' results <- two_class_res |> purrr::pluck("result")
#' names(results) <- two_class_res$wflow_id
#'
#' # These are all objects that have been resampled or tuned:
#' purrr::map_chr(results, ~ class(.x)[1])
#' purrr::map_chr(results, \(x) class(x)[1])
#'
#' # Use rlang's !!! operator to splice in the elements of the list
#' new_set <- as_workflow_set(!!!results)
Expand All @@ -40,13 +40,13 @@
#' lr_spec <- logistic_reg()
#'
#' main_effects <-
#' workflow() %>%
#' add_model(lr_spec) %>%
#' workflow() |>
#' add_model(lr_spec) |>
#' add_formula(Class ~ .)
#'
#' interactions <-
#' workflow() %>%
#' add_model(lr_spec) %>%
#' workflow() |>
#' add_model(lr_spec) |>
#' add_formula(Class ~ (.)^2)
#'
#' as_workflow_set(main = main_effects, int = interactions)
Expand All @@ -55,27 +55,30 @@ as_workflow_set <- function(...) {
object <- rlang::list2(...)

# These could be workflows or objects of class `tune_result`
is_workflow <- purrr::map_lgl(object, ~ inherits(.x, "workflow"))
is_workflow <- purrr::map_lgl(object, \(x) inherits(x, "workflow"))
wflows <- vector("list", length(is_workflow))
wflows[is_workflow] <- object[is_workflow]
wflows[!is_workflow] <- purrr::map(object[!is_workflow], tune::.get_tune_workflow)
wflows[!is_workflow] <- purrr::map(
object[!is_workflow],
tune::.get_tune_workflow
)
names(wflows) <- names(object)

check_names(wflows)
check_for_workflow(wflows)

res <- tibble::tibble(wflow_id = names(wflows))
res <-
res %>%
res |>
dplyr::mutate(
workflow = unname(wflows),
info = purrr::map(workflow, get_info),
option = purrr::map(1:nrow(res), ~ new_workflow_set_options())
option = purrr::map(1:nrow(res), \(i) new_workflow_set_options())
)
res$result <- vector(mode = "list", length = nrow(res))
res$result[!is_workflow] <- object[!is_workflow]

res %>%
dplyr::select(wflow_id, info, option, result) %>%
res |>
dplyr::select(wflow_id, info, option, result) |>
new_workflow_set()
}
61 changes: 41 additions & 20 deletions R/autoplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,34 +44,56 @@
#' autoplot(two_class_res, select_best = TRUE)
#' autoplot(two_class_res, id = "yj_trans_cart", metric = "roc_auc")
#' @export
autoplot.workflow_set <- function(object, rank_metric = NULL, metric = NULL,
id = "workflow_set",
select_best = FALSE,
std_errs = qnorm(0.95),
type = "class",
...) {
autoplot.workflow_set <- function(
object,
rank_metric = NULL,
metric = NULL,
id = "workflow_set",
select_best = FALSE,
std_errs = qnorm(0.95),
type = "class",
...
) {
rlang::arg_match(type, c("class", "wflow_id"))
check_string(rank_metric, allow_null = TRUE)
check_character(metric, allow_null = TRUE)
check_number_decimal(std_errs)
check_bool(select_best)

if (id == "workflow_set") {
p <- rank_plot(object,
rank_metric = rank_metric, metric = metric,
select_best = select_best, std_errs = std_errs, type = type
p <- rank_plot(
object,
rank_metric = rank_metric,
metric = metric,
select_best = select_best,
std_errs = std_errs,
type = type
)
} else {
p <- autoplot(object$result[[which(object$wflow_id == id)]], metric = metric, ...)
p <- autoplot(
object$result[[which(object$wflow_id == id)]],
metric = metric,
...
)
}
p
}

rank_plot <- function(object, rank_metric = NULL, metric = NULL,
select_best = FALSE, std_errs = 1, type = "class") {
rank_plot <- function(
object,
rank_metric = NULL,
metric = NULL,
select_best = FALSE,
std_errs = 1,
type = "class"
) {
metric_info <- pick_metric(object, rank_metric, metric)
metrics <- collate_metrics(object)
res <- rank_results(object, rank_metric = metric_info$metric, select_best = select_best)
res <- rank_results(
object,
rank_metric = metric_info$metric,
select_best = select_best
)

if (!is.null(metric)) {
keep_metrics <- unique(c(rank_metric, metric))
Expand All @@ -82,13 +104,12 @@ rank_plot <- function(object, rank_metric = NULL, metric = NULL,
has_std_error <- !all(is.na(res$std_err))

p <-
switch(type,
class =
ggplot(res, aes(x = rank, y = mean, col = model)) +
geom_point(aes(shape = preprocessor)),
wflow_id =
ggplot(res, aes(x = rank, y = mean, col = wflow_id)) +
geom_point()
switch(
type,
class = ggplot(res, aes(x = rank, y = mean, col = model)) +
geom_point(aes(shape = preprocessor)),
wflow_id = ggplot(res, aes(x = rank, y = mean, col = wflow_id)) +
geom_point()
)

if (num_metrics > 1) {
Expand Down
Loading