From d56c0ed77b0f00a1e2766e792cd099efb8fccc9b Mon Sep 17 00:00:00 2001 From: Firestar99 <4696087-firestar99@users.noreply.gitlab.com> Date: Mon, 23 Sep 2024 11:23:31 +0200 Subject: [PATCH] subgroup: make VectorOrScalar trait match discussions in https://github.com/EmbarkStudios/rust-gpu/pull/1030 --- crates/spirv-std/src/arch/subgroup.rs | 2 +- crates/spirv-std/src/float.rs | 4 +- crates/spirv-std/src/scalar.rs | 69 +++++---------------- crates/spirv-std/src/vector.rs | 89 +++++++++++---------------- tests/ui/arch/all.rs | 7 ++- tests/ui/arch/any.rs | 7 ++- 6 files changed, 66 insertions(+), 112 deletions(-) diff --git a/crates/spirv-std/src/arch/subgroup.rs b/crates/spirv-std/src/arch/subgroup.rs index a8eb38fc2a..b587f0e0ad 100644 --- a/crates/spirv-std/src/arch/subgroup.rs +++ b/crates/spirv-std/src/arch/subgroup.rs @@ -4,7 +4,7 @@ use crate::float::Float; use crate::integer::{Integer, SignedInteger, UnsignedInteger}; #[cfg(target_arch = "spirv")] use crate::memory::{Scope, Semantics}; -use crate::scalar::VectorOrScalar; +use crate::vector::VectorOrScalar; #[cfg(target_arch = "spirv")] use core::arch::asm; diff --git a/crates/spirv-std/src/float.rs b/crates/spirv-std/src/float.rs index be9133ee0a..ce441e99b5 100644 --- a/crates/spirv-std/src/float.rs +++ b/crates/spirv-std/src/float.rs @@ -1,9 +1,10 @@ //! Traits and helper functions related to floats. -use crate::scalar::VectorOrScalar; use crate::vector::Vector; +use crate::vector::{create_dim, VectorOrScalar}; #[cfg(target_arch = "spirv")] use core::arch::asm; +use core::num::NonZeroUsize; /// Abstract trait representing a SPIR-V floating point type. /// @@ -74,6 +75,7 @@ struct F32x2 { } unsafe impl VectorOrScalar for F32x2 { type Scalar = f32; + const DIM: NonZeroUsize = create_dim(2); } unsafe impl Vector for F32x2 {} diff --git a/crates/spirv-std/src/scalar.rs b/crates/spirv-std/src/scalar.rs index 9747cc995e..e9ab3ae758 100644 --- a/crates/spirv-std/src/scalar.rs +++ b/crates/spirv-std/src/scalar.rs @@ -1,48 +1,7 @@ //! Traits related to scalars. -/// Abstract trait representing either a vector or a scalar type. -/// -/// # Safety -/// Implementing this trait on non-scalar or non-vector types may break assumptions about other -/// unsafe code, and should not be done. -pub unsafe trait VectorOrScalar: Default { - /// Either the scalar component type of the vector or the scalar itself. - type Scalar: Scalar; -} - -unsafe impl VectorOrScalar for bool { - type Scalar = bool; -} -unsafe impl VectorOrScalar for f32 { - type Scalar = f32; -} -unsafe impl VectorOrScalar for f64 { - type Scalar = f64; -} -unsafe impl VectorOrScalar for u8 { - type Scalar = u8; -} -unsafe impl VectorOrScalar for u16 { - type Scalar = u16; -} -unsafe impl VectorOrScalar for u32 { - type Scalar = u32; -} -unsafe impl VectorOrScalar for u64 { - type Scalar = u64; -} -unsafe impl VectorOrScalar for i8 { - type Scalar = i8; -} -unsafe impl VectorOrScalar for i16 { - type Scalar = i16; -} -unsafe impl VectorOrScalar for i32 { - type Scalar = i32; -} -unsafe impl VectorOrScalar for i64 { - type Scalar = i64; -} +use crate::vector::{create_dim, VectorOrScalar}; +use core::num::NonZeroUsize; /// Abstract trait representing a SPIR-V scalar type. /// @@ -54,14 +13,16 @@ pub unsafe trait Scalar: { } -unsafe impl Scalar for bool {} -unsafe impl Scalar for f32 {} -unsafe impl Scalar for f64 {} -unsafe impl Scalar for u8 {} -unsafe impl Scalar for u16 {} -unsafe impl Scalar for u32 {} -unsafe impl Scalar for u64 {} -unsafe impl Scalar for i8 {} -unsafe impl Scalar for i16 {} -unsafe impl Scalar for i32 {} -unsafe impl Scalar for i64 {} +macro_rules! impl_scalar { + ($($ty:ty),+) => { + $( + unsafe impl VectorOrScalar for $ty { + type Scalar = Self; + const DIM: NonZeroUsize = create_dim(1); + } + unsafe impl Scalar for $ty {} + )+ + }; +} + +impl_scalar!(bool, f32, f64, u8, u16, u32, u64, i8, i16, i32, i64); diff --git a/crates/spirv-std/src/vector.rs b/crates/spirv-std/src/vector.rs index 7510953034..137d3f2151 100644 --- a/crates/spirv-std/src/vector.rs +++ b/crates/spirv-std/src/vector.rs @@ -1,49 +1,28 @@ //! Traits related to vectors. -use crate::scalar::{Scalar, VectorOrScalar}; +use crate::scalar::Scalar; +use core::num::NonZeroUsize; use glam::{Vec3Swizzles, Vec4Swizzles}; -unsafe impl VectorOrScalar for glam::Vec2 { - type Scalar = f32; -} -unsafe impl VectorOrScalar for glam::Vec3 { - type Scalar = f32; -} -unsafe impl VectorOrScalar for glam::Vec3A { - type Scalar = f32; -} -unsafe impl VectorOrScalar for glam::Vec4 { - type Scalar = f32; -} - -unsafe impl VectorOrScalar for glam::DVec2 { - type Scalar = f64; -} -unsafe impl VectorOrScalar for glam::DVec3 { - type Scalar = f64; -} -unsafe impl VectorOrScalar for glam::DVec4 { - type Scalar = f64; -} +/// Abstract trait representing either a vector or a scalar type. +/// +/// # Safety +/// Implementing this trait on non-scalar or non-vector types may break assumptions about other +/// unsafe code, and should not be done. +pub unsafe trait VectorOrScalar: Default { + /// Either the scalar component type of the vector or the scalar itself. + type Scalar: Scalar; -unsafe impl VectorOrScalar for glam::UVec2 { - type Scalar = u32; -} -unsafe impl VectorOrScalar for glam::UVec3 { - type Scalar = u32; -} -unsafe impl VectorOrScalar for glam::UVec4 { - type Scalar = u32; + /// The dimension of the vector, or 1 if it is a scalar + const DIM: NonZeroUsize; } -unsafe impl VectorOrScalar for glam::IVec2 { - type Scalar = i32; -} -unsafe impl VectorOrScalar for glam::IVec3 { - type Scalar = i32; -} -unsafe impl VectorOrScalar for glam::IVec4 { - type Scalar = i32; +/// replace with `NonZeroUsize::new(n).unwrap()` once `unwrap()` is const stabilized +pub(crate) const fn create_dim(n: usize) -> NonZeroUsize { + match NonZeroUsize::new(n) { + None => panic!("dim must not be 0"), + Some(n) => n, + } } /// Abstract trait representing a SPIR-V vector type. @@ -53,22 +32,24 @@ unsafe impl VectorOrScalar for glam::IVec4 { /// should not be done. pub unsafe trait Vector: VectorOrScalar {} -unsafe impl Vector for glam::Vec2 {} -unsafe impl Vector for glam::Vec3 {} -unsafe impl Vector for glam::Vec3A {} -unsafe impl Vector for glam::Vec4 {} - -unsafe impl Vector for glam::DVec2 {} -unsafe impl Vector for glam::DVec3 {} -unsafe impl Vector for glam::DVec4 {} - -unsafe impl Vector for glam::UVec2 {} -unsafe impl Vector for glam::UVec3 {} -unsafe impl Vector for glam::UVec4 {} +macro_rules! impl_vector { + ($($scalar:ty: $($vec:ty => $dim:literal),+;)+) => { + $($( + unsafe impl VectorOrScalar for $vec { + type Scalar = $scalar; + const DIM: NonZeroUsize = create_dim($dim); + } + unsafe impl Vector<$scalar, $dim> for $vec {} + )+)+ + }; +} -unsafe impl Vector for glam::IVec2 {} -unsafe impl Vector for glam::IVec3 {} -unsafe impl Vector for glam::IVec4 {} +impl_vector! { + f32: glam::Vec2 => 2, glam::Vec3 => 3, glam::Vec3A => 3, glam::Vec4 => 4; + f64: glam::DVec2 => 2, glam::DVec3 => 3, glam::DVec4 => 4; + u32: glam::UVec2 => 2, glam::UVec3 => 3, glam::UVec4 => 4; + i32: glam::IVec2 => 2, glam::IVec3 => 3, glam::IVec4 => 4; +} /// Trait that implements slicing of a vector into a scalar or vector of lower dimensions, by /// ignoring the higter dimensions diff --git a/tests/ui/arch/all.rs b/tests/ui/arch/all.rs index fbedae03c4..472a2d82a0 100644 --- a/tests/ui/arch/all.rs +++ b/tests/ui/arch/all.rs @@ -2,8 +2,9 @@ #![feature(repr_simd)] +use core::num::NonZeroUsize; use spirv_std::spirv; -use spirv_std::{scalar::Scalar, scalar::VectorOrScalar, vector::Vector}; +use spirv_std::{scalar::Scalar, vector::Vector, vector::VectorOrScalar}; /// HACK(shesp). Rust doesn't allow us to declare regular (tuple-)structs containing `bool` members /// as `#[repl(simd)]`. But we need this for `spirv_std::arch::any()` and `spirv_std::arch::all()` @@ -14,6 +15,10 @@ use spirv_std::{scalar::Scalar, scalar::VectorOrScalar, vector::Vector}; struct Vec2(T, T); unsafe impl VectorOrScalar for Vec2 { type Scalar = T; + const DIM: NonZeroUsize = match NonZeroUsize::new(2) { + None => panic!(), + Some(n) => n, + }; } unsafe impl Vector for Vec2 {} diff --git a/tests/ui/arch/any.rs b/tests/ui/arch/any.rs index 5f4caed88f..c61928fed9 100644 --- a/tests/ui/arch/any.rs +++ b/tests/ui/arch/any.rs @@ -2,8 +2,9 @@ #![feature(repr_simd)] +use core::num::NonZeroUsize; use spirv_std::spirv; -use spirv_std::{scalar::Scalar, scalar::VectorOrScalar, vector::Vector}; +use spirv_std::{scalar::Scalar, vector::Vector, vector::VectorOrScalar}; /// HACK(shesp). Rust doesn't allow us to declare regular (tuple-)structs containing `bool` members /// as `#[repl(simd)]`. But we need this for `spirv_std::arch::any()` and `spirv_std::arch::all()` @@ -14,6 +15,10 @@ use spirv_std::{scalar::Scalar, scalar::VectorOrScalar, vector::Vector}; struct Vec2(T, T); unsafe impl VectorOrScalar for Vec2 { type Scalar = T; + const DIM: NonZeroUsize = match NonZeroUsize::new(2) { + None => panic!(), + Some(n) => n, + }; } unsafe impl Vector for Vec2 {}