Skip to content

Commit f7fd850

Browse files
authored
Merge pull request #826 from facebookexperimental/add_objective_weights
2 parents 534f5f2 + eb2e8f8 commit f7fd850

File tree

16 files changed

+158
-64
lines changed

16 files changed

+158
-64
lines changed

R/DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Package: Robyn
22
Type: Package
33
Title: Semi-Automated Marketing Mix Modeling (MMM) from Meta Marketing Science
4-
Version: 3.10.4.9012
4+
Version: 3.10.5.9000
55
Authors@R: c(
66
person("Gufeng", "Zhou", , "gufeng@meta.com", c("cre","aut")),
77
person("Leonel", "Sentana", , "leonelsentana@meta.com", c("aut")),

R/R/allocator.R

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ robyn_allocator <- function(robyn_object = NULL,
108108
quiet = FALSE,
109109
ui = FALSE,
110110
...) {
111-
112111
### Use previously exported model using json_file
113112
if (!is.null(json_file)) {
114113
if (is.null(InputCollect)) {
@@ -141,12 +140,12 @@ robyn_allocator <- function(robyn_object = NULL,
141140
# OutputCollect <- imported$OutputCollect
142141
# select_model <- imported$select_model
143142
# } else {
144-
if (is.null(select_model) && length(OutputCollect$allSolutions == 1)) {
145-
select_model <- OutputCollect$allSolutions
146-
}
147-
if (any(is.null(InputCollect), is.null(OutputCollect), is.null(select_model))) {
148-
stop("When 'robyn_object' is not provided, then InputCollect, OutputCollect, select_model must be provided")
149-
}
143+
if (is.null(select_model) && length(OutputCollect$allSolutions == 1)) {
144+
select_model <- OutputCollect$allSolutions
145+
}
146+
if (any(is.null(InputCollect), is.null(OutputCollect), is.null(select_model))) {
147+
stop("When 'robyn_object' is not provided, then InputCollect, OutputCollect, select_model must be provided")
148+
}
150149
# }
151150

152151
if (length(InputCollect$paid_media_spends) <= 1) {
@@ -238,9 +237,12 @@ robyn_allocator <- function(robyn_object = NULL,
238237
simulation_period <- initial_mean_period <- unlist(summarise_all(select(histFiltered, any_of(mediaSpendSorted)), length))
239238
nDates <- lapply(mediaSpendSorted, function(x) histFiltered$ds)
240239
names(nDates) <- mediaSpendSorted
241-
if (!quiet) message(sprintf(
242-
"Date Window: %s:%s (%s %ss)",
243-
date_min, date_max, unique(initial_mean_period), InputCollect$intervalType))
240+
if (!quiet) {
241+
message(sprintf(
242+
"Date Window: %s:%s (%s %ss)",
243+
date_min, date_max, unique(initial_mean_period), InputCollect$intervalType
244+
))
245+
}
244246
zero_spend_channel <- names(histSpendWindow[histSpendWindow == 0])
245247

246248
initSpendUnitTotal <- sum(initSpendUnit)
@@ -253,7 +255,7 @@ robyn_allocator <- function(robyn_object = NULL,
253255
if (usecase == "all_historical_vec") {
254256
ndates_loc <- which(InputCollect$dt_mod$ds %in% histFiltered$ds)
255257
} else {
256-
ndates_loc <- 1:length(histFiltered$ds)
258+
ndates_loc <- seq_along(histFiltered$ds)
257259
}
258260
usecase <- paste(usecase, ifelse(!is.null(total_budget), "+ defined_budget", "+ historical_budget"))
259261

@@ -359,14 +361,19 @@ robyn_allocator <- function(robyn_object = NULL,
359361
skip_these <- (channel_constr_low == 0 & channel_constr_up == 0)
360362
zero_constraint_channel <- mediaSpendSorted[skip_these]
361363
if (any(skip_these) && !quiet) {
362-
message("Excluded variables (constrained to 0): ",
363-
paste(zero_constraint_channel, collapse = ", "))
364+
message(
365+
"Excluded variables (constrained to 0): ",
366+
paste(zero_constraint_channel, collapse = ", ")
367+
)
364368
}
365369
if (!all(coefSelectorSorted)) {
366370
zero_coef_channel <- setdiff(names(coefSelectorSorted), mediaSpendSorted[coefSelectorSorted])
367-
if (!quiet) message(
368-
"Excluded variables (coefficients are 0): ",
369-
paste(zero_coef_channel, collapse = ", "))
371+
if (!quiet) {
372+
message(
373+
"Excluded variables (coefficients are 0): ",
374+
paste(zero_coef_channel, collapse = ", ")
375+
)
376+
}
370377
} else {
371378
zero_coef_channel <- as.character()
372379
}
@@ -754,7 +761,9 @@ robyn_allocator <- function(robyn_object = NULL,
754761
select_model, scenario, eval_list,
755762
export, plot_folder, quiet
756763
)
757-
} else plots <- NULL
764+
} else {
765+
plots <- NULL
766+
}
758767

759768
output <- list(
760769
dt_optimOut = dt_optimOut,

R/R/checks.R

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,14 @@ check_novar <- function(dt_input, InputCollect = NULL) {
3838
"There are %s column(s) with no-variance: %s. \nPlease, remove the variable(s) to proceed...",
3939
length(novar), v2t(novar)
4040
)
41-
if (!is.null(InputCollect)) msg <- sprintf(
42-
"%s\n>>> Note: there's no variance on these variables because of the modeling window filter (%s:%s)",
43-
msg,
44-
InputCollect$window_start,
45-
InputCollect$window_end
46-
)
41+
if (!is.null(InputCollect)) {
42+
msg <- sprintf(
43+
"%s\n>>> Note: there's no variance on these variables because of the modeling window filter (%s:%s)",
44+
msg,
45+
InputCollect$window_start,
46+
InputCollect$window_end
47+
)
48+
}
4749
stop(msg)
4850
}
4951
}
@@ -164,9 +166,12 @@ check_prophet <- function(dt_holidays, prophet_country, prophet_vars, prophet_si
164166
prophet_vars <- tolower(prophet_vars)
165167
opts <- c("trend", "season", "monthly", "weekday", "holiday")
166168
if (!"holiday" %in% prophet_vars) {
167-
if (!is.null(prophet_country)) warning(paste(
168-
"Input 'prophet_country' is defined as", prophet_country,
169-
"but 'holiday' is not setup within 'prophet_vars' parameter"))
169+
if (!is.null(prophet_country)) {
170+
warning(paste(
171+
"Input 'prophet_country' is defined as", prophet_country,
172+
"but 'holiday' is not setup within 'prophet_vars' parameter"
173+
))
174+
}
170175
prophet_country <- NULL
171176
}
172177
if (!all(prophet_vars %in% opts)) {
@@ -177,7 +182,7 @@ check_prophet <- function(dt_holidays, prophet_country, prophet_vars, prophet_si
177182
}
178183
if ("holiday" %in% prophet_vars && (
179184
is.null(prophet_country) || length(prophet_country) > 1 |
180-
isTRUE(!prophet_country %in% unique(dt_holidays$country)))) {
185+
isTRUE(!prophet_country %in% unique(dt_holidays$country)))) {
181186
stop(paste(
182187
"You must provide 1 country code in 'prophet_country' input.",
183188
length(unique(dt_holidays$country)), "countries are included:",
@@ -649,6 +654,26 @@ check_calibration <- function(dt_input, date_var, calibration_input, dayInterval
649654
return(calibration_input)
650655
}
651656

657+
check_obj_weight <- function(calibration_input, objective_weights, refresh) {
658+
obj_len <- ifelse(is.null(calibration_input), 2, 3)
659+
if (!is.null(objective_weights)) {
660+
if ((length(objective_weights) != obj_len)) {
661+
stop(paste0("objective_weights must have length of ", obj_len))
662+
}
663+
if (any(objective_weights < 0) | any(objective_weights > 10)) {
664+
stop("objective_weights must be >= 0 & <= 10")
665+
}
666+
}
667+
if (is.null(objective_weights) & refresh) {
668+
if (obj_len == 2) {
669+
objective_weights <- c(1, 10)
670+
} else {
671+
objective_weights <- c(1, 10, 10)
672+
}
673+
}
674+
return(objective_weights)
675+
}
676+
652677
check_iteration <- function(calibration_input, iterations, trials, hyps_fixed, refresh) {
653678
if (!refresh) {
654679
if (!hyps_fixed) {

R/R/clusters.R

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,10 @@ robyn_clusters <- function(input, dep_var_type,
7777
{
7878
suppressMessages(
7979
clusterKmeans(df,
80-
k = NULL, limit = limit_clusters, ignore = ignore,
81-
dim_red = dim_red, quiet = TRUE, seed = seed
82-
))
80+
k = NULL, limit = limit_clusters, ignore = ignore,
81+
dim_red = dim_red, quiet = TRUE, seed = seed
82+
)
83+
)
8384
},
8485
error = function(err) {
8586
message(paste("Couldn't automatically create clusters:", err))
@@ -109,8 +110,10 @@ robyn_clusters <- function(input, dep_var_type,
109110
stopifnot(k %in% min_clusters:30)
110111
suppressMessages(
111112
cls <- clusterKmeans(
112-
df, k = k, limit = limit_clusters, ignore = ignore,
113-
dim_red = dim_red, quiet = TRUE, seed = seed)
113+
df,
114+
k = k, limit = limit_clusters, ignore = ignore,
115+
dim_red = dim_red, quiet = TRUE, seed = seed
116+
)
114117
)
115118

116119
# Select top models by minimum (weighted) distance to zero
@@ -181,8 +184,9 @@ confidence_calcs <- function(
181184
if (length(unique(df_outcome$solID)) < 3) {
182185
warning(paste("Cluster", j, "does not contain enough models to calculate CI"))
183186
} else {
184-
if (cluster_by == "hyperparameters")
187+
if (cluster_by == "hyperparameters") {
185188
all_paid <- unique(gsub(paste(paste0("_", HYPS_NAMES), collapse = "|"), "", all_paid))
189+
}
186190
for (i in all_paid) {
187191
# Bootstrap CI
188192
if (dep_var_type == "conversion") {
@@ -317,7 +321,8 @@ errors_scores <- function(df, balance = rep(1, 3), ts_validation = TRUE, ...) {
317321
if (cluster_by == "hyperparameters") {
318322
outcome <- select(
319323
x, .data$solID, contains(HYPS_NAMES),
320-
contains(c("nrmse", "decomp.rssd", "mape"))) %>%
324+
contains(c("nrmse", "decomp.rssd", "mape"))
325+
) %>%
321326
removenacols(all = FALSE)
322327
}
323328
}

R/R/convergence.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ robyn_converge <- function(OutputModels,
148148
moo_cloud_plot <- df %>%
149149
mutate(nrmse = lares::winsorize(.data$nrmse, nrmse_win)) %>%
150150
ggplot(aes(
151-
x = .data$nrmse, y = .data$decomp.rssd, colour = .data$ElapsedAccum
152-
)) +
151+
x = .data$nrmse, y = .data$decomp.rssd, colour = .data$ElapsedAccum
152+
)) +
153153
scale_colour_gradient(low = "skyblue", high = "navyblue") +
154154
labs(
155155
title = ifelse(!calibrated, "Multi-objective evolutionary performance",

R/R/inputs.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,10 @@ robyn_inputs <- function(dt_input = NULL,
361361
# Check for no-variance columns (after filtering modeling window)
362362
dt_mod_model_window <- InputCollect$dt_mod %>%
363363
select(-any_of(InputCollect$unused_vars)) %>%
364-
filter(.data$ds >= InputCollect$window_start,
365-
.data$ds <= InputCollect$window_end)
364+
filter(
365+
.data$ds >= InputCollect$window_start,
366+
.data$ds <= InputCollect$window_end
367+
)
366368
check_novar(dt_mod_model_window, InputCollect)
367369
}
368370

R/R/json.R

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ robyn_write <- function(InputCollect,
6464
outputs_time <- sprintf("%s min", attr(OutputCollect, "runTime"))
6565
total_time <- sprintf(
6666
"%s min",
67-
attr(OutputCollect, "runTime") + attr(OutputCollect$OutputModels, "runTime"))
67+
attr(OutputCollect, "runTime") + attr(OutputCollect$OutputModels, "runTime")
68+
)
6869
if (!is.null(OutputCollect)) {
6970
outputs <- list()
7071
outputs$select_model <- select_model
@@ -160,12 +161,12 @@ print.robyn_write <- function(x, ...) {
160161
print(glued("\n\nSummary Values on Selected Model:"))
161162

162163
print(x$ExportedModel$summary %>%
163-
select(-contains("boot"), -contains("ci_")) %>%
164-
dplyr::rename_at("performance", list(~ ifelse(x$InputCollect$dep_var_type == "revenue", "ROI", "CPA"))) %>%
165-
mutate(decompPer = formatNum(100 * .data$decompPer, pos = "%")) %>%
166-
dplyr::mutate_if(is.numeric, function(x) ifelse(!is.infinite(x), x, 0)) %>%
167-
dplyr::mutate_if(is.numeric, function(x) formatNum(x, 4, abbr = TRUE)) %>%
168-
replace(., . == "NA", "-") %>% as.data.frame())
164+
select(-contains("boot"), -contains("ci_")) %>%
165+
dplyr::rename_at("performance", list(~ ifelse(x$InputCollect$dep_var_type == "revenue", "ROI", "CPA"))) %>%
166+
mutate(decompPer = formatNum(100 * .data$decompPer, pos = "%")) %>%
167+
dplyr::mutate_if(is.numeric, function(x) ifelse(!is.infinite(x), x, 0)) %>%
168+
dplyr::mutate_if(is.numeric, function(x) formatNum(x, 4, abbr = TRUE)) %>%
169+
replace(., . == "NA", "-") %>% as.data.frame())
169170

170171
print(glued(
171172
"\n\nHyper-parameters:\n Adstock: {x$InputCollect$adstock}"
@@ -178,8 +179,8 @@ print.robyn_write <- function(x, ...) {
178179
select(-contains("lambda"), -any_of(HYPS_OTHERS)) %>%
179180
tidyr::gather() %>%
180181
tidyr::separate(.data$key,
181-
into = c("channel", "none"),
182-
sep = regex, remove = FALSE
182+
into = c("channel", "none"),
183+
sep = regex, remove = FALSE
183184
) %>%
184185
mutate(hyperparameter = gsub("^.*_", "", .data$key)) %>%
185186
select(.data$channel, .data$hyperparameter, .data$value) %>%

R/R/model.R

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@
4949
#' In other words, given the same DECOMP.RSSD score, a model with 50\% 0-coef
5050
#' variables will get penalized by DECOMP.RSSD * 1.5 (larger error), while
5151
#' another model with no 0-coef variables gets un-penalized with DECOMP.RSSD * 1.
52+
#' @param objective_weights Numeric. Default to NULL that gives equal weights
53+
#' to all objective functions. Set c(2, 1) to give double weight to the 1st.
54+
#' This is an experimental feature. There's no research on optimal weight
55+
#' setting. Subjective weights might strongly bias modelling result.
5256
#' @param seed Integer. For reproducible results when running nevergrad.
5357
#' @param outputs Boolean. Process results with \code{robyn_outputs()}?
5458
#' @param lambda_control Deprecated in v3.6.0.
@@ -81,6 +85,7 @@ robyn_run <- function(InputCollect = NULL,
8185
trials = 5,
8286
iterations = 2000,
8387
rssd_zero_penalty = TRUE,
88+
objective_weights = NULL,
8489
nevergrad_algo = "TwoPointsDE",
8590
intercept = TRUE,
8691
intercept_sign = "non_negative",
@@ -135,6 +140,7 @@ robyn_run <- function(InputCollect = NULL,
135140
check_run_inputs(cores, iterations, trials, intercept_sign, nevergrad_algo)
136141
check_iteration(InputCollect$calibration_input, iterations, trials, hyps_fixed, refresh)
137142
init_msgs_run(InputCollect, refresh, lambda_control = NULL, quiet)
143+
objective_weights <- check_obj_weight(InputCollect$calibration_input, objective_weights, refresh)
138144

139145
#####################################
140146
#### Prepare hyper-parameters
@@ -160,6 +166,7 @@ robyn_run <- function(InputCollect = NULL,
160166
ts_validation = ts_validation,
161167
add_penalty_factor = add_penalty_factor,
162168
rssd_zero_penalty = rssd_zero_penalty,
169+
objective_weights = objective_weights,
163170
refresh, seed, quiet
164171
)
165172

@@ -290,6 +297,7 @@ robyn_train <- function(InputCollect, hyper_collect,
290297
dt_hyper_fixed = NULL,
291298
ts_validation = TRUE,
292299
add_penalty_factor = FALSE,
300+
objective_weights = NULL,
293301
rssd_zero_penalty = TRUE,
294302
refresh = FALSE, seed = 123,
295303
quiet = FALSE) {
@@ -309,6 +317,7 @@ robyn_train <- function(InputCollect, hyper_collect,
309317
ts_validation = ts_validation,
310318
add_penalty_factor = add_penalty_factor,
311319
rssd_zero_penalty = rssd_zero_penalty,
320+
objective_weights = objective_weights,
312321
seed = seed,
313322
quiet = quiet
314323
)
@@ -345,6 +354,7 @@ robyn_train <- function(InputCollect, hyper_collect,
345354
ts_validation = ts_validation,
346355
add_penalty_factor = add_penalty_factor,
347356
rssd_zero_penalty = rssd_zero_penalty,
357+
objective_weights = objective_weights,
348358
refresh = refresh,
349359
trial = ngt,
350360
seed = seed + ngt,
@@ -401,6 +411,7 @@ robyn_mmm <- function(InputCollect,
401411
intercept_sign,
402412
ts_validation = TRUE,
403413
add_penalty_factor = FALSE,
414+
objective_weights = NULL,
404415
dt_hyper_fixed = NULL,
405416
# lambda_fixed = NULL,
406417
rssd_zero_penalty = TRUE,
@@ -529,11 +540,24 @@ robyn_mmm <- function(InputCollect,
529540
my_tuple <- tuple(hyper_count)
530541
instrumentation <- ng$p$Array(shape = my_tuple, lower = 0, upper = 1)
531542
optimizer <- ng$optimizers$registry[optimizer_name](instrumentation, budget = iterTotal, num_workers = cores)
543+
532544
# Set multi-objective dimensions for objective functions (errors)
533545
if (is.null(calibration_input)) {
534546
optimizer$tell(ng$p$MultiobjectiveReference(), tuple(1, 1))
547+
if (is.null(objective_weights)) {
548+
objective_weights <- tuple(1, 1)
549+
} else {
550+
objective_weights <- tuple(objective_weights[1], objective_weights[2])
551+
}
552+
optimizer$set_objective_weights(objective_weights)
535553
} else {
536554
optimizer$tell(ng$p$MultiobjectiveReference(), tuple(1, 1, 1))
555+
if (is.null(objective_weights)) {
556+
objective_weights <- tuple(1, 1, 1)
557+
} else {
558+
objective_weights <- tuple(objective_weights[1], objective_weights[2], objective_weights[3])
559+
}
560+
optimizer$set_objective_weights(objective_weights)
537561
}
538562
}
539563

R/R/plots.R

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -589,10 +589,10 @@ robyn_onepagers <- function(
589589
rver <- utils::sessionInfo()$R.version
590590
onepagerTitle <- sprintf("One-pager for Model ID: %s", sid)
591591
onepagerCaption <- sprintf("Robyn v%s [R-%s.%s]", ver, rver$major, rver$minor)
592-
get_height <- length(unique(plotMediaShareLoopLine$rn)) / 5
592+
get_height <- length(unique(plotMediaShareLoopLine$rn)) / 5
593593
pg <- (p2 + p5) / (p1 + p8) / (p3 + p7) / (p4 + p6) +
594594
patchwork::plot_layout(heights = c(get_height, get_height, get_height, 1), guides = "collect") +
595-
#pg <- wrap_plots(p2, p5, p1, p8, p3, p7, p4, p6, ncol = 2) +
595+
# pg <- wrap_plots(p2, p5, p1, p8, p3, p7, p4, p6, ncol = 2) +
596596
plot_annotation(
597597
title = onepagerTitle, subtitle = errors,
598598
theme = theme_lares(background = "white", legend = "none"),
@@ -1339,7 +1339,8 @@ refresh_plots_json <- function(OutputCollectRF, json_file, export = TRUE, ...) {
13391339
ggsave(
13401340
filename = paste0(
13411341
chainData[[length(chainData)]]$ExportedModel$plot_folder,
1342-
"report_decomposition.png"),
1342+
"report_decomposition.png"
1343+
),
13431344
plot = pBarRF,
13441345
dpi = 900, width = 12, height = 8, limitsize = FALSE
13451346
)

0 commit comments

Comments
 (0)