Skip to content

Commit ed92e5d

Browse files
ephraimbuddyashb
andauthored
Fix mini scheduler expansion of mapped task (#27506)
We have a case where the mini scheduler tries to expand a mapped task even when the downstream tasks are not yet done. The mini scheduler extracts a partial subset of a dag and in the process, some upstream tasks are dropped. If the task happens to be a mapped task, the expansion will fail since it needs the upstream output to make the expansion. When the expansion fails, the task is marked as `upstream_failed`. This leads to other downstream tasks being marked as upstream failed. The solution was to ignore this error and not mark the mapped task as upstream_failed when the expansion fails and the dag is a partial subset Co-authored-by: Ash Berlin-Taylor <[email protected]>
1 parent 47a2b9e commit ed92e5d

File tree

5 files changed

+165
-69
lines changed

5 files changed

+165
-69
lines changed

airflow/jobs/local_task_job.py

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,20 @@
1818
from __future__ import annotations
1919

2020
import signal
21-
from typing import TYPE_CHECKING
2221

2322
import psutil
24-
from sqlalchemy.exc import OperationalError
2523

2624
from airflow.configuration import conf
2725
from airflow.exceptions import AirflowException
2826
from airflow.jobs.base_job import BaseJob
2927
from airflow.listeners.events import register_task_instance_state_events
3028
from airflow.listeners.listener import get_listener_manager
31-
from airflow.models.dagrun import DagRun
3229
from airflow.models.taskinstance import TaskInstance
33-
from airflow.sentry import Sentry
3430
from airflow.stats import Stats
3531
from airflow.task.task_runner import get_task_runner
3632
from airflow.utils import timezone
3733
from airflow.utils.net import get_hostname
3834
from airflow.utils.session import provide_session
39-
from airflow.utils.sqlalchemy import with_row_locks
4035
from airflow.utils.state import State
4136

4237

@@ -165,7 +160,7 @@ def handle_task_exit(self, return_code: int) -> None:
165160

166161
if not self.task_instance.test_mode:
167162
if conf.getboolean('scheduler', 'schedule_after_task_execution', fallback=True):
168-
self._run_mini_scheduler_on_child_tasks()
163+
self.task_instance.schedule_downstream_tasks()
169164

170165
def on_kill(self):
171166
self.task_runner.terminate()
@@ -230,58 +225,6 @@ def heartbeat_callback(self, session=None):
230225
self.terminating = True
231226
self._state_change_checks += 1
232227

233-
@provide_session
234-
@Sentry.enrich_errors
235-
def _run_mini_scheduler_on_child_tasks(self, session=None) -> None:
236-
try:
237-
# Re-select the row with a lock
238-
dag_run = with_row_locks(
239-
session.query(DagRun).filter_by(
240-
dag_id=self.dag_id,
241-
run_id=self.task_instance.run_id,
242-
),
243-
session=session,
244-
).one()
245-
246-
task = self.task_instance.task
247-
if TYPE_CHECKING:
248-
assert task.dag
249-
250-
# Get a partial DAG with just the specific tasks we want to examine.
251-
# In order for dep checks to work correctly, we include ourself (so
252-
# TriggerRuleDep can check the state of the task we just executed).
253-
partial_dag = task.dag.partial_subset(
254-
task.downstream_task_ids,
255-
include_downstream=True,
256-
include_upstream=False,
257-
include_direct_upstream=True,
258-
)
259-
260-
dag_run.dag = partial_dag
261-
info = dag_run.task_instance_scheduling_decisions(session)
262-
263-
skippable_task_ids = {
264-
task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids
265-
}
266-
267-
schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids]
268-
for schedulable_ti in schedulable_tis:
269-
if not hasattr(schedulable_ti, "task"):
270-
schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id)
271-
272-
num = dag_run.schedule_tis(schedulable_tis)
273-
self.log.info("%d downstream tasks scheduled from follow-on schedule check", num)
274-
275-
session.commit()
276-
except OperationalError as e:
277-
# Any kind of DB error here is _non fatal_ as this block is just an optimisation.
278-
self.log.info(
279-
"Skipping mini scheduling run due to exception: %s",
280-
e.statement,
281-
exc_info=True,
282-
)
283-
session.rollback()
284-
285228
@staticmethod
286229
def _enable_task_listeners():
287230
"""

airflow/models/mappedoperator.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -620,13 +620,18 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence
620620
try:
621621
total_length = self._get_specified_expand_input().get_total_map_length(run_id, session=session)
622622
except NotFullyPopulated as e:
623-
self.log.info(
624-
"Cannot expand %r for run %s; missing upstream values: %s",
625-
self,
626-
run_id,
627-
sorted(e.missing),
628-
)
629623
total_length = None
624+
# partial dags comes from the mini scheduler. It's
625+
# possible that the upstream tasks are not yet done,
626+
# but we don't have upstream of upstreams in partial dags,
627+
# so we ignore this exception.
628+
if not self.dag or not self.dag.partial:
629+
self.log.error(
630+
"Cannot expand %r for run %s; missing upstream values: %s",
631+
self,
632+
run_id,
633+
sorted(e.missing),
634+
)
630635

631636
state: TaskInstanceState | None = None
632637
unmapped_ti: TaskInstance | None = (
@@ -647,10 +652,15 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence
647652
# The unmapped task instance still exists and is unfinished, i.e. we
648653
# haven't tried to run it before.
649654
if total_length is None:
650-
# If the map length cannot be calculated (due to unavailable
651-
# upstream sources), fail the unmapped task.
652-
unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED
653-
indexes_to_map: Iterable[int] = ()
655+
if self.dag and self.dag.partial:
656+
# If the DAG is partial, it's likely that the upstream tasks
657+
# are not done yet, so we do nothing
658+
indexes_to_map: Iterable[int] = ()
659+
else:
660+
# If the map length cannot be calculated (due to unavailable
661+
# upstream sources), fail the unmapped task.
662+
unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED
663+
indexes_to_map = ()
654664
elif total_length < 1:
655665
# If the upstream maps this to a zero-length value, simply mark
656666
# the unmapped task instance as SKIPPED (if needed).

airflow/models/taskinstance.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2459,6 +2459,67 @@ def ti_selector_condition(cls, vals: Collection[str | tuple[str, int]]) -> Colum
24592459
return filters[0]
24602460
return or_(*filters)
24612461

2462+
@Sentry.enrich_errors
2463+
@provide_session
2464+
def schedule_downstream_tasks(self, session=None):
2465+
"""
2466+
The mini-scheduler for scheduling downstream tasks of this task instance
2467+
:meta: private
2468+
"""
2469+
from sqlalchemy.exc import OperationalError
2470+
2471+
from airflow.models import DagRun
2472+
2473+
try:
2474+
# Re-select the row with a lock
2475+
dag_run = with_row_locks(
2476+
session.query(DagRun).filter_by(
2477+
dag_id=self.dag_id,
2478+
run_id=self.run_id,
2479+
),
2480+
session=session,
2481+
).one()
2482+
2483+
task = self.task
2484+
if TYPE_CHECKING:
2485+
assert task.dag
2486+
2487+
# Get a partial DAG with just the specific tasks we want to examine.
2488+
# In order for dep checks to work correctly, we include ourself (so
2489+
# TriggerRuleDep can check the state of the task we just executed).
2490+
partial_dag = task.dag.partial_subset(
2491+
task.downstream_task_ids,
2492+
include_downstream=True,
2493+
include_upstream=False,
2494+
include_direct_upstream=True,
2495+
)
2496+
2497+
dag_run.dag = partial_dag
2498+
info = dag_run.task_instance_scheduling_decisions(session)
2499+
2500+
skippable_task_ids = {
2501+
task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids
2502+
}
2503+
2504+
schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids]
2505+
for schedulable_ti in schedulable_tis:
2506+
if not hasattr(schedulable_ti, "task"):
2507+
schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id)
2508+
2509+
num = dag_run.schedule_tis(schedulable_tis, session=session)
2510+
self.log.info("%d downstream tasks scheduled from follow-on schedule check", num)
2511+
2512+
session.flush()
2513+
2514+
except OperationalError as e:
2515+
# Any kind of DB error here is _non fatal_ as this block is just an optimisation.
2516+
self.log.info(
2517+
"Skipping mini scheduling run due to exception: %s",
2518+
e.statement,
2519+
exc_info=True,
2520+
)
2521+
session.rollback()
2522+
24622523

24632524
# State of the task instance.
24642525
# Stores string version of the task state.

tests/jobs/test_local_task_job.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,6 @@ def test_mini_scheduler_works_with_wait_for_upstream(self, caplog, get_test_dag)
739739
ti2_l.refresh_from_db()
740740
assert ti2_k.state == State.SUCCESS
741741
assert ti2_l.state == State.NONE
742-
assert "0 downstream tasks scheduled from follow-on schedule" in caplog.text
743742

744743
failed_deps = list(ti2_l.get_failed_dep_statuses())
745744
assert len(failed_deps) == 1

tests/models/test_taskinstance.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3613,3 +3613,86 @@ def get_extra_env():
36133613

36143614
echo_task = dag.get_task("echo")
36153615
assert "get_extra_env" in echo_task.upstream_task_ids
3616+
3617+
3618+
def test_mapped_task_does_not_error_in_mini_scheduler_if_upstreams_are_not_done(dag_maker, caplog, session):
3619+
"""
3620+
This tests that when scheduling child tasks of a task and there's a mapped downstream task,
3621+
if the mapped downstream task has upstreams that are not yet done, the mapped downstream task is
3622+
not marked as `upstream_failed'
3623+
"""
3624+
with dag_maker() as dag:
3625+
3626+
@dag.task
3627+
def second_task():
3628+
return [0, 1, 2]
3629+
3630+
@dag.task
3631+
def first_task():
3632+
print(2)
3633+
3634+
@dag.task
3635+
def middle_task(id):
3636+
return id
3637+
3638+
middle = middle_task.expand(id=second_task())
3639+
3640+
@dag.task
3641+
def last_task():
3642+
print(3)
3643+
3644+
[first_task(), middle] >> last_task()
3645+
3646+
dag_run = dag_maker.create_dagrun()
3647+
first_ti = dag_run.get_task_instance(task_id="first_task")
3648+
second_ti = dag_run.get_task_instance(task_id="second_task")
3649+
first_ti.state = State.SUCCESS
3650+
second_ti.state = State.RUNNING
3651+
session.merge(first_ti)
3652+
session.merge(second_ti)
3653+
session.commit()
3654+
first_ti.schedule_downstream_tasks(session=session)
3655+
middle_ti = dag_run.get_task_instance(task_id="middle_task")
3656+
assert middle_ti.state != State.UPSTREAM_FAILED
3657+
assert "0 downstream tasks scheduled from follow-on schedule" in caplog.text
3658+
3659+
3660+
def test_mapped_task_expands_in_mini_scheduler_if_upstreams_are_done(dag_maker, caplog, session):
3661+
"""Test that mini scheduler expands mapped task"""
3662+
with dag_maker() as dag:
3663+
3664+
@dag.task
3665+
def second_task():
3666+
return [0, 1, 2]
3667+
3668+
@dag.task
3669+
def first_task():
3670+
print(2)
3671+
3672+
@dag.task
3673+
def middle_task(id):
3674+
return id
3675+
3676+
middle = middle_task.expand(id=second_task())
3677+
3678+
@dag.task
3679+
def last_task():
3680+
print(3)
3681+
3682+
[first_task(), middle] >> last_task()
3683+
3684+
dr = dag_maker.create_dagrun()
3685+
3686+
first_ti = dr.get_task_instance(task_id="first_task")
3687+
first_ti.state = State.SUCCESS
3688+
session.merge(first_ti)
3689+
session.commit()
3690+
second_task = dag.get_task("second_task")
3691+
second_ti = dr.get_task_instance(task_id="second_task")
3692+
second_ti.refresh_from_task(second_task)
3693+
second_ti.run()
3694+
second_ti.schedule_downstream_tasks(session=session)
3695+
for i in range(3):
3696+
middle_ti = dr.get_task_instance(task_id="middle_task", map_index=i)
3697+
assert middle_ti.state == State.SCHEDULED
3698+
assert "3 downstream tasks scheduled from follow-on schedule" in caplog.text

0 commit comments

Comments
 (0)