diff --git a/velox/functions/sparksql/ConcatWs.cpp b/velox/functions/sparksql/ConcatWs.cpp index 4cd38394559a..5efc7f3ba979 100644 --- a/velox/functions/sparksql/ConcatWs.cpp +++ b/velox/functions/sparksql/ConcatWs.cpp @@ -18,11 +18,17 @@ #include "velox/expression/VectorFunction.h" namespace facebook::velox::functions::sparksql { + +namespace { class ConcatWs : public exec::VectorFunction { public: explicit ConcatWs(const std::optional& separator) : separator_(separator) {} + bool isConstantSeparator() const { + return separator_.has_value(); + } + // Calculate the total number of bytes in the result. size_t calculateTotalResultBytes( const SelectivityVector& rows, @@ -47,6 +53,10 @@ class ConcatWs : public exec::VectorFunction { size_t totalResultBytes = 0; rows.applyToSelected([&](auto row) { + // NULL separator produces NULL result. + if (!isConstantSeparator() && decodedSeparator->isNullAt(row)) { + return; + } int32_t allElements = 0; // Calculate size for array columns data. for (int i = 0; i < arrayArgNum; i++) { @@ -87,7 +97,7 @@ class ConcatWs : public exec::VectorFunction { totalResultBytes += value.size(); } - int32_t separatorSize = separator_.has_value() + int32_t separatorSize = isConstantSeparator() ? separator_.value().size() : decodedSeparator->valueAt(row).size(); @@ -113,7 +123,7 @@ class ConcatWs : public exec::VectorFunction { } // Handles string arg. argMapping.push_back(i); - if (!separator_.has_value()) { + if (!isConstantSeparator()) { // Cannot concat consecutive constant string args in advance. constantStrings.push_back(""); continue; @@ -159,7 +169,8 @@ class ConcatWs : public exec::VectorFunction { const SelectivityVector& rows, std::vector& args, exec::EvalCtx& context, - FlatVector& flatResult) const { + VectorPtr& result) const { + auto& flatResult = *result->asFlatVector(); std::vector argMapping; std::vector constantStrings; auto numArgs = args.size(); @@ -182,7 +193,7 @@ class ConcatWs : public exec::VectorFunction { constantStrings, decodedStringArgs); exec::LocalDecodedVector decodedSeparator(context); - if (!separator_.has_value()) { + if (!isConstantSeparator()) { decodedSeparator = exec::LocalDecodedVector(context, *args[0], rows); } @@ -210,6 +221,11 @@ class ConcatWs : public exec::VectorFunction { auto rawBuffer = flatResult.getRawStringBufferWithSpace(totalResultBytes, true); rows.applyToSelected([&](auto row) { + // NULL separtor produces NULL result. + if (!isConstantSeparator() && decodedSeparator->isNullAt(row)) { + result->setNull(row, true); + return; + } const char* start = rawBuffer; auto isFirst = true; // For array arg. @@ -249,7 +265,7 @@ class ConcatWs : public exec::VectorFunction { auto element = elementsDecoded->valueAt(offset + k); copyToBuffer( element, - separator_.has_value() + isConstantSeparator() ? StringView(separator_.value()) : decodedSeparator->valueAt(row)); } @@ -275,9 +291,8 @@ class ConcatWs : public exec::VectorFunction { } copyToBuffer( value, - separator_.has_value() - ? StringView(separator_.value()) - : decodedSeparator->valueAt(row)); + isConstantSeparator() ? StringView(separator_.value()) + : decodedSeparator->valueAt(row)); j++; } flatResult.setNoCopy(row, StringView(start, rawBuffer - start)); @@ -287,60 +302,86 @@ class ConcatWs : public exec::VectorFunction { void apply( const SelectivityVector& rows, std::vector& args, - const TypePtr& /* outputType */, + const TypePtr& outputType, exec::EvalCtx& context, VectorPtr& result) const override { context.ensureWritable(rows, VARCHAR(), result); auto flatResult = result->asFlatVector(); auto numArgs = args.size(); // If separator is NULL, result is NULL. - if (args[0]->isNullAt(0)) { - rows.applyToSelected([&](auto row) { result->setNull(row, true); }); - return; + if (isConstantSeparator()) { + auto constant = args[0]->as>(); + if (constant->isNullAt(0)) { + auto localResult = BaseVector::createNullConstant( + outputType, rows.end(), context.pool()); + context.moveOrCopyResult(localResult, rows, result); + return; + } } // If only separator (not a NULL) is provided, result is an empty string. if (numArgs == 1) { - rows.applyToSelected( - [&](auto row) { flatResult->setNoCopy(row, StringView("")); }); + auto decodedSeparator = exec::LocalDecodedVector(context, *args[0], rows); + // 1. Separator is constant and not a NULL. + // 2. Separator is column and have no NULL. + if (isConstantSeparator() || !decodedSeparator->mayHaveNulls()) { + rows.applyToSelected( + [&](auto row) { flatResult->setNoCopy(row, StringView("")); }); + } else { + rows.applyToSelected([&](auto row) { + if (decodedSeparator->isNullAt(row)) { + result->setNull(row, true); + } else { + flatResult->setNoCopy(row, StringView("")); + } + }); + } return; } - doApply(rows, args, context, *flatResult); + doApply(rows, args, context, result); } private: // For holding constant separator. const std::optional separator_; }; +} // namespace -TypePtr ConcatWsCallToSpecialForm::resolveType(const std::vector& /*argTypes*/) { - return VARCHAR(); +TypePtr ConcatWsCallToSpecialForm::resolveType( + const std::vector& /*argTypes*/) { + return VARCHAR(); } exec::ExprPtr ConcatWsCallToSpecialForm::constructSpecialForm( - const TypePtr& type, - std::vector&& args, - bool trackCpuUsage, - const core::QueryConfig& config) { + const TypePtr& type, + std::vector&& args, + bool trackCpuUsage, + const core::QueryConfig& config) { auto numArgs = args.size(); VELOX_USER_CHECK_GE( numArgs, 1, "concat_ws requires one arguments at least, but got {}.", numArgs); - for (const auto& arg : args) { + VELOX_USER_CHECK( + args[0]->type()->isVarchar(), + "The first argument of concat_ws must be a varchar."); + for (size_t i = 1; i < args.size(); i++) { VELOX_USER_CHECK( - arg->type()->isVarchar() || - (arg->type()->isArray() && - arg->type()->asArray().elementType()->isVarchar()), - "concat_ws requires varchar or array(varchar) arguments, but got {}.", - arg->type()->toString()); + args[i]->type()->isVarchar() || + (args[i]->type()->isArray() && + args[i]->type()->asArray().elementType()->isVarchar()), + "The 2nd and following arguments for concat_ws should be varchar or array(varchar), but got {}.", + args[i]->type()->toString()); } std::optional separator = std::nullopt; auto constantExpr = std::dynamic_pointer_cast(args[0]); if (constantExpr != nullptr) { - separator = constantExpr->value()->asUnchecked>()->valueAt(0).str(); + separator = constantExpr->value() + ->asUnchecked>() + ->valueAt(0) + .str(); } auto concatWsFunction = std::make_shared(separator); return std::make_shared( diff --git a/velox/functions/sparksql/ConcatWs.h b/velox/functions/sparksql/ConcatWs.h index 8f86108ec930..a0b6d38dd22c 100644 --- a/velox/functions/sparksql/ConcatWs.h +++ b/velox/functions/sparksql/ConcatWs.h @@ -22,7 +22,6 @@ namespace facebook::velox::functions::sparksql { class ConcatWsCallToSpecialForm : public exec::FunctionCallToSpecialForm { public: - // Throws not supported exception. TypePtr resolveType(const std::vector& argTypes) override; exec::ExprPtr constructSpecialForm( diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index 7162c2db9b81..348b46f0f7a1 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -140,8 +140,8 @@ void registerAllSpecialFormGeneralFunctions() { registerFunctionCallToSpecialForm( "try_cast", std::make_unique()); registerFunctionCallToSpecialForm( - ConcatWsCallToSpecialForm::kConcatWs, - std::make_unique()); + ConcatWsCallToSpecialForm::kConcatWs, + std::make_unique()); } namespace { @@ -232,11 +232,6 @@ void registerFunctions(const std::string& prefix) { prefix + "length", lengthSignatures(), makeLength); registerFunction( {prefix + "substring_index"}); -// exec::registerStatefulVectorFunction( -// prefix + "concat_ws", -// concatWsSignatures(), -// makeConcatWs, -// exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build()); registerFunction({prefix + "md5"}); registerFunction( diff --git a/velox/functions/sparksql/tests/ConcatWsTest.cpp b/velox/functions/sparksql/tests/ConcatWsTest.cpp index 7c8db02e3d4c..671789a2c9bb 100644 --- a/velox/functions/sparksql/tests/ConcatWsTest.cpp +++ b/velox/functions/sparksql/tests/ConcatWsTest.cpp @@ -183,26 +183,26 @@ TEST_F(ConcatWsTest, arrayArgs) { // One array arg. auto result = evaluate>( "concat_ws('--', c0)", makeRowVector({arrayVector})); - auto expected1 = makeFlatVector({ + auto expected = makeFlatVector({ "red--blue", "blue--yellow--orange", "", "", "red--purple--green", }); - velox::test::assertEqualVectors(expected1, result); + velox::test::assertEqualVectors(expected, result); // Two array args. result = evaluate>( "concat_ws('--', c0, c1)", makeRowVector({arrayVector, arrayVector})); - auto expected2 = makeFlatVector({ + expected = makeFlatVector({ "red--blue--red--blue", "blue--yellow--orange--blue--yellow--orange", "", "", "red--purple--green--red--purple--green", }); - velox::test::assertEqualVectors(expected2, result); + velox::test::assertEqualVectors(expected, result); } TEST_F(ConcatWsTest, mixedStringAndArrayArgs) { @@ -234,8 +234,8 @@ TEST_F(ConcatWsTest, mixedStringAndArrayArgs) { } TEST_F(ConcatWsTest, nonconstantSeparator) { - auto separatorVector = - makeFlatVector({"##", "--", "~~", "**", "++"}); + auto separatorVector = makeNullableFlatVector( + {"##", "--", "~~", "**", std::nullopt}); auto arrayVector = makeNullableArrayVector({ {"red", "blue"}, {"blue", std::nullopt, "yellow", std::nullopt, "orange"}, @@ -246,12 +246,27 @@ TEST_F(ConcatWsTest, nonconstantSeparator) { auto result = evaluate>( "concat_ws(c0, c1, '|')", makeRowVector({separatorVector, arrayVector})); - auto expected = makeFlatVector({ + auto expected = makeNullableFlatVector({ "red##blue##|", "blue--yellow--orange--|", "red~~blue~~|", "blue**yellow**orange**|", - "red++purple++green++|", + std::nullopt, + }); + velox::test::assertEqualVectors(expected, result); +} + +TEST_F(ConcatWsTest, separatorOnly) { + auto separatorVector = makeNullableFlatVector( + {"##", std::nullopt, "~~", "**", std::nullopt}); + auto result = evaluate>( + "concat_ws(c0)", makeRowVector({separatorVector})); + auto expected = makeNullableFlatVector({ + "", + std::nullopt, + "", + "", + std::nullopt, }); velox::test::assertEqualVectors(expected, result); }