Skip to content

Get test data rework #463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ export(flatline_forecaster)
export(flusight_hub_formatter)
export(forecast)
export(frosting)
export(get_test_data)
export(get_predict_data)
export(is_epi_recipe)
export(is_epi_workflow)
export(is_layer)
Expand Down
60 changes: 31 additions & 29 deletions R/arx_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,19 @@ arx_forecaster <- function(
if (!is_regression(trainer)) {
cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'regression'.")
}

wf <- arx_fcast_epi_workflow(epi_data, outcome, predictors, trainer, args_list)
wf <- fit(wf, epi_data)

# get the forecast date for the forecast function
if (args_list$adjust_latency == "none") {
forecast_date_default <- max(epi_data$time_value)
reference_date_default <- max(epi_data$time_value)
} else {
forecast_date_default <- attributes(epi_data)$metadata$as_of
reference_date_default <- attributes(epi_data)$metadata$as_of
}
forecast_date <- args_list$forecast_date %||% forecast_date_default

reference_date <- args_list$reference_date %||% reference_date_default
predict_interval <- args_list$predict_interval

preds <- forecast(wf, forecast_date = forecast_date) %>%
preds <- forecast(wf, reference_dates = reference_date, predict_interval = predict_interval) %>%
as_tibble() %>%
select(-time_value)

Expand Down Expand Up @@ -126,21 +125,21 @@ arx_fcast_epi_workflow <- function(
# if they don't and they're not adjusting latency, it defaults to the max time_value
# if they're adjusting, it defaults to the as_of
if (args_list$adjust_latency == "none") {
forecast_date_default <- max(epi_data$time_value)
if (!is.null(args_list$forecast_date) && args_list$forecast_date != forecast_date_default) {
reference_date_default <- max(epi_data$time_value)
if (!is.null(args_list$reference_date) && args_list$reference_date != reference_date_default) {
cli_warn(
"The specified forecast date {args_list$forecast_date} doesn't match the date from which the forecast is actually occurring {forecast_date_default}.",
"The specified forecast date {args_list$reference_date} doesn't match the date from which the forecast is actually occurring {reference_date_default}.",
class = "epipredict__arx_forecaster__forecast_date_defaulting"
)
}
} else {
forecast_date_default <- attributes(epi_data)$metadata$as_of
reference_date_default <- attributes(epi_data)$metadata$as_of
}
forecast_date <- args_list$forecast_date %||% forecast_date_default
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
if (forecast_date + args_list$ahead != target_date) {
cli_abort("`forecast_date` {.val {forecast_date}} + `ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}.",
class = "epipredict__arx_forecaster__inconsistent_target_ahead_forecaste_date"
reference_date <- args_list$reference_date %||% reference_date_default
target_date <- args_list$target_date %||% (reference_date + args_list$ahead)
if (reference_date + args_list$ahead != target_date) {
cli_abort("`reference_date` {.val {reference_date}} + `ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}.",
class = "epipredict__arx_forecaster__inconsistent_target_ahead_forecast_date"
)
}

Expand All @@ -153,12 +152,12 @@ arx_fcast_epi_workflow <- function(
if (!is.null(method_adjust_latency)) {
if (method_adjust_latency == "extend_ahead") {
r <- r %>% step_adjust_latency(all_outcomes(),
fixed_forecast_date = forecast_date,
fixed_reference_date = reference_date,
method = method_adjust_latency
)
} else if (method_adjust_latency == "extend_lags") {
r <- r %>% step_adjust_latency(all_predictors(),
fixed_forecast_date = forecast_date,
fixed_reference_date = reference_date,
method = method_adjust_latency
)
}
Expand Down Expand Up @@ -218,7 +217,7 @@ arx_fcast_epi_workflow <- function(
by_key = args_list$quantile_by_key
)
}
f <- layer_add_forecast_date(f, forecast_date = forecast_date) %>%
f <- layer_add_forecast_date(f, forecast_date = reference_date) %>%
layer_add_target_date(target_date = target_date)
if (args_list$nonneg) f <- layer_threshold(f, dplyr::starts_with(".pred"))

Expand All @@ -238,19 +237,19 @@ arx_fcast_epi_workflow <- function(
#' @param n_training Integer. An upper limit for the number of rows per
#' key that are used for training
#' (in the time unit of the `epi_df`).
#' @param forecast_date Date. The date from which the forecast is occurring.
#' @param reference_date Date. The date from which the forecast is occurring.
#' The default `NULL` will determine this automatically from either
#' 1. the maximum time value for which there's data if there is no latency
#' adjustment (the default case), or
#' 2. the `as_of` date of `epi_data` if `adjust_latency` is
#' non-`NULL`.
#' @param target_date Date. The date that is being forecast. The default `NULL`
#' will determine this automatically as `forecast_date + ahead`.
#' will determine this automatically as `reference_date + ahead`.
#' @param adjust_latency Character. One of the `method`s of
#' [step_adjust_latency()], or `"none"` (in which case there is no adjustment).
#' If the `forecast_date` is after the last day of data, this determines how
#' If the `reference_date` is after the last day of data, this determines how
#' to shift the model to account for this difference. The options are:
#' - `"none"` the default, assumes the `forecast_date` is the last day of data
#' - `"none"` the default, assumes the `reference_date` is the last day of data
#' - `"extend_ahead"`: increase the `ahead` by the latency so it's relative to
#' the last day of data. For example, if the last day of data was 3 days ago,
#' the ahead becomes `ahead+3`.
Expand Down Expand Up @@ -280,6 +279,7 @@ arx_fcast_epi_workflow <- function(
#' column names on which to group the data and check threshold within each
#' group. Useful if training per group (for example, per geo_value).
#' @param ... Space to handle future expansions (unused).
#' @inheritParams get_predict_data
#'
#'
#' @return A list containing updated parameter choices with class `arx_flist`.
Expand All @@ -294,7 +294,7 @@ arx_args_list <- function(
lags = c(0L, 7L, 14L),
ahead = 7L,
n_training = Inf,
forecast_date = NULL,
reference_date = NULL,
target_date = NULL,
adjust_latency = c("none", "extend_ahead", "extend_lags", "locf"),
warn_latency = TRUE,
Expand All @@ -304,6 +304,7 @@ arx_args_list <- function(
quantile_by_key = character(0L),
check_enough_data_n = NULL,
check_enough_data_epi_keys = NULL,
predict_interval = NULL,
...) {
# error checking if lags is a list
rlang::check_dots_empty()
Expand All @@ -313,8 +314,8 @@ arx_args_list <- function(
adjust_latency <- rlang::arg_match(adjust_latency)
arg_is_scalar(ahead, n_training, symmetrize, nonneg, adjust_latency, warn_latency)
arg_is_chr(quantile_by_key, allow_empty = TRUE)
arg_is_scalar(forecast_date, target_date, allow_null = TRUE)
arg_is_date(forecast_date, target_date, allow_null = TRUE)
arg_is_scalar(reference_date, target_date, allow_null = TRUE)
arg_is_date(reference_date, target_date, allow_null = TRUE)
arg_is_nonneg_int(ahead, lags)
arg_is_lgl(symmetrize, nonneg)
arg_is_probabilities(quantile_levels, allow_null = TRUE)
Expand All @@ -323,9 +324,9 @@ arx_args_list <- function(
arg_is_pos(check_enough_data_n, allow_null = TRUE)
arg_is_chr(check_enough_data_epi_keys, allow_null = TRUE)

if (!is.null(forecast_date) && !is.null(target_date)) {
if (forecast_date + ahead != target_date) {
cli_abort("`forecast_date` {.val {forecast_date}} + `ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}.",
if (!is.null(reference_date) && !is.null(target_date)) {
if (reference_date + ahead != target_date) {
cli_abort("`reference_date` {.val {reference_date}} + `ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}.",
class = "epipredict__arx_args__inconsistent_target_ahead_forecaste_date"
)
}
Expand All @@ -338,8 +339,9 @@ arx_args_list <- function(
ahead,
n_training,
quantile_levels,
forecast_date,
reference_date,
target_date,
predict_interval,
adjust_latency,
warn_latency,
symmetrize,
Expand Down
2 changes: 1 addition & 1 deletion R/cdc_baseline_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ cdc_baseline_forecaster <- function(
# target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)


latest <- get_test_data(epi_recipe(epi_data), epi_data)
latest <- get_predict_data(epi_recipe(epi_data), epi_data)

f <- frosting() %>%
layer_predict() %>%
Expand Down
34 changes: 25 additions & 9 deletions R/epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor
#' @param new_data A data frame containing the new predictors to preprocess
#' and predict on
#'
#' @param reference_dates A vector matching the type of `time_value` in
#' `new_data` giving the dates of the predictions to keep. Defaults to the `reference_date` of the `object`'s recipe.
#'
#' @inheritParams parsnip::predict.model_fit
#'
#' @return
Expand All @@ -155,7 +158,7 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor
#'
#' preds <- predict(wf, latest)
#' preds
predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), ...) {
predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), reference_dates = NULL, ...) {
if (!workflows::is_trained_workflow(object)) {
cli_abort(c(
"Can't predict on an untrained epi_workflow.",
Expand All @@ -170,7 +173,19 @@ predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), .

components$keys <- grab_forged_keys(components$forged, object, new_data)
components <- apply_frosting(object, components, new_data, type = type, opts = opts, ...)
components$predictions
reference_dates <- reference_dates %||% extract_recipe(object)$reference_date
#browser()
predictions <- components$predictions %>% filter(time_value %in% reference_dates)
predictions
if (nrow(predictions) == 0) {
last_pred_date <- components$predictions %>% pull(time_value) %>% max()
last_data_date <- new_data %>% pull(time_value) %>% max()
cli_warn(
"no predictions on the reference date(s) {reference_dates}. The last prediction was on {last_pred_date}. The most recent prediction data is on {last_data_date}",
class = "epipredict__predict_epi_workflow__no_predictions"
)
}
predictions
}


Expand Down Expand Up @@ -238,14 +253,12 @@ print.epi_workflow <- function(x, ...) {
#' example, suppose n_recent = 3, then if the 3 most recent observations in any
#' geo_value are all NA’s, we won’t be able to fill anything, and an error
#' message will be thrown. (See details.)
#' @param forecast_date By default, this is set to the maximum time_value in x.
#' But if there is data latency such that recent NA's should be filled, this may
#' be after the last available time_value.
#' @inheritParams get_predict_data
#'
#' @return A forecast tibble.
#'
#' @export
forecast.epi_workflow <- function(object, ..., n_recent = NULL, forecast_date = NULL) {
forecast.epi_workflow <- function(object, ..., n_recent = NULL, reference_dates = NULL, predict_interval = NULL) {
rlang::check_dots_empty()

if (!object$trained) {
Expand All @@ -255,6 +268,7 @@ forecast.epi_workflow <- function(object, ..., n_recent = NULL, forecast_date =
))
}

#browser()
frosting_fd <- NULL
if (has_postprocessor(object) && detect_layer(object, "layer_add_forecast_date")) {
frosting_fd <- extract_argument(object, "layer_add_forecast_date", "forecast_date")
Expand All @@ -266,10 +280,12 @@ forecast.epi_workflow <- function(object, ..., n_recent = NULL, forecast_date =
}
}

test_data <- get_test_data(
predict_data <- get_predict_data(
hardhat::extract_preprocessor(object),
object$original_data
object$original_data,
reference_date = reference_dates,
predict_interval = predict_interval
)

predict(object, new_data = test_data)
predict(object, new_data = predict_data, reference_dates = reference_dates)
}
64 changes: 64 additions & 0 deletions R/get_predict_data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#' Get test data for prediction based on longest lag period
#'
#' Based on the longest lag period in the recipe,
#' `get_predict_data()` creates an [epi_df][epiprocess::as_epi_df]
#' with columns `geo_value`, `time_value`
#' and other variables in the original dataset,
#' which will be used to create features necessary to produce forecasts.
#'
#' The minimum required (recent) data to produce a forecast is equal to
#' the maximum lag requested (on any predictor) plus the longest horizon
#' used if growth rate calculations are requested by the recipe. This is
#' calculated internally.
#'
#' @param recipe A recipe object.
#' @param x An epi_df. The typical usage is to
#' pass the same data as that used for fitting the recipe.
#' @param predict_interval A time interval or integer. The length of time before
#' the `forecast_date` to consider for the forecast. The default is 1 year,
#' which you will likely only need to make longer if you are doing long
#' forecast horizons, or shorter if you are forecasting using an expensive
#' model.
#' @param reference_date By default, this is set to the maximum time_value in x.
#' But if there is data latency such that recent NA's should be filled, this may
#' be after the last available time_value.
#'
#' @return An object of the same type as `x` with columns `geo_value`,
#' `time_value`, any additional keys, as well other variables in the original
#' dataset.
#' @examples
#' # create recipe
#' rec <- epi_recipe(covid_case_death_rates) %>%
#' step_epi_ahead(death_rate, ahead = 7) %>%
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
#' step_epi_lag(case_rate, lag = c(0, 7, 14))
#' get_predict_data(recipe = rec, x = covid_case_death_rates)
#' @importFrom rlang %@%
#' @importFrom stats na.omit
#' @export
get_predict_data <- function(recipe,
x,
predict_interval = NULL,
reference_date = NULL) {
if (!is_epi_df(x)) cli_abort("`x` must be an `epi_df`.")
check <- hardhat::check_column_names(x, colnames(recipe$template))
if (!check$ok) {
cli_abort(c(
"Some variables used for training are not available in {.arg x}.",
i = "The following required columns are missing: {check$missing_names}"
))
}
reference_date <- reference_date %||% recipe$reference_date
predict_interval <- predict_interval %||% as.difftime(365, units = "days")
trimmed_x <- x %>%
filter((reference_date - time_value) < predict_interval)

if (nrow(trimmed_x) == 0) {
cli_abort(
"predict data is filtered to no rows; check your `predict_interval = {predict_interval}`, `reference_date= {reference_date}` and latest data {max(x$time_value)}",
class = "epipredict__get_predict_data__no_predict_data"
)
}

trimmed_x
}
Loading
Loading