Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[naga] Implement atomicCompareExchangeWeak for MSL backend #6265

Open
wants to merge 3 commits into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216).

- Allow using [VK_GOOGLE_display_timing](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_GOOGLE_display_timing.html) unsafely with the `VULKAN_GOOGLE_DISPLAY_TIMING` feature. By @DJMcNab in [#6149](https://github.com/gfx-rs/wgpu/pull/6149)

#### Metal

- Implement `atomicCompareExchangeWeak`. By @AsherJingkongChen in [#6265](https://github.com/gfx-rs/wgpu/pull/6265)

### Bug Fixes

- Fix incorrect hlsl image output type conversion. By @atlv24 in [#6123](https://github.com/gfx-rs/wgpu/pull/6123)
Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ pub enum Error {
UnsupportedAttribute(String),
#[error("function '{0}' is not supported for target MSL version")]
UnsupportedFunction(String),
#[error("scalar {0:?} is not supported for target MSL version")]
UnsupportedScalar(crate::Scalar),
#[error("can not use writeable storage buffers in fragment stage prior to MSL 1.2")]
UnsupportedWriteableStorageBuffer,
#[error("can not use writeable storage textures in {0:?} stage prior to MSL 1.2")]
Expand Down
140 changes: 90 additions & 50 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
const RAY_QUERY_FIELD_READY: &str = "ready";
const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";

pub(crate) const ATOMIC_COMP_EXCH_FUNCTION_KEY: &str = "naga_atomic_compare_exchange_weak";
pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";

Expand Down Expand Up @@ -1151,42 +1152,6 @@ impl<W: Write> Writer<W> {
Ok(())
}

fn put_atomic_operation(
&mut self,
pointer: Handle<crate::Expression>,
key: &str,
value: Handle<crate::Expression>,
context: &ExpressionContext,
) -> BackendResult {
// If the pointer we're passing to the atomic operation needs to be conditional
// for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
// the pointer operand should be unchecked.
let policy = context.choose_bounds_check_policy(pointer);
let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
&& self.put_bounds_checks(pointer, context, back::Level(0), "")?;

// If requested and successfully put bounds checks, continue the ternary expression.
if checked {
write!(self.out, " ? ")?;
}

write!(
self.out,
"{NAMESPACE}::atomic_{key}_explicit({ATOMIC_REFERENCE}"
)?;
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ", ")?;
self.put_expression(value, context, true)?;
write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;

// Finish the ternary expression.
if checked {
write!(self.out, " : DefaultConstructible()")?;
}

Ok(())
}

/// Emit code for the arithmetic expression of the dot product.
///
fn put_dot_product(
Expand Down Expand Up @@ -3045,24 +3010,61 @@ impl<W: Write> Writer<W> {
value,
result,
} => {
let context = &context.expression;

// This backend supports `SHADER_INT64_ATOMIC_MIN_MAX` but not
// `SHADER_INT64_ATOMIC_ALL_OPS`, so we can assume that if `result` is
// `Some`, we are not operating on a 64-bit value, and that if we are
// operating on a 64-bit value, `result` is `None`.
write!(self.out, "{level}")?;
let fun_str = if let Some(result) = result {
let fun_key = if let Some(result) = result {
let res_name = Baked(result).to_string();
self.start_baking_expression(result, &context.expression, &res_name)?;
self.start_baking_expression(result, context, &res_name)?;
self.named_expressions.insert(result, res_name);
fun.to_msl()?
} else if context.expression.resolve_type(value).scalar_width() == Some(8) {
fun.to_msl()
} else if context.resolve_type(value).scalar_width() == Some(8) {
fun.to_msl_64_bit()?
} else {
fun.to_msl()?
fun.to_msl()
};

self.put_atomic_operation(pointer, fun_str, value, &context.expression)?;
// done
// If the pointer we're passing to the atomic operation needs to be conditional
// for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
// the pointer operand should be unchecked.
let policy = context.choose_bounds_check_policy(pointer);
let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
&& self.put_bounds_checks(pointer, context, back::Level(0), "")?;

// If requested and successfully put bounds checks, continue the ternary expression.
if checked {
write!(self.out, " ? ")?;
}

write!(
self.out,
"{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
)?;
self.put_access_chain(pointer, policy, context)?;

// Put the extra argument if provided.
if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
write!(self.out, ", ")?;
self.put_expression(cmp, context, true)?;
write!(self.out, ", ")?;
self.put_expression(value, context, true)?;
write!(self.out, ")")?;
} else {
write!(self.out, ", ")?;
self.put_expression(value, context, true)?;
write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
}

// Finish the ternary expression.
if checked {
write!(self.out, " : DefaultConstructible()")?;
}

// Done
writeln!(self.out, ";")?;
}
crate::Statement::WorkGroupUniformLoad { pointer, result } => {
Expand Down Expand Up @@ -3690,7 +3692,47 @@ impl<W: Write> Writer<W> {
struct_name, struct_name
)?;
}
&crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
&crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
let crate::Scalar { kind, width } = scalar;
let arg_type_name = match width {
1 => "bool",
4 => match kind {
crate::ScalarKind::Sint => "int",
crate::ScalarKind::Uint => "uint",
crate::ScalarKind::Float => "float",
_ => return Err(Error::UnsupportedScalar(scalar)),
},
_ => return Err(Error::UnsupportedScalar(scalar)),
};

let called_func_name = "atomic_compare_exchange_weak_explicit";
let defined_func_key = ATOMIC_COMP_EXCH_FUNCTION_KEY;
let struct_name = &self.names[&NameKey::Type(*struct_ty)];

writeln!(self.out)?;
writeln!(self.out, "namespace {NAMESPACE} {{")?;

for address_space_name in ["device", "threadgroup"] {
writeln!(
self.out,
" \
template <typename A>
{struct_name} atomic_{defined_func_key}_explicit(
volatile {address_space_name} A *atomic_ptr,
{arg_type_name} cmp,
{arg_type_name} v
) {{
bool swapped = {NAMESPACE}::{called_func_name}(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return {struct_name}{{cmp, swapped}};
}}"
)?;
}

writeln!(self.out, "}}")?;
}
}
}

Expand Down Expand Up @@ -5928,8 +5970,8 @@ fn test_stack_size() {
}

impl crate::AtomicFunction {
fn to_msl(self) -> Result<&'static str, Error> {
Ok(match self {
const fn to_msl(self) -> &'static str {
match self {
Self::Add => "fetch_add",
Self::Subtract => "fetch_sub",
Self::And => "fetch_and",
Expand All @@ -5938,10 +5980,8 @@ impl crate::AtomicFunction {
Self::Min => "fetch_min",
Self::Max => "fetch_max",
Self::Exchange { compare: None } => "exchange",
Self::Exchange { compare: Some(_) } => Err(Error::FeatureNotImplemented(
"atomic CompareExchange".to_string(),
))?,
})
Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION_KEY,
}
}

fn to_msl_64_bit(self) -> Result<&'static str, Error> {
Expand Down
164 changes: 164 additions & 0 deletions naga/tests/out/msl/atomicCompareExchange.msl
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// language: metal1.0
#include <metal_stdlib>
#include <simd/simd.h>

using metal::uint;

struct type_2 {
metal::atomic_int inner[128];
};
struct type_4 {
metal::atomic_uint inner[128];
};
struct _atomic_compare_exchange_resultSint4_ {
int old_value;
bool exchanged;
};
struct _atomic_compare_exchange_resultUint4_ {
uint old_value;
bool exchanged;
};

namespace metal {
template <typename A>
_atomic_compare_exchange_resultSint4_ atomic_naga_atomic_compare_exchange_weak_explicit(
volatile device A *atomic_ptr,
int cmp,
int v
) {
bool swapped = metal::atomic_compare_exchange_weak_explicit(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return _atomic_compare_exchange_resultSint4_{cmp, swapped};
}
template <typename A>
_atomic_compare_exchange_resultSint4_ atomic_naga_atomic_compare_exchange_weak_explicit(
volatile threadgroup A *atomic_ptr,
int cmp,
int v
) {
bool swapped = metal::atomic_compare_exchange_weak_explicit(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return _atomic_compare_exchange_resultSint4_{cmp, swapped};
}
}

namespace metal {
template <typename A>
_atomic_compare_exchange_resultUint4_ atomic_naga_atomic_compare_exchange_weak_explicit(
volatile device A *atomic_ptr,
uint cmp,
uint v
) {
bool swapped = metal::atomic_compare_exchange_weak_explicit(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return _atomic_compare_exchange_resultUint4_{cmp, swapped};
}
template <typename A>
_atomic_compare_exchange_resultUint4_ atomic_naga_atomic_compare_exchange_weak_explicit(
volatile threadgroup A *atomic_ptr,
uint cmp,
uint v
) {
bool swapped = metal::atomic_compare_exchange_weak_explicit(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return _atomic_compare_exchange_resultUint4_{cmp, swapped};
}
}
constant uint SIZE = 128u;

kernel void test_atomic_compare_exchange_i32_(
device type_2& arr_i32_ [[user(fake0)]]
) {
uint i = 0u;
int old = {};
bool exchanged = {};
bool loop_init = true;
while(true) {
if (!loop_init) {
uint _e27 = i;
i = _e27 + 1u;
}
loop_init = false;
uint _e2 = i;
if (_e2 < SIZE) {
} else {
break;
}
{
uint _e6 = i;
int _e8 = metal::atomic_load_explicit(&arr_i32_.inner[_e6], metal::memory_order_relaxed);
old = _e8;
exchanged = false;
while(true) {
bool _e12 = exchanged;
if (!(_e12)) {
} else {
break;
}
{
int _e14 = old;
int new_ = as_type<int>(as_type<float>(_e14) + 1.0);
uint _e20 = i;
int _e22 = old;
_atomic_compare_exchange_resultSint4_ _e23 = metal::atomic_naga_atomic_compare_exchange_weak_explicit(&arr_i32_.inner[_e20], _e22, new_);
old = _e23.old_value;
exchanged = _e23.exchanged;
}
}
}
}
return;
}


kernel void test_atomic_compare_exchange_u32_(
device type_4& arr_u32_ [[user(fake0)]]
) {
uint i_1 = 0u;
uint old_1 = {};
bool exchanged_1 = {};
bool loop_init_1 = true;
while(true) {
if (!loop_init_1) {
uint _e27 = i_1;
i_1 = _e27 + 1u;
}
loop_init_1 = false;
uint _e2 = i_1;
if (_e2 < SIZE) {
} else {
break;
}
{
uint _e6 = i_1;
uint _e8 = metal::atomic_load_explicit(&arr_u32_.inner[_e6], metal::memory_order_relaxed);
old_1 = _e8;
exchanged_1 = false;
while(true) {
bool _e12 = exchanged_1;
if (!(_e12)) {
} else {
break;
}
{
uint _e14 = old_1;
uint new_1 = as_type<uint>(as_type<float>(_e14) + 1.0);
uint _e20 = i_1;
uint _e22 = old_1;
_atomic_compare_exchange_resultUint4_ _e23 = metal::atomic_naga_atomic_compare_exchange_weak_explicit(&arr_u32_.inner[_e20], _e22, new_1);
old_1 = _e23.old_value;
exchanged_1 = _e23.exchanged;
}
}
}
}
return;
}
Loading