Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Add AlignedBytes types #19308

Merged
merged 3 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/polars-arrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ description = "Minimal implementation of the Arrow specification forked from arr

[dependencies]
atoi = { workspace = true, optional = true }
bytemuck = { workspace = true }
bytemuck = { workspace = true, features = ["must_cast"] }
chrono = { workspace = true }
# for timezone support
chrono-tz = { workspace = true, optional = true }
Expand Down
4 changes: 3 additions & 1 deletion crates/polars-arrow/src/array/binview/view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use polars_utils::total_ord::{TotalEq, TotalOrd};

use crate::buffer::Buffer;
use crate::datatypes::PrimitiveType;
use crate::types::NativeType;
use crate::types::{Bytes16Alignment4, NativeType};

// We use this instead of u128 because we want alignment of <= 8 bytes.
/// A reference to a set of bytes.
Expand Down Expand Up @@ -346,7 +346,9 @@ impl MinMax for View {

impl NativeType for View {
const PRIMITIVE: PrimitiveType = PrimitiveType::UInt128;

type Bytes = [u8; 16];
type AlignedBytes = Bytes16Alignment4;

#[inline]
fn to_le_bytes(&self) -> Self::Bytes {
Expand Down
112 changes: 112 additions & 0 deletions crates/polars-arrow/src/types/aligned_bytes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
use bytemuck::{Pod, Zeroable};

use super::{days_ms, f16, i256, months_days_ns};
use crate::array::View;

/// Define that a type has the same byte alignment and size as `B`.
///
/// # Safety
///
/// This is safe to implement if both types have the same alignment and size.
pub unsafe trait AlignedBytesCast<B: AlignedBytes>: Pod {}

/// A representation of a type as raw bytes with the same alignment as the original type.
pub trait AlignedBytes: Pod + Zeroable + Copy + Default + Eq {
const ALIGNMENT: usize;
const SIZE: usize;

type Unaligned: AsRef<[u8]>
+ AsMut<[u8]>
+ std::ops::Index<usize, Output = u8>
+ std::ops::IndexMut<usize, Output = u8>
+ for<'a> TryFrom<&'a [u8]>
+ std::fmt::Debug
+ Default
+ IntoIterator<Item = u8>
+ Pod;

fn to_unaligned(&self) -> Self::Unaligned;
fn from_unaligned(unaligned: Self::Unaligned) -> Self;

/// Safely cast a mutable reference to a [`Vec`] of `T` to a mutable reference of `Self`.
fn cast_vec_ref_mut<T: AlignedBytesCast<Self>>(vec: &mut Vec<T>) -> &mut Vec<Self> {
if cfg!(debug_assertions) {
assert_eq!(size_of::<T>(), size_of::<Self>());
assert_eq!(align_of::<T>(), align_of::<Self>());
}

// SAFETY: SameBytes guarantees that T:
// 1. has the same size
// 2. has the same alignment
// 3. is Pod (therefore has no life-time issues)
unsafe { std::mem::transmute(vec) }
}
}

macro_rules! impl_aligned_bytes {
(
$(($name:ident, $size:literal, $alignment:literal, [$($eq_type:ty),*]),)+
) => {
$(
/// Bytes with a size and alignment.
///
/// This is used to reduce the monomorphizations for routines that solely rely on the size
/// and alignment of types.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Default, Pod, Zeroable)]
#[repr(C, align($alignment))]
pub struct $name([u8; $size]);

impl AlignedBytes for $name {
const ALIGNMENT: usize = $alignment;
const SIZE: usize = $size;

type Unaligned = [u8; $size];

#[inline(always)]
fn to_unaligned(&self) -> Self::Unaligned {
self.0
}
#[inline(always)]
fn from_unaligned(unaligned: Self::Unaligned) -> Self {
Self(unaligned)
}
}

impl AsRef<[u8; $size]> for $name {
#[inline(always)]
fn as_ref(&self) -> &[u8; $size] {
&self.0
}
}

$(
impl From<$eq_type> for $name {
#[inline(always)]
fn from(value: $eq_type) -> Self {
bytemuck::must_cast(value)
}
}
impl From<$name> for $eq_type {
#[inline(always)]
fn from(value: $name) -> Self {
bytemuck::must_cast(value)
}
}
unsafe impl AlignedBytesCast<$name> for $eq_type {}
)*
)+
}
}

impl_aligned_bytes! {
(Bytes1Alignment1, 1, 1, [u8, i8]),
(Bytes2Alignment2, 2, 2, [u16, i16, f16]),
(Bytes4Alignment4, 4, 4, [u32, i32, f32]),
(Bytes8Alignment8, 8, 8, [u64, i64, f64]),
(Bytes8Alignment4, 8, 4, [days_ms]),
(Bytes12Alignment4, 12, 4, [[u32; 3]]),
(Bytes16Alignment4, 16, 4, [View]),
(Bytes16Alignment8, 16, 8, [months_days_ns]),
(Bytes16Alignment16, 16, 16, [u128, i128]),
(Bytes32Alignment16, 32, 16, [i256]),
}
2 changes: 2 additions & 0 deletions crates/polars-arrow/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
//! Finally, this module contains traits used to compile code based on [`NativeType`] optimized
//! for SIMD, at [`mod@simd`].

mod aligned_bytes;
pub use aligned_bytes::*;
mod bit_chunk;
pub use bit_chunk::{BitChunk, BitChunkIter, BitChunkOnes};
mod index;
Expand Down
47 changes: 33 additions & 14 deletions crates/polars-arrow/src/types/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use polars_utils::min_max::MinMax;
use polars_utils::nulls::IsNull;
use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash, TotalOrd, TotalOrdWrap};

use super::aligned_bytes::*;
use super::PrimitiveType;

/// Sealed trait implemented by all physical types that can be allocated,
Expand All @@ -27,6 +28,7 @@ pub trait NativeType:
+ TotalOrd
+ IsNull
+ MinMax
+ AlignedBytesCast<Self::AlignedBytes>
{
/// The corresponding variant of [`PrimitiveType`].
const PRIMITIVE: PrimitiveType;
Expand All @@ -42,6 +44,11 @@ pub trait NativeType:
+ Default
+ IntoIterator<Item = u8>;

/// Type denoting its representation as aligned bytes.
///
/// This is `[u8; N]` where `N = size_of::<Self>` and has alignment `align_of::<Self>`.
type AlignedBytes: AlignedBytes<Unaligned = Self::Bytes> + From<Self> + Into<Self>;

/// To bytes in little endian
fn to_le_bytes(&self) -> Self::Bytes;

Expand All @@ -56,11 +63,13 @@ pub trait NativeType:
}

macro_rules! native_type {
($type:ty, $primitive_type:expr) => {
($type:ty, $aligned:ty, $primitive_type:expr) => {
impl NativeType for $type {
const PRIMITIVE: PrimitiveType = $primitive_type;

type Bytes = [u8; size_of::<Self>()];
type Bytes = [u8; std::mem::size_of::<Self>()];
type AlignedBytes = $aligned;

#[inline]
fn to_le_bytes(&self) -> Self::Bytes {
Self::to_le_bytes(*self)
Expand All @@ -84,18 +93,18 @@ macro_rules! native_type {
};
}

native_type!(u8, PrimitiveType::UInt8);
native_type!(u16, PrimitiveType::UInt16);
native_type!(u32, PrimitiveType::UInt32);
native_type!(u64, PrimitiveType::UInt64);
native_type!(i8, PrimitiveType::Int8);
native_type!(i16, PrimitiveType::Int16);
native_type!(i32, PrimitiveType::Int32);
native_type!(i64, PrimitiveType::Int64);
native_type!(f32, PrimitiveType::Float32);
native_type!(f64, PrimitiveType::Float64);
native_type!(i128, PrimitiveType::Int128);
native_type!(u128, PrimitiveType::UInt128);
native_type!(u8, Bytes1Alignment1, PrimitiveType::UInt8);
native_type!(u16, Bytes2Alignment2, PrimitiveType::UInt16);
native_type!(u32, Bytes4Alignment4, PrimitiveType::UInt32);
native_type!(u64, Bytes8Alignment8, PrimitiveType::UInt64);
native_type!(i8, Bytes1Alignment1, PrimitiveType::Int8);
native_type!(i16, Bytes2Alignment2, PrimitiveType::Int16);
native_type!(i32, Bytes4Alignment4, PrimitiveType::Int32);
native_type!(i64, Bytes8Alignment8, PrimitiveType::Int64);
native_type!(f32, Bytes4Alignment4, PrimitiveType::Float32);
native_type!(f64, Bytes8Alignment8, PrimitiveType::Float64);
native_type!(i128, Bytes16Alignment16, PrimitiveType::Int128);
native_type!(u128, Bytes16Alignment16, PrimitiveType::UInt128);

/// The in-memory representation of the DayMillisecond variant of arrow's "Interval" logical type.
#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Zeroable, Pod)]
Expand Down Expand Up @@ -151,7 +160,10 @@ impl MinMax for days_ms {

impl NativeType for days_ms {
const PRIMITIVE: PrimitiveType = PrimitiveType::DaysMs;

type Bytes = [u8; 8];
type AlignedBytes = Bytes8Alignment4;

#[inline]
fn to_le_bytes(&self) -> Self::Bytes {
let days = self.0.to_le_bytes();
Expand Down Expand Up @@ -289,7 +301,10 @@ impl MinMax for months_days_ns {

impl NativeType for months_days_ns {
const PRIMITIVE: PrimitiveType = PrimitiveType::MonthDayNano;

type Bytes = [u8; 16];
type AlignedBytes = Bytes16Alignment8;

#[inline]
fn to_le_bytes(&self) -> Self::Bytes {
let months = self.months().to_le_bytes();
Expand Down Expand Up @@ -658,7 +673,10 @@ impl MinMax for f16 {

impl NativeType for f16 {
const PRIMITIVE: PrimitiveType = PrimitiveType::Float16;

type Bytes = [u8; 2];
type AlignedBytes = Bytes2Alignment2;

#[inline]
fn to_le_bytes(&self) -> Self::Bytes {
self.0.to_le_bytes()
Expand Down Expand Up @@ -758,6 +776,7 @@ impl NativeType for i256 {
const PRIMITIVE: PrimitiveType = PrimitiveType::Int256;

type Bytes = [u8; 32];
type AlignedBytes = Bytes32Alignment16;

#[inline]
fn to_le_bytes(&self) -> Self::Bytes {
Expand Down
Loading