From 5be961565482669bb9e5de86c08a678afcbe5a95 Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 20 Dec 2024 12:45:56 -0600 Subject: [PATCH 1/2] Layout convolution as NHWC or NCHW only --- src/layout_convolution.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/layout_convolution.cpp b/src/layout_convolution.cpp index 83acb839ce6..7cd629e8a1a 100644 --- a/src/layout_convolution.cpp +++ b/src/layout_convolution.cpp @@ -40,14 +40,18 @@ inline namespace MIGRAPHX_INLINE_NS { namespace { std::vector get_permutation(instruction_ref ins, const layout_convolution& lc) { + std::vector perm(ins->get_shape().ndim()); if(lc.channels_last) { - std::vector perm(ins->get_shape().ndim()); std::iota(perm.begin() + 1, perm.end() - 1, 2); perm.back() = 1; - return perm; } - return find_permutation(ins->inputs().front()->get_shape()); + else + { + std::iota(perm.begin(), perm.end(), 0); + } + return perm; + } bool skip_layout(const shape& s) From b62c7a8c367127fc54d31dedfe4413942533d1a0 Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 20 Dec 2024 12:46:39 -0600 Subject: [PATCH 2/2] Format --- src/layout_convolution.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/layout_convolution.cpp b/src/layout_convolution.cpp index 7cd629e8a1a..8f6c85ead3c 100644 --- a/src/layout_convolution.cpp +++ b/src/layout_convolution.cpp @@ -51,7 +51,6 @@ std::vector get_permutation(instruction_ref ins, const layout_convoluti std::iota(perm.begin(), perm.end(), 0); } return perm; - } bool skip_layout(const shape& s)