diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 250c00a9c..c3b7db59b 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -3,6 +3,7 @@ ## Release v0.75.0 ### New Features and Improvements +* Add support for unified hosts with experimental flag ### Security diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index bbb490ac7..fa61bdbb9 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -6,6 +6,7 @@ import pathlib import sys import urllib.parse +from enum import Enum from typing import Dict, Iterable, List, Optional import requests @@ -19,11 +20,26 @@ DatabricksEnvironment, get_environment_for_hostname) from .oauth import (OidcEndpoints, Token, get_account_endpoints, get_azure_entra_id_workspace_endpoints, - get_workspace_endpoints) + get_unified_endpoints, get_workspace_endpoints) logger = logging.getLogger("databricks.sdk") +class HostType(Enum): + """Enum representing the type of Databricks host.""" + + ACCOUNTS = "accounts" + WORKSPACE = "workspace" + UNIFIED = "unified" + + +class ClientType(Enum): + """Enum representing the type of client configuration.""" + + ACCOUNT = "account" + WORKSPACE = "workspace" + + class ConfigAttribute: """Configuration attribute metadata and descriptor protocols.""" @@ -61,6 +77,10 @@ def with_user_agent_extra(key: str, value: str): class Config: host: str = ConfigAttribute(env="DATABRICKS_HOST") account_id: str = ConfigAttribute(env="DATABRICKS_ACCOUNT_ID") + workspace_id: str = ConfigAttribute(env="DATABRICKS_WORKSPACE_ID") + + # Experimental flag to indicate if the host is a unified host (supports both workspace and account APIs) + experimental_is_unified_host: bool = ConfigAttribute(env="DATABRICKS_EXPERIMENTAL_IS_UNIFIED_HOST") # PAT token. token: str = ConfigAttribute(env="DATABRICKS_TOKEN", auth="pat", sensitive=True) @@ -338,8 +358,65 @@ def is_gcp(self) -> bool: def is_aws(self) -> bool: return self.environment.cloud == Cloud.AWS + @property + def host_type(self) -> HostType: + """Determine the type of host based on the configuration. + + Returns the HostType which can be ACCOUNTS, WORKSPACE, or UNIFIED. + """ + # Check if explicitly marked as unified host + if self.experimental_is_unified_host: + return HostType.UNIFIED + + if not self.host: + return HostType.WORKSPACE + + # Check for accounts host pattern + if self.host.startswith("https://accounts.") or self.host.startswith("https://accounts-dod."): + return HostType.ACCOUNTS + + return HostType.WORKSPACE + + @property + def client_type(self) -> ClientType: + """Determine the type of client configuration. + + This is separate from host_type. For example, a unified host can support both + workspace and account client types. + + Returns ClientType.ACCOUNT or ClientType.WORKSPACE based on the configuration. + + For unified hosts, account_id must be set. If workspace_id is also set, + returns WORKSPACE, otherwise returns ACCOUNT. + """ + host_type = self.host_type + + if host_type == HostType.ACCOUNTS: + return ClientType.ACCOUNT + + if host_type == HostType.WORKSPACE: + return ClientType.WORKSPACE + + if host_type == HostType.UNIFIED: + if not self.account_id: + raise ValueError("Unified host requires account_id to be set") + if self.workspace_id: + return ClientType.WORKSPACE + return ClientType.ACCOUNT + + # Default to workspace for backward compatibility + return ClientType.WORKSPACE + @property def is_account_client(self) -> bool: + """[Deprecated] Use host_type or client_type instead. + + Determines if this is an account client based on the host URL. + """ + if self.experimental_is_unified_host: + raise ValueError( + "is_account_client cannot be used with unified hosts; use host_type or client_type instead" + ) if not self.host: return False return self.host.startswith("https://accounts.") or self.host.startswith("https://accounts-dod.") @@ -394,8 +471,18 @@ def oidc_endpoints(self) -> Optional[OidcEndpoints]: return None if self.is_azure and self.azure_client_id: return get_azure_entra_id_workspace_endpoints(self.host) - if self.is_account_client and self.account_id: + + # Handle unified hosts + if self.host_type == HostType.UNIFIED: + if not self.account_id: + raise ValueError("Unified host requires account_id to be set for OAuth endpoints") + return get_unified_endpoints(self.host, self.account_id) + + # Handle traditional account hosts + if self.host_type == HostType.ACCOUNTS and self.account_id: return get_account_endpoints(self.host, self.account_id) + + # Default to workspace endpoints return get_workspace_endpoints(self.host) def debug_string(self) -> str: diff --git a/databricks/sdk/core.py b/databricks/sdk/core.py index 92e3dbf89..9c701fb86 100644 --- a/databricks/sdk/core.py +++ b/databricks/sdk/core.py @@ -22,11 +22,20 @@ class ApiClient: def __init__(self, cfg: Config): self._cfg = cfg + + # Create header factory that includes both auth and org ID headers + def combined_header_factory(): + headers = cfg.authenticate() + # Add X-Databricks-Org-Id header for workspace clients on unified hosts + if cfg.workspace_id and cfg.host_type == HostType.UNIFIED: + headers["X-Databricks-Org-Id"] = cfg.workspace_id + return headers + self._api_client = _BaseClient( debug_truncate_bytes=cfg.debug_truncate_bytes, retry_timeout_seconds=cfg.retry_timeout_seconds, user_agent_base=cfg.user_agent, - header_factory=cfg.authenticate, + header_factory=combined_header_factory, max_connection_pools=cfg.max_connection_pools, max_connections_per_pool=cfg.max_connections_per_pool, pool_block=True, diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index 72681669f..3e6a97abb 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -418,6 +418,19 @@ def get_workspace_endpoints(host: str, client: _BaseClient = _BaseClient()) -> O return OidcEndpoints.from_dict(resp) +def get_unified_endpoints(host: str, account_id: str, client: _BaseClient = _BaseClient()) -> OidcEndpoints: + """ + Get the OIDC endpoints for a unified host. + :param host: The Databricks unified host. + :param account_id: The account ID. + :return: The OIDC endpoints for the unified host. + """ + host = _fix_host_if_needed(host) + oidc = f"{host}/oidc/accounts/{account_id}/.well-known/oauth-authorization-server" + resp = client.do("GET", oidc) + return OidcEndpoints.from_dict(resp) + + def get_azure_entra_id_workspace_endpoints( host: str, ) -> Optional[OidcEndpoints]: diff --git a/tests/test_config.py b/tests/test_config.py index 59fbf8712..00e7540d9 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -8,7 +8,8 @@ import pytest from databricks.sdk import oauth, useragent -from databricks.sdk.config import Config, with_product, with_user_agent_extra +from databricks.sdk.config import (ClientType, Config, HostType, with_product, + with_user_agent_extra) from databricks.sdk.version import __version__ from .conftest import noop_credentials, set_az_path @@ -260,3 +261,195 @@ def test_oauth_token_reuses_existing_provider(mocker): # Both calls should work and use the same provider instance assert token1 == token2 == mock_token assert mock_oauth_provider.oauth_token.call_count == 2 + + +def test_host_type_workspace(): + """Test that a regular workspace host is identified correctly.""" + config = Config(host="https://test.databricks.com", token="test-token") + assert config.host_type == HostType.WORKSPACE + + +def test_host_type_accounts(): + """Test that an accounts host is identified correctly.""" + config = Config(host="https://accounts.cloud.databricks.com", account_id="test-account", token="test-token") + assert config.host_type == HostType.ACCOUNTS + + +def test_host_type_accounts_dod(): + """Test that an accounts-dod host is identified correctly.""" + config = Config(host="https://accounts-dod.cloud.databricks.us", account_id="test-account", token="test-token") + assert config.host_type == HostType.ACCOUNTS + + +def test_host_type_unified(): + """Test that a unified host is identified when experimental flag is set.""" + config = Config( + host="https://unified.databricks.com", + workspace_id="test-workspace", + experimental_is_unified_host=True, + token="test-token", + ) + assert config.host_type == HostType.UNIFIED + + +def test_client_type_workspace(): + """Test that client type is workspace when workspace_id is set on unified host.""" + config = Config( + host="https://unified.databricks.com", + workspace_id="test-workspace", + account_id="test-account", + experimental_is_unified_host=True, + token="test-token", + ) + assert config.client_type == ClientType.WORKSPACE + + +def test_client_type_account(): + """Test that client type is account when account_id is set without workspace_id.""" + config = Config( + host="https://unified.databricks.com", + account_id="test-account", + experimental_is_unified_host=True, + token="test-token", + ) + assert config.client_type == ClientType.ACCOUNT + + +def test_client_type_workspace_default(): + """Test that client type defaults to workspace.""" + config = Config(host="https://test.databricks.com", token="test-token") + assert config.client_type == ClientType.WORKSPACE + + +def test_client_type_accounts_host(): + """Test that client type is account for accounts host.""" + config = Config( + host="https://accounts.cloud.databricks.com", + account_id="test-account", + token="test-token", + ) + assert config.client_type == ClientType.ACCOUNT + + +def test_client_type_unified_without_account_id(): + """Test that client type raises error for unified host without account_id.""" + config = Config( + host="https://unified.databricks.com", + experimental_is_unified_host=True, + token="test-token", + ) + with pytest.raises(ValueError, match="Unified host requires account_id"): + _ = config.client_type + + +def test_is_account_client_backward_compatibility(): + """Test that is_account_client property still works for backward compatibility.""" + config_workspace = Config(host="https://test.databricks.com", token="test-token") + assert not config_workspace.is_account_client + + config_account = Config(host="https://accounts.cloud.databricks.com", account_id="test-account", token="test-token") + assert config_account.is_account_client + + +def test_is_account_client_raises_on_unified_host(): + """Test that is_account_client raises ValueError when used with unified hosts.""" + config = Config( + host="https://unified.databricks.com", + experimental_is_unified_host=True, + workspace_id="test-workspace", + token="test-token", + ) + with pytest.raises(ValueError, match="is_account_client cannot be used with unified hosts"): + _ = config.is_account_client + + +def test_oidc_endpoints_unified_workspace(mocker, requests_mock): + """Test that oidc_endpoints returns unified endpoints for workspace on unified host.""" + requests_mock.get( + "https://unified.databricks.com/oidc/accounts/test-account/.well-known/oauth-authorization-server", + json={ + "authorization_endpoint": "https://unified.databricks.com/oidc/accounts/test-account/v1/authorize", + "token_endpoint": "https://unified.databricks.com/oidc/accounts/test-account/v1/token", + }, + ) + + config = Config( + host="https://unified.databricks.com", + workspace_id="test-workspace", + account_id="test-account", + experimental_is_unified_host=True, + token="test-token", + ) + + endpoints = config.oidc_endpoints + assert endpoints is not None + assert "accounts/test-account" in endpoints.authorization_endpoint + assert "accounts/test-account" in endpoints.token_endpoint + + +def test_oidc_endpoints_unified_account(mocker, requests_mock): + """Test that oidc_endpoints returns account endpoints for account on unified host.""" + requests_mock.get( + "https://unified.databricks.com/oidc/accounts/test-account/.well-known/oauth-authorization-server", + json={ + "authorization_endpoint": "https://unified.databricks.com/oidc/accounts/test-account/v1/authorize", + "token_endpoint": "https://unified.databricks.com/oidc/accounts/test-account/v1/token", + }, + ) + + config = Config( + host="https://unified.databricks.com", + account_id="test-account", + experimental_is_unified_host=True, + token="test-token", + ) + + endpoints = config.oidc_endpoints + assert endpoints is not None + assert "accounts/test-account" in endpoints.authorization_endpoint + assert "accounts/test-account" in endpoints.token_endpoint + + +def test_oidc_endpoints_unified_missing_ids(): + """Test that oidc_endpoints raises error when unified host lacks required account_id.""" + config = Config(host="https://unified.databricks.com", experimental_is_unified_host=True, token="test-token") + + with pytest.raises(ValueError) as exc_info: + _ = config.oidc_endpoints + + assert "Unified host requires account_id" in str(exc_info.value) + + +def test_workspace_org_id_header_on_unified_host(requests_mock): + """Test that X-Databricks-Org-Id header is added for workspace clients on unified hosts.""" + from databricks.sdk.core import ApiClient + + requests_mock.get("https://unified.databricks.com/api/2.0/test", json={"result": "success"}) + + config = Config( + host="https://unified.databricks.com", + workspace_id="test-workspace-123", + experimental_is_unified_host=True, + token="test-token", + ) + + api_client = ApiClient(config) + api_client.do("GET", "/api/2.0/test") + + # Verify the request was made with the X-Databricks-Org-Id header + assert requests_mock.last_request.headers.get("X-Databricks-Org-Id") == "test-workspace-123" + + +def test_no_org_id_header_on_regular_workspace(requests_mock): + """Test that X-Databricks-Org-Id header is NOT added for regular workspace hosts.""" + from databricks.sdk.core import ApiClient + + requests_mock.get("https://test.databricks.com/api/2.0/test", json={"result": "success"}) + + config = Config(host="https://test.databricks.com", token="test-token") + + api_client = ApiClient(config) + api_client.do("GET", "/api/2.0/test") + + # Verify the X-Databricks-Org-Id header was NOT added + assert "X-Databricks-Org-Id" not in requests_mock.last_request.headers