diff --git a/src/gpu/intel/gpu_post_ops.hpp b/src/gpu/intel/gpu_post_ops.hpp index 310c9df4344..f6b3b0848a1 100644 --- a/src/gpu/intel/gpu_post_ops.hpp +++ b/src/gpu/intel/gpu_post_ops.hpp @@ -108,18 +108,23 @@ struct ndim_normalizer_t { int ndims(const memory_desc_t &md) const { return md.ndims + bcast_ndims; } int dim_idx(int md_idx) const { - if (bcast_ndims == 0) return 0; return (md_idx < insert_idx) ? md_idx : md_idx + bcast_ndims; } dim_t dim(int idx, const memory_desc_t &md) const { auto &dims = md.dims; - return (idx < insert_idx) ? dims[idx] : dims[idx - bcast_ndims]; + return (idx < insert_idx) + ? dims[idx] + : (idx < insert_idx + bcast_ndims ? 1 + : dims[idx - bcast_ndims]); } dim_t stride(int idx, const memory_desc_t &md) const { auto &strides = md.format_desc.blocking.strides; - return (idx < insert_idx) ? strides[idx] : strides[idx - bcast_ndims]; + return (idx < insert_idx) + ? strides[idx] + : (idx < insert_idx + bcast_ndims ? 0 + : strides[idx - bcast_ndims]); } // Position to insert broadcast dimensions, dimensions