Skip to content

Commit

Permalink
Support custom comparison in RowContainer (facebookincubator#11024)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookincubator#11024

Building on facebookincubator#11021 this adds support for custom
comparison functions provided by custom types in RowContainer.

Again, compare was already handled by updating SimpleVector, so this just updates the hash
function and adds tests.

Reviewed By: pansatadru

Differential Revision: D62905323

fbshipit-source-id: e839f615178f6aa0cafab3c1dd48b453ba4c5d63
  • Loading branch information
Kevin Wilfong authored and facebook-github-bot committed Sep 25, 2024
1 parent 50ee522 commit 367faf8
Show file tree
Hide file tree
Showing 4 changed files with 310 additions and 72 deletions.
41 changes: 31 additions & 10 deletions velox/exec/RowContainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ int32_t RowContainer::compareComplexType(
return compareComplexType(left, right, type, offset, offset, flags);
}

template <TypeKind Kind>
template <bool typeProvidesCustomComparison, TypeKind Kind>
void RowContainer::hashTyped(
const Type* type,
RowColumn column,
Expand All @@ -881,6 +881,7 @@ void RowContainer::hashTyped(
auto offset = column.offset();
std::string storage;
auto numRows = rows.size();

for (int32_t i = 0; i < numRows; ++i) {
char* row = rows[i];
if (nullable && isNullAt(row, column)) {
Expand All @@ -897,6 +898,9 @@ void RowContainer::hashTyped(
Kind == TypeKind::MAP) {
auto in = prepareRead(row, offset);
hash = ContainerRowSerde::hash(*in, type);
} else if constexpr (typeProvidesCustomComparison) {
hash = static_cast<const CanProvideCustomComparisonType<Kind>*>(type)
->hash(valueAt<T>(row, offset));
} else if constexpr (std::is_floating_point_v<T>) {
hash = util::floating_point::NaNAwareHash<T>()(valueAt<T>(row, offset));
} else {
Expand All @@ -921,15 +925,32 @@ void RowContainer::hash(
}

bool nullable = column >= keyTypes_.size() || nullableKeys_;
VELOX_DYNAMIC_TYPE_DISPATCH(
hashTyped,
typeKinds_[column],
types_[column].get(),
columnAt(column),
nullable,
rows,
mix,
result);

const auto& type = types_[column];

if (type->providesCustomComparison()) {
VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH(
hashTyped,
true,
typeKinds_[column],
type.get(),
columnAt(column),
nullable,
rows,
mix,
result);
} else {
VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH(
hashTyped,
false,
typeKinds_[column],
type.get(),
columnAt(column),
nullable,
rows,
mix,
result);
}
}

void RowContainer::clear() {
Expand Down
205 changes: 143 additions & 62 deletions velox/exec/RowContainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,7 @@ class RowContainer {
const char* row,
int32_t offset);

template <TypeKind Kind>
template <bool typeProvidesCustomComparison, TypeKind Kind>
void hashTyped(
const Type* type,
RowColumn column,
Expand All @@ -1071,7 +1071,7 @@ class RowContainer {
bool mix,
uint64_t* result);

template <TypeKind Kind>
template <bool typeProvidesCustomComparison, TypeKind Kind>
inline bool equalsWithNulls(
const char* row,
int32_t offset,
Expand All @@ -1085,15 +1085,18 @@ class RowContainer {
return rowIsNull == indexIsNull;
}

return equalsNoNulls<Kind>(row, offset, decoded, index);
return equalsNoNulls<typeProvidesCustomComparison, Kind>(
row, offset, decoded, index);
}

template <TypeKind Kind>
template <bool typeProvidesCustomComparison, TypeKind Kind>
inline bool equalsNoNulls(
const char* row,
int32_t offset,
const DecodedVector& decoded,
vector_size_t index) {
using T = typename KindToFlatVector<Kind>::HashRowType;

if constexpr (
Kind == TypeKind::ROW || Kind == TypeKind::ARRAY ||
Kind == TypeKind::MAP) {
Expand All @@ -1102,20 +1105,22 @@ class RowContainer {
Kind == TypeKind::VARCHAR || Kind == TypeKind::VARBINARY) {
return compareStringAsc(
valueAt<StringView>(row, offset), decoded, index) == 0;
} else if constexpr (typeProvidesCustomComparison) {
return SimpleVector<T>::template comparePrimitiveAscWithCustomComparison<
Kind>(
decoded.base()->type().get(),
decoded.valueAt<T>(index),
valueAt<T>(row, offset)) == 0;
} else {
using T = typename KindToFlatVector<Kind>::HashRowType;
return decoded.base()->typeUsesCustomComparison()
? SimpleVector<T>::template comparePrimitiveAscWithCustomComparison<
Kind>(
decoded.base()->type().get(),
decoded.valueAt<T>(index),
valueAt<T>(row, offset))
: SimpleVector<T>::comparePrimitiveAsc(
decoded.valueAt<T>(index), valueAt<T>(row, offset)) == 0;
return SimpleVector<T>::comparePrimitiveAsc(
decoded.valueAt<T>(index), valueAt<T>(row, offset)) == 0;
}
}

template <TypeKind Kind>
template <
bool typeProvidesCustomComparison,
TypeKind Kind,
std::enable_if_t<Kind != TypeKind::OPAQUE, int32_t> = 0>
inline int compare(
const char* row,
RowColumn column,
Expand Down Expand Up @@ -1143,15 +1148,37 @@ class RowContainer {
} else {
auto left = valueAt<T>(row, column.offset());
auto right = decoded.valueAt<T>(index);
auto result = decoded.base()->typeUsesCustomComparison()
? SimpleVector<T>::template comparePrimitiveAscWithCustomComparison<
Kind>(decoded.base()->type().get(), left, right)
: SimpleVector<T>::comparePrimitiveAsc(left, right);

int result;
if constexpr (typeProvidesCustomComparison) {
result =
SimpleVector<T>::template comparePrimitiveAscWithCustomComparison<
Kind>(decoded.base()->type().get(), left, right);
} else {
result = SimpleVector<T>::comparePrimitiveAsc(left, right);
}

return flags.ascending ? result : result * -1;
}
}

template <TypeKind Kind>
template <
bool typeProvidesCustomComparison,
TypeKind Kind,
std::enable_if_t<Kind == TypeKind::OPAQUE, int32_t> = 0>
inline int compare(
const char* /*row*/,
RowColumn /*column*/,
const DecodedVector& /*decoded*/,
vector_size_t /*index*/,
CompareFlags /*flags*/) {
VELOX_UNSUPPORTED("Comparing Opaque types is not supported.");
}

template <
bool typeProvidesCustomComparison,
TypeKind Kind,
std::enable_if_t<Kind != TypeKind::OPAQUE, int32_t> = 0>
inline int compare(
const char* left,
const char* right,
Expand Down Expand Up @@ -1187,22 +1214,43 @@ class RowContainer {
} else {
auto leftValue = valueAt<T>(left, leftOffset);
auto rightValue = valueAt<T>(right, rightOffset);
auto result = type->providesCustomComparison()
? SimpleVector<T>::template comparePrimitiveAscWithCustomComparison<
Kind>(type, leftValue, rightValue)
: SimpleVector<T>::comparePrimitiveAsc(leftValue, rightValue);

int result;
if constexpr (typeProvidesCustomComparison) {
result =
SimpleVector<T>::template comparePrimitiveAscWithCustomComparison<
Kind>(type, leftValue, rightValue);
} else {
result = SimpleVector<T>::comparePrimitiveAsc(leftValue, rightValue);
}

return flags.ascending ? result : result * -1;
}
}

template <TypeKind Kind>
template <
bool typeProvidesCustomComparison,
TypeKind Kind,
std::enable_if_t<Kind == TypeKind::OPAQUE, int32_t> = 0>
inline int compare(
const char* /*left*/,
const char* /*right*/,
const Type* /*type*/,
RowColumn /*leftColumn*/,
RowColumn /*rightColumn*/,
CompareFlags /*flags*/) {
VELOX_UNSUPPORTED("Comparing Opaque types is not supported.");
}

template <bool typeProvidesCustomComparison, TypeKind Kind>
inline int compare(
const char* left,
const char* right,
const Type* type,
RowColumn column,
CompareFlags flags) {
return compare<Kind>(left, right, type, column, column, flags);
return compare<typeProvidesCustomComparison, Kind>(
left, right, type, column, column, flags);
}

void storeComplexType(
Expand Down Expand Up @@ -1589,11 +1637,12 @@ inline bool RowContainer::equals(
}

if constexpr (!mayHaveNulls) {
return VELOX_DYNAMIC_TYPE_DISPATCH(
equalsNoNulls, typeKind, row, column.offset(), decoded, index);
return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH(
equalsNoNulls, false, typeKind, row, column.offset(), decoded, index);
} else {
return VELOX_DYNAMIC_TYPE_DISPATCH(
return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH(
equalsWithNulls,
false,
typeKind,
row,
column.offset(),
Expand All @@ -1604,35 +1653,33 @@ inline bool RowContainer::equals(
}
}

template <>
inline int RowContainer::compare<TypeKind::OPAQUE>(
const char* /*row*/,
RowColumn /*column*/,
const DecodedVector& /*decoded*/,
vector_size_t /*index*/,
CompareFlags /*flags*/) {
VELOX_UNSUPPORTED("Comparing Opaque types is not supported.");
}

template <>
inline int RowContainer::compare<TypeKind::OPAQUE>(
const char* /*left*/,
const char* /*right*/,
const Type* /*type*/,
RowColumn /*leftColumn*/,
RowColumn /*rightColumn*/,
CompareFlags /*flags*/) {
VELOX_UNSUPPORTED("Comparing Opaque types is not supported.");
}

inline int RowContainer::compare(
const char* row,
RowColumn column,
const DecodedVector& decoded,
vector_size_t index,
CompareFlags flags) {
return VELOX_DYNAMIC_TYPE_DISPATCH_ALL(
compare, decoded.base()->typeKind(), row, column, decoded, index, flags);
if (decoded.base()->typeUsesCustomComparison()) {
return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH_ALL(
compare,
true,
decoded.base()->typeKind(),
row,
column,
decoded,
index,
flags);
} else {
return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH_ALL(
compare,
false,
decoded.base()->typeKind(),
row,
column,
decoded,
index,
flags);
}
}

inline int RowContainer::compare(
Expand All @@ -1641,8 +1688,27 @@ inline int RowContainer::compare(
int columnIndex,
CompareFlags flags) {
auto type = types_[columnIndex].get();
return VELOX_DYNAMIC_TYPE_DISPATCH_ALL(
compare, type->kind(), left, right, type, columnAt(columnIndex), flags);
if (type->providesCustomComparison()) {
return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH_ALL(
compare,
true,
type->kind(),
left,
right,
type,
columnAt(columnIndex),
flags);
} else {
return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH_ALL(
compare,
false,
type->kind(),
left,
right,
type,
columnAt(columnIndex),
flags);
}
}

inline int RowContainer::compare(
Expand All @@ -1654,15 +1720,30 @@ inline int RowContainer::compare(
auto leftType = types_[leftColumnIndex].get();
auto rightType = types_[rightColumnIndex].get();
VELOX_CHECK(leftType->equivalent(*rightType));
return VELOX_DYNAMIC_TYPE_DISPATCH_ALL(
compare,
leftType->kind(),
left,
right,
leftType,
columnAt(leftColumnIndex),
columnAt(rightColumnIndex),
flags);

if (leftType->providesCustomComparison()) {
return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH_ALL(
compare,
true,
leftType->kind(),
left,
right,
leftType,
columnAt(leftColumnIndex),
columnAt(rightColumnIndex),
flags);
} else {
return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH_ALL(
compare,
false,
leftType->kind(),
left,
right,
leftType,
columnAt(leftColumnIndex),
columnAt(rightColumnIndex),
flags);
}
}

/// A comparator of rows stored in the RowContainer compatible with
Expand Down
Loading

0 comments on commit 367faf8

Please sign in to comment.