Skip to content

Commit

Permalink
Fix column separator with NULLs
Browse files Browse the repository at this point in the history
  • Loading branch information
PHILO-HE committed Jun 24, 2024
1 parent b007b1e commit ca1250c
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 44 deletions.
97 changes: 69 additions & 28 deletions velox/functions/sparksql/ConcatWs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@
#include "velox/expression/VectorFunction.h"

namespace facebook::velox::functions::sparksql {

namespace {
class ConcatWs : public exec::VectorFunction {
public:
explicit ConcatWs(const std::optional<std::string>& separator)
: separator_(separator) {}

bool isConstantSeparator() const {
return separator_.has_value();
}

// Calculate the total number of bytes in the result.
size_t calculateTotalResultBytes(
const SelectivityVector& rows,
Expand All @@ -47,6 +53,10 @@ class ConcatWs : public exec::VectorFunction {

size_t totalResultBytes = 0;
rows.applyToSelected([&](auto row) {
// NULL separator produces NULL result.
if (!isConstantSeparator() && decodedSeparator->isNullAt(row)) {
return;
}
int32_t allElements = 0;
// Calculate size for array columns data.
for (int i = 0; i < arrayArgNum; i++) {
Expand Down Expand Up @@ -87,7 +97,7 @@ class ConcatWs : public exec::VectorFunction {
totalResultBytes += value.size();
}

int32_t separatorSize = separator_.has_value()
int32_t separatorSize = isConstantSeparator()
? separator_.value().size()
: decodedSeparator->valueAt<StringView>(row).size();

Expand All @@ -113,7 +123,7 @@ class ConcatWs : public exec::VectorFunction {
}
// Handles string arg.
argMapping.push_back(i);
if (!separator_.has_value()) {
if (!isConstantSeparator()) {
// Cannot concat consecutive constant string args in advance.
constantStrings.push_back("");
continue;
Expand Down Expand Up @@ -159,7 +169,8 @@ class ConcatWs : public exec::VectorFunction {
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
exec::EvalCtx& context,
FlatVector<StringView>& flatResult) const {
VectorPtr& result) const {
auto& flatResult = *result->asFlatVector<StringView>();
std::vector<column_index_t> argMapping;
std::vector<std::string> constantStrings;
auto numArgs = args.size();
Expand All @@ -182,7 +193,7 @@ class ConcatWs : public exec::VectorFunction {
constantStrings,
decodedStringArgs);
exec::LocalDecodedVector decodedSeparator(context);
if (!separator_.has_value()) {
if (!isConstantSeparator()) {
decodedSeparator = exec::LocalDecodedVector(context, *args[0], rows);
}

Expand Down Expand Up @@ -210,6 +221,11 @@ class ConcatWs : public exec::VectorFunction {
auto rawBuffer =
flatResult.getRawStringBufferWithSpace(totalResultBytes, true);
rows.applyToSelected([&](auto row) {
// NULL separtor produces NULL result.
if (!isConstantSeparator() && decodedSeparator->isNullAt(row)) {
result->setNull(row, true);
return;
}
const char* start = rawBuffer;
auto isFirst = true;
// For array arg.
Expand Down Expand Up @@ -249,7 +265,7 @@ class ConcatWs : public exec::VectorFunction {
auto element = elementsDecoded->valueAt<StringView>(offset + k);
copyToBuffer(
element,
separator_.has_value()
isConstantSeparator()
? StringView(separator_.value())
: decodedSeparator->valueAt<StringView>(row));
}
Expand All @@ -275,9 +291,8 @@ class ConcatWs : public exec::VectorFunction {
}
copyToBuffer(
value,
separator_.has_value()
? StringView(separator_.value())
: decodedSeparator->valueAt<StringView>(row));
isConstantSeparator() ? StringView(separator_.value())
: decodedSeparator->valueAt<StringView>(row));
j++;
}
flatResult.setNoCopy(row, StringView(start, rawBuffer - start));
Expand All @@ -287,60 +302,86 @@ class ConcatWs : public exec::VectorFunction {
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& /* outputType */,
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& result) const override {
context.ensureWritable(rows, VARCHAR(), result);
auto flatResult = result->asFlatVector<StringView>();
auto numArgs = args.size();
// If separator is NULL, result is NULL.
if (args[0]->isNullAt(0)) {
rows.applyToSelected([&](auto row) { result->setNull(row, true); });
return;
if (isConstantSeparator()) {
auto constant = args[0]->as<ConstantVector<StringView>>();
if (constant->isNullAt(0)) {
auto localResult = BaseVector::createNullConstant(
outputType, rows.end(), context.pool());
context.moveOrCopyResult(localResult, rows, result);
return;
}
}
// If only separator (not a NULL) is provided, result is an empty string.
if (numArgs == 1) {
rows.applyToSelected(
[&](auto row) { flatResult->setNoCopy(row, StringView("")); });
auto decodedSeparator = exec::LocalDecodedVector(context, *args[0], rows);
// 1. Separator is constant and not a NULL.
// 2. Separator is column and have no NULL.
if (isConstantSeparator() || !decodedSeparator->mayHaveNulls()) {
rows.applyToSelected(
[&](auto row) { flatResult->setNoCopy(row, StringView("")); });
} else {
rows.applyToSelected([&](auto row) {
if (decodedSeparator->isNullAt(row)) {
result->setNull(row, true);
} else {
flatResult->setNoCopy(row, StringView(""));
}
});
}
return;
}
doApply(rows, args, context, *flatResult);
doApply(rows, args, context, result);
}

private:
// For holding constant separator.
const std::optional<std::string> separator_;
};
} // namespace

TypePtr ConcatWsCallToSpecialForm::resolveType(const std::vector<TypePtr>& /*argTypes*/) {
return VARCHAR();
TypePtr ConcatWsCallToSpecialForm::resolveType(
const std::vector<TypePtr>& /*argTypes*/) {
return VARCHAR();
}

exec::ExprPtr ConcatWsCallToSpecialForm::constructSpecialForm(
const TypePtr& type,
std::vector<exec::ExprPtr>&& args,
bool trackCpuUsage,
const core::QueryConfig& config) {
const TypePtr& type,
std::vector<exec::ExprPtr>&& args,
bool trackCpuUsage,
const core::QueryConfig& config) {
auto numArgs = args.size();
VELOX_USER_CHECK_GE(
numArgs,
1,
"concat_ws requires one arguments at least, but got {}.",
numArgs);
for (const auto& arg : args) {
VELOX_USER_CHECK(
args[0]->type()->isVarchar(),
"The first argument of concat_ws must be a varchar.");
for (size_t i = 1; i < args.size(); i++) {
VELOX_USER_CHECK(
arg->type()->isVarchar() ||
(arg->type()->isArray() &&
arg->type()->asArray().elementType()->isVarchar()),
"concat_ws requires varchar or array(varchar) arguments, but got {}.",
arg->type()->toString());
args[i]->type()->isVarchar() ||
(args[i]->type()->isArray() &&
args[i]->type()->asArray().elementType()->isVarchar()),
"The 2nd and following arguments for concat_ws should be varchar or array(varchar), but got {}.",
args[i]->type()->toString());
}

std::optional<std::string> separator = std::nullopt;
auto constantExpr = std::dynamic_pointer_cast<exec::ConstantExpr>(args[0]);

if (constantExpr != nullptr) {
separator = constantExpr->value()->asUnchecked<ConstantVector<StringView>>()->valueAt(0).str();
separator = constantExpr->value()
->asUnchecked<ConstantVector<StringView>>()
->valueAt(0)
.str();
}
auto concatWsFunction = std::make_shared<ConcatWs>(separator);
return std::make_shared<exec::Expr>(
Expand Down
1 change: 0 additions & 1 deletion velox/functions/sparksql/ConcatWs.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ namespace facebook::velox::functions::sparksql {

class ConcatWsCallToSpecialForm : public exec::FunctionCallToSpecialForm {
public:
// Throws not supported exception.
TypePtr resolveType(const std::vector<TypePtr>& argTypes) override;

exec::ExprPtr constructSpecialForm(
Expand Down
9 changes: 2 additions & 7 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ void registerAllSpecialFormGeneralFunctions() {
registerFunctionCallToSpecialForm(
"try_cast", std::make_unique<SparkTryCastCallToSpecialForm>());
registerFunctionCallToSpecialForm(
ConcatWsCallToSpecialForm::kConcatWs,
std::make_unique<ConcatWsCallToSpecialForm>());
ConcatWsCallToSpecialForm::kConcatWs,
std::make_unique<ConcatWsCallToSpecialForm>());
}

namespace {
Expand Down Expand Up @@ -232,11 +232,6 @@ void registerFunctions(const std::string& prefix) {
prefix + "length", lengthSignatures(), makeLength);
registerFunction<SubstringIndexFunction, Varchar, Varchar, Varchar, int32_t>(
{prefix + "substring_index"});
// exec::registerStatefulVectorFunction(
// prefix + "concat_ws",
// concatWsSignatures(),
// makeConcatWs,
// exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build());

registerFunction<Md5Function, Varchar, Varbinary>({prefix + "md5"});
registerFunction<Sha1HexStringFunction, Varchar, Varbinary>(
Expand Down
31 changes: 23 additions & 8 deletions velox/functions/sparksql/tests/ConcatWsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,26 +183,26 @@ TEST_F(ConcatWsTest, arrayArgs) {
// One array arg.
auto result = evaluate<SimpleVector<StringView>>(
"concat_ws('--', c0)", makeRowVector({arrayVector}));
auto expected1 = makeFlatVector<StringView>({
auto expected = makeFlatVector<StringView>({
"red--blue",
"blue--yellow--orange",
"",
"",
"red--purple--green",
});
velox::test::assertEqualVectors(expected1, result);
velox::test::assertEqualVectors(expected, result);

// Two array args.
result = evaluate<SimpleVector<StringView>>(
"concat_ws('--', c0, c1)", makeRowVector({arrayVector, arrayVector}));
auto expected2 = makeFlatVector<StringView>({
expected = makeFlatVector<StringView>({
"red--blue--red--blue",
"blue--yellow--orange--blue--yellow--orange",
"",
"",
"red--purple--green--red--purple--green",
});
velox::test::assertEqualVectors(expected2, result);
velox::test::assertEqualVectors(expected, result);
}

TEST_F(ConcatWsTest, mixedStringAndArrayArgs) {
Expand Down Expand Up @@ -234,8 +234,8 @@ TEST_F(ConcatWsTest, mixedStringAndArrayArgs) {
}

TEST_F(ConcatWsTest, nonconstantSeparator) {
auto separatorVector =
makeFlatVector<StringView>({"##", "--", "~~", "**", "++"});
auto separatorVector = makeNullableFlatVector<StringView>(
{"##", "--", "~~", "**", std::nullopt});
auto arrayVector = makeNullableArrayVector<StringView>({
{"red", "blue"},
{"blue", std::nullopt, "yellow", std::nullopt, "orange"},
Expand All @@ -246,12 +246,27 @@ TEST_F(ConcatWsTest, nonconstantSeparator) {

auto result = evaluate<SimpleVector<StringView>>(
"concat_ws(c0, c1, '|')", makeRowVector({separatorVector, arrayVector}));
auto expected = makeFlatVector<StringView>({
auto expected = makeNullableFlatVector<StringView>({
"red##blue##|",
"blue--yellow--orange--|",
"red~~blue~~|",
"blue**yellow**orange**|",
"red++purple++green++|",
std::nullopt,
});
velox::test::assertEqualVectors(expected, result);
}

TEST_F(ConcatWsTest, separatorOnly) {
auto separatorVector = makeNullableFlatVector<StringView>(
{"##", std::nullopt, "~~", "**", std::nullopt});
auto result = evaluate<SimpleVector<StringView>>(
"concat_ws(c0)", makeRowVector({separatorVector}));
auto expected = makeNullableFlatVector<StringView>({
"",
std::nullopt,
"",
"",
std::nullopt,
});
velox::test::assertEqualVectors(expected, result);
}
Expand Down

0 comments on commit ca1250c

Please sign in to comment.