Skip to content

Commit c48681c

Browse files
authored
[Feature] Add support for Unified Host with experimental flag (#1135)
## What changes are proposed in this pull request? This PR adds support for unified host: - Separates client type from host type determination, deprecating is_account_client and replacing it with host_type and client_type properties using new HostType and ClientType enums - Adds an experimental flag to indicate if a host is unified: experimental_is_unified_host - Adds a workspace_id attribute to Config, which is necessary for workspace clients that talk to unified hosts - Adds get_unified_endpoints() function, which is used in the OIDC endpoint resolution logic to discover OAuth endpoints on unified hosts - Adds header injection logic in ApiClient which adds an X-Databricks-Org-Id header to requests made by workspace clients on unified hosts - Adds comprehensive test coverage including unit tests for host/config type detection, OIDC endpoint resolution, header injection, and integration tests Similar to what is done for databricks/databricks-sdk-go#1307 ## How is this tested? - Unit tests - Manually integration tested <img width="1275" height="331" alt="image" src="https://github.com/user-attachments/assets/f456e104-6210-43a4-a99a-1c2ff0b134a2" /> - All existing integration tests pass Note: Integration test would be added in another PR once the test infra supports spog
1 parent 37611d6 commit c48681c

File tree

5 files changed

+307
-4
lines changed

5 files changed

+307
-4
lines changed

NEXT_CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## Release v0.76.0
44

55
### New Features and Improvements
6+
* Add support for unified hosts with experimental flag
67

78
### Security
89

databricks/sdk/config.py

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pathlib
77
import sys
88
import urllib.parse
9+
from enum import Enum
910
from typing import Dict, Iterable, List, Optional
1011

1112
import requests
@@ -19,11 +20,26 @@
1920
DatabricksEnvironment, get_environment_for_hostname)
2021
from .oauth import (OidcEndpoints, Token, get_account_endpoints,
2122
get_azure_entra_id_workspace_endpoints,
22-
get_workspace_endpoints)
23+
get_unified_endpoints, get_workspace_endpoints)
2324

2425
logger = logging.getLogger("databricks.sdk")
2526

2627

28+
class HostType(Enum):
29+
"""Enum representing the type of Databricks host."""
30+
31+
ACCOUNTS = "accounts"
32+
WORKSPACE = "workspace"
33+
UNIFIED = "unified"
34+
35+
36+
class ClientType(Enum):
37+
"""Enum representing the type of client configuration."""
38+
39+
ACCOUNT = "account"
40+
WORKSPACE = "workspace"
41+
42+
2743
class ConfigAttribute:
2844
"""Configuration attribute metadata and descriptor protocols."""
2945

@@ -61,6 +77,10 @@ def with_user_agent_extra(key: str, value: str):
6177
class Config:
6278
host: str = ConfigAttribute(env="DATABRICKS_HOST")
6379
account_id: str = ConfigAttribute(env="DATABRICKS_ACCOUNT_ID")
80+
workspace_id: str = ConfigAttribute(env="DATABRICKS_WORKSPACE_ID")
81+
82+
# Experimental flag to indicate if the host is a unified host (supports both workspace and account APIs)
83+
experimental_is_unified_host: bool = ConfigAttribute(env="DATABRICKS_EXPERIMENTAL_IS_UNIFIED_HOST")
6484

6585
# PAT token.
6686
token: str = ConfigAttribute(env="DATABRICKS_TOKEN", auth="pat", sensitive=True)
@@ -338,8 +358,65 @@ def is_gcp(self) -> bool:
338358
def is_aws(self) -> bool:
339359
return self.environment.cloud == Cloud.AWS
340360

361+
@property
362+
def host_type(self) -> HostType:
363+
"""Determine the type of host based on the configuration.
364+
365+
Returns the HostType which can be ACCOUNTS, WORKSPACE, or UNIFIED.
366+
"""
367+
# Check if explicitly marked as unified host
368+
if self.experimental_is_unified_host:
369+
return HostType.UNIFIED
370+
371+
if not self.host:
372+
return HostType.WORKSPACE
373+
374+
# Check for accounts host pattern
375+
if self.host.startswith("https://accounts.") or self.host.startswith("https://accounts-dod."):
376+
return HostType.ACCOUNTS
377+
378+
return HostType.WORKSPACE
379+
380+
@property
381+
def client_type(self) -> ClientType:
382+
"""Determine the type of client configuration.
383+
384+
This is separate from host_type. For example, a unified host can support both
385+
workspace and account client types.
386+
387+
Returns ClientType.ACCOUNT or ClientType.WORKSPACE based on the configuration.
388+
389+
For unified hosts, account_id must be set. If workspace_id is also set,
390+
returns WORKSPACE, otherwise returns ACCOUNT.
391+
"""
392+
host_type = self.host_type
393+
394+
if host_type == HostType.ACCOUNTS:
395+
return ClientType.ACCOUNT
396+
397+
if host_type == HostType.WORKSPACE:
398+
return ClientType.WORKSPACE
399+
400+
if host_type == HostType.UNIFIED:
401+
if not self.account_id:
402+
raise ValueError("Unified host requires account_id to be set")
403+
if self.workspace_id:
404+
return ClientType.WORKSPACE
405+
return ClientType.ACCOUNT
406+
407+
# Default to workspace for backward compatibility
408+
return ClientType.WORKSPACE
409+
341410
@property
342411
def is_account_client(self) -> bool:
412+
"""[Deprecated] Use host_type or client_type instead.
413+
414+
Determines if this is an account client based on the host URL.
415+
"""
416+
if self.experimental_is_unified_host:
417+
raise ValueError(
418+
"is_account_client cannot be used with unified hosts; use host_type or client_type instead"
419+
)
343420
if not self.host:
344421
return False
345422
return self.host.startswith("https://accounts.") or self.host.startswith("https://accounts-dod.")
@@ -394,8 +471,18 @@ def oidc_endpoints(self) -> Optional[OidcEndpoints]:
394471
return None
395472
if self.is_azure and self.azure_client_id:
396473
return get_azure_entra_id_workspace_endpoints(self.host)
397-
if self.is_account_client and self.account_id:
474+
475+
# Handle unified hosts
476+
if self.host_type == HostType.UNIFIED:
477+
if not self.account_id:
478+
raise ValueError("Unified host requires account_id to be set for OAuth endpoints")
479+
return get_unified_endpoints(self.host, self.account_id)
480+
481+
# Handle traditional account hosts
482+
if self.host_type == HostType.ACCOUNTS and self.account_id:
398483
return get_account_endpoints(self.host, self.account_id)
484+
485+
# Default to workspace endpoints
399486
return get_workspace_endpoints(self.host)
400487

401488
def debug_string(self) -> str:

databricks/sdk/core.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,20 @@ class ApiClient:
2222

2323
def __init__(self, cfg: Config):
2424
self._cfg = cfg
25+
26+
# Create header factory that includes both auth and org ID headers
27+
def combined_header_factory():
28+
headers = cfg.authenticate()
29+
# Add X-Databricks-Org-Id header for workspace clients on unified hosts
30+
if cfg.workspace_id and cfg.host_type == HostType.UNIFIED:
31+
headers["X-Databricks-Org-Id"] = cfg.workspace_id
32+
return headers
33+
2534
self._api_client = _BaseClient(
2635
debug_truncate_bytes=cfg.debug_truncate_bytes,
2736
retry_timeout_seconds=cfg.retry_timeout_seconds,
2837
user_agent_base=cfg.user_agent,
29-
header_factory=cfg.authenticate,
38+
header_factory=combined_header_factory,
3039
max_connection_pools=cfg.max_connection_pools,
3140
max_connections_per_pool=cfg.max_connections_per_pool,
3241
pool_block=True,

databricks/sdk/oauth.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,19 @@ def get_workspace_endpoints(host: str, client: _BaseClient = _BaseClient()) -> O
418418
return OidcEndpoints.from_dict(resp)
419419

420420

421+
def get_unified_endpoints(host: str, account_id: str, client: _BaseClient = _BaseClient()) -> OidcEndpoints:
422+
"""
423+
Get the OIDC endpoints for a unified host.
424+
:param host: The Databricks unified host.
425+
:param account_id: The account ID.
426+
:return: The OIDC endpoints for the unified host.
427+
"""
428+
host = _fix_host_if_needed(host)
429+
oidc = f"{host}/oidc/accounts/{account_id}/.well-known/oauth-authorization-server"
430+
resp = client.do("GET", oidc)
431+
return OidcEndpoints.from_dict(resp)
432+
433+
421434
def get_azure_entra_id_workspace_endpoints(
422435
host: str,
423436
) -> Optional[OidcEndpoints]:

tests/test_config.py

Lines changed: 194 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import pytest
99

1010
from databricks.sdk import oauth, useragent
11-
from databricks.sdk.config import Config, with_product, with_user_agent_extra
11+
from databricks.sdk.config import (ClientType, Config, HostType, with_product,
12+
with_user_agent_extra)
1213
from databricks.sdk.version import __version__
1314

1415
from .conftest import noop_credentials, set_az_path
@@ -260,3 +261,195 @@ def test_oauth_token_reuses_existing_provider(mocker):
260261
# Both calls should work and use the same provider instance
261262
assert token1 == token2 == mock_token
262263
assert mock_oauth_provider.oauth_token.call_count == 2
264+
265+
266+
def test_host_type_workspace():
267+
"""Test that a regular workspace host is identified correctly."""
268+
config = Config(host="https://test.databricks.com", token="test-token")
269+
assert config.host_type == HostType.WORKSPACE
270+
271+
272+
def test_host_type_accounts():
273+
"""Test that an accounts host is identified correctly."""
274+
config = Config(host="https://accounts.cloud.databricks.com", account_id="test-account", token="test-token")
275+
assert config.host_type == HostType.ACCOUNTS
276+
277+
278+
def test_host_type_accounts_dod():
279+
"""Test that an accounts-dod host is identified correctly."""
280+
config = Config(host="https://accounts-dod.cloud.databricks.us", account_id="test-account", token="test-token")
281+
assert config.host_type == HostType.ACCOUNTS
282+
283+
284+
def test_host_type_unified():
285+
"""Test that a unified host is identified when experimental flag is set."""
286+
config = Config(
287+
host="https://unified.databricks.com",
288+
workspace_id="test-workspace",
289+
experimental_is_unified_host=True,
290+
token="test-token",
291+
)
292+
assert config.host_type == HostType.UNIFIED
293+
294+
295+
def test_client_type_workspace():
296+
"""Test that client type is workspace when workspace_id is set on unified host."""
297+
config = Config(
298+
host="https://unified.databricks.com",
299+
workspace_id="test-workspace",
300+
account_id="test-account",
301+
experimental_is_unified_host=True,
302+
token="test-token",
303+
)
304+
assert config.client_type == ClientType.WORKSPACE
305+
306+
307+
def test_client_type_account():
308+
"""Test that client type is account when account_id is set without workspace_id."""
309+
config = Config(
310+
host="https://unified.databricks.com",
311+
account_id="test-account",
312+
experimental_is_unified_host=True,
313+
token="test-token",
314+
)
315+
assert config.client_type == ClientType.ACCOUNT
316+
317+
318+
def test_client_type_workspace_default():
319+
"""Test that client type defaults to workspace."""
320+
config = Config(host="https://test.databricks.com", token="test-token")
321+
assert config.client_type == ClientType.WORKSPACE
322+
323+
324+
def test_client_type_accounts_host():
325+
"""Test that client type is account for accounts host."""
326+
config = Config(
327+
host="https://accounts.cloud.databricks.com",
328+
account_id="test-account",
329+
token="test-token",
330+
)
331+
assert config.client_type == ClientType.ACCOUNT
332+
333+
334+
def test_client_type_unified_without_account_id():
335+
"""Test that client type raises error for unified host without account_id."""
336+
config = Config(
337+
host="https://unified.databricks.com",
338+
experimental_is_unified_host=True,
339+
token="test-token",
340+
)
341+
with pytest.raises(ValueError, match="Unified host requires account_id"):
342+
_ = config.client_type
343+
344+
345+
def test_is_account_client_backward_compatibility():
346+
"""Test that is_account_client property still works for backward compatibility."""
347+
config_workspace = Config(host="https://test.databricks.com", token="test-token")
348+
assert not config_workspace.is_account_client
349+
350+
config_account = Config(host="https://accounts.cloud.databricks.com", account_id="test-account", token="test-token")
351+
assert config_account.is_account_client
352+
353+
354+
def test_is_account_client_raises_on_unified_host():
355+
"""Test that is_account_client raises ValueError when used with unified hosts."""
356+
config = Config(
357+
host="https://unified.databricks.com",
358+
experimental_is_unified_host=True,
359+
workspace_id="test-workspace",
360+
token="test-token",
361+
)
362+
with pytest.raises(ValueError, match="is_account_client cannot be used with unified hosts"):
363+
_ = config.is_account_client
364+
365+
366+
def test_oidc_endpoints_unified_workspace(mocker, requests_mock):
367+
"""Test that oidc_endpoints returns unified endpoints for workspace on unified host."""
368+
requests_mock.get(
369+
"https://unified.databricks.com/oidc/accounts/test-account/.well-known/oauth-authorization-server",
370+
json={
371+
"authorization_endpoint": "https://unified.databricks.com/oidc/accounts/test-account/v1/authorize",
372+
"token_endpoint": "https://unified.databricks.com/oidc/accounts/test-account/v1/token",
373+
},
374+
)
375+
376+
config = Config(
377+
host="https://unified.databricks.com",
378+
workspace_id="test-workspace",
379+
account_id="test-account",
380+
experimental_is_unified_host=True,
381+
token="test-token",
382+
)
383+
384+
endpoints = config.oidc_endpoints
385+
assert endpoints is not None
386+
assert "accounts/test-account" in endpoints.authorization_endpoint
387+
assert "accounts/test-account" in endpoints.token_endpoint
388+
389+
390+
def test_oidc_endpoints_unified_account(mocker, requests_mock):
391+
"""Test that oidc_endpoints returns account endpoints for account on unified host."""
392+
requests_mock.get(
393+
"https://unified.databricks.com/oidc/accounts/test-account/.well-known/oauth-authorization-server",
394+
json={
395+
"authorization_endpoint": "https://unified.databricks.com/oidc/accounts/test-account/v1/authorize",
396+
"token_endpoint": "https://unified.databricks.com/oidc/accounts/test-account/v1/token",
397+
},
398+
)
399+
400+
config = Config(
401+
host="https://unified.databricks.com",
402+
account_id="test-account",
403+
experimental_is_unified_host=True,
404+
token="test-token",
405+
)
406+
407+
endpoints = config.oidc_endpoints
408+
assert endpoints is not None
409+
assert "accounts/test-account" in endpoints.authorization_endpoint
410+
assert "accounts/test-account" in endpoints.token_endpoint
411+
412+
413+
def test_oidc_endpoints_unified_missing_ids():
414+
"""Test that oidc_endpoints raises error when unified host lacks required account_id."""
415+
config = Config(host="https://unified.databricks.com", experimental_is_unified_host=True, token="test-token")
416+
417+
with pytest.raises(ValueError) as exc_info:
418+
_ = config.oidc_endpoints
419+
420+
assert "Unified host requires account_id" in str(exc_info.value)
421+
422+
423+
def test_workspace_org_id_header_on_unified_host(requests_mock):
424+
"""Test that X-Databricks-Org-Id header is added for workspace clients on unified hosts."""
425+
from databricks.sdk.core import ApiClient
426+
427+
requests_mock.get("https://unified.databricks.com/api/2.0/test", json={"result": "success"})
428+
429+
config = Config(
430+
host="https://unified.databricks.com",
431+
workspace_id="test-workspace-123",
432+
experimental_is_unified_host=True,
433+
token="test-token",
434+
)
435+
436+
api_client = ApiClient(config)
437+
api_client.do("GET", "/api/2.0/test")
438+
439+
# Verify the request was made with the X-Databricks-Org-Id header
440+
assert requests_mock.last_request.headers.get("X-Databricks-Org-Id") == "test-workspace-123"
441+
442+
443+
def test_no_org_id_header_on_regular_workspace(requests_mock):
444+
"""Test that X-Databricks-Org-Id header is NOT added for regular workspace hosts."""
445+
from databricks.sdk.core import ApiClient
446+
447+
requests_mock.get("https://test.databricks.com/api/2.0/test", json={"result": "success"})
448+
449+
config = Config(host="https://test.databricks.com", token="test-token")
450+
451+
api_client = ApiClient(config)
452+
api_client.do("GET", "/api/2.0/test")
453+
454+
# Verify the X-Databricks-Org-Id header was NOT added
455+
assert "X-Databricks-Org-Id" not in requests_mock.last_request.headers

0 commit comments

Comments
 (0)