Skip to content

Commit

Permalink
subgroup: make VectorOrScalar trait match discussions in EmbarkStudio…
Browse files Browse the repository at this point in the history
  • Loading branch information
Firestar99 committed Sep 23, 2024
1 parent 589af48 commit d56c0ed
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 112 deletions.
2 changes: 1 addition & 1 deletion crates/spirv-std/src/arch/subgroup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::float::Float;
use crate::integer::{Integer, SignedInteger, UnsignedInteger};
#[cfg(target_arch = "spirv")]
use crate::memory::{Scope, Semantics};
use crate::scalar::VectorOrScalar;
use crate::vector::VectorOrScalar;
#[cfg(target_arch = "spirv")]
use core::arch::asm;

Expand Down
4 changes: 3 additions & 1 deletion crates/spirv-std/src/float.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
//! Traits and helper functions related to floats.
use crate::scalar::VectorOrScalar;
use crate::vector::Vector;
use crate::vector::{create_dim, VectorOrScalar};
#[cfg(target_arch = "spirv")]
use core::arch::asm;
use core::num::NonZeroUsize;

/// Abstract trait representing a SPIR-V floating point type.
///
Expand Down Expand Up @@ -74,6 +75,7 @@ struct F32x2 {
}
unsafe impl VectorOrScalar for F32x2 {
type Scalar = f32;
const DIM: NonZeroUsize = create_dim(2);
}
unsafe impl Vector<f32, 2> for F32x2 {}

Expand Down
69 changes: 15 additions & 54 deletions crates/spirv-std/src/scalar.rs
Original file line number Diff line number Diff line change
@@ -1,48 +1,7 @@
//! Traits related to scalars.
/// Abstract trait representing either a vector or a scalar type.
///
/// # Safety
/// Implementing this trait on non-scalar or non-vector types may break assumptions about other
/// unsafe code, and should not be done.
pub unsafe trait VectorOrScalar: Default {
/// Either the scalar component type of the vector or the scalar itself.
type Scalar: Scalar;
}

unsafe impl VectorOrScalar for bool {
type Scalar = bool;
}
unsafe impl VectorOrScalar for f32 {
type Scalar = f32;
}
unsafe impl VectorOrScalar for f64 {
type Scalar = f64;
}
unsafe impl VectorOrScalar for u8 {
type Scalar = u8;
}
unsafe impl VectorOrScalar for u16 {
type Scalar = u16;
}
unsafe impl VectorOrScalar for u32 {
type Scalar = u32;
}
unsafe impl VectorOrScalar for u64 {
type Scalar = u64;
}
unsafe impl VectorOrScalar for i8 {
type Scalar = i8;
}
unsafe impl VectorOrScalar for i16 {
type Scalar = i16;
}
unsafe impl VectorOrScalar for i32 {
type Scalar = i32;
}
unsafe impl VectorOrScalar for i64 {
type Scalar = i64;
}
use crate::vector::{create_dim, VectorOrScalar};
use core::num::NonZeroUsize;

/// Abstract trait representing a SPIR-V scalar type.
///
Expand All @@ -54,14 +13,16 @@ pub unsafe trait Scalar:
{
}

unsafe impl Scalar for bool {}
unsafe impl Scalar for f32 {}
unsafe impl Scalar for f64 {}
unsafe impl Scalar for u8 {}
unsafe impl Scalar for u16 {}
unsafe impl Scalar for u32 {}
unsafe impl Scalar for u64 {}
unsafe impl Scalar for i8 {}
unsafe impl Scalar for i16 {}
unsafe impl Scalar for i32 {}
unsafe impl Scalar for i64 {}
macro_rules! impl_scalar {
($($ty:ty),+) => {
$(
unsafe impl VectorOrScalar for $ty {
type Scalar = Self;
const DIM: NonZeroUsize = create_dim(1);
}
unsafe impl Scalar for $ty {}
)+
};
}

impl_scalar!(bool, f32, f64, u8, u16, u32, u64, i8, i16, i32, i64);
89 changes: 35 additions & 54 deletions crates/spirv-std/src/vector.rs
Original file line number Diff line number Diff line change
@@ -1,49 +1,28 @@
//! Traits related to vectors.
use crate::scalar::{Scalar, VectorOrScalar};
use crate::scalar::Scalar;
use core::num::NonZeroUsize;
use glam::{Vec3Swizzles, Vec4Swizzles};

unsafe impl VectorOrScalar for glam::Vec2 {
type Scalar = f32;
}
unsafe impl VectorOrScalar for glam::Vec3 {
type Scalar = f32;
}
unsafe impl VectorOrScalar for glam::Vec3A {
type Scalar = f32;
}
unsafe impl VectorOrScalar for glam::Vec4 {
type Scalar = f32;
}

unsafe impl VectorOrScalar for glam::DVec2 {
type Scalar = f64;
}
unsafe impl VectorOrScalar for glam::DVec3 {
type Scalar = f64;
}
unsafe impl VectorOrScalar for glam::DVec4 {
type Scalar = f64;
}
/// Abstract trait representing either a vector or a scalar type.
///
/// # Safety
/// Implementing this trait on non-scalar or non-vector types may break assumptions about other
/// unsafe code, and should not be done.
pub unsafe trait VectorOrScalar: Default {
/// Either the scalar component type of the vector or the scalar itself.
type Scalar: Scalar;

unsafe impl VectorOrScalar for glam::UVec2 {
type Scalar = u32;
}
unsafe impl VectorOrScalar for glam::UVec3 {
type Scalar = u32;
}
unsafe impl VectorOrScalar for glam::UVec4 {
type Scalar = u32;
/// The dimension of the vector, or 1 if it is a scalar
const DIM: NonZeroUsize;
}

unsafe impl VectorOrScalar for glam::IVec2 {
type Scalar = i32;
}
unsafe impl VectorOrScalar for glam::IVec3 {
type Scalar = i32;
}
unsafe impl VectorOrScalar for glam::IVec4 {
type Scalar = i32;
/// replace with `NonZeroUsize::new(n).unwrap()` once `unwrap()` is const stabilized
pub(crate) const fn create_dim(n: usize) -> NonZeroUsize {
match NonZeroUsize::new(n) {
None => panic!("dim must not be 0"),
Some(n) => n,
}
}

/// Abstract trait representing a SPIR-V vector type.
Expand All @@ -53,22 +32,24 @@ unsafe impl VectorOrScalar for glam::IVec4 {
/// should not be done.
pub unsafe trait Vector<T: Scalar, const N: usize>: VectorOrScalar<Scalar = T> {}

unsafe impl Vector<f32, 2> for glam::Vec2 {}
unsafe impl Vector<f32, 3> for glam::Vec3 {}
unsafe impl Vector<f32, 3> for glam::Vec3A {}
unsafe impl Vector<f32, 4> for glam::Vec4 {}

unsafe impl Vector<f64, 2> for glam::DVec2 {}
unsafe impl Vector<f64, 3> for glam::DVec3 {}
unsafe impl Vector<f64, 4> for glam::DVec4 {}

unsafe impl Vector<u32, 2> for glam::UVec2 {}
unsafe impl Vector<u32, 3> for glam::UVec3 {}
unsafe impl Vector<u32, 4> for glam::UVec4 {}
macro_rules! impl_vector {
($($scalar:ty: $($vec:ty => $dim:literal),+;)+) => {
$($(
unsafe impl VectorOrScalar for $vec {
type Scalar = $scalar;
const DIM: NonZeroUsize = create_dim($dim);
}
unsafe impl Vector<$scalar, $dim> for $vec {}
)+)+
};
}

unsafe impl Vector<i32, 2> for glam::IVec2 {}
unsafe impl Vector<i32, 3> for glam::IVec3 {}
unsafe impl Vector<i32, 4> for glam::IVec4 {}
impl_vector! {
f32: glam::Vec2 => 2, glam::Vec3 => 3, glam::Vec3A => 3, glam::Vec4 => 4;
f64: glam::DVec2 => 2, glam::DVec3 => 3, glam::DVec4 => 4;
u32: glam::UVec2 => 2, glam::UVec3 => 3, glam::UVec4 => 4;
i32: glam::IVec2 => 2, glam::IVec3 => 3, glam::IVec4 => 4;
}

/// Trait that implements slicing of a vector into a scalar or vector of lower dimensions, by
/// ignoring the higter dimensions
Expand Down
7 changes: 6 additions & 1 deletion tests/ui/arch/all.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

#![feature(repr_simd)]

use core::num::NonZeroUsize;
use spirv_std::spirv;
use spirv_std::{scalar::Scalar, scalar::VectorOrScalar, vector::Vector};
use spirv_std::{scalar::Scalar, vector::Vector, vector::VectorOrScalar};

/// HACK(shesp). Rust doesn't allow us to declare regular (tuple-)structs containing `bool` members
/// as `#[repl(simd)]`. But we need this for `spirv_std::arch::any()` and `spirv_std::arch::all()`
Expand All @@ -14,6 +15,10 @@ use spirv_std::{scalar::Scalar, scalar::VectorOrScalar, vector::Vector};
struct Vec2<T>(T, T);
unsafe impl<T: Scalar> VectorOrScalar for Vec2<T> {
type Scalar = T;
const DIM: NonZeroUsize = match NonZeroUsize::new(2) {
None => panic!(),
Some(n) => n,
};
}
unsafe impl<T: Scalar> Vector<T, 2> for Vec2<T> {}

Expand Down
7 changes: 6 additions & 1 deletion tests/ui/arch/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

#![feature(repr_simd)]

use core::num::NonZeroUsize;
use spirv_std::spirv;
use spirv_std::{scalar::Scalar, scalar::VectorOrScalar, vector::Vector};
use spirv_std::{scalar::Scalar, vector::Vector, vector::VectorOrScalar};

/// HACK(shesp). Rust doesn't allow us to declare regular (tuple-)structs containing `bool` members
/// as `#[repl(simd)]`. But we need this for `spirv_std::arch::any()` and `spirv_std::arch::all()`
Expand All @@ -14,6 +15,10 @@ use spirv_std::{scalar::Scalar, scalar::VectorOrScalar, vector::Vector};
struct Vec2<T>(T, T);
unsafe impl<T: Scalar> VectorOrScalar for Vec2<T> {
type Scalar = T;
const DIM: NonZeroUsize = match NonZeroUsize::new(2) {
None => panic!(),
Some(n) => n,
};
}
unsafe impl<T: Scalar> Vector<T, 2> for Vec2<T> {}

Expand Down

0 comments on commit d56c0ed

Please sign in to comment.