Skip to content

Commit 3ca11eb

Browse files
Kubernetes executor can adopt tasks from other schedulers (#10996)
* KubernetesExecutor can adopt tasks from other schedulers * simplify * recreate tables properly * fix pylint Co-authored-by: Daniel Imberman <[email protected]>
1 parent 427a4a8 commit 3ca11eb

File tree

13 files changed

+330
-218
lines changed

13 files changed

+330
-218
lines changed

airflow/cli/commands/dag_command.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def generate_pod_yaml(args):
385385
"""Generates yaml files for each task in the DAG. Used for testing output of KubernetesExecutor"""
386386
from kubernetes.client.api_client import ApiClient
387387

388-
from airflow.executors.kubernetes_executor import AirflowKubernetesScheduler, KubeConfig
388+
from airflow.executors.kubernetes_executor import KubeConfig, create_pod_id
389389
from airflow.kubernetes import pod_generator
390390
from airflow.kubernetes.pod_generator import PodGenerator
391391
from airflow.settings import pod_mutation_hook
@@ -399,14 +399,14 @@ def generate_pod_yaml(args):
399399
pod = PodGenerator.construct_pod(
400400
dag_id=args.dag_id,
401401
task_id=ti.task_id,
402-
pod_id=AirflowKubernetesScheduler._create_pod_id( # pylint: disable=W0212
402+
pod_id=create_pod_id(
403403
args.dag_id, ti.task_id),
404404
try_number=ti.try_number,
405405
kube_image=kube_config.kube_image,
406406
date=ti.execution_date,
407407
command=ti.command_as_list(),
408408
pod_override_object=PodGenerator.from_obj(ti.executor_config),
409-
worker_uuid="worker-config",
409+
scheduler_job_id="worker-config",
410410
namespace=kube_config.executor_namespace,
411411
base_worker_pod=PodGenerator.deserialize_model_file(kube_config.pod_template_file)
412412
)

airflow/executors/base_executor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ class BaseExecutor(LoggingMixin):
5656
``0`` for infinity
5757
"""
5858

59+
job_id: Optional[str] = None
60+
5961
def __init__(self, parallelism: int = PARALLELISM):
6062
super().__init__()
6163
self.parallelism: int = parallelism

airflow/executors/kubernetes_executor.py

Lines changed: 140 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import multiprocessing
2828
import time
2929
from queue import Empty, Queue # pylint: disable=unused-import
30-
from typing import Any, Dict, Optional, Tuple, Union
30+
from typing import Any, Dict, List, Optional, Tuple, Union
3131

3232
import kubernetes
3333
from dateutil import parser
@@ -44,7 +44,7 @@
4444
from airflow.kubernetes.kube_client import get_kube_client
4545
from airflow.kubernetes.pod_generator import MAX_POD_ID_LEN, PodGenerator
4646
from airflow.kubernetes.pod_launcher import PodLauncher
47-
from airflow.models import KubeResourceVersion, KubeWorkerIdentifier, TaskInstance
47+
from airflow.models import TaskInstance
4848
from airflow.models.taskinstance import TaskInstanceKey
4949
from airflow.utils.log.logging_mixin import LoggingMixin
5050
from airflow.utils.session import provide_session
@@ -60,6 +60,18 @@
6060
KubernetesWatchType = Tuple[str, str, Optional[str], Dict[str, str], str]
6161

6262

63+
class ResourceVersion:
64+
"""Singleton for tracking resourceVersion from Kubernetes"""
65+
66+
_instance = None
67+
resource_version = "0"
68+
69+
def __new__(cls):
70+
if cls._instance is None:
71+
cls._instance = super().__new__(cls)
72+
return cls._instance
73+
74+
6375
class KubeConfig: # pylint: disable=too-many-instance-attributes
6476
"""Configuration for Kubernetes"""
6577

@@ -134,25 +146,25 @@ def __init__(self,
134146
multi_namespace_mode: bool,
135147
watcher_queue: 'Queue[KubernetesWatchType]',
136148
resource_version: Optional[str],
137-
worker_uuid: Optional[str],
149+
scheduler_job_id: Optional[str],
138150
kube_config: Configuration):
139151
super().__init__()
140152
self.namespace = namespace
141153
self.multi_namespace_mode = multi_namespace_mode
142-
self.worker_uuid = worker_uuid
154+
self.scheduler_job_id = scheduler_job_id
143155
self.watcher_queue = watcher_queue
144156
self.resource_version = resource_version
145157
self.kube_config = kube_config
146158

147159
def run(self) -> None:
148160
"""Performs watching"""
149161
kube_client: client.CoreV1Api = get_kube_client()
150-
if not self.worker_uuid:
162+
if not self.scheduler_job_id:
151163
raise AirflowException(NOT_STARTED_MESSAGE)
152164
while True:
153165
try:
154166
self.resource_version = self._run(kube_client, self.resource_version,
155-
self.worker_uuid, self.kube_config)
167+
self.scheduler_job_id, self.kube_config)
156168
except ReadTimeoutError:
157169
self.log.warning("There was a timeout error accessing the Kube API. "
158170
"Retrying request.", exc_info=True)
@@ -167,15 +179,15 @@ def run(self) -> None:
167179
def _run(self,
168180
kube_client: client.CoreV1Api,
169181
resource_version: Optional[str],
170-
worker_uuid: str,
182+
scheduler_job_id: str,
171183
kube_config: Any) -> Optional[str]:
172184
self.log.info(
173185
'Event: and now my watch begins starting at resource_version: %s',
174186
resource_version
175187
)
176188
watcher = watch.Watch()
177189

178-
kwargs = {'label_selector': 'airflow-worker={}'.format(worker_uuid)}
190+
kwargs = {'label_selector': 'airflow-worker={}'.format(scheduler_job_id)}
179191
if resource_version:
180192
kwargs['resource_version'] = resource_version
181193
if kube_config.kube_client_request_args:
@@ -277,7 +289,7 @@ def __init__(self,
277289
task_queue: 'Queue[KubernetesJobType]',
278290
result_queue: 'Queue[KubernetesResultsType]',
279291
kube_client: client.CoreV1Api,
280-
worker_uuid: str):
292+
scheduler_job_id: str):
281293
super().__init__()
282294
self.log.debug("Creating Kubernetes executor")
283295
self.kube_config = kube_config
@@ -289,16 +301,16 @@ def __init__(self,
289301
self.launcher = PodLauncher(kube_client=self.kube_client)
290302
self._manager = multiprocessing.Manager()
291303
self.watcher_queue = self._manager.Queue()
292-
self.worker_uuid = worker_uuid
304+
self.scheduler_job_id = scheduler_job_id
293305
self.kube_watcher = self._make_kube_watcher()
294306

295307
def _make_kube_watcher(self) -> KubernetesJobWatcher:
296-
resource_version = KubeResourceVersion.get_current_resource_version()
308+
resource_version = ResourceVersion().resource_version
297309
watcher = KubernetesJobWatcher(watcher_queue=self.watcher_queue,
298310
namespace=self.kube_config.kube_namespace,
299311
multi_namespace_mode=self.kube_config.multi_namespace_mode,
300312
resource_version=resource_version,
301-
worker_uuid=self.worker_uuid,
313+
scheduler_job_id=self.scheduler_job_id,
302314
kube_config=self.kube_config)
303315
watcher.start()
304316
return watcher
@@ -333,8 +345,8 @@ def run_next(self, next_job: KubernetesJobType) -> None:
333345

334346
pod = PodGenerator.construct_pod(
335347
namespace=self.namespace,
336-
worker_uuid=self.worker_uuid,
337-
pod_id=self._create_pod_id(dag_id, task_id),
348+
scheduler_job_id=self.scheduler_job_id,
349+
pod_id=create_pod_id(dag_id, task_id),
338350
dag_id=dag_id,
339351
task_id=task_id,
340352
kube_image=self.kube_config.kube_image,
@@ -404,21 +416,6 @@ def _annotations_to_key(self, annotations: Dict[str, str]) -> Optional[TaskInsta
404416

405417
return TaskInstanceKey(dag_id, task_id, execution_date, try_number)
406418

407-
@staticmethod
408-
def _strip_unsafe_kubernetes_special_chars(string: str) -> str:
409-
"""
410-
Kubernetes only supports lowercase alphanumeric characters and "-" and "." in
411-
the pod name
412-
However, there are special rules about how "-" and "." can be used so let's
413-
only keep
414-
alphanumeric chars see here for detail:
415-
https://kubernetes.io/docs/concepts/overview/working-with-objects/names/
416-
417-
:param string: The requested Pod name
418-
:return: ``str`` Pod name stripped of any unsafe characters
419-
"""
420-
return ''.join(ch.lower() for ind, ch in enumerate(string) if ch.isalnum())
421-
422419
@staticmethod
423420
def _make_safe_pod_id(safe_dag_id: str, safe_task_id: str, safe_uuid: str) -> str:
424421
r"""
@@ -437,14 +434,6 @@ def _make_safe_pod_id(safe_dag_id: str, safe_task_id: str, safe_uuid: str) -> st
437434

438435
return safe_pod_id
439436

440-
@staticmethod
441-
def _create_pod_id(dag_id: str, task_id: str) -> str:
442-
safe_dag_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars(
443-
dag_id)
444-
safe_task_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars(
445-
task_id)
446-
return safe_dag_id + safe_task_id
447-
448437
def _flush_watcher_queue(self) -> None:
449438
self.log.debug('Executor shutting down, watcher_queue approx. size=%d', self.watcher_queue.qsize())
450439
while True:
@@ -470,6 +459,36 @@ def terminate(self) -> None:
470459
self._manager.shutdown()
471460

472461

462+
def _strip_unsafe_kubernetes_special_chars(string: str) -> str:
463+
"""
464+
Kubernetes only supports lowercase alphanumeric characters, "-" and "." in
465+
the pod name.
466+
However, there are special rules about how "-" and "." can be used so let's
467+
only keep
468+
alphanumeric chars see here for detail:
469+
https://kubernetes.io/docs/concepts/overview/working-with-objects/names/
470+
471+
:param string: The requested Pod name
472+
:return: ``str`` Pod name stripped of any unsafe characters
473+
"""
474+
return ''.join(ch.lower() for ind, ch in enumerate(string) if ch.isalnum())
475+
476+
477+
def create_pod_id(dag_id: str, task_id: str) -> str:
478+
"""
479+
Generates the kubernetes safe pod_id. Note that this is
480+
NOT the full ID that will be launched to k8s. We will add a uuid
481+
to ensure uniqueness.
482+
483+
:param dag_id: DAG ID
484+
:param task_id: Task ID
485+
:@return: The non-unique pod_id for this task/DAG pairing1
486+
"""
487+
safe_dag_id = _strip_unsafe_kubernetes_special_chars(dag_id)
488+
safe_task_id = _strip_unsafe_kubernetes_special_chars(task_id)
489+
return safe_dag_id + safe_task_id
490+
491+
473492
class KubernetesExecutor(BaseExecutor, LoggingMixin):
474493
"""Executor for Kubernetes"""
475494

@@ -480,7 +499,7 @@ def __init__(self):
480499
self.result_queue: 'Queue[KubernetesResultsType]' = self._manager.Queue()
481500
self.kube_scheduler: Optional[AirflowKubernetesScheduler] = None
482501
self.kube_client: Optional[client.CoreV1Api] = None
483-
self.worker_uuid: Optional[str] = None
502+
self.scheduler_job_id: Optional[str] = None
484503
super().__init__(parallelism=self.kube_config.parallelism)
485504

486505
@provide_session
@@ -519,7 +538,7 @@ def clear_not_launched_queued_tasks(self, session=None) -> None:
519538
pod_generator.datetime_to_label_safe_datestring(
520539
task.execution_date
521540
),
522-
self.worker_uuid
541+
self.scheduler_job_id
523542
)
524543
)
525544
# pylint: enable=protected-access
@@ -568,19 +587,14 @@ def _create_or_update_secret(secret_name, secret_path):
568587
def start(self) -> None:
569588
"""Starts the executor"""
570589
self.log.info('Start Kubernetes executor')
571-
self.worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid()
572-
if not self.worker_uuid:
573-
raise AirflowException("Could not get worker uuid")
574-
self.log.debug('Start with worker_uuid: %s', self.worker_uuid)
575-
# always need to reset resource version since we don't know
576-
# when we last started, note for behavior below
577-
# https://github.com/kubernetes-client/python/blob/master/kubernetes/docs
578-
# /CoreV1Api.md#list_namespaced_pod
579-
KubeResourceVersion.reset_resource_version()
590+
if not self.job_id:
591+
raise AirflowException("Could not get scheduler_job_id")
592+
self.scheduler_job_id = self.job_id
593+
self.log.debug('Start with scheduler_job_id: %s', self.scheduler_job_id)
580594
self.kube_client = get_kube_client()
581595
self.kube_scheduler = AirflowKubernetesScheduler(
582596
self.kube_config, self.task_queue, self.result_queue,
583-
self.kube_client, self.worker_uuid
597+
self.kube_client, self.scheduler_job_id
584598
)
585599
self._inject_secrets()
586600
self.clear_not_launched_queued_tasks()
@@ -595,10 +609,10 @@ def execute_async(self,
595609
'Add task %s with command %s with executor_config %s',
596610
key, command, executor_config
597611
)
598-
599612
kube_executor_config = PodGenerator.from_obj(executor_config)
600613
if not self.task_queue:
601614
raise AirflowException(NOT_STARTED_MESSAGE)
615+
self.event_buffer[key] = (State.QUEUED, self.scheduler_job_id)
602616
self.task_queue.put((key, command, kube_executor_config))
603617

604618
def sync(self) -> None:
@@ -607,7 +621,7 @@ def sync(self) -> None:
607621
self.log.debug('self.running: %s', self.running)
608622
if self.queued_tasks:
609623
self.log.debug('self.queued: %s', self.queued_tasks)
610-
if not self.worker_uuid:
624+
if not self.scheduler_job_id:
611625
raise AirflowException(NOT_STARTED_MESSAGE)
612626
if not self.kube_scheduler:
613627
raise AirflowException(NOT_STARTED_MESSAGE)
@@ -640,7 +654,8 @@ def sync(self) -> None:
640654
except Empty:
641655
break
642656

643-
KubeResourceVersion.checkpoint_resource_version(last_resource_version)
657+
resource_instance = ResourceVersion()
658+
resource_instance.resource_version = last_resource_version or resource_instance.resource_version
644659

645660
# pylint: disable=too-many-nested-blocks
646661
for _ in range(self.kube_config.worker_pods_creation_batch_size):
@@ -681,6 +696,79 @@ def _change_state(self,
681696
self.log.debug('Could not find key: %s', str(key))
682697
self.event_buffer[key] = state, None
683698

699+
def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]:
700+
tis_to_flush = [ti for ti in tis if not ti.external_executor_id]
701+
scheduler_job_ids = [ti.external_executor_id for ti in tis]
702+
pod_ids = {
703+
create_pod_id(dag_id=ti.dag_id, task_id=ti.task_id): ti
704+
for ti in tis if ti.external_executor_id
705+
}
706+
kube_client: client.CoreV1Api = self.kube_client
707+
for scheduler_job_id in scheduler_job_ids:
708+
kwargs = {
709+
'label_selector': f'airflow-worker={scheduler_job_id}'
710+
}
711+
pod_list = kube_client.list_namespaced_pod(
712+
namespace=self.kube_config.kube_namespace,
713+
**kwargs
714+
)
715+
for pod in pod_list.items:
716+
self.adopt_launched_task(kube_client, pod, pod_ids)
717+
self._adopt_completed_pods(kube_client)
718+
tis_to_flush.extend(pod_ids.values())
719+
return tis_to_flush
720+
721+
def adopt_launched_task(self, kube_client, pod, pod_ids: dict):
722+
"""
723+
Patch existing pod so that the current KubernetesJobWatcher can monitor it via label selectors
724+
725+
:param kube_client: kubernetes client for speaking to kube API
726+
:param pod: V1Pod spec that we will patch with new label
727+
:param pod_ids: pod_ids we expect to patch.
728+
"""
729+
self.log.info("attempting to adopt pod %s", pod.metadata.name)
730+
pod.metadata.labels['airflow-worker'] = str(self.scheduler_job_id)
731+
dag_id = pod.metadata.labels['dag_id']
732+
task_id = pod.metadata.labels['task_id']
733+
pod_id = create_pod_id(dag_id=dag_id, task_id=task_id)
734+
if pod_id not in pod_ids:
735+
self.log.error("attempting to adopt task %s in dag %s"
736+
" which was not specified by database", task_id, dag_id)
737+
else:
738+
try:
739+
kube_client.patch_namespaced_pod(
740+
name=pod.metadata.name,
741+
namespace=pod.metadata.namespace,
742+
body=PodGenerator.serialize_pod(pod),
743+
)
744+
pod_ids.pop(pod_id)
745+
except ApiException as e:
746+
self.log.info("Failed to adopt pod %s. Reason: %s", pod.metadata.name, e)
747+
748+
def _adopt_completed_pods(self, kube_client: kubernetes.client.CoreV1Api):
749+
"""
750+
751+
Patch completed pod so that the KubernetesJobWatcher can delete it.
752+
753+
:param kube_client: kubernetes client for speaking to kube API
754+
"""
755+
kwargs = {
756+
'field_selector': "status.phase=Succeeded",
757+
'label_selector': 'kubernetes_executor=True',
758+
}
759+
pod_list = kube_client.list_namespaced_pod(namespace=self.kube_config.kube_namespace, **kwargs)
760+
for pod in pod_list.items:
761+
self.log.info("Attempting to adopt pod %s", pod.metadata.name)
762+
pod.metadata.labels['airflow-worker'] = str(self.scheduler_job_id)
763+
try:
764+
kube_client.patch_namespaced_pod(
765+
name=pod.metadata.name,
766+
namespace=pod.metadata.namespace,
767+
body=PodGenerator.serialize_pod(pod),
768+
)
769+
except ApiException as e:
770+
self.log.info("Failed to adopt pod %s. Reason: %s", pod.metadata.name, e)
771+
684772
def _flush_task_queue(self) -> None:
685773
if not self.task_queue:
686774
raise AirflowException(NOT_STARTED_MESSAGE)

0 commit comments

Comments
 (0)