diff --git a/arrow-cast/src/cast/dictionary.rs b/arrow-cast/src/cast/dictionary.rs index daaddc4915ef..fc4d99430151 100644 --- a/arrow-cast/src/cast/dictionary.rs +++ b/arrow-cast/src/cast/dictionary.rs @@ -202,11 +202,41 @@ pub(crate) fn cast_to_dictionary( UInt16 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), UInt32 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), UInt64 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), - Decimal128(_, _) => { - pack_numeric_to_dictionary::(array, dict_value_type, cast_options) + Decimal128(p, s) => { + let dict = pack_numeric_to_dictionary::( + array, + dict_value_type, + cast_options, + )?; + let dict = dict + .as_dictionary::() + .downcast_dict::() + .unwrap(); + let value = dict.values().clone(); + // Set correct precision/scale + let value = value.with_precision_and_scale(p, s)?; + Ok(Arc::new(DictionaryArray::::try_new( + dict.keys().clone(), + Arc::new(value), + )?)) } - Decimal256(_, _) => { - pack_numeric_to_dictionary::(array, dict_value_type, cast_options) + Decimal256(p, s) => { + let dict = pack_numeric_to_dictionary::( + array, + dict_value_type, + cast_options, + )?; + let dict = dict + .as_dictionary::() + .downcast_dict::() + .unwrap(); + let value = dict.values().clone(); + // Set correct precision/scale + let value = value.with_precision_and_scale(p, s)?; + Ok(Arc::new(DictionaryArray::::try_new( + dict.keys().clone(), + Arc::new(value), + )?)) } Float16 => { pack_numeric_to_dictionary::(array, dict_value_type, cast_options) diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index fe59a141cbe2..e80d497c8cba 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -2650,6 +2650,38 @@ mod tests { err.unwrap_err().to_string()); } + #[test] + fn test_cast_decimal128_to_decimal128_dict() { + let p = 20; + let s = 3; + let input_type = DataType::Decimal128(p, s); + let output_type = DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Decimal128(p, s)), + ); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let array = create_decimal_array(array, p, s).unwrap(); + let cast_array = cast_with_options(&array, &output_type, &CastOptions::default()).unwrap(); + assert_eq!(cast_array.data_type(), &output_type); + } + + #[test] + fn test_cast_decimal256_to_decimal256_dict() { + let p = 20; + let s = 3; + let input_type = DataType::Decimal256(p, s); + let output_type = DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Decimal256(p, s)), + ); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let array = create_decimal_array(array, p, s).unwrap(); + let cast_array = cast_with_options(&array, &output_type, &CastOptions::default()).unwrap(); + assert_eq!(cast_array.data_type(), &output_type); + } + #[test] fn test_cast_decimal128_to_decimal128_overflow() { let input_type = DataType::Decimal128(38, 3);