Skip to content

Commit

Permalink
[PECO-1532] Arrow and CloudFetch result handlers: return row count wi…
Browse files Browse the repository at this point in the history
…th raw batch data

Signed-off-by: Levko Kravets <[email protected]>
  • Loading branch information
kravets-levko committed Mar 21, 2024
1 parent 6673660 commit 85343e9
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 17 deletions.
8 changes: 4 additions & 4 deletions lib/result/ArrowResultConverter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import {
import { TGetResultSetMetadataResp, TColumnDesc } from '../../thrift/TCLIService_types';
import IClientContext from '../contracts/IClientContext';
import IResultsProvider, { ResultsProviderFetchNextOptions } from './IResultsProvider';
import { getSchemaColumns, convertThriftValue } from './utils';
import { ArrowBatch, getSchemaColumns, convertThriftValue } from './utils';

const { isArrowBigNumSymbol, bigNumToBigInt } = arrowUtils;

Expand All @@ -26,15 +26,15 @@ type ArrowSchemaField = Field<DataType<Type, TypeMap>>;
export default class ArrowResultConverter implements IResultsProvider<Array<any>> {
protected readonly context: IClientContext;

private readonly source: IResultsProvider<Array<Buffer>>;
private readonly source: IResultsProvider<ArrowBatch>;

private readonly schema: Array<TColumnDesc>;

private reader?: IterableIterator<RecordBatch<TypeMap>>;

private pendingRecordBatch?: RecordBatch<TypeMap>;

constructor(context: IClientContext, source: IResultsProvider<Array<Buffer>>, { schema }: TGetResultSetMetadataResp) {
constructor(context: IClientContext, source: IResultsProvider<ArrowBatch>, { schema }: TGetResultSetMetadataResp) {
this.context = context;
this.source = source;
this.schema = getSchemaColumns(schema);
Expand Down Expand Up @@ -73,7 +73,7 @@ export default class ArrowResultConverter implements IResultsProvider<Array<any>
}

// eslint-disable-next-line no-await-in-loop
const batches = await this.source.fetchNext(options);
const { batches } = await this.source.fetchNext(options);
if (batches.length === 0) {
this.reader = undefined;
break;
Expand Down
23 changes: 17 additions & 6 deletions lib/result/ArrowResultHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ import LZ4 from 'lz4';
import { TGetResultSetMetadataResp, TRowSet } from '../../thrift/TCLIService_types';
import IClientContext from '../contracts/IClientContext';
import IResultsProvider, { ResultsProviderFetchNextOptions } from './IResultsProvider';
import { hiveSchemaToArrowSchema } from './utils';
import { ArrowBatch, hiveSchemaToArrowSchema } from './utils';

export default class ArrowResultHandler implements IResultsProvider<Array<Buffer>> {
export default class ArrowResultHandler implements IResultsProvider<ArrowBatch> {
protected readonly context: IClientContext;

private readonly source: IResultsProvider<TRowSet | undefined>;
Expand Down Expand Up @@ -35,22 +35,33 @@ export default class ArrowResultHandler implements IResultsProvider<Array<Buffer

public async fetchNext(options: ResultsProviderFetchNextOptions) {
if (!this.arrowSchema) {
return [];
return {
batches: [],
rowCount: 0,
};
}

const rowSet = await this.source.fetchNext(options);

const batches: Array<Buffer> = [];
rowSet?.arrowBatches?.forEach(({ batch }) => {
let totalRowCount = 0;
rowSet?.arrowBatches?.forEach(({ batch, rowCount }) => {
if (batch) {
batches.push(this.isLZ4Compressed ? LZ4.decode(batch) : batch);
totalRowCount += rowCount.toNumber(true);
}
});

if (batches.length === 0) {
return [];
return {
batches: [],
rowCount: 0,
};
}

return [this.arrowSchema, ...batches];
return {
batches: [this.arrowSchema, ...batches],
rowCount: totalRowCount,
};
}
}
23 changes: 16 additions & 7 deletions lib/result/CloudFetchResultHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ import fetch, { RequestInfo, RequestInit, Request } from 'node-fetch';
import { TGetResultSetMetadataResp, TRowSet, TSparkArrowResultLink } from '../../thrift/TCLIService_types';
import IClientContext from '../contracts/IClientContext';
import IResultsProvider, { ResultsProviderFetchNextOptions } from './IResultsProvider';
import { ArrowBatch } from './utils';

export default class CloudFetchResultHandler implements IResultsProvider<Array<Buffer>> {
export default class CloudFetchResultHandler implements IResultsProvider<ArrowBatch> {
protected readonly context: IClientContext;

private readonly source: IResultsProvider<TRowSet | undefined>;
Expand All @@ -13,7 +14,7 @@ export default class CloudFetchResultHandler implements IResultsProvider<Array<B

private pendingLinks: Array<TSparkArrowResultLink> = [];

private downloadTasks: Array<Promise<Buffer>> = [];
private downloadTasks: Array<Promise<ArrowBatch>> = [];

constructor(
context: IClientContext,
Expand Down Expand Up @@ -49,15 +50,20 @@ export default class CloudFetchResultHandler implements IResultsProvider<Array<B
}

const batch = await this.downloadTasks.shift();
const batches = batch ? [batch] : [];
if (!batch) {
return {
batches: [],
rowCount: 0,
};
}

if (this.isLZ4Compressed) {
return batches.map((buffer) => LZ4.decode(buffer));
batch.batches = batch.batches.map((buffer) => LZ4.decode(buffer));
}
return batches;
return batch;
}

private async downloadLink(link: TSparkArrowResultLink): Promise<Buffer> {
private async downloadLink(link: TSparkArrowResultLink): Promise<ArrowBatch> {
if (Date.now() >= link.expiryTime.toNumber()) {
throw new Error('CloudFetch link has expired');
}
Expand All @@ -68,7 +74,10 @@ export default class CloudFetchResultHandler implements IResultsProvider<Array<B
}

const result = await response.arrayBuffer();
return Buffer.from(result);
return {
batches: [Buffer.from(result)],
rowCount: link.rowCount.toNumber(true),
};
}

private async fetch(url: RequestInfo, init?: RequestInit) {
Expand Down
5 changes: 5 additions & 0 deletions lib/result/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ import {
import { TTableSchema, TColumnDesc, TPrimitiveTypeEntry, TTypeId } from '../../thrift/TCLIService_types';
import HiveDriverError from '../errors/HiveDriverError';

export interface ArrowBatch {
batches: Array<Buffer>;
rowCount: number;
}

export function getSchemaColumns(schema?: TTableSchema): Array<TColumnDesc> {
if (!schema) {
return [];
Expand Down

0 comments on commit 85343e9

Please sign in to comment.