diff --git a/CHANGELOG.md b/CHANGELOG.md index c926d003a9..80b811e89d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -93,6 +93,10 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216). - Allow using [VK_GOOGLE_display_timing](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_GOOGLE_display_timing.html) unsafely with the `VULKAN_GOOGLE_DISPLAY_TIMING` feature. By @DJMcNab in [#6149](https://github.com/gfx-rs/wgpu/pull/6149) +#### Metal + +- Implement `atomicCompareExchangeWeak`. By @AsherJingkongChen in [#6265](https://github.com/gfx-rs/wgpu/pull/6265) + ### Bug Fixes - Fix incorrect hlsl image output type conversion. By @atlv24 in [#6123](https://github.com/gfx-rs/wgpu/pull/6123) diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 7ab97f491c..7fb30b4f7b 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -33,6 +33,7 @@ const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection"; const RAY_QUERY_FIELD_READY: &str = "ready"; const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type"; +pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit"; pub(crate) const MODF_FUNCTION: &str = "naga_modf"; pub(crate) const FREXP_FUNCTION: &str = "naga_frexp"; @@ -1279,42 +1280,6 @@ impl Writer { Ok(()) } - fn put_atomic_operation( - &mut self, - pointer: Handle, - key: &str, - value: Handle, - context: &ExpressionContext, - ) -> BackendResult { - // If the pointer we're passing to the atomic operation needs to be conditional - // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and - // the pointer operand should be unchecked. - let policy = context.choose_bounds_check_policy(pointer); - let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite - && self.put_bounds_checks(pointer, context, back::Level(0), "")?; - - // If requested and successfully put bounds checks, continue the ternary expression. - if checked { - write!(self.out, " ? ")?; - } - - write!( - self.out, - "{NAMESPACE}::atomic_{key}_explicit({ATOMIC_REFERENCE}" - )?; - self.put_access_chain(pointer, policy, context)?; - write!(self.out, ", ")?; - self.put_expression(value, context, true)?; - write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?; - - // Finish the ternary expression. - if checked { - write!(self.out, " : DefaultConstructible()")?; - } - - Ok(()) - } - /// Emit code for the arithmetic expression of the dot product. /// fn put_dot_product( @@ -3182,24 +3147,65 @@ impl Writer { value, result, } => { + let context = &context.expression; + // This backend supports `SHADER_INT64_ATOMIC_MIN_MAX` but not // `SHADER_INT64_ATOMIC_ALL_OPS`, so we can assume that if `result` is // `Some`, we are not operating on a 64-bit value, and that if we are // operating on a 64-bit value, `result` is `None`. write!(self.out, "{level}")?; - let fun_str = if let Some(result) = result { + let fun_key = if let Some(result) = result { let res_name = Baked(result).to_string(); - self.start_baking_expression(result, &context.expression, &res_name)?; + self.start_baking_expression(result, context, &res_name)?; self.named_expressions.insert(result, res_name); - fun.to_msl()? - } else if context.expression.resolve_type(value).scalar_width() == Some(8) { + fun.to_msl() + } else if context.resolve_type(value).scalar_width() == Some(8) { fun.to_msl_64_bit()? } else { - fun.to_msl()? + fun.to_msl() }; - self.put_atomic_operation(pointer, fun_str, value, &context.expression)?; - // done + // If the pointer we're passing to the atomic operation needs to be conditional + // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and + // the pointer operand should be unchecked. + let policy = context.choose_bounds_check_policy(pointer); + let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite + && self.put_bounds_checks(pointer, context, back::Level(0), "")?; + + // If requested and successfully put bounds checks, continue the ternary expression. + if checked { + write!(self.out, " ? ")?; + } + + // Put the atomic function invocation. + match *fun { + crate::AtomicFunction::Exchange { compare: Some(cmp) } => { + write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?; + self.put_access_chain(pointer, policy, context)?; + write!(self.out, ", ")?; + self.put_expression(cmp, context, true)?; + write!(self.out, ", ")?; + self.put_expression(value, context, true)?; + write!(self.out, ")")?; + } + _ => { + write!( + self.out, + "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}" + )?; + self.put_access_chain(pointer, policy, context)?; + write!(self.out, ", ")?; + self.put_expression(value, context, true)?; + write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?; + } + } + + // Finish the ternary expression. + if checked { + write!(self.out, " : DefaultConstructible()")?; + } + + // Done writeln!(self.out, ";")?; } crate::Statement::WorkGroupUniformLoad { pointer, result } => { @@ -3827,7 +3833,33 @@ impl Writer { }}" )?; } - &crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {} + &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => { + let arg_type_name = scalar.to_msl_name(); + let called_func_name = "atomic_compare_exchange_weak_explicit"; + let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION; + let struct_name = &self.names[&NameKey::Type(*struct_ty)]; + + writeln!(self.out)?; + + for address_space_name in ["device", "threadgroup"] { + writeln!( + self.out, + "\ +template +{struct_name} {defined_func_name}( + {address_space_name} A *atomic_ptr, + {arg_type_name} cmp, + {arg_type_name} v +) {{ + bool swapped = {NAMESPACE}::{called_func_name}( + atomic_ptr, &cmp, v, + metal::memory_order_relaxed, metal::memory_order_relaxed + ); + return {struct_name}{{cmp, swapped}}; +}}" + )?; + } + } } } @@ -6065,8 +6097,8 @@ fn test_stack_size() { } impl crate::AtomicFunction { - fn to_msl(self) -> Result<&'static str, Error> { - Ok(match self { + const fn to_msl(self) -> &'static str { + match self { Self::Add => "fetch_add", Self::Subtract => "fetch_sub", Self::And => "fetch_and", @@ -6075,10 +6107,8 @@ impl crate::AtomicFunction { Self::Min => "fetch_min", Self::Max => "fetch_max", Self::Exchange { compare: None } => "exchange", - Self::Exchange { compare: Some(_) } => Err(Error::FeatureNotImplemented( - "atomic CompareExchange".to_string(), - ))?, - }) + Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION, + } } fn to_msl_64_bit(self) -> Result<&'static str, Error> { diff --git a/naga/tests/out/msl/atomicCompareExchange.msl b/naga/tests/out/msl/atomicCompareExchange.msl new file mode 100644 index 0000000000..800b5b2012 --- /dev/null +++ b/naga/tests/out/msl/atomicCompareExchange.msl @@ -0,0 +1,161 @@ +// language: metal1.0 +#include +#include + +using metal::uint; + +struct type_2 { + metal::atomic_int inner[128]; +}; +struct type_4 { + metal::atomic_uint inner[128]; +}; +struct _atomic_compare_exchange_resultSint4_ { + int old_value; + bool exchanged; +}; +struct _atomic_compare_exchange_resultUint4_ { + uint old_value; + bool exchanged; +}; + +template +_atomic_compare_exchange_resultSint4_ naga_atomic_compare_exchange_weak_explicit( + device A *atomic_ptr, + int cmp, + int v +) { + bool swapped = metal::atomic_compare_exchange_weak_explicit( + atomic_ptr, &cmp, v, + metal::memory_order_relaxed, metal::memory_order_relaxed + ); + return _atomic_compare_exchange_resultSint4_{cmp, swapped}; +} +template +_atomic_compare_exchange_resultSint4_ naga_atomic_compare_exchange_weak_explicit( + threadgroup A *atomic_ptr, + int cmp, + int v +) { + bool swapped = metal::atomic_compare_exchange_weak_explicit( + atomic_ptr, &cmp, v, + metal::memory_order_relaxed, metal::memory_order_relaxed + ); + return _atomic_compare_exchange_resultSint4_{cmp, swapped}; +} + +template +_atomic_compare_exchange_resultUint4_ naga_atomic_compare_exchange_weak_explicit( + device A *atomic_ptr, + uint cmp, + uint v +) { + bool swapped = metal::atomic_compare_exchange_weak_explicit( + atomic_ptr, &cmp, v, + metal::memory_order_relaxed, metal::memory_order_relaxed + ); + return _atomic_compare_exchange_resultUint4_{cmp, swapped}; +} +template +_atomic_compare_exchange_resultUint4_ naga_atomic_compare_exchange_weak_explicit( + threadgroup A *atomic_ptr, + uint cmp, + uint v +) { + bool swapped = metal::atomic_compare_exchange_weak_explicit( + atomic_ptr, &cmp, v, + metal::memory_order_relaxed, metal::memory_order_relaxed + ); + return _atomic_compare_exchange_resultUint4_{cmp, swapped}; +} +constant uint SIZE = 128u; + +kernel void test_atomic_compare_exchange_i32_( + device type_2& arr_i32_ [[user(fake0)]] +) { + uint i = 0u; + int old = {}; + bool exchanged = {}; +#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop) + bool loop_init = true; + LOOP_IS_REACHABLE while(true) { + if (!loop_init) { + uint _e27 = i; + i = _e27 + 1u; + } + loop_init = false; + uint _e2 = i; + if (_e2 < SIZE) { + } else { + break; + } + { + uint _e6 = i; + int _e8 = metal::atomic_load_explicit(&arr_i32_.inner[_e6], metal::memory_order_relaxed); + old = _e8; + exchanged = false; + LOOP_IS_REACHABLE while(true) { + bool _e12 = exchanged; + if (!(_e12)) { + } else { + break; + } + { + int _e14 = old; + int new_ = as_type(as_type(_e14) + 1.0); + uint _e20 = i; + int _e22 = old; + _atomic_compare_exchange_resultSint4_ _e23 = naga_atomic_compare_exchange_weak_explicit(&arr_i32_.inner[_e20], _e22, new_); + old = _e23.old_value; + exchanged = _e23.exchanged; + } + } + } + } + return; +} + + +kernel void test_atomic_compare_exchange_u32_( + device type_4& arr_u32_ [[user(fake0)]] +) { + uint i_1 = 0u; + uint old_1 = {}; + bool exchanged_1 = {}; + bool loop_init_1 = true; + LOOP_IS_REACHABLE while(true) { + if (!loop_init_1) { + uint _e27 = i_1; + i_1 = _e27 + 1u; + } + loop_init_1 = false; + uint _e2 = i_1; + if (_e2 < SIZE) { + } else { + break; + } + { + uint _e6 = i_1; + uint _e8 = metal::atomic_load_explicit(&arr_u32_.inner[_e6], metal::memory_order_relaxed); + old_1 = _e8; + exchanged_1 = false; + LOOP_IS_REACHABLE while(true) { + bool _e12 = exchanged_1; + if (!(_e12)) { + } else { + break; + } + { + uint _e14 = old_1; + uint new_1 = as_type(as_type(_e14) + 1.0); + uint _e20 = i_1; + uint _e22 = old_1; + _atomic_compare_exchange_resultUint4_ _e23 = naga_atomic_compare_exchange_weak_explicit(&arr_u32_.inner[_e20], _e22, new_1); + old_1 = _e23.old_value; + exchanged_1 = _e23.exchanged; + } + } + } + } + return; +} diff --git a/naga/tests/out/msl/overrides-atomicCompareExchangeWeak.msl b/naga/tests/out/msl/overrides-atomicCompareExchangeWeak.msl new file mode 100644 index 0000000000..d87190c595 --- /dev/null +++ b/naga/tests/out/msl/overrides-atomicCompareExchangeWeak.msl @@ -0,0 +1,48 @@ +// language: metal1.0 +#include +#include + +using metal::uint; + +struct _atomic_compare_exchange_resultUint4_ { + uint old_value; + bool exchanged; +}; + +template +_atomic_compare_exchange_resultUint4_ naga_atomic_compare_exchange_weak_explicit( + device A *atomic_ptr, + uint cmp, + uint v +) { + bool swapped = metal::atomic_compare_exchange_weak_explicit( + atomic_ptr, &cmp, v, + metal::memory_order_relaxed, metal::memory_order_relaxed + ); + return _atomic_compare_exchange_resultUint4_{cmp, swapped}; +} +template +_atomic_compare_exchange_resultUint4_ naga_atomic_compare_exchange_weak_explicit( + threadgroup A *atomic_ptr, + uint cmp, + uint v +) { + bool swapped = metal::atomic_compare_exchange_weak_explicit( + atomic_ptr, &cmp, v, + metal::memory_order_relaxed, metal::memory_order_relaxed + ); + return _atomic_compare_exchange_resultUint4_{cmp, swapped}; +} +constant int o = 2; + +kernel void f( + metal::uint3 __local_invocation_id [[thread_position_in_threadgroup]] +, threadgroup metal::atomic_uint& a +) { + if (metal::all(__local_invocation_id == metal::uint3(0u))) { + metal::atomic_store_explicit(&a, 0, metal::memory_order_relaxed); + } + metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); + _atomic_compare_exchange_resultUint4_ _e5 = naga_atomic_compare_exchange_weak_explicit(&a, 2u, 1u); + return; +} diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index adf67f8333..78b2331dca 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -773,7 +773,10 @@ fn convert_wgsl() { "atomicOps", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), - ("atomicCompareExchange", Targets::SPIRV | Targets::WGSL), + ( + "atomicCompareExchange", + Targets::SPIRV | Targets::METAL | Targets::WGSL, + ), ( "padding", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, @@ -917,7 +920,7 @@ fn convert_wgsl() { ), ( "overrides-atomicCompareExchangeWeak", - Targets::IR | Targets::SPIRV, + Targets::IR | Targets::SPIRV | Targets::METAL, ), ( "overrides-ray-query",