diff --git a/velox/docs/functions/spark/string.rst b/velox/docs/functions/spark/string.rst index be4f5d06fab0..acbcea70bf25 100644 --- a/velox/docs/functions/spark/string.rst +++ b/velox/docs/functions/spark/string.rst @@ -22,11 +22,11 @@ Unless specified otherwise, all functions return NULL if at least one of the arg .. spark:function:: concat_ws(separator, [string]/[array], ...) -> varchar - Returns the concatenation for ``string`` & all elements in ``array``, separated by - ``separator``. Only accepts constant ``separator``. It takes variable number of remaining - arguments. And ``string`` & ``array`` can be used in combination. If ``separator`` - is NULL, returns NULL, regardless of the following inputs. If only ``separator`` (not a - NULL) is provided or all remaining inputs are NULL, returns an empty string. :: + Returns the concatenation for ``string`` and all elements in ``array``, separated + by ``separator``. ``separator`` can be empty string. It takes variable number of remaining + arguments. And ``string`` & ``array`` can be used in combination. If ``separator`` + is NULL, returns NULL, regardless of the following inputs. For non-NULL ``separator``, if + only it is provided or all remaining inputs are NULL, returns an empty string. :: SELECT concat_ws('~', 'a', 'b', 'c'); -- 'a~b~c' SELECT concat_ws('~', ['a', 'b', 'c'], ['d']); -- 'a~b~c~d' @@ -34,6 +34,7 @@ Unless specified otherwise, all functions return NULL if at least one of the arg SELECT concat_ws(NULL, 'a'); -- NULL SELECT concat_ws('~'); -- '' SELECT concat_ws('~', NULL, NULL); -- '' + SELECT concat_ws('~', [NULL]); -- '' .. spark:function:: contains(left, right) -> boolean diff --git a/velox/functions/sparksql/String.cpp b/velox/functions/sparksql/String.cpp index 6f1e59060d32..c8b487c3ad34 100644 --- a/velox/functions/sparksql/String.cpp +++ b/velox/functions/sparksql/String.cpp @@ -107,7 +107,7 @@ void doApply( const SelectivityVector& rows, std::vector& args, exec::EvalCtx& context, - const std::string& separator, + const std::optional& separator, FlatVector& flatResult) { std::vector argMapping; std::vector constantStrings; @@ -142,19 +142,24 @@ void doApply( } // Handles string arg. argMapping.push_back(i); + // Cannot concat string args in advance. + if (!separator.has_value()) { + constantStrings.push_back(""); + continue; + } if (args[i] && args[i]->as>() && !args[i]->as>()->isNullAt(0)) { std::ostringstream out; out << args[i]->as>()->valueAt(0); column_index_t j = i + 1; - // Concat constant string args. + // Concat constant string args in advance. for (; j < numArgs; ++j) { if (!args[j] || args[j]->typeKind() == TypeKind::ARRAY || !args[j]->as>() || args[j]->as>()->isNullAt(0)) { break; } - out << separator + out << separator.value() << args[j]->as>()->valueAt(0); } constantStrings.emplace_back(out.str()); @@ -169,18 +174,22 @@ void doApply( // For column string arg decoding. std::vector decodedStringArgs; decodedStringArgs.reserve(numStringCols); - for (auto i = 0; i < numStringCols; ++i) { if (constantStrings[i].empty()) { auto index = argMapping[i]; decodedStringArgs.emplace_back(context, *args[index], rows); } } + exec::LocalDecodedVector separatorDecoded(context); + if (!separator.has_value()) { + separatorDecoded = exec::LocalDecodedVector(context, *args[0], rows); + } // Calculate the total number of bytes in the result. size_t totalResultBytes = 0; rows.applyToSelected([&](auto row) { int32_t allElements = 0; + // Array arg. for (int i = 0; i < rawSizesVector.size(); i++) { auto size = rawSizesVector[i][indicesVector[i][row]]; auto offset = rawOffsetsVector[i][indicesVector[i][row]]; @@ -194,6 +203,8 @@ void doApply( } } } + + // String arg. auto it = decodedStringArgs.begin(); for (int i = 0; i < numStringCols; i++) { auto value = constantStrings[i].empty() @@ -204,8 +215,13 @@ void doApply( totalResultBytes += value.size(); } } - if (allElements > 1) { - totalResultBytes += (allElements - 1) * separator.size(); + + int32_t separatorSize = separator.has_value() + ? separator.value().size() + : separatorDecoded->valueAt(row).size(); + + if (allElements > 1 && separatorSize > 0) { + totalResultBytes += (allElements - 1) * separatorSize; } }); @@ -222,7 +238,7 @@ void doApply( int32_t j = 0; auto it = decodedStringArgs.begin(); - auto copyToBuffer = [&](StringView value) { + auto copyToBuffer = [&](StringView value, StringView separator) { if (value.empty()) { return; } @@ -230,9 +246,11 @@ void doApply( isFirst = false; } else { // Add separator before the current value. - memcpy(rawBuffer + bufferOffset, separator.data(), separator.size()); - bufferOffset += separator.size(); - combinedSize += separator.size(); + if (!separator.empty()) { + memcpy(rawBuffer + bufferOffset, separator.data(), separator.size()); + bufferOffset += separator.size(); + combinedSize += separator.size(); + } } memcpy(rawBuffer + bufferOffset, value.data(), value.size()); combinedSize += value.size(); @@ -246,7 +264,11 @@ void doApply( for (int k = 0; k < size; ++k) { if (!decodedVectors[i].isNullAt(offset + k)) { auto element = decodedVectors[i].valueAt(offset + k); - copyToBuffer(element); + copyToBuffer( + element, + separator.has_value() + ? StringView(separator.value()) + : separatorDecoded->valueAt(row)); } } i++; @@ -261,7 +283,10 @@ void doApply( } else { value = StringView(constantStrings[j]); } - copyToBuffer(value); + copyToBuffer( + value, + separator.has_value() ? StringView(separator.value()) + : separatorDecoded->valueAt(row)); j++; } flatResult.setNoCopy(row, StringView(start, combinedSize)); @@ -270,7 +295,8 @@ void doApply( class ConcatWs : public exec::VectorFunction { public: - explicit ConcatWs(const std::string& separator) : separator_(separator) {} + explicit ConcatWs(const std::optional& separator) + : separator_(separator) {} void apply( const SelectivityVector& selected, @@ -296,7 +322,8 @@ class ConcatWs : public exec::VectorFunction { } private: - const std::string separator_; + // If has no value, the separator is non-constant. + const std::optional separator_; }; } // namespace @@ -371,12 +398,12 @@ std::shared_ptr makeConcatWs( } BaseVector* constantPattern = inputArgs[0].constantValue.get(); - VELOX_USER_CHECK( - nullptr != constantPattern, - "concat_ws requires constant separator arguments."); + std::optional separator = std::nullopt; + if (constantPattern != nullptr) { + separator = + constantPattern->as>()->valueAt(0).str(); + } - auto separator = - constantPattern->as>()->valueAt(0).str(); return std::make_shared(separator); } diff --git a/velox/functions/sparksql/tests/ConcatWsTest.cpp b/velox/functions/sparksql/tests/ConcatWsTest.cpp index 5084a903d56e..1c40c3fa9205 100644 --- a/velox/functions/sparksql/tests/ConcatWsTest.cpp +++ b/velox/functions/sparksql/tests/ConcatWsTest.cpp @@ -88,7 +88,7 @@ class ConcatWsTest : public SparkFunctionBaseTest { } }; -TEST_F(ConcatWsTest, columnStringArgs) { +TEST_F(ConcatWsTest, stringArgs) { // Test with constant args. auto rows = makeRowVector(makeRowType({VARCHAR(), VARCHAR()}), 10); auto c0 = generateRandomString(20); @@ -118,10 +118,12 @@ TEST_F(ConcatWsTest, columnStringArgs) { SCOPED_TRACE(fmt::format("Number of arguments: {}", argsCount)); testConcatWsFlatVector(inputTable, argsCount, "--testSep--"); + // Test with empty separator. + testConcatWsFlatVector(inputTable, argsCount, ""); } } -TEST_F(ConcatWsTest, mixedConstantAndColumnStringArgs) { +TEST_F(ConcatWsTest, mixedConstantAndNonconstantStringArgs) { size_t maxStringLength = 100; std::string value; auto data = makeRowVector({ @@ -183,63 +185,83 @@ TEST_F(ConcatWsTest, mixedConstantAndColumnStringArgs) { } TEST_F(ConcatWsTest, arrayArgs) { - using S = StringView; auto arrayVector = makeNullableArrayVector({ - {S("red"), S("blue")}, - {S("blue"), std::nullopt, S("yellow"), std::nullopt, S("orange")}, + {"red", "blue"}, + {"blue", std::nullopt, "yellow", std::nullopt, "orange"}, {}, {std::nullopt}, - {S("red"), S("purple"), S("green")}, + {"red", "purple", "green"}, }); // One array arg. auto result = evaluate>( "concat_ws('----', c0)", makeRowVector({arrayVector})); - auto expected1 = { - S("red----blue"), - S("blue----yellow----orange"), - S(""), - S(""), - S("red----purple----green"), - }; - velox::test::assertEqualVectors( - makeFlatVector(expected1), result); + auto expected1 = makeFlatVector({ + "red----blue", + "blue----yellow----orange", + "", + "", + "red----purple----green", + }); + velox::test::assertEqualVectors(expected1, result); // Two array args. result = evaluate>( "concat_ws('----', c0, c1)", makeRowVector({arrayVector, arrayVector})); - auto expected2 = { - S("red----blue----red----blue"), - S("blue----yellow----orange----blue----yellow----orange"), - S(""), - S(""), - S("red----purple----green----red----purple----green"), - }; - velox::test::assertEqualVectors( - makeFlatVector(expected2), result); + auto expected2 = makeFlatVector({ + "red----blue----red----blue", + "blue----yellow----orange----blue----yellow----orange", + "", + "", + "red----purple----green----red----purple----green", + }); + velox::test::assertEqualVectors(expected2, result); } -TEST_F(ConcatWsTest, mixedStringArrayArgs) { +TEST_F(ConcatWsTest, mixedStringAndArrayArgs) { using S = StringView; auto arrayVector = makeNullableArrayVector({ - {S("red"), S("blue")}, - {S("blue"), std::nullopt, S("yellow"), std::nullopt, S("orange")}, + {"red", "blue"}, + {"blue", std::nullopt, "yellow", std::nullopt, "orange"}, {}, {std::nullopt}, - {S("red"), S("purple"), S("green")}, + {"red", "purple", "green"}, }); auto result = evaluate>( "concat_ws('----', c0, 'foo', c1, 'bar', 'end')", makeRowVector({arrayVector, arrayVector})); - auto expected = { - S("red----blue----foo----red----blue----bar----end"), - S("blue----yellow----orange----foo----blue----yellow----orange----bar----end"), - S("foo----bar----end"), - S("foo----bar----end"), - S("red----purple----green----foo----red----purple----green----bar----end"), - }; - velox::test::assertEqualVectors(makeFlatVector(expected), result); + auto expected = makeFlatVector({ + "red----blue----foo----red----blue----bar----end", + "blue----yellow----orange----foo----blue----yellow----orange----bar----end", + "foo----bar----end", + "foo----bar----end", + "red----purple----green----foo----red----purple----green----bar----end", + }); + velox::test::assertEqualVectors(expected, result); +} + +TEST_F(ConcatWsTest, nonconstantSeparator) { + auto separatorVector = + makeFlatVector({"##", "--", "~~", "**", "++"}); + auto arrayVector = makeNullableArrayVector({ + {"red", "blue"}, + {"blue", std::nullopt, "yellow", std::nullopt, "orange"}, + {"red", "blue"}, + {"blue", std::nullopt, "yellow", std::nullopt, "orange"}, + {"red", "purple", "green"}, + }); + + auto result = evaluate>( + "concat_ws(c0, c1, '|')", makeRowVector({separatorVector, arrayVector})); + auto expected = makeFlatVector({ + "red##blue##|", + "blue--yellow--orange--|", + "red~~blue~~|", + "blue**yellow**orange**|", + "red++purple++green++|", + }); + velox::test::assertEqualVectors(expected, result); } } // namespace