forked from NVIDIA/spark-rapids
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'shuffle-gpu-serde' into 0527-base
- Loading branch information
Showing
14 changed files
with
839 additions
and
173 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
175 changes: 175 additions & 0 deletions
175
sql-plugin/src/main/java/com/nvidia/spark/rapids/PackedTableHostColumnVector.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.nvidia.spark.rapids; | ||
|
||
import ai.rapids.cudf.ContiguousTable; | ||
import ai.rapids.cudf.DeviceMemoryBuffer; | ||
import ai.rapids.cudf.HostMemoryBuffer; | ||
import com.nvidia.spark.rapids.format.TableMeta; | ||
import org.apache.spark.sql.types.DataTypes; | ||
import org.apache.spark.sql.types.Decimal; | ||
import org.apache.spark.sql.vectorized.ColumnVector; | ||
import org.apache.spark.sql.vectorized.ColumnarArray; | ||
import org.apache.spark.sql.vectorized.ColumnarBatch; | ||
import org.apache.spark.sql.vectorized.ColumnarMap; | ||
import org.apache.spark.unsafe.types.UTF8String; | ||
|
||
/** | ||
* A column vector that tracks a packed (or compressed) table on host. Unlike a normal | ||
* host column vector, the columnar data within cannot be accessed directly. | ||
* This is intended to only be used during shuffle after the data is partitioned and | ||
* before it is serialized. | ||
*/ | ||
public final class PackedTableHostColumnVector extends ColumnVector { | ||
|
||
private static final String BAD_ACCESS_MSG = "Column is packed"; | ||
|
||
private final TableMeta tableMeta; | ||
private final HostMemoryBuffer tableBuffer; | ||
|
||
PackedTableHostColumnVector(TableMeta tableMeta, HostMemoryBuffer tableBuffer) { | ||
super(DataTypes.NullType); | ||
long rows = tableMeta.rowCount(); | ||
int batchRows = (int) rows; | ||
if (rows != batchRows) { | ||
throw new IllegalStateException("Cannot support a batch larger that MAX INT rows"); | ||
} | ||
this.tableMeta = tableMeta; | ||
this.tableBuffer = tableBuffer; | ||
} | ||
|
||
private static ColumnarBatch from(TableMeta meta, DeviceMemoryBuffer devBuf) { | ||
HostMemoryBuffer tableBuf; | ||
try(HostMemoryBuffer buf = HostMemoryBuffer.allocate(devBuf.getLength())) { | ||
buf.copyFromDeviceBuffer(devBuf); | ||
buf.incRefCount(); | ||
tableBuf = buf; | ||
} | ||
ColumnVector column = new PackedTableHostColumnVector(meta, tableBuf); | ||
return new ColumnarBatch(new ColumnVector[] { column }, (int) meta.rowCount()); | ||
} | ||
|
||
/** Both the input table and output batch should be closed. */ | ||
public static ColumnarBatch from(CompressedTable table) { | ||
return from(table.meta(), table.buffer()); | ||
} | ||
|
||
/** Both the input table and output batch should be closed. */ | ||
public static ColumnarBatch from(ContiguousTable table) { | ||
return from(MetaUtils.buildTableMeta(0, table), table.getBuffer()); | ||
} | ||
|
||
/** Returns true if this columnar batch uses a packed table on host */ | ||
public static boolean isBatchPackedOnHost(ColumnarBatch batch) { | ||
return batch.numCols() == 1 && batch.column(0) instanceof PackedTableHostColumnVector; | ||
} | ||
|
||
public TableMeta getTableMeta() { | ||
return tableMeta; | ||
} | ||
|
||
public HostMemoryBuffer getTableBuffer() { | ||
return tableBuffer; | ||
} | ||
|
||
@Override | ||
public void close() { | ||
if (tableBuffer != null) { | ||
tableBuffer.close(); | ||
} | ||
} | ||
|
||
@Override | ||
public boolean hasNull() { | ||
throw new IllegalStateException(BAD_ACCESS_MSG); | ||
} | ||
|
||
@Override | ||
public int numNulls() { | ||
throw new IllegalStateException(BAD_ACCESS_MSG); | ||
} | ||
|
||
@Override | ||
public boolean isNullAt(int rowId) { | ||
throw new IllegalStateException(BAD_ACCESS_MSG); | ||
} | ||
|
||
@Override | ||
public boolean getBoolean(int rowId) { | ||
throw new IllegalStateException(BAD_ACCESS_MSG); | ||
} | ||
|
||
@Override | ||
public byte getByte(int rowId) { | ||
throw new IllegalStateException(BAD_ACCESS_MSG); | ||
} | ||
|
||
@Override | ||
public short getShort(int rowId) { | ||
throw new IllegalStateException(BAD_ACCESS_MSG); | ||
} | ||
|
||
@Override | ||
public int getInt(int rowId) { | ||
throw new IllegalStateException(BAD_ACCESS_MSG); | ||
} | ||
|
||
@Override | ||
public long getLong(int rowId) { | ||
throw new IllegalStateException(BAD_ACCESS_MSG); | ||
} | ||
|
||
@Override | ||
public float getFloat(int rowId) { | ||
throw new IllegalStateException(BAD_ACCESS_MSG); | ||
} | ||
|
||
@Override | ||
public double getDouble(int rowId) { | ||
throw new IllegalStateException(BAD_ACCESS_MSG); | ||
} | ||
|
||
@Override | ||
public ColumnarArray getArray(int rowId) { | ||
throw new IllegalStateException(BAD_ACCESS_MSG); | ||
} | ||
|
||
@Override | ||
public ColumnarMap getMap(int rowId) { | ||
throw new IllegalStateException(BAD_ACCESS_MSG); | ||
} | ||
|
||
@Override | ||
public Decimal getDecimal(int rowId, int precision, int scale) { | ||
throw new IllegalStateException(BAD_ACCESS_MSG); | ||
} | ||
|
||
@Override | ||
public UTF8String getUTF8String(int rowId) { | ||
throw new IllegalStateException(BAD_ACCESS_MSG); | ||
} | ||
|
||
@Override | ||
public byte[] getBinary(int rowId) { | ||
throw new IllegalStateException(BAD_ACCESS_MSG); | ||
} | ||
|
||
@Override | ||
public ColumnVector getChild(int ordinal) { | ||
throw new IllegalStateException(BAD_ACCESS_MSG); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.