Skip to content

Commit 47dd4c9

Browse files
BasPHashb
authored andcommitted
[AIRFLOW-4835] Refactor operator render_template (#5461)
- Refactors `BaseOperator.render_template()` and removes `render_template_from_field()`. The functionality could be greatly simplified into a single `render_template()` function. - Removes six usage. - Improves performance by removing two `hasattr` calls and avoiding recreating Jinja environments. - Removes the argument `attr` to `render_template()` which wasn't used. - Squashes multiple similar tests into two parameterized tests. - Adheres to 110 line length. - Adds support for templating sets. - Adds Pydoc. - Adds typing.
1 parent 8cf0635 commit 47dd4c9

File tree

9 files changed

+230
-313
lines changed

9 files changed

+230
-313
lines changed

airflow/models/baseoperator.py

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import sys
2626
import warnings
2727
from datetime import timedelta, datetime
28-
from typing import Callable, Dict, Iterable, List, Optional, Set
28+
from typing import Callable, Dict, Iterable, List, Optional, Set, Any
2929

3030
import jinja2
3131

@@ -636,51 +636,76 @@ def __setstate__(self, state):
636636
self.__dict__ = state
637637
self._log = logging.getLogger("airflow.task.operators")
638638

639-
def render_template_from_field(self, attr, content, context, jinja_env):
639+
def render_template_fields(self, context: Dict, jinja_env: Optional[jinja2.Environment] = None) -> None:
640640
"""
641-
Renders a template from a field. If the field is a string, it will
642-
simply render the string and return the result. If it is a collection or
643-
nested set of collections, it will traverse the structure and render
644-
all elements in it. If the field has another type, it will return it as it is.
641+
Template all attributes listed in template_fields. Note this operation is irreversible.
642+
643+
:param context: Dict with values to apply on content
644+
:type context: dict
645+
:param jinja_env: Jinja environment
646+
:type jinja_env: jinja2.Environment
647+
"""
648+
649+
if not jinja_env:
650+
jinja_env = self.get_template_env()
651+
652+
for attr_name in self.template_fields:
653+
content = getattr(self, attr_name)
654+
if content:
655+
rendered_content = self.render_template(content, context, jinja_env)
656+
setattr(self, attr_name, rendered_content)
657+
658+
def render_template(
659+
self, content: Any, context: Dict, jinja_env: Optional[jinja2.Environment] = None
660+
) -> Any:
661+
"""
662+
Render a templated string. The content can be a collection holding multiple templated strings and will
663+
be templated recursively.
664+
665+
:param content: Content to template. Only strings can be templated (may be inside collection).
666+
:type content: Any
667+
:param context: Dict with values to apply on templated content
668+
:type context: dict
669+
:param jinja_env: Jinja environment. Can be provided to avoid re-creating Jinja environments during
670+
recursion.
671+
:type jinja_env: jinja2.Environment
672+
:return: Templated content
645673
"""
646-
rt = self.render_template
674+
675+
if not jinja_env:
676+
jinja_env = self.get_template_env()
677+
647678
if isinstance(content, str):
648-
result = jinja_env.from_string(content).render(**context)
649-
elif isinstance(content, tuple):
679+
if any(content.endswith(ext) for ext in self.template_ext):
680+
# Content contains a filepath
681+
return jinja_env.get_template(content).render(**context)
682+
else:
683+
return jinja_env.from_string(content).render(**context)
684+
685+
if isinstance(content, tuple):
650686
if type(content) is not tuple:
651687
# Special case for named tuples
652-
result = content.__class__(*(rt(attr, e, context) for e in content))
688+
return content.__class__(
689+
*(self.render_template(element, context, jinja_env) for element in content)
690+
)
653691
else:
654-
result = tuple(rt(attr, e, context) for e in content)
692+
return tuple(self.render_template(element, context, jinja_env) for element in content)
693+
655694
elif isinstance(content, list):
656-
result = [rt(attr, e, context) for e in content]
695+
return [self.render_template(element, context, jinja_env) for element in content]
696+
657697
elif isinstance(content, dict):
658-
result = {
659-
k: rt("{}[{}]".format(attr, k), v, context)
660-
for k, v in list(content.items())}
661-
else:
662-
result = content
663-
return result
698+
return {key: self.render_template(value, context, jinja_env) for key, value in content.items()}
664699

665-
def render_template(self, attr, content, context):
666-
"""
667-
Renders a template either from a file or directly in a field, and returns
668-
the rendered result.
669-
"""
670-
jinja_env = self.get_template_env()
700+
elif isinstance(content, set):
701+
return {self.render_template(element, context, jinja_env) for element in content}
671702

672-
exts = self.__class__.template_ext
673-
if (
674-
isinstance(content, str) and
675-
any([content.endswith(ext) for ext in exts])):
676-
return jinja_env.get_template(content).render(**context)
677703
else:
678-
return self.render_template_from_field(attr, content, context, jinja_env)
704+
return content
679705

680-
def get_template_env(self):
681-
return self.dag.get_template_env() \
682-
if hasattr(self, 'dag') \
683-
else jinja2.Environment(cache_size=0)
706+
def get_template_env(self) -> jinja2.Environment:
707+
"""Fetch a Jinja template environment from the DAG or instantiate empty environment if no DAG."""
708+
return self.dag.get_template_env() if self.has_dag() else jinja2.Environment(cache_size=0)
684709

685710
def prepare_template(self):
686711
"""

airflow/models/dag.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -544,9 +544,7 @@ def filepath(self):
544544

545545
@property
546546
def folder(self):
547-
"""
548-
Folder location of where the dag object is instantiated
549-
"""
547+
"""Folder location of where the DAG object is instantiated."""
550548
return os.path.dirname(self.full_filepath)
551549

552550
@property
@@ -701,11 +699,10 @@ def resolve_template_files(self):
701699
for t in self.tasks:
702700
t.resolve_template_files()
703701

704-
def get_template_env(self):
705-
"""
706-
Returns a jinja2 Environment while taking into account the DAGs
707-
template_searchpath, user_defined_macros and user_defined_filters
708-
"""
702+
def get_template_env(self) -> jinja2.Environment:
703+
"""Build a Jinja2 environment."""
704+
705+
# Collect directories to search for template files
709706
searchpath = [self.folder]
710707
if self.template_searchpath:
711708
searchpath += self.template_searchpath
@@ -714,7 +711,11 @@ def get_template_env(self):
714711
loader=jinja2.FileSystemLoader(searchpath),
715712
undefined=self.template_undefined,
716713
extensions=["jinja2.ext.do"],
717-
cache_size=0)
714+
cache_size=0,
715+
)
716+
717+
# Add any user defined items. Safe to edit globals as long as no templates are rendered yet.
718+
# http://jinja.pocoo.org/docs/2.10/api/#jinja2.Environment.globals
718719
if self.user_defined_macros:
719720
env.globals.update(self.user_defined_macros)
720721
if self.user_defined_filters:

airflow/models/taskinstance.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,21 +1239,12 @@ def overwrite_params_with_dag_run_conf(self, params, dag_run):
12391239
if dag_run and dag_run.conf:
12401240
params.update(dag_run.conf)
12411241

1242-
def render_templates(self, context=None):
1243-
task = self.task
1242+
def render_templates(self, context=None) -> None:
1243+
"""Render templates in the operator fields."""
12441244
if not context:
12451245
context = self.get_template_context()
12461246

1247-
if hasattr(self, 'task') and hasattr(self.task, 'dag'):
1248-
if self.task.dag.user_defined_macros:
1249-
context.update(self.task.dag.user_defined_macros)
1250-
1251-
rt = self.task.render_template # shortcut to method
1252-
for attr in task.__class__.template_fields:
1253-
content = getattr(task, attr)
1254-
if content:
1255-
rendered_content = rt(attr, content, context)
1256-
setattr(task, attr, rendered_content)
1247+
self.task.render_template_fields(context)
12571248

12581249
def email_alert(self, exception):
12591250
exception_html = str(exception).replace('\n', '<br>')

tests/contrib/operators/test_databricks_operator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def test_init_with_templating(self):
165165
}
166166
dag = DAG('test', start_date=datetime.now())
167167
op = DatabricksSubmitRunOperator(dag=dag, task_id=TASK_ID, json=json)
168-
op.json = op.render_template('json', op.json, {'ds': DATE})
168+
op.render_template_fields(context={'ds': DATE})
169169
expected = databricks_operator._deep_string_coerce({
170170
'new_cluster': NEW_CLUSTER,
171171
'notebook_task': RENDERED_TEMPLATED_NOTEBOOK_TASK,
@@ -332,7 +332,7 @@ def test_init_with_templating(self):
332332

333333
dag = DAG('test', start_date=datetime.now())
334334
op = DatabricksRunNowOperator(dag=dag, task_id=TASK_ID, job_id=JOB_ID, json=json)
335-
op.json = op.render_template('json', op.json, {'ds': DATE})
335+
op.render_template_fields(context={'ds': DATE})
336336
expected = databricks_operator._deep_string_coerce({
337337
'notebook_params': NOTEBOOK_PARAMS,
338338
'jar_params': RENDERED_TEMPLATED_JAR_PARAMS,

tests/contrib/operators/test_qubole_check_operator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ def __construct_operator(self, query, pass_value, tolerance=None,
5252
def test_pass_value_template(self):
5353
pass_value_str = "2018-03-22"
5454
operator = self.__construct_operator('select date from tab1;', "{{ ds }}")
55-
result = operator.render_template('pass_value', operator.pass_value,
56-
{'ds': pass_value_str})
55+
result = operator.render_template(operator.pass_value, {'ds': pass_value_str})
5756

5857
self.assertEqual(operator.task_id, self.task_id)
5958
self.assertEqual(result, pass_value_str)

tests/contrib/operators/test_qubole_operator.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,12 @@ def test_init_with_default_connection(self):
5858
self.assertEqual(op.qubole_conn_id, DEFAULT_CONN)
5959

6060
def test_init_with_template_connection(self):
61-
dag = DAG(DAG_ID, start_date=DEFAULT_DATE)
62-
63-
with dag:
64-
task = QuboleOperator(task_id=TASK_ID, dag=dag,
65-
qubole_conn_id="{{ dag_run.conf['qubole_conn_id'] }}")
61+
with DAG(DAG_ID, start_date=DEFAULT_DATE):
62+
task = QuboleOperator(task_id=TASK_ID, qubole_conn_id="{{ qubole_conn_id }}")
6663

67-
result = task.render_template('qubole_conn_id', "{{ qubole_conn_id }}",
68-
{'qubole_conn_id': TEMPLATE_CONN})
64+
task.render_template_fields({'qubole_conn_id': TEMPLATE_CONN})
6965
self.assertEqual(task.task_id, TASK_ID)
70-
self.assertEqual(result, TEMPLATE_CONN)
66+
self.assertEqual(task.qubole_conn_id, TEMPLATE_CONN)
7167

7268
def test_init_with_template_cluster_label(self):
7369
dag = DAG(DAG_ID, start_date=DEFAULT_DATE)

tests/models/test_baseoperator.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
20+
import datetime
21+
import unittest
22+
from unittest import mock
23+
import uuid
24+
from collections import namedtuple
25+
26+
import jinja2
27+
from parameterized import parameterized
28+
29+
from airflow.models import DAG, BaseOperator
30+
from airflow.operators.dummy_operator import DummyOperator
31+
from airflow.utils.decorators import apply_defaults
32+
from tests.models import DEFAULT_DATE
33+
34+
35+
class TestOperator(BaseOperator):
36+
"""Operator for testing purposes."""
37+
38+
template_fields = ("arg1", "arg2")
39+
40+
@apply_defaults
41+
def __init__(self, arg1: str = "", arg2: str = "", **kwargs):
42+
super().__init__(**kwargs)
43+
self.arg1 = arg1
44+
self.arg2 = arg2
45+
46+
def execute(self, context):
47+
pass
48+
49+
50+
# Namedtuple for testing purposes
51+
TestNamedTuple = namedtuple("TestNamedTuple", ["var1", "var2"])
52+
53+
54+
class BaseOperatorTest(unittest.TestCase):
55+
@parameterized.expand(
56+
[
57+
("{{ foo }}", {"foo": "bar"}, "bar"),
58+
("{{ foo }}", {}, ""),
59+
(["{{ foo }}_1", "{{ foo }}_2"], {"foo": "bar"}, ["bar_1", "bar_2"]),
60+
(("{{ foo }}_1", "{{ foo }}_2"), {"foo": "bar"}, ("bar_1", "bar_2")),
61+
(
62+
{"key1": "{{ foo }}_1", "key2": "{{ foo }}_2"},
63+
{"foo": "bar"},
64+
{"key1": "bar_1", "key2": "bar_2"},
65+
),
66+
(
67+
{"key_{{ foo }}_1": 1, "key_2": "{{ foo }}_2"},
68+
{"foo": "bar"},
69+
{"key_{{ foo }}_1": 1, "key_2": "bar_2"},
70+
),
71+
(datetime.date(2018, 12, 6), {"foo": "bar"}, datetime.date(2018, 12, 6)),
72+
(datetime.datetime(2018, 12, 6, 10, 55), {"foo": "bar"}, datetime.datetime(2018, 12, 6, 10, 55)),
73+
(TestNamedTuple("{{ foo }}_1", "{{ foo }}_2"), {"foo": "bar"}, TestNamedTuple("bar_1", "bar_2")),
74+
({"{{ foo }}_1", "{{ foo }}_2"}, {"foo": "bar"}, {"bar_1", "bar_2"}),
75+
]
76+
)
77+
def test_render_template(self, content, context, expected_output):
78+
"""Test render_template given various input types."""
79+
with DAG("test-dag", start_date=DEFAULT_DATE):
80+
task = DummyOperator(task_id="op1")
81+
82+
result = task.render_template(content, context)
83+
self.assertEqual(result, expected_output)
84+
85+
def test_render_template_fields(self):
86+
"""Verify if operator attributes are correctly templated."""
87+
with DAG("test-dag", start_date=DEFAULT_DATE):
88+
task = TestOperator(task_id="op1", arg1="{{ foo }}", arg2="{{ bar }}")
89+
90+
# Assert nothing is templated yet
91+
self.assertEqual(task.arg1, "{{ foo }}")
92+
self.assertEqual(task.arg2, "{{ bar }}")
93+
94+
# Trigger templating and verify if attributes are templated correctly
95+
task.render_template_fields(context={"foo": "footemplated", "bar": "bartemplated"})
96+
self.assertEqual(task.arg1, "footemplated")
97+
self.assertEqual(task.arg2, "bartemplated")
98+
99+
@parameterized.expand(
100+
[
101+
({"user_defined_macros": {"foo": "bar"}}, "{{ foo }}", {}, "bar"),
102+
({"user_defined_macros": {"foo": "bar"}}, 1, {}, 1),
103+
(
104+
{"user_defined_filters": {"hello": lambda name: "Hello %s" % name}},
105+
"{{ 'world' | hello }}",
106+
{},
107+
"Hello world",
108+
),
109+
]
110+
)
111+
def test_render_template_fields_with_dag_settings(self, dag_kwargs, content, context, expected_output):
112+
"""Test render_template with additional DAG settings."""
113+
with DAG("test-dag", start_date=DEFAULT_DATE, **dag_kwargs):
114+
task = DummyOperator(task_id="op1")
115+
116+
result = task.render_template(content, context)
117+
self.assertEqual(result, expected_output)
118+
119+
@parameterized.expand([(object(),), (uuid.uuid4(),)])
120+
def test_render_template_fields_no_change(self, content):
121+
"""Tests if non-templatable types remain unchanged."""
122+
with DAG("test-dag", start_date=DEFAULT_DATE):
123+
task = DummyOperator(task_id="op1")
124+
125+
result = task.render_template(content, {"foo": "bar"})
126+
self.assertEqual(content, result)
127+
128+
def test_render_template_field_undefined_strict(self):
129+
"""Test render_template with template_undefined configured."""
130+
with DAG("test-dag", start_date=DEFAULT_DATE, template_undefined=jinja2.StrictUndefined):
131+
task = DummyOperator(task_id="op1")
132+
133+
with self.assertRaises(jinja2.UndefinedError):
134+
task.render_template("{{ foo }}", {})
135+
136+
@mock.patch("jinja2.Environment", autospec=True)
137+
def test_jinja_env_creation(self, mock_jinja_env):
138+
"""Verify if a Jinja environment is created only once when templating."""
139+
with DAG("test-dag", start_date=DEFAULT_DATE):
140+
task = TestOperator(task_id="op1", arg1="{{ foo }}", arg2="{{ bar }}")
141+
142+
task.render_template_fields(context={"foo": "whatever", "bar": "whatever"})
143+
self.assertEqual(mock_jinja_env.call_count, 1)

0 commit comments

Comments
 (0)