From f70e5a5686a5648d14aa9e57279b5b0349e97542 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Fri, 23 Aug 2024 13:27:27 +0300 Subject: [PATCH] [PECO-1752] Refactor CloudFetch downloader (#234) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [PECO-1752] This PR is an attempt to fix a CloudFetch error "row number N is not contained in any arrow batch". See also databricks/databricks-sql-python#405 - basically, the same issue, the same root cause, similar solution. #### The problem In current implementation, all the links from a single TRowSet are added to a concurrent thread (goroutine) pool. The pool downloads them in a random order (all the tasks have the same priority and as a result - same chance to be executed first). To maintain the order of results, `startRowOffset`/`rowCount` fields from each CloudFetch link are used: library keeps track of the current row number, and use it to pick the right CloudFetch link (looking for the file where the current row is within [startRowOffset; startRowOffset + rowCount]). This solution has several caveats. First of all, library allows to fetch data only from beginning to the end. With a concurrent thread pool, you never know which file will be downloaded first. In the worst case, while the user is waiting for the very first file, the library may download all the other ones and keep them in memory because the user may need them in future. This increases the latency (on average it will be okay, but we have no control over it), and also memory consumption. Another problem with this approach is that if any of the files cannot be downloaded - there is no need to download the remaining files, the user won’t be able to process them anyway. But because files are downloaded in arbitrary order - nobody knows how many files will be downloaded before the user reaches the failed one. Also, seems that error handling wasn't done quite right, but that part of code was a bit unclear to me. Anyway, with this fix all the errors are properly handled and propagated to user when needed. #### The solution This PR changes CloudFetch downloader to use a queue. Downloader keeps a list of pending links (not scheduled), and current tasks. Number of tasks is limited, so new files are scheduled only when previous task is completed and extracted from queue. As user requests next files, downloader will pick the first task from the queue, and schedule the new one to run in background - to keep the queue full. Then, downloader will wait for the task it picked from the queue, and then return it to user. Tasks are still running in in parallel in background. Also, each task itself is reponsible for handling errors (e.g. retry failed downloads), so when task completes - it is either eventually successfull, or failed after all possible retries. With this approach, the proper order of files is automatically assured. All errors are either handled in downloader or propagated to user. If some file cannot be downloaded due to error - library will not download the remaining ones (like it did previously). Because new files are downloaded only when user consumes previous ones - library will not keep the whole dataset in memory. [PECO-1752]: https://databricks.atlassian.net/browse/PECO-1752?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ --------- Signed-off-by: Levko Kravets --- internal/fetcher/fetcher.go | 164 ------- internal/fetcher/fetcher_test.go | 123 ------ .../rows/arrowbased/arrowRecordIterator.go | 23 +- .../arrowbased/arrowRecordIterator_test.go | 36 +- internal/rows/arrowbased/arrowRows.go | 12 +- internal/rows/arrowbased/arrowRows_test.go | 177 ++++---- internal/rows/arrowbased/batchloader.go | 414 +++++++++--------- internal/rows/arrowbased/batchloader_test.go | 287 ++++++++---- internal/rows/arrowbased/queue.go | 51 +++ 9 files changed, 573 insertions(+), 714 deletions(-) delete mode 100644 internal/fetcher/fetcher.go delete mode 100644 internal/fetcher/fetcher_test.go create mode 100644 internal/rows/arrowbased/queue.go diff --git a/internal/fetcher/fetcher.go b/internal/fetcher/fetcher.go deleted file mode 100644 index 8430ff0d..00000000 --- a/internal/fetcher/fetcher.go +++ /dev/null @@ -1,164 +0,0 @@ -package fetcher - -import ( - "context" - "sync" - - "github.com/databricks/databricks-sql-go/driverctx" - dbsqllog "github.com/databricks/databricks-sql-go/logger" -) - -type FetchableItems[OutputType any] interface { - Fetch(ctx context.Context) (OutputType, error) -} - -type Fetcher[OutputType any] interface { - Err() error - Start() (<-chan OutputType, context.CancelFunc, error) -} - -type concurrentFetcher[I FetchableItems[O], O any] struct { - cancelChan chan bool - inputChan <-chan FetchableItems[O] - outChan chan O - err error - nWorkers int - mu sync.Mutex - start sync.Once - ctx context.Context - cancelFunc context.CancelFunc - *dbsqllog.DBSQLLogger -} - -func (rf *concurrentFetcher[I, O]) Err() error { - rf.mu.Lock() - defer rf.mu.Unlock() - return rf.err -} - -func (f *concurrentFetcher[I, O]) Start() (<-chan O, context.CancelFunc, error) { - f.start.Do(func() { - // wait group for the worker routines - var wg sync.WaitGroup - - for i := 0; i < f.nWorkers; i++ { - - // increment wait group - wg.Add(1) - - f.logger().Trace().Msgf("concurrent fetcher starting worker %d", i) - go func(x int) { - // when work function remove one from the wait group - defer wg.Done() - // do the actual work - work(f, x) - f.logger().Trace().Msgf("concurrent fetcher worker %d done", x) - }(i) - - } - - // We want to close the output channel when all - // the workers are finished. This way the client won't - // be stuck waiting on the output channel. - go func() { - wg.Wait() - f.logger().Trace().Msg("concurrent fetcher closing output channel") - close(f.outChan) - }() - - // We return a cancel function so that the client can - // cancel fetching. - var cancelOnce sync.Once = sync.Once{} - f.cancelFunc = func() { - f.logger().Trace().Msg("concurrent fetcher cancel func") - cancelOnce.Do(func() { - f.logger().Trace().Msg("concurrent fetcher closing cancel channel") - close(f.cancelChan) - }) - } - }) - - return f.outChan, f.cancelFunc, nil -} - -func (f *concurrentFetcher[I, O]) setErr(err error) { - f.mu.Lock() - if f.err == nil { - f.err = err - } - f.mu.Unlock() -} - -func (f *concurrentFetcher[I, O]) logger() *dbsqllog.DBSQLLogger { - if f.DBSQLLogger == nil { - - f.DBSQLLogger = dbsqllog.WithContext(driverctx.ConnIdFromContext(f.ctx), driverctx.CorrelationIdFromContext(f.ctx), "") - - } - return f.DBSQLLogger -} - -func NewConcurrentFetcher[I FetchableItems[O], O any](ctx context.Context, nWorkers, maxItemsInMemory int, inputChan <-chan FetchableItems[O]) (Fetcher[O], error) { - if nWorkers < 1 { - nWorkers = 1 - } - if maxItemsInMemory < 1 { - maxItemsInMemory = 1 - } - - // channel for loaded items - // TODO: pass buffer size - outputChannel := make(chan O, maxItemsInMemory) - - // channel to signal a cancel - stopChannel := make(chan bool) - - if ctx == nil { - ctx = context.Background() - } - - fetcher := &concurrentFetcher[I, O]{ - inputChan: inputChan, - outChan: outputChannel, - cancelChan: stopChannel, - ctx: ctx, - nWorkers: nWorkers, - } - - return fetcher, nil -} - -func work[I FetchableItems[O], O any](f *concurrentFetcher[I, O], workerIndex int) { - - for { - select { - case <-f.cancelChan: - f.logger().Debug().Msgf("concurrent fetcher worker %d received cancel signal", workerIndex) - return - - case <-f.ctx.Done(): - f.logger().Debug().Msgf("concurrent fetcher worker %d context done", workerIndex) - return - - case input, ok := <-f.inputChan: - if ok { - f.logger().Trace().Msgf("concurrent fetcher worker %d loading item", workerIndex) - result, err := input.Fetch(f.ctx) - if err != nil { - f.logger().Trace().Msgf("concurrent fetcher worker %d received error", workerIndex) - f.setErr(err) - f.cancelFunc() - return - } else { - f.logger().Trace().Msgf("concurrent fetcher worker %d item loaded", workerIndex) - f.outChan <- result - } - } else { - f.logger().Trace().Msgf("concurrent fetcher ending %d", workerIndex) - return - } - - } - } - -} diff --git a/internal/fetcher/fetcher_test.go b/internal/fetcher/fetcher_test.go deleted file mode 100644 index dbe6ced0..00000000 --- a/internal/fetcher/fetcher_test.go +++ /dev/null @@ -1,123 +0,0 @@ -package fetcher - -import ( - "context" - "math" - "testing" - "time" - - "github.com/pkg/errors" -) - -// Create a mock struct for FetchableItems -type mockFetchableItem struct { - item int - wait time.Duration -} - -type mockOutput struct { - item int -} - -// Implement the Fetch method -func (m *mockFetchableItem) Fetch(ctx context.Context) ([]*mockOutput, error) { - time.Sleep(m.wait) - outputs := make([]*mockOutput, 5) - for i := range outputs { - sampleOutput := mockOutput{item: m.item} - outputs[i] = &sampleOutput - } - return outputs, nil -} - -var _ FetchableItems[[]*mockOutput] = (*mockFetchableItem)(nil) - -func TestConcurrentFetcher(t *testing.T) { - t.Run("Comprehensively tests the concurrent fetcher", func(t *testing.T) { - ctx := context.Background() - - inputChan := make(chan FetchableItems[[]*mockOutput], 10) - for i := 0; i < 10; i++ { - item := mockFetchableItem{item: i, wait: 1 * time.Second} - inputChan <- &item - } - close(inputChan) - - // Create a fetcher - fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 3, 3, inputChan) - if err != nil { - t.Fatalf("Error creating fetcher: %v", err) - } - - start := time.Now() - outChan, _, err := fetcher.Start() - if err != nil { - t.Fatalf("Error starting fetcher: %v", err) - } - - var results []*mockOutput - for result := range outChan { - results = append(results, result...) - } - - // Check if the fetcher returned the expected results - expectedLen := 50 - if len(results) != expectedLen { - t.Errorf("Expected %d results, got %d", expectedLen, len(results)) - } - - // Check if the fetcher returned an error - if fetcher.Err() != nil { - t.Errorf("Fetcher returned an error: %v", fetcher.Err()) - } - - // Check if the fetcher took around the estimated amount of time - timeElapsed := time.Since(start) - rounds := int(math.Ceil(float64(10) / 3)) - expectedTime := time.Duration(rounds) * time.Second - buffer := 100 * time.Millisecond - if timeElapsed-expectedTime > buffer { - t.Errorf("Expected fetcher to take around %d ms, took %d ms", int64(expectedTime/time.Millisecond), int64(timeElapsed/time.Millisecond)) - } - }) - - t.Run("Cancel the concurrent fetcher", func(t *testing.T) { - // Create a context with a timeout - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - - // Create an input channel - inputChan := make(chan FetchableItems[[]*mockOutput], 3) - for i := 0; i < 3; i++ { - item := mockFetchableItem{item: i, wait: 1 * time.Second} - inputChan <- &item - } - close(inputChan) - - // Create a new fetcher - fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 2, 2, inputChan) - if err != nil { - t.Fatalf("Error creating fetcher: %v", err) - } - - // Start the fetcher - outChan, cancelFunc, err := fetcher.Start() - if err != nil { - t.Fatal(err) - } - - // Ensure that the fetcher is cancelled successfully - go func() { - cancelFunc() - }() - - for range outChan { - // Just drain the channel - } - - // Check if an error occurred - if err := fetcher.Err(); err != nil && !errors.Is(err, context.DeadlineExceeded) { - t.Fatalf("unexpected error: %v", err) - } - }) -} diff --git a/internal/rows/arrowbased/arrowRecordIterator.go b/internal/rows/arrowbased/arrowRecordIterator.go index 898d0a45..583cbd04 100644 --- a/internal/rows/arrowbased/arrowRecordIterator.go +++ b/internal/rows/arrowbased/arrowRecordIterator.go @@ -163,29 +163,10 @@ func (ri *arrowRecordIterator) getBatchIterator() error { // Create a new batch iterator from a page of the result set func (ri *arrowRecordIterator) newBatchIterator(fr *cli_service.TFetchResultsResp) (BatchIterator, error) { - bl, err := ri.newBatchLoader(fr) - if err != nil { - return nil, err - } - - bi, err := NewBatchIterator(bl) - - return bi, err -} - -// Create a new batch loader from a page of the result set -func (ri *arrowRecordIterator) newBatchLoader(fr *cli_service.TFetchResultsResp) (BatchLoader, error) { rowSet := fr.Results - var bl BatchLoader - var err error if len(rowSet.ResultLinks) > 0 { - bl, err = NewCloudBatchLoader(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, &ri.cfg) + return NewCloudBatchIterator(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, &ri.cfg) } else { - bl, err = NewLocalBatchLoader(ri.ctx, rowSet.ArrowBatches, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg) + return NewLocalBatchIterator(ri.ctx, rowSet.ArrowBatches, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg) } - if err != nil { - return nil, err - } - - return bl, nil } diff --git a/internal/rows/arrowbased/arrowRecordIterator_test.go b/internal/rows/arrowbased/arrowRecordIterator_test.go index a3e4040c..a3b67687 100644 --- a/internal/rows/arrowbased/arrowRecordIterator_test.go +++ b/internal/rows/arrowbased/arrowRecordIterator_test.go @@ -19,6 +19,8 @@ import ( func TestArrowRecordIterator(t *testing.T) { t.Run("with direct results", func(t *testing.T) { + logger := dbsqllog.WithContext("connectionId", "correlationId", "") + executeStatementResp := cli_service.TExecuteStatementResp{} loadTestData2(t, "directResultsMultipleFetch/ExecuteStatement.json", &executeStatementResp) @@ -30,32 +32,37 @@ func TestArrowRecordIterator(t *testing.T) { var fetchesInfo []fetchResultsInfo - client := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1, fetchResp2}) - logger := dbsqllog.WithContext("connectionId", "correlationId", "") + simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1, fetchResp2}) rpi := rowscanner.NewResultPageIterator( rowscanner.NewDelimiter(0, 7311), 5000, nil, false, - client, + simpleClient, "connectionId", "correlationId", - logger) + logger, + ) - bl, err := NewLocalBatchLoader( + cfg := *config.WithDefaults() + + bi, err := NewLocalBatchIterator( context.Background(), executeStatementResp.DirectResults.ResultSet.Results.ArrowBatches, 0, executeStatementResp.DirectResults.ResultSetMetadata.ArrowSchema, - nil, + &cfg, ) - assert.Nil(t, err) - bi, err := NewBatchIterator(bl) assert.Nil(t, err) - cfg := *config.WithDefaults() - rs := NewArrowRecordIterator(context.Background(), rpi, bi, executeStatementResp.DirectResults.ResultSetMetadata.ArrowSchema, cfg) + rs := NewArrowRecordIterator( + context.Background(), + rpi, + bi, + executeStatementResp.DirectResults.ResultSetMetadata.ArrowSchema, + cfg, + ) defer rs.Close() hasNext := rs.HasNext() @@ -108,6 +115,7 @@ func TestArrowRecordIterator(t *testing.T) { }) t.Run("no direct results", func(t *testing.T) { + logger := dbsqllog.WithContext("connectionId", "correlationId", "") fetchResp1 := cli_service.TFetchResultsResp{} loadTestData2(t, "multipleFetch/FetchResults1.json", &fetchResp1) @@ -120,17 +128,17 @@ func TestArrowRecordIterator(t *testing.T) { var fetchesInfo []fetchResultsInfo - client := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1, fetchResp2, fetchResp3}) - logger := dbsqllog.WithContext("connectionId", "correlationId", "") + simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1, fetchResp2, fetchResp3}) rpi := rowscanner.NewResultPageIterator( rowscanner.NewDelimiter(0, 0), 5000, nil, false, - client, + simpleClient, "connectionId", "correlationId", - logger) + logger, + ) cfg := *config.WithDefaults() rs := NewArrowRecordIterator(context.Background(), rpi, nil, nil, cfg) diff --git a/internal/rows/arrowbased/arrowRows.go b/internal/rows/arrowbased/arrowRows.go index 89fe9b94..f6a60c58 100644 --- a/internal/rows/arrowbased/arrowRows.go +++ b/internal/rows/arrowbased/arrowRows.go @@ -112,27 +112,21 @@ func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp return nil, dbsqlerrint.NewDriverError(ctx, errArrowRowsToTimestampFn, err) } - var bl BatchLoader + var bi BatchIterator var err2 dbsqlerr.DBError if len(rowSet.ResultLinks) > 0 { logger.Debug().Msgf("Initialize CloudFetch loader, row set start offset: %d, file list:", rowSet.StartRowOffset) for _, resultLink := range rowSet.ResultLinks { logger.Debug().Msgf("- start row offset: %d, row count: %d", resultLink.StartRowOffset, resultLink.RowCount) } - bl, err2 = NewCloudBatchLoader(context.Background(), rowSet.ResultLinks, rowSet.StartRowOffset, cfg) - logger.Debug().Msgf("Created CloudFetch concurrent loader, rows range [%d..%d]", bl.Start(), bl.End()) + bi, err2 = NewCloudBatchIterator(context.Background(), rowSet.ResultLinks, rowSet.StartRowOffset, cfg) } else { - bl, err2 = NewLocalBatchLoader(context.Background(), rowSet.ArrowBatches, rowSet.StartRowOffset, schemaBytes, cfg) + bi, err2 = NewLocalBatchIterator(context.Background(), rowSet.ArrowBatches, rowSet.StartRowOffset, schemaBytes, cfg) } if err2 != nil { return nil, err2 } - bi, err := NewBatchIterator(bl) - if err != nil { - return nil, err2 - } - var location *time.Location = time.UTC if cfg != nil { if cfg.Location != nil { diff --git a/internal/rows/arrowbased/arrowRows_test.go b/internal/rows/arrowbased/arrowRows_test.go index c693ff56..c43674eb 100644 --- a/internal/rows/arrowbased/arrowRows_test.go +++ b/internal/rows/arrowbased/arrowRows_test.go @@ -469,7 +469,6 @@ func TestArrowRowScanner(t *testing.T) { }) t.Run("Create column value holders on first batch load", func(t *testing.T) { - rowSet := &cli_service.TRowSet{ ArrowBatches: []*cli_service.TSparkArrowBatch{ {RowCount: 5}, @@ -494,13 +493,13 @@ func TestArrowRowScanner(t *testing.T) { &sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(2, 3), Record: &fakeRecord{}}}} b2 := &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(5, 3), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(5, 3), Record: &fakeRecord{}}}} b3 := &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(8, 7), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(8, 7), Record: &fakeRecord{}}}} - fbl := &fakeBatchLoader{ - Delimiter: rowscanner.NewDelimiter(0, 15), + + fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{b1, b2, b3}, + index: -1, + callCount: 0, } - var e dbsqlerr.DBError - ars.batchIterator, e = NewBatchIterator(fbl) - assert.Nil(t, e) + ars.batchIterator = fbi var callCount int ars.valueContainerMaker = &fakeValueContainerMaker{fnMakeColumnValuesContainers: func(ars *arrowRowScanner, d rowscanner.Delimiter) dbsqlerr.DBError { @@ -517,25 +516,25 @@ func TestArrowRowScanner(t *testing.T) { assert.Nil(t, err) assert.Equal(t, len(metadataResp.Schema.Columns), ars.rowValues.NColumns()) assert.Equal(t, 1, callCount) - assert.Equal(t, 1, fbl.callCount) + assert.Equal(t, 1, fbi.callCount) err = ars.loadBatchFor(1) assert.Nil(t, err) assert.Equal(t, len(metadataResp.Schema.Columns), ars.rowValues.NColumns()) assert.Equal(t, 1, callCount) - assert.Equal(t, 1, fbl.callCount) + assert.Equal(t, 1, fbi.callCount) err = ars.loadBatchFor(2) assert.Nil(t, err) assert.Equal(t, len(metadataResp.Schema.Columns), ars.rowValues.NColumns()) assert.Equal(t, 1, callCount) - assert.Equal(t, 1, fbl.callCount) + assert.Equal(t, 1, fbi.callCount) err = ars.loadBatchFor(5) assert.Nil(t, err) assert.Equal(t, len(metadataResp.Schema.Columns), ars.rowValues.NColumns()) assert.Equal(t, 1, callCount) - assert.Equal(t, 2, fbl.callCount) + assert.Equal(t, 2, fbi.callCount) }) @@ -557,25 +556,24 @@ func TestArrowRowScanner(t *testing.T) { var ars *arrowRowScanner = d.(*arrowRowScanner) - fbl := &fakeBatchLoader{ - Delimiter: rowscanner.NewDelimiter(0, 15), + fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{ &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(0, 5), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(0, 5), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(5, 3), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(5, 3), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(8, 7), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(8, 7), Record: &fakeRecord{}}}}, }, + index: -1, + callCount: 0, } - var e dbsqlerr.DBError - ars.batchIterator, e = NewBatchIterator(fbl) - assert.Nil(t, e) + ars.batchIterator = fbi err := ars.loadBatchFor(0) assert.Nil(t, err) - assert.Equal(t, 1, fbl.callCount) + assert.Equal(t, 1, fbi.callCount) err = ars.loadBatchFor(0) assert.Nil(t, err) - assert.Equal(t, 1, fbl.callCount) + assert.Equal(t, 1, fbi.callCount) }) t.Run("loadBatch index out of bounds", func(t *testing.T) { @@ -596,17 +594,16 @@ func TestArrowRowScanner(t *testing.T) { var ars *arrowRowScanner = d.(*arrowRowScanner) - fbl := &fakeBatchLoader{ - Delimiter: rowscanner.NewDelimiter(0, 15), + fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{ &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(0, 5), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(0, 5), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(5, 3), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(5, 3), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(8, 7), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(8, 7), Record: &fakeRecord{}}}}, }, + index: -1, + callCount: 0, } - var e dbsqlerr.DBError - ars.batchIterator, e = NewBatchIterator(fbl) - assert.Nil(t, e) + ars.batchIterator = fbi err := ars.loadBatchFor(-1) assert.NotNil(t, err) @@ -636,17 +633,16 @@ func TestArrowRowScanner(t *testing.T) { var ars *arrowRowScanner = d.(*arrowRowScanner) - fbl := &fakeBatchLoader{ - Delimiter: rowscanner.NewDelimiter(0, 15), + fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{ &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(0, 5), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(0, 5), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(5, 3), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(5, 3), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(8, 7), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(8, 7), Record: &fakeRecord{}}}}, }, + index: -1, + callCount: 0, } - var e dbsqlerr.DBError - ars.batchIterator, e = NewBatchIterator(fbl) - assert.Nil(t, e) + ars.batchIterator = fbi ars.valueContainerMaker = &fakeValueContainerMaker{ fnMakeColumnValuesContainers: func(ars *arrowRowScanner, d rowscanner.Delimiter) dbsqlerr.DBError { @@ -657,7 +653,6 @@ func TestArrowRowScanner(t *testing.T) { err := ars.loadBatchFor(0) assert.NotNil(t, err) assert.ErrorContains(t, err, "error making containers") - }) t.Run("loadBatch record read failure", func(t *testing.T) { @@ -679,18 +674,17 @@ func TestArrowRowScanner(t *testing.T) { var ars *arrowRowScanner = d.(*arrowRowScanner) - fbl := &fakeBatchLoader{ - Delimiter: rowscanner.NewDelimiter(0, 15), + fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{ &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(0, 5), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(0, 5), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(5, 3), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(5, 3), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(8, 7), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(8, 7), Record: &fakeRecord{}}}}, }, - err: dbsqlerrint.NewDriverError(context.TODO(), "error reading record", nil), + index: -1, + callCount: 0, + err: dbsqlerrint.NewDriverError(context.TODO(), "error reading record", nil), } - var e dbsqlerr.DBError - ars.batchIterator, e = NewBatchIterator(fbl) - assert.Nil(t, e) + ars.batchIterator = fbi err := ars.loadBatchFor(0) assert.NotNil(t, err) @@ -716,40 +710,39 @@ func TestArrowRowScanner(t *testing.T) { var ars *arrowRowScanner = d.(*arrowRowScanner) - fbl := &fakeBatchLoader{ - Delimiter: rowscanner.NewDelimiter(0, 15), + fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{ &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(0, 5), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(0, 5), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(5, 3), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(5, 3), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(8, 7), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(8, 7), Record: &fakeRecord{}}}}, }, + index: -1, + callCount: 0, } - var e dbsqlerr.DBError - ars.batchIterator, e = NewBatchIterator(fbl) - assert.Nil(t, e) + ars.batchIterator = fbi for _, i := range []int64{0, 1, 2, 3, 4} { err := ars.loadBatchFor(i) assert.Nil(t, err) - assert.NotNil(t, fbl.lastReadBatch) - assert.Equal(t, 1, fbl.callCount) - assert.Equal(t, int64(0), fbl.lastReadBatch.Start()) + assert.NotNil(t, fbi.lastReadBatch) + assert.Equal(t, 1, fbi.callCount) + assert.Equal(t, int64(0), fbi.lastReadBatch.Start()) } for _, i := range []int64{5, 6, 7} { err := ars.loadBatchFor(i) assert.Nil(t, err) - assert.NotNil(t, fbl.lastReadBatch) - assert.Equal(t, 2, fbl.callCount) - assert.Equal(t, int64(5), fbl.lastReadBatch.Start()) + assert.NotNil(t, fbi.lastReadBatch) + assert.Equal(t, 2, fbi.callCount) + assert.Equal(t, int64(5), fbi.lastReadBatch.Start()) } for _, i := range []int64{8, 9, 10, 11, 12, 13, 14} { err := ars.loadBatchFor(i) assert.Nil(t, err) - assert.NotNil(t, fbl.lastReadBatch) - assert.Equal(t, 3, fbl.callCount) - assert.Equal(t, int64(8), fbl.lastReadBatch.Start()) + assert.NotNil(t, fbi.lastReadBatch) + assert.Equal(t, 3, fbi.callCount) + assert.Equal(t, int64(8), fbi.lastReadBatch.Start()) } err := ars.loadBatchFor(-1) @@ -869,17 +862,16 @@ func TestArrowRowScanner(t *testing.T) { ars.UseArrowNativeDecimal = true ars.UseArrowNativeIntervalTypes = true - fbl := &fakeBatchLoader{ - Delimiter: rowscanner.NewDelimiter(0, 15), + fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{ &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(0, 5), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(0, 5), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(5, 3), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(5, 3), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(8, 7), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(8, 7), Record: &fakeRecord{}}}}, }, + index: -1, + callCount: 0, } - var e dbsqlerr.DBError - ars.batchIterator, e = NewBatchIterator(fbl) - assert.Nil(t, e) + ars.batchIterator = fbi ars.valueContainerMaker = &fakeValueContainerMaker{fnMakeColumnValuesContainers: func(ars *arrowRowScanner, d rowscanner.Delimiter) dbsqlerr.DBError { columnValueHolders := make([]columnValues, len(ars.arrowSchema.Fields())) @@ -1049,17 +1041,13 @@ func TestArrowRowScanner(t *testing.T) { ars := d.(*arrowRowScanner) assert.Equal(t, int64(53940), ars.NRows()) - bi, ok := ars.batchIterator.(*batchIterator) + bi, ok := ars.batchIterator.(*localBatchIterator) assert.True(t, ok) - bl := bi.batchLoader - fbl := &batchLoaderWrapper{ - Delimiter: rowscanner.NewDelimiter(bl.Start(), bl.Count()), - bl: bl, + fbi := &batchIteratorWrapper{ + bi: bi, } - var e dbsqlerr.DBError - ars.batchIterator, e = NewBatchIterator(fbl) - assert.Nil(t, e) + ars.batchIterator = fbi dest := make([]driver.Value, len(executeStatementResp.DirectResults.ResultSetMetadata.Schema.Columns)) for i := int64(0); i < ars.NRows(); i = i + 1 { @@ -1079,7 +1067,7 @@ func TestArrowRowScanner(t *testing.T) { } } - assert.Equal(t, 54, fbl.callCount) + assert.Equal(t, 54, fbi.callCount) }) t.Run("Retrieve values - native arrow schema", func(t *testing.T) { @@ -1647,49 +1635,68 @@ func (cv *fakeColumnValues) SetValueArray(colData arrow.ArrayData) error { return nil } -type fakeBatchLoader struct { - rowscanner.Delimiter +type fakeBatchIterator struct { batches []SparkArrowBatch + index int callCount int err dbsqlerr.DBError lastReadBatch SparkArrowBatch } -var _ BatchLoader = (*fakeBatchLoader)(nil) +var _ BatchIterator = (*fakeBatchIterator)(nil) + +func (fbi *fakeBatchIterator) Next() (SparkArrowBatch, error) { + fbi.callCount += 1 -func (fbl *fakeBatchLoader) Close() {} -func (fbl *fakeBatchLoader) GetBatchFor(recordNum int64) (SparkArrowBatch, dbsqlerr.DBError) { - fbl.callCount += 1 - if fbl.err != nil { - return nil, fbl.err + if fbi.err != nil { + return nil, fbi.err } - for i := range fbl.batches { - if fbl.batches[i].Contains(recordNum) { - fbl.lastReadBatch = fbl.batches[i] - return fbl.batches[i], nil - } + cnt := len(fbi.batches) + fbi.index++ + if fbi.index < cnt { + fbi.lastReadBatch = fbi.batches[fbi.index] + return fbi.lastReadBatch, nil } - return nil, dbsqlerrint.NewDriverError(context.Background(), errArrowRowsInvalidRowNumber(recordNum), nil) + + fbi.lastReadBatch = nil + return nil, io.EOF +} + +func (fbi *fakeBatchIterator) HasNext() bool { + // `Next()` will first increment an index, and only then return a batch + // So `HasNext` should check if index can be incremented and still be within array + return fbi.index+1 < len(fbi.batches) +} + +func (fbi *fakeBatchIterator) Close() { + fbi.index = len(fbi.batches) + fbi.lastReadBatch = nil } -type batchLoaderWrapper struct { - rowscanner.Delimiter - bl BatchLoader +type batchIteratorWrapper struct { + bi BatchIterator callCount int lastLoadedBatch SparkArrowBatch } -var _ BatchLoader = (*batchLoaderWrapper)(nil) +var _ BatchIterator = (*batchIteratorWrapper)(nil) -func (fbl *batchLoaderWrapper) Close() { fbl.bl.Close() } -func (fbl *batchLoaderWrapper) GetBatchFor(recordNum int64) (SparkArrowBatch, dbsqlerr.DBError) { - fbl.callCount += 1 - batch, err := fbl.bl.GetBatchFor(recordNum) - fbl.lastLoadedBatch = batch +func (biw *batchIteratorWrapper) Next() (SparkArrowBatch, error) { + biw.callCount += 1 + batch, err := biw.bi.Next() + biw.lastLoadedBatch = batch return batch, err } +func (biw *batchIteratorWrapper) HasNext() bool { + return biw.bi.HasNext() +} + +func (biw *batchIteratorWrapper) Close() { + biw.bi.Close() +} + type fakeRecord struct { fnRelease func() fnRetain func() diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index e36153ab..45b067dd 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -3,6 +3,7 @@ package arrowbased import ( "bytes" "context" + "fmt" "io" "time" @@ -17,7 +18,6 @@ import ( dbsqlerr "github.com/databricks/databricks-sql-go/errors" "github.com/databricks/databricks-sql-go/internal/cli_service" dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors" - "github.com/databricks/databricks-sql-go/internal/fetcher" "github.com/databricks/databricks-sql-go/logger" ) @@ -27,210 +27,274 @@ type BatchIterator interface { Close() } -type BatchLoader interface { - rowscanner.Delimiter - GetBatchFor(recordNum int64) (SparkArrowBatch, dbsqlerr.DBError) - Close() -} +func NewCloudBatchIterator( + ctx context.Context, + files []*cli_service.TSparkArrowResultLink, + startRowOffset int64, + cfg *config.Config, +) (BatchIterator, dbsqlerr.DBError) { + bi := &cloudBatchIterator{ + ctx: ctx, + cfg: cfg, + startRowOffset: startRowOffset, + pendingLinks: NewQueue[cli_service.TSparkArrowResultLink](), + downloadTasks: NewQueue[cloudFetchDownloadTask](), + } -func NewBatchIterator(batchLoader BatchLoader) (BatchIterator, dbsqlerr.DBError) { - bi := &batchIterator{ - nextBatchStart: batchLoader.Start(), - batchLoader: batchLoader, + for _, link := range files { + bi.pendingLinks.Enqueue(link) } return bi, nil } -func NewCloudBatchLoader(ctx context.Context, files []*cli_service.TSparkArrowResultLink, startRowOffset int64, cfg *config.Config) (*batchLoader[*cloudURL], dbsqlerr.DBError) { - - if cfg == nil { - cfg = config.WithDefaults() +func NewLocalBatchIterator( + ctx context.Context, + batches []*cli_service.TSparkArrowBatch, + startRowOffset int64, + arrowSchemaBytes []byte, + cfg *config.Config, +) (BatchIterator, dbsqlerr.DBError) { + bi := &localBatchIterator{ + cfg: cfg, + startRowOffset: startRowOffset, + arrowSchemaBytes: arrowSchemaBytes, + batches: batches, + index: -1, } - inputChan := make(chan fetcher.FetchableItems[SparkArrowBatch], len(files)) - - var rowCount int64 - for i := range files { - f := files[i] - li := &cloudURL{ - // TSparkArrowResultLink: f, - Delimiter: rowscanner.NewDelimiter(f.StartRowOffset, f.RowCount), - fileLink: f.FileLink, - expiryTime: f.ExpiryTime, - minTimeToExpiry: cfg.MinTimeToExpiry, - compressibleBatch: compressibleBatch{useLz4Compression: cfg.UseLz4Compression}, - } - inputChan <- li + return bi, nil +} - rowCount += f.RowCount - } +type localBatchIterator struct { + cfg *config.Config + startRowOffset int64 + arrowSchemaBytes []byte + batches []*cli_service.TSparkArrowBatch + index int +} - // make sure to close input channel or fetcher will block waiting for more inputs - close(inputChan) +var _ BatchIterator = (*localBatchIterator)(nil) - f, _ := fetcher.NewConcurrentFetcher[*cloudURL](ctx, cfg.MaxDownloadThreads, cfg.MaxFilesInMemory, inputChan) - cbl := &batchLoader[*cloudURL]{ - Delimiter: rowscanner.NewDelimiter(startRowOffset, rowCount), - fetcher: f, - ctx: ctx, - } +func (bi *localBatchIterator) Next() (SparkArrowBatch, error) { + cnt := len(bi.batches) + bi.index++ + if bi.index < cnt { + ab := bi.batches[bi.index] - return cbl, nil -} - -func NewLocalBatchLoader(ctx context.Context, batches []*cli_service.TSparkArrowBatch, startRowOffset int64, arrowSchemaBytes []byte, cfg *config.Config) (*batchLoader[*localBatch], dbsqlerr.DBError) { + reader := io.MultiReader( + bytes.NewReader(bi.arrowSchemaBytes), + getReader(bytes.NewReader(ab.Batch), bi.cfg.UseLz4Compression), + ) - if cfg == nil { - cfg = config.WithDefaults() - } + records, err := getArrowRecords(reader, bi.startRowOffset) + if err != nil { + return &sparkArrowBatch{}, err + } - var startRow int64 = startRowOffset - var rowCount int64 - inputChan := make(chan fetcher.FetchableItems[SparkArrowBatch], len(batches)) - for i := range batches { - b := batches[i] - if b != nil { - li := &localBatch{ - Delimiter: rowscanner.NewDelimiter(startRow, b.RowCount), - batchBytes: b.Batch, - arrowSchemaBytes: arrowSchemaBytes, - compressibleBatch: compressibleBatch{useLz4Compression: cfg.UseLz4Compression}, - } - inputChan <- li - startRow = startRow + b.RowCount - rowCount += b.RowCount + batch := sparkArrowBatch{ + Delimiter: rowscanner.NewDelimiter(bi.startRowOffset, ab.RowCount), + arrowRecords: records, } - } - close(inputChan) - f, _ := fetcher.NewConcurrentFetcher[*localBatch](ctx, cfg.MaxDownloadThreads, cfg.MaxFilesInMemory, inputChan) - cbl := &batchLoader[*localBatch]{ - Delimiter: rowscanner.NewDelimiter(startRowOffset, rowCount), - fetcher: f, - ctx: ctx, + bi.startRowOffset += ab.RowCount // advance to beginning of the next batch + + return &batch, nil } - return cbl, nil + bi.index = cnt + return nil, io.EOF } -type batchLoader[T interface { - Fetch(ctx context.Context) (SparkArrowBatch, error) -}] struct { - rowscanner.Delimiter - fetcher fetcher.Fetcher[SparkArrowBatch] - arrowBatches []SparkArrowBatch - ctx context.Context +func (bi *localBatchIterator) HasNext() bool { + // `Next()` will first increment an index, and only then return a batch + // So `HasNext` should check if index can be incremented and still be within array + return bi.index+1 < len(bi.batches) } -var _ BatchLoader = (*batchLoader[*localBatch])(nil) - -func (cbl *batchLoader[T]) GetBatchFor(rowNumber int64) (SparkArrowBatch, dbsqlerr.DBError) { +func (bi *localBatchIterator) Close() { + bi.index = len(bi.batches) +} - logger.Debug().Msgf("batchLoader.GetBatchFor(%d)", rowNumber) +type cloudBatchIterator struct { + ctx context.Context + cfg *config.Config + startRowOffset int64 + pendingLinks Queue[cli_service.TSparkArrowResultLink] + downloadTasks Queue[cloudFetchDownloadTask] +} - for i := range cbl.arrowBatches { - logger.Debug().Msgf(" trying batch for range [%d..%d]", cbl.arrowBatches[i].Start(), cbl.arrowBatches[i].End()) - if cbl.arrowBatches[i].Contains(rowNumber) { - logger.Debug().Msgf(" found batch containing the requested row %d", rowNumber) - return cbl.arrowBatches[i], nil +var _ BatchIterator = (*cloudBatchIterator)(nil) + +func (bi *cloudBatchIterator) Next() (SparkArrowBatch, error) { + for (bi.downloadTasks.Len() < bi.cfg.MaxDownloadThreads) && (bi.pendingLinks.Len() > 0) { + link := bi.pendingLinks.Dequeue() + logger.Debug().Msgf( + "CloudFetch: schedule link at offset %d row count %d", + link.StartRowOffset, + link.RowCount, + ) + + cancelCtx, cancelFn := context.WithCancel(bi.ctx) + task := &cloudFetchDownloadTask{ + ctx: cancelCtx, + cancel: cancelFn, + useLz4Compression: bi.cfg.UseLz4Compression, + link: link, + resultChan: make(chan cloudFetchDownloadTaskResult), + minTimeToExpiry: bi.cfg.MinTimeToExpiry, } + task.Run() + bi.downloadTasks.Enqueue(task) } - logger.Debug().Msgf(" batch not found, trying to download more") - - batchChan, _, err := cbl.fetcher.Start() - var emptyBatch SparkArrowBatch - if err != nil { - logger.Debug().Msgf(" no batch found for row %d", rowNumber) - return emptyBatch, dbsqlerrint.NewDriverError(cbl.ctx, errArrowRowsInvalidRowNumber(rowNumber), err) + task := bi.downloadTasks.Dequeue() + if task == nil { + return nil, io.EOF } - for { - batch, ok := <-batchChan - if !ok { - err := cbl.fetcher.Err() - if err != nil { - logger.Debug().Msgf(" no batch found for row %d", rowNumber) - return emptyBatch, dbsqlerrint.NewDriverError(cbl.ctx, errArrowRowsInvalidRowNumber(rowNumber), err) - } - break - } + batch, err := task.GetResult() - cbl.arrowBatches = append(cbl.arrowBatches, batch) - logger.Debug().Msgf(" trying newly downloaded batch for range [%d..%d]", batch.Start(), batch.End()) - if batch.Contains(rowNumber) { - logger.Debug().Msgf(" found batch containing the requested row %d", rowNumber) - return batch, nil - } + // once we've got an errored out task - cancel the remaining ones + if err != nil { + bi.Close() + return nil, err } - logger.Debug().Msgf(" no batch found for row %d", rowNumber) + // explicitly call cancel function on successfully completed task to avoid context leak + task.cancel() + return batch, nil +} - return emptyBatch, dbsqlerrint.NewDriverError(cbl.ctx, errArrowRowsInvalidRowNumber(rowNumber), err) +func (bi *cloudBatchIterator) HasNext() bool { + return (bi.pendingLinks.Len() > 0) || (bi.downloadTasks.Len() > 0) } -func (cbl *batchLoader[T]) Close() { - for i := range cbl.arrowBatches { - cbl.arrowBatches[i].Close() +func (bi *cloudBatchIterator) Close() { + bi.pendingLinks.Clear() + for bi.downloadTasks.Len() > 0 { + task := bi.downloadTasks.Dequeue() + task.cancel() } } -type compressibleBatch struct { +type cloudFetchDownloadTaskResult struct { + batch SparkArrowBatch + err error +} + +type cloudFetchDownloadTask struct { + ctx context.Context + cancel context.CancelFunc useLz4Compression bool + minTimeToExpiry time.Duration + link *cli_service.TSparkArrowResultLink + resultChan chan cloudFetchDownloadTaskResult } -func (cb compressibleBatch) getReader(r io.Reader) io.Reader { - if cb.useLz4Compression { - return lz4.NewReader(r) +func (cft *cloudFetchDownloadTask) GetResult() (SparkArrowBatch, error) { + link := cft.link + + result, ok := <-cft.resultChan + if ok { + if result.err != nil { + logger.Debug().Msgf( + "CloudFetch: failed to download link at offset %d row count %d, reason: %s", + link.StartRowOffset, + link.RowCount, + result.err.Error(), + ) + return nil, result.err + } + logger.Debug().Msgf( + "CloudFetch: received data for link at offset %d row count %d", + link.StartRowOffset, + link.RowCount, + ) + return result.batch, nil } - return r -} -type cloudURL struct { - compressibleBatch - rowscanner.Delimiter - fileLink string - expiryTime int64 - minTimeToExpiry time.Duration + // This branch should never be reached. If you see this message - something got really wrong + logger.Debug().Msgf( + "CloudFetch: channel was closed before result was received; link at offset %d row count %d", + link.StartRowOffset, + link.RowCount, + ) + return nil, nil } -func (cu *cloudURL) Fetch(ctx context.Context) (SparkArrowBatch, error) { - var sab SparkArrowBatch +func (cft *cloudFetchDownloadTask) Run() { + go func() { + defer close(cft.resultChan) + + logger.Debug().Msgf( + "CloudFetch: start downloading link at offset %d row count %d", + cft.link.StartRowOffset, + cft.link.RowCount, + ) + data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry) + if err != nil { + cft.resultChan <- cloudFetchDownloadTaskResult{batch: nil, err: err} + return + } + + // io.ReadCloser.Close() may return an error, but in this case it should be safe to ignore (I hope so) + defer data.Close() + + logger.Debug().Msgf( + "CloudFetch: reading records for link at offset %d row count %d", + cft.link.StartRowOffset, + cft.link.RowCount, + ) + reader := getReader(data, cft.useLz4Compression) + + records, err := getArrowRecords(reader, cft.link.StartRowOffset) + if err != nil { + cft.resultChan <- cloudFetchDownloadTaskResult{batch: nil, err: err} + return + } + + batch := sparkArrowBatch{ + Delimiter: rowscanner.NewDelimiter(cft.link.StartRowOffset, cft.link.RowCount), + arrowRecords: records, + } + cft.resultChan <- cloudFetchDownloadTaskResult{batch: &batch, err: nil} + }() +} - if isLinkExpired(cu.expiryTime, cu.minTimeToExpiry) { - return sab, errors.New(dbsqlerr.ErrLinkExpired) +func fetchBatchBytes( + ctx context.Context, + link *cli_service.TSparkArrowResultLink, + minTimeToExpiry time.Duration, +) (io.ReadCloser, error) { + if isLinkExpired(link.ExpiryTime, minTimeToExpiry) { + return nil, errors.New(dbsqlerr.ErrLinkExpired) } - req, err := http.NewRequestWithContext(ctx, "GET", cu.fileLink, nil) + // TODO: Retry on HTTP errors + req, err := http.NewRequestWithContext(ctx, "GET", link.FileLink, nil) if err != nil { - return sab, err + return nil, err } client := http.DefaultClient res, err := client.Do(req) if err != nil { - return sab, err + return nil, err } if res.StatusCode != http.StatusOK { - return sab, dbsqlerrint.NewDriverError(ctx, errArrowRowsCloudFetchDownloadFailure, err) + msg := fmt.Sprintf("%s: %s %d", errArrowRowsCloudFetchDownloadFailure, "HTTP error", res.StatusCode) + return nil, dbsqlerrint.NewDriverError(ctx, msg, err) } - defer res.Body.Close() - - r := cu.compressibleBatch.getReader(res.Body) - - records, err := getArrowRecords(r, cu.Start()) - if err != nil { - return nil, err - } + return res.Body, nil +} - arrowBatch := sparkArrowBatch{ - Delimiter: rowscanner.NewDelimiter(cu.Start(), cu.Count()), - arrowRecords: records, +func getReader(r io.Reader, useLz4Compression bool) io.Reader { + if useLz4Compression { + return lz4.NewReader(r) } - - return &arrowBatch, nil + return r } func isLinkExpired(expiryTime int64, linkExpiryBuffer time.Duration) bool { @@ -238,35 +302,6 @@ func isLinkExpired(expiryTime int64, linkExpiryBuffer time.Duration) bool { return expiryTime-bufferSecs < time.Now().Unix() } -var _ fetcher.FetchableItems[SparkArrowBatch] = (*cloudURL)(nil) - -type localBatch struct { - compressibleBatch - rowscanner.Delimiter - batchBytes []byte - arrowSchemaBytes []byte -} - -var _ fetcher.FetchableItems[SparkArrowBatch] = (*localBatch)(nil) - -func (lb *localBatch) Fetch(ctx context.Context) (SparkArrowBatch, error) { - r := lb.compressibleBatch.getReader(bytes.NewReader(lb.batchBytes)) - r = io.MultiReader(bytes.NewReader(lb.arrowSchemaBytes), r) - - records, err := getArrowRecords(r, lb.Start()) - if err != nil { - return &sparkArrowBatch{}, err - } - - lb.batchBytes = nil - batch := sparkArrowBatch{ - Delimiter: rowscanner.NewDelimiter(lb.Start(), lb.Count()), - arrowRecords: records, - } - - return &batch, nil -} - func getArrowRecords(r io.Reader, startRowOffset int64) ([]SparkArrowRecord, error) { ipcReader, err := ipc.NewReader(r) if err != nil { @@ -300,34 +335,3 @@ func getArrowRecords(r io.Reader, startRowOffset int64) ([]SparkArrowRecord, err return records, nil } - -type batchIterator struct { - nextBatchStart int64 - batchLoader BatchLoader -} - -var _ BatchIterator = (*batchIterator)(nil) - -func (bi *batchIterator) Next() (SparkArrowBatch, error) { - if !bi.HasNext() { - return nil, io.EOF - } - if bi != nil && bi.batchLoader != nil { - batch, err := bi.batchLoader.GetBatchFor(bi.nextBatchStart) - if batch != nil && err == nil { - bi.nextBatchStart = batch.Start() + batch.Count() - } - return batch, err - } - return nil, nil -} - -func (bi *batchIterator) HasNext() bool { - return bi != nil && bi.batchLoader != nil && bi.batchLoader.Contains(bi.nextBatchStart) -} - -func (bi *batchIterator) Close() { - if bi != nil && bi.batchLoader != nil { - bi.batchLoader.Close() - } -} diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index 35bad337..e47eef08 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -4,9 +4,11 @@ import ( "bytes" "context" "fmt" + dbsqlerr "github.com/databricks/databricks-sql-go/errors" + "github.com/databricks/databricks-sql-go/internal/cli_service" + "github.com/databricks/databricks-sql-go/internal/config" "net/http" "net/http/httptest" - "reflect" "testing" "time" @@ -14,117 +16,216 @@ import ( "github.com/apache/arrow/go/v12/arrow/array" "github.com/apache/arrow/go/v12/arrow/ipc" "github.com/apache/arrow/go/v12/arrow/memory" - dbsqlerr "github.com/databricks/databricks-sql-go/errors" - dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors" - "github.com/databricks/databricks-sql-go/internal/rows/rowscanner" - "github.com/pkg/errors" "github.com/stretchr/testify/assert" ) -func TestCloudURLFetch(t *testing.T) { +func TestCloudFetchIterator(t *testing.T) { var handler func(w http.ResponseWriter, r *http.Request) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler(w, r) })) defer server.Close() - testTable := []struct { - name string - response func(w http.ResponseWriter, r *http.Request) - linkExpired bool - expectedResponse SparkArrowBatch - expectedErr error - }{ - { - name: "cloud-fetch-happy-case", - response: func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) - if err != nil { - panic(err) - } + + t.Run("should fetch all the links", func(t *testing.T) { + handler = func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) + if err != nil { + panic(err) + } + } + + startRowOffset := int64(100) + + links := []*cli_service.TSparkArrowResultLink{ + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, }, - linkExpired: false, - expectedResponse: &sparkArrowBatch{ - Delimiter: rowscanner.NewDelimiter(0, 3), - arrowRecords: []SparkArrowRecord{ - &sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(0, 3), Record: generateArrowRecord()}, - &sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(3, 3), Record: generateArrowRecord()}, - }, + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset + 1, + RowCount: 1, }, - expectedErr: nil, - }, - { - name: "cloud-fetch-expired_link", - response: func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) - if err != nil { - panic(err) - } + } + + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + + bi, err := NewCloudBatchIterator( + context.Background(), + links, + startRowOffset, + cfg, + ) + if err != nil { + panic(err) + } + + cbi := bi.(*cloudBatchIterator) + + assert.True(t, bi.HasNext()) + assert.Equal(t, cbi.pendingLinks.Len(), len(links)) + assert.Equal(t, cbi.downloadTasks.Len(), 0) + + // get first link - should succeed + sab1, err2 := bi.Next() + if err2 != nil { + panic(err2) + } + + assert.Equal(t, cbi.pendingLinks.Len(), len(links)-1) + assert.Equal(t, cbi.downloadTasks.Len(), 0) + assert.Equal(t, sab1.Start(), startRowOffset) + + // get second link - should succeed + sab2, err3 := bi.Next() + if err3 != nil { + panic(err3) + } + + assert.Equal(t, cbi.pendingLinks.Len(), len(links)-2) + assert.Equal(t, cbi.downloadTasks.Len(), 0) + assert.Equal(t, sab2.Start(), startRowOffset+sab1.Count()) + + // all links downloaded, should be no more data + assert.False(t, bi.HasNext()) + }) + + t.Run("should fail on expired link", func(t *testing.T) { + handler = func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) + if err != nil { + panic(err) + } + } + + startRowOffset := int64(100) + + links := []*cli_service.TSparkArrowResultLink{ + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, }, - linkExpired: true, - expectedResponse: nil, - expectedErr: errors.New(dbsqlerr.ErrLinkExpired), - }, - { - name: "cloud-fetch-http-error", - response: func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(-10 * time.Minute).Unix(), // expired link + StartRowOffset: startRowOffset + 1, + RowCount: 1, }, - linkExpired: false, - expectedResponse: nil, - expectedErr: dbsqlerrint.NewDriverError(context.TODO(), errArrowRowsCloudFetchDownloadFailure, nil), - }, - } + } - for _, tc := range testTable { - t.Run(tc.name, func(t *testing.T) { - handler = tc.response + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 - expiryTime := time.Now() - // If link expired, subtract 1 sec from current time to get expiration time - if tc.linkExpired { - expiryTime = expiryTime.Add(-1 * time.Second) - } else { - expiryTime = expiryTime.Add(10 * time.Second) - } + bi, err := NewCloudBatchIterator( + context.Background(), + links, + startRowOffset, + cfg, + ) + if err != nil { + panic(err) + } - cu := &cloudURL{ - Delimiter: rowscanner.NewDelimiter(0, 3), - fileLink: server.URL, - expiryTime: expiryTime.Unix(), - } + cbi := bi.(*cloudBatchIterator) - ctx := context.Background() - - resp, err := cu.Fetch(ctx) - - if tc.expectedResponse != nil { - assert.NotNil(t, resp) - esab, ok := tc.expectedResponse.(*sparkArrowBatch) - assert.True(t, ok) - asab, ok2 := resp.(*sparkArrowBatch) - assert.True(t, ok2) - if !reflect.DeepEqual(esab.Delimiter, asab.Delimiter) { - t.Errorf("expected (%v), got (%v)", esab.Delimiter, asab.Delimiter) - } - assert.Equal(t, len(esab.arrowRecords), len(asab.arrowRecords)) - for i := range esab.arrowRecords { - er := esab.arrowRecords[i] - ar := asab.arrowRecords[i] - - eb := generateMockArrowBytes(er) - ab := generateMockArrowBytes(ar) - assert.Equal(t, eb, ab) - } - } + assert.True(t, bi.HasNext()) + assert.Equal(t, cbi.pendingLinks.Len(), len(links)) + assert.Equal(t, cbi.downloadTasks.Len(), 0) + + // get first link - should succeed + sab1, err2 := bi.Next() + if err2 != nil { + panic(err2) + } + + assert.Equal(t, cbi.pendingLinks.Len(), len(links)-1) + assert.Equal(t, cbi.downloadTasks.Len(), 0) + assert.Equal(t, sab1.Start(), startRowOffset) + + // get second link - should fail + _, err3 := bi.Next() + assert.NotNil(t, err3) + assert.ErrorContains(t, err3, dbsqlerr.ErrLinkExpired) + }) + + t.Run("should fail on HTTP errors", func(t *testing.T) { + startRowOffset := int64(100) + + links := []*cli_service.TSparkArrowResultLink{ + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }, + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset + 1, + RowCount: 1, + }, + } - if !errors.Is(err, tc.expectedErr) { - assert.EqualErrorf(t, err, fmt.Sprintf("%v", tc.expectedErr), "expected (%v), got (%v)", tc.expectedErr, err) + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + + bi, err := NewCloudBatchIterator( + context.Background(), + links, + startRowOffset, + cfg, + ) + if err != nil { + panic(err) + } + + cbi := bi.(*cloudBatchIterator) + + assert.True(t, bi.HasNext()) + assert.Equal(t, cbi.pendingLinks.Len(), len(links)) + assert.Equal(t, cbi.downloadTasks.Len(), 0) + + // set handler for the first link, which returns some data + handler = func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) + if err != nil { + panic(err) } - }) - } + } + + // get first link - should succeed + sab1, err2 := bi.Next() + if err2 != nil { + panic(err2) + } + + assert.Equal(t, cbi.pendingLinks.Len(), len(links)-1) + assert.Equal(t, cbi.downloadTasks.Len(), 0) + assert.Equal(t, sab1.Start(), startRowOffset) + + // set handler for the first link, which fails with some non-retryable HTTP error + handler = func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + } + + // get second link - should fail + _, err3 := bi.Next() + assert.NotNil(t, err3) + assert.ErrorContains(t, err3, fmt.Sprintf("%s %d", "HTTP error", http.StatusNotFound)) + }) } func generateArrowRecord() arrow.Record { diff --git a/internal/rows/arrowbased/queue.go b/internal/rows/arrowbased/queue.go new file mode 100644 index 00000000..ed1d16f5 --- /dev/null +++ b/internal/rows/arrowbased/queue.go @@ -0,0 +1,51 @@ +package arrowbased + +import ( + "container/list" +) + +type Queue[ItemType any] interface { + Enqueue(item *ItemType) + Dequeue() *ItemType + Clear() + Len() int +} + +func NewQueue[ItemType any]() Queue[ItemType] { + return &queue[ItemType]{ + items: list.New(), + } +} + +type queue[ItemType any] struct { + items *list.List +} + +var _ Queue[any] = (*queue[any])(nil) + +func (q *queue[ItemType]) Enqueue(item *ItemType) { + q.items.PushBack(item) +} + +func (q *queue[ItemType]) Dequeue() *ItemType { + el := q.items.Front() + if el == nil { + return nil + } + q.items.Remove(el) + + value, ok := el.Value.(*ItemType) + if !ok { + return nil + } + + return value +} + +func (q *queue[ItemType]) Clear() { + q.items.Init() +} + +func (q *queue[ItemType]) Len() int { + return q.items.Len() +}