Skip to content

Commit

Permalink
perf: Faster decimal precision overflow checks (#6419)
Browse files Browse the repository at this point in the history
* add benchmark

* add optimization

* fix

* fix

* cargo fmt

* clippy

* Update arrow-data/src/decimal.rs

Co-authored-by: Liang-Chi Hsieh <[email protected]>

* optimize to avoid allocating an idx variable

* revert change to public api

* fix error in rustdoc

---------

Co-authored-by: Liang-Chi Hsieh <[email protected]>
  • Loading branch information
andygrove and viirya committed Sep 21, 2024
1 parent d727503 commit c90713b
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 42 deletions.
4 changes: 4 additions & 0 deletions arrow-array/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,7 @@ harness = false
[[bench]]
name = "fixed_size_list_array"
harness = false

[[bench]]
name = "decimal_overflow"
harness = false
53 changes: 53 additions & 0 deletions arrow-array/benches/decimal_overflow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow_array::builder::{Decimal128Builder, Decimal256Builder};
use arrow_buffer::i256;
use criterion::*;

fn criterion_benchmark(c: &mut Criterion) {
let len = 8192;
let mut builder_128 = Decimal128Builder::with_capacity(len);
let mut builder_256 = Decimal256Builder::with_capacity(len);
for i in 0..len {
if i % 10 == 0 {
builder_128.append_value(i128::MAX);
builder_256.append_value(i256::from_i128(i128::MAX));
} else {
builder_128.append_value(i as i128);
builder_256.append_value(i256::from_i128(i as i128));
}
}
let array_128 = builder_128.finish();
let array_256 = builder_256.finish();

c.bench_function("validate_decimal_precision_128", |b| {
b.iter(|| black_box(array_128.validate_decimal_precision(8)));
});
c.bench_function("null_if_overflow_precision_128", |b| {
b.iter(|| black_box(array_128.null_if_overflow_precision(8)));
});
c.bench_function("validate_decimal_precision_256", |b| {
b.iter(|| black_box(array_256.validate_decimal_precision(8)));
});
c.bench_function("null_if_overflow_precision_256", |b| {
b.iter(|| black_box(array_256.null_if_overflow_precision(8)));
});
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
4 changes: 1 addition & 3 deletions arrow-array/src/array/primitive_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1570,9 +1570,7 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
/// Validates the Decimal Array, if the value of slot is overflow for the specified precision, and
/// will be casted to Null
pub fn null_if_overflow_precision(&self, precision: u8) -> Self {
self.unary_opt::<_, T>(|v| {
(T::validate_decimal_precision(v, precision).is_ok()).then_some(v)
})
self.unary_opt::<_, T>(|v| T::is_valid_decimal_precision(v, precision).then_some(v))
}

/// Returns [`Self::value`] formatted as a string
Expand Down
16 changes: 15 additions & 1 deletion arrow-array/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ use crate::temporal_conversions::as_datetime_with_timezone;
use crate::timezone::Tz;
use crate::{ArrowNativeTypeOp, OffsetSizeTrait};
use arrow_buffer::{i256, Buffer, OffsetBuffer};
use arrow_data::decimal::{validate_decimal256_precision, validate_decimal_precision};
use arrow_data::decimal::{
is_validate_decimal256_precision, is_validate_decimal_precision, validate_decimal256_precision,
validate_decimal_precision,
};
use arrow_data::{validate_binary_view, validate_string_view};
use arrow_schema::{
ArrowError, DataType, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
Expand Down Expand Up @@ -1194,6 +1197,9 @@ pub trait DecimalType:

/// Validates that `value` contains no more than `precision` decimal digits
fn validate_decimal_precision(value: Self::Native, precision: u8) -> Result<(), ArrowError>;

/// Determines whether `value` contains no more than `precision` decimal digits
fn is_valid_decimal_precision(value: Self::Native, precision: u8) -> bool;
}

/// Validate that `precision` and `scale` are valid for `T`
Expand Down Expand Up @@ -1256,6 +1262,10 @@ impl DecimalType for Decimal128Type {
fn validate_decimal_precision(num: i128, precision: u8) -> Result<(), ArrowError> {
validate_decimal_precision(num, precision)
}

fn is_valid_decimal_precision(value: Self::Native, precision: u8) -> bool {
is_validate_decimal_precision(value, precision)
}
}

impl ArrowPrimitiveType for Decimal128Type {
Expand Down Expand Up @@ -1286,6 +1296,10 @@ impl DecimalType for Decimal256Type {
fn validate_decimal_precision(num: i256, precision: u8) -> Result<(), ArrowError> {
validate_decimal256_precision(num, precision)
}

fn is_valid_decimal_precision(value: Self::Native, precision: u8) -> bool {
is_validate_decimal256_precision(value, precision)
}
}

impl ArrowPrimitiveType for Decimal256Type {
Expand Down
10 changes: 3 additions & 7 deletions arrow-cast/src/cast/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,7 @@ where
if cast_options.safe {
let iter = from.iter().map(|v| {
v.and_then(|v| parse_string_to_decimal_native::<T>(v, scale as usize).ok())
.and_then(|v| {
T::validate_decimal_precision(v, precision)
.is_ok()
.then_some(v)
})
.and_then(|v| T::is_valid_decimal_precision(v, precision).then_some(v))
});
// Benefit:
// 20% performance improvement
Expand Down Expand Up @@ -430,7 +426,7 @@ where
(mul * v.as_())
.round()
.to_i128()
.filter(|v| Decimal128Type::validate_decimal_precision(*v, precision).is_ok())
.filter(|v| Decimal128Type::is_valid_decimal_precision(*v, precision))
})
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
Expand Down Expand Up @@ -473,7 +469,7 @@ where
array
.unary_opt::<_, Decimal256Type>(|v| {
i256::from_f64((v.as_() * mul).round())
.filter(|v| Decimal256Type::validate_decimal_precision(*v, precision).is_ok())
.filter(|v| Decimal256Type::is_valid_decimal_precision(*v, precision))
})
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
Expand Down
14 changes: 8 additions & 6 deletions arrow-cast/src/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,10 @@ where
let array = if scale < 0 {
match cast_options.safe {
true => array.unary_opt::<_, D>(|v| {
v.as_().div_checked(scale_factor).ok().and_then(|v| {
(D::validate_decimal_precision(v, precision).is_ok()).then_some(v)
})
v.as_()
.div_checked(scale_factor)
.ok()
.and_then(|v| (D::is_valid_decimal_precision(v, precision)).then_some(v))
}),
false => array.try_unary::<_, D, _>(|v| {
v.as_()
Expand All @@ -340,9 +341,10 @@ where
} else {
match cast_options.safe {
true => array.unary_opt::<_, D>(|v| {
v.as_().mul_checked(scale_factor).ok().and_then(|v| {
(D::validate_decimal_precision(v, precision).is_ok()).then_some(v)
})
v.as_()
.mul_checked(scale_factor)
.ok()
.and_then(|v| (D::is_valid_decimal_precision(v, precision)).then_some(v))
}),
false => array.try_unary::<_, D, _>(|v| {
v.as_()
Expand Down
Loading

0 comments on commit c90713b

Please sign in to comment.