Skip to content

Commit

Permalink
Mark UNKNOWN type as orderable and comparable (facebookincubator#10213)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookincubator#10213

UNKNOWN type is a scalar type that supports only NULL values.
Like any other scalar type, UNKNOWN is comparable and orderable.

This change affects functions that use orderableTypeVariable in their
signatures: min, max, min_by, max_by, and array_sort.

Add tests for UNKNOWN inputs in these functions.

Reviewed By: amitkdutta

Differential Revision: D58635074

fbshipit-source-id: 0ad733f1ef2d3ccafc5d596db8b7b01febd443c6
  • Loading branch information
mbasmanova authored and facebook-github-bot committed Jun 16, 2024
1 parent f2ce4a5 commit bca2f1c
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 6 deletions.
6 changes: 6 additions & 0 deletions velox/functions/lib/aggregates/MinMaxByAggregatesBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,9 @@ std::unique_ptr<exec::Aggregate> create(
case TypeKind::ROW:
return std::make_unique<Aggregate<W, ComplexType, isMaxFunc, Comparator>>(
resultType, throwOnNestedNulls);
case TypeKind::UNKNOWN:
return std::make_unique<
Aggregate<W, UnknownValue, isMaxFunc, Comparator>>(resultType);
default:
VELOX_FAIL("{}", errorMessage);
return nullptr;
Expand Down Expand Up @@ -662,6 +665,9 @@ std::unique_ptr<exec::Aggregate> create(
case TypeKind::ROW:
return create<Aggregate, isMaxFunc, Comparator, ComplexType>(
resultType, compareType, errorMessage, throwOnNestedNulls);
case TypeKind::UNKNOWN:
return create<Aggregate, isMaxFunc, Comparator, UnknownValue>(
resultType, compareType, errorMessage, throwOnNestedNulls);
default:
VELOX_FAIL(errorMessage);
}
Expand Down
18 changes: 13 additions & 5 deletions velox/functions/prestosql/ArraySort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ void applyScalarType(
}

// See documentation at https://prestodb.io/docs/current/functions/array.html
template <TypeKind T>
template <TypeKind Kind>
class ArraySortFunction : public exec::VectorFunction {
public:
/// This class implements the array_sort query function. Takes an array as
Expand Down Expand Up @@ -237,7 +237,10 @@ class ArraySortFunction : public exec::VectorFunction {
VectorPtr localResult;

// Input can be constant or flat.
if (arg->isConstantEncoding()) {
if constexpr (Kind == TypeKind::UNKNOWN) {
// All elements are NULL. Hence, sorting doesn't change anything.
localResult = arg;
} else if (arg->isConstantEncoding()) {
auto* constantArray = arg->as<ConstantVector<ComplexType>>();
const auto& flatArray = constantArray->valueVector();
const auto flatIndex = constantArray->index();
Expand All @@ -262,10 +265,10 @@ class ArraySortFunction : public exec::VectorFunction {
auto inputArray = arg->as<ArrayVector>();
VectorPtr resultElements;

if (velox::TypeTraits<T>::isPrimitiveType) {
if constexpr (velox::TypeTraits<Kind>::isPrimitiveType) {
VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
applyScalarType,
T,
Kind,
rows,
inputArray,
ascending_,
Expand Down Expand Up @@ -399,7 +402,12 @@ std::shared_ptr<exec::VectorFunction> create(
ascending, throwOnNestedNull);
}

auto elementType = inputArgs.front().type->childAt(0);
const auto elementType = inputArgs.front().type->childAt(0);
if (elementType->isUnKnown()) {
return createTyped<TypeKind::UNKNOWN>(
inputArgs, ascending, throwOnNestedNull);
}

return VELOX_DYNAMIC_TYPE_DISPATCH(
createTyped,
elementType->kind(),
Expand Down
2 changes: 2 additions & 0 deletions velox/functions/prestosql/aggregates/MinMaxAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,8 @@ exec::AggregateRegistrationResult registerMinMax(
case TypeKind::ROW:
return std::make_unique<TNonNumeric>(
inputType, throwOnNestedNulls);
case TypeKind::UNKNOWN:
return std::make_unique<TNumeric<UnknownValue>>(resultType);
default:
VELOX_UNREACHABLE(
"Unknown input type for {} aggregation {}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,66 @@ TEST_F(MinMaxByComplexTypes, failOnUnorderableType) {
}
}

class MinMaxByUnknownTest : public AggregationTestBase {};

TEST_F(MinMaxByUnknownTest, unknown) {
auto data = makeRowVector(
{"k", "vn", "cn", "v", "c"},
{
makeFlatVector<int64_t>({1, 2, 1, 2, 1, 2}),
makeAllNullFlatVector<UnknownValue>(6), // value
makeAllNullFlatVector<UnknownValue>(6), // compare
makeFlatVector<int64_t>({1, 2, 3, 4, 5, 6}), // value
makeFlatVector<int64_t>({1, 2, 3, 4, 5, 6}), // compare
});

// Global agg.
auto expected = makeRowVector({
makeAllNullFlatVector<UnknownValue>(1),
makeAllNullFlatVector<UnknownValue>(1),
});

// Both value and compare are UNKNOWN.
testAggregations(
{data}, {}, {"min_by(vn, cn)", "max_by(vn, cn)"}, {expected});

// Only value is UNKNOWN.
testAggregations({data}, {}, {"min_by(vn, c)", "max_by(vn, c)"}, {expected});

// Only compare is UNKNOWN.
expected = makeRowVector({
makeAllNullFlatVector<int64_t>(1),
makeAllNullFlatVector<int64_t>(1),
});

testAggregations({data}, {}, {"min_by(v, cn)", "max_by(v, cn)"}, {expected});

// Group by.
expected = makeRowVector({
makeFlatVector<int64_t>({1, 2}),
makeAllNullFlatVector<UnknownValue>(2),
makeAllNullFlatVector<UnknownValue>(2),
});

// Both value and compare are UNKNOWN.
testAggregations(
{data}, {"k"}, {"min_by(vn, cn)", "max_by(vn, cn)"}, {expected});

// Only value is UNKNOWN.
testAggregations(
{data}, {"k"}, {"min_by(vn, c)", "max_by(vn, c)"}, {expected});

// Only compare is UNKNOWN.
expected = makeRowVector({
makeFlatVector<int64_t>({1, 2}),
makeAllNullFlatVector<int64_t>(2),
makeAllNullFlatVector<int64_t>(2),
});

testAggregations(
{data}, {"k"}, {"min_by(v, cn)", "max_by(v, cn)"}, {expected});
}

class MinMaxByNTest : public AggregationTestBase {
protected:
void SetUp() override {
Expand Down
22 changes: 22 additions & 0 deletions velox/functions/prestosql/aggregates/tests/MinMaxTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,28 @@ TEST_F(MinMaxTest, minMaxDate) {
"SELECT c0 % 17, min(c1), max(c1) FROM tmp GROUP BY 1");
}

TEST_F(MinMaxTest, minMaxUnknown) {
auto data = makeRowVector({
makeFlatVector<int64_t>({1, 2, 1, 2, 1, 2}),
makeAllNullFlatVector<UnknownValue>(6),
});

auto expected = makeRowVector({
makeAllNullFlatVector<UnknownValue>(1),
makeAllNullFlatVector<UnknownValue>(1),
});

testAggregations({data}, {}, {"min(c1)", "max(c1)"}, {expected});

expected = makeRowVector({
makeFlatVector<int64_t>({1, 2}),
makeAllNullFlatVector<UnknownValue>(2),
makeAllNullFlatVector<UnknownValue>(2),
});

testAggregations({data}, {"c0"}, {"min(c1)", "max(c1)"}, {expected});
}

TEST_F(MinMaxTest, initialValue) {
// Ensures that no groups are default initialized (to 0) in
// aggregate::SimpleNumericAggregate.
Expand Down
19 changes: 19 additions & 0 deletions velox/functions/prestosql/tests/ArraySortTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"

#include <fmt/format.h>
#include <cstdint>

using namespace facebook::velox;
using namespace facebook::velox::test;
Expand Down Expand Up @@ -373,6 +374,24 @@ TEST_P(ArraySortTest, basic) {
runTest(GetParam());
}

TEST_F(ArraySortTest, unknown) {
auto input = makeNullableArrayVector<UnknownValue>({
{std::nullopt, std::nullopt},
{std::nullopt, std::nullopt, std::nullopt},
});

auto result = evaluate("array_sort(c0)", makeRowVector({input}));
assertEqualVectors(input, result);

input = makeArrayVectorFromJson<int32_t>({
"[1, 2, 3]",
"[1, 2]",
});

result = evaluate("array_sort(c0, x -> null)", makeRowVector({input}));
assertEqualVectors(input, result);
}

TEST_F(ArraySortTest, constant) {
vector_size_t size = 1'000;
auto data =
Expand Down
8 changes: 8 additions & 0 deletions velox/type/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,14 @@ class UnknownType : public TypeBase<TypeKind::UNKNOWN> {
return 0;
}

bool isOrderable() const override {
return true;
}

bool isComparable() const override {
return true;
}

bool equivalent(const Type& other) const override {
return Type::hasSameTypeId(other);
}
Expand Down
16 changes: 15 additions & 1 deletion velox/type/tests/TypeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,20 @@ TEST(TypeTest, intervalYearMonth) {
testTypeSerde(interval);
}

TEST(TypeTest, unknown) {
auto type = UNKNOWN();
EXPECT_EQ(type->toString(), "UNKNOWN");
EXPECT_EQ(type->size(), 0);
EXPECT_THROW(type->childAt(0), std::invalid_argument);
EXPECT_EQ(type->kind(), TypeKind::UNKNOWN);
EXPECT_STREQ(type->kindName(), "UNKNOWN");
EXPECT_EQ(type->begin(), type->end());
EXPECT_TRUE(type->isComparable());
EXPECT_TRUE(type->isOrderable());

testTypeSerde(type);
}

TEST(TypeTest, shortDecimal) {
auto shortDecimal = DECIMAL(10, 5);
EXPECT_EQ(shortDecimal->toString(), "DECIMAL(10, 5)");
Expand Down Expand Up @@ -820,7 +834,7 @@ TEST(TypeTest, follySformat) {
"{}", ROW({{"a", BOOLEAN()}, {"b", VARCHAR()}, {"c", BIGINT()}})));
}

TEST(TypeTest, unknown) {
TEST(TypeTest, unknownArray) {
auto unknownArray = ARRAY(UNKNOWN());
EXPECT_TRUE(unknownArray->containsUnknown());

Expand Down

0 comments on commit bca2f1c

Please sign in to comment.