diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs index 7474ae41c526..ab8e86dc60f6 100644 --- a/arrow/src/compute/kernels/cast.rs +++ b/arrow/src/compute/kernels/cast.rs @@ -68,6 +68,44 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { } match (from_type, to_type) { + ( + Null, + Boolean + | Int8 + | UInt8 + | Int16 + | UInt16 + | Int32 + | UInt32 + | Float32 + | Date32 + | Time32(_) + | Int64 + | UInt64 + | Float64 + | Date64 + | List(_) + | Dictionary(_, _), + ) + | ( + Boolean + | Int8 + | UInt8 + | Int16 + | UInt16 + | Int32 + | UInt32 + | Float32 + | Date32 + | Time32(_) + | Int64 + | UInt64 + | Float64 + | Date64 + | List(_) + | Dictionary(_, _), + Null, + ) => true, (Struct(_), _) => false, (_, Struct(_)) => false, (LargeList(list_from), LargeList(list_to)) => { @@ -306,7 +344,6 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Timestamp(_, _), Date64) => true, // date64 to timestamp might not make sense, (Int64, Duration(_)) => true, - (Null, Int32) => true, (_, _) => false, } } @@ -867,6 +904,44 @@ pub fn cast_with_options( return Ok(array.clone()); } match (from_type, to_type) { + ( + Null, + Boolean + | Int8 + | UInt8 + | Int16 + | UInt16 + | Int32 + | UInt32 + | Float32 + | Date32 + | Time32(_) + | Int64 + | UInt64 + | Float64 + | Date64 + | List(_) + | Dictionary(_, _), + ) + | ( + Boolean + | Int8 + | UInt8 + | Int16 + | UInt16 + | Int32 + | UInt32 + | Float32 + | Date32 + | Time32(_) + | Int64 + | UInt64 + | Float64 + | Date64 + | List(_) + | Dictionary(_, _), + Null, + ) => Ok(new_null_array(to_type, array.len())), (Struct(_), _) => Err(ArrowError::CastError( "Cannot cast from struct to other types".to_string(), )), @@ -1706,10 +1781,6 @@ pub fn cast_with_options( } } } - - // null to primitive/flat types - (Null, Int32) => Ok(Arc::new(Int32Array::from(vec![None; array.len()]))), - (_, _) => Err(ArrowError::CastError(format!( "Casting from {:?} to {:?} not supported", from_type, to_type, @@ -4268,17 +4339,39 @@ mod tests { } #[test] - fn test_cast_null_array_to_int32() { - let array = Arc::new(NullArray::new(6)) as ArrayRef; + fn test_cast_null_array_from_and_to_others() { + macro_rules! typed_test { + ($ARR_TYPE:ident, $DATATYPE:ident, $TYPE:tt) => {{ + { + let array = Arc::new(NullArray::new(6)) as ArrayRef; + let expected = $ARR_TYPE::from(vec![None; 6]); + let cast_type = DataType::$DATATYPE; + let cast_array = cast(&array, &cast_type).expect("cast failed"); + let cast_array = as_primitive_array::<$TYPE>(&cast_array); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(cast_array, &expected); + } + { + let array = Arc::new($ARR_TYPE::from(vec![None; 4])) as ArrayRef; + let expected = NullArray::new(4); + let cast_array = cast(&array, &DataType::Null).expect("cast failed"); + let cast_array = as_null_array(&cast_array); + assert_eq!(cast_array.data_type(), &DataType::Null); + assert_eq!(cast_array, &expected); + } + }}; + } - let expected = Int32Array::from(vec![None; 6]); + typed_test!(Int16Array, Int16, Int16Type); + typed_test!(Int32Array, Int32, Int32Type); + typed_test!(Int64Array, Int64, Int64Type); - // Cast to a dictionary (same value type, Utf8) - let cast_type = DataType::Int32; - let cast_array = cast(&array, &cast_type).expect("cast failed"); - let cast_array = as_primitive_array::(&cast_array); - assert_eq!(cast_array.data_type(), &cast_type); - assert_eq!(cast_array, &expected); + typed_test!(UInt16Array, UInt16, UInt16Type); + typed_test!(UInt32Array, UInt32, UInt32Type); + typed_test!(UInt64Array, UInt64, UInt64Type); + + typed_test!(Float32Array, Float32, Float32Type); + typed_test!(Float64Array, Float64, Float64Type); } /// Print the `DictionaryArray` `array` as a vector of strings