Skip to content

Commit 66d39c5

Browse files
authored
[Workflows] Auth secret token mounting for Workflows (#9151)
### 📝 Description <!-- A short summary of what this PR does. --> <!-- Include any relevant context or background information. --> Add support for IG4 authentication on workflows by mounting the secret on the argo pods. This PR moves the core logic of `enrich_and_validate_auth_token_name` out of the launcher to a more common place so it can be used by workflows since they don't go through launcher/runtime handler. --- ### 🛠️ Changes Made <!-- - Key changes (e.g., added feature X, refactored Y, fixed Z) --> - Move `enrich_and_validate_auth_token_name` core logic from launcher to `mlrun.auth.utils` - Create helper function `resolve_auth_token_secret_name` for pipelines that gets token name and then extract secret name. - Refactor `replace_kfp_plaintext_secret_env_vars_with_secret_refs` to `process_kfp_workflow_secret_references` to pass the `auth_secret_name` param so that it gets mounted to the argo pods during `_enrich_kfp_workflow_yaml_credentials` --- ### ✅ Checklist - [ ] I updated the documentation (if applicable) - [x] I have tested the changes in this PR - [ ] I confirmed whether my changes are covered by system tests - [ ] If yes, I ran all relevant system tests and ensured they passed before submitting this PR - [ ] I updated existing system tests and/or added new ones if needed to cover my changes - [ ] If I introduced a deprecation: - [ ] I followed the [Deprecation Guidelines](./DEPRECATION.md) - [ ] I updated the relevant Jira ticket for documentation --- ### 🧪 Testing <!-- - How it was tested (unit tests, manual, integration) --> <!-- - Any special cases covered. --> Unit tests - `test_resolve_auth_secret_name` - `test_enrich_and_validate_auth_token_name` --- ### 🔗 References - Ticket link: https://iguazio.atlassian.net/browse/ML-11588 - Design docs links: - External links: --- ### 🚨 Breaking Changes? - [ ] Yes (explain below) - [ ] No <!-- If yes, describe what needs to be changed downstream: --> --- ### 🔍️ Additional Notes <!-- Anything else reviewers should know (follow-up tasks, known issues, affected areas etc.). --> <!-- ### 📸 Screenshots / Logs -->
1 parent c512081 commit 66d39c5

File tree

10 files changed

+194
-33
lines changed

10 files changed

+194
-33
lines changed

mlrun/auth/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import yaml
1919

20+
import mlrun.common.constants
2021
import mlrun.common.schemas
2122
import mlrun.utils.helpers
2223
from mlrun.config import config as mlconf

mlrun/common/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141
MLRUN_JOB_AUTH_SECRET_PATH = "/var/mlrun-secrets/auth"
4242
MLRUN_JOB_AUTH_SECRET_FILE = ".igz.yml"
43-
MLRUN_JOB_AUTH_DEFAULT_TOKEN_NAME = "default"
43+
MLRUN_RUNTIME_AUTH_DEFAULT_TOKEN_NAME = "default"
4444

4545

4646
class MLRunInternalLabels:

pipeline-adapters/mlrun-pipelines-kfp-common/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "mlrun-pipelines-kfp-common"
3-
version = "0.6.0"
3+
version = "0.6.1"
44
description = "MLRun Pipelines adapter package for providing KFP common functionality"
55
readme = "README.md"
66
requires-python = ">=3.11, <3.12"

pipeline-adapters/mlrun-pipelines-kfp-common/src/mlrun_pipelines/common/ops.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -730,11 +730,12 @@ def _enrich_gpu_limits(function, task):
730730
task.container.add_resource_limit(resource_name, resource_value)
731731

732732

733-
def replace_kfp_plaintext_secret_env_vars_with_secret_refs(
733+
def process_kfp_workflow_secret_references(
734734
byte_buffer: bytes,
735735
content_type: str,
736736
env_var_names: list[str],
737737
secrets_store: "SecretsStore",
738+
auth_secret_name: typing.Optional[str] = None,
738739
) -> bytes:
739740
if content_type.endswith(
740741
"zip"
@@ -744,13 +745,15 @@ def replace_kfp_plaintext_secret_env_vars_with_secret_refs(
744745
byte_buffer=byte_buffer,
745746
env_var_names=env_var_names,
746747
secrets_store=secrets_store,
748+
auth_secret_name=auth_secret_name,
747749
)
748750
return modified_zip_bytes
749751
elif content_type.endswith(("yaml", "plain")):
750752
modified_yaml_bytes = _enrich_kfp_workflow_yaml_credentials(
751753
yaml_bytes=byte_buffer,
752754
env_var_names=env_var_names,
753755
secrets_store=secrets_store,
756+
auth_secret_name=auth_secret_name,
754757
)
755758
return modified_yaml_bytes
756759
else:
@@ -761,11 +764,12 @@ def _enrich_kfp_workflow_credentials_in_subprocess(
761764
byte_buffer: bytes,
762765
env_var_names: list[str],
763766
secrets_store: "SecretsStore",
767+
auth_secret_name: typing.Optional[str] = None,
764768
) -> bytes:
765769
queue = multiprocessing.Queue()
766770
process = multiprocessing.Process(
767771
target=_enrich_wrapper,
768-
args=(queue, byte_buffer, env_var_names, secrets_store),
772+
args=(queue, byte_buffer, env_var_names, secrets_store, auth_secret_name),
769773
)
770774
process.start()
771775
result = queue.get()
@@ -778,11 +782,13 @@ def _enrich_wrapper(
778782
byte_buffer: bytes,
779783
env_var_names: list[str],
780784
secrets_store: "SecretsStore",
785+
auth_secret_name: typing.Optional[str] = None,
781786
):
782787
result = _enrich_kfp_workflow_zip_credentials(
783788
byte_buffer=byte_buffer,
784789
env_var_names=env_var_names,
785790
secrets_store=secrets_store,
791+
auth_secret_name=auth_secret_name,
786792
)
787793
queue.put(result)
788794

@@ -791,6 +797,7 @@ def _enrich_kfp_workflow_zip_credentials(
791797
byte_buffer: bytes,
792798
env_var_names: list[str],
793799
secrets_store: "SecretsStore",
800+
auth_secret_name: typing.Optional[str] = None,
794801
) -> bytes:
795802
in_memory_zip = io.BytesIO(byte_buffer)
796803
with zipfile.ZipFile(in_memory_zip, "r") as zip_read:
@@ -806,6 +813,7 @@ def _enrich_kfp_workflow_zip_credentials(
806813
yaml_bytes=file_data,
807814
env_var_names=env_var_names,
808815
secrets_store=secrets_store,
816+
auth_secret_name=auth_secret_name,
809817
)
810818
files_data[file_name] = modified_yaml
811819

@@ -821,13 +829,16 @@ def _enrich_kfp_workflow_yaml_credentials(
821829
yaml_bytes: bytes,
822830
env_var_names: list[str],
823831
secrets_store: "SecretsStore",
832+
auth_secret_name: typing.Optional[str] = None,
824833
) -> bytes:
825834
"""
826835
Modifies the given workflow YAML to add secret environment variables to container specifications.
827836
The function checks if the workflow uses Argo Workflows or Tekton Pipelines and injects the
828837
environment variables accordingly.
829838
"""
830839
workflow_dict = yaml.safe_load(yaml_bytes)
840+
workflow_dict = add_auth_mount_to_argo_pods(workflow_dict, auth_secret_name)
841+
831842
# Determine the KFP version by checking the 'apiVersion' field
832843
api_version = (
833844
workflow_dict.get("api_version") or workflow_dict.get("apiVersion", "").lower()
@@ -867,6 +878,41 @@ def _enrich_kfp_workflow_yaml_credentials(
867878
)
868879

869880

881+
def add_auth_mount_to_argo_pods(
882+
workflow_dict: dict, auth_secret_name: typing.Optional[str] = None
883+
) -> dict:
884+
if auth_secret_name:
885+
volume = {
886+
"name": "secret",
887+
"secret": {
888+
"secretName": auth_secret_name,
889+
"items": [
890+
{
891+
"key": "tokensFile",
892+
"path": mlrun.common.constants.MLRUN_JOB_AUTH_SECRET_FILE,
893+
}
894+
],
895+
},
896+
}
897+
volume_mount = {
898+
"name": "secret",
899+
"mountPath": mlrun.common.constants.MLRUN_JOB_AUTH_SECRET_PATH,
900+
}
901+
902+
for template in workflow_dict["spec"]["templates"]:
903+
# Skip DAG-only templates
904+
if "container" not in template:
905+
continue
906+
907+
# Add volumes to the template
908+
template.setdefault("volumes", []).append(volume)
909+
910+
# Add volumeMounts to the container
911+
template["container"].setdefault("volumeMounts", []).append(volume_mount)
912+
913+
return workflow_dict
914+
915+
870916
def _replace_secret_envs_in_argocd_template(
871917
env_var_names: list[str],
872918
container: dict,

server/py/services/api/api/endpoints/pipelines.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,7 @@ async def _create_pipeline(
487487
content_type,
488488
data,
489489
arguments,
490+
auth_info,
490491
)
491492

492493
return {

server/py/services/api/crud/pipelines.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import sqlalchemy.orm
2626

2727
import mlrun
28+
import mlrun.auth.utils
2829
import mlrun.common.constants as mlrun_constants
2930
import mlrun.common.formatters
3031
import mlrun.common.helpers
@@ -47,9 +48,10 @@
4748

4849
import framework.api.utils
4950
import framework.utils.singletons.db
51+
import framework.utils.singletons.k8s
5052
import services.api.crud
53+
import services.api.utils.helpers
5154
from services.api.crud.workflows import RerunRunner
52-
from services.api.utils.helpers import resolve_client_default_kfp_image
5355

5456

5557
class Pipelines(
@@ -432,7 +434,7 @@ def rerun_pipeline_via_runner(
432434
- status: `"running"`
433435
- run_id: the new MLRun-run UID for the RerunRunner job
434436
"""
435-
client_image = resolve_client_default_kfp_image(
437+
client_image = services.api.utils.helpers.resolve_client_default_kfp_image(
436438
project,
437439
workflow_spec=None,
438440
client_version=client_version,
@@ -601,6 +603,7 @@ def create_pipeline(
601603
content_type: str,
602604
data: bytes,
603605
arguments: typing.Optional[dict] = None,
606+
auth_info: typing.Optional[mlrun.common.schemas.AuthInfo] = None,
604607
):
605608
if arguments is None:
606609
arguments = {}
@@ -616,11 +619,21 @@ def create_pipeline(
616619
mlrun.utils.logger.debug(
617620
"Writing pipeline to temp file", content_type=content_type
618621
)
619-
data = mlrun_pipelines.common.ops.replace_kfp_plaintext_secret_env_vars_with_secret_refs(
622+
623+
# TODO In ML-11600, pass the token name from the request
624+
provided_token_name = None
625+
# Workflows do not go through launcher/runtime handler
626+
# So enrichment, validation and secret retrieval need to be done here
627+
auth_secret_name = services.api.utils.helpers.resolve_auth_token_secret_name(
628+
provided_token_name=provided_token_name, username=auth_info.username
629+
)
630+
631+
data = mlrun_pipelines.common.ops.process_kfp_workflow_secret_references(
620632
byte_buffer=data,
621633
content_type=content_type,
622634
env_var_names=["MLRUN_AUTH_SESSION", "V3IO_ACCESS_KEY"],
623635
secrets_store=services.api.crud.Secrets(),
636+
auth_secret_name=auth_secret_name,
624637
)
625638
pipeline_file = tempfile.NamedTemporaryFile(suffix=content_type)
626639
with open(pipeline_file.name, "wb") as fp:

server/py/services/api/launcher.py

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from dependency_injector import containers, providers
2020

21+
import mlrun.auth.utils
2122
import mlrun.common.constants as mlrun_constants
2223
import mlrun.common.runtimes.constants
2324
import mlrun.common.schemas.schedule
@@ -674,36 +675,24 @@ def _validate_retry(runtime_kind: str, retry: Optional["mlrun.model.Retry"]):
674675
f"must be less than {staleness_threshold_seconds} seconds, got {max_delay} seconds"
675676
)
676677

677-
# TODO In ML-11600, implement token name resolution and validation + tests
678678
def enrich_and_validate_auth_token_name(
679679
self, object: Union[mlrun.run.RunObject, mlrun.runtimes.RemoteRuntime]
680680
):
681-
if mlrun.mlconf.is_iguazio_v4_mode():
682-
if object.spec.auth is None:
683-
object.spec.auth = {}
684-
685-
# Get the provided token name, if any
686-
provided_token_name = object.spec.auth.get("token_name")
687-
688-
# Resolve token name and raise error only if token is explicitly provided by the user
689-
# in ML-11600, we will implement a proper resolution logic that checks all secret tokens
690-
# of the user and finds a valid one if no token name is provided
691-
raise_error_on_failure = bool(provided_token_name)
692-
token_name = (
693-
provided_token_name
694-
or mlrun.common.constants.MLRUN_JOB_AUTH_DEFAULT_TOKEN_NAME
695-
)
696-
self._validate_token_name(
697-
token_name, raise_error_on_failure=raise_error_on_failure
698-
)
699-
700-
object.spec.auth["token_name"] = token_name
681+
if object.spec.auth is None:
682+
object.spec.auth = {}
683+
684+
# Get the provided token name, if any
685+
provided_token_name = object.spec.auth.get("token_name")
686+
687+
# In ML-11600, we will implement a proper resolution logic that checks all secret tokens
688+
# of the user and finds a valid one if no token name is provided
689+
# If token name not provided, use default
690+
token_name = (
691+
provided_token_name
692+
or mlrun.common.constants.MLRUN_RUNTIME_AUTH_DEFAULT_TOKEN_NAME
693+
)
701694

702-
# TODO implement validation in ML-11600
703-
def _validate_token_name(
704-
self, token_name: str, raise_error_on_failure: bool = False
705-
):
706-
pass
695+
object.spec.auth["token_name"] = token_name
707696

708697

709698
# Once this file is imported it will set the container server side launcher

server/py/services/api/tests/unit/api/test_utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import mlrun.runtimes.mounts
3535
import mlrun.runtimes.pod
3636
import mlrun.utils
37+
from mlrun.common.types import AuthenticationMode
3738
from server.py.framework.api.utils import (
3839
_generate_function_and_task_from_submit_run_body,
3940
)
@@ -2043,3 +2044,57 @@ def test_setenv_from_the_project_secret(secret_name, expect_exception, kind):
20432044
else:
20442045
# Should not raise
20452046
framework.api.utils.validate_function_env_vars(function)
2047+
2048+
2049+
@pytest.mark.parametrize(
2050+
"provided_token, secret_name, expected_secret_name, expected_token_name",
2051+
[
2052+
# default token, secret exists
2053+
(
2054+
None,
2055+
"secret-1",
2056+
"secret-1",
2057+
mlrun.common.constants.MLRUN_RUNTIME_AUTH_DEFAULT_TOKEN_NAME,
2058+
),
2059+
# explicit token, secret exists
2060+
("custom-token", "secret-2", "secret-2", "custom-token"),
2061+
# default token, secret missing
2062+
(
2063+
None,
2064+
None,
2065+
None,
2066+
mlrun.common.constants.MLRUN_RUNTIME_AUTH_DEFAULT_TOKEN_NAME,
2067+
),
2068+
# explicit token, secret missing
2069+
("custom-token", None, None, "custom-token"),
2070+
],
2071+
)
2072+
def test_resolve_auth_secret_name(
2073+
monkeypatch, provided_token, secret_name, expected_secret_name, expected_token_name
2074+
):
2075+
mlrun.mlconf.httpdb.authentication.mode = AuthenticationMode.IGUAZIO_V4
2076+
2077+
secret = None
2078+
if secret_name:
2079+
secret = unittest.mock.Mock()
2080+
secret.metadata.name = secret_name
2081+
2082+
k8s_helper = unittest.mock.Mock()
2083+
k8s_helper._get_user_token_secret.return_value = secret
2084+
2085+
monkeypatch.setattr(
2086+
"framework.utils.singletons.k8s.get_k8s_helper",
2087+
lambda: k8s_helper,
2088+
)
2089+
2090+
result = services.api.utils.helpers.resolve_auth_token_secret_name(
2091+
provided_token, "test-user"
2092+
)
2093+
2094+
assert result == expected_secret_name
2095+
2096+
# Verify the function uses the correct token name (default or provided)
2097+
k8s_helper._get_user_token_secret.assert_called_once_with(
2098+
username="test-user",
2099+
token_name=expected_token_name,
2100+
)

server/py/services/api/tests/unit/test_launcher.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,3 +468,30 @@ def test_launcher_skips_aborted_or_deleted_run(monkeypatch):
468468
# Validate result
469469
assert run.status.state == mlrun.common.runtimes.constants.RunStates.aborted
470470
assert not runtime_handler_mock.called
471+
472+
473+
@pytest.mark.parametrize(
474+
"initial_auth, expected_token",
475+
[
476+
# auth missing → default token
477+
(None, mlrun.common.constants.MLRUN_RUNTIME_AUTH_DEFAULT_TOKEN_NAME),
478+
# auth exists but no token_name → default token
479+
({}, mlrun.common.constants.MLRUN_RUNTIME_AUTH_DEFAULT_TOKEN_NAME),
480+
# explicit token_name → preserved
481+
({"token_name": "custom-token"}, "custom-token"),
482+
],
483+
)
484+
def test_enrich_and_validate_auth_token_name(
485+
initial_auth,
486+
expected_token,
487+
):
488+
launcher = services.api.launcher.ServerSideLauncher(
489+
auth_info=mlrun.common.schemas.AuthInfo()
490+
)
491+
run = mlrun.run.RunObject(
492+
spec=mlrun.model.RunSpec(auth=initial_auth),
493+
)
494+
495+
launcher.enrich_and_validate_auth_token_name(run)
496+
497+
assert run.spec.auth["token_name"] == expected_token

0 commit comments

Comments
 (0)