Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion airflow/providers/amazon/aws/hooks/redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
# under the License.

import sys
from typing import Dict, List, Optional, Union
from contextlib import closing
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

import redshift_connector
import sqlparse
from redshift_connector import Connection as RedshiftConnection
from sqlalchemy import create_engine
from sqlalchemy.engine.url import URL
Expand Down Expand Up @@ -133,3 +135,55 @@ def get_conn(self) -> RedshiftConnection:
conn: RedshiftConnection = redshift_connector.connect(**conn_kwargs)

return conn

def set_autocommit(self, conn, autocommit: Any) -> None:
conn.autocommit = autocommit

def get_autocommit(self, conn):
return getattr(conn, 'autocommit_mode', False)

def run(
self,
sql,
autocommit: bool = False,
parameters: Optional[Union[Sequence[Any], Dict[Any, Any]]] = None,
handler: Optional[Callable] = None,
) -> None:
Copy link
Member

@uranusjr uranusjr Apr 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This return type annotation is wrong

"""
Runs a command or a list of commands. Pass a list of sql
statements to the sql parameter to get them to execute
sequentially.

:param sql: the sql string to be executed with possibly multiple statements,
or a list of sql statements to execute
:param autocommit: What to set the connection's autocommit setting to
before executing the query.
:param parameters: The parameters to render the SQL query with.
:param handler: The result handler which is called with the result of each statement.
:return: query results if handler was provided.
"""

with closing(self.get_conn()) as conn:
self.set_autocommit(conn, autocommit)
sql = sqlparse.split(sql)
self.log.debug("Executing %d statements against Redshift DB", len(sql))
with closing(conn.cursor()) as cur:
if self.supports_autocommit:
self.set_autocommit(conn, autocommit)
results = []
for stmt in sql:
if parameters:
cur.execute(stmt, parameters)
else:
cur.execute(stmt)

if handler is not None:
result = handler(cur)
results.append(result)

self.log.info("Rows affected: %s", cur.rowcount)

if handler is None:
return None

return results
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/operators/redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Sequence, Union

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(
*,
sql: Union[str, Iterable[str]],
redshift_conn_id: str = 'redshift_default',
parameters: Optional[dict] = None,
parameters: Optional[Union[Sequence[Any], Dict[Any, Any]]] = None,
autocommit: bool = True,
**kwargs,
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/transfers/redshift_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
"""Transfers data from AWS Redshift into a S3 Bucket."""
from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(
unload_options: Optional[List] = None,
autocommit: bool = False,
include_header: bool = False,
parameters: Optional[Union[Mapping, Iterable]] = None,
parameters: Optional[Union[Sequence[Any], Dict[Any, Any]]] = None,
table_as_file_name: bool = True, # Set to True by default for not breaking current workflows
**kwargs,
) -> None:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
pandas_requirement,
'mypy-boto3-rds>=1.21.0',
'mypy-boto3-redshift-data>=1.21.0',
'sqlparse>=0.4.1',
]
apache_beam = [
'apache-beam>=2.33.0',
Expand Down