diff --git a/README.md b/README.md index d57efda1f..5ce0af221 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,37 @@ or to a Databricks Runtime interactive cluster (e.g. /sql/protocolv1/o/123456789 > to authenticate the target Databricks user account and needs to open the browser for authentication. So it > can only run on the user's machine. +## Transaction Support + +The connector supports multi-statement transactions with manual commit/rollback control: + +```python +import os +from databricks import sql + +connection = sql.connect( + server_hostname=os.getenv("DATABRICKS_HOST"), + http_path=os.getenv("DATABRICKS_HTTP_PATH") +) + +# Disable autocommit to use explicit transactions +connection.autocommit = False + +cursor = connection.cursor() +try: + cursor.execute("INSERT INTO table1 VALUES (1, 'a')") + cursor.execute("INSERT INTO table2 VALUES (2, 'b')") + connection.commit() # Commit both inserts atomically +except Exception as e: + connection.rollback() # Rollback on error + raise + +cursor.close() +connection.close() +``` + +For detailed information about transaction behavior, isolation levels, error handling, and best practices, see [TRANSACTIONS.md](TRANSACTIONS.md). + ## SQLAlchemy Starting from `databricks-sql-connector` version 4.0.0 SQLAlchemy support has been extracted to a new library `databricks-sqlalchemy`. diff --git a/TRANSACTIONS.md b/TRANSACTIONS.md new file mode 100644 index 000000000..7def1230c --- /dev/null +++ b/TRANSACTIONS.md @@ -0,0 +1,368 @@ +# Transaction Support + +The Databricks SQL Connector for Python supports multi-statement transactions (MST). This allows you to group multiple SQL statements into atomic units that either succeed completely or fail completely. + +## Autocommit Behavior + +By default, every SQL statement executes in its own transaction and commits immediately (autocommit mode). This is the standard behavior for most database connectors. + +```python +from databricks import sql + +connection = sql.connect( + server_hostname="your-workspace.cloud.databricks.com", + http_path="/sql/1.0/warehouses/abc123" +) + +# Default: autocommit is True +print(connection.autocommit) # True + +# Each statement commits immediately +cursor = connection.cursor() +cursor.execute("INSERT INTO my_table VALUES (1, 'data')") +# Already committed - data is visible to other connections +``` + +To use explicit transactions, disable autocommit: + +```python +connection.autocommit = False + +# Now statements are grouped into a transaction +cursor = connection.cursor() +cursor.execute("INSERT INTO my_table VALUES (1, 'data')") +# Not committed yet - must call connection.commit() + +connection.commit() # Now it's visible +``` + +## Basic Transaction Operations + +### Committing Changes + +When autocommit is disabled, you must explicitly commit your changes: + +```python +connection.autocommit = False +cursor = connection.cursor() + +try: + cursor.execute("INSERT INTO orders VALUES (1, 100.00)") + cursor.execute("INSERT INTO order_items VALUES (1, 'Widget', 2)") + connection.commit() # Both inserts succeed together +except Exception as e: + connection.rollback() # Neither insert is saved + raise +``` + +### Rolling Back Changes + +Use `rollback()` to discard all changes made in the current transaction: + +```python +connection.autocommit = False +cursor = connection.cursor() + +cursor.execute("INSERT INTO accounts VALUES (1, 1000)") +cursor.execute("UPDATE accounts SET balance = balance - 500 WHERE id = 1") + +# Changed your mind? +connection.rollback() # All changes discarded +``` + +Note: Calling `rollback()` when autocommit is enabled is safe (it's a no-op), but calling `commit()` will raise a `TransactionError`. + +### Sequential Transactions + +After a commit or rollback, a new transaction starts automatically: + +```python +connection.autocommit = False + +# First transaction +cursor.execute("INSERT INTO logs VALUES (1, 'event1')") +connection.commit() + +# Second transaction starts automatically +cursor.execute("INSERT INTO logs VALUES (2, 'event2')") +connection.rollback() # Only the second insert is discarded +``` + +## Multi-Table Transactions + +Transactions span multiple tables atomically. Either all changes are committed, or all are rolled back: + +```python +connection.autocommit = False +cursor = connection.cursor() + +try: + # Insert into multiple tables + cursor.execute("INSERT INTO customers VALUES (1, 'Alice')") + cursor.execute("INSERT INTO orders VALUES (1, 1, 100.00)") + cursor.execute("INSERT INTO shipments VALUES (1, 1, 'pending')") + + connection.commit() # All three inserts succeed atomically +except Exception as e: + connection.rollback() # All three inserts are discarded + raise +``` + +This is particularly useful for maintaining data consistency across related tables. + +## Transaction Isolation + +Databricks uses **Snapshot Isolation** (mapped to `REPEATABLE_READ` in standard SQL terminology). This means: + +- **Repeatable reads**: Once you read data in a transaction, subsequent reads will see the same data (even if other transactions modify it) +- **Atomic commits**: Changes are visible to other connections only after commit +- **Write serializability within a single table**: Concurrent writes to the same table will cause conflicts +- **Snapshot isolation across tables**: Concurrent writes to different tables can succeed + +### Getting the Isolation Level + +```python +level = connection.get_transaction_isolation() +print(level) # Output: REPEATABLE_READ +``` + +### Setting the Isolation Level + +Currently, only `REPEATABLE_READ` is supported: + +```python +from databricks import sql + +# Using the constant +connection.set_transaction_isolation(sql.TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ) + +# Or using a string +connection.set_transaction_isolation("REPEATABLE_READ") + +# Other levels will raise NotSupportedError +connection.set_transaction_isolation("READ_COMMITTED") # Raises NotSupportedError +``` + +### What Repeatable Read Means in Practice + +Within a transaction, you'll always see a consistent snapshot of the data: + +```python +connection.autocommit = False +cursor = connection.cursor() + +# First read +cursor.execute("SELECT balance FROM accounts WHERE id = 1") +balance1 = cursor.fetchone()[0] # Returns 1000 + +# Another connection updates the balance +# (In a separate connection: UPDATE accounts SET balance = 500 WHERE id = 1) + +# Second read in the same transaction +cursor.execute("SELECT balance FROM accounts WHERE id = 1") +balance2 = cursor.fetchone()[0] # Still returns 1000 (repeatable read!) + +connection.commit() + +# After commit, new transactions will see the updated value (500) +``` + +## Error Handling + +### Setting Autocommit During a Transaction + +You cannot change autocommit mode while a transaction is active: + +```python +connection.autocommit = False +cursor.execute("INSERT INTO logs VALUES (1, 'data')") + +# This will raise TransactionError +try: + connection.autocommit = True +except sql.TransactionError as e: + print(f"Cannot change autocommit: {e}") + connection.rollback() # Clean up +``` + +### Committing Without an Active Transaction + +If autocommit is enabled, there's no active transaction, so calling `commit()` will fail: + +```python +connection.autocommit = True # Default + +try: + connection.commit() # Raises TransactionError +except sql.TransactionError as e: + print(f"No active transaction: {e}") +``` + +However, `rollback()` is safe in this case (it's a no-op). + +### Recovering from Query Failures + +If a statement fails during a transaction, roll back and start a new transaction: + +```python +connection.autocommit = False +cursor = connection.cursor() + +try: + cursor.execute("INSERT INTO valid_table VALUES (1, 'data')") + cursor.execute("INSERT INTO nonexistent_table VALUES (2, 'data')") # Fails + connection.commit() +except Exception as e: + connection.rollback() # Discard the partial transaction + + # Now you can start a fresh transaction + cursor.execute("INSERT INTO error_log VALUES (1, 'Query failed')") + connection.commit() +``` + +## Querying Server State + +By default, the `autocommit` property returns a cached value for performance. If you need to query the server each time (for example, if you're debugging or the state might change externally): + +```python +connection = sql.connect( + server_hostname="your-workspace.cloud.databricks.com", + http_path="/sql/1.0/warehouses/abc123", + fetch_autocommit_from_server=True +) + +# Each access queries the server +state = connection.autocommit # Executes "SET AUTOCOMMIT" query +``` + +This is generally not needed for normal usage. + +## Write Conflicts + +### Within a Single Table + +Databricks enforces **write serializability** within a single table. If two transactions try to modify the same table concurrently, one will fail: + +```python +# Connection 1 +conn1.autocommit = False +cursor1 = conn1.cursor() +cursor1.execute("INSERT INTO accounts VALUES (1, 100)") + +# Connection 2 (concurrent) +conn2.autocommit = False +cursor2 = conn2.cursor() +cursor2.execute("INSERT INTO accounts VALUES (2, 200)") + +# First commit succeeds +conn1.commit() # OK + +# Second commit fails with concurrent write conflict +try: + conn2.commit() # Raises error about concurrent writes +except Exception as e: + conn2.rollback() + print(f"Concurrent write detected: {e}") +``` + +This happens even when the rows being modified are different. The conflict detection is at the table level. + +### Across Multiple Tables + +Concurrent writes to *different* tables can succeed. Each table tracks its own write conflicts independently: + +```python +# Connection 1: writes to table_a +conn1.autocommit = False +cursor1 = conn1.cursor() +cursor1.execute("INSERT INTO table_a VALUES (1, 'data')") + +# Connection 2: writes to table_b (different table) +conn2.autocommit = False +cursor2 = conn2.cursor() +cursor2.execute("INSERT INTO table_b VALUES (1, 'data')") + +# Both commits succeed (different tables) +conn1.commit() # OK +conn2.commit() # Also OK +``` + +## Best Practices + +1. **Keep transactions short**: Long-running transactions can cause conflicts with other connections. Commit as soon as your atomic unit of work is complete. + +2. **Always handle exceptions**: Wrap transaction code in try/except and call `rollback()` on errors. + +```python +connection.autocommit = False +cursor = connection.cursor() + +try: + cursor.execute("INSERT INTO table1 VALUES (1, 'data')") + cursor.execute("UPDATE table2 SET status = 'updated'") + connection.commit() +except Exception as e: + connection.rollback() + logger.error(f"Transaction failed: {e}") + raise +``` + +3. **Use context managers**: If you're writing helper functions, consider using a context manager pattern: + +```python +from contextlib import contextmanager + +@contextmanager +def transaction(connection): + connection.autocommit = False + try: + yield connection + connection.commit() + except Exception: + connection.rollback() + raise + finally: + connection.autocommit = True + +# Usage +with transaction(connection): + cursor = connection.cursor() + cursor.execute("INSERT INTO logs VALUES (1, 'message')") + # Auto-commits on success, auto-rolls back on exception +``` + +4. **Reset autocommit when done**: After using explicit transactions, consider resetting autocommit to True for subsequent operations: + +```python +connection.autocommit = False +try: + # ... transaction code ... + connection.commit() +finally: + connection.autocommit = True # Reset to default +``` + +5. **Be aware of isolation semantics**: Remember that repeatable read means you see a snapshot from the start of your transaction. If you need to see recent changes from other transactions, commit your current transaction and start a new one. + +## Requirements + +To use transactions, you need: +- A Databricks SQL warehouse that supports Multi-Statement Transactions (MST) +- Tables created with the `delta.feature.catalogOwned-preview` table property: + +```sql +CREATE TABLE my_table (id INT, value STRING) +USING DELTA +TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') +``` + +## Related APIs + +- `connection.autocommit` - Get or set autocommit mode (boolean) +- `connection.commit()` - Commit the current transaction +- `connection.rollback()` - Roll back the current transaction +- `connection.get_transaction_isolation()` - Get the isolation level (returns `"REPEATABLE_READ"`) +- `connection.set_transaction_isolation(level)` - Validate/set isolation level (only `"REPEATABLE_READ"` supported) +- `sql.TransactionError` - Exception raised for transaction-specific errors + +All of these are extensions to [PEP 249](https://www.python.org/dev/peps/pep-0249/) (Python Database API Specification v2.0). diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index 403a4d130..df44dd534 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -8,6 +8,9 @@ paramstyle = "named" +# Transaction isolation level constants (extension to PEP 249) +TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ = "REPEATABLE_READ" + import re from typing import TYPE_CHECKING diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 5bb191ca2..4db1ad118 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -20,6 +20,8 @@ InterfaceError, NotSupportedError, ProgrammingError, + TransactionError, + DatabaseError, ) from databricks.sql.thrift_api.TCLIService import ttypes @@ -86,6 +88,9 @@ NO_NATIVE_PARAMS: List = [] +# Transaction isolation level constants (extension to PEP 249) +TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ = "REPEATABLE_READ" + class Connection: def __init__( @@ -206,6 +211,11 @@ def read(self) -> Optional[OAuthToken]: This allows 1. cursor.tables() to return METRIC_VIEW table type 2. cursor.columns() to return "measure" column type + :param fetch_autocommit_from_server: `bool`, optional (default is False) + When True, the connection.autocommit property queries the server for current state + using SET AUTOCOMMIT instead of returning cached value. + Set to True if autocommit might be changed by external means (e.g., external SQL commands). + When False (default), uses cached state for better performance. """ # Internal arguments in **kwargs: @@ -304,6 +314,9 @@ def read(self) -> Optional[OAuthToken]: kwargs.get("use_inline_params", False) ) self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None) + self._fetch_autocommit_from_server = kwargs.get( + "fetch_autocommit_from_server", False + ) self.force_enable_telemetry = kwargs.get("force_enable_telemetry", False) self.enable_telemetry = kwargs.get("enable_telemetry", False) @@ -473,15 +486,261 @@ def _close(self, close_cursors=True) -> None: if self.http_client: self.http_client.close() - def commit(self): - """No-op because Databricks does not support transactions""" - pass + @property + def autocommit(self) -> bool: + """ + Get auto-commit mode for this connection. - def rollback(self): - raise NotSupportedError( - "Transactions are not supported on Databricks", - session_id_hex=self.get_session_id_hex(), - ) + Extension to PEP 249. Returns cached value by default. + If fetch_autocommit_from_server=True was set during connection, + queries server for current state. + + Returns: + bool: True if auto-commit is enabled, False otherwise + + Raises: + InterfaceError: If connection is closed + TransactionError: If fetch_autocommit_from_server=True and query fails + """ + if not self.open: + raise InterfaceError( + "Cannot get autocommit on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + if self._fetch_autocommit_from_server: + return self._fetch_autocommit_state_from_server() + + return self.session.get_autocommit() + + @autocommit.setter + def autocommit(self, value: bool) -> None: + """ + Set auto-commit mode for this connection. + + Extension to PEP 249. Executes SET AUTOCOMMIT command on server. + + Args: + value: True to enable auto-commit, False to disable + + Raises: + InterfaceError: If connection is closed + TransactionError: If server rejects the change + """ + if not self.open: + raise InterfaceError( + "Cannot set autocommit on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + # Create internal cursor for transaction control + cursor = None + try: + cursor = self.cursor() + sql = f"SET AUTOCOMMIT = {'TRUE' if value else 'FALSE'}" + cursor.execute(sql) + + # Update cached state on success + self.session.set_autocommit(value) + + except DatabaseError as e: + # Wrap in TransactionError with context + raise TransactionError( + f"Failed to set autocommit to {value}: {e.message}", + context={ + **e.context, + "operation": "set_autocommit", + "autocommit_value": value, + }, + session_id_hex=self.get_session_id_hex(), + ) from e + finally: + if cursor: + cursor.close() + + def _fetch_autocommit_state_from_server(self) -> bool: + """ + Query server for current autocommit state using SET AUTOCOMMIT. + + Returns: + bool: Server's autocommit state + + Raises: + TransactionError: If query fails + """ + cursor = None + try: + cursor = self.cursor() + cursor.execute("SET AUTOCOMMIT") + + # Fetch result: should return row with value column + result = cursor.fetchone() + if result is None: + raise TransactionError( + "No result returned from SET AUTOCOMMIT query", + context={"operation": "fetch_autocommit"}, + session_id_hex=self.get_session_id_hex(), + ) + + # Parse value (first column should be "true" or "false") + value_str = str(result[0]).lower() + autocommit_state = value_str == "true" + + # Update cache + self.session.set_autocommit(autocommit_state) + + return autocommit_state + + except TransactionError: + # Re-raise TransactionError as-is + raise + except DatabaseError as e: + # Wrap other DatabaseErrors + raise TransactionError( + f"Failed to fetch autocommit state from server: {e.message}", + context={**e.context, "operation": "fetch_autocommit"}, + session_id_hex=self.get_session_id_hex(), + ) from e + finally: + if cursor: + cursor.close() + + def commit(self) -> None: + """ + Commit the current transaction. + + Per PEP 249. Should be called only when autocommit is disabled. + + When autocommit is False: + - Commits the current transaction + - Server automatically starts new transaction + + When autocommit is True: + - Server may throw error if no active transaction + + Raises: + InterfaceError: If connection is closed + TransactionError: If commit fails (e.g., no active transaction) + """ + if not self.open: + raise InterfaceError( + "Cannot commit on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + cursor = None + try: + cursor = self.cursor() + cursor.execute("COMMIT") + + except DatabaseError as e: + raise TransactionError( + f"Failed to commit transaction: {e.message}", + context={**e.context, "operation": "commit"}, + session_id_hex=self.get_session_id_hex(), + ) from e + finally: + if cursor: + cursor.close() + + def rollback(self) -> None: + """ + Rollback the current transaction. + + Per PEP 249. Should be called only when autocommit is disabled. + + When autocommit is False: + - Rolls back the current transaction + - Server automatically starts new transaction + + When autocommit is True: + - ROLLBACK is forgiving (no-op, doesn't throw exception) + + Note: ROLLBACK is safe to call even without active transaction. + + Raises: + InterfaceError: If connection is closed + TransactionError: If rollback fails + """ + if not self.open: + raise InterfaceError( + "Cannot rollback on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + cursor = None + try: + cursor = self.cursor() + cursor.execute("ROLLBACK") + + except DatabaseError as e: + raise TransactionError( + f"Failed to rollback transaction: {e.message}", + context={**e.context, "operation": "rollback"}, + session_id_hex=self.get_session_id_hex(), + ) from e + finally: + if cursor: + cursor.close() + + def get_transaction_isolation(self) -> str: + """ + Get the transaction isolation level. + + Extension to PEP 249. + + Databricks supports REPEATABLE_READ isolation level (Snapshot Isolation), + which is the default and only supported level. + + Returns: + str: "REPEATABLE_READ" - the transaction isolation level constant + + Raises: + InterfaceError: If connection is closed + """ + if not self.open: + raise InterfaceError( + "Cannot get transaction isolation on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + return TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ + + def set_transaction_isolation(self, level: str) -> None: + """ + Set transaction isolation level. + + Extension to PEP 249. + + Databricks supports only REPEATABLE_READ isolation level (Snapshot Isolation). + This method validates that the requested level is supported but does not + execute any SQL, as REPEATABLE_READ is the default server behavior. + + Args: + level: Isolation level. Must be "REPEATABLE_READ" or "REPEATABLE READ" + (case-insensitive, underscores and spaces are interchangeable) + + Raises: + InterfaceError: If connection is closed + NotSupportedError: If isolation level not supported + """ + if not self.open: + raise InterfaceError( + "Cannot set transaction isolation on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + # Normalize and validate isolation level + normalized_level = level.upper().replace("_", " ") + + if normalized_level != TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ.replace( + "_", " " + ): + raise NotSupportedError( + f"Setting transaction isolation level '{level}' is not supported. " + f"Only {TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ} is supported.", + session_id_hex=self.get_session_id_hex(), + ) class Cursor: diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 4a772c49b..3a3a6b3c5 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -70,6 +70,23 @@ class NotSupportedError(DatabaseError): pass +class TransactionError(DatabaseError): + """ + Exception raised for transaction-specific errors. + + This exception is used when transaction control operations fail, such as: + - Setting autocommit mode (AUTOCOMMIT_SET_DURING_ACTIVE_TRANSACTION) + - Committing a transaction (MULTI_STATEMENT_TRANSACTION_NO_ACTIVE_TRANSACTION) + - Rolling back a transaction + - Setting transaction isolation level + + The exception includes context about which transaction operation failed + and preserves the underlying cause via exception chaining. + """ + + pass + + ### Custom error classes ### class InvalidServerResponseError(OperationalError): """Thrown if the server does not set the initial namespace correctly""" diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index d8ba5d125..0f723d144 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -45,6 +45,9 @@ def __init__( self.schema = schema self.http_path = http_path + # Initialize autocommit state (JDBC default is True) + self._autocommit = True + user_agent_entry = kwargs.get("user_agent_entry") if user_agent_entry is None: user_agent_entry = kwargs.get("_user_agent_entry") @@ -168,6 +171,24 @@ def guid_hex(self) -> str: """Get the session ID in hex format""" return self._session_id.hex_guid + def get_autocommit(self) -> bool: + """ + Get the cached autocommit state for this session. + + Returns: + bool: True if autocommit is enabled, False otherwise + """ + return self._autocommit + + def set_autocommit(self, value: bool) -> None: + """ + Update the cached autocommit state for this session. + + Args: + value: True to cache autocommit as enabled, False as disabled + """ + self._autocommit = value + def close(self) -> None: """Close the underlying session.""" logger.info("Closing session %s", self.guid_hex) diff --git a/tests/e2e/test_transactions.py b/tests/e2e/test_transactions.py new file mode 100644 index 000000000..308a8c3d6 --- /dev/null +++ b/tests/e2e/test_transactions.py @@ -0,0 +1,594 @@ +""" +End-to-end integration tests for Multi-Statement Transaction (MST) APIs. + +These tests verify: +- autocommit property (getter/setter) +- commit() and rollback() methods +- get_transaction_isolation() and set_transaction_isolation() methods +- Transaction error handling + +Requirements: +- DBSQL warehouse that supports Multi-Statement Transactions (MST) +- Test environment configured via test.env file or environment variables + +Setup: +Set the following environment variables: +- DATABRICKS_SERVER_HOSTNAME +- DATABRICKS_HTTP_PATH +- DATABRICKS_ACCESS_TOKEN (or use OAuth) + +Usage: + pytest tests/e2e/test_transactions.py -v +""" + +import logging +import os +import pytest +from typing import Any, Dict + +import databricks.sql as sql +from databricks.sql import TransactionError, NotSupportedError, InterfaceError + +logger = logging.getLogger(__name__) + + +class TestTransactions: + """E2E tests for transaction control methods (MST support).""" + + # Test table name + TEST_TABLE_NAME = "transaction_test_table" + + @pytest.fixture(autouse=True) + def setup_and_teardown(self, connection_details): + """Setup test environment before each test and cleanup after.""" + self.connection_params = { + "server_hostname": connection_details["host"], + "http_path": connection_details["http_path"], + "access_token": connection_details.get("access_token"), + } + + # Get catalog and schema from environment or use defaults + self.catalog = os.getenv("DATABRICKS_CATALOG", "main") + self.schema = os.getenv("DATABRICKS_SCHEMA", "default") + + # Create connection for setup + self.connection = sql.connect(**self.connection_params) + + # Setup: Create test table + self._create_test_table() + + yield + + # Teardown: Cleanup + self._cleanup() + + def _get_fully_qualified_table_name(self) -> str: + """Get the fully qualified table name.""" + return f"{self.catalog}.{self.schema}.{self.TEST_TABLE_NAME}" + + def _create_test_table(self): + """Create the test table with Delta format and MST support.""" + fq_table_name = self._get_fully_qualified_table_name() + cursor = self.connection.cursor() + + try: + # Drop if exists + cursor.execute(f"DROP TABLE IF EXISTS {fq_table_name}") + + # Create table with Delta and catalog-owned feature for MST compatibility + cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS {fq_table_name} + (id INT, value STRING) + USING DELTA + TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') + """ + ) + + logger.info(f"Created test table: {fq_table_name}") + finally: + cursor.close() + + def _cleanup(self): + """Cleanup after test: rollback pending transactions, drop table, close connection.""" + try: + # Try to rollback any pending transaction + if ( + self.connection + and self.connection.open + and not self.connection.autocommit + ): + try: + self.connection.rollback() + except Exception as e: + logger.debug( + f"Rollback during cleanup failed (may be expected): {e}" + ) + + # Reset to autocommit mode + try: + self.connection.autocommit = True + except Exception as e: + logger.debug(f"Reset autocommit during cleanup failed: {e}") + + # Drop test table + if self.connection and self.connection.open: + fq_table_name = self._get_fully_qualified_table_name() + cursor = self.connection.cursor() + try: + cursor.execute(f"DROP TABLE IF EXISTS {fq_table_name}") + logger.info(f"Dropped test table: {fq_table_name}") + except Exception as e: + logger.warning(f"Failed to drop test table: {e}") + finally: + cursor.close() + + finally: + # Close connection + if self.connection: + self.connection.close() + + # ==================== BASIC AUTOCOMMIT TESTS ==================== + + def test_default_autocommit_is_true(self): + """Test that new connection defaults to autocommit=true.""" + assert ( + self.connection.autocommit is True + ), "New connection should have autocommit=true by default" + + def test_set_autocommit_to_false(self): + """Test successfully setting autocommit to false.""" + self.connection.autocommit = False + assert ( + self.connection.autocommit is False + ), "autocommit should be false after setting to false" + + def test_set_autocommit_to_true(self): + """Test successfully setting autocommit back to true.""" + # First disable + self.connection.autocommit = False + assert self.connection.autocommit is False + + # Then enable + self.connection.autocommit = True + assert ( + self.connection.autocommit is True + ), "autocommit should be true after setting to true" + + # ==================== COMMIT TESTS ==================== + + def test_commit_single_insert(self): + """Test successfully committing a transaction with single INSERT.""" + fq_table_name = self._get_fully_qualified_table_name() + + # Start transaction + self.connection.autocommit = False + + # Insert data + cursor = self.connection.cursor() + cursor.execute( + f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'test_value')" + ) + cursor.close() + + # Commit + self.connection.commit() + + # Verify data is persisted using a new connection + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute(f"SELECT value FROM {fq_table_name} WHERE id = 1") + result = verify_cursor.fetchone() + verify_cursor.close() + + assert result is not None, "Should find inserted row after commit" + assert result[0] == "test_value", "Value should match inserted value" + finally: + verify_conn.close() + + def test_commit_multiple_inserts(self): + """Test successfully committing a transaction with multiple INSERTs.""" + fq_table_name = self._get_fully_qualified_table_name() + + self.connection.autocommit = False + + # Insert multiple rows + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'value1')") + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'value2')") + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (3, 'value3')") + cursor.close() + + self.connection.commit() + + # Verify all rows persisted + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name}") + result = verify_cursor.fetchone() + verify_cursor.close() + + assert result[0] == 3, "Should have 3 rows after commit" + finally: + verify_conn.close() + + # ==================== ROLLBACK TESTS ==================== + + def test_rollback_single_insert(self): + """Test successfully rolling back a transaction.""" + fq_table_name = self._get_fully_qualified_table_name() + + self.connection.autocommit = False + + # Insert data + cursor = self.connection.cursor() + cursor.execute( + f"INSERT INTO {fq_table_name} (id, value) VALUES (100, 'rollback_test')" + ) + cursor.close() + + # Rollback + self.connection.rollback() + + # Verify data is NOT persisted + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 100" + ) + result = verify_cursor.fetchone() + verify_cursor.close() + + assert result[0] == 0, "Rolled back data should not be persisted" + finally: + verify_conn.close() + + # ==================== SEQUENTIAL TRANSACTION TESTS ==================== + + def test_multiple_sequential_transactions(self): + """Test executing multiple sequential transactions (commit, commit, rollback).""" + fq_table_name = self._get_fully_qualified_table_name() + + self.connection.autocommit = False + + # First transaction - commit + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'txn1')") + cursor.close() + self.connection.commit() + + # Second transaction - commit + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'txn2')") + cursor.close() + self.connection.commit() + + # Third transaction - rollback + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (3, 'txn3')") + cursor.close() + self.connection.rollback() + + # Verify only first two transactions persisted + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table_name} WHERE id IN (1, 2)" + ) + result = verify_cursor.fetchone() + assert result[0] == 2, "Should have 2 committed rows" + + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 3") + result = verify_cursor.fetchone() + assert result[0] == 0, "Rolled back row should not exist" + verify_cursor.close() + finally: + verify_conn.close() + + def test_auto_start_transaction_after_commit(self): + """Test that new transaction automatically starts after commit.""" + fq_table_name = self._get_fully_qualified_table_name() + + self.connection.autocommit = False + + # First transaction - commit + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'first')") + cursor.close() + self.connection.commit() + + # New transaction should start automatically - insert and rollback + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'second')") + cursor.close() + self.connection.rollback() + + # Verify: first committed, second rolled back + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 1") + result = verify_cursor.fetchone() + assert result[0] == 1, "First insert should be committed" + + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 2") + result = verify_cursor.fetchone() + assert result[0] == 0, "Second insert should be rolled back" + verify_cursor.close() + finally: + verify_conn.close() + + def test_auto_start_transaction_after_rollback(self): + """Test that new transaction automatically starts after rollback.""" + fq_table_name = self._get_fully_qualified_table_name() + + self.connection.autocommit = False + + # First transaction - rollback + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'first')") + cursor.close() + self.connection.rollback() + + # New transaction should start automatically - insert and commit + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'second')") + cursor.close() + self.connection.commit() + + # Verify: first rolled back, second committed + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 1") + result = verify_cursor.fetchone() + assert result[0] == 0, "First insert should be rolled back" + + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 2") + result = verify_cursor.fetchone() + assert result[0] == 1, "Second insert should be committed" + verify_cursor.close() + finally: + verify_conn.close() + + # ==================== UPDATE/DELETE OPERATION TESTS ==================== + + def test_update_in_transaction(self): + """Test UPDATE operation in transaction.""" + fq_table_name = self._get_fully_qualified_table_name() + + # First insert a row with autocommit + cursor = self.connection.cursor() + cursor.execute( + f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'original')" + ) + cursor.close() + + # Start transaction and update + self.connection.autocommit = False + cursor = self.connection.cursor() + cursor.execute(f"UPDATE {fq_table_name} SET value = 'updated' WHERE id = 1") + cursor.close() + self.connection.commit() + + # Verify update persisted + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute(f"SELECT value FROM {fq_table_name} WHERE id = 1") + result = verify_cursor.fetchone() + assert result[0] == "updated", "Value should be updated after commit" + verify_cursor.close() + finally: + verify_conn.close() + + # ==================== MULTI-TABLE TRANSACTION TESTS ==================== + + def test_multi_table_transaction_commit(self): + """Test atomic commit across multiple tables.""" + fq_table1_name = self._get_fully_qualified_table_name() + table2_name = self.TEST_TABLE_NAME + "_2" + fq_table2_name = f"{self.catalog}.{self.schema}.{table2_name}" + + # Create second table + cursor = self.connection.cursor() + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") + cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS {fq_table2_name} + (id INT, category STRING) + USING DELTA + TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') + """ + ) + cursor.close() + + try: + # Start transaction and insert into both tables + self.connection.autocommit = False + + cursor = self.connection.cursor() + cursor.execute( + f"INSERT INTO {fq_table1_name} (id, value) VALUES (10, 'table1_data')" + ) + cursor.execute( + f"INSERT INTO {fq_table2_name} (id, category) VALUES (10, 'table2_data')" + ) + cursor.close() + + # Commit both atomically + self.connection.commit() + + # Verify both inserts persisted + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table1_name} WHERE id = 10" + ) + result = verify_cursor.fetchone() + assert result[0] == 1, "Table1 insert should be committed" + + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table2_name} WHERE id = 10" + ) + result = verify_cursor.fetchone() + assert result[0] == 1, "Table2 insert should be committed" + + verify_cursor.close() + finally: + verify_conn.close() + + finally: + # Cleanup second table + self.connection.autocommit = True + cursor = self.connection.cursor() + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") + cursor.close() + + def test_multi_table_transaction_rollback(self): + """Test atomic rollback across multiple tables.""" + fq_table1_name = self._get_fully_qualified_table_name() + table2_name = self.TEST_TABLE_NAME + "_2" + fq_table2_name = f"{self.catalog}.{self.schema}.{table2_name}" + + # Create second table + cursor = self.connection.cursor() + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") + cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS {fq_table2_name} + (id INT, category STRING) + USING DELTA + TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') + """ + ) + cursor.close() + + try: + # Start transaction and insert into both tables + self.connection.autocommit = False + + cursor = self.connection.cursor() + cursor.execute( + f"INSERT INTO {fq_table1_name} (id, value) VALUES (20, 'rollback1')" + ) + cursor.execute( + f"INSERT INTO {fq_table2_name} (id, category) VALUES (20, 'rollback2')" + ) + cursor.close() + + # Rollback both atomically + self.connection.rollback() + + # Verify both inserts were rolled back + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table1_name} WHERE id = 20" + ) + result = verify_cursor.fetchone() + assert result[0] == 0, "Table1 insert should be rolled back" + + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table2_name} WHERE id = 20" + ) + result = verify_cursor.fetchone() + assert result[0] == 0, "Table2 insert should be rolled back" + + verify_cursor.close() + finally: + verify_conn.close() + + finally: + # Cleanup second table + self.connection.autocommit = True + cursor = self.connection.cursor() + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") + cursor.close() + + # ==================== ERROR HANDLING TESTS ==================== + + def test_set_autocommit_during_active_transaction(self): + """Test that setting autocommit during an active transaction throws error.""" + fq_table_name = self._get_fully_qualified_table_name() + + # Start transaction + self.connection.autocommit = False + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (99, 'test')") + cursor.close() + + # Try to set autocommit=True during active transaction + with pytest.raises(TransactionError) as exc_info: + self.connection.autocommit = True + + # Verify error message mentions autocommit or active transaction + error_msg = str(exc_info.value).lower() + assert ( + "autocommit" in error_msg or "active transaction" in error_msg + ), "Error should mention autocommit or active transaction" + + # Cleanup - rollback the transaction + self.connection.rollback() + + def test_commit_without_active_transaction_throws_error(self): + """Test that commit() throws error when autocommit=true (no active transaction).""" + # Ensure autocommit is true (default) + assert self.connection.autocommit is True + + # Attempt commit without active transaction should throw + with pytest.raises(TransactionError) as exc_info: + self.connection.commit() + + # Verify error message indicates no active transaction + error_message = str(exc_info.value) + assert ( + "MULTI_STATEMENT_TRANSACTION_NO_ACTIVE_TRANSACTION" in error_message + or "no active transaction" in error_message.lower() + ), "Error should indicate no active transaction" + + def test_rollback_without_active_transaction_is_safe(self): + """Test that rollback() without active transaction is a safe no-op.""" + # With autocommit=true (no active transaction) + assert self.connection.autocommit is True + + # ROLLBACK should be safe (no exception) + self.connection.rollback() + + # Verify connection is still usable + assert self.connection.autocommit is True + assert self.connection.open is True + + # ==================== TRANSACTION ISOLATION TESTS ==================== + + def test_get_transaction_isolation_returns_repeatable_read(self): + """Test that get_transaction_isolation() returns REPEATABLE_READ.""" + isolation_level = self.connection.get_transaction_isolation() + assert ( + isolation_level == "REPEATABLE_READ" + ), "Databricks MST should use REPEATABLE_READ (Snapshot Isolation)" + + def test_set_transaction_isolation_accepts_repeatable_read(self): + """Test that set_transaction_isolation() accepts REPEATABLE_READ.""" + # Should not raise - these are all valid formats + self.connection.set_transaction_isolation("REPEATABLE_READ") + self.connection.set_transaction_isolation("REPEATABLE READ") + self.connection.set_transaction_isolation("repeatable_read") + self.connection.set_transaction_isolation("repeatable read") + + def test_set_transaction_isolation_rejects_unsupported_level(self): + """Test that set_transaction_isolation() rejects unsupported levels.""" + with pytest.raises(NotSupportedError) as exc_info: + self.connection.set_transaction_isolation("READ_COMMITTED") + + error_message = str(exc_info.value) + assert "not supported" in error_message.lower() + assert "READ_COMMITTED" in error_message diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 19375cde3..cb810afbb 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -22,7 +22,13 @@ import databricks.sql import databricks.sql.client as client -from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError +from databricks.sql import ( + InterfaceError, + DatabaseError, + Error, + NotSupportedError, + TransactionError, +) from databricks.sql.types import Row from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.backend.types import CommandId, CommandState @@ -439,11 +445,6 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set(self): "last operation", ) - @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) - def test_commit_a_noop(self, mock_thrift_backend_class): - c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - c.commit() - def test_setinputsizes_a_noop(self): cursor = client.Cursor(Mock(), Mock()) cursor.setinputsizes(1) @@ -452,12 +453,6 @@ def test_setoutputsizes_a_noop(self): cursor = client.Cursor(Mock(), Mock()) cursor.setoutputsize(1) - @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) - def test_rollback_not_supported(self, mock_thrift_backend_class): - c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - with self.assertRaises(NotSupportedError): - c.rollback() - @unittest.skip("JDW: skipping winter 2024 as we're about to rewrite this interface") @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_row_number_respected(self, mock_thrift_backend_class): @@ -639,11 +634,377 @@ def mock_close_normal(): ) +class TransactionTestSuite(unittest.TestCase): + """ + Unit tests for transaction control methods (MST support). + """ + + PACKAGE_NAME = "databricks.sql" + DUMMY_CONNECTION_ARGS = { + "server_hostname": "foo", + "http_path": "dummy_path", + "access_token": "tok", + } + + def _create_mock_connection(self, mock_session_class): + """Helper to create a mocked connection for transaction tests.""" + # Mock session + mock_session = Mock() + mock_session.is_open = True + mock_session.guid_hex = "test-session-id" + mock_session.get_autocommit.return_value = True + mock_session_class.return_value = mock_session + + # Create connection + conn = client.Connection(**self.DUMMY_CONNECTION_ARGS) + return conn + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_autocommit_getter_returns_cached_value(self, mock_session_class): + """Test that autocommit property returns cached session value by default.""" + conn = self._create_mock_connection(mock_session_class) + + # Get autocommit (should use cached value) + result = conn.autocommit + + conn.session.get_autocommit.assert_called_once() + self.assertTrue(result) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_autocommit_setter_executes_sql(self, mock_session_class): + """Test that setting autocommit executes SET AUTOCOMMIT command.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + with patch.object(conn, "cursor", return_value=mock_cursor): + conn.autocommit = False + + # Verify SQL was executed + mock_cursor.execute.assert_called_once_with("SET AUTOCOMMIT = FALSE") + mock_cursor.close.assert_called_once() + + conn.session.set_autocommit.assert_called_once_with(False) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_autocommit_setter_with_true_value(self, mock_session_class): + """Test setting autocommit to True.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + with patch.object(conn, "cursor", return_value=mock_cursor): + conn.autocommit = True + + mock_cursor.execute.assert_called_once_with("SET AUTOCOMMIT = TRUE") + conn.session.set_autocommit.assert_called_once_with(True) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_autocommit_setter_wraps_database_error(self, mock_session_class): + """Test that autocommit setter wraps DatabaseError in TransactionError.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + server_error = DatabaseError( + "AUTOCOMMIT_SET_DURING_ACTIVE_TRANSACTION", + context={"sql_state": "25000"}, + session_id_hex="test-session-id", + ) + mock_cursor.execute.side_effect = server_error + + with patch.object(conn, "cursor", return_value=mock_cursor): + with self.assertRaises(TransactionError) as ctx: + conn.autocommit = False + + self.assertIn("Failed to set autocommit", str(ctx.exception)) + self.assertEqual(ctx.exception.context["operation"], "set_autocommit") + self.assertEqual(ctx.exception.context["autocommit_value"], False) + + mock_cursor.close.assert_called_once() + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_autocommit_setter_preserves_exception_chain(self, mock_session_class): + """Test that exception chaining is preserved.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + original_error = DatabaseError( + "Original error", session_id_hex="test-session-id" + ) + mock_cursor.execute.side_effect = original_error + + with patch.object(conn, "cursor", return_value=mock_cursor): + with self.assertRaises(TransactionError) as ctx: + conn.autocommit = False + + self.assertEqual(ctx.exception.__cause__, original_error) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_commit_executes_sql(self, mock_session_class): + """Test that commit() executes COMMIT command.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + with patch.object(conn, "cursor", return_value=mock_cursor): + conn.commit() + + mock_cursor.execute.assert_called_once_with("COMMIT") + mock_cursor.close.assert_called_once() + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_commit_wraps_database_error(self, mock_session_class): + """Test that commit() wraps DatabaseError in TransactionError.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + server_error = DatabaseError( + "MULTI_STATEMENT_TRANSACTION_NO_ACTIVE_TRANSACTION", + context={"sql_state": "25000"}, + session_id_hex="test-session-id", + ) + mock_cursor.execute.side_effect = server_error + + with patch.object(conn, "cursor", return_value=mock_cursor): + with self.assertRaises(TransactionError) as ctx: + conn.commit() + + self.assertIn("Failed to commit", str(ctx.exception)) + self.assertEqual(ctx.exception.context["operation"], "commit") + mock_cursor.close.assert_called_once() + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_commit_on_closed_connection_raises_interface_error( + self, mock_session_class + ): + """Test that commit() on closed connection raises InterfaceError.""" + conn = self._create_mock_connection(mock_session_class) + conn.session.is_open = False + + with self.assertRaises(InterfaceError) as ctx: + conn.commit() + + self.assertIn("Cannot commit on closed connection", str(ctx.exception)) + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_rollback_executes_sql(self, mock_session_class): + """Test that rollback() executes ROLLBACK command.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + with patch.object(conn, "cursor", return_value=mock_cursor): + conn.rollback() + + mock_cursor.execute.assert_called_once_with("ROLLBACK") + mock_cursor.close.assert_called_once() + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_rollback_wraps_database_error(self, mock_session_class): + """Test that rollback() wraps DatabaseError in TransactionError.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + server_error = DatabaseError( + "Unexpected rollback error", + context={"sql_state": "HY000"}, + session_id_hex="test-session-id", + ) + mock_cursor.execute.side_effect = server_error + + with patch.object(conn, "cursor", return_value=mock_cursor): + with self.assertRaises(TransactionError) as ctx: + conn.rollback() + + self.assertIn("Failed to rollback", str(ctx.exception)) + self.assertEqual(ctx.exception.context["operation"], "rollback") + mock_cursor.close.assert_called_once() + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_rollback_on_closed_connection_raises_interface_error( + self, mock_session_class + ): + """Test that rollback() on closed connection raises InterfaceError.""" + conn = self._create_mock_connection(mock_session_class) + conn.session.is_open = False + + with self.assertRaises(InterfaceError) as ctx: + conn.rollback() + + self.assertIn("Cannot rollback on closed connection", str(ctx.exception)) + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_get_transaction_isolation_returns_repeatable_read( + self, mock_session_class + ): + """Test that get_transaction_isolation() returns REPEATABLE_READ.""" + conn = self._create_mock_connection(mock_session_class) + + result = conn.get_transaction_isolation() + + self.assertEqual(result, "REPEATABLE_READ") + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_get_transaction_isolation_on_closed_connection_raises_interface_error( + self, mock_session_class + ): + """Test that get_transaction_isolation() on closed connection raises InterfaceError.""" + conn = self._create_mock_connection(mock_session_class) + conn.session.is_open = False + + with self.assertRaises(InterfaceError) as ctx: + conn.get_transaction_isolation() + + self.assertIn( + "Cannot get transaction isolation on closed connection", str(ctx.exception) + ) + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_set_transaction_isolation_accepts_repeatable_read( + self, mock_session_class + ): + """Test that set_transaction_isolation() accepts REPEATABLE_READ.""" + conn = self._create_mock_connection(mock_session_class) + + # Should not raise + conn.set_transaction_isolation("REPEATABLE_READ") + conn.set_transaction_isolation("REPEATABLE READ") # With space + conn.set_transaction_isolation("repeatable_read") # Lowercase with underscore + conn.set_transaction_isolation("repeatable read") # Lowercase with space + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_set_transaction_isolation_rejects_other_levels(self, mock_session_class): + """Test that set_transaction_isolation() rejects non-REPEATABLE_READ levels.""" + conn = self._create_mock_connection(mock_session_class) + + with self.assertRaises(NotSupportedError) as ctx: + conn.set_transaction_isolation("READ_COMMITTED") + + self.assertIn("not supported", str(ctx.exception)) + self.assertIn("READ_COMMITTED", str(ctx.exception)) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_set_transaction_isolation_on_closed_connection_raises_interface_error( + self, mock_session_class + ): + """Test that set_transaction_isolation() on closed connection raises InterfaceError.""" + conn = self._create_mock_connection(mock_session_class) + conn.session.is_open = False + + with self.assertRaises(InterfaceError) as ctx: + conn.set_transaction_isolation("REPEATABLE_READ") + + self.assertIn( + "Cannot set transaction isolation on closed connection", str(ctx.exception) + ) + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_fetch_autocommit_from_server_queries_server(self, mock_session_class): + """Test that fetch_autocommit_from_server=True queries server.""" + # Create connection with fetch_autocommit_from_server=True + mock_session = Mock() + mock_session.is_open = True + mock_session.guid_hex = "test-session-id" + mock_session_class.return_value = mock_session + + conn = client.Connection( + fetch_autocommit_from_server=True, **self.DUMMY_CONNECTION_ARGS + ) + + mock_cursor = Mock() + mock_row = Mock() + mock_row.__getitem__ = Mock(return_value="true") + mock_cursor.fetchone.return_value = mock_row + + with patch.object(conn, "cursor", return_value=mock_cursor): + result = conn.autocommit + + mock_cursor.execute.assert_called_once_with("SET AUTOCOMMIT") + mock_cursor.fetchone.assert_called_once() + mock_cursor.close.assert_called_once() + + conn.session.set_autocommit.assert_called_once_with(True) + + self.assertTrue(result) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_fetch_autocommit_from_server_handles_false_value(self, mock_session_class): + """Test that fetch_autocommit_from_server correctly parses false value.""" + mock_session = Mock() + mock_session.is_open = True + mock_session.guid_hex = "test-session-id" + mock_session_class.return_value = mock_session + + conn = client.Connection( + fetch_autocommit_from_server=True, **self.DUMMY_CONNECTION_ARGS + ) + + mock_cursor = Mock() + mock_row = Mock() + mock_row.__getitem__ = Mock(return_value="false") + mock_cursor.fetchone.return_value = mock_row + + with patch.object(conn, "cursor", return_value=mock_cursor): + result = conn.autocommit + + conn.session.set_autocommit.assert_called_once_with(False) + self.assertFalse(result) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_fetch_autocommit_from_server_raises_on_no_result(self, mock_session_class): + """Test that fetch_autocommit_from_server raises error when no result.""" + mock_session = Mock() + mock_session.is_open = True + mock_session.guid_hex = "test-session-id" + mock_session_class.return_value = mock_session + + conn = client.Connection( + fetch_autocommit_from_server=True, **self.DUMMY_CONNECTION_ARGS + ) + + mock_cursor = Mock() + mock_cursor.fetchone.return_value = None + + with patch.object(conn, "cursor", return_value=mock_cursor): + with self.assertRaises(TransactionError) as ctx: + _ = conn.autocommit + + self.assertIn("No result returned", str(ctx.exception)) + mock_cursor.close.assert_called_once() + + conn.close() + + if __name__ == "__main__": suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__]) loader = unittest.TestLoader() test_classes = [ ClientTestSuite, + TransactionTestSuite, FetchTests, ThriftBackendTestSuite, ArrowQueueSuite,