From 6eaad66b896211d38374e4ee914d1120b56d03f6 Mon Sep 17 00:00:00 2001 From: dsweber2 Date: Fri, 7 Feb 2025 13:11:03 -0600 Subject: [PATCH 01/12] adding 0.5 as a default canned quantile --- R/arx_forecaster.R | 2 +- R/autoplot.R | 13 +- R/flatline_forecaster.R | 2 +- R/layer_quantile_distn.R | 2 +- R/layer_residual_quantiles.R | 2 +- man/arx_args_list.Rd | 2 +- man/flatline_args_list.Rd | 2 +- man/layer_quantile_distn.Rd | 2 +- man/layer_residual_quantiles.Rd | 2 +- man/step_adjust_latency.Rd | 4 +- tests/testthat/_snaps/snapshots.md | 262 ++++++++++++++-------------- tests/testthat/test-arx_args_list.R | 2 +- 12 files changed, 153 insertions(+), 144 deletions(-) diff --git a/R/arx_forecaster.R b/R/arx_forecaster.R index b386fe456..09e009cd4 100644 --- a/R/arx_forecaster.R +++ b/R/arx_forecaster.R @@ -295,7 +295,7 @@ arx_args_list <- function( target_date = NULL, adjust_latency = c("none", "extend_ahead", "extend_lags", "locf"), warn_latency = TRUE, - quantile_levels = c(0.05, 0.95), + quantile_levels = c(0.05, 0.5, 0.95), symmetrize = TRUE, nonneg = TRUE, quantile_by_key = character(0L), diff --git a/R/autoplot.R b/R/autoplot.R index 870dcb8d8..65e68f6c1 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -236,19 +236,20 @@ plot_bands <- function( alpha = 0.6, linewidth = 0.05) { innames <- names(predictions) - n <- length(levels) - alpha <- alpha / (n - 1) - l <- (1 - levels) / 2 - l <- c(rev(l), 1 - l) + na_levels <- length(levels) + alpha <- alpha / (n_levels - 1) + # generate the corresponding level that is 1 - level + levels <- (1 - levels) / 2 + levels <- c(rev(levels), 1 - levels) ntarget_dates <- dplyr::n_distinct(predictions$time_value) predictions <- predictions %>% - mutate(.pred_distn = dist_quantiles(quantile(.pred_distn, l), l)) %>% + mutate(.pred_distn = dist_quantiles(quantile(.pred_distn, levels), levels)) %>% pivot_quantiles_wider(.pred_distn) qnames <- setdiff(names(predictions), innames) - for (i in 1:n) { + for (i in 1:n_levels) { bottom <- qnames[i] top <- rev(qnames)[i] if (i == 1) { diff --git a/R/flatline_forecaster.R b/R/flatline_forecaster.R index 7efda3efd..b3578b1be 100644 --- a/R/flatline_forecaster.R +++ b/R/flatline_forecaster.R @@ -108,7 +108,7 @@ flatline_args_list <- function( n_training = Inf, forecast_date = NULL, target_date = NULL, - quantile_levels = c(0.05, 0.95), + quantile_levels = c(0.05, 0.5, 0.95), symmetrize = TRUE, nonneg = TRUE, quantile_by_key = character(0L), diff --git a/R/layer_quantile_distn.R b/R/layer_quantile_distn.R index f7bc9259d..2f0357bf5 100644 --- a/R/layer_quantile_distn.R +++ b/R/layer_quantile_distn.R @@ -44,7 +44,7 @@ #' p layer_quantile_distn <- function(frosting, ..., - quantile_levels = c(.25, .75), + quantile_levels = c(0.25, 0.5, 0.75), truncate = c(-Inf, Inf), name = ".pred_distn", id = rand_id("quantile_distn")) { diff --git a/R/layer_residual_quantiles.R b/R/layer_residual_quantiles.R index e9b5b7c19..624932ce6 100644 --- a/R/layer_residual_quantiles.R +++ b/R/layer_residual_quantiles.R @@ -48,7 +48,7 @@ #' p2 <- forecast(wf2) layer_residual_quantiles <- function( frosting, ..., - quantile_levels = c(0.05, 0.95), + quantile_levels = c(0.05, 0.5, 0.95), symmetrize = TRUE, by_key = character(0L), name = ".pred_distn", diff --git a/man/arx_args_list.Rd b/man/arx_args_list.Rd index f28cdefab..7a9cd592b 100644 --- a/man/arx_args_list.Rd +++ b/man/arx_args_list.Rd @@ -12,7 +12,7 @@ arx_args_list( target_date = NULL, adjust_latency = c("none", "extend_ahead", "extend_lags", "locf"), warn_latency = TRUE, - quantile_levels = c(0.05, 0.95), + quantile_levels = c(0.05, 0.5, 0.95), symmetrize = TRUE, nonneg = TRUE, quantile_by_key = character(0L), diff --git a/man/flatline_args_list.Rd b/man/flatline_args_list.Rd index 401850efe..634d30e5d 100644 --- a/man/flatline_args_list.Rd +++ b/man/flatline_args_list.Rd @@ -9,7 +9,7 @@ flatline_args_list( n_training = Inf, forecast_date = NULL, target_date = NULL, - quantile_levels = c(0.05, 0.95), + quantile_levels = c(0.05, 0.5, 0.95), symmetrize = TRUE, nonneg = TRUE, quantile_by_key = character(0L), diff --git a/man/layer_quantile_distn.Rd b/man/layer_quantile_distn.Rd index 3a5cb60e2..a0de7b669 100644 --- a/man/layer_quantile_distn.Rd +++ b/man/layer_quantile_distn.Rd @@ -7,7 +7,7 @@ layer_quantile_distn( frosting, ..., - quantile_levels = c(0.25, 0.75), + quantile_levels = c(0.25, 0.5, 0.75), truncate = c(-Inf, Inf), name = ".pred_distn", id = rand_id("quantile_distn") diff --git a/man/layer_residual_quantiles.Rd b/man/layer_residual_quantiles.Rd index a7deded71..40e7b0303 100644 --- a/man/layer_residual_quantiles.Rd +++ b/man/layer_residual_quantiles.Rd @@ -7,7 +7,7 @@ layer_residual_quantiles( frosting, ..., - quantile_levels = c(0.05, 0.95), + quantile_levels = c(0.05, 0.5, 0.95), symmetrize = TRUE, by_key = character(0L), name = ".pred_distn", diff --git a/man/step_adjust_latency.Rd b/man/step_adjust_latency.Rd index 0078de100..53098504e 100644 --- a/man/step_adjust_latency.Rd +++ b/man/step_adjust_latency.Rd @@ -267,8 +267,8 @@ while this will not: \if{html}{\out{
}}\preformatted{toy_recipe <- epi_recipe(toy_df) \%>\% step_epi_lag(a, lag=0) \%>\% step_adjust_latency(a, method = "extend_lags") -#> Warning: If `method` is "extend_lags" or "locf", then the previous `step_epi_lag`s won't -#> work with modified data. +#> Warning: If `method` is "extend_lags" or "locf", then the previous `step_epi_lag`s won't work with +#> modified data. }\if{html}{\out{
}} If you create columns that you then apply lags to (such as diff --git a/tests/testthat/_snaps/snapshots.md b/tests/testthat/_snaps/snapshots.md index 9fd339ded..213f6b7be 100644 --- a/tests/testthat/_snaps/snapshots.md +++ b/tests/testthat/_snaps/snapshots.md @@ -3,24 +3,24 @@ structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" ), .pred = c(0.1393442, 0.103199, 0.3121244, 0.4218461, 0.7319844, 0.1975426), .pred_distn = structure(list(structure(list(values = c(0, - 0.34820911), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + 0.1393442, 0.34820911), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0, 0.31206391), quantile_levels = c(0.05, 0.95 - )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr")), structure(list(values = c(0.10325949, 0.52098931 - ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + values = c(0, 0.103199, 0.31206391), quantile_levels = c(0.05, + 0.5, 0.95)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.10325949, + 0.3121244, 0.52098931), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.21298119, 0.63071101), quantile_levels = c(0.05, - 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr")), structure(list(values = c(0.52311949, 0.94084931 - ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + values = c(0.21298119, 0.4218461, 0.63071101), quantile_levels = c(0.05, + 0.5, 0.95)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.52311949, + 0.7319844, 0.94084931), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0, 0.40640751), quantile_levels = c(0.05, 0.95 - )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", "list" - )), forecast_date = structure(c(18992, 18992, 18992, 18992, 18992, - 18992), class = "Date"), target_date = structure(c(18999, 18999, - 18999, 18999, 18999, 18999), class = "Date")), row.names = c(NA, + values = c(0, 0.1975426, 0.40640751), quantile_levels = c(0.05, + 0.5, 0.95)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", + "list")), forecast_date = structure(c(18992, 18992, 18992, 18992, + 18992, 18992), class = "Date"), target_date = structure(c(18999, + 18999, 18999, 18999, 18999, 18999), class = "Date")), row.names = c(NA, -6L), class = c("tbl_df", "tbl", "data.frame")) --- @@ -28,24 +28,24 @@ structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" ), .pred = c(0.1393442, 0.103199, 0.3121244, 0.4218461, 0.7319844, 0.1975426), .pred_distn = structure(list(structure(list(values = c(0.084583345, - 0.194105055), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + 0.1393442, 0.194105055), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.048438145, 0.157959855), quantile_levels = c(0.05, - 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr")), structure(list(values = c(0.257363545, 0.366885255 - ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + values = c(0.048438145, 0.103199, 0.157959855), quantile_levels = c(0.05, + 0.5, 0.95)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.257363545, + 0.3121244, 0.366885255), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.367085245, 0.476606955), quantile_levels = c(0.05, - 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr")), structure(list(values = c(0.677223545, 0.786745255 - ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + values = c(0.367085245, 0.4218461, 0.476606955), quantile_levels = c(0.05, + 0.5, 0.95)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.677223545, + 0.7319844, 0.786745255), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.142781745, 0.252303455), quantile_levels = c(0.05, - 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", "list" - )), forecast_date = structure(c(18992, 18992, 18992, 18992, 18992, - 18992), class = "Date"), target_date = structure(c(18993, 18993, - 18993, 18993, 18993, 18993), class = "Date")), row.names = c(NA, + values = c(0.142781745, 0.1975426, 0.252303455), quantile_levels = c(0.05, + 0.5, 0.95)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", + "list")), forecast_date = structure(c(18992, 18992, 18992, 18992, + 18992, 18992), class = "Date"), target_date = structure(c(18993, + 18993, 18993, 18993, 18993, 18993), class = "Date")), row.names = c(NA, -6L), class = c("tbl_df", "tbl", "data.frame")) --- @@ -53,24 +53,24 @@ structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" ), .pred = c(0.1393442, 0.103199, 0.3121244, 0.4218461, 0.7319844, 0.1975426), .pred_distn = structure(list(structure(list(values = c(0, - 0.34820911), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + 0.1393442, 0.34820911), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0, 0.31206391), quantile_levels = c(0.05, 0.95 - )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr")), structure(list(values = c(0.10325949, 0.52098931 - ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + values = c(0, 0.103199, 0.31206391), quantile_levels = c(0.05, + 0.5, 0.95)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.10325949, + 0.3121244, 0.52098931), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.21298119, 0.63071101), quantile_levels = c(0.05, - 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr")), structure(list(values = c(0.52311949, 0.94084931 - ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + values = c(0.21298119, 0.4218461, 0.63071101), quantile_levels = c(0.05, + 0.5, 0.95)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.52311949, + 0.7319844, 0.94084931), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0, 0.40640751), quantile_levels = c(0.05, 0.95 - )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", "list" - )), forecast_date = structure(c(18992, 18992, 18992, 18992, 18992, - 18992), class = "Date"), target_date = structure(c(18999, 18999, - 18999, 18999, 18999, 18999), class = "Date")), row.names = c(NA, + values = c(0, 0.1975426, 0.40640751), quantile_levels = c(0.05, + 0.5, 0.95)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", + "list")), forecast_date = structure(c(18992, 18992, 18992, 18992, + 18992, 18992), class = "Date"), target_date = structure(c(18999, + 18999, 18999, 18999, 18999, 18999), class = "Date")), row.names = c(NA, -6L), class = c("tbl_df", "tbl", "data.frame")) --- @@ -78,24 +78,24 @@ structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" ), .pred = c(0.1393442, 0.103199, 0.3121244, 0.4218461, 0.7319844, 0.1975426), .pred_distn = structure(list(structure(list(values = c(0, - 0.34820911), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + 0.1393442, 0.34820911), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0, 0.31206391), quantile_levels = c(0.05, 0.95 - )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr")), structure(list(values = c(0.10325949, 0.52098931 - ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + values = c(0, 0.103199, 0.31206391), quantile_levels = c(0.05, + 0.5, 0.95)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.10325949, + 0.3121244, 0.52098931), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.21298119, 0.63071101), quantile_levels = c(0.05, - 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr")), structure(list(values = c(0.52311949, 0.94084931 - ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + values = c(0.21298119, 0.4218461, 0.63071101), quantile_levels = c(0.05, + 0.5, 0.95)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.52311949, + 0.7319844, 0.94084931), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0, 0.40640751), quantile_levels = c(0.05, 0.95 - )), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", "list" - )), forecast_date = structure(c(18992, 18992, 18992, 18992, 18992, - 18992), class = "Date"), target_date = structure(c(18993, 18993, - 18993, 18993, 18993, 18993), class = "Date")), row.names = c(NA, + values = c(0, 0.1975426, 0.40640751), quantile_levels = c(0.05, + 0.5, 0.95)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", + "list")), forecast_date = structure(c(18992, 18992, 18992, 18992, + 18992, 18992), class = "Date"), target_date = structure(c(18993, + 18993, 18993, 18993, 18993, 18993), class = "Date")), row.names = c(NA, -6L), class = c("tbl_df", "tbl", "data.frame")) # cdc_baseline_forecaster snapshots @@ -981,26 +981,28 @@ # arx_forecaster snapshots structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" - ), .pred = c(0.353013358779435, 0.648525432444877, 0.667670289394328, - 1.1418673907239, 0.830448695683587, 0.329799431948649), .pred_distn = structure(list( - structure(list(values = c(0.171022956902288, 0.535003760656582 - ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + ), .pred = c(0.353013358779434, 0.648525432444876, 0.667670289394327, + 1.1418673907239, 0.830448695683588, 0.329799431948648), .pred_distn = structure(list( + structure(list(values = c(0.171022956902287, 0.353013358779434, + 0.535003760656581), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.46653503056773, 0.830515834322024), quantile_levels = c(0.05, - 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr")), structure(list(values = c(0.485679887517181, - 0.849660691271475), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + values = c(0.466535030567729, 0.648525432444876, 0.830515834322023 + ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.959876988846753, 1.32385779260105), quantile_levels = c(0.05, - 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr")), structure(list(values = c(0.64845829380644, - 1.01243909756073), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + values = c(0.48567988751718, 0.667670289394327, 0.849660691271474 + ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.147809030071502, 0.511789833825796), quantile_levels = c(0.05, - 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", - "list")), forecast_date = structure(c(18992, 18992, 18992, 18992, - 18992, 18992), class = "Date"), target_date = structure(c(18999, + values = c(0.959876988846751, 1.1418673907239, 1.32385779260105 + ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.648458293806441, 0.830448695683588, 1.01243909756074 + ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.147809030071501, 0.329799431948648, 0.511789833825795 + ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", + "vctrs_vctr", "list")), forecast_date = structure(c(18992, 18992, + 18992, 18992, 18992, 18992), class = "Date"), target_date = structure(c(18999, 18999, 18999, 18999, 18999, 18999), class = "Date")), row.names = c(NA, -6L), class = c("tbl_df", "tbl", "data.frame")) @@ -1008,77 +1010,83 @@ structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" ), .pred = c(0.149303403634372, 0.139764664505947, 0.333186321066645, - 0.470345577837143, 0.725986105412007, 0.212686665274007), .pred_distn = structure(list( - structure(list(values = c(0.0961118191398633, 0.202494988128882 - ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + 0.470345577837144, 0.725986105412009, 0.212686665274006), .pred_distn = structure(list( + structure(list(values = c(0.0961118191398626, 0.149303403634372, + 0.202494988128881), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.0865730800114382, 0.192956249000457), quantile_levels = c(0.05, - 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr")), structure(list(values = c(0.279994736572135, - 0.386377905561154), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + values = c(0.0865730800114375, 0.139764664505947, 0.192956249000456 + ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.417153993342634, 0.523537162331653), quantile_levels = c(0.05, - 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr")), structure(list(values = c(0.672794520917498, - 0.779177689906517), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + values = c(0.279994736572135, 0.333186321066645, 0.386377905561154 + ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.159495080779498, 0.265878249768516), quantile_levels = c(0.05, - 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", - "list")), forecast_date = structure(c(18992, 18992, 18992, 18992, - 18992, 18992), class = "Date"), target_date = structure(c(18993, + values = c(0.417153993342635, 0.470345577837144, 0.523537162331653 + ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.6727945209175, 0.725986105412009, 0.779177689906518 + ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.159495080779497, 0.212686665274006, 0.265878249768516 + ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", + "vctrs_vctr", "list")), forecast_date = structure(c(18992, 18992, + 18992, 18992, 18992, 18992), class = "Date"), target_date = structure(c(18993, 18993, 18993, 18993, 18993, 18993), class = "Date")), row.names = c(NA, -6L), class = c("tbl_df", "tbl", "data.frame")) --- structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" - ), .pred = c(0.303244704017742, 0.531332853311081, 0.588827944685979, - 0.98869024921623, 0.79480199700164, 0.306895457225321), .pred_distn = structure(list( - structure(list(values = c(0.136509784083987, 0.469979623951498 - ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + ), .pred = c(0.303244704017743, 0.531332853311082, 0.588827944685979, + 0.988690249216229, 0.794801997001639, 0.306895457225321), .pred_distn = structure(list( + structure(list(values = c(0.136509784083987, 0.303244704017743, + 0.469979623951498), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.364597933377326, 0.531332853311082, 0.698067773244837 + ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.364597933377326, 0.698067773244837), quantile_levels = c(0.05, - 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr")), structure(list(values = c(0.422093024752224, - 0.755562864619735), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + values = c(0.422093024752224, 0.588827944685979, 0.755562864619735 + ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.821955329282475, 1.15542516914999), quantile_levels = c(0.05, - 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr")), structure(list(values = c(0.628067077067884, - 0.961536916935395), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + values = c(0.821955329282474, 0.988690249216229, 1.15542516914998 + ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.140160537291565, 0.473630377159077), quantile_levels = c(0.05, - 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", - "list")), forecast_date = structure(c(18997, 18997, 18997, 18997, - 18997, 18997), class = "Date"), target_date = structure(c(18998, + values = c(0.628067077067883, 0.794801997001639, 0.961536916935394 + ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.140160537291566, 0.306895457225321, 0.473630377159077 + ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", + "vctrs_vctr", "list")), forecast_date = structure(c(18997, 18997, + 18997, 18997, 18997, 18997), class = "Date"), target_date = structure(c(18998, 18998, 18998, 18998, 18998, 18998), class = "Date")), row.names = c(NA, -6L), class = c("tbl_df", "tbl", "data.frame")) --- structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" - ), .pred = c(0.303244704017742, 0.531332853311081, 0.588827944685979, - 0.98869024921623, 0.79480199700164, 0.306895457225321), .pred_distn = structure(list( - structure(list(values = c(0.136509784083987, 0.469979623951498 - ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + ), .pred = c(0.303244704017743, 0.531332853311082, 0.588827944685979, + 0.988690249216229, 0.794801997001639, 0.306895457225321), .pred_distn = structure(list( + structure(list(values = c(0.136509784083987, 0.303244704017743, + 0.469979623951498), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.364597933377326, 0.531332853311082, 0.698067773244837 + ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.422093024752224, 0.588827944685979, 0.755562864619735 + ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.364597933377326, 0.698067773244837), quantile_levels = c(0.05, - 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr")), structure(list(values = c(0.422093024752224, - 0.755562864619735), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + values = c(0.821955329282474, 0.988690249216229, 1.15542516914998 + ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.821955329282475, 1.15542516914999), quantile_levels = c(0.05, - 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr")), structure(list(values = c(0.628067077067884, - 0.961536916935395), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + values = c(0.628067077067883, 0.794801997001639, 0.961536916935394 + ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.140160537291565, 0.473630377159077), quantile_levels = c(0.05, - 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", - "list")), forecast_date = structure(c(18997, 18997, 18997, 18997, - 18997, 18997), class = "Date"), target_date = structure(c(18998, + values = c(0.140160537291566, 0.306895457225321, 0.473630377159077 + ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", + "vctrs_vctr", "list")), forecast_date = structure(c(18997, 18997, + 18997, 18997, 18997, 18997), class = "Date"), target_date = structure(c(18998, 18998, 18998, 18998, 18998, 18998), class = "Date")), row.names = c(NA, -6L), class = c("tbl_df", "tbl", "data.frame")) diff --git a/tests/testthat/test-arx_args_list.R b/tests/testthat/test-arx_args_list.R index 226444e37..8102c31d8 100644 --- a/tests/testthat/test-arx_args_list.R +++ b/tests/testthat/test-arx_args_list.R @@ -41,7 +41,7 @@ test_that("arx forecaster disambiguates quantiles", { tlist <- eval(formals(quantile_reg)$quantile_levels) expect_identical( # both default compare_quantile_args(alist, tlist), - sort(c(alist, tlist)) + c(0.05, 0.5, 0.95) ) expect_snapshot( error = TRUE, From 5126b72bd68cc0a18502320a883e3cd11256c7d2 Mon Sep 17 00:00:00 2001 From: dsweber2 Date: Fri, 7 Feb 2025 13:12:39 -0600 Subject: [PATCH 02/12] news & description --- DESCRIPTION | 2 +- NEWS.md | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 1b52822a5..e45633b0e 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: epipredict Title: Basic epidemiology forecasting methods -Version: 0.1.7 +Version: 0.1.8 Authors@R: c( person("Daniel J.", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")), person("Ryan", "Tibshirani", , "ryantibs@cmu.edu", role = "aut"), diff --git a/NEWS.md b/NEWS.md index 2f138fd96..84aaffadc 100644 --- a/NEWS.md +++ b/NEWS.md @@ -25,6 +25,7 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat - Shifting no columns results in no error for either `step_epi_ahead` and `step_epi_lag` - Quantiles produced by `grf` were sometimes out of order. - dist_quantiles can have all `NA` values without causing unrelated errors +- add `0.5` as a default quantile for canned forecasters to avoid strange thresholding behavior # epipredict 0.1 From aa121d342cde58404d57f8e2cc26c3b43edaa247 Mon Sep 17 00:00:00 2001 From: David Weber Date: Fri, 7 Feb 2025 12:21:51 -0800 Subject: [PATCH 03/12] Update R/autoplot.R Co-authored-by: Dmitry Shemetov --- R/autoplot.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/autoplot.R b/R/autoplot.R index 65e68f6c1..6761ce6f7 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -236,7 +236,7 @@ plot_bands <- function( alpha = 0.6, linewidth = 0.05) { innames <- names(predictions) - na_levels <- length(levels) + n_levels <- length(levels) alpha <- alpha / (n_levels - 1) # generate the corresponding level that is 1 - level levels <- (1 - levels) / 2 From 42a74f548ce03bedc6536365d86720018ea8e51b Mon Sep 17 00:00:00 2001 From: dsweber2 Date: Fri, 7 Feb 2025 14:54:31 -0600 Subject: [PATCH 04/12] autoplot doesn't extrapolate by default --- R/autoplot.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/autoplot.R b/R/autoplot.R index 6761ce6f7..8ae9226bf 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -85,7 +85,7 @@ NULL #' @rdname autoplot-epipred autoplot.epi_workflow <- function( object, predictions = NULL, - .levels = c(.5, .8, .95), ..., + .levels = c(.5, .8, .9), ..., .color_by = c("all_keys", "geo_value", "other_keys", ".response", "all", "none"), .facet_by = c(".response", "other_keys", "all_keys", "geo_value", "all", "none"), .base_color = "dodgerblue4", @@ -231,7 +231,7 @@ starts_with_impl <- function(x, vars) { plot_bands <- function( base_plot, predictions, - levels = c(.5, .8, .95), + levels = c(.5, .8, .9), fill = "blue4", alpha = 0.6, linewidth = 0.05) { From 7d31ee3bf02156f3f91f51177018f5797d8b3006 Mon Sep 17 00:00:00 2001 From: dsweber2 Date: Fri, 7 Feb 2025 15:28:15 -0600 Subject: [PATCH 05/12] docs mismatch, missing `dplyr::` --- R/autoplot.R | 2 +- man/autoplot-epipred.Rd | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/autoplot.R b/R/autoplot.R index 8ae9226bf..6b8a5d9c8 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -183,7 +183,7 @@ autoplot.epi_workflow <- function( } if (".pred" %in% names(predictions)) { - ntarget_dates <- n_distinct(predictions$time_value) + ntarget_dates <- dplyr::n_distinct(predictions$time_value) if (ntarget_dates > 1L) { bp <- bp + geom_line( diff --git a/man/autoplot-epipred.Rd b/man/autoplot-epipred.Rd index 1025759b3..c3b3e902b 100644 --- a/man/autoplot-epipred.Rd +++ b/man/autoplot-epipred.Rd @@ -9,7 +9,7 @@ \method{autoplot}{epi_workflow}( object, predictions = NULL, - .levels = c(0.5, 0.8, 0.95), + .levels = c(0.5, 0.8, 0.9), ..., .color_by = c("all_keys", "geo_value", "other_keys", ".response", "all", "none"), .facet_by = c(".response", "other_keys", "all_keys", "geo_value", "all", "none"), From 94de99d9f327d91a522007c60e7b73127b1f889c Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Fri, 7 Feb 2025 13:49:36 -0800 Subject: [PATCH 06/12] better defaults in layer_residual_quantiles, force 0.5 and document --- R/layer_residual_quantiles.R | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/R/layer_residual_quantiles.R b/R/layer_residual_quantiles.R index 624932ce6..09a374579 100644 --- a/R/layer_residual_quantiles.R +++ b/R/layer_residual_quantiles.R @@ -3,7 +3,8 @@ #' @param frosting a `frosting` postprocessor #' @param ... Unused, include for consistency with other layers. #' @param quantile_levels numeric vector of probabilities with values in (0,1) -#' referring to the desired quantile. +#' referring to the desired quantile. Note that 0.5 will always be included +#' even if left out by the user. #' @param symmetrize logical. If `TRUE` then interval will be symmetric. #' @param by_key A character vector of keys to group the residuals by before #' calculating quantiles. The default, `c()` performs no grouping. @@ -28,7 +29,7 @@ #' f <- frosting() %>% #' layer_predict() %>% #' layer_residual_quantiles( -#' quantile_levels = c(0.0275, 0.975), +#' quantile_levels = c(0.025, 0.975), #' symmetrize = FALSE #' ) %>% #' layer_naomit(.pred) @@ -48,7 +49,7 @@ #' p2 <- forecast(wf2) layer_residual_quantiles <- function( frosting, ..., - quantile_levels = c(0.05, 0.5, 0.95), + quantile_levels = c(0.025, 0.1, 0.25, 0.5, 0.75, 0.9, 0.975), symmetrize = TRUE, by_key = character(0L), name = ".pred_distn", @@ -59,6 +60,7 @@ layer_residual_quantiles <- function( arg_is_chr(by_key, allow_empty = TRUE) arg_is_probabilities(quantile_levels) arg_is_lgl(symmetrize) + quantile_levels <- sort(unique(c(0.5, quantile_levels))) add_layer( frosting, layer_residual_quantiles_new( From 521756db0c796b998abf966acb860a590292f6b7 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Fri, 7 Feb 2025 13:55:54 -0800 Subject: [PATCH 07/12] redocument, pass layer test --- man/layer_residual_quantiles.Rd | 7 ++++--- man/step_adjust_latency.Rd | 4 ++-- tests/testthat/test-layer_residual_quantiles.R | 8 ++++---- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/man/layer_residual_quantiles.Rd b/man/layer_residual_quantiles.Rd index 40e7b0303..96a9da56d 100644 --- a/man/layer_residual_quantiles.Rd +++ b/man/layer_residual_quantiles.Rd @@ -7,7 +7,7 @@ layer_residual_quantiles( frosting, ..., - quantile_levels = c(0.05, 0.5, 0.95), + quantile_levels = c(0.025, 0.1, 0.25, 0.5, 0.75, 0.9, 0.975), symmetrize = TRUE, by_key = character(0L), name = ".pred_distn", @@ -20,7 +20,8 @@ layer_residual_quantiles( \item{...}{Unused, include for consistency with other layers.} \item{quantile_levels}{numeric vector of probabilities with values in (0,1) -referring to the desired quantile.} +referring to the desired quantile. Note that 0.5 will always be included +even if left out by the user.} \item{symmetrize}{logical. If \code{TRUE} then interval will be symmetric.} @@ -53,7 +54,7 @@ wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) f <- frosting() \%>\% layer_predict() \%>\% layer_residual_quantiles( - quantile_levels = c(0.0275, 0.975), + quantile_levels = c(0.025, 0.975), symmetrize = FALSE ) \%>\% layer_naomit(.pred) diff --git a/man/step_adjust_latency.Rd b/man/step_adjust_latency.Rd index 53098504e..0078de100 100644 --- a/man/step_adjust_latency.Rd +++ b/man/step_adjust_latency.Rd @@ -267,8 +267,8 @@ while this will not: \if{html}{\out{
}}\preformatted{toy_recipe <- epi_recipe(toy_df) \%>\% step_epi_lag(a, lag=0) \%>\% step_adjust_latency(a, method = "extend_lags") -#> Warning: If `method` is "extend_lags" or "locf", then the previous `step_epi_lag`s won't work with -#> modified data. +#> Warning: If `method` is "extend_lags" or "locf", then the previous `step_epi_lag`s won't +#> work with modified data. }\if{html}{\out{
}} If you create columns that you then apply lags to (such as diff --git a/tests/testthat/test-layer_residual_quantiles.R b/tests/testthat/test-layer_residual_quantiles.R index 12e44809e..498b80f95 100644 --- a/tests/testthat/test-layer_residual_quantiles.R +++ b/tests/testthat/test-layer_residual_quantiles.R @@ -26,8 +26,8 @@ test_that("Returns expected number or rows and columns", { nested <- p %>% dplyr::mutate(.quantiles = nested_quantiles(.pred_distn)) unnested <- nested %>% tidyr::unnest(.quantiles) - expect_equal(nrow(unnested), 9L) - expect_equal(unique(unnested$quantile_levels), c(.0275, .8, .95)) + expect_equal(nrow(unnested), 12L) + expect_equal(unique(unnested$quantile_levels), c(.0275, .5, .8, .95)) }) @@ -65,9 +65,9 @@ test_that("Grouping by keys is supported", { expect_warning(p2 <- forecast(wf2)) pivot1 <- pivot_quantiles_wider(p1, .pred_distn) %>% - mutate(width = `0.95` - `0.05`) + mutate(width = `0.9` - `0.1`) pivot2 <- pivot_quantiles_wider(p2, .pred_distn) %>% - mutate(width = `0.95` - `0.05`) + mutate(width = `0.9` - `0.1`) expect_equal(pivot1$width, rep(pivot1$width[1], nrow(pivot1))) expect_false(all(pivot2$width == pivot2$width[1])) }) From f9c1d3294338ac8be84cbdd6e007d5f755f8e466 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Fri, 7 Feb 2025 14:12:02 -0800 Subject: [PATCH 08/12] default quantiles are the same everywhere --- R/arx_forecaster.R | 8 ++++---- R/autoplot.R | 4 +--- R/extract.R | 2 +- R/flatline_forecaster.R | 2 +- R/layer_quantile_distn.R | 2 +- R/layer_residual_quantiles.R | 2 +- R/make_grf_quantiles.R | 2 +- R/make_quantile_reg.R | 6 ++++-- R/make_smooth_quantile_reg.R | 9 ++------- tests/testthat/test-layer_residual_quantiles.R | 4 ++-- vignettes/epipredict.Rmd | 7 ++++--- vignettes/panel-data.Rmd | 1 - vignettes/preprocessing-and-models.Rmd | 2 +- 13 files changed, 23 insertions(+), 28 deletions(-) diff --git a/R/arx_forecaster.R b/R/arx_forecaster.R index 09e009cd4..228fa1217 100644 --- a/R/arx_forecaster.R +++ b/R/arx_forecaster.R @@ -253,8 +253,8 @@ arx_fcast_epi_workflow <- function( #' the last day of data. For example, if the last day of data was 3 days ago, #' the ahead becomes `ahead+3`. #' - `"extend_lags"`: increase the lags so they're relative to the actual -#' forecast date. For example, if the lags are `c(0,7,14)` and the last day of -#' data was 3 days ago, the lags become `c(3,10,17)`. +#' forecast date. For example, if the lags are `c(0, 7, 14)` and the last day of +#' data was 3 days ago, the lags become `c(3, 10, 17)`. #' @param warn_latency by default, `step_adjust_latency` warns the user if the #' latency is large. If this is `FALSE`, that warning is turned off. #' @param quantile_levels Vector or `NULL`. A vector of probabilities to produce @@ -295,7 +295,7 @@ arx_args_list <- function( target_date = NULL, adjust_latency = c("none", "extend_ahead", "extend_lags", "locf"), warn_latency = TRUE, - quantile_levels = c(0.05, 0.5, 0.95), + quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95), symmetrize = TRUE, nonneg = TRUE, quantile_by_key = character(0L), @@ -362,7 +362,7 @@ compare_quantile_args <- function(alist, tlist, train_method = c("qr", "grf")) { default_alist <- eval(formals(arx_args_list)$quantile_levels) default_tlist <- switch(train_method, "qr" = eval(formals(quantile_reg)$quantile_levels), - "grf" = c(.1, .5, .9) + "grf" = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95) ) if (setequal(alist, default_alist)) { if (setequal(tlist, default_tlist)) { diff --git a/R/autoplot.R b/R/autoplot.R index 6b8a5d9c8..c0e3c68dd 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -39,9 +39,7 @@ ggplot2::autoplot #' step_epi_naomit() #' #' f <- frosting() %>% -#' layer_residual_quantiles( -#' quantile_levels = c(.025, .1, .25, .75, .9, .975) -#' ) %>% +#' layer_residual_quantiles() %>% #' layer_threshold(starts_with(".pred")) %>% #' layer_add_target_date() #' diff --git a/R/extract.R b/R/extract.R index e227b59b1..2e06567e2 100644 --- a/R/extract.R +++ b/R/extract.R @@ -13,7 +13,7 @@ #' @examples #' f <- frosting() %>% #' layer_predict() %>% -#' layer_residual_quantiles(quantile_levels = c(0.0275, 0.975), symmetrize = FALSE) %>% +#' layer_residual_quantiles(symmetrize = FALSE) %>% #' layer_naomit(.pred) #' #' extract_argument(f, "layer_residual_quantiles", "symmetrize") diff --git a/R/flatline_forecaster.R b/R/flatline_forecaster.R index b3578b1be..7faad31b3 100644 --- a/R/flatline_forecaster.R +++ b/R/flatline_forecaster.R @@ -108,7 +108,7 @@ flatline_args_list <- function( n_training = Inf, forecast_date = NULL, target_date = NULL, - quantile_levels = c(0.05, 0.5, 0.95), + quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95), symmetrize = TRUE, nonneg = TRUE, quantile_by_key = character(0L), diff --git a/R/layer_quantile_distn.R b/R/layer_quantile_distn.R index 2f0357bf5..b39b58f4a 100644 --- a/R/layer_quantile_distn.R +++ b/R/layer_quantile_distn.R @@ -44,7 +44,7 @@ #' p layer_quantile_distn <- function(frosting, ..., - quantile_levels = c(0.25, 0.5, 0.75), + quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95), truncate = c(-Inf, Inf), name = ".pred_distn", id = rand_id("quantile_distn")) { diff --git a/R/layer_residual_quantiles.R b/R/layer_residual_quantiles.R index 09a374579..96ad88411 100644 --- a/R/layer_residual_quantiles.R +++ b/R/layer_residual_quantiles.R @@ -49,7 +49,7 @@ #' p2 <- forecast(wf2) layer_residual_quantiles <- function( frosting, ..., - quantile_levels = c(0.025, 0.1, 0.25, 0.5, 0.75, 0.9, 0.975), + quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95), symmetrize = TRUE, by_key = character(0L), name = ".pred_distn", diff --git a/R/make_grf_quantiles.R b/R/make_grf_quantiles.R index 00e7d0e71..fbd221d22 100644 --- a/R/make_grf_quantiles.R +++ b/R/make_grf_quantiles.R @@ -141,7 +141,7 @@ make_grf_quantiles <- function() { data = c(x = "X", y = "Y"), func = c(pkg = "grf", fun = "quantile_forest"), defaults = list( - quantiles = c(0.1, 0.5, 0.9), + quantiles = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95), num.threads = 1L, seed = rlang::expr(stats::runif(1, 0, .Machine$integer.max)) ) diff --git a/R/make_quantile_reg.R b/R/make_quantile_reg.R index 1388dd859..223881c85 100644 --- a/R/make_quantile_reg.R +++ b/R/make_quantile_reg.R @@ -12,7 +12,7 @@ #' @param engine Character string naming the fitting function. Currently, only #' "rq" and "grf" are supported. #' @param quantile_levels A scalar or vector of values in (0, 1) to determine which -#' quantiles to estimate (default is 0.5). +#' quantiles to estimate (default is the set 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95). #' @param method A fitting method used by [quantreg::rq()]. See the #' documentation for a list of options. #' @@ -27,7 +27,9 @@ #' rq_spec <- quantile_reg(quantile_levels = c(.2, .8)) %>% set_engine("rq") #' ff <- rq_spec %>% fit(y ~ ., data = tib) #' predict(ff, new_data = tib) -quantile_reg <- function(mode = "regression", engine = "rq", quantile_levels = 0.5, method = "br") { +quantile_reg <- function(mode = "regression", engine = "rq", + quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95), + method = "br") { # Check for correct mode if (mode != "regression") { cli_abort("`mode` must be 'regression'") diff --git a/R/make_smooth_quantile_reg.R b/R/make_smooth_quantile_reg.R index 448ee0fa5..f31ade3cb 100644 --- a/R/make_smooth_quantile_reg.R +++ b/R/make_smooth_quantile_reg.R @@ -5,12 +5,7 @@ #' the [tidymodels](https://www.tidymodels.org/) framework. Currently, the #' only supported engine is [smoothqr::smooth_qr()]. #' -#' @param mode A single character string for the type of model. -#' The only possible value for this model is "regression". -#' @param engine Character string naming the fitting function. Currently, only -#' "smooth_qr" is supported. -#' @param quantile_levels A scalar or vector of values in (0, 1) to determine which -#' quantiles to estimate (default is 0.5). +#' @inheritParams quantile_reg #' @param outcome_locations Defaults to the vector `1:ncol(y)` but if the #' responses are observed at a different spacing (or appear in a different #' order), that information should be used here. This @@ -76,7 +71,7 @@ smooth_quantile_reg <- function( mode = "regression", engine = "smoothqr", outcome_locations = NULL, - quantile_levels = 0.5, + quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95), degree = 3L) { # Check for correct mode if (mode != "regression") cli_abort("`mode` must be 'regression'") diff --git a/tests/testthat/test-layer_residual_quantiles.R b/tests/testthat/test-layer_residual_quantiles.R index 498b80f95..1736ded2d 100644 --- a/tests/testthat/test-layer_residual_quantiles.R +++ b/tests/testthat/test-layer_residual_quantiles.R @@ -65,9 +65,9 @@ test_that("Grouping by keys is supported", { expect_warning(p2 <- forecast(wf2)) pivot1 <- pivot_quantiles_wider(p1, .pred_distn) %>% - mutate(width = `0.9` - `0.1`) + mutate(width = `0.95` - `0.05`) pivot2 <- pivot_quantiles_wider(p2, .pred_distn) %>% - mutate(width = `0.9` - `0.1`) + mutate(width = `0.95` - `0.05`) expect_equal(pivot1$width, rep(pivot1$width[1], nrow(pivot1))) expect_false(all(pivot2$width == pivot2$width[1])) }) diff --git a/vignettes/epipredict.Rmd b/vignettes/epipredict.Rmd index 2cf7037c7..f5d57a071 100644 --- a/vignettes/epipredict.Rmd +++ b/vignettes/epipredict.Rmd @@ -208,7 +208,7 @@ quantiles. ```{r differential-levels} out_q <- arx_forecaster(jhu, "death_rate", c("case_rate", "death_rate"), args_list = arx_args_list( - quantile_levels = c(.01, .025, seq(.05, .95, by = .05), .975, .99) + quantile_levels = c(.01, .025, 1:19 / 20, .975, .99) ) ) ``` @@ -237,7 +237,8 @@ function: ```{r, eval = FALSE} arx_args_list( lags = c(0L, 7L, 14L), ahead = 7L, n_training = Inf, - forecast_date = NULL, target_date = NULL, quantile_levels = c(0.05, 0.95), + forecast_date = NULL, target_date = NULL, + quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95), symmetrize = TRUE, nonneg = TRUE, quantile_by_key = character(0L), nafill_buffer = Inf ) @@ -407,7 +408,7 @@ intervals at 0. The code to do this (inside the forecaster) is f <- frosting() %>% layer_predict() %>% layer_residual_quantiles( - quantile_levels = c(.01, .025, seq(.05, .95, by = .05), .975, .99), + quantile_levels = c(.01, .025, 1:19 / 20, .975, .99), symmetrize = TRUE ) %>% layer_add_forecast_date() %>% diff --git a/vignettes/panel-data.Rmd b/vignettes/panel-data.Rmd index 1faf5b56f..e99057897 100644 --- a/vignettes/panel-data.Rmd +++ b/vignettes/panel-data.Rmd @@ -364,7 +364,6 @@ f <- frosting() %>% layer_threshold(.pred, lower = 0) %>% # 90% prediction interval layer_residual_quantiles( - quantile_levels = c(0.1, 0.9), symmetrize = FALSE ) %>% layer_population_scaling( diff --git a/vignettes/preprocessing-and-models.Rmd b/vignettes/preprocessing-and-models.Rmd index 8d1d2f19f..6bff45611 100644 --- a/vignettes/preprocessing-and-models.Rmd +++ b/vignettes/preprocessing-and-models.Rmd @@ -381,7 +381,7 @@ f <- frosting() %>% df_pop_col = "pop" ) -wf <- epi_workflow(r, quantile_reg(quantile_levels = c(.05, .5, .95))) %>% +wf <- epi_workflow(r, quantile_reg()) %>% fit(jhu) %>% add_frosting(f) From 82a76743b4abeb5f839da0323e16d09d83d73f1e Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Fri, 7 Feb 2025 14:12:46 -0800 Subject: [PATCH 09/12] redocument --- man/arx_args_list.Rd | 6 +++--- man/arx_class_args_list.Rd | 4 ++-- man/autoplot-epipred.Rd | 4 +--- man/extract_argument.Rd | 2 +- man/flatline_args_list.Rd | 2 +- man/grf_quantiles.Rd | 4 ++-- man/layer_quantile_distn.Rd | 2 +- man/layer_residual_quantiles.Rd | 2 +- man/quantile_reg.Rd | 4 ++-- man/smooth_quantile_reg.Rd | 6 +++--- 10 files changed, 17 insertions(+), 19 deletions(-) diff --git a/man/arx_args_list.Rd b/man/arx_args_list.Rd index 7a9cd592b..5ece7109b 100644 --- a/man/arx_args_list.Rd +++ b/man/arx_args_list.Rd @@ -12,7 +12,7 @@ arx_args_list( target_date = NULL, adjust_latency = c("none", "extend_ahead", "extend_lags", "locf"), warn_latency = TRUE, - quantile_levels = c(0.05, 0.5, 0.95), + quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95), symmetrize = TRUE, nonneg = TRUE, quantile_by_key = character(0L), @@ -55,8 +55,8 @@ to shift the model to account for this difference. The options are: the last day of data. For example, if the last day of data was 3 days ago, the ahead becomes \code{ahead+3}. \item \code{"extend_lags"}: increase the lags so they're relative to the actual -forecast date. For example, if the lags are \code{c(0,7,14)} and the last day of -data was 3 days ago, the lags become \code{c(3,10,17)}. +forecast date. For example, if the lags are \code{c(0, 7, 14)} and the last day of +data was 3 days ago, the lags become \code{c(3, 10, 17)}. }} \item{warn_latency}{by default, \code{step_adjust_latency} warns the user if the diff --git a/man/arx_class_args_list.Rd b/man/arx_class_args_list.Rd index a229b67c0..dbf275355 100644 --- a/man/arx_class_args_list.Rd +++ b/man/arx_class_args_list.Rd @@ -57,8 +57,8 @@ to shift the model to account for this difference. The options are: the last day of data. For example, if the last day of data was 3 days ago, the ahead becomes \code{ahead+3}. \item \code{"extend_lags"}: increase the lags so they're relative to the actual -forecast date. For example, if the lags are \code{c(0,7,14)} and the last day of -data was 3 days ago, the lags become \code{c(3,10,17)}. +forecast date. For example, if the lags are \code{c(0, 7, 14)} and the last day of +data was 3 days ago, the lags become \code{c(3, 10, 17)}. }} \item{warn_latency}{by default, \code{step_adjust_latency} warns the user if the diff --git a/man/autoplot-epipred.Rd b/man/autoplot-epipred.Rd index c3b3e902b..066f55383 100644 --- a/man/autoplot-epipred.Rd +++ b/man/autoplot-epipred.Rd @@ -81,9 +81,7 @@ r <- epi_recipe(jhu) \%>\% step_epi_naomit() f <- frosting() \%>\% - layer_residual_quantiles( - quantile_levels = c(.025, .1, .25, .75, .9, .975) - ) \%>\% + layer_residual_quantiles() \%>\% layer_threshold(starts_with(".pred")) \%>\% layer_add_target_date() diff --git a/man/extract_argument.Rd b/man/extract_argument.Rd index 69c610c98..a276d59a6 100644 --- a/man/extract_argument.Rd +++ b/man/extract_argument.Rd @@ -24,7 +24,7 @@ Extract an argument made to a frosting layer or recipe step \examples{ f <- frosting() \%>\% layer_predict() \%>\% - layer_residual_quantiles(quantile_levels = c(0.0275, 0.975), symmetrize = FALSE) \%>\% + layer_residual_quantiles(symmetrize = FALSE) \%>\% layer_naomit(.pred) extract_argument(f, "layer_residual_quantiles", "symmetrize") diff --git a/man/flatline_args_list.Rd b/man/flatline_args_list.Rd index 634d30e5d..626bcb6f1 100644 --- a/man/flatline_args_list.Rd +++ b/man/flatline_args_list.Rd @@ -9,7 +9,7 @@ flatline_args_list( n_training = Inf, forecast_date = NULL, target_date = NULL, - quantile_levels = c(0.05, 0.5, 0.95), + quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95), symmetrize = TRUE, nonneg = TRUE, quantile_by_key = character(0L), diff --git a/man/grf_quantiles.Rd b/man/grf_quantiles.Rd index f6400edcf..ad33ce11a 100644 --- a/man/grf_quantiles.Rd +++ b/man/grf_quantiles.Rd @@ -52,8 +52,8 @@ details, see \href{https://grf-labs.github.io/grf/articles/categorical_inputs.ht #> Model fit template: #> grf::quantile_forest(X = missing_arg(), Y = missing_arg(), mtry = min_cols(~integer(1), #> x), num.trees = integer(1), min.node.size = min_rows(~integer(1), -#> x), quantiles = c(0.1, 0.5, 0.9), num.threads = 1L, seed = stats::runif(1, -#> 0, .Machine$integer.max)) +#> x), quantiles = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95), +#> num.threads = 1L, seed = stats::runif(1, 0, .Machine$integer.max)) }\if{html}{\out{}} } diff --git a/man/layer_quantile_distn.Rd b/man/layer_quantile_distn.Rd index a0de7b669..ed4762fa3 100644 --- a/man/layer_quantile_distn.Rd +++ b/man/layer_quantile_distn.Rd @@ -7,7 +7,7 @@ layer_quantile_distn( frosting, ..., - quantile_levels = c(0.25, 0.5, 0.75), + quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95), truncate = c(-Inf, Inf), name = ".pred_distn", id = rand_id("quantile_distn") diff --git a/man/layer_residual_quantiles.Rd b/man/layer_residual_quantiles.Rd index 96a9da56d..4efea525f 100644 --- a/man/layer_residual_quantiles.Rd +++ b/man/layer_residual_quantiles.Rd @@ -7,7 +7,7 @@ layer_residual_quantiles( frosting, ..., - quantile_levels = c(0.025, 0.1, 0.25, 0.5, 0.75, 0.9, 0.975), + quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95), symmetrize = TRUE, by_key = character(0L), name = ".pred_distn", diff --git a/man/quantile_reg.Rd b/man/quantile_reg.Rd index 5079c3434..31a5dd123 100644 --- a/man/quantile_reg.Rd +++ b/man/quantile_reg.Rd @@ -7,7 +7,7 @@ quantile_reg( mode = "regression", engine = "rq", - quantile_levels = 0.5, + quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95), method = "br" ) } @@ -19,7 +19,7 @@ The only possible value for this model is "regression".} "rq" and "grf" are supported.} \item{quantile_levels}{A scalar or vector of values in (0, 1) to determine which -quantiles to estimate (default is 0.5).} +quantiles to estimate (default is the set 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95).} \item{method}{A fitting method used by \code{\link[quantreg:rq]{quantreg::rq()}}. See the documentation for a list of options.} diff --git a/man/smooth_quantile_reg.Rd b/man/smooth_quantile_reg.Rd index c6b17dd86..90b2c104f 100644 --- a/man/smooth_quantile_reg.Rd +++ b/man/smooth_quantile_reg.Rd @@ -8,7 +8,7 @@ smooth_quantile_reg( mode = "regression", engine = "smoothqr", outcome_locations = NULL, - quantile_levels = 0.5, + quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95), degree = 3L ) } @@ -17,7 +17,7 @@ smooth_quantile_reg( The only possible value for this model is "regression".} \item{engine}{Character string naming the fitting function. Currently, only -"smooth_qr" is supported.} +"rq" and "grf" are supported.} \item{outcome_locations}{Defaults to the vector \code{1:ncol(y)} but if the responses are observed at a different spacing (or appear in a different @@ -25,7 +25,7 @@ order), that information should be used here. This argument will be mapped to the \code{ahead} argument of \code{\link[smoothqr:smooth_qr]{smoothqr::smooth_qr()}}.} \item{quantile_levels}{A scalar or vector of values in (0, 1) to determine which -quantiles to estimate (default is 0.5).} +quantiles to estimate (default is the set 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95).} \item{degree}{the number of polynomials used for response smoothing. Must be no more than the number of responses.} From a37df952f06fa4fb4b85001684b4e8eec591b815 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Fri, 7 Feb 2025 14:31:15 -0800 Subject: [PATCH 10/12] pass all tests --- R/arx_forecaster.R | 3 +- tests/testthat/_snaps/snapshots.md | 315 ++++++++++++-------- tests/testthat/test-arx_args_list.R | 2 +- tests/testthat/test-extract_argument.R | 6 +- tests/testthat/test-grf_quantiles.R | 12 +- tests/testthat/test-layer_threshold_preds.R | 5 +- 6 files changed, 211 insertions(+), 132 deletions(-) diff --git a/R/arx_forecaster.R b/R/arx_forecaster.R index 228fa1217..5397ca881 100644 --- a/R/arx_forecaster.R +++ b/R/arx_forecaster.R @@ -200,7 +200,8 @@ arx_fcast_epi_workflow <- function( } else { quantile_levels <- sort(compare_quantile_args( args_list$quantile_levels, - rlang::eval_tidy(trainer$eng_args$quantiles) %||% c(.1, .5, .9), + rlang::eval_tidy(trainer$eng_args$quantiles) %||% + c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95), "grf" )) trainer$eng_args$quantiles <- rlang::enquo(quantile_levels) diff --git a/tests/testthat/_snaps/snapshots.md b/tests/testthat/_snaps/snapshots.md index 213f6b7be..17191e041 100644 --- a/tests/testthat/_snaps/snapshots.md +++ b/tests/testthat/_snaps/snapshots.md @@ -3,20 +3,28 @@ structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" ), .pred = c(0.1393442, 0.103199, 0.3121244, 0.4218461, 0.7319844, 0.1975426), .pred_distn = structure(list(structure(list(values = c(0, - 0.1393442, 0.34820911), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + 0.00989957999999999, 0.09353595, 0.1393442, 0.18515245, 0.26878882, + 0.34820911), quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, + 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0.05739075, 0.103199, + 0.14900725, 0.23264362, 0.31206391), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0, 0.103199, 0.31206391), quantile_levels = c(0.05, - 0.5, 0.95)), class = c("dist_quantiles", "dist_default", - "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.10325949, - 0.3121244, 0.52098931), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + values = c(0.10325949, 0.18267978, 0.26631615, 0.3121244, + 0.35793265, 0.44156902, 0.52098931), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.21298119, 0.4218461, 0.63071101), quantile_levels = c(0.05, - 0.5, 0.95)), class = c("dist_quantiles", "dist_default", - "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.52311949, - 0.7319844, 0.94084931), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + values = c(0.21298119, 0.29240148, 0.37603785, 0.4218461, + 0.46765435, 0.55129072, 0.63071101), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0, 0.1975426, 0.40640751), quantile_levels = c(0.05, - 0.5, 0.95)), class = c("dist_quantiles", "dist_default", + values = c(0.52311949, 0.60253978, 0.68617615, 0.7319844, + 0.77779265, 0.86142902, 0.94084931), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0.06809798, 0.15173435, 0.1975426, 0.24335085, + 0.32698722, 0.40640751), quantile_levels = c(0.05, 0.1, 0.25, + 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", "list")), forecast_date = structure(c(18992, 18992, 18992, 18992, 18992, 18992), class = "Date"), target_date = structure(c(18999, @@ -28,23 +36,30 @@ structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" ), .pred = c(0.1393442, 0.103199, 0.3121244, 0.4218461, 0.7319844, 0.1975426), .pred_distn = structure(list(structure(list(values = c(0.084583345, - 0.1393442, 0.194105055), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + 0.1073314, 0.1292864, 0.1393442, 0.149402, 0.171357, 0.194105055 + ), quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.048438145, 0.103199, 0.157959855), quantile_levels = c(0.05, - 0.5, 0.95)), class = c("dist_quantiles", "dist_default", + values = c(0.048438145, 0.0711862, 0.0931412, 0.103199, 0.1132568, + 0.1352118, 0.157959855), quantile_levels = c(0.05, 0.1, 0.25, + 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.257363545, - 0.3121244, 0.366885255), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + 0.2801116, 0.3020666, 0.3121244, 0.3221822, 0.3441372, 0.366885255 + ), quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.367085245, 0.4218461, 0.476606955), quantile_levels = c(0.05, - 0.5, 0.95)), class = c("dist_quantiles", "dist_default", - "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.677223545, - 0.7319844, 0.786745255), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + values = c(0.367085245, 0.3898333, 0.4117883, 0.4218461, + 0.4319039, 0.4538589, 0.476606955), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.142781745, 0.1975426, 0.252303455), quantile_levels = c(0.05, - 0.5, 0.95)), class = c("dist_quantiles", "dist_default", - "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", - "list")), forecast_date = structure(c(18992, 18992, 18992, 18992, - 18992, 18992), class = "Date"), target_date = structure(c(18993, + values = c(0.677223545, 0.6999716, 0.7219266, 0.7319844, + 0.7420422, 0.7639972, 0.786745255), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.142781745, 0.1655298, 0.1874848, 0.1975426, + 0.2076004, 0.2295554, 0.252303455), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", + "vctrs_vctr", "list")), forecast_date = structure(c(18992, 18992, + 18992, 18992, 18992, 18992), class = "Date"), target_date = structure(c(18993, 18993, 18993, 18993, 18993, 18993), class = "Date")), row.names = c(NA, -6L), class = c("tbl_df", "tbl", "data.frame")) @@ -53,20 +68,28 @@ structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" ), .pred = c(0.1393442, 0.103199, 0.3121244, 0.4218461, 0.7319844, 0.1975426), .pred_distn = structure(list(structure(list(values = c(0, - 0.1393442, 0.34820911), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + 0.00989957999999999, 0.09353595, 0.1393442, 0.18515245, 0.26878882, + 0.34820911), quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, + 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0.05739075, 0.103199, + 0.14900725, 0.23264362, 0.31206391), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.10325949, 0.18267978, 0.26631615, 0.3121244, + 0.35793265, 0.44156902, 0.52098931), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0, 0.103199, 0.31206391), quantile_levels = c(0.05, - 0.5, 0.95)), class = c("dist_quantiles", "dist_default", - "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.10325949, - 0.3121244, 0.52098931), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + values = c(0.21298119, 0.29240148, 0.37603785, 0.4218461, + 0.46765435, 0.55129072, 0.63071101), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.21298119, 0.4218461, 0.63071101), quantile_levels = c(0.05, - 0.5, 0.95)), class = c("dist_quantiles", "dist_default", - "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.52311949, - 0.7319844, 0.94084931), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + values = c(0.52311949, 0.60253978, 0.68617615, 0.7319844, + 0.77779265, 0.86142902, 0.94084931), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0, 0.1975426, 0.40640751), quantile_levels = c(0.05, - 0.5, 0.95)), class = c("dist_quantiles", "dist_default", + values = c(0, 0.06809798, 0.15173435, 0.1975426, 0.24335085, + 0.32698722, 0.40640751), quantile_levels = c(0.05, 0.1, 0.25, + 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", "list")), forecast_date = structure(c(18992, 18992, 18992, 18992, 18992, 18992), class = "Date"), target_date = structure(c(18999, @@ -78,20 +101,28 @@ structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" ), .pred = c(0.1393442, 0.103199, 0.3121244, 0.4218461, 0.7319844, 0.1975426), .pred_distn = structure(list(structure(list(values = c(0, - 0.1393442, 0.34820911), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + 0.00989957999999999, 0.09353595, 0.1393442, 0.18515245, 0.26878882, + 0.34820911), quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, + 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0.05739075, 0.103199, + 0.14900725, 0.23264362, 0.31206391), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0, 0.103199, 0.31206391), quantile_levels = c(0.05, - 0.5, 0.95)), class = c("dist_quantiles", "dist_default", - "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.10325949, - 0.3121244, 0.52098931), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + values = c(0.10325949, 0.18267978, 0.26631615, 0.3121244, + 0.35793265, 0.44156902, 0.52098931), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.21298119, 0.4218461, 0.63071101), quantile_levels = c(0.05, - 0.5, 0.95)), class = c("dist_quantiles", "dist_default", - "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.52311949, - 0.7319844, 0.94084931), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + values = c(0.21298119, 0.29240148, 0.37603785, 0.4218461, + 0.46765435, 0.55129072, 0.63071101), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0, 0.1975426, 0.40640751), quantile_levels = c(0.05, - 0.5, 0.95)), class = c("dist_quantiles", "dist_default", + values = c(0.52311949, 0.60253978, 0.68617615, 0.7319844, + 0.77779265, 0.86142902, 0.94084931), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0, 0.06809798, 0.15173435, 0.1975426, 0.24335085, + 0.32698722, 0.40640751), quantile_levels = c(0.05, 0.1, 0.25, + 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", "list")), forecast_date = structure(c(18992, 18992, 18992, 18992, 18992, 18992), class = "Date"), target_date = structure(c(18993, @@ -981,82 +1012,114 @@ # arx_forecaster snapshots structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" - ), .pred = c(0.353013358779434, 0.648525432444876, 0.667670289394327, - 1.1418673907239, 0.830448695683588, 0.329799431948648), .pred_distn = structure(list( - structure(list(values = c(0.171022956902287, 0.353013358779434, - 0.535003760656581), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", - "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.466535030567729, 0.648525432444876, 0.830515834322023 - ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + ), .pred = c(0.353013358779435, 0.648525432444877, 0.667670289394328, + 1.1418673907239, 0.830448695683587, 0.329799431948649), .pred_distn = structure(list( + structure(list(values = c(0.171022956902288, 0.244945899624723, + 0.308032696431071, 0.353013358779435, 0.397994021127798, + 0.461080817934147, 0.535003760656582), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.48567988751718, 0.667670289394327, 0.849660691271474 - ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + values = c(0.46653503056773, 0.540457973290166, 0.603544770096514, + 0.648525432444877, 0.693506094793241, 0.756592891599589, + 0.830515834322024), quantile_levels = c(0.05, 0.1, 0.25, + 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.485679887517181, + 0.559602830239616, 0.622689627045964, 0.667670289394328, + 0.712650951742692, 0.77573774854904, 0.849660691271475), + quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.959876988846753, + 1.03379993156919, 1.09688672837554, 1.1418673907239, 1.18684805307226, + 1.24993484987861, 1.32385779260105), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.959876988846751, 1.1418673907239, 1.32385779260105 - ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", - "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.648458293806441, 0.830448695683588, 1.01243909756074 - ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", - "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.147809030071501, 0.329799431948648, 0.511789833825795 - ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", - "dist_default", "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", - "vctrs_vctr", "list")), forecast_date = structure(c(18992, 18992, - 18992, 18992, 18992, 18992), class = "Date"), target_date = structure(c(18999, + values = c(0.64845829380644, 0.722381236528875, 0.785468033335223, + 0.830448695683587, 0.875429358031951, 0.938516154838299, + 1.01243909756073), quantile_levels = c(0.05, 0.1, 0.25, + 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.147809030071502, + 0.221731972793937, 0.284818769600285, 0.329799431948649, + 0.374780094297013, 0.437866891103361, 0.511789833825796), + quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", + "list")), forecast_date = structure(c(18992, 18992, 18992, 18992, + 18992, 18992), class = "Date"), target_date = structure(c(18999, 18999, 18999, 18999, 18999, 18999), class = "Date")), row.names = c(NA, -6L), class = c("tbl_df", "tbl", "data.frame")) --- structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" - ), .pred = c(0.149303403634372, 0.139764664505947, 0.333186321066645, - 0.470345577837144, 0.725986105412009, 0.212686665274006), .pred_distn = structure(list( - structure(list(values = c(0.0961118191398626, 0.149303403634372, - 0.202494988128881), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", - "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.0865730800114375, 0.139764664505947, 0.192956249000456 - ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", - "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.279994736572135, 0.333186321066645, 0.386377905561154 - ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + ), .pred = c(0.149303403634373, 0.139764664505948, 0.333186321066645, + 0.470345577837144, 0.725986105412008, 0.212686665274007), .pred_distn = structure(list( + structure(list(values = c(0.0961118191398634, 0.118312393281548, + 0.13840396557592, 0.149303403634373, 0.160202841692825, 0.180294413987198, + 0.202494988128882), quantile_levels = c(0.05, 0.1, 0.25, + 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.0865730800114383, + 0.108773654153123, 0.128865226447495, 0.139764664505948, + 0.1506641025644, 0.170755674858773, 0.192956249000457), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.417153993342635, 0.470345577837144, 0.523537162331653 - ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", - "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.6727945209175, 0.725986105412009, 0.779177689906518 - ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", - "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.159495080779497, 0.212686665274006, 0.265878249768516 - ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", - "dist_default", "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", - "vctrs_vctr", "list")), forecast_date = structure(c(18992, 18992, - 18992, 18992, 18992, 18992), class = "Date"), target_date = structure(c(18993, + values = c(0.279994736572136, 0.30219531071382, 0.322286883008193, + 0.333186321066645, 0.344085759125097, 0.36417733141947, + 0.386377905561154), quantile_levels = c(0.05, 0.1, 0.25, + 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.417153993342634, + 0.439354567484319, 0.459446139778691, 0.470345577837144, + 0.481245015895596, 0.501336588189969, 0.523537162331653), + quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.672794520917498, + 0.694995095059183, 0.715086667353556, 0.725986105412008, + 0.73688554347046, 0.756977115764833, 0.779177689906517), + quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.159495080779498, + 0.181695654921182, 0.201787227215555, 0.212686665274007, + 0.223586103332459, 0.243677675626832, 0.265878249768516), + quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", + "list")), forecast_date = structure(c(18992, 18992, 18992, 18992, + 18992, 18992), class = "Date"), target_date = structure(c(18993, 18993, 18993, 18993, 18993, 18993), class = "Date")), row.names = c(NA, -6L), class = c("tbl_df", "tbl", "data.frame")) --- structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" - ), .pred = c(0.303244704017743, 0.531332853311082, 0.588827944685979, - 0.988690249216229, 0.794801997001639, 0.306895457225321), .pred_distn = structure(list( - structure(list(values = c(0.136509784083987, 0.303244704017743, - 0.469979623951498), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", - "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.364597933377326, 0.531332853311082, 0.698067773244837 - ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + ), .pred = c(0.303244704017742, 0.531332853311081, 0.58882794468598, + 0.98869024921623, 0.79480199700164, 0.306895457225321), .pred_distn = structure(list( + structure(list(values = c(0.136509784083987, 0.202348949370703, + 0.263837900408968, 0.303244704017742, 0.342651507626517, + 0.404140458664782, 0.469979623951498), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.422093024752224, 0.588827944685979, 0.755562864619735 - ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + values = c(0.364597933377326, 0.430437098664042, 0.491926049702307, + 0.531332853311081, 0.570739656919856, 0.632228607958121, + 0.698067773244837), quantile_levels = c(0.05, 0.1, 0.25, + 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.422093024752224, + 0.48793219003894, 0.549421141077205, 0.58882794468598, 0.628234748294754, + 0.689723699333019, 0.755562864619735), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.821955329282474, 0.988690249216229, 1.15542516914998 - ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + values = c(0.821955329282475, 0.887794494569191, 0.949283445607456, + 0.98869024921623, 1.028097052825, 1.08958600386327, 1.15542516914999 + ), quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.628067077067884, + 0.693906242354601, 0.755395193392866, 0.79480199700164, 0.834208800610414, + 0.895697751648679, 0.961536916935395), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.628067077067883, 0.794801997001639, 0.961536916935394 - ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", - "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.140160537291566, 0.306895457225321, 0.473630377159077 - ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", - "dist_default", "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", + values = c(0.140160537291566, 0.205999702578282, 0.267488653616547, + 0.306895457225321, 0.346302260834096, 0.407791211872361, + 0.473630377159077), quantile_levels = c(0.05, 0.1, 0.25, + 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", "list")), forecast_date = structure(c(18997, 18997, 18997, 18997, 18997, 18997), class = "Date"), target_date = structure(c(18998, 18998, 18998, 18998, 18998, 18998), class = "Date")), row.names = c(NA, @@ -1065,26 +1128,36 @@ --- structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" - ), .pred = c(0.303244704017743, 0.531332853311082, 0.588827944685979, - 0.988690249216229, 0.794801997001639, 0.306895457225321), .pred_distn = structure(list( - structure(list(values = c(0.136509784083987, 0.303244704017743, - 0.469979623951498), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", - "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.364597933377326, 0.531332853311082, 0.698067773244837 - ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", - "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.422093024752224, 0.588827944685979, 0.755562864619735 - ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + ), .pred = c(0.303244704017742, 0.531332853311081, 0.58882794468598, + 0.98869024921623, 0.79480199700164, 0.306895457225321), .pred_distn = structure(list( + structure(list(values = c(0.136509784083987, 0.202348949370703, + 0.263837900408968, 0.303244704017742, 0.342651507626517, + 0.404140458664782, 0.469979623951498), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.821955329282474, 0.988690249216229, 1.15542516914998 - ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + values = c(0.364597933377326, 0.430437098664042, 0.491926049702307, + 0.531332853311081, 0.570739656919856, 0.632228607958121, + 0.698067773244837), quantile_levels = c(0.05, 0.1, 0.25, + 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0.422093024752224, + 0.48793219003894, 0.549421141077205, 0.58882794468598, 0.628234748294754, + 0.689723699333019, 0.755562864619735), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.628067077067883, 0.794801997001639, 0.961536916935394 - ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", + values = c(0.821955329282475, 0.887794494569191, 0.949283445607456, + 0.98869024921623, 1.028097052825, 1.08958600386327, 1.15542516914999 + ), quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.628067077067884, + 0.693906242354601, 0.755395193392866, 0.79480199700164, 0.834208800610414, + 0.895697751648679, 0.961536916935395), quantile_levels = c(0.05, + 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.140160537291566, 0.306895457225321, 0.473630377159077 - ), quantile_levels = c(0.05, 0.5, 0.95)), class = c("dist_quantiles", - "dist_default", "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", + values = c(0.140160537291566, 0.205999702578282, 0.267488653616547, + 0.306895457225321, 0.346302260834096, 0.407791211872361, + 0.473630377159077), quantile_levels = c(0.05, 0.1, 0.25, + 0.5, 0.75, 0.9, 0.95)), class = c("dist_quantiles", "dist_default", + "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", "list")), forecast_date = structure(c(18997, 18997, 18997, 18997, 18997, 18997), class = "Date"), target_date = structure(c(18998, 18998, 18998, 18998, 18998, 18998), class = "Date")), row.names = c(NA, diff --git a/tests/testthat/test-arx_args_list.R b/tests/testthat/test-arx_args_list.R index 8102c31d8..50379357a 100644 --- a/tests/testthat/test-arx_args_list.R +++ b/tests/testthat/test-arx_args_list.R @@ -41,7 +41,7 @@ test_that("arx forecaster disambiguates quantiles", { tlist <- eval(formals(quantile_reg)$quantile_levels) expect_identical( # both default compare_quantile_args(alist, tlist), - c(0.05, 0.5, 0.95) + c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95) ) expect_snapshot( error = TRUE, diff --git a/tests/testthat/test-extract_argument.R b/tests/testthat/test-extract_argument.R index 7ac160e67..e60632289 100644 --- a/tests/testthat/test-extract_argument.R +++ b/tests/testthat/test-extract_argument.R @@ -8,20 +8,20 @@ test_that("layer argument extractor works", { expect_snapshot(error = TRUE, extract_argument(f$layers[[1]], "layer_predict", "bubble")) expect_identical( extract_argument(f$layers[[2]], "layer_residual_quantiles", "quantile_levels"), - c(0.0275, 0.9750) + c(0.0275, 0.5, 0.9750) ) expect_snapshot(error = TRUE, extract_argument(f, "layer_thresh", "quantile_levels")) expect_identical( extract_argument(f, "layer_residual_quantiles", "quantile_levels"), - c(0.0275, 0.9750) + c(0.0275, 0.5, 0.9750) ) wf <- epi_workflow(postprocessor = f) expect_snapshot(error = TRUE, extract_argument(epi_workflow(), "layer_residual_quantiles", "quantile_levels")) expect_identical( extract_argument(wf, "layer_residual_quantiles", "quantile_levels"), - c(0.0275, 0.9750) + c(0.0275, 0.5, 0.9750) ) expect_snapshot(error = TRUE, extract_argument(wf, "layer_predict", c("type", "opts"))) diff --git a/tests/testthat/test-grf_quantiles.R b/tests/testthat/test-grf_quantiles.R index e2cf90cf7..5adbf6518 100644 --- a/tests/testthat/test-grf_quantiles.R +++ b/tests/testthat/test-grf_quantiles.R @@ -9,7 +9,10 @@ test_that("quantile_rand_forest defaults work", { spec <- rand_forest(engine = "grf_quantiles", mode = "regression") expect_silent(out <- fit(spec, formula = y ~ x + z, data = tib)) pars <- parsnip::extract_fit_engine(out) - manual <- quantile_forest(as.matrix(tib[, 2:3]), tib$y, quantiles = c(0.1, 0.5, 0.9)) + manual <- quantile_forest( + as.matrix(tib[, 2:3]), tib$y, + quantiles = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95) + ) expect_identical(pars$quantiles.orig, manual$quantiles.orig) expect_identical(pars$`_num_trees`, manual$`_num_trees`) @@ -43,7 +46,8 @@ test_that("quantile_rand_forest handles alternative quantiles", { test_that("quantile_rand_forest handles allows setting the trees and mtry", { - spec <- rand_forest(mode = "regression", mtry = 2, trees = 100, engine = "grf_quantiles") + spec <- rand_forest(mode = "regression", mtry = 2, trees = 100) %>% + set_engine(engine = "grf_quantiles", quantiles = c(0.1, 0.5, 0.9)) expect_silent(out <- fit(spec, formula = y ~ x + z, data = tib)) pars <- parsnip::extract_fit_engine(out) manual <- quantile_forest(as.matrix(tib[, 2:3]), tib$y, mtry = 2, num.trees = 100) @@ -68,13 +72,13 @@ test_that("quantile_rand_forest operates with arx_forecaster", { spec2 <- parsnip::extract_spec_parsnip(o) expect_identical( rlang::eval_tidy(spec2$eng_args$quantiles), - c(.05, .1, .5, .9, .95) # merged with arx_args default + c(.05, .1, 0.25, .5, 0.75, .9, .95) # merged with arx_args default ) df <- epidatasets::counts_subset %>% filter(time_value >= "2021-10-01") z <- arx_forecaster(df, "cases", "cases", spec2) expect_identical( nested_quantiles(z$predictions$.pred_distn[1])[[1]]$quantile_levels, - c(.05, .1, .5, .9, .95) + c(.05, .1, 0.25, .5, 0.75, .9, .95) ) }) diff --git a/tests/testthat/test-layer_threshold_preds.R b/tests/testthat/test-layer_threshold_preds.R index f3e90c21a..2201d7fc4 100644 --- a/tests/testthat/test-layer_threshold_preds.R +++ b/tests/testthat/test-layer_threshold_preds.R @@ -58,6 +58,7 @@ test_that("thresholds additional columns", { p <- p %>% dplyr::mutate(.quantiles = nested_quantiles(.pred_distn)) %>% tidyr::unnest(.quantiles) - expect_equal(round(p$values, digits = 3), c(0.180, 0.31, 0.180, .18, 0.310, .31)) - expect_equal(p$quantile_levels, rep(c(.1, .9), times = 3)) + expect_equal(round(p$values, digits = 3), + c(0.180, 0.180, 0.31, 0.180, 0.180, .18, 0.310, .31, .31)) + expect_equal(p$quantile_levels, rep(c(.1, 0.5, .9), times = 3)) }) From dd6b564c4abf2ed8431a0e3ce08a61835d8db6f8 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Fri, 7 Feb 2025 16:40:59 -0800 Subject: [PATCH 11/12] fu styler --- tests/testthat/test-layer_threshold_preds.R | 6 ++++-- vignettes/epipredict.Rmd | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/testthat/test-layer_threshold_preds.R b/tests/testthat/test-layer_threshold_preds.R index 2201d7fc4..2ff43a165 100644 --- a/tests/testthat/test-layer_threshold_preds.R +++ b/tests/testthat/test-layer_threshold_preds.R @@ -58,7 +58,9 @@ test_that("thresholds additional columns", { p <- p %>% dplyr::mutate(.quantiles = nested_quantiles(.pred_distn)) %>% tidyr::unnest(.quantiles) - expect_equal(round(p$values, digits = 3), - c(0.180, 0.180, 0.31, 0.180, 0.180, .18, 0.310, .31, .31)) + expect_equal( + round(p$values, digits = 3), + c(0.180, 0.180, 0.31, 0.180, 0.180, .18, 0.310, .31, .31) + ) expect_equal(p$quantile_levels, rep(c(.1, 0.5, .9), times = 3)) }) diff --git a/vignettes/epipredict.Rmd b/vignettes/epipredict.Rmd index f5d57a071..ce0a7e38e 100644 --- a/vignettes/epipredict.Rmd +++ b/vignettes/epipredict.Rmd @@ -237,7 +237,7 @@ function: ```{r, eval = FALSE} arx_args_list( lags = c(0L, 7L, 14L), ahead = 7L, n_training = Inf, - forecast_date = NULL, target_date = NULL, + forecast_date = NULL, target_date = NULL, quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95), symmetrize = TRUE, nonneg = TRUE, quantile_by_key = character(0L), nafill_buffer = Inf From 10b5b615655e1d997ddc7ca70c26184e2841a1b6 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Fri, 7 Feb 2025 17:02:02 -0800 Subject: [PATCH 12/12] bump news --- NEWS.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 84aaffadc..633086cb0 100644 --- a/NEWS.md +++ b/NEWS.md @@ -25,7 +25,8 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat - Shifting no columns results in no error for either `step_epi_ahead` and `step_epi_lag` - Quantiles produced by `grf` were sometimes out of order. - dist_quantiles can have all `NA` values without causing unrelated errors -- add `0.5` as a default quantile for canned forecasters to avoid strange thresholding behavior +- adjust default quantiles throughout so that they match. +- force `layer_residual_quantiles()` to always include `0.5`. # epipredict 0.1