Skip to content

Commit

Permalink
add fallback impls for some extension operations
Browse files Browse the repository at this point in the history
  • Loading branch information
a10y committed Oct 9, 2024
1 parent 44b49f6 commit a0f0ba1
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 36 deletions.
6 changes: 1 addition & 5 deletions vortex-array/src/array/chunked/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,7 @@ pub(crate) fn try_canonicalize_chunks(
// ExtensionArray, so we should canonicalize each chunk into ExtensionArray first.
.map(|chunk| chunk.clone().into_extension().map(|ext| ext.storage()))
.collect::<VortexResult<Vec<Array>>>()?;
let storage_dtype = storage_chunks
.first()
.ok_or_else(|| vortex_err!("Expected at least one chunk in ChunkedArray"))?
.dtype()
.clone();
let storage_dtype = ext_dtype.scalars_dtype().clone();
let chunked_storage =
ChunkedArray::try_new(storage_chunks, storage_dtype)?.into_array();

Expand Down
10 changes: 7 additions & 3 deletions vortex-array/src/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,14 @@ impl Canonical {
Canonical::Struct(a) => struct_to_arrow(a)?,
Canonical::VarBin(a) => varbin_to_arrow(a)?,
Canonical::Extension(a) => {
if !is_temporal_ext_type(a.id()) {
vortex_bail!("unsupported extension dtype with ID {}", a.id().as_ref())
if is_temporal_ext_type(a.id()) {
temporal_to_arrow(TemporalArray::try_from(&a.into_array())?)?
} else {
// Convert storage array directly into arrow.
// NOTE: this loses the extension type information and we lose the ability to
// round-trip back to Vortex.
a.storage().into_canonical()?.into_arrow()?
}
temporal_to_arrow(TemporalArray::try_from(&a.into_array())?)?
}
})
}
Expand Down
21 changes: 6 additions & 15 deletions vortex-datafusion/src/datatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,14 @@ pub(crate) fn infer_data_type(dtype: &DType) -> DataType {
)))
}
DType::Extension(ext_dtype) => {
// Try and match against the known extension DTypes.
// Special case: the Vortex logical type system represents many temporal types from
// Arrow, and we want those to serialize properly.
if is_temporal_ext_type(ext_dtype.id()) {
make_arrow_temporal_dtype(ext_dtype)
} else {
// TODO(aduffy): allow extension type authors to plugin their own to/from Arrow
// conversions.
vortex_panic!("Unsupported extension type \"{}\"", ext_dtype.id())
// All other extension types, we rely on the scalar type to determine how it gets
// pushed to Arrow.
infer_data_type(ext_dtype.scalars_dtype())
}
}
}
Expand All @@ -109,7 +110,7 @@ mod test {

use arrow_schema::{DataType, Field, FieldRef, Fields, Schema};
use vortex_dtype::{
DType, ExtDType, ExtID, FieldName, FieldNames, Nullability, PType, StructDType,
DType, FieldName, FieldNames, Nullability, PType, StructDType,
};

use super::*;
Expand Down Expand Up @@ -165,16 +166,6 @@ mod test {
);
}

#[test]
#[should_panic]
fn test_dtype_conversion_panics() {
let _ = infer_data_type(&DType::Extension(ExtDType::new(
ExtID::from("my-fake-ext-dtype"),
Arc::new(PType::I32.into()),
None,
)));
}

#[test]
fn test_schema_conversion() {
let struct_dtype = the_struct();
Expand Down
5 changes: 5 additions & 0 deletions vortex-dtype/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ impl ExtDType {
/// }
/// ```
pub fn new(id: ExtID, scalars_dtype: Arc<DType>, metadata: Option<ExtMetadata>) -> Self {
assert!(
!matches!(scalars_dtype.as_ref(), &DType::Extension(_)),
"ExtDType cannot have Extension scalars_dtype"
);

Self {
id,
scalars_dtype,
Expand Down
27 changes: 14 additions & 13 deletions vortex-scalar/src/datafusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ impl TryFrom<Scalar> for ScalarValue {
type Error = VortexError;

fn try_from(value: Scalar) -> Result<Self, Self::Error> {
Ok(match value.dtype {
let (dtype, value) = value.into_parts();
Ok(match dtype {
DType::Null => ScalarValue::Null,
DType::Bool(_) => ScalarValue::Boolean(value.value.as_bool()?),
DType::Bool(_) => ScalarValue::Boolean(value.as_bool()?),
DType::Primitive(ptype, _) => {
let pvalue = value.value.as_pvalue()?;
let pvalue = value.as_pvalue()?;
match pvalue {
None => match ptype {
PType::U8 => ScalarValue::UInt8(None),
Expand Down Expand Up @@ -46,15 +47,11 @@ impl TryFrom<Scalar> for ScalarValue {
},
}
}
DType::Utf8(_) => ScalarValue::Utf8(
value
.value
.as_buffer_string()?
.map(|b| b.as_str().to_string()),
),
DType::Utf8(_) => {
ScalarValue::Utf8(value.as_buffer_string()?.map(|b| b.as_str().to_string()))
}
DType::Binary(_) => ScalarValue::Binary(
value
.value
.as_buffer()?
.map(|b| b.into_vec().unwrap_or_else(|buf| buf.as_slice().to_vec())),
),
Expand All @@ -65,9 +62,11 @@ impl TryFrom<Scalar> for ScalarValue {
todo!("list scalar conversion")
}
DType::Extension(ext) => {
// Special handling: temporal extension types in Vortex correspond to Arrow's
// temporal physical types.
if is_temporal_ext_type(ext.id()) {
let metadata = TemporalMetadata::try_from(&ext)?;
let pv = value.value.as_pvalue()?;
let pv = value.as_pvalue()?;
return Ok(match metadata {
TemporalMetadata::Time(u) => match u {
TimeUnit::Ns => {
Expand Down Expand Up @@ -111,9 +110,11 @@ impl TryFrom<Scalar> for ScalarValue {
}
},
});
} else {
// Unknown extension type: perform scalar conversion using the canonical
// scalar DType.
ScalarValue::try_from(Scalar::new(ext.scalars_dtype().clone(), value))?
}

todo!("Non temporal extension scalar conversion")
}
})
}
Expand Down
5 changes: 5 additions & 0 deletions vortex-scalar/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ impl Scalar {
&self.value
}

#[inline]
pub fn into_parts(self) -> (DType, ScalarValue) {
(self.dtype, self.value)
}

#[inline]
pub fn into_value(self) -> ScalarValue {
self.value
Expand Down

0 comments on commit a0f0ba1

Please sign in to comment.