Skip to content

Commit 49c193f

Browse files
authored
[AIP-34] TaskGroup: A UI task grouping concept as an alternative to SubDagOperator (apache#10153)
This commit introduces TaskGroup, which is a simple UI task grouping concept. - TaskGroups can be collapsed/expanded in Graph View when clicked - TaskGroups can be nested - TaskGroups can be put upstream/downstream of tasks or other TaskGroups with >> and << operators - Search box, hovering, focusing in Graph View treats TaskGroup properly. E.g. searching for tasks also highlights TaskGroup that contains matching task_id. When TaskGroup is expanded/collapsed, the affected TaskGroup is put in focus and moved to the centre of the graph. What this commit does not do: - This commit does not change or remove SubDagOperator. Although TaskGroup is intended as an alternative for SubDagOperator, deprecating SubDagOperator will need to be discussed/implemented in the future. - This PR only implemented TaskGroup handling in the Graph View. In places such as Tree View, it will look like as-if - TaskGroup does not exist and all tasks are in the same flat DAG. GitHub Issue: apache#8078 AIP: https://cwiki.apache.org/confluence/display/AIRFLOW/AIP-34+TaskGroup%3A+A+UI+task+grouping+concept+as+an+alternative+to+SubDagOperator
1 parent f16f474 commit 49c193f

File tree

16 files changed

+1857
-145
lines changed

16 files changed

+1857
-145
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
19+
"""Example DAG demonstrating the usage of the TaskGroup."""
20+
21+
from airflow.models.dag import DAG
22+
from airflow.operators.dummy_operator import DummyOperator
23+
from airflow.utils.dates import days_ago
24+
from airflow.utils.task_group import TaskGroup
25+
26+
# [START howto_task_group]
27+
with DAG(dag_id="example_task_group", start_date=days_ago(2)) as dag:
28+
start = DummyOperator(task_id="start")
29+
30+
# [START howto_task_group_section_1]
31+
with TaskGroup("section_1", tooltip="Tasks for section_1") as section_1:
32+
task_1 = DummyOperator(task_id="task_1")
33+
task_2 = DummyOperator(task_id="task_2")
34+
task_3 = DummyOperator(task_id="task_3")
35+
36+
task_1 >> [task_2, task_3]
37+
# [END howto_task_group_section_1]
38+
39+
# [START howto_task_group_section_2]
40+
with TaskGroup("section_2", tooltip="Tasks for section_2") as section_2:
41+
task_1 = DummyOperator(task_id="task_1")
42+
43+
# [START howto_task_group_inner_section_2]
44+
with TaskGroup("inner_section_2", tooltip="Tasks for inner_section2") as inner_section_2:
45+
task_2 = DummyOperator(task_id="task_2")
46+
task_3 = DummyOperator(task_id="task_3")
47+
task_4 = DummyOperator(task_id="task_4")
48+
49+
[task_2, task_3] >> task_4
50+
# [END howto_task_group_inner_section_2]
51+
52+
# [END howto_task_group_section_2]
53+
54+
end = DummyOperator(task_id='end')
55+
56+
start >> section_1 >> section_2 >> end
57+
# [END howto_task_group]

airflow/models/baseoperator.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from abc import ABCMeta, abstractmethod
2828
from datetime import datetime, timedelta
2929
from typing import (
30-
Any, Callable, ClassVar, Dict, FrozenSet, Iterable, List, Optional, Sequence, Set, Tuple, Type, Union,
30+
TYPE_CHECKING, Any, Callable, ClassVar, Dict, FrozenSet, Iterable, List, Optional, Sequence, Set, Tuple,
31+
Type, Union,
3132
)
3233

3334
import attr
@@ -58,6 +59,9 @@
5859
from airflow.utils.trigger_rule import TriggerRule
5960
from airflow.utils.weight_rule import WeightRule
6061

62+
if TYPE_CHECKING:
63+
from airflow.utils.task_group import TaskGroup # pylint: disable=cyclic-import
64+
6165
ScheduleInterval = Union[str, timedelta, relativedelta]
6266

6367
TaskStateChangeCallback = Callable[[Context], None]
@@ -360,9 +364,12 @@ def __init__(
360364
do_xcom_push: bool = True,
361365
inlets: Optional[Any] = None,
362366
outlets: Optional[Any] = None,
367+
task_group: Optional["TaskGroup"] = None,
363368
**kwargs
364369
):
365370
from airflow.models.dag import DagContext
371+
from airflow.utils.task_group import TaskGroupContext
372+
366373
super().__init__()
367374
if kwargs:
368375
if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'):
@@ -382,6 +389,11 @@ def __init__(
382389
)
383390
validate_key(task_id)
384391
self.task_id = task_id
392+
self.label = task_id
393+
task_group = task_group or TaskGroupContext.get_current_task_group(dag)
394+
if task_group:
395+
self.task_id = task_group.child_id(task_id)
396+
task_group.add(self)
385397
self.owner = owner
386398
self.email = email
387399
self.email_on_retry = email_on_retry
@@ -609,7 +621,7 @@ def dag(self, dag: Any):
609621
elif self.task_id in dag.task_dict and dag.task_dict[self.task_id] is not self:
610622
dag.add_task(self)
611623

612-
self._dag = dag # pylint: disable=attribute-defined-outside-init
624+
self._dag = dag
613625

614626
def has_dag(self):
615627
"""
@@ -1120,21 +1132,25 @@ def roots(self) -> List["BaseOperator"]:
11201132
"""Required by TaskMixin"""
11211133
return [self]
11221134

1135+
@property
1136+
def leaves(self) -> List["BaseOperator"]:
1137+
"""Required by TaskMixin"""
1138+
return [self]
1139+
11231140
def _set_relatives(
11241141
self,
11251142
task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]],
11261143
upstream: bool = False,
11271144
) -> None:
11281145
"""Sets relatives for the task or task list."""
1129-
1130-
if isinstance(task_or_task_list, Sequence):
1131-
task_like_object_list = task_or_task_list
1132-
else:
1133-
task_like_object_list = [task_or_task_list]
1146+
if not isinstance(task_or_task_list, Sequence):
1147+
task_or_task_list = [task_or_task_list]
11341148

11351149
task_list: List["BaseOperator"] = []
1136-
for task_object in task_like_object_list:
1137-
task_list.extend(task_object.roots)
1150+
for task_object in task_or_task_list:
1151+
task_object.update_relative(self, not upstream)
1152+
relatives = task_object.leaves if upstream else task_object.roots
1153+
task_list.extend(relatives)
11381154

11391155
for task in task_list:
11401156
if not isinstance(task, BaseOperator):

airflow/models/dag.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
import warnings
2828
from collections import OrderedDict
2929
from datetime import datetime, timedelta
30-
from typing import Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Type, Union, cast
30+
from typing import (
31+
TYPE_CHECKING, Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Type, Union, cast,
32+
)
3133

3234
import jinja2
3335
import pendulum
@@ -59,6 +61,9 @@
5961
from airflow.utils.state import State
6062
from airflow.utils.types import DagRunType
6163

64+
if TYPE_CHECKING:
65+
from airflow.utils.task_group import TaskGroup
66+
6267
log = logging.getLogger(__name__)
6368

6469
ScheduleInterval = Union[str, timedelta, relativedelta]
@@ -238,6 +243,8 @@ def __init__(
238243
jinja_environment_kwargs: Optional[Dict] = None,
239244
tags: Optional[List[str]] = None
240245
):
246+
from airflow.utils.task_group import TaskGroup
247+
241248
self.user_defined_macros = user_defined_macros
242249
self.user_defined_filters = user_defined_filters
243250
self.default_args = copy.deepcopy(default_args or {})
@@ -329,6 +336,7 @@ def __init__(
329336

330337
self.jinja_environment_kwargs = jinja_environment_kwargs
331338
self.tags = tags
339+
self._task_group = TaskGroup.create_root(self)
332340

333341
def __repr__(self):
334342
return "<DAG: {self.dag_id}>".format(self=self)
@@ -570,6 +578,10 @@ def tasks(self, val):
570578
def task_ids(self) -> List[str]:
571579
return list(self.task_dict.keys())
572580

581+
@property
582+
def task_group(self) -> "TaskGroup":
583+
return self._task_group
584+
573585
@property
574586
def filepath(self) -> str:
575587
"""
@@ -1240,7 +1252,6 @@ def sub_dag(self, task_regex, include_downstream=False,
12401252
based on a regex that should match one or many tasks, and includes
12411253
upstream and downstream neighbours based on the flag passed.
12421254
"""
1243-
12441255
# deep-copying self.task_dict takes a long time, and we don't want all
12451256
# the tasks anyway, so we copy the tasks manually later
12461257
task_dict = self.task_dict
@@ -1261,9 +1272,38 @@ def sub_dag(self, task_regex, include_downstream=False,
12611272
# Make sure to not recursively deepcopy the dag while copying the task
12621273
dag.task_dict = {t.task_id: copy.deepcopy(t, {id(t.dag): dag})
12631274
for t in regex_match + also_include}
1275+
1276+
# Remove tasks not included in the subdag from task_group
1277+
def remove_excluded(group):
1278+
for child in list(group.children.values()):
1279+
if isinstance(child, BaseOperator):
1280+
if child.task_id not in dag.task_dict:
1281+
group.children.pop(child.task_id)
1282+
else:
1283+
# The tasks in the subdag are a copy of tasks in the original dag
1284+
# so update the reference in the TaskGroups too.
1285+
group.children[child.task_id] = dag.task_dict[child.task_id]
1286+
else:
1287+
remove_excluded(child)
1288+
1289+
# Remove this TaskGroup if it doesn't contain any tasks in this subdag
1290+
if not child.children:
1291+
group.children.pop(child.group_id)
1292+
1293+
remove_excluded(dag.task_group)
1294+
1295+
# Removing upstream/downstream references to tasks and TaskGroups that did not make
1296+
# the cut.
1297+
subdag_task_groups = dag.task_group.get_task_group_dict()
1298+
for group in subdag_task_groups.values():
1299+
group.upstream_group_ids = group.upstream_group_ids.intersection(subdag_task_groups.keys())
1300+
group.downstream_group_ids = group.downstream_group_ids.intersection(subdag_task_groups.keys())
1301+
group.upstream_task_ids = group.upstream_task_ids.intersection(dag.task_dict.keys())
1302+
group.downstream_task_ids = group.downstream_task_ids.intersection(dag.task_dict.keys())
1303+
12641304
for t in dag.tasks:
12651305
# Removing upstream/downstream references to tasks that did not
1266-
# made the cut
1306+
# make the cut
12671307
t._upstream_task_ids = t.upstream_task_ids.intersection(dag.task_dict.keys())
12681308
t._downstream_task_ids = t.downstream_task_ids.intersection(
12691309
dag.task_dict.keys())
@@ -1357,12 +1397,15 @@ def add_task(self, task):
13571397
elif task.end_date and self.end_date:
13581398
task.end_date = min(task.end_date, self.end_date)
13591399

1360-
if task.task_id in self.task_dict and self.task_dict[task.task_id] is not task:
1400+
if ((task.task_id in self.task_dict and self.task_dict[task.task_id] is not task)
1401+
or task.task_id in self._task_group.used_group_ids):
13611402
raise DuplicateTaskIdFound(
13621403
"Task id '{}' has already been added to the DAG".format(task.task_id))
13631404
else:
13641405
self.task_dict[task.task_id] = task
13651406
task.dag = self
1407+
# Add task_id to used_group_ids to prevent group_id and task_id collisions.
1408+
self._task_group.used_group_ids.add(task.task_id)
13661409

13671410
self.task_count = len(self.task_dict)
13681411

airflow/models/taskmixin.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ def roots(self):
3333
"""Should return list of root operator List[BaseOperator]"""
3434
raise NotImplementedError()
3535

36+
@property
37+
def leaves(self):
38+
"""Should return list of leaf operator List[BaseOperator]"""
39+
raise NotImplementedError()
40+
3641
@abstractmethod
3742
def set_upstream(self, other: Union["TaskMixin", Sequence["TaskMixin"]]):
3843
"""
@@ -47,6 +52,12 @@ def set_downstream(self, other: Union["TaskMixin", Sequence["TaskMixin"]]):
4752
"""
4853
raise NotImplementedError()
4954

55+
def update_relative(self, other: "TaskMixin", upstream=True) -> None:
56+
"""
57+
Update relationship information about another TaskMixin. Default is no-op.
58+
Override if necessary.
59+
"""
60+
5061
def __lshift__(self, other: Union["TaskMixin", Sequence["TaskMixin"]]):
5162
"""
5263
Implements Task << Task

airflow/models/xcom_arg.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ def roots(self) -> List[BaseOperator]:
102102
"""Required by TaskMixin"""
103103
return [self._operator]
104104

105+
@property
106+
def leaves(self) -> List[BaseOperator]:
107+
"""Required by TaskMixin"""
108+
return [self._operator]
109+
105110
@property
106111
def key(self) -> str:
107112
"""Returns keys of this XComArg"""

airflow/serialization/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,4 @@ class DagAttributeTypes(str, Enum):
4343
SET = 'set'
4444
TUPLE = 'tuple'
4545
POD = 'k8s.V1Pod'
46+
TASK_GROUP = 'taskgroup'

airflow/serialization/schema.json

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,11 @@
9696
"_default_view": { "type" : "string"},
9797
"_access_control": {"$ref": "#/definitions/dict" },
9898
"is_paused_upon_creation": { "type": "boolean" },
99-
"tags": { "type": "array" }
99+
"tags": { "type": "array" },
100+
"_task_group": {"anyOf": [
101+
{ "type": "null" },
102+
{ "$ref": "#/definitions/task_group" }
103+
]}
100104
},
101105
"required": [
102106
"_dag_id",
@@ -125,6 +129,7 @@
125129
"_task_module": { "type": "string" },
126130
"_operator_extra_links": { "$ref": "#/definitions/extra_links" },
127131
"task_id": { "type": "string" },
132+
"label": { "type": "string" },
128133
"owner": { "type": "string" },
129134
"start_date": { "$ref": "#/definitions/datetime" },
130135
"end_date": { "$ref": "#/definitions/datetime" },
@@ -156,6 +161,47 @@
156161
}
157162
},
158163
"additionalProperties": true
164+
},
165+
"task_group": {
166+
"$comment": "A TaskGroup containing tasks",
167+
"type": "object",
168+
"required": [
169+
"_group_id",
170+
"prefix_group_id",
171+
"children",
172+
"tooltip",
173+
"ui_color",
174+
"ui_fgcolor",
175+
"upstream_group_ids",
176+
"downstream_group_ids",
177+
"upstream_task_ids",
178+
"downstream_task_ids"
179+
],
180+
"properties": {
181+
"_group_id": {"anyOf": [{"type": "null"}, { "type": "string" }]},
182+
"prefix_group_id": { "type": "boolean" },
183+
"children": { "$ref": "#/definitions/dict" },
184+
"tooltip": { "type": "string" },
185+
"ui_color": { "type": "string" },
186+
"ui_fgcolor": { "type": "string" },
187+
"upstream_group_ids": {
188+
"type": "array",
189+
"items": { "type": "string" }
190+
},
191+
"downstream_group_ids": {
192+
"type": "array",
193+
"items": { "type": "string" }
194+
},
195+
"upstream_task_ids": {
196+
"type": "array",
197+
"items": { "type": "string" }
198+
},
199+
"downstream_task_ids": {
200+
"type": "array",
201+
"items": { "type": "string" }
202+
}
203+
},
204+
"additionalProperties": false
159205
}
160206
},
161207

0 commit comments

Comments
 (0)