Skip to content

Commit

Permalink
fix error when first argument of replace function is const (#9615) (#…
Browse files Browse the repository at this point in the history
…9646)

close #9522

1. remove some useless arguments(pos, occ, match_type)
2. support first argument as ColumnConst

Signed-off-by: guo-shaoge <[email protected]>

Co-authored-by: guo-shaoge <[email protected]>
  • Loading branch information
ti-chi-bot and guo-shaoge authored Nov 18, 2024
1 parent ea55cfe commit 7bee6e8
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 160 deletions.
208 changes: 84 additions & 124 deletions dbms/src/Functions/FunctionsStringReplace.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/GatherUtils/Algorithms.h>
#include <Functions/GatherUtils/Sources.h>
#include <Functions/IFunction.h>

namespace DB
Expand All @@ -41,30 +43,10 @@ class FunctionStringReplace : public IFunction

String getName() const override { return name; }

size_t getNumberOfArguments() const override { return 0; }
size_t getNumberOfArguments() const override { return 3; }

bool isVariadic() const override { return true; }
bool isVariadic() const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override
{
if constexpr (Impl::support_non_const_needle && Impl::support_non_const_replacement)
{
return {3, 4, 5};
}
else if constexpr (Impl::support_non_const_needle)
{
return {2, 3, 4, 5};
}
else if constexpr (Impl::support_non_const_replacement)
{
return {1, 3, 4, 5};
}
else
{
return {1, 2, 3, 4, 5};
}
}
void setCollator(const TiDB::TiDBCollatorPtr & collator_) override { collator = collator_; }

DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
Expand All @@ -83,84 +65,45 @@ class FunctionStringReplace : public IFunction
"Illegal type " + arguments[2]->getName() + " of third argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

if (arguments.size() > 3 && !arguments[3]->isInteger())
throw Exception(
"Illegal type " + arguments[2]->getName() + " of forth argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

if (arguments.size() > 4 && !arguments[4]->isInteger())
throw Exception(
"Illegal type " + arguments[2]->getName() + " of fifth argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

if (arguments.size() > 5 && !arguments[5]->isStringOrFixedString())
throw Exception(
"Illegal type " + arguments[2]->getName() + " of sixth argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

return std::make_shared<DataTypeString>();
}

void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override
{
const ColumnPtr & column_src = block.getByPosition(arguments[0]).column;
const ColumnPtr & column_needle = block.getByPosition(arguments[1]).column;
const ColumnPtr & column_replacement = block.getByPosition(arguments[2]).column;
const ColumnPtr column_pos = arguments.size() > 3 ? block.getByPosition(arguments[3]).column : nullptr;
const ColumnPtr column_occ = arguments.size() > 4 ? block.getByPosition(arguments[4]).column : nullptr;
const ColumnPtr column_match_type = arguments.size() > 5 ? block.getByPosition(arguments[5]).column : nullptr;

if ((column_pos != nullptr && !column_pos->isColumnConst())
|| (column_occ != nullptr && !column_occ->isColumnConst())
|| (column_match_type != nullptr && !column_match_type->isColumnConst()))
throw Exception("4th, 5th, 6th arguments of function " + getName() + " must be constants.");
Int64 pos = column_pos == nullptr ? 1 : typeid_cast<const ColumnConst *>(column_pos.get())->getInt(0);
Int64 occ = column_occ == nullptr ? 0 : typeid_cast<const ColumnConst *>(column_occ.get())->getInt(0);
String match_type = column_match_type == nullptr
? ""
: typeid_cast<const ColumnConst *>(column_match_type.get())->getValue<String>();
ColumnPtr column_src = block.getByPosition(arguments[0]).column;
ColumnPtr column_needle = block.getByPosition(arguments[1]).column;
ColumnPtr column_replacement = block.getByPosition(arguments[2]).column;

ColumnWithTypeAndName & column_result = block.getByPosition(result);

bool needle_const = column_needle->isColumnConst();
bool replacement_const = column_replacement->isColumnConst();

if (needle_const && replacement_const)
{
executeImpl(column_src, column_needle, column_replacement, pos, occ, match_type, column_result);
}
else if (needle_const)
if (column_src->isColumnConst())
{
executeImplNonConstReplacement(
executeImplConstHaystack(
column_src,
column_needle,
column_replacement,
pos,
occ,
match_type,
needle_const,
replacement_const,
column_result);
}
else if (needle_const && replacement_const)
{
executeImpl(column_src, column_needle, column_replacement, column_result);
}
else if (needle_const)
{
executeImplNonConstReplacement(column_src, column_needle, column_replacement, column_result);
}
else if (replacement_const)
{
executeImplNonConstNeedle(
column_src,
column_needle,
column_replacement,
pos,
occ,
match_type,
column_result);
executeImplNonConstNeedle(column_src, column_needle, column_replacement, column_result);
}
else
{
executeImplNonConstNeedleReplacement(
column_src,
column_needle,
column_replacement,
pos,
occ,
match_type,
column_result);
executeImplNonConstNeedleReplacement(column_src, column_needle, column_replacement, column_result);
}
}

Expand All @@ -169,9 +112,6 @@ class FunctionStringReplace : public IFunction
const ColumnPtr & column_src,
const ColumnPtr & column_needle,
const ColumnPtr & column_replacement,
Int64 pos,
Int64 occ,
const String & match_type,
ColumnWithTypeAndName & column_result) const
{
const auto * c1_const = typeid_cast<const ColumnConst *>(column_needle.get());
Expand All @@ -187,10 +127,6 @@ class FunctionStringReplace : public IFunction
col->getOffsets(),
needle,
replacement,
pos,
occ,
match_type,
collator,
col_res->getChars(),
col_res->getOffsets());
column_result.column = std::move(col_res);
Expand All @@ -203,10 +139,6 @@ class FunctionStringReplace : public IFunction
col->getN(),
needle,
replacement,
pos,
occ,
match_type,
collator,
col_res->getChars(),
col_res->getOffsets());
column_result.column = std::move(col_res);
Expand All @@ -217,13 +149,73 @@ class FunctionStringReplace : public IFunction
ErrorCodes::ILLEGAL_COLUMN);
}

void executeImplConstHaystack(
const ColumnPtr & column_src,
const ColumnPtr & column_needle,
const ColumnPtr & column_replacement,
bool needle_const,
bool replacement_const,
ColumnWithTypeAndName & column_result) const
{
auto res_col = ColumnString::create();
res_col->reserve(column_src->size());

RUNTIME_CHECK_MSG(
!needle_const || !replacement_const,
"should not got here when all argments of replace are constant");

const auto * column_src_const = checkAndGetColumnConst<ColumnString>(column_src.get());
RUNTIME_CHECK(column_src_const);

using GatherUtils::ConstSource;
using GatherUtils::StringSource;
if (!needle_const && !replacement_const)
{
const auto * column_needle_string = checkAndGetColumn<ColumnString>(column_needle.get());
const auto * column_replacement_string = checkAndGetColumn<ColumnString>(column_replacement.get());
RUNTIME_CHECK(column_needle_string);
RUNTIME_CHECK(column_replacement_string);

GatherUtils::replace<Impl>(
ConstSource<StringSource>(*column_src_const),
StringSource(*column_needle_string),
StringSource(*column_replacement_string),
res_col);
}
else if (needle_const && !replacement_const)
{
const auto * column_needle_const = checkAndGetColumnConst<ColumnString>(column_needle.get());
const auto * column_replacement_string = checkAndGetColumn<ColumnString>(column_replacement.get());
RUNTIME_CHECK(column_needle_const);
RUNTIME_CHECK(column_replacement_string);

GatherUtils::replace<Impl>(
ConstSource<StringSource>(*column_src_const),
ConstSource<StringSource>(*column_needle_const),
StringSource(*column_replacement_string),
res_col);
}
else if (!needle_const && replacement_const)
{
const auto * column_needle_string = checkAndGetColumn<ColumnString>(column_needle.get());
const auto * column_replacement_const = checkAndGetColumnConst<ColumnString>(column_replacement.get());
RUNTIME_CHECK(column_needle_string);
RUNTIME_CHECK(column_replacement_const);

GatherUtils::replace<Impl>(
ConstSource<StringSource>(*column_src_const),
StringSource(*column_needle_string),
ConstSource<StringSource>(*column_replacement_const),
res_col);
}

column_result.column = std::move(res_col);
}

void executeImplNonConstNeedle(
const ColumnPtr & column_src,
const ColumnPtr & column_needle,
const ColumnPtr & column_replacement,
Int64 pos [[maybe_unused]],
Int64 occ [[maybe_unused]],
const String & match_type,
ColumnWithTypeAndName & column_result) const
{
if constexpr (Impl::support_non_const_needle)
Expand All @@ -241,10 +233,6 @@ class FunctionStringReplace : public IFunction
col_needle->getChars(),
col_needle->getOffsets(),
replacement,
pos,
occ,
match_type,
collator,
col_res->getChars(),
col_res->getOffsets());
column_result.column = std::move(col_res);
Expand All @@ -258,10 +246,6 @@ class FunctionStringReplace : public IFunction
col_needle->getChars(),
col_needle->getOffsets(),
replacement,
pos,
occ,
match_type,
collator,
col_res->getChars(),
col_res->getOffsets());
column_result.column = std::move(col_res);
Expand All @@ -281,9 +265,6 @@ class FunctionStringReplace : public IFunction
const ColumnPtr & column_src,
const ColumnPtr & column_needle,
const ColumnPtr & column_replacement,
Int64 pos [[maybe_unused]],
Int64 occ [[maybe_unused]],
const String & match_type,
ColumnWithTypeAndName & column_result) const
{
if constexpr (Impl::support_non_const_replacement)
Expand All @@ -301,10 +282,6 @@ class FunctionStringReplace : public IFunction
needle,
col_replacement->getChars(),
col_replacement->getOffsets(),
pos,
occ,
match_type,
collator,
col_res->getChars(),
col_res->getOffsets());
column_result.column = std::move(col_res);
Expand All @@ -318,10 +295,6 @@ class FunctionStringReplace : public IFunction
needle,
col_replacement->getChars(),
col_replacement->getOffsets(),
pos,
occ,
match_type,
collator,
col_res->getChars(),
col_res->getOffsets());
column_result.column = std::move(col_res);
Expand All @@ -341,9 +314,6 @@ class FunctionStringReplace : public IFunction
const ColumnPtr & column_src,
const ColumnPtr & column_needle,
const ColumnPtr & column_replacement,
Int64 pos [[maybe_unused]],
Int64 occ [[maybe_unused]],
const String & match_type,
ColumnWithTypeAndName & column_result) const
{
if constexpr (Impl::support_non_const_needle && Impl::support_non_const_replacement)
Expand All @@ -361,10 +331,6 @@ class FunctionStringReplace : public IFunction
col_needle->getOffsets(),
col_replacement->getChars(),
col_replacement->getOffsets(),
pos,
occ,
match_type,
collator,
col_res->getChars(),
col_res->getOffsets());
column_result.column = std::move(col_res);
Expand All @@ -379,10 +345,6 @@ class FunctionStringReplace : public IFunction
col_needle->getOffsets(),
col_replacement->getChars(),
col_replacement->getOffsets(),
pos,
occ,
match_type,
collator,
col_res->getChars(),
col_res->getOffsets());
column_result.column = std::move(col_res);
Expand All @@ -399,7 +361,5 @@ class FunctionStringReplace : public IFunction
ErrorCodes::ILLEGAL_COLUMN);
}
}

TiDB::TiDBCollatorPtr collator{};
};
} // namespace DB
Loading

0 comments on commit 7bee6e8

Please sign in to comment.