From ad510db79b0c0375f0431a06d7122a078cc8590c Mon Sep 17 00:00:00 2001 From: Yenda Li Date: Wed, 16 Oct 2024 15:18:06 -0700 Subject: [PATCH] Add trail function (#11265) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/11265 Add trail function as per specs here: https://prestodb.io/docs/current/functions/string.html#trail-string-N-varchar We return the last N characters of the string up to at most the length of the input string. Reviewed By: amitkdutta Differential Revision: D64379575 fbshipit-source-id: dc0e766e54550b7b81bb72fad7ea25c3e3e77e36 --- velox/docs/functions/presto/string.rst | 4 ++ velox/functions/prestosql/StringFunctions.h | 54 +++++++++++++++++++ .../StringFunctionsRegistration.cpp | 3 ++ .../prestosql/tests/StringFunctionsTest.cpp | 16 ++++++ 4 files changed, 77 insertions(+) diff --git a/velox/docs/functions/presto/string.rst b/velox/docs/functions/presto/string.rst index 9040b382f31e..8b6e74184592 100644 --- a/velox/docs/functions/presto/string.rst +++ b/velox/docs/functions/presto/string.rst @@ -214,6 +214,10 @@ String Functions SELECT strrpos('aaa', 'aa', 2); -- 1 +.. function:: trail(string, N) -> varchar + + Returns the last ``N`` characters of the input ``string`` up to at most the length of ``string``. + .. function:: substr(string, start) -> varchar Returns the rest of ``string`` from the starting position ``start``. diff --git a/velox/functions/prestosql/StringFunctions.h b/velox/functions/prestosql/StringFunctions.h index 22d3ea26ccb8..34b67dbcc77b 100644 --- a/velox/functions/prestosql/StringFunctions.h +++ b/velox/functions/prestosql/StringFunctions.h @@ -47,6 +47,60 @@ struct CodePointFunction { } }; +/// trail(string, N) -> varchar +/// +/// Returns the last N characters of the input string. +template +struct TrailFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + // Results refer to strings in the first argument. + static constexpr int32_t reuse_strings_from_arg = 0; + + // ASCII input always produces ASCII result. + static constexpr bool is_default_ascii_behavior = true; + + template + FOLLY_ALWAYS_INLINE void callNullFree( + out_type& result, + const null_free_arg_type& input, + I N) { + doCall(result, input, N); + } + + template + FOLLY_ALWAYS_INLINE void + callAscii(out_type& result, const arg_type& input, I N) { + doCall(result, input, N); + } + + private: + template + FOLLY_ALWAYS_INLINE void + doCall(out_type& result, const arg_type& input, I N) { + if (N <= 0) { + result.setEmpty(); + return; + } + + I numCharacters = stringImpl::length(input); + + // Get the start position of the last N characters + // If N is greater than the number of characters, start at 1/ + I start = N > numCharacters ? 1 : numCharacters - N + 1; + + // Adjust length + I adjustedLength = std::min(N, numCharacters); + + auto byteRange = stringCore::getByteRange( + input.data(), input.size(), start, adjustedLength); + + // Generating output string + result.setNoCopy(StringView( + input.data() + byteRange.first, byteRange.second - byteRange.first)); + } +}; + /// substr(string, start) -> varchar /// /// Returns the rest of string from the starting position start. diff --git a/velox/functions/prestosql/registration/StringFunctionsRegistration.cpp b/velox/functions/prestosql/registration/StringFunctionsRegistration.cpp index 58389dc5f5bd..bb97c332b311 100644 --- a/velox/functions/prestosql/registration/StringFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/StringFunctionsRegistration.cpp @@ -52,6 +52,9 @@ void registerSimpleFunctions(const std::string& prefix) { registerFunction( {prefix + "ends_with"}); + registerFunction( + {prefix + "trail"}); + registerFunction( {prefix + "substr"}); registerFunction( diff --git a/velox/functions/prestosql/tests/StringFunctionsTest.cpp b/velox/functions/prestosql/tests/StringFunctionsTest.cpp index ae0709a420e3..73fcb0a50729 100644 --- a/velox/functions/prestosql/tests/StringFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/StringFunctionsTest.cpp @@ -2164,3 +2164,19 @@ TEST_F(StringFunctionsTest, normalize) { normalizeWithForm("sch\u00f6n", "NFKE"), "Normalization form must be one of [NFD, NFC, NFKD, NFKC]"); } + +TEST_F(StringFunctionsTest, trail) { + auto trail = [&](std::optional string, + std::optional N) { + return evaluateOnce("trail(c0, c1)", string, N); + }; + + // Basic Test + EXPECT_EQ("bar", trail("foobar", 3)); + EXPECT_EQ("foobar", trail("foobar", 7)); + EXPECT_EQ("", trail("foobar", 0)); + EXPECT_EQ("", trail("foobar", -1)); + + // Test empty + EXPECT_EQ("", trail("", 3)); +}