diff --git a/connection.go b/connection.go index 29a15a08..100dfe31 100644 --- a/connection.go +++ b/connection.go @@ -14,24 +14,25 @@ import ( "github.com/pkg/errors" ) -type conn struct { +type Conn struct { id string cfg *config.Config client cli_service.TCLIService session *cli_service.TOpenSessionResp + exc *Execution } // The driver does not really implement prepared statements. -func (c *conn) Prepare(query string) (driver.Stmt, error) { +func (c *Conn) Prepare(query string) (driver.Stmt, error) { return &stmt{conn: c, query: query}, nil } // The driver does not really implement prepared statements. -func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { +func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { return &stmt{conn: c, query: query}, nil } -func (c *conn) Close() error { +func (c *Conn) Close() error { log := logger.WithContext(c.id, "", "") ctx := driverctx.NewContextWithConnId(context.Background(), c.id) sentinel := sentinel.Sentinel{ @@ -50,16 +51,16 @@ func (c *conn) Close() error { } // Not supported in Databricks -func (c *conn) Begin() (driver.Tx, error) { +func (c *Conn) Begin() (driver.Tx, error) { return nil, errors.New(ErrTransactionsNotSupported) } // Not supported in Databricks -func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { +func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { return nil, errors.New(ErrTransactionsNotSupported) } -func (c *conn) Ping(ctx context.Context) error { +func (c *Conn) Ping(ctx context.Context) error { log := logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), "") ctx = driverctx.NewContextWithConnId(ctx, c.id) ctx1, cancel := context.WithTimeout(ctx, 15*time.Second) @@ -73,12 +74,13 @@ func (c *conn) Ping(ctx context.Context) error { } // Implementation of SessionResetter -func (c *conn) ResetSession(ctx context.Context) error { +func (c *Conn) ResetSession(ctx context.Context) error { // For now our session does not have any important state to reset before re-use + c.exc = nil return nil } -func (c *conn) IsValid() bool { +func (c *Conn) IsValid() bool { return c.session.GetStatus().StatusCode == cli_service.TStatusCode_SUCCESS_STATUS } @@ -87,13 +89,17 @@ func (c *conn) IsValid() bool { // // ExecContext honors the context timeout and return when it is canceled. // Statement ExecContext is the same as connection ExecContext -func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { +func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { log := logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), "") msg, start := logger.Track("ExecContext") ctx = driverctx.NewContextWithConnId(ctx, c.id) if len(args) > 0 { return nil, errors.New(ErrParametersNotSupported) } + if query == "" && c.exc != nil { + //TODO + return nil, errors.New(ErrNotImplemented) + } exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args) if exStmtResp != nil && exStmtResp.OperationHandle != nil { @@ -115,7 +121,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name // // QueryContext honors the context timeout and return when it is canceled. // Statement QueryContext is the same as connection QueryContext -func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { +func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { corrId := driverctx.CorrelationIdFromContext(ctx) log := logger.WithContext(c.id, corrId, "") msg, start := log.Track("QueryContext") @@ -124,42 +130,84 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam if len(args) > 0 { return nil, errors.New(ErrParametersNotSupported) } - // first we try to get the results synchronously. - // at any point in time that the context is done we must cancel and return - exStmtResp, _, err := c.runQuery(ctx, query, args) + if query == "" && c.exc != nil { + opHandle := toOperationHandle(c.exc) + rows := rows{ + connId: c.id, + correlationId: corrId, + client: c.client, + opHandle: opHandle, + pageSize: int64(c.cfg.MaxRows), + location: c.cfg.Location, + } + return &rows, nil + } else { - if exStmtResp != nil && exStmtResp.OperationHandle != nil { - log = logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID)) - } - defer log.Duration(msg, start) + // first we try to get the results synchronously. + // at any point in time that the context is done we must cancel and return + exStmtResp, opStatus, err := c.runQuery(ctx, query, args) - if err != nil { - log.Err(err).Msgf("databricks: failed to run query: query %s", query) - return nil, wrapErrf(err, "failed to run query") - } - // hold on to the operation handle - opHandle := exStmtResp.OperationHandle + excId := "" + excStatus := ExecutionUnknown - rows := rows{ - connId: c.id, - correlationId: corrId, - client: c.client, - opHandle: opHandle, - pageSize: int64(c.cfg.MaxRows), - location: c.cfg.Location, - } + if opStatus != nil { + excStatus = toExecutionStatus(opStatus.GetOperationState()) + } + // hold on to the operation handle + opHandle := exStmtResp.OperationHandle - if exStmtResp.DirectResults != nil { - // return results - rows.fetchResults = exStmtResp.DirectResults.ResultSet - rows.fetchResultsMetadata = exStmtResp.DirectResults.ResultSetMetadata + excId = client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID) + log = logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), excId) + + defer log.Duration(msg, start) + + if c.cfg.RunAsync { + + excPtr := excFromContext(ctx) + *excPtr = Execution{ + Id: excId, + Status: excStatus, + Secret: opHandle.OperationId.Secret, + HasResultSet: opHandle.HasResultSet, + } + } + if err != nil { + log.Err(err).Msgf("databricks: failed to run query: query %s", query) + return nil, wrapErrf(err, "failed to run query") + } + + rows := rows{ + connId: c.id, + correlationId: corrId, + client: c.client, + opHandle: opHandle, + pageSize: int64(c.cfg.MaxRows), + location: c.cfg.Location, + } + + if exStmtResp.DirectResults != nil { + // return results + rows.fetchResults = exStmtResp.DirectResults.ResultSet + rows.fetchResultsMetadata = exStmtResp.DirectResults.ResultSetMetadata + } + + // if the direct results has all rows, the operation will be deleted, so + // set it to closed so clients won't ask for it + // excStatus = ExecutionClosed + + if c.cfg.RunAsync && excStatus != ExecutionFinished { + rows.opHandle = nil + + } + + return &rows, nil } - return &rows, nil } -func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedValue) (*cli_service.TExecuteStatementResp, *cli_service.TGetOperationStatusResp, error) { +func (c *Conn) runQuery(ctx context.Context, query string, args []driver.NamedValue) (*cli_service.TExecuteStatementResp, *cli_service.TGetOperationStatusResp, error) { + log := logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), "") // first we try to get the results synchronously. // at any point in time that the context is done we must cancel and return @@ -190,6 +238,39 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa return exStmtResp, opStatus, errors.New(opStatus.GetDisplayMessage()) // live states case cli_service.TOperationState_INITIALIZED_STATE, cli_service.TOperationState_PENDING_STATE, cli_service.TOperationState_RUNNING_STATE: + if c.cfg.RunAsync { + return exStmtResp, opStatus, nil + } else { + statusResp, err := c.pollOperation(ctx, opHandle) + if err != nil { + return exStmtResp, statusResp, err + } + switch statusResp.GetOperationState() { + // terminal states + // good + case cli_service.TOperationState_FINISHED_STATE: + // return handle to fetch results later + return exStmtResp, opStatus, nil + // bad + case cli_service.TOperationState_CANCELED_STATE, cli_service.TOperationState_CLOSED_STATE, cli_service.TOperationState_ERROR_STATE, cli_service.TOperationState_TIMEDOUT_STATE: + logBadQueryState(log, statusResp) + return exStmtResp, opStatus, errors.New(statusResp.GetDisplayMessage()) + // live states + default: + logBadQueryState(log, statusResp) + return exStmtResp, opStatus, errors.New("invalid operation state. This should not have happened") + } + } + // weird states + default: + logBadQueryState(log, opStatus) + return exStmtResp, opStatus, errors.New("invalid operation state. This should not have happened") + } + + } else { + if c.cfg.RunAsync { + return exStmtResp, nil, nil + } else { statusResp, err := c.pollOperation(ctx, opHandle) if err != nil { return exStmtResp, statusResp, err @@ -199,41 +280,16 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa // good case cli_service.TOperationState_FINISHED_STATE: // return handle to fetch results later - return exStmtResp, opStatus, nil + return exStmtResp, statusResp, nil // bad case cli_service.TOperationState_CANCELED_STATE, cli_service.TOperationState_CLOSED_STATE, cli_service.TOperationState_ERROR_STATE, cli_service.TOperationState_TIMEDOUT_STATE: logBadQueryState(log, statusResp) - return exStmtResp, opStatus, errors.New(statusResp.GetDisplayMessage()) + return exStmtResp, statusResp, errors.New(statusResp.GetDisplayMessage()) // live states default: logBadQueryState(log, statusResp) - return exStmtResp, opStatus, errors.New("invalid operation state. This should not have happened") + return exStmtResp, statusResp, errors.New("invalid operation state. This should not have happened") } - // weird states - default: - logBadQueryState(log, opStatus) - return exStmtResp, opStatus, errors.New("invalid operation state. This should not have happened") - } - - } else { - statusResp, err := c.pollOperation(ctx, opHandle) - if err != nil { - return exStmtResp, statusResp, err - } - switch statusResp.GetOperationState() { - // terminal states - // good - case cli_service.TOperationState_FINISHED_STATE: - // return handle to fetch results later - return exStmtResp, statusResp, nil - // bad - case cli_service.TOperationState_CANCELED_STATE, cli_service.TOperationState_CLOSED_STATE, cli_service.TOperationState_ERROR_STATE, cli_service.TOperationState_TIMEDOUT_STATE: - logBadQueryState(log, statusResp) - return exStmtResp, statusResp, errors.New(statusResp.GetDisplayMessage()) - // live states - default: - logBadQueryState(log, statusResp) - return exStmtResp, statusResp, errors.New("invalid operation state. This should not have happened") } } } @@ -243,7 +299,7 @@ func logBadQueryState(log *logger.DBSQLLogger, opStatus *cli_service.TGetOperati log.Error().Msg(opStatus.GetErrorMessage()) } -func (c *conn) executeStatement(ctx context.Context, query string, args []driver.NamedValue) (*cli_service.TExecuteStatementResp, error) { +func (c *Conn) executeStatement(ctx context.Context, query string, args []driver.NamedValue) (*cli_service.TExecuteStatementResp, error) { corrId := driverctx.CorrelationIdFromContext(ctx) log := logger.WithContext(c.id, corrId, "") sentinel := sentinel.Sentinel{ @@ -251,9 +307,9 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver req := cli_service.TExecuteStatementReq{ SessionHandle: c.session.SessionHandle, Statement: query, - RunAsync: c.cfg.RunAsync, + RunAsync: true, QueryTimeout: int64(c.cfg.QueryTimeout / time.Second), - // this is specific for databricks. It shortcuts server roundtrips + // this is specific for databricks. It shortcuts server round-trips GetDirectResults: &cli_service.TSparkGetDirectResults{ MaxRows: int64(c.cfg.MaxRows), }, @@ -281,7 +337,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver return exStmtResp, err } -func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperationHandle) (*cli_service.TGetOperationStatusResp, error) { +func (c *Conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperationHandle) (*cli_service.TGetOperationStatusResp, error) { corrId := driverctx.CorrelationIdFromContext(ctx) log := logger.WithContext(c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID)) var statusResp *cli_service.TGetOperationStatusResp @@ -298,7 +354,7 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati OperationHandle: opHandle, }) if statusResp != nil && statusResp.OperationState != nil { - log.Debug().Msgf("databricks: status %s", statusResp.GetOperationState().String()) + log.Debug().Msgf("databricks: status %s", toExecutionStatus(statusResp.GetOperationState())) } return func() bool { // which other states? @@ -333,11 +389,81 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati return statusResp, nil } -var _ driver.Conn = (*conn)(nil) -var _ driver.Pinger = (*conn)(nil) -var _ driver.SessionResetter = (*conn)(nil) -var _ driver.Validator = (*conn)(nil) -var _ driver.ExecerContext = (*conn)(nil) -var _ driver.QueryerContext = (*conn)(nil) -var _ driver.ConnPrepareContext = (*conn)(nil) -var _ driver.ConnBeginTx = (*conn)(nil) +func (c *Conn) cancelOperation(ctx context.Context, exc Execution) error { + // TODO wrap in Sentinel + req := cli_service.TCancelOperationReq{ + OperationHandle: toOperationHandle(&exc), + } + _, err := c.client.CancelOperation(ctx, &req) + return err +} + +func (c *Conn) getOperationStatus(ctx context.Context, exc Execution) (Execution, error) { + // TODO wrap in Sentinel + statusResp, err := c.client.GetOperationStatus(ctx, &cli_service.TGetOperationStatusReq{ + OperationHandle: toOperationHandle(&exc), + }) + if err != nil { + return exc, err + } + exRet := Execution{ + Status: toExecutionStatus(statusResp.GetOperationState()), + Id: exc.Id, + Secret: exc.Secret, + HasResultSet: exc.HasResultSet, + } + return exRet, nil +} + +func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error { + ex, ok := nv.Value.(Execution) + if ok { + c.exc = &ex + return driver.ErrRemoveArgument + } + return nil +} + +func toExecutionStatus(state cli_service.TOperationState) ExecutionStatus { + switch state { + + case cli_service.TOperationState_INITIALIZED_STATE: + return ExecutionInitialized + case cli_service.TOperationState_RUNNING_STATE: + return ExecutionRunning + case cli_service.TOperationState_FINISHED_STATE: + return ExecutionFinished + case cli_service.TOperationState_CANCELED_STATE: + return ExecutionCanceled + case cli_service.TOperationState_CLOSED_STATE: + return ExecutionClosed + case cli_service.TOperationState_ERROR_STATE: + return ExecutionError + case cli_service.TOperationState_PENDING_STATE: + return ExecutionPending + case cli_service.TOperationState_TIMEDOUT_STATE: + return ExecutionTimedOut + default: + return ExecutionUnknown + } +} + +func toOperationHandle(ex *Execution) *cli_service.TOperationHandle { + return &cli_service.TOperationHandle{ + OperationId: &cli_service.THandleIdentifier{ + GUID: client.DecodeGuid(ex.Id), + Secret: ex.Secret, + }, + OperationType: cli_service.TOperationType_EXECUTE_STATEMENT, + HasResultSet: ex.HasResultSet, + } +} + +var _ driver.Conn = (*Conn)(nil) +var _ driver.Pinger = (*Conn)(nil) +var _ driver.SessionResetter = (*Conn)(nil) +var _ driver.Validator = (*Conn)(nil) +var _ driver.ExecerContext = (*Conn)(nil) +var _ driver.QueryerContext = (*Conn)(nil) +var _ driver.ConnPrepareContext = (*Conn)(nil) +var _ driver.ConnBeginTx = (*Conn)(nil) diff --git a/connection_test.go b/connection_test.go index ecb4b71e..45f8f585 100644 --- a/connection_test.go +++ b/connection_test.go @@ -26,7 +26,7 @@ func TestConn_executeStatement(t *testing.T) { testClient := &client.TestClient{ FnExecuteStatement: executeStatement, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -76,7 +76,7 @@ func TestConn_executeStatement(t *testing.T) { testClient := &client.TestClient{ FnExecuteStatement: executeStatement, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -102,7 +102,7 @@ func TestConn_pollOperation(t *testing.T) { testClient := &client.TestClient{ FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -132,7 +132,7 @@ func TestConn_pollOperation(t *testing.T) { testClient := &client.TestClient{ FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -162,7 +162,7 @@ func TestConn_pollOperation(t *testing.T) { testClient := &client.TestClient{ FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -192,7 +192,7 @@ func TestConn_pollOperation(t *testing.T) { testClient := &client.TestClient{ FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -222,7 +222,7 @@ func TestConn_pollOperation(t *testing.T) { testClient := &client.TestClient{ FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -254,7 +254,7 @@ func TestConn_pollOperation(t *testing.T) { testClient := &client.TestClient{ FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -296,7 +296,7 @@ func TestConn_pollOperation(t *testing.T) { FnGetOperationStatus: getOperationStatus, FnCancelOperation: cancelOperation, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -342,7 +342,7 @@ func TestConn_pollOperation(t *testing.T) { } cfg := config.WithDefaults() cfg.PollInterval = 100 * time.Millisecond - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: cfg, @@ -388,7 +388,7 @@ func TestConn_pollOperation(t *testing.T) { } cfg := config.WithDefaults() cfg.PollInterval = 100 * time.Millisecond - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: cfg, @@ -423,7 +423,7 @@ func TestConn_runQuery(t *testing.T) { testClient := &client.TestClient{ FnExecuteStatement: executeStatement, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -465,7 +465,7 @@ func TestConn_runQuery(t *testing.T) { FnExecuteStatement: executeStatement, FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -509,7 +509,7 @@ func TestConn_runQuery(t *testing.T) { FnExecuteStatement: executeStatement, FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -553,7 +553,7 @@ func TestConn_runQuery(t *testing.T) { FnExecuteStatement: executeStatement, FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -605,7 +605,7 @@ func TestConn_runQuery(t *testing.T) { FnExecuteStatement: executeStatement, FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -657,7 +657,7 @@ func TestConn_runQuery(t *testing.T) { FnExecuteStatement: executeStatement, FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -709,7 +709,7 @@ func TestConn_runQuery(t *testing.T) { FnExecuteStatement: executeStatement, FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -761,7 +761,7 @@ func TestConn_runQuery(t *testing.T) { FnExecuteStatement: executeStatement, FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -782,7 +782,7 @@ func TestConn_ExecContext(t *testing.T) { var executeStatementCount int testClient := &client.TestClient{} - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -817,7 +817,7 @@ func TestConn_ExecContext(t *testing.T) { testClient := &client.TestClient{ FnExecuteStatement: executeStatement, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -860,7 +860,7 @@ func TestConn_ExecContext(t *testing.T) { FnExecuteStatement: executeStatement, FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -881,7 +881,7 @@ func TestConn_QueryContext(t *testing.T) { var executeStatementCount int testClient := &client.TestClient{} - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -916,7 +916,7 @@ func TestConn_QueryContext(t *testing.T) { testClient := &client.TestClient{ FnExecuteStatement: executeStatement, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -959,7 +959,7 @@ func TestConn_QueryContext(t *testing.T) { FnExecuteStatement: executeStatement, FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -994,7 +994,7 @@ func TestConn_Ping(t *testing.T) { testClient := &client.TestClient{ FnExecuteStatement: executeStatement, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -1037,7 +1037,7 @@ func TestConn_Ping(t *testing.T) { FnGetOperationStatus: getOperationStatus, } - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: testClient, cfg: config.WithDefaults(), @@ -1051,7 +1051,7 @@ func TestConn_Ping(t *testing.T) { func TestConn_Begin(t *testing.T) { t.Run("Begin not supported", func(t *testing.T) { - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: &client.TestClient{}, cfg: config.WithDefaults(), @@ -1064,7 +1064,7 @@ func TestConn_Begin(t *testing.T) { func TestConn_BeginTx(t *testing.T) { t.Run("BeginTx not supported", func(t *testing.T) { - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: &client.TestClient{}, cfg: config.WithDefaults(), @@ -1077,7 +1077,7 @@ func TestConn_BeginTx(t *testing.T) { func TestConn_ResetSession(t *testing.T) { t.Run("ResetSession not currently supported", func(t *testing.T) { - testConn := &conn{ + testConn := &Conn{ session: getTestSession(), client: &client.TestClient{}, cfg: config.WithDefaults(), diff --git a/connector.go b/connector.go index 74f78d8a..b615263c 100644 --- a/connector.go +++ b/connector.go @@ -59,7 +59,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, errors.New("databricks: invalid open session response") } - conn := &conn{ + conn := &Conn{ id: client.SprintGuid(session.SessionHandle.GetSessionId().GUID), cfg: c.cfg, client: tclient, diff --git a/db.go b/db.go new file mode 100644 index 00000000..555e3460 --- /dev/null +++ b/db.go @@ -0,0 +1,149 @@ +package dbsql + +import ( + "context" + "database/sql" + "database/sql/driver" + + "github.com/databricks/databricks-sql-go/driverctx" + "github.com/pkg/errors" +) + +type DatabricksDB interface { + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, Execution, error) + CancelExecution(ctx context.Context, exc Execution) error + GetExecutionRows(ctx context.Context, exc Execution) (*sql.Rows, error) + CheckExecution(ctx context.Context, exc Execution) (Execution, error) + Close() error + Stats() sql.DBStats + SetMaxOpenConns(n int) +} + +type databricksDB struct { + sqldb *sql.DB +} + +func OpenDB(c driver.Connector) DatabricksDB { + cnnr := c.(*connector) + cnnr.cfg.RunAsync = true + db := sql.OpenDB(c) + return &databricksDB{db} +} + +func (db *databricksDB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, Execution, error) { + exc := Execution{} + ctx2 := newContextWithExecution(ctx, &exc) + rs, err := db.sqldb.QueryContext(ctx2, query, args...) + if exc.Status != ExecutionFinished && rs != nil { + rs.Close() + } + return rs, exc, err +} + +func (db *databricksDB) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, string, error) { + // db.sqldb.ExecContext() + return nil, "", errors.New(ErrNotImplemented) +} + +func (db *databricksDB) Close() error { + return db.sqldb.Close() +} + +func (db *databricksDB) Stats() sql.DBStats { + return db.sqldb.Stats() +} + +func (db *databricksDB) SetMaxOpenConns(n int) { + db.sqldb.SetMaxOpenConns(n) +} + +func (db *databricksDB) CancelExecution(ctx context.Context, exc Execution) error { + con, err := db.sqldb.Conn(ctx) + if err != nil { + return err + } + defer con.Close() + return con.Raw(func(driverConn any) error { + dbsqlcon, ok := driverConn.(*Conn) + if !ok { + return errors.New("invalid connection type") + } + return dbsqlcon.cancelOperation(ctx, exc) + }) +} + +func (db *databricksDB) CheckExecution(ctx context.Context, exc Execution) (Execution, error) { + con, err := db.sqldb.Conn(ctx) + if err != nil { + return exc, err + } + defer con.Close() + exRet := exc + err = con.Raw(func(driverConn any) error { + dbsqlcon, ok := driverConn.(*Conn) + if !ok { + return errors.New("invalid connection type") + } + exRet, err = dbsqlcon.getOperationStatus(ctx, exc) + return err + }) + return exRet, err +} + +func (db *databricksDB) GetExecutionRows(ctx context.Context, exc Execution) (*sql.Rows, error) { + return db.sqldb.QueryContext(ctx, "", exc) +} + +func (db *databricksDB) GetExecutionResult(ctx context.Context, exc Execution) (sql.Result, error) { + return db.sqldb.ExecContext(ctx, "", exc) +} + +type Execution struct { + Status ExecutionStatus + Id string + Secret []byte + HasResultSet bool +} + +type ExecutionStatus string + +const ( + // live state Initialized + ExecutionInitialized ExecutionStatus = "Initialized" + // live state Running + ExecutionRunning ExecutionStatus = "Running" + // terminal state Finished + ExecutionFinished ExecutionStatus = "Finished" + // terminal state Canceled + ExecutionCanceled ExecutionStatus = "Canceled" + // terminal state Closed + ExecutionClosed ExecutionStatus = "Closed" + // terminal state Error + ExecutionError ExecutionStatus = "Error" + ExecutionUnknown ExecutionStatus = "Unknown" + // live state Pending + ExecutionPending ExecutionStatus = "Pending" + // terminal state TimedOut + ExecutionTimedOut ExecutionStatus = "TimedOut" +) + +func (e ExecutionStatus) Terminal() bool { + switch e { + case ExecutionInitialized, ExecutionPending, ExecutionRunning: + return false + default: + return true + } +} + +func newContextWithExecution(ctx context.Context, exc *Execution) context.Context { + return context.WithValue(ctx, driverctx.ExecutionContextKey, exc) +} + +func excFromContext(ctx context.Context) *Execution { + excId, ok := ctx.Value(driverctx.ExecutionContextKey).(*Execution) + if !ok { + return nil + } + return excId +} diff --git a/db_test.go b/db_test.go new file mode 100644 index 00000000..08a359d0 --- /dev/null +++ b/db_test.go @@ -0,0 +1 @@ +package dbsql diff --git a/driver.go b/driver.go index 80e54ef8..68d4eead 100644 --- a/driver.go +++ b/driver.go @@ -40,41 +40,3 @@ func (d *databricksDriver) OpenConnector(dsn string) (driver.Connector, error) { var _ driver.Driver = (*databricksDriver)(nil) var _ driver.DriverContext = (*databricksDriver)(nil) - -// type databricksDB struct { -// *sql.DB -// } - -// func OpenDB(c driver.Connector) *databricksDB { -// db := sql.OpenDB(c) -// return &databricksDB{db} -// } - -// func (db *databricksDB) QueryContextAsync(ctx context.Context, query string, args ...any) (rows *sql.Rows, queryId string, err error) { -// return nil, "", nil -// } - -// func (db *databricksDB) ExecContextAsync(ctx context.Context, query string, args ...any) (result sql.Result, queryId string) { -// //go do something -// return nil, "" -// } - -// func (db *databricksDB) CancelQuery(ctx context.Context, queryId string) error { -// //go do something -// return nil -// } - -// func (db *databricksDB) GetQueryStatus(ctx context.Context, queryId string) error { -// //go do something -// return nil -// } - -// func (db *databricksDB) FetchRows(ctx context.Context, queryId string) (rows *sql.Rows, err error) { -// //go do something -// return nil, nil -// } - -// func (db *databricksDB) FetchResult(ctx context.Context, queryId string) (rows sql.Result, err error) { -// //go do something -// return nil, nil -// } diff --git a/driverctx/ctx.go b/driverctx/ctx.go index 21397f5f..60975a66 100644 --- a/driverctx/ctx.go +++ b/driverctx/ctx.go @@ -11,6 +11,7 @@ type contextKey int const ( CorrelationIdContextKey contextKey = iota ConnIdContextKey + ExecutionContextKey ) // NewContextWithCorrelationId creates a new context with correlationId value. Used by Logger to populate field corrId. diff --git a/examples/asyncWorkflow/main.go b/examples/asyncWorkflow/main.go new file mode 100644 index 00000000..ddb960c5 --- /dev/null +++ b/examples/asyncWorkflow/main.go @@ -0,0 +1,151 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "strconv" + "time" + + dbsql "github.com/databricks/databricks-sql-go" + dbsqlctx "github.com/databricks/databricks-sql-go/driverctx" + dbsqllog "github.com/databricks/databricks-sql-go/logger" + "github.com/joho/godotenv" +) + +func main() { + // use this package to set up logging. By default logging level is `warn`. If you want to disable logging, use `disabled` + if err := dbsqllog.SetLogLevel("debug"); err != nil { + panic(err) + } + // sets the logging output. By default it will use os.Stderr. If running in terminal, it will use ConsoleWriter to make it pretty + // dbsqllog.SetLogOutput(os.Stdout) + + // this is just to make it easy to load all variables + if err := godotenv.Load(); err != nil { + panic(err) + } + port, err := strconv.Atoi(os.Getenv("DATABRICKS_PORT")) + if err != nil { + panic(err) + } + + // programmatically initializes the connector + // another way is to use a DNS. In this case the equivalent DNS would be: + // "token:@hostname:port/http_path?catalog=hive_metastore&schema=default&timeout=60&maxRows=10&&timezone=America/Sao_Paulo&ANSI_MODE=true" + connector, err := dbsql.NewConnector( + // minimum configuration + dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")), + dbsql.WithPort(port), + dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTPPATH")), + dbsql.WithAccessToken(os.Getenv("DATABRICKS_ACCESSTOKEN")), + //optional configuration + dbsql.WithUserAgentEntry("workflow-example"), + dbsql.WithInitialNamespace("hive_metastore", "default"), + dbsql.WithTimeout(time.Minute), // defaults to no timeout. Global timeout. Any query will be canceled if taking more than this time. + dbsql.WithMaxRows(10), // defaults to 10000 + ) + if err != nil { + // This will not be a connection error, but a DSN parse error or + // another initialization error. + panic(err) + + } + // Opening a driver typically will not attempt to connect to the database. + db := dbsql.OpenDB(connector) + // make sure to close it later + defer db.Close() + + db.SetMaxOpenConns(2) + + // the "github.com/databricks/databricks-sql-go/driverctx" has some functions to help set the context for the driver + ogCtx := dbsqlctx.NewContextWithCorrelationId(context.Background(), "asyncWorkflow-example") + + // for _, v := range []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"} { + // i := v + // go func() { + // _, exc, err := db.QueryContext(ogCtx, fmt.Sprintf("select %s", i)) + rs, exc, err := db.QueryContext(ogCtx, `SELECT id FROM RANGE(100) ORDER BY RANDOM() + 2 asc`) + if err != nil { + log.Fatal(err) + } + defer rs.Close() + // can't do this. If direct results is done, the operation is gone + exc, err = db.CheckExecution(ogCtx, exc) + if err != nil { + log.Fatal(err) + } + + if exc.Status == dbsql.ExecutionFinished { + rs, err = db.GetExecutionRows(ogCtx, exc) + if err != nil { + log.Fatal(err) + } + var res string + i := 0 + for rs.Next() { + err := rs.Scan(&res) + if err != nil { + fmt.Println(err) + rs.Close() + return + } + fmt.Println(res) + if i < 10 { + i++ + } else { + return + } + } + } + for { + if exc.Status.Terminal() { + break + } else { + exc, err = db.CheckExecution(ogCtx, exc) + if err != nil { + log.Fatal(err) + } + } + fmt.Println(db.Stats()) + time.Sleep(time.Second) + } + + if exc.Status == dbsql.ExecutionFinished { + rs, err = db.GetExecutionRows(ogCtx, exc) + if err != nil { + log.Fatal(err) + } + var res string + i := 0 + for rs.Next() { + err := rs.Scan(&res) + if err != nil { + fmt.Println(err) + break + } + fmt.Println(res) + if i < 12 { + i++ + } else { + return + } + } + } else { + fmt.Println(exc.Status) + } + // } + // timezones are also supported + // var curTimestamp time.Time + // var curDate time.Time + // var curTimezone string + // if err := db.QueryRowContext(ogCtx, `select current_date(), current_timestamp(), current_timezone()`).Scan(&curDate, &curTimestamp, &curTimezone); err != nil { + // panic(err) + // } else { + // // this will print now at timezone America/Sao_Paulo is: 2022-11-16 20:25:15.282 -0300 -03 + // fmt.Printf("current timestamp at timezone %s is: %s\n", curTimezone, curTimestamp) + // fmt.Printf("current date at timezone %s is: %s\n", curTimezone, curDate) + // } + +} diff --git a/internal/client/client.go b/internal/client/client.go index e76441b7..271d87db 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -3,6 +3,7 @@ package client import ( "compress/zlib" "context" + "encoding/hex" "encoding/json" "fmt" "net/http" @@ -271,3 +272,21 @@ func SprintGuid(bts []byte) string { logger.Warn().Msgf("GUID not valid: %x", bts) return fmt.Sprintf("%x", bts) } + +func DecodeGuid(str string) []byte { + if len(str) == 36 { + bts, err := hex.DecodeString(str[0:8] + str[9:13] + str[14:18] + str[19:23] + str[24:36]) + if err != nil { + logger.Warn().Msgf("GUID not valid: %s", str) + return []byte{} + } + return bts + } + logger.Warn().Msgf("GUID not valid: %s", str) + bts, err := hex.DecodeString(str) + if err != nil { + logger.Warn().Msgf("GUID not valid: %s", str) + return []byte{} + } + return bts +} diff --git a/internal/config/config.go b/internal/config/config.go index 19c47a55..15d81027 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -20,7 +20,7 @@ type Config struct { TLSConfig *tls.Config // nil disables TLS Authenticator string //TODO for oauth - RunAsync bool // TODO + RunAsync bool PollInterval time.Duration ConnectTimeout time.Duration // max time to open session ClientTimeout time.Duration // max time the http request can last @@ -134,7 +134,7 @@ func WithDefaults() *Config { UserConfig: UserConfig{}.WithDefaults(), TLSConfig: &tls.Config{MinVersion: tls.VersionTLS12}, Authenticator: "", - RunAsync: true, + RunAsync: false, PollInterval: 1 * time.Second, ConnectTimeout: 60 * time.Second, ClientTimeout: 900 * time.Second, diff --git a/rows.go b/rows.go index 5c9da4ca..a8d2e671 100644 --- a/rows.go +++ b/rows.go @@ -51,6 +51,10 @@ func (r *rows) Columns() []string { return []string{} } + if r.opHandle == nil { + return []string{} + } + resultMetadata, err := r.getResultMetadata() if err != nil { return []string{} @@ -76,15 +80,17 @@ func (r *rows) Close() error { if err != nil { return err } + if r.opHandle != nil { - req := cli_service.TCloseOperationReq{ - OperationHandle: r.opHandle, - } - ctx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), r.connId), r.correlationId) + req := cli_service.TCloseOperationReq{ + OperationHandle: r.opHandle, + } + ctx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), r.connId), r.correlationId) - _, err1 := r.client.CloseOperation(ctx, &req) - if err1 != nil { - return err1 + _, err1 := r.client.CloseOperation(ctx, &req) + if err1 != nil { + return err1 + } } return nil } @@ -103,6 +109,9 @@ func (r *rows) Next(dest []driver.Value) error { if err != nil { return err } + if r.opHandle == nil { + return io.EOF + } // if the next row is not in the current result page // fetch the containing page @@ -334,6 +343,9 @@ func (r *rows) getResultMetadata() (*cli_service.TGetResultSetMetadataResp, erro if err != nil { return nil, err } + if r.opHandle == nil { + return nil, errors.New("metadata not available") + } req := cli_service.TGetResultSetMetadataReq{ OperationHandle: r.opHandle, diff --git a/statement.go b/statement.go index 940649a2..7ade39a9 100644 --- a/statement.go +++ b/statement.go @@ -7,7 +7,7 @@ import ( ) type stmt struct { - conn *conn + conn *Conn query string }