-
Notifications
You must be signed in to change notification settings - Fork 16.3k
Use Boto waiters instead of customer _await_status method for RDS Operators #27410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3505784
372e345
ed479fb
207ac8b
7c96dc0
843f7c6
f22fcb2
ae0331b
acb82d9
6465b61
5ca70f9
3c2309a
7fe1c14
a527a78
c94fbd4
5adeb79
dd10249
a1d6d43
f9e9819
cf469e7
95c1f7d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,8 +18,10 @@ | |
| """Interact with AWS RDS.""" | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
| import time | ||
| from typing import TYPE_CHECKING, Callable | ||
|
|
||
| from airflow.exceptions import AirflowException, AirflowNotFoundException | ||
| from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook | ||
|
|
||
| if TYPE_CHECKING: | ||
|
|
@@ -48,3 +50,300 @@ class RdsHook(AwsGenericHook["RDSClient"]): | |
| def __init__(self, *args, **kwargs) -> None: | ||
| kwargs["client_type"] = "rds" | ||
| super().__init__(*args, **kwargs) | ||
|
|
||
| def get_db_snapshot_state(self, snapshot_id: str) -> str: | ||
| """ | ||
| Get the current state of a DB instance snapshot. | ||
|
|
||
| :param snapshot_id: The ID of the target DB instance snapshot | ||
| :return: Returns the status of the DB snapshot as a string (eg. "available") | ||
| :rtype: str | ||
| :raises AirflowNotFoundException: If the DB instance snapshot does not exist. | ||
| """ | ||
| try: | ||
| response = self.conn.describe_db_snapshots(DBSnapshotIdentifier=snapshot_id) | ||
| except self.conn.exceptions.ClientError as e: | ||
| if e.response["Error"]["Code"] == "DBSnapshotNotFound": | ||
| raise AirflowNotFoundException(e) | ||
| raise e | ||
| return response["DBSnapshots"][0]["Status"].lower() | ||
|
|
||
| def wait_for_db_snapshot_state( | ||
| self, snapshot_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40 | ||
| ) -> None: | ||
| """ | ||
| Polls :py:meth:`RDS.Client.describe_db_snapshots` until the target state is reached. | ||
| An error is raised after a max number of attempts. | ||
|
|
||
| :param snapshot_id: The ID of the target DB instance snapshot | ||
| :param target_state: Wait until this state is reached | ||
| :param check_interval: The amount of time in seconds to wait between attempts | ||
| :param max_attempts: The maximum number of attempts to be made | ||
| """ | ||
|
|
||
| def poke(): | ||
| return self.get_db_snapshot_state(snapshot_id) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When boto waiters aren't available, I pass a closure like this to a generic waiter method so that we can reuse waiting logic. |
||
|
|
||
| target_state = target_state.lower() | ||
| if target_state in ("available", "deleted", "completed"): | ||
| waiter = self.conn.get_waiter(f"db_snapshot_{target_state}") # type: ignore | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see |
||
| waiter.wait( | ||
| DBSnapshotIdentifier=snapshot_id, | ||
| WaiterConfig={"Delay": check_interval, "MaxAttempts": max_attempts}, | ||
| ) | ||
| else: | ||
| self._wait_for_state(poke, target_state, check_interval, max_attempts) | ||
| self.log.info("DB snapshot '%s' reached the '%s' state", snapshot_id, target_state) | ||
|
|
||
| def get_db_cluster_snapshot_state(self, snapshot_id: str) -> str: | ||
| """ | ||
| Get the current state of a DB cluster snapshot. | ||
|
|
||
| :param snapshot_id: The ID of the target DB cluster. | ||
| :return: Returns the status of the DB cluster snapshot as a string (eg. "available") | ||
| :rtype: str | ||
| :raises AirflowNotFoundException: If the DB cluster snapshot does not exist. | ||
| """ | ||
| try: | ||
| response = self.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=snapshot_id) | ||
| except self.conn.exceptions.ClientError as e: | ||
| if e.response["Error"]["Code"] == "DBClusterSnapshotNotFoundFault": | ||
| raise AirflowNotFoundException(e) | ||
| raise e | ||
| return response["DBClusterSnapshots"][0]["Status"].lower() | ||
|
|
||
| def wait_for_db_cluster_snapshot_state( | ||
| self, snapshot_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40 | ||
| ) -> None: | ||
| """ | ||
| Polls :py:meth:`RDS.Client.describe_db_cluster_snapshots` until the target state is reached. | ||
| An error is raised after a max number of attempts. | ||
|
|
||
| :param snapshot_id: The ID of the target DB cluster snapshot | ||
| :param target_state: Wait until this state is reached | ||
| :param check_interval: The amount of time in seconds to wait between attempts | ||
| :param max_attempts: The maximum number of attempts to be made | ||
|
|
||
| .. seealso:: | ||
| A list of possible values for target_state: | ||
| https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.describe_db_cluster_snapshots | ||
| """ | ||
|
|
||
| def poke(): | ||
| return self.get_db_cluster_snapshot_state(snapshot_id) | ||
|
|
||
| target_state = target_state.lower() | ||
| if target_state in ("available", "deleted"): | ||
| waiter = self.conn.get_waiter(f"db_cluster_snapshot_{target_state}") # type: ignore | ||
| waiter.wait( | ||
| DBClusterSnapshotIdentifier=snapshot_id, | ||
| WaiterConfig={"Delay": check_interval, "MaxAttempts": max_attempts}, | ||
| ) | ||
| else: | ||
| self._wait_for_state(poke, target_state, check_interval, max_attempts) | ||
| self.log.info("DB cluster snapshot '%s' reached the '%s' state", snapshot_id, target_state) | ||
|
|
||
| def get_export_task_state(self, export_task_id: str) -> str: | ||
| """ | ||
| Gets the current state of an RDS snapshot export to Amazon S3. | ||
|
|
||
| :param export_task_id: The identifier of the target snapshot export task. | ||
| :return: Returns the status of the snapshot export task as a string (eg. "canceled") | ||
| :rtype: str | ||
| :raises AirflowNotFoundException: If the export task does not exist. | ||
| """ | ||
| try: | ||
| response = self.conn.describe_export_tasks(ExportTaskIdentifier=export_task_id) | ||
| except self.conn.exceptions.ClientError as e: | ||
| if e.response["Error"]["Code"] == "ExportTaskNotFoundFault": | ||
| raise AirflowNotFoundException(e) | ||
| raise e | ||
| return response["ExportTasks"][0]["Status"].lower() | ||
|
|
||
| def wait_for_export_task_state( | ||
| self, export_task_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40 | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default boto waiter delay seems to be 30 for all waiters, but |
||
| ) -> None: | ||
| """ | ||
| Polls :py:meth:`RDS.Client.describe_export_tasks` until the target state is reached. | ||
| An error is raised after a max number of attempts. | ||
|
|
||
| :param export_task_id: The identifier of the target snapshot export task. | ||
| :param target_state: Wait until this state is reached | ||
| :param check_interval: The amount of time in seconds to wait between attempts | ||
| :param max_attempts: The maximum number of attempts to be made | ||
|
|
||
| .. seealso:: | ||
| A list of possible values for target_state: | ||
| https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.describe_export_tasks | ||
| """ | ||
|
|
||
| def poke(): | ||
| return self.get_export_task_state(export_task_id) | ||
|
|
||
| target_state = target_state.lower() | ||
| self._wait_for_state(poke, target_state, check_interval, max_attempts) | ||
| self.log.info("export task '%s' reached the '%s' state", export_task_id, target_state) | ||
|
|
||
| def get_event_subscription_state(self, subscription_name: str) -> str: | ||
| """ | ||
| Gets the current state of an RDS snapshot export to Amazon S3. | ||
|
|
||
| :param subscription_name: The name of the target RDS event notification subscription. | ||
| :return: Returns the status of the event subscription as a string (eg. "active") | ||
| :rtype: str | ||
| :raises AirflowNotFoundException: If the event subscription does not exist. | ||
| """ | ||
| try: | ||
| response = self.conn.describe_event_subscriptions(SubscriptionName=subscription_name) | ||
| except self.conn.exceptions.ClientError as e: | ||
| if e.response["Error"]["Code"] == "SubscriptionNotFoundFault": | ||
| raise AirflowNotFoundException(e) | ||
| raise e | ||
| return response["EventSubscriptionsList"][0]["Status"].lower() | ||
|
|
||
| def wait_for_event_subscription_state( | ||
| self, subscription_name: str, target_state: str, check_interval: int = 30, max_attempts: int = 40 | ||
| ) -> None: | ||
| """ | ||
| Polls :py:meth:`RDS.Client.describe_event_subscriptions` until the target state is reached. | ||
| An error is raised after a max number of attempts. | ||
|
|
||
| :param subscription_name: The name of the target RDS event notification subscription. | ||
| :param target_state: Wait until this state is reached | ||
| :param check_interval: The amount of time in seconds to wait between attempts | ||
| :param max_attempts: The maximum number of attempts to be made | ||
|
|
||
| .. seealso:: | ||
| A list of possible values for target_state: | ||
| https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.describe_event_subscriptions | ||
| """ | ||
|
|
||
| def poke(): | ||
| return self.get_event_subscription_state(subscription_name) | ||
|
|
||
| target_state = target_state.lower() | ||
| self._wait_for_state(poke, target_state, check_interval, max_attempts) | ||
| self.log.info("event subscription '%s' reached the '%s' state", subscription_name, target_state) | ||
|
|
||
| def get_db_instance_state(self, db_instance_id: str) -> str: | ||
| """ | ||
| Get the current state of a DB instance. | ||
|
|
||
| :param snapshot_id: The ID of the target DB instance. | ||
| :return: Returns the status of the DB instance as a string (eg. "available") | ||
| :rtype: str | ||
| :raises AirflowNotFoundException: If the DB instance does not exist. | ||
| """ | ||
| try: | ||
| response = self.conn.describe_db_instances(DBInstanceIdentifier=db_instance_id) | ||
| except self.conn.exceptions.ClientError as e: | ||
| if e.response["Error"]["Code"] == "DBInstanceNotFoundFault": | ||
| raise AirflowNotFoundException(e) | ||
| raise e | ||
| return response["DBInstances"][0]["DBInstanceStatus"].lower() | ||
|
|
||
| def wait_for_db_instance_state( | ||
| self, db_instance_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40 | ||
| ) -> None: | ||
| """ | ||
| Polls :py:meth:`RDS.Client.describe_db_instances` until the target state is reached. | ||
| An error is raised after a max number of attempts. | ||
|
|
||
| :param db_instance_id: The ID of the target DB instance. | ||
| :param target_state: Wait until this state is reached | ||
| :param check_interval: The amount of time in seconds to wait between attempts | ||
| :param max_attempts: The maximum number of attempts to be made | ||
|
|
||
| .. seealso:: | ||
| For information about DB instance statuses, see Viewing DB instance status in the Amazon RDS | ||
| User Guide. | ||
| https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/accessing-monitoring.html#Overview.DBInstance.Status | ||
| """ | ||
|
|
||
| def poke(): | ||
| return self.get_db_instance_state(db_instance_id) | ||
|
|
||
| target_state = target_state.lower() | ||
| if target_state in ("available", "deleted"): | ||
| waiter = self.conn.get_waiter(f"db_instance_{target_state}") # type: ignore | ||
| waiter.wait( | ||
| DBInstanceIdentifier=db_instance_id, | ||
| WaiterConfig={"Delay": check_interval, "MaxAttempts": max_attempts}, | ||
| ) | ||
| else: | ||
| self._wait_for_state(poke, target_state, check_interval, max_attempts) | ||
| self.log.info("DB cluster snapshot '%s' reached the '%s' state", db_instance_id, target_state) | ||
|
|
||
| def get_db_cluster_state(self, db_cluster_id: str) -> str: | ||
| """ | ||
| Get the current state of a DB cluster. | ||
|
|
||
| :param snapshot_id: The ID of the target DB cluster. | ||
| :return: Returns the status of the DB cluster as a string (eg. "available") | ||
| :rtype: str | ||
| :raises AirflowNotFoundException: If the DB cluster does not exist. | ||
| """ | ||
| try: | ||
| response = self.conn.describe_db_clusters(DBClusterIdentifier=db_cluster_id) | ||
| except self.conn.exceptions.ClientError as e: | ||
| if e.response["Error"]["Code"] == "DBClusterNotFoundFault": | ||
| raise AirflowNotFoundException(e) | ||
| raise e | ||
| return response["DBClusters"][0]["Status"].lower() | ||
|
|
||
| def wait_for_db_cluster_state( | ||
| self, db_cluster_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40 | ||
| ) -> None: | ||
| """ | ||
| Polls :py:meth:`RDS.Client.describe_db_clusters` until the target state is reached. | ||
| An error is raised after a max number of attempts. | ||
|
|
||
| :param db_cluster_id: The ID of the target DB cluster. | ||
| :param target_state: Wait until this state is reached | ||
| :param check_interval: The amount of time in seconds to wait between attempts | ||
| :param max_attempts: The maximum number of attempts to be made | ||
|
|
||
| .. seealso:: | ||
| For information about DB instance statuses, see Viewing DB instance status in the Amazon RDS | ||
| User Guide. | ||
| https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/accessing-monitoring.html#Overview.DBInstance.Status | ||
| """ | ||
|
|
||
| def poke(): | ||
| return self.get_db_cluster_state(db_cluster_id) | ||
|
|
||
| target_state = target_state.lower() | ||
| if target_state in ("available", "deleted"): | ||
| waiter = self.conn.get_waiter(f"db_cluster_{target_state}") # type: ignore | ||
| waiter.wait( | ||
| DBClusterIdentifier=db_cluster_id, | ||
| WaiterConfig={"Delay": check_interval, "MaxAttempts": max_attempts}, | ||
| ) | ||
| else: | ||
| self._wait_for_state(poke, target_state, check_interval, max_attempts) | ||
| self.log.info("DB cluster snapshot '%s' reached the '%s' state", db_cluster_id, target_state) | ||
|
|
||
| def _wait_for_state( | ||
| self, | ||
| poke: Callable[..., str], | ||
| target_state: str, | ||
| check_interval: int, | ||
| max_attempts: int, | ||
| ) -> None: | ||
| """ | ||
| Polls the poke function for the current state until it reaches the target_state. | ||
|
|
||
| :param poke: A function that returns the current state of the target resource as a string. | ||
| :param target_state: Wait until this state is reached | ||
| :param check_interval: The amount of time in seconds to wait between attempts | ||
| :param max_attempts: The maximum number of attempts to be made | ||
| """ | ||
| state = poke() | ||
| tries = 1 | ||
| while state != target_state: | ||
| self.log.info("Current state is %s", state) | ||
| if tries >= max_attempts: | ||
| raise AirflowException("Max attempts exceeded") | ||
| time.sleep(check_interval) | ||
| state = poke() | ||
| tries += 1 | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For each of the following resources, I added:
Resources:
In the current implementation,
RdsBaseOperatorandRdsBaseSensoreach have a method that calls the appropriate boto "describe" function based on a string (like "db_snapshot" or "event_subscription").airflow/airflow/providers/amazon/aws/operators/rds.py
Lines 49 to 66 in 64174ce
I considered just moving this logic over to
RdsHook, but decided to split it up because the existence (and variety) of boto waiters depends on the resource. For example, the db snapshot resource has 3 boto waiters, the db cluster snapshot resource has 2, and the export task resource has none.One could maybe reduce the amount of code by keep a separate list of resources and their waiters and dynamically choosing at runtime, but I didn't go that far.