diff --git a/dbms/src/Functions/FunctionsStringReplace.h b/dbms/src/Functions/FunctionsStringReplace.h index 604c2479bb0..5583239c027 100644 --- a/dbms/src/Functions/FunctionsStringReplace.h +++ b/dbms/src/Functions/FunctionsStringReplace.h @@ -22,6 +22,8 @@ #include #include #include +#include +#include #include namespace DB @@ -41,30 +43,10 @@ class FunctionStringReplace : public IFunction String getName() const override { return name; } - size_t getNumberOfArguments() const override { return 0; } + size_t getNumberOfArguments() const override { return 3; } - bool isVariadic() const override { return true; } + bool isVariadic() const override { return false; } bool useDefaultImplementationForConstants() const override { return true; } - ColumnNumbers getArgumentsThatAreAlwaysConstant() const override - { - if constexpr (Impl::support_non_const_needle && Impl::support_non_const_replacement) - { - return {3, 4, 5}; - } - else if constexpr (Impl::support_non_const_needle) - { - return {2, 3, 4, 5}; - } - else if constexpr (Impl::support_non_const_replacement) - { - return {1, 3, 4, 5}; - } - else - { - return {1, 2, 3, 4, 5}; - } - } - void setCollator(const TiDB::TiDBCollatorPtr & collator_) override { collator = collator_; } DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { @@ -83,84 +65,45 @@ class FunctionStringReplace : public IFunction "Illegal type " + arguments[2]->getName() + " of third argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - if (arguments.size() > 3 && !arguments[3]->isInteger()) - throw Exception( - "Illegal type " + arguments[2]->getName() + " of forth argument of function " + getName(), - ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - - if (arguments.size() > 4 && !arguments[4]->isInteger()) - throw Exception( - "Illegal type " + arguments[2]->getName() + " of fifth argument of function " + getName(), - ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - - if (arguments.size() > 5 && !arguments[5]->isStringOrFixedString()) - throw Exception( - "Illegal type " + arguments[2]->getName() + " of sixth argument of function " + getName(), - ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - return std::make_shared(); } void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override { - const ColumnPtr & column_src = block.getByPosition(arguments[0]).column; - const ColumnPtr & column_needle = block.getByPosition(arguments[1]).column; - const ColumnPtr & column_replacement = block.getByPosition(arguments[2]).column; - const ColumnPtr column_pos = arguments.size() > 3 ? block.getByPosition(arguments[3]).column : nullptr; - const ColumnPtr column_occ = arguments.size() > 4 ? block.getByPosition(arguments[4]).column : nullptr; - const ColumnPtr column_match_type = arguments.size() > 5 ? block.getByPosition(arguments[5]).column : nullptr; - - if ((column_pos != nullptr && !column_pos->isColumnConst()) - || (column_occ != nullptr && !column_occ->isColumnConst()) - || (column_match_type != nullptr && !column_match_type->isColumnConst())) - throw Exception("4th, 5th, 6th arguments of function " + getName() + " must be constants."); - Int64 pos = column_pos == nullptr ? 1 : typeid_cast(column_pos.get())->getInt(0); - Int64 occ = column_occ == nullptr ? 0 : typeid_cast(column_occ.get())->getInt(0); - String match_type = column_match_type == nullptr - ? "" - : typeid_cast(column_match_type.get())->getValue(); + ColumnPtr column_src = block.getByPosition(arguments[0]).column; + ColumnPtr column_needle = block.getByPosition(arguments[1]).column; + ColumnPtr column_replacement = block.getByPosition(arguments[2]).column; ColumnWithTypeAndName & column_result = block.getByPosition(result); bool needle_const = column_needle->isColumnConst(); bool replacement_const = column_replacement->isColumnConst(); - if (needle_const && replacement_const) - { - executeImpl(column_src, column_needle, column_replacement, pos, occ, match_type, column_result); - } - else if (needle_const) + if (column_src->isColumnConst()) { - executeImplNonConstReplacement( + executeImplConstHaystack( column_src, column_needle, column_replacement, - pos, - occ, - match_type, + needle_const, + replacement_const, column_result); } + else if (needle_const && replacement_const) + { + executeImpl(column_src, column_needle, column_replacement, column_result); + } + else if (needle_const) + { + executeImplNonConstReplacement(column_src, column_needle, column_replacement, column_result); + } else if (replacement_const) { - executeImplNonConstNeedle( - column_src, - column_needle, - column_replacement, - pos, - occ, - match_type, - column_result); + executeImplNonConstNeedle(column_src, column_needle, column_replacement, column_result); } else { - executeImplNonConstNeedleReplacement( - column_src, - column_needle, - column_replacement, - pos, - occ, - match_type, - column_result); + executeImplNonConstNeedleReplacement(column_src, column_needle, column_replacement, column_result); } } @@ -169,9 +112,6 @@ class FunctionStringReplace : public IFunction const ColumnPtr & column_src, const ColumnPtr & column_needle, const ColumnPtr & column_replacement, - Int64 pos, - Int64 occ, - const String & match_type, ColumnWithTypeAndName & column_result) const { const auto * c1_const = typeid_cast(column_needle.get()); @@ -187,10 +127,6 @@ class FunctionStringReplace : public IFunction col->getOffsets(), needle, replacement, - pos, - occ, - match_type, - collator, col_res->getChars(), col_res->getOffsets()); column_result.column = std::move(col_res); @@ -203,10 +139,6 @@ class FunctionStringReplace : public IFunction col->getN(), needle, replacement, - pos, - occ, - match_type, - collator, col_res->getChars(), col_res->getOffsets()); column_result.column = std::move(col_res); @@ -217,13 +149,73 @@ class FunctionStringReplace : public IFunction ErrorCodes::ILLEGAL_COLUMN); } + void executeImplConstHaystack( + const ColumnPtr & column_src, + const ColumnPtr & column_needle, + const ColumnPtr & column_replacement, + bool needle_const, + bool replacement_const, + ColumnWithTypeAndName & column_result) const + { + auto res_col = ColumnString::create(); + res_col->reserve(column_src->size()); + + RUNTIME_CHECK_MSG( + !needle_const || !replacement_const, + "should not got here when all argments of replace are constant"); + + const auto * column_src_const = checkAndGetColumnConst(column_src.get()); + RUNTIME_CHECK(column_src_const); + + using GatherUtils::ConstSource; + using GatherUtils::StringSource; + if (!needle_const && !replacement_const) + { + const auto * column_needle_string = checkAndGetColumn(column_needle.get()); + const auto * column_replacement_string = checkAndGetColumn(column_replacement.get()); + RUNTIME_CHECK(column_needle_string); + RUNTIME_CHECK(column_replacement_string); + + GatherUtils::replace( + ConstSource(*column_src_const), + StringSource(*column_needle_string), + StringSource(*column_replacement_string), + res_col); + } + else if (needle_const && !replacement_const) + { + const auto * column_needle_const = checkAndGetColumnConst(column_needle.get()); + const auto * column_replacement_string = checkAndGetColumn(column_replacement.get()); + RUNTIME_CHECK(column_needle_const); + RUNTIME_CHECK(column_replacement_string); + + GatherUtils::replace( + ConstSource(*column_src_const), + ConstSource(*column_needle_const), + StringSource(*column_replacement_string), + res_col); + } + else if (!needle_const && replacement_const) + { + const auto * column_needle_string = checkAndGetColumn(column_needle.get()); + const auto * column_replacement_const = checkAndGetColumnConst(column_replacement.get()); + RUNTIME_CHECK(column_needle_string); + RUNTIME_CHECK(column_replacement_const); + + GatherUtils::replace( + ConstSource(*column_src_const), + StringSource(*column_needle_string), + ConstSource(*column_replacement_const), + res_col); + } + + column_result.column = std::move(res_col); + } + void executeImplNonConstNeedle( const ColumnPtr & column_src, const ColumnPtr & column_needle, const ColumnPtr & column_replacement, - Int64 pos [[maybe_unused]], - Int64 occ [[maybe_unused]], - const String & match_type, ColumnWithTypeAndName & column_result) const { if constexpr (Impl::support_non_const_needle) @@ -241,10 +233,6 @@ class FunctionStringReplace : public IFunction col_needle->getChars(), col_needle->getOffsets(), replacement, - pos, - occ, - match_type, - collator, col_res->getChars(), col_res->getOffsets()); column_result.column = std::move(col_res); @@ -258,10 +246,6 @@ class FunctionStringReplace : public IFunction col_needle->getChars(), col_needle->getOffsets(), replacement, - pos, - occ, - match_type, - collator, col_res->getChars(), col_res->getOffsets()); column_result.column = std::move(col_res); @@ -281,9 +265,6 @@ class FunctionStringReplace : public IFunction const ColumnPtr & column_src, const ColumnPtr & column_needle, const ColumnPtr & column_replacement, - Int64 pos [[maybe_unused]], - Int64 occ [[maybe_unused]], - const String & match_type, ColumnWithTypeAndName & column_result) const { if constexpr (Impl::support_non_const_replacement) @@ -301,10 +282,6 @@ class FunctionStringReplace : public IFunction needle, col_replacement->getChars(), col_replacement->getOffsets(), - pos, - occ, - match_type, - collator, col_res->getChars(), col_res->getOffsets()); column_result.column = std::move(col_res); @@ -318,10 +295,6 @@ class FunctionStringReplace : public IFunction needle, col_replacement->getChars(), col_replacement->getOffsets(), - pos, - occ, - match_type, - collator, col_res->getChars(), col_res->getOffsets()); column_result.column = std::move(col_res); @@ -341,9 +314,6 @@ class FunctionStringReplace : public IFunction const ColumnPtr & column_src, const ColumnPtr & column_needle, const ColumnPtr & column_replacement, - Int64 pos [[maybe_unused]], - Int64 occ [[maybe_unused]], - const String & match_type, ColumnWithTypeAndName & column_result) const { if constexpr (Impl::support_non_const_needle && Impl::support_non_const_replacement) @@ -361,10 +331,6 @@ class FunctionStringReplace : public IFunction col_needle->getOffsets(), col_replacement->getChars(), col_replacement->getOffsets(), - pos, - occ, - match_type, - collator, col_res->getChars(), col_res->getOffsets()); column_result.column = std::move(col_res); @@ -379,10 +345,6 @@ class FunctionStringReplace : public IFunction col_needle->getOffsets(), col_replacement->getChars(), col_replacement->getOffsets(), - pos, - occ, - match_type, - collator, col_res->getChars(), col_res->getOffsets()); column_result.column = std::move(col_res); @@ -399,7 +361,5 @@ class FunctionStringReplace : public IFunction ErrorCodes::ILLEGAL_COLUMN); } } - - TiDB::TiDBCollatorPtr collator{}; }; } // namespace DB diff --git a/dbms/src/Functions/FunctionsStringSearch.cpp b/dbms/src/Functions/FunctionsStringSearch.cpp index 5b5318bcc30..bae0e9103cb 100644 --- a/dbms/src/Functions/FunctionsStringSearch.cpp +++ b/dbms/src/Functions/FunctionsStringSearch.cpp @@ -816,10 +816,6 @@ struct ReplaceStringImpl const ColumnString::Offsets & offsets, const std::string & needle, const std::string & replacement, - const Int64 & /* pos */, - const Int64 & /* occ */, - const std::string & /* match_type */, - TiDB::TiDBCollatorPtr /* collator */, ColumnString::Chars_t & res_data, ColumnString::Offsets & res_offsets) { @@ -904,10 +900,6 @@ struct ReplaceStringImpl const ColumnString::Chars_t & needle_chars, const ColumnString::Offsets & needle_offsets, const std::string & replacement, - const Int64 & /* pos */, - const Int64 & /* occ */, - const std::string & /* match_type */, - TiDB::TiDBCollatorPtr /* collator */, ColumnString::Chars_t & res_data, ColumnString::Offsets & res_offsets) { @@ -978,10 +970,6 @@ struct ReplaceStringImpl const std::string & needle, const ColumnString::Chars_t & replacement_chars, const ColumnString::Offsets & replacement_offsets, - const Int64 & /* pos */, - const Int64 & /* occ */, - const std::string & /* match_type */, - TiDB::TiDBCollatorPtr /* collator */, ColumnString::Chars_t & res_data, ColumnString::Offsets & res_offsets) { @@ -1070,10 +1058,6 @@ struct ReplaceStringImpl const ColumnString::Offsets & needle_offsets, const ColumnString::Chars_t & replacement_chars, const ColumnString::Offsets & replacement_offsets, - const Int64 & /* pos */, - const Int64 & /* occ */, - const std::string & /* match_type */, - TiDB::TiDBCollatorPtr /* collator */, ColumnString::Chars_t & res_data, ColumnString::Offsets & res_offsets) { @@ -1146,10 +1130,6 @@ struct ReplaceStringImpl size_t n, const std::string & needle, const std::string & replacement, - const Int64 & /* pos */, - const Int64 & /* occ */, - const std::string & /* match_type */, - TiDB::TiDBCollatorPtr /* collator */, ColumnString::Chars_t & res_data, ColumnString::Offsets & res_offsets) { @@ -1244,10 +1224,6 @@ struct ReplaceStringImpl const ColumnString::Chars_t & needle_chars, const ColumnString::Offsets & needle_offsets, const std::string & replacement, - const Int64 & /* pos */, - const Int64 & /* occ */, - const std::string & /* match_type */, - TiDB::TiDBCollatorPtr /* collator */, ColumnString::Chars_t & res_data, ColumnString::Offsets & res_offsets) { @@ -1319,10 +1295,6 @@ struct ReplaceStringImpl const std::string & needle, const ColumnString::Chars_t & replacement_chars, const ColumnString::Offsets & replacement_offsets, - const Int64 & /* pos */, - const Int64 & /* occ */, - const std::string & /* match_type */, - TiDB::TiDBCollatorPtr /* collator */, ColumnString::Chars_t & res_data, ColumnString::Offsets & res_offsets) { @@ -1421,10 +1393,6 @@ struct ReplaceStringImpl const ColumnString::Offsets & needle_offsets, const ColumnString::Chars_t & replacement_chars, const ColumnString::Offsets & replacement_offsets, - const Int64 & /* pos */, - const Int64 & /* occ */, - const std::string & /* match_type */, - TiDB::TiDBCollatorPtr /* collator */, ColumnString::Chars_t & res_data, ColumnString::Offsets & res_offsets) { @@ -1498,10 +1466,6 @@ struct ReplaceStringImpl const std::string & data, const std::string & needle, const std::string & replacement, - const Int64 & /* pos */, - const Int64 & /* occ */, - const std::string & /* match_type */, - TiDB::TiDBCollatorPtr /* collator */, std::string & res_data) { if (needle.empty()) diff --git a/dbms/src/Functions/GatherUtils/Algorithms.h b/dbms/src/Functions/GatherUtils/Algorithms.h index f08dbf8fead..baa6e6c27db 100644 --- a/dbms/src/Functions/GatherUtils/Algorithms.h +++ b/dbms/src/Functions/GatherUtils/Algorithms.h @@ -813,4 +813,30 @@ void resizeConstantSize(ArraySource && array_source, ValueSource && value_source } } +template +void replace( + HaystackSource && src_h, + NeedleSource && src_n, + ReplacementSource && src_r, + ColumnString::MutablePtr & res_col) +{ + while (!src_h.isEnd()) + { + const auto slice_h = src_h.getWhole(); + const auto slice_n = src_n.getWhole(); + const auto slice_r = src_r.getWhole(); + + const String str_h(reinterpret_cast(slice_h.data), slice_h.size); + const String str_n(reinterpret_cast(slice_n.data), slice_n.size); + const String str_r(reinterpret_cast(slice_r.data), slice_r.size); + String res; + Impl::constant(str_h, str_n, str_r, res); + res_col->insertData(res.data(), res.size()); + + src_h.next(); + src_n.next(); + src_r.next(); + } +} + } // namespace DB::GatherUtils diff --git a/dbms/src/Functions/tests/gtest_strings_replace.cpp b/dbms/src/Functions/tests/gtest_strings_replace.cpp index 4615d634e5f..ad14cbc7be7 100644 --- a/dbms/src/Functions/tests/gtest_strings_replace.cpp +++ b/dbms/src/Functions/tests/gtest_strings_replace.cpp @@ -70,6 +70,13 @@ try toVec({"", "w", "ww", " www ", "w w w"}), executeFunction("replaceAll", toVec({"", "w", "ww", " www ", "w w w"}), toConst(""), toConst(" "))); + ASSERT_COLUMN_EQ( + createConstColumn(1, {" bc"}), + executeFunction("replaceAll", toConst("abc"), toConst("a"), toConst(" "))); + ASSERT_COLUMN_EQ( + createConstColumn(1, {""}), + executeFunction("replaceAll", toConst(""), toConst(""), toConst(" "))); + /// non-const needle and const replacement ASSERT_COLUMN_EQ( toVec({"hello", " e llo", "hello ", " ", "hello world"}), @@ -87,6 +94,14 @@ try toVec({" ", "w", "w", "www", " w"}), toConst("ww"))); + ASSERT_COLUMN_EQ( + toVec({" bc", "a c", "ab "}), + executeFunction( + "replaceAll", + createConstColumn(3, "abc"), + toVec({"a", "b", "c"}), + createConstColumn(3, " "))); + /// const needle and non-const replacement ASSERT_COLUMN_EQ( toVec({"hello", "xxxhxexllo", "helloxxxxxxxx", " ", "hello,,world"}), @@ -96,6 +111,14 @@ try toConst(" "), toVec({"", "x", "xx", " ", ","}))); + ASSERT_COLUMN_EQ( + toVec({"123", "456", "789"}), + executeFunction( + "replaceAll", + createConstColumn(3, "abc"), + createConstColumn(3, "abc"), + toVec({"123", "456", "789"}))); + /// non-const needle and non-const replacement ASSERT_COLUMN_EQ( toVec({"hello", " x e llo", "hello ", " ", "hello, world"}), @@ -104,6 +127,14 @@ try toVec({" hello ", " h e llo", "hello ", " ", "hello, world"}), toVec({" ", "h", "", "h", ","}), toVec({"", "x", "xx", " ", ","}))); + + ASSERT_COLUMN_EQ( + toVec({"1bc", "a2c", "ab3"}), + executeFunction( + "replaceAll", + createConstColumn(3, "abc"), + toVec({"a", "b", "c"}), + toVec({"1", "2", "3"}))); } CATCH @@ -127,6 +158,13 @@ try toConst("你"), toConst("您"))); + ASSERT_COLUMN_EQ( + createConstColumn(1, {"你你世界"}), + executeFunction("replaceAll", toConst("你好世界"), toConst("好"), toConst("你"))); + ASSERT_COLUMN_EQ( + createConstColumn(1, {" "}), + executeFunction("replaceAll", toConst("你好世界"), toConst("你好世界"), toConst(" "))); + /// non-const needle and const replacement ASSERT_COLUMN_EQ( toVec({" 你好 ", "你好", " ", "你 好 ", "你不好"}), @@ -144,6 +182,14 @@ try toVec({" ", " 你", "你好", " 你", "你好"}), toConst("x"))); + ASSERT_COLUMN_EQ( + toVec({" 好世界", "你好 界", "你 世界"}), + executeFunction( + "replaceAll", + createConstColumn(3, "你好世界"), + toVec({"你", "世", "好"}), + createConstColumn(3, " "))); + /// const needle and non-const replacement ASSERT_COLUMN_EQ( toVec({" 好 ", " 你 好", "你好好 你好好", " 你 好 ", "你好不好"}), @@ -153,6 +199,14 @@ try toConst("你"), toVec({"", " 你", "你好", " 你", "你好"}))); + ASSERT_COLUMN_EQ( + toVec({"你一二世界", "你天天世界", "你向上世界"}), + executeFunction( + "replaceAll", + createConstColumn(3, "你好世界"), + createConstColumn(3, "好"), + toVec({"一二", "天天", "向上"}))); + /// non-const needle and non-const replacement ASSERT_COLUMN_EQ( toVec({" 你好 ", " 你 你 你你 你好", "好 好", " 你好 ", "你不好"}), @@ -161,6 +215,14 @@ try toVec({" 你好 ", " 你 好", "你好 你好", "你 好 ", "你不好"}), toVec({"", " ", "你好", "你 ", "你好"}), toVec({" ", " 你", "好", " 你", "你好"}))); + + ASSERT_COLUMN_EQ( + toVec({"你好世好", "你好好界", "你学世界", "习好世界"}), + executeFunction( + "replaceAll", + createConstColumn(3, "你好世界"), + toVec({"界", "世", "好", "你"}), + toVec({"好", "好", "学", "习"}))); } CATCH diff --git a/tests/fullstack-test/expr/replace.test b/tests/fullstack-test/expr/replace.test new file mode 100644 index 00000000000..1f5c7f8a9c8 --- /dev/null +++ b/tests/fullstack-test/expr/replace.test @@ -0,0 +1,40 @@ +# Copyright 2024 PingCAP, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +mysql> drop table if exists test.t +mysql> create table test.t(c1 varchar(100), c2 varchar(100), c3 varchar(100)) +mysql> insert into test.t values('hello world', 'hello', '???') +mysql> alter table test.t set tiflash replica 1 +func> wait_table test t +mysql> set tidb_isolation_read_engines = 'tiflash'; set tidb_enforce_mpp=1; select replace(c1, c2, c3) from test.t +replace(c1, c2, c3) +??? world + +mysql> set tidb_isolation_read_engines = 'tiflash'; set tidb_enforce_mpp=1; select replace('hello world', c2, c3) from test.t +replace('hello world', c2, c3) +??? world + +mysql> set tidb_isolation_read_engines = 'tiflash'; set tidb_enforce_mpp=1; select replace('hello world', 'hello', '???') from test.t +replace('hello world', 'hello', '???') +??? world + +mysql> set tidb_isolation_read_engines = 'tiflash'; set tidb_enforce_mpp=1; select replace('hello world', c2, '???') from test.t +replace('hello world', c2, '???') +??? world + +mysql> set tidb_isolation_read_engines = 'tiflash'; set tidb_enforce_mpp=1; select replace('hello world', 'hello', c3) from test.t +replace('hello world', 'hello', c3) +??? world + +mysql> drop table if exists test.t