Skip to content

Commit 56bed8c

Browse files
committed
support col_names as tidyselect
1 parent c6ee7f9 commit 56bed8c

File tree

4 files changed

+89
-61
lines changed

4 files changed

+89
-61
lines changed

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ importFrom(rlang,"!!")
132132
importFrom(rlang,.data)
133133
importFrom(rlang,.env)
134134
importFrom(rlang,arg_match)
135+
importFrom(rlang,as_label)
135136
importFrom(rlang,caller_arg)
136137
importFrom(rlang,caller_env)
137138
importFrom(rlang,enquo)
@@ -146,6 +147,7 @@ importFrom(rlang,is_missing)
146147
importFrom(rlang,is_quosure)
147148
importFrom(rlang,missing_arg)
148149
importFrom(rlang,new_function)
150+
importFrom(rlang,quo_get_expr)
149151
importFrom(rlang,quo_is_missing)
150152
importFrom(rlang,sym)
151153
importFrom(rlang,syms)

R/slide.R

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,8 @@ epi_slide <- function(x, f, ..., before, after, ref_time_values,
393393
#' @param x The `epi_df` object under consideration, [grouped][dplyr::group_by]
394394
#' or ungrouped. If ungrouped, all data in `x` will be treated as part of a
395395
#' single data group.
396-
#' @param col_names A character vector of the names of one or more columns for
397-
#' which to calculate the rolling mean.
396+
#' @param col_names A single tidyselection or a tidyselection vector of the
397+
#' names of one or more columns for which to calculate the rolling mean.
398398
#' @param ... Additional arguments to pass to `data.table::frollmean`, for
399399
#' example, `na.rm` and `algo`. `data.table::frollmean` is automatically
400400
#' passed the data `x` to operate on, the window size `n`, and the alignment
@@ -473,7 +473,8 @@ epi_slide <- function(x, f, ..., before, after, ref_time_values,
473473
#' leading window was intended, but the `after` argument was forgotten or
474474
#' misspelled.)
475475
#'
476-
#' @importFrom dplyr bind_rows mutate %>% arrange tibble
476+
#' @importFrom dplyr bind_rows mutate %>% arrange tibble select
477+
#' @importFrom rlang enquo quo_get_expr as_label
477478
#' @importFrom purrr map
478479
#' @importFrom data.table frollmean
479480
#' @importFrom lubridate as.period
@@ -484,7 +485,7 @@ epi_slide <- function(x, f, ..., before, after, ref_time_values,
484485
#' # slide a 7-day trailing average formula on cases
485486
#' jhu_csse_daily_subset %>%
486487
#' group_by(geo_value) %>%
487-
#' epi_slide_mean("cases", new_col_names = "cases_7dav", names_sep = NULL, before = 6) %>%
488+
#' epi_slide_mean(cases, new_col_names = "cases_7dav", names_sep = NULL, before = 6) %>%
488489
#' # Remove a nonessential var. to ensure new col is printed
489490
#' dplyr::select(geo_value, time_value, cases, cases_7dav) %>%
490491
#' ungroup()
@@ -493,7 +494,7 @@ epi_slide <- function(x, f, ..., before, after, ref_time_values,
493494
#' # and accuracy, and to allow partially-missing windows.
494495
#' jhu_csse_daily_subset %>%
495496
#' group_by(geo_value) %>%
496-
#' epi_slide_mean("cases",
497+
#' epi_slide_mean(cases,
497498
#' new_col_names = "cases_7dav", names_sep = NULL, before = 6,
498499
#' # `frollmean` options
499500
#' na.rm = TRUE, algo = "exact", hasNA = TRUE
@@ -504,23 +505,23 @@ epi_slide <- function(x, f, ..., before, after, ref_time_values,
504505
#' # slide a 7-day leading average
505506
#' jhu_csse_daily_subset %>%
506507
#' group_by(geo_value) %>%
507-
#' epi_slide_mean("cases", new_col_names = "cases_7dav", names_sep = NULL, after = 6) %>%
508+
#' epi_slide_mean(cases, new_col_names = "cases_7dav", names_sep = NULL, after = 6) %>%
508509
#' # Remove a nonessential var. to ensure new col is printed
509510
#' dplyr::select(geo_value, time_value, cases, cases_7dav) %>%
510511
#' ungroup()
511512
#'
512513
#' # slide a 7-day centre-aligned average
513514
#' jhu_csse_daily_subset %>%
514515
#' group_by(geo_value) %>%
515-
#' epi_slide_mean("cases", new_col_names = "cases_7dav", names_sep = NULL, before = 3, after = 3) %>%
516+
#' epi_slide_mean(cases, new_col_names = "cases_7dav", names_sep = NULL, before = 3, after = 3) %>%
516517
#' # Remove a nonessential var. to ensure new col is printed
517518
#' dplyr::select(geo_value, time_value, cases, cases_7dav) %>%
518519
#' ungroup()
519520
#'
520521
#' # slide a 14-day centre-aligned average
521522
#' jhu_csse_daily_subset %>%
522523
#' group_by(geo_value) %>%
523-
#' epi_slide_mean("cases", new_col_names = "cases_14dav", names_sep = NULL, before = 6, after = 7) %>%
524+
#' epi_slide_mean(cases, new_col_names = "cases_14dav", names_sep = NULL, before = 6, after = 7) %>%
524525
#' # Remove a nonessential var. to ensure new col is printed
525526
#' dplyr::select(geo_value, time_value, cases, cases_14dav) %>%
526527
#' ungroup()
@@ -604,29 +605,46 @@ epi_slide_mean <- function(x, col_names, ..., before, after, ref_time_values,
604605
# `before` and `after` params.
605606
m <- before + after + 1L
606607

608+
col_names_quo <- enquo(col_names)
609+
col_names_chr <- as.character(rlang::quo_get_expr(col_names_quo))
610+
if (startsWith(rlang::as_label(col_names_quo), "c(")) {
611+
# List or vector of col names. We need to drop the first element since it
612+
# will be either "c" (if built as a vector) or "list" (if built as a
613+
# list).
614+
col_names_chr <- col_names_chr[-1]
615+
} else if (startsWith(rlang::as_label(col_names_quo), "list(")) {
616+
cli_abort(
617+
"`col_names` must be a single tidy column name or a vector
618+
(`c()`) of tidy column names",
619+
class = "epiprocess__epi_slide_mean__col_names_in_list",
620+
epiprocess__col_names = col_names_chr
621+
)
622+
}
623+
# If single column name, do nothing.
624+
607625
if (is.null(names_sep)) {
608-
if (length(new_col_names) != length(col_names)) {
626+
if (length(new_col_names) != length(col_names_chr)) {
609627
cli_abort(
610628
c(
611629
"`new_col_names` must be the same length as `col_names` when
612630
`names_sep` is NULL to avoid duplicate output column names."
613631
),
614632
class = "epiprocess__epi_slide_mean__col_names_length_mismatch",
615633
epiprocess__new_col_names = new_col_names,
616-
epiprocess__col_names = col_names
634+
epiprocess__col_names = col_names_chr
617635
)
618636
}
619637
result_col_names <- new_col_names
620638
} else {
621-
if (length(new_col_names) != 1L && length(new_col_names) != length(col_names)) {
639+
if (length(new_col_names) != 1L && length(new_col_names) != length(col_names_chr)) {
622640
cli_abort(
623641
"`new_col_names` must be either length 1 or the same length as `col_names`.",
624642
class = "epiprocess__epi_slide_mean__col_names_length_mismatch_and_not_one",
625643
epiprocess__new_col_names = new_col_names,
626-
epiprocess__col_names = col_names
644+
epiprocess__col_names = col_names_chr
627645
)
628646
}
629-
result_col_names <- paste(new_col_names, col_names, sep = names_sep)
647+
result_col_names <- paste(new_col_names, col_names_chr, sep = names_sep)
630648
}
631649

632650
slide_one_grp <- function(.data_group, .group_key, ...) {
@@ -675,7 +693,7 @@ epi_slide_mean <- function(x, col_names, ..., before, after, ref_time_values,
675693
}
676694

677695
roll_output <- data.table::frollmean(
678-
x = .data_group[, col_names], n = m, align = "right", ...
696+
x = select(.data_group, {{ col_names }}), n = m, align = "right", ...
679697
)
680698

681699
if (after >= 1) {

man/epi_slide_mean.Rd

Lines changed: 8 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)