Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Release v0.75.0

### New Features and Improvements
* Add support for unified hosts with experimental flag

### Security

Expand Down
91 changes: 89 additions & 2 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pathlib
import sys
import urllib.parse
from enum import Enum
from typing import Dict, Iterable, List, Optional

import requests
Expand All @@ -19,11 +20,26 @@
DatabricksEnvironment, get_environment_for_hostname)
from .oauth import (OidcEndpoints, Token, get_account_endpoints,
get_azure_entra_id_workspace_endpoints,
get_workspace_endpoints)
get_unified_endpoints, get_workspace_endpoints)

logger = logging.getLogger("databricks.sdk")


class HostType(Enum):
"""Enum representing the type of Databricks host."""

ACCOUNTS = "accounts"
WORKSPACE = "workspace"
UNIFIED = "unified"


class ClientType(Enum):
"""Enum representing the type of client configuration."""

ACCOUNT = "account"
WORKSPACE = "workspace"


class ConfigAttribute:
"""Configuration attribute metadata and descriptor protocols."""

Expand Down Expand Up @@ -61,6 +77,10 @@ def with_user_agent_extra(key: str, value: str):
class Config:
host: str = ConfigAttribute(env="DATABRICKS_HOST")
account_id: str = ConfigAttribute(env="DATABRICKS_ACCOUNT_ID")
workspace_id: str = ConfigAttribute(env="DATABRICKS_WORKSPACE_ID")

# Experimental flag to indicate if the host is a unified host (supports both workspace and account APIs)
experimental_is_unified_host: bool = ConfigAttribute(env="DATABRICKS_EXPERIMENTAL_IS_UNIFIED_HOST")

# PAT token.
token: str = ConfigAttribute(env="DATABRICKS_TOKEN", auth="pat", sensitive=True)
Expand Down Expand Up @@ -338,8 +358,65 @@ def is_gcp(self) -> bool:
def is_aws(self) -> bool:
return self.environment.cloud == Cloud.AWS

@property
def host_type(self) -> HostType:
"""Determine the type of host based on the configuration.

Returns the HostType which can be ACCOUNTS, WORKSPACE, or UNIFIED.
"""
# Check if explicitly marked as unified host
if self.experimental_is_unified_host:
return HostType.UNIFIED

if not self.host:
return HostType.WORKSPACE

# Check for accounts host pattern
if self.host.startswith("https://accounts.") or self.host.startswith("https://accounts-dod."):
return HostType.ACCOUNTS

return HostType.WORKSPACE

@property
def client_type(self) -> ClientType:
"""Determine the type of client configuration.

This is separate from host_type. For example, a unified host can support both
workspace and account client types.

Returns ClientType.ACCOUNT or ClientType.WORKSPACE based on the configuration.

For unified hosts, account_id must be set. If workspace_id is also set,
returns WORKSPACE, otherwise returns ACCOUNT.
"""
host_type = self.host_type

if host_type == HostType.ACCOUNTS:
return ClientType.ACCOUNT

if host_type == HostType.WORKSPACE:
return ClientType.WORKSPACE

if host_type == HostType.UNIFIED:
if not self.account_id:
raise ValueError("Unified host requires account_id to be set")
if self.workspace_id:
return ClientType.WORKSPACE
return ClientType.ACCOUNT

# Default to workspace for backward compatibility
return ClientType.WORKSPACE

@property
def is_account_client(self) -> bool:
"""[Deprecated] Use host_type or client_type instead.

Determines if this is an account client based on the host URL.
"""
if self.experimental_is_unified_host:
raise ValueError(
"is_account_client cannot be used with unified hosts; use host_type or client_type instead"
)
if not self.host:
return False
return self.host.startswith("https://accounts.") or self.host.startswith("https://accounts-dod.")
Expand Down Expand Up @@ -394,8 +471,18 @@ def oidc_endpoints(self) -> Optional[OidcEndpoints]:
return None
if self.is_azure and self.azure_client_id:
return get_azure_entra_id_workspace_endpoints(self.host)
if self.is_account_client and self.account_id:

# Handle unified hosts
if self.host_type == HostType.UNIFIED:
if not self.account_id:
raise ValueError("Unified host requires account_id to be set for OAuth endpoints")
return get_unified_endpoints(self.host, self.account_id)

# Handle traditional account hosts
if self.host_type == HostType.ACCOUNTS and self.account_id:
return get_account_endpoints(self.host, self.account_id)

# Default to workspace endpoints
return get_workspace_endpoints(self.host)

def debug_string(self) -> str:
Expand Down
11 changes: 10 additions & 1 deletion databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,20 @@ class ApiClient:

def __init__(self, cfg: Config):
self._cfg = cfg

# Create header factory that includes both auth and org ID headers
def combined_header_factory():
headers = cfg.authenticate()
# Add X-Databricks-Org-Id header for workspace clients on unified hosts
if cfg.workspace_id and cfg.host_type == HostType.UNIFIED:
headers["X-Databricks-Org-Id"] = cfg.workspace_id
return headers

self._api_client = _BaseClient(
debug_truncate_bytes=cfg.debug_truncate_bytes,
retry_timeout_seconds=cfg.retry_timeout_seconds,
user_agent_base=cfg.user_agent,
header_factory=cfg.authenticate,
header_factory=combined_header_factory,
max_connection_pools=cfg.max_connection_pools,
max_connections_per_pool=cfg.max_connections_per_pool,
pool_block=True,
Expand Down
13 changes: 13 additions & 0 deletions databricks/sdk/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,19 @@ def get_workspace_endpoints(host: str, client: _BaseClient = _BaseClient()) -> O
return OidcEndpoints.from_dict(resp)


def get_unified_endpoints(host: str, account_id: str, client: _BaseClient = _BaseClient()) -> OidcEndpoints:
"""
Get the OIDC endpoints for a unified host.
:param host: The Databricks unified host.
:param account_id: The account ID.
:return: The OIDC endpoints for the unified host.
"""
host = _fix_host_if_needed(host)
oidc = f"{host}/oidc/accounts/{account_id}/.well-known/oauth-authorization-server"
resp = client.do("GET", oidc)
return OidcEndpoints.from_dict(resp)


def get_azure_entra_id_workspace_endpoints(
host: str,
) -> Optional[OidcEndpoints]:
Expand Down
195 changes: 194 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import pytest

from databricks.sdk import oauth, useragent
from databricks.sdk.config import Config, with_product, with_user_agent_extra
from databricks.sdk.config import (ClientType, Config, HostType, with_product,
with_user_agent_extra)
from databricks.sdk.version import __version__

from .conftest import noop_credentials, set_az_path
Expand Down Expand Up @@ -260,3 +261,195 @@ def test_oauth_token_reuses_existing_provider(mocker):
# Both calls should work and use the same provider instance
assert token1 == token2 == mock_token
assert mock_oauth_provider.oauth_token.call_count == 2


def test_host_type_workspace():
"""Test that a regular workspace host is identified correctly."""
config = Config(host="https://test.databricks.com", token="test-token")
assert config.host_type == HostType.WORKSPACE


def test_host_type_accounts():
"""Test that an accounts host is identified correctly."""
config = Config(host="https://accounts.cloud.databricks.com", account_id="test-account", token="test-token")
assert config.host_type == HostType.ACCOUNTS


def test_host_type_accounts_dod():
"""Test that an accounts-dod host is identified correctly."""
config = Config(host="https://accounts-dod.cloud.databricks.us", account_id="test-account", token="test-token")
assert config.host_type == HostType.ACCOUNTS


def test_host_type_unified():
"""Test that a unified host is identified when experimental flag is set."""
config = Config(
host="https://unified.databricks.com",
workspace_id="test-workspace",
experimental_is_unified_host=True,
token="test-token",
)
assert config.host_type == HostType.UNIFIED


def test_client_type_workspace():
"""Test that client type is workspace when workspace_id is set on unified host."""
config = Config(
host="https://unified.databricks.com",
workspace_id="test-workspace",
account_id="test-account",
experimental_is_unified_host=True,
token="test-token",
)
assert config.client_type == ClientType.WORKSPACE


def test_client_type_account():
"""Test that client type is account when account_id is set without workspace_id."""
config = Config(
host="https://unified.databricks.com",
account_id="test-account",
experimental_is_unified_host=True,
token="test-token",
)
assert config.client_type == ClientType.ACCOUNT


def test_client_type_workspace_default():
"""Test that client type defaults to workspace."""
config = Config(host="https://test.databricks.com", token="test-token")
assert config.client_type == ClientType.WORKSPACE


def test_client_type_accounts_host():
"""Test that client type is account for accounts host."""
config = Config(
host="https://accounts.cloud.databricks.com",
account_id="test-account",
token="test-token",
)
assert config.client_type == ClientType.ACCOUNT


def test_client_type_unified_without_account_id():
"""Test that client type raises error for unified host without account_id."""
config = Config(
host="https://unified.databricks.com",
experimental_is_unified_host=True,
token="test-token",
)
with pytest.raises(ValueError, match="Unified host requires account_id"):
_ = config.client_type


def test_is_account_client_backward_compatibility():
"""Test that is_account_client property still works for backward compatibility."""
config_workspace = Config(host="https://test.databricks.com", token="test-token")
assert not config_workspace.is_account_client

config_account = Config(host="https://accounts.cloud.databricks.com", account_id="test-account", token="test-token")
assert config_account.is_account_client


def test_is_account_client_raises_on_unified_host():
"""Test that is_account_client raises ValueError when used with unified hosts."""
config = Config(
host="https://unified.databricks.com",
experimental_is_unified_host=True,
workspace_id="test-workspace",
token="test-token",
)
with pytest.raises(ValueError, match="is_account_client cannot be used with unified hosts"):
_ = config.is_account_client


def test_oidc_endpoints_unified_workspace(mocker, requests_mock):
"""Test that oidc_endpoints returns unified endpoints for workspace on unified host."""
requests_mock.get(
"https://unified.databricks.com/oidc/accounts/test-account/.well-known/oauth-authorization-server",
json={
"authorization_endpoint": "https://unified.databricks.com/oidc/accounts/test-account/v1/authorize",
"token_endpoint": "https://unified.databricks.com/oidc/accounts/test-account/v1/token",
},
)

config = Config(
host="https://unified.databricks.com",
workspace_id="test-workspace",
account_id="test-account",
experimental_is_unified_host=True,
token="test-token",
)

endpoints = config.oidc_endpoints
assert endpoints is not None
assert "accounts/test-account" in endpoints.authorization_endpoint
assert "accounts/test-account" in endpoints.token_endpoint


def test_oidc_endpoints_unified_account(mocker, requests_mock):
"""Test that oidc_endpoints returns account endpoints for account on unified host."""
requests_mock.get(
"https://unified.databricks.com/oidc/accounts/test-account/.well-known/oauth-authorization-server",
json={
"authorization_endpoint": "https://unified.databricks.com/oidc/accounts/test-account/v1/authorize",
"token_endpoint": "https://unified.databricks.com/oidc/accounts/test-account/v1/token",
},
)

config = Config(
host="https://unified.databricks.com",
account_id="test-account",
experimental_is_unified_host=True,
token="test-token",
)

endpoints = config.oidc_endpoints
assert endpoints is not None
assert "accounts/test-account" in endpoints.authorization_endpoint
assert "accounts/test-account" in endpoints.token_endpoint


def test_oidc_endpoints_unified_missing_ids():
"""Test that oidc_endpoints raises error when unified host lacks required account_id."""
config = Config(host="https://unified.databricks.com", experimental_is_unified_host=True, token="test-token")

with pytest.raises(ValueError) as exc_info:
_ = config.oidc_endpoints

assert "Unified host requires account_id" in str(exc_info.value)


def test_workspace_org_id_header_on_unified_host(requests_mock):
"""Test that X-Databricks-Org-Id header is added for workspace clients on unified hosts."""
from databricks.sdk.core import ApiClient

requests_mock.get("https://unified.databricks.com/api/2.0/test", json={"result": "success"})

config = Config(
host="https://unified.databricks.com",
workspace_id="test-workspace-123",
experimental_is_unified_host=True,
token="test-token",
)

api_client = ApiClient(config)
api_client.do("GET", "/api/2.0/test")

# Verify the request was made with the X-Databricks-Org-Id header
assert requests_mock.last_request.headers.get("X-Databricks-Org-Id") == "test-workspace-123"


def test_no_org_id_header_on_regular_workspace(requests_mock):
"""Test that X-Databricks-Org-Id header is NOT added for regular workspace hosts."""
from databricks.sdk.core import ApiClient

requests_mock.get("https://test.databricks.com/api/2.0/test", json={"result": "success"})

config = Config(host="https://test.databricks.com", token="test-token")

api_client = ApiClient(config)
api_client.do("GET", "/api/2.0/test")

# Verify the X-Databricks-Org-Id header was NOT added
assert "X-Databricks-Org-Id" not in requests_mock.last_request.headers
Loading