Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ def __repr__(self) -> str:
return f"<ConfigAttribute '{self.name}' {self.transform.__name__}>"


def _parse_cloud(value) -> Optional[Cloud]:
"""Parse a cloud value from string or Cloud instance; returns None for unknown or empty."""
if value is None:
return None
if isinstance(value, Cloud):
return value
return Cloud.parse(str(value))


def _parse_scopes(value):
"""Parse scopes into a deduplicated, sorted list."""
if value is None:
Expand Down Expand Up @@ -84,6 +93,10 @@ class Config:
# Experimental flag to indicate if the host is a unified host (supports both workspace and account APIs)
experimental_is_unified_host: bool = ConfigAttribute(env="DATABRICKS_EXPERIMENTAL_IS_UNIFIED_HOST")

# [Experimental] Cloud provider. When set, is_aws/is_azure/is_gcp use this value directly
# instead of inferring from hostname. Populated automatically from /.well-known/databricks-config.
cloud: Cloud = ConfigAttribute(env="DATABRICKS_CLOUD", transform=_parse_cloud)

# [Experimental] OpenID Connect discovery URL. When set, OIDC endpoints are fetched directly
# from this URL instead of the default host-type-based well-known endpoint logic.
discovery_url: str = ConfigAttribute(env="DATABRICKS_DISCOVERY_URL")
Expand Down Expand Up @@ -380,14 +393,20 @@ def environment(self) -> DatabricksEnvironment:
def is_azure(self) -> bool:
if self.azure_workspace_resource_id:
return True
if self.cloud:
return self.cloud == Cloud.AZURE
return self.environment.cloud == Cloud.AZURE

@property
def is_gcp(self) -> bool:
if self.cloud:
return self.cloud == Cloud.GCP
return self.environment.cloud == Cloud.GCP

@property
def is_aws(self) -> bool:
if self.cloud:
return self.cloud == Cloud.AWS
return self.environment.cloud == Cloud.AWS

@property
Expand Down Expand Up @@ -653,6 +672,9 @@ def _resolve_host_metadata(self) -> None:
raise ValueError("account_id is required to resolve discovery_url from host metadata")
logger.debug(f"Resolved discovery_url from host metadata: {meta.oidc_endpoint}")
self.discovery_url = meta.oidc_endpoint.replace("{account_id}", self.account_id or "")
if not self.cloud and meta.cloud:
logger.debug(f"Resolved cloud from host metadata: {meta.cloud.value}")
self.cloud = meta.cloud

def _fix_host_if_needed(self):
updated_host = _fix_host_if_needed(self.host)
Expand Down
10 changes: 10 additions & 0 deletions databricks/sdk/environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ class Cloud(Enum):
AZURE = "AZURE"
GCP = "GCP"

@classmethod
def parse(cls, value: str) -> Optional["Cloud"]:
"""Case-insensitive parse. Returns None for empty or unrecognized values."""
if not value:
return None
try:
return cls(value.upper())
except ValueError:
return None


@dataclass
class DatabricksEnvironment:
Expand Down
4 changes: 4 additions & 0 deletions databricks/sdk/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import requests.auth

from ._base_client import _BaseClient, _fix_host_if_needed
from .environments import Cloud

# Error code for PKCE flow in Azure Active Directory, that gets additional retry.
# See https://stackoverflow.com/a/75466778/277035 for more info
Expand Down Expand Up @@ -426,20 +427,23 @@ class HostMetadata:
oidc_endpoint: str
account_id: Optional[str] = None
workspace_id: Optional[str] = None
cloud: Optional[Cloud] = None

@staticmethod
def from_dict(d: dict) -> "HostMetadata":
return HostMetadata(
oidc_endpoint=d.get("oidc_endpoint", ""),
account_id=d.get("account_id"),
workspace_id=d.get("workspace_id"),
cloud=Cloud.parse(d.get("cloud", "")),
)

def as_dict(self) -> dict:
return {
"oidc_endpoint": self.oidc_endpoint,
"account_id": self.account_id,
"workspace_id": self.workspace_id,
"cloud": self.cloud.value if self.cloud else None,
}


Expand Down
119 changes: 119 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from databricks.sdk import AccountClient, WorkspaceClient, oauth, useragent
from databricks.sdk.config import (ClientType, Config, HostType, with_product,
with_user_agent_extra)
from databricks.sdk.environments import Cloud
from databricks.sdk.version import __version__

from .conftest import noop_credentials, set_az_path, set_home
Expand Down Expand Up @@ -892,3 +893,121 @@ def test_resolve_host_metadata_skipped_for_non_unified(mocker):
mock_get = mocker.patch("databricks.sdk.config.get_host_metadata")
Config(host=_DUMMY_WS_HOST, token="t")
mock_get.assert_not_called()


# ---------------------------------------------------------------------------
# Cloud field tests
# ---------------------------------------------------------------------------


@pytest.mark.parametrize(
"cloud_str,expected_cloud",
[
("AWS", Cloud.AWS),
("aws", Cloud.AWS),
("Azure", Cloud.AZURE),
("AZURE", Cloud.AZURE),
("GCP", Cloud.GCP),
("gcp", Cloud.GCP),
],
)
def test_cloud_parse_case_insensitive(cloud_str, expected_cloud):
"""Cloud.parse handles any casing."""
assert Cloud.parse(cloud_str) == expected_cloud


def test_cloud_parse_unknown_returns_none():
"""Cloud.parse returns None for unrecognized values (forward compatibility)."""
assert Cloud.parse("UNKNOWN_FUTURE_CLOUD") is None


def test_cloud_parse_empty_returns_none():
assert Cloud.parse("") is None
assert Cloud.parse(None) is None


def test_cloud_field_overrides_dns_detection_aws():
"""Explicit cloud=AWS wins over hostname-based detection."""
config = Config(host="https://myworkspace.azuredatabricks.net", token="t", cloud="AWS")
assert config.is_aws
assert not config.is_azure
assert not config.is_gcp


def test_cloud_field_overrides_dns_detection_azure():
"""Explicit cloud=AZURE wins over hostname-based detection."""
config = Config(host="https://myworkspace.cloud.databricks.com", token="t", cloud="AZURE")
assert config.is_azure
assert not config.is_aws
assert not config.is_gcp


def test_cloud_field_overrides_dns_detection_gcp():
"""Explicit cloud=GCP wins over hostname-based detection."""
config = Config(host="https://myworkspace.cloud.databricks.com", token="t", cloud="GCP")
assert config.is_gcp
assert not config.is_aws
assert not config.is_azure


def test_cloud_field_falls_back_to_dns_when_unset():
"""When cloud is not set, falls back to DNS-based detection."""
config = Config(host="https://myworkspace.azuredatabricks.net", token="t")
assert config.is_azure
assert not config.is_aws


def test_cloud_field_azure_resource_id_still_wins():
"""azure_workspace_resource_id takes precedence over cloud field for is_azure."""
config = Config(
host="https://myworkspace.cloud.databricks.com",
credentials_strategy=noop_credentials,
cloud="AWS",
azure_workspace_resource_id="/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Databricks/workspaces/ws",
)
assert config.is_azure


def test_resolve_host_metadata_populates_cloud(mocker):
"""Cloud is populated from the discovery endpoint."""
mocker.patch(
"databricks.sdk.config.get_host_metadata",
return_value=oauth.HostMetadata.from_dict(
{
"oidc_endpoint": f"{_DUMMY_WS_HOST}/oidc",
"cloud": "AWS",
}
),
)
config = Config(host=_DUMMY_WS_HOST, token="t", experimental_is_unified_host=True)
assert config.cloud == Cloud.AWS


def test_resolve_host_metadata_cloud_not_overwritten(mocker):
"""Explicit cloud config is not overwritten by the discovery endpoint."""
mocker.patch(
"databricks.sdk.config.get_host_metadata",
return_value=oauth.HostMetadata.from_dict(
{
"oidc_endpoint": f"{_DUMMY_WS_HOST}/oidc",
"cloud": "AZURE",
}
),
)
config = Config(
host=_DUMMY_WS_HOST,
token="t",
experimental_is_unified_host=True,
cloud="AWS",
)
assert config.cloud == Cloud.AWS


def test_resolve_host_metadata_cloud_missing_in_response(mocker):
"""When endpoint omits cloud, the field stays unset (falls back to DNS)."""
mocker.patch(
"databricks.sdk.config.get_host_metadata",
return_value=oauth.HostMetadata.from_dict({"oidc_endpoint": f"{_DUMMY_WS_HOST}/oidc"}),
)
config = Config(host=_DUMMY_WS_HOST, token="t", experimental_is_unified_host=True)
assert config.cloud is None
Loading