From c32590eca27ef932e7f3e27877773c77d9ca9ca0 Mon Sep 17 00:00:00 2001 From: Roy Oursler Date: Mon, 28 Oct 2024 15:12:59 -0700 Subject: [PATCH 01/19] xe: jit: gemm: selector: db: add thin m kernel --- src/gpu/intel/jit/gemm/selector/db/kernel.db | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gpu/intel/jit/gemm/selector/db/kernel.db b/src/gpu/intel/jit/gemm/selector/db/kernel.db index 3e576846f7a..5b2b9a849db 100644 --- a/src/gpu/intel/jit/gemm/selector/db/kernel.db +++ b/src/gpu/intel/jit/gemm/selector/db/kernel.db @@ -104,6 +104,7 @@ auto _CATALOG_ = kcatalog::toFlatCatalog({ {{'C', "gemm", {"H", "H", "H"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, 8, -1}, {1, 1, 1}, ""}, "ab2x2 as16 ab l4 int", {8, (LoopType) 0, 128, {(LoopType) 0, (LoopType) 1, (LoopType) 255}, {4096, 4096, 2048}, {4096, 4096, 2048}, {32, 8, 16}, {2, 8, 1}, 1, (WGType) 0, 1, 0, 0, {2, 2, 2}, {true, true, true}}, {'W', 1, {256}}}, {{'C', "gemm", {"H", "H", "H"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "qxy"}, "sb2/1 su8x2 ab l4 ca1 wg 2x8 int", {8, (LoopType) 0, 128, {(LoopType) 0, (LoopType) 1, (LoopType) 255}, {4096, 4096, 2048}, {4096, 4096, 2048}, {64, 16, 16}, {2, 8, 1}, 1, (WGType) 1, 1, 4096, 0, {2, 2, 2}, {false, false, true}}, {'W', 1, {1024}}}, {{'C', "gemm", {"H", "H", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "ab4 as8 ab l4 ca1 wg 2x8 int", {8, (LoopType) 0, 128, {(LoopType) 0, (LoopType) 1, (LoopType) 255}, {4096, 4096, 2048}, {4096, 4096, 2048}, {32, 16, 8}, {2, 8, 1}, 1, (WGType) 1, 1, 2048, 0, {2, 2, 4}, {true, true, true}}, {'W', 1, {512}}}, +{{'C', "gemm", {"H", "H", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, 16, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "ab4x2 ab8 ab wg 2x1x8 ikr kc4 acb ar sb32 bk0 np", {8, (LoopType) 0, 128, {(LoopType) 0, (LoopType) 1, (LoopType) 2}, {4096, 4096, 16777216}, {4096, 4096, 16777216}, {32, 1, 8}, {2, 1, 8}, 1, (WGType) 0, 4101, 0, 256, {2, 2, 4}, {true, true, true}}, {'W', 1, {32}}}, {{'C', "gemm", {"H", "H", "H"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "ab2x2 ab2x2 ab k4 l4 vnc", {8, (LoopType) 0, 128, {(LoopType) 0, (LoopType) 1, (LoopType) 255}, {4096, 4096, 2048}, {4096, 4096, 2048}, {32, 32, 4}, {2, 8, 1}, 1, (WGType) 0, 1, 0, 0, {2, 2, 2}, {true, true, true}}, {'W', 1, {1024}}}, {{'C', "gemm", {"H", "H", "S"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "ab2 ab8 ab l4 cab1 wg 4x4 int", {8, (LoopType) 0, 128, {(LoopType) 0, (LoopType) 1, (LoopType) 255}, {4096, 4096, 2048}, {4096, 4096, 2048}, {32, 16, 8}, {4, 4, 1}, 1, (WGType) 1, 1, 6144, 0, {2, 2, 4}, {true, true, true}}, {'W', 1, {512}}}, {{'C', "gemm", {"H", "H", "H"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "as4 as8 ab k8 l4 vnc", {8, (LoopType) 0, 128, {(LoopType) 0, (LoopType) 1, (LoopType) 255}, {4096, 4096, 1024}, {4096, 4096, 1024}, {32, 32, 8}, {2, 8, 1}, 1, (WGType) 0, 1, 0, 0, {2, 2, 2}, {true, true, true}}, {'W', 1, {1024}}}, From 11772a6828a72db70c976ac26848f841e33ba681 Mon Sep 17 00:00:00 2001 From: rupakroyintel Date: Tue, 29 Oct 2024 09:17:27 -0700 Subject: [PATCH 02/19] add example for int4 weight decompression --- .../int4_weight_decompression_cmnts.cpp | 267 ++++++++++++++++++ 1 file changed, 267 insertions(+) create mode 100644 examples/tutorials/matmul/int4_weight_decompression_cmnts.cpp diff --git a/examples/tutorials/matmul/int4_weight_decompression_cmnts.cpp b/examples/tutorials/matmul/int4_weight_decompression_cmnts.cpp new file mode 100644 index 00000000000..6eeb7be57f0 --- /dev/null +++ b/examples/tutorials/matmul/int4_weight_decompression_cmnts.cpp @@ -0,0 +1,267 @@ +/******************************************************************************* + * Copyright 2023-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ +/// @example int4_weights_decompression.cpp +/// > Annotated version: @ref int4_weights_decompression.cpp +/// +/// @page int4_weights_decompression +/// C++ API example demonstrating how one can use +/// [MatMul](@ref dev_guide_matmul) with int4 compressed weights. +/// +/// Concepts: +/// - AWQ (activation-aware quantization) +/// - Scales: dnnl::primitive_attr::set_scales() +/// - Zero points: dnnl::primitive_attr::set_zero_points() +/// - [Operation fusion](@ref dev_guide_attributes_post_ops) +/// - Create primitive once, use multiple times +/// - Weights pre-packing: use #dnnl::memory::format_tag::any +/// +/// @page int4_weights_decompression_matmul_cpp MatMul Tutorial: weights +/// decompression +/// @copydetails int4_weights_decompression_matmul_cpp +/// +/// Assumptions: +/// 1. The shape of the weights (matrix \f$B(K, N)\f$) is known in advance, the +/// data type is `int4` and shifted from 0 (i.e. the zero point is not 0). +/// 2. The source matrix \f$A\f$ and destination matrix \f$C\f$ have floating +/// point data type. +/// 3. Scaling (re-quantization) factor specified at run-time only. +/// +/// Since the shape of weights is known in advance, the MatMul weights can be +/// created with format tag #dnnl::memory::format_tag::any to enable the library +/// to choose the most appropriate layout for best performance. +/// +/// @warning +/// The format tag #dnnl::memory::format_tag::any doesn't work for memory +/// descriptors that have one or more unknown dimensions and/or strides. +/// +/// @include weights_decompression_matmul.cpp +#include +#include +#include +#include +#include +#include +#include +#include + +#include "oneapi/dnnl/dnnl.hpp" + +#include "example_utils.hpp" + +using namespace dnnl; + +namespace { + +void init_vector(std::vector &v) { + std::mt19937 gen; + std::uniform_real_distribution u(0, 1); + for (auto &e : v) + e = u(gen); +} +// Comparing two vectors by calculating their L2 norms and the L2 norm of their +// difference Checking if the difference is within a calculated threshold The +// function returns 0 if the vectors are considered similar, otherwise it +// returns 1. --Rupak +int compare_vectors(const std::vector &v1, const std::vector &v2, + int64_t K, const char *message) { + double v1_l2 = 0, diff_l2 = 0; + for (size_t n = 0; n < v1.size(); ++n) { + float diff = v1[n] - v2[n]; + v1_l2 += v1[n] * v1[n]; + diff_l2 += diff * diff; + } + + v1_l2 = std::sqrt(v1_l2); + diff_l2 = std::sqrt(diff_l2); + + // Finding the reasonable (tight and accurate) threshold is quite difficult + // problem. + // The implementation testing might also use special data filling to + // alleviate issues related to the finite precision arithmetic. + // However, in simple cases the machine epsilon multiplied by log(K) should + // work reasonably well. + const double threshold = std::numeric_limits::epsilon() + * std::log(std::max(2., (double)K)); + bool ok = diff_l2 <= threshold * v1_l2; + + printf("%s\n\tL2 Norms" + "\n\t\tReference matrix:%g\n\t\tError:%g\n\t\tRelative_error:%g\n" + "\tAccuracy check: %s\n", + message, v1_l2, diff_l2, diff_l2 / v1_l2, ok ? "OK" : "FAILED"); + + return ok ? 0 : 1; +} + +} // namespace + +// Floating point MatMul +// Inputs: +// - Shape: M, N, K +// - Matrices A and B +// Outputs: +// - Matrix C +void ref_compute_matmul_f32(int64_t M, int64_t N, int64_t K, int64_t G, + std::vector &A_f32, std::vector &B_f32, + std::vector &zp_B_f32, std::vector &sc_B, + std::vector &C_f32) { + // Perform the GEMM operation + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + for (int k = 0; k < K; ++k) { + // Decompress the weight + int64_t idx1 = k * N + n; + int64_t idx2 = (k / G) * N + n; + float decompressed_B + = (B_f32[idx1] - zp_B_f32[idx1]) * sc_B[idx2]; + // Perform the multiplication and accumulation + C_f32[m * N + n] += A_f32[m * K + k] * decompressed_B; + } + } + } +} + +// Create a MatMul primitive descriptor for the following op: +// C_f32 = A_f32 * (B_s4 - zp_B) * sc_B[:] --Rupak +matmul::primitive_desc matmul_pd_create( + int64_t M, int64_t N, int64_t K, int64_t G, const engine &eng) { + + memory::desc a_md({M, K}, memory::data_type::f32, {K, 1}); // M x K layout + memory::desc b_md({K, N}, memory::data_type::s4, + memory::format_tag::any); // K x N layout + memory::desc c_md({M, N}, memory::data_type::f32, {N, 1}); // M x N layout + + // Create attributes and indicate that the alpha and zero points are + // runtime parameters + primitive_attr attr; + // Set scales with multiple scales along K and N dimensions and with groups + // along K. + attr.set_scales(DNNL_ARG_WEIGHTS, + /* mask */ (1 << 0) + (1 << 1), {G, 1}, memory::data_type::f32); + + // Set zero points with s4 data type. + // The mask determines which dimensions the zero points are applied to. + // Current mask value (1 << 0) + (1 << 1) means zero points are applied + // both along K and N dimensions. + // Changing the mask value would alter the dimensions along which the zero + // points are applied. For example: + // - mask = (1 << 0) would apply zero points only along the K dimension. + // - mask = (1 << 1) would apply zero points only along the N dimension. + int mask = (1 << 0) + + (1 << 1); // zero points both along K and N dimensions --Rupak + memory::dims groups = {}; + attr.set_zero_points(DNNL_ARG_WEIGHTS, mask, groups, memory::data_type::s4); + + // Set fpmath mode with `apply_to_int=true` to apply fpmath mode behavior to + // integral primitives (in this example, matmul). + attr.set_fpmath_mode(fpmath_mode::f16, true); + + // Create a MatMul primitive descriptor + return matmul::primitive_desc(eng, a_md, b_md, c_md, attr); +} + +// Function to perform matrix multiplication with int4 weights decompression +// using oneDNN --Rupka +void weights_decompression_matmul(int64_t M, int64_t N, int64_t K, int64_t G, + std::vector &A_f32, std::vector &B_f32, + std::vector &zp_B_f32, std::vector &sc_B, + std::vector &C_f32, const engine &eng) { + auto matmul_pd = matmul_pd_create(M, N, K, G, eng); + stream s(eng); + + // Pre-packed weights stored as int4 + memory B_s4_mem(matmul_pd.weights_desc(), eng); + { + memory B_f32_mem( + {{K, N}, memory::data_type::f32, memory::format_tag::ab}, eng); + write_to_dnnl_memory(B_f32.data(), B_f32_mem); + reorder(B_f32_mem, B_s4_mem).execute(s, B_f32_mem, B_s4_mem); + s.wait(); + } + matmul matmul_p(matmul_pd); + + // input of the current layer / operation + memory A_f32_mem({{M, K}, memory::data_type::f32, {K, 1}}, eng); + // De-quantization parameters (eg. Scale and Shift) + const int64_t n_groups = K / G; + memory sc_B_mem({{N, n_groups}, memory::data_type::f32, {1, N}}, eng); + + // Pre-packed zp stored as int4 + // A unique zero point is used for each weight in this example + // Allocates memory for zp_B_s4_mem with specified dimensions and data type. + // --Rupak + memory zp_B_s4_mem({{K, N}, memory::data_type::s4, {1, K}}, eng); + { + memory zp_B_f32_mem({{K, N}, memory::data_type::f32, {1, K}}, eng); + write_to_dnnl_memory(zp_B_f32.data(), zp_B_f32_mem); + reorder(zp_B_f32_mem, zp_B_s4_mem) + .execute(s, zp_B_f32_mem, zp_B_s4_mem); + s.wait(); + } + + write_to_dnnl_memory(A_f32.data(), A_f32_mem); + write_to_dnnl_memory(sc_B.data(), sc_B_mem); + + // output - no initialization required + memory C_f32_mem({{M, N}, memory::data_type::f32, {N, 1}}, eng); + + matmul_p.execute(s, + {{DNNL_ARG_SRC, A_f32_mem}, {DNNL_ARG_WEIGHTS, B_s4_mem}, + {DNNL_ARG_DST, C_f32_mem}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, sc_B_mem}, + {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, + zp_B_s4_mem}}); + s.wait(); +} + +// Compares the results of reference matrix multiplication and oneDNN weights +// decompression. --Rupak +void compare_ref_and_weights_decompression(engine::kind engine_kind) { + engine eng(engine_kind, 0); + + // MatMul parameters + const int64_t M = 1, N = 4096, K = 1024; + // Quantization Group size for scales + const int64_t G = 64; + + // Prepare matrices + std::vector A_f32(M * K), C_ref(M * N), sc_B(K * N / G); + std::vector B_f32(K * N); + std::vector zp_B_f32(K * N); + init_vector(A_f32); + init_vector(B_f32); + init_vector(sc_B); + init_vector(zp_B_f32); + init_vector(C_ref); + std::vector C_onednn = C_ref; + + // Compute _true_ C_ref result + ref_compute_matmul_f32(M, N, K, G, A_f32, B_f32, zp_B_f32, sc_B, C_ref); + + // Compute _true_ C_onednn result + weights_decompression_matmul( + M, N, K, G, A_f32, B_f32, zp_B_f32, sc_B, C_onednn, eng); + + int rc = 0; + rc |= compare_vectors( + C_ref, C_onednn, K, "Compare ref vs oneDNN weights decompression"); + if (rc) throw std::logic_error("The resulting matrices diverged too much."); +} + +int main(int argc, char **argv) { + engine::kind engine_kind = parse_engine_kind(argc, argv); + return handle_example_errors( + compare_ref_and_weights_decompression, engine_kind); +} \ No newline at end of file From 8f269b22e7eee1998cfe5391da32deccf979da4d Mon Sep 17 00:00:00 2001 From: Daniel Youssif Date: Fri, 18 Oct 2024 10:04:38 -0700 Subject: [PATCH 03/19] ocl: matmul: ref: allow f32 dst for f16 src --- src/gpu/intel/ocl/ref_matmul.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gpu/intel/ocl/ref_matmul.hpp b/src/gpu/intel/ocl/ref_matmul.hpp index 91c71a3af25..428f0289c46 100644 --- a/src/gpu/intel/ocl/ref_matmul.hpp +++ b/src/gpu/intel/ocl/ref_matmul.hpp @@ -79,7 +79,7 @@ struct ref_matmul_t : public gpu_primitive_t { && dst_dt_ == f32; const bool is_f16 = src_dt_ == f16 && utils::one_of(wei_dt_, f16, s8, u8, s4, u4) - && utils::one_of(dst_dt_, u8, s8, f16); + && utils::one_of(dst_dt_, u8, s8, f16, f32); const bool is_f8 = (utils::one_of(src_dt_, f8_e5m2, f8_e4m3) || utils::one_of(wei_dt_, f8_e5m2, f8_e4m3)) From 3109fd07110efe1d7d691365b02d5030abab2b56 Mon Sep 17 00:00:00 2001 From: Daniel Youssif Date: Fri, 25 Oct 2024 13:08:11 -0700 Subject: [PATCH 04/19] xe: jit: gemm: fixup s4 scales --- src/gpu/intel/jit/gemm/generator/pieces/quantization.cxx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gpu/intel/jit/gemm/generator/pieces/quantization.cxx b/src/gpu/intel/jit/gemm/generator/pieces/quantization.cxx index b080d43fd62..be4b8fc21ae 100644 --- a/src/gpu/intel/jit/gemm/generator/pieces/quantization.cxx +++ b/src/gpu/intel/jit/gemm/generator/pieces/quantization.cxx @@ -82,7 +82,7 @@ bool BLASKernelGenerator::gemmMake2DQuantizationLayouts(bool isA, const GEMM bool int4SpecialPath = Tx_ext.isInt4() && one_of(Tx, Type::f16, Type::bf16, Type::f32); if (int4SpecialPath) - Txo_int = Txs_int = Type::f16; + Txo_int = Txs_int = Tx_scaleOp = Type::f16; // Get tile sizes, depending on whether A/B are copied to SLM. // For late scaling (after compute), scales are always applied to the whole tile. From dee6e7fef95835fea17cc7c5ec326749d0458ba7 Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Thu, 24 Oct 2024 13:16:50 -0700 Subject: [PATCH 05/19] xe: jit: gemm: legalize hf8 mov regioning --- .../jit/gemm/generator/pieces/copy_plan.cpp | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/gpu/intel/jit/gemm/generator/pieces/copy_plan.cpp b/src/gpu/intel/jit/gemm/generator/pieces/copy_plan.cpp index 180c8931751..e7bafef29e1 100644 --- a/src/gpu/intel/jit/gemm/generator/pieces/copy_plan.cpp +++ b/src/gpu/intel/jit/gemm/generator/pieces/copy_plan.cpp @@ -733,10 +733,12 @@ void CopyPlan::planTypeConversions() } else planEmulatedHalveFloat(i); } - } else if (st == DataType::hf8 && dt == DataType::hf && hw < HW::Xe3) { - planEmulatedHF8ToHF(i); - } else if (st == DataType::hf && dt == DataType::hf8 && hw < HW::Xe3) { - planEmulatedHFToHF8(i); + } else if (st == DataType::hf8 && dt == DataType::hf) { + if (hw < HW::Xe3) + planEmulatedHF8ToHF(i); + } else if (st == DataType::hf && dt == DataType::hf8) { + if (hw < HW::Xe3) + planEmulatedHFToHF8(i); } else if (st != dt && (isFP8(st) || isFP8(dt))) { copyThrough(i, DataType::hf, 1); rerun = true; @@ -1328,11 +1330,13 @@ void CopyPlan::legalizeRegions() if (!i.dst) continue; /* Check for special packed conversion cases */ - if (i.op == Opcode::mov && s0t == DataType::hf && dt == DataType::bf8) { - // hf -> bf8: src0/dst must be packed unit stride, zero offset - if (i.src0.offset != 0 || i.src0.stride != 1) + if (i.op == Opcode::mov && ((s0t == DataType::hf && isFP8(dt)) + || (dt == DataType::hf && isFP8(s0t)))) { + // hf <-> bf8/hf8: src0/dst must be packed unit stride, zero offset + if (i.src0.offset != 0 || i.src0.stride != 1) { repositionSrc(i, 0, 1, 0); - if (i.dst.offset != 0 || i.dst.stride != 1) + rerun = true; + } else if (i.dst.offset != 0 || i.dst.stride != 1) repositionDst(i, 1, 0); if (i.simd == 1) hw_unsupported(); continue; From ba1740cbc6acc08c9d7d5c8cbf45c9c56adec7db Mon Sep 17 00:00:00 2001 From: Daniel Youssif Date: Mon, 21 Oct 2024 10:20:38 -0700 Subject: [PATCH 06/19] build: add Xe3 ISA flag --- cmake/configuring_primitive_list.cmake | 2 +- cmake/options.cmake | 2 +- doc/build/build_options.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cmake/configuring_primitive_list.cmake b/cmake/configuring_primitive_list.cmake index 3524f171070..dae7460336e 100644 --- a/cmake/configuring_primitive_list.cmake +++ b/cmake/configuring_primitive_list.cmake @@ -58,7 +58,7 @@ if (DNNL_ENABLE_PRIMITIVE_GPU_ISA STREQUAL "ALL") else() foreach(isa ${DNNL_ENABLE_PRIMITIVE_GPU_ISA}) string(TOUPPER ${isa} uisa) - if(NOT "${uisa}" MATCHES "^(GEN9|GEN11|XELP|XEHP|XEHPG|XEHPC|XE2)$") + if(NOT "${uisa}" MATCHES "^(GEN9|GEN11|XELP|XEHP|XEHPG|XEHPC|XE2|XE3)$") message(FATAL_ERROR "Unsupported primitive GPU ISA: ${uisa}") endif() set(BUILD_${uisa} TRUE) diff --git a/cmake/options.cmake b/cmake/options.cmake index 0bb963ae24b..635cf5e8ae3 100644 --- a/cmake/options.cmake +++ b/cmake/options.cmake @@ -147,7 +147,7 @@ set(DNNL_ENABLE_PRIMITIVE_GPU_ISA "ALL" CACHE STRING implementations will always be available. Valid values: - ALL (the default). Includes all ISA to be enabled. - ;;... Includes only selected ISA to be enabled. - Possible values are: GEN9, GEN11, XELP, XEHP, XEHPG, XEHPC, XE2.") + Possible values are: GEN9, GEN11, XELP, XEHP, XEHPG, XEHPC, XE2, XE3.") set(ONEDNN_ENABLE_GEMM_KERNELS_ISA "ALL" CACHE STRING "Specifies an ISA set of GeMM kernels residing in x64/gemm folder to be diff --git a/doc/build/build_options.md b/doc/build/build_options.md index 2bcdede9ce2..a6c9f48eb9e 100644 --- a/doc/build/build_options.md +++ b/doc/build/build_options.md @@ -118,7 +118,7 @@ Example that enables SSE41 and AVX2 sets: #### ONEDNN_ENABLE_PRIMITIVE_GPU_ISA This option supports several values: `ALL` (the default) which enables all ISA implementations or any set of `GEN9`, `GEN11`, `XELP`, `XEHP`, `XEHPG`, -`XEHPC`, and `XE2`. Selected ISA will enable correspondent parts in +`XEHPC`, `XE2`, and `XE3`. Selected ISA will enable correspondent parts in just-in-time kernel generation based implementations. OpenCL based kernels and implementations will always be available. Example that enables XeLP and XeHP set: From ffa263c51d781f217c4fc8e03ecdf76f34d14a9f Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Sat, 26 Oct 2024 18:02:15 -0700 Subject: [PATCH 07/19] xe: jit: ngen: gcc13 workaround for indirect jmpi --- src/gpu/intel/jit/ngen/ngen_gen12.hpp | 56 ++++++++++++++------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/src/gpu/intel/jit/ngen/ngen_gen12.hpp b/src/gpu/intel/jit/ngen/ngen_gen12.hpp index 9b6b374debd..f06855ad7cd 100644 --- a/src/gpu/intel/jit/ngen/ngen_gen12.hpp +++ b/src/gpu/intel/jit/ngen/ngen_gen12.hpp @@ -635,36 +635,40 @@ static inline constexpr14 TernaryOperand12 encodeTernaryOperand12(const Extended static inline void encodeCommon12(Instruction12 &i, Opcode opcode, const InstructionModifier &mod, const RegData &dst, EncodingTag12 tag) { - i.common.opcode = static_cast(opcode) | (mod.parts.autoSWSB << 7); - i.common.swsb = SWSBInfo12(mod.getSWSB(), opcode).raw(); - i.common.execSize = mod.parts.eSizeField; - i.common.execOffset = mod.parts.chanOff; - i.common.flagReg = (mod.parts.flagRegNum << 1) | mod.parts.flagSubRegNum; - i.common.predCtrl = mod.parts.predCtrl; - i.common.predInv = mod.parts.predInv; - i.common.cmptCtrl = mod.parts.cmptCtrl; - i.common.debugCtrl = mod.parts.debugCtrl; - i.common.maskCtrl = mod.parts.maskCtrl; - i.common.atomicCtrl = mod.parts.threadCtrl; - i.common.accWrCtrl = mod.parts.accWrCtrl; - i.common.saturate = mod.parts.saturate; + Instruction12 i2; /* separate variable to avoid gcc13 bug */ + i2.common.opcode = static_cast(opcode) | (mod.parts.autoSWSB << 7); + i2.common.swsb = SWSBInfo12(mod.getSWSB(), opcode).raw(); + i2.common.execSize = mod.parts.eSizeField; + i2.common.execOffset = mod.parts.chanOff; + i2.common.flagReg = (mod.parts.flagRegNum << 1) | mod.parts.flagSubRegNum; + i2.common.predCtrl = mod.parts.predCtrl; + i2.common.predInv = mod.parts.predInv; + i2.common.cmptCtrl = mod.parts.cmptCtrl; + i2.common.debugCtrl = mod.parts.debugCtrl; + i2.common.maskCtrl = mod.parts.maskCtrl; + i2.common.atomicCtrl = mod.parts.threadCtrl; + i2.common.accWrCtrl = mod.parts.accWrCtrl; + i2.common.saturate = mod.parts.saturate; + i.common = i2.common; } static inline void encodeCommon12(Instruction12 &i, Opcode opcode, const InstructionModifier &mod, const RegData &dst, EncodingTagXeHPC tag) { - i.common.opcode = static_cast(opcode) | (mod.parts.autoSWSB << 7); - i.commonXeHPC.swsb = SWSBInfoXeHPC(mod.getSWSB(), opcode).raw(); - i.commonXeHPC.execSize = mod.parts.eSizeField; - i.commonXeHPC.flagReg = (mod.parts.flagRegNum1 << 2) | (mod.parts.flagRegNum << 1) | mod.parts.flagSubRegNum; - i.commonXeHPC.execOffset = mod.parts.chanOff >> 1; - i.commonXeHPC.predCtrl = mod.parts.predCtrl; - i.common.predInv = mod.parts.predInv; - i.common.cmptCtrl = mod.parts.cmptCtrl; - i.common.debugCtrl = mod.parts.debugCtrl; - i.common.maskCtrl = mod.parts.maskCtrl; - i.common.atomicCtrl = mod.parts.threadCtrl; - i.commonXeHPC.dstExt = (dst.isIndirect() ? dst.getOffset() : dst.getByteOffset()) & 1; - i.common.saturate = mod.parts.saturate; + Instruction12 i2; /* separate variable to avoid gcc13 bug */ + i2.common.opcode = static_cast(opcode) | (mod.parts.autoSWSB << 7); + i2.commonXeHPC.swsb = SWSBInfoXeHPC(mod.getSWSB(), opcode).raw(); + i2.commonXeHPC.execSize = mod.parts.eSizeField; + i2.commonXeHPC.flagReg = (mod.parts.flagRegNum1 << 2) | (mod.parts.flagRegNum << 1) | mod.parts.flagSubRegNum; + i2.commonXeHPC.execOffset = mod.parts.chanOff >> 1; + i2.commonXeHPC.predCtrl = mod.parts.predCtrl; + i2.common.predInv = mod.parts.predInv; + i2.common.cmptCtrl = mod.parts.cmptCtrl; + i2.common.debugCtrl = mod.parts.debugCtrl; + i2.common.maskCtrl = mod.parts.maskCtrl; + i2.common.atomicCtrl = mod.parts.threadCtrl; + i2.commonXeHPC.dstExt = (dst.isIndirect() ? dst.getOffset() : dst.getByteOffset()) & 1; + i2.common.saturate = mod.parts.saturate; + i.common = i2.common; } template From cd0d5ff081db43cdf302467847d74235c248b2ad Mon Sep 17 00:00:00 2001 From: rupakroyintel Date: Tue, 29 Oct 2024 11:10:26 -0700 Subject: [PATCH 08/19] minor changes --- .../matmul/int4_weight_decompression_cmnts.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/examples/tutorials/matmul/int4_weight_decompression_cmnts.cpp b/examples/tutorials/matmul/int4_weight_decompression_cmnts.cpp index 6eeb7be57f0..fd20ff13f92 100644 --- a/examples/tutorials/matmul/int4_weight_decompression_cmnts.cpp +++ b/examples/tutorials/matmul/int4_weight_decompression_cmnts.cpp @@ -74,7 +74,7 @@ void init_vector(std::vector &v) { // Comparing two vectors by calculating their L2 norms and the L2 norm of their // difference Checking if the difference is within a calculated threshold The // function returns 0 if the vectors are considered similar, otherwise it -// returns 1. --Rupak +// returns 1. int compare_vectors(const std::vector &v1, const std::vector &v2, int64_t K, const char *message) { double v1_l2 = 0, diff_l2 = 0; @@ -134,7 +134,7 @@ void ref_compute_matmul_f32(int64_t M, int64_t N, int64_t K, int64_t G, } // Create a MatMul primitive descriptor for the following op: -// C_f32 = A_f32 * (B_s4 - zp_B) * sc_B[:] --Rupak +// C_f32 = A_f32 * (B_s4 - zp_B) * sc_B[:] matmul::primitive_desc matmul_pd_create( int64_t M, int64_t N, int64_t K, int64_t G, const engine &eng) { @@ -159,8 +159,7 @@ matmul::primitive_desc matmul_pd_create( // points are applied. For example: // - mask = (1 << 0) would apply zero points only along the K dimension. // - mask = (1 << 1) would apply zero points only along the N dimension. - int mask = (1 << 0) - + (1 << 1); // zero points both along K and N dimensions --Rupak + int mask = (1 << 0) + (1 << 1); // zero points both along K and N dimensions memory::dims groups = {}; attr.set_zero_points(DNNL_ARG_WEIGHTS, mask, groups, memory::data_type::s4); @@ -201,7 +200,6 @@ void weights_decompression_matmul(int64_t M, int64_t N, int64_t K, int64_t G, // Pre-packed zp stored as int4 // A unique zero point is used for each weight in this example // Allocates memory for zp_B_s4_mem with specified dimensions and data type. - // --Rupak memory zp_B_s4_mem({{K, N}, memory::data_type::s4, {1, K}}, eng); { memory zp_B_f32_mem({{K, N}, memory::data_type::f32, {1, K}}, eng); @@ -227,7 +225,7 @@ void weights_decompression_matmul(int64_t M, int64_t N, int64_t K, int64_t G, } // Compares the results of reference matrix multiplication and oneDNN weights -// decompression. --Rupak +// decompression. void compare_ref_and_weights_decompression(engine::kind engine_kind) { engine eng(engine_kind, 0); @@ -264,4 +262,4 @@ int main(int argc, char **argv) { engine::kind engine_kind = parse_engine_kind(argc, argv); return handle_example_errors( compare_ref_and_weights_decompression, engine_kind); -} \ No newline at end of file +} From 3b22a818770d704d5a73ee0e40acf61d8d2394bc Mon Sep 17 00:00:00 2001 From: Simon Ewing Date: Tue, 29 Oct 2024 10:56:33 -0700 Subject: [PATCH 09/19] xe: gemm: skip locking flag register on no-load blocks --- src/gpu/intel/jit/gemm/generator/pieces/allocators.cpp | 7 +++++++ src/gpu/intel/jit/gemm/generator/pieces/allocators.hpp | 2 +- src/gpu/intel/jit/gemm/generator/pieces/copy.cxx | 2 +- src/gpu/intel/jit/gemm/generator/pieces/matrix_access.cxx | 4 +++- 4 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/gpu/intel/jit/gemm/generator/pieces/allocators.cpp b/src/gpu/intel/jit/gemm/generator/pieces/allocators.cpp index 1fd0819040c..58ba6e54831 100644 --- a/src/gpu/intel/jit/gemm/generator/pieces/allocators.cpp +++ b/src/gpu/intel/jit/gemm/generator/pieces/allocators.cpp @@ -99,6 +99,13 @@ FlagRegister VirtualFlagAllocator::assignPhysical(VirtualFlag vflag) return pflag.toPhysical(); } +bool VirtualFlagAllocator::lock(VirtualFlag vflag, bool allowAlreadyLocked) { + bool wasLocked = isLocked(vflag); + if (wasLocked && !allowAlreadyLocked) stub("Illegally locking an already-locked flag register"); + locked |= mask(vflag); + return wasLocked; +} + bool VirtualFlagAllocator::canLock(int n) const { uint8_t unlocked = ~locked & ((1 << nflag) - 1); diff --git a/src/gpu/intel/jit/gemm/generator/pieces/allocators.hpp b/src/gpu/intel/jit/gemm/generator/pieces/allocators.hpp index 783a1c680c5..0e3cf11e637 100644 --- a/src/gpu/intel/jit/gemm/generator/pieces/allocators.hpp +++ b/src/gpu/intel/jit/gemm/generator/pieces/allocators.hpp @@ -78,7 +78,7 @@ class VirtualFlagAllocator { bool isVirtual(VirtualFlag vflag) { return (vflag.idx >= nflag); } - bool lock(VirtualFlag vflag) { bool wasLocked = isLocked(vflag); locked |= mask(vflag); return wasLocked; } + bool lock(VirtualFlag vflag, bool allowAlreadyLocked = false); void unlock(VirtualFlag vflag) { locked &= ~mask(vflag); } bool isLocked(VirtualFlag vflag) const { return !(~locked & mask(vflag)); } bool canLock(int n = 1) const; diff --git a/src/gpu/intel/jit/gemm/generator/pieces/copy.cxx b/src/gpu/intel/jit/gemm/generator/pieces/copy.cxx index 2be36f87020..93c12505f44 100644 --- a/src/gpu/intel/jit/gemm/generator/pieces/copy.cxx +++ b/src/gpu/intel/jit/gemm/generator/pieces/copy.cxx @@ -172,7 +172,7 @@ void BLASKernelGenerator::copyExecute(CopyPlan &&plan, CommonState &state) if (!state.vflagsEnabled()) for (int i = 0; i < nflag; i++) if (!raVFlag0.isFree(VirtualFlag{i})) - raVFlag0.lock(VirtualFlag{i}); + raVFlag0.lock(VirtualFlag{i}, true); auto raVFlag = raVFlag0; // If we have enough free flags, use those. diff --git a/src/gpu/intel/jit/gemm/generator/pieces/matrix_access.cxx b/src/gpu/intel/jit/gemm/generator/pieces/matrix_access.cxx index 91c4ffcd535..cbee13ed0ff 100644 --- a/src/gpu/intel/jit/gemm/generator/pieces/matrix_access.cxx +++ b/src/gpu/intel/jit/gemm/generator/pieces/matrix_access.cxx @@ -621,6 +621,8 @@ void BLASKernelGenerator::prepareSeriesRegisterBlockMasking(const vector::prepareSeriesRegisterBlockMasking(const vector Date: Tue, 29 Oct 2024 21:34:05 -0700 Subject: [PATCH 10/19] make changes based on review --- .../matmul/int4_weight_decompression.cpp | 265 ++++++++++++++++++ 1 file changed, 265 insertions(+) create mode 100644 examples/tutorials/matmul/int4_weight_decompression.cpp diff --git a/examples/tutorials/matmul/int4_weight_decompression.cpp b/examples/tutorials/matmul/int4_weight_decompression.cpp new file mode 100644 index 00000000000..b879e7d3025 --- /dev/null +++ b/examples/tutorials/matmul/int4_weight_decompression.cpp @@ -0,0 +1,265 @@ +/******************************************************************************* + * Copyright 2023-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ +/// @example int4_weights_decompression.cpp +/// > Annotated version: @ref int4_weights_decompression.cpp +/// +/// @page int4_weights_decompression +/// C++ API example demonstrating how one can use +/// [MatMul](@ref dev_guide_matmul) with int4 compressed weights. +/// +/// Concepts: +/// - AWQ (activation-aware quantization) +/// - Scales: dnnl::primitive_attr::set_scales() +/// - Zero points: dnnl::primitive_attr::set_zero_points() +/// - [Operation fusion](@ref dev_guide_attributes_post_ops) +/// - Create primitive once, use multiple times +/// - Weights pre-packing: use #dnnl::memory::format_tag::any +/// +/// @page int4_weights_decompression_matmul_cpp MatMul Tutorial: weights +/// decompression +/// @copydetails int4_weights_decompression_matmul_cpp +/// +/// Assumptions: +/// 1. The shape of the weights (matrix \f$B(K, N)\f$) is known in advance, the +/// data type is `int4` and shifted from 0 (i.e. the zero point is not 0). +/// 2. The source matrix \f$A\f$ and destination matrix \f$C\f$ have floating +/// point data type. +/// 3. Scaling (re-quantization) factor specified at run-time only. +/// +/// Since the shape of weights is known in advance, the MatMul weights can be +/// created with format tag #dnnl::memory::format_tag::any to enable the library +/// to choose the most appropriate layout for best performance. +/// +/// @warning +/// The format tag #dnnl::memory::format_tag::any doesn't work for memory +/// descriptors that have one or more unknown dimensions and/or strides. +/// +/// @include int4_weight_decompression.cpp +#include +#include +#include +#include +#include +#include +#include +#include + +#include "oneapi/dnnl/dnnl.hpp" + +#include "example_utils.hpp" + +using namespace dnnl; + +namespace { + +void init_vector(std::vector &v) { + std::mt19937 gen; + std::uniform_real_distribution u(0, 1); + for (auto &e : v) + e = u(gen); +} +// Comparing two vectors by calculating their L2 norms and the L2 norm of their +// difference Checking if the difference is within a calculated threshold The +// function returns 0 if the vectors are considered similar, otherwise it +// returns 1. +int compare_vectors(const std::vector &v1, const std::vector &v2, + int64_t K, const char *message) { + double v1_l2 = 0, diff_l2 = 0; + for (size_t n = 0; n < v1.size(); ++n) { + float diff = v1[n] - v2[n]; + v1_l2 += v1[n] * v1[n]; + diff_l2 += diff * diff; + } + + v1_l2 = std::sqrt(v1_l2); + diff_l2 = std::sqrt(diff_l2); + + // Finding the reasonable (tight and accurate) threshold is quite difficult + // problem. + // The implementation testing might also use special data filling to + // alleviate issues related to the finite precision arithmetic. + // However, in simple cases the machine epsilon multiplied by log(K) should + // work reasonably well. + const double threshold = std::numeric_limits::epsilon() + * std::log(std::max(2., (double)K)); + bool ok = diff_l2 <= threshold * v1_l2; + + printf("%s\n\tL2 Norms" + "\n\t\tReference matrix:%g\n\t\tError:%g\n\t\tRelative_error:%g\n" + "\tAccuracy check: %s\n", + message, v1_l2, diff_l2, diff_l2 / v1_l2, ok ? "OK" : "FAILED"); + + return ok ? 0 : 1; +} + +} // namespace + +// Floating point MatMul +// Inputs: +// - Shape: M, N, K +// - Matrices A and B +// Outputs: +// - Matrix C +void ref_compute_matmul_f32(int64_t M, int64_t N, int64_t K, int64_t G, + std::vector &A_f32, std::vector &B_f32, + std::vector &zp_B_f32, std::vector &sc_B, + std::vector &C_f32) { + // Perform the GEMM operation + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + for (int k = 0; k < K; ++k) { + // Decompress the weight + int64_t idx1 = k * N + n; + int64_t idx2 = (k / G) * N + n; + float decompressed_B + = (B_f32[idx1] - zp_B_f32[idx1]) * sc_B[idx2]; + // Perform the multiplication and accumulation + C_f32[m * N + n] += A_f32[m * K + k] * decompressed_B; + } + } + } +} + +// Create a MatMul primitive descriptor for the following op: +// C_f32 = A_f32 * (B_s4 - zp_B) * sc_B[:] +matmul::primitive_desc matmul_pd_create( + int64_t M, int64_t N, int64_t K, int64_t G, const engine &eng) { + + memory::desc a_md({M, K}, memory::data_type::f32, {K, 1}); // M x K layout + memory::desc b_md({K, N}, memory::data_type::s4, + memory::format_tag::any); // K x N layout + memory::desc c_md({M, N}, memory::data_type::f32, {N, 1}); // M x N layout + + // Create attributes and indicate that the alpha and zero points are + // runtime parameters + primitive_attr attr; + // Set scales with multiple scales along K and N dimensions and with groups + // along K. + attr.set_scales(DNNL_ARG_WEIGHTS, + /* mask */ (1 << 0) + (1 << 1), {G, 1}, memory::data_type::f32); + + // Set zero points with s4 data type. + // The mask determines which dimensions the zero points are applied to. + // Current mask value (1 << 0) + (1 << 1) means zero points are applied + // both along K and N dimensions. + // Changing the mask value would alter the dimensions along which the zero + // points are applied. For example: + // - mask = (1 << 0) would apply zero points only along the K dimension. + // - mask = (1 << 1) would apply zero points only along the N dimension. + int mask = (1 << 0) + (1 << 1); // zero points both along K and N dimensions + memory::dims groups = {}; + attr.set_zero_points(DNNL_ARG_WEIGHTS, mask, groups, memory::data_type::s4); + + // Set fpmath mode with `apply_to_int=true` to apply fpmath mode behavior to + // integral primitives (in this example, matmul). + attr.set_fpmath_mode(fpmath_mode::f16, true); + + // Create a MatMul primitive descriptor + return matmul::primitive_desc(eng, a_md, b_md, c_md, attr); +} + +// Function to perform matrix multiplication with int4 weights decompression +// using oneDNN +void weights_decompression_matmul(int64_t M, int64_t N, int64_t K, int64_t G, + std::vector &A_f32, std::vector &B_f32, + std::vector &zp_B_f32, std::vector &sc_B, + std::vector &C_f32, const engine &eng) { + auto matmul_pd = matmul_pd_create(M, N, K, G, eng); + stream s(eng); + + // Pre-packed weights stored as int4 + memory B_s4_mem(matmul_pd.weights_desc(), eng); + { + memory B_f32_mem( + {{K, N}, memory::data_type::f32, memory::format_tag::ab}, eng); + write_to_dnnl_memory(B_f32.data(), B_f32_mem); + reorder(B_f32_mem, B_s4_mem).execute(s, B_f32_mem, B_s4_mem); + s.wait(); + } + matmul matmul_p(matmul_pd); + + // input of the current layer / operation + memory A_f32_mem({{M, K}, memory::data_type::f32, {K, 1}}, eng); + // De-quantization parameters (eg. Scale and Shift) + const int64_t n_groups = K / G; + memory sc_B_mem({{N, n_groups}, memory::data_type::f32, {1, N}}, eng); + + // Pre-packed zp stored as int4 + // A unique zero point is used for each weight in this example + // Allocates memory for zp_B_s4_mem with specified dimensions and data type. + memory zp_B_s4_mem({{K, N}, memory::data_type::s4, {1, K}}, eng); + { + memory zp_B_f32_mem({{K, N}, memory::data_type::f32, {1, K}}, eng); + write_to_dnnl_memory(zp_B_f32.data(), zp_B_f32_mem); + reorder(zp_B_f32_mem, zp_B_s4_mem) + .execute(s, zp_B_f32_mem, zp_B_s4_mem); + s.wait(); + } + + write_to_dnnl_memory(A_f32.data(), A_f32_mem); + write_to_dnnl_memory(sc_B.data(), sc_B_mem); + + // output - no initialization required + memory C_f32_mem({{M, N}, memory::data_type::f32, {N, 1}}, eng); + + matmul_p.execute(s, + {{DNNL_ARG_SRC, A_f32_mem}, {DNNL_ARG_WEIGHTS, B_s4_mem}, + {DNNL_ARG_DST, C_f32_mem}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, sc_B_mem}, + {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, + zp_B_s4_mem}}); + s.wait(); +} + +// Compares the results of reference matrix multiplication and oneDNN weights +// decompression. +void compare_ref_and_weights_decompression(engine::kind engine_kind) { + engine eng(engine_kind, 0); + + // MatMul parameters + const int64_t M = 1, N = 4096, K = 1024; + // Quantization Group size for scales + const int64_t G = 64; + + // Prepare matrices + std::vector A_f32(M * K), C_ref(M * N), sc_B(K * N / G); + std::vector B_f32(K * N); + std::vector zp_B_f32(K * N); + init_vector(A_f32); + init_vector(B_f32); + init_vector(sc_B); + init_vector(zp_B_f32); + init_vector(C_ref); + std::vector C_onednn = C_ref; + + // Compute _true_ C_ref result + ref_compute_matmul_f32(M, N, K, G, A_f32, B_f32, zp_B_f32, sc_B, C_ref); + + // Compute _true_ C_onednn result + weights_decompression_matmul( + M, N, K, G, A_f32, B_f32, zp_B_f32, sc_B, C_onednn, eng); + + int rc = 0; + rc |= compare_vectors( + C_ref, C_onednn, K, "Compare ref vs oneDNN weights decompression"); + if (rc) throw std::logic_error("The resulting matrices diverged too much."); +} + +int main(int argc, char **argv) { + engine::kind engine_kind = parse_engine_kind(argc, argv); + return handle_example_errors( + compare_ref_and_weights_decompression, engine_kind); +} From 0e6591e3bb2d6ebbd3582c270ede8f101c168df4 Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Tue, 29 Oct 2024 15:15:30 -0700 Subject: [PATCH 11/19] xe: jit: gemm: fixup negative input handling after f7783a86 --- src/gpu/intel/jit/gemm/include/type.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gpu/intel/jit/gemm/include/type.hpp b/src/gpu/intel/jit/gemm/include/type.hpp index 1965bb641f4..773ee6207a5 100644 --- a/src/gpu/intel/jit/gemm/include/type.hpp +++ b/src/gpu/intel/jit/gemm/include/type.hpp @@ -100,7 +100,7 @@ class Type { constexpr Type baseType() const { return *this; } template constexpr friend int operator*(U a, Type t) { - return t.isInt4() ? int((a + 1) / 2) : int(a * (U(1) << t.log2Size())); + return t.isInt4() ? int((unsigned(a) + 1) >> 1) : int(a * (U(1) << t.log2Size())); } template constexpr friend int operator*(Type t, U a) { return a * t; } template friend int operator*=(U &a, Type t) { a = a * t; return a; } From 343203c9916841b58b270843a319b1630f4627df Mon Sep 17 00:00:00 2001 From: rupakroyintel Date: Tue, 29 Oct 2024 21:49:13 -0700 Subject: [PATCH 12/19] remove file int4_weight_decompression_cmnts.cpp --- .../int4_weight_decompression_cmnts.cpp | 265 ------------------ 1 file changed, 265 deletions(-) delete mode 100644 examples/tutorials/matmul/int4_weight_decompression_cmnts.cpp diff --git a/examples/tutorials/matmul/int4_weight_decompression_cmnts.cpp b/examples/tutorials/matmul/int4_weight_decompression_cmnts.cpp deleted file mode 100644 index fd20ff13f92..00000000000 --- a/examples/tutorials/matmul/int4_weight_decompression_cmnts.cpp +++ /dev/null @@ -1,265 +0,0 @@ -/******************************************************************************* - * Copyright 2023-2024 Intel Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - *******************************************************************************/ -/// @example int4_weights_decompression.cpp -/// > Annotated version: @ref int4_weights_decompression.cpp -/// -/// @page int4_weights_decompression -/// C++ API example demonstrating how one can use -/// [MatMul](@ref dev_guide_matmul) with int4 compressed weights. -/// -/// Concepts: -/// - AWQ (activation-aware quantization) -/// - Scales: dnnl::primitive_attr::set_scales() -/// - Zero points: dnnl::primitive_attr::set_zero_points() -/// - [Operation fusion](@ref dev_guide_attributes_post_ops) -/// - Create primitive once, use multiple times -/// - Weights pre-packing: use #dnnl::memory::format_tag::any -/// -/// @page int4_weights_decompression_matmul_cpp MatMul Tutorial: weights -/// decompression -/// @copydetails int4_weights_decompression_matmul_cpp -/// -/// Assumptions: -/// 1. The shape of the weights (matrix \f$B(K, N)\f$) is known in advance, the -/// data type is `int4` and shifted from 0 (i.e. the zero point is not 0). -/// 2. The source matrix \f$A\f$ and destination matrix \f$C\f$ have floating -/// point data type. -/// 3. Scaling (re-quantization) factor specified at run-time only. -/// -/// Since the shape of weights is known in advance, the MatMul weights can be -/// created with format tag #dnnl::memory::format_tag::any to enable the library -/// to choose the most appropriate layout for best performance. -/// -/// @warning -/// The format tag #dnnl::memory::format_tag::any doesn't work for memory -/// descriptors that have one or more unknown dimensions and/or strides. -/// -/// @include weights_decompression_matmul.cpp -#include -#include -#include -#include -#include -#include -#include -#include - -#include "oneapi/dnnl/dnnl.hpp" - -#include "example_utils.hpp" - -using namespace dnnl; - -namespace { - -void init_vector(std::vector &v) { - std::mt19937 gen; - std::uniform_real_distribution u(0, 1); - for (auto &e : v) - e = u(gen); -} -// Comparing two vectors by calculating their L2 norms and the L2 norm of their -// difference Checking if the difference is within a calculated threshold The -// function returns 0 if the vectors are considered similar, otherwise it -// returns 1. -int compare_vectors(const std::vector &v1, const std::vector &v2, - int64_t K, const char *message) { - double v1_l2 = 0, diff_l2 = 0; - for (size_t n = 0; n < v1.size(); ++n) { - float diff = v1[n] - v2[n]; - v1_l2 += v1[n] * v1[n]; - diff_l2 += diff * diff; - } - - v1_l2 = std::sqrt(v1_l2); - diff_l2 = std::sqrt(diff_l2); - - // Finding the reasonable (tight and accurate) threshold is quite difficult - // problem. - // The implementation testing might also use special data filling to - // alleviate issues related to the finite precision arithmetic. - // However, in simple cases the machine epsilon multiplied by log(K) should - // work reasonably well. - const double threshold = std::numeric_limits::epsilon() - * std::log(std::max(2., (double)K)); - bool ok = diff_l2 <= threshold * v1_l2; - - printf("%s\n\tL2 Norms" - "\n\t\tReference matrix:%g\n\t\tError:%g\n\t\tRelative_error:%g\n" - "\tAccuracy check: %s\n", - message, v1_l2, diff_l2, diff_l2 / v1_l2, ok ? "OK" : "FAILED"); - - return ok ? 0 : 1; -} - -} // namespace - -// Floating point MatMul -// Inputs: -// - Shape: M, N, K -// - Matrices A and B -// Outputs: -// - Matrix C -void ref_compute_matmul_f32(int64_t M, int64_t N, int64_t K, int64_t G, - std::vector &A_f32, std::vector &B_f32, - std::vector &zp_B_f32, std::vector &sc_B, - std::vector &C_f32) { - // Perform the GEMM operation - for (int m = 0; m < M; ++m) { - for (int n = 0; n < N; ++n) { - for (int k = 0; k < K; ++k) { - // Decompress the weight - int64_t idx1 = k * N + n; - int64_t idx2 = (k / G) * N + n; - float decompressed_B - = (B_f32[idx1] - zp_B_f32[idx1]) * sc_B[idx2]; - // Perform the multiplication and accumulation - C_f32[m * N + n] += A_f32[m * K + k] * decompressed_B; - } - } - } -} - -// Create a MatMul primitive descriptor for the following op: -// C_f32 = A_f32 * (B_s4 - zp_B) * sc_B[:] -matmul::primitive_desc matmul_pd_create( - int64_t M, int64_t N, int64_t K, int64_t G, const engine &eng) { - - memory::desc a_md({M, K}, memory::data_type::f32, {K, 1}); // M x K layout - memory::desc b_md({K, N}, memory::data_type::s4, - memory::format_tag::any); // K x N layout - memory::desc c_md({M, N}, memory::data_type::f32, {N, 1}); // M x N layout - - // Create attributes and indicate that the alpha and zero points are - // runtime parameters - primitive_attr attr; - // Set scales with multiple scales along K and N dimensions and with groups - // along K. - attr.set_scales(DNNL_ARG_WEIGHTS, - /* mask */ (1 << 0) + (1 << 1), {G, 1}, memory::data_type::f32); - - // Set zero points with s4 data type. - // The mask determines which dimensions the zero points are applied to. - // Current mask value (1 << 0) + (1 << 1) means zero points are applied - // both along K and N dimensions. - // Changing the mask value would alter the dimensions along which the zero - // points are applied. For example: - // - mask = (1 << 0) would apply zero points only along the K dimension. - // - mask = (1 << 1) would apply zero points only along the N dimension. - int mask = (1 << 0) + (1 << 1); // zero points both along K and N dimensions - memory::dims groups = {}; - attr.set_zero_points(DNNL_ARG_WEIGHTS, mask, groups, memory::data_type::s4); - - // Set fpmath mode with `apply_to_int=true` to apply fpmath mode behavior to - // integral primitives (in this example, matmul). - attr.set_fpmath_mode(fpmath_mode::f16, true); - - // Create a MatMul primitive descriptor - return matmul::primitive_desc(eng, a_md, b_md, c_md, attr); -} - -// Function to perform matrix multiplication with int4 weights decompression -// using oneDNN --Rupka -void weights_decompression_matmul(int64_t M, int64_t N, int64_t K, int64_t G, - std::vector &A_f32, std::vector &B_f32, - std::vector &zp_B_f32, std::vector &sc_B, - std::vector &C_f32, const engine &eng) { - auto matmul_pd = matmul_pd_create(M, N, K, G, eng); - stream s(eng); - - // Pre-packed weights stored as int4 - memory B_s4_mem(matmul_pd.weights_desc(), eng); - { - memory B_f32_mem( - {{K, N}, memory::data_type::f32, memory::format_tag::ab}, eng); - write_to_dnnl_memory(B_f32.data(), B_f32_mem); - reorder(B_f32_mem, B_s4_mem).execute(s, B_f32_mem, B_s4_mem); - s.wait(); - } - matmul matmul_p(matmul_pd); - - // input of the current layer / operation - memory A_f32_mem({{M, K}, memory::data_type::f32, {K, 1}}, eng); - // De-quantization parameters (eg. Scale and Shift) - const int64_t n_groups = K / G; - memory sc_B_mem({{N, n_groups}, memory::data_type::f32, {1, N}}, eng); - - // Pre-packed zp stored as int4 - // A unique zero point is used for each weight in this example - // Allocates memory for zp_B_s4_mem with specified dimensions and data type. - memory zp_B_s4_mem({{K, N}, memory::data_type::s4, {1, K}}, eng); - { - memory zp_B_f32_mem({{K, N}, memory::data_type::f32, {1, K}}, eng); - write_to_dnnl_memory(zp_B_f32.data(), zp_B_f32_mem); - reorder(zp_B_f32_mem, zp_B_s4_mem) - .execute(s, zp_B_f32_mem, zp_B_s4_mem); - s.wait(); - } - - write_to_dnnl_memory(A_f32.data(), A_f32_mem); - write_to_dnnl_memory(sc_B.data(), sc_B_mem); - - // output - no initialization required - memory C_f32_mem({{M, N}, memory::data_type::f32, {N, 1}}, eng); - - matmul_p.execute(s, - {{DNNL_ARG_SRC, A_f32_mem}, {DNNL_ARG_WEIGHTS, B_s4_mem}, - {DNNL_ARG_DST, C_f32_mem}, - {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, sc_B_mem}, - {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, - zp_B_s4_mem}}); - s.wait(); -} - -// Compares the results of reference matrix multiplication and oneDNN weights -// decompression. -void compare_ref_and_weights_decompression(engine::kind engine_kind) { - engine eng(engine_kind, 0); - - // MatMul parameters - const int64_t M = 1, N = 4096, K = 1024; - // Quantization Group size for scales - const int64_t G = 64; - - // Prepare matrices - std::vector A_f32(M * K), C_ref(M * N), sc_B(K * N / G); - std::vector B_f32(K * N); - std::vector zp_B_f32(K * N); - init_vector(A_f32); - init_vector(B_f32); - init_vector(sc_B); - init_vector(zp_B_f32); - init_vector(C_ref); - std::vector C_onednn = C_ref; - - // Compute _true_ C_ref result - ref_compute_matmul_f32(M, N, K, G, A_f32, B_f32, zp_B_f32, sc_B, C_ref); - - // Compute _true_ C_onednn result - weights_decompression_matmul( - M, N, K, G, A_f32, B_f32, zp_B_f32, sc_B, C_onednn, eng); - - int rc = 0; - rc |= compare_vectors( - C_ref, C_onednn, K, "Compare ref vs oneDNN weights decompression"); - if (rc) throw std::logic_error("The resulting matrices diverged too much."); -} - -int main(int argc, char **argv) { - engine::kind engine_kind = parse_engine_kind(argc, argv); - return handle_example_errors( - compare_ref_and_weights_decompression, engine_kind); -} From 734148505a2da94c176c1649e120943a34fcea7a Mon Sep 17 00:00:00 2001 From: "Wang, Zhitao" Date: Fri, 30 Aug 2024 02:49:48 +0000 Subject: [PATCH 13/19] src: graph: extend setting fpmath mode with apply_to_int --- include/oneapi/dnnl/dnnl_graph.h | 22 ++++++++++ include/oneapi/dnnl/dnnl_graph.hpp | 35 +++++++++++++++ .../backend/dnnl/dnnl_partition_impl.cpp | 4 +- .../backend/dnnl/dnnl_partition_impl.hpp | 4 +- src/graph/backend/dnnl/fusion_info.hpp | 9 ++-- .../dnnl/kernels/sdp_primitive_config.cpp | 3 +- src/graph/backend/dnnl/op_executable.cpp | 21 ++++++--- src/graph/backend/dnnl/subgraph.cpp | 5 ++- src/graph/backend/dnnl/subgraph.hpp | 3 +- src/graph/interface/graph.cpp | 23 ++++++++++ src/graph/interface/graph.hpp | 27 +++++++++--- src/graph/interface/graph_attr.hpp | 44 +++++++++++++++++++ src/graph/interface/partition.cpp | 4 +- src/graph/interface/partition.hpp | 3 +- src/graph/interface/partition_impl.hpp | 17 ++++--- src/graph/utils/verbose.cpp | 4 +- tests/gtests/graph/api/test_cpp_api_graph.cpp | 18 +++++++- 17 files changed, 210 insertions(+), 36 deletions(-) create mode 100644 src/graph/interface/graph_attr.hpp diff --git a/include/oneapi/dnnl/dnnl_graph.h b/include/oneapi/dnnl/dnnl_graph.h index a0d465982ca..77f7b46b48f 100644 --- a/include/oneapi/dnnl/dnnl_graph.h +++ b/include/oneapi/dnnl/dnnl_graph.h @@ -590,6 +590,28 @@ dnnl_status_t DNNL_API dnnl_graph_graph_create_with_fpmath_mode( /// otherwise. dnnl_status_t DNNL_API dnnl_graph_graph_destroy(dnnl_graph_graph_t graph); +/// Set the floating point math mode for a graph. +/// +/// @param graph The target graph. +/// @param mode The floating-point math mode. +/// @param apply_to_int The flag that controls whether to use floating-point +/// arithmetic for integral operations. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_graph_set_fpmath_mode( + dnnl_graph_graph_t graph, dnnl_fpmath_mode_t mode, int apply_to_int); + +/// Get the floating point math mode for a graph. +/// +/// @param graph The target graph. +/// @param mode The floating-point math mode. +/// @param apply_to_int The flag that controls whether to use floating-point +/// arithmetic for integral operations. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_graph_get_fpmath_mode( + dnnl_graph_graph_t graph, dnnl_fpmath_mode_t *mode, int *apply_to_int); + /// Adds an operation into a graph. The API will return failure if the operator /// has already been added to the graph or the operation cannot pass the schema /// check in the library (eg. input and output numbers and data types, the diff --git a/include/oneapi/dnnl/dnnl_graph.hpp b/include/oneapi/dnnl/dnnl_graph.hpp index 1d178e07973..2e124ce6feb 100644 --- a/include/oneapi/dnnl/dnnl_graph.hpp +++ b/include/oneapi/dnnl/dnnl_graph.hpp @@ -1373,6 +1373,10 @@ class graph : public graph_handle { /// mode. All partitions returned from the graph will inherit the engine /// kind and floating-point math mode. /// + /// Setting the floating-point math mode enables automatic down-conversion + /// of inputs for the given graph, promoting speedup by using + /// lower-precision data types when available. + /// /// @param engine_kind Engine kind. /// @param mode Floating-point math mode. graph(engine::kind engine_kind, fpmath_mode mode) { @@ -1384,6 +1388,37 @@ class graph : public graph_handle { reset(g); } + /// Set the floating point math mode for a graph. Users can enforce the + /// graph to comply with the mode by specifying a boolean flag with the + /// setter function. + /// + /// @param mode The floating-point math mode. + /// @param apply_to_int The flag that controls whether to use + /// floating-point arithmetic for integral operations. + void set_fpmath_mode(fpmath_mode mode, bool apply_to_int = false) { + error::wrap_c_api(dnnl_graph_graph_set_fpmath_mode( + get(), convert_to_c(mode), apply_to_int), + "could not set fpmath mode graph attribute"); + } + + /// Get the floating point math mode and the boolean flag that specifies + /// whether the graph will be enforced to comply the mode. + /// + /// @param mode The floating-point math mode. + /// @param apply_to_int The flag that controls whether to use + /// floating-point arithmetic for integral operations. + void get_fpmath_mode(fpmath_mode &mode, bool &apply_to_int) const { + dnnl_fpmath_mode_t c_mode; + int c_apply_to_int; + + error::wrap_c_api(dnnl_graph_graph_get_fpmath_mode( + get(), &c_mode, &c_apply_to_int), + "could not get fpmath mode graph attribute"); + + mode = fpmath_mode(c_mode); + apply_to_int = static_cast(c_apply_to_int); + } + /// Adds an op into the graph to construct a computational DAG. The API will /// return failure if the operator has already been added to the graph or /// the operation cannot pass the schema check in the library (eg. input and diff --git a/src/graph/backend/dnnl/dnnl_partition_impl.cpp b/src/graph/backend/dnnl/dnnl_partition_impl.cpp index 4b8148041e0..e8f0ce1cee9 100644 --- a/src/graph/backend/dnnl/dnnl_partition_impl.cpp +++ b/src/graph/backend/dnnl/dnnl_partition_impl.cpp @@ -130,7 +130,9 @@ status_t dnnl_partition_impl_t::compile( // Dispatch to fake kernel if one of the output dimensions is zero. const std::vector> &fused_op = part->get_ops(); - auto agraph = graph_t(fused_op, get_engine_kind(), get_fpmath_mode()); + auto fpm = get_fpmath_mode(); + auto agraph = graph_t(fused_op, get_engine_kind()); + agraph.set_fpmath_mode(fpm.mode_, fpm.apply_to_int_); agraph.set_user_inputs_outputs(inputs, outputs); agraph.infer_shape(); for (const auto &val : agraph.get_output_values()) { diff --git a/src/graph/backend/dnnl/dnnl_partition_impl.hpp b/src/graph/backend/dnnl/dnnl_partition_impl.hpp index 19d69d85438..41dabcd8f94 100644 --- a/src/graph/backend/dnnl/dnnl_partition_impl.hpp +++ b/src/graph/backend/dnnl/dnnl_partition_impl.hpp @@ -90,8 +90,8 @@ class dnnl_partition_impl_t : public partition_impl_t { friend class dnnl_backend_t; public: - dnnl_partition_impl_t(engine_kind_t engine_kind, fpmath_mode_t fpmath_mode, - partition_kind_t pkind) + dnnl_partition_impl_t(engine_kind_t engine_kind, + const fpmath_t &fpmath_mode, partition_kind_t pkind) : partition_impl_t(engine_kind, fpmath_mode, pkind) {} ~dnnl_partition_impl_t() override = default; diff --git a/src/graph/backend/dnnl/fusion_info.hpp b/src/graph/backend/dnnl/fusion_info.hpp index 09840769cd8..0ae2f68f570 100644 --- a/src/graph/backend/dnnl/fusion_info.hpp +++ b/src/graph/backend/dnnl/fusion_info.hpp @@ -26,6 +26,7 @@ #include #include "graph/interface/c_types_map.hpp" +#include "graph/interface/graph_attr.hpp" #include "graph/interface/op.hpp" #include "graph/interface/value.hpp" #include "graph/utils/utils.hpp" @@ -267,8 +268,8 @@ class fusion_info_t { // info key to query it out from the manager. class fusion_info_mgr_t { public: - fusion_info_mgr_t(fpmath_mode_t fpm_mode = fpmath_mode::strict, - bool can_use_blocked_layout = false) + fusion_info_mgr_t( + graph::fpmath_t fpm_mode = {}, bool can_use_blocked_layout = false) : fpmath_mode_(fpm_mode) , can_use_blocked_layout_(can_use_blocked_layout) {} @@ -298,13 +299,13 @@ class fusion_info_mgr_t { return data_[k]; } - fpmath_mode_t get_fpmath_mode() const { return fpmath_mode_; } + const fpmath_t &get_fpmath_mode() const { return fpmath_mode_; } bool get_use_blocked_layout() const { return can_use_blocked_layout_; } private: std::vector data_; // specified floating-point math mode for all fusions - fpmath_mode_t fpmath_mode_ {}; + fpmath_t fpmath_mode_ {}; bool can_use_blocked_layout_; }; diff --git a/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp b/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp index 89533df6bca..f20d4b57b4c 100644 --- a/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp +++ b/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp @@ -218,7 +218,8 @@ status_t sdp_primitive_config_t::init(std::shared_ptr &sg, auto &mgr = sg->fusion_info_mgr_; attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); - attr.set_fpmath_mode(static_cast(mgr.get_fpmath_mode())); + attr.set_fpmath_mode( + static_cast(mgr.get_fpmath_mode().mode_)); CHECK(create_sdpa_pd(sdpa_pd_, p_engine.get(), md_q.get(), md_k.get(), md_v.get(), md_dst.get(), md_mask.get(), scale_dt, invert_scale_, diff --git a/src/graph/backend/dnnl/op_executable.cpp b/src/graph/backend/dnnl/op_executable.cpp index 73fc14f39f4..d371c7bc80d 100644 --- a/src/graph/backend/dnnl/op_executable.cpp +++ b/src/graph/backend/dnnl/op_executable.cpp @@ -66,8 +66,9 @@ conv_fwd_executable_t::desc_t conv_fwd_executable_t::create_desc( prm_attr = make_dnnl_primitive_attr(op, fusion_info); } prm_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + auto fpmath = mgr.get_fpmath_mode(); prm_attr.set_fpmath_mode( - static_cast(mgr.get_fpmath_mode())); + static_cast(fpmath.mode_), fpmath.apply_to_int_); const bool can_use_blocked_layout = mgr.get_use_blocked_layout(); auto src = make_dnnl_memory_desc( @@ -181,8 +182,9 @@ deconv_fwd_executable_t::desc_t deconv_fwd_executable_t::create_desc( prm_attr = make_dnnl_primitive_attr(op, mgr.get_info(key)); } prm_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + auto fpmath = mgr.get_fpmath_mode(); prm_attr.set_fpmath_mode( - static_cast(mgr.get_fpmath_mode())); + static_cast(fpmath.mode_), fpmath.apply_to_int_); auto src = make_dnnl_memory_desc( op->get_input_value(0)->get_logical_tensor()); @@ -241,8 +243,9 @@ deconv_bwd_data_executable_t::desc_t deconv_bwd_data_executable_t::create_desc( prm_attr = make_dnnl_primitive_attr(op, mgr.get_info(key)); } prm_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + auto fpmath = mgr.get_fpmath_mode(); prm_attr.set_fpmath_mode( - static_cast(mgr.get_fpmath_mode())); + static_cast(fpmath.mode_), fpmath.apply_to_int_); auto diff_dst = make_dnnl_memory_desc( op->get_input_value(0)->get_logical_tensor()); @@ -292,8 +295,9 @@ deconv_bwd_weights_executable_t::create_desc(std::shared_ptr &op, int64_t key = op->get_attr(op_attr::fusion_info_key); prm_attr = make_dnnl_primitive_attr(op, mgr.get_info(key)); } + auto fpmath = mgr.get_fpmath_mode(); prm_attr.set_fpmath_mode( - static_cast(mgr.get_fpmath_mode())); + static_cast(fpmath.mode_), fpmath.apply_to_int_); auto src = make_dnnl_memory_desc( op->get_input_value(0)->get_logical_tensor()); @@ -336,8 +340,9 @@ matmul_executable_t::desc_t matmul_executable_t::create_desc( prm_attr = make_dnnl_primitive_attr(op, mgr.get_info(key)); } prm_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + auto fpmath = mgr.get_fpmath_mode(); prm_attr.set_fpmath_mode( - static_cast(mgr.get_fpmath_mode())); + static_cast(fpmath.mode_), fpmath.apply_to_int_); auto src = make_dnnl_memory_desc( op->get_input_value(0)->get_logical_tensor()); @@ -835,8 +840,9 @@ conv_bwd_data_executable_t::desc_t conv_bwd_data_executable_t::create_desc( prm_attr = make_dnnl_primitive_attr(op, mgr.get_info(key)); } prm_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + auto fpmath = mgr.get_fpmath_mode(); prm_attr.set_fpmath_mode( - static_cast(mgr.get_fpmath_mode())); + static_cast(fpmath.mode_), fpmath.apply_to_int_); const bool can_use_blocked_layout = mgr.get_use_blocked_layout(); auto diff_dst = make_dnnl_memory_desc( @@ -894,8 +900,9 @@ conv_bwd_weights_executable_t::create_desc(std::shared_ptr &op, prm_attr = make_dnnl_primitive_attr(op, mgr.get_info(key)); } prm_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + auto fpmath = mgr.get_fpmath_mode(); prm_attr.set_fpmath_mode( - static_cast(mgr.get_fpmath_mode())); + static_cast(fpmath.mode_), fpmath.apply_to_int_); const bool can_use_blocked_layout = mgr.get_use_blocked_layout(); auto src = make_dnnl_memory_desc( diff --git a/src/graph/backend/dnnl/subgraph.cpp b/src/graph/backend/dnnl/subgraph.cpp index bf686051eb0..512aef1dd42 100644 --- a/src/graph/backend/dnnl/subgraph.cpp +++ b/src/graph/backend/dnnl/subgraph.cpp @@ -48,11 +48,12 @@ using value_ptr = std::shared_ptr; using ltw = logical_tensor_wrapper_t; subgraph_t::subgraph_t(const std::vector &ops, const dnnl::engine &eng, - impl::fpmath_mode_t fpm_mode, bool can_use_blocked_layout, + const graph::fpmath_t &fpm_mode, bool can_use_blocked_layout, bool reset_layout) - : graph_t(ops, static_cast(eng.get_kind()), fpm_mode) + : graph_t(ops, static_cast(eng.get_kind())) , p_engine_(&eng) , fusion_info_mgr_(fpm_mode, can_use_blocked_layout) { + set_fpmath_mode(fpm_mode.mode_, fpm_mode.apply_to_int_); if (reset_layout) { set_all_layout_to_any(get_mutable_ops()); } } diff --git a/src/graph/backend/dnnl/subgraph.hpp b/src/graph/backend/dnnl/subgraph.hpp index 5dcd2416841..efbcba8dc55 100644 --- a/src/graph/backend/dnnl/subgraph.hpp +++ b/src/graph/backend/dnnl/subgraph.hpp @@ -30,6 +30,7 @@ #include "graph/interface/c_types_map.hpp" #include "graph/interface/graph.hpp" +#include "graph/interface/graph_attr.hpp" #include "graph/interface/op.hpp" #include "graph/interface/value.hpp" #include "graph/utils/utils.hpp" @@ -65,7 +66,7 @@ class subgraph_t : public graph_t { public: subgraph_t(const std::vector &ops, const dnnl::engine &eng, - impl::fpmath_mode_t fpm_mode, bool can_use_blocked_layout, + const graph::fpmath_t &fpm_mode, bool can_use_blocked_layout, bool reset_layout); subgraph_t(const std::vector &ops, bool reset_layout = true); diff --git a/src/graph/interface/graph.cpp b/src/graph/interface/graph.cpp index b2a7307ec51..a888166cec1 100644 --- a/src/graph/interface/graph.cpp +++ b/src/graph/interface/graph.cpp @@ -280,6 +280,29 @@ status_t DNNL_API dnnl_graph_graph_destroy(graph_t *graph) { return status::success; } +status_t dnnl_graph_graph_set_fpmath_mode( + dnnl_graph_graph_t graph, dnnl_fpmath_mode_t mode, int apply_to_int) { + + if (graph == nullptr) { return status::invalid_arguments; } + + if (graph->is_finalized()) { return status::invalid_graph; } + + return graph->set_fpmath_mode(mode, apply_to_int); +} + +status_t dnnl_graph_graph_get_fpmath_mode( + dnnl_graph_graph_t graph, dnnl_fpmath_mode_t *mode, int *apply_to_int) { + + if (graph == nullptr) { return status::invalid_arguments; } + if (graph->is_finalized()) { return status::invalid_graph; } + + const auto &fpmath = graph->get_fpmath_mode(); + if (mode) *mode = fpmath.mode_; + if (apply_to_int) *apply_to_int = fpmath.apply_to_int_; + + return status::success; +} + status_t DNNL_API dnnl_graph_add_op(graph_t *graph, op_t *op) { if (graph == nullptr || op == nullptr) { return status::invalid_arguments; } diff --git a/src/graph/interface/graph.hpp b/src/graph/interface/graph.hpp index 739a217ee84..88040917dd6 100644 --- a/src/graph/interface/graph.hpp +++ b/src/graph/interface/graph.hpp @@ -31,6 +31,7 @@ #include "oneapi/dnnl/dnnl_graph.h" #include "graph/interface/c_types_map.hpp" +#include "graph/interface/graph_attr.hpp" #include "graph/interface/logical_tensor.hpp" #include "graph/interface/op.hpp" #include "graph/interface/op_schema.hpp" @@ -74,7 +75,7 @@ struct dnnl_graph_graph : public graph::utils::id_t { graph::engine_kind_t engine_kind_ {}; /*! \brief The floating-point math mode */ - graph::fpmath_mode_t fpmath_mode_ {}; + graph::fpmath_t fpmath_ {}; std::vector> partition_impls_; @@ -85,24 +86,28 @@ struct dnnl_graph_graph : public graph::utils::id_t { public: dnnl_graph_graph(graph::engine_kind_t kind = graph::engine_kind::cpu) - : engine_kind_(kind), fpmath_mode_(dnnl::impl::get_fpmath_mode()) {} + : engine_kind_(kind) { + fpmath_.mode_ = dnnl::impl::get_fpmath_mode(); + } dnnl_graph_graph( graph::engine_kind_t kind, graph::fpmath_mode_t fpmath_mode) - : engine_kind_(kind), fpmath_mode_(fpmath_mode) {} + : engine_kind_(kind) { + fpmath_.mode_ = fpmath_mode; + } // deep copy (except that the partition_impls_ is shallow copy) dnnl_graph_graph(const dnnl_graph_graph &other) : id_t(other) , ops_(deep_copy(other.ops_)) , engine_kind_(other.engine_kind_) - , fpmath_mode_(other.fpmath_mode_) + , fpmath_(other.fpmath_) , partition_impls_(other.partition_impls_) {}; dnnl_graph_graph(const std::vector &ops, graph::engine_kind_t kind = graph::engine_kind::cpu, graph::fpmath_mode_t fpmath_mode = graph::fpmath_mode::strict) - : ops_(ops), engine_kind_(kind), fpmath_mode_(fpmath_mode) {} + : ops_(ops), engine_kind_(kind), fpmath_ {fpmath_mode, false} {}; dnnl_graph_graph &operator=(const dnnl_graph_graph &other) = delete; @@ -110,7 +115,7 @@ struct dnnl_graph_graph : public graph::utils::id_t { graph::engine_kind_t get_engine_kind() const { return engine_kind_; } - graph::fpmath_mode_t get_fpmath_mode() const { return fpmath_mode_; } + const graph::fpmath_t &get_fpmath_mode() const { return fpmath_; } /*! * \brief Check whether an operator can be added @@ -142,6 +147,13 @@ struct dnnl_graph_graph : public graph::utils::id_t { return graph::status::success; } + graph::status_t set_fpmath_mode( + graph::fpmath_mode_t mode, bool apply_to_int) { + fpmath_.mode_ = mode; + fpmath_.apply_to_int_ = apply_to_int; + return graph::status::success; + } + op_t *create_op(graph::op_kind_t kind, std::string name = "") { ops_.push_back(std::make_shared(kind, std::move(name))); return ops_.back().get(); @@ -410,7 +422,8 @@ struct dnnl_graph_graph : public graph::utils::id_t { writer.write_keyvalue("engine_kind", std::string(graph::utils::engine_kind2str(get_engine_kind()))); writer.write_keyvalue("fpmath_mode", - std::string(graph::utils::fpmath_mode2str(get_fpmath_mode()))); + std::string(graph::utils::fpmath_mode2str( + get_fpmath_mode().mode_))); std::vector inputs_id; inputs_id.reserve(get_input_values().size()); for (const auto &val : get_input_values()) { diff --git a/src/graph/interface/graph_attr.hpp b/src/graph/interface/graph_attr.hpp new file mode 100644 index 00000000000..d6d4ddfb9a3 --- /dev/null +++ b/src/graph/interface/graph_attr.hpp @@ -0,0 +1,44 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GRAPH_INTERFACE_GRAPH_ATTR_HPP +#define GRAPH_INTERFACE_GRAPH_ATTR_HPP + +#include "graph/interface/c_types_map.hpp" + +namespace dnnl { +namespace impl { +namespace graph { + +struct fpmath_t { + + fpmath_t(dnnl_fpmath_mode_t mode = fpmath_mode::strict, + bool apply_to_int = false) + : mode_(mode), apply_to_int_(apply_to_int) {} + + bool operator==(const fpmath_t &rhs) const { + return mode_ == rhs.mode_ && apply_to_int_ == rhs.apply_to_int_; + } + + graph::fpmath_mode_t mode_; + bool apply_to_int_ = false; +}; + +} // namespace graph +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/graph/interface/partition.cpp b/src/graph/interface/partition.cpp index 5fccd685c10..dabb0a24537 100644 --- a/src/graph/interface/partition.cpp +++ b/src/graph/interface/partition.cpp @@ -607,7 +607,9 @@ status_t dnnl_graph_partition::compile(compiled_partition_t *cp, auto part = pimpl_->clone(); const std::vector> &fused_op = part->get_ops(); if (fused_op.empty()) return status::invalid_arguments; - auto agraph = graph_t(fused_op, get_engine_kind(), get_fpmath_mode()); + const auto &fpm = get_fpmath_mode(); + auto agraph = graph_t(fused_op, get_engine_kind()); + agraph.set_fpmath_mode(fpm.mode_, fpm.apply_to_int_); // set user given logical tensors and infer shape agraph.set_user_inputs_outputs(tmp_inputs, tmp_outputs); agraph.infer_shape(); diff --git a/src/graph/interface/partition.hpp b/src/graph/interface/partition.hpp index 5c34b50b348..6b569fb01b8 100644 --- a/src/graph/interface/partition.hpp +++ b/src/graph/interface/partition.hpp @@ -28,6 +28,7 @@ #include #include "graph/interface/c_types_map.hpp" +#include "graph/interface/graph_attr.hpp" #include "graph/interface/logical_tensor.hpp" #include "graph/interface/op.hpp" #include "graph/interface/partition_impl.hpp" @@ -86,7 +87,7 @@ struct dnnl_graph_partition : public dnnl::impl::graph::utils::id_t { return pimpl_->get_engine_kind(); } - graph::fpmath_mode_t get_fpmath_mode() const { + const graph::fpmath_t &get_fpmath_mode() const { return pimpl_->get_fpmath_mode(); } diff --git a/src/graph/interface/partition_impl.hpp b/src/graph/interface/partition_impl.hpp index 161ce3615c6..2569358963c 100644 --- a/src/graph/interface/partition_impl.hpp +++ b/src/graph/interface/partition_impl.hpp @@ -30,6 +30,7 @@ #include "common/engine.hpp" #include "graph/interface/c_types_map.hpp" +#include "graph/interface/graph_attr.hpp" #include "graph/interface/logical_tensor.hpp" #include "graph/interface/op.hpp" @@ -66,19 +67,21 @@ class backend_t; class partition_impl_t : public std::enable_shared_from_this { public: - explicit partition_impl_t(engine_kind_t engine_kind, - fpmath_mode_t fpmath_mode, partition_kind_t pkind) + explicit partition_impl_t(engine_kind_t engine_kind, fpmath_t fpmath_mode, + partition_kind_t pkind) : engine_kind_(engine_kind) , fpmath_mode_(fpmath_mode) , pkind_(pkind) , can_use_blocked_layout_(false) {} - explicit partition_impl_t(engine_kind_t engine_kind, - fpmath_mode_t fpmath_mode = fpmath_mode::strict) + explicit partition_impl_t( + engine_kind_t engine_kind, fpmath_t fpmath_mode = {}) : engine_kind_(engine_kind) , fpmath_mode_(fpmath_mode) , pkind_(partition_kind_t::undef) - , can_use_blocked_layout_(false) {} + , can_use_blocked_layout_(false) { + fpmath_mode_.mode_ = graph::fpmath_mode::strict; + } virtual ~partition_impl_t() = default; @@ -86,7 +89,7 @@ class partition_impl_t : public std::enable_shared_from_this { engine_kind_t get_engine_kind() const { return engine_kind_; } /// The getter for fpmath_mode_ - fpmath_mode_t get_fpmath_mode() const { return fpmath_mode_; } + const fpmath_t &get_fpmath_mode() const { return fpmath_mode_; } /// The getter for partition kind partition_kind_t get_kind() const { return pkind_; } @@ -193,7 +196,7 @@ class partition_impl_t : public std::enable_shared_from_this { engine_kind_t engine_kind_; // floating-point math mode - fpmath_mode_t fpmath_mode_; + fpmath_t fpmath_mode_; // Partition kind partition_kind_t pkind_; diff --git a/src/graph/utils/verbose.cpp b/src/graph/utils/verbose.cpp index 7f780198256..bc168c4af00 100644 --- a/src/graph/utils/verbose.cpp +++ b/src/graph/utils/verbose.cpp @@ -226,7 +226,9 @@ std::string init_info_partition(const engine_t *engine, } } - ss << ",fpm:" << fpmath_mode2str(partition.get_pimpl()->get_fpmath_mode()); + const auto &fpm = partition.get_pimpl()->get_fpmath_mode(); + ss << ",fpm:" << fpmath_mode2str(fpm.mode_); + if (fpm.apply_to_int_) ss << ":true"; ss << "," << compiled_partition->get_pimpl()->str(); diff --git a/tests/gtests/graph/api/test_cpp_api_graph.cpp b/tests/gtests/graph/api/test_cpp_api_graph.cpp index 765682aa07a..208a255495b 100644 --- a/tests/gtests/graph/api/test_cpp_api_graph.cpp +++ b/tests/gtests/graph/api/test_cpp_api_graph.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2023 Intel Corporation +* Copyright 2020-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,22 @@ #include "oneapi/dnnl/dnnl_graph.hpp" +TEST(APIGraph, SetAndGetFloatingPointMathMode) { + using namespace dnnl::graph; + using fpmath_mode = dnnl::fpmath_mode; + + dnnl::engine::kind engine_kind = dnnl::engine::kind::cpu; + graph g(engine_kind); + g.set_fpmath_mode(fpmath_mode::bf16, true); + + fpmath_mode mode; + bool apply_to_int; + g.get_fpmath_mode(mode, apply_to_int); + + EXPECT_EQ(mode, fpmath_mode::bf16); + EXPECT_EQ(apply_to_int, 1U); +} + TEST(APIGraph, GetPartitions) { using namespace dnnl::graph; dnnl::engine::kind engine_kind = dnnl::engine::kind::cpu; From 579b9e0d5ae0c8e06a46c0903845831a5178be44 Mon Sep 17 00:00:00 2001 From: "Wang, Zhitao" Date: Fri, 30 Aug 2024 02:50:33 +0000 Subject: [PATCH 14/19] tests: benchdnn: graph: accept apply_to_int for fpmath mode setting --- tests/benchdnn/graph/deserialize.cpp | 10 ++++-- tests/benchdnn/graph/deserialize.hpp | 14 ++++++--- tests/benchdnn/graph/flex_rewrite.cpp | 9 ++++-- tests/benchdnn/graph/flex_rewrite.hpp | 6 ++-- tests/benchdnn/graph/graph.cpp | 17 +++++----- tests/benchdnn/graph/graph.hpp | 16 +++++----- tests/benchdnn/graph/parser.cpp | 40 ++++++++++++++++++++++-- tests/benchdnn/graph/parser.hpp | 3 +- tests/benchdnn/graph/setting_handler.cpp | 3 +- tests/benchdnn/graph/utils.hpp | 20 ++++++++++++ 10 files changed, 107 insertions(+), 31 deletions(-) diff --git a/tests/benchdnn/graph/deserialize.cpp b/tests/benchdnn/graph/deserialize.cpp index d4540a653c5..d1bf7e12bd2 100644 --- a/tests/benchdnn/graph/deserialize.cpp +++ b/tests/benchdnn/graph/deserialize.cpp @@ -245,6 +245,8 @@ void deserialized_graph::load(const std::string &pass_config_json) { helper.declare_field("version", &version_); helper.declare_field("engine_kind", &engine_kind_); helper.declare_field("fpmath_mode", &fpmath_mode_); + helper.declare_field( + "fpmath_mode_apply_to_int", &fpmath_mode_apply_to_int_); helper.declare_field("input_ports", &input_ports_); helper.declare_field("output_ports", &output_ports_); helper.read_fields(&read); @@ -432,9 +434,13 @@ std::string deserialized_graph::get_string() const { } dnnl::graph::graph deserialized_graph::to_graph( - dnnl::fpmath_mode fpmath_mode) const { + const graph_fpmath_mode_t &fpmath_mode) const { const auto &engine = get_graph_engine(); - dnnl::graph::graph g(engine.get_kind(), fpmath_mode); + dnnl::graph::graph g(engine.get_kind()); + g.set_fpmath_mode(static_cast( + str2fpmath_mode(fpmath_mode.mode_.c_str())), + fpmath_mode.apply_to_int_); + for (const auto &aop : ops_) { try { g.add_op(aop.create()); diff --git a/tests/benchdnn/graph/deserialize.hpp b/tests/benchdnn/graph/deserialize.hpp index 074564957c9..6ddb28e74e9 100644 --- a/tests/benchdnn/graph/deserialize.hpp +++ b/tests/benchdnn/graph/deserialize.hpp @@ -68,6 +68,7 @@ struct deserialized_op { std::string name_; std::string kind_; std::string fpmath_mode_; + std::string fpmath_mode_apply_to_int_; std::unordered_map attrs_; std::vector in_lts_; @@ -108,7 +109,7 @@ using op_ref_list_t = std::list; struct deserialized_graph { void load(const std::string &pass_config_json); - dnnl::graph::graph to_graph(dnnl::fpmath_mode fpmath_mode) const; + dnnl::graph::graph to_graph(const graph_fpmath_mode_t &fpmath_mode) const; const std::vector &get_input_ports() const { return input_ports_; }; std::vector ops_; @@ -131,16 +132,21 @@ struct deserialized_graph { std::string get_string() const; // Return the fpmath mode attribute - const std::string &get_fpmath_mode() const { return fpmath_mode_; } + const std::pair get_fpmath_mode() const { + return std::make_pair( + fpmath_mode_, str2bool(fpmath_mode_apply_to_int_.c_str())); + } - void set_fpmath_mode(const std::string &fpmath_mode) { - fpmath_mode_ = fpmath_mode; + void set_fpmath_mode(const graph_fpmath_mode_t &fpmath_mode) { + fpmath_mode_ = fpmath_mode.mode_; + fpmath_mode_apply_to_int_ = bool2str(fpmath_mode.apply_to_int_); } private: std::string engine_kind_; std::string version_; std::string fpmath_mode_; + std::string fpmath_mode_apply_to_int_; std::vector input_ports_; std::vector output_ports_; diff --git a/tests/benchdnn/graph/flex_rewrite.cpp b/tests/benchdnn/graph/flex_rewrite.cpp index 8f79a1c4238..5a24c54d3cf 100644 --- a/tests/benchdnn/graph/flex_rewrite.cpp +++ b/tests/benchdnn/graph/flex_rewrite.cpp @@ -1065,12 +1065,15 @@ void flex_rewrite::dt_rewrite(deserialized_graph &dgraph) { void flex_rewrite::graph_attrs_rewrite(deserialized_graph &dgraph) { - // if the fpmath mode is specified by users through cml - if (fpmath_mode_ != "default") dgraph.set_fpmath_mode(fpmath_mode_); + // if the fpmath mode is specified by users through cml, replace the fpmath + // mode from JSON file with the value from cml. + if (!fpmath_mode_.override_json_value_) + dgraph.set_fpmath_mode(fpmath_mode_); for (auto &aop : dgraph.ops_) { // save the graph-level config for ops - aop.fpmath_mode_ = dgraph.get_fpmath_mode(); + aop.fpmath_mode_ = fpmath_mode_.mode_; + aop.fpmath_mode_apply_to_int_ = bool2str(fpmath_mode_.apply_to_int_); } } diff --git a/tests/benchdnn/graph/flex_rewrite.hpp b/tests/benchdnn/graph/flex_rewrite.hpp index 200f053b985..325bf6c2934 100644 --- a/tests/benchdnn/graph/flex_rewrite.hpp +++ b/tests/benchdnn/graph/flex_rewrite.hpp @@ -21,14 +21,16 @@ #include #include "deserialize.hpp" +#include "utils.hpp" namespace graph { struct flex_rewrite { flex_rewrite(const std::map &in_shapes, const std::map &op_attrs, - const std::string &fpmath_mode, const int64_t mb, + const graph_fpmath_mode_t &fpmath_mode, const int64_t mb, const dnnl_data_type_t dt) + : in_shapes_(in_shapes) , op_attrs_(op_attrs) , fpmath_mode_(fpmath_mode) @@ -42,7 +44,7 @@ struct flex_rewrite { std::map in_shapes_; // input attributes from CML std::map op_attrs_; - std::string fpmath_mode_; + graph_fpmath_mode_t fpmath_mode_; int64_t mb_; dnnl_data_type_t dt_; diff --git a/tests/benchdnn/graph/graph.cpp b/tests/benchdnn/graph/graph.cpp index 13f06ca9a53..ddc3acfbcd4 100644 --- a/tests/benchdnn/graph/graph.cpp +++ b/tests/benchdnn/graph/graph.cpp @@ -342,8 +342,9 @@ using namespace dnnl::graph; std::string case_to_str(const std::string &json_file, const std::map &in_shapes, const std::map &op_attrs, - const std::string &fpmath_mode, const size_t expected_n_partitions, - const int64_t mb, const dnnl_data_type_t dt) { + const graph_fpmath_mode_t &fpmath_mode, + const size_t expected_n_partitions, const int64_t mb, + const dnnl_data_type_t dt) { std::stringstream s; dump_global_params(s); @@ -374,8 +375,10 @@ std::string case_to_str(const std::string &json_file, s << " "; } - if (strcmp(fpmath_mode.c_str(), "default") != 0) { - s << "--attr-fpmath=" << fpmath_mode << " "; + if (fpmath_mode.override_json_value_) { + s << "--attr-fpmath=" << fpmath_mode.mode_.c_str(); + if (fpmath_mode.apply_to_int_) { s << ":true"; } + s << " "; } if (expected_n_partitions != 0) { @@ -423,12 +426,12 @@ void skip_unimplemented_ops(const dnnl::graph::partition &partition, } void skip_unimplemented_graph_attribute( - const dnnl::fpmath_mode &fpmath_mode, res_t *res) { + const graph_fpmath_mode_t &fpmath_mode, res_t *res) { // Compiler backend only supports strict and bf16 for floating-point math // mode if (is_gc_backend()) { - if (fpmath_mode != dnnl::fpmath_mode::strict - && fpmath_mode != dnnl::fpmath_mode::bf16) { + const auto &mode = fpmath_mode.mode_; + if (mode != "strict" && mode != "bf16") { res->state = SKIPPED; res->reason = skip_reason::case_not_supported; return; diff --git a/tests/benchdnn/graph/graph.hpp b/tests/benchdnn/graph/graph.hpp index 79e7ed8d8f9..6fffffd8a81 100644 --- a/tests/benchdnn/graph/graph.hpp +++ b/tests/benchdnn/graph/graph.hpp @@ -50,8 +50,7 @@ struct settings_t : public base_settings_t { // `0` means not specified by user with command line knob, will skip // the partition num check. std::vector expected_n_partition_vec {0}; - // `default` means not specified by user with command line knob. - std::vector fpmath_mode_vec {"default"}; + std::vector fpmath_mode_vec {graph_fpmath_mode_t {}}; std::vector dt {dnnl_data_type_undef}; const char *perf_template_csv @@ -68,21 +67,22 @@ struct prb_t { prb_t(const deserialized_graph &dg, const size_t &expected_n_partition) : dg(dg), expected_n_partition(expected_n_partition) { - const std::string &fpmath_mode = dg.get_fpmath_mode(); - this->fpmath_mode = static_cast( - str2fpmath_mode(fpmath_mode.c_str())); + const auto &fpmath = dg.get_fpmath_mode(); + fpmath_mode.mode_ = fpmath.first; + fpmath_mode.apply_to_int_ = fpmath.second; } deserialized_graph dg; size_t expected_n_partition; - dnnl::fpmath_mode fpmath_mode; + graph_fpmath_mode_t fpmath_mode; }; std::string case_to_str(const std::string &json_file, const std::map &in_shapes, const std::map &op_attrs, - const std::string &fpmath_mode, const size_t expected_n_partitions, - const int64_t mb, const dnnl_data_type_t dt); + const graph_fpmath_mode_t &fpmath_mode, + const size_t expected_n_partitions, const int64_t mb, + const dnnl_data_type_t dt); struct perf_report_t : public base_perf_report_t { perf_report_t(const std::string case_str, const char *perf_template) diff --git a/tests/benchdnn/graph/parser.cpp b/tests/benchdnn/graph/parser.cpp index af042ac60e7..17cfff9e113 100644 --- a/tests/benchdnn/graph/parser.cpp +++ b/tests/benchdnn/graph/parser.cpp @@ -17,6 +17,7 @@ #include "utils/parser.hpp" #include "parser.hpp" +#include "utils.hpp" namespace graph { @@ -68,6 +69,15 @@ void parse_key_value(std::vector> &res_v, res_v.push_back(key_val_case); } } + +// Copy-pasted from utils::parser. Refer to documentation there. +std::string get_substr(const std::string &s, size_t &start_pos, char delim) { + auto end_pos = s.find_first_of(delim, start_pos); + auto sub = s.substr(start_pos, end_pos - start_pos); + start_pos = end_pos + (end_pos != eol); + return sub; +} + } // namespace bool parse_input_shapes( @@ -115,7 +125,7 @@ bool parse_graph_expected_n_partitions( } bool parse_graph_fpmath_mode( - std::vector &fpmath_mode_vec, const char *str) { + std::vector &fpmath_mode_vec, const char *str) { std::string graph_attrs_str; if (!parse_string(graph_attrs_str, str, "attr-fpmath")) return false; @@ -123,10 +133,34 @@ bool parse_graph_fpmath_mode( std::string mode; while (std::getline(ss, mode, ',')) { if (!mode.empty()) { + // override_json_value == false indicates that the fpmath mode is + // not from the cml knob. if (fpmath_mode_vec.size() == 1 - && fpmath_mode_vec.front() == "default") + && !fpmath_mode_vec.front().override_json_value_) fpmath_mode_vec.pop_back(); - fpmath_mode_vec.emplace_back(mode); + + size_t start_pos = 0; + auto subs = get_substr(mode, start_pos, ':'); + if (start_pos != std::string::npos && start_pos >= mode.size()) { + BENCHDNN_PRINT(0, "%s \'%s\'\n", + "Error: dangling symbol at the end of input", + mode.c_str()); + SAFE_V(FAIL); + } + + bool apply_to_int = false; + if (start_pos != std::string::npos) { + subs = get_substr(mode, start_pos, '\0'); + if (start_pos != std::string::npos) { + BENCHDNN_PRINT(0, "%s \'%s\'\n", + "Error: dangling symbol at the end of input", + mode.c_str()); + SAFE_V(FAIL); + } + apply_to_int = str2bool(subs.c_str()); + } + fpmath_mode_vec.emplace_back( + mode, apply_to_int, /* override_json_value = */ true); } } return true; diff --git a/tests/benchdnn/graph/parser.hpp b/tests/benchdnn/graph/parser.hpp index 2417aae74f5..31355a8567f 100644 --- a/tests/benchdnn/graph/parser.hpp +++ b/tests/benchdnn/graph/parser.hpp @@ -23,6 +23,7 @@ #include "allocator.hpp" #include "dnnl_common.hpp" #include "oneapi/dnnl/dnnl_graph.hpp" +#include "utils.hpp" extern dnnl_engine_kind_t engine_tgt_kind; @@ -39,7 +40,7 @@ bool parse_graph_expected_n_partitions( std::vector &expected_n_partition_vec, const char *str); bool parse_graph_fpmath_mode( - std::vector &fpmath_mode_vec, const char *str); + std::vector &fpmath_mode_vec, const char *str); bool parse_input_file(std::string &json_file, const char *str); diff --git a/tests/benchdnn/graph/setting_handler.cpp b/tests/benchdnn/graph/setting_handler.cpp index ceffb45d5c5..93b483d20b8 100644 --- a/tests/benchdnn/graph/setting_handler.cpp +++ b/tests/benchdnn/graph/setting_handler.cpp @@ -77,7 +77,8 @@ bool get_graph_attr(const deserialized_op &base_op_ref, attr_t::fpmath_mode_t &arg_fpmath_mode) { const auto &fpmath_mode = base_op_ref.fpmath_mode_; - arg_fpmath_mode.set(str2fpmath_mode(fpmath_mode.c_str())); + arg_fpmath_mode.set(str2fpmath_mode(fpmath_mode.c_str()), + str2bool(base_op_ref.fpmath_mode_apply_to_int_.c_str())); return true; } diff --git a/tests/benchdnn/graph/utils.hpp b/tests/benchdnn/graph/utils.hpp index cb0a76caa90..25a31c6ed5e 100644 --- a/tests/benchdnn/graph/utils.hpp +++ b/tests/benchdnn/graph/utils.hpp @@ -220,5 +220,25 @@ inline double GB(double bytes) { return bytes / powf(2, 30); } +struct graph_fpmath_mode_t { + graph_fpmath_mode_t() = default; + graph_fpmath_mode_t(const std::string &mode, bool apply_to_int, + bool override_json_value) + : mode_(mode) + , apply_to_int_(apply_to_int) + , override_json_value_(override_json_value) {} + + bool operator==(const graph_fpmath_mode_t &rhs) const { + return mode_ == rhs.mode_ && apply_to_int_ == rhs.apply_to_int_ + && override_json_value_ == rhs.override_json_value_; + } + + std::string mode_ = "strict"; + bool apply_to_int_ = false; + // Since fpmath_mode doesn't provide an "undef" value that would indicate + // it was not set externally to the json case, need to maintain this flag. + bool override_json_value_ = false; +}; + } // namespace graph #endif From 218064840ca1636bfd5c71b9545c9b7c918de922 Mon Sep 17 00:00:00 2001 From: "Wang, Zhitao" Date: Thu, 26 Sep 2024 05:49:42 +0000 Subject: [PATCH 15/19] doc: graph: add description for set fpmath mode API --- .../programming_model/graph_basic_concepts.md | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/doc/graph/programming_model/graph_basic_concepts.md b/doc/graph/programming_model/graph_basic_concepts.md index 5ee5eaa558b..b2f75349d1f 100644 --- a/doc/graph/programming_model/graph_basic_concepts.md +++ b/doc/graph/programming_model/graph_basic_concepts.md @@ -41,13 +41,19 @@ tensor as the edge between them. ## Graph `Graph` (@ref dnnl::graph::graph) contains a set of operations. A graph object -is associated to a specific engine kind (@ref dnnl::engine::kind). Multiple -operations can be added (@ref dnnl::graph::graph::add_op) along with input and -output logical tensors to a graph. After finishing adding operations, -finalization API (@ref dnnl::graph::graph::finalize) can be called to indicate -that the graph is ready for partitioning. By calling partitioning API (@ref -dnnl::graph::graph::get_partitions), a group of partitions from the graph will -be returned . +is associated to a specific engine kind (@ref dnnl::engine::kind). In addition, +you can set the graph-level floating-point math mode through the setter API +( @ref dnnl::graph::graph::set_fpmath_mode ) or in the constructor. The API +accepts two paramters, the given floating point math mode and a optional boolean +flag to indicate whether to use floating-point arithmetic for integral +operations. + +Multiple operations can be added (@ref dnnl::graph::graph::add_op) along with +input and output logical tensors to a graph. After finishing adding the +operations, finalization API (@ref dnnl::graph::graph::finalize) can be called +to indicate that the graph is ready for partitioning. By calling partitioning +API (@ref dnnl::graph::graph::get_partitions), a group of partitions from the +graph will be returned. ## Partition From 76ee665a4382ca82bb9d9e81424c980cc1009e12 Mon Sep 17 00:00:00 2001 From: "Wang, Zhitao" Date: Tue, 22 Oct 2024 01:51:17 +0000 Subject: [PATCH 16/19] tests: benchdnn: graph: inputs: add test cases for fpmath mode with new flag --- tests/benchdnn/inputs/graph/pattern/harness_int8_all | 6 +++--- .../graph/pattern/int8/int8_bf16_conv_add_fusion.json | 1 + .../inputs/graph/pattern/int8/int8_bf16_matmul.json | 1 + .../graph/pattern/int8/int8_maxpool_add_mul_fusion.json | 1 + 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/benchdnn/inputs/graph/pattern/harness_int8_all b/tests/benchdnn/inputs/graph/pattern/harness_int8_all index d4fe03d5d0d..807d77c6681 100644 --- a/tests/benchdnn/inputs/graph/pattern/harness_int8_all +++ b/tests/benchdnn/inputs/graph/pattern/harness_int8_all @@ -1,8 +1,8 @@ ---reset --case=pattern/int8/int8_bf16_matmul.json +--reset --attr-fpmath=strict:false,bf16:false,tf32:false --case=pattern/int8/int8_bf16_matmul.json --reset --case=pattern/int8/int8_bf16_matmul_mul_add_fusion.json --reset --case=pattern/int8/int8_bf16_matmul_post_ops_fusion.json --reset --case=pattern/int8/int8_concat_fusion.json ---reset --case=pattern/int8/int8_conv_bias_fusion.json +--reset --attr-fpmath=strict:false,bf16:false,tf32:false --case=pattern/int8/int8_conv_bias_fusion.json --reset --case=pattern/int8/int8_conv_post_ops_fusion.json --reset --case=pattern/int8/int8_conv_post_ops_int8_add_fusion.json --reset --case=pattern/int8/int8_convtranspose_post_ops_fusion.json @@ -12,7 +12,7 @@ --reset --case=pattern/int8/int8_matmul_sum_add_mul_relu.json --reset --case=pattern/int8/int8_bf16_matmul_add_mul_relu.json --reset --case=pattern/int8/int8_bf16_matmul_sum_add_mul_relu.json ---reset --case=pattern/int8/int8_avgpool_reshape_fusion.json +--reset --attr-fpmath=strict:false,bf16:false,tf32:false --case=pattern/int8/int8_avgpool_reshape_fusion.json --reset --case=pattern/int8/int8_avgpool_transpose_fusion.json --reset --case=pattern/int8/int8_bf16_conv_add_relu_mul.json --reset --case=pattern/int8/int8_bf16_matmul_tc_add_quant_fusion.json diff --git a/tests/benchdnn/inputs/graph/pattern/int8/int8_bf16_conv_add_fusion.json b/tests/benchdnn/inputs/graph/pattern/int8/int8_bf16_conv_add_fusion.json index 1e732f6631f..2444a1966ce 100644 --- a/tests/benchdnn/inputs/graph/pattern/int8/int8_bf16_conv_add_fusion.json +++ b/tests/benchdnn/inputs/graph/pattern/int8/int8_bf16_conv_add_fusion.json @@ -2,6 +2,7 @@ "version": "3.0.0", "engine_kind": "cpu", "fpmath_mode": "strict", + "fpmath_mode_apply_to_int": "false", "graph": [ { "id": 2065, diff --git a/tests/benchdnn/inputs/graph/pattern/int8/int8_bf16_matmul.json b/tests/benchdnn/inputs/graph/pattern/int8/int8_bf16_matmul.json index 865a470f52b..101308454f1 100644 --- a/tests/benchdnn/inputs/graph/pattern/int8/int8_bf16_matmul.json +++ b/tests/benchdnn/inputs/graph/pattern/int8/int8_bf16_matmul.json @@ -2,6 +2,7 @@ "version": "3.0.0", "engine_kind": "cpu", "fpmath_mode": "strict", + "fpmath_mode_apply_to_int": "false", "graph": [ { "id": 0, diff --git a/tests/benchdnn/inputs/graph/pattern/int8/int8_maxpool_add_mul_fusion.json b/tests/benchdnn/inputs/graph/pattern/int8/int8_maxpool_add_mul_fusion.json index 56363610437..9a71813e57f 100644 --- a/tests/benchdnn/inputs/graph/pattern/int8/int8_maxpool_add_mul_fusion.json +++ b/tests/benchdnn/inputs/graph/pattern/int8/int8_maxpool_add_mul_fusion.json @@ -2,6 +2,7 @@ "version": "3.2.0", "engine_kind": "cpu", "fpmath_mode": "strict", + "fpmath_mode_apply_to_int": "false", "graph": [ { "id": 2065, From 0222cd54fb048496045e00217268f5aa3377808f Mon Sep 17 00:00:00 2001 From: "Gu, Yonghao" Date: Tue, 29 Oct 2024 13:31:44 +0000 Subject: [PATCH 17/19] graph: backend: fix the micro_kernel check for GQA --- src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp b/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp index f20d4b57b4c..6d08675c8fb 100644 --- a/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp +++ b/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp @@ -169,8 +169,12 @@ status_t sdp_primitive_config_t::initial_check( if (in_lt.data_type != dnnl_data_type_t::dnnl_f16) return status::unimplemented; - auto find_graph_inport = [&inputs](const std::shared_ptr &val) { + auto find_graph_inport = [&inputs](std::shared_ptr val) { for (int i = 0; i < (int)inputs.size(); i++) { + // For GQA, it has producer such as static_reshape. + while (val->has_producer()) { + val = val->get_producer().get_input_value(0); + } if (val->get_logical_tensor().id == inputs[i].id) { return i; } } // If the corresponding input is not found, return an invalid value From 24058ecd4e7e6091a58a3e36bad1e3e4022a5c2d Mon Sep 17 00:00:00 2001 From: "Gu, Yonghao" Date: Wed, 30 Oct 2024 01:42:40 +0000 Subject: [PATCH 18/19] benchdnn: graph: fix reshape+matmul data filling --- tests/benchdnn/graph/input_displacer.cpp | 33 ++++++++++++------------ 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/tests/benchdnn/graph/input_displacer.cpp b/tests/benchdnn/graph/input_displacer.cpp index 4c6f5876ce8..f3f9d3aa708 100644 --- a/tests/benchdnn/graph/input_displacer.cpp +++ b/tests/benchdnn/graph/input_displacer.cpp @@ -109,24 +109,23 @@ partition_data_displacer_t::partition_data_displacer_t( filling_type_t::quantization)); break; } - - if (parent_op->kind_ == "StaticReshape") { - // StaticReshape is accepted when the pattern is - // "StaticReshape + Matmul" and it doesn't have any - // predecessors in the partition - const auto &parent_op_in_lt = parent_op->in_lts_[0]; - const auto &prev_parent_op - = dg_->get_op_by_out_lt(parent_op_in_lt.id_); - if (prev_parent_op.empty() - || op_ids_set_.find(prev_parent_op.id_) - == op_ids_set_.end()) { - if (aop.kind_ == "MatMul") { - quantize_displace_.emplace(parent_op_in_lt.id_, - std::make_tuple(aop, i, parent_op_in_lt, - filling_type_t::quantization)); - } - break; + } + if (parent_op->kind_ == "StaticReshape") { + // StaticReshape is accepted when the pattern is + // "StaticReshape + Matmul" and it doesn't have any + // predecessors in the partition + const auto &parent_op_in_lt = parent_op->in_lts_[0]; + const auto &prev_parent_op + = dg_->get_op_by_out_lt(parent_op_in_lt.id_); + if (prev_parent_op.empty() + || op_ids_set_.find(prev_parent_op.id_) + == op_ids_set_.end()) { + if (aop.kind_ == "MatMul") { + quantize_displace_.emplace(parent_op_in_lt.id_, + std::make_tuple(aop, i, parent_op_in_lt, + filling_type_t::quantization)); } + break; } } // Continue only on allowed ops. From 1fe8ee54b18c764d32932d21e776a86f46a6d0cf Mon Sep 17 00:00:00 2001 From: Ye Tao Date: Tue, 15 Oct 2024 15:18:08 +0000 Subject: [PATCH 19/19] cpu: aarch64: fix acl matmul dim guard for 4d tensor broadcast Signed-off-by: Ye Tao --- src/cpu/aarch64/matmul/acl_matmul_utils.cpp | 16 ++++++++++++++++ tests/benchdnn/inputs/matmul/shapes_4d | 1 + 2 files changed, 17 insertions(+) diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp index a921422ac0b..9b23f9a0c29 100644 --- a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp +++ b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp @@ -47,10 +47,26 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, // for e.g when ab in abcd is 1x1 bool batch_ok = IMPLICATION(src_batch > 1, wei_batch == 1) && IMPLICATION(wei_batch > 1, src_batch == 1); + ACL_CHECK_SUPPORT(src_d.ndims() == 4 && src_batch != wei_batch && !batch_ok, "matmul broadcast supported only for 3D shapes and 4D shapes when " "ab is 1x1"); + if (src_d.ndims() == 4 && src_batch == wei_batch + && src_d.dims()[0] != wei_d.dims()[0]) { // 4D broadcast occurred + if (src_d.dims()[0] == 1 && wei_d.dims()[0] != 1) { // Broadcast src + ACL_CHECK_SUPPORT( + IMPLICATION(src_d.dims()[1] != 1, wei_d.dims()[1] == 1), + "acl only broadcasts one of src or wei at once"); + } + + if (wei_d.dims()[0] == 1 && src_d.dims()[0] != 1) { // Broadcast wei + ACL_CHECK_SUPPORT( + IMPLICATION(src_d.dims()[1] == 1, wei_d.dims()[1] != 1), + "acl only broadcasts one of src or wei at once"); + } + } + // ACL does not support bias bool with_bias = md.bias_desc.format_kind != format_kind::undef; ACL_CHECK_SUPPORT(with_bias, "ACL does not support bias for matmul"); diff --git a/tests/benchdnn/inputs/matmul/shapes_4d b/tests/benchdnn/inputs/matmul/shapes_4d index 7a8aa33de14..924b6607d6c 100644 --- a/tests/benchdnn/inputs/matmul/shapes_4d +++ b/tests/benchdnn/inputs/matmul/shapes_4d @@ -18,5 +18,6 @@ 74x16x54x64:74x16x64x54 1x1x35x64:113x16x64x35 1x16x38x64:105x1x64x38 +1x3x35x64:3x1x64x35 74x16x54x64:1x1x64x54n"B_full_bcast" 74x6x1x253:1x1x253x1n"dot_prod_w_B_full_bcast"