2727import multiprocessing
2828import time
2929from queue import Empty , Queue # pylint: disable=unused-import
30- from typing import Any , Dict , Optional , Tuple , Union
30+ from typing import Any , Dict , List , Optional , Tuple , Union
3131
3232import kubernetes
3333from dateutil import parser
4444from airflow .kubernetes .kube_client import get_kube_client
4545from airflow .kubernetes .pod_generator import MAX_POD_ID_LEN , PodGenerator
4646from airflow .kubernetes .pod_launcher import PodLauncher
47- from airflow .models import KubeResourceVersion , KubeWorkerIdentifier , TaskInstance
47+ from airflow .models import TaskInstance
4848from airflow .models .taskinstance import TaskInstanceKey
4949from airflow .utils .log .logging_mixin import LoggingMixin
5050from airflow .utils .session import provide_session
6060KubernetesWatchType = Tuple [str , str , Optional [str ], Dict [str , str ], str ]
6161
6262
63+ class ResourceVersion :
64+ """Singleton for tracking resourceVersion from Kubernetes"""
65+
66+ _instance = None
67+ resource_version = "0"
68+
69+ def __new__ (cls ):
70+ if cls ._instance is None :
71+ cls ._instance = super ().__new__ (cls )
72+ return cls ._instance
73+
74+
6375class KubeConfig : # pylint: disable=too-many-instance-attributes
6476 """Configuration for Kubernetes"""
6577
@@ -134,25 +146,25 @@ def __init__(self,
134146 multi_namespace_mode : bool ,
135147 watcher_queue : 'Queue[KubernetesWatchType]' ,
136148 resource_version : Optional [str ],
137- worker_uuid : Optional [str ],
149+ scheduler_job_id : Optional [str ],
138150 kube_config : Configuration ):
139151 super ().__init__ ()
140152 self .namespace = namespace
141153 self .multi_namespace_mode = multi_namespace_mode
142- self .worker_uuid = worker_uuid
154+ self .scheduler_job_id = scheduler_job_id
143155 self .watcher_queue = watcher_queue
144156 self .resource_version = resource_version
145157 self .kube_config = kube_config
146158
147159 def run (self ) -> None :
148160 """Performs watching"""
149161 kube_client : client .CoreV1Api = get_kube_client ()
150- if not self .worker_uuid :
162+ if not self .scheduler_job_id :
151163 raise AirflowException (NOT_STARTED_MESSAGE )
152164 while True :
153165 try :
154166 self .resource_version = self ._run (kube_client , self .resource_version ,
155- self .worker_uuid , self .kube_config )
167+ self .scheduler_job_id , self .kube_config )
156168 except ReadTimeoutError :
157169 self .log .warning ("There was a timeout error accessing the Kube API. "
158170 "Retrying request." , exc_info = True )
@@ -167,15 +179,15 @@ def run(self) -> None:
167179 def _run (self ,
168180 kube_client : client .CoreV1Api ,
169181 resource_version : Optional [str ],
170- worker_uuid : str ,
182+ scheduler_job_id : str ,
171183 kube_config : Any ) -> Optional [str ]:
172184 self .log .info (
173185 'Event: and now my watch begins starting at resource_version: %s' ,
174186 resource_version
175187 )
176188 watcher = watch .Watch ()
177189
178- kwargs = {'label_selector' : 'airflow-worker={}' .format (worker_uuid )}
190+ kwargs = {'label_selector' : 'airflow-worker={}' .format (scheduler_job_id )}
179191 if resource_version :
180192 kwargs ['resource_version' ] = resource_version
181193 if kube_config .kube_client_request_args :
@@ -277,7 +289,7 @@ def __init__(self,
277289 task_queue : 'Queue[KubernetesJobType]' ,
278290 result_queue : 'Queue[KubernetesResultsType]' ,
279291 kube_client : client .CoreV1Api ,
280- worker_uuid : str ):
292+ scheduler_job_id : str ):
281293 super ().__init__ ()
282294 self .log .debug ("Creating Kubernetes executor" )
283295 self .kube_config = kube_config
@@ -289,16 +301,16 @@ def __init__(self,
289301 self .launcher = PodLauncher (kube_client = self .kube_client )
290302 self ._manager = multiprocessing .Manager ()
291303 self .watcher_queue = self ._manager .Queue ()
292- self .worker_uuid = worker_uuid
304+ self .scheduler_job_id = scheduler_job_id
293305 self .kube_watcher = self ._make_kube_watcher ()
294306
295307 def _make_kube_watcher (self ) -> KubernetesJobWatcher :
296- resource_version = KubeResourceVersion . get_current_resource_version ()
308+ resource_version = ResourceVersion (). resource_version
297309 watcher = KubernetesJobWatcher (watcher_queue = self .watcher_queue ,
298310 namespace = self .kube_config .kube_namespace ,
299311 multi_namespace_mode = self .kube_config .multi_namespace_mode ,
300312 resource_version = resource_version ,
301- worker_uuid = self .worker_uuid ,
313+ scheduler_job_id = self .scheduler_job_id ,
302314 kube_config = self .kube_config )
303315 watcher .start ()
304316 return watcher
@@ -333,8 +345,8 @@ def run_next(self, next_job: KubernetesJobType) -> None:
333345
334346 pod = PodGenerator .construct_pod (
335347 namespace = self .namespace ,
336- worker_uuid = self .worker_uuid ,
337- pod_id = self . _create_pod_id (dag_id , task_id ),
348+ scheduler_job_id = self .scheduler_job_id ,
349+ pod_id = create_pod_id (dag_id , task_id ),
338350 dag_id = dag_id ,
339351 task_id = task_id ,
340352 kube_image = self .kube_config .kube_image ,
@@ -404,21 +416,6 @@ def _annotations_to_key(self, annotations: Dict[str, str]) -> Optional[TaskInsta
404416
405417 return TaskInstanceKey (dag_id , task_id , execution_date , try_number )
406418
407- @staticmethod
408- def _strip_unsafe_kubernetes_special_chars (string : str ) -> str :
409- """
410- Kubernetes only supports lowercase alphanumeric characters and "-" and "." in
411- the pod name
412- However, there are special rules about how "-" and "." can be used so let's
413- only keep
414- alphanumeric chars see here for detail:
415- https://kubernetes.io/docs/concepts/overview/working-with-objects/names/
416-
417- :param string: The requested Pod name
418- :return: ``str`` Pod name stripped of any unsafe characters
419- """
420- return '' .join (ch .lower () for ind , ch in enumerate (string ) if ch .isalnum ())
421-
422419 @staticmethod
423420 def _make_safe_pod_id (safe_dag_id : str , safe_task_id : str , safe_uuid : str ) -> str :
424421 r"""
@@ -437,14 +434,6 @@ def _make_safe_pod_id(safe_dag_id: str, safe_task_id: str, safe_uuid: str) -> st
437434
438435 return safe_pod_id
439436
440- @staticmethod
441- def _create_pod_id (dag_id : str , task_id : str ) -> str :
442- safe_dag_id = AirflowKubernetesScheduler ._strip_unsafe_kubernetes_special_chars (
443- dag_id )
444- safe_task_id = AirflowKubernetesScheduler ._strip_unsafe_kubernetes_special_chars (
445- task_id )
446- return safe_dag_id + safe_task_id
447-
448437 def _flush_watcher_queue (self ) -> None :
449438 self .log .debug ('Executor shutting down, watcher_queue approx. size=%d' , self .watcher_queue .qsize ())
450439 while True :
@@ -470,6 +459,36 @@ def terminate(self) -> None:
470459 self ._manager .shutdown ()
471460
472461
462+ def _strip_unsafe_kubernetes_special_chars (string : str ) -> str :
463+ """
464+ Kubernetes only supports lowercase alphanumeric characters, "-" and "." in
465+ the pod name.
466+ However, there are special rules about how "-" and "." can be used so let's
467+ only keep
468+ alphanumeric chars see here for detail:
469+ https://kubernetes.io/docs/concepts/overview/working-with-objects/names/
470+
471+ :param string: The requested Pod name
472+ :return: ``str`` Pod name stripped of any unsafe characters
473+ """
474+ return '' .join (ch .lower () for ind , ch in enumerate (string ) if ch .isalnum ())
475+
476+
477+ def create_pod_id (dag_id : str , task_id : str ) -> str :
478+ """
479+ Generates the kubernetes safe pod_id. Note that this is
480+ NOT the full ID that will be launched to k8s. We will add a uuid
481+ to ensure uniqueness.
482+
483+ :param dag_id: DAG ID
484+ :param task_id: Task ID
485+ :@return: The non-unique pod_id for this task/DAG pairing1
486+ """
487+ safe_dag_id = _strip_unsafe_kubernetes_special_chars (dag_id )
488+ safe_task_id = _strip_unsafe_kubernetes_special_chars (task_id )
489+ return safe_dag_id + safe_task_id
490+
491+
473492class KubernetesExecutor (BaseExecutor , LoggingMixin ):
474493 """Executor for Kubernetes"""
475494
@@ -480,7 +499,7 @@ def __init__(self):
480499 self .result_queue : 'Queue[KubernetesResultsType]' = self ._manager .Queue ()
481500 self .kube_scheduler : Optional [AirflowKubernetesScheduler ] = None
482501 self .kube_client : Optional [client .CoreV1Api ] = None
483- self .worker_uuid : Optional [str ] = None
502+ self .scheduler_job_id : Optional [str ] = None
484503 super ().__init__ (parallelism = self .kube_config .parallelism )
485504
486505 @provide_session
@@ -519,7 +538,7 @@ def clear_not_launched_queued_tasks(self, session=None) -> None:
519538 pod_generator .datetime_to_label_safe_datestring (
520539 task .execution_date
521540 ),
522- self .worker_uuid
541+ self .scheduler_job_id
523542 )
524543 )
525544 # pylint: enable=protected-access
@@ -568,19 +587,14 @@ def _create_or_update_secret(secret_name, secret_path):
568587 def start (self ) -> None :
569588 """Starts the executor"""
570589 self .log .info ('Start Kubernetes executor' )
571- self .worker_uuid = KubeWorkerIdentifier .get_or_create_current_kube_worker_uuid ()
572- if not self .worker_uuid :
573- raise AirflowException ("Could not get worker uuid" )
574- self .log .debug ('Start with worker_uuid: %s' , self .worker_uuid )
575- # always need to reset resource version since we don't know
576- # when we last started, note for behavior below
577- # https://github.com/kubernetes-client/python/blob/master/kubernetes/docs
578- # /CoreV1Api.md#list_namespaced_pod
579- KubeResourceVersion .reset_resource_version ()
590+ if not self .job_id :
591+ raise AirflowException ("Could not get scheduler_job_id" )
592+ self .scheduler_job_id = self .job_id
593+ self .log .debug ('Start with scheduler_job_id: %s' , self .scheduler_job_id )
580594 self .kube_client = get_kube_client ()
581595 self .kube_scheduler = AirflowKubernetesScheduler (
582596 self .kube_config , self .task_queue , self .result_queue ,
583- self .kube_client , self .worker_uuid
597+ self .kube_client , self .scheduler_job_id
584598 )
585599 self ._inject_secrets ()
586600 self .clear_not_launched_queued_tasks ()
@@ -595,10 +609,10 @@ def execute_async(self,
595609 'Add task %s with command %s with executor_config %s' ,
596610 key , command , executor_config
597611 )
598-
599612 kube_executor_config = PodGenerator .from_obj (executor_config )
600613 if not self .task_queue :
601614 raise AirflowException (NOT_STARTED_MESSAGE )
615+ self .event_buffer [key ] = (State .QUEUED , self .scheduler_job_id )
602616 self .task_queue .put ((key , command , kube_executor_config ))
603617
604618 def sync (self ) -> None :
@@ -607,7 +621,7 @@ def sync(self) -> None:
607621 self .log .debug ('self.running: %s' , self .running )
608622 if self .queued_tasks :
609623 self .log .debug ('self.queued: %s' , self .queued_tasks )
610- if not self .worker_uuid :
624+ if not self .scheduler_job_id :
611625 raise AirflowException (NOT_STARTED_MESSAGE )
612626 if not self .kube_scheduler :
613627 raise AirflowException (NOT_STARTED_MESSAGE )
@@ -640,7 +654,8 @@ def sync(self) -> None:
640654 except Empty :
641655 break
642656
643- KubeResourceVersion .checkpoint_resource_version (last_resource_version )
657+ resource_instance = ResourceVersion ()
658+ resource_instance .resource_version = last_resource_version or resource_instance .resource_version
644659
645660 # pylint: disable=too-many-nested-blocks
646661 for _ in range (self .kube_config .worker_pods_creation_batch_size ):
@@ -681,6 +696,79 @@ def _change_state(self,
681696 self .log .debug ('Could not find key: %s' , str (key ))
682697 self .event_buffer [key ] = state , None
683698
699+ def try_adopt_task_instances (self , tis : List [TaskInstance ]) -> List [TaskInstance ]:
700+ tis_to_flush = [ti for ti in tis if not ti .external_executor_id ]
701+ scheduler_job_ids = [ti .external_executor_id for ti in tis ]
702+ pod_ids = {
703+ create_pod_id (dag_id = ti .dag_id , task_id = ti .task_id ): ti
704+ for ti in tis if ti .external_executor_id
705+ }
706+ kube_client : client .CoreV1Api = self .kube_client
707+ for scheduler_job_id in scheduler_job_ids :
708+ kwargs = {
709+ 'label_selector' : f'airflow-worker={ scheduler_job_id } '
710+ }
711+ pod_list = kube_client .list_namespaced_pod (
712+ namespace = self .kube_config .kube_namespace ,
713+ ** kwargs
714+ )
715+ for pod in pod_list .items :
716+ self .adopt_launched_task (kube_client , pod , pod_ids )
717+ self ._adopt_completed_pods (kube_client )
718+ tis_to_flush .extend (pod_ids .values ())
719+ return tis_to_flush
720+
721+ def adopt_launched_task (self , kube_client , pod , pod_ids : dict ):
722+ """
723+ Patch existing pod so that the current KubernetesJobWatcher can monitor it via label selectors
724+
725+ :param kube_client: kubernetes client for speaking to kube API
726+ :param pod: V1Pod spec that we will patch with new label
727+ :param pod_ids: pod_ids we expect to patch.
728+ """
729+ self .log .info ("attempting to adopt pod %s" , pod .metadata .name )
730+ pod .metadata .labels ['airflow-worker' ] = str (self .scheduler_job_id )
731+ dag_id = pod .metadata .labels ['dag_id' ]
732+ task_id = pod .metadata .labels ['task_id' ]
733+ pod_id = create_pod_id (dag_id = dag_id , task_id = task_id )
734+ if pod_id not in pod_ids :
735+ self .log .error ("attempting to adopt task %s in dag %s"
736+ " which was not specified by database" , task_id , dag_id )
737+ else :
738+ try :
739+ kube_client .patch_namespaced_pod (
740+ name = pod .metadata .name ,
741+ namespace = pod .metadata .namespace ,
742+ body = PodGenerator .serialize_pod (pod ),
743+ )
744+ pod_ids .pop (pod_id )
745+ except ApiException as e :
746+ self .log .info ("Failed to adopt pod %s. Reason: %s" , pod .metadata .name , e )
747+
748+ def _adopt_completed_pods (self , kube_client : kubernetes .client .CoreV1Api ):
749+ """
750+
751+ Patch completed pod so that the KubernetesJobWatcher can delete it.
752+
753+ :param kube_client: kubernetes client for speaking to kube API
754+ """
755+ kwargs = {
756+ 'field_selector' : "status.phase=Succeeded" ,
757+ 'label_selector' : 'kubernetes_executor=True' ,
758+ }
759+ pod_list = kube_client .list_namespaced_pod (namespace = self .kube_config .kube_namespace , ** kwargs )
760+ for pod in pod_list .items :
761+ self .log .info ("Attempting to adopt pod %s" , pod .metadata .name )
762+ pod .metadata .labels ['airflow-worker' ] = str (self .scheduler_job_id )
763+ try :
764+ kube_client .patch_namespaced_pod (
765+ name = pod .metadata .name ,
766+ namespace = pod .metadata .namespace ,
767+ body = PodGenerator .serialize_pod (pod ),
768+ )
769+ except ApiException as e :
770+ self .log .info ("Failed to adopt pod %s. Reason: %s" , pod .metadata .name , e )
771+
684772 def _flush_task_queue (self ) -> None :
685773 if not self .task_queue :
686774 raise AirflowException (NOT_STARTED_MESSAGE )
0 commit comments