Skip to content

Commit

Permalink
Support cast between Durations + between Durations all numeric types (#…
Browse files Browse the repository at this point in the history
…6452)

* Support cast between Durations

Signed-off-by: tison <[email protected]>

* Support cast between Durations and all numeric type

Signed-off-by: tison <[email protected]>

* Impl cast between Durations

Signed-off-by: tison <[email protected]>

* Add test_cast_between_durations

Signed-off-by: tison <[email protected]>

* add test cases

Signed-off-by: tison <[email protected]>

* cargo clippy

Signed-off-by: tison <[email protected]>

---------

Signed-off-by: tison <[email protected]>
  • Loading branch information
tisonkun authored Sep 26, 2024
1 parent d48010e commit 2881dbe
Showing 1 changed file with 153 additions and 21 deletions.
174 changes: 153 additions & 21 deletions arrow-cast/src/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,9 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
| Time64(Microsecond)
| Time64(Nanosecond),
) => true,
(Int64, Duration(_)) => true,
(Duration(_), Int64) => true,
(_, Duration(_)) if from_type.is_numeric() => true,
(Duration(_), _) if to_type.is_numeric() => true,
(Duration(_), Duration(_)) => true,
(Interval(from_type), Int64) => {
match from_type {
YearMonth => true,
Expand Down Expand Up @@ -518,6 +519,15 @@ fn make_timestamp_array(
}
}

fn make_duration_array(array: &PrimitiveArray<Int64Type>, unit: TimeUnit) -> ArrayRef {
match unit {
TimeUnit::Second => Arc::new(array.reinterpret_cast::<DurationSecondType>()),
TimeUnit::Millisecond => Arc::new(array.reinterpret_cast::<DurationMillisecondType>()),
TimeUnit::Microsecond => Arc::new(array.reinterpret_cast::<DurationMicrosecondType>()),
TimeUnit::Nanosecond => Arc::new(array.reinterpret_cast::<DurationNanosecondType>()),
}
}

fn as_time_res_with_timezone<T: ArrowPrimitiveType>(
v: i64,
tz: Option<Tz>,
Expand Down Expand Up @@ -2074,31 +2084,53 @@ pub fn cast_with_options(
.as_primitive::<Date32Type>()
.unary::<_, TimestampNanosecondType>(|x| (x as i64) * NANOSECONDS_IN_DAY),
)),
(Int64, Duration(TimeUnit::Second)) => {
cast_reinterpret_arrays::<Int64Type, DurationSecondType>(array)
}
(Int64, Duration(TimeUnit::Millisecond)) => {
cast_reinterpret_arrays::<Int64Type, DurationMillisecondType>(array)
}
(Int64, Duration(TimeUnit::Microsecond)) => {
cast_reinterpret_arrays::<Int64Type, DurationMicrosecondType>(array)

(_, Duration(unit)) if from_type.is_numeric() => {
let array = cast_with_options(array, &Int64, cast_options)?;
Ok(make_duration_array(array.as_primitive(), *unit))
}
(Int64, Duration(TimeUnit::Nanosecond)) => {
cast_reinterpret_arrays::<Int64Type, DurationNanosecondType>(array)
(Duration(TimeUnit::Second), _) if to_type.is_numeric() => {
let array = cast_reinterpret_arrays::<DurationSecondType, Int64Type>(array)?;
cast_with_options(&array, to_type, cast_options)
}

(Duration(TimeUnit::Second), Int64) => {
cast_reinterpret_arrays::<DurationSecondType, Int64Type>(array)
(Duration(TimeUnit::Millisecond), _) if to_type.is_numeric() => {
let array = cast_reinterpret_arrays::<DurationMillisecondType, Int64Type>(array)?;
cast_with_options(&array, to_type, cast_options)
}
(Duration(TimeUnit::Millisecond), Int64) => {
cast_reinterpret_arrays::<DurationMillisecondType, Int64Type>(array)
(Duration(TimeUnit::Microsecond), _) if to_type.is_numeric() => {
let array = cast_reinterpret_arrays::<DurationMicrosecondType, Int64Type>(array)?;
cast_with_options(&array, to_type, cast_options)
}
(Duration(TimeUnit::Microsecond), Int64) => {
cast_reinterpret_arrays::<DurationMicrosecondType, Int64Type>(array)
(Duration(TimeUnit::Nanosecond), _) if to_type.is_numeric() => {
let array = cast_reinterpret_arrays::<DurationNanosecondType, Int64Type>(array)?;
cast_with_options(&array, to_type, cast_options)
}
(Duration(TimeUnit::Nanosecond), Int64) => {
cast_reinterpret_arrays::<DurationNanosecondType, Int64Type>(array)

(Duration(from_unit), Duration(to_unit)) => {
let array = cast_with_options(array, &Int64, cast_options)?;
let time_array = array.as_primitive::<Int64Type>();
let from_size = time_unit_multiple(from_unit);
let to_size = time_unit_multiple(to_unit);
// we either divide or multiply, depending on size of each unit
// units are never the same when the types are the same
let converted = match from_size.cmp(&to_size) {
Ordering::Greater => {
let divisor = from_size / to_size;
time_array.unary::<_, Int64Type>(|o| o / divisor)
}
Ordering::Equal => time_array.clone(),
Ordering::Less => {
let mul = to_size / from_size;
if cast_options.safe {
time_array.unary_opt::<_, Int64Type>(|o| o.checked_mul(mul))
} else {
time_array.try_unary::<_, Int64Type, _>(|o| o.mul_checked(mul))?
}
}
};
Ok(make_duration_array(&converted, *to_unit))
}

(Duration(TimeUnit::Second), Interval(IntervalUnit::MonthDayNano)) => {
cast_duration_to_interval::<DurationSecondType>(array, cast_options)
}
Expand Down Expand Up @@ -5254,6 +5286,106 @@ mod tests {
}
}

#[test]
fn test_cast_between_durations_and_numerics() {
fn test_cast_between_durations<FromType, ToType>()
where
FromType: ArrowPrimitiveType<Native = i64>,
ToType: ArrowPrimitiveType<Native = i64>,
PrimitiveArray<FromType>: From<Vec<Option<i64>>>,
{
let from_unit = match FromType::DATA_TYPE {
DataType::Duration(unit) => unit,
_ => panic!("Expected a duration type"),
};
let to_unit = match ToType::DATA_TYPE {
DataType::Duration(unit) => unit,
_ => panic!("Expected a duration type"),
};
let from_size = time_unit_multiple(&from_unit);
let to_size = time_unit_multiple(&to_unit);

let (v1_before, v2_before) = (8640003005, 1696002001);
let (v1_after, v2_after) = if from_size >= to_size {
(
v1_before / (from_size / to_size),
v2_before / (from_size / to_size),
)
} else {
(
v1_before * (to_size / from_size),
v2_before * (to_size / from_size),
)
};

let array =
PrimitiveArray::<FromType>::from(vec![Some(v1_before), Some(v2_before), None]);
let b = cast(&array, &ToType::DATA_TYPE).unwrap();
let c = b.as_primitive::<ToType>();
assert_eq!(v1_after, c.value(0));
assert_eq!(v2_after, c.value(1));
assert!(c.is_null(2));
}

// between each individual duration type
test_cast_between_durations::<DurationSecondType, DurationMillisecondType>();
test_cast_between_durations::<DurationSecondType, DurationMicrosecondType>();
test_cast_between_durations::<DurationSecondType, DurationNanosecondType>();
test_cast_between_durations::<DurationMillisecondType, DurationSecondType>();
test_cast_between_durations::<DurationMillisecondType, DurationMicrosecondType>();
test_cast_between_durations::<DurationMillisecondType, DurationNanosecondType>();
test_cast_between_durations::<DurationMicrosecondType, DurationSecondType>();
test_cast_between_durations::<DurationMicrosecondType, DurationMillisecondType>();
test_cast_between_durations::<DurationMicrosecondType, DurationNanosecondType>();
test_cast_between_durations::<DurationNanosecondType, DurationSecondType>();
test_cast_between_durations::<DurationNanosecondType, DurationMillisecondType>();
test_cast_between_durations::<DurationNanosecondType, DurationMicrosecondType>();

// cast failed
let array = DurationSecondArray::from(vec![
Some(i64::MAX),
Some(8640203410378005),
Some(10241096),
None,
]);
let b = cast(&array, &DataType::Duration(TimeUnit::Nanosecond)).unwrap();
let c = b.as_primitive::<DurationNanosecondType>();
assert!(c.is_null(0));
assert!(c.is_null(1));
assert_eq!(10241096000000000, c.value(2));
assert!(c.is_null(3));

// durations to numerics
let array = DurationSecondArray::from(vec![
Some(i64::MAX),
Some(8640203410378005),
Some(10241096),
None,
]);
let b = cast(&array, &DataType::Int64).unwrap();
let c = b.as_primitive::<Int64Type>();
assert_eq!(i64::MAX, c.value(0));
assert_eq!(8640203410378005, c.value(1));
assert_eq!(10241096, c.value(2));
assert!(c.is_null(3));

let b = cast(&array, &DataType::Int32).unwrap();
let c = b.as_primitive::<Int32Type>();
assert_eq!(0, c.value(0));
assert_eq!(0, c.value(1));
assert_eq!(10241096, c.value(2));
assert!(c.is_null(3));

// numerics to durations
let array = Int32Array::from(vec![Some(i32::MAX), Some(802034103), Some(10241096), None]);
let b = cast(&array, &DataType::Duration(TimeUnit::Second)).unwrap();
let c = b.as_any().downcast_ref::<DurationSecondArray>().unwrap();
assert_eq!(i32::MAX as i64, c.value(0));
assert_eq!(802034103, c.value(1));
assert_eq!(10241096, c.value(2));
assert!(c.is_null(3));
}

#[test]
fn test_cast_to_strings() {
let a = Int32Array::from(vec![1, 2, 3]);
Expand Down

0 comments on commit 2881dbe

Please sign in to comment.