Skip to content

Commit

Permalink
Support non-constant separator
Browse files Browse the repository at this point in the history
  • Loading branch information
PHILO-HE committed Apr 12, 2024
1 parent 8203b08 commit d26e33b
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 60 deletions.
11 changes: 6 additions & 5 deletions velox/docs/functions/spark/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,19 @@ Unless specified otherwise, all functions return NULL if at least one of the arg

.. spark:function:: concat_ws(separator, [string]/[array<string>], ...) -> varchar
Returns the concatenation for ``string`` & all elements in ``array<string>``, separated by
``separator``. Only accepts constant ``separator``. It takes variable number of remaining
arguments. And ``string`` & ``array<string>`` can be used in combination. If ``separator``
is NULL, returns NULL, regardless of the following inputs. If only ``separator`` (not a
NULL) is provided or all remaining inputs are NULL, returns an empty string. ::
Returns the concatenation for ``string`` and all elements in ``array<string>``, separated
by ``separator``. ``separator`` can be empty string. It takes variable number of remaining
arguments. And ``string`` & ``array<string>`` can be used in combination. If ``separator``
is NULL, returns NULL, regardless of the following inputs. For non-NULL ``separator``, if
only it is provided or all remaining inputs are NULL, returns an empty string. ::

SELECT concat_ws('~', 'a', 'b', 'c'); -- 'a~b~c'
SELECT concat_ws('~', ['a', 'b', 'c'], ['d']); -- 'a~b~c~d'
SELECT concat_ws('~', 'a', ['b', 'c']); -- 'a~b~c'
SELECT concat_ws(NULL, 'a'); -- NULL
SELECT concat_ws('~'); -- ''
SELECT concat_ws('~', NULL, NULL); -- ''
SELECT concat_ws('~', [NULL]); -- ''

.. spark:function:: contains(left, right) -> boolean
Expand Down
65 changes: 46 additions & 19 deletions velox/functions/sparksql/String.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ void doApply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
exec::EvalCtx& context,
const std::string& separator,
const std::optional<std::string>& separator,
FlatVector<StringView>& flatResult) {
std::vector<column_index_t> argMapping;
std::vector<std::string> constantStrings;
Expand Down Expand Up @@ -142,19 +142,24 @@ void doApply(
}
// Handles string arg.
argMapping.push_back(i);
// Cannot concat string args in advance.
if (!separator.has_value()) {
constantStrings.push_back("");
continue;
}
if (args[i] && args[i]->as<ConstantVector<StringView>>() &&
!args[i]->as<ConstantVector<StringView>>()->isNullAt(0)) {
std::ostringstream out;
out << args[i]->as<ConstantVector<StringView>>()->valueAt(0);
column_index_t j = i + 1;
// Concat constant string args.
// Concat constant string args in advance.
for (; j < numArgs; ++j) {
if (!args[j] || args[j]->typeKind() == TypeKind::ARRAY ||
!args[j]->as<ConstantVector<StringView>>() ||
args[j]->as<ConstantVector<StringView>>()->isNullAt(0)) {
break;
}
out << separator
out << separator.value()
<< args[j]->as<ConstantVector<StringView>>()->valueAt(0);
}
constantStrings.emplace_back(out.str());
Expand All @@ -169,18 +174,22 @@ void doApply(
// For column string arg decoding.
std::vector<exec::LocalDecodedVector> decodedStringArgs;
decodedStringArgs.reserve(numStringCols);

for (auto i = 0; i < numStringCols; ++i) {
if (constantStrings[i].empty()) {
auto index = argMapping[i];
decodedStringArgs.emplace_back(context, *args[index], rows);
}
}
exec::LocalDecodedVector separatorDecoded(context);
if (!separator.has_value()) {
separatorDecoded = exec::LocalDecodedVector(context, *args[0], rows);
}

// Calculate the total number of bytes in the result.
size_t totalResultBytes = 0;
rows.applyToSelected([&](auto row) {
int32_t allElements = 0;
// Array arg.
for (int i = 0; i < rawSizesVector.size(); i++) {
auto size = rawSizesVector[i][indicesVector[i][row]];
auto offset = rawOffsetsVector[i][indicesVector[i][row]];
Expand All @@ -194,6 +203,8 @@ void doApply(
}
}
}

// String arg.
auto it = decodedStringArgs.begin();
for (int i = 0; i < numStringCols; i++) {
auto value = constantStrings[i].empty()
Expand All @@ -204,8 +215,13 @@ void doApply(
totalResultBytes += value.size();
}
}
if (allElements > 1) {
totalResultBytes += (allElements - 1) * separator.size();

int32_t separatorSize = separator.has_value()
? separator.value().size()
: separatorDecoded->valueAt<StringView>(row).size();

if (allElements > 1 && separatorSize > 0) {
totalResultBytes += (allElements - 1) * separatorSize;
}
});

Expand All @@ -222,17 +238,19 @@ void doApply(
int32_t j = 0;
auto it = decodedStringArgs.begin();

auto copyToBuffer = [&](StringView value) {
auto copyToBuffer = [&](StringView value, StringView separator) {
if (value.empty()) {
return;
}
if (isFirst) {
isFirst = false;
} else {
// Add separator before the current value.
memcpy(rawBuffer + bufferOffset, separator.data(), separator.size());
bufferOffset += separator.size();
combinedSize += separator.size();
if (!separator.empty()) {
memcpy(rawBuffer + bufferOffset, separator.data(), separator.size());
bufferOffset += separator.size();
combinedSize += separator.size();
}
}
memcpy(rawBuffer + bufferOffset, value.data(), value.size());
combinedSize += value.size();
Expand All @@ -246,7 +264,11 @@ void doApply(
for (int k = 0; k < size; ++k) {
if (!decodedVectors[i].isNullAt(offset + k)) {
auto element = decodedVectors[i].valueAt<StringView>(offset + k);
copyToBuffer(element);
copyToBuffer(
element,
separator.has_value()
? StringView(separator.value())
: separatorDecoded->valueAt<StringView>(row));
}
}
i++;
Expand All @@ -261,7 +283,10 @@ void doApply(
} else {
value = StringView(constantStrings[j]);
}
copyToBuffer(value);
copyToBuffer(
value,
separator.has_value() ? StringView(separator.value())
: separatorDecoded->valueAt<StringView>(row));
j++;
}
flatResult.setNoCopy(row, StringView(start, combinedSize));
Expand All @@ -270,7 +295,8 @@ void doApply(

class ConcatWs : public exec::VectorFunction {
public:
explicit ConcatWs(const std::string& separator) : separator_(separator) {}
explicit ConcatWs(const std::optional<std::string>& separator)
: separator_(separator) {}

void apply(
const SelectivityVector& selected,
Expand All @@ -296,7 +322,8 @@ class ConcatWs : public exec::VectorFunction {
}

private:
const std::string separator_;
// If has no value, the separator is non-constant.
const std::optional<std::string> separator_;
};

} // namespace
Expand Down Expand Up @@ -371,12 +398,12 @@ std::shared_ptr<exec::VectorFunction> makeConcatWs(
}

BaseVector* constantPattern = inputArgs[0].constantValue.get();
VELOX_USER_CHECK(
nullptr != constantPattern,
"concat_ws requires constant separator arguments.");
std::optional<std::string> separator = std::nullopt;
if (constantPattern != nullptr) {
separator =
constantPattern->as<ConstantVector<StringView>>()->valueAt(0).str();
}

auto separator =
constantPattern->as<ConstantVector<StringView>>()->valueAt(0).str();
return std::make_shared<ConcatWs>(separator);
}

Expand Down
94 changes: 58 additions & 36 deletions velox/functions/sparksql/tests/ConcatWsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class ConcatWsTest : public SparkFunctionBaseTest {
}
};

TEST_F(ConcatWsTest, columnStringArgs) {
TEST_F(ConcatWsTest, stringArgs) {
// Test with constant args.
auto rows = makeRowVector(makeRowType({VARCHAR(), VARCHAR()}), 10);
auto c0 = generateRandomString(20);
Expand Down Expand Up @@ -118,10 +118,12 @@ TEST_F(ConcatWsTest, columnStringArgs) {

SCOPED_TRACE(fmt::format("Number of arguments: {}", argsCount));
testConcatWsFlatVector(inputTable, argsCount, "--testSep--");
// Test with empty separator.
testConcatWsFlatVector(inputTable, argsCount, "");
}
}

TEST_F(ConcatWsTest, mixedConstantAndColumnStringArgs) {
TEST_F(ConcatWsTest, mixedConstantAndNonconstantStringArgs) {
size_t maxStringLength = 100;
std::string value;
auto data = makeRowVector({
Expand Down Expand Up @@ -183,63 +185,83 @@ TEST_F(ConcatWsTest, mixedConstantAndColumnStringArgs) {
}

TEST_F(ConcatWsTest, arrayArgs) {
using S = StringView;
auto arrayVector = makeNullableArrayVector<StringView>({
{S("red"), S("blue")},
{S("blue"), std::nullopt, S("yellow"), std::nullopt, S("orange")},
{"red", "blue"},
{"blue", std::nullopt, "yellow", std::nullopt, "orange"},
{},
{std::nullopt},
{S("red"), S("purple"), S("green")},
{"red", "purple", "green"},
});

// One array arg.
auto result = evaluate<SimpleVector<StringView>>(
"concat_ws('----', c0)", makeRowVector({arrayVector}));
auto expected1 = {
S("red----blue"),
S("blue----yellow----orange"),
S(""),
S(""),
S("red----purple----green"),
};
velox::test::assertEqualVectors(
makeFlatVector<StringView>(expected1), result);
auto expected1 = makeFlatVector<StringView>({
"red----blue",
"blue----yellow----orange",
"",
"",
"red----purple----green",
});
velox::test::assertEqualVectors(expected1, result);

// Two array args.
result = evaluate<SimpleVector<StringView>>(
"concat_ws('----', c0, c1)", makeRowVector({arrayVector, arrayVector}));
auto expected2 = {
S("red----blue----red----blue"),
S("blue----yellow----orange----blue----yellow----orange"),
S(""),
S(""),
S("red----purple----green----red----purple----green"),
};
velox::test::assertEqualVectors(
makeFlatVector<StringView>(expected2), result);
auto expected2 = makeFlatVector<StringView>({
"red----blue----red----blue",
"blue----yellow----orange----blue----yellow----orange",
"",
"",
"red----purple----green----red----purple----green",
});
velox::test::assertEqualVectors(expected2, result);
}

TEST_F(ConcatWsTest, mixedStringArrayArgs) {
TEST_F(ConcatWsTest, mixedStringAndArrayArgs) {
using S = StringView;
auto arrayVector = makeNullableArrayVector<StringView>({
{S("red"), S("blue")},
{S("blue"), std::nullopt, S("yellow"), std::nullopt, S("orange")},
{"red", "blue"},
{"blue", std::nullopt, "yellow", std::nullopt, "orange"},
{},
{std::nullopt},
{S("red"), S("purple"), S("green")},
{"red", "purple", "green"},
});

auto result = evaluate<SimpleVector<StringView>>(
"concat_ws('----', c0, 'foo', c1, 'bar', 'end')",
makeRowVector({arrayVector, arrayVector}));
auto expected = {
S("red----blue----foo----red----blue----bar----end"),
S("blue----yellow----orange----foo----blue----yellow----orange----bar----end"),
S("foo----bar----end"),
S("foo----bar----end"),
S("red----purple----green----foo----red----purple----green----bar----end"),
};
velox::test::assertEqualVectors(makeFlatVector<StringView>(expected), result);
auto expected = makeFlatVector<StringView>({
"red----blue----foo----red----blue----bar----end",
"blue----yellow----orange----foo----blue----yellow----orange----bar----end",
"foo----bar----end",
"foo----bar----end",
"red----purple----green----foo----red----purple----green----bar----end",
});
velox::test::assertEqualVectors(expected, result);
}

TEST_F(ConcatWsTest, nonconstantSeparator) {
auto separatorVector =
makeFlatVector<StringView>({"##", "--", "~~", "**", "++"});
auto arrayVector = makeNullableArrayVector<StringView>({
{"red", "blue"},
{"blue", std::nullopt, "yellow", std::nullopt, "orange"},
{"red", "blue"},
{"blue", std::nullopt, "yellow", std::nullopt, "orange"},
{"red", "purple", "green"},
});

auto result = evaluate<SimpleVector<StringView>>(
"concat_ws(c0, c1, '|')", makeRowVector({separatorVector, arrayVector}));
auto expected = makeFlatVector<StringView>({
"red##blue##|",
"blue--yellow--orange--|",
"red~~blue~~|",
"blue**yellow**orange**|",
"red++purple++green++|",
});
velox::test::assertEqualVectors(expected, result);
}

} // namespace
Expand Down

0 comments on commit d26e33b

Please sign in to comment.