Skip to content

Commit 96dd703

Browse files
authored
fix SagemakerProcessingOperator ThrottlingException (#19195)
fix SagemakerProcessingOperator ThrottlingException (#19195)
1 parent eb12bb2 commit 96dd703

File tree

4 files changed

+40
-16
lines changed

4 files changed

+40
-16
lines changed

airflow/providers/amazon/aws/hooks/sagemaker.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,3 +954,13 @@ def _list_request(
954954
return results
955955
else:
956956
next_token = response["NextToken"]
957+
958+
def find_processing_job_by_name(self, processing_job_name: str) -> bool:
959+
"""Query processing job by name"""
960+
try:
961+
self.get_conn().describe_processing_job(ProcessingJobName=processing_job_name)
962+
return True
963+
except ClientError as e:
964+
if e.response['Error']['Code'] in ['ValidationException', 'ResourceNotFound']:
965+
return False
966+
raise

airflow/providers/amazon/aws/operators/sagemaker_processing.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,11 @@ def execute(self, context) -> dict:
9595
self.preprocess_config()
9696

9797
processing_job_name = self.config["ProcessingJobName"]
98-
processing_jobs = self.hook.list_processing_jobs(NameContains=processing_job_name)
9998

100-
# Check if given ProcessingJobName already exists
101-
if processing_job_name in [pj["ProcessingJobName"] for pj in processing_jobs]:
102-
if self.action_if_job_exists == "fail":
103-
raise AirflowException(
104-
f"A SageMaker processing job with name {processing_job_name} already exists."
105-
)
106-
if self.action_if_job_exists == "increment":
107-
self.log.info("Found existing processing job with name '%s'.", processing_job_name)
108-
new_processing_job_name = f"{processing_job_name}-{len(processing_jobs) + 1}"
109-
self.config["ProcessingJobName"] = new_processing_job_name
110-
self.log.info("Incremented processing job name to '%s'.", new_processing_job_name)
99+
if self.hook.find_processing_job_by_name(processing_job_name):
100+
raise AirflowException(
101+
f"A SageMaker processing job with name {processing_job_name} already exists."
102+
)
111103

112104
self.log.info("Creating SageMaker processing job %s.", self.config["ProcessingJobName"])
113105
response = self.hook.create_processing_job(

tests/providers/amazon/aws/hooks/test_sagemaker.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,3 +647,23 @@ def test_training_with_logs(self, mock_describe, mock_client, mock_log_client, m
647647
)
648648
assert mock_describe.call_count == 3
649649
assert mock_session.describe_training_job.call_count == 1
650+
651+
@mock.patch.object(SageMakerHook, 'get_conn')
652+
def test_find_processing_job_by_name(self, mock_conn):
653+
hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
654+
mock_conn.describe_processing_job.return_value = {}
655+
ret = hook.find_processing_job_by_name("existing_job")
656+
assert ret
657+
658+
@mock.patch.object(SageMakerHook, 'get_conn')
659+
def test_find_processing_job_by_name_job_not_exists_should_return_false(self, mock_conn):
660+
from botocore.exceptions import ClientError
661+
662+
error_resp = {"Error": {"Code": "ValidationException"}}
663+
mock_conn().describe_processing_job.side_effect = ClientError(
664+
error_response=error_resp, operation_name="dummy"
665+
)
666+
hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
667+
668+
ret = hook.find_processing_job_by_name("existing_job")
669+
assert not ret

tests/providers/amazon/aws/operators/test_sagemaker_processing.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,13 @@ def test_integer_fields_are_set(self, config, expected_fields):
115115
assert sagemaker.integer_fields == expected_fields
116116

117117
@mock.patch.object(SageMakerHook, 'get_conn')
118+
@mock.patch.object(SageMakerHook, "find_processing_job_by_name", return_value=False)
118119
@mock.patch.object(
119120
SageMakerHook,
120121
'create_processing_job',
121122
return_value={'ProcessingJobArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 200}},
122123
)
123-
def test_execute(self, mock_processing, mock_client):
124+
def test_execute(self, mock_processing, mock_hook, mock_client):
124125
sagemaker = SageMakerProcessingOperator(
125126
**self.processing_config_kwargs, config=create_processing_params
126127
)
@@ -142,13 +143,14 @@ def test_execute_with_failure(self, mock_processing, mock_client):
142143
with pytest.raises(AirflowException):
143144
sagemaker.execute(None)
144145

146+
@unittest.skip("Currently, the auto-increment jobname functionality is not missing.")
145147
@mock.patch.object(SageMakerHook, "get_conn")
146-
@mock.patch.object(SageMakerHook, "list_processing_jobs", return_value=[{"ProcessingJobName": job_name}])
148+
@mock.patch.object(SageMakerHook, "find_processing_job_by_name", return_value=True)
147149
@mock.patch.object(
148150
SageMakerHook, "create_processing_job", return_value={"ResponseMetadata": {"HTTPStatusCode": 200}}
149151
)
150152
def test_execute_with_existing_job_increment(
151-
self, mock_create_processing_job, mock_list_processing_jobs, mock_client
153+
self, mock_create_processing_job, find_processing_job_by_name, mock_client
152154
):
153155
sagemaker = SageMakerProcessingOperator(
154156
**self.processing_config_kwargs, config=create_processing_params
@@ -167,7 +169,7 @@ def test_execute_with_existing_job_increment(
167169
)
168170

169171
@mock.patch.object(SageMakerHook, "get_conn")
170-
@mock.patch.object(SageMakerHook, "list_processing_jobs", return_value=[{"ProcessingJobName": job_name}])
172+
@mock.patch.object(SageMakerHook, "find_processing_job_by_name", return_value=True)
171173
@mock.patch.object(
172174
SageMakerHook, "create_processing_job", return_value={"ResponseMetadata": {"HTTPStatusCode": 200}}
173175
)

0 commit comments

Comments
 (0)