Skip to content

Commit 5332aa6

Browse files
Merge pull request #936 from facebookexperimental/bl_branch
fix: mediaVecCollect's organic vars weren't being multiplied by coef values + InputCollect print when no prophet country set
2 parents cf7dae7 + f911903 commit 5332aa6

File tree

6 files changed

+15
-17
lines changed

6 files changed

+15
-17
lines changed

R/DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Package: Robyn
22
Type: Package
33
Title: Semi-Automated Marketing Mix Modeling (MMM) from Meta Marketing Science
4-
Version: 3.10.6.9001
4+
Version: 3.10.6.9002
55
Authors@R: c(
66
person("Gufeng", "Zhou", , "gufeng@meta.com", c("cre","aut")),
77
person("Bernardo", "Lares", , "laresbernardo@gmail.com", c("aut")),

R/R/checks.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ check_paidmedia <- function(dt_input, paid_media_vars, paid_media_signs, paid_me
243243
check_vector(paid_media_vars)
244244
check_vector(paid_media_signs)
245245
check_vector(paid_media_spends)
246-
mediaVarCount <- length(paid_media_vars)
246+
expVarCount <- length(paid_media_vars)
247247
spendVarCount <- length(paid_media_spends)
248248

249249
temp <- paid_media_vars %in% names(dt_input)
@@ -261,7 +261,7 @@ check_paidmedia <- function(dt_input, paid_media_vars, paid_media_signs, paid_me
261261
))
262262
}
263263
if (is.null(paid_media_signs)) {
264-
paid_media_signs <- rep("positive", mediaVarCount)
264+
paid_media_signs <- rep("positive", expVarCount)
265265
}
266266
if (!all(paid_media_signs %in% OPTS_PDN)) {
267267
stop("Allowed values for 'paid_media_signs' are: ", paste(OPTS_PDN, collapse = ", "))
@@ -272,7 +272,7 @@ check_paidmedia <- function(dt_input, paid_media_vars, paid_media_signs, paid_me
272272
if (length(paid_media_signs) != length(paid_media_vars)) {
273273
stop("Input 'paid_media_signs' must have same length as 'paid_media_vars'")
274274
}
275-
if (spendVarCount != mediaVarCount) {
275+
if (spendVarCount != expVarCount) {
276276
stop("Input 'paid_media_spends' must have same length as 'paid_media_vars'")
277277
}
278278
is_num <- unlist(lapply(dt_input[, paid_media_vars], is.numeric))
@@ -295,7 +295,7 @@ check_paidmedia <- function(dt_input, paid_media_vars, paid_media_signs, paid_me
295295
}
296296
return(invisible(list(
297297
paid_media_signs = paid_media_signs,
298-
mediaVarCount = mediaVarCount,
298+
expVarCount = expVarCount,
299299
paid_media_vars = paid_media_vars
300300
)))
301301
}
@@ -706,7 +706,7 @@ check_InputCollect <- function(list) {
706706
names_list <- c(
707707
"dt_input", "paid_media_vars", "paid_media_spends", "context_vars",
708708
"organic_vars", "all_ind_vars", "date_var", "dep_var",
709-
"rollingWindowStartWhich", "rollingWindowEndWhich", "mediaVarCount",
709+
"rollingWindowStartWhich", "rollingWindowEndWhich",
710710
"factor_vars", "prophet_vars", "prophet_signs", "prophet_country",
711711
"intervalType", "dt_holidays"
712712
)

R/R/clusters.R

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
robyn_clusters <- function(input, dep_var_type,
4242
cluster_by = "hyperparameters",
4343
all_media = NULL,
44-
k = "auto", wss_var = 0.05, limit = 1,
44+
k = "auto", wss_var = 0.06, limit = 1,
4545
weights = rep(1, 3), dim_red = "PCA",
4646
quiet = FALSE, export = FALSE, seed = 123,
4747
...) {
@@ -315,8 +315,7 @@ errors_scores <- function(df, balance = rep(1, 3), ts_validation = TRUE, ...) {
315315
select(any_of(c("solID", all_media)))
316316
}
317317
errors <- distinct(
318-
x, .data$solID, .data$nrmse, .data$nrmse_test,
319-
.data$nrmse_train, .data$decomp.rssd, .data$mape
318+
x, .data$solID, starts_with("nrmse"), .data$decomp.rssd, .data$mape
320319
)
321320
outcome <- left_join(outcome, errors, "solID") %>% ungroup()
322321
} else {

R/R/inputs.R

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,10 @@ robyn_inputs <- function(dt_input = NULL,
231231
context <- check_context(dt_input, context_vars, context_signs)
232232
context_signs <- context$context_signs
233233

234-
## Check paid media variables (set mediaVarCount and maybe transform paid_media_signs)
234+
## Check paid media variables (and maybe transform paid_media_signs)
235235
if (is.null(paid_media_vars)) paid_media_vars <- paid_media_spends
236236
paidmedia <- check_paidmedia(dt_input, paid_media_vars, paid_media_signs, paid_media_spends)
237237
paid_media_signs <- paidmedia$paid_media_signs
238-
mediaVarCount <- paidmedia$mediaVarCount
239238
exposure_vars <- paid_media_vars[!(paid_media_vars == paid_media_spends)]
240239

241240
## Check organic media variables (and maybe transform organic_signs)
@@ -309,7 +308,6 @@ robyn_inputs <- function(dt_input = NULL,
309308
paid_media_signs = paid_media_signs,
310309
paid_media_spends = paid_media_spends,
311310
paid_media_total = paid_media_total,
312-
mediaVarCount = mediaVarCount,
313311
exposure_vars = exposure_vars,
314312
organic_vars = organic_vars,
315313
organic_signs = organic_signs,
@@ -435,7 +433,8 @@ Adstock: {x$adstock}
435433
windows = paste(x$window_start, x$window_end, sep = ":"),
436434
custom_params = if (length(x$custom_params) > 0) paste("\n", flatten_hyps(x$custom_params)) else "None",
437435
prophet = if (length(x$prophet_vars) > 0) {
438-
sprintf("%s on %s", paste(x$prophet_vars, collapse = ", "), x$prophet_country)
436+
sprintf("%s on %s", paste(x$prophet_vars, collapse = ", "),
437+
ifelse(!is.null(x$prophet_country), x$prophet_country, "data"))
439438
} else {
440439
"\033[0;31mDeactivated\033[0m"
441440
},
@@ -622,7 +621,7 @@ robyn_engineering <- function(x, quiet = FALSE, ...) {
622621
mediaCostFactor <- colSums(subset(dt_inputRollWind, select = paid_media_spends), na.rm = TRUE) /
623622
colSums(subset(dt_inputRollWind, select = paid_media_vars), na.rm = TRUE)
624623

625-
for (i in 1:InputCollect$mediaVarCount) {
624+
for (i in seq_along(paid_media_spends)) {
626625
if (exposure_selector[i]) {
627626
# Run models (NLS and/or LM)
628627
dt_spendModInput <- subset(dt_inputRollWind, select = c(paid_media_spends[i], paid_media_vars[i]))

R/R/pareto.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ robyn_pareto <- function(InputCollect, OutputModels,
407407
)
408408
}
409409
dt_transformSaturationDecomp <- dt_transformSaturation
410-
for (i in 1:InputCollect$mediaVarCount) {
410+
for (i in seq_along(InputCollect$all_media)) {
411411
coef <- plotWaterfallLoop$coef[plotWaterfallLoop$rn == InputCollect$all_media[i]]
412412
dt_transformSaturationDecomp[InputCollect$all_media[i]] <- coef *
413413
dt_transformSaturationDecomp[InputCollect$all_media[i]]
@@ -418,7 +418,7 @@ robyn_pareto <- function(InputCollect, OutputModels,
418418

419419
## Reverse MM fitting
420420
# dt_transformSaturationSpendReverse <- copy(dt_transformAdstock[, c("ds", InputCollect$all_media), with = FALSE])
421-
# for (i in 1:InputCollect$mediaVarCount) {
421+
# for (i in seq_along(InputCollect$paid_media_spends)) {
422422
# chn <- InputCollect$paid_media_vars[i]
423423
# if (chn %in% InputCollect$paid_media_vars[InputCollect$exposure_selector]) {
424424
# # Get Michaelis Menten nls fitting param

R/man/robyn_clusters.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)