Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions go/adbc/driver/flightsql/flightsql_adbc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3047,6 +3047,150 @@ func (suite *BulkIngestTests) TestBulkIngestWithStream() {
suite.Equal(int64(5), totalRows)
}

func (suite *BulkIngestTests) TestBulkIngestBindStreamBeforeOptions() {
stmt, err := suite.cnxn.NewStatement()
suite.Require().NoError(err)
defer validation.CheckedClose(suite.T(), stmt)

schema := arrow.NewSchema([]arrow.Field{
{Name: "batch_id", Type: arrow.PrimitiveTypes.Int32, Nullable: false},
}, nil)

bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
defer bldr.Release()

bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{1}, nil)
rec1 := bldr.NewRecordBatch()
bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{2, 3}, nil)
rec2 := bldr.NewRecordBatch()
defer rec1.Release()
defer rec2.Release()

rdr, err := array.NewRecordReader(schema, []arrow.RecordBatch{rec1, rec2})
suite.Require().NoError(err)
defer rdr.Release()

suite.Require().NoError(stmt.BindStream(context.Background(), rdr))

suite.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bind_first"))
suite.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeCreate))

nRows, err := stmt.ExecuteUpdate(context.Background())
suite.Require().NoError(err)
suite.Equal(int64(3), nRows)

requests := suite.server.GetIngestRequests()
suite.Require().Len(requests, 1)
suite.Equal("bind_first", requests[0].GetTable())
}

func (suite *BulkIngestTests) TestBulkIngestBindBeforeOptions() {
stmt, err := suite.cnxn.NewStatement()
suite.Require().NoError(err)
defer validation.CheckedClose(suite.T(), stmt)

schema := arrow.NewSchema([]arrow.Field{
{Name: "id", Type: arrow.PrimitiveTypes.Int32, Nullable: false},
}, nil)

bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
defer bldr.Release()

bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{10, 20}, nil)
rec := bldr.NewRecordBatch()
defer rec.Release()

suite.Require().NoError(stmt.Bind(context.Background(), rec))

suite.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bind_batch_first"))
suite.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeCreate))

nRows, err := stmt.ExecuteUpdate(context.Background())
suite.Require().NoError(err)
suite.Equal(int64(2), nRows)

requests := suite.server.GetIngestRequests()
suite.Require().Len(requests, 1)
suite.Equal("bind_batch_first", requests[0].GetTable())
}

func (suite *BulkIngestTests) TestBulkIngestBindStreamMissingTarget() {
stmt, err := suite.cnxn.NewStatement()
suite.Require().NoError(err)
defer validation.CheckedClose(suite.T(), stmt)

schema := arrow.NewSchema([]arrow.Field{
{Name: "batch_id", Type: arrow.PrimitiveTypes.Int32, Nullable: false},
}, nil)

bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
defer bldr.Release()

bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{1}, nil)
rec := bldr.NewRecordBatch()
defer rec.Release()

rdr, err := array.NewRecordReader(schema, []arrow.RecordBatch{rec})
suite.Require().NoError(err)
defer rdr.Release()

suite.Require().NoError(stmt.BindStream(context.Background(), rdr))

_, err = stmt.ExecuteUpdate(context.Background())
suite.Require().Error(err)
suite.Contains(err.Error(), "must set IngestTargetTable before bulk ingestion")
}

func (suite *BulkIngestTests) TestBulkIngestBindMissingTarget() {
stmt, err := suite.cnxn.NewStatement()
suite.Require().NoError(err)
defer validation.CheckedClose(suite.T(), stmt)

schema := arrow.NewSchema([]arrow.Field{
{Name: "id", Type: arrow.PrimitiveTypes.Int32, Nullable: false},
}, nil)

bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
defer bldr.Release()

bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{1}, nil)
rec := bldr.NewRecordBatch()
defer rec.Release()

suite.Require().NoError(stmt.Bind(context.Background(), rec))

_, err = stmt.ExecuteUpdate(context.Background())
suite.Require().Error(err)
suite.Contains(err.Error(), "must set IngestTargetTable before bulk ingestion")
}

func (suite *BulkIngestTests) TestBulkIngestBindStreamMissingTargetExecuteQuery() {
stmt, err := suite.cnxn.NewStatement()
suite.Require().NoError(err)
defer validation.CheckedClose(suite.T(), stmt)

schema := arrow.NewSchema([]arrow.Field{
{Name: "batch_id", Type: arrow.PrimitiveTypes.Int32, Nullable: false},
}, nil)

bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
defer bldr.Release()

bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{1}, nil)
rec := bldr.NewRecordBatch()
defer rec.Release()

rdr, err := array.NewRecordReader(schema, []arrow.RecordBatch{rec})
suite.Require().NoError(err)
defer rdr.Release()

suite.Require().NoError(stmt.BindStream(context.Background(), rdr))

_, _, err = stmt.ExecuteQuery(context.Background())
suite.Require().Error(err)
suite.Contains(err.Error(), "must set IngestTargetTable before bulk ingestion")
}

func (suite *BulkIngestTests) TestBulkIngestWithoutBind() {
stmt, err := suite.cnxn.NewStatement()
suite.Require().NoError(err)
Expand Down
88 changes: 48 additions & 40 deletions go/adbc/driver/flightsql/flightsql_statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,14 @@ func (s *statement) ExecuteQuery(ctx context.Context) (rdr array.RecordReader, n
return nil, -1, err
}

// Reject staged binds if no ingest target was provided
if s.targetTable == "" && s.prepared == nil && (s.bound != nil || s.streamBind != nil) {
return nil, -1, adbc.Error{
Msg: "[Flight SQL Statement] must set IngestTargetTable before bulk ingestion",
Code: adbc.StatusInvalidState,
}
}

// Handle bulk ingest
if s.targetTable != "" {
nrec, err = s.executeIngest(ctx)
Expand Down Expand Up @@ -535,6 +543,14 @@ func (s *statement) ExecuteUpdate(ctx context.Context) (n int64, err error) {
return -1, err
}

// Reject staged binds if no ingest target was provided
if s.targetTable == "" && s.prepared == nil && (s.bound != nil || s.streamBind != nil) {
return -1, adbc.Error{
Msg: "[Flight SQL Statement] must set IngestTargetTable before bulk ingestion",
Code: adbc.StatusInvalidState,
}
}

// Handle bulk ingest
if s.targetTable != "" {
return s.executeIngest(ctx)
Expand Down Expand Up @@ -600,66 +616,58 @@ func (s *statement) SetSubstraitPlan(plan []byte) error {
// but it may not do this until the statement is closed or another
// record is bound.
func (s *statement) Bind(_ context.Context, values arrow.RecordBatch) error {
// For bulk ingest, bind to the statement
if s.targetTable != "" {
if s.streamBind != nil {
s.streamBind.Release()
s.streamBind = nil
}
if s.bound != nil {
s.bound.Release()
}
s.bound = values
if s.bound != nil {
s.bound.Retain()
}
if s.targetTable != "" || s.prepared == nil {
s.setBound(values)
return nil
}

if s.prepared == nil {
return adbc.Error{
Msg: "[Flight SQL Statement] must call Prepare or set IngestTargetTable before calling Bind",
Code: adbc.StatusInvalidState}
}

// calls retain
s.prepared.SetParameters(values)
return nil
}

func (s *statement) setBound(values arrow.RecordBatch) {
if s.streamBind != nil {
s.streamBind.Release()
s.streamBind = nil
}
if s.bound != nil {
s.bound.Release()
}
s.bound = values
if s.bound != nil {
s.bound.Retain()
}
}

// BindStream uses a record batch stream to bind parameters for this
// query. This can be used for bulk inserts or prepared statements.
//
// The driver will call Release on the record reader, but may not do this
// until Close is called.
func (s *statement) BindStream(_ context.Context, stream array.RecordReader) error {
// For bulk ingest, bind to the statement
if s.targetTable != "" {
if s.bound != nil {
s.bound.Release()
s.bound = nil
}
if s.streamBind != nil {
s.streamBind.Release()
}
s.streamBind = stream
if s.streamBind != nil {
s.streamBind.Retain()
}
if s.targetTable != "" || s.prepared == nil {
s.setStreamBound(stream)
return nil
}

if s.prepared == nil {
return adbc.Error{
Msg: "[Flight SQL Statement] must call Prepare or set IngestTargetTable before calling Bind",
Code: adbc.StatusInvalidState}
}

// calls retain
s.prepared.SetRecordReader(stream)
return nil
}

func (s *statement) setStreamBound(stream array.RecordReader) {
if s.bound != nil {
s.bound.Release()
s.bound = nil
}
if s.streamBind != nil {
s.streamBind.Release()
}
s.streamBind = stream
if s.streamBind != nil {
s.streamBind.Retain()
}
}

// GetParameterSchema returns an Arrow schema representation of
// the expected parameters to be bound.
//
Expand Down
24 changes: 12 additions & 12 deletions go/adbc/ext.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,19 @@ func IngestStream(ctx context.Context, cnxn Connection, reader array.RecordReade
err = errors.Join(err, stmt.Close())
}()

// Bind the record batch stream
if err = stmt.BindStream(ctx, reader); err != nil {
return -1, fmt.Errorf("error during ingestion: BindStream: %w", err)
}

// Set required options
// Set required options before binding
if err = stmt.SetOption(OptionKeyIngestTargetTable, targetTable); err != nil {
return -1, fmt.Errorf("error during ingestion: SetOption(target_table=%s): %w", targetTable, err)
}
if err = stmt.SetOption(OptionKeyIngestMode, ingestMode); err != nil {
return -1, fmt.Errorf("error during ingestion: SetOption(mode=%s): %w", ingestMode, err)
}

// Bind the record batch stream
if err = stmt.BindStream(ctx, reader); err != nil {
return -1, fmt.Errorf("error during ingestion: BindStream: %w", err)
}

// Set other options if provided
if opt.Catalog != "" {
if err = stmt.SetOption(OptionValueIngestTargetCatalog, opt.Catalog); err != nil {
Expand Down Expand Up @@ -167,19 +167,19 @@ func IngestStreamContext(ctx context.Context, cnxn ConnectionWithContext, reader
err = errors.Join(err, stmt.Close(ctx))
}()

// Bind the record batch stream
if err = stmt.BindStream(ctx, reader); err != nil {
return -1, fmt.Errorf("error during ingestion: BindStream: %w", err)
}

// Set required options
// Set required options before binding (some drivers require target first)
if err = stmt.SetOption(ctx, OptionKeyIngestTargetTable, targetTable); err != nil {
return -1, fmt.Errorf("error during ingestion: SetOption(target_table=%s): %w", targetTable, err)
}
if err = stmt.SetOption(ctx, OptionKeyIngestMode, ingestMode); err != nil {
return -1, fmt.Errorf("error during ingestion: SetOption(mode=%s): %w", ingestMode, err)
}

// Bind the record batch stream
if err = stmt.BindStream(ctx, reader); err != nil {
return -1, fmt.Errorf("error during ingestion: BindStream: %w", err)
}

// Set other options if provided
if opt.Catalog != "" {
if err = stmt.SetOption(ctx, OptionValueIngestTargetCatalog, opt.Catalog); err != nil {
Expand Down