Skip to content

Commit

Permalink
Union encoding (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Oct 26, 2024
1 parent 21554f2 commit f650af0
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 5 deletions.
117 changes: 112 additions & 5 deletions src/common_union.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use std::sync::{Arc, OnceLock};

use datafusion::arrow::array::{
Array, ArrayRef, BooleanArray, Float64Array, Int64Array, NullArray, StringArray, UnionArray,
Array, ArrayRef, AsArray, BooleanArray, Float64Array, Int64Array, NullArray, StringArray, UnionArray,
};
use datafusion::arrow::buffer::Buffer;
use datafusion::arrow::buffer::{Buffer, ScalarBuffer};
use datafusion::arrow::datatypes::{DataType, Field, UnionFields, UnionMode};
use datafusion::arrow::error::ArrowError;
use datafusion::common::ScalarValue;

pub(crate) fn is_json_union(data_type: &DataType) -> bool {
pub fn is_json_union(data_type: &DataType) -> bool {
match data_type {
DataType::Union(fields, UnionMode::Sparse) => fields == &union_fields(),
_ => false,
Expand Down Expand Up @@ -64,7 +65,7 @@ impl JsonUnion {
strings: vec![None; length],
arrays: vec![None; length],
objects: vec![None; length],
type_ids: vec![0; length],
type_ids: vec![TYPE_ID_NULL; length],
index: 0,
length,
}
Expand Down Expand Up @@ -114,7 +115,7 @@ impl FromIterator<Option<JsonUnionField>> for JsonUnion {
}

impl TryFrom<JsonUnion> for UnionArray {
type Error = datafusion::arrow::error::ArrowError;
type Error = ArrowError;

fn try_from(value: JsonUnion) -> Result<Self, Self::Error> {
let children: Vec<Arc<dyn Array>> = vec![
Expand Down Expand Up @@ -199,3 +200,109 @@ impl From<JsonUnionField> for ScalarValue {
}
}
}

pub struct JsonUnionEncoder {
boolean: BooleanArray,
int: Int64Array,
float: Float64Array,
string: StringArray,
array: StringArray,
object: StringArray,
type_ids: ScalarBuffer<i8>,
}

impl JsonUnionEncoder {
#[must_use]
pub fn from_union(union: UnionArray) -> Option<Self> {
if is_json_union(union.data_type()) {
let (_, type_ids, _, c) = union.into_parts();
Some(Self {
boolean: c[1].as_boolean().clone(),
int: c[2].as_primitive().clone(),
float: c[3].as_primitive().clone(),
string: c[4].as_string().clone(),
array: c[5].as_string().clone(),
object: c[6].as_string().clone(),
type_ids,
})
} else {
None
}
}

#[must_use]
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.type_ids.len()
}

/// Get the encodable value for a given index
///
/// # Panics
///
/// Panics if the idx is outside the union values or an invalid type id exists in the union.
#[must_use]
pub fn get_value(&self, idx: usize) -> JsonUnionValue {
let type_id = self.type_ids[idx];
match type_id {
TYPE_ID_NULL => JsonUnionValue::JsonNull,
TYPE_ID_BOOL => JsonUnionValue::Bool(self.boolean.value(idx)),
TYPE_ID_INT => JsonUnionValue::Int(self.int.value(idx)),
TYPE_ID_FLOAT => JsonUnionValue::Float(self.float.value(idx)),
TYPE_ID_STR => JsonUnionValue::Str(self.string.value(idx)),
TYPE_ID_ARRAY => JsonUnionValue::Array(self.array.value(idx)),
TYPE_ID_OBJECT => JsonUnionValue::Object(self.object.value(idx)),
_ => panic!("Invalid type_id: {type_id}, not a valid JSON type"),
}
}
}

#[derive(Debug, PartialEq)]
pub enum JsonUnionValue<'a> {
JsonNull,
Bool(bool),
Int(i64),
Float(f64),
Str(&'a str),
Array(&'a str),
Object(&'a str),
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_json_union() {
let json_union = JsonUnion::from_iter(vec![
Some(JsonUnionField::JsonNull),
Some(JsonUnionField::Bool(true)),
Some(JsonUnionField::Bool(false)),
Some(JsonUnionField::Int(42)),
Some(JsonUnionField::Float(42.0)),
Some(JsonUnionField::Str("foo".to_string())),
Some(JsonUnionField::Array("[42]".to_string())),
Some(JsonUnionField::Object(r#"{"foo": 42}"#.to_string())),
None,
]);

let union_array = UnionArray::try_from(json_union).unwrap();
let encoder = JsonUnionEncoder::from_union(union_array).unwrap();

let values_after: Vec<_> = (0..encoder.len()).map(|idx| encoder.get_value(idx)).collect();
assert_eq!(
values_after,
vec![
JsonUnionValue::JsonNull,
JsonUnionValue::Bool(true),
JsonUnionValue::Bool(false),
JsonUnionValue::Int(42),
JsonUnionValue::Float(42.0),
JsonUnionValue::Str("foo"),
JsonUnionValue::Array("[42]"),
JsonUnionValue::Object(r#"{"foo": 42}"#),
JsonUnionValue::JsonNull,
]
);
}
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ mod json_get_str;
mod json_length;
mod rewrite;

pub use common_union::{JsonUnionEncoder, JsonUnionValue};

pub mod functions {
pub use crate::json_as_text::json_as_text;
pub use crate::json_contains::json_contains;
Expand Down

0 comments on commit f650af0

Please sign in to comment.