diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go index 6567c0008b..c10b38c880 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go @@ -3047,6 +3047,150 @@ func (suite *BulkIngestTests) TestBulkIngestWithStream() { suite.Equal(int64(5), totalRows) } +func (suite *BulkIngestTests) TestBulkIngestBindStreamBeforeOptions() { + stmt, err := suite.cnxn.NewStatement() + suite.Require().NoError(err) + defer validation.CheckedClose(suite.T(), stmt) + + schema := arrow.NewSchema([]arrow.Field{ + {Name: "batch_id", Type: arrow.PrimitiveTypes.Int32, Nullable: false}, + }, nil) + + bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema) + defer bldr.Release() + + bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{1}, nil) + rec1 := bldr.NewRecordBatch() + bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{2, 3}, nil) + rec2 := bldr.NewRecordBatch() + defer rec1.Release() + defer rec2.Release() + + rdr, err := array.NewRecordReader(schema, []arrow.RecordBatch{rec1, rec2}) + suite.Require().NoError(err) + defer rdr.Release() + + suite.Require().NoError(stmt.BindStream(context.Background(), rdr)) + + suite.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bind_first")) + suite.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeCreate)) + + nRows, err := stmt.ExecuteUpdate(context.Background()) + suite.Require().NoError(err) + suite.Equal(int64(3), nRows) + + requests := suite.server.GetIngestRequests() + suite.Require().Len(requests, 1) + suite.Equal("bind_first", requests[0].GetTable()) +} + +func (suite *BulkIngestTests) TestBulkIngestBindBeforeOptions() { + stmt, err := suite.cnxn.NewStatement() + suite.Require().NoError(err) + defer validation.CheckedClose(suite.T(), stmt) + + schema := arrow.NewSchema([]arrow.Field{ + {Name: "id", Type: arrow.PrimitiveTypes.Int32, Nullable: false}, + }, nil) + + bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema) + defer bldr.Release() + + bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{10, 20}, nil) + rec := bldr.NewRecordBatch() + defer rec.Release() + + suite.Require().NoError(stmt.Bind(context.Background(), rec)) + + suite.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bind_batch_first")) + suite.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeCreate)) + + nRows, err := stmt.ExecuteUpdate(context.Background()) + suite.Require().NoError(err) + suite.Equal(int64(2), nRows) + + requests := suite.server.GetIngestRequests() + suite.Require().Len(requests, 1) + suite.Equal("bind_batch_first", requests[0].GetTable()) +} + +func (suite *BulkIngestTests) TestBulkIngestBindStreamMissingTarget() { + stmt, err := suite.cnxn.NewStatement() + suite.Require().NoError(err) + defer validation.CheckedClose(suite.T(), stmt) + + schema := arrow.NewSchema([]arrow.Field{ + {Name: "batch_id", Type: arrow.PrimitiveTypes.Int32, Nullable: false}, + }, nil) + + bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema) + defer bldr.Release() + + bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{1}, nil) + rec := bldr.NewRecordBatch() + defer rec.Release() + + rdr, err := array.NewRecordReader(schema, []arrow.RecordBatch{rec}) + suite.Require().NoError(err) + defer rdr.Release() + + suite.Require().NoError(stmt.BindStream(context.Background(), rdr)) + + _, err = stmt.ExecuteUpdate(context.Background()) + suite.Require().Error(err) + suite.Contains(err.Error(), "must set IngestTargetTable before bulk ingestion") +} + +func (suite *BulkIngestTests) TestBulkIngestBindMissingTarget() { + stmt, err := suite.cnxn.NewStatement() + suite.Require().NoError(err) + defer validation.CheckedClose(suite.T(), stmt) + + schema := arrow.NewSchema([]arrow.Field{ + {Name: "id", Type: arrow.PrimitiveTypes.Int32, Nullable: false}, + }, nil) + + bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema) + defer bldr.Release() + + bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{1}, nil) + rec := bldr.NewRecordBatch() + defer rec.Release() + + suite.Require().NoError(stmt.Bind(context.Background(), rec)) + + _, err = stmt.ExecuteUpdate(context.Background()) + suite.Require().Error(err) + suite.Contains(err.Error(), "must set IngestTargetTable before bulk ingestion") +} + +func (suite *BulkIngestTests) TestBulkIngestBindStreamMissingTargetExecuteQuery() { + stmt, err := suite.cnxn.NewStatement() + suite.Require().NoError(err) + defer validation.CheckedClose(suite.T(), stmt) + + schema := arrow.NewSchema([]arrow.Field{ + {Name: "batch_id", Type: arrow.PrimitiveTypes.Int32, Nullable: false}, + }, nil) + + bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema) + defer bldr.Release() + + bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{1}, nil) + rec := bldr.NewRecordBatch() + defer rec.Release() + + rdr, err := array.NewRecordReader(schema, []arrow.RecordBatch{rec}) + suite.Require().NoError(err) + defer rdr.Release() + + suite.Require().NoError(stmt.BindStream(context.Background(), rdr)) + + _, _, err = stmt.ExecuteQuery(context.Background()) + suite.Require().Error(err) + suite.Contains(err.Error(), "must set IngestTargetTable before bulk ingestion") +} + func (suite *BulkIngestTests) TestBulkIngestWithoutBind() { stmt, err := suite.cnxn.NewStatement() suite.Require().NoError(err) diff --git a/go/adbc/driver/flightsql/flightsql_statement.go b/go/adbc/driver/flightsql/flightsql_statement.go index eb034ec66f..08d92c44b0 100644 --- a/go/adbc/driver/flightsql/flightsql_statement.go +++ b/go/adbc/driver/flightsql/flightsql_statement.go @@ -503,6 +503,14 @@ func (s *statement) ExecuteQuery(ctx context.Context) (rdr array.RecordReader, n return nil, -1, err } + // Reject staged binds if no ingest target was provided + if s.targetTable == "" && s.prepared == nil && (s.bound != nil || s.streamBind != nil) { + return nil, -1, adbc.Error{ + Msg: "[Flight SQL Statement] must set IngestTargetTable before bulk ingestion", + Code: adbc.StatusInvalidState, + } + } + // Handle bulk ingest if s.targetTable != "" { nrec, err = s.executeIngest(ctx) @@ -535,6 +543,14 @@ func (s *statement) ExecuteUpdate(ctx context.Context) (n int64, err error) { return -1, err } + // Reject staged binds if no ingest target was provided + if s.targetTable == "" && s.prepared == nil && (s.bound != nil || s.streamBind != nil) { + return -1, adbc.Error{ + Msg: "[Flight SQL Statement] must set IngestTargetTable before bulk ingestion", + Code: adbc.StatusInvalidState, + } + } + // Handle bulk ingest if s.targetTable != "" { return s.executeIngest(ctx) @@ -600,66 +616,58 @@ func (s *statement) SetSubstraitPlan(plan []byte) error { // but it may not do this until the statement is closed or another // record is bound. func (s *statement) Bind(_ context.Context, values arrow.RecordBatch) error { - // For bulk ingest, bind to the statement - if s.targetTable != "" { - if s.streamBind != nil { - s.streamBind.Release() - s.streamBind = nil - } - if s.bound != nil { - s.bound.Release() - } - s.bound = values - if s.bound != nil { - s.bound.Retain() - } + if s.targetTable != "" || s.prepared == nil { + s.setBound(values) return nil } - if s.prepared == nil { - return adbc.Error{ - Msg: "[Flight SQL Statement] must call Prepare or set IngestTargetTable before calling Bind", - Code: adbc.StatusInvalidState} - } - - // calls retain s.prepared.SetParameters(values) return nil } +func (s *statement) setBound(values arrow.RecordBatch) { + if s.streamBind != nil { + s.streamBind.Release() + s.streamBind = nil + } + if s.bound != nil { + s.bound.Release() + } + s.bound = values + if s.bound != nil { + s.bound.Retain() + } +} + // BindStream uses a record batch stream to bind parameters for this // query. This can be used for bulk inserts or prepared statements. // // The driver will call Release on the record reader, but may not do this // until Close is called. func (s *statement) BindStream(_ context.Context, stream array.RecordReader) error { - // For bulk ingest, bind to the statement - if s.targetTable != "" { - if s.bound != nil { - s.bound.Release() - s.bound = nil - } - if s.streamBind != nil { - s.streamBind.Release() - } - s.streamBind = stream - if s.streamBind != nil { - s.streamBind.Retain() - } + if s.targetTable != "" || s.prepared == nil { + s.setStreamBound(stream) return nil } - if s.prepared == nil { - return adbc.Error{ - Msg: "[Flight SQL Statement] must call Prepare or set IngestTargetTable before calling Bind", - Code: adbc.StatusInvalidState} - } - - // calls retain s.prepared.SetRecordReader(stream) return nil } +func (s *statement) setStreamBound(stream array.RecordReader) { + if s.bound != nil { + s.bound.Release() + s.bound = nil + } + if s.streamBind != nil { + s.streamBind.Release() + } + s.streamBind = stream + if s.streamBind != nil { + s.streamBind.Retain() + } +} + // GetParameterSchema returns an Arrow schema representation of // the expected parameters to be bound. // diff --git a/go/adbc/ext.go b/go/adbc/ext.go index dca0097097..4f0df04620 100644 --- a/go/adbc/ext.go +++ b/go/adbc/ext.go @@ -100,12 +100,7 @@ func IngestStream(ctx context.Context, cnxn Connection, reader array.RecordReade err = errors.Join(err, stmt.Close()) }() - // Bind the record batch stream - if err = stmt.BindStream(ctx, reader); err != nil { - return -1, fmt.Errorf("error during ingestion: BindStream: %w", err) - } - - // Set required options + // Set required options before binding if err = stmt.SetOption(OptionKeyIngestTargetTable, targetTable); err != nil { return -1, fmt.Errorf("error during ingestion: SetOption(target_table=%s): %w", targetTable, err) } @@ -113,6 +108,11 @@ func IngestStream(ctx context.Context, cnxn Connection, reader array.RecordReade return -1, fmt.Errorf("error during ingestion: SetOption(mode=%s): %w", ingestMode, err) } + // Bind the record batch stream + if err = stmt.BindStream(ctx, reader); err != nil { + return -1, fmt.Errorf("error during ingestion: BindStream: %w", err) + } + // Set other options if provided if opt.Catalog != "" { if err = stmt.SetOption(OptionValueIngestTargetCatalog, opt.Catalog); err != nil { @@ -167,12 +167,7 @@ func IngestStreamContext(ctx context.Context, cnxn ConnectionWithContext, reader err = errors.Join(err, stmt.Close(ctx)) }() - // Bind the record batch stream - if err = stmt.BindStream(ctx, reader); err != nil { - return -1, fmt.Errorf("error during ingestion: BindStream: %w", err) - } - - // Set required options + // Set required options before binding (some drivers require target first) if err = stmt.SetOption(ctx, OptionKeyIngestTargetTable, targetTable); err != nil { return -1, fmt.Errorf("error during ingestion: SetOption(target_table=%s): %w", targetTable, err) } @@ -180,6 +175,11 @@ func IngestStreamContext(ctx context.Context, cnxn ConnectionWithContext, reader return -1, fmt.Errorf("error during ingestion: SetOption(mode=%s): %w", ingestMode, err) } + // Bind the record batch stream + if err = stmt.BindStream(ctx, reader); err != nil { + return -1, fmt.Errorf("error during ingestion: BindStream: %w", err) + } + // Set other options if provided if opt.Catalog != "" { if err = stmt.SetOption(ctx, OptionValueIngestTargetCatalog, opt.Catalog); err != nil {