Skip to content

Commit b453a96

Browse files
authored
Merge branch 'main' into sycai_rolling_window
2 parents d9888d2 + b5297f9 commit b453a96

File tree

16 files changed

+689
-10
lines changed

16 files changed

+689
-10
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,5 +60,6 @@ coverage.xml
6060
system_tests/local_test_setup
6161

6262
# Make sure a generated file isn't accidentally committed.
63+
demo.ipynb
6364
pylintrc
6465
pylintrc.test

bigframes/ml/core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ def model(self) -> bigquery.Model:
117117
"""Get the BQML model associated with this wrapper"""
118118
return self._model
119119

120+
def recommend(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
121+
return self._apply_ml_tvf(
122+
input_data,
123+
self._model_manipulation_sql_generator.ml_recommend,
124+
)
125+
120126
def predict(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
121127
return self._apply_ml_tvf(
122128
input_data,

bigframes/ml/decomposition.py

Lines changed: 166 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from typing import List, Literal, Optional, Union
2121

22+
import bigframes_vendored.sklearn.decomposition._mf
2223
import bigframes_vendored.sklearn.decomposition._pca
2324
from google.cloud import bigquery
2425

@@ -27,7 +28,15 @@
2728
import bigframes.pandas as bpd
2829
import bigframes.session
2930

30-
_BQML_PARAMS_MAPPING = {"svd_solver": "pcaSolver"}
31+
_BQML_PARAMS_MAPPING = {
32+
"svd_solver": "pcaSolver",
33+
"feedback_type": "feedbackType",
34+
"num_factors": "numFactors",
35+
"user_col": "userColumn",
36+
"item_col": "itemColumn",
37+
"_input_label_columns": "inputLabelColumns",
38+
"l2_reg": "l2Regularization",
39+
}
3140

3241

3342
@log_adapter.class_logger
@@ -197,3 +206,159 @@ def score(
197206

198207
# TODO(b/291973741): X param is ignored. Update BQML supports input in ML.EVALUATE.
199208
return self._bqml_model.evaluate()
209+
210+
211+
@log_adapter.class_logger
212+
class MatrixFactorization(
213+
base.UnsupervisedTrainablePredictor,
214+
bigframes_vendored.sklearn.decomposition._mf.MatrixFactorization,
215+
):
216+
__doc__ = bigframes_vendored.sklearn.decomposition._mf.MatrixFactorization.__doc__
217+
218+
def __init__(
219+
self,
220+
*,
221+
feedback_type: Literal["explicit", "implicit"] = "explicit",
222+
num_factors: int,
223+
user_col: str,
224+
item_col: str,
225+
rating_col: str = "rating",
226+
# TODO: Add support for hyperparameter tuning.
227+
l2_reg: float = 1.0,
228+
):
229+
230+
feedback_type = feedback_type.lower() # type: ignore
231+
if feedback_type not in ("explicit", "implicit"):
232+
raise ValueError("Expected feedback_type to be `explicit` or `implicit`.")
233+
234+
self.feedback_type = feedback_type
235+
236+
if not isinstance(num_factors, int):
237+
raise TypeError(
238+
f"Expected num_factors to be an int, but got {type(num_factors)}."
239+
)
240+
241+
if num_factors < 0:
242+
raise ValueError(
243+
f"Expected num_factors to be a positive integer, but got {num_factors}."
244+
)
245+
246+
self.num_factors = num_factors
247+
248+
if not isinstance(user_col, str):
249+
raise TypeError(f"Expected user_col to be a str, but got {type(user_col)}.")
250+
251+
self.user_col = user_col
252+
253+
if not isinstance(item_col, str):
254+
raise TypeError(f"Expected item_col to be STR, but got {type(item_col)}.")
255+
256+
self.item_col = item_col
257+
258+
if not isinstance(rating_col, str):
259+
raise TypeError(
260+
f"Expected rating_col to be a str, but got {type(rating_col)}."
261+
)
262+
263+
self._input_label_columns = [rating_col]
264+
265+
if not isinstance(l2_reg, (float, int)):
266+
raise TypeError(
267+
f"Expected l2_reg to be a float or int, but got {type(l2_reg)}."
268+
)
269+
270+
self.l2_reg = l2_reg
271+
self._bqml_model: Optional[core.BqmlModel] = None
272+
self._bqml_model_factory = globals.bqml_model_factory()
273+
274+
@property
275+
def rating_col(self) -> str:
276+
"""str: The rating column name. Defaults to 'rating'."""
277+
return self._input_label_columns[0]
278+
279+
@classmethod
280+
def _from_bq(
281+
cls, session: bigframes.session.Session, bq_model: bigquery.Model
282+
) -> MatrixFactorization:
283+
assert bq_model.model_type == "MATRIX_FACTORIZATION"
284+
285+
kwargs = utils.retrieve_params_from_bq_model(
286+
cls, bq_model, _BQML_PARAMS_MAPPING
287+
)
288+
289+
model = cls(**kwargs)
290+
model._bqml_model = core.BqmlModel(session, bq_model)
291+
return model
292+
293+
@property
294+
def _bqml_options(self) -> dict:
295+
"""The model options as they will be set for BQML"""
296+
options: dict = {
297+
"model_type": "matrix_factorization",
298+
"feedback_type": self.feedback_type,
299+
"user_col": self.user_col,
300+
"item_col": self.item_col,
301+
"rating_col": self.rating_col,
302+
"l2_reg": self.l2_reg,
303+
}
304+
305+
if self.num_factors is not None:
306+
options["num_factors"] = self.num_factors
307+
308+
return options
309+
310+
def _fit(
311+
self,
312+
X: utils.ArrayType,
313+
y=None,
314+
transforms: Optional[List[str]] = None,
315+
) -> MatrixFactorization:
316+
if y is not None:
317+
raise ValueError(
318+
"Label column not supported for Matrix Factorization model but y was not `None`"
319+
)
320+
321+
(X,) = utils.batch_convert_to_dataframe(X)
322+
323+
self._bqml_model = self._bqml_model_factory.create_model(
324+
X_train=X,
325+
transforms=transforms,
326+
options=self._bqml_options,
327+
)
328+
return self
329+
330+
def predict(self, X: utils.ArrayType) -> bpd.DataFrame:
331+
if not self._bqml_model:
332+
raise RuntimeError("A model must be fitted before recommend")
333+
334+
(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)
335+
336+
return self._bqml_model.recommend(X)
337+
338+
def to_gbq(self, model_name: str, replace: bool = False) -> MatrixFactorization:
339+
"""Save the model to BigQuery.
340+
341+
Args:
342+
model_name (str):
343+
The name of the model.
344+
replace (bool, default False):
345+
Determine whether to replace if the model already exists. Default to False.
346+
347+
Returns:
348+
MatrixFactorization: Saved model."""
349+
if not self._bqml_model:
350+
raise RuntimeError("A model must be fitted before it can be saved")
351+
352+
new_model = self._bqml_model.copy(model_name, replace)
353+
return new_model.session.read_gbq_model(model_name)
354+
355+
def score(
356+
self,
357+
X=None,
358+
y=None,
359+
) -> bpd.DataFrame:
360+
if not self._bqml_model:
361+
raise RuntimeError("A model must be fitted before score")
362+
363+
# TODO(b/291973741): X param is ignored. Update BQML supports input in ML.EVALUATE.
364+
return self._bqml_model.evaluate()

bigframes/ml/loader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
"LINEAR_REGRESSION": linear_model.LinearRegression,
4343
"LOGISTIC_REGRESSION": linear_model.LogisticRegression,
4444
"KMEANS": cluster.KMeans,
45+
"MATRIX_FACTORIZATION": decomposition.MatrixFactorization,
4546
"PCA": decomposition.PCA,
4647
"BOOSTED_TREE_REGRESSOR": ensemble.XGBRegressor,
4748
"BOOSTED_TREE_CLASSIFIER": ensemble.XGBClassifier,
@@ -80,6 +81,7 @@
8081
def from_bq(
8182
session: bigframes.session.Session, bq_model: bigquery.Model
8283
) -> Union[
84+
decomposition.MatrixFactorization,
8385
decomposition.PCA,
8486
cluster.KMeans,
8587
linear_model.LinearRegression,

bigframes/ml/sql.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,11 @@ def alter_model(
299299
return "\n".join(parts)
300300

301301
# ML prediction TVFs
302+
def ml_recommend(self, source_sql: str) -> str:
303+
"""Encode ML.RECOMMEND for BQML"""
304+
return f"""SELECT * FROM ML.RECOMMEND(MODEL {self._model_ref_sql()},
305+
({source_sql}))"""
306+
302307
def ml_predict(self, source_sql: str) -> str:
303308
"""Encode ML.PREDICT for BQML"""
304309
return f"""SELECT * FROM ML.PREDICT(MODEL {self._model_ref_sql()},

owlbot.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@
6464
# Fixup files
6565
# ----------------------------------------------------------------------------
6666

67+
# Add scratch space for experimentation to .gitignore.
68+
assert 1 == s.replace(
69+
[".gitignore"],
70+
re.escape("# Make sure a generated file isn't accidentally committed.\n"),
71+
"# Make sure a generated file isn't accidentally committed.\ndemo.ipynb\n",
72+
)
73+
6774
# Encourage sharring all relevant versions in bug reports.
6875
assert 1 == s.replace( # bug_report.md
6976
[".github/ISSUE_TEMPLATE/bug_report.md"],

samples/snippets/create_multiple_timeseries_forecasting_model_test.py

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,26 +73,91 @@ def test_multiple_timeseries_forecasting_model(random_model_id: str) -> None:
7373
from bigframes.ml import forecasting
7474
import bigframes.pandas as bpd
7575

76+
model = forecasting.ARIMAPlus(
77+
# To reduce the query runtime with the compromise of a potential slight
78+
# drop in model quality, you could decrease the value of the
79+
# auto_arima_max_order. This shrinks the search space of hyperparameter
80+
# tuning in the auto.ARIMA algorithm.
81+
auto_arima_max_order=5,
82+
)
83+
7684
df = bpd.read_gbq("bigquery-public-data.new_york.citibike_trips")
7785

86+
# This query creates twelve time series models, one for each of the twelve
87+
# Citi Bike start stations in the input data. If you remove this row
88+
# filter, there would be 600+ time series to forecast.
89+
df = df[df["start_station_name"].str.contains("Central Park")]
90+
7891
features = bpd.DataFrame(
7992
{
80-
"num_trips": df.starttime,
93+
"start_station_name": df["start_station_name"],
94+
"num_trips": df["starttime"],
8195
"date": df["starttime"].dt.date,
8296
}
8397
)
84-
num_trips = features.groupby(["date"], as_index=False).count()
85-
model = forecasting.ARIMAPlus()
98+
num_trips = features.groupby(
99+
["start_station_name", "date"],
100+
as_index=False,
101+
).count()
86102

87103
X = num_trips["date"].to_frame()
88104
y = num_trips["num_trips"].to_frame()
89105

90-
model.fit(X, y)
106+
model.fit(
107+
X,
108+
y,
109+
# The input data that you want to get forecasts for,
110+
# in this case the Citi Bike station, as represented by the
111+
# start_station_name column.
112+
id_col=num_trips["start_station_name"].to_frame(),
113+
)
114+
91115
# The model.fit() call above created a temporary model.
92116
# Use the to_gbq() method to write to a permanent location.
93-
94117
model.to_gbq(
95118
your_model_id, # For example: "bqml_tutorial.nyc_citibike_arima_model",
96119
replace=True,
97120
)
98121
# [END bigquery_dataframes_bqml_arima_multiple_step_3_fit]
122+
123+
# [START bigquery_dataframes_bqml_arima_multiple_step_4_evaluate]
124+
# Evaluate the time series models by using the summary() function. The summary()
125+
# function shows you the evaluation metrics of all the candidate models evaluated
126+
# during the process of automatic hyperparameter tuning.
127+
summary = model.summary()
128+
print(summary.peek())
129+
130+
# Expected output:
131+
# start_station_name non_seasonal_p non_seasonal_d non_seasonal_q has_drift log_likelihood AIC variance ...
132+
# 1 Central Park West & W 72 St 0 1 5 False -1966.449243 3944.898487 1215.689281 ...
133+
# 8 Central Park W & W 96 St 0 0 5 False -274.459923 562.919847 655.776577 ...
134+
# 9 Central Park West & W 102 St 0 0 0 False -226.639918 457.279835 258.83582 ...
135+
# 11 Central Park West & W 76 St 1 1 2 False -1700.456924 3408.913848 383.254161 ...
136+
# 4 Grand Army Plaza & Central Park S 0 1 5 False -5507.553498 11027.106996 624.138741 ...
137+
# [END bigquery_dataframes_bqml_arima_multiple_step_4_evaluate]
138+
139+
# [START bigquery_dataframes_bqml_arima_multiple_step_5_coefficients]
140+
coef = model.coef_
141+
print(coef.peek())
142+
143+
# Expected output:
144+
# start_station_name ar_coefficients ma_coefficients intercept_or_drift
145+
# 5 Central Park West & W 68 St [] [-0.41014089 0.21979212 -0.59854213 -0.251438... 0.0
146+
# 6 Central Park S & 6 Ave [] [-0.71488957 -0.36835772 0.61008532 0.183290... 0.0
147+
# 0 Central Park West & W 85 St [] [-0.39270166 -0.74494638 0.76432596 0.489146... 0.0
148+
# 3 W 82 St & Central Park West [-0.50219511 -0.64820817] [-0.20665325 0.67683137 -0.68108631] 0.0
149+
# 11 W 106 St & Central Park West [-0.70442887 -0.66885553 -0.25030325 -0.34160669] [] 0.0
150+
# [END bigquery_dataframes_bqml_arima_multiple_step_5_coefficients]
151+
152+
# [START bigquery_dataframes_bqml_arima_multiple_step_6_forecast]
153+
prediction = model.predict(horizon=3, confidence_level=0.9)
154+
155+
print(prediction.peek())
156+
# Expected output:
157+
# forecast_timestamp start_station_name forecast_value standard_error confidence_level ...
158+
# 4 2016-10-01 00:00:00+00:00 Central Park S & 6 Ave 302.377201 32.572948 0.9 ...
159+
# 14 2016-10-02 00:00:00+00:00 Central Park North & Adam Clayton Powell Blvd 263.917567 45.284082 0.9 ...
160+
# 1 2016-09-25 00:00:00+00:00 Central Park West & W 85 St 189.574706 39.874856 0.9 ...
161+
# 20 2016-10-02 00:00:00+00:00 Central Park West & W 72 St 175.474862 40.940794 0.9 ...
162+
# 12 2016-10-01 00:00:00+00:00 W 106 St & Central Park West 63.88163 18.088868 0.9 ...
163+
# [END bigquery_dataframes_bqml_arima_multiple_step_6_forecast]

scratch/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Ignore all files in this directory.
2+
*

tests/data/ratings.jsonl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{"user_id": 1, "item_id": 2, "rating": 4.0}
2+
{"user_id": 1, "item_id": 5, "rating": 3.0}
3+
{"user_id": 2, "item_id": 1, "rating": 5.0}
4+
{"user_id": 2, "item_id": 3, "rating": 2.0}
5+
{"user_id": 3, "item_id": 4, "rating": 4.5}
6+
{"user_id": 3, "item_id": 7, "rating": 3.5}
7+
{"user_id": 4, "item_id": 2, "rating": 1.0}
8+
{"user_id": 4, "item_id": 8, "rating": 5.0}
9+
{"user_id": 5, "item_id": 3, "rating": 4.0}
10+
{"user_id": 5, "item_id": 9, "rating": 2.5}
11+
{"user_id": 6, "item_id": 1, "rating": 3.0}
12+
{"user_id": 6, "item_id": 6, "rating": 4.5}
13+
{"user_id": 7, "item_id": 5, "rating": 5.0}
14+
{"user_id": 7, "item_id": 10, "rating": 1.5}
15+
{"user_id": 8, "item_id": 4, "rating": 2.0}
16+
{"user_id": 8, "item_id": 7, "rating": 4.0}
17+
{"user_id": 9, "item_id": 2, "rating": 3.5}
18+
{"user_id": 9, "item_id": 9, "rating": 5.0}
19+
{"user_id": 10, "item_id": 3, "rating": 4.5}
20+
{"user_id": 10, "item_id": 8, "rating": 2.5}

tests/data/ratings_schema.json

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
[
2+
{
3+
"mode": "NULLABLE",
4+
"name": "user_id",
5+
"type": "STRING"
6+
},
7+
{
8+
"mode": "NULLABLE",
9+
"name": "item_id",
10+
"type": "INT64"
11+
},
12+
{
13+
"mode": "NULLABLE",
14+
"name": "rating",
15+
"type": "FLOAT"
16+
}
17+
]

0 commit comments

Comments
 (0)