diff --git a/ggml/include/ggml-cann.h b/ggml/include/ggml-cann.h deleted file mode 100644 index ca73211fe..000000000 --- a/ggml/include/ggml-cann.h +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Copyright (c) 2023-2024 The ggml authors - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS - * IN THE SOFTWARE. - */ - -#pragma once - -#include "ggml-backend.h" -#include "ggml.h" - -#ifdef __cplusplus -extern "C" { -#endif - -/** - * @brief Maximum number of CANN devices supported. - */ -#define GGML_CANN_MAX_DEVICES 16 - -/** - * @brief Initializes the CANN backend for a specified device. - * - * This function initializes the CANN backend for the given device. - * It verifies the device index, allocates a context, and creates a backend - * instance. - * - * @param device The index of the device to initialize. - * @return A pointer to the initialized backend instance, or nullptr on failure. - */ -GGML_API GGML_CALL ggml_backend_t ggml_backend_cann_init(int32_t device); - -/** - * @brief Checks if a given backend is a CANN backend. - * - * This function verifies if the provided backend is a CANN backend by comparing - * its GUID with the CANN backend's GUID. - * - * @param backend The backend instance to check. - * @return True if the backend is a CANN backend, false otherwise. - */ -GGML_API GGML_CALL bool ggml_backend_is_cann(ggml_backend_t backend); - -/** - * @brief Retrieves the CANN buffer type for a specified device. - * - * This function initializes and returns the buffer type interface associated - * with the given device. It ensures thread-safe access using a mutex. - * - * @param device The device index for which to retrieve the buffer type. - * @return A pointer to the buffer type interface for the specified device, or - * nullptr if the device index is out of range. - */ -GGML_API GGML_CALL ggml_backend_buffer_type_t -ggml_backend_cann_buffer_type(int32_t device); - -/** - * @brief Retrieves the number of CANN devices available. - * - * This function returns the number of CANN devices available based on - * information obtained from `ggml_cann_info()`. - * - * @return The number of CANN devices available. - */ -GGML_API GGML_CALL int32_t ggml_backend_cann_get_device_count(void); - -/** - * @brief Retrieves the description of a specific CANN device. - * - * This function sets the specified device, retrieves the SoC name, - * and writes it into the provided description buffer. - * - * @param device The device index to retrieve the description for. - * @param description Pointer to a buffer where the description will be written. - * @param description_size Size of the description buffer. - */ -GGML_API GGML_CALL void ggml_backend_cann_get_device_description( - int32_t device, char* description, size_t description_size); - -/** - * @brief Retrieves the memory information of a specific CANN device. - * - * This function sets the specified device, retrieves the free and total - * memory information of the specified type (ACL_HBM_MEM), and stores them - * in the provided pointers. - * - * @param device The device index to retrieve memory information for. - * @param free Pointer to a variable where the free memory size will be stored. - * @param total Pointer to a variable where the total memory size will be - * stored. - */ -GGML_API GGML_CALL void ggml_backend_cann_get_device_memory(int32_t device, - size_t* free, - size_t* total); - -/** - * @brief Set the logging callback for GGML. - * - * This function sets the logging callback and user data for logging. - * - * @param log_callback The logging callback to set. - * @param user_data User data to pass to the logging callback. - */ -GGML_API void ggml_backend_cann_log_set_callback(ggml_log_callback log_callback, - void* user_data); - -#ifdef __cplusplus -} -#endif diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c deleted file mode 100644 index af53dea17..000000000 --- a/ggml/src/ggml-aarch64.c +++ /dev/null @@ -1,2193 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd. -#define GGML_COMMON_IMPL_C -#include "ggml-common.h" - -#include "ggml-quants.h" -#include "ggml-impl.h" - -#include -#include -#include -#include -#include // for qsort -#include // for GGML_ASSERT - -#include "ggml-aarch64.h" - -#if defined(__GNUC__) -#pragma GCC diagnostic ignored "-Woverlength-strings" -#endif - -#define UNUSED GGML_UNUSED - -// Functions to create the interleaved data layout formats - -// interleave 4 block_q4_0s in blocks of blck_size_interleave -// returns an interleaved block_q4_0x4 -// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks -// first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave -// -// - in : an array of block_q4_0 pointers -// - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of -// blck_size_interleave bytes -// - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes -// from bias offset form to pure sign form (this saves subtract -// operations durin unpacking) -// -static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) { - block_q4_0x4 out; - - for (int i = 0; i < 4; i++) { - out.d[i] = in[i].d; - } - - for (int i = 0; i < QK4_0 * 2; i++) { - int src_offset = (i / (4 * blck_size_interleave)) * blck_size_interleave; - int src_id = (i % (4 * blck_size_interleave)) / blck_size_interleave; - src_offset += (i % blck_size_interleave); - - out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; - } - - return out; -} - -// interleave 8 block_q4_0s in blocks of blck_size_interleave -// returns an interleaved block_q4_0x8 -// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks -// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave -static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) { - block_q4_0x8 out; - - for (int i = 0; i < 8; i++) { - out.d[i] = in[i].d; - } - - for (int i = 0; i < QK4_0 * 4; i++) { - int src_offset = (i / (8 * blck_size_interleave)) * blck_size_interleave; - int src_id = (i % (8 * blck_size_interleave)) / blck_size_interleave; - src_offset += (i % blck_size_interleave); - - out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; - } - - return out; -} - -void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) { - assert(QK8_0 == 32); - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; - - block_q8_0x4 * restrict y = (block_q8_0x4 *) vy; - -#if defined(__ARM_NEON) - float32x4_t srcv[4][8]; - float id[4]; - - for (int i = 0; i < nb; i++) { - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int row_iter = 0; row_iter < 4; row_iter++) { - for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); - - for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); - for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); - for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); - - const float amax = vmaxvq_f32(amaxv[0]); - - const float d = amax / ((1 << 7) - 1); - id[row_iter] = d ? 1.0f / d : 0.0f; - - y[i].d[row_iter] = GGML_FP32_TO_FP16(d); - } - - for (int j = 0; j < 8; j++) { - float32x4_t v = vmulq_n_f32(srcv[0][j], id[0]); - int32x4_t vi = vcvtnq_s32_f32(v); - y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[1][j], id[1]); - vi = vcvtnq_s32_f32(v); - y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0); - y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1); - y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2); - y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[2][j], id[2]); - vi = vcvtnq_s32_f32(v); - y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0); - y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1); - y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2); - y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[3][j], id[3]); - vi = vcvtnq_s32_f32(v); - y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0); - y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1); - y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2); - y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3); - } - } -#else - // scalar - const int blck_size_interleave = 4; - float srcv[4][QK8_0]; - float id[4]; - - for (int i = 0; i < nb; i++) { - for (int row_iter = 0; row_iter < 4; row_iter++) { - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_0; j++) { - srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; - amax = MAX(amax, fabsf(srcv[row_iter][j])); - } - - const float d = amax / ((1 << 7) - 1); - id[row_iter] = d ? 1.0f / d : 0.0f; - - y[i].d[row_iter] = GGML_FP32_TO_FP16(d); - } - - for (int j = 0; j < QK8_0 * 4; j++) { - int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; - int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; - src_offset += (j % blck_size_interleave); - - float x0 = srcv[src_id][src_offset] * id[src_id]; - y[i].qs[j] = roundf(x0); - } - } -#endif -} - -void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) { - assert(QK8_0 == 32); - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; - - block_q8_0x4 * restrict y = (block_q8_0x4 *) vy; - -#if defined(__ARM_NEON) - float32x4_t srcv[4][8]; - float id[4]; - - for (int i = 0; i < nb; i++) { - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int row_iter = 0; row_iter < 4; row_iter++) { - for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); - - for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); - for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); - for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); - - const float amax = vmaxvq_f32(amaxv[0]); - - const float d = amax / ((1 << 7) - 1); - id[row_iter] = d ? 1.0f / d : 0.0f; - - y[i].d[row_iter] = GGML_FP32_TO_FP16(d); - } - - for (int j = 0; j < 4; j++) { - float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]); - int32x4_t vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3); - v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[1][2 * j], id[1]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3); - v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[2][2 * j], id[2]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3); - v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[3][2 * j], id[3]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3); - v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3); - } - } -#else - // scalar - const int blck_size_interleave = 8; - float srcv[4][QK8_0]; - float id[4]; - - for (int i = 0; i < nb; i++) { - for (int row_iter = 0; row_iter < 4; row_iter++) { - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_0; j++) { - srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; - amax = MAX(amax, fabsf(srcv[row_iter][j])); - } - - const float d = amax / ((1 << 7) - 1); - id[row_iter] = d ? 1.0f / d : 0.0f; - - y[i].d[row_iter] = GGML_FP32_TO_FP16(d); - } - - for (int j = 0; j < QK8_0 * 4; j++) { - int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; - int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; - src_offset += (j % blck_size_interleave); - - float x0 = srcv[src_id][src_offset] * id[src_id]; - y[i].qs[j] = roundf(x0); - } - } -#endif -} - -void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) { - assert(nrow == 4); - UNUSED(nrow); - if (blck_size_interleave == 4) { - quantize_q8_0_4x4(x, vy, n_per_row); - } else if (blck_size_interleave == 8) { - quantize_q8_0_4x8(x, vy, n_per_row); - } else { - assert(false); - } -} - -static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int blck_size_interleave) { - assert(n_per_row % QK4_0 == 0); - const int nb = n_per_row / QK4_0; - - void * out_ptr = NULL; - if (nrows_interleaved == 8) { - out_ptr = (block_q4_0x8 *) dst; - } - else if (nrows_interleaved == 4) { - out_ptr = (block_q4_0x4 *) dst; - } - assert(nrows_interleaved <= 8); - block_q4_0 dst_tmp[8]; - - for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) { - - for (int64_t x = 0; x < nb; x++) { - - for (int i = 0; i < nrows_interleaved; i++ ) { - quantize_row_q4_0_ref(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0); - } - - if (nrows_interleaved == 8) { - *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave, 0x88); - out_ptr = (block_q4_0x8 *) out_ptr + 1; - } - else if (nrows_interleaved == 4) { - *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave, 0x88); - out_ptr = (block_q4_0x4 *) out_ptr + 1; - } - } - } - - return ((nrow * n_per_row) / QK4_0 * sizeof(block_q4_0)); -} - -size_t quantize_q4_0_4x4(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - if (!quant_weights) { - return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4); - } - else { - assert(false); - return 0; - } -} - -size_t quantize_q4_0_4x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - if (!quant_weights) { - return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8); - } - else { - assert(false); - return 0; - } -} - -size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - if (!quant_weights) { - return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8); - } - else { - assert(false); - return 0; - } -} - -void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined(__ARM_FEATURE_SVE) - if (svcntw() == 8) { - GGML_ASSERT(!(ggml_cpu_has_sve() && (svcntw() == 8)) && - "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); - } -#endif -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) - GGML_ASSERT(!(ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) && - "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance"); -#elif defined(__ARM_NEON) && defined(__aarch64__) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - - __asm__ __volatile__( - "movi v31.16b, #0x4\n" - "movi v30.16b, #0xf0\n" - "add %x[b_ptr], %x[b_ptr], #0x8\n" - "1:" // Column loop - "add x22, %x[a_ptr], #0x2\n" - "movi v29.16b, #0x0\n" - "mov x21, %x[nb]\n" - "2:" // Block loop - "ldr q28, [%x[b_ptr], #0x0]\n" - "ldr q27, [x22, #0x0]\n" - "movi v26.4s, #0x0\n" - "sub x20, x22, #0x2\n" - "ldr q25, [x22, #0x10]\n" - "ldr q24, [%x[b_ptr], #0x10]\n" - "sub x21, x21, #0x1\n" - "add x22, x22, #0x22\n" - "ldr q23, [%x[b_ptr], #0x20]\n" - "ldr q22, [%x[b_ptr], #0x30]\n" - "ld1r { v21.8h }, [x20]\n" - "ldr q20, [%x[b_ptr], #-0x8]\n" - "sshl v16.16b, v28.16b, v31.16b\n" - "and v28.16b, v28.16b, v30.16b\n" - "sshl v19.16b, v24.16b, v31.16b\n" - "and v24.16b, v24.16b, v30.16b\n" - "add %x[b_ptr], %x[b_ptr], #0x48\n" - "sshl v18.16b, v23.16b, v31.16b\n" - "and v23.16b, v23.16b, v30.16b\n" - ".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n" - "sshl v17.16b, v22.16b, v31.16b\n" - "and v22.16b, v22.16b, v30.16b\n" - "fcvtl v21.4s, v21.4h\n" - "fcvtl v16.4s, v20.4h\n" - ".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n" - "fmul v16.4s, v16.4s, v21.4s\n" - ".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n" - ".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n" - ".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n" - ".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n" - ".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n" - ".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n" - "scvtf v26.4s, v26.4s, #0x4\n" - "fmla v29.4s, v26.4s, v16.4s\n" - "cbnz x21, 2b\n" - "sub %x[nc], %x[nc], #0x4\n" - "str q29, [%x[res_ptr], #0x0]\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "cbnz %x[nc], 1b\n" - : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) - : [a_ptr] "r" (a_ptr), [nb] "r" (nb) - : "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22" - ); -#else - float sumf[4]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; - } - sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } -#endif -} - -void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined(__ARM_FEATURE_SVE) - if (svcntw() == 8) { - GGML_ASSERT(!(ggml_cpu_has_sve() && (svcntw() == 8)) && - "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); - } -#endif -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - - __asm__ __volatile__( - "movi v2.16b, #0x4\n" - "movi v1.16b, #0xf0\n" - "add %x[b_ptr], %x[b_ptr], #0x8\n" - "1:" // Column loop - "add x23, %x[a_ptr], #0x2\n" - "movi v0.16b, #0x0\n" - "mov x22, %x[nb]\n" - "2:" // Block loop - "ldr q31, [%x[b_ptr], #0x0]\n" - "ldr q30, [%x[b_ptr], #0x10]\n" - "mov x21, x23\n" - "movi v29.4s, #0x0\n" - "ldr q28, [%x[b_ptr], #0x20]\n" - "ldr q27, [%x[b_ptr], #0x30]\n" - "movi v26.4s, #0x0\n" - "sub x20, x23, #0x2\n" - "ld1r { v25.8h }, [x20]\n" - "ldr q24, [%x[b_ptr], #-0x8]\n" - "sub x22, x22, #0x1\n" - "add x23, x23, #0x22\n" - "ld1r { v23.2d }, [x21], #0x8\n" - "sshl v22.16b, v31.16b, v2.16b\n" - "sshl v16.16b, v30.16b, v2.16b\n" - "add %x[b_ptr], %x[b_ptr], #0x48\n" - "ld1r { v21.2d }, [x21], #0x8\n" - "sshl v20.16b, v28.16b, v2.16b\n" - "sshl v19.16b, v27.16b, v2.16b\n" - "ld1r { v18.2d }, [x21], #0x8\n" - "ld1r { v17.2d }, [x21], #0x8\n" - "and v31.16b, v31.16b, v1.16b\n" - "and v30.16b, v30.16b, v1.16b\n" - ".inst 0x4e9796dd // sdot v29.4s, v22.16b, v23.16b\n" - ".inst 0x4e97961a // sdot v26.4s, v16.16b, v23.16b\n" - "and v28.16b, v28.16b, v1.16b\n" - "and v27.16b, v27.16b, v1.16b\n" - "fcvtl v25.4s, v25.4h\n" - "fcvtl v16.4s, v24.4h\n" - ".inst 0x4e95969d // sdot v29.4s, v20.16b, v21.16b\n" - ".inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n" - "fmul v16.4s, v16.4s, v25.4s\n" - ".inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n" - ".inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n" - ".inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n" - ".inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n" - "addp v29.4s, v29.4s, v26.4s\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "fmla v0.4s, v29.4s, v16.4s\n" - "cbnz x22, 2b\n" - "sub %x[nc], %x[nc], #0x4\n" - "str q0, [%x[res_ptr], #0x0]\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "cbnz %x[nc], 1b\n" - : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) - : [a_ptr] "r" (a_ptr), [nb] "r" (nb) - : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23" - ); -#elif defined(__ARM_NEON) && defined(__aarch64__) - GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) && - "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " - "performance"); -#else - float sumf[4]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; - } - sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } -#endif -} - -void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined(__ARM_FEATURE_SVE) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) - if (svcntw() == 8) { - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - - __asm__ __volatile__( - "ptrue p0.b\n" - "add %x[b_ptr], %x[b_ptr], #0x10\n" - "1:" // Column loop - "add x22, %x[a_ptr], #0x2\n" - "mov z31.b, #0x0\n" - "mov x21, %x[nb]\n" - "2:" // Block loop - "ld1b { z30.b }, p0/Z, [%x[b_ptr]]\n" - "ld1b { z29.b }, p0/Z, [%x[b_ptr], #1, MUL VL]\n" - "mov z28.s, #0x0\n" - "mov z27.s, #0x0\n" - "ld1rd { z26.d }, p0/Z, [x22]\n" - "ld1b { z25.b }, p0/Z, [%x[b_ptr], #2, MUL VL]\n" - "sub x20, x22, #0x2\n" - "sub x21, x21, #0x1\n" - "ld1b { z24.b }, p0/Z, [%x[b_ptr], #3, MUL VL]\n" - "ld1rd { z23.d }, p0/Z, [x22, #8]\n" - "lsl z22.b, z30.b, #0x4\n" - "lsl z16.b, z29.b, #0x4\n" - "and z30.b, z30.b, #0xf0\n" - "and z29.b, z29.b, #0xf0\n" - "ld1rd { z21.d }, p0/Z, [x22, #16]\n" - "ld1rd { z20.d }, p0/Z, [x22, #24]\n" - "lsl z19.b, z25.b, #0x4\n" - "and z25.b, z25.b, #0xf0\n" - "ld1rh { z17.h }, p0/Z, [x20]\n" - "ld1h { z18.s }, p0/Z, [%x[b_ptr], #-1, MUL VL]\n" - "sdot z28.s, z22.b, z26.b\n" - "sdot z27.s, z16.b, z26.b\n" - "lsl z16.b, z24.b, #0x4\n" - "add x22, x22, #0x22\n" - "and z24.b, z24.b, #0xf0\n" - "add %x[b_ptr], %x[b_ptr], #0x90\n" - "fcvt z17.s, p0/m, z17.h\n" - "fcvt z18.s, p0/m, z18.h\n" - "sdot z28.s, z19.b, z23.b\n" - "sdot z27.s, z16.b, z23.b\n" - "fmul z18.s, z18.s, z17.s\n" - "sdot z28.s, z30.b, z21.b\n" - "sdot z27.s, z29.b, z21.b\n" - "sdot z28.s, z25.b, z20.b\n" - "sdot z27.s, z24.b, z20.b\n" - "uzp1 z17.s, z28.s, z27.s\n" - "uzp2 z16.s, z28.s, z27.s\n" - "add z17.s, z17.s, z16.s\n" - "asr z17.s, z17.s, #0x4\n" - "scvtf z17.s, p0/m, z17.s\n" - "fmla z31.s, p0/M, z17.s, z18.s\n" - "cbnz x21, 2b\n" - "sub %x[nc], %x[nc], #0x8\n" - "st1w { z31.s }, p0, [%x[res_ptr]]\n" - "add %x[res_ptr], %x[res_ptr], #0x20\n" - "cbnz %x[nc], 1b\n" - : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) - : [a_ptr] "r" (a_ptr), [nb] "r" (nb) - : "memory", "p0", "x20", "x21", "x22", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" - ); - return; - } - else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { - GGML_ASSERT((ggml_cpu_has_sve() && (svcntw() == 8)) && - "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal " - "performance"); - } - else if (ggml_cpu_has_neon()) { - GGML_ASSERT(((ggml_cpu_has_sve() && (svcntw() == 8)) || ggml_cpu_has_matmul_int8()) && - "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 " - "quantization format for optimal performance"); - } -#endif -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) - GGML_ASSERT(ggml_cpu_has_sve() && - "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance"); -#elif defined(__ARM_NEON) && defined(__aarch64__) - GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) && - "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " - "performance"); -#else - float sumf[8]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; - } - sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } -#endif -} - -void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) - if (svcntw() == 8) { - GGML_ASSERT(!(ggml_cpu_has_sve() && (svcntw() == 8)) && - "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); - } -#endif -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) - GGML_ASSERT(!(ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) && - "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance"); -#elif defined(__ARM_NEON) && defined(__aarch64__) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - size_t res_stride = bs * sizeof(float); - - __asm__ __volatile__( - "mov x10, %x[nr]\n" - "mov x9, #0x88\n" - "cmp x10, #0x10\n" - "mul x9, %x[nb], x9\n" - "blt 4f\n" - "1:" // Row loop - "add x28, %x[b_ptr], #0x8\n" - "mov x27, %x[nc]\n" - "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" - "2:" // Column loop - "add x25, %x[a_ptr], #0x8\n" - "movi v15.16b, #0x0\n" - "movi v19.16b, #0x0\n" - "mov x24, %x[nb]\n" - "add x23, x25, x9\n" - "movi v18.16b, #0x0\n" - "movi v14.16b, #0x0\n" - "add x22, x23, x9\n" - "movi v11.16b, #0x0\n" - "movi v13.16b, #0x0\n" - "add x21, x22, x9\n" - "movi v23.16b, #0x0\n" - "movi v16.16b, #0x0\n" - "movi v25.16b, #0x0\n" - "movi v7.16b, #0x0\n" - "movi v0.16b, #0x0\n" - "movi v4.16b, #0x0\n" - "movi v5.16b, #0x0\n" - "movi v21.16b, #0x0\n" - "movi v8.16b, #0x0\n" - "movi v1.16b, #0x0\n" - "3:" // Block loop - "ldr q3, [x28, #0x0]\n" - "ldr q31, [x25, #0x0]\n" - "movi v28.16b, #0x4\n" - "movi v10.4s, #0x0\n" - "ldr q22, [x28, #0x10]\n" - "ldr q6, [x25, #0x10]\n" - "movi v29.4s, #0x0\n" - "movi v9.4s, #0x0\n" - "ldr q27, [x28, #0x20]\n" - "ldr q30, [x28, #0x30]\n" - "movi v20.4s, #0x0\n" - "movi v24.16b, #0xf0\n" - "ldr d2, [x25, #-0x8]\n" - "ldr d26, [x23, #-0x8]\n" - "sshl v12.16b, v3.16b, v28.16b\n" - "sub x20, x28, #0x8\n" - "ldr d17, [x20, #0x0]\n" - "and v3.16b, v3.16b, v24.16b\n" - "subs x24, x24, #0x1\n" - "add x28, x28, #0x48\n" - ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n" - ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n" - ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n" - ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n" - "sshl v31.16b, v22.16b, v28.16b\n" - "and v22.16b, v22.16b, v24.16b\n" - "fcvtl v17.4s, v17.4h\n" - "fcvtl v2.4s, v2.4h\n" - "fcvtl v26.4s, v26.4h\n" - ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n" - ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n" - ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n" - ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n" - "sshl v6.16b, v27.16b, v28.16b\n" - "sshl v28.16b, v30.16b, v28.16b\n" - "and v27.16b, v27.16b, v24.16b\n" - "and v30.16b, v30.16b, v24.16b\n" - "ldr q24, [x25, #0x20]\n" - ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n" - ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" - ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n" - ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x30]\n" - ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n" - ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n" - ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n" - ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x40]\n" - ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n" - ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" - ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n" - ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x50]\n" - ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n" - ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n" - ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n" - ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x60]\n" - ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n" - ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" - ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n" - ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x70]\n" - "add x25, x25, #0x88\n" - ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n" - ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n" - ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n" - ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n" - "fmul v24.4s, v17.4s, v2.s[0]\n" - "scvtf v10.4s, v10.4s, #0x4\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "fmla v15.4s, v10.4s, v24.4s\n" - "ldr q24, [x23, #0x0]\n" - "fmul v10.4s, v17.4s, v2.s[1]\n" - "fmla v19.4s, v29.4s, v10.4s\n" - "ldr q10, [x23, #0x10]\n" - "fmul v29.4s, v17.4s, v2.s[2]\n" - "fmul v2.4s, v17.4s, v2.s[3]\n" - "fmla v18.4s, v9.4s, v29.4s\n" - "movi v9.4s, #0x0\n" - "movi v29.4s, #0x0\n" - ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n" - ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n" - "fmla v14.4s, v20.4s, v2.4s\n" - "movi v20.4s, #0x0\n" - "movi v2.4s, #0x0\n" - ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n" - ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" - "ldr q24, [x23, #0x20]\n" - ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n" - ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n" - ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n" - ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n" - "ldr q10, [x23, #0x30]\n" - ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n" - ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" - ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n" - ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" - "ldr q24, [x23, #0x40]\n" - ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n" - ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n" - ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n" - ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n" - "ldr q10, [x23, #0x50]\n" - ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n" - ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" - ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n" - ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" - "ldr q24, [x23, #0x60]\n" - ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n" - ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n" - ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n" - ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n" - "ldr q10, [x23, #0x70]\n" - "add x23, x23, #0x88\n" - ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n" - ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" - ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n" - ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" - "ldr q24, [x22, #0x0]\n" - ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n" - ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n" - ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n" - ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n" - "fmul v10.4s, v17.4s, v26.s[0]\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "fmla v11.4s, v9.4s, v10.4s\n" - "ldr q9, [x22, #0x10]\n" - "fmul v10.4s, v17.4s, v26.s[1]\n" - "fmla v13.4s, v29.4s, v10.4s\n" - "ldr d29, [x22, #-0x8]\n" - "fmul v10.4s, v17.4s, v26.s[2]\n" - "fmul v26.4s, v17.4s, v26.s[3]\n" - "fcvtl v29.4s, v29.4h\n" - "fmla v23.4s, v20.4s, v10.4s\n" - "movi v20.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "fmla v16.4s, v2.4s, v26.4s\n" - "movi v26.4s, #0x0\n" - "movi v2.4s, #0x0\n" - ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n" - ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" - ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n" - ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" - "ldr q24, [x22, #0x20]\n" - ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n" - ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" - ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n" - ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n" - "ldr q9, [x22, #0x30]\n" - ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n" - ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n" - ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n" - ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" - "ldr q24, [x22, #0x40]\n" - ".inst 0x4f89e394 // sdot v20.4s, v28.16b, v9.4b[0]\n" - ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" - ".inst 0x4f89eb9a // sdot v26.4s, v28.16b, v9.4b[2]\n" - ".inst 0x4fa9eb82 // sdot v2.4s, v28.16b, v9.4b[3]\n" - "ldr q9, [x22, #0x50]\n" - ".inst 0x4f98e074 // sdot v20.4s, v3.16b, v24.4b[0]\n" - ".inst 0x4fb8e06a // sdot v10.4s, v3.16b, v24.4b[1]\n" - ".inst 0x4f98e87a // sdot v26.4s, v3.16b, v24.4b[2]\n" - ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" - "ldr q24, [x22, #0x60]\n" - ".inst 0x4f89e2d4 // sdot v20.4s, v22.16b, v9.4b[0]\n" - ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" - ".inst 0x4f89eada // sdot v26.4s, v22.16b, v9.4b[2]\n" - ".inst 0x4fa9eac2 // sdot v2.4s, v22.16b, v9.4b[3]\n" - "ldr q9, [x22, #0x70]\n" - "add x22, x22, #0x88\n" - ".inst 0x4f98e374 // sdot v20.4s, v27.16b, v24.4b[0]\n" - ".inst 0x4fb8e36a // sdot v10.4s, v27.16b, v24.4b[1]\n" - ".inst 0x4f98eb7a // sdot v26.4s, v27.16b, v24.4b[2]\n" - ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" - "ldr q24, [x21, #0x0]\n" - ".inst 0x4f89e3d4 // sdot v20.4s, v30.16b, v9.4b[0]\n" - ".inst 0x4fa9e3ca // sdot v10.4s, v30.16b, v9.4b[1]\n" - ".inst 0x4f89ebda // sdot v26.4s, v30.16b, v9.4b[2]\n" - ".inst 0x4fa9ebc2 // sdot v2.4s, v30.16b, v9.4b[3]\n" - "fmul v9.4s, v17.4s, v29.s[0]\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "scvtf v10.4s, v10.4s, #0x4\n" - "scvtf v26.4s, v26.4s, #0x4\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "fmla v25.4s, v20.4s, v9.4s\n" - "ldr q9, [x21, #0x10]\n" - "fmul v20.4s, v17.4s, v29.s[1]\n" - "fmla v7.4s, v10.4s, v20.4s\n" - "ldr d20, [x21, #-0x8]\n" - "fmul v10.4s, v17.4s, v29.s[2]\n" - "fmul v29.4s, v17.4s, v29.s[3]\n" - "fcvtl v20.4s, v20.4h\n" - "fmla v0.4s, v26.4s, v10.4s\n" - "movi v26.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "fmla v4.4s, v2.4s, v29.4s\n" - "movi v2.4s, #0x0\n" - "movi v29.4s, #0x0\n" - ".inst 0x4f98e19a // sdot v26.4s, v12.16b, v24.4b[0]\n" - ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" - ".inst 0x4f98e982 // sdot v2.4s, v12.16b, v24.4b[2]\n" - ".inst 0x4fb8e99d // sdot v29.4s, v12.16b, v24.4b[3]\n" - "ldr q12, [x21, #0x20]\n" - "fmul v24.4s, v17.4s, v20.s[0]\n" - ".inst 0x4f89e3fa // sdot v26.4s, v31.16b, v9.4b[0]\n" - ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" - ".inst 0x4f89ebe2 // sdot v2.4s, v31.16b, v9.4b[2]\n" - ".inst 0x4fa9ebfd // sdot v29.4s, v31.16b, v9.4b[3]\n" - "ldr q9, [x21, #0x30]\n" - "fmul v31.4s, v17.4s, v20.s[1]\n" - ".inst 0x4f8ce0da // sdot v26.4s, v6.16b, v12.4b[0]\n" - ".inst 0x4face0ca // sdot v10.4s, v6.16b, v12.4b[1]\n" - ".inst 0x4f8ce8c2 // sdot v2.4s, v6.16b, v12.4b[2]\n" - ".inst 0x4face8dd // sdot v29.4s, v6.16b, v12.4b[3]\n" - "ldr q12, [x21, #0x40]\n" - "fmul v6.4s, v17.4s, v20.s[2]\n" - "fmul v20.4s, v17.4s, v20.s[3]\n" - ".inst 0x4f89e39a // sdot v26.4s, v28.16b, v9.4b[0]\n" - ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" - ".inst 0x4f89eb82 // sdot v2.4s, v28.16b, v9.4b[2]\n" - ".inst 0x4fa9eb9d // sdot v29.4s, v28.16b, v9.4b[3]\n" - "ldr q9, [x21, #0x50]\n" - ".inst 0x4f8ce07a // sdot v26.4s, v3.16b, v12.4b[0]\n" - ".inst 0x4face06a // sdot v10.4s, v3.16b, v12.4b[1]\n" - ".inst 0x4f8ce862 // sdot v2.4s, v3.16b, v12.4b[2]\n" - ".inst 0x4face87d // sdot v29.4s, v3.16b, v12.4b[3]\n" - "ldr q12, [x21, #0x60]\n" - ".inst 0x4f89e2da // sdot v26.4s, v22.16b, v9.4b[0]\n" - ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" - ".inst 0x4f89eac2 // sdot v2.4s, v22.16b, v9.4b[2]\n" - ".inst 0x4fa9eadd // sdot v29.4s, v22.16b, v9.4b[3]\n" - "ldr q17, [x21, #0x70]\n" - "add x21, x21, #0x88\n" - ".inst 0x4f8ce37a // sdot v26.4s, v27.16b, v12.4b[0]\n" - ".inst 0x4face36a // sdot v10.4s, v27.16b, v12.4b[1]\n" - ".inst 0x4f8ceb62 // sdot v2.4s, v27.16b, v12.4b[2]\n" - ".inst 0x4faceb7d // sdot v29.4s, v27.16b, v12.4b[3]\n" - ".inst 0x4f91e3da // sdot v26.4s, v30.16b, v17.4b[0]\n" - ".inst 0x4fb1e3ca // sdot v10.4s, v30.16b, v17.4b[1]\n" - ".inst 0x4f91ebc2 // sdot v2.4s, v30.16b, v17.4b[2]\n" - ".inst 0x4fb1ebdd // sdot v29.4s, v30.16b, v17.4b[3]\n" - "scvtf v26.4s, v26.4s, #0x4\n" - "scvtf v10.4s, v10.4s, #0x4\n" - "fmla v5.4s, v26.4s, v24.4s\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "fmla v21.4s, v10.4s, v31.4s\n" - "fmla v8.4s, v2.4s, v6.4s\n" - "fmla v1.4s, v29.4s, v20.4s\n" - "bgt 3b\n" - "mov x20, %x[res_ptr]\n" - "subs x27, x27, #0x4\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "str q15, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q19, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q18, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q14, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q11, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q13, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q23, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q16, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q25, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q7, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q0, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q4, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q5, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q21, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q8, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q1, [x20, #0x0]\n" - "bne 2b\n" - "mov x20, #0x4\n" - "sub x10, x10, #0x10\n" - "cmp x10, #0x10\n" - "mov %x[res_ptr], x26\n" - "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" - "bge 1b\n" - "4:" // Row loop skip - "cbz x10, 9f\n" - "5:" // Row tail: Row loop - "add x24, %x[b_ptr], #0x8\n" - "mov x23, %x[nc]\n" - "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" - "6:" // Row tail: Column loop - "movi v15.16b, #0x0\n" - "movi v19.16b, #0x0\n" - "add x25, %x[a_ptr], #0x8\n" - "mov x21, %x[nb]\n" - "movi v18.16b, #0x0\n" - "movi v14.16b, #0x0\n" - "7:" // Row tail: Block loop - "ldr q7, [x24, #0x0]\n" - "ldr q5, [x25, #0x0]\n" - "movi v9.16b, #0x4\n" - "movi v4.4s, #0x0\n" - "ldr q3, [x24, #0x10]\n" - "ldr q2, [x25, #0x10]\n" - "movi v1.4s, #0x0\n" - "movi v0.4s, #0x0\n" - "ldr q13, [x24, #0x20]\n" - "ldr q31, [x25, #0x20]\n" - "movi v30.4s, #0x0\n" - "movi v29.16b, #0xf0\n" - "ldr q28, [x24, #0x30]\n" - "ldr q27, [x25, #0x30]\n" - "sshl v20.16b, v7.16b, v9.16b\n" - "sub x20, x24, #0x8\n" - "ldr q26, [x25, #0x40]\n" - "ldr q25, [x25, #0x50]\n" - "sshl v17.16b, v3.16b, v9.16b\n" - "and v7.16b, v7.16b, v29.16b\n" - "ldr q24, [x25, #0x60]\n" - "ldr q16, [x25, #0x70]\n" - "sshl v22.16b, v13.16b, v9.16b\n" - "and v3.16b, v3.16b, v29.16b\n" - "ldr d21, [x20, #0x0]\n" - "ldr d12, [x25, #-0x8]\n" - ".inst 0x4f85e284 // sdot v4.4s, v20.16b, v5.4b[0]\n" - ".inst 0x4fa5e281 // sdot v1.4s, v20.16b, v5.4b[1]\n" - ".inst 0x4f85ea80 // sdot v0.4s, v20.16b, v5.4b[2]\n" - ".inst 0x4fa5ea9e // sdot v30.4s, v20.16b, v5.4b[3]\n" - "sshl v9.16b, v28.16b, v9.16b\n" - "subs x21, x21, #0x1\n" - "and v13.16b, v13.16b, v29.16b\n" - "and v28.16b, v28.16b, v29.16b\n" - "add x25, x25, #0x88\n" - "add x24, x24, #0x48\n" - "fcvtl v21.4s, v21.4h\n" - "fcvtl v12.4s, v12.4h\n" - ".inst 0x4f82e224 // sdot v4.4s, v17.16b, v2.4b[0]\n" - ".inst 0x4fa2e221 // sdot v1.4s, v17.16b, v2.4b[1]\n" - ".inst 0x4f82ea20 // sdot v0.4s, v17.16b, v2.4b[2]\n" - ".inst 0x4fa2ea3e // sdot v30.4s, v17.16b, v2.4b[3]\n" - "fmul v11.4s, v21.4s, v12.s[0]\n" - "fmul v23.4s, v21.4s, v12.s[1]\n" - "fmul v17.4s, v21.4s, v12.s[2]\n" - ".inst 0x4f9fe2c4 // sdot v4.4s, v22.16b, v31.4b[0]\n" - "fmul v6.4s, v21.4s, v12.s[3]\n" - ".inst 0x4fbfe2c1 // sdot v1.4s, v22.16b, v31.4b[1]\n" - ".inst 0x4f9feac0 // sdot v0.4s, v22.16b, v31.4b[2]\n" - ".inst 0x4fbfeade // sdot v30.4s, v22.16b, v31.4b[3]\n" - ".inst 0x4f9be124 // sdot v4.4s, v9.16b, v27.4b[0]\n" - ".inst 0x4fbbe121 // sdot v1.4s, v9.16b, v27.4b[1]\n" - ".inst 0x4f9be920 // sdot v0.4s, v9.16b, v27.4b[2]\n" - ".inst 0x4fbbe93e // sdot v30.4s, v9.16b, v27.4b[3]\n" - ".inst 0x4f9ae0e4 // sdot v4.4s, v7.16b, v26.4b[0]\n" - ".inst 0x4fbae0e1 // sdot v1.4s, v7.16b, v26.4b[1]\n" - ".inst 0x4f9ae8e0 // sdot v0.4s, v7.16b, v26.4b[2]\n" - ".inst 0x4fbae8fe // sdot v30.4s, v7.16b, v26.4b[3]\n" - ".inst 0x4f99e064 // sdot v4.4s, v3.16b, v25.4b[0]\n" - ".inst 0x4fb9e061 // sdot v1.4s, v3.16b, v25.4b[1]\n" - ".inst 0x4f99e860 // sdot v0.4s, v3.16b, v25.4b[2]\n" - ".inst 0x4fb9e87e // sdot v30.4s, v3.16b, v25.4b[3]\n" - ".inst 0x4f98e1a4 // sdot v4.4s, v13.16b, v24.4b[0]\n" - ".inst 0x4fb8e1a1 // sdot v1.4s, v13.16b, v24.4b[1]\n" - ".inst 0x4f98e9a0 // sdot v0.4s, v13.16b, v24.4b[2]\n" - ".inst 0x4fb8e9be // sdot v30.4s, v13.16b, v24.4b[3]\n" - ".inst 0x4f90e384 // sdot v4.4s, v28.16b, v16.4b[0]\n" - ".inst 0x4fb0e381 // sdot v1.4s, v28.16b, v16.4b[1]\n" - ".inst 0x4f90eb80 // sdot v0.4s, v28.16b, v16.4b[2]\n" - ".inst 0x4fb0eb9e // sdot v30.4s, v28.16b, v16.4b[3]\n" - "scvtf v4.4s, v4.4s, #0x4\n" - "scvtf v1.4s, v1.4s, #0x4\n" - "scvtf v0.4s, v0.4s, #0x4\n" - "fmla v15.4s, v4.4s, v11.4s\n" - "scvtf v30.4s, v30.4s, #0x4\n" - "fmla v19.4s, v1.4s, v23.4s\n" - "fmla v18.4s, v0.4s, v17.4s\n" - "fmla v14.4s, v30.4s, v6.4s\n" - "bgt 7b\n" - "mov x20, %x[res_ptr]\n" - "cmp x10, #0x1\n" - "str q15, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x10, #0x2\n" - "str q19, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x10, #0x3\n" - "str q18, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "str q14, [x20, #0x0]\n" - "8:" // Row tail: Accumulator store skip - "subs x23, x23, #0x4\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "bne 6b\n" - "subs x10, x10, #0x4\n" - "add %x[a_ptr], %x[a_ptr], x9\n" - "mov %x[res_ptr], x22\n" - "bgt 5b\n" - "9:" // Row tail: Row loop skip - : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) - : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" - ); -#else - float sumf[4][4]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; - } - sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } -#endif -} - -void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) - if (svcntw() == 8) { - GGML_ASSERT(!(ggml_cpu_has_sve() && (svcntw() == 8)) && - "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); - } -#endif -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - size_t res_stride = bs * sizeof(float); - - __asm__ __volatile__( - "mov x10, %x[nr]\n" - "mov x9, #0x88\n" - "cmp x10, #0x10\n" - "mul x9, %x[nb], x9\n" - "blt 4f\n" - "1:" // Row loop - "add x28, %x[b_ptr], #0x8\n" - "mov x27, %x[nc]\n" - "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" - "2:" // Column loop - "add x25, %x[a_ptr], #0x8\n" - "movi v2.16b, #0x0\n" - "movi v10.16b, #0x0\n" - "mov x24, %x[nb]\n" - "add x23, x25, x9\n" - "movi v12.16b, #0x0\n" - "movi v28.16b, #0x0\n" - "add x22, x23, x9\n" - "movi v11.16b, #0x0\n" - "movi v13.16b, #0x0\n" - "add x21, x22, x9\n" - "movi v22.16b, #0x0\n" - "movi v23.16b, #0x0\n" - "movi v25.16b, #0x0\n" - "movi v5.16b, #0x0\n" - "movi v7.16b, #0x0\n" - "movi v4.16b, #0x0\n" - "movi v6.16b, #0x0\n" - "movi v30.16b, #0x0\n" - "movi v24.16b, #0x0\n" - "movi v14.16b, #0x0\n" - "3:" // Block loop - "ldr q21, [x28, #0x0]\n" - "ldr q16, [x28, #0x10]\n" - "movi v1.16b, #0x4\n" - "movi v19.4s, #0x0\n" - "ldr q27, [x25, #0x0]\n" - "ldr q15, [x25, #0x10]\n" - "movi v26.4s, #0x0\n" - "movi v18.4s, #0x0\n" - "ldr q29, [x28, #0x20]\n" - "ldr q3, [x28, #0x30]\n" - "movi v17.4s, #0x0\n" - "movi v0.16b, #0xf0\n" - "ldr d20, [x25, #-0x8]\n" - "ldr d9, [x23, #-0x8]\n" - "sshl v8.16b, v21.16b, v1.16b\n" - "sshl v31.16b, v16.16b, v1.16b\n" - "and v21.16b, v21.16b, v0.16b\n" - "and v16.16b, v16.16b, v0.16b\n" - "sub x20, x28, #0x8\n" - "subs x24, x24, #0x1\n" - "add x28, x28, #0x48\n" - ".inst 0x4e88a773 // smmla v19.4s, v27.16b, v8.16b\n" - ".inst 0x4e9fa77a // smmla v26.4s, v27.16b, v31.16b\n" - "ldr q27, [x25, #0x20]\n" - ".inst 0x4e88a5f2 // smmla v18.4s, v15.16b, v8.16b\n" - ".inst 0x4e9fa5f1 // smmla v17.4s, v15.16b, v31.16b\n" - "sshl v15.16b, v29.16b, v1.16b\n" - "sshl v1.16b, v3.16b, v1.16b\n" - "and v29.16b, v29.16b, v0.16b\n" - "and v3.16b, v3.16b, v0.16b\n" - "ldr q0, [x25, #0x30]\n" - "fcvtl v20.4s, v20.4h\n" - ".inst 0x4e8fa773 // smmla v19.4s, v27.16b, v15.16b\n" - "fcvtl v9.4s, v9.4h\n" - ".inst 0x4e81a77a // smmla v26.4s, v27.16b, v1.16b\n" - "ldr q27, [x25, #0x40]\n" - ".inst 0x4e8fa412 // smmla v18.4s, v0.16b, v15.16b\n" - ".inst 0x4e81a411 // smmla v17.4s, v0.16b, v1.16b\n" - "ldr q0, [x25, #0x50]\n" - ".inst 0x4e95a773 // smmla v19.4s, v27.16b, v21.16b\n" - ".inst 0x4e90a77a // smmla v26.4s, v27.16b, v16.16b\n" - "ldr q27, [x25, #0x60]\n" - ".inst 0x4e95a412 // smmla v18.4s, v0.16b, v21.16b\n" - ".inst 0x4e90a411 // smmla v17.4s, v0.16b, v16.16b\n" - "ldr q0, [x25, #0x70]\n" - "add x25, x25, #0x88\n" - ".inst 0x4e9da773 // smmla v19.4s, v27.16b, v29.16b\n" - ".inst 0x4e83a77a // smmla v26.4s, v27.16b, v3.16b\n" - "ldr d27, [x20, #0x0]\n" - ".inst 0x4e9da412 // smmla v18.4s, v0.16b, v29.16b\n" - ".inst 0x4e83a411 // smmla v17.4s, v0.16b, v3.16b\n" - "fcvtl v27.4s, v27.4h\n" - "uzp1 v0.2d, v19.2d, v26.2d\n" - "uzp2 v26.2d, v19.2d, v26.2d\n" - "fmul v19.4s, v27.4s, v20.s[0]\n" - "scvtf v0.4s, v0.4s, #0x4\n" - "scvtf v26.4s, v26.4s, #0x4\n" - "fmla v2.4s, v0.4s, v19.4s\n" - "ldr q19, [x23, #0x0]\n" - "uzp1 v0.2d, v18.2d, v17.2d\n" - "uzp2 v18.2d, v18.2d, v17.2d\n" - "fmul v17.4s, v27.4s, v20.s[1]\n" - "scvtf v0.4s, v0.4s, #0x4\n" - "scvtf v18.4s, v18.4s, #0x4\n" - "fmla v10.4s, v26.4s, v17.4s\n" - "ldr q17, [x23, #0x10]\n" - "fmul v26.4s, v27.4s, v20.s[2]\n" - "fmul v20.4s, v27.4s, v20.s[3]\n" - "fmla v12.4s, v0.4s, v26.4s\n" - "ldr d0, [x22, #-0x8]\n" - "ldr d26, [x21, #-0x8]\n" - "fcvtl v0.4s, v0.4h\n" - "fmla v28.4s, v18.4s, v20.4s\n" - "movi v20.4s, #0x0\n" - "movi v18.4s, #0x0\n" - ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" - ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" - "ldr q19, [x23, #0x20]\n" - "fcvtl v26.4s, v26.4h\n" - ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" - ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" - "ldr q19, [x23, #0x40]\n" - ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" - ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" - "ldr q19, [x23, #0x60]\n" - ".inst 0x4e9da674 // smmla v20.4s, v19.16b, v29.16b\n" - ".inst 0x4e83a672 // smmla v18.4s, v19.16b, v3.16b\n" - "uzp1 v19.2d, v20.2d, v18.2d\n" - "scvtf v19.4s, v19.4s, #0x4\n" - "uzp2 v20.2d, v20.2d, v18.2d\n" - "fmul v18.4s, v27.4s, v9.s[0]\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "fmla v11.4s, v19.4s, v18.4s\n" - "ldr q18, [x22, #0x0]\n" - "fmul v19.4s, v27.4s, v9.s[1]\n" - "fmla v13.4s, v20.4s, v19.4s\n" - "movi v19.4s, #0x0\n" - "movi v20.4s, #0x0\n" - ".inst 0x4e88a633 // smmla v19.4s, v17.16b, v8.16b\n" - ".inst 0x4e9fa634 // smmla v20.4s, v17.16b, v31.16b\n" - "ldr q17, [x23, #0x30]\n" - ".inst 0x4e8fa633 // smmla v19.4s, v17.16b, v15.16b\n" - ".inst 0x4e81a634 // smmla v20.4s, v17.16b, v1.16b\n" - "ldr q17, [x23, #0x50]\n" - ".inst 0x4e95a633 // smmla v19.4s, v17.16b, v21.16b\n" - ".inst 0x4e90a634 // smmla v20.4s, v17.16b, v16.16b\n" - "ldr q17, [x23, #0x70]\n" - "add x23, x23, #0x88\n" - ".inst 0x4e9da633 // smmla v19.4s, v17.16b, v29.16b\n" - ".inst 0x4e83a634 // smmla v20.4s, v17.16b, v3.16b\n" - "uzp1 v17.2d, v19.2d, v20.2d\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "uzp2 v20.2d, v19.2d, v20.2d\n" - "fmul v19.4s, v27.4s, v9.s[2]\n" - "fmul v9.4s, v27.4s, v9.s[3]\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "fmla v22.4s, v17.4s, v19.4s\n" - "ldr q17, [x22, #0x10]\n" - "movi v19.4s, #0x0\n" - ".inst 0x4e88a653 // smmla v19.4s, v18.16b, v8.16b\n" - "fmla v23.4s, v20.4s, v9.4s\n" - "movi v20.4s, #0x0\n" - "movi v9.4s, #0x0\n" - ".inst 0x4e9fa654 // smmla v20.4s, v18.16b, v31.16b\n" - "ldr q18, [x22, #0x20]\n" - ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" - ".inst 0x4e8fa653 // smmla v19.4s, v18.16b, v15.16b\n" - ".inst 0x4e81a654 // smmla v20.4s, v18.16b, v1.16b\n" - "ldr q18, [x22, #0x40]\n" - ".inst 0x4e95a653 // smmla v19.4s, v18.16b, v21.16b\n" - ".inst 0x4e90a654 // smmla v20.4s, v18.16b, v16.16b\n" - "ldr q18, [x22, #0x60]\n" - ".inst 0x4e9da653 // smmla v19.4s, v18.16b, v29.16b\n" - ".inst 0x4e83a654 // smmla v20.4s, v18.16b, v3.16b\n" - "movi v18.4s, #0x0\n" - ".inst 0x4e9fa632 // smmla v18.4s, v17.16b, v31.16b\n" - "ldr q17, [x22, #0x30]\n" - ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" - ".inst 0x4e81a632 // smmla v18.4s, v17.16b, v1.16b\n" - "ldr q17, [x22, #0x50]\n" - ".inst 0x4e95a629 // smmla v9.4s, v17.16b, v21.16b\n" - ".inst 0x4e90a632 // smmla v18.4s, v17.16b, v16.16b\n" - "ldr q17, [x22, #0x70]\n" - "add x22, x22, #0x88\n" - ".inst 0x4e9da629 // smmla v9.4s, v17.16b, v29.16b\n" - ".inst 0x4e83a632 // smmla v18.4s, v17.16b, v3.16b\n" - "uzp1 v17.2d, v19.2d, v20.2d\n" - "uzp2 v20.2d, v19.2d, v20.2d\n" - "fmul v19.4s, v27.4s, v0.s[0]\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "fmla v25.4s, v17.4s, v19.4s\n" - "ldr q19, [x21, #0x0]\n" - "fmul v17.4s, v27.4s, v0.s[1]\n" - "fmla v5.4s, v20.4s, v17.4s\n" - "ldr q17, [x21, #0x10]\n" - "uzp1 v20.2d, v9.2d, v18.2d\n" - "uzp2 v9.2d, v9.2d, v18.2d\n" - "fmul v18.4s, v27.4s, v0.s[2]\n" - "fmul v0.4s, v27.4s, v0.s[3]\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "fmla v7.4s, v20.4s, v18.4s\n" - "movi v20.4s, #0x0\n" - "movi v18.4s, #0x0\n" - ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" - ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" - "ldr q19, [x21, #0x20]\n" - "fmla v4.4s, v9.4s, v0.4s\n" - "movi v9.4s, #0x0\n" - "movi v0.4s, #0x0\n" - ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" - "fmul v8.4s, v27.4s, v26.s[0]\n" - ".inst 0x4e9fa620 // smmla v0.4s, v17.16b, v31.16b\n" - "ldr q17, [x21, #0x30]\n" - ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" - "fmul v31.4s, v27.4s, v26.s[1]\n" - ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" - "ldr q19, [x21, #0x40]\n" - ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" - "fmul v15.4s, v27.4s, v26.s[2]\n" - "fmul v27.4s, v27.4s, v26.s[3]\n" - ".inst 0x4e81a620 // smmla v0.4s, v17.16b, v1.16b\n" - "ldr q1, [x21, #0x50]\n" - ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" - ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" - "ldr q26, [x21, #0x60]\n" - ".inst 0x4e95a429 // smmla v9.4s, v1.16b, v21.16b\n" - ".inst 0x4e90a420 // smmla v0.4s, v1.16b, v16.16b\n" - "ldr q21, [x21, #0x70]\n" - "add x21, x21, #0x88\n" - ".inst 0x4e9da754 // smmla v20.4s, v26.16b, v29.16b\n" - ".inst 0x4e83a752 // smmla v18.4s, v26.16b, v3.16b\n" - ".inst 0x4e9da6a9 // smmla v9.4s, v21.16b, v29.16b\n" - ".inst 0x4e83a6a0 // smmla v0.4s, v21.16b, v3.16b\n" - "uzp1 v29.2d, v20.2d, v18.2d\n" - "uzp2 v21.2d, v20.2d, v18.2d\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "uzp1 v18.2d, v9.2d, v0.2d\n" - "uzp2 v16.2d, v9.2d, v0.2d\n" - "scvtf v21.4s, v21.4s, #0x4\n" - "fmla v6.4s, v29.4s, v8.4s\n" - "scvtf v18.4s, v18.4s, #0x4\n" - "scvtf v16.4s, v16.4s, #0x4\n" - "fmla v30.4s, v21.4s, v31.4s\n" - "fmla v24.4s, v18.4s, v15.4s\n" - "fmla v14.4s, v16.4s, v27.4s\n" - "bgt 3b\n" - "mov x20, %x[res_ptr]\n" - "subs x27, x27, #0x4\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "str q2, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q10, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q12, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q28, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q11, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q13, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q22, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q23, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q25, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q5, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q7, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q4, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q6, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q30, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q24, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q14, [x20, #0x0]\n" - "bne 2b\n" - "mov x20, #0x4\n" - "sub x10, x10, #0x10\n" - "cmp x10, #0x10\n" - "mov %x[res_ptr], x26\n" - "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" - "bge 1b\n" - "4:" // Row loop skip - "cbz x10, 9f\n" - "5:" // Row tail: Row loop - "add x24, %x[b_ptr], #0x8\n" - "mov x23, %x[nc]\n" - "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" - "6:" // Row tail: Column loop - "movi v2.16b, #0x0\n" - "movi v10.16b, #0x0\n" - "add x25, %x[a_ptr], #0x8\n" - "mov x21, %x[nb]\n" - "movi v12.16b, #0x0\n" - "movi v28.16b, #0x0\n" - "7:" // Row tail: Block loop - "ldr q6, [x24, #0x0]\n" - "ldr q5, [x24, #0x10]\n" - "movi v17.16b, #0x4\n" - "movi v8.4s, #0x0\n" - "ldr q4, [x25, #0x0]\n" - "ldr q13, [x25, #0x10]\n" - "movi v27.4s, #0x0\n" - "movi v0.4s, #0x0\n" - "ldr q31, [x24, #0x20]\n" - "ldr q14, [x24, #0x30]\n" - "movi v29.4s, #0x0\n" - "movi v22.16b, #0xf0\n" - "ldr q11, [x25, #0x20]\n" - "ldr q23, [x25, #0x30]\n" - "sshl v21.16b, v6.16b, v17.16b\n" - "sshl v16.16b, v5.16b, v17.16b\n" - "ldr q20, [x25, #0x40]\n" - "ldr q26, [x25, #0x50]\n" - "and v6.16b, v6.16b, v22.16b\n" - "and v5.16b, v5.16b, v22.16b\n" - "ldr q25, [x25, #0x60]\n" - "ldr q3, [x25, #0x70]\n" - "sshl v19.16b, v31.16b, v17.16b\n" - "sshl v18.16b, v14.16b, v17.16b\n" - "ldr d17, [x25, #-0x8]\n" - ".inst 0x4e95a488 // smmla v8.4s, v4.16b, v21.16b\n" - ".inst 0x4e90a49b // smmla v27.4s, v4.16b, v16.16b\n" - "and v31.16b, v31.16b, v22.16b\n" - ".inst 0x4e95a5a0 // smmla v0.4s, v13.16b, v21.16b\n" - ".inst 0x4e90a5bd // smmla v29.4s, v13.16b, v16.16b\n" - "and v14.16b, v14.16b, v22.16b\n" - "sub x20, x24, #0x8\n" - "ldr d16, [x20, #0x0]\n" - "subs x21, x21, #0x1\n" - "add x25, x25, #0x88\n" - "fcvtl v17.4s, v17.4h\n" - "add x24, x24, #0x48\n" - ".inst 0x4e93a568 // smmla v8.4s, v11.16b, v19.16b\n" - ".inst 0x4e92a57b // smmla v27.4s, v11.16b, v18.16b\n" - ".inst 0x4e93a6e0 // smmla v0.4s, v23.16b, v19.16b\n" - ".inst 0x4e92a6fd // smmla v29.4s, v23.16b, v18.16b\n" - "fcvtl v16.4s, v16.4h\n" - ".inst 0x4e86a688 // smmla v8.4s, v20.16b, v6.16b\n" - ".inst 0x4e85a69b // smmla v27.4s, v20.16b, v5.16b\n" - "fmul v23.4s, v16.4s, v17.s[0]\n" - "fmul v21.4s, v16.4s, v17.s[1]\n" - "fmul v1.4s, v16.4s, v17.s[2]\n" - "fmul v20.4s, v16.4s, v17.s[3]\n" - ".inst 0x4e86a740 // smmla v0.4s, v26.16b, v6.16b\n" - ".inst 0x4e85a75d // smmla v29.4s, v26.16b, v5.16b\n" - ".inst 0x4e9fa728 // smmla v8.4s, v25.16b, v31.16b\n" - ".inst 0x4e8ea73b // smmla v27.4s, v25.16b, v14.16b\n" - ".inst 0x4e9fa460 // smmla v0.4s, v3.16b, v31.16b\n" - ".inst 0x4e8ea47d // smmla v29.4s, v3.16b, v14.16b\n" - "uzp1 v19.2d, v8.2d, v27.2d\n" - "uzp2 v18.2d, v8.2d, v27.2d\n" - "scvtf v19.4s, v19.4s, #0x4\n" - "uzp1 v17.2d, v0.2d, v29.2d\n" - "uzp2 v16.2d, v0.2d, v29.2d\n" - "scvtf v18.4s, v18.4s, #0x4\n" - "fmla v2.4s, v19.4s, v23.4s\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "scvtf v16.4s, v16.4s, #0x4\n" - "fmla v10.4s, v18.4s, v21.4s\n" - "fmla v12.4s, v17.4s, v1.4s\n" - "fmla v28.4s, v16.4s, v20.4s\n" - "bgt 7b\n" - "mov x20, %x[res_ptr]\n" - "cmp x10, #0x1\n" - "str q2, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x10, #0x2\n" - "str q10, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x10, #0x3\n" - "str q12, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "str q28, [x20, #0x0]\n" - "8:" // Row tail: Accumulator store skip - "subs x23, x23, #0x4\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "bne 6b\n" - "subs x10, x10, #0x4\n" - "add %x[a_ptr], %x[a_ptr], x9\n" - "mov %x[res_ptr], x22\n" - "bgt 5b\n" - "9:" // Row tail: Row loop skip - : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) - : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" - ); -#elif defined(__ARM_NEON) && defined(__aarch64__) - GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) && - "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " - "performance"); -#else - float sumf[4][4]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; - } - sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } -#endif -} - -void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) - if (svcntw() == 8) { - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - size_t res_stride = bs * sizeof(float); - - __asm__ __volatile__( - "mov x20, #0x4\n" - "mov x13, %x[nr]\n" - "mov z28.s, #-0x4\n" - "mov x12, #0x88\n" - "ptrue p1.b\n" - "whilelt p0.s, XZR, x20\n" - "cmp x13, #0x10\n" - "mul x12, %x[nb], x12\n" - "blt 4f\n" - "1:" // Row loop - "add x11, %x[b_ptr], #0x10\n" - "mov x10, %x[nc]\n" - "add x9, %x[res_ptr], %x[res_stride], LSL #4\n" - "2:" // Column loop - "add x28, %x[a_ptr], #0x8\n" - "mov z24.b, #0x0\n" - "mov z15.b, #0x0\n" - "mov x27, %x[nb]\n" - "add x26, x28, x12\n" - "mov z12.b, #0x0\n" - "mov z0.b, #0x0\n" - "add x25, x26, x12\n" - "mov z13.b, #0x0\n" - "mov z1.b, #0x0\n" - "add x24, x25, x12\n" - "mov z20.b, #0x0\n" - "mov z25.b, #0x0\n" - "mov z11.b, #0x0\n" - "mov z16.b, #0x0\n" - "mov z19.b, #0x0\n" - "mov z26.b, #0x0\n" - "mov z8.b, #0x0\n" - "mov z29.b, #0x0\n" - "mov z27.b, #0x0\n" - "mov z10.b, #0x0\n" - "3:" // Block loop - "ld1b { z30.b }, p1/Z, [x11]\n" - "ld1b { z21.b }, p1/Z, [x11, #1, MUL VL]\n" - "mov z18.s, #0x0\n" - "mov z7.s, #0x0\n" - "ld1rqb { z3.b }, p1/Z, [x28]\n" - "ld1rqb { z5.b }, p1/Z, [x28, #16]\n" - "mov z9.s, #0x0\n" - "mov z22.s, #0x0\n" - "ld1b { z4.b }, p1/Z, [x11, #2, MUL VL]\n" - "ld1b { z17.b }, p1/Z, [x11, #3, MUL VL]\n" - "sub x20, x11, #0x10\n" - "sub x23, x28, #0x8\n" - "lsl z31.b, z30.b, #0x4\n" - "lsl z6.b, z21.b, #0x4\n" - "ld1h { z23.s }, p1/Z, [x20]\n" - "sub x22, x26, #0x8\n" - "and z30.b, z30.b, #0xf0\n" - "and z21.b, z21.b, #0xf0\n" - "sub x21, x25, #0x8\n" - "sub x20, x24, #0x8\n" - "lsl z14.b, z4.b, #0x4\n" - "lsl z2.b, z17.b, #0x4\n" - "subs x27, x27, #0x1\n" - "add x11, x11, #0x90\n" - ".inst 0x451f9872 // smmla z18.s, z3.b, z31.b\n" - ".inst 0x45069867 // smmla z7.s, z3.b, z6.b\n" - "ld1rqb { z3.b }, p1/Z, [x28, #32]\n" - "and z4.b, z4.b, #0xf0\n" - ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n" - ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n" - "ld1rqb { z5.b }, p1/Z, [x28, #48]\n" - "and z17.b, z17.b, #0xf0\n" - "fcvt z23.s, p1/m, z23.h\n" - ".inst 0x450e9872 // smmla z18.s, z3.b, z14.b\n" - ".inst 0x45029867 // smmla z7.s, z3.b, z2.b\n" - "ld1rqb { z3.b }, p1/Z, [x28, #64]\n" - ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n" - ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n" - "ld1rqb { z5.b }, p1/Z, [x28, #80]\n" - "fscale z23.s, p1/m, z23.s, z28.s\n" - ".inst 0x451e9872 // smmla z18.s, z3.b, z30.b\n" - ".inst 0x45159867 // smmla z7.s, z3.b, z21.b\n" - "ld1rqb { z3.b }, p1/Z, [x28, #96]\n" - ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n" - ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n" - "ld1rqb { z5.b }, p1/Z, [x28, #112]\n" - "add x28, x28, #0x88\n" - ".inst 0x45049872 // smmla z18.s, z3.b, z4.b\n" - ".inst 0x45119867 // smmla z7.s, z3.b, z17.b\n" - "ld1h { z3.s }, p0/Z, [x23]\n" - ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n" - ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n" - "fcvt z3.s, p1/m, z3.h\n" - "uzp1 z5.d, z18.d, z7.d\n" - "uzp2 z18.d, z18.d, z7.d\n" - "mov z3.q, z3.q[0]\n" - "uzp1 z7.d, z9.d, z22.d\n" - "uzp2 z22.d, z9.d, z22.d\n" - "fmul z9.s, z23.s, z3.s[0]\n" - "scvtf z5.s, p1/m, z5.s\n" - "scvtf z18.s, p1/m, z18.s\n" - "scvtf z7.s, p1/m, z7.s\n" - "scvtf z22.s, p1/m, z22.s\n" - "fmla z24.s, p1/M, z5.s, z9.s\n" - "ld1rqb { z5.b }, p1/Z, [x26]\n" - "fmul z9.s, z23.s, z3.s[1]\n" - "fmla z15.s, p1/M, z18.s, z9.s\n" - "ld1rqb { z18.b }, p1/Z, [x26, #16]\n" - "fmul z9.s, z23.s, z3.s[2]\n" - "fmul z3.s, z23.s, z3.s[3]\n" - "fmla z12.s, p1/M, z7.s, z9.s\n" - "mov z9.s, #0x0\n" - "ld1h { z7.s }, p0/Z, [x22]\n" - ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n" - "fmla z0.s, p1/M, z22.s, z3.s\n" - "mov z22.s, #0x0\n" - "ld1h { z3.s }, p0/Z, [x21]\n" - ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n" - "ld1rqb { z5.b }, p1/Z, [x26, #32]\n" - "fcvt z7.s, p1/m, z7.h\n" - "fcvt z3.s, p1/m, z3.h\n" - ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n" - ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n" - "ld1rqb { z5.b }, p1/Z, [x26, #64]\n" - "mov z7.q, z7.q[0]\n" - "mov z3.q, z3.q[0]\n" - ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n" - ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n" - "ld1rqb { z5.b }, p1/Z, [x26, #96]\n" - ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n" - ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n" - "uzp1 z5.d, z9.d, z22.d\n" - "scvtf z5.s, p1/m, z5.s\n" - "uzp2 z22.d, z9.d, z22.d\n" - "fmul z9.s, z23.s, z7.s[0]\n" - "scvtf z22.s, p1/m, z22.s\n" - "fmla z13.s, p1/M, z5.s, z9.s\n" - "ld1rqb { z9.b }, p1/Z, [x25]\n" - "fmul z5.s, z23.s, z7.s[1]\n" - "fmla z1.s, p1/M, z22.s, z5.s\n" - "mov z5.s, #0x0\n" - "mov z22.s, #0x0\n" - ".inst 0x451f9a45 // smmla z5.s, z18.b, z31.b\n" - ".inst 0x45069a56 // smmla z22.s, z18.b, z6.b\n" - "ld1rqb { z18.b }, p1/Z, [x26, #48]\n" - ".inst 0x450e9a45 // smmla z5.s, z18.b, z14.b\n" - ".inst 0x45029a56 // smmla z22.s, z18.b, z2.b\n" - "ld1rqb { z18.b }, p1/Z, [x26, #80]\n" - ".inst 0x451e9a45 // smmla z5.s, z18.b, z30.b\n" - ".inst 0x45159a56 // smmla z22.s, z18.b, z21.b\n" - "ld1rqb { z18.b }, p1/Z, [x26, #112]\n" - "add x26, x26, #0x88\n" - ".inst 0x45049a45 // smmla z5.s, z18.b, z4.b\n" - ".inst 0x45119a56 // smmla z22.s, z18.b, z17.b\n" - "uzp1 z18.d, z5.d, z22.d\n" - "scvtf z18.s, p1/m, z18.s\n" - "uzp2 z22.d, z5.d, z22.d\n" - "fmul z5.s, z23.s, z7.s[2]\n" - "fmul z7.s, z23.s, z7.s[3]\n" - "scvtf z22.s, p1/m, z22.s\n" - "fmla z20.s, p1/M, z18.s, z5.s\n" - "ld1rqb { z18.b }, p1/Z, [x25, #16]\n" - "ld1h { z5.s }, p0/Z, [x20]\n" - "fcvt z5.s, p1/m, z5.h\n" - "fmla z25.s, p1/M, z22.s, z7.s\n" - "mov z22.s, #0x0\n" - "mov z7.s, #0x0\n" - ".inst 0x451f9936 // smmla z22.s, z9.b, z31.b\n" - ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n" - "ld1rqb { z9.b }, p1/Z, [x25, #32]\n" - "mov z5.q, z5.q[0]\n" - ".inst 0x450e9936 // smmla z22.s, z9.b, z14.b\n" - ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n" - "ld1rqb { z9.b }, p1/Z, [x25, #64]\n" - ".inst 0x451e9936 // smmla z22.s, z9.b, z30.b\n" - ".inst 0x45159927 // smmla z7.s, z9.b, z21.b\n" - "ld1rqb { z9.b }, p1/Z, [x25, #96]\n" - ".inst 0x45049936 // smmla z22.s, z9.b, z4.b\n" - ".inst 0x45119927 // smmla z7.s, z9.b, z17.b\n" - "uzp1 z9.d, z22.d, z7.d\n" - "scvtf z9.s, p1/m, z9.s\n" - "uzp2 z22.d, z22.d, z7.d\n" - "fmul z7.s, z23.s, z3.s[0]\n" - "scvtf z22.s, p1/m, z22.s\n" - "fmla z11.s, p1/M, z9.s, z7.s\n" - "ld1rqb { z9.b }, p1/Z, [x24]\n" - "fmul z7.s, z23.s, z3.s[1]\n" - "fmla z16.s, p1/M, z22.s, z7.s\n" - "mov z22.s, #0x0\n" - "mov z7.s, #0x0\n" - ".inst 0x451f9a56 // smmla z22.s, z18.b, z31.b\n" - ".inst 0x45069a47 // smmla z7.s, z18.b, z6.b\n" - "ld1rqb { z18.b }, p1/Z, [x25, #48]\n" - ".inst 0x450e9a56 // smmla z22.s, z18.b, z14.b\n" - ".inst 0x45029a47 // smmla z7.s, z18.b, z2.b\n" - "ld1rqb { z18.b }, p1/Z, [x25, #80]\n" - ".inst 0x451e9a56 // smmla z22.s, z18.b, z30.b\n" - ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n" - "ld1rqb { z18.b }, p1/Z, [x25, #112]\n" - "add x25, x25, #0x88\n" - ".inst 0x45049a56 // smmla z22.s, z18.b, z4.b\n" - ".inst 0x45119a47 // smmla z7.s, z18.b, z17.b\n" - "uzp1 z18.d, z22.d, z7.d\n" - "scvtf z18.s, p1/m, z18.s\n" - "uzp2 z7.d, z22.d, z7.d\n" - "fmul z22.s, z23.s, z3.s[2]\n" - "fmul z3.s, z23.s, z3.s[3]\n" - "scvtf z7.s, p1/m, z7.s\n" - "fmla z19.s, p1/M, z18.s, z22.s\n" - "ld1rqb { z18.b }, p1/Z, [x24, #16]\n" - "fmul z22.s, z23.s, z5.s[0]\n" - "fmla z26.s, p1/M, z7.s, z3.s\n" - "mov z3.s, #0x0\n" - "mov z7.s, #0x0\n" - ".inst 0x451f9923 // smmla z3.s, z9.b, z31.b\n" - ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n" - "ld1rqb { z9.b }, p1/Z, [x24, #32]\n" - ".inst 0x450e9923 // smmla z3.s, z9.b, z14.b\n" - ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n" - "mov z9.s, #0x0\n" - ".inst 0x451f9a49 // smmla z9.s, z18.b, z31.b\n" - "mov z31.s, #0x0\n" - ".inst 0x45069a5f // smmla z31.s, z18.b, z6.b\n" - "ld1rqb { z6.b }, p1/Z, [x24, #48]\n" - "ld1rqb { z18.b }, p1/Z, [x24, #64]\n" - ".inst 0x450e98c9 // smmla z9.s, z6.b, z14.b\n" - "fmul z14.s, z23.s, z5.s[1]\n" - ".inst 0x450298df // smmla z31.s, z6.b, z2.b\n" - "ld1rqb { z6.b }, p1/Z, [x24, #80]\n" - "fmul z2.s, z23.s, z5.s[2]\n" - "fmul z23.s, z23.s, z5.s[3]\n" - ".inst 0x451e9a43 // smmla z3.s, z18.b, z30.b\n" - ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n" - "ld1rqb { z5.b }, p1/Z, [x24, #96]\n" - ".inst 0x451e98c9 // smmla z9.s, z6.b, z30.b\n" - ".inst 0x451598df // smmla z31.s, z6.b, z21.b\n" - "ld1rqb { z18.b }, p1/Z, [x24, #112]\n" - "add x24, x24, #0x88\n" - ".inst 0x450498a3 // smmla z3.s, z5.b, z4.b\n" - ".inst 0x451198a7 // smmla z7.s, z5.b, z17.b\n" - ".inst 0x45049a49 // smmla z9.s, z18.b, z4.b\n" - ".inst 0x45119a5f // smmla z31.s, z18.b, z17.b\n" - "uzp1 z18.d, z3.d, z7.d\n" - "uzp2 z5.d, z3.d, z7.d\n" - "scvtf z18.s, p1/m, z18.s\n" - "uzp1 z6.d, z9.d, z31.d\n" - "uzp2 z9.d, z9.d, z31.d\n" - "scvtf z5.s, p1/m, z5.s\n" - "fmla z8.s, p1/M, z18.s, z22.s\n" - "scvtf z6.s, p1/m, z6.s\n" - "scvtf z9.s, p1/m, z9.s\n" - "fmla z29.s, p1/M, z5.s, z14.s\n" - "fmla z27.s, p1/M, z6.s, z2.s\n" - "fmla z10.s, p1/M, z9.s, z23.s\n" - "bgt 3b\n" - "mov x20, %x[res_ptr]\n" - "subs x10, x10, #0x8\n" - "add %x[res_ptr], %x[res_ptr], #0x20\n" - "st1w { z24.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z15.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z12.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z0.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z13.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z1.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z20.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z25.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z11.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z16.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z19.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z26.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z8.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z29.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z27.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z10.s }, p1, [x20]\n" - "bne 2b\n" - "mov x20, #0x4\n" - "sub x13, x13, #0x10\n" - "cmp x13, #0x10\n" - "mov %x[res_ptr], x9\n" - "madd %x[a_ptr], x20, x12, %x[a_ptr]\n" - "bge 1b\n" - "4:" // Row loop skip - "cbz x13, 9f\n" - "5:" // Row tail: Row loop - "add x25, %x[b_ptr], #0x10\n" - "mov x24, %x[nc]\n" - "add x23, %x[res_ptr], %x[res_stride], LSL #2\n" - "6:" // Row tail: Column loop - "mov z24.b, #0x0\n" - "mov z15.b, #0x0\n" - "add x28, %x[a_ptr], #0x8\n" - "mov x22, %x[nb]\n" - "mov z12.b, #0x0\n" - "mov z0.b, #0x0\n" - "7:" // Row tail: Block loop - "ld1b { z3.b }, p1/Z, [x25]\n" - "ld1b { z6.b }, p1/Z, [x25, #1, MUL VL]\n" - "mov z2.s, #0x0\n" - "mov z25.s, #0x0\n" - "ld1rqb { z26.b }, p1/Z, [x28]\n" - "ld1rqb { z21.b }, p1/Z, [x28, #16]\n" - "mov z27.s, #0x0\n" - "mov z19.s, #0x0\n" - "ld1b { z29.b }, p1/Z, [x25, #2, MUL VL]\n" - "ld1b { z16.b }, p1/Z, [x25, #3, MUL VL]\n" - "sub x21, x25, #0x10\n" - "sub x20, x28, #0x8\n" - "lsl z20.b, z3.b, #0x4\n" - "lsl z4.b, z6.b, #0x4\n" - "ld1rqb { z10.b }, p1/Z, [x28, #32]\n" - "ld1rqb { z23.b }, p1/Z, [x28, #48]\n" - "and z3.b, z3.b, #0xf0\n" - "and z6.b, z6.b, #0xf0\n" - "ld1rqb { z11.b }, p1/Z, [x28, #64]\n" - "ld1rqb { z7.b }, p1/Z, [x28, #80]\n" - "lsl z8.b, z29.b, #0x4\n" - "lsl z14.b, z16.b, #0x4\n" - "ld1rqb { z18.b }, p1/Z, [x28, #96]\n" - "ld1rqb { z30.b }, p1/Z, [x28, #112]\n" - ".inst 0x45149b42 // smmla z2.s, z26.b, z20.b\n" - ".inst 0x45049b59 // smmla z25.s, z26.b, z4.b\n" - "and z29.b, z29.b, #0xf0\n" - "ld1h { z17.s }, p1/Z, [x21]\n" - ".inst 0x45149abb // smmla z27.s, z21.b, z20.b\n" - ".inst 0x45049ab3 // smmla z19.s, z21.b, z4.b\n" - "and z16.b, z16.b, #0xf0\n" - "ld1h { z4.s }, p0/Z, [x20]\n" - "subs x22, x22, #0x1\n" - "add x28, x28, #0x88\n" - "fcvt z17.s, p1/m, z17.h\n" - "add x25, x25, #0x90\n" - ".inst 0x45089942 // smmla z2.s, z10.b, z8.b\n" - ".inst 0x450e9959 // smmla z25.s, z10.b, z14.b\n" - "fcvt z4.s, p1/m, z4.h\n" - ".inst 0x45089afb // smmla z27.s, z23.b, z8.b\n" - ".inst 0x450e9af3 // smmla z19.s, z23.b, z14.b\n" - "fscale z17.s, p1/m, z17.s, z28.s\n" - "mov z4.q, z4.q[0]\n" - ".inst 0x45039962 // smmla z2.s, z11.b, z3.b\n" - ".inst 0x45069979 // smmla z25.s, z11.b, z6.b\n" - "fmul z23.s, z17.s, z4.s[0]\n" - "fmul z9.s, z17.s, z4.s[1]\n" - "fmul z21.s, z17.s, z4.s[2]\n" - "fmul z4.s, z17.s, z4.s[3]\n" - ".inst 0x450398fb // smmla z27.s, z7.b, z3.b\n" - ".inst 0x450698f3 // smmla z19.s, z7.b, z6.b\n" - ".inst 0x451d9a42 // smmla z2.s, z18.b, z29.b\n" - ".inst 0x45109a59 // smmla z25.s, z18.b, z16.b\n" - ".inst 0x451d9bdb // smmla z27.s, z30.b, z29.b\n" - ".inst 0x45109bd3 // smmla z19.s, z30.b, z16.b\n" - "uzp1 z31.d, z2.d, z25.d\n" - "uzp2 z13.d, z2.d, z25.d\n" - "scvtf z31.s, p1/m, z31.s\n" - "uzp1 z17.d, z27.d, z19.d\n" - "uzp2 z18.d, z27.d, z19.d\n" - "scvtf z13.s, p1/m, z13.s\n" - "fmla z24.s, p1/M, z31.s, z23.s\n" - "scvtf z17.s, p1/m, z17.s\n" - "scvtf z18.s, p1/m, z18.s\n" - "fmla z15.s, p1/M, z13.s, z9.s\n" - "fmla z12.s, p1/M, z17.s, z21.s\n" - "fmla z0.s, p1/M, z18.s, z4.s\n" - "bgt 7b\n" - "mov x20, %x[res_ptr]\n" - "cmp x13, #0x1\n" - "st1w { z24.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x13, #0x2\n" - "st1w { z15.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x13, #0x3\n" - "st1w { z12.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "st1w { z0.s }, p1, [x20]\n" - "8:" // Row tail: Accumulator store skip - "subs x24, x24, #0x8\n" - "add %x[res_ptr], %x[res_ptr], #0x20\n" - "bne 6b\n" - "subs x13, x13, #0x4\n" - "add %x[a_ptr], %x[a_ptr], x12\n" - "mov %x[res_ptr], x23\n" - "bgt 5b\n" - "9:" // Row tail: Row loop skip - : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) - : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) - : "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" - ); - return; - } - else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { - GGML_ASSERT((ggml_cpu_has_sve() && (svcntw() == 8)) && - "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal " - "performance"); - } - else if (ggml_cpu_has_neon()) { - GGML_ASSERT(((ggml_cpu_has_sve() && (svcntw() == 8)) || ggml_cpu_has_matmul_int8()) && - "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 " - "quantization format for optimal performance"); - } -#endif -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) - GGML_ASSERT(ggml_cpu_has_sve() && - "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance"); -#elif defined(__ARM_NEON) && defined(__aarch64__) - GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) && - "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " - "performance"); -#else - float sumf[4][8]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; - } - sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } -#endif -} diff --git a/ggml/src/ggml-aarch64.h b/ggml/src/ggml-aarch64.h deleted file mode 100644 index 517babaf1..000000000 --- a/ggml/src/ggml-aarch64.h +++ /dev/null @@ -1,39 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd. -#pragma once - -#define GGML_COMMON_DECL_C -#include "ggml-common.h" - -#include "ggml.h" - -// GGML internal header - -#ifdef __cplusplus -extern "C" { -#endif - -// Quantization -void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); - -void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave); - -// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") -size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q4_0_4x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); - -// GEMV -void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); - -// GEMM -void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); - -#ifdef __cplusplus -} -#endif - diff --git a/ggml/src/ggml-cann.cpp b/ggml/src/ggml-cann.cpp deleted file mode 100644 index 06930ba2e..000000000 --- a/ggml/src/ggml-cann.cpp +++ /dev/null @@ -1,2020 +0,0 @@ -/* - * Copyright (c) 2023-2024 The ggml authors - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS - * IN THE SOFTWARE. - */ - -#include "ggml-cann.h" - -#include -#include - -#include -#include -#include -#include - -#include "ggml-backend-impl.h" -#include "ggml-cann/aclnn_ops.h" -#include "ggml-cann/common.h" - -#define GGML_COMMON_DECL_C - -#include "ggml-common.h" - -/** - * @brief Default logging callback for GGML. - * - * This function is the default logging callback that logs messages to stderr. - * - * @param level The log level. - * @param msg The log message. - * @param user_data User data passed to the callback. - */ -static void ggml_cann_default_log_callback(enum ggml_log_level level, - const char* msg, void* user_data) { - GGML_UNUSED(level); - GGML_UNUSED(user_data); - fprintf(stderr, "%s", msg); -} - -ggml_log_callback ggml_cann_log_callback = ggml_cann_default_log_callback; -void* ggml_cann_log_user_data = NULL; - -GGML_API void ggml_backend_cann_log_set_callback(ggml_log_callback log_callback, - void* user_data) { - ggml_cann_log_callback = log_callback; - ggml_cann_log_user_data = user_data; -} - -#define GGML_CANN_LOG_INFO(...) ggml_cann_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__) -#define GGML_CANN_LOG_WARN(...) ggml_cann_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__) -#define GGML_CANN_LOG_ERROR(...) \ - ggml_cann_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) - -GGML_ATTRIBUTE_FORMAT(2, 3) - -/** - * @brief Log a message using the current logging callback. - * - * This function formats a log message and passes it to the current logging - * callback. - * - * @param level The log level. - * @param format The format string for the log message. - * @param ... The arguments for the format string. - */ -static void ggml_cann_log(enum ggml_log_level level, const char* format, ...) { - if (ggml_cann_log_callback != NULL) { - va_list args; - va_start(args, format); - char buffer[128]; - int len = vsnprintf(buffer, 128, format, args); - if (len < 128) { - ggml_cann_log_callback(level, buffer, ggml_cann_log_user_data); - } else { - // vsnprintf adds a null terminator - std::vector buffer2(len + 1); - va_end(args); - va_start(args, format); - vsnprintf(&buffer2[0], buffer2.size(), format, args); - ggml_cann_log_callback(level, buffer2.data(), - ggml_cann_log_user_data); - } - va_end(args); - } -} - -/** - * @brief Handles CANN errors by printing an error message and aborting. - * - * @param stmt The statement that caused the error. - * @param func The function in which the error occurred. - * @param file The file in which the error occurred. - * @param line The line number where the error occurred. - * @param msg The error message. - */ -[[noreturn]] void ggml_cann_error(const char* stmt, const char* func, - const char* file, int line, const char* msg) { - int32_t id = -1; - aclrtGetDevice(&id); - - GGML_CANN_LOG_ERROR("CANN error: %s\n", msg); - GGML_CANN_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, - file, line); - GGML_CANN_LOG_ERROR(" %s\n", stmt); - // abort with GGML_ASSERT to get a stack trace - GGML_ABORT("CANN error"); -} - -/** - * @brief Sets the device to be used by CANN. - * - * @param device The device ID to set. - */ -void ggml_cann_set_device(const int32_t device) { - // TODO: uncomment these lines after empty context has fixed. - // int current_device; - // ACL_CHECK(aclrtGetDevice(¤t_device)); - - // if (device == current_device) { - // return; - // } - ACL_CHECK(aclrtSetDevice(device)); -} - -/** - * @brief Retrieves the current device ID. - * - * @return The current device ID. - */ -int32_t ggml_cann_get_device() { - int32_t id; - ACL_CHECK(aclrtGetDevice(&id)); - return id; -} - -/** - * @brief Initialize the CANN device information. - * - * This function initializes the CANN device information by obtaining the - * device count and setting the memory allocation granularity for each device. - * - * @return A structure containing the device information. - */ -static ggml_cann_device_info ggml_cann_init() { - ggml_cann_device_info info = {}; - - aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count); - - if (err != ACL_SUCCESS) { - GGML_CANN_LOG_ERROR("%s: failed to initialize CANN: %s\n", - __func__, aclGetRecentErrMsg()); - return info; - } - - GGML_ASSERT(info.device_count <= GGML_CANN_MAX_DEVICES); - - for (int id = 0; id < info.device_count; ++id) { - aclrtPhysicalMemProp prop = {}; - prop.handleType = ACL_MEM_HANDLE_TYPE_NONE; - prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED; - prop.memAttr = ACL_HBM_MEM_HUGE; - prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE; - prop.location.id = id; - prop.reserve = 0; - ACL_CHECK(aclrtMemGetAllocationGranularity( - &prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED, - &info.devices[id].vmm_granularity)); - } - - // TODO: add more device info later. - return info; -} - -/** - * @brief Retrieve the CANN device information. - * - * This function returns a reference to a structure containing the CANN device - * information. The device information is initialized once and reused on - * subsequent calls. - * - * @return A reference to the structure containing the device information. - */ -const ggml_cann_device_info& ggml_cann_info() { - static ggml_cann_device_info info = ggml_cann_init(); - return info; -} - -//#define DEBUG_CANN_MALLOC -/** - * @brief A pool of CANN buffers(legacy). - * - * This class manages a pool of CANN buffers for a specific device. - */ -struct ggml_cann_pool_leg : public ggml_cann_pool { - /** - * @brief The maximum number of buffers in the pool. - */ - static const int MAX_BUFFERS = 256; - - /** - * @brief The device ID associated with this buffer pool. - */ - int device; - - /** - * @brief Structure representing a CANN buffer. - */ - struct ggml_cann_buffer { - void* ptr = nullptr; ///< Pointer to the buffer memory. - size_t size = 0; ///< Size of the buffer. - }; - - /** - * @brief Array of CANN buffers in the pool. - */ - ggml_cann_buffer buffer_pool[MAX_BUFFERS] = {}; - - /** - * @brief Total size of all buffers in the pool. - */ - size_t pool_size = 0; - - /** - * @brief Constructor to initialize the buffer pool for a specific device. - * - * @param device The device ID to associate with this buffer pool. - */ - explicit ggml_cann_pool_leg(int device) : device(device) {} - - /** - * @brief Destructor to free all buffers in the pool. - */ - ~ggml_cann_pool_leg() { - ggml_cann_set_device(device); - for (int i = 0; i < MAX_BUFFERS; ++i) { - ggml_cann_buffer& b = buffer_pool[i]; - if (b.ptr != nullptr) { - ACL_CHECK(aclrtFree(b.ptr)); - pool_size -= b.size; - } - } - GGML_ASSERT(pool_size == 0); - } - - /** - * @brief Allocate a buffer of the given size. - * - * @param size The size of the buffer to allocate. - * @param actual_size A pointer to a variable to receive the actual size of - * the allocated buffer. - * @return A pointer to the allocated buffer. - */ - void* alloc(size_t size, size_t* actual_size) override { -#ifdef DEBUG_CANN_MALLOC - int nnz = 0; - size_t max_size = 0; -#endif - size_t best_diff = 1ull << 36; - int ibest = -1; - for (int i = 0; i < MAX_BUFFERS; ++i) { - ggml_cann_buffer& b = buffer_pool[i]; - if (b.ptr != nullptr) { -#ifdef DEBUG_CANN_MALLOC - ++nnz; - if (b.size > max_size) max_size = b.size; -#endif - if (b.size >= size) { - size_t diff = b.size - size; - if (diff < best_diff) { - best_diff = diff; - ibest = i; - if (!best_diff) { - void* ptr = b.ptr; - *actual_size = b.size; - b.ptr = nullptr; - b.size = 0; - return ptr; - } - } - } - } - } - if (ibest >= 0) { - ggml_cann_buffer& b = buffer_pool[ibest]; - void* ptr = b.ptr; - *actual_size = b.size; - b.ptr = nullptr; - b.size = 0; - return ptr; - } - void* ptr; - size_t look_ahead_size = (size_t)(1.05 * size); - look_ahead_size = 256 * ((look_ahead_size + 255) / 256); - ggml_cann_set_device(device); - ACL_CHECK( - aclrtMalloc(&ptr, look_ahead_size, ACL_MEM_MALLOC_HUGE_FIRST)); - *actual_size = look_ahead_size; - pool_size += look_ahead_size; -#ifdef DEBUG_CANN_MALLOC - GGML_CANN_LOG_INFO( - "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, " - "requested %u MB\n", - __func__, device, nnz, (uint32_t)(max_size / 1024 / 1024), - (uint32_t)(pool_size / 1024 / 1024), - (uint32_t)(size / 1024 / 1024)); -#endif - return ptr; - } - - /** - * @brief Free a buffer and return it to the pool. - * - * @param ptr Pointer to the buffer to free. - * @param size Size of the buffer to free. - */ - void free(void* ptr, size_t size) override { - for (int i = 0; i < MAX_BUFFERS; ++i) { - ggml_cann_buffer& b = buffer_pool[i]; - if (b.ptr == nullptr) { - b.ptr = ptr; - b.size = size; - return; - } - } - // memory should always buffered. these memory may still needed by - // tasks in stream. - // TODO, fix me. - GGML_ABORT("Cann buffer pool full, increase MAX_CANN_BUFFERS\n"); - } -}; - -/** - * @brief A pool of CANN buffers with virtual memory. - * - * This class manages a pool of CANN buffers with virtual memory for a specific - * device. - */ -struct ggml_cann_pool_vmm : public ggml_cann_pool { - /** - * @brief The maximum size of the virtual memory pool (32 GB). - */ - static const size_t CANN_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB - - /** - * @brief The device ID associated with this buffer pool. - */ - int device; - - /** - * @brief Pointer to the start of the virtual memory pool. - */ - void* pool_addr = 0; - - /** - * @brief Amount of virtual memory used in the pool. - */ - size_t pool_used = 0; - - /** - * @brief Total size of the virtual memory pool. - */ - size_t pool_size = 0; - - /** - * @brief Allocation granularity for the virtual memory pool. - */ - size_t granularity; - - /** - * @brief Handles for the physical memory allocated. - */ - std::vector handles; - - /** - * @brief Offsets for the mapped memory regions. - */ - std::vector map_offsets; - - /** - * @brief Constructor to initialize the buffer pool with virtual memory for - * a specific device. - * - * @param device The device ID to associate with this buffer pool. - */ - explicit ggml_cann_pool_vmm(int device) - : device(device), - granularity(ggml_cann_info().devices[device].vmm_granularity) {} - - /** - * @brief Destructor to free all buffers in the virtual memory pool. - */ - ~ggml_cann_pool_vmm() { - if (pool_addr != 0) { - for (auto& offset : map_offsets) { - ACL_CHECK(aclrtUnmapMem(offset)); - } - for (auto& handle : handles) { - ACL_CHECK(aclrtFreePhysical(handle)); - } - ACL_CHECK(aclrtReleaseMemAddress(pool_addr)); - } - } - - /** - * @brief Allocate a buffer of the given size in the virtual memory pool. - * - * @param size The size of the buffer to allocate. - * @param actual_size A pointer to a variable to receive the actual size of - * the allocated buffer. - * @return A pointer to the allocated buffer. - */ - void* alloc(size_t size, size_t* actual_size) override { - // round up the allocation size to the alignment to ensure that all - // allocations are aligned for all data types - const size_t alignment = 128; - size = alignment * ((size + alignment - 1) / alignment); - - size_t avail = pool_size - pool_used; - - if (size > avail) { - // round up to the next multiple of the granularity - size_t reserve_size = size - avail; - reserve_size = - granularity * ((reserve_size + granularity - 1) / granularity); - - GGML_ASSERT(pool_size + reserve_size <= CANN_POOL_VMM_MAX_SIZE); - - // allocate more physical memory - aclrtPhysicalMemProp prop = {}; - prop.handleType = ACL_MEM_HANDLE_TYPE_NONE; - prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED; - prop.memAttr = ACL_HBM_MEM_HUGE; - prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE; - prop.location.id = device; - prop.reserve = 0; - aclrtDrvMemHandle handle; - ACL_CHECK(aclrtMallocPhysical(&handle, reserve_size, &prop, 0)); - - // reserve virtual address space (if not already reserved) - if (pool_addr == 0) { - ACL_CHECK(aclrtReserveMemAddress( - &pool_addr, CANN_POOL_VMM_MAX_SIZE, 0, NULL, 1)); - } - - // map at the end of the pool - ACL_CHECK(aclrtMapMem((char*)pool_addr + pool_size, reserve_size, 0, - handle, 0)); - - handles.push_back(handle); - map_offsets.push_back((char*)pool_addr + pool_size); - - // add to the pool - pool_size += reserve_size; - - // GGML_CANN_LOG_INFO("cann pool[%d]: size increased to %llu MB ( - // reserved %llu MB)\n", - // device, (unsigned long long) (pool_size/1024/1024), - // (unsigned long long) (reserve_size/1024/1024)); - } - - GGML_ASSERT(pool_addr != 0); - - void* ptr = (void*)((char*)pool_addr + pool_used); - *actual_size = size; - pool_used += size; - -#ifdef DEBUG_CANN_MALLOC - GGML_CANN_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device, - (unsigned long long)size, (unsigned long long)ptr); -#endif - return ptr; - } - - /** - * @brief Free a buffer and return it to the virtual memory pool. - * - * @param ptr Pointer to the buffer to free. - * @param size Size of the buffer to free. - */ - void free(void* ptr, size_t size) override { -#ifdef DEBUG_CANN_MALLOC - GGML_CANN_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device, - (unsigned long long)size, (unsigned long long)ptr); -#endif - - pool_used -= size; - - // all deallocations must be in reverse order of the allocations - GGML_ASSERT(ptr == (void*)((char*)pool_addr + pool_used)); - } -}; - -/** - * @brief Create a new CANN pool for a specific device. - * - * Factory method to create a new CANN pool object based on the device type. - * - * @param device The device ID for which to create the pool. - * @return A unique pointer to the created CANN pool. - */ -std::unique_ptr ggml_backend_cann_context::new_pool_for_device( - int device) { - // return std::unique_ptr(new ggml_cann_pool_leg(device)); - return std::unique_ptr(new ggml_cann_pool_vmm(device)); -} - -// cann buffer -/** - * @brief Context for managing a CANN buffer associated with a specific device. - * - * This structure holds information about a CANN buffer, including the device - * ID, device pointer, and a name derived from GGML_CANN_NAME and the device ID. - */ -struct ggml_backend_cann_buffer_context { - int32_t device; ///< The device ID associated with this buffer context. - void* dev_ptr = - nullptr; ///< Pointer to the device memory allocated for the buffer. - - /** - * @brief Constructor to initialize the CANN buffer context. - * - * @param device The device ID associated with this buffer context. - * @param dev_ptr Pointer to the device memory allocated for the buffer. - */ - ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr) - : device(device), - dev_ptr(dev_ptr) {} - - /** - * @brief Destructor to free the device memory allocated for the buffer. - */ - ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); } -}; - -/** - * @brief Retrieve the name associated with a CANN buffer. - * - * This function returns the name of a CANN buffer, which is stored in the - * context of the buffer. - * - * @param buffer The CANN buffer whose name is to be retrieved. - * @return A pointer to a C-string containing the name of the buffer. - */ - -GGML_CALL static const char* ggml_backend_cann_buffer_get_name( - ggml_backend_buffer_t buffer) { - return "CANN"; - - GGML_UNUSED(buffer); -} - -/** - * @brief Check if a buffer is a CANN buffer. - * - * This function checks if a given buffer is a CANN buffer by comparing its - * `get_name` function pointer to `ggml_backend_cann_buffer_get_name`. - * - * @param buffer The buffer to check. - * @return true if the buffer is a CANN buffer, false otherwise. - */ -GGML_CALL static bool ggml_backend_buffer_is_cann( - ggml_backend_buffer_t buffer) { - return buffer->iface.get_name == ggml_backend_cann_buffer_get_name; -} - -/** - * @brief Free resources associated with a CANN buffer. - * - * This function frees the resources associated with a CANN buffer, including - * its context. - * - * @param buffer The CANN buffer to free. - */ -GGML_CALL static void ggml_backend_cann_buffer_free_buffer( - ggml_backend_buffer_t buffer) { - ggml_backend_cann_buffer_context* ctx = - (ggml_backend_cann_buffer_context*)buffer->context; - delete ctx; -} - -/** - * @brief Retrieve the base pointer of a CANN buffer. - * - * This function returns the base pointer of a CANN buffer, which points to the - * device memory allocated for the buffer. - * - * @param buffer The CANN buffer whose base pointer is to be retrieved. - * @return A pointer to the base of the device memory allocated for the buffer. - */ -GGML_CALL static void* ggml_backend_cann_buffer_get_base( - ggml_backend_buffer_t buffer) { - ggml_backend_cann_buffer_context* ctx = - (ggml_backend_cann_buffer_context*)buffer->context; - return ctx->dev_ptr; -} - -/** - * @brief Transform quantized Q4.0 tensor data into a format suitable for CANN - * processing. - * - * This function transforms quantized Q4.0 tensor data into a format suitable - * for CANN processing. It extracts quantization values and scales from the - * source data and prepares them in a format expected by CANN operations. - * - * @param tensor Pointer to the tensor information. - * @param src Pointer to the source data in Q4.0 format. - * @param dst Pointer to the destination buffer where transformed data will be - * stored. - */ -GGML_CALL static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor, - const void* src, - void* dst) { - - int64_t n_elems = ggml_nelements(tensor); - int64_t groups = n_elems / QK4_0; - size_t quant_bytes = n_elems * sizeof(uint8_t) / 2; - - uint8_t* quant_offset = (uint8_t*)dst; - uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes); - - for (int i = 0; i < groups; i++) { - const block_q4_0* group = - (const block_q4_0*)((const char*)src + i * sizeof(block_q4_0)); - *scale_offset = group->d; - scale_offset++; - - // 0-15 - for (int j = 0; j < QK4_0 / 2; j += 2) { - (*quant_offset) = (group->qs[j] & 0x0F); - (*quant_offset) |= ((group->qs[j + 1] << 4)); - quant_offset++; - } - - // 16-31 - for (int j = 0; j < QK4_0 / 2; j += 2) { - (*quant_offset) = (group->qs[j] >> 4); - (*quant_offset) |= (group->qs[j + 1] & 0xF0); - quant_offset++; - } - } - - // put (uint4b_t -8) into int4b_t - for (quant_offset = (uint8_t*)dst; - quant_offset < (uint8_t*)dst + quant_bytes; quant_offset++) { - (*quant_offset) ^= 0x88; - } -} - -/** - * @brief Transform CANN processed data back into quantized Q4.0 format. - * - * This function transforms CANN processed data back into quantized Q4.0 format. - * It reverses the transformation performed by - * ggml_backend_cann_transform_q4_0(), converting the data back into its - * original quantized form. - * - * @param tensor Pointer to the tensor information. - * @param src Pointer to the source buffer containing transformed data. - * @param dst Pointer to the destination buffer where the Q4.0 formatted data - * will be stored. - */ -GGML_CALL static void ggml_backend_cann_transform_back_q4_0( - const ggml_tensor* tensor, void* src, void* dst) { - - int64_t n_elems = ggml_nelements(tensor); - int64_t groups = n_elems / QK4_0; - size_t quant_bytes = n_elems * sizeof(uint8_t) / 2; - - uint8_t* quant_offset = (uint8_t*)src; - uint16_t* scale_offset = (uint16_t*)((char*)src + quant_bytes); - - for (; quant_offset < (uint8_t*)src + quant_bytes; quant_offset++) { - (*quant_offset) ^= 0x88; - } - quant_offset = (uint8_t*)src; - - for (int i = 0; i < groups; i++) { - block_q4_0* group = (block_q4_0*)((char*)dst + i * sizeof(block_q4_0)); - group->d = *scale_offset; - scale_offset++; - - // 0-15 - for (int j = 0; j < QK4_0 / 2; j += 2) { - group->qs[j] = ((*quant_offset) & 0x0F); - group->qs[j + 1] = ((*quant_offset) >> 4); - quant_offset++; - } - - // 16-31 - for (int j = 0; j < QK4_0 / 2; j += 2) { - group->qs[j] |= ((*quant_offset) << 4); - group->qs[j + 1] |= ((*quant_offset) & 0xF0); - quant_offset++; - } - } -} - -/** - * @brief Transform quantized Q8.0 tensor data into a format suitable for CANN - * processing. - * - * This function transforms quantized Q8.0 tensor data into a format suitable - * for CANN processing. It extracts quantization values and scales from the - * source data and prepares them in a format expected by CANN operations. - * - * @param tensor Pointer to the tensor information. - * @param src Pointer to the source data in Q8.0 format. - * @param dst Pointer to the destination buffer where transformed data will be - * stored. - */ -GGML_CALL static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor, - const void* src, - void* dst) { - int64_t n_elems = ggml_nelements(tensor); - int64_t groups = n_elems / QK8_0; - size_t quant_bytes = n_elems * sizeof(uint8_t); - - uint8_t* quant_offset = (uint8_t*)dst; - uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes); - - for (int i = 0; i < groups; i++) { - const block_q8_0* group = - (const block_q8_0*)((const char*)src + i * sizeof(block_q8_0)); - *scale_offset = group->d; - scale_offset++; - size_t group_quant_size = QK8_0 * sizeof(uint8_t); - memcpy(quant_offset, group->qs, group_quant_size); - quant_offset += group_quant_size; - } -} - -/** - * @brief Transform CANN processed data back into quantized Q8.0 format. - * - * This function transforms CANN processed data back into quantized Q8.0 format. - * It reverses the transformation performed by - * ggml_backend_cann_transform_q8_0(), converting the data back into its - * original quantized form. - * - * @param tensor Pointer to the tensor information. - * @param src Pointer to the source buffer containing transformed data. - * @param dst Pointer to the destination buffer where the Q8.0 formatted data - * will be stored. - */ -GGML_CALL static void ggml_backend_cann_transform_back_q8_0( - const ggml_tensor* tensor, const void* src, void* dst) { - int64_t n_elems = ggml_nelements(tensor); - int64_t groups = n_elems / QK8_0; - size_t quant_bytes = n_elems * sizeof(uint8_t); - - const uint8_t* quant_offset = (const uint8_t*)src; - const uint16_t* scale_offset = - (const uint16_t*)((const char*)src + quant_bytes); - - for (int i = 0; i < groups; i++) { - block_q8_0* group = (block_q8_0*)((char*)dst + i * sizeof(block_q8_0)); - group->d = *scale_offset; - scale_offset++; - size_t group_quant_size = QK8_0 * sizeof(uint8_t); - memcpy(group->qs, quant_offset, group_quant_size); - quant_offset += group_quant_size; - } -} - -/** - * @brief Transform tensor data based on its type for CANN processing. - * - * This function transforms tensor data based on its quantization type for CANN - * processing. It dispatches the transformation based on the tensor's type to - * specialized functions handling Q4.0 and Q8.0 formats. - * - * @param tensor Pointer to the tensor information. - * @param src Pointer to the source data to be transformed. - * @param dst Pointer to the destination buffer where transformed data will be - * stored. - */ -GGML_CALL static void ggml_backend_cann_transform(ggml_tensor* tensor, - const void* src, void* dst) { - switch (tensor->type) { - case GGML_TYPE_Q4_0: - ggml_backend_cann_transform_q4_0(tensor, src, dst); - break; - case GGML_TYPE_Q8_0: - ggml_backend_cann_transform_q8_0(tensor, src, dst); - break; - default: - break; - } -} - -/** - * @brief Transform CANN processed data back into tensor data based on its type. - * - * This function transforms CANN processed data back into tensor data based on - * its quantization type for Q4.0 and Q8.0 formats. It dispatches the - * transformation based on the tensor's type to specialized functions. - * - * @param tensor Pointer to the tensor information. - * @param src Pointer to the source data containing CANN processed data. - * @param dst Pointer to the destination buffer where transformed tensor data - * will be stored. - */ -GGML_CALL static void ggml_backend_cann_transform_back( - const ggml_tensor* tensor, void* src, void* dst) { - switch (tensor->type) { - case GGML_TYPE_Q4_0: - ggml_backend_cann_transform_back_q4_0(tensor, src, dst); - break; - case GGML_TYPE_Q8_0: - ggml_backend_cann_transform_back_q8_0(tensor, src, dst); - break; - default: - break; - } -} - -/** - * @brief Check if transformation is needed for a given tensor type. - * - * This function checks if transformation is needed for a given tensor type - * to prepare data for CANN processing. - * - * @param type The tensor type to check. - * @return true if transformation is needed, false otherwise. - */ -GGML_CALL static bool need_transform(ggml_type type) { - switch (type) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q8_0: - return true; - default: - return false; - } -} - -/** - * @brief Initialize a tensor using data from a CANN buffer. - * - * This function initializes a tensor using data from a CANN buffer. - * It handles special cases such as views and quantization. - * - * @param buffer The CANN buffer from which to initialize the tensor. - * @param tensor Pointer to the tensor to be initialized. - */ -GGML_CALL static void ggml_backend_cann_buffer_init_tensor( - ggml_backend_buffer_t buffer, ggml_tensor* tensor) { - if (tensor->view_src != NULL && tensor->view_offs == 0) { - GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); - return; - } - - // TODO: can backend doesn't support quantized yet. Just leave the code - // here. - if (ggml_is_quantized(tensor->type)) { - // Initialize padding to 0 to avoid possible NaN values - size_t original_size = ggml_nbytes(tensor); - size_t padded_size = - ggml_backend_buft_get_alloc_size(buffer->buft, tensor); - - if (padded_size > original_size && tensor->view_src == nullptr) { - size_t memset_size = padded_size - original_size; - ACL_CHECK(aclrtMemset((char*)tensor->data + original_size, - memset_size, 0, memset_size)); - } - } -} - -// TODO: need handle tensor which has paddings. -/** - * @brief Set tensor data in a CANN buffer. - * - * This function sets tensor data in a CANN buffer, handling transformations - * if needed based on the tensor's type. - * - * @param buffer The CANN buffer where the tensor data will be set. - * @param tensor Pointer to the tensor whose data will be set. - * @param data Pointer to the source data to be copied into the tensor. - * @param offset Offset in the source data from where to start copying. - * @param size Size of the data to be copied, in bytes. - */ -GGML_CALL static void ggml_backend_cann_buffer_set_tensor( - ggml_backend_buffer_t buffer, ggml_tensor *tensor, const void *data, - size_t offset, size_t size) { - ggml_backend_cann_buffer_context *ctx = - (ggml_backend_cann_buffer_context *)buffer->context; - - ggml_cann_set_device(ctx->device); - // TODO: refer to cann(#6017), it use thread's default stream. - // For acl, synchronous functions use this default stream. - // Why aclrtSynchronizeDevice? - - if (!need_transform(tensor->type)) { - ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size, - ACL_MEMCPY_HOST_TO_DEVICE)); - } else { - void *transform_buffer = malloc(size); - ggml_backend_cann_transform(tensor, data, transform_buffer); - -#ifndef NDEBUG - void *check_buffer = malloc(size); - ggml_backend_cann_transform_back(tensor, transform_buffer, - check_buffer); - GGML_ASSERT(memcmp(data, check_buffer, size) == 0); - free(check_buffer); -#endif - ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, - transform_buffer, size, - ACL_MEMCPY_HOST_TO_DEVICE)); - free(transform_buffer); - } -} - -/** - * @brief Get tensor data from a CANN buffer. - * - * This function retrieves tensor data from a CANN buffer, handling - * transformations if needed based on the tensor's type. - * - * @param buffer The CANN buffer from which to retrieve tensor data. - * @param tensor Pointer to the tensor whose data will be retrieved. - * @param data Pointer to the destination buffer where the tensor data will be - * copied. - * @param offset Offset in the destination buffer where to start copying. - * @param size Size of the data to be copied, in bytes. - */ -GGML_CALL static void ggml_backend_cann_buffer_get_tensor( - ggml_backend_buffer_t buffer, const ggml_tensor* tensor, void* data, - size_t offset, size_t size) { - ggml_backend_cann_buffer_context* ctx = - (ggml_backend_cann_buffer_context*)buffer->context; - - ggml_cann_set_device(ctx->device); - - if (!need_transform(tensor->type)) { - ACL_CHECK(aclrtMemcpy(data, size, (char*)tensor->data + offset, size, - ACL_MEMCPY_DEVICE_TO_HOST)); - } else { - void* transform_buffer = malloc(size); - ACL_CHECK(aclrtMemcpy(transform_buffer, size, - (char*)tensor->data + offset, size, - ACL_MEMCPY_DEVICE_TO_HOST)); - ggml_backend_cann_transform_back(tensor, transform_buffer, data); - free(transform_buffer); - } -} - -/** - * @brief Copy tensor data between CANN buffers if possible. - * - * This function copies tensor data between CANN buffers if the source and - * destination buffers are CANN buffers and they meet the necessary conditions - * (same device or devices can access each other). - * - * @param buffer The destination CANN buffer where the tensor data will be - * copied. - * @param src Pointer to the source tensor whose data will be copied. - * @param dst Pointer to the destination tensor where the data will be copied. - * @return true if the copy operation succeeded, false otherwise. - */ -GGML_CALL static bool ggml_backend_cann_buffer_cpy_tensor( - ggml_backend_buffer_t buffer, const ggml_tensor* src, ggml_tensor* dst) { - if (ggml_backend_buffer_is_cann(src->buffer)) { - ggml_backend_cann_buffer_context* src_ctx = - (ggml_backend_cann_buffer_context*)src->buffer->context; - ggml_backend_cann_buffer_context* dst_ctx = - (ggml_backend_cann_buffer_context*)buffer->context; - - size_t memcpy_size = ggml_nbytes(src); - // Same device. - if (src_ctx->device == dst_ctx->device) { - ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size, - (const char*)src->data, memcpy_size, - ACL_MEMCPY_DEVICE_TO_DEVICE)); - return true; - } else { - // Different device but can access by peer. - int32_t canAccessPeer = 0; - ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device, - dst_ctx->device)); - if (canAccessPeer) { - ggml_cann_set_device(src_ctx->device); - ACL_CHECK(aclrtDeviceEnablePeerAccess(dst_ctx->device, 0)); - ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size, - (const char*)src->data, memcpy_size, - ACL_MEMCPY_DEVICE_TO_DEVICE)); - return true; - } - } - } - return false; -} - -/** - * @brief Clear a CANN buffer by setting all its memory to a specified value. - * - * This function clears a CANN buffer by setting all its memory to a specified - * value. - * - * @param buffer The CANN buffer to be cleared. - * @param value The value to which each byte in the buffer will be set. - */ -GGML_CALL static void ggml_backend_cann_buffer_clear( - ggml_backend_buffer_t buffer, uint8_t value) { - ggml_backend_cann_buffer_context* ctx = - (ggml_backend_cann_buffer_context*)buffer->context; - - ggml_cann_set_device(ctx->device); - ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size)); -} - -/** - * @brief Interface for a CANN buffer in the backend. - * - * This structure defines function pointers to operations that can be performed - * on a CANN buffer within the backend. - */ -static ggml_backend_buffer_i ggml_backend_cann_buffer_interface = { - /* .get_name = */ ggml_backend_cann_buffer_get_name, - /* .free_buffer = */ ggml_backend_cann_buffer_free_buffer, - /* .get_base = */ ggml_backend_cann_buffer_get_base, - /* .init_tensor = */ ggml_backend_cann_buffer_init_tensor, - /* .set_tensor = */ ggml_backend_cann_buffer_set_tensor, - /* .get_tensor = */ ggml_backend_cann_buffer_get_tensor, - /* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor, - /* .clear = */ ggml_backend_cann_buffer_clear, - /* .reset = */ NULL, -}; - -// cann buffer type -/** - * @brief Structure representing context information for a specific backend - * buffer type. - */ -struct ggml_backend_cann_buffer_type_context { - int32_t - device; /**< Device identifier associated with the buffer context. */ - std::string name; /**< Name associated with the buffer context. */ -}; - -/** - * @brief Retrieves the name associated with a CANN buffer type. - * - * This function returns the descriptive name associated with the specified - * CANN buffer type context. - * - * @param buft Pointer to the buffer type context. - * @return Const pointer to the C-style string containing the name. - */ -GGML_CALL static const char* ggml_backend_cann_buffer_type_name( - ggml_backend_buffer_type_t buft) { - return "CANN"; - - GGML_UNUSED(buft); -} - -/** - * @brief Allocates a new CANN buffer of the specified type and size. - * - * This function allocates a new CANN buffer on the specified device with the - * given size. - * - * @param buft Pointer to the buffer type context. - * @param size Size in bytes of the buffer to allocate. - * @return Pointer to the allocated buffer, or nullptr if allocation fails. - */ -GGML_CALL static ggml_backend_buffer_t -ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, - size_t size) { - ggml_backend_cann_buffer_type_context* buft_ctx = - (ggml_backend_cann_buffer_type_context*)buft->context; - - ggml_cann_set_device(buft_ctx->device); - - size = std::max(size, (size_t)1); - - void* dev_ptr; - aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST); - if (err != ACL_SUCCESS) { - GGML_CANN_LOG_ERROR( - "%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n", - __func__, size / 1024.0 / 1024.0, buft_ctx->device, - aclGetRecentErrMsg()); - return nullptr; - } - - ggml_backend_cann_buffer_context* ctx = - new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr); - - return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface, - ctx, size); -} - -/** - * @brief Retrieves the memory alignment requirement for CANN buffers of this - * type. - * - * This function returns the alignment requirement in bytes for memory allocated - * by the CANN buffer type. - * - * @param buft Pointer to the buffer type context (unused in this - * implementation). - * @return The alignment requirement in bytes (fixed at 128 bytes for CANN - * buffers). - */ -GGML_CALL static size_t ggml_backend_cann_buffer_type_get_alignment( - ggml_backend_buffer_type_t buft) { - return 128; - - GGML_UNUSED(buft); -} - -/** - * @brief Calculates the allocation size required for a tensor in a CANN buffer. - * - * Computes the total allocation size needed for storing the tensor's data in a - * CANN buffer, considering any necessary padding or adjustments for quantized - * types. - * - * @param buft Pointer to the buffer type context (unused in this - * implementation). - * @param tensor Pointer to the tensor for which the allocation size is - * calculated. - * @return The total allocation size in bytes required for the tensor in the - * CANN buffer. - */ -GGML_CALL static size_t ggml_backend_cann_buffer_type_get_alloc_size( - ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) { - size_t size = ggml_nbytes(tensor); - int64_t ne0 = tensor->ne[0]; - - // last line must bigger than 32, because every single op deal at - // least 32 bytes. - // TODO: quantized type? - // int64_t line_size = ne0 * ggml_element_size(tensor); - // int64_t line_size_align_32 = (line_size + 31) & ~31; - // size += (line_size_align_32 - line_size); - - // TODO: not support quantized yet. - // TODO: consider un-continue tensor. - if (ggml_is_quantized(tensor->type)) { - if (ne0 % MATRIX_ROW_PADDING != 0) { - size += ggml_row_size( - tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); - } - } - - return size; - - GGML_UNUSED(buft); -} - -/** - * @brief Interface for managing CANN buffer types in the GGML backend. - * - * Provides function pointers for allocating, querying properties, and managing - * memory for CANN buffer types in the GGML backend. - */ -static ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = { - /* .get_name = */ ggml_backend_cann_buffer_type_name, - /* .alloc_buffer = */ ggml_backend_cann_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_cann_buffer_type_get_alignment, - /* .get_max_size = */ NULL, // defaults to SIZE_MAX - /* .get_alloc_size = */ ggml_backend_cann_buffer_type_get_alloc_size, - /* .is_host = */ NULL, -}; - -/** - * @brief Retrieves the CANN buffer type for a specified device. - * - * This function initializes and returns the buffer type interface associated - * with the given device. It ensures thread-safe access using a mutex. - * - * @param device The device index for which to retrieve the buffer type. - * @return A pointer to the buffer type interface for the specified device, or - * nullptr if the device index is out of range. - */ -GGML_CALL ggml_backend_buffer_type_t -ggml_backend_cann_buffer_type(int32_t device) { - static std::mutex mutex; - std::lock_guard lock(mutex); - - if (device >= ggml_backend_cann_get_device_count()) { - return nullptr; - } - - static ggml_backend_buffer_type - ggml_backend_cann_buffer_types[GGML_CANN_MAX_DEVICES]; - - static bool ggml_backend_cann_buffer_type_initialized = false; - - if (!ggml_backend_cann_buffer_type_initialized) { - for (int32_t i = 0; i < GGML_CANN_MAX_DEVICES; i++) { - ggml_backend_cann_buffer_types[i] = { - /* .iface = */ ggml_backend_cann_buffer_type_interface, - /* .context = */ - new ggml_backend_cann_buffer_type_context{ - i, "CANN" + std::to_string(i)}, - }; - } - ggml_backend_cann_buffer_type_initialized = true; - } - - return &ggml_backend_cann_buffer_types[device]; -} - -/** - * @brief Computes the forward operation for a given tensor using CANN - * operations. - * - * This function selects the appropriate CANN operation based on the type of - * operation specified in the tensor and performs the computation. - * - * @param ctx The CANN context containing necessary resources and - * configurations. - * @param dst The destination tensor where the result of the computation will be - * stored. - * @return true if the computation was successful; false otherwise. - */ -static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, - struct ggml_tensor* dst) { - switch (dst->op) { - case GGML_OP_REPEAT: - ggml_cann_repeat(ctx, dst); - break; - case GGML_OP_GET_ROWS: - ggml_cann_get_rows(ctx, dst); - break; - case GGML_OP_DUP: - ggml_cann_dup(ctx, dst); - break; - case GGML_OP_ADD: - ggml_cann_add(ctx, dst); - break; - case GGML_OP_ACC: - ggml_cann_acc(ctx, dst); - break; - case GGML_OP_MUL: - ggml_cann_mul_div(ctx, dst); - break; - case GGML_OP_DIV: - ggml_cann_mul_div(ctx, dst); - break; - case GGML_OP_UNARY: - switch (ggml_get_unary_op(dst)) { - case GGML_UNARY_OP_GELU: - ggml_cann_activation( - ctx, dst); - break; - case GGML_UNARY_OP_SILU: - ggml_cann_activation( - ctx, dst); - break; - // TODO: Use faster gelu?? - case GGML_UNARY_OP_GELU_QUICK: - ggml_cann_activation( - ctx, dst); - break; - case GGML_UNARY_OP_TANH: - ggml_cann_activation( - ctx, dst); - break; - case GGML_UNARY_OP_RELU: - ggml_cann_activation( - ctx, dst); - break; - case GGML_UNARY_OP_HARDSIGMOID: - ggml_cann_activation(ctx, dst); - break; - case GGML_UNARY_OP_HARDSWISH: - ggml_cann_activation(ctx, dst); - break; - default: - return false; - } - break; - case GGML_OP_NORM: - ggml_cann_norm(ctx, dst); - break; - case GGML_OP_GROUP_NORM: - ggml_cann_group_norm(ctx, dst); - break; - case GGML_OP_CONCAT: - ggml_cann_concat(ctx, dst); - break; - case GGML_OP_UPSCALE: - ggml_cann_upsample_nearest2d(ctx, dst); - break; - case GGML_OP_PAD: - ggml_cann_pad(ctx, dst); - break; - case GGML_OP_ARANGE: - ggml_cann_arange(ctx, dst); - break; - case GGML_OP_TIMESTEP_EMBEDDING: - ggml_cann_timestep_embedding(ctx, dst); - break; - case GGML_OP_LEAKY_RELU: - ggml_cann_leaky_relu(ctx, dst); - break; - case GGML_OP_RMS_NORM: - ggml_cann_rms_norm(ctx, dst); - break; - case GGML_OP_MUL_MAT: - ggml_cann_mul_mat(ctx, dst); - break; - case GGML_OP_MUL_MAT_ID: - return false; - case GGML_OP_SCALE: - ggml_cann_scale(ctx, dst); - break; - case GGML_OP_SQR: - ggml_cann_sqr(ctx, dst); - break; - case GGML_OP_CLAMP: - ggml_cann_clamp(ctx, dst); - break; - case GGML_OP_CPY: - ggml_cann_cpy(ctx, dst); - break; - case GGML_OP_CONT: - ggml_cann_dup(ctx, dst); - break; - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_PERMUTE: - case GGML_OP_TRANSPOSE: - break; - case GGML_OP_DIAG_MASK_INF: - ggml_cann_diag_mask(ctx, dst, -INFINITY); - break; - case GGML_OP_SOFT_MAX: - ggml_cann_softmax(ctx, dst); - break; - case GGML_OP_ROPE: - ggml_cann_rope(ctx, dst); - break; - case GGML_OP_IM2COL: - ggml_cann_im2col(ctx, dst); - break; - case GGML_OP_POOL_2D: - ggml_cann_pool2d(ctx, dst); - break; - case GGML_OP_SUM_ROWS: - ggml_cann_sum_rows(ctx, dst); - break; - case GGML_OP_ARGSORT: - ggml_cann_argsort(ctx, dst); - break; - default: - return false; - } - - return true; -} - -// backend -/** - * @brief Retrieves the name associated with the CANN backend. - * - * This function returns the name assigned to the CANN backend, which is stored - * in the context of the provided backend structure. - * - * @param backend Pointer to the CANN backend structure. - * @return A pointer to a constant string representing the backend name. - */ -GGML_CALL static const char* ggml_backend_cann_name(ggml_backend_t backend) { - ggml_backend_cann_context* cann_ctx = - (ggml_backend_cann_context*)backend->context; - - return cann_ctx->name.c_str(); -} - -/** - * @brief Frees resources associated with the CANN backend. - * - * This function releases resources associated with the CANN backend context - * and resets the device associated with the backend to its initial state. - * - * @param backend Pointer to the CANN backend structure to be freed. - */ -GGML_CALL static void ggml_backend_cann_free(ggml_backend_t backend) { - ggml_backend_cann_context* cann_ctx = - (ggml_backend_cann_context*)backend->context; - ACL_CHECK(aclrtSynchronizeDevice()); - ACL_CHECK(aclrtResetDevice(cann_ctx->device)); - - // finalize when last backend freed. - if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) { - ACL_CHECK(aclFinalize()); - } - - delete cann_ctx; - delete backend; -} - -/** - * @brief Retrieves the default buffer type associated with the CANN backend. - * - * This function returns the buffer type specific to the device associated - * with the CANN backend. It is used to allocate buffers for computations - * performed by the backend. - * - * @param backend Pointer to the CANN backend structure. - * @return Pointer to the buffer type structure for the CANN backend. - */ -GGML_CALL static ggml_backend_buffer_type_t -ggml_backend_cann_get_default_buffer_type(ggml_backend_t backend) { - ggml_backend_cann_context* cann_ctx = - (ggml_backend_cann_context*)backend->context; - - return ggml_backend_cann_buffer_type(cann_ctx->device); -} - -/** - * @brief Sets tensor data asynchronously in the CANN backend. - * - * This function asynchronously sets tensor data in the CANN backend. Depending - * on the tensor type, it may perform data transformations before copying data - * to the device. - * - * @param backend Pointer to the CANN backend structure. - * @param tensor Pointer to the tensor structure to set data for. - * @param data Pointer to the host data to copy to the tensor. - * @param offset Offset in bytes within the host data. - * @param size Size of the data to copy in bytes. - */ -GGML_CALL static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend, - ggml_tensor *tensor, - const void *data, - size_t offset, - size_t size) { - ggml_backend_cann_context *cann_ctx = - (ggml_backend_cann_context *)backend->context; - - if (!need_transform(tensor->type)) { - ACL_CHECK(aclrtMemcpyAsync((char *)tensor->data + offset, size, data, - size, ACL_MEMCPY_HOST_TO_DEVICE, - cann_ctx->stream())); - } else { - void *transform_buffer = malloc(size); - ggml_backend_cann_transform(tensor, data, transform_buffer); - -#ifndef NDEBUG - void *check_buffer = malloc(size); - ggml_backend_cann_transform_back(tensor, transform_buffer, - check_buffer); - GGML_ASSERT(memcmp(data, check_buffer, size)); - free(check_buffer); -#endif - ACL_CHECK(aclrtMemcpyAsync( - (char *)tensor->data + offset, size, transform_buffer, size, - ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream())); - ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream())); - free(transform_buffer); - } -} - -GGML_CALL static void ggml_backend_cann_get_tensor_async( - ggml_backend_t backend, const ggml_tensor *tensor, void *data, - size_t offset, size_t size) { - ggml_backend_cann_context *cann_ctx = - (ggml_backend_cann_context *)backend->context; - ggml_backend_buffer_t buf = - tensor->view_src ? tensor->view_src->buffer : tensor->buffer; - - GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && - "unsupported buffer type"); - - if (!need_transform(tensor->type)) { - ACL_CHECK(aclrtMemcpyAsync(data, size, (char *)tensor->data + offset, - size, ACL_MEMCPY_DEVICE_TO_HOST, - cann_ctx->stream())); - } else { - void *transform_buffer = malloc(size); - ACL_CHECK(aclrtMemcpyAsync( - transform_buffer, size, (char *)tensor->data + offset, size, - ACL_MEMCPY_DEVICE_TO_HOST, cann_ctx->stream())); - ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream())); - ggml_backend_cann_transform_back(tensor, transform_buffer, data); - free(transform_buffer); - } -} - -/** - * @brief Asynchronously copies tensor data between CANN backends. - * - * This function copies tensor data asynchronously between two CANN backends. It - * checks if both tensors reside in CANN buffers and whether the devices support - * peer-to-peer access for direct copying. If not, it returns false. - * - * @param backend_src Pointer to the source CANN backend structure. - * @param backend_dst Pointer to the destination CANN backend structure. - * @param src Pointer to the source tensor to copy data from. - * @param dst Pointer to the destination tensor to copy data to. - * @return true if the copy operation succeeds, false otherwise. - */ -GGML_CALL static bool ggml_backend_cann_cpy_tensor_async( - ggml_backend_t backend_src, ggml_backend_t backend_dst, - const ggml_tensor* src, ggml_tensor* dst) { - GGML_ASSERT(ggml_backend_is_cann(backend_src) || - ggml_backend_is_cann(backend_dst)); - - if (!ggml_backend_buffer_is_cann(src->buffer) || - !ggml_backend_buffer_is_cann(dst->buffer)) { - return false; - } - - ggml_backend_buffer_t buf_src = - src->view_src ? src->view_src->buffer : src->buffer; - ggml_backend_buffer_t buf_dst = - dst->view_src ? dst->view_src->buffer : dst->buffer; - - ggml_backend_cann_context* cann_ctx_src = - (ggml_backend_cann_context*)backend_src->context; - ggml_backend_cann_context* cann_ctx_dst = - (ggml_backend_cann_context*)backend_dst->context; - - size_t copy_size = ggml_nbytes(dst); - if (backend_src != backend_dst) { - ggml_backend_cann_buffer_context* buf_ctx_src = - (ggml_backend_cann_buffer_context*)buf_src->context; - ggml_backend_cann_buffer_context* buf_ctx_dst = - (ggml_backend_cann_buffer_context*)buf_dst->context; - - GGML_ASSERT(cann_ctx_src->device == buf_ctx_src->device); - GGML_ASSERT(cann_ctx_dst->device == buf_ctx_dst->device); - - int32_t canAccessPeer = 0; - ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, cann_ctx_src->device, - cann_ctx_dst->device)); - if (!canAccessPeer) { - return false; - } - - // need open both directions for memcpyasync between devices. - ggml_cann_set_device(cann_ctx_dst->device); - ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0)); - ggml_cann_set_device(cann_ctx_src->device); - ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0)); - - ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, - ACL_MEMCPY_DEVICE_TO_DEVICE, - cann_ctx_src->stream())); - - //TODO: workaround for Event didn`t work here. - aclrtSynchronizeStream(cann_ctx_src->stream()); - } else { - // src and dst are on the same backend - ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, - ACL_MEMCPY_DEVICE_TO_DEVICE, - cann_ctx_dst->stream())); - } - - return true; -} - -/** - * @brief Synchronizes a CANN backend. - * - * This function synchronizes the specified CANN backend by waiting for all - * operations in its associated stream to complete. - * - * @param backend Pointer to the CANN backend structure to synchronize. - */ -GGML_CALL static void ggml_backend_cann_synchronize(ggml_backend_t backend) { - ggml_backend_cann_context* cann_ctx = - (ggml_backend_cann_context*)backend->context; - - ggml_cann_set_device(cann_ctx->device); - - ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream())); -} - -/** - * @brief Computes a computational graph using a CANN backend. - * - * This function computes the operations defined in the computational graph - * using the specified CANN backend. - * - * @param backend Pointer to the CANN backend structure to use for computation. - * @param cgraph Pointer to the computational graph structure containing nodes - * representing operations to be computed. - * @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation - * completes successfully, otherwise an appropriate error status. - */ -GGML_CALL static enum ggml_status ggml_backend_cann_graph_compute( - ggml_backend_t backend, ggml_cgraph* cgraph) { - ggml_backend_cann_context* cann_ctx = - (ggml_backend_cann_context*)backend->context; - - ggml_cann_set_device(cann_ctx->device); - - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor* node = cgraph->nodes[i]; - - if (ggml_is_empty(node) || node->op == GGML_OP_NONE) { - continue; - } - - bool ok = ggml_cann_compute_forward(*cann_ctx, node); - - if (!ok) { - GGML_CANN_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, - node->name, ggml_op_name(node->op)); - } - GGML_ASSERT(ok); - } - - return GGML_STATUS_SUCCESS; -} - -/** - * @brief Checks if the CANN backend supports a specific operation. - * - * This function checks whether the specified operation is supported by the - * CANN backend. - * - * @param backend Pointer to the CANN backend structure to check support for - * the operation. - * @param op Pointer to the tensor representing the operation to check. - * @return bool Returns true if the operation is supported by the backend, - * otherwise false. - */ -GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend, - const ggml_tensor* op) { - switch (op->op) { - case GGML_OP_UNARY: - switch (ggml_get_unary_op(op)) { - case GGML_UNARY_OP_GELU: - case GGML_UNARY_OP_SILU: - case GGML_UNARY_OP_RELU: - case GGML_UNARY_OP_HARDSIGMOID: - case GGML_UNARY_OP_HARDSWISH: - case GGML_UNARY_OP_GELU_QUICK: - case GGML_UNARY_OP_TANH: - return true; - default: - return false; - } - case GGML_OP_MUL_MAT: { - switch (op->src[0]->type) { - case GGML_TYPE_F16: - case GGML_TYPE_F32: - case GGML_TYPE_Q8_0: - // TODO: fix me - // Current groupsize should not be greater than k-1 in - // aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(). - case GGML_TYPE_Q4_0: - return true; - default: - return false; - } - } - case GGML_OP_MUL_MAT_ID: - return false; - // embedding - case GGML_OP_GET_ROWS: { - switch (op->src[0]->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q8_0: - return true; - default: - return false; - } - } break; - case GGML_OP_CPY: { - switch (op->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q4_0: - return true; - default: - return false; - } - } - case GGML_OP_DUP: - case GGML_OP_REPEAT: - case GGML_OP_CONCAT: - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_PERMUTE: - case GGML_OP_TRANSPOSE: - case GGML_OP_NORM: - case GGML_OP_ADD: - case GGML_OP_MUL: - case GGML_OP_DIV: - case GGML_OP_RMS_NORM: - case GGML_OP_SCALE: - case GGML_OP_SQR: - case GGML_OP_CLAMP: - case GGML_OP_CONT: - case GGML_OP_DIAG_MASK_INF: - case GGML_OP_SOFT_MAX: - case GGML_OP_ROPE: - case GGML_OP_IM2COL: - case GGML_OP_POOL_2D: - case GGML_OP_SUM_ROWS: - case GGML_OP_ARGSORT: - case GGML_OP_ACC: - case GGML_OP_GROUP_NORM: - case GGML_OP_UPSCALE: - case GGML_OP_PAD: - case GGML_OP_ARANGE: - case GGML_OP_TIMESTEP_EMBEDDING: - case GGML_OP_LEAKY_RELU: - return true; - default: - return false; - } - - GGML_UNUSED(backend); -} - -/** - * @brief Checks if the backend buffer type is associated with the CANN backend. - * - * This function checks whether the provided backend buffer type is associated - * with the CANN backend based on the comparison of its name retrieval function - * pointer. - * - * @param buft Pointer to the backend buffer type to check. - * @return bool Returns true if the buffer type is associated with the CANN - * backend, otherwise false. - */ -static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) { - return buft->iface.get_name == ggml_backend_cann_buffer_type_name; -} - -/** - * @brief Checks if the CANN backend supports a specific backend buffer type. - * - * This function determines whether the CANN backend supports the given backend - * buffer type by comparing the device context of the backend and buffer type. - * It returns true if the devices are same between the backend context and - * buffer type context. - * - * @param backend Pointer to the CANN backend. - * @param buft Pointer to the backend buffer type to check. - * @return bool Returns true if the CANN backend supports the buffer type, - * otherwise false. - */ -GGML_CALL static bool ggml_backend_cann_supports_buft( - ggml_backend_t backend, ggml_backend_buffer_type_t buft) { - if (ggml_backend_buft_is_cann(buft)) { - ggml_backend_cann_context * cann_ctx = - (ggml_backend_cann_context *)backend->context; - ggml_backend_cann_buffer_type_context * buft_ctx = - (ggml_backend_cann_buffer_type_context *)buft->context; - return buft_ctx->device == cann_ctx->device; - } - return false; -} - -/** - * @brief Determines if a tensor operation should be offloaded to the CANN - * backend. - * - * This function checks if a given tensor operation should be offloaded to the - * CANN backend based on the operation type and the size of the tensor. It - * returns true if the second dimension (ne[1]) of the tensor is greater than or - * equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS. - * - * @param backend Pointer to the CANN backend. - * @param op Pointer to the tensor operation to check. - * @return bool Returns true if the operation should be offloaded, otherwise - * false. - */ -GGML_CALL static bool ggml_backend_cann_offload_op(ggml_backend_t backend, - const ggml_tensor* op) { - const int min_batch_size = 32; - GGML_UNUSED(backend); - - return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS; -} - -/** - * @brief Creates a new event for the CANN backend. - * - * This function initializes a new event for the CANN backend by setting the - * device and creating an ACL runtime event. The created event is then wrapped - * in a ggml_backend_event structure and returned. - * - * @param backend Pointer to the CANN backend. - * @return ggml_backend_event_t Returns a pointer to the new event structure. - */ -static ggml_backend_event_t ggml_backend_cann_event_new( - ggml_backend_t backend) { - ggml_backend_cann_context* cann_ctx = - (ggml_backend_cann_context*)backend->context; - - ggml_cann_set_device(cann_ctx->device); - - aclrtEvent event; - ACL_CHECK(aclrtCreateEvent(&event)); - - return new ggml_backend_event{ - /* .backend = */ backend, - /* .context = */ event, - }; -} - -/** - * @brief Frees a CANN backend event. - * - * This function destroys the ACL runtime event associated with the given CANN - * backend event and then deletes the event structure itself. - * - * @param event Pointer to the event structure to be freed. - */ -static void ggml_backend_cann_event_free(ggml_backend_event_t event) { - ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context)); - - delete event; -} - -/** - * @brief Records an event on the CANN backend stream. - * - * This function records the given event on the ACL runtime stream associated - * with the backend context. - * - * @param event Pointer to the event structure to be recorded. - */ -static void ggml_backend_cann_event_record(ggml_backend_event_t event) { - ggml_backend_cann_context* cann_ctx = - (ggml_backend_cann_context*)event->backend->context; - - ACL_CHECK(aclrtRecordEvent((aclrtEvent)event->context, cann_ctx->stream())); -} - -/** - * @brief Waits for a recorded event to complete on the CANN backend stream. - * - * This function makes the given backend wait for the event to complete on its - * ACL runtime stream. - * - * @param backend Pointer to the backend structure. - * @param event Pointer to the event structure that the backend needs to wait - * for. - */ -static void ggml_backend_cann_event_wait(ggml_backend_t backend, - ggml_backend_event_t event) { - ggml_backend_cann_context* cann_ctx = - (ggml_backend_cann_context*)backend->context; - - if (ggml_backend_is_cann(event->backend)) { - ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(), - (aclrtEvent)event->context)); - } else { - GGML_ABORT("fatal error"); - } -} - -/** - * @brief Synchronizes the given event on the CANN backend. - * - * This function waits for the specified event to complete on the ACL runtime. - * - * @param event Pointer to the event structure to be synchronized. - */ -static void ggml_backend_cann_event_synchronize(ggml_backend_event_t event) { - ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context)); -} - -/** - * @brief Structure defining the interface for the CANN backend. - * - * This structure contains function pointers for various operations - * supported by the CANN backend, including name retrieval, memory - * management, tensor operations, synchronization, and event handling. - */ -static ggml_backend_i ggml_backend_cann_interface = { - /* .get_name = */ ggml_backend_cann_name, - /* .free = */ ggml_backend_cann_free, - /* .get_default_buffer_type = */ ggml_backend_cann_get_default_buffer_type, - /* .set_tensor_async = */ ggml_backend_cann_set_tensor_async, - /* .get_tensor_async = */ ggml_backend_cann_get_tensor_async, - /* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async, - /* .synchronize = */ ggml_backend_cann_synchronize, - /* .graph_plan_create = */ NULL, - /* .graph_plan_free = */ NULL, - /* .graph_plan_update = */ NULL, - /* .graph_plan_compute = */ NULL, - /* .graph_compute = */ ggml_backend_cann_graph_compute, - /* .supports_op = */ ggml_backend_cann_supports_op, - /* .supports_buft = */ ggml_backend_cann_supports_buft, - /* .offload_op = */ ggml_backend_cann_offload_op, - /* .event_new = */ ggml_backend_cann_event_new, - /* .event_free = */ ggml_backend_cann_event_free, - /* .event_record = */ ggml_backend_cann_event_record, - /* .event_wait = */ ggml_backend_cann_event_wait, - /* .event_synchronize = */ ggml_backend_cann_event_synchronize, -}; - -/** - * @brief Return the hardcoded GUID for the CANN backend. - * - * This function returns a static GUID which uniquely identifies the CANN - * backend. - * - * @return A pointer to the static GUID. - */ -static ggml_guid_t ggml_backend_cann_guid() { - static ggml_guid guid = {0xa1, 0x94, 0xaf, 0xac, 0xbd, 0x4f, 0x47, 0x34, - 0xbe, 0x1a, 0x9e, 0x71, 0x1f, 0x9e, 0xed, 0x64}; - return &guid; -} - -GGML_CALL ggml_backend_t ggml_backend_cann_init(int32_t device) { - aclInit(nullptr); - if (device < 0 || device >= ggml_backend_cann_get_device_count()) { - GGML_CANN_LOG_ERROR("%s: error: invalid device %d\n", __func__, device); - return nullptr; - } - - ggml_backend_cann_context* ctx = new ggml_backend_cann_context(device); - if (ctx == nullptr) { - GGML_CANN_LOG_ERROR("%s: error: failed to allocate context\n", __func__); - return nullptr; - } - - ggml_backend_t cann_backend = - new ggml_backend{/* .guid = */ ggml_backend_cann_guid(), - /* .interface = */ ggml_backend_cann_interface, - /* .context = */ ctx}; - - return cann_backend; -} - -GGML_CALL bool ggml_backend_is_cann(ggml_backend_t backend) { - return backend != NULL && - ggml_guid_matches(backend->guid, ggml_backend_cann_guid()); -} - -GGML_CALL int32_t ggml_backend_cann_get_device_count() { - return ggml_cann_info().device_count; -} - -GGML_CALL void ggml_backend_cann_get_device_description( - int32_t device, char* description, size_t description_size) { - ggml_cann_set_device(device); - const char* soc_name = aclrtGetSocName(); - snprintf(description, description_size, "%s", soc_name); -} - -GGML_CALL void ggml_backend_cann_get_device_memory(int32_t device, size_t* free, - size_t* total) { - ggml_cann_set_device(device); - ACL_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, free, total)); -} - -// backend registry -/** - * @brief Initializes a CANN backend based on the provided parameters. - * - * This function initializes a CANN backend using the device index and then - * initializes the backend using `ggml_backend_cann_init`. - * - * @param params Parameters for initialization (unused in this implementation). - * @param user_data User data containing the device index to initialize the - * backend. - * @return ggml_backend_t The initialized CANN backend. - */ -GGML_CALL static ggml_backend_t ggml_backend_reg_cann_init(const char* params, - void* user_data) { - ggml_backend_t cann_backend = - ggml_backend_cann_init((int)(intptr_t)user_data); - return cann_backend; - - GGML_UNUSED(params); -} - -extern "C" GGML_CALL int ggml_backend_cann_reg_devices(); - -/** - * @brief Registers CANN (Ascend) devices as backend options. - * - * This function initializes ACL, retrieves the number of available CANN - * devices, and registers each device as a backend option using - * `ggml_backend_register`. Each device is given a unique name based on - * `GGML_CANN_NAME` followed by its index. - * - * @return int The number of CANN devices registered. - */ -GGML_CALL int ggml_backend_cann_reg_devices() { - uint32_t device_count = ggml_backend_cann_get_device_count(); - // initialization - for (uint32_t i = 0; i < device_count; i++) { - char name[128]; - snprintf(name, sizeof(name), "CANN%d", i); - ggml_backend_register(name, ggml_backend_reg_cann_init, - ggml_backend_cann_buffer_type(i), - (void*)(intptr_t)i); - } - return device_count; -} diff --git a/ggml/include/ggml-amx.h b/include/ggml-amx.h similarity index 100% rename from ggml/include/ggml-amx.h rename to include/ggml-amx.h diff --git a/scripts/sync-llama-am.sh b/scripts/sync-llama-am.sh index bea88da5d..b0feb0af3 100755 --- a/scripts/sync-llama-am.sh +++ b/scripts/sync-llama-am.sh @@ -76,6 +76,7 @@ while read c; do ggml/src/ggml*.m \ ggml/src/ggml*.metal \ ggml/src/ggml*.cu \ + ggml/src/ggml-amx/* \ ggml/src/ggml-cann/* \ ggml/src/ggml-cuda/* \ ggml/src/ggml-sycl/* \ @@ -121,6 +122,8 @@ if [ -f $SRC_GGML/llama-src.patch ]; then # ggml/src/ggml-aarch64.c -> src/ggml-aarch64.c # ggml/src/ggml-aarch64.h -> src/ggml-aarch64.h # ggml/src/ggml-alloc.c -> src/ggml-alloc.c + # ggml/src/ggml-amx/* -> src/ggml-amx/* + # ggml/src/ggml-amx.cpp -> src/ggml-amx.cpp # ggml/src/ggml-backend-impl.h -> src/ggml-backend-impl.h # ggml/src/ggml-backend.cpp -> src/ggml-backend.cpp # ggml/src/ggml-blas.cpp -> src/ggml-blas.cpp @@ -142,6 +145,7 @@ if [ -f $SRC_GGML/llama-src.patch ]; then # # ggml/include/ggml.h -> include/ggml.h # ggml/include/ggml-alloc.h -> include/ggml-alloc.h + # ggml/include/ggml-amx.h -> include/ggml-amx.h # ggml/include/ggml-backend.h -> include/ggml-backend.h # ggml/include/ggml-blas.h -> include/ggml-blas.h # ggml/include/ggml-cann.h -> include/ggml-cann.h @@ -169,6 +173,8 @@ if [ -f $SRC_GGML/llama-src.patch ]; then -e 's/\/ggml\/src\/ggml-aarch64\.c/\/src\/ggml-aarch64.c/g' \ -e 's/\/ggml\/src\/ggml-aarch64\.h/\/src\/ggml-aarch64.h/g' \ -e 's/\/ggml\/src\/ggml-alloc\.c/\/src\/ggml-alloc.c/g' \ + -e 's/\/ggml\/src\/ggml-amx\//\/src\/ggml-amx\//g' \ + -e 's/\/ggml\/src\/ggml-amx\.cpp/\/src\/ggml-amx.cpp/g' \ -e 's/\/ggml\/src\/ggml-backend-impl\.h/\/src\/ggml-backend-impl.h/g' \ -e 's/\/ggml\/src\/ggml-backend\.cpp/\/src\/ggml-backend.cpp/g' \ -e 's/\/ggml\/src\/ggml-blas\.cpp/\/src\/ggml-blas.cpp/g' \ @@ -189,6 +195,7 @@ if [ -f $SRC_GGML/llama-src.patch ]; then -e 's/\/ggml\/src\/vulkan-shaders\//\/src\/vulkan-shaders\//g' \ -e 's/\/ggml\/include\/ggml\.h/\/include\/ggml.h/g' \ -e 's/\/ggml\/include\/ggml-alloc\.h/\/include\/ggml-alloc.h/g' \ + -e 's/\/ggml\/include\/ggml-amx\.h/\/include\/ggml-amx.h/g' \ -e 's/\/ggml\/include\/ggml-backend\.h/\/include\/ggml-backend.h/g' \ -e 's/\/ggml\/include\/ggml-blas\.h/\/include\/ggml-blas.h/g' \ -e 's/\/ggml\/include\/ggml-cann\.h/\/include\/ggml-cann.h/g' \ diff --git a/scripts/sync-llama.last b/scripts/sync-llama.last index 808abe009..9381d6b0a 100644 --- a/scripts/sync-llama.last +++ b/scripts/sync-llama.last @@ -1 +1 @@ -190a37d7977eb5bd6a729299bd1e371208c87149 +668750357e66bfa3d1504b65699f5a0dfe3cb7cb diff --git a/scripts/sync-llama.sh b/scripts/sync-llama.sh index 289dc3118..b3648142e 100755 --- a/scripts/sync-llama.sh +++ b/scripts/sync-llama.sh @@ -8,6 +8,8 @@ cp -rpv ../llama.cpp/ggml/src/ggml.c src/ggml.c cp -rpv ../llama.cpp/ggml/src/ggml-aarch64.c src/ggml-aarch64.c cp -rpv ../llama.cpp/ggml/src/ggml-aarch64.h src/ggml-aarch64.h cp -rpv ../llama.cpp/ggml/src/ggml-alloc.c src/ggml-alloc.c +cp -rpv ../llama.cpp/ggml/src/ggml-amx/* src/ggml-amx/ +cp -rpv ../llama.cpp/ggml/src/ggml-amx.cpp src/ggml-amx.cpp cp -rpv ../llama.cpp/ggml/src/ggml-backend-impl.h src/ggml-backend-impl.h cp -rpv ../llama.cpp/ggml/src/ggml-backend.cpp src/ggml-backend.cpp cp -rpv ../llama.cpp/ggml/src/ggml-blas.cpp src/ggml-blas.cpp @@ -30,6 +32,7 @@ cp -rpv ../llama.cpp/ggml/src/vulkan-shaders/* src/vulkan-shaders/ cp -rpv ../llama.cpp/ggml/include/ggml.h include/ggml.h cp -rpv ../llama.cpp/ggml/include/ggml-alloc.h include/ggml-alloc.h +cp -rpv ../llama.cpp/ggml/include/ggml-amx.h include/ggml-amx.h cp -rpv ../llama.cpp/ggml/include/ggml-backend.h include/ggml-backend.h cp -rpv ../llama.cpp/ggml/include/ggml-blas.h include/ggml-blas.h cp -rpv ../llama.cpp/ggml/include/ggml-cann.h include/ggml-cann.h diff --git a/scripts/sync-whisper-am.sh b/scripts/sync-whisper-am.sh index 2fbf3c667..2ede7af0b 100755 --- a/scripts/sync-whisper-am.sh +++ b/scripts/sync-whisper-am.sh @@ -62,6 +62,7 @@ while read c; do ggml/src/ggml*.m \ ggml/src/ggml*.metal \ ggml/src/ggml*.cu \ + ggml/src/ggml-amx/* \ ggml/src/ggml-cann/* \ ggml/src/ggml-cuda/* \ ggml/src/ggml-sycl/* \ @@ -106,6 +107,8 @@ if [ -f $SRC_GGML/whisper-src.patch ]; then # ggml/src/ggml-aarch64.c -> src/ggml-aarch64.c # ggml/src/ggml-aarch64.h -> src/ggml-aarch64.h # ggml/src/ggml-alloc.c -> src/ggml-alloc.c + # ggml/src/ggml-amx/* -> src/ggml-amx/* + # ggml/src/ggml-amx.cpp -> src/ggml-amx.cpp # ggml/src/ggml-backend-impl.h -> src/ggml-backend-impl.h # ggml/src/ggml-backend.cpp -> src/ggml-backend.cpp # ggml/src/ggml-blas.cpp -> src/ggml-blas.cpp @@ -127,6 +130,7 @@ if [ -f $SRC_GGML/whisper-src.patch ]; then # # ggml/include/ggml.h -> include/ggml.h # ggml/include/ggml-alloc.h -> include/ggml-alloc.h + # ggml/include/ggml-amx.h -> include/ggml-amx.h # ggml/include/ggml-backend.h -> include/ggml-backend.h # ggml/include/ggml-blas.h -> include/ggml-blas.h # ggml/include/ggml-cann.h -> include/ggml-cann.h @@ -153,6 +157,8 @@ if [ -f $SRC_GGML/whisper-src.patch ]; then -e 's/\/ggml\/src\/ggml-aarch64\.c/\/src\/ggml-aarch64.c/g' \ -e 's/\/ggml\/src\/ggml-aarch64\.h/\/src\/ggml-aarch64.h/g' \ -e 's/\/ggml\/src\/ggml-alloc\.c/\/src\/ggml-alloc.c/g' \ + -e 's/\/ggml\/src\/ggml-amx\//\/src\/ggml-amx\//g' \ + -e 's/\/ggml\/src\/ggml-amx\.cpp/\/src\/ggml-amx.cpp/g' \ -e 's/\/ggml\/src\/ggml-backend-impl\.h/\/src\/ggml-backend-impl.h/g' \ -e 's/\/ggml\/src\/ggml-backend\.cpp/\/src\/ggml-backend.cpp/g' \ -e 's/\/ggml\/src\/ggml-blas\.cpp/\/src\/ggml-blas.cpp/g' \ @@ -173,6 +179,7 @@ if [ -f $SRC_GGML/whisper-src.patch ]; then -e 's/\/ggml\/src\/vulkan-shaders\//\/src\/vulkan-shaders\//g' \ -e 's/\/ggml\/include\/ggml\.h/\/include\/ggml.h/g' \ -e 's/\/ggml\/include\/ggml-alloc\.h/\/include\/ggml-alloc.h/g' \ + -e 's/\/ggml\/include\/ggml-amx\.h/\/include\/ggml-amx.h/g' \ -e 's/\/ggml\/include\/ggml-backend\.h/\/include\/ggml-backend.h/g' \ -e 's/\/ggml\/include\/ggml-blas\.h/\/include\/ggml-blas.h/g' \ -e 's/\/ggml\/include\/ggml-cann\.h/\/include\/ggml-cann.h/g' \ diff --git a/scripts/sync-whisper.sh b/scripts/sync-whisper.sh index 123358c6a..4dd5f2c38 100755 --- a/scripts/sync-whisper.sh +++ b/scripts/sync-whisper.sh @@ -8,6 +8,8 @@ cp -rpv ../whisper.cpp/ggml/src/ggml.c src/ggml.c cp -rpv ../whisper.cpp/ggml/src/ggml-aarch64.c src/ggml-aarch64.c cp -rpv ../whisper.cpp/ggml/src/ggml-aarch64.h src/ggml-aarch64.h cp -rpv ../whisper.cpp/ggml/src/ggml-alloc.c src/ggml-alloc.c +cp -rpv ../whisper.cpp/ggml/src/ggml-amx/* src/ggml-amx/ +cp -rpv ../whisper.cpp/ggml/src/ggml-amx.cpp src/ggml-amx.cpp cp -rpv ../whisper.cpp/ggml/src/ggml-backend-impl.h src/ggml-backend-impl.h cp -rpv ../whisper.cpp/ggml/src/ggml-backend.cpp src/ggml-backend.cpp cp -rpv ../whisper.cpp/ggml/src/ggml-blas.cpp src/ggml-blas.cpp @@ -30,6 +32,7 @@ cp -rpv ../whisper.cpp/ggml/src/vulkan-shaders/* src/vulkan-shaders/ cp -rpv ../whisper.cpp/ggml/include/ggml.h include/ggml.h cp -rpv ../whisper.cpp/ggml/include/ggml-alloc.h include/ggml-alloc.h +cp -rpv ../whisper.cpp/ggml/include/ggml-amx.h include/ggml-amx.h cp -rpv ../whisper.cpp/ggml/include/ggml-backend.h include/ggml-backend.h cp -rpv ../whisper.cpp/ggml/include/ggml-blas.h include/ggml-blas.h cp -rpv ../whisper.cpp/ggml/include/ggml-cann.h include/ggml-cann.h diff --git a/ggml/src/ggml-amx.cpp b/src/ggml-amx.cpp similarity index 100% rename from ggml/src/ggml-amx.cpp rename to src/ggml-amx.cpp diff --git a/src/ggml-amx/common.h b/src/ggml-amx/common.h new file mode 100644 index 000000000..2b6c63527 --- /dev/null +++ b/src/ggml-amx/common.h @@ -0,0 +1,93 @@ +#pragma once + +#include "ggml.h" +#include "ggml-cpu-impl.h" // + +#include +#include +#include + +#if defined(_OPENMP) +#include +#endif + +#define TILE_M 16 +#define TILE_N 16 +#define TILE_K 32 +#define VNNI_BLK 4 + +#define AMX_BLK_SIZE 32 + +#define TMM0 0 +#define TMM1 1 +#define TMM2 2 +#define TMM3 3 +#define TMM4 4 +#define TMM5 5 +#define TMM6 6 +#define TMM7 7 + +// parallel routines +template ::value, int>::type = 0> +inline T div_up(T x, T y) { return (x + y - 1) / y; } + +template +inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) { +#if 0 + // onednn partition pattern + T& n_my = n_end; + if (nth <= 1 || n == 0) { + n_start = 0; + n_my = n; + } else { + T n1 = div_up(n, nth); + T n2 = n1 - 1; + T T1 = n - n2 * nth; + n_my = ith < T1 ? n1 : n2; + n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2; + } + n_end += n_start; +#else + // pytorch aten partition pattern + T n_my = div_up(n, nth); + n_start = ith * n_my; + n_end = std::min(n_start + n_my, n); +#endif +} + +template +inline void parallel_for(int nth, int n, const func_t& f) { +#if defined(_OPENMP) +#pragma omp parallel num_threads(nth) +{ + //int nth = omp_get_num_threads(); + int ith = omp_get_thread_num(); + int tbegin, tend; + balance211(n, nth, ith, tbegin, tend); + f(tbegin, tend); +} +#else + f(0, n); + + GGML_UNUSED(nth); +#endif +} + +// quantized types that have AMX support +inline bool qtype_has_amx_kernels(const enum ggml_type type) { + // TODO: fix padding for vnni format + return (type == GGML_TYPE_Q4_0) || + (type == GGML_TYPE_Q4_1); + //(type == GGML_TYPE_Q8_0) || + //(type == GGML_TYPE_Q4_K) || + //(type == GGML_TYPE_Q5_K) || + //(type == GGML_TYPE_Q6_K) || + //(type == GGML_TYPE_IQ4_XS); +} + +// ggml backend context +struct ggml_backend_amx_context { + int n_threads = GGML_DEFAULT_N_THREADS; + std::unique_ptr work_data; + size_t work_size = 0; +}; diff --git a/src/ggml-amx/mmq.cpp b/src/ggml-amx/mmq.cpp new file mode 100644 index 000000000..239d15121 --- /dev/null +++ b/src/ggml-amx/mmq.cpp @@ -0,0 +1,2509 @@ + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Wpedantic" +#pragma GCC diagnostic ignored "-Wunused-local-typedefs" +#endif + +#include "mmq.h" +#include "ggml-impl.h" +#include "ggml-quants.h" +#include +#include + +#if defined(__gnu_linux__) +#include +#include +#endif + +#if defined(_OPENMP) +#include +#endif + +#if (defined(_WIN32) || defined(_WIN64)) +#define RESTRICT __restrict +#else +#define RESTRICT __restrict__ +#endif + +#if (defined(_WIN32) || defined(_WIN64)) +#define ALWAYS_INLINE __forceinline +#elif __has_attribute(always_inline) || defined(__GNUC__) +#define ALWAYS_INLINE __attribute__((__always_inline__)) inline +#else +#define ALWAYS_INLINE inline +#endif + +#if defined(__AMX_INT8__) + +namespace { + +// Forced unrolling +template +struct Unroll { + template + ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + Unroll{}(f, args...); + f(std::integral_constant{}, args...); + } +}; + +template <> +struct Unroll<1> { + template + ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + f(std::integral_constant{}, args...); + } +}; + +// type traits +template struct PackedTypes {}; +template <> struct PackedTypes { using type = int8_t; }; +template <> struct PackedTypes { using type = uint8_t; }; +template <> struct PackedTypes { using type = int8_t; }; +template using packed_B_type = typename PackedTypes::type; + +template +struct do_compensate : std::integral_constant::value> {}; + +template +struct do_unpack : std::integral_constant::value || + std::is_same::value> {}; + +template +struct is_type_qkk : std::integral_constant::value || + std::is_same::value || + std::is_same::value || + std::is_same::value> {}; + +#define GGML_DISPATCH_FLOATING_TYPES(TYPE, ...) \ + [&] { \ + switch (TYPE) { \ + case GGML_TYPE_F16: { \ + using type = ggml_fp16_t; \ + constexpr int blck_size = 16; \ + return __VA_ARGS__(); \ + } \ + case GGML_TYPE_BF16: { \ + using type = ggml_bf16_t; \ + constexpr int blck_size = 32; \ + return __VA_ARGS__(); \ + } \ + default: \ + fprintf(stderr, "Unsupported floating data type\n"); \ + } \ + }() + +#define GGML_DISPATCH_QTYPES(QT, ...) \ + [&] { \ + switch (QT) { \ + case GGML_TYPE_Q4_0: { \ + using type = block_q4_0; \ + using vec_dot_type = block_q8_0; \ + constexpr int blck_size = QK4_0; \ + return __VA_ARGS__(); \ + } \ + case GGML_TYPE_Q4_1: { \ + using type = block_q4_1; \ + using vec_dot_type = block_q8_1; \ + constexpr int blck_size = QK4_1; \ + return __VA_ARGS__(); \ + } \ + case GGML_TYPE_Q8_0: { \ + using type = block_q8_0; \ + using vec_dot_type = block_q8_0; \ + constexpr int blck_size = QK8_0; \ + return __VA_ARGS__(); \ + } \ + case GGML_TYPE_Q4_K: { \ + using type = block_q4_K; \ + using vec_dot_type = block_q8_K; \ + constexpr int blck_size = QK_K; \ + return __VA_ARGS__(); \ + } \ + case GGML_TYPE_Q5_K: { \ + using type = block_q5_K; \ + using vec_dot_type = block_q8_K; \ + constexpr int blck_size = QK_K; \ + return __VA_ARGS__(); \ + } \ + case GGML_TYPE_Q6_K: { \ + using type = block_q6_K; \ + using vec_dot_type = block_q8_K; \ + constexpr int blck_size = QK_K; \ + return __VA_ARGS__(); \ + } \ + case GGML_TYPE_IQ4_XS: { \ + using type = block_iq4_xs; \ + using vec_dot_type = block_q8_K; \ + constexpr int blck_size = QK_K; \ + return __VA_ARGS__(); \ + } \ + default: \ + fprintf(stderr, "Unsupported quantized data type: %d\n", int(TYPE)); \ + } \ + }() + +#define GGML_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \ + [&] { \ + if (BOOL_V) { \ + constexpr bool BOOL_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool BOOL_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +// define amx tile config data structure +struct tile_config_t{ + uint8_t palette_id = 0; + uint8_t start_row = 0; + uint8_t reserved_0[14] = {0}; + uint16_t colsb[16] = {0}; + uint8_t rows[16] = {0}; +}; + +// Notes: amx tile config +// +// Typically, TMUL calculates A and B of size 16 x 64 containing INT8 values, +// and accumulate the result to a 16 x 16 matrix C containing INT32 values, +// +// As many GGUF quantized types as `block_size` of 32, so a 16-16-32 config is used +// instead of the normally used 16-16-64 config. +// +// Block A: {16, 32}, dtype = int8_t +// Block B: {16, 32}, dtype = uint8_t/int8_t +// Block C: {16, 16}, dtype = int32_t +// +// Block B needs to be prepacked to vnni format before feeding into TMUL: +// packed_B: from {n, k} to {k/vnni_blk, n, vnni_blck}, viewed in 2d, we get {8, 64} +// +// Therefore, we get tileconfig: +// A B C +// rows 16 8 16 +// colsb 32 64 16 +// +// For tile distribution, follow a 2-2-4 pattern, e.g. A used TMM2-TMM3, B used TMM0-TMM1, +// C used TMM4-TMM7: +// B TMM0 B TMM1 +// A TMM2 C TMM4 C TMM6 +// A TMM3 C TMM5 C TMM7 +// +// Each `amx` kernel handles 4 blocks at a time: 2MB * 2NB, when m < 2 * BLOCK_M, unpack A +// will be needed. +// +// Here another commonly used pattern 1-3-3 is skipped, as it is mostly used when m <=16; +// and the sinlge batch gemm (m=1) has a special fast path with `avx512-vnni`. +// +// ref: https://www.intel.com/content/www/us/en/developer/articles/code-sample/ +// advanced-matrix-extensions-intrinsics-functions.html +// + +#define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb +void ggml_tile_config_init(void) { + static thread_local bool is_first_time = true; + + if (!is_first_time) { + return; + } + + static thread_local tile_config_t tc; + tile_config_t current_tc; + _tile_storeconfig(¤t_tc); + + // load only when config changes + if (tc.palette_id == 0 || (memcmp(¤t_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 && + memcmp(¤t_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) { + tc.palette_id = 1; + tc.start_row = 0; + TC_CONFIG_TILE(TMM0, 8, 64); + TC_CONFIG_TILE(TMM1, 8, 64); + TC_CONFIG_TILE(TMM2, 16, 32); + TC_CONFIG_TILE(TMM3, 16, 32); + TC_CONFIG_TILE(TMM4, 16, 64); + TC_CONFIG_TILE(TMM5, 16, 64); + TC_CONFIG_TILE(TMM6, 16, 64); + TC_CONFIG_TILE(TMM7, 16, 64); + _tile_loadconfig(&tc); + } + + is_first_time = false; +} + +// we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation. +// See the notes `s8s8 igemm compensation in avx512-vnni` for detail. +template +int get_tile_size() { + int tile_size = TILE_N * sizeof(TB); + if (do_compensate::value) { + tile_size += TILE_N * sizeof(int32_t); + } + if (std::is_same::value || + std::is_same::value) { + tile_size += TILE_N * 4; + } + if (std::is_same::value) { + tile_size += TILE_N * 2; + } + return tile_size; +} + +template +int get_row_size(int K) { + int KB = K / BLOCK_K; + int row_size = KB * sizeof(TB); + if (do_compensate::value) { + row_size += KB * sizeof(int32_t); + } + if (std::is_same::value || + std::is_same::value) { + row_size += KB * 4; + } + if (std::is_same::value) { + row_size += KB * 2; + } + return row_size; +} + +// vectorized dtype conversion +inline float FP16_TO_FP32(ggml_half val) { + __m256i v = _mm256_setr_epi16( + val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); + __m512 o = _mm512_cvtph_ps(v); + return _mm512_cvtss_f32(o); +} + +inline __m512 FP16_TO_FP32_VEC(ggml_half val) { + __m256i v = _mm256_set1_epi16(val); + return _mm512_cvtph_ps(v); +} + +// horizontal reduce +inline float _mm512_reduce_max_ps(const __m512 x) { + __m512 v = x; + __m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E); + v = _mm512_max_ps(v, v1); + v1 = _mm512_shuffle_f32x4(v, v, 0xB1); + v = _mm512_max_ps(v, v1); + v1 = _mm512_shuffle_ps(v, v, 0x4E); + v = _mm512_max_ps(v, v1); + v1 = _mm512_shuffle_ps(v, v, 0xB1); + v = _mm512_max_ps(v, v1); + return _mm512_cvtss_f32(v); +} + +// transpose utils +#define SHUFFLE_EPI32(a, b, mask) \ + _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask)) +inline void transpose_8x8_32bit(__m256i * v, __m256i * v1) { + // unpacking and 32-bit elements + v1[0] = _mm256_unpacklo_epi32(v[0], v[1]); + v1[1] = _mm256_unpackhi_epi32(v[0], v[1]); + v1[2] = _mm256_unpacklo_epi32(v[2], v[3]); + v1[3] = _mm256_unpackhi_epi32(v[2], v[3]); + v1[4] = _mm256_unpacklo_epi32(v[4], v[5]); + v1[5] = _mm256_unpackhi_epi32(v[4], v[5]); + v1[6] = _mm256_unpacklo_epi32(v[6], v[7]); + v1[7] = _mm256_unpackhi_epi32(v[6], v[7]); + + // shuffling the 32-bit elements + v[0] = SHUFFLE_EPI32(v1[0], v1[2], 0x44); + v[1] = SHUFFLE_EPI32(v1[0], v1[2], 0xee); + v[2] = SHUFFLE_EPI32(v1[4], v1[6], 0x44); + v[3] = SHUFFLE_EPI32(v1[4], v1[6], 0xee); + v[4] = SHUFFLE_EPI32(v1[1], v1[3], 0x44); + v[5] = SHUFFLE_EPI32(v1[1], v1[3], 0xee); + v[6] = SHUFFLE_EPI32(v1[5], v1[7], 0x44); + v[7] = SHUFFLE_EPI32(v1[5], v1[7], 0xee); + + // shuffling 128-bit elements + v1[0] = _mm256_permute2f128_si256(v[2], v[0], 0x02); + v1[1] = _mm256_permute2f128_si256(v[3], v[1], 0x02); + v1[2] = _mm256_permute2f128_si256(v[6], v[4], 0x02); + v1[3] = _mm256_permute2f128_si256(v[7], v[5], 0x02); + v1[4] = _mm256_permute2f128_si256(v[2], v[0], 0x13); + v1[5] = _mm256_permute2f128_si256(v[3], v[1], 0x13); + v1[6] = _mm256_permute2f128_si256(v[6], v[4], 0x13); + v1[7] = _mm256_permute2f128_si256(v[7], v[5], 0x13); +} + +inline void transpose_16x4_32bit(__m512i * r, __m512i * d) { + + static const __m512i index1 = _mm512_set_epi32( + 0x0f, 0x0b, 0x07, 0x03, + 0x0e, 0x0a, 0x06, 0x02, + 0x0d, 0x09, 0x05, 0x01, + 0x0c, 0x08, 0x04, 0x00); + + d[0] = _mm512_permutexvar_epi32(index1, r[0]); + d[1] = _mm512_permutexvar_epi32(index1, r[1]); + d[2] = _mm512_permutexvar_epi32(index1, r[2]); + d[3] = _mm512_permutexvar_epi32(index1, r[3]); + + r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44); + r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee); + r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44); + r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee); + + d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88); + d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd); + d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88); + d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd); +} + +inline void transpose_16x16_32bit(__m512i * v) { + __m512i v1[16]; + v1[0] = _mm512_unpacklo_epi32(v[0], v[1]); + v1[1] = _mm512_unpackhi_epi32(v[0], v[1]); + v1[2] = _mm512_unpacklo_epi32(v[2], v[3]); + v1[3] = _mm512_unpackhi_epi32(v[2], v[3]); + v1[4] = _mm512_unpacklo_epi32(v[4], v[5]); + v1[5] = _mm512_unpackhi_epi32(v[4], v[5]); + v1[6] = _mm512_unpacklo_epi32(v[6], v[7]); + v1[7] = _mm512_unpackhi_epi32(v[6], v[7]); + v1[8] = _mm512_unpacklo_epi32(v[8], v[9]); + v1[9] = _mm512_unpackhi_epi32(v[8], v[9]); + v1[10] = _mm512_unpacklo_epi32(v[10], v[11]); + v1[11] = _mm512_unpackhi_epi32(v[10], v[11]); + v1[12] = _mm512_unpacklo_epi32(v[12], v[13]); + v1[13] = _mm512_unpackhi_epi32(v[12], v[13]); + v1[14] = _mm512_unpacklo_epi32(v[14], v[15]); + v1[15] = _mm512_unpackhi_epi32(v[14], v[15]); + + v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]); + v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]); + v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]); + v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]); + v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]); + v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]); + v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]); + v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]); + v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]); + v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]); + v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]); + v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]); + v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]); + v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]); + v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]); + v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]); + + v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88); + v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88); + v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88); + v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88); + v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd); + v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd); + v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd); + v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd); + v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88); + v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88); + v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88); + v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88); + v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd); + v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd); + v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd); + v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd); + + v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88); + v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88); + v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88); + v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88); + v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88); + v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88); + v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88); + v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88); + v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd); + v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd); + v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd); + v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd); + v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd); + v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd); + v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd); + v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd); +} + +void quantize_row_q8_K_vnni(const float * RESTRICT x, void * RESTRICT vy, int64_t k) { + assert(k % QK_K == 0); + const int KB = k / QK_K; + constexpr int kVecs = QK_K / 16; + + block_q8_K * y = reinterpret_cast(vy); + + // hold 16 float vecs from x + __m512 v[kVecs]; + + // hold the quants vecs + __m512i vq[kVecs / 4]; + + // hold the packed quants vecs + __m512i vq_packed[kVecs / 4]; + + const __m512 signBit = _mm512_set1_ps(-0.f); + + for (int i = 0; i < KB; ++i) { + // Compute max(abs(e)) for the block + __m512 vamax = _mm512_set1_ps(0.f); + for (int j = 0; j < kVecs; ++j) { + v[j] = _mm512_loadu_ps(x); x += 16; + vamax = _mm512_max_ps(vamax, _mm512_andnot_ps(signBit, v[j])); + } + const float amax = _mm512_reduce_max_ps(vamax); + + // Quantize these floats + const float iscale = 127.f / amax; + y[i].d = GGML_FP32_TO_FP16(1 / iscale); + const float id = ( amax != 0.0f ) ? iscale : 0.f; + const __m512 vscale = _mm512_set1_ps(id); + + // Apply multiplier and round to nearest integer + for (int j = 0; j < kVecs; ++j) { + v[j] = _mm512_mul_ps(v[j], vscale); + v[j] = _mm512_roundscale_ps(v[j], (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + + // Pack to epi8 vecs + for (int j = 0; j < kVecs / 4; ++j) { + __m128i q8_0 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 0])); + __m128i q8_1 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 1])); + __m128i q8_2 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 2])); + __m128i q8_3 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 3])); + + __m256i q8_01 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_0), (q8_1), 1); + __m256i q8_23 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_2), (q8_3), 1); + + vq[j] = _mm512_inserti32x8(_mm512_castsi256_si512(q8_01), q8_23, 1); + _mm512_storeu_si512((__m512i *)(y[i].qs + j * 64), vq[j]); + } + + // Compute the bsums with vnni + transpose_16x4_32bit(vq, vq_packed); + + const __m512i one = _mm512_set1_epi8(1); + __m512i sum = _mm512_setzero_si512(); + for (int k = 0; k < 4; ++k) { + sum = _mm512_dpbusd_epi32(sum, one, vq_packed[k]); + } + _mm256_storeu_si256((__m256i *)(y[i].bsums), _mm512_cvtepi32_epi16(sum)); + } +} + +// quantize A from float to `vec_dot_type` +template +inline void from_float(const float * x, char * vy, int64_t k); + +template <> +inline void from_float(const float * x, char * vy, int64_t k) { + quantize_row_q8_0(x, vy, k); +} + +template <> +inline void from_float(const float * x, char * vy, int64_t k) { + quantize_row_q8_1(x, vy, k); +} + +template <> +inline void from_float(const float * x, char * vy, int64_t k) { +#if 1 + // TODO: this is reference impl! + quantize_row_q8_K(x, vy, k); +#else + quantize_row_q8_K_vnni(x, vy, k); +#endif +} + +// load A from memory to array when nrows can not fill in whole tile +void unpack_A(int8_t * RESTRICT tile, const block_q8_0 * RESTRICT A, int lda, int nr) { + assert(nr != TILE_M); + for (int m = 0; m < nr; ++m) { + const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs)); + _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v); + } +} + +void unpack_A(int8_t * RESTRICT tile, const block_q8_1 * RESTRICT A, int lda, int nr) { + assert(nr != TILE_M); + for (int m = 0; m < nr; ++m) { + const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs)); + _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v); + } +} + +template +void unpack_A(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) { + assert(nr <= TILE_M); + for (int m = 0; m < nr; ++m) { + const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs + k * 32)); + _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v); + } +} + +template <> +void unpack_A(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) { + assert(nr <= TILE_M); + // zero padding k from 16 to 32, so that we don't have to re-config amx + const __m128i zero = _mm_setzero_si128(); + for (int m = 0; m < nr; ++m) { + const __m128i v = _mm_loadu_si128((const __m128i *)(A[m * lda].qs + k * 16)); + const __m256i r = _mm256_insertf128_si256(_mm256_castsi128_si256(v), zero, 1); + _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), r); + } +} + +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) +inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) { + const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); + const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); + const __m256i lowMask = _mm256_set1_epi8(0xF); + return _mm256_and_si256(lowMask, bytes); +} + +// used for block_q4_K +inline __m512i bytes_from_nibbles_64(const uint8_t * rsi) { + const __m256i tmp = _mm256_loadu_si256((const __m256i *)rsi); + const __m256i lowMask = _mm256_set1_epi8(0xF); + const __m256i q4l = _mm256_and_si256(tmp, lowMask); + const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(tmp, 4), lowMask); + return _mm512_inserti32x8(_mm512_castsi256_si512(q4l), q4h, 1); +} + +// used for block_q5_K +inline __m512i bytes_from_nibbles_64(const uint8_t * qs, const uint8_t * qh, int k) { + const __m256i lowMask = _mm256_set1_epi8(0xF); + __m256i hmask = _mm256_set1_epi8(1); + hmask = _mm256_slli_epi16(hmask, k); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i *)qs); + const __m256i hbits = _mm256_loadu_si256((const __m256i *)qh); + + const __m256i q5l_0 = _mm256_and_si256(q5bits, lowMask); + const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 0), 4); + const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); + hmask = _mm256_slli_epi16(hmask, 1); + + const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), lowMask); + const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 1), 4); + const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); + + return _mm512_inserti32x8(_mm512_castsi256_si512(q5_0), q5_1, 1); +} + +// used for block_q6_K +inline void bytes_from_nibbles_128(__m512i& r0, __m512i& r1, const uint8_t * qs, const uint8_t * qh) { + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m2 = _mm256_set1_epi8(0x3); + + const __m256i q6bits1 = _mm256_loadu_si256((const __m256i *)qs); + const __m256i q6bits2 = _mm256_loadu_si256((const __m256i *)(qs + 32)); + const __m256i q6bitsH = _mm256_loadu_si256((const __m256i *)qh); + + const __m256i q6h_0 = _mm256_slli_epi16(_mm256_and_si256( q6bitsH, m2), 4); + const __m256i q6h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 2), m2), 4); + const __m256i q6h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 4), m2), 4); + const __m256i q6h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 6), m2), 4); + + const __m256i q6_0 = _mm256_or_si256(_mm256_and_si256(q6bits1, m4), q6h_0); + const __m256i q6_1 = _mm256_or_si256(_mm256_and_si256(q6bits2, m4), q6h_1); + const __m256i q6_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits1, 4), m4), q6h_2); + const __m256i q6_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits2, 4), m4), q6h_3); + + r0 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_0), q6_1, 1); + r1 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_2), q6_3, 1); +} + +inline __m512i packNibbles(__m512i r0, __m512i r1) { + return _mm512_or_si512(r0, _mm512_slli_epi16(r1, 4)); +} + +template +inline void pack_qs(void * RESTRICT packed_B, const TB * RESTRICT B, int KB) { + int8_t tmp[8 * 64]; + __m256i v[8], v2[8]; + for (int n = 0; n < 8; ++n) { + v[n] = bytes_from_nibbles_32(B[n * KB].qs); + } + transpose_8x8_32bit(v, v2); + for (int n = 0; n < 8; ++n) { + _mm256_storeu_si256((__m256i *)(tmp + n * 64), v2[n]); + } + for (int n = 0; n < 8; ++n) { + v[n] = bytes_from_nibbles_32(B[(n + 8) * KB].qs); + } + transpose_8x8_32bit(v, v2); + for (int n = 0; n < 8; ++n) { + _mm256_storeu_si256((__m256i *)(tmp + n * 64 + 32), v2[n]); + } + + // pack again with 128 to fully utilize vector length + for (int n = 0; n < 8; n += 2) { + __m512i r0 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64)); + __m512i r1 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64 + 64)); + __m512i r1r0 = packNibbles(r0, r1); + _mm512_storeu_si512((__m512i *)((char *)packed_B + n * 32), r1r0); + } +} + +template <> +inline void pack_qs(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) { + __m256i v[8], v2[8]; + for (int n = 0; n < 8; ++n) { + v[n] = _mm256_loadu_si256((const __m256i *)(B[n * KB].qs)); + } + transpose_8x8_32bit(v, v2); + for (int n = 0; n < 8; ++n) { + _mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64), v2[n]); + } + for (int n = 0; n < 8; ++n) { + v[n] = _mm256_loadu_si256((const __m256i *)(B[(n + 8) * KB].qs)); + } + transpose_8x8_32bit(v, v2); + for (int n = 0; n < 8; ++n) { + _mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64 + 32), v2[n]); + } +} + +template <> +inline void pack_qs(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) { + __m512i v[16]; + // QK_K 256 with 8 groups, handle 2 groups at a time + char * pb = (char *)packed_B; + for (int k = 0; k < QK_K / 64; ++k) { + // pack 2 groups { n, g, k} to {g, k/4, 4n} + // e.g. {16, 2, 32} to {2, 8, 64} + for (int n = 0; n < TILE_N; ++n) { + v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32); + } + + transpose_16x16_32bit(v); + + // pack again with 128 to fully utilize vector length + for (int n = 0; n < TILE_N; n += 2) { + _mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1])); + pb += 64; + } + } +} + +template <> +inline void pack_qs(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) { + __m512i v[16]; + const __m512i lowMask = _mm512_set1_epi8(0xF); + // QK_K 256 with 8 groups, handle 2 groups at a time + char * pb = (char *)packed_B; + char * ph = (char *)packed_B + (QK_K / 2) * TILE_N; + for (int k = 0; k < QK_K / 64; ++k) { + // pack 2 groups { n, g, k} to {g, k/4, 4n} + // e.g. {16, 2, 32} to {2, 8, 64} + for (int n = 0; n < TILE_N; ++n) { + v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32, B[n * KB].qh, /* group */2 * k); + } + + transpose_16x16_32bit(v); + + // 1. pack lower 4bits with 2 groups + for (int n = 0; n < TILE_N; n += 2) { + // get lower 4 bits + const __m512i r0 = _mm512_and_si512(v[n], lowMask); + const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask); + _mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64; + } + + // 2. pack higher 1bit with 2 groups + const __m512i hmask = _mm512_set1_epi8(0x10); + for (int g = 0; g < 2; ++g) { + __m512i hbits = _mm512_setzero_si512(); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 0], hmask), 4)); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 1], hmask), 3)); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 2], hmask), 2)); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 3], hmask), 1)); + hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 8 + 4], hmask) ); + hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 5], hmask), 1)); + hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 6], hmask), 2)); + hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 7], hmask), 3)); + _mm512_storeu_si512((__m512i *)ph, hbits); ph += 64; + } + } +} + +template <> +inline void pack_qs(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) { + __m512i v[32]; + const __m512i lowMask = _mm512_set1_epi8(0xF); + // QK_K 256 with 8 groups, handle 4 groups at a time + char * pb = (char *)packed_B; + char * ph = (char *)packed_B + (QK_K / 2) * TILE_N; + for (int k = 0; k < QK_K / 128; ++k) { + for (int n = 0; n < TILE_N; ++n) { + bytes_from_nibbles_128(v[n], v[n + 16], B[n * KB].ql + k * 64, B[n * KB].qh + k * 32); + } + + // top half: group 0,1 or 4,5; bottom half: group 2,3 or 6,7 + transpose_16x16_32bit(v); + transpose_16x16_32bit(v + 16); + + // 1. pack lower 4bits with 4 groups + for (int n = 0; n < 32; n += 2) { + const __m512i r0 = _mm512_and_si512(v[n], lowMask); + const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask); + _mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64; + } + + // 2. pack higher 2bit with 4 groups + const __m512i hmask = _mm512_set1_epi8(0x30); + for (int g = 0; g < 8; ++g) { + __m512i hbits = _mm512_setzero_si512(); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 0], hmask), 4)); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 1], hmask), 2)); + hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 4 + 2], hmask) ); + hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 4 + 3], hmask), 2)); + _mm512_storeu_si512((__m512i *)ph, hbits); ph += 64; + } + } +} + +template <> +inline void pack_qs(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) { + __m512i v[16]; + char * pb = (char *)packed_B; + for (int k = 0; k < QK_K / 64; ++k) { + for (int n = 0; n < TILE_N; ++n) { + __m256i r0 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 0); + __m256i r1 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 16); + v[n] = _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1); + } + + transpose_16x16_32bit(v); + + // pack again with 128 to fully utilize vector length + for (int n = 0; n < TILE_N; n += 2) { + _mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1])); + pb += 64; + } + } +} + +// pack B to vnni formats in 4bits or 8 bits +void pack_B(void * RESTRICT packed_B, const block_q4_0 * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + ggml_half * d0 = reinterpret_cast((char *)packed_B + TILE_N * TILE_K / 2); + for (int n = 0; n < TILE_N; ++n) { + d0[n] = B[n * KB].d; + } +} + +void pack_B(void * RESTRICT packed_B, const block_q4_1 * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + ggml_half * d0 = reinterpret_cast((char *)packed_B + TILE_N * TILE_K / 2); + ggml_half * m0 = d0 + TILE_N; + for (int n = 0; n < TILE_N; ++n) { + d0[n] = B[n * KB].d; + m0[n] = B[n * KB].m; + } +} + +inline void s8s8_compensation(void * RESTRICT packed_B) { + // packed_B layout: + // quants {TILE_N, TILEK} int8_t + // d0 {TILE_N} ggml_half + // comp {TILE_N} int32_t + const int offset = TILE_N * TILE_K + TILE_N * sizeof(ggml_half); + __m512i vcomp = _mm512_setzero_si512(); + const __m512i off = _mm512_set1_epi8(static_cast(0x80)); + for (int k = 0; k < 8; ++k) { + __m512i vb = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + k * 64)); + vcomp = _mm512_dpbusd_epi32(vcomp, off, vb); + } + _mm512_storeu_si512((__m512i *)((char *)(packed_B) + offset), vcomp); +} + +void pack_B(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + ggml_half * d0 = reinterpret_cast((char *)packed_B + TILE_N * TILE_K); + for (int n = 0; n < TILE_N; ++n) { + d0[n] = B[n * KB].d; + } + s8s8_compensation(packed_B); +} + +// convert 8 * {min, scale} from int6 to int8 +inline void unpack_mins_and_scales(const uint8_t * scales, uint32_t * utmp) { + const uint32_t kmask1 = 0x3f3f3f3f; + const uint32_t kmask2 = 0x0f0f0f0f; + const uint32_t kmask3 = 0x03030303; + + memcpy(utmp, scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; +} + +// packed_B layout: +// quants {8, TILE_N, 16} uint8 +// scales {8, TILE_N} uint8 +// mins {8, TILE_N} uint8 +// d {TILE_N} ggml_half +// dmin {TILE_N} ggml_half +void pack_B(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + + uint8_t * scales = reinterpret_cast((char *)packed_B + (QK_K / 2) * TILE_N); + uint8_t * mins = scales + 8 * TILE_N; + ggml_half * d = reinterpret_cast(mins + 8 * TILE_N); + ggml_half * dmin = d + TILE_N; + + union { + uint32_t u32[4]; + uint8_t u8[16]; + } s; + + for (int n = 0; n < TILE_N; ++n) { + unpack_mins_and_scales(B[n * KB].scales, s.u32); + for (int k = 0; k < 8; ++k) { + scales[k * TILE_N + n] = s.u8[k]; + mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8]; + } + d[n] = B[n * KB].d; + dmin[n] = B[n * KB].dmin; + } +} + +// packed_B layout: +// quants {8, TILE_N, 16} uint8 +// qh {8, TILE_N, 4} uint8 +// scales {8, TILE_N} uint8 +// mins {8, TILE_N} uint8 +// d {TILE_N} ggml_half +// dmin {TILE_N} ggml_half +void pack_B(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + + uint8_t * scales = reinterpret_cast((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N); + uint8_t * mins = scales + 8 * TILE_N; + ggml_half * d = reinterpret_cast(mins + 8 * TILE_N); + ggml_half * dmin = d + TILE_N; + + union { + uint32_t u32[4]; + uint8_t u8[16]; + } s; + + for (int n = 0; n < TILE_N; ++n) { + unpack_mins_and_scales(B[n * KB].scales, s.u32); + for (int k = 0; k < 8; ++k) { + scales[k * TILE_N + n] = s.u8[k]; + mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8]; + } + d[n] = B[n * KB].d; + dmin[n] = B[n * KB].dmin; + } +} + +// packed_B layout: +// quants {16, TILE_N, 8} uint8 +// qh {16, TILE_N, 4} uint8 +// scales {16, TILE_N} uint8 +// d {TILE_N} ggml_half +void pack_B(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + + uint8_t * scales = reinterpret_cast((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N); + ggml_half * d = reinterpret_cast(scales + 16 * TILE_N); + for (int n = 0; n < TILE_N; ++n) { + const int8_t * ps = B[n * KB].scales; + for (int k = 0; k < 16; ++k) { + scales[k * TILE_N + n] = ps[k]; + } + d[n] = B[n * KB].d; + } +} + +// packed_B layout: +// quants {8, TILE_N, 16} uint8 +// scales {8, TILE_N} int8 +// d {TILE_N} ggml_half +void pack_B(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + + int8_t * scales = reinterpret_cast((char *)packed_B + (QK_K / 2) * TILE_N); + ggml_half * d = reinterpret_cast(scales + 8 * TILE_N); + + // pack the scales + for (int n = 0; n < TILE_N; ++n) { + uint16_t sh = B[n * KB].scales_h; + for (int k = 0; k < 8; k += 2) { + const int16_t ls1 = ((B[n * KB].scales_l[k / 2] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int16_t ls2 = ((B[n * KB].scales_l[k / 2] >> 4) | ((sh << 2) & 0x30)) - 32; + scales[(k + 0) * TILE_N + n] = ls1; + scales[(k + 1) * TILE_N + n] = ls2; + sh >>= 4; + } + d[n] = B[n * KB].d; + } +} + +template> +void unpack_B(packed_B_t * RESTRICT tile, const void * RESTRICT packed_B) { + GGML_UNUSED(tile); + GGML_UNUSED(packed_B); +}; + +template <> +void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B) { + const __m512i off = _mm512_set1_epi8(8); + const __m512i lowMask = _mm512_set1_epi8(0xF); + for (int n = 0; n < 8; n += 2) { + __m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32)); + const __m512i r0 = _mm512_sub_epi8(_mm512_and_si512(bytes, lowMask), off); + const __m512i r1 = _mm512_sub_epi8(_mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask), off); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); + } +} + +template <> +void unpack_B(uint8_t * RESTRICT tile, const void * RESTRICT packed_B) { + const __m512i lowMask = _mm512_set1_epi8(0xF); + for (int n = 0; n < 8; n += 2) { + __m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32)); + const __m512i r0 = _mm512_and_si512(bytes, lowMask); + const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); + } +} + +// packed_B_t for QKK is int8_t +template +void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { + const int packed_B_group_size = QK_K / 2 * TILE_N / 8; + const char * packed_B_group = (const char *)packed_B + k * packed_B_group_size; + const __m512i lowMask = _mm512_set1_epi8(0xF); + for (int n = 0; n < 8; n += 2) { + __m512i bytes = _mm512_loadu_si512(packed_B_group + n * 32); + const __m512i r0 = _mm512_and_si512(bytes, lowMask); + const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); + } +} + +template <> +void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { + // lower 4bits, stride 256 bytes + const int packed_l4_group_size = QK_K / 2 * TILE_N / 8; + const char * pb = (const char *)packed_B + k * packed_l4_group_size; + + // higher 1bit, stride 64 bytes + const int packed_h1_group_size = QK_K / 8 * TILE_N / 8; + const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h1_group_size; + const __m512i hbits = _mm512_loadu_si512(ph); + + const __m512i lowMask = _mm512_set1_epi8(0xF); + __m512i hmask0 = _mm512_set1_epi8(0x1); + __m512i hmask1 = _mm512_set1_epi8(0x2); + + for (int n = 0; n < 8; n += 2) { + __m512i bytes = _mm512_loadu_si512(pb + n * 32); + __m512i r0 = _mm512_and_si512(bytes, lowMask); + __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + __m512i h0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), n), 4); + __m512i h1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), n + 1), 4); + + hmask0 = _mm512_slli_epi16(hmask0, 2); + hmask1 = _mm512_slli_epi16(hmask1, 2); + r0 = _mm512_add_epi8(r0, h0); + r1 = _mm512_add_epi8(r1, h1); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); + } +} + +template <> +void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { + // lower 4bits, stride 128 bytes + const int packed_l4_group_size = QK_K / 2 * TILE_N / 16; + const char * pb = (const char *)packed_B + k * packed_l4_group_size; + + // higher 2bits, stride 64 bytes + const int packed_h2_group_size = QK_K / 4 * TILE_N / 16; + const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h2_group_size; + const __m512i hbits = _mm512_loadu_si512(ph); + + const __m512i off = _mm512_set1_epi8(32); + const __m512i lowMask = _mm512_set1_epi8(0xF); + __m512i hmask0 = _mm512_set1_epi8(0x3); // 0011 + __m512i hmask1 = _mm512_set1_epi8(0xC); // 1100 + + // notes: skip zero padding from row4 to row7 as we have done so in `unpack_A` + __m512i bytes = _mm512_loadu_si512(pb); + __m512i r0 = _mm512_and_si512(bytes, lowMask); + __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + __m512i h0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask0), 4); + __m512i h1 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask1), 2); + _mm512_storeu_si512((__m512i *)(tile + 0), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off)); + _mm512_storeu_si512((__m512i *)(tile + 64), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off)); + + hmask0 = _mm512_slli_epi16(hmask0, 4); + hmask1 = _mm512_slli_epi16(hmask1, 4); + + bytes = _mm512_loadu_si512(pb + 64); + r0 = _mm512_and_si512(bytes, lowMask); + r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + h0 = _mm512_and_si512(hbits, hmask0); + h1 = _mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), 2); + _mm512_storeu_si512((__m512i *)(tile + 128), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off)); + _mm512_storeu_si512((__m512i *)(tile + 192), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off)); +} + +template <> +void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { + static const __m512i values128 = _mm512_set_epi8( + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127 + ); + + const int packed_B_group_size = QK_K / 2 * TILE_N / 8; + const char * pb = (const char *)packed_B + k * packed_B_group_size; + const __m512i lowMask = _mm512_set1_epi8(0xF); + + for (int n = 0; n < 8; n += 2) { + __m512i bytes = _mm512_loadu_si512(pb + n * 32); + const __m512i r0 = _mm512_shuffle_epi8(values128, _mm512_and_si512(bytes, lowMask)); + const __m512i r1 = _mm512_shuffle_epi8(values128, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask)); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); + } +} + +template +struct acc_C {}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) { + const int offset = TILE_N * TILE_K / 2; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset))); + + for (int m = 0; m < nr; ++m) { + const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d)); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_1 * A, int lda, const void * packed_B, int nr) { + const int offset = TILE_N * TILE_K / 2; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset))); + const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset + TILE_N * sizeof(ggml_half)))); + + for (int m = 0; m < nr; ++m) { + const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d)); + const __m512 vs1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].s)); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum); + vsum = _mm512_fmadd_ps(vm0, vs1, vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) { + const int offset = TILE_N * TILE_K; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset))); + + for (int m = 0; m < nr; ++m) { + const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d)); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { + const uint8_t * scales = reinterpret_cast((const char *)packed_B + (QK_K / 2) * TILE_N); + const uint8_t * mins = scales + 8 * TILE_N; + const ggml_half * d0 = reinterpret_cast(mins + 8 * TILE_N); + const ggml_half * dmin = d0 + TILE_N; + + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); + const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin)); + + for (int m = 0; m < nr; ++m) { + const float d1 = A[m * lda].d; + const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); + const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + + __m512i acc_m = _mm512_setzero_si512(); + for (int k = 0; k < 4; ++k) { + __m512i vmask = _mm512_set1_epi32(k); + __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s)); + __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32))); + acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); + } + + vsum = _mm512_fmadd_ps(vtile, vd, vsum); + vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { + const uint8_t * scales = reinterpret_cast((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N); + const uint8_t * mins = scales + 8 * TILE_N; + const ggml_half * d0 = reinterpret_cast(mins + 8 * TILE_N); + const ggml_half * dmin = d0 + TILE_N; + + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); + const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin)); + + for (int m = 0; m < nr; ++m) { + const float d1 = A[m * lda].d; + const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); + const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + + __m512i acc_m = _mm512_setzero_si512(); + for (int k = 0; k < 4; ++k) { + __m512i vmask = _mm512_set1_epi32(k); + __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s)); + __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32))); + acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); + } + + vsum = _mm512_fmadd_ps(vtile, vd, vsum); + vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { + const uint8_t * scales = reinterpret_cast((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N); + const ggml_half * d0 = reinterpret_cast(scales + 16 * TILE_N); + + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); + + for (int m = 0; m < nr; ++m) { + const float d1 = A[m * lda].d; + const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + + vsum = _mm512_fmadd_ps(vtile, vd, vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { + const int8_t * scales = reinterpret_cast((const char *)packed_B + (QK_K / 2) * TILE_N); + const ggml_half * d0 = reinterpret_cast(scales + 8 * TILE_N); + + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); + + for (int m = 0; m < nr; ++m) { + const float d1 = A[m * lda].d; + const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + + vsum = _mm512_fmadd_ps(vtile, vd, vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template constexpr int get_quants_size(); +template <> constexpr int get_quants_size() { return (QK_K / 2) * TILE_N; } +template <> constexpr int get_quants_size() { return (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N; } +template <> constexpr int get_quants_size() { return (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N; } +template <> constexpr int get_quants_size() { return (QK_K / 2) * TILE_N; } + +// used for QKK format +template ::value, int>::type = 0> +inline void scale_C(const int32_t * RESTRICT tile, int32_t * RESTRICT sumi, const void * packed_B, int k, int nr) { + const uint8_t * scales = reinterpret_cast((const char *)packed_B + get_quants_size()); + const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(scales + k * TILE_N))); + + for (int m = 0; m < nr; ++m) { + __m512i vsumi; + if (is_acc) { + vsumi = _mm512_loadu_si512(sumi + m * TILE_N); + } else { + vsumi = _mm512_setzero_si512(); + } + __m512i vtile = _mm512_loadu_si512(tile + m * TILE_N); + vsumi = _mm512_add_epi32(vsumi, _mm512_mullo_epi32(vtile, vscale)); + _mm512_storeu_si512((__m512i *)(sumi + m * TILE_N), vsumi); + } +} + +template +struct tinygemm_kernel_avx { + static void apply(int K, const TA * RESTRICT A, const TB * RESTRICT B, TC * RESTRICT C, int ldc) { + GGML_UNUSED(K); + GGML_UNUSED(A); + GGML_UNUSED(B); + GGML_UNUSED(C); + GGML_UNUSED(ldc); + } +}; + +template +struct tinygemm_kernel_avx { + static void apply(int K, const float * RESTRICT A, const ggml_fp16_t * RESTRICT B, float * RESTRICT C, int ldc) { + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N; + assert(BLOCK_K == 16); + + __m512 va; + __m512 vb[COLS]; + __m512 vc[ROWS * COLS]; + + auto loadc = [&](int idx) { + vc[idx] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](int idx, int k) { + // TODO: use `constexpr` here to get rid of interger div + // when upgraded to C++17 + const int row = idx / COLS; + const int col = idx % COLS; + + if (col == 0) { + va = _mm512_loadu_ps(A + row * K + k); + } + if (row == 0) { + vb[col] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(B + col * K + k))); + } + vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]); + }; + + for (int k = 0; k < K; k += 16) { + Unroll{}(compute, k); + } + + auto storec = [&](int idx) { + const int row = idx / COLS; + const int col = idx % COLS; + C[row * ldc + col] = _mm512_reduce_add_ps(vc[idx]); + }; + Unroll{}(storec); + } +}; + +#define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_avx::apply( \ + K, (const float *)src1->data + mb_start * K, \ + (const type *)src0->data + nb_start * K, \ + (float *)dst->data + mb_start * ldc + nb_start, ldc); + + +// re-organize in the format {NB, KB, TILE_SIZE}: +#define PACKED_INDEX(n, k, KB, tile_size) (n * KB + k) * tile_size + +template +void convert_B_packed_format(void * RESTRICT packed_B, const TB * RESTRICT B, int N, int K, int n_threads) { + const int NB = N / TILE_N; + const int KB = K / BLOCK_K; + const int TILE_SIZE = get_tile_size(); + + // parallel on NB should be enough + parallel_for(n_threads, NB, [&](int begin, int end) { + for (int n = begin; n < end; ++n) { + for (int k = 0; k < KB; ++k) { + int n0 = n * TILE_N; + pack_B((char *)packed_B + PACKED_INDEX(n, k, KB, TILE_SIZE), &B[n0 * KB + k], KB); + } + } + }); +} + +template +struct tinygemm_kernel_vnni {}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q4_0); + + const block_q8_0 * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + __m512i va[8]; + __m512 vc[COLS]; + __m512 vd1; + + // sum of offsets, shared across COLS + // + // avx512-vnni does not have `_mm512_dpbssd_epi32`, + // need to transfrom ss to us: + // a * (b - 8) is equavilent to b * a - 8 * a + // s u u u s u s + // + __m512i vcomp; + + const __m512i off = _mm512_set1_epi8(8); + const __m512i lowMask = _mm512_set1_epi8(0xF); + + auto loadc = [&](int col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](int col, int i) { + // load a and compute compensation + if (col == 0) { + const int32_t * a_ptr = reinterpret_cast(A[0 * KB + i].qs); + vcomp = _mm512_setzero_si512(); + for (int k = 0; k < 8; ++k) { + va[k] = _mm512_set1_epi32(a_ptr[k]); + vcomp = _mm512_dpbusd_epi32(vcomp, off, va[k]); + } + vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d)); + } + + // load b + __m512i vsum = _mm512_setzero_si512(); + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + for (int k = 0; k < 8; k += 2) { + __m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32)); + __m512i vb0 = _mm512_and_si512(bytes, lowMask); + vsum = _mm512_dpbusd_epi32(vsum, vb0, va[k + 0]); + __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va[k + 1]); + } + const int offset = TILE_N * TILE_K / 2; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset))); + vsum = _mm512_sub_epi32(vsum, vcomp); + + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](int col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q4_1); + + const block_q8_1 * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + __m512i va[8]; + __m512i vb[8]; + __m512 vc[COLS]; + __m512 vd1, vs1; + + const __m512i lowMask = _mm512_set1_epi8(0xF); + + auto loadc = [&](int col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](int col, int i) { + // load a + if (col == 0) { + const int32_t * a_ptr = reinterpret_cast(A[0 * KB + i].qs); + for (int k = 0; k < 8; ++k) { + va[k] = _mm512_set1_epi32(a_ptr[k]); + } + vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d)); + vs1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].s)); + } + + // load b + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + for (int k = 0; k < 8; k += 2) { + __m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32)); + vb[k + 0] = _mm512_and_si512(bytes, lowMask); + vb[k + 1] = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + } + const int offset = TILE_N * TILE_K / 2; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset))); + const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset + TILE_N * sizeof(ggml_half)))); + + __m512i vsum = _mm512_setzero_si512(); + for (int k = 0; k < 8; ++k) { + vsum = _mm512_dpbusd_epi32(vsum, vb[k], va[k]); + } + + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]); + vc[col] = _mm512_fmadd_ps(vm0, vs1, vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](int col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q8_0) + TILE_N * sizeof(int32_t); + + const block_q8_0 * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + __m512i va[8]; + __m512i vb[8]; + __m512 vc[COLS]; + __m512 vd1; + + // Notes: s8s8 igemm compensation in avx512-vnni + // change s8s8 to u8s8 with compensate + // a * b = (a + 128) * b - 128 * b + // s s u s u s + // + // (128 * b is pre-computed when packing B to vnni formats) + // + const __m512i off = _mm512_set1_epi8(static_cast(0x80)); + + auto loadc = [&](int col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](int col, int i) { + // load a and add offset 128 + if (col == 0) { + const int32_t * a_ptr = reinterpret_cast(A[0 * KB + i].qs); + for (int k = 0; k < 8; ++k) { + va[k] = _mm512_set1_epi32(a_ptr[k]); + va[k] = _mm512_add_epi8(va[k], off); + } + vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d)); + } + + // load b + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + for (int k = 0; k < 8; ++k) { + vb[k] = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 64)); + } + const int offset = TILE_N * TILE_K; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset))); + const int offset2 = TILE_N * TILE_K + TILE_N * sizeof(ggml_half); + const __m512i vcomp = _mm512_loadu_si512((const __m512i *)(b_ptr + offset2)); + + __m512i vsum = _mm512_setzero_si512(); + for (int k = 0; k < 8; ++k) { + vsum = _mm512_dpbusd_epi32(vsum, va[k], vb[k]); + } + vsum = _mm512_sub_epi32(vsum, vcomp); + + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](int col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q4_K) + TILE_N * 4; + + const block_q8_K * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + // a.qs: 8 groups, 32 bytes each group (m256i) + __m512i va[8]; + // a.bsum: 8 groups, 2 bytes each group (m128i) + __m512i va_bsum; + __m512 vc[COLS]; + __m512 vd1; + + // packed_B: + const int offset_scales = (QK_K / 2) * TILE_N; + const int offset_mins = (QK_K / 2) * TILE_N + 8 * TILE_N; + const int offset_d0 = (QK_K / 2) * TILE_N + 16 * TILE_N; + const int offset_dmin = (QK_K / 2) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half); + + const __m512i lowMask = _mm512_set1_epi8(0xF); + + auto loadc = [&](int col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + // Notes: vnni formats in QK_K + // a) quants vnni format + // int8 {k/4, n, 4}, viewed as 2d {k/4, 4n}, k = 32 + // from {16, 32} to {8, 64} + // + // b) min vnni format + // int16 {k/2, n, 2}, viewed as 2d {k/2, 2n}, k = 8 + // from {16, 8} to {4, 32} + // + auto compute = [&](int col, int i) { + // load a + if (col == 0) { + for (int k_group = 0; k_group < QK_K / 32; ++k_group) { + va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32))); + } + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + va_bsum = _mm512_castsi128_si512(q8s); + vd1 = _mm512_set1_ps(A[0 * KB + i].d); + } + + // step 1: accumultate the quants + __m512i acc = _mm512_setzero_si512(); + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + const char * b_qs = b_ptr; + for (int k_group = 0; k_group < QK_K / 32; ++k_group) { + __m512i vsum = _mm512_setzero_si512(); + for (int k = 0; k < 8; k += 2) { + __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]); + __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]); + + __m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs); + __m512i vb0 = _mm512_and_si512(bytes, lowMask); + vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); + __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); + + b_qs += 64; + } + // vacc += scale * (q8 @ q4) + const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); + acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); + } + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); + + // step 2: accumulate the mins + __m512i acc_m = _mm512_setzero_si512(); + for (int k = 0; k < 4; ++k) { + __m512i vmask = _mm512_set1_epi32(k); + __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum); + __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32))); + acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); + } + const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin))); + vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](int col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q5_K) + TILE_N * 4; + + const block_q8_K * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + // a.qs: 8 groups, 32 bytes each group (m256i) + __m512i va[8]; + // a.bsum: 8 groups, 2 bytes each group (m128i) + __m512i va_bsum; + __m512 vc[COLS]; + __m512 vd1; + + // packed_B: + const int offset_qh = (QK_K / 2) * TILE_N; + const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N; + const int offset_mins = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 8 * TILE_N; + const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N; + const int offset_dmin = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half); + + const __m512i lowMask = _mm512_set1_epi8(0xF); + + auto loadc = [&](int col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + // Q5_K and Q4_K shares the same vnni formats, refer to notes above. + auto compute = [&](int col, int i) { + // load a + if (col == 0) { + for (int k_group = 0; k_group < QK_K / 32; ++k_group) { + va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32))); + } + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + va_bsum = _mm512_castsi128_si512(q8s); + vd1 = _mm512_set1_ps(A[0 * KB + i].d); + } + + // step 1: accumultate the quants + __m512i acc = _mm512_setzero_si512(); + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + const char * b_qs = b_ptr; + const char * b_qh = b_ptr + offset_qh; + for (int k_group = 0; k_group < QK_K / 32; ++k_group) { + __m512i vsum = _mm512_setzero_si512(); + __m512i hmask0 = _mm512_set1_epi8(0x1); + __m512i hmask1 = _mm512_set1_epi8(0x2); + __m512i hbits = _mm512_loadu_si512((const __m512i *)(b_qh + k_group * 64)); + for (int k = 0; k < 8; k += 2) { + __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]); + __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]); + + __m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs); + __m512i vb0 = _mm512_and_si512(bytes, lowMask); + __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + + __m512i vh0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), k), 4); + __m512i vh1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), k + 1), 4); + + hmask0 = _mm512_slli_epi16(hmask0, 2); + hmask1 = _mm512_slli_epi16(hmask1, 2); + vb0 = _mm512_add_epi8(vb0, vh0); + vb1 = _mm512_add_epi8(vb1, vh1); + + vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); + + b_qs += 64; + } + // vacc += scale * (q8 @ q5) + const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); + acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); + } + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); + + // step 2: accumulate the mins + __m512i acc_m = _mm512_setzero_si512(); + for (int k = 0; k < 4; ++k) { + __m512i vmask = _mm512_set1_epi32(k); + __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum); + __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32))); + acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); + } + const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin))); + vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](int col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q6_K); + + const block_q8_K * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + // load the 256 bytes from A to 4 avx512 vectors + __m512i va[4]; + __m512 vc[COLS]; + __m512 vd1; + + // packed_B: + const int offset_qh = (QK_K / 2) * TILE_N; + const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N; + const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N + 16 * TILE_N; + + // compensation + __m512i vcomp; + + const __m512i m32s = _mm512_set1_epi32(32); + const __m512i lowMask = _mm512_set1_epi8(0xF); + + auto loadc = [&](int col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](int col, int i) { + if (col == 0) { + // load a + va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0)); + va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64)); + va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128)); + va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192)); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); + vcomp = _mm512_mullo_epi32(_mm512_cvtepi16_epi32(q8sums), m32s); + vd1 = _mm512_set1_ps(A[0 * KB + i].d); + } + + // accmulate the quants + __m512i acc = _mm512_setzero_si512(); + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + const char * b_qs = b_ptr; + const char * b_qh = b_ptr + offset_qh; + int mask = 0; + for (int k_group = 0; k_group < QK_K / 16; ++k_group) { + int r = k_group >> 2; + __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + + __m512i vsum = _mm512_setzero_si512(); + __m512i hmask = _mm512_set1_epi8(0x3); + + __m512i bytes = _mm512_loadu_si512(b_qs); + __m512i hbits = _mm512_loadu_si512(b_qh); + __m512i vb0 = _mm512_and_si512(bytes, lowMask); + __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + __m512i vh0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask), 4); + __m512i vh1 = _mm512_slli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 2)), 2); + + vb0 = _mm512_add_epi8(vb0, vh0); + vb1 = _mm512_add_epi8(vb1, vh1); + vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); + b_qs += 64; + + va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + + bytes = _mm512_loadu_si512(b_qs); + vb0 = _mm512_and_si512(bytes, lowMask); + vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + vh0 = _mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 4)); + vh1 = _mm512_srli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 6)), 2); + vb0 = _mm512_add_epi8(vb0, vh0); + vb1 = _mm512_add_epi8(vb1, vh1); + vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); + b_qs += 64; + b_qh += 64; + + // B * A - 32 * A + __m512i vmask = _mm512_set1_epi32(k_group); + vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp)); + + // vacc += scale * (q8 @ q6) + const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); + acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); + } + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](int col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_iq4_xs) + TILE_N * 2; + + const block_q8_K * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + // load the 256 bytes from A to 4 avx512 vectors + __m512i va[4]; + __m512 vc[COLS]; + __m512 vd1; + + // packed_B: + const int offset_scales = (QK_K / 2) * TILE_N ; + const int offset_d0 = (QK_K / 2) * TILE_N + 8 * TILE_N; + + // compensation + __m512i vcomp; + + const __m256i m128s = _mm256_set1_epi16(128); + const __m512i lowMask = _mm512_set1_epi8(0xF); + + const __m512i values128 = _mm512_set_epi8( + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127 + ); + const __m512i off = _mm512_set1_epi8(static_cast(0x80)); + const __m512i values256 = _mm512_add_epi8(values128, off); + + auto loadc = [&](int col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](int col, int i) { + if (col == 0) { + // load a + va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0)); + va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64)); + va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128)); + va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192)); + + // compensation: 128 * A + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); + vcomp = _mm512_castsi256_si512(_mm256_madd_epi16(q8sums, m128s)); + vd1 = _mm512_set1_ps(A[0 * KB + i].d); + } + + // accmulate the quants + __m512i acc = _mm512_setzero_si512(); + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + const char * b_qs = b_ptr; + int mask = 0; + for (int k_group = 0; k_group < QK_K / 32; ++k_group) { + int r = k_group >> 1; + __m512i vmask = _mm512_set1_epi32(k_group); + __m512i vsum = _mm512_setzero_si512(); + for (int k = 0; k < 8; k += 2) { + __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + + __m512i bytes = _mm512_loadu_si512(b_qs); + __m512i vb0 = _mm512_shuffle_epi8(values256, _mm512_and_si512(bytes, lowMask)); + __m512i vb1 = _mm512_shuffle_epi8(values256, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask)); + + vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); + b_qs += 64; + } + // (B + 128) * A - 128 * A + vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp)); + + // vacc += scale * (q8 @ q4) + const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); + acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); + } + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](int col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \ + tinygemm_kernel_vnni::apply( \ + KB, (const char *)wdata + 0 * row_size_A, \ + (const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \ + (float *) dst->data + 0 * N + nb_start, ldc) + +template ::value, int>::type = 0> +void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, TC * RESTRICT C, int ldc) { + using packed_B_t = packed_B_type; + const int TILE_SIZE = get_tile_size(); + const bool need_unpack = do_unpack::value; + + GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N); + const TA * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + const int m0 = std::min(M, TILE_M); + const int m1 = std::max(M - TILE_M, 0); + const int lda = KB * sizeof(TA); + //const int ldb = KB * sizeof(TB); + + static thread_local packed_B_t Tile0[TILE_N * TILE_K]; + static thread_local packed_B_t Tile1[TILE_N * TILE_K]; + static thread_local int8_t Tile23[TILE_M * TILE_K]; + + static thread_local int32_t TileC0[TILE_M * TILE_N * 4]; + static thread_local int32_t TileC1[TILE_M * TILE_N * 4]; + + // double buffering C to interleave avx512 and amx + int32_t * C_cur = TileC0; + int32_t * C_pre = TileC1; + + auto Tile4 = [&](int32_t * base) { return base; }; + auto Tile5 = [&](int32_t * base) { return base + TILE_M * TILE_N; }; + auto Tile6 = [&](int32_t * base) { return base + 2 * TILE_M * TILE_N; }; + auto Tile7 = [&](int32_t * base) { return base + 3 * TILE_M * TILE_N; }; + + if (M == 2 * TILE_M) { + // i = 0 + const char * B_blk0 = B + PACKED_INDEX(0, 0, KB, TILE_SIZE); + const char * B_blk1 = B + PACKED_INDEX(1, 0, KB, TILE_SIZE); + if (need_unpack) { + unpack_B(Tile0, B_blk0); + _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK); + } + + _tile_zero(TMM4); + _tile_loadd(TMM2, A[0].qs, lda); + _tile_dpbssd(TMM4, TMM2, TMM0); + _tile_stored(TMM4, Tile4(C_pre), TILE_N * sizeof(int32_t)); + + _tile_zero(TMM5); + _tile_loadd(TMM3, A[TILE_M * KB + 0].qs, lda); + _tile_dpbssd(TMM5, TMM3, TMM0); + _tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t)); + + if (need_unpack) { + unpack_B(Tile1, B_blk0); + _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK); + } + + _tile_zero(TMM6); + _tile_dpbssd(TMM6, TMM2, TMM1); + _tile_stored(TMM6, Tile6(C_pre), TILE_N * sizeof(int32_t)); + + _tile_zero(TMM7); + _tile_dpbssd(TMM7, TMM3, TMM1); + _tile_stored(TMM7, Tile7(C_pre), TILE_N * sizeof(int32_t)); + + for (int i = 1; i < KB; ++i) { + // index of previous iter + const int ii = i - 1; + const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE); + const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE); + GGML_DISPATCH_BOOL(ii > 0, is_acc, [&] { + if (need_unpack) { + unpack_B(Tile0, B_blk0); + _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK); + } + _tile_zero(TMM4); + _tile_loadd(TMM2, A[i].qs, lda); + acc_C::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); + + _tile_dpbssd(TMM4, TMM2, TMM0); + _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t)); + + _tile_zero(TMM5); + _tile_loadd(TMM3, A[TILE_M * KB + i].qs, lda); + acc_C::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); + + _tile_dpbssd(TMM5, TMM3, TMM0); + _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t)); + + if (need_unpack) { + unpack_B(Tile1, B_blk1); + _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK); + } + _tile_zero(TMM6); + acc_C::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); + + _tile_dpbssd(TMM6, TMM2, TMM1); + _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t)); + + _tile_zero(TMM7); + acc_C::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); + + _tile_dpbssd(TMM7, TMM3, TMM1); + _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t)); + + std::swap(C_cur, C_pre); + }); + } + // final accumulation + { + int ii = KB - 1; + acc_C::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); + acc_C::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); + acc_C::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); + acc_C::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); + } + } else { + for (int i = 0; i < KB; ++i) { + _tile_zero(TMM4); + _tile_zero(TMM6); + if (m1 != 0) { + _tile_zero(TMM5); + _tile_zero(TMM7); + } + + const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE); + const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE); + if (need_unpack) { + unpack_B(Tile0, B_blk0); + _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK); + } + + if (need_unpack) { + unpack_B(Tile1, B_blk1); + _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK); + } + + if (m0 == TILE_M) { + _tile_loadd(TMM2, A[i].qs, lda); + } else { + unpack_A(Tile23, &A[i], KB, m0); + _tile_loadd(TMM2, Tile23, TILE_K); + } + + _tile_dpbssd(TMM4, TMM2, TMM0); + _tile_dpbssd(TMM6, TMM2, TMM1); + + _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t)); + _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t)); + + GGML_DISPATCH_BOOL(i > 0, is_acc, [&] { + acc_C::apply(C, ldc, Tile4(C_cur), &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0); + acc_C::apply(C + TILE_N, ldc, Tile6(C_cur), &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0); + }); + + if (m1 != 0) { + unpack_A(Tile23, &A[TILE_M * KB + i], KB, m1); + _tile_loadd(TMM3, Tile23, TILE_K); + + _tile_dpbssd(TMM5, TMM3, TMM0); + _tile_dpbssd(TMM7, TMM3, TMM1); + _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t)); + _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t)); + GGML_DISPATCH_BOOL(i > 0, is_acc, [&] { + acc_C::apply(C + TILE_M * ldc, ldc, Tile5(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1); + acc_C::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1); + }); + } + } + } + return; +} + +template ::value, int>::type = 0> +void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + static_assert(std::is_same::value); + const int TILE_SIZE = get_tile_size(); + + GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N); + const TA * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + const int m0 = std::min(M, TILE_M); + const int m1 = std::max(M - TILE_M, 0); + //const int lda = KB * sizeof(TA); + + static thread_local int8_t Tile0[TILE_N * TILE_K]; + static thread_local int8_t Tile1[TILE_N * TILE_K]; + static thread_local int8_t Tile23[TILE_M * TILE_K]; + + // mat mul result for each group + static thread_local int32_t Tile4[TILE_M * TILE_N]; + static thread_local int32_t Tile5[TILE_M * TILE_N]; + static thread_local int32_t Tile6[TILE_M * TILE_N]; + static thread_local int32_t Tile7[TILE_M * TILE_N]; + + // sum of each QK_K block, contains 8 groups, int32 + static thread_local int32_t Sumi4[TILE_M * TILE_N]; + static thread_local int32_t Sumi5[TILE_M * TILE_N]; + static thread_local int32_t Sumi6[TILE_M * TILE_N]; + static thread_local int32_t Sumi7[TILE_M * TILE_N]; + + const int k_group_size = std::is_same::value ? 16 : 32; + for (int i = 0; i < KB; ++i) { + // step 1: accumulate the quants across 8 groups, each group with 32 + for (int k = 0; k < QK_K / k_group_size; ++k) { + GGML_DISPATCH_BOOL(k > 0, is_acc, [&] { + _tile_zero(TMM4); + _tile_zero(TMM6); + + unpack_B(Tile0, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k); + _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); + + unpack_B(Tile1, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k); + _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); + + unpack_A(Tile23, &A[i], KB, k, m0); + _tile_loadd(TMM2, Tile23, TILE_K); + + _tile_dpbssd(TMM4, TMM2, TMM0); + _tile_dpbssd(TMM6, TMM2, TMM1); + + _tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); + _tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); + + scale_C(Tile4, Sumi4, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m0); + scale_C(Tile6, Sumi6, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m0); + + if (m1 != 0) { + _tile_zero(TMM5); + _tile_zero(TMM7); + + unpack_A(Tile23, &A[TILE_M * KB + i], KB, k, m1); + _tile_loadd(TMM3, Tile23, TILE_K); + + _tile_dpbssd(TMM5, TMM3, TMM0); + _tile_dpbssd(TMM7, TMM3, TMM1); + + _tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); + _tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); + + scale_C(Tile5, Sumi5, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m1); + scale_C(Tile7, Sumi7, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m1); + } + }); + } + + // step 2: accmulate the mins + GGML_DISPATCH_BOOL(i > 0, is_acc, [&] { + acc_C::apply(C, ldc, Sumi4, &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0); + acc_C::apply(C + TILE_N, ldc, Sumi6, &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0); + if (m1 != 0) { + acc_C::apply(C + TILE_M * ldc, ldc, Sumi5, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1); + acc_C::apply(C + TILE_M * ldc + TILE_N, ldc, Sumi7, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1); + } + }); + } + return; +} + +} // anonymous namespace + +// get the packed tensor size for quantized weights +size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor) { + const enum ggml_type TYPE = tensor->type; + + const int K = tensor->ne[0]; // ne0: in_features + const int N = tensor->ne[1]; // ne1: out_features + + auto get_tensor_size = [&] { + size_t row_size_B{0}; + GGML_DISPATCH_QTYPES(TYPE, [&] { + row_size_B = get_row_size(K); + }); + return N * row_size_B; + }; + + if (qtype_has_amx_kernels(TYPE)) { + return get_tensor_size(); + } else { + // for f16, bf16 we don't do packing + return ggml_nbytes(tensor); + } +} + +// pack weight to vnni format +void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + + size_t alloc_size = ggml_backend_amx_get_alloc_size(tensor); + GGML_ASSERT(alloc_size == size); + + const enum ggml_type TYPE = tensor->type; + + const int K = tensor->ne[0]; // ne0: in_features + const int N = tensor->ne[1]; // ne1: out_features + +#if defined(_OPENMP) + // the buffer ctx is not initialized when .set_tensor is called + int n_threads = omp_get_num_threads(); +#else + int n_threads = 1; +#endif + + GGML_DISPATCH_QTYPES(TYPE, [&] { + convert_B_packed_format((void *)((char *)tensor->data + offset), (const type *)data, N, K, n_threads); + }); +} + +// NB: mixed dtype gemm with Advanced Matrix Extensions (Intel AMX) +// +// src0: weight in shape of {N, K}, quantized +// src1: input in shape of {M, K}, float32 +// dst: output in shape of {M, N}, float32 +// +// the function performs: dst = src1 @ src0.T +// +void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor * dst) { + struct ggml_tensor * src0 = dst->src[0]; + struct ggml_tensor * src1 = dst->src[1]; + + const enum ggml_type TYPE = src0->type; + + const int n_threads = ctx->n_threads; + + // f16 only has avx512 kernels for now, + // amx kernels will be added once 6th gen xeon is released. + const bool is_floating_type = TYPE == GGML_TYPE_F16; + + const int M = dst->ne[1]; + const int N = dst->ne[0]; + const int K = src0->ne[0]; + const int ldc = dst->nb[1] / dst->nb[0]; + + if (is_floating_type) { + constexpr int BLOCK_M = 4; + constexpr int BLOCK_N = 6; + const int MB = div_up(M, BLOCK_M); + const int NB = div_up(N, BLOCK_N); + + parallel_for(n_threads, MB * NB, [&](int begin, int end) { + GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] { + for (int i = begin; i < end; ++i) { + int mb = i / NB; + int nb = i % NB; + + int mb_start = mb * BLOCK_M; + int mb_size = std::min(BLOCK_M, M - mb_start); + int nb_start = nb * BLOCK_N; + int nb_size = std::min(BLOCK_N, N - nb_start); + + switch (mb_size << 4 | nb_size) { + case 0x12: LAUNCH_TINYGEMM_KERNEL_AVX(1, 2); break; + case 0x14: LAUNCH_TINYGEMM_KERNEL_AVX(1, 4); break; + case 0x16: LAUNCH_TINYGEMM_KERNEL_AVX(1, 6); break; + case 0x22: LAUNCH_TINYGEMM_KERNEL_AVX(2, 2); break; + case 0x24: LAUNCH_TINYGEMM_KERNEL_AVX(2, 4); break; + case 0x26: LAUNCH_TINYGEMM_KERNEL_AVX(2, 6); break; + case 0x32: LAUNCH_TINYGEMM_KERNEL_AVX(3, 2); break; + case 0x34: LAUNCH_TINYGEMM_KERNEL_AVX(3, 4); break; + case 0x36: LAUNCH_TINYGEMM_KERNEL_AVX(3, 6); break; + case 0x42: LAUNCH_TINYGEMM_KERNEL_AVX(4, 2); break; + case 0x44: LAUNCH_TINYGEMM_KERNEL_AVX(4, 4); break; + case 0x46: LAUNCH_TINYGEMM_KERNEL_AVX(4, 6); break; + default: fprintf(stderr, "Unexpected block size!\n"); + } + } + }); + }); + return; + } + + // pointer to work space, used convert A from float to quantized type + void * wdata = nullptr; + + //TODO: performance improvement: merge quant A + GGML_DISPATCH_QTYPES(TYPE, [&] { + const size_t row_size_A = K / blck_size * sizeof(vec_dot_type); + const size_t desired_wsize = M * row_size_A; + if (ctx->work_size < desired_wsize) { + ctx->work_data.reset(new char[desired_wsize]); + ctx->work_size = desired_wsize; + } + wdata = ctx->work_data.get(); + + // Q4_0, Q4_1, Q8_0 handles 1 TILE_K per blck_size + // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size + GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size); + + const float * A_data = static_cast(src1->data); + for (int m = 0; m < M; ++m) { + from_float(A_data + m * K, (char *)wdata + m * row_size_A, K); + } + }); + + if (M == 1) { + // MB = 1 and handle 8 tiles in each block + constexpr int kTilesN = 4; + constexpr int BLOCK_N = TILE_N * kTilesN; + const int NB = div_up(N, BLOCK_N); + + parallel_for(n_threads, NB, [&](int begin, int end) { + GGML_DISPATCH_QTYPES(TYPE, [&] { + const int KB = K / blck_size; + const int TILE_SIZE = get_tile_size(); + const int row_size_A = KB * sizeof(vec_dot_type); + for (int i = begin; i < end; ++i) { + int nb = i; + int nb_start = nb * BLOCK_N; + int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96 + + switch (nb_size) { + //case 160: LAUNCH_TINYGEMM_KERNEL_VNNI(160); break; + case 128: LAUNCH_TINYGEMM_KERNEL_VNNI(128); break; + case 96: LAUNCH_TINYGEMM_KERNEL_VNNI(96); break; + case 64: LAUNCH_TINYGEMM_KERNEL_VNNI(64); break; + case 32: LAUNCH_TINYGEMM_KERNEL_VNNI(32); break; + default: fprintf(stderr, "Unexpected n block size!\n"); + } + } + }); + }); + return; + } + + // handle 4 tiles at a tile + constexpr int BLOCK_M = TILE_M * 2; + constexpr int BLOCK_N = TILE_N * 2; + const int MB = div_up(M, BLOCK_M); + const int NB = div_up(N, BLOCK_N); + + parallel_for(n_threads, MB * NB, [&](int begin, int end) { + // init tile config for each thread + ggml_tile_config_init(); + + GGML_DISPATCH_QTYPES(TYPE, [&] { + const int KB = K / blck_size; + const int TILE_SIZE = get_tile_size(); + const int row_size_A = KB * sizeof(vec_dot_type); + + for (int i = begin; i < end; ++i) { + int mb = i / NB; + int nb = i % NB; + + int mb_start = mb * BLOCK_M; + int mb_size = std::min(BLOCK_M, M - mb_start); + int nb_start = nb * BLOCK_N; + int nb_size = BLOCK_N; + + tinygemm_kernel_amx( + mb_size, nb_size, KB, + (const char *)wdata + mb_start * row_size_A, + (const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE), + (float *) dst->data + mb_start * N + nb_start, ldc); + } + }); + }); +} + +#else // if defined(__AMX_INT8__) + +void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor * dst) { + fprintf(stderr, "GGML is not compiled with AMX support!\n"); + + GGML_UNUSED(ctx); + GGML_UNUSED(dst); +} + +#endif // if defined(__AMX_INT8__) diff --git a/src/ggml-amx/mmq.h b/src/ggml-amx/mmq.h new file mode 100644 index 000000000..cf0920620 --- /dev/null +++ b/src/ggml-amx/mmq.h @@ -0,0 +1,17 @@ +#pragma once +#include "common.h" +#include + +#ifdef __cplusplus +extern "C" { +#endif + +size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor); + +void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + +void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor * dst); + +#ifdef __cplusplus +} +#endif diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index fa280b529..21c9f5e38 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -1151,8 +1151,8 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) { GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer)); - char * src_ptr = (char *) src->data; - char * dst_ptr = (char *) dst; + const char * src_ptr = (const char *) src->data; + char * dst_ptr = (char *) dst; const int64_t ne0 = src->ne[0]; const int64_t nb0 = src->nb[0]; @@ -1162,7 +1162,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( const enum ggml_type type = src->type; const int64_t ts = ggml_type_size(type); const int64_t bs = ggml_blck_size(type); - int64_t i1_diff = i1_high - i1_low; + const int64_t i1_diff = i1_high - i1_low; const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; if (nb0 == ts && nb1 == ts*ne0/bs) { @@ -1479,13 +1479,18 @@ static void ggml_cuda_op_mul_mat( if (src0_is_contiguous) { dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data; } else { - dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), ggml_nbytes(src0)); + // If src0 is not contiguous it will be copied to a temporary buffer. + // This buffer needs to be cleared entirely because multiple regions will function as padding. + const size_t nbytes_data = ggml_nbytes(src0); + const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING); + dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding); + CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream)); } - // If src0 is on a temporary compute buffers (partial offloading) there may be some padding that needs to be cleared: + // If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared: if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) { - const int64_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00); - const int64_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING); + const size_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00); + const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING); CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream)); } diff --git a/src/ggml-cuda/mmq.cu b/src/ggml-cuda/mmq.cu index 4935f8818..ae5c68ab3 100644 --- a/src/ggml-cuda/mmq.cu +++ b/src/ggml-cuda/mmq.cu @@ -8,8 +8,6 @@ void ggml_cuda_op_mul_mat_q( const int64_t ne00 = src0->ne[0]; - const int64_t nb01 = src0->nb[1]; - const int64_t ne10 = src1->ne[0]; const int64_t ne11 = src1->ne[1]; GGML_ASSERT(ne10 % QK8_1 == 0); @@ -17,7 +15,7 @@ void ggml_cuda_op_mul_mat_q( const int64_t ne0 = dst->ne[0]; const int64_t row_diff = row_high - row_low; - const int64_t stride00 = nb01 / ggml_type_size(src0->type); + const int64_t stride00 = ne00 / ggml_blck_size(src0->type); int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; diff --git a/src/ggml-metal.m b/src/ggml-metal.m index e9541441c..80c08f15b 100644 --- a/src/ggml-metal.m +++ b/src/ggml-metal.m @@ -1015,19 +1015,21 @@ static void ggml_metal_encode_node( id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; - //GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); - //if (src0) { - // GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, - // ggml_is_contiguous(src0), src0->name); - //} - //if (src1) { - // GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, - // ggml_is_contiguous(src1), src1->name); - //} - //if (dst) { - // GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, - // dst->name); - //} +#if 0 + GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); + if (src0) { + GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, + ggml_is_contiguous(src0), src0->name); + } + if (src1) { + GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, + ggml_is_contiguous(src1), src1->name); + } + if (dst) { + GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, + dst->name); + } +#endif id device = ctx_dev->mtl_device; @@ -1810,14 +1812,16 @@ static void ggml_metal_encode_node( [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:13]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:14]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:15]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:16]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { @@ -1986,20 +1990,22 @@ static void ggml_metal_encode_node( [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:17]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:19]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:20]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { + src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || + src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { @@ -2048,6 +2054,9 @@ static void ggml_metal_encode_node( GGML_ASSERT(src1t == GGML_TYPE_F32); + GGML_ASSERT(ne03 == 1); + GGML_ASSERT(ne13 == 1); + // find the break-even point where the matrix-matrix kernel becomes more efficient compared // to the matrix-vector kernel // ne20 = n_used_experts diff --git a/src/ggml-metal.metal b/src/ggml-metal.metal index 71b58be1f..defde6246 100644 --- a/src/ggml-metal.metal +++ b/src/ggml-metal.metal @@ -777,10 +777,10 @@ kernel void kernel_ssm_conv_f32( const int64_t i3 = tgpig.z; const int64_t nc = ne10; - const int64_t ncs = ne00; - const int64_t nr = ne01; - const int64_t n_t = ne1; - const int64_t n_s = ne2; + //const int64_t ncs = ne00; + //const int64_t nr = ne01; + //const int64_t n_t = ne1; + //const int64_t n_s = ne2; device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02); device const float * c = (device const float *) ((device const char *) src1 + ir*nb11); @@ -834,9 +834,9 @@ kernel void kernel_ssm_scan_f32( const int64_t i3 = tgpig.y; const int64_t nc = d_state; - const int64_t nr = d_inner; + //const int64_t nr = d_inner; const int64_t n_t = n_seq_tokens; - const int64_t n_s = n_seqs; + //const int64_t n_s = n_seqs; for (int64_t i2 = 0; i2 < n_t; ++i2) { device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02); @@ -1064,17 +1064,18 @@ kernel void kernel_group_norm( inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; - float2 acc = 0.f; + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); + device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2); - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); + for (int i = 0; i < 8; i += 2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F); + acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0); + acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000); } - return d * (sumy * -8.f + acc[0] + acc[1]); + + return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]); } // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) @@ -1085,17 +1086,18 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre float d = qb_curr->d; float m = qb_curr->m; - float2 acc = 0.f; + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2); for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F); + acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0); + acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000); } - return d * (acc[0] + acc[1]) + sumy * m; + + return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; } // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) @@ -1105,18 +1107,19 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; - float2 acc = 0.f; + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); const uint32_t qh = *((device const uint32_t *)qb_curr->qh); for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) - + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); - acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) - + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)); + acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)); + acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); } - return d * (sumy * -16.f + acc[0] + acc[1]); + + return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]); } // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) @@ -1127,18 +1130,19 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre float d = qb_curr->d; float m = qb_curr->m; - float2 acc = 0.f; + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); const uint32_t qh = *((device const uint32_t *)qb_curr->qh); for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) - + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); - acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) - + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)); + acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)); + acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); } - return d * (acc[0] + acc[1]) + sumy * m; + + return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; } // putting them in the kernel cause a significant performance penalty @@ -1156,14 +1160,22 @@ void mul_vec_q_n_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, uint r3, threadgroup int8_t * shared_values, - uint3 tgpig, uint tiisg, uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK4_0; const int r0 = tgpig.x; @@ -1175,10 +1187,19 @@ void mul_vec_q_n_f32_impl( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_q_type * x = (device const block_q_type *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + //device const block_q_type * x = (device const block_q_type *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); + + // pointers to src0 rows + device const block_q_type * ax[nr]; + for (int row = 0; row < nr; ++row) { + const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); + } float yl[16]; // src1 vector cache float sumf[nr] = {0.f}; @@ -1190,19 +1211,22 @@ void mul_vec_q_n_f32_impl( // each thread in a SIMD group deals with half a block. for (int ib = ix; ib < nb; ib += nw/2) { - float sumy = 0; + float sumy[2] = { 0.f, 0.f }; + +#pragma unroll for (int i = 0; i < 8; i += 2) { - sumy += yb[i] + yb[i+1]; - yl[i+0] = yb[i+ 0]; - yl[i+1] = yb[i+ 1]/256.f; + sumy[0] += yb[i + 0] + yb[i + 1]; + yl[i + 0] = yb[i + 0]; + yl[i + 1] = yb[i + 1]/256.f; - sumy += yb[i+16] + yb[i+17]; - yl[i+8] = yb[i+16]/16.f; - yl[i+9] = yb[i+17]/4096.f; + sumy[1] += yb[i + 16] + yb[i + 17]; + yl[i + 8] = yb[i + 16]/16.f; + yl[i + 9] = yb[i + 17]/4096.f; } +#pragma unroll for (int row = 0; row < nr; row++) { - sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il); + sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); } yb += QK4_0 * 16; @@ -1226,12 +1250,14 @@ kernel void kernel_mul_mv_q4_0_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1239,7 +1265,7 @@ kernel void kernel_mul_mv_q4_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -1252,12 +1278,14 @@ kernel void kernel_mul_mv_q4_1_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1265,7 +1293,7 @@ kernel void kernel_mul_mv_q4_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -1278,12 +1306,14 @@ kernel void kernel_mul_mv_q5_0_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1291,7 +1321,7 @@ kernel void kernel_mul_mv_q5_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -1304,12 +1334,14 @@ kernel void kernel_mul_mv_q5_1_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1317,7 +1349,7 @@ kernel void kernel_mul_mv_q5_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } @@ -1330,8 +1362,14 @@ void kernel_mul_mv_q8_0_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -1354,10 +1392,19 @@ void kernel_mul_mv_q8_0_f32_impl( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + //device const block_q8_0 * x = (device const block_q8_0 *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); + + // pointers to src0 rows + device const block_q8_0 * ax[nr]; + for (int row = 0; row < nr; ++row) { + const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); + } float yl[NB_Q8_0]; float sumf[nr]={0.f}; @@ -1374,12 +1421,12 @@ void kernel_mul_mv_q8_0_f32_impl( } for (int row = 0; row < nr; row++) { - device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; + device const int8_t * qs = ax[row][ib].qs + NB_Q8_0*il; float sumq = 0.f; for (int iq = 0; iq < NB_Q8_0; ++iq) { sumq += qs[iq] * yl[iq]; } - sumf[row] += sumq*x[ib+row*nb].d; + sumf[row] += sumq*ax[row][ib].d; } yb += NB_Q8_0 * nw; @@ -1404,12 +1451,14 @@ kernel void kernel_mul_mv_q8_0_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1417,7 +1466,7 @@ kernel void kernel_mul_mv_q8_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } #define N_MV_T_T 4 @@ -1433,12 +1482,14 @@ void kernel_mul_mv_impl( uint64_t nb00, uint64_t nb01, uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne11, int64_t ne12, uint64_t nb10, uint64_t nb11, uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -1452,7 +1503,7 @@ void kernel_mul_mv_impl( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; device const T0 * x = (device const T0 *) (src0 + offset0); @@ -1463,7 +1514,9 @@ void kernel_mul_mv_impl( break; } - device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12); + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + device const T1 * y = (device const T1 *) (src1 + offset1); float sumf = 0; for (int i = tiisg; i < ne00; i += 32) { @@ -1483,7 +1536,9 @@ void kernel_mul_mv_impl( break; } - device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12); + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + device const T1 * y = (device const T1 *) (src1 + offset1); device const T14 * y4 = (device const T14 *) y; float sumf = 0; @@ -1511,12 +1566,14 @@ kernel void kernel_mul_mv( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1533,12 +1590,14 @@ kernel void kernel_mul_mv( nb00, nb01, nb02, + nb03, ne10, ne11, ne12, nb10, nb11, nb12, + nb13, ne0, ne1, r2, @@ -1564,12 +1623,14 @@ kernel void kernel_mul_mv_1row( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1584,10 +1645,11 @@ kernel void kernel_mul_mv_1row( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; device const T * x = (device const T *) (src0 + offset0); - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const float * y = (device const float *) (src1 + offset1); float sumf = 0; if (ne00 < 128) { @@ -1631,12 +1693,14 @@ kernel void kernel_mul_mv_l4( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1651,12 +1715,14 @@ kernel void kernel_mul_mv_l4( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; device const T4 * x4 = (device const T4 *) (src0 + offset0); for (int r1 = 0; r1 < nrows; ++r1) { - device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + device const float4 * y4 = (device const float4 *) (src1 + offset1); float sumf = 0; for (int i = tiisg; i < ne00/4; i += 32) { @@ -3416,8 +3482,14 @@ void kernel_mul_mv_q2_K_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -3433,21 +3505,19 @@ void kernel_mul_mv_q2_K_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_q2_K * x = (device const block_q2_K *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; - const int step = sizeof(block_q2_K) * nb; - const int ix = tiisg/8; // 0...3 const int it = tiisg%8; // 0...7 const int iq = it/4; // 0 or 1 @@ -3492,9 +3562,9 @@ void kernel_mul_mv_q2_K_f32_impl( (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); - qs += step/2; - sc += step; - dh += step/2; + qs += nb01/2; + sc += nb01; + dh += nb01/2; } y4 += 4 * QK_K; @@ -3519,12 +3589,14 @@ kernel void kernel_mul_mv_q2_K_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -3533,7 +3605,7 @@ kernel void kernel_mul_mv_q2_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q3_K_f32_impl( @@ -3543,8 +3615,14 @@ void kernel_mul_mv_q3_K_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -3565,10 +3643,11 @@ void kernel_mul_mv_q3_K_f32_impl( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_q3_K * x = (device const block_q3_K *) ((device char *) src0 + offset0); + device const float * yy = (device const float *) ((device char *) src1 + offset1); float yl[32]; @@ -3608,8 +3687,6 @@ void kernel_mul_mv_q3_K_f32_impl( const int q_offset = 32*ip + l0; const int y_offset = 128*ip + 32*il + l0; - const int step = sizeof(block_q3_K) * nb / 2; - device const float * y1 = yy + ix*QK_K + y_offset; uint32_t scales32, aux32; @@ -3619,7 +3696,6 @@ void kernel_mul_mv_q3_K_f32_impl( float sumf1[2] = {0.f}; float sumf2[2] = {0.f}; for (int i = ix; i < nb; i += 4) { - for (int l = 0; l < 8; ++l) { yl[l+ 0] = y1[l+ 0]; yl[l+ 8] = y1[l+16]; @@ -3633,7 +3709,6 @@ void kernel_mul_mv_q3_K_f32_impl( device const half * dh = &x[i].d; for (int row = 0; row < 2; ++row) { - const float d_all = (float)dh[0]; scales16[0] = a[4]; @@ -3673,15 +3748,13 @@ void kernel_mul_mv_q3_K_f32_impl( sumf1[row] += d1 * (scales[1] - 32); sumf2[row] += d2 * (scales[3] - 32); - q += step; - h += step; - a += step; - dh += step; - + q += nb01/2; + h += nb01/2; + a += nb01/2; + dh += nb01/2; } y1 += 4 * QK_K; - } for (int row = 0; row < 2; ++row) { @@ -3706,12 +3779,14 @@ kernel void kernel_mul_mv_q3_K_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -3720,7 +3795,7 @@ kernel void kernel_mul_mv_q3_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q4_K_f32_impl( @@ -3730,8 +3805,14 @@ void kernel_mul_mv_q4_K_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -3756,29 +3837,26 @@ void kernel_mul_mv_q4_K_f32_impl( const int im = tgpig.z; //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int first_row = r0 * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_q4_K * x = (device const block_q4_K *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[16]; float yh[16]; float sumf[N_DST]={0.f}, all_sum; - const int step = sizeof(block_q4_K) * nb / 2; - device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; uint16_t sc16[4]; thread const uint8_t * sc8 = (thread const uint8_t *)sc16; for (int ib = ix; ib < nb; ib += 4) { - float4 sumy = {0.f, 0.f, 0.f, 0.f}; for (int i = 0; i < 8; ++i) { yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; @@ -3792,7 +3870,6 @@ void kernel_mul_mv_q4_K_f32_impl( device const half * dh = &x[ib].d; for (int row = 0; row < N_DST; row++) { - sc16[0] = sc[0] & kmask1; sc16[1] = sc[2] & kmask1; sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); @@ -3821,9 +3898,9 @@ void kernel_mul_mv_q4_K_f32_impl( (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - q1 += step; - sc += step; - dh += step; + q1 += nb01/2; + sc += nb01/2; + dh += nb01/2; } y4 += 4 * QK_K; @@ -3848,12 +3925,14 @@ kernel void kernel_mul_mv_q4_K_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -3862,7 +3941,7 @@ kernel void kernel_mul_mv_q4_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q5_K_f32_impl( @@ -3872,8 +3951,14 @@ void kernel_mul_mv_q5_K_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -3894,15 +3979,14 @@ void kernel_mul_mv_q5_K_f32_impl( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_q5_K * x = (device const block_q5_K *) ((device char *) src0 + offset0); + device const float * yy = (device const float *) ((device char *) src1 + offset1); float sumf[2]={0.f}; - const int step = sizeof(block_q5_K) * nb; - float yl[16], yh[16]; const uint16_t kmask1 = 0x3f3f; @@ -3930,7 +4014,6 @@ void kernel_mul_mv_q5_K_f32_impl( device const float * y1 = yy + ix*QK_K + y_offset; for (int i = ix; i < nb; i += 4) { - device const uint8_t * q1 = x[i].qs + q_offset; device const uint8_t * qh = x[i].qh + l0; device const half * dh = &x[i].d; @@ -3946,7 +4029,6 @@ void kernel_mul_mv_q5_K_f32_impl( } for (int row = 0; row < 2; ++row) { - device const uint8_t * q2 = q1 + 64; sc16[0] = a[0] & kmask1; @@ -3975,15 +4057,13 @@ void kernel_mul_mv_q5_K_f32_impl( sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - q1 += step; - qh += step; - dh += step/2; - a += step/2; - + q1 += nb01; + qh += nb01; + dh += nb01/2; + a += nb01/2; } y1 += 4 * QK_K; - } for (int row = 0; row < 2; ++row) { @@ -4005,12 +4085,14 @@ kernel void kernel_mul_mv_q5_K_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4019,7 +4101,7 @@ kernel void kernel_mul_mv_q5_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q6_K_f32_impl( @@ -4029,8 +4111,14 @@ void kernel_mul_mv_q6_K_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4056,10 +4144,11 @@ void kernel_mul_mv_q6_K_f32_impl( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_q6_K * x = (device const block_q6_K *) ((device char *) src0 + offset0); + device const float * yy = (device const float *) ((device char *) src1 + offset1); float sumf = 0; @@ -4115,12 +4204,14 @@ kernel void kernel_mul_mv_q6_K_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4129,7 +4220,7 @@ kernel void kernel_mul_mv_q6_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } // ======================= "True" 2-bit @@ -4141,8 +4232,14 @@ void kernel_mul_mv_iq2_xxs_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4158,15 +4255,15 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_iq2_xxs * x = (device const block_iq2_xxs *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4219,8 +4316,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl( } sumf[row] += d * sum; - dh += nb*sizeof(block_iq2_xxs)/2; - q2 += nb*sizeof(block_iq2_xxs)/2; + dh += nb01/2; + q2 += nb01/2; } y4 += 32 * 32; @@ -4245,12 +4342,14 @@ kernel void kernel_mul_mv_iq2_xxs_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4260,7 +4359,7 @@ kernel void kernel_mul_mv_iq2_xxs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq2_xs_f32_impl( @@ -4270,8 +4369,14 @@ void kernel_mul_mv_iq2_xs_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4287,15 +4392,15 @@ void kernel_mul_mv_iq2_xs_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_iq2_xs * x = (device const block_iq2_xs *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4357,9 +4462,9 @@ void kernel_mul_mv_iq2_xs_f32_impl( } sumf[row] += d1 * sum1 + d2 * sum2; - dh += nb*sizeof(block_iq2_xs)/2; - q2 += nb*sizeof(block_iq2_xs)/2; - sc += nb*sizeof(block_iq2_xs); + dh += nb01/2; + q2 += nb01/2; + sc += nb01; } y4 += 32 * 32; @@ -4384,12 +4489,14 @@ kernel void kernel_mul_mv_iq2_xs_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4399,7 +4506,7 @@ kernel void kernel_mul_mv_iq2_xs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq3_xxs_f32_impl( @@ -4409,8 +4516,14 @@ void kernel_mul_mv_iq3_xxs_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4426,15 +4539,15 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_iq3_xxs * x = (device const block_iq3_xxs *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4489,9 +4602,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl( } sumf[row] += d * (sum[0] + sum[1]); - dh += nb*sizeof(block_iq3_xxs)/2; - q3 += nb*sizeof(block_iq3_xxs); - gas += nb*sizeof(block_iq3_xxs)/2; + dh += nb01/2; + q3 += nb01; + gas += nb01/2; } y4 += 32 * 32; @@ -4516,12 +4629,14 @@ kernel void kernel_mul_mv_iq3_xxs_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4531,7 +4646,7 @@ kernel void kernel_mul_mv_iq3_xxs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq3_s_f32_impl( @@ -4541,8 +4656,14 @@ void kernel_mul_mv_iq3_s_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4558,15 +4679,15 @@ void kernel_mul_mv_iq3_s_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_iq3_s * x = (device const block_iq3_s *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4619,11 +4740,11 @@ void kernel_mul_mv_iq3_s_f32_impl( } sumf[row] += d * (sum[0] + sum[1]); - dh += nb*sizeof(block_iq3_s)/2; - qs += nb*sizeof(block_iq3_s); - qh += nb*sizeof(block_iq3_s); - sc += nb*sizeof(block_iq3_s); - signs += nb*sizeof(block_iq3_s); + dh += nb01/2; + qs += nb01; + qh += nb01; + sc += nb01; + signs += nb01; } y4 += 32 * 32; @@ -4648,12 +4769,14 @@ kernel void kernel_mul_mv_iq3_s_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4663,7 +4786,7 @@ kernel void kernel_mul_mv_iq3_s_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq2_s_f32_impl( @@ -4673,8 +4796,14 @@ void kernel_mul_mv_iq2_s_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4690,15 +4819,15 @@ void kernel_mul_mv_iq2_s_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_iq2_s * x = (device const block_iq2_s *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4752,11 +4881,11 @@ void kernel_mul_mv_iq2_s_f32_impl( } sumf[row] += d1 * sum[0] + d2 * sum[1]; - dh += nb*sizeof(block_iq2_s)/2; - qs += nb*sizeof(block_iq2_s); - qh += nb*sizeof(block_iq2_s); - sc += nb*sizeof(block_iq2_s); - signs += nb*sizeof(block_iq2_s); + dh += nb01/2; + qs += nb01; + qh += nb01; + sc += nb01; + signs += nb01; } y4 += 32 * 32; @@ -4781,12 +4910,14 @@ kernel void kernel_mul_mv_iq2_s_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4796,7 +4927,7 @@ kernel void kernel_mul_mv_iq2_s_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq1_s_f32_impl( @@ -4806,8 +4937,14 @@ void kernel_mul_mv_iq1_s_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4823,14 +4960,15 @@ void kernel_mul_mv_iq1_s_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + device const block_iq1_s * x = (device const block_iq1_s *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4873,9 +5011,9 @@ void kernel_mul_mv_iq1_s_f32_impl( } sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1); - dh += nb*sizeof(block_iq1_s)/2; - qs += nb*sizeof(block_iq1_s); - qh += nb*sizeof(block_iq1_s)/2; + dh += nb01/2; + qs += nb01; + qh += nb01/2; } y4 += 32 * 32; @@ -4896,8 +5034,14 @@ void kernel_mul_mv_iq1_m_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4913,14 +5057,15 @@ void kernel_mul_mv_iq1_m_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + device const block_iq1_m * x = (device const block_iq1_m *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4972,9 +5117,9 @@ void kernel_mul_mv_iq1_m_f32_impl( sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) + (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1)); - sc += nb*sizeof(block_iq1_m)/2; - qs += nb*sizeof(block_iq1_m); - qh += nb*sizeof(block_iq1_m); + sc += nb01/2; + qs += nb01; + qh += nb01; } y4 += 32 * 32; @@ -4995,8 +5140,14 @@ void kernel_mul_mv_iq4_nl_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -5012,14 +5163,15 @@ void kernel_mul_mv_iq4_nl_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * 2 + sgitg) * 2; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + device const block_iq4_nl * x = (device const block_iq4_nl *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); const int ix = tiisg/2; // 0...15 const int it = tiisg%2; // 0 or 1 @@ -5089,8 +5241,14 @@ void kernel_mul_mv_iq4_xs_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -5106,14 +5264,15 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * 2 + sgitg) * 2; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + device const block_iq4_xs * x = (device const block_iq4_xs *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); const int ix = tiisg/16; // 0 or 1 const int it = tiisg%16; // 0...15 @@ -5188,12 +5347,14 @@ kernel void kernel_mul_mv_iq1_s_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -5202,7 +5363,7 @@ kernel void kernel_mul_mv_iq1_s_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq1_m_f32")]] @@ -5216,12 +5377,14 @@ kernel void kernel_mul_mv_iq1_m_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -5230,7 +5393,7 @@ kernel void kernel_mul_mv_iq1_m_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_nl_f32")]] @@ -5244,12 +5407,14 @@ kernel void kernel_mul_mv_iq4_nl_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -5259,7 +5424,7 @@ kernel void kernel_mul_mv_iq4_nl_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_xs_f32")]] @@ -5273,12 +5438,14 @@ kernel void kernel_mul_mv_iq4_xs_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -5288,7 +5455,7 @@ kernel void kernel_mul_mv_iq4_xs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } //============================= templates and their specializations ============================= @@ -5833,10 +6000,12 @@ kernel void kernel_mul_mm(device const uchar * src0, constant int64_t & ne02, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -5873,12 +6042,13 @@ kernel void kernel_mul_mm(device const uchar * src0, const uint i12 = im%ne12; const uint i13 = im/ne12; - uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); + uint offset0 = (i12/r2)*nb02 + (i13/r3)*nb03; ushort offset1 = il/nl; device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; device const float * y = (device const float *)(src1 - + nb12 * im + + nb13 * i13 + + nb12 * i12 + nb11 * (r1 * BLOCK_SIZE_N + thread_col) + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); @@ -6257,12 +6427,14 @@ typedef void (kernel_mul_mv_impl_t)( uint64_t nb00, uint64_t nb01, uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne11, int64_t ne12, uint64_t nb10, uint64_t nb11, uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -6277,8 +6449,14 @@ typedef void (kernel_mul_mv2_impl_t)( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -6299,6 +6477,7 @@ void mmv_fn( uint64_t nb00, uint64_t nb01, uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne11, int64_t ne12, @@ -6306,6 +6485,7 @@ void mmv_fn( uint64_t nb10, uint64_t nb11, uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint64_t nb1, @@ -6316,7 +6496,7 @@ void mmv_fn( uint tiitg, uint tiisg, uint sgitg) { - impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg); + impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,nb03,ne10,ne11,ne12,nb10,nb11,nb12,nb13,ne0,ne1,r2,r3,tgpig,tiisg); } template @@ -6330,6 +6510,7 @@ void mmv_fn( uint64_t nb00, uint64_t nb01, uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne11, int64_t ne12, @@ -6337,6 +6518,7 @@ void mmv_fn( uint64_t nb10, uint64_t nb11, uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint64_t nb1, @@ -6347,7 +6529,7 @@ void mmv_fn( uint tiitg, uint tiisg, uint sgitg) { - impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); + impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); } typedef decltype(mmv_fn>) mul_mv_impl_fn_t; @@ -6396,8 +6578,8 @@ kernel void kernel_mul_mv_id( const int64_t i2 = i12; device const char * src0_cur = src0s + i02*nb02; - device const char * src1_cur = src1 + i11*nb11 + i12*nb12; - device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; + device const char * src1_cur = src1 + i11*nb11 + i12*nb12; + device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; impl_fn( /* src0 */ src0_cur, @@ -6405,19 +6587,21 @@ kernel void kernel_mul_mv_id( /* dst */ dst_cur, /* ne00 */ ne00, /* ne01 */ ne01, - /* ne02 */ 1,//ne02, + /* ne02 */ 1, // ne02, /* nb00 */ nb00, /* nb01 */ nb01, /* nb02 */ nb02, + /* nb03 */ nb02, // ne02 == 1 /* ne10 */ ne10, - /* ne11 */ 1,//ne11, - /* ne12 */ 1,//ne12, - /* ne13 */ 1,//ne13, + /* ne11 */ 1, // ne11, + /* ne12 */ 1, // ne12, + /* ne13 */ 1, // ne13, /* nb10 */ nb10, /* nb11 */ nb11, /* nb12 */ nb12, + /* ne13 */ nb12, // ne12 == 1 /* ne0 */ ne0, - /* ne1 */ 1,//ne1, + /* ne1 */ 1, // ne1, /* nb1 */ nb1, /* r2 */ 1, /* r3 */ 1, diff --git a/src/ggml.c b/src/ggml.c index 1741d3338..66df9a9c1 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -3464,7 +3464,7 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) { size_t ggml_nbytes(const struct ggml_tensor * tensor) { size_t nbytes; - size_t blck_size = ggml_blck_size(tensor->type); + const size_t blck_size = ggml_blck_size(tensor->type); if (blck_size == 1) { nbytes = ggml_type_size(tensor->type); for (int i = 0; i < GGML_MAX_DIMS; ++i) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 7e769a91a..2e3ad79f0 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1650,11 +1650,12 @@ struct test_mul_mat : public test_case { const int64_t m; const int64_t n; const int64_t k; - const std::array bs; // dims 3 and 4 - const std::array nr; // repeat in dims 3 and 4 + const std::array bs; // dims 3 and 4 + const std::array nr; // repeat in dims 3 and 4 + const std::array per; // permutation of dimensions std::string vars() override { - return VARS_TO_STR7(type_a, type_b, m, n, k, bs, nr); + return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, per); } double max_nmse_err() override { @@ -1669,17 +1670,44 @@ struct test_mul_mat : public test_case { test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, int64_t m = 32, int64_t n = 32, int64_t k = 32, std::array bs = {10, 10}, - std::array nr = {2, 2}) - : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr) {} + std::array nr = {2, 2}, + std::array per = {0, 1, 2, 3}) + : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per) {} ggml_tensor * build_graph(ggml_context * ctx) override { // C^T = A * B^T: (k, m) * (k, n) => (m, n) - ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0] , bs[1]); - ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]); - ggml_set_param(ctx, a); - ggml_set_param(ctx, b); - ggml_set_name(a, "a"); - ggml_set_name(b, "b"); + ggml_tensor * a; + ggml_tensor * b; + + const int npermuted = (per[0] != 0) + (per[1] != 1) + (per[2] != 2) + (per[3] != 3); + if (npermuted > 0) { + GGML_ASSERT(npermuted == 2); + GGML_ASSERT(!ggml_is_quantized(type_a) || per[0] == 0); + GGML_ASSERT(!ggml_is_quantized(type_b) || per[0] == 0); + + // Create tensors with the permuted dimensions, then permute them back to the dimensions given by m,n,k. + const int64_t ne_a[4] = {k, m, bs[0], bs[1]}; + const int64_t ne_b[4] = {k, n, bs[0]*nr[0], bs[1]*nr[1]}; + + a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]); + b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]); + ggml_set_param(ctx, a); + ggml_set_param(ctx, b); + ggml_set_name(a, "a"); + ggml_set_name(b, "b"); + + a = ggml_permute(ctx, a, per[0], per[1], per[2], per[3]); + b = ggml_permute(ctx, b, per[0], per[1], per[2], per[3]); + ggml_set_name(a, "a_permuted"); + ggml_set_name(b, "b_permuted"); + } else { + a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]); + b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]); + ggml_set_param(ctx, a); + ggml_set_param(ctx, b); + ggml_set_name(a, "a"); + ggml_set_name(b, "b"); + } ggml_tensor * out = ggml_mul_mat(ctx, a, b); ggml_set_name(out, "out"); @@ -3478,13 +3506,14 @@ static std::vector> make_test_cases_eval() { #if 1 for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) { - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2})); + // test cases without permutation + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1, 1}, {1, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {1, 1})); @@ -3493,6 +3522,19 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2})); + + // test cases with permutation + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 1, 3, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 3, 2, 1})); + + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 2, 1, 3})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 1, 3, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 3, 2, 1})); + + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1})); } } for (ggml_type type_a : other_types) {