Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ _cache$
^\.vscode$
^[.]?air[.]toml$
^\.claude$
^\.positai$
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ tests/testthat/Rplots.pdf
docs
revdep
tests/testthat/_snaps/notes.new.md
.positai
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# tune (development version)

## Bug Fixes

* Resampling and tuning would fail for quantile regression models if they passed a quantile regression metric (#1186)

# tune 2.1.0

* Model tuning has been enabled for quantile regression models. (#1125)
Expand Down
23 changes: 15 additions & 8 deletions R/metric-selection.R
Original file line number Diff line number Diff line change
Expand Up @@ -328,23 +328,30 @@ check_metrics_arg <- function(mtr_set, wflow, ..., call = rlang::caller_env()) {
)

return(mtr_set)
} else {
Comment thread
topepo marked this conversation as resolved.
Outdated
if (!inherits(mtr_set, "metric_set")) {
cli::cli_abort(
"The {.arg metrics} argument should have class {.cls metric_set}, not {.cls {class(mtr_set)}}.",
call = call
)
}
}

is_numeric_metric_set <- inherits(mtr_set, "numeric_metric_set")
is_class_prob_metric_set <- inherits(mtr_set, "class_prob_metric_set")
is_surv_metric_set <- inherits(mtr_set, c("survival_metric_set"))
is_qnt_metric_set <- inherits(mtr_set, c("quantile_metric_set"))

if (
!is_numeric_metric_set && !is_class_prob_metric_set && !is_surv_metric_set
) {
if (mode == "regression" && !is_numeric_metric_set) {
cli::cli_abort(
"The {.arg metrics} argument should be the results of
{.fn yardstick::metric_set}.",
"The parsnip model has {.code mode} value of {.val {mode}},
but the {.arg metrics} is a metric set for a
different model mode.",
Comment thread
topepo marked this conversation as resolved.
Outdated
call = call
)
}

if (mode == "regression" && !is_numeric_metric_set) {
if (mode == "classification" && !is_class_prob_metric_set) {
cli::cli_abort(
"The parsnip model has {.code mode} value of {.val {mode}},
but the {.arg metrics} is a metric set for a
Expand All @@ -353,7 +360,7 @@ check_metrics_arg <- function(mtr_set, wflow, ..., call = rlang::caller_env()) {
)
}

if (mode == "classification" && !is_class_prob_metric_set) {
if (mode == "censored regression" && !is_surv_metric_set) {
cli::cli_abort(
"The parsnip model has {.code mode} value of {.val {mode}},
but the {.arg metrics} is a metric set for a
Expand All @@ -362,7 +369,7 @@ check_metrics_arg <- function(mtr_set, wflow, ..., call = rlang::caller_env()) {
)
}

if (mode == "censored regression" && !is_surv_metric_set) {
if (mode == "quantile regression" && !is_qnt_metric_set) {
cli::cli_abort(
"The parsnip model has {.code mode} value of {.val {mode}},
but the {.arg metrics} is a metric set for a
Expand Down
32 changes: 32 additions & 0 deletions tests/testthat/_snaps/metric-args.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
- `rmse()`, a numeric metric | direction: minimize
- `rsq()`, a numeric metric | direction: maximize

---

Code
check_metrics_arg(rmse, wflow)
Condition
Error:
! The `metrics` argument should have class <metric_set>, not <numeric_metric/metric/function>.

---

Code
Expand Down Expand Up @@ -289,3 +297,27 @@
Error in `last_fit()`:
! The parsnip model has `mode` value of "censored regression", but the `metrics` is a metric set for a different model mode.

# metric inputs are checked for quantile regression models

Code
check_metrics_arg(NULL, wflow)
Output
A metric set, consisting of:
- `weighted_interval_score()`, a quantile metric | direction: minimize

---

Code
check_metrics_arg(metric_set(rmse), wflow)
Condition
Error:
! The parsnip model has `mode` value of "quantile regression", but the `metrics` is a metric set for a different model mode.
Comment thread
topepo marked this conversation as resolved.
Outdated

---

Code
check_metrics_arg(metric_set(weighted_interval_score), wflow)
Output
A metric set, consisting of:
- `weighted_interval_score()`, a quantile metric | direction: minimize

17 changes: 17 additions & 0 deletions tests/testthat/test-metric-args.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ test_that("metric inputs are checked for regression models", {
metric_set(brier_survival_integrated, brier_survival, concordance_survival)
met_reg <- metric_set(rmse)
met_cls <- metric_set(brier_class)
met_qnt <- metric_set(weighted_interval_score)

# ------------------------------------------------------------------------------
# check inputs

expect_snapshot(check_metrics_arg(NULL, wflow))

expect_snapshot(check_metrics_arg(rmse, wflow), error = TRUE)

expect_snapshot(check_metrics_arg(met_reg, wflow))
expect_snapshot(check_metrics_arg(met_cls, wflow), error = TRUE)
expect_snapshot(check_metrics_arg(met_mix_int, wflow), error = TRUE)
Expand Down Expand Up @@ -180,3 +183,17 @@ test_that("metric inputs are checked for censored regression models", {
expect_snapshot(last_fit(wflow, split, metrics = met_cls), error = TRUE)
expect_snapshot(last_fit(wflow, split, metrics = met_reg), error = TRUE)
})

test_that("metric inputs are checked for quantile regression models", {
wflow <- workflow(
y ~ x,
linear_reg(engine = "quantreg", mode = "quantile regression")
)

expect_snapshot(check_metrics_arg(NULL, wflow))

expect_snapshot(check_metrics_arg(metric_set(rmse), wflow), error = TRUE)
expect_snapshot(
check_metrics_arg(metric_set(weighted_interval_score), wflow)
)
})
Loading