Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 300 additions & 1 deletion airflow/providers/amazon/aws/hooks/rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
"""Interact with AWS RDS."""
from __future__ import annotations

from typing import TYPE_CHECKING
import time
from typing import TYPE_CHECKING, Callable

from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook

if TYPE_CHECKING:
Expand Down Expand Up @@ -48,3 +50,300 @@ class RdsHook(AwsGenericHook["RDSClient"]):
def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = "rds"
super().__init__(*args, **kwargs)

def get_db_snapshot_state(self, snapshot_id: str) -> str:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For each of the following resources, I added:

  1. A method to get the resource state
  2. A method to wait for the resource to reach a certain state

Resources:

  • db instance
  • db cluster
  • db snapshot
  • db cluster snapshot
  • export task
  • event subscription

In the current implementation, RdsBaseOperator and RdsBaseSensor each have a method that calls the appropriate boto "describe" function based on a string (like "db_snapshot" or "event_subscription").

if item_type == "instance_snapshot":
db_snaps = self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=item_name)
return db_snaps["DBSnapshots"]
elif item_type == "cluster_snapshot":
cl_snaps = self.hook.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=item_name)
return cl_snaps["DBClusterSnapshots"]
elif item_type == "export_task":
exports = self.hook.conn.describe_export_tasks(ExportTaskIdentifier=item_name)
return exports["ExportTasks"]
elif item_type == "event_subscription":
subscriptions = self.hook.conn.describe_event_subscriptions(SubscriptionName=item_name)
return subscriptions["EventSubscriptionsList"]
elif item_type == "db_instance":
instances = self.hook.conn.describe_db_instances(DBInstanceIdentifier=item_name)
return instances["DBInstances"]
elif item_type == "db_cluster":
clusters = self.hook.conn.describe_db_clusters(DBClusterIdentifier=item_name)
return clusters["DBClusters"]

I considered just moving this logic over to RdsHook, but decided to split it up because the existence (and variety) of boto waiters depends on the resource. For example, the db snapshot resource has 3 boto waiters, the db cluster snapshot resource has 2, and the export task resource has none.

One could maybe reduce the amount of code by keep a separate list of resources and their waiters and dynamically choosing at runtime, but I didn't go that far.

"""
Get the current state of a DB instance snapshot.

:param snapshot_id: The ID of the target DB instance snapshot
:return: Returns the status of the DB snapshot as a string (eg. "available")
:rtype: str
:raises AirflowNotFoundException: If the DB instance snapshot does not exist.
"""
try:
response = self.conn.describe_db_snapshots(DBSnapshotIdentifier=snapshot_id)
except self.conn.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "DBSnapshotNotFound":
raise AirflowNotFoundException(e)
raise e
return response["DBSnapshots"][0]["Status"].lower()

def wait_for_db_snapshot_state(
self, snapshot_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40
) -> None:
"""
Polls :py:meth:`RDS.Client.describe_db_snapshots` until the target state is reached.
An error is raised after a max number of attempts.

:param snapshot_id: The ID of the target DB instance snapshot
:param target_state: Wait until this state is reached
:param check_interval: The amount of time in seconds to wait between attempts
:param max_attempts: The maximum number of attempts to be made
"""

def poke():
return self.get_db_snapshot_state(snapshot_id)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When boto waiters aren't available, I pass a closure like this to a generic waiter method so that we can reuse waiting logic.


target_state = target_state.lower()
if target_state in ("available", "deleted", "completed"):
waiter = self.conn.get_waiter(f"db_snapshot_{target_state}") # type: ignore
Copy link
Contributor Author

@hankehly hankehly Oct 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see no overload exists for ... mypy warnings here because I'm creating the waiter name dynamically. The if statement one line above assures we aren't creating a waiter that doesn't exist so I ignore the warning. If someone has a more appropriate solution, please let me know.

waiter.wait(
DBSnapshotIdentifier=snapshot_id,
WaiterConfig={"Delay": check_interval, "MaxAttempts": max_attempts},
)
else:
self._wait_for_state(poke, target_state, check_interval, max_attempts)
self.log.info("DB snapshot '%s' reached the '%s' state", snapshot_id, target_state)

def get_db_cluster_snapshot_state(self, snapshot_id: str) -> str:
"""
Get the current state of a DB cluster snapshot.

:param snapshot_id: The ID of the target DB cluster.
:return: Returns the status of the DB cluster snapshot as a string (eg. "available")
:rtype: str
:raises AirflowNotFoundException: If the DB cluster snapshot does not exist.
"""
try:
response = self.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=snapshot_id)
except self.conn.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "DBClusterSnapshotNotFoundFault":
raise AirflowNotFoundException(e)
raise e
return response["DBClusterSnapshots"][0]["Status"].lower()

def wait_for_db_cluster_snapshot_state(
self, snapshot_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40
) -> None:
"""
Polls :py:meth:`RDS.Client.describe_db_cluster_snapshots` until the target state is reached.
An error is raised after a max number of attempts.

:param snapshot_id: The ID of the target DB cluster snapshot
:param target_state: Wait until this state is reached
:param check_interval: The amount of time in seconds to wait between attempts
:param max_attempts: The maximum number of attempts to be made

.. seealso::
A list of possible values for target_state:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.describe_db_cluster_snapshots
"""

def poke():
return self.get_db_cluster_snapshot_state(snapshot_id)

target_state = target_state.lower()
if target_state in ("available", "deleted"):
waiter = self.conn.get_waiter(f"db_cluster_snapshot_{target_state}") # type: ignore
waiter.wait(
DBClusterSnapshotIdentifier=snapshot_id,
WaiterConfig={"Delay": check_interval, "MaxAttempts": max_attempts},
)
else:
self._wait_for_state(poke, target_state, check_interval, max_attempts)
self.log.info("DB cluster snapshot '%s' reached the '%s' state", snapshot_id, target_state)

def get_export_task_state(self, export_task_id: str) -> str:
"""
Gets the current state of an RDS snapshot export to Amazon S3.

:param export_task_id: The identifier of the target snapshot export task.
:return: Returns the status of the snapshot export task as a string (eg. "canceled")
:rtype: str
:raises AirflowNotFoundException: If the export task does not exist.
"""
try:
response = self.conn.describe_export_tasks(ExportTaskIdentifier=export_task_id)
except self.conn.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "ExportTaskNotFoundFault":
raise AirflowNotFoundException(e)
raise e
return response["ExportTasks"][0]["Status"].lower()

def wait_for_export_task_state(
self, export_task_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default boto waiter delay seems to be 30 for all waiters, but max_attempts varies from 40~60. I (arbitrarily) opted to go with 40 everywhere.

) -> None:
"""
Polls :py:meth:`RDS.Client.describe_export_tasks` until the target state is reached.
An error is raised after a max number of attempts.

:param export_task_id: The identifier of the target snapshot export task.
:param target_state: Wait until this state is reached
:param check_interval: The amount of time in seconds to wait between attempts
:param max_attempts: The maximum number of attempts to be made

.. seealso::
A list of possible values for target_state:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.describe_export_tasks
"""

def poke():
return self.get_export_task_state(export_task_id)

target_state = target_state.lower()
self._wait_for_state(poke, target_state, check_interval, max_attempts)
self.log.info("export task '%s' reached the '%s' state", export_task_id, target_state)

def get_event_subscription_state(self, subscription_name: str) -> str:
"""
Gets the current state of an RDS snapshot export to Amazon S3.

:param subscription_name: The name of the target RDS event notification subscription.
:return: Returns the status of the event subscription as a string (eg. "active")
:rtype: str
:raises AirflowNotFoundException: If the event subscription does not exist.
"""
try:
response = self.conn.describe_event_subscriptions(SubscriptionName=subscription_name)
except self.conn.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "SubscriptionNotFoundFault":
raise AirflowNotFoundException(e)
raise e
return response["EventSubscriptionsList"][0]["Status"].lower()

def wait_for_event_subscription_state(
self, subscription_name: str, target_state: str, check_interval: int = 30, max_attempts: int = 40
) -> None:
"""
Polls :py:meth:`RDS.Client.describe_event_subscriptions` until the target state is reached.
An error is raised after a max number of attempts.

:param subscription_name: The name of the target RDS event notification subscription.
:param target_state: Wait until this state is reached
:param check_interval: The amount of time in seconds to wait between attempts
:param max_attempts: The maximum number of attempts to be made

.. seealso::
A list of possible values for target_state:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.describe_event_subscriptions
"""

def poke():
return self.get_event_subscription_state(subscription_name)

target_state = target_state.lower()
self._wait_for_state(poke, target_state, check_interval, max_attempts)
self.log.info("event subscription '%s' reached the '%s' state", subscription_name, target_state)

def get_db_instance_state(self, db_instance_id: str) -> str:
"""
Get the current state of a DB instance.

:param snapshot_id: The ID of the target DB instance.
:return: Returns the status of the DB instance as a string (eg. "available")
:rtype: str
:raises AirflowNotFoundException: If the DB instance does not exist.
"""
try:
response = self.conn.describe_db_instances(DBInstanceIdentifier=db_instance_id)
except self.conn.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "DBInstanceNotFoundFault":
raise AirflowNotFoundException(e)
raise e
return response["DBInstances"][0]["DBInstanceStatus"].lower()

def wait_for_db_instance_state(
self, db_instance_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40
) -> None:
"""
Polls :py:meth:`RDS.Client.describe_db_instances` until the target state is reached.
An error is raised after a max number of attempts.

:param db_instance_id: The ID of the target DB instance.
:param target_state: Wait until this state is reached
:param check_interval: The amount of time in seconds to wait between attempts
:param max_attempts: The maximum number of attempts to be made

.. seealso::
For information about DB instance statuses, see Viewing DB instance status in the Amazon RDS
User Guide.
https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/accessing-monitoring.html#Overview.DBInstance.Status
"""

def poke():
return self.get_db_instance_state(db_instance_id)

target_state = target_state.lower()
if target_state in ("available", "deleted"):
waiter = self.conn.get_waiter(f"db_instance_{target_state}") # type: ignore
waiter.wait(
DBInstanceIdentifier=db_instance_id,
WaiterConfig={"Delay": check_interval, "MaxAttempts": max_attempts},
)
else:
self._wait_for_state(poke, target_state, check_interval, max_attempts)
self.log.info("DB cluster snapshot '%s' reached the '%s' state", db_instance_id, target_state)

def get_db_cluster_state(self, db_cluster_id: str) -> str:
"""
Get the current state of a DB cluster.

:param snapshot_id: The ID of the target DB cluster.
:return: Returns the status of the DB cluster as a string (eg. "available")
:rtype: str
:raises AirflowNotFoundException: If the DB cluster does not exist.
"""
try:
response = self.conn.describe_db_clusters(DBClusterIdentifier=db_cluster_id)
except self.conn.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "DBClusterNotFoundFault":
raise AirflowNotFoundException(e)
raise e
return response["DBClusters"][0]["Status"].lower()

def wait_for_db_cluster_state(
self, db_cluster_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40
) -> None:
"""
Polls :py:meth:`RDS.Client.describe_db_clusters` until the target state is reached.
An error is raised after a max number of attempts.

:param db_cluster_id: The ID of the target DB cluster.
:param target_state: Wait until this state is reached
:param check_interval: The amount of time in seconds to wait between attempts
:param max_attempts: The maximum number of attempts to be made

.. seealso::
For information about DB instance statuses, see Viewing DB instance status in the Amazon RDS
User Guide.
https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/accessing-monitoring.html#Overview.DBInstance.Status
"""

def poke():
return self.get_db_cluster_state(db_cluster_id)

target_state = target_state.lower()
if target_state in ("available", "deleted"):
waiter = self.conn.get_waiter(f"db_cluster_{target_state}") # type: ignore
waiter.wait(
DBClusterIdentifier=db_cluster_id,
WaiterConfig={"Delay": check_interval, "MaxAttempts": max_attempts},
)
else:
self._wait_for_state(poke, target_state, check_interval, max_attempts)
self.log.info("DB cluster snapshot '%s' reached the '%s' state", db_cluster_id, target_state)

def _wait_for_state(
self,
poke: Callable[..., str],
target_state: str,
check_interval: int,
max_attempts: int,
) -> None:
"""
Polls the poke function for the current state until it reaches the target_state.

:param poke: A function that returns the current state of the target resource as a string.
:param target_state: Wait until this state is reached
:param check_interval: The amount of time in seconds to wait between attempts
:param max_attempts: The maximum number of attempts to be made
"""
state = poke()
tries = 1
while state != target_state:
self.log.info("Current state is %s", state)
if tries >= max_attempts:
raise AirflowException("Max attempts exceeded")
time.sleep(check_interval)
state = poke()
tries += 1
Loading