From 96f60c30bcf9ac7c11ce70cbd247104a46fcbac1 Mon Sep 17 00:00:00 2001 From: amorynan Date: Sat, 30 Nov 2024 10:56:25 +0800 Subject: [PATCH] fix ip functions --- be/src/vec/functions/function_ip.h | 62 ++++++------------- .../scalar_function/IP.groovy | 13 +++- .../test_ipv6_cidr_to_range_function.groovy | 12 ++-- 3 files changed, 36 insertions(+), 51 deletions(-) diff --git a/be/src/vec/functions/function_ip.h b/be/src/vec/functions/function_ip.h index 4fe24d44a4bad3..c01172270069d5 100644 --- a/be/src/vec/functions/function_ip.h +++ b/be/src/vec/functions/function_ip.h @@ -866,6 +866,11 @@ class FunctionIPv4CIDRToRange : public IFunction { } }; +/** + * this function accepts two arguments: an IPv6 address and a CIDR mask + * IPv6 address can be either ipv6 type or string type as ipv6 string address + * FE: PropagateNullable is used to handle nullable columns + */ class FunctionIPv6CIDRToRange : public IFunction { public: static constexpr auto name = "ipv6_cidr_to_range"; @@ -900,9 +905,11 @@ class FunctionIPv6CIDRToRange : public IFunction { col_res = execute_impl(*ipv6_addr_column, *cidr_col, input_rows_count, add_col_const, col_const); } else if (addr_type.is_string()) { - const auto* str_addr_column = assert_cast(addr_column.get()); - col_res = execute_impl(*str_addr_column, *cidr_col, input_rows_count, - add_col_const, col_const); + ColumnPtr col_ipv6 = + convert_to_ipv6(addr_column, nullptr); + const auto* ipv6_addr_column = assert_cast(col_ipv6.get()); + col_res = execute_impl(*ipv6_addr_column, *cidr_col, input_rows_count, + add_col_const, col_const); } else { return Status::RuntimeError( "Illegal column {} of argument of function {}, Expected IPv6 or String", @@ -923,19 +930,8 @@ class FunctionIPv6CIDRToRange : public IFunction { auto& vec_res_upper_range = col_res_upper_range->get_data(); static constexpr UInt8 max_cidr_mask = IPV6_BINARY_LENGTH * 8; - unsigned char ipv6_address_data[IPV6_BINARY_LENGTH]; if (is_addr_const) { - StringRef str_ref = from_column.get_data_at(0); - const char* value = str_ref.data; - size_t value_size = str_ref.size; - if (value_size > IPV6_BINARY_LENGTH || value == nullptr || value_size == 0 || - !IPv6Value::is_valid_string(value, value_size)) { - throw Exception(ErrorCode::INVALID_ARGUMENT, "Illegal ipv6 address '{}'", - std::string(value, value_size)); - } - memcpy(ipv6_address_data, value, value_size); - memset(ipv6_address_data + value_size, 0, IPV6_BINARY_LENGTH - value_size); for (size_t i = 0; i < input_rows_count; ++i) { auto cidr = cidr_column.get_int(i); if (cidr < 0 || cidr > max_cidr_mask) { @@ -945,9 +941,9 @@ class FunctionIPv6CIDRToRange : public IFunction { if constexpr (std::is_same_v) { // 16 bytes ipv6 string is stored in big-endian byte order // so transfer to little-endian firstly - std::reverse(ipv6_address_data, ipv6_address_data + IPV6_BINARY_LENGTH); - apply_cidr_mask(reinterpret_cast(&ipv6_address_data), - reinterpret_cast(&vec_res_lower_range[i]), + auto* src_data = const_cast(from_column.get_data_at(0).data); + std::reverse(src_data, src_data + IPV6_BINARY_LENGTH); + apply_cidr_mask(src_data, reinterpret_cast(&vec_res_lower_range[i]), reinterpret_cast(&vec_res_upper_range[i]), cast_set(cidr)); } else { @@ -967,19 +963,9 @@ class FunctionIPv6CIDRToRange : public IFunction { if constexpr (std::is_same_v) { // 16 bytes ipv6 string is stored in big-endian byte order // so transfer to little-endian firstly - StringRef str_ref = from_column.get_data_at(i); - const char* value = str_ref.data; - size_t value_size = str_ref.size; - if (value_size > IPV6_BINARY_LENGTH || value == nullptr || value_size == 0 || - !IPv6Value::is_valid_string(value, value_size)) { - throw Exception(ErrorCode::INVALID_ARGUMENT, "Illegal ipv6 address '{}'", - std::string(value, value_size)); - } - memcpy(ipv6_address_data, value, value_size); - memset(ipv6_address_data + value_size, 0, IPV6_BINARY_LENGTH - value_size); - std::reverse(ipv6_address_data, ipv6_address_data + IPV6_BINARY_LENGTH); - apply_cidr_mask(reinterpret_cast(&ipv6_address_data), - reinterpret_cast(&vec_res_lower_range[i]), + auto* src_data = const_cast(from_column.get_data_at(i).data); + std::reverse(src_data, src_data + IPV6_BINARY_LENGTH); + apply_cidr_mask(src_data, reinterpret_cast(&vec_res_lower_range[i]), reinterpret_cast(&vec_res_upper_range[i]), cast_set(cidr)); } else { @@ -999,19 +985,9 @@ class FunctionIPv6CIDRToRange : public IFunction { if constexpr (std::is_same_v) { // 16 bytes ipv6 string is stored in big-endian byte order // so transfer to little-endian firstly - StringRef str_ref = from_column.get_data_at(i); - const char* value = str_ref.data; - size_t value_size = str_ref.size; - if (value_size > IPV6_BINARY_LENGTH || value == nullptr || value_size == 0 || - !IPv6Value::is_valid_string(value, value_size)) { - throw Exception(ErrorCode::INVALID_ARGUMENT, "Illegal ipv6 address '{}'", - std::string(value, value_size)); - } - memcpy(ipv6_address_data, value, value_size); - memset(ipv6_address_data + value_size, 0, IPV6_BINARY_LENGTH - value_size); - std::reverse(ipv6_address_data, ipv6_address_data + IPV6_BINARY_LENGTH); - apply_cidr_mask(reinterpret_cast(&ipv6_address_data), - reinterpret_cast(&vec_res_lower_range[i]), + auto* src_data = const_cast(from_column.get_data_at(i).data); + std::reverse(src_data, src_data + IPV6_BINARY_LENGTH); + apply_cidr_mask(src_data, reinterpret_cast(&vec_res_lower_range[i]), reinterpret_cast(&vec_res_upper_range[i]), cast_set(cidr)); } else { diff --git a/regression-test/suites/nereids_function_p0/scalar_function/IP.groovy b/regression-test/suites/nereids_function_p0/scalar_function/IP.groovy index 71abccc9dae104..a1c1d00caa29fe 100644 --- a/regression-test/suites/nereids_function_p0/scalar_function/IP.groovy +++ b/regression-test/suites/nereids_function_p0/scalar_function/IP.groovy @@ -34,7 +34,11 @@ suite("nereids_scalar_fn_IP") { qt_sql_cidr_ipv6_nullable_ "select id, ipv6_cidr_to_range(to_ipv6('::'), 32) from fn_test_ip_nullable order by id;" test { sql "select id, ipv6_cidr_to_range(nullable(''), 32) from fn_test_ip_nullable order by id" - exception "Illegal ipv6 address" + exception "Invalid IPv6 value" + } + test { + sql "select id, ipv6_cidr_to_range(nullable('abc'), 32) from fn_test_ip_not_nullable order by id" + exception "Invalid IPv6 value" } // test IPV4_STRING_TO_NUM/IPV6_STRING_TO_NUM (we have null value in ip4 and ip6 column in fn_test_ip_nullable table) test { @@ -162,7 +166,12 @@ suite("nereids_scalar_fn_IP") { qt_sql_not_null_cidr_ipv6_nullable_ "select id, ipv6_cidr_to_range(to_ipv6('::'), 32) from fn_test_ip_nullable order by id;" test { sql "select id, ipv6_cidr_to_range(nullable(''), 32) from fn_test_ip_not_nullable order by id" - exception "Illegal ipv6 address" + exception "Invalid IPv6 value" + } + + test { + sql "select id, ipv6_cidr_to_range(nullable('abc'), 32) from fn_test_ip_not_nullable order by id" + exception "Invalid IPv6 value" } // test IPV4_STRING_TO_NUM/IPV6_STRING_TO_NUM qt_sql_not_null_ipv6_string_to_num 'select id, hex(ipv6_string_to_num(ip6)) from fn_test_ip_not_nullable order by id' diff --git a/regression-test/suites/query_p0/sql_functions/ip_functions/test_ipv6_cidr_to_range_function.groovy b/regression-test/suites/query_p0/sql_functions/ip_functions/test_ipv6_cidr_to_range_function.groovy index 41432c986fec49..0a8ba107013b4e 100644 --- a/regression-test/suites/query_p0/sql_functions/ip_functions/test_ipv6_cidr_to_range_function.groovy +++ b/regression-test/suites/query_p0/sql_functions/ip_functions/test_ipv6_cidr_to_range_function.groovy @@ -91,13 +91,13 @@ suite("test_ipv6_cidr_to_range_function") { (9, 'ffff:0000:0000:0000:0000:0000:0000:0000', NULL) """ - qt_sql "select id, struct_element(ipv6_cidr_to_range(ipv6_string_to_num_or_null(addr), cidr), 'min') as min_range, struct_element(ipv6_cidr_to_range(ipv6_string_to_num_or_null(addr), cidr), 'max') as max_range from test_str_cidr_to_range_function order by id" + qt_sql "select id, struct_element(ipv6_cidr_to_range(ipv6_num_to_string(ipv6_string_to_num_or_null(addr)), cidr), 'min') as min_range, struct_element(ipv6_cidr_to_range(ipv6_num_to_string(ipv6_string_to_num_or_null(addr)), cidr), 'max') as max_range from test_str_cidr_to_range_function order by id" sql """ DROP TABLE IF EXISTS test_str_cidr_to_range_function """ - qt_sql "select ipv6_cidr_to_range(ipv6_string_to_num('2001:0db8:0000:85a3:0000:0000:ac1f:8001'), 0)" - qt_sql "select ipv6_cidr_to_range(ipv6_string_to_num('2001:0db8:0000:85a3:0000:0000:ac1f:8001'), 128)" - qt_sql "select ipv6_cidr_to_range(ipv6_string_to_num('ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff'), 64)" - qt_sql "select ipv6_cidr_to_range(ipv6_string_to_num('0000:0000:0000:0000:0000:0000:0000:0000'), 8)" - qt_sql "select ipv6_cidr_to_range(ipv6_string_to_num('ffff:0000:0000:0000:0000:0000:0000:0000'), 4)" + qt_sql "select ipv6_cidr_to_range(ipv6_num_to_string(ipv6_string_to_num('2001:0db8:0000:85a3:0000:0000:ac1f:8001')), 0)" + qt_sql "select ipv6_cidr_to_range(ipv6_num_to_string(ipv6_string_to_num('2001:0db8:0000:85a3:0000:0000:ac1f:8001')), 128)" + qt_sql "select ipv6_cidr_to_range(ipv6_num_to_string(ipv6_string_to_num('ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff')), 64)" + qt_sql "select ipv6_cidr_to_range(ipv6_num_to_string(ipv6_string_to_num('0000:0000:0000:0000:0000:0000:0000:0000')), 8)" + qt_sql "select ipv6_cidr_to_range(ipv6_num_to_string(ipv6_string_to_num('ffff:0000:0000:0000:0000:0000:0000:0000')), 4)" }