@@ -114,21 +114,20 @@ def testStartCloudTraining(self, mock_discovery):
114114
115115 default_image = 'gcr.io/tfx-oss-public/tfx:{}' .format (
116116 version_utils .get_image_version ())
117- self .assertDictContainsSubset (
118- {
119- 'masterConfig' : {
120- 'imageUri' :
121- default_image ,
122- 'containerCommand' :
123- runner ._CONTAINER_COMMAND + [
124- '--executor_class_path' , class_path , '--inputs' , '{}' ,
125- '--outputs' , '{}' , '--exec-properties' ,
126- ('{"custom_config": '
127- '"{\\ "ai_platform_training_args\\ ": {\\ "project\\ ": \\ "12345\\ "'
128- '}}"}' )
129- ],
130- },
131- }, body ['training_input' ])
117+ self .assertLessEqual ({
118+ 'masterConfig' : {
119+ 'imageUri' :
120+ default_image ,
121+ 'containerCommand' :
122+ runner ._CONTAINER_COMMAND + [
123+ '--executor_class_path' , class_path , '--inputs' , '{}' ,
124+ '--outputs' , '{}' , '--exec-properties' ,
125+ ('{"custom_config": '
126+ '"{\\ "ai_platform_training_args\\ ": {\\ "project\\ ": \\ "12345\\ "'
127+ '}}"}' )
128+ ],
129+ },
130+ }.items (), body ['training_input' ].items ())
132131 self .assertNotIn ('project' , body ['training_input' ])
133132 self .assertStartsWith (body ['job_id' ], 'tfx_' )
134133 self ._mock_get .execute .assert_called_with ()
@@ -239,28 +238,27 @@ def testStartCloudTrainingWithUserContainer_Vertex(self, mock_gapic):
239238 custom_job = mock .ANY )
240239 kwargs = self ._mock_create .call_args [1 ]
241240 body = kwargs ['custom_job' ]
242- self .assertDictContainsSubset (
243- {
244- 'worker_pool_specs' : [{
245- 'container_spec' : {
246- 'image_uri' :
247- 'my-custom-image' ,
248- 'command' :
249- runner ._CONTAINER_COMMAND + [
250- '--executor_class_path' , class_path , '--inputs' ,
251- '{}' , '--outputs' , '{}' , '--exec-properties' ,
252- ('{"custom_config": '
253- '"{\\ "ai_platform_training_args\\ ": '
254- '{\\ "project\\ ": \\ "12345\\ ", '
255- '\\ "worker_pool_specs\\ ": '
256- '[{\\ "container_spec\\ ": '
257- '{\\ "image_uri\\ ": \\ "my-custom-image\\ "}}]}, '
258- '\\ "ai_platform_training_job_id\\ ": '
259- '\\ "my_jobid\\ "}"}' )
260- ],
261- },
262- },],
263- }, body ['job_spec' ])
241+ self .assertLessEqual ({
242+ 'worker_pool_specs' : [{
243+ 'container_spec' : {
244+ 'image_uri' :
245+ 'my-custom-image' ,
246+ 'command' :
247+ runner ._CONTAINER_COMMAND + [
248+ '--executor_class_path' , class_path , '--inputs' , '{}' ,
249+ '--outputs' , '{}' , '--exec-properties' ,
250+ ('{"custom_config": '
251+ '"{\\ "ai_platform_training_args\\ ": '
252+ '{\\ "project\\ ": \\ "12345\\ ", '
253+ '\\ "worker_pool_specs\\ ": '
254+ '[{\\ "container_spec\\ ": '
255+ '{\\ "image_uri\\ ": \\ "my-custom-image\\ "}}]}, '
256+ '\\ "ai_platform_training_job_id\\ ": '
257+ '\\ "my_jobid\\ "}"}' )
258+ ],
259+ },
260+ },],
261+ }.items (), body ['job_spec' ].items ())
264262 self .assertEqual (body ['display_name' ], 'my_jobid' )
265263 self ._mock_get .assert_called_with (name = 'vertex_job_study_id' )
266264
@@ -329,7 +327,7 @@ def testStartCloudTrainingWithVertexCustomJob(self, mock_gapic):
329327 }, body ['job_spec' ])
330328 self .assertEqual (body ['display_name' ], 'valid_name' )
331329 self .assertDictEqual (body ['encryption_spec' ], expected_encryption_spec )
332- self .assertDictContainsSubset (user_provided_labels , body ['labels' ])
330+ self .assertLessEqual (user_provided_labels . items () , body ['labels' ]. items () )
333331 self ._mock_get .assert_called_with (name = 'vertex_job_study_id' )
334332
335333 def _setUpPredictionMocks (self ):
0 commit comments