Skip to content

Commit 741b202

Browse files
committed
refactor: centralize credential and param extraction for eks workload list tools
1 parent 9a8bcbc commit 741b202

5 files changed

Lines changed: 97 additions & 33 deletions

File tree

app/tools/EKSListClustersTool/__init__.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from app.services.eks.eks_client import EKSClient
1111
from app.tools.tool_decorator import tool
12+
from app.tools.utils.eks_workload_helper import extract_workload_params
1213

1314
logger = logging.getLogger(__name__)
1415

@@ -26,14 +27,6 @@ def _eks_creds(eks: dict) -> dict:
2627
}
2728

2829

29-
def _list_clusters_extract_params(sources: dict[str, dict]) -> dict[str, Any]:
30-
eks = sources["eks"]
31-
return {
32-
"cluster_names": eks.get("cluster_names", []),
33-
**_eks_creds(eks),
34-
}
35-
36-
3730
@tool(
3831
name="list_eks_clusters",
3932
source="eks",
@@ -55,7 +48,7 @@ def _list_clusters_extract_params(sources: dict[str, dict]) -> dict[str, Any]:
5548
"required": ["role_arn"],
5649
},
5750
is_available=_eks_available,
58-
extract_params=_list_clusters_extract_params,
51+
extract_params=extract_workload_params,
5952
)
6053
def list_eks_clusters(
6154
role_arn: str,

app/tools/EKSListDeploymentsTool/__init__.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,13 @@
66
from typing import Any, cast
77

88
from app.services.eks.eks_k8s_client import build_k8s_clients
9-
from app.tools.EKSListClustersTool import _eks_creds
109
from app.tools.tool_decorator import tool
1110
from app.tools.utils.availability import eks_available_or_backend
11+
from app.tools.utils.eks_workload_helper import extract_workload_params
1212

1313
logger = logging.getLogger(__name__)
1414

1515

16-
def _list_deployments_extract_params(sources: dict[str, dict]) -> dict[str, Any]:
17-
eks = sources["eks"]
18-
return {
19-
"cluster_name": eks.get("cluster_name", ""),
20-
"namespace": eks.get("namespace") or "all",
21-
"eks_backend": eks.get("_backend"),
22-
**_eks_creds(eks),
23-
}
24-
25-
2616
@tool(
2717
name="list_eks_deployments",
2818
source="eks",
@@ -45,7 +35,7 @@ def _list_deployments_extract_params(sources: dict[str, dict]) -> dict[str, Any]
4535
"required": ["cluster_name", "namespace", "role_arn"],
4636
},
4737
is_available=eks_available_or_backend,
48-
extract_params=_list_deployments_extract_params,
38+
extract_params=extract_workload_params,
4939
)
5040
def list_eks_deployments(
5141
cluster_name: str,

app/tools/EKSListPodsTool/__init__.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,13 @@
66
from typing import Any, cast
77

88
from app.services.eks.eks_k8s_client import build_k8s_clients
9-
from app.tools.EKSListClustersTool import _eks_creds
109
from app.tools.tool_decorator import tool
1110
from app.tools.utils.availability import eks_available_or_backend
11+
from app.tools.utils.eks_workload_helper import extract_workload_params
1212

1313
logger = logging.getLogger(__name__)
1414

1515

16-
def _list_pods_extract_params(sources: dict[str, dict]) -> dict[str, Any]:
17-
eks = sources["eks"]
18-
return {
19-
"cluster_name": eks.get("cluster_name", ""),
20-
"namespace": eks.get("namespace") or "all",
21-
"eks_backend": eks.get("_backend"),
22-
**_eks_creds(eks),
23-
}
24-
25-
2616
@tool(
2717
name="list_eks_pods",
2818
source="eks",
@@ -46,7 +36,7 @@ def _list_pods_extract_params(sources: dict[str, dict]) -> dict[str, Any]:
4636
"required": ["cluster_name", "namespace", "role_arn"],
4737
},
4838
is_available=eks_available_or_backend,
49-
extract_params=_list_pods_extract_params,
39+
extract_params=extract_workload_params,
5040
)
5141
def list_eks_pods(
5242
cluster_name: str,
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""Shared helpers for EKS workload investigation tools"""
2+
3+
from __future__ import annotations
4+
5+
from typing import Any
6+
7+
8+
def _eks_creds(eks: dict) -> dict:
9+
"""Extract AWS credentials from EKS source"""
10+
11+
return {
12+
"role_arn": eks.get("role_arn", ""),
13+
"external_id": eks.get("external_id", ""),
14+
"region": eks.get("region", "us-east-1"),
15+
"credentials": eks.get("credentials"),
16+
}
17+
18+
19+
def extract_workload_params(sources: dict[str, dict]) -> dict[str, Any]:
20+
"""Extract common parameters for workload list operations (pods/deployments)"""
21+
22+
eks = sources.get("eks")
23+
if eks is None:
24+
raise ValueError("Sources dictionary must contain an 'eks' key with cluster configuration")
25+
26+
return {
27+
"cluster_name": eks.get("cluster_name", ""),
28+
"namespace": eks.get("namespace") or "all",
29+
"eks_backend": eks.get("_backend"),
30+
**_eks_creds(eks),
31+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""Tests for eks workload helpers"""
2+
3+
from __future__ import annotations
4+
5+
from app.tools.utils.eks_workload_helper import extract_workload_params
6+
7+
8+
def test_extract_basic_params():
9+
"""Test basic parameter extraction with minimal config"""
10+
sources = {"eks": {"cluster_name": "test-cluster", "namespace": "default"}}
11+
12+
result = extract_workload_params(sources)
13+
14+
assert result["cluster_name"] == "test-cluster"
15+
assert result["namespace"] == "default"
16+
assert result["region"] == "us-east-1"
17+
assert result["role_arn"] == ""
18+
assert result["external_id"] == ""
19+
assert result["credentials"] is None
20+
21+
22+
def test_namespace_defaults_to_all():
23+
"""Test namespace defaults to 'all' when not provided"""
24+
sources = {"eks": {"cluster_name": "test-cluster"}}
25+
26+
result = extract_workload_params(sources)
27+
28+
assert result["namespace"] == "all"
29+
30+
31+
def test_handles_all_optional_fields():
32+
"""Test extraction includes all optional AWS fields"""
33+
sources = {
34+
"eks": {
35+
"cluster_name": "prod-cluster",
36+
"role_arn": "arn:aws:iam::123:role/test",
37+
"external_id": "external-123",
38+
"region": "us-west-2",
39+
"credentials": {"access_key": "key123"},
40+
}
41+
}
42+
43+
result = extract_workload_params(sources)
44+
45+
assert result["cluster_name"] == "prod-cluster"
46+
assert result["role_arn"] == "arn:aws:iam::123:role/test"
47+
assert result["external_id"] == "external-123"
48+
assert result["region"] == "us-west-2"
49+
assert result["credentials"] == {"access_key": "key123"}
50+
51+
52+
def test_missing_eks_raises_error():
53+
"""Test ValueError when 'eks' key is missing"""
54+
sources = {"other": {}}
55+
56+
try:
57+
extract_workload_params(sources)
58+
raise AssertionError("Should have raised ValueError")
59+
except ValueError as e:
60+
assert "must contain an 'eks' key" in str(e)

0 commit comments

Comments
 (0)