2828from airflow .exceptions import AirflowException
2929from airflow .models .base import ID_LEN , Base
3030from airflow .stats import Stats
31- from airflow .ti_deps .dep_context import DepContext
31+ from airflow .ti_deps .dep_context import SCHEDULEABLE_STATES , DepContext
3232from airflow .utils import timezone
3333from airflow .utils .log .logging_mixin import LoggingMixin
3434from airflow .utils .session import provide_session
@@ -201,7 +201,6 @@ def get_task_instances(self, state=None, session=None):
201201
202202 if self .dag and self .dag .partial :
203203 tis = tis .filter (TaskInstance .task_id .in_ (self .dag .task_ids ))
204-
205204 return tis .all ()
206205
207206 @provide_session
@@ -268,49 +267,33 @@ def update_state(self, session=None):
268267 Determines the overall state of the DagRun based on the state
269268 of its TaskInstances.
270269
271- :return: State
270+ :return: ready_tis: the tis that can be scheduled in the current loop
271+ :rtype ready_tis: list[airflow.models.TaskInstance]
272272 """
273273
274274 dag = self .get_dag ()
275-
276- tis = self .get_task_instances (session = session )
277- self . log . debug ( "Updating state for %s considering %s task(s)" , self , len ( tis ))
278-
275+ ready_tis = []
276+ tis = [ ti for ti in self .get_task_instances (session = session ,
277+ state = State . task_states + ( State . SHUTDOWN ,))]
278+ self . log . debug ( "number of tis tasks for %s: %s task(s)" , self , len ( tis ))
279279 for ti in list (tis ):
280- # skip in db?
281- if ti .state == State .REMOVED :
282- tis .remove (ti )
283- else :
284- ti .task = dag .get_task (ti .task_id )
280+ ti .task = dag .get_task (ti .task_id )
285281
286- # pre-calculate
287- # db is faster
288282 start_dttm = timezone .utcnow ()
289- unfinished_tasks = self .get_task_instances (
290- state = State .unfinished (),
291- session = session
292- )
283+ unfinished_tasks = [t for t in tis if t .state in State .unfinished ()]
284+ finished_tasks = [t for t in tis if t .state in State .finished () + [State .UPSTREAM_FAILED ]]
293285 none_depends_on_past = all (not t .task .depends_on_past for t in unfinished_tasks )
294286 none_task_concurrency = all (t .task .task_concurrency is None
295287 for t in unfinished_tasks )
296288 # small speed up
297289 if unfinished_tasks and none_depends_on_past and none_task_concurrency :
298- # todo: this can actually get pretty slow: one task costs between 0.01-015s
299- no_dependencies_met = True
300- for ut in unfinished_tasks :
301- # We need to flag upstream and check for changes because upstream
302- # failures/re-schedules can result in deadlock false positives
303- old_state = ut .state
304- deps_met = ut .are_dependencies_met (
305- dep_context = DepContext (
306- flag_upstream_failed = True ,
307- ignore_in_retry_period = True ,
308- ignore_in_reschedule_period = True ),
309- session = session )
310- if deps_met or old_state != ut .current_state (session = session ):
311- no_dependencies_met = False
312- break
290+ scheduleable_tasks = [ut for ut in unfinished_tasks if ut .state in SCHEDULEABLE_STATES ]
313291
292+ self .log .debug ("number of scheduleable tasks for %s: %s task(s)" , self , len (scheduleable_tasks ))
293+ ready_tis , changed_tis = self ._get_ready_tis (scheduleable_tasks , finished_tasks , session )
294+ self .log .debug ("ready tis length for %s: %s task(s)" , self , len (ready_tis ))
295+ are_runnable_tasks = ready_tis or self ._are_premature_tis (
296+ unfinished_tasks , finished_tasks , session ) or changed_tis
314297 duration = (timezone .utcnow () - start_dttm )
315298 Stats .timing ("dagrun.dependency-check.{}" .format (self .dag_id ), duration )
316299
@@ -335,7 +318,7 @@ def update_state(self, session=None):
335318
336319 # if *all tasks* are deadlocked, the run failed
337320 elif (unfinished_tasks and none_depends_on_past and
338- none_task_concurrency and no_dependencies_met ):
321+ none_task_concurrency and not are_runnable_tasks ):
339322 self .log .info ('Deadlock; marking run %s failed' , self )
340323 self .set_state (State .FAILED )
341324 dag .handle_callback (self , success = False , reason = 'all_tasks_deadlocked' ,
@@ -351,7 +334,35 @@ def update_state(self, session=None):
351334 session .merge (self )
352335 session .commit ()
353336
354- return self .state
337+ return ready_tis
338+
339+ def _get_ready_tis (self , scheduleable_tasks , finished_tasks , session ):
340+ ready_tis = []
341+ changed_tis = False
342+ for st in scheduleable_tasks :
343+ st_old_state = st .state
344+ if st .are_dependencies_met (
345+ dep_context = DepContext (
346+ flag_upstream_failed = True ,
347+ finished_tasks = finished_tasks ),
348+ session = session ):
349+ ready_tis .append (st )
350+ elif st_old_state != st .current_state (session = session ):
351+ changed_tis = True
352+ return ready_tis , changed_tis
353+
354+ def _are_premature_tis (self , unfinished_tasks , finished_tasks , session ):
355+ # there might be runnable tasks that are up for retry and from some reason(retry delay, etc) are
356+ # not ready yet so we set the flags to count them in
357+ for ut in unfinished_tasks :
358+ if ut .are_dependencies_met (
359+ dep_context = DepContext (
360+ flag_upstream_failed = True ,
361+ ignore_in_retry_period = True ,
362+ ignore_in_reschedule_period = True ,
363+ finished_tasks = finished_tasks ),
364+ session = session ):
365+ return True
355366
356367 def _emit_duration_stats_for_finished_state (self ):
357368 if self .state == State .RUNNING :
0 commit comments