diff --git a/src/common/memory_tracking.hpp b/src/common/memory_tracking.hpp index e193980e28b..46867487a5d 100644 --- a/src/common/memory_tracking.hpp +++ b/src/common/memory_tracking.hpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright 2018-2024 Intel Corporation +* Copyright 2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/cpu/aarch64/matmul/acl_matmul.cpp b/src/cpu/aarch64/matmul/acl_matmul.cpp index 91f342d44e3..51080921b64 100644 --- a/src/cpu/aarch64/matmul/acl_matmul.cpp +++ b/src/cpu/aarch64/matmul/acl_matmul.cpp @@ -44,24 +44,32 @@ status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const { // Run transpose kernel if (is_transA && !is_transB) { - acl_obj.src_tensor.allocator()->allocate(); + auto transA_scratch = scratchpad.get( + memory_tracking::names::key_matmul_src_trans); + acl_obj.src_tensor.allocator()->import_memory(transA_scratch); acl_obj.src_acc_tensor.allocator()->import_memory( const_cast(src_base)); acl_obj.transA.run(); acl_obj.wei_tensor.allocator()->import_memory( const_cast(wei_base)); } else if (is_transB && !is_transA) { - acl_obj.wei_tensor.allocator()->allocate(); + auto transB_scratch = scratchpad.get( + memory_tracking::names::key_matmul_wei_trans); + acl_obj.wei_tensor.allocator()->import_memory(transB_scratch); acl_obj.wei_acc_tensor.allocator()->import_memory( const_cast(wei_base)); acl_obj.transB.run(); acl_obj.src_tensor.allocator()->import_memory( const_cast(src_base)); } else if (is_transA && is_transB && !do_transC) { - acl_obj.src_tensor.allocator()->allocate(); + auto transA_scratch = scratchpad.get( + memory_tracking::names::key_matmul_src_trans); + auto transB_scratch = scratchpad.get( + memory_tracking::names::key_matmul_wei_trans); + acl_obj.src_tensor.allocator()->import_memory(transA_scratch); acl_obj.src_acc_tensor.allocator()->import_memory( const_cast(src_base)); - acl_obj.wei_tensor.allocator()->allocate(); + acl_obj.wei_tensor.allocator()->import_memory(transB_scratch); acl_obj.wei_acc_tensor.allocator()->import_memory( const_cast(wei_base)); acl_obj.transA.run(); @@ -71,7 +79,11 @@ status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const { const_cast(src_base)); acl_obj.wei_tensor.allocator()->import_memory( const_cast(wei_base)); - if (do_transC) { acl_obj.dst_acc_tensor.allocator()->allocate(); } + if (do_transC) { + auto transC_scratch = scratchpad.get( + memory_tracking::names::key_matmul_dst_trans); + acl_obj.dst_acc_tensor.allocator()->import_memory(transC_scratch); + } } // If we have an unfused sum post op, put the result in a scratchpad tensor. @@ -94,6 +106,7 @@ status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const { pd()->acl_post_ops.execute(ctx, dst); acl_obj.dst_tensor.allocator()->free(); + if (do_transC) acl_obj.dst_acc_tensor.allocator()->free(); return status; } diff --git a/src/cpu/aarch64/matmul/acl_matmul.hpp b/src/cpu/aarch64/matmul/acl_matmul.hpp index 1427a5bcb85..6385996f490 100644 --- a/src/cpu/aarch64/matmul/acl_matmul.hpp +++ b/src/cpu/aarch64/matmul/acl_matmul.hpp @@ -170,7 +170,8 @@ struct acl_matmul_t : public primitive_t { } auto scratchpad = scratchpad_registry().registrar(); - CHECK(acl_matmul_utils::init_scratchpad(scratchpad, amp_, dst_md_)); + CHECK(acl_matmul_utils::init_scratchpad( + scratchpad, amp_, src_md_, weights_md_, dst_md_)); return status::success; } diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp index 134ce94c905..cf5d2b600d6 100644 --- a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp +++ b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp @@ -174,12 +174,28 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, } status_t init_scratchpad(memory_tracking::registrar_t &scratchpad, - acl_matmul_conf_t &, memory_desc_t &dst_md) { + const acl_matmul_conf_t &, const memory_desc_t &src_md, + const memory_desc_t &weights_md, const memory_desc_t &dst_md) { if (amp.use_dst_acc_for_sum) { const memory_desc_wrapper dst_d(&dst_md); scratchpad.book(memory_tracking::names::key_matmul_dst_in_acc_dt, dst_d.nelems(), dst_d.data_type_size()); } + if (amp.is_transA) { + const memory_desc_wrapper src_d(&src_md); + scratchpad.book(memory_tracking::names::key_matmul_src_trans, + src_d.nelems(), src_d.data_type_size()); + } + if (amp.is_transB) { + const memory_desc_wrapper wei_d(&weights_md); + scratchpad.book(memory_tracking::names::key_matmul_wei_trans, + wei_d.nelems(), wei_d.data_type_size()); + } + if (amp.do_transC) { + const memory_desc_wrapper dst_d(&dst_md); + scratchpad.book(memory_tracking::names::key_matmul_dst_trans, + dst_d.nelems(), dst_d.data_type_size()); + } return status::success; } diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.hpp b/src/cpu/aarch64/matmul/acl_matmul_utils.hpp index d3fa65d915a..2fc9e419ac8 100644 --- a/src/cpu/aarch64/matmul/acl_matmul_utils.hpp +++ b/src/cpu/aarch64/matmul/acl_matmul_utils.hpp @@ -63,7 +63,8 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, const primitive_attr_t &attr); status_t init_scratchpad(memory_tracking::registrar_t &scratchpad, - acl_matmul_conf_t &, memory_desc_t &dst_md); + const acl_matmul_conf_t &, const memory_desc_t &src_md, + const memory_desc_t &weights_md, const memory_desc_t &dst_md); } // namespace acl_matmul_utils