Skip to content

Commit

Permalink
Add more tests & fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
PHILO-HE committed Apr 9, 2024
1 parent e3dcb05 commit a36c8a1
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 106 deletions.
8 changes: 4 additions & 4 deletions velox/functions/sparksql/String.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,8 @@ void doApply(
int32_t i = 0;
// For string arg.
int32_t j = 0;
for (auto& arg : args) {
if (arg->typeKind() == TypeKind::ARRAY) {
for (auto it = args.begin() + 1; it != args.end(); ++it) {
if ((*it)->typeKind() == TypeKind::ARRAY) {
auto size = rawSizesVector[i][indicesVector[i][row]];
auto offset = rawOffsetsVector[i][indicesVector[i][row]];
for (int k = 0; k < size; ++k) {
Expand Down Expand Up @@ -361,10 +361,9 @@ std::vector<std::shared_ptr<exec::FunctionSignature>> concatWsSignatures() {
// The argument type will be checked in makeConcatWs.
// varchar, [varchar], [array(varchar)], ... -> varchar.
exec::FunctionSignatureBuilder()
.typeVariable("T")
.returnType("varchar")
.constantArgumentType("varchar")
.argumentType("T")
.argumentType("any")
.variableArity()
.build()};
}
Expand All @@ -380,6 +379,7 @@ std::shared_ptr<exec::VectorFunction> makeConcatWs(
numArgs);
for (auto& arg : inputArgs) {
VELOX_USER_CHECK(
// TODO: check array's element type.
arg.type->isVarchar() || arg.type->isArray(),
"concat_ws requires varchar or array(varchar) arguments, but got {}.",
arg.type->toString());
Expand Down
226 changes: 125 additions & 101 deletions velox/functions/sparksql/tests/ConcatWsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,19 @@ class ConcatWsTest : public SparkFunctionBaseTest {
}
};

// Test concat_ws vector function
TEST_F(ConcatWsTest, columnStringArgs) {
// Test constant input vector with 2 args
{
auto rows = makeRowVector(makeRowType({VARCHAR(), VARCHAR()}), 10);
auto c0 = generateRandomString(20);
auto c1 = generateRandomString(20);
auto result = evaluate<SimpleVector<StringView>>(
fmt::format("concat_ws('-', '{}', '{}')", c0, c1), rows);
for (int i = 0; i < 10; ++i) {
EXPECT_EQ(result->valueAt(i), c0 + "-" + c1);
}
auto rows = makeRowVector(makeRowType({VARCHAR(), VARCHAR()}), 10);
auto c0 = generateRandomString(20);
auto c1 = generateRandomString(20);
auto result = evaluate<SimpleVector<StringView>>(
fmt::format("concat_ws('-', '{}', '{}')", c0, c1), rows);
for (int i = 0; i < 10; ++i) {
EXPECT_EQ(result->valueAt(i), c0 + "-" + c1);
}

// test concat_ws variable arguments
size_t maxArgsCount = 10; // cols
// Test concat_ws variable arguments
size_t maxArgsCount = 10;
size_t rowCount = 100;
size_t maxStringLength = 100;

Expand All @@ -127,99 +124,126 @@ TEST_F(ConcatWsTest, columnStringArgs) {
}
}

TEST_F(ConcatWsTest, constantStringArgs) {
// Multiple consecutive constant inputs.
{
size_t maxStringLength = 100;
std::string value;
auto data = makeRowVector({
makeFlatVector<StringView>(
1'000,
[&](auto /* row */) {
value = generateRandomString(
folly::Random::rand32() % maxStringLength);
return StringView(value);
}),
makeFlatVector<StringView>(
1'000,
[&](auto /* row */) {
value = generateRandomString(
folly::Random::rand32() % maxStringLength);
return StringView(value);
}),
});

auto c0 = data->childAt(0)->as<FlatVector<StringView>>()->rawValues();
auto c1 = data->childAt(1)->as<FlatVector<StringView>>()->rawValues();

auto result = evaluate<SimpleVector<StringView>>(
"concat_ws('--', c0, c1, 'foo', 'bar')", data);

auto expected = makeFlatVector<StringView>(1'000, [&](auto row) {
value = "";
const std::string& s0 = c0[row].str();
const std::string& s1 = c1[row].str();

if (s0.empty() && s1.empty()) {
value = "foo--bar";
} else if (!s0.empty() && !s1.empty()) {
value = s0 + "--" + s1 + "--foo--bar";
} else {
value = s0 + s1 + "--foo--bar";
}
return StringView(value);
});

velox::test::assertEqualVectors(expected, result);

result = evaluate<SimpleVector<StringView>>(
"concat_ws('$*@', 'aaa', '测试', c0, 'eee', 'ddd', c1, '\u82f9\u679c', 'fff')",
data);

expected = makeFlatVector<StringView>(1'000, [&](auto row) {
value = "";
std::string delim = "$*@";
const std::string& s0 =
c0[row].str().empty() ? c0[row].str() : delim + c0[row].str();
const std::string& s1 =
c1[row].str().empty() ? c1[row].str() : delim + c1[row].str();

value = "aaa" + delim + "测试" + s0 + delim + "eee" + delim + "ddd" + s1 +
delim + "\u82f9\u679c" + delim + "fff";
return StringView(value);
});
velox::test::assertEqualVectors(expected, result);
}
TEST_F(ConcatWsTest, mixedConstantAndColumnStringArgs) {
size_t maxStringLength = 100;
std::string value;
auto data = makeRowVector({
makeFlatVector<StringView>(
1'000,
[&](auto /* row */) {
value =
generateRandomString(folly::Random::rand32() % maxStringLength);
return StringView(value);
}),
makeFlatVector<StringView>(
1'000,
[&](auto /* row */) {
value =
generateRandomString(folly::Random::rand32() % maxStringLength);
return StringView(value);
}),
});

auto c0 = data->childAt(0)->as<FlatVector<StringView>>()->rawValues();
auto c1 = data->childAt(1)->as<FlatVector<StringView>>()->rawValues();

// Test with consecutive constant inputs.
auto result = evaluate<SimpleVector<StringView>>(
"concat_ws('--', c0, c1, 'foo', 'bar')", data);
auto expected = makeFlatVector<StringView>(1'000, [&](auto row) {
value = "";
const std::string& s0 = c0[row].str();
const std::string& s1 = c1[row].str();

if (s0.empty() && s1.empty()) {
value = "foo--bar";
} else if (!s0.empty() && !s1.empty()) {
value = s0 + "--" + s1 + "--foo--bar";
} else {
value = s0 + s1 + "--foo--bar";
}
return StringView(value);
});
velox::test::assertEqualVectors(expected, result);

// Test with non-ASCII characters.
result = evaluate<SimpleVector<StringView>>(
"concat_ws('$*@', 'aaa', '测试', c0, 'eee', 'ddd', c1, '\u82f9\u679c', 'fff')",
data);
expected = makeFlatVector<StringView>(1'000, [&](auto row) {
value = "";
std::string delim = "$*@";
const std::string& s0 =
c0[row].str().empty() ? c0[row].str() : delim + c0[row].str();
const std::string& s1 =
c1[row].str().empty() ? c1[row].str() : delim + c1[row].str();

value = "aaa" + delim + "测试" + s0 + delim + "eee" + delim + "ddd" + s1 +
delim + "\u82f9\u679c" + delim + "fff";
return StringView(value);
});
velox::test::assertEqualVectors(expected, result);
}

TEST_F(ConcatWsTest, arrayArgs) {
// test concat_ws array
{
using S = StringView;
auto arrayVector = makeNullableArrayVector<StringView>({
{S("red"), S("blue")},
{S("blue"), std::nullopt, S("yellow"), std::nullopt, S("orange")},
{},
{std::nullopt},
{S("red"), S("purple"), S("green")},
});

auto result = evaluate<SimpleVector<StringView>>(
"concat_ws('----', c0)", makeRowVector({arrayVector}));

auto expected = {
S("red----blue"),
S("blue----yellow----orange"),
S(""),
S(""),
S("red----purple----green"),
};
using S = StringView;
auto arrayVector = makeNullableArrayVector<StringView>({
{S("red"), S("blue")},
{S("blue"), std::nullopt, S("yellow"), std::nullopt, S("orange")},
{},
{std::nullopt},
{S("red"), S("purple"), S("green")},
});

// One array arg.
auto result = evaluate<SimpleVector<StringView>>(
"concat_ws('----', c0)", makeRowVector({arrayVector}));
auto expected1 = {
S("red----blue"),
S("blue----yellow----orange"),
S(""),
S(""),
S("red----purple----green"),
};
velox::test::assertEqualVectors(
makeFlatVector<StringView>(expected1), result);

// Two array args.
result = evaluate<SimpleVector<StringView>>(
"concat_ws('----', c0, c1)", makeRowVector({arrayVector, arrayVector}));
auto expected2 = {
S("red----blue----red----blue"),
S("blue----yellow----orange----blue----yellow----orange"),
S(""),
S(""),
S("red----purple----green----red----purple----green"),
};
velox::test::assertEqualVectors(
makeFlatVector<StringView>(expected2), result);
}

velox::test::assertEqualVectors(
makeFlatVector<StringView>(expected), result);
}
TEST_F(ConcatWsTest, mixedStringArrayArgs) {
using S = StringView;
auto arrayVector = makeNullableArrayVector<StringView>({
{S("red"), S("blue")},
{S("blue"), std::nullopt, S("yellow"), std::nullopt, S("orange")},
{},
{std::nullopt},
{S("red"), S("purple"), S("green")},
});

auto result = evaluate<SimpleVector<StringView>>(
"concat_ws('----', c0, 'foo', c1, 'bar', 'end')",
makeRowVector({arrayVector, arrayVector}));
auto expected = {
S("red----blue----foo----red----blue----bar----end"),
S("blue----yellow----orange----foo----blue----yellow----orange----bar----end"),
S("foo----bar----end"),
S("foo----bar----end"),
S("red----purple----green----foo----red----purple----green----bar----end"),
};
velox::test::assertEqualVectors(makeFlatVector<StringView>(expected), result);
}
// TODO: add test with mixed constant & column string args, array args.

} // namespace
} // namespace facebook::velox::functions::sparksql::test
} // namespace facebook::velox::functions::sparksql::test
1 change: 0 additions & 1 deletion velox/functions/sparksql/tests/StringTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,5 @@ TEST_F(StringTest, trim) {
trimWithTrimStr("\u6570", "\u6574\u6570 \u6570\u636E!"),
"\u6574\u6570 \u6570\u636E!");
}

} // namespace
} // namespace facebook::velox::functions::sparksql::test

0 comments on commit a36c8a1

Please sign in to comment.