Skip to content

Commit aa86fbc

Browse files
sirtorrybusunkim96
authored andcommitted
feat(tables): update samples to show explainability [(#2523)](GoogleCloudPlatform/python-docs-samples#2523)
* show xai * local feature importance * use updated client * use fixed library * use new model
1 parent d0a2d74 commit aa86fbc

7 files changed

+267
-171
lines changed

samples/tables/automl_tables_dataset.py

Lines changed: 129 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -79,23 +79,38 @@ def list_datasets(project_id, compute_region, filter_=None):
7979
print("Dataset id: {}".format(dataset.name.split("/")[-1]))
8080
print("Dataset display name: {}".format(dataset.display_name))
8181
metadata = dataset.tables_dataset_metadata
82-
print("Dataset primary table spec id: {}".format(
83-
metadata.primary_table_spec_id))
84-
print("Dataset target column spec id: {}".format(
85-
metadata.target_column_spec_id))
86-
print("Dataset target column spec id: {}".format(
87-
metadata.target_column_spec_id))
88-
print("Dataset weight column spec id: {}".format(
89-
metadata.weight_column_spec_id))
90-
print("Dataset ml use column spec id: {}".format(
91-
metadata.ml_use_column_spec_id))
82+
print(
83+
"Dataset primary table spec id: {}".format(
84+
metadata.primary_table_spec_id
85+
)
86+
)
87+
print(
88+
"Dataset target column spec id: {}".format(
89+
metadata.target_column_spec_id
90+
)
91+
)
92+
print(
93+
"Dataset target column spec id: {}".format(
94+
metadata.target_column_spec_id
95+
)
96+
)
97+
print(
98+
"Dataset weight column spec id: {}".format(
99+
metadata.weight_column_spec_id
100+
)
101+
)
102+
print(
103+
"Dataset ml use column spec id: {}".format(
104+
metadata.ml_use_column_spec_id
105+
)
106+
)
92107
print("Dataset example count: {}".format(dataset.example_count))
93108
print("Dataset create time:")
94109
print("\tseconds: {}".format(dataset.create_time.seconds))
95110
print("\tnanos: {}".format(dataset.create_time.nanos))
96111
print("\n")
97112

98-
# [END automl_tables_list_datasets]
113+
# [END automl_tables_list_datasets]
99114
result.append(dataset)
100115

101116
return result
@@ -119,28 +134,31 @@ def list_table_specs(
119134

120135
# List all the table specs in the dataset by applying filter.
121136
response = client.list_table_specs(
122-
dataset_display_name=dataset_display_name, filter_=filter_)
137+
dataset_display_name=dataset_display_name, filter_=filter_
138+
)
123139

124140
print("List of table specs:")
125141
for table_spec in response:
126142
# Display the table_spec information.
127143
print("Table spec name: {}".format(table_spec.name))
128144
print("Table spec id: {}".format(table_spec.name.split("/")[-1]))
129-
print("Table spec time column spec id: {}".format(
130-
table_spec.time_column_spec_id))
145+
print(
146+
"Table spec time column spec id: {}".format(
147+
table_spec.time_column_spec_id
148+
)
149+
)
131150
print("Table spec row count: {}".format(table_spec.row_count))
132151
print("Table spec column count: {}".format(table_spec.column_count))
133152

134-
# [END automl_tables_list_specs]
153+
# [END automl_tables_list_specs]
135154
result.append(table_spec)
136155

137156
return result
138157

139158

140-
def list_column_specs(project_id,
141-
compute_region,
142-
dataset_display_name,
143-
filter_=None):
159+
def list_column_specs(
160+
project_id, compute_region, dataset_display_name, filter_=None
161+
):
144162
"""List all column specs."""
145163
result = []
146164
# [START automl_tables_list_column_specs]
@@ -156,7 +174,8 @@ def list_column_specs(project_id,
156174

157175
# List all the table specs in the dataset by applying filter.
158176
response = client.list_column_specs(
159-
dataset_display_name=dataset_display_name, filter_=filter_)
177+
dataset_display_name=dataset_display_name, filter_=filter_
178+
)
160179

161180
print("List of column specs:")
162181
for column_spec in response:
@@ -166,7 +185,7 @@ def list_column_specs(project_id,
166185
print("Column spec display name: {}".format(column_spec.display_name))
167186
print("Column spec data type: {}".format(column_spec.data_type))
168187

169-
# [END automl_tables_list_column_specs]
188+
# [END automl_tables_list_column_specs]
170189
result.append(column_spec)
171190

172191
return result
@@ -227,19 +246,20 @@ def get_table_spec(project_id, compute_region, dataset_id, table_spec_id):
227246
# Display the table spec information.
228247
print("Table spec name: {}".format(table_spec.name))
229248
print("Table spec id: {}".format(table_spec.name.split("/")[-1]))
230-
print("Table spec time column spec id: {}".format(
231-
table_spec.time_column_spec_id))
249+
print(
250+
"Table spec time column spec id: {}".format(
251+
table_spec.time_column_spec_id
252+
)
253+
)
232254
print("Table spec row count: {}".format(table_spec.row_count))
233255
print("Table spec column count: {}".format(table_spec.column_count))
234256

235257
# [END automl_tables_get_table_spec]
236258

237259

238-
def get_column_spec(project_id,
239-
compute_region,
240-
dataset_id,
241-
table_spec_id,
242-
column_spec_id):
260+
def get_column_spec(
261+
project_id, compute_region, dataset_id, table_spec_id, column_spec_id
262+
):
243263
"""Get the column spec."""
244264
# [START automl_tables_get_column_spec]
245265
# TODO(developer): Uncomment and set the following variables
@@ -288,7 +308,7 @@ def import_data(project_id, compute_region, dataset_display_name, path):
288308
client = automl.TablesClient(project=project_id, region=compute_region)
289309

290310
response = None
291-
if path.startswith('bq'):
311+
if path.startswith("bq"):
292312
response = client.import_data(
293313
dataset_display_name=dataset_display_name, bigquery_input_uri=path
294314
)
@@ -297,7 +317,7 @@ def import_data(project_id, compute_region, dataset_display_name, path):
297317
input_uris = path.split(",")
298318
response = client.import_data(
299319
dataset_display_name=dataset_display_name,
300-
gcs_input_uris=input_uris
320+
gcs_input_uris=input_uris,
301321
)
302322

303323
print("Processing import...")
@@ -321,8 +341,10 @@ def export_data(project_id, compute_region, dataset_display_name, gcs_uri):
321341
client = automl.TablesClient(project=project_id, region=compute_region)
322342

323343
# Export the dataset to the output URI.
324-
response = client.export_data(dataset_display_name=dataset_display_name,
325-
gcs_output_uri_prefix=gcs_uri)
344+
response = client.export_data(
345+
dataset_display_name=dataset_display_name,
346+
gcs_output_uri_prefix=gcs_uri,
347+
)
326348

327349
print("Processing export...")
328350
# synchronous check of operation status.
@@ -331,12 +353,14 @@ def export_data(project_id, compute_region, dataset_display_name, gcs_uri):
331353
# [END automl_tables_export_data]
332354

333355

334-
def update_dataset(project_id,
335-
compute_region,
336-
dataset_display_name,
337-
target_column_spec_name=None,
338-
weight_column_spec_name=None,
339-
test_train_column_spec_name=None):
356+
def update_dataset(
357+
project_id,
358+
compute_region,
359+
dataset_display_name,
360+
target_column_spec_name=None,
361+
weight_column_spec_name=None,
362+
test_train_column_spec_name=None,
363+
):
340364
"""Update dataset."""
341365
# [START automl_tables_update_dataset]
342366
# TODO(developer): Uncomment and set the following variables
@@ -354,29 +378,31 @@ def update_dataset(project_id,
354378
if target_column_spec_name is not None:
355379
response = client.set_target_column(
356380
dataset_display_name=dataset_display_name,
357-
column_spec_display_name=target_column_spec_name
381+
column_spec_display_name=target_column_spec_name,
358382
)
359383
print("Target column updated. {}".format(response))
360384
if weight_column_spec_name is not None:
361385
response = client.set_weight_column(
362386
dataset_display_name=dataset_display_name,
363-
column_spec_display_name=weight_column_spec_name
387+
column_spec_display_name=weight_column_spec_name,
364388
)
365389
print("Weight column updated. {}".format(response))
366390
if test_train_column_spec_name is not None:
367391
response = client.set_test_train_column(
368392
dataset_display_name=dataset_display_name,
369-
column_spec_display_name=test_train_column_spec_name
393+
column_spec_display_name=test_train_column_spec_name,
370394
)
371395
print("Test/train column updated. {}".format(response))
372396

373397
# [END automl_tables_update_dataset]
374398

375399

376-
def update_table_spec(project_id,
377-
compute_region,
378-
dataset_display_name,
379-
time_column_spec_display_name):
400+
def update_table_spec(
401+
project_id,
402+
compute_region,
403+
dataset_display_name,
404+
time_column_spec_display_name,
405+
):
380406
"""Update table spec."""
381407
# [START automl_tables_update_table_spec]
382408
# TODO(developer): Uncomment and set the following variables
@@ -391,20 +417,22 @@ def update_table_spec(project_id,
391417

392418
response = client.set_time_column(
393419
dataset_display_name=dataset_display_name,
394-
column_spec_display_name=time_column_spec_display_name
420+
column_spec_display_name=time_column_spec_display_name,
395421
)
396422

397423
# synchronous check of operation status.
398424
print("Table spec updated. {}".format(response))
399425
# [END automl_tables_update_table_spec]
400426

401427

402-
def update_column_spec(project_id,
403-
compute_region,
404-
dataset_display_name,
405-
column_spec_display_name,
406-
type_code,
407-
nullable=None):
428+
def update_column_spec(
429+
project_id,
430+
compute_region,
431+
dataset_display_name,
432+
column_spec_display_name,
433+
type_code,
434+
nullable=None,
435+
):
408436
"""Update column spec."""
409437
# [START automl_tables_update_column_spec]
410438
# TODO(developer): Uncomment and set the following variables
@@ -423,7 +451,8 @@ def update_column_spec(project_id,
423451
response = client.update_column_spec(
424452
dataset_display_name=dataset_display_name,
425453
column_spec_display_name=column_spec_display_name,
426-
type_code=type_code, nullable=nullable
454+
type_code=type_code,
455+
nullable=nullable,
427456
)
428457

429458
# synchronous check of operation status.
@@ -546,56 +575,62 @@ def delete_dataset(project_id, compute_region, dataset_display_name):
546575
if args.command == "list_datasets":
547576
list_datasets(project_id, compute_region, args.filter_)
548577
if args.command == "list_table_specs":
549-
list_table_specs(project_id,
550-
compute_region,
551-
args.dataset_display_name,
552-
args.filter_)
578+
list_table_specs(
579+
project_id, compute_region, args.dataset_display_name, args.filter_
580+
)
553581
if args.command == "list_column_specs":
554-
list_column_specs(project_id,
555-
compute_region,
556-
args.dataset_display_name,
557-
args.filter_)
582+
list_column_specs(
583+
project_id, compute_region, args.dataset_display_name, args.filter_
584+
)
558585
if args.command == "get_dataset":
559586
get_dataset(project_id, compute_region, args.dataset_display_name)
560587
if args.command == "get_table_spec":
561-
get_table_spec(project_id,
562-
compute_region,
563-
args.dataset_display_name,
564-
args.table_spec_id)
588+
get_table_spec(
589+
project_id,
590+
compute_region,
591+
args.dataset_display_name,
592+
args.table_spec_id,
593+
)
565594
if args.command == "get_column_spec":
566-
get_column_spec(project_id,
567-
compute_region,
568-
args.dataset_display_name,
569-
args.table_spec_id,
570-
args.column_spec_id)
595+
get_column_spec(
596+
project_id,
597+
compute_region,
598+
args.dataset_display_name,
599+
args.table_spec_id,
600+
args.column_spec_id,
601+
)
571602
if args.command == "import_data":
572-
import_data(project_id,
573-
compute_region,
574-
args.dataset_display_name,
575-
args.path)
603+
import_data(
604+
project_id, compute_region, args.dataset_display_name, args.path
605+
)
576606
if args.command == "export_data":
577-
export_data(project_id,
578-
compute_region,
579-
args.dataset_display_name,
580-
args.gcs_uri)
607+
export_data(
608+
project_id, compute_region, args.dataset_display_name, args.gcs_uri
609+
)
581610
if args.command == "update_dataset":
582-
update_dataset(project_id,
583-
compute_region,
584-
args.dataset_display_name,
585-
args.target_column_spec_name,
586-
args.weight_column_spec_name,
587-
args.ml_use_column_spec_name)
611+
update_dataset(
612+
project_id,
613+
compute_region,
614+
args.dataset_display_name,
615+
args.target_column_spec_name,
616+
args.weight_column_spec_name,
617+
args.ml_use_column_spec_name,
618+
)
588619
if args.command == "update_table_spec":
589-
update_table_spec(project_id,
590-
compute_region,
591-
args.dataset_display_name,
592-
args.time_column_spec_display_name)
620+
update_table_spec(
621+
project_id,
622+
compute_region,
623+
args.dataset_display_name,
624+
args.time_column_spec_display_name,
625+
)
593626
if args.command == "update_column_spec":
594-
update_column_spec(project_id,
595-
compute_region,
596-
args.dataset_display_name,
597-
args.column_spec_display_name,
598-
args.type_code,
599-
args.nullable)
627+
update_column_spec(
628+
project_id,
629+
compute_region,
630+
args.dataset_display_name,
631+
args.column_spec_display_name,
632+
args.type_code,
633+
args.nullable,
634+
)
600635
if args.command == "delete_dataset":
601636
delete_dataset(project_id, compute_region, args.dataset_display_name)

0 commit comments

Comments
 (0)