Skip to content

Commit 96baa3d

Browse files
authored
feat(automl): expose disable_early_stopping option for create_model (#9779)
Disable early stopping is turned off by default. It's defined in proto here: https://github.com/googleapis/google-cloud-python/blob/bfb4da8542981d2eedffe20f64e87ab528a17592/automl/google/cloud/automl_v1beta1/proto/tables.proto#L196 This PR: - exposes disable_early_stopping option for TablesClient.create_model. - changes the default value of create_model's parameter model_metadata from empty dict to None. Default parameter values in Python is non-intuitive: http://effbot.org/zone/default-values.htm
1 parent dfe4667 commit 96baa3d

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

automl/google/cloud/automl_v1beta1/tables/tables_client.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2106,9 +2106,10 @@ def create_model(
21062106
optimization_objective=None,
21072107
project=None,
21082108
region=None,
2109-
model_metadata={},
2109+
model_metadata=None,
21102110
include_column_spec_names=None,
21112111
exclude_column_spec_names=None,
2112+
disable_early_stopping=False,
21122113
**kwargs
21132114
):
21142115
"""Create a model. This will train your model on the given dataset.
@@ -2168,6 +2169,10 @@ def create_model(
21682169
exclude_column_spec_names(Optional[str]):
21692170
The list of the names of the columns you want to exclude and
21702171
not train your model on.
2172+
disable_early_stopping(Optional[bool]):
2173+
True if disable early stopping. By default, the early stopping
2174+
feature is enabled, which means that AutoML Tables might stop
2175+
training before the entire training budget has been used.
21712176
Returns:
21722177
google.api_core.operation.Operation:
21732178
An operation future that can be used to check for
@@ -2180,6 +2185,9 @@ def create_model(
21802185
to a retryable error and retry attempts failed.
21812186
ValueError: If required parameters are missing.
21822187
"""
2188+
if model_metadata is None:
2189+
model_metadata = {}
2190+
21832191
if (
21842192
train_budget_milli_node_hours is None
21852193
or train_budget_milli_node_hours < 1000
@@ -2212,6 +2220,8 @@ def create_model(
22122220
model_metadata["train_budget_milli_node_hours"] = train_budget_milli_node_hours
22132221
if optimization_objective is not None:
22142222
model_metadata["optimization_objective"] = optimization_objective
2223+
if disable_early_stopping:
2224+
model_metadata["disable_early_stopping"] = True
22152225

22162226
dataset_id = dataset_name.rsplit("/", 1)[-1]
22172227
columns = [

0 commit comments

Comments
 (0)