diff --git a/src/cpu/aarch64/acl_softmax.hpp b/src/cpu/aarch64/acl_softmax.hpp index 2a4db7c9d4e..020e6ca5ab0 100644 --- a/src/cpu/aarch64/acl_softmax.hpp +++ b/src/cpu/aarch64/acl_softmax.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2022 Arm Ltd. and affiliates +* Copyright 2021-2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -94,12 +94,13 @@ struct acl_softmax_fwd_t : public primitive_t { status_t init(engine_t *engine) { bool ok = is_fwd() - // ACL only supports matching src/dst data types - && src_md()->data_type == dst_md()->data_type + && set_default_formats() == status::success + // ACL only supports matching src/dst (this must come after + // set_default_formats() to handle format_kind::any) + && *src_md() == *dst_md() && utils::one_of( src_md()->data_type, data_type::f32, data_type::f16) - && attr()->has_default_values() - && set_default_formats() == status::success; + && attr()->has_default_values(); if (!ok) return status::unimplemented; // Get memory desc to find sizes and dims