Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def __init__(
self._load_from_env()
self._known_file_config_loader()
self._fix_host_if_needed()
self._resolve_host_metadata()
self._validate()
self.init_auth()
self._init_product(product, product_version)
Expand Down Expand Up @@ -629,23 +630,29 @@ def _resolve_host_metadata(self) -> None:
Fills in account_id, workspace_id, and discovery_url (derived from oidc_endpoint,
with any {account_id} placeholder substituted) if not already set.
"""
# TODO: Enable this everywhere
if not self.host_type == HostType.UNIFIED:
return
if not self.host:
return
meta = get_host_metadata(self.host)
try:
meta = get_host_metadata(self.host)
except Exception as e:
logger.warning(
f"Failed to automatically resolve config from host metadata: {e}. Falling back to explicit user provided configuration."
)
return
if not self.account_id and meta.account_id:
logger.debug(f"Resolved account_id from host metadata: {meta.account_id}")
self.account_id = meta.account_id
if not self.account_id:
raise ValueError("account_id is not configured and could not be resolved from host metadata")
if not self.workspace_id and meta.workspace_id:
logger.debug(f"Resolved workspace_id from host metadata: {meta.workspace_id}")
self.workspace_id = meta.workspace_id
if not self.discovery_url:
if meta.oidc_endpoint:
logger.debug(f"Resolved discovery_url from host metadata: {meta.oidc_endpoint}")
self.discovery_url = meta.oidc_endpoint.replace("{account_id}", self.account_id)
else:
raise ValueError("discovery_url is not configured and could not be resolved from host metadata")
if not self.discovery_url and meta.oidc_endpoint:
if "{account_id}" in meta.oidc_endpoint and not self.account_id:
raise ValueError("account_id is required to resolve discovery_url from host metadata")
logger.debug(f"Resolved discovery_url from host metadata: {meta.oidc_endpoint}")
self.discovery_url = meta.oidc_endpoint.replace("{account_id}", self.account_id or "")

def _fix_host_if_needed(self):
updated_host = _fix_host_if_needed(self.host)
Expand Down
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
from .integration.conftest import restorable_env # type: ignore


@pytest.fixture(autouse=True)
def stub_host_metadata(mocker):
from databricks.sdk.oauth import HostMetadata

mocker.patch("databricks.sdk.config.get_host_metadata", return_value=HostMetadata(oidc_endpoint=""))


@credentials_strategy("noop", [])
def noop_credentials(_: any):
return lambda: {}
Expand Down
92 changes: 49 additions & 43 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,69 +820,75 @@ def test_databricks_oidc_endpoints_uses_discovery_url(requests_mock):
"account_id": _DUMMY_ACCOUNT_ID,
"workspace_id": _DUMMY_WORKSPACE_ID,
},
{},
{"experimental_is_unified_host": True},
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're replacing a test for the general case with a test for the experimental flag case. Shouldn't we have both until the flag is removed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this test was always for the unified path.
The always called the resolve_metadata_path, but now this is skipped if the flag is false.

{
"account_id": _DUMMY_ACCOUNT_ID,
"workspace_id": _DUMMY_WORKSPACE_ID,
"discovery_url": f"{_DUMMY_WS_HOST}/oidc",
},
id="workspace-populates-all-fields",
id="unified-populates-all-fields",
),
pytest.param(
_DUMMY_ACC_HOST,
{"oidc_endpoint": f"{_DUMMY_ACC_HOST}/oidc/accounts/{{account_id}}"},
{"account_id": _DUMMY_ACCOUNT_ID},
{"account_id": _DUMMY_ACCOUNT_ID, "experimental_is_unified_host": True},
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

{"discovery_url": f"{_DUMMY_ACC_HOST}/oidc/accounts/{_DUMMY_ACCOUNT_ID}"},
id="account-substitutes-account-id",
id="unified-substitutes-account-id",
),
pytest.param(
_DUMMY_WS_HOST,
{"oidc_endpoint": f"{_DUMMY_WS_HOST}/oidc", "account_id": "other-account", "workspace_id": "other-ws"},
{
"account_id": _DUMMY_ACCOUNT_ID,
"workspace_id": _DUMMY_WORKSPACE_ID,
"experimental_is_unified_host": True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

},
{"account_id": _DUMMY_ACCOUNT_ID, "workspace_id": _DUMMY_WORKSPACE_ID},
{"account_id": _DUMMY_ACCOUNT_ID, "workspace_id": _DUMMY_WORKSPACE_ID},
id="does-not-overwrite-existing-fields",
id="unified-does-not-overwrite-existing-fields",
),
],
)
def test_resolve_host_metadata(requests_mock, host, response_json, config_kwargs, expected_fields):
requests_mock.get(f"{host}/.well-known/databricks-config", json=response_json)
def test_resolve_host_metadata(mocker, host, response_json, config_kwargs, expected_fields):
mocker.patch("databricks.sdk.config.get_host_metadata", return_value=oauth.HostMetadata.from_dict(response_json))
config = Config(host=host, token="t", **config_kwargs)
config._resolve_host_metadata()
for field, expected in expected_fields.items():
assert getattr(config, field) == expected


@pytest.mark.parametrize(
"host,response_json,status_code,config_kwargs,error_match",
[
pytest.param(
_DUMMY_ACC_HOST,
{"oidc_endpoint": f"{_DUMMY_ACC_HOST}/oidc/accounts/{{account_id}}"},
200,
{},
"account_id is not configured",
id="missing-account-id",
),
pytest.param(
_DUMMY_WS_HOST,
{"account_id": _DUMMY_ACCOUNT_ID},
200,
{},
"discovery_url is not configured",
id="missing-oidc-endpoint",
),
pytest.param(
_DUMMY_WS_HOST,
{},
500,
{},
"Failed to fetch host metadata",
id="http-error",
),
],
)
def test_resolve_host_metadata_raises(requests_mock, host, response_json, status_code, config_kwargs, error_match):
requests_mock.get(f"{host}/.well-known/databricks-config", status_code=status_code, json=response_json)
config = Config(host=host, token="t", **config_kwargs)
with pytest.raises(ValueError, match=error_match):
config._resolve_host_metadata()
def test_resolve_host_metadata_missing_account_id(mocker):
"""Raises when the oidc_endpoint template requires account_id but none is configured."""
mocker.patch(
"databricks.sdk.config.get_host_metadata",
return_value=oauth.HostMetadata.from_dict({"oidc_endpoint": f"{_DUMMY_ACC_HOST}/oidc/accounts/{{account_id}}"}),
)
with pytest.raises(ValueError, match="account_id is required to resolve discovery_url"):
Config(host=_DUMMY_ACC_HOST, token="t", experimental_is_unified_host=True)


def test_resolve_host_metadata_no_oidc_endpoint(mocker):
"""No raise when metadata has no oidc_endpoint; discovery_url stays unset."""
mocker.patch(
"databricks.sdk.config.get_host_metadata",
return_value=oauth.HostMetadata.from_dict({"account_id": _DUMMY_ACCOUNT_ID}),
)
config = Config(host=_DUMMY_WS_HOST, token="t", experimental_is_unified_host=True)
assert config.account_id == _DUMMY_ACCOUNT_ID
assert config.discovery_url is None


def test_resolve_host_metadata_http_error(mocker):
"""HTTP failure is swallowed with a warning; fields remain unset."""
mocker.patch(
"databricks.sdk.config.get_host_metadata",
side_effect=ValueError(f"Failed to fetch host metadata from {_DUMMY_WS_HOST}/.well-known/databricks-config"),
)
config = Config(host=_DUMMY_WS_HOST, token="t", experimental_is_unified_host=True)
assert config.account_id is None
assert config.discovery_url is None


def test_resolve_host_metadata_skipped_for_non_unified(mocker):
"""Metadata resolution is skipped entirely for non-unified (workspace/account) hosts."""
mock_get = mocker.patch("databricks.sdk.config.get_host_metadata")
Config(host=_DUMMY_WS_HOST, token="t")
mock_get.assert_not_called()
Loading