Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve rpad udf by using a GenericStringBuilder #12070

Merged
merged 3 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions datafusion/functions/benches/pad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,12 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| {
b.iter(|| criterion::black_box(rpad().invoke(&args).unwrap()))
});
//
// let args = create_args::<i32>(size, 32, true);
// group.bench_function(BenchmarkId::new("stringview type", size), |b| {
// b.iter(|| criterion::black_box(rpad().invoke(&args).unwrap()))
// });

// rpad for stringview type
let args = create_args::<i32>(size, 32, true);
group.bench_function(BenchmarkId::new("stringview type", size), |b| {
b.iter(|| criterion::black_box(rpad().invoke(&args).unwrap()))
});

group.finish();
}
Expand Down
333 changes: 174 additions & 159 deletions datafusion/functions/src/unicode/rpad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,23 @@
// specific language governing permissions and limitations
// under the License.

use std::any::Any;
use std::sync::Arc;

use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
use arrow::datatypes::DataType;
use datafusion_common::cast::{
as_generic_string_array, as_int64_array, as_string_view_array,
};
use unicode_segmentation::UnicodeSegmentation;

use crate::string::common::StringArrayType;
use crate::utils::{make_scalar_function, utf8_to_str_type};
use arrow::array::{
ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
OffsetSizeTrait, StringViewArray,
};
use arrow::datatypes::DataType;
use datafusion_common::cast::as_int64_array;
use datafusion_common::DataFusionError;
use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use std::fmt::Write;
use std::sync::Arc;
use unicode_segmentation::UnicodeSegmentation;
use DataType::{LargeUtf8, Utf8, Utf8View};

#[derive(Debug)]
pub struct RPadFunc {
Expand Down Expand Up @@ -84,170 +87,182 @@ impl ScalarUDFImpl for RPadFunc {
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args.len() {
2 => match args[0].data_type() {
DataType::Utf8 | DataType::Utf8View => {
make_scalar_function(rpad::<i32, i32>, vec![])(args)
}
DataType::LargeUtf8 => {
make_scalar_function(rpad::<i64, i64>, vec![])(args)
}
other => exec_err!("Unsupported data type {other:?} for function rpad"),
},
3 => match (args[0].data_type(), args[2].data_type()) {
(
DataType::Utf8 | DataType::Utf8View,
DataType::Utf8 | DataType::Utf8View,
) => make_scalar_function(rpad::<i32, i32>, vec![])(args),
(DataType::LargeUtf8, DataType::LargeUtf8) => {
make_scalar_function(rpad::<i64, i64>, vec![])(args)
}
(DataType::LargeUtf8, DataType::Utf8View | DataType::Utf8) => {
make_scalar_function(rpad::<i64, i32>, vec![])(args)
}
(DataType::Utf8View | DataType::Utf8, DataType::LargeUtf8) => {
make_scalar_function(rpad::<i32, i64>, vec![])(args)
}
(first_type, last_type) => {
exec_err!("unsupported arguments type for rpad, first argument type is {}, last argument type is {}", first_type, last_type)
}
},
number => {
exec_err!("unsupported arguments number {} for rpad", number)
match (
args.len(),
args[0].data_type(),
args.get(2).map(|arg| arg.data_type()),
) {
(2, Utf8 | Utf8View, _) => {
make_scalar_function(rpad::<i32, i32>, vec![])(args)
}
(2, LargeUtf8, _) => make_scalar_function(rpad::<i64, i64>, vec![])(args),
(3, Utf8 | Utf8View, Some(Utf8 | Utf8View)) => {
make_scalar_function(rpad::<i32, i32>, vec![])(args)
}
(3, LargeUtf8, Some(LargeUtf8)) => {
make_scalar_function(rpad::<i64, i64>, vec![])(args)
}
(3, Utf8 | Utf8View, Some(LargeUtf8)) => {
make_scalar_function(rpad::<i32, i64>, vec![])(args)
}
(3, LargeUtf8, Some(Utf8 | Utf8View)) => {
make_scalar_function(rpad::<i64, i32>, vec![])(args)
}
(_, _, _) => {
exec_err!("Unsupported combination of data types for function rpad")
}
}
}
}

macro_rules! process_rpad {
// For the two-argument case
($string_array:expr, $length_array:expr) => {{
$string_array
.iter()
.zip($length_array.iter())
.map(|(string, length)| match (string, length) {
(Some(string), Some(length)) => {
if length > i32::MAX as i64 {
return exec_err!("rpad requested length {} too large", length);
}

let length = if length < 0 { 0 } else { length as usize };
if length == 0 {
Ok(Some("".to_string()))
} else {
let graphemes = string.graphemes(true).collect::<Vec<&str>>();
if length < graphemes.len() {
Ok(Some(graphemes[..length].concat()))
} else {
let mut s = string.to_string();
s.push_str(" ".repeat(length - graphemes.len()).as_str());
Ok(Some(s))
}
}
}
_ => Ok(None),
})
.collect::<Result<GenericStringArray<StringArrayLen>>>()
}};

// For the three-argument case
($string_array:expr, $length_array:expr, $fill_array:expr) => {{
$string_array
.iter()
.zip($length_array.iter())
.zip($fill_array.iter())
.map(|((string, length), fill)| match (string, length, fill) {
(Some(string), Some(length), Some(fill)) => {
if length > i32::MAX as i64 {
return exec_err!("rpad requested length {} too large", length);
}

let length = if length < 0 { 0 } else { length as usize };
let graphemes = string.graphemes(true).collect::<Vec<&str>>();
let fill_chars = fill.chars().collect::<Vec<char>>();
pub fn rpad<StringArrayLen: OffsetSizeTrait, FillArrayLen: OffsetSizeTrait>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😍 -- 100% for no more macros!

args: &[ArrayRef],
) -> Result<ArrayRef> {
if args.len() < 2 || args.len() > 3 {
return exec_err!(
"rpad was called with {} arguments. It requires 2 or 3 arguments.",
args.len()
);
}

if length < graphemes.len() {
Ok(Some(graphemes[..length].concat()))
} else if fill_chars.is_empty() {
Ok(Some(string.to_string()))
} else {
let mut s = string.to_string();
let char_vector: Vec<char> = (0..length - graphemes.len())
.map(|l| fill_chars[l % fill_chars.len()])
.collect();
s.push_str(&char_vector.iter().collect::<String>());
Ok(Some(s))
}
}
_ => Ok(None),
})
.collect::<Result<GenericStringArray<StringArrayLen>>>()
}};
let length_array = as_int64_array(&args[1])?;
match (
args.len(),
args[0].data_type(),
args.get(2).map(|arg| arg.data_type()),
) {
(2, Utf8View, _) => {
rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>(
args[0].as_string_view(),
length_array,
None,
)
}
(3, Utf8View, Some(Utf8View)) => {
rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>(
args[0].as_string_view(),
length_array,
Some(args[2].as_string_view()),
)
}
(3, Utf8View, Some(Utf8 | LargeUtf8)) => {
rpad_impl::<&StringViewArray, &GenericStringArray<FillArrayLen>, StringArrayLen>(
args[0].as_string_view(),
length_array,
Some(args[2].as_string::<FillArrayLen>()),
)
}
(3, Utf8 | LargeUtf8, Some(Utf8View)) => rpad_impl::<
&GenericStringArray<StringArrayLen>,
&StringViewArray,
StringArrayLen,
>(
args[0].as_string::<StringArrayLen>(),
length_array,
Some(args[2].as_string_view()),
),
(_, _, _) => rpad_impl::<
&GenericStringArray<StringArrayLen>,
&GenericStringArray<FillArrayLen>,
StringArrayLen,
>(
args[0].as_string::<StringArrayLen>(),
length_array,
args.get(2).map(|arg| arg.as_string::<FillArrayLen>()),
),
}
}

/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated.
/// rpad('hi', 5, 'xy') = 'hixyx'
pub fn rpad<StringArrayLen: OffsetSizeTrait, FillArrayLen: OffsetSizeTrait>(
args: &[ArrayRef],
) -> Result<ArrayRef> {
match (args.len(), args[0].data_type()) {
(2, DataType::Utf8View) => {
let string_array = as_string_view_array(&args[0])?;
let length_array = as_int64_array(&args[1])?;
pub fn rpad_impl<'a, StringArrType, FillArrType, StringArrayLen>(
string_array: StringArrType,
length_array: &Int64Array,
fill_array: Option<FillArrType>,
) -> Result<ArrayRef>
where
StringArrType: StringArrayType<'a>,
FillArrType: StringArrayType<'a>,
StringArrayLen: OffsetSizeTrait,
{
let mut builder: GenericStringBuilder<StringArrayLen> = GenericStringBuilder::new();

let result = process_rpad!(string_array, length_array)?;
Ok(Arc::new(result) as ArrayRef)
match fill_array {
None => {
string_array.iter().zip(length_array.iter()).try_for_each(
|(string, length)| -> Result<(), DataFusionError> {
match (string, length) {
(Some(string), Some(length)) => {
if length > i32::MAX as i64 {
return exec_err!(
"rpad requested length {} too large",
length
);
}
let length = if length < 0 { 0 } else { length as usize };
if length == 0 {
builder.append_value("");
} else {
let graphemes =
string.graphemes(true).collect::<Vec<&str>>();
if length < graphemes.len() {
builder.append_value(graphemes[..length].concat());
} else {
builder.write_str(string)?;
builder.write_str(
&" ".repeat(length - graphemes.len()),
)?;
builder.append_value("");
}
}
}
_ => builder.append_null(),
}
Ok(())
},
)?;
}
(2, _) => {
let string_array = as_generic_string_array::<StringArrayLen>(&args[0])?;
let length_array = as_int64_array(&args[1])?;
Some(fill_array) => {
string_array
.iter()
.zip(length_array.iter())
.zip(fill_array.iter())
.try_for_each(
|((string, length), fill)| -> Result<(), DataFusionError> {
match (string, length, fill) {
(Some(string), Some(length), Some(fill)) => {
if length > i32::MAX as i64 {
return exec_err!(
"rpad requested length {} too large",
length
);
}
let length = if length < 0 { 0 } else { length as usize };
let graphemes =
string.graphemes(true).collect::<Vec<&str>>();

let result = process_rpad!(string_array, length_array)?;
Ok(Arc::new(result) as ArrayRef)
}
(3, DataType::Utf8View) => {
let string_array = as_string_view_array(&args[0])?;
let length_array = as_int64_array(&args[1])?;
match args[2].data_type() {
DataType::Utf8View => {
let fill_array = as_string_view_array(&args[2])?;
let result = process_rpad!(string_array, length_array, fill_array)?;
Ok(Arc::new(result) as ArrayRef)
}
DataType::Utf8 | DataType::LargeUtf8 => {
let fill_array = as_generic_string_array::<FillArrayLen>(&args[2])?;
let result = process_rpad!(string_array, length_array, fill_array)?;
Ok(Arc::new(result) as ArrayRef)
}
other_type => {
exec_err!("unsupported type for rpad's third operator: {}", other_type)
}
}
}
(3, _) => {
let string_array = as_generic_string_array::<StringArrayLen>(&args[0])?;
let length_array = as_int64_array(&args[1])?;
match args[2].data_type() {
DataType::Utf8View => {
let fill_array = as_string_view_array(&args[2])?;
let result = process_rpad!(string_array, length_array, fill_array)?;
Ok(Arc::new(result) as ArrayRef)
}
DataType::Utf8 | DataType::LargeUtf8 => {
let fill_array = as_generic_string_array::<FillArrayLen>(&args[2])?;
let result = process_rpad!(string_array, length_array, fill_array)?;
Ok(Arc::new(result) as ArrayRef)
}
other_type => {
exec_err!("unsupported type for rpad's third operator: {}", other_type)
}
}
if length < graphemes.len() {
builder.append_value(graphemes[..length].concat());
} else if fill.is_empty() {
builder.append_value(string);
} else {
builder.write_str(string)?;
fill.chars()
.cycle()
.take(length - graphemes.len())
.for_each(|ch| builder.write_char(ch).unwrap());
builder.append_value("");
}
}
_ => builder.append_null(),
}
Ok(())
},
)?;
}
(other, other_type) => exec_err!(
"rpad requires 2 or 3 arguments with corresponding types, but got {}. number of arguments with {}",
other, other_type
),
}

Ok(Arc::new(builder.finish()) as ArrayRef)
}

#[cfg(test)]
Expand Down