Skip to content
75 changes: 60 additions & 15 deletions airflow/providers/amazon/aws/hooks/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import warnings
from datetime import datetime
from functools import partial
from typing import Any, Callable, Dict, Generator, List, Optional, Set, cast
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, cast

from botocore.exceptions import ClientError

Expand Down Expand Up @@ -844,24 +844,38 @@ def list_training_jobs(
:param kwargs: (optional) kwargs to boto3's list_training_jobs method
:return: results of the list_training_jobs request
"""
config = {}
config, max_results = self._preprocess_list_request_args(name_contains, max_results, **kwargs)
list_training_jobs_request = partial(self.get_conn().list_training_jobs, **config)
results = self._list_request(
list_training_jobs_request, "TrainingJobSummaries", max_results=max_results
)
return results

if name_contains:
if "NameContains" in kwargs:
raise AirflowException("Either name_contains or NameContains can be provided, not both.")
config["NameContains"] = name_contains
def list_transform_jobs(
self, name_contains: Optional[str] = None, max_results: Optional[int] = None, **kwargs
) -> List[Dict]:
"""
This method wraps boto3's `list_transform_jobs`.
The transform job name and max results are configurable via arguments.
Other arguments are not, and should be provided via kwargs. Note boto3 expects these in
CamelCase format, for example:

if "MaxResults" in kwargs and kwargs["MaxResults"] is not None:
if max_results:
raise AirflowException("Either max_results or MaxResults can be provided, not both.")
# Unset MaxResults, we'll use the SageMakerHook's internal method for iteratively fetching results
max_results = kwargs["MaxResults"]
del kwargs["MaxResults"]
.. code-block:: python

config.update(kwargs)
list_training_jobs_request = partial(self.get_conn().list_training_jobs, **config)
list_transform_jobs(name_contains="myjob", StatusEquals="Failed")

.. seealso::
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_transform_jobs

:param name_contains: (optional) partial name to match
:param max_results: (optional) maximum number of results to return. None returns infinite results
:param kwargs: (optional) kwargs to boto3's list_transform_jobs method
:return: results of the list_transform_jobs request
"""
config, max_results = self._preprocess_list_request_args(name_contains, max_results, **kwargs)
list_transform_jobs_request = partial(self.get_conn().list_transform_jobs, **config)
results = self._list_request(
list_training_jobs_request, "TrainingJobSummaries", max_results=max_results
list_transform_jobs_request, "TransformJobSummaries", max_results=max_results
)
return results

Expand All @@ -886,6 +900,37 @@ def list_processing_jobs(self, **kwargs) -> List[Dict]:
)
return results

def _preprocess_list_request_args(
self, name_contains: Optional[str] = None, max_results: Optional[int] = None, **kwargs
) -> Tuple[Dict[str, Any], Optional[int]]:
"""
This method preprocesses the arguments to the boto3's list_* methods.
It will turn arguments name_contains and max_results as boto3 compliant CamelCase format.
This method also makes sure that these two arguments are only set once.

:param name_contains: boto3 function with arguments
:param max_results: the result key to iterate over
:param kwargs: (optional) kwargs to boto3's list_* method
:return: Tuple with config dict to be passed to boto3's list_* method and max_results parameter
"""
config = {}

if name_contains:
if "NameContains" in kwargs:
raise AirflowException("Either name_contains or NameContains can be provided, not both.")
config["NameContains"] = name_contains

if "MaxResults" in kwargs and kwargs["MaxResults"] is not None:
if max_results:
raise AirflowException("Either max_results or MaxResults can be provided, not both.")
# Unset MaxResults, we'll use the SageMakerHook's internal method for iteratively fetching results
max_results = kwargs["MaxResults"]
del kwargs["MaxResults"]

config.update(kwargs)

return config, max_results

def _list_request(
self, partial_func: Callable, result_key: str, max_results: Optional[int] = None
) -> List[Dict]:
Expand Down
34 changes: 33 additions & 1 deletion airflow/providers/amazon/aws/operators/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,11 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
:param max_ingestion_time: If wait is set to True, the operation fails
if the transform job doesn't finish within max_ingestion_time seconds. If you
set this parameter to None, the operation does not timeout.
:param check_if_job_exists: If set to true, then the operator will check whether a transform job
already exists for the name in the config.
:param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment"
(default) and "fail".
This is only relevant if check_if_job_exists is True.
:return Dict: Returns The ARN of the model created in Amazon SageMaker.
"""

Expand All @@ -411,6 +416,8 @@ def __init__(
wait_for_completion: bool = True,
check_interval: int = CHECK_INTERVAL_SECOND,
max_ingestion_time: Optional[int] = None,
check_if_job_exists: bool = True,
action_if_job_exists: str = 'increment',
**kwargs,
):
super().__init__(config=config, **kwargs)
Expand All @@ -419,6 +426,14 @@ def __init__(
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
self.check_if_job_exists = check_if_job_exists
if action_if_job_exists in ('increment', 'fail'):
self.action_if_job_exists = action_if_job_exists
else:
raise AirflowException(
f"Argument action_if_job_exists accepts only 'increment' and 'fail'. \
Provided value: '{action_if_job_exists}'."
)

def _create_integer_fields(self) -> None:
"""Set fields which should be cast to integers."""
Expand All @@ -444,6 +459,8 @@ def execute(self, context: 'Context') -> Dict:
self.preprocess_config()
model_config = self.config.get('Model')
transform_config = self.config.get('Transform', self.config)
if self.check_if_job_exists:
self._check_if_transform_job_exists()
if model_config:
self.log.info('Creating SageMaker Model %s for transform job', model_config['ModelName'])
self.hook.create_model(model_config)
Expand All @@ -462,6 +479,21 @@ def execute(self, context: 'Context') -> Dict:
'Transform': self.hook.describe_transform_job(transform_config['TransformJobName']),
}

def _check_if_transform_job_exists(self) -> None:
transform_config = self.config.get('Transform', self.config)
transform_job_name = transform_config['TransformJobName']
transform_jobs = self.hook.list_transform_jobs(name_contains=transform_job_name)
if transform_job_name in [tj['TransformJobName'] for tj in transform_jobs]:
if self.action_if_job_exists == 'increment':
self.log.info("Found existing transform job with name '%s'.", transform_job_name)
new_transform_job_name = f'{transform_job_name}-{(len(transform_jobs) + 1)}'
transform_config['TransformJobName'] = new_transform_job_name
self.log.info("Incremented transform job name to '%s'.", new_transform_job_name)
elif self.action_if_job_exists == 'fail':
raise AirflowException(
f'A SageMaker transform job with name {transform_job_name} already exists.'
)


class SageMakerTuningOperator(SageMakerBaseOperator):
"""
Expand Down Expand Up @@ -605,7 +637,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
already exists for the name in the config.
:param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment"
(default) and "fail".
This is only relevant if check_if
This is only relevant if check_if_job_exists is True.
:return Dict: Returns The ARN of the training job created in Amazon SageMaker.
"""

Expand Down
57 changes: 57 additions & 0 deletions tests/providers/amazon/aws/operators/test_sagemaker_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,60 @@ def test_execute_with_failure(self, mock_transform, mock_model, mock_client):
}
with pytest.raises(AirflowException):
self.sagemaker.execute(None)

@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_transform_job')
def test_execute_with_check_if_job_exists(self, mock_transform, mock_client):
mock_transform.return_value = {
'TransformJobArn': 'test_arn',
'ResponseMetadata': {'HTTPStatusCode': 200},
}
self.sagemaker._check_if_transform_job_exists = mock.MagicMock()
self.sagemaker.execute(None)
self.sagemaker._check_if_transform_job_exists.assert_called_once()
mock_transform.assert_called_once_with(
CREATE_TRANSFORM_PARAMS,
wait_for_completion=False,
check_interval=5,
max_ingestion_time=None,
)

@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_transform_job')
def test_execute_without_check_if_job_exists(self, mock_transform, mock_client):
mock_transform.return_value = {
'TransformJobArn': 'test_arn',
'ResponseMetadata': {'HTTPStatusCode': 200},
}
self.sagemaker.check_if_job_exists = False
self.sagemaker._check_if_transform_job_exists = mock.MagicMock()
self.sagemaker.execute(None)
self.sagemaker._check_if_transform_job_exists.assert_not_called()
mock_transform.assert_called_once_with(
CREATE_TRANSFORM_PARAMS,
wait_for_completion=False,
check_interval=5,
max_ingestion_time=None,
)

@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'list_transform_jobs')
def test_check_if_job_exists_increment(self, mock_list_transform_jobs, mock_client):
self.sagemaker.check_if_job_exists = True
self.sagemaker.action_if_job_exists = 'increment'
mock_list_transform_jobs.return_value = [{'TransformJobName': 'job_name'}]
self.sagemaker._check_if_transform_job_exists()

expected_config = CONFIG.copy()
# Expect to see TransformJobName suffixed with "-2" because we return one existing job
expected_config["Transform"]['TransformJobName'] = 'job_name-2'
assert self.sagemaker.config == expected_config

@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'list_transform_jobs')
def test_check_if_job_exists_fail(self, mock_list_transform_jobs, mock_client):
self.sagemaker.check_if_job_exists = True
self.sagemaker.action_if_job_exists = 'fail'
mock_list_transform_jobs.return_value = [{'TransformJobName': 'job_name'}]
with pytest.raises(AirflowException):
self.sagemaker._check_if_transform_job_exists()