From 8ee360c8aefc3d8505a9d1b1716d92b27049f6a9 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Sat, 7 Jan 2023 06:26:31 +0200 Subject: [PATCH] [WIP] add `#[spirv(typed_buffer)]` for explicit `SpirvType::InterfaceBlock`s. --- crates/rustc_codegen_spirv/src/abi.rs | 38 +++- crates/rustc_codegen_spirv/src/attr.rs | 1 + .../src/builder/spirv_asm.rs | 6 +- .../src/codegen_cx/entry.rs | 171 ++++++++++-------- crates/rustc_codegen_spirv/src/spirv_type.rs | 10 +- crates/rustc_codegen_spirv/src/symbols.rs | 4 + crates/spirv-std/src/lib.rs | 2 + crates/spirv-std/src/typed_buffer.rs | 100 ++++++++++ .../ui/arch/debug_printf_type_checking.stderr | 20 +- .../runtime_descriptor_array_error.stderr | 2 +- 10 files changed, 259 insertions(+), 95 deletions(-) create mode 100644 crates/spirv-std/src/typed_buffer.rs diff --git a/crates/rustc_codegen_spirv/src/abi.rs b/crates/rustc_codegen_spirv/src/abi.rs index 0d3b8de664..7e8cf0a2e2 100644 --- a/crates/rustc_codegen_spirv/src/abi.rs +++ b/crates/rustc_codegen_spirv/src/abi.rs @@ -936,11 +936,13 @@ fn trans_intrinsic_type<'tcx>( .err("#[spirv(runtime_array)] type must have size 4")); } - // We use a generic to indicate the underlying element type. - // The spirv type of it will be generated by querying the type of the first generic. + // We use a generic param to indicate the underlying element type. + // The SPIR-V element type will be generated from the first generic param. if let Some(elem_ty) = args.types().next() { - let element = cx.layout_of(elem_ty).spirv_type(span, cx); - Ok(SpirvType::RuntimeArray { element }.def(span, cx)) + Ok(SpirvType::RuntimeArray { + element: cx.layout_of(elem_ty).spirv_type(span, cx), + } + .def(span, cx)) } else { Err(cx .tcx @@ -948,6 +950,34 @@ fn trans_intrinsic_type<'tcx>( .err("#[spirv(runtime_array)] type must have a generic element type")) } } + IntrinsicType::TypedBuffer => { + if ty.size != Size::from_bytes(4) { + return Err(cx + .tcx + .sess + .dcx() + .err("#[spirv(typed_buffer)] type must have size 4")); + } + + // We use a generic param to indicate the underlying data type. + // The SPIR-V data type will be generated from the first generic param. + if let Some(data_ty) = args.types().next() { + // HACK(eddyb) this should be a *pointer* to an "interface block", + // but SPIR-V screwed up and used no explicit indirection for the + // descriptor indexing case, and instead made a `RuntimeArray` of + // `InterfaceBlock`s be an "array of typed buffer resources". + Ok(SpirvType::InterfaceBlock { + inner_type: cx.layout_of(data_ty).spirv_type(span, cx), + } + .def(span, cx)) + } else { + Err(cx + .tcx + .sess + .dcx() + .err("#[spirv(typed_buffer)] type must have a generic data type")) + } + } IntrinsicType::Matrix => { let span = def_id_for_spirv_type_adt(ty) .map(|did| cx.tcx.def_span(did)) diff --git a/crates/rustc_codegen_spirv/src/attr.rs b/crates/rustc_codegen_spirv/src/attr.rs index f47062f31f..d3976747ba 100644 --- a/crates/rustc_codegen_spirv/src/attr.rs +++ b/crates/rustc_codegen_spirv/src/attr.rs @@ -65,6 +65,7 @@ pub enum IntrinsicType { SampledImage, RayQueryKhr, RuntimeArray, + TypedBuffer, Matrix, } diff --git a/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs b/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs index 623893e465..4c0c92cb58 100644 --- a/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs +++ b/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs @@ -700,7 +700,11 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { }; ty = match cx.lookup_type(ty) { SpirvType::Array { element, .. } - | SpirvType::RuntimeArray { element } => element, + | SpirvType::RuntimeArray { element } + // HACK(eddyb) this is pretty bad because it's not + // checking that the index is an `OpConstant 0`, but + // there's no other valid choice anyway. + | SpirvType::InterfaceBlock { inner_type: element } => element, SpirvType::Adt { field_types, .. } => *index_to_usize() .and_then(|i| field_types.get(i)) diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs index 8979e9d9d5..f04080b88b 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs @@ -495,71 +495,98 @@ impl<'tcx> CodegenCx<'tcx> { .dcx() .span_fatal(hir_param.ty_span, "pair type not supported yet") } + // FIXME(eddyb) should this talk about "typed buffers" instead of "interface blocks"? + // FIXME(eddyb) should we talk about "descriptor indexing" or + // actually use more reasonable terms like "resource arrays"? + let needs_interface_block_and_supports_descriptor_indexing = matches!( + storage_class, + Ok(StorageClass::Uniform | StorageClass::StorageBuffer) + ); + let needs_interface_block = needs_interface_block_and_supports_descriptor_indexing + || storage_class == Ok(StorageClass::PushConstant); + // NOTE(eddyb) `#[spirv(typed_buffer)]` adds `SpirvType::InterfaceBlock`s + // which must bypass the automated ones (i.e. the user is taking control). + let has_explicit_interface_block = needs_interface_block_and_supports_descriptor_indexing + && { + // Peel off arrays first (used for "descriptor indexing"). + let outermost_or_array_element = match self.lookup_type(value_spirv_type) { + SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element } => { + element + } + _ => value_spirv_type, + }; + matches!( + self.lookup_type(outermost_or_array_element), + SpirvType::InterfaceBlock { .. } + ) + }; let var_ptr_spirv_type; - let (value_ptr, value_len) = match storage_class { - Ok( - StorageClass::PushConstant | StorageClass::Uniform | StorageClass::StorageBuffer, - ) => { - let var_spirv_type = SpirvType::InterfaceBlock { - inner_type: value_spirv_type, - } - .def(hir_param.span, self); - var_ptr_spirv_type = self.type_ptr_to(var_spirv_type); - - let zero_u32 = self.constant_u32(hir_param.span, 0).def_cx(self); - let value_ptr_spirv_type = self.type_ptr_to(value_spirv_type); - let value_ptr = bx - .emit() - .in_bounds_access_chain( - value_ptr_spirv_type, - None, - var_id.unwrap(), - [zero_u32].iter().cloned(), - ) - .unwrap() - .with_type(value_ptr_spirv_type); + let (value_ptr, value_len) = if needs_interface_block && !has_explicit_interface_block { + let var_spirv_type = SpirvType::InterfaceBlock { + inner_type: value_spirv_type, + } + .def(hir_param.span, self); + var_ptr_spirv_type = self.type_ptr_to(var_spirv_type); + + let zero_u32 = self.constant_u32(hir_param.span, 0).def_cx(self); + let value_ptr_spirv_type = self.type_ptr_to(value_spirv_type); + let value_ptr = bx + .emit() + .in_bounds_access_chain( + value_ptr_spirv_type, + None, + var_id.unwrap(), + [zero_u32].iter().cloned(), + ) + .unwrap() + .with_type(value_ptr_spirv_type); - let value_len = if is_unsized_with_len { - match self.lookup_type(value_spirv_type) { - SpirvType::RuntimeArray { .. } => {} - _ => { - self.tcx.dcx().span_err( - hir_param.ty_span, - "only plain slices are supported as unsized types", - ); - } + let value_len = if is_unsized_with_len { + match self.lookup_type(value_spirv_type) { + SpirvType::RuntimeArray { .. } => {} + _ => { + self.tcx.dcx().span_err( + hir_param.ty_span, + "only plain slices are supported as unsized types", + ); } + } - // FIXME(eddyb) shouldn't this be `usize`? - let len_spirv_type = self.type_isize(); - let len = bx - .emit() - .array_length(len_spirv_type, None, var_id.unwrap(), 0) - .unwrap(); - - Some(len.with_type(len_spirv_type)) - } else { - if is_unsized { - // It's OK to use a RuntimeArray and not have a length parameter, but - // it's just nicer ergonomics to use a slice. - self.tcx - .dcx() - .span_warn(hir_param.ty_span, "use &[T] instead of &RuntimeArray"); - } - None - }; + // FIXME(eddyb) shouldn't this be `usize`? + let len_spirv_type = self.type_isize(); + let len = bx + .emit() + .array_length(len_spirv_type, None, var_id.unwrap(), 0) + .unwrap(); - (Ok(value_ptr), value_len) - } - Ok(StorageClass::UniformConstant) => { - var_ptr_spirv_type = self.type_ptr_to(value_spirv_type); + Some(len.with_type(len_spirv_type)) + } else { + if is_unsized { + // It's OK to use a RuntimeArray and not have a length parameter, but + // it's just nicer ergonomics to use a slice. + self.tcx + .dcx() + .span_warn(hir_param.ty_span, "use &[T] instead of &RuntimeArray"); + } + None + }; + (Ok(value_ptr), value_len) + } else { + var_ptr_spirv_type = self.type_ptr_to(value_spirv_type); + + // FIXME(eddyb) should we talk about "descriptor indexing" or + // actually use more reasonable terms like "resource arrays"? + let unsized_is_descriptor_indexing = + needs_interface_block_and_supports_descriptor_indexing + || storage_class == Ok(StorageClass::UniformConstant); + if unsized_is_descriptor_indexing { match self.lookup_type(value_spirv_type) { SpirvType::RuntimeArray { .. } => { if is_unsized_with_len { self.tcx.dcx().span_err( hir_param.ty_span, - "uniform_constant must use &RuntimeArray, not &[T]", + "descriptor indexing must use &RuntimeArray, not &[T]", ); } } @@ -567,24 +594,15 @@ impl<'tcx> CodegenCx<'tcx> { if is_unsized { self.tcx.dcx().span_err( hir_param.ty_span, - "only plain slices are supported as unsized types", + "only RuntimeArray is supported, not other unsized types", ); } } } - - let value_len = if is_pair { - // We've already emitted an error, fill in a placeholder value - Some(bx.undef(self.type_isize())) - } else { - None - }; - - (Ok(var_id.unwrap().with_type(var_ptr_spirv_type)), value_len) - } - _ => { - var_ptr_spirv_type = self.type_ptr_to(value_spirv_type); - + } else { + // FIXME(eddyb) determine, based on the type, what kind of type + // this is, to narrow it further to e.g. "buffer in a non-buffer + // storage class" or "storage class expects fixed data sizes". if is_unsized { self.tcx.dcx().span_fatal( hir_param.ty_span, @@ -597,12 +615,19 @@ impl<'tcx> CodegenCx<'tcx> { ), ); } - - ( - var_id.map(|var_id| var_id.with_type(var_ptr_spirv_type)), - None, - ) } + + let value_len = if is_pair { + // We've already emitted an error, fill in a placeholder value + Some(bx.undef(self.type_isize())) + } else { + None + }; + + ( + var_id.map(|var_id| var_id.with_type(var_ptr_spirv_type)), + value_len, + ) }; // Compute call argument(s) to match what the Rust entry `fn` expects, diff --git a/crates/rustc_codegen_spirv/src/spirv_type.rs b/crates/rustc_codegen_spirv/src/spirv_type.rs index 074783116e..ad97e78c52 100644 --- a/crates/rustc_codegen_spirv/src/spirv_type.rs +++ b/crates/rustc_codegen_spirv/src/spirv_type.rs @@ -347,9 +347,8 @@ impl SpirvType<'_> { | Self::AccelerationStructureKhr | Self::RayQueryKhr | Self::Sampler - | Self::SampledImage { .. } => Size::from_bytes(4), - - Self::InterfaceBlock { inner_type } => cx.lookup_type(inner_type).sizeof(cx)?, + | Self::SampledImage { .. } + | Self::InterfaceBlock { .. } => Size::from_bytes(4), }; Some(result) } @@ -377,9 +376,8 @@ impl SpirvType<'_> { | Self::AccelerationStructureKhr | Self::RayQueryKhr | Self::Sampler - | Self::SampledImage { .. } => Align::from_bytes(4).unwrap(), - - Self::InterfaceBlock { inner_type } => cx.lookup_type(inner_type).alignof(cx), + | Self::SampledImage { .. } + | Self::InterfaceBlock { .. } => Align::from_bytes(4).unwrap(), } } diff --git a/crates/rustc_codegen_spirv/src/symbols.rs b/crates/rustc_codegen_spirv/src/symbols.rs index b0dae13875..f7038a0472 100644 --- a/crates/rustc_codegen_spirv/src/symbols.rs +++ b/crates/rustc_codegen_spirv/src/symbols.rs @@ -340,6 +340,10 @@ impl Symbols { "runtime_array", SpirvAttribute::IntrinsicType(IntrinsicType::RuntimeArray), ), + ( + "typed_buffer", + SpirvAttribute::IntrinsicType(IntrinsicType::TypedBuffer), + ), ( "matrix", SpirvAttribute::IntrinsicType(IntrinsicType::Matrix), diff --git a/crates/spirv-std/src/lib.rs b/crates/spirv-std/src/lib.rs index 3f4920c3e2..175099537a 100644 --- a/crates/spirv-std/src/lib.rs +++ b/crates/spirv-std/src/lib.rs @@ -107,6 +107,7 @@ mod runtime_array; mod sampler; pub mod scalar; pub(crate) mod sealed; +mod typed_buffer; pub mod vector; pub use self::sampler::Sampler; @@ -114,6 +115,7 @@ pub use crate::macros::Image; pub use byte_addressable_buffer::ByteAddressableBuffer; pub use num_traits; pub use runtime_array::*; +pub use typed_buffer::*; pub use glam; diff --git a/crates/spirv-std/src/typed_buffer.rs b/crates/spirv-std/src/typed_buffer.rs new file mode 100644 index 0000000000..9b5bff0747 --- /dev/null +++ b/crates/spirv-std/src/typed_buffer.rs @@ -0,0 +1,100 @@ +#[cfg(target_arch = "spirv")] +use core::arch::asm; +use core::marker::PhantomData; +use core::ops::{Deref, DerefMut}; + +/// Explicit (uniform/storage) buffer handle for descriptor indexing. +/// +/// Examples (for an `#[spirv(storage_buffer)]`-annotated entry-point parameter): +/// - `buffer: &[u32]` (implicit, 1 buffer) +/// - `buffer: &TypedBuffer<[u32]>` (explicit, one buffer) +/// - `buffers: &RuntimeArray>` (explicit, many buffers) +// +// TODO(eddyb) fully document! +#[spirv(typed_buffer)] +// HACK(eddyb) avoids "transparent newtype of `_anti_zst_padding`" misinterpretation. +#[repr(C)] +pub struct TypedBuffer { + // HACK(eddyb) avoids the layout becoming ZST (and being elided in one way + // or another, before `#[spirv(runtime_array)]` can special-case it). + _anti_zst_padding: core::mem::MaybeUninit, + _phantom: PhantomData, +} + +impl Deref for TypedBuffer { + type Target = T; + #[spirv_std_macros::gpu_only] + fn deref(&self) -> &T { + unsafe { + let mut result_slot = core::mem::MaybeUninit::uninit(); + asm! { + "%uint = OpTypeInt 32 0", + "%uint_0 = OpConstant %uint 0", + "%result = OpAccessChain _ {buffer} %uint_0", + "OpStore {result_slot} %result", + buffer = in(reg) self, + result_slot = in(reg) result_slot.as_mut_ptr(), + } + result_slot.assume_init() + } + } +} + +impl DerefMut for TypedBuffer { + #[spirv_std_macros::gpu_only] + fn deref_mut(&mut self) -> &mut T { + unsafe { + let mut result_slot = core::mem::MaybeUninit::uninit(); + asm! { + "%uint = OpTypeInt 32 0", + "%uint_0 = OpConstant %uint 0", + "%result = OpAccessChain _ {buffer} %uint_0", + "OpStore {result_slot} %result", + buffer = in(reg) self, + result_slot = in(reg) result_slot.as_mut_ptr(), + } + result_slot.assume_init() + } + } +} + +impl Deref for TypedBuffer<[T]> { + type Target = [T]; + #[spirv_std_macros::gpu_only] + fn deref(&self) -> &[T] { + unsafe { + let mut result_slot = core::mem::MaybeUninit::uninit(); + asm! { + "%uint = OpTypeInt 32 0", + "%uint_0 = OpConstant %uint 0", + "%inner_ptr = OpAccessChain _ {buffer} %uint_0", + "%inner_len = OpArrayLength %uint {buffer} 0", + "%result = OpCompositeConstruct typeof*{result_slot} %inner_ptr %inner_len", + "OpStore {result_slot} %result", + buffer = in(reg) self, + result_slot = in(reg) result_slot.as_mut_ptr(), + } + result_slot.assume_init() + } + } +} + +impl DerefMut for TypedBuffer<[T]> { + #[spirv_std_macros::gpu_only] + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { + let mut result_slot = core::mem::MaybeUninit::uninit(); + asm! { + "%uint = OpTypeInt 32 0", + "%uint_0 = OpConstant %uint 0", + "%inner_ptr = OpAccessChain _ {buffer} %uint_0", + "%inner_len = OpArrayLength %uint {buffer} 0", + "%result = OpCompositeConstruct typeof*{result_slot} %inner_ptr %inner_len", + "OpStore {result_slot} %result", + buffer = in(reg) self, + result_slot = in(reg) result_slot.as_mut_ptr(), + } + result_slot.assume_init() + } + } +} diff --git a/tests/ui/arch/debug_printf_type_checking.stderr b/tests/ui/arch/debug_printf_type_checking.stderr index fcd2013a15..b027cc7b34 100644 --- a/tests/ui/arch/debug_printf_type_checking.stderr +++ b/tests/ui/arch/debug_printf_type_checking.stderr @@ -75,9 +75,9 @@ help: the return type of this call is `u32` due to the type of the argument pass | | | this argument influences the return type of `spirv_std` note: function defined here - --> $SPIRV_STD_SRC/lib.rs:136:8 + --> $SPIRV_STD_SRC/lib.rs:138:8 | -136 | pub fn debug_printf_assert_is_type(ty: T) -> T { +138 | pub fn debug_printf_assert_is_type(ty: T) -> T { | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ = note: this error originates in the macro `debug_printf` (in Nightly builds, run with -Z macro-backtrace for more info) help: change the type of the numeric literal from `u32` to `f32` @@ -102,9 +102,9 @@ help: the return type of this call is `f32` due to the type of the argument pass | | | this argument influences the return type of `spirv_std` note: function defined here - --> $SPIRV_STD_SRC/lib.rs:136:8 + --> $SPIRV_STD_SRC/lib.rs:138:8 | -136 | pub fn debug_printf_assert_is_type(ty: T) -> T { +138 | pub fn debug_printf_assert_is_type(ty: T) -> T { | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ = note: this error originates in the macro `debug_printf` (in Nightly builds, run with -Z macro-backtrace for more info) help: change the type of the numeric literal from `f32` to `u32` @@ -129,12 +129,12 @@ error[E0277]: the trait bound `{float}: Vector` is not satisfied > and 5 others note: required by a bound in `debug_printf_assert_is_vector` - --> $SPIRV_STD_SRC/lib.rs:143:8 + --> $SPIRV_STD_SRC/lib.rs:145:8 | -141 | pub fn debug_printf_assert_is_vector< +143 | pub fn debug_printf_assert_is_vector< | ----------------------------- required by a bound in this function -142 | TY: crate::scalar::Scalar, -143 | V: crate::vector::Vector, +144 | TY: crate::scalar::Scalar, +145 | V: crate::vector::Vector, | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `debug_printf_assert_is_vector` = note: this error originates in the macro `debug_printf` (in Nightly builds, run with -Z macro-backtrace for more info) @@ -155,9 +155,9 @@ help: the return type of this call is `Vec2` due to the type of the argument pas | | | this argument influences the return type of `spirv_std` note: function defined here - --> $SPIRV_STD_SRC/lib.rs:136:8 + --> $SPIRV_STD_SRC/lib.rs:138:8 | -136 | pub fn debug_printf_assert_is_type(ty: T) -> T { +138 | pub fn debug_printf_assert_is_type(ty: T) -> T { | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ = note: this error originates in the macro `debug_printf` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/tests/ui/storage_class/runtime_descriptor_array_error.stderr b/tests/ui/storage_class/runtime_descriptor_array_error.stderr index a6025a5fd0..a3c0d1b0b6 100644 --- a/tests/ui/storage_class/runtime_descriptor_array_error.stderr +++ b/tests/ui/storage_class/runtime_descriptor_array_error.stderr @@ -1,4 +1,4 @@ -error: uniform_constant must use &RuntimeArray, not &[T] +error: descriptor indexing must use &RuntimeArray, not &[T] --> $DIR/runtime_descriptor_array_error.rs:7:52 | 7 | #[spirv(descriptor_set = 0, binding = 0)] one: &[Image!(2D, type=f32, sampled)],