From 5567b7e210be12b07164e923abb26a2644921cc0 Mon Sep 17 00:00:00 2001 From: Jonathan Deakin Date: Wed, 7 Feb 2024 16:00:35 +0000 Subject: [PATCH] cpu: aarch64: softmax: fall back to ref if stag != dtag `acl_softmax` would return incorrect answer if stag did not match dtag, for example ```sh ./benchdnn -v5 --softmax --stag=axb --dtag=abx 2x19x16x64 ``` This patch fixes that by falling back to ref if stag != dtag. --- src/cpu/aarch64/acl_softmax.hpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/cpu/aarch64/acl_softmax.hpp b/src/cpu/aarch64/acl_softmax.hpp index 2a4db7c9d4e..020e6ca5ab0 100644 --- a/src/cpu/aarch64/acl_softmax.hpp +++ b/src/cpu/aarch64/acl_softmax.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2022 Arm Ltd. and affiliates +* Copyright 2021-2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -94,12 +94,13 @@ struct acl_softmax_fwd_t : public primitive_t { status_t init(engine_t *engine) { bool ok = is_fwd() - // ACL only supports matching src/dst data types - && src_md()->data_type == dst_md()->data_type + && set_default_formats() == status::success + // ACL only supports matching src/dst (this must come after + // set_default_formats() to handle format_kind::any) + && *src_md() == *dst_md() && utils::one_of( src_md()->data_type, data_type::f32, data_type::f16) - && attr()->has_default_values() - && set_default_formats() == status::success; + && attr()->has_default_values(); if (!ok) return status::unimplemented; // Get memory desc to find sizes and dims