diff --git a/src/common_union.rs b/src/common_union.rs index 8fcbeba..93b3ead 100644 --- a/src/common_union.rs +++ b/src/common_union.rs @@ -1,13 +1,14 @@ use std::sync::{Arc, OnceLock}; use datafusion::arrow::array::{ - Array, ArrayRef, BooleanArray, Float64Array, Int64Array, NullArray, StringArray, UnionArray, + Array, ArrayRef, AsArray, BooleanArray, Float64Array, Int64Array, NullArray, StringArray, UnionArray, }; -use datafusion::arrow::buffer::Buffer; +use datafusion::arrow::buffer::{Buffer, ScalarBuffer}; use datafusion::arrow::datatypes::{DataType, Field, UnionFields, UnionMode}; +use datafusion::arrow::error::ArrowError; use datafusion::common::ScalarValue; -pub(crate) fn is_json_union(data_type: &DataType) -> bool { +pub fn is_json_union(data_type: &DataType) -> bool { match data_type { DataType::Union(fields, UnionMode::Sparse) => fields == &union_fields(), _ => false, @@ -64,7 +65,7 @@ impl JsonUnion { strings: vec![None; length], arrays: vec![None; length], objects: vec![None; length], - type_ids: vec![0; length], + type_ids: vec![TYPE_ID_NULL; length], index: 0, length, } @@ -114,7 +115,7 @@ impl FromIterator> for JsonUnion { } impl TryFrom for UnionArray { - type Error = datafusion::arrow::error::ArrowError; + type Error = ArrowError; fn try_from(value: JsonUnion) -> Result { let children: Vec> = vec![ @@ -199,3 +200,109 @@ impl From for ScalarValue { } } } + +pub struct JsonUnionEncoder { + boolean: BooleanArray, + int: Int64Array, + float: Float64Array, + string: StringArray, + array: StringArray, + object: StringArray, + type_ids: ScalarBuffer, +} + +impl JsonUnionEncoder { + #[must_use] + pub fn from_union(union: UnionArray) -> Option { + if is_json_union(union.data_type()) { + let (_, type_ids, _, c) = union.into_parts(); + Some(Self { + boolean: c[1].as_boolean().clone(), + int: c[2].as_primitive().clone(), + float: c[3].as_primitive().clone(), + string: c[4].as_string().clone(), + array: c[5].as_string().clone(), + object: c[6].as_string().clone(), + type_ids, + }) + } else { + None + } + } + + #[must_use] + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.type_ids.len() + } + + /// Get the encodable value for a given index + /// + /// # Panics + /// + /// Panics if the idx is outside the union values or an invalid type id exists in the union. + #[must_use] + pub fn get_value(&self, idx: usize) -> JsonUnionValue { + let type_id = self.type_ids[idx]; + match type_id { + TYPE_ID_NULL => JsonUnionValue::JsonNull, + TYPE_ID_BOOL => JsonUnionValue::Bool(self.boolean.value(idx)), + TYPE_ID_INT => JsonUnionValue::Int(self.int.value(idx)), + TYPE_ID_FLOAT => JsonUnionValue::Float(self.float.value(idx)), + TYPE_ID_STR => JsonUnionValue::Str(self.string.value(idx)), + TYPE_ID_ARRAY => JsonUnionValue::Array(self.array.value(idx)), + TYPE_ID_OBJECT => JsonUnionValue::Object(self.object.value(idx)), + _ => panic!("Invalid type_id: {type_id}, not a valid JSON type"), + } + } +} + +#[derive(Debug, PartialEq)] +pub enum JsonUnionValue<'a> { + JsonNull, + Bool(bool), + Int(i64), + Float(f64), + Str(&'a str), + Array(&'a str), + Object(&'a str), +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_json_union() { + let json_union = JsonUnion::from_iter(vec![ + Some(JsonUnionField::JsonNull), + Some(JsonUnionField::Bool(true)), + Some(JsonUnionField::Bool(false)), + Some(JsonUnionField::Int(42)), + Some(JsonUnionField::Float(42.0)), + Some(JsonUnionField::Str("foo".to_string())), + Some(JsonUnionField::Array("[42]".to_string())), + Some(JsonUnionField::Object(r#"{"foo": 42}"#.to_string())), + None, + ]); + + let union_array = UnionArray::try_from(json_union).unwrap(); + let encoder = JsonUnionEncoder::from_union(union_array).unwrap(); + + let values_after: Vec<_> = (0..encoder.len()).map(|idx| encoder.get_value(idx)).collect(); + assert_eq!( + values_after, + vec![ + JsonUnionValue::JsonNull, + JsonUnionValue::Bool(true), + JsonUnionValue::Bool(false), + JsonUnionValue::Int(42), + JsonUnionValue::Float(42.0), + JsonUnionValue::Str("foo"), + JsonUnionValue::Array("[42]"), + JsonUnionValue::Object(r#"{"foo": 42}"#), + JsonUnionValue::JsonNull, + ] + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index 75b18f6..692478e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,6 +19,8 @@ mod json_get_str; mod json_length; mod rewrite; +pub use common_union::{JsonUnionEncoder, JsonUnionValue}; + pub mod functions { pub use crate::json_as_text::json_as_text; pub use crate::json_contains::json_contains;