From 367faf8e7492927dfd6c4d2891b2dcd9309da338 Mon Sep 17 00:00:00 2001 From: Kevin Wilfong Date: Wed, 25 Sep 2024 11:14:45 -0700 Subject: [PATCH] Support custom comparison in RowContainer (#11024) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/11024 Building on https://github.com/facebookincubator/velox/pull/11021 this adds support for custom comparison functions provided by custom types in RowContainer. Again, compare was already handled by updating SimpleVector, so this just updates the hash function and adds tests. Reviewed By: pansatadru Differential Revision: D62905323 fbshipit-source-id: e839f615178f6aa0cafab3c1dd48b453ba4c5d63 --- velox/exec/RowContainer.cpp | 41 ++++-- velox/exec/RowContainer.h | 205 ++++++++++++++++++-------- velox/exec/tests/RowContainerTest.cpp | 121 +++++++++++++++ velox/type/Type.h | 15 ++ 4 files changed, 310 insertions(+), 72 deletions(-) diff --git a/velox/exec/RowContainer.cpp b/velox/exec/RowContainer.cpp index a158c31cdbee..8566f5fb1212 100644 --- a/velox/exec/RowContainer.cpp +++ b/velox/exec/RowContainer.cpp @@ -868,7 +868,7 @@ int32_t RowContainer::compareComplexType( return compareComplexType(left, right, type, offset, offset, flags); } -template +template void RowContainer::hashTyped( const Type* type, RowColumn column, @@ -881,6 +881,7 @@ void RowContainer::hashTyped( auto offset = column.offset(); std::string storage; auto numRows = rows.size(); + for (int32_t i = 0; i < numRows; ++i) { char* row = rows[i]; if (nullable && isNullAt(row, column)) { @@ -897,6 +898,9 @@ void RowContainer::hashTyped( Kind == TypeKind::MAP) { auto in = prepareRead(row, offset); hash = ContainerRowSerde::hash(*in, type); + } else if constexpr (typeProvidesCustomComparison) { + hash = static_cast*>(type) + ->hash(valueAt(row, offset)); } else if constexpr (std::is_floating_point_v) { hash = util::floating_point::NaNAwareHash()(valueAt(row, offset)); } else { @@ -921,15 +925,32 @@ void RowContainer::hash( } bool nullable = column >= keyTypes_.size() || nullableKeys_; - VELOX_DYNAMIC_TYPE_DISPATCH( - hashTyped, - typeKinds_[column], - types_[column].get(), - columnAt(column), - nullable, - rows, - mix, - result); + + const auto& type = types_[column]; + + if (type->providesCustomComparison()) { + VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH( + hashTyped, + true, + typeKinds_[column], + type.get(), + columnAt(column), + nullable, + rows, + mix, + result); + } else { + VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH( + hashTyped, + false, + typeKinds_[column], + type.get(), + columnAt(column), + nullable, + rows, + mix, + result); + } } void RowContainer::clear() { diff --git a/velox/exec/RowContainer.h b/velox/exec/RowContainer.h index ca5bceda9af1..39ae2aa568c4 100644 --- a/velox/exec/RowContainer.h +++ b/velox/exec/RowContainer.h @@ -1062,7 +1062,7 @@ class RowContainer { const char* row, int32_t offset); - template + template void hashTyped( const Type* type, RowColumn column, @@ -1071,7 +1071,7 @@ class RowContainer { bool mix, uint64_t* result); - template + template inline bool equalsWithNulls( const char* row, int32_t offset, @@ -1085,15 +1085,18 @@ class RowContainer { return rowIsNull == indexIsNull; } - return equalsNoNulls(row, offset, decoded, index); + return equalsNoNulls( + row, offset, decoded, index); } - template + template inline bool equalsNoNulls( const char* row, int32_t offset, const DecodedVector& decoded, vector_size_t index) { + using T = typename KindToFlatVector::HashRowType; + if constexpr ( Kind == TypeKind::ROW || Kind == TypeKind::ARRAY || Kind == TypeKind::MAP) { @@ -1102,20 +1105,22 @@ class RowContainer { Kind == TypeKind::VARCHAR || Kind == TypeKind::VARBINARY) { return compareStringAsc( valueAt(row, offset), decoded, index) == 0; + } else if constexpr (typeProvidesCustomComparison) { + return SimpleVector::template comparePrimitiveAscWithCustomComparison< + Kind>( + decoded.base()->type().get(), + decoded.valueAt(index), + valueAt(row, offset)) == 0; } else { - using T = typename KindToFlatVector::HashRowType; - return decoded.base()->typeUsesCustomComparison() - ? SimpleVector::template comparePrimitiveAscWithCustomComparison< - Kind>( - decoded.base()->type().get(), - decoded.valueAt(index), - valueAt(row, offset)) - : SimpleVector::comparePrimitiveAsc( - decoded.valueAt(index), valueAt(row, offset)) == 0; + return SimpleVector::comparePrimitiveAsc( + decoded.valueAt(index), valueAt(row, offset)) == 0; } } - template + template < + bool typeProvidesCustomComparison, + TypeKind Kind, + std::enable_if_t = 0> inline int compare( const char* row, RowColumn column, @@ -1143,15 +1148,37 @@ class RowContainer { } else { auto left = valueAt(row, column.offset()); auto right = decoded.valueAt(index); - auto result = decoded.base()->typeUsesCustomComparison() - ? SimpleVector::template comparePrimitiveAscWithCustomComparison< - Kind>(decoded.base()->type().get(), left, right) - : SimpleVector::comparePrimitiveAsc(left, right); + + int result; + if constexpr (typeProvidesCustomComparison) { + result = + SimpleVector::template comparePrimitiveAscWithCustomComparison< + Kind>(decoded.base()->type().get(), left, right); + } else { + result = SimpleVector::comparePrimitiveAsc(left, right); + } + return flags.ascending ? result : result * -1; } } - template + template < + bool typeProvidesCustomComparison, + TypeKind Kind, + std::enable_if_t = 0> + inline int compare( + const char* /*row*/, + RowColumn /*column*/, + const DecodedVector& /*decoded*/, + vector_size_t /*index*/, + CompareFlags /*flags*/) { + VELOX_UNSUPPORTED("Comparing Opaque types is not supported."); + } + + template < + bool typeProvidesCustomComparison, + TypeKind Kind, + std::enable_if_t = 0> inline int compare( const char* left, const char* right, @@ -1187,22 +1214,43 @@ class RowContainer { } else { auto leftValue = valueAt(left, leftOffset); auto rightValue = valueAt(right, rightOffset); - auto result = type->providesCustomComparison() - ? SimpleVector::template comparePrimitiveAscWithCustomComparison< - Kind>(type, leftValue, rightValue) - : SimpleVector::comparePrimitiveAsc(leftValue, rightValue); + + int result; + if constexpr (typeProvidesCustomComparison) { + result = + SimpleVector::template comparePrimitiveAscWithCustomComparison< + Kind>(type, leftValue, rightValue); + } else { + result = SimpleVector::comparePrimitiveAsc(leftValue, rightValue); + } + return flags.ascending ? result : result * -1; } } - template + template < + bool typeProvidesCustomComparison, + TypeKind Kind, + std::enable_if_t = 0> + inline int compare( + const char* /*left*/, + const char* /*right*/, + const Type* /*type*/, + RowColumn /*leftColumn*/, + RowColumn /*rightColumn*/, + CompareFlags /*flags*/) { + VELOX_UNSUPPORTED("Comparing Opaque types is not supported."); + } + + template inline int compare( const char* left, const char* right, const Type* type, RowColumn column, CompareFlags flags) { - return compare(left, right, type, column, column, flags); + return compare( + left, right, type, column, column, flags); } void storeComplexType( @@ -1589,11 +1637,12 @@ inline bool RowContainer::equals( } if constexpr (!mayHaveNulls) { - return VELOX_DYNAMIC_TYPE_DISPATCH( - equalsNoNulls, typeKind, row, column.offset(), decoded, index); + return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH( + equalsNoNulls, false, typeKind, row, column.offset(), decoded, index); } else { - return VELOX_DYNAMIC_TYPE_DISPATCH( + return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH( equalsWithNulls, + false, typeKind, row, column.offset(), @@ -1604,35 +1653,33 @@ inline bool RowContainer::equals( } } -template <> -inline int RowContainer::compare( - const char* /*row*/, - RowColumn /*column*/, - const DecodedVector& /*decoded*/, - vector_size_t /*index*/, - CompareFlags /*flags*/) { - VELOX_UNSUPPORTED("Comparing Opaque types is not supported."); -} - -template <> -inline int RowContainer::compare( - const char* /*left*/, - const char* /*right*/, - const Type* /*type*/, - RowColumn /*leftColumn*/, - RowColumn /*rightColumn*/, - CompareFlags /*flags*/) { - VELOX_UNSUPPORTED("Comparing Opaque types is not supported."); -} - inline int RowContainer::compare( const char* row, RowColumn column, const DecodedVector& decoded, vector_size_t index, CompareFlags flags) { - return VELOX_DYNAMIC_TYPE_DISPATCH_ALL( - compare, decoded.base()->typeKind(), row, column, decoded, index, flags); + if (decoded.base()->typeUsesCustomComparison()) { + return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH_ALL( + compare, + true, + decoded.base()->typeKind(), + row, + column, + decoded, + index, + flags); + } else { + return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH_ALL( + compare, + false, + decoded.base()->typeKind(), + row, + column, + decoded, + index, + flags); + } } inline int RowContainer::compare( @@ -1641,8 +1688,27 @@ inline int RowContainer::compare( int columnIndex, CompareFlags flags) { auto type = types_[columnIndex].get(); - return VELOX_DYNAMIC_TYPE_DISPATCH_ALL( - compare, type->kind(), left, right, type, columnAt(columnIndex), flags); + if (type->providesCustomComparison()) { + return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH_ALL( + compare, + true, + type->kind(), + left, + right, + type, + columnAt(columnIndex), + flags); + } else { + return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH_ALL( + compare, + false, + type->kind(), + left, + right, + type, + columnAt(columnIndex), + flags); + } } inline int RowContainer::compare( @@ -1654,15 +1720,30 @@ inline int RowContainer::compare( auto leftType = types_[leftColumnIndex].get(); auto rightType = types_[rightColumnIndex].get(); VELOX_CHECK(leftType->equivalent(*rightType)); - return VELOX_DYNAMIC_TYPE_DISPATCH_ALL( - compare, - leftType->kind(), - left, - right, - leftType, - columnAt(leftColumnIndex), - columnAt(rightColumnIndex), - flags); + + if (leftType->providesCustomComparison()) { + return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH_ALL( + compare, + true, + leftType->kind(), + left, + right, + leftType, + columnAt(leftColumnIndex), + columnAt(rightColumnIndex), + flags); + } else { + return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH_ALL( + compare, + false, + leftType->kind(), + left, + right, + leftType, + columnAt(leftColumnIndex), + columnAt(rightColumnIndex), + flags); + } } /// A comparator of rows stored in the RowContainer compatible with diff --git a/velox/exec/tests/RowContainerTest.cpp b/velox/exec/tests/RowContainerTest.cpp index 129efad3411e..eb510c3c9ba2 100644 --- a/velox/exec/tests/RowContainerTest.cpp +++ b/velox/exec/tests/RowContainerTest.cpp @@ -19,6 +19,7 @@ #include "velox/exec/VectorHasher.h" #include "velox/exec/tests/utils/RowContainerTestBase.h" #include "velox/expression/VectorReaders.h" +#include "velox/type/tests/utils/CustomTypesForTesting.h" #include "velox/vector/fuzzer/VectorFuzzer.h" using namespace facebook::velox; @@ -2169,3 +2170,123 @@ TEST_F(RowContainerTest, store) { } } } + +TEST_F(RowContainerTest, customComparison) { + auto values = makeNullableFlatVector( + {std::nullopt, + 256 * 4 + 3, + 256 * 2 + 4, + 256 + 2, + 5, + 256 * 5, + 256 * 2 + 1}, + BIGINT_TYPE_WITH_CUSTOM_COMPARISON()); + + // The custom comparison compares the values mod 256. + std::vector> ascNullsFirstOrder = { + std::nullopt, 256 * 5, 256 * 2 + 1, 256 + 2, 256 * 4 + 3, 256 * 2 + 4, 5}; + + testOrderAndEqualsWithNullsFirstVariations( + BIGINT_TYPE_WITH_CUSTOM_COMPARISON(), + values, + ascNullsFirstOrder, + [&](const auto& expectedOrder) { + return makeNullableFlatVector( + expectedOrder, BIGINT_TYPE_WITH_CUSTOM_COMPARISON()); + }); +} + +TEST_F(RowContainerTest, customComparisonArray) { + auto values = makeNullableArrayVector( + {std::nullopt, + {{std::nullopt}}, + {{256 * 4 + 3}}, + {{256 * 2 + 4}}, + {{256 + 2}}, + {{5}}, + {{256 * 5}}, + {{256 * 2 + 1}}}, + ARRAY(BIGINT_TYPE_WITH_CUSTOM_COMPARISON())); + + // The custom comparison compares the values mod 256. + std::vector>>> + ascNullsFirstOrder = { + std::nullopt, + {{std::nullopt}}, + {{256 * 5}}, + {{256 * 2 + 1}}, + {{256 + 2}}, + {{256 * 4 + 3}}, + {{256 * 2 + 4}}, + {{5}}}; + + testOrderAndEqualsWithNullsFirstVariations< + std::vector>>( + ARRAY(BIGINT_TYPE_WITH_CUSTOM_COMPARISON()), + values, + ascNullsFirstOrder, + [&](const auto& expectedOrder) { + return makeNullableArrayVector( + expectedOrder, ARRAY(BIGINT_TYPE_WITH_CUSTOM_COMPARISON())); + }); +} + +TEST_F(RowContainerTest, customComparisonMap) { + auto values = makeNullableMapVector( + {{{std::nullopt}}, + {{{256 * 4 + 3, 1}}}, + {{{256 * 2 + 4, 2}}}, + {{{256 + 2, 3}}}, + {{{5, 4}}}, + {{{256 * 5, 5}}}, + {{{256 * 2 + 1, 6}}}}, + MAP(BIGINT_TYPE_WITH_CUSTOM_COMPARISON(), INTEGER())); + + // The custom comparison compares the values mod 256. + std::vector< + std::optional>>>> + ascNullsFirstOrder = { + {{std::nullopt}}, + {{{256 * 5, 5}}}, + {{{256 * 2 + 1, 6}}}, + {{{256 + 2, 3}}}, + {{{256 * 4 + 3, 1}}}, + {{{256 * 2 + 4, 2}}}, + {{{5, 4}}}}; + + testOrderAndEqualsWithNullsFirstVariations< + std::vector>>>( + MAP(BIGINT_TYPE_WITH_CUSTOM_COMPARISON(), INTEGER()), + values, + ascNullsFirstOrder, + [&](const auto& expectedOrder) { + return makeNullableMapVector( + expectedOrder, + MAP(BIGINT_TYPE_WITH_CUSTOM_COMPARISON(), INTEGER())); + }); +} + +TEST_F(RowContainerTest, customComparisonRow) { + auto values = makeRowVector({makeNullableFlatVector( + {std::nullopt, + 256 * 4 + 3, + 256 * 2 + 4, + 256 + 2, + 5, + 256 * 5, + 256 * 2 + 1}, + BIGINT_TYPE_WITH_CUSTOM_COMPARISON())}); + + // The custom comparison compares the values mod 256. + std::vector> ascNullsFirstOrder = { + std::nullopt, 256 * 5, 256 * 2 + 1, 256 + 2, 256 * 4 + 3, 256 * 2 + 4, 5}; + + testOrderAndEqualsWithNullsFirstVariations( + ROW({BIGINT_TYPE_WITH_CUSTOM_COMPARISON()}), + values, + ascNullsFirstOrder, + [&](const auto& expectedOrder) { + return makeRowVector({makeNullableFlatVector( + expectedOrder, BIGINT_TYPE_WITH_CUSTOM_COMPARISON())}); + }); +} diff --git a/velox/type/Type.h b/velox/type/Type.h index 535bdd1e8ce0..891130d26227 100644 --- a/velox/type/Type.h +++ b/velox/type/Type.h @@ -1612,6 +1612,21 @@ std::shared_ptr OPAQUE() { } \ }() +#define VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH_ALL( \ + TEMPLATE_FUNC, T, typeKind, ...) \ + [&]() { \ + if ((typeKind) == ::facebook::velox::TypeKind::UNKNOWN) { \ + return TEMPLATE_FUNC( \ + __VA_ARGS__); \ + } else if ((typeKind) == ::facebook::velox::TypeKind::OPAQUE) { \ + return TEMPLATE_FUNC( \ + __VA_ARGS__); \ + } else { \ + return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH( \ + TEMPLATE_FUNC, T, typeKind, __VA_ARGS__); \ + } \ + }() + #define VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH_ALL(TEMPLATE_FUNC, typeKind, ...) \ [&]() { \ if ((typeKind) == ::facebook::velox::TypeKind::UNKNOWN) { \