Skip to content

Commit

Permalink
cpu: aarch64: softmax: fall back to ref if stag != dtag
Browse files Browse the repository at this point in the history
`acl_softmax` would return incorrect answer if stag did not match dtag,
for example
```sh
./benchdnn -v5 --softmax --stag=axb --dtag=abx 2x19x16x64
```
This patch fixes that by falling back to ref if stag != dtag.
  • Loading branch information
jondea authored and mgouicem committed Mar 12, 2024
1 parent e44b78f commit d3886c8
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/cpu/aarch64/acl_softmax.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2022 Arm Ltd. and affiliates
* Copyright 2021-2024 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -94,12 +94,13 @@ struct acl_softmax_fwd_t : public primitive_t {
status_t init(engine_t *engine) {

bool ok = is_fwd()
// ACL only supports matching src/dst data types
&& src_md()->data_type == dst_md()->data_type
&& set_default_formats() == status::success
// ACL only supports matching src/dst (this must come after
// set_default_formats() to handle format_kind::any)
&& *src_md() == *dst_md()
&& utils::one_of(
src_md()->data_type, data_type::f32, data_type::f16)
&& attr()->has_default_values()
&& set_default_formats() == status::success;
&& attr()->has_default_values();
if (!ok) return status::unimplemented;

// Get memory desc to find sizes and dims
Expand Down

0 comments on commit d3886c8

Please sign in to comment.