Skip to content

Commit

Permalink
refactor: Add AlignedBytes types (#19308)
Browse files Browse the repository at this point in the history
Co-authored-by: Orson Peters <[email protected]>
  • Loading branch information
coastalwhite and orlp authored Oct 23, 2024
1 parent 7aefcc8 commit fc8eec2
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 16 deletions.
2 changes: 1 addition & 1 deletion crates/polars-arrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ description = "Minimal implementation of the Arrow specification forked from arr

[dependencies]
atoi = { workspace = true, optional = true }
bytemuck = { workspace = true }
bytemuck = { workspace = true, features = ["must_cast"] }
chrono = { workspace = true }
# for timezone support
chrono-tz = { workspace = true, optional = true }
Expand Down
4 changes: 3 additions & 1 deletion crates/polars-arrow/src/array/binview/view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use polars_utils::total_ord::{TotalEq, TotalOrd};

use crate::buffer::Buffer;
use crate::datatypes::PrimitiveType;
use crate::types::NativeType;
use crate::types::{Bytes16Alignment4, NativeType};

// We use this instead of u128 because we want alignment of <= 8 bytes.
/// A reference to a set of bytes.
Expand Down Expand Up @@ -346,7 +346,9 @@ impl MinMax for View {

impl NativeType for View {
const PRIMITIVE: PrimitiveType = PrimitiveType::UInt128;

type Bytes = [u8; 16];
type AlignedBytes = Bytes16Alignment4;

#[inline]
fn to_le_bytes(&self) -> Self::Bytes {
Expand Down
112 changes: 112 additions & 0 deletions crates/polars-arrow/src/types/aligned_bytes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
use bytemuck::{Pod, Zeroable};

use super::{days_ms, f16, i256, months_days_ns};
use crate::array::View;

/// Define that a type has the same byte alignment and size as `B`.
///
/// # Safety
///
/// This is safe to implement if both types have the same alignment and size.
pub unsafe trait AlignedBytesCast<B: AlignedBytes>: Pod {}

/// A representation of a type as raw bytes with the same alignment as the original type.
pub trait AlignedBytes: Pod + Zeroable + Copy + Default + Eq {
const ALIGNMENT: usize;
const SIZE: usize;

type Unaligned: AsRef<[u8]>
+ AsMut<[u8]>
+ std::ops::Index<usize, Output = u8>
+ std::ops::IndexMut<usize, Output = u8>
+ for<'a> TryFrom<&'a [u8]>
+ std::fmt::Debug
+ Default
+ IntoIterator<Item = u8>
+ Pod;

fn to_unaligned(&self) -> Self::Unaligned;
fn from_unaligned(unaligned: Self::Unaligned) -> Self;

/// Safely cast a mutable reference to a [`Vec`] of `T` to a mutable reference of `Self`.
fn cast_vec_ref_mut<T: AlignedBytesCast<Self>>(vec: &mut Vec<T>) -> &mut Vec<Self> {
if cfg!(debug_assertions) {
assert_eq!(size_of::<T>(), size_of::<Self>());
assert_eq!(align_of::<T>(), align_of::<Self>());
}

// SAFETY: SameBytes guarantees that T:
// 1. has the same size
// 2. has the same alignment
// 3. is Pod (therefore has no life-time issues)
unsafe { std::mem::transmute(vec) }
}
}

macro_rules! impl_aligned_bytes {
(
$(($name:ident, $size:literal, $alignment:literal, [$($eq_type:ty),*]),)+
) => {
$(
/// Bytes with a size and alignment.
///
/// This is used to reduce the monomorphizations for routines that solely rely on the size
/// and alignment of types.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Default, Pod, Zeroable)]
#[repr(C, align($alignment))]
pub struct $name([u8; $size]);

impl AlignedBytes for $name {
const ALIGNMENT: usize = $alignment;
const SIZE: usize = $size;

type Unaligned = [u8; $size];

#[inline(always)]
fn to_unaligned(&self) -> Self::Unaligned {
self.0
}
#[inline(always)]
fn from_unaligned(unaligned: Self::Unaligned) -> Self {
Self(unaligned)
}
}

impl AsRef<[u8; $size]> for $name {
#[inline(always)]
fn as_ref(&self) -> &[u8; $size] {
&self.0
}
}

$(
impl From<$eq_type> for $name {
#[inline(always)]
fn from(value: $eq_type) -> Self {
bytemuck::must_cast(value)
}
}
impl From<$name> for $eq_type {
#[inline(always)]
fn from(value: $name) -> Self {
bytemuck::must_cast(value)
}
}
unsafe impl AlignedBytesCast<$name> for $eq_type {}
)*
)+
}
}

impl_aligned_bytes! {
(Bytes1Alignment1, 1, 1, [u8, i8]),
(Bytes2Alignment2, 2, 2, [u16, i16, f16]),
(Bytes4Alignment4, 4, 4, [u32, i32, f32]),
(Bytes8Alignment8, 8, 8, [u64, i64, f64]),
(Bytes8Alignment4, 8, 4, [days_ms]),
(Bytes12Alignment4, 12, 4, [[u32; 3]]),
(Bytes16Alignment4, 16, 4, [View]),
(Bytes16Alignment8, 16, 8, [months_days_ns]),
(Bytes16Alignment16, 16, 16, [u128, i128]),
(Bytes32Alignment16, 32, 16, [i256]),
}
2 changes: 2 additions & 0 deletions crates/polars-arrow/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
//! Finally, this module contains traits used to compile code based on [`NativeType`] optimized
//! for SIMD, at [`mod@simd`].

mod aligned_bytes;
pub use aligned_bytes::*;
mod bit_chunk;
pub use bit_chunk::{BitChunk, BitChunkIter, BitChunkOnes};
mod index;
Expand Down
47 changes: 33 additions & 14 deletions crates/polars-arrow/src/types/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use polars_utils::min_max::MinMax;
use polars_utils::nulls::IsNull;
use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash, TotalOrd, TotalOrdWrap};

use super::aligned_bytes::*;
use super::PrimitiveType;

/// Sealed trait implemented by all physical types that can be allocated,
Expand All @@ -27,6 +28,7 @@ pub trait NativeType:
+ TotalOrd
+ IsNull
+ MinMax
+ AlignedBytesCast<Self::AlignedBytes>
{
/// The corresponding variant of [`PrimitiveType`].
const PRIMITIVE: PrimitiveType;
Expand All @@ -42,6 +44,11 @@ pub trait NativeType:
+ Default
+ IntoIterator<Item = u8>;

/// Type denoting its representation as aligned bytes.
///
/// This is `[u8; N]` where `N = size_of::<Self>` and has alignment `align_of::<Self>`.
type AlignedBytes: AlignedBytes<Unaligned = Self::Bytes> + From<Self> + Into<Self>;

/// To bytes in little endian
fn to_le_bytes(&self) -> Self::Bytes;

Expand All @@ -56,11 +63,13 @@ pub trait NativeType:
}

macro_rules! native_type {
($type:ty, $primitive_type:expr) => {
($type:ty, $aligned:ty, $primitive_type:expr) => {
impl NativeType for $type {
const PRIMITIVE: PrimitiveType = $primitive_type;

type Bytes = [u8; size_of::<Self>()];
type Bytes = [u8; std::mem::size_of::<Self>()];
type AlignedBytes = $aligned;

#[inline]
fn to_le_bytes(&self) -> Self::Bytes {
Self::to_le_bytes(*self)
Expand All @@ -84,18 +93,18 @@ macro_rules! native_type {
};
}

native_type!(u8, PrimitiveType::UInt8);
native_type!(u16, PrimitiveType::UInt16);
native_type!(u32, PrimitiveType::UInt32);
native_type!(u64, PrimitiveType::UInt64);
native_type!(i8, PrimitiveType::Int8);
native_type!(i16, PrimitiveType::Int16);
native_type!(i32, PrimitiveType::Int32);
native_type!(i64, PrimitiveType::Int64);
native_type!(f32, PrimitiveType::Float32);
native_type!(f64, PrimitiveType::Float64);
native_type!(i128, PrimitiveType::Int128);
native_type!(u128, PrimitiveType::UInt128);
native_type!(u8, Bytes1Alignment1, PrimitiveType::UInt8);
native_type!(u16, Bytes2Alignment2, PrimitiveType::UInt16);
native_type!(u32, Bytes4Alignment4, PrimitiveType::UInt32);
native_type!(u64, Bytes8Alignment8, PrimitiveType::UInt64);
native_type!(i8, Bytes1Alignment1, PrimitiveType::Int8);
native_type!(i16, Bytes2Alignment2, PrimitiveType::Int16);
native_type!(i32, Bytes4Alignment4, PrimitiveType::Int32);
native_type!(i64, Bytes8Alignment8, PrimitiveType::Int64);
native_type!(f32, Bytes4Alignment4, PrimitiveType::Float32);
native_type!(f64, Bytes8Alignment8, PrimitiveType::Float64);
native_type!(i128, Bytes16Alignment16, PrimitiveType::Int128);
native_type!(u128, Bytes16Alignment16, PrimitiveType::UInt128);

/// The in-memory representation of the DayMillisecond variant of arrow's "Interval" logical type.
#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Zeroable, Pod)]
Expand Down Expand Up @@ -151,7 +160,10 @@ impl MinMax for days_ms {

impl NativeType for days_ms {
const PRIMITIVE: PrimitiveType = PrimitiveType::DaysMs;

type Bytes = [u8; 8];
type AlignedBytes = Bytes8Alignment4;

#[inline]
fn to_le_bytes(&self) -> Self::Bytes {
let days = self.0.to_le_bytes();
Expand Down Expand Up @@ -289,7 +301,10 @@ impl MinMax for months_days_ns {

impl NativeType for months_days_ns {
const PRIMITIVE: PrimitiveType = PrimitiveType::MonthDayNano;

type Bytes = [u8; 16];
type AlignedBytes = Bytes16Alignment8;

#[inline]
fn to_le_bytes(&self) -> Self::Bytes {
let months = self.months().to_le_bytes();
Expand Down Expand Up @@ -658,7 +673,10 @@ impl MinMax for f16 {

impl NativeType for f16 {
const PRIMITIVE: PrimitiveType = PrimitiveType::Float16;

type Bytes = [u8; 2];
type AlignedBytes = Bytes2Alignment2;

#[inline]
fn to_le_bytes(&self) -> Self::Bytes {
self.0.to_le_bytes()
Expand Down Expand Up @@ -758,6 +776,7 @@ impl NativeType for i256 {
const PRIMITIVE: PrimitiveType = PrimitiveType::Int256;

type Bytes = [u8; 32];
type AlignedBytes = Bytes32Alignment16;

#[inline]
fn to_le_bytes(&self) -> Self::Bytes {
Expand Down

0 comments on commit fc8eec2

Please sign in to comment.