Skip to content

Commit 50efda5

Browse files
amichai07ashb
authored andcommitted
[AIRFLOW-3607] Only query DB once per DAG run for TriggerRuleDep (apache#4751)
This decreases scheduler delay between tasks by about 20% for larger DAGs, sometimes more for larger or more complex DAGs. The delay between tasks can be a major issue, especially when we have dags with many subdags, figures out that the scheduling process spends plenty of time in dependency checking, we took the trigger rule dependency which calls the db for each task instance, we made it call the db just once for each dag_run
1 parent e54fba5 commit 50efda5

File tree

8 files changed

+157
-101
lines changed

8 files changed

+157
-101
lines changed

airflow/jobs/scheduler_job.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from airflow.models import DAG, DagRun, SlaMiss, errors
4444
from airflow.models.taskinstance import SimpleTaskInstance
4545
from airflow.stats import Stats
46-
from airflow.ti_deps.dep_context import SCHEDULEABLE_STATES, SCHEDULED_DEPS, DepContext
46+
from airflow.ti_deps.dep_context import SCHEDULED_DEPS, DepContext
4747
from airflow.ti_deps.deps.pool_slots_available_dep import STATES_TO_COUNT_AS_RUNNING
4848
from airflow.utils import asciiart, helpers, timezone
4949
from airflow.utils.dag_processing import (
@@ -648,28 +648,10 @@ def _process_task_instances(self, dag, task_instances_list, session=None):
648648
run.dag = dag
649649
# todo: preferably the integrity check happens at dag collection time
650650
run.verify_integrity(session=session)
651-
run.update_state(session=session)
651+
ready_tis = run.update_state(session=session)
652652
if run.state == State.RUNNING:
653-
make_transient(run)
654-
active_dag_runs.append(run)
655-
656-
for run in active_dag_runs:
657-
self.log.debug("Examining active DAG run: %s", run)
658-
tis = run.get_task_instances(state=SCHEDULEABLE_STATES)
659-
660-
# this loop is quite slow as it uses are_dependencies_met for
661-
# every task (in ti.is_runnable). This is also called in
662-
# update_state above which has already checked these tasks
663-
for ti in tis:
664-
task = dag.get_task(ti.task_id)
665-
666-
# fixme: ti.task is transient but needs to be set
667-
ti.task = task
668-
669-
if ti.are_dependencies_met(
670-
dep_context=DepContext(flag_upstream_failed=True),
671-
session=session
672-
):
653+
self.log.debug("Examining active DAG run: %s", run)
654+
for ti in ready_tis:
673655
self.log.debug('Queuing task: %s', ti)
674656
task_instances_list.append(ti.key)
675657

airflow/models/dagrun.py

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from airflow.exceptions import AirflowException
2929
from airflow.models.base import ID_LEN, Base
3030
from airflow.stats import Stats
31-
from airflow.ti_deps.dep_context import DepContext
31+
from airflow.ti_deps.dep_context import SCHEDULEABLE_STATES, DepContext
3232
from airflow.utils import timezone
3333
from airflow.utils.log.logging_mixin import LoggingMixin
3434
from airflow.utils.session import provide_session
@@ -201,7 +201,6 @@ def get_task_instances(self, state=None, session=None):
201201

202202
if self.dag and self.dag.partial:
203203
tis = tis.filter(TaskInstance.task_id.in_(self.dag.task_ids))
204-
205204
return tis.all()
206205

207206
@provide_session
@@ -268,49 +267,33 @@ def update_state(self, session=None):
268267
Determines the overall state of the DagRun based on the state
269268
of its TaskInstances.
270269
271-
:return: State
270+
:return: ready_tis: the tis that can be scheduled in the current loop
271+
:rtype ready_tis: list[airflow.models.TaskInstance]
272272
"""
273273

274274
dag = self.get_dag()
275-
276-
tis = self.get_task_instances(session=session)
277-
self.log.debug("Updating state for %s considering %s task(s)", self, len(tis))
278-
275+
ready_tis = []
276+
tis = [ti for ti in self.get_task_instances(session=session,
277+
state=State.task_states + (State.SHUTDOWN,))]
278+
self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis))
279279
for ti in list(tis):
280-
# skip in db?
281-
if ti.state == State.REMOVED:
282-
tis.remove(ti)
283-
else:
284-
ti.task = dag.get_task(ti.task_id)
280+
ti.task = dag.get_task(ti.task_id)
285281

286-
# pre-calculate
287-
# db is faster
288282
start_dttm = timezone.utcnow()
289-
unfinished_tasks = self.get_task_instances(
290-
state=State.unfinished(),
291-
session=session
292-
)
283+
unfinished_tasks = [t for t in tis if t.state in State.unfinished()]
284+
finished_tasks = [t for t in tis if t.state in State.finished() + [State.UPSTREAM_FAILED]]
293285
none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks)
294286
none_task_concurrency = all(t.task.task_concurrency is None
295287
for t in unfinished_tasks)
296288
# small speed up
297289
if unfinished_tasks and none_depends_on_past and none_task_concurrency:
298-
# todo: this can actually get pretty slow: one task costs between 0.01-015s
299-
no_dependencies_met = True
300-
for ut in unfinished_tasks:
301-
# We need to flag upstream and check for changes because upstream
302-
# failures/re-schedules can result in deadlock false positives
303-
old_state = ut.state
304-
deps_met = ut.are_dependencies_met(
305-
dep_context=DepContext(
306-
flag_upstream_failed=True,
307-
ignore_in_retry_period=True,
308-
ignore_in_reschedule_period=True),
309-
session=session)
310-
if deps_met or old_state != ut.current_state(session=session):
311-
no_dependencies_met = False
312-
break
290+
scheduleable_tasks = [ut for ut in unfinished_tasks if ut.state in SCHEDULEABLE_STATES]
313291

292+
self.log.debug("number of scheduleable tasks for %s: %s task(s)", self, len(scheduleable_tasks))
293+
ready_tis, changed_tis = self._get_ready_tis(scheduleable_tasks, finished_tasks, session)
294+
self.log.debug("ready tis length for %s: %s task(s)", self, len(ready_tis))
295+
are_runnable_tasks = ready_tis or self._are_premature_tis(
296+
unfinished_tasks, finished_tasks, session) or changed_tis
314297
duration = (timezone.utcnow() - start_dttm)
315298
Stats.timing("dagrun.dependency-check.{}".format(self.dag_id), duration)
316299

@@ -335,7 +318,7 @@ def update_state(self, session=None):
335318

336319
# if *all tasks* are deadlocked, the run failed
337320
elif (unfinished_tasks and none_depends_on_past and
338-
none_task_concurrency and no_dependencies_met):
321+
none_task_concurrency and not are_runnable_tasks):
339322
self.log.info('Deadlock; marking run %s failed', self)
340323
self.set_state(State.FAILED)
341324
dag.handle_callback(self, success=False, reason='all_tasks_deadlocked',
@@ -351,7 +334,35 @@ def update_state(self, session=None):
351334
session.merge(self)
352335
session.commit()
353336

354-
return self.state
337+
return ready_tis
338+
339+
def _get_ready_tis(self, scheduleable_tasks, finished_tasks, session):
340+
ready_tis = []
341+
changed_tis = False
342+
for st in scheduleable_tasks:
343+
st_old_state = st.state
344+
if st.are_dependencies_met(
345+
dep_context=DepContext(
346+
flag_upstream_failed=True,
347+
finished_tasks=finished_tasks),
348+
session=session):
349+
ready_tis.append(st)
350+
elif st_old_state != st.current_state(session=session):
351+
changed_tis = True
352+
return ready_tis, changed_tis
353+
354+
def _are_premature_tis(self, unfinished_tasks, finished_tasks, session):
355+
# there might be runnable tasks that are up for retry and from some reason(retry delay, etc) are
356+
# not ready yet so we set the flags to count them in
357+
for ut in unfinished_tasks:
358+
if ut.are_dependencies_met(
359+
dep_context=DepContext(
360+
flag_upstream_failed=True,
361+
ignore_in_retry_period=True,
362+
ignore_in_reschedule_period=True,
363+
finished_tasks=finished_tasks),
364+
session=session):
365+
return True
355366

356367
def _emit_duration_stats_for_finished_state(self):
357368
if self.state == State.RUNNING:

airflow/ti_deps/dep_context.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ class DepContext:
6767
:type ignore_task_deps: bool
6868
:param ignore_ti_state: Ignore the task instance's previous failure/success
6969
:type ignore_ti_state: bool
70+
:param finished_tasks: A list of all the finished tasks of this run
71+
:type finished_tasks: list[airflow.models.TaskInstance]
7072
"""
7173
def __init__(
7274
self,
@@ -77,7 +79,8 @@ def __init__(
7779
ignore_in_retry_period=False,
7880
ignore_in_reschedule_period=False,
7981
ignore_task_deps=False,
80-
ignore_ti_state=False):
82+
ignore_ti_state=False,
83+
finished_tasks=None):
8184
self.deps = deps or set()
8285
self.flag_upstream_failed = flag_upstream_failed
8386
self.ignore_all_deps = ignore_all_deps
@@ -86,6 +89,7 @@ def __init__(
8689
self.ignore_in_reschedule_period = ignore_in_reschedule_period
8790
self.ignore_task_deps = ignore_task_deps
8891
self.ignore_ti_state = ignore_ti_state
92+
self.finished_tasks = finished_tasks
8993

9094

9195
# In order to be able to get queued a task must have one of these states

airflow/ti_deps/deps/trigger_rule_dep.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# specific language governing permissions and limitations
1818
# under the License.
1919

20-
from sqlalchemy import case, func
20+
from collections import Counter
2121

2222
import airflow
2323
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
@@ -34,11 +34,32 @@ class TriggerRuleDep(BaseTIDep):
3434
IGNOREABLE = True
3535
IS_TASK_DEP = True
3636

37+
@staticmethod
38+
@provide_session
39+
def _get_states_count_upstream_ti(ti, finished_tasks, session):
40+
"""
41+
This function returns the states of the upstream tis for a specific ti in order to determine
42+
whether this ti can run in this iteration
43+
44+
:param ti: the ti that we want to calculate deps for
45+
:type ti: airflow.models.TaskInstance
46+
:param finished_tasks: all the finished tasks of the dag_run
47+
:type finished_tasks: list[airflow.models.TaskInstance]
48+
"""
49+
if finished_tasks is None:
50+
# this is for the strange feature of running tasks without dag_run
51+
finished_tasks = ti.task.dag.get_task_instances(
52+
start_date=ti.execution_date,
53+
end_date=ti.execution_date,
54+
state=State.finished() + [State.UPSTREAM_FAILED],
55+
session=session)
56+
counter = Counter(task.state for task in finished_tasks if task.task_id in ti.task.upstream_task_ids)
57+
return counter.get(State.SUCCESS, 0), counter.get(State.SKIPPED, 0), counter.get(State.FAILED, 0), \
58+
counter.get(State.UPSTREAM_FAILED, 0), sum(counter.values())
59+
3760
@provide_session
3861
def _get_dep_statuses(self, ti, session, dep_context):
39-
TI = airflow.models.TaskInstance
4062
TR = airflow.utils.trigger_rule.TriggerRule
41-
4263
# Checking that all upstream dependencies have succeeded
4364
if not ti.task.upstream_list:
4465
yield self._passing_status(
@@ -48,34 +69,11 @@ def _get_dep_statuses(self, ti, session, dep_context):
4869
if ti.task.trigger_rule == TR.DUMMY:
4970
yield self._passing_status(reason="The task had a dummy trigger rule set.")
5071
return
72+
# see if the task name is in the task upstream for our task
73+
successes, skipped, failed, upstream_failed, done = self._get_states_count_upstream_ti(
74+
ti=ti,
75+
finished_tasks=dep_context.finished_tasks)
5176

52-
# TODO(unknown): this query becomes quite expensive with dags that have many
53-
# tasks. It should be refactored to let the task report to the dag run and get the
54-
# aggregates from there.
55-
qry = (
56-
session
57-
.query(
58-
func.coalesce(func.sum(
59-
case([(TI.state == State.SUCCESS, 1)], else_=0)), 0),
60-
func.coalesce(func.sum(
61-
case([(TI.state == State.SKIPPED, 1)], else_=0)), 0),
62-
func.coalesce(func.sum(
63-
case([(TI.state == State.FAILED, 1)], else_=0)), 0),
64-
func.coalesce(func.sum(
65-
case([(TI.state == State.UPSTREAM_FAILED, 1)], else_=0)), 0),
66-
func.count(TI.task_id),
67-
)
68-
.filter(
69-
TI.dag_id == ti.dag_id,
70-
TI.task_id.in_(ti.task.upstream_task_ids),
71-
TI.execution_date == ti.execution_date,
72-
TI.state.in_([
73-
State.SUCCESS, State.FAILED,
74-
State.UPSTREAM_FAILED, State.SKIPPED]),
75-
)
76-
)
77-
78-
successes, skipped, failed, upstream_failed, done = qry.first()
7977
yield from self._evaluate_trigger_rule(
8078
ti=ti,
8179
successes=successes,

tests/jobs/test_backfill_job.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,6 +1261,7 @@ def test_backfill_execute_subdag_with_removed_task(self):
12611261

12621262
session = settings.Session()
12631263
session.merge(removed_task_ti)
1264+
session.commit()
12641265

12651266
with timeout(seconds=30):
12661267
job.run()

tests/jobs/test_scheduler_job.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2034,8 +2034,8 @@ def test_dagrun_root_fail_unfinished(self):
20342034
ti = dr.get_task_instance('test_dagrun_unfinished', session=session)
20352035
ti.state = State.NONE
20362036
session.commit()
2037-
dr_state = dr.update_state()
2038-
self.assertEqual(dr_state, State.RUNNING)
2037+
dr.update_state()
2038+
self.assertEqual(dr.state, State.RUNNING)
20392039

20402040
def test_dagrun_root_after_dagrun_unfinished(self):
20412041
"""

tests/models/test_dagrun.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ def test_dagrun_success_when_all_skipped(self):
159159
dag_run = self.create_dag_run(dag=dag,
160160
state=State.RUNNING,
161161
task_states=initial_task_states)
162-
updated_dag_state = dag_run.update_state()
163-
self.assertEqual(State.SUCCESS, updated_dag_state)
162+
dag_run.update_state()
163+
self.assertEqual(State.SUCCESS, dag_run.state)
164164

165165
def test_dagrun_success_conditions(self):
166166
session = settings.Session()
@@ -198,15 +198,15 @@ def test_dagrun_success_conditions(self):
198198
ti_op4 = dr.get_task_instance(task_id=op4.task_id)
199199

200200
# root is successful, but unfinished tasks
201-
state = dr.update_state()
202-
self.assertEqual(State.RUNNING, state)
201+
dr.update_state()
202+
self.assertEqual(State.RUNNING, dr.state)
203203

204204
# one has failed, but root is successful
205205
ti_op2.set_state(state=State.FAILED, session=session)
206206
ti_op3.set_state(state=State.SUCCESS, session=session)
207207
ti_op4.set_state(state=State.SUCCESS, session=session)
208-
state = dr.update_state()
209-
self.assertEqual(State.SUCCESS, state)
208+
dr.update_state()
209+
self.assertEqual(State.SUCCESS, dr.state)
210210

211211
def test_dagrun_deadlock(self):
212212
session = settings.Session()
@@ -321,8 +321,8 @@ def on_success_callable(context):
321321
dag_run = self.create_dag_run(dag=dag,
322322
state=State.RUNNING,
323323
task_states=initial_task_states)
324-
updated_dag_state = dag_run.update_state()
325-
self.assertEqual(State.SUCCESS, updated_dag_state)
324+
dag_run.update_state()
325+
self.assertEqual(State.SUCCESS, dag_run.state)
326326

327327
def test_dagrun_failure_callback(self):
328328
def on_failure_callable(context):
@@ -352,8 +352,8 @@ def on_failure_callable(context):
352352
dag_run = self.create_dag_run(dag=dag,
353353
state=State.RUNNING,
354354
task_states=initial_task_states)
355-
updated_dag_state = dag_run.update_state()
356-
self.assertEqual(State.FAILED, updated_dag_state)
355+
dag_run.update_state()
356+
self.assertEqual(State.FAILED, dag_run.state)
357357

358358
def test_dagrun_set_state_end_date(self):
359359
session = settings.Session()

0 commit comments

Comments
 (0)