@@ -114,21 +114,20 @@ def testStartCloudTraining(self, mock_discovery):
114
114
115
115
default_image = 'gcr.io/tfx-oss-public/tfx:{}' .format (
116
116
version_utils .get_image_version ())
117
- self .assertDictContainsSubset (
118
- {
119
- 'masterConfig' : {
120
- 'imageUri' :
121
- default_image ,
122
- 'containerCommand' :
123
- runner ._CONTAINER_COMMAND + [
124
- '--executor_class_path' , class_path , '--inputs' , '{}' ,
125
- '--outputs' , '{}' , '--exec-properties' ,
126
- ('{"custom_config": '
127
- '"{\\ "ai_platform_training_args\\ ": {\\ "project\\ ": \\ "12345\\ "'
128
- '}}"}' )
129
- ],
130
- },
131
- }, body ['training_input' ])
117
+ self .assertLessEqual ({
118
+ 'masterConfig' : {
119
+ 'imageUri' :
120
+ default_image ,
121
+ 'containerCommand' :
122
+ runner ._CONTAINER_COMMAND + [
123
+ '--executor_class_path' , class_path , '--inputs' , '{}' ,
124
+ '--outputs' , '{}' , '--exec-properties' ,
125
+ ('{"custom_config": '
126
+ '"{\\ "ai_platform_training_args\\ ": {\\ "project\\ ": \\ "12345\\ "'
127
+ '}}"}' )
128
+ ],
129
+ },
130
+ }.items (), body ['training_input' ].items ())
132
131
self .assertNotIn ('project' , body ['training_input' ])
133
132
self .assertStartsWith (body ['job_id' ], 'tfx_' )
134
133
self ._mock_get .execute .assert_called_with ()
@@ -239,28 +238,27 @@ def testStartCloudTrainingWithUserContainer_Vertex(self, mock_gapic):
239
238
custom_job = mock .ANY )
240
239
kwargs = self ._mock_create .call_args [1 ]
241
240
body = kwargs ['custom_job' ]
242
- self .assertDictContainsSubset (
243
- {
244
- 'worker_pool_specs' : [{
245
- 'container_spec' : {
246
- 'image_uri' :
247
- 'my-custom-image' ,
248
- 'command' :
249
- runner ._CONTAINER_COMMAND + [
250
- '--executor_class_path' , class_path , '--inputs' ,
251
- '{}' , '--outputs' , '{}' , '--exec-properties' ,
252
- ('{"custom_config": '
253
- '"{\\ "ai_platform_training_args\\ ": '
254
- '{\\ "project\\ ": \\ "12345\\ ", '
255
- '\\ "worker_pool_specs\\ ": '
256
- '[{\\ "container_spec\\ ": '
257
- '{\\ "image_uri\\ ": \\ "my-custom-image\\ "}}]}, '
258
- '\\ "ai_platform_training_job_id\\ ": '
259
- '\\ "my_jobid\\ "}"}' )
260
- ],
261
- },
262
- },],
263
- }, body ['job_spec' ])
241
+ self .assertLessEqual ({
242
+ 'worker_pool_specs' : [{
243
+ 'container_spec' : {
244
+ 'image_uri' :
245
+ 'my-custom-image' ,
246
+ 'command' :
247
+ runner ._CONTAINER_COMMAND + [
248
+ '--executor_class_path' , class_path , '--inputs' , '{}' ,
249
+ '--outputs' , '{}' , '--exec-properties' ,
250
+ ('{"custom_config": '
251
+ '"{\\ "ai_platform_training_args\\ ": '
252
+ '{\\ "project\\ ": \\ "12345\\ ", '
253
+ '\\ "worker_pool_specs\\ ": '
254
+ '[{\\ "container_spec\\ ": '
255
+ '{\\ "image_uri\\ ": \\ "my-custom-image\\ "}}]}, '
256
+ '\\ "ai_platform_training_job_id\\ ": '
257
+ '\\ "my_jobid\\ "}"}' )
258
+ ],
259
+ },
260
+ },],
261
+ }.items (), body ['job_spec' ].items ())
264
262
self .assertEqual (body ['display_name' ], 'my_jobid' )
265
263
self ._mock_get .assert_called_with (name = 'vertex_job_study_id' )
266
264
@@ -329,7 +327,7 @@ def testStartCloudTrainingWithVertexCustomJob(self, mock_gapic):
329
327
}, body ['job_spec' ])
330
328
self .assertEqual (body ['display_name' ], 'valid_name' )
331
329
self .assertDictEqual (body ['encryption_spec' ], expected_encryption_spec )
332
- self .assertDictContainsSubset (user_provided_labels , body ['labels' ])
330
+ self .assertLessEqual (user_provided_labels . items () , body ['labels' ]. items () )
333
331
self ._mock_get .assert_called_with (name = 'vertex_job_study_id' )
334
332
335
333
def _setUpPredictionMocks (self ):
0 commit comments