Skip to content

Commit

Permalink
Move to separate file
Browse files Browse the repository at this point in the history
  • Loading branch information
PHILO-HE committed Jun 23, 2024
1 parent 15d4bab commit 821cfd8
Show file tree
Hide file tree
Showing 6 changed files with 387 additions and 342 deletions.
1 change: 1 addition & 0 deletions velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_library(
ArraySort.cpp
Bitwise.cpp
Comparisons.cpp
ConcatWs.cpp
DecimalArithmetic.cpp
DecimalCompare.cpp
Hash.cpp
Expand Down
357 changes: 357 additions & 0 deletions velox/functions/sparksql/ConcatWs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,357 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/functions/sparksql/ConcatWs.h"
#include "velox/expression/VectorFunction.h"

namespace facebook::velox::functions::sparksql {
class ConcatWs : public exec::VectorFunction {
public:
explicit ConcatWs(const std::optional<std::string>& separator)
: separator_(separator) {}

// Calculate the total number of bytes in the result.
size_t calculateTotalResultBytes(
const SelectivityVector& rows,
exec::EvalCtx& context,
std::vector<exec::LocalDecodedVector>& decodedArrays,
const std::vector<std::string>& constantStrings,
const std::vector<exec::LocalDecodedVector>& decodedStringArgs,
const exec::LocalDecodedVector& decodedSeparator) const {
auto arrayArgNum = decodedArrays.size();
std::vector<const ArrayVector*> arrayVectors;
std::vector<DecodedVector*> elementsDecodedVectors;
for (auto i = 0; i < arrayArgNum; ++i) {
auto arrayVector = decodedArrays[i].get()->base()->as<ArrayVector>();
arrayVectors.push_back(arrayVector);
auto elements = arrayVector->elements();
exec::LocalSelectivityVector nestedRows(context, elements->size());
nestedRows.get()->setAll();
exec::LocalDecodedVector elementsHolder(
context, *elements, *nestedRows.get());
elementsDecodedVectors.push_back(elementsHolder.get());
}

size_t totalResultBytes = 0;
rows.applyToSelected([&](auto row) {
int32_t allElements = 0;
// Calculate size for array columns data.
for (int i = 0; i < arrayArgNum; i++) {
auto arrayVector = arrayVectors[i];
auto rawSizes = arrayVector->rawSizes();
auto rawOffsets = arrayVector->rawOffsets();
auto indices = decodedArrays[i].get()->indices();
auto elementsDecoded = elementsDecodedVectors[i];

auto size = rawSizes[indices[row]];
auto offset = rawOffsets[indices[row]];
for (int j = 0; j < size; ++j) {
if (!elementsDecoded->isNullAt(offset + j)) {
auto element = elementsDecoded->valueAt<StringView>(offset + j);
// No matter empty string or not.
allElements++;
totalResultBytes += element.size();
}
}
}

// Calculate size for string arg.
auto it = decodedStringArgs.begin();
for (int i = 0; i < constantStrings.size(); i++) {
StringView value;
if (!constantStrings[i].empty()) {
value = StringView(constantStrings[i]);
} else {
// Skip NULL.
if ((*it)->isNullAt(row)) {
++it;
continue;
}
value = (*it++)->valueAt<StringView>(row);
}
// No matter empty string or not.
allElements++;
totalResultBytes += value.size();
}

int32_t separatorSize = separator_.has_value()
? separator_.value().size()
: decodedSeparator->valueAt<StringView>(row).size();

if (allElements > 1 && separatorSize > 0) {
totalResultBytes += (allElements - 1) * separatorSize;
}
});
return totalResultBytes;
}

void initVectors(
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
const exec::EvalCtx& context,
std::vector<exec::LocalDecodedVector>& decodedArrays,
std::vector<column_index_t>& argMapping,
std::vector<std::string>& constantStrings,
std::vector<exec::LocalDecodedVector>& decodedStringArgs) const {
for (auto i = 1; i < args.size(); ++i) {
if (args[i] && args[i]->typeKind() == TypeKind::ARRAY) {
decodedArrays.emplace_back(context, *args[i], rows);
continue;
}
// Handles string arg.
argMapping.push_back(i);
if (!separator_.has_value()) {
// Cannot concat consecutive constant string args in advance.
constantStrings.push_back("");
continue;
}
if (args[i] && args[i]->as<ConstantVector<StringView>>() &&
!args[i]->as<ConstantVector<StringView>>()->isNullAt(0)) {
std::ostringstream out;
out << args[i]->as<ConstantVector<StringView>>()->valueAt(0);
column_index_t j = i + 1;
// Concat constant string args in advance.
for (; j < args.size(); ++j) {
if (!args[j] || args[j]->typeKind() == TypeKind::ARRAY ||
!args[j]->as<ConstantVector<StringView>>() ||
args[j]->as<ConstantVector<StringView>>()->isNullAt(0)) {
break;
}
out << separator_.value()
<< args[j]->as<ConstantVector<StringView>>()->valueAt(0);
}
constantStrings.emplace_back(out.str());
i = j - 1;
} else {
constantStrings.push_back("");
}
}

// Number of string columns after combined consecutive constant ones.
auto numStringCols = constantStrings.size();
for (auto i = 0; i < numStringCols; ++i) {
if (constantStrings[i].empty()) {
auto index = argMapping[i];
decodedStringArgs.emplace_back(context, *args[index], rows);
}
}
}

// ConcatWs implementation. It concatenates the arguments with the separator.
// Mixed using of VARCHAR & ARRAY<VARCHAR> is considered. If separator is
// constant, concatenate consecutive constant string args in advance. Then,
// concatenate the intemediate result with neighboring array args or
// non-constant string args.
void doApply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
exec::EvalCtx& context,
FlatVector<StringView>& flatResult) const {
std::vector<column_index_t> argMapping;
std::vector<std::string> constantStrings;
auto numArgs = args.size();
argMapping.reserve(numArgs - 1);
// Save intermediate result for consecutive constant string args.
// They will be concatenated in advance.
constantStrings.reserve(numArgs - 1);
std::vector<exec::LocalDecodedVector> decodedArrays;
decodedArrays.reserve(numArgs - 1);
// For column string arg decoding.
std::vector<exec::LocalDecodedVector> decodedStringArgs;
decodedStringArgs.reserve(numArgs);

initVectors(
rows,
args,
context,
decodedArrays,
argMapping,
constantStrings,
decodedStringArgs);
exec::LocalDecodedVector decodedSeparator(context);
if (!separator_.has_value()) {
decodedSeparator = exec::LocalDecodedVector(context, *args[0], rows);
}

auto totalResultBytes = calculateTotalResultBytes(
rows,
context,
decodedArrays,
constantStrings,
decodedStringArgs,
decodedSeparator);

std::vector<const ArrayVector*> arrayVectors;
std::vector<DecodedVector*> elementsDecodedVectors;
for (auto i = 0; i < decodedArrays.size(); ++i) {
auto arrayVector = decodedArrays[i].get()->base()->as<ArrayVector>();
arrayVectors.push_back(arrayVector);
auto elements = arrayVector->elements();
exec::LocalSelectivityVector nestedRows(context, elements->size());
nestedRows.get()->setAll();
exec::LocalDecodedVector elementsHolder(
context, *elements, *nestedRows.get());
elementsDecodedVectors.push_back(elementsHolder.get());
}
// Allocate a string buffer.
auto rawBuffer =
flatResult.getRawStringBufferWithSpace(totalResultBytes, true);
rows.applyToSelected([&](auto row) {
const char* start = rawBuffer;
auto isFirst = true;
// For array arg.
int32_t i = 0;
// For string arg.
int32_t j = 0;
auto it = decodedStringArgs.begin();

auto copyToBuffer = [&](StringView value, StringView separator) {
if (isFirst) {
isFirst = false;
} else {
// Add separator before the current value.
if (!separator.empty()) {
memcpy(rawBuffer, separator.data(), separator.size());
rawBuffer += separator.size();
}
}
if (!value.empty()) {
memcpy(rawBuffer, value.data(), value.size());
rawBuffer += value.size();
}
};

for (auto itArgs = args.begin() + 1; itArgs != args.end(); ++itArgs) {
if ((*itArgs)->typeKind() == TypeKind::ARRAY) {
auto arrayVector = arrayVectors[i];
auto rawSizes = arrayVector->rawSizes();
auto rawOffsets = arrayVector->rawOffsets();
auto indices = decodedArrays[i].get()->indices();
auto elementsDecoded = elementsDecodedVectors[i];

auto size = rawSizes[indices[row]];
auto offset = rawOffsets[indices[row]];
for (int k = 0; k < size; ++k) {
if (!elementsDecoded->isNullAt(offset + k)) {
auto element = elementsDecoded->valueAt<StringView>(offset + k);
copyToBuffer(
element,
separator_.has_value()
? StringView(separator_.value())
: decodedSeparator->valueAt<StringView>(row));
}
}
i++;
continue;
}

if (j >= constantStrings.size()) {
continue;
}

StringView value;
if (!constantStrings[j].empty()) {
value = StringView(constantStrings[j]);
} else {
// Skip NULL.
if ((*it)->isNullAt(row)) {
++it;
continue;
}
value = (*it++)->valueAt<StringView>(row);
}
copyToBuffer(
value,
separator_.has_value()
? StringView(separator_.value())
: decodedSeparator->valueAt<StringView>(row));
j++;
}
flatResult.setNoCopy(row, StringView(start, rawBuffer - start));
});
}

void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& /* outputType */,
exec::EvalCtx& context,
VectorPtr& result) const override {
context.ensureWritable(rows, VARCHAR(), result);
auto flatResult = result->asFlatVector<StringView>();
auto numArgs = args.size();
// If separator is NULL, result is NULL.
if (args[0]->isNullAt(0)) {
rows.applyToSelected([&](auto row) { result->setNull(row, true); });
return;
}
// If only separator (not a NULL) is provided, result is an empty string.
if (numArgs == 1) {
rows.applyToSelected(
[&](auto row) { flatResult->setNoCopy(row, StringView("")); });
return;
}
doApply(rows, args, context, *flatResult);
}

private:
// For holding constant separator.
const std::optional<std::string> separator_;
};

std::vector<std::shared_ptr<exec::FunctionSignature>> concatWsSignatures() {
return {// The second and folowing arguments are varchar or array(varchar).
// Use "any" to allow the mixed using of these two types. The
// argument type will be checked in makeConcatWs.
//
// varchar, [varchar], [array(varchar)], ... -> varchar.
exec::FunctionSignatureBuilder()
.returnType("varchar")
.argumentType("varchar")
.argumentType("any")
.variableArity()
.build()};
}

std::shared_ptr<exec::VectorFunction> makeConcatWs(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/) {
auto numArgs = inputArgs.size();
VELOX_USER_CHECK_GE(
numArgs,
1,
"concat_ws requires one arguments at least, but got {}.",
numArgs);
for (const auto& arg : inputArgs) {
VELOX_USER_CHECK(
arg.type->isVarchar() ||
(arg.type->isArray() &&
arg.type->asArray().elementType()->isVarchar()),
"concat_ws requires varchar or array(varchar) arguments, but got {}.",
arg.type->toString());
}

BaseVector* constantSeparator = inputArgs[0].constantValue.get();
std::optional<std::string> separator = std::nullopt;
if (constantSeparator != nullptr) {
separator =
constantSeparator->as<ConstantVector<StringView>>()->valueAt(0).str();
}

return std::make_shared<ConcatWs>(separator);
}

}
28 changes: 28 additions & 0 deletions velox/functions/sparksql/ConcatWs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <string>
#include "velox/expression/VectorFunction.h"

namespace facebook::velox::functions::sparksql {
std::vector<std::shared_ptr<exec::FunctionSignature>> concatWsSignatures();

std::shared_ptr<exec::VectorFunction> makeConcatWs(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& config);
} // namespace facebook::velox::functions::sparksql
Loading

0 comments on commit 821cfd8

Please sign in to comment.