From 5b14955c3a6ced9abe13be804f6a3164937e944c Mon Sep 17 00:00:00 2001 From: ro99 Date: Tue, 10 Sep 2024 18:49:06 -0300 Subject: [PATCH 01/15] ext2 kernels --- .gitignore | 3 +- mistralrs-quant/kernels/exl2/compat.cuh | 59 ++ mistralrs-quant/kernels/exl2/matrix_view.cuh | 124 ++++ mistralrs-quant/kernels/exl2/q_gemm_exl2.cu | 93 +++ .../kernels/exl2/q_gemm_kernel.cuh | 556 ++++++++++++++++++ mistralrs-quant/kernels/exl2/q_matrix.cu | 435 ++++++++++++++ mistralrs-quant/kernels/exl2/q_matrix.cuh | 72 +++ mistralrs-quant/kernels/exl2/quant/qdq_2.cuh | 78 +++ mistralrs-quant/kernels/exl2/quant/qdq_3.cuh | 138 +++++ mistralrs-quant/kernels/exl2/quant/qdq_4.cuh | 141 +++++ mistralrs-quant/kernels/exl2/quant/qdq_5.cuh | 170 ++++++ mistralrs-quant/kernels/exl2/quant/qdq_6.cuh | 36 ++ mistralrs-quant/kernels/exl2/quant/qdq_8.cuh | 32 + .../kernels/exl2/quant/qdq_util.cuh | 56 ++ 14 files changed, 1992 insertions(+), 1 deletion(-) create mode 100644 mistralrs-quant/kernels/exl2/compat.cuh create mode 100644 mistralrs-quant/kernels/exl2/matrix_view.cuh create mode 100644 mistralrs-quant/kernels/exl2/q_gemm_exl2.cu create mode 100644 mistralrs-quant/kernels/exl2/q_gemm_kernel.cuh create mode 100644 mistralrs-quant/kernels/exl2/q_matrix.cu create mode 100644 mistralrs-quant/kernels/exl2/q_matrix.cuh create mode 100644 mistralrs-quant/kernels/exl2/quant/qdq_2.cuh create mode 100644 mistralrs-quant/kernels/exl2/quant/qdq_3.cuh create mode 100644 mistralrs-quant/kernels/exl2/quant/qdq_4.cuh create mode 100644 mistralrs-quant/kernels/exl2/quant/qdq_5.cuh create mode 100644 mistralrs-quant/kernels/exl2/quant/qdq_6.cuh create mode 100644 mistralrs-quant/kernels/exl2/quant/qdq_8.cuh create mode 100644 mistralrs-quant/kernels/exl2/quant/qdq_util.cuh diff --git a/.gitignore b/.gitignore index 9a2aada80..83b45cf26 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ /target .ruff_cache .vscode -*.a \ No newline at end of file +*.a +.DS_Store diff --git a/mistralrs-quant/kernels/exl2/compat.cuh b/mistralrs-quant/kernels/exl2/compat.cuh new file mode 100644 index 000000000..9e7851c5c --- /dev/null +++ b/mistralrs-quant/kernels/exl2/compat.cuh @@ -0,0 +1,59 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#ifndef _compat_cuh +#define _compat_cuh + +// atomicAdd for half types, to support CC < 7.x + +__device__ __forceinline__ void atomicAdd_half(half* address, half val) +{ + unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do + { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } + while (assumed != old); +} + +// atomicAdd for half2 types + +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) +{ + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do + { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } + while (assumed != old); +} + +// + +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) +#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + +__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } + +#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } +#endif + +#endif +#endif + +#endif \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/matrix_view.cuh b/mistralrs-quant/kernels/exl2/matrix_view.cuh new file mode 100644 index 000000000..dd0aebf52 --- /dev/null +++ b/mistralrs-quant/kernels/exl2/matrix_view.cuh @@ -0,0 +1,124 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#ifndef _matrix_view_cuh +#define _matrix_view_cuh + +#include +#include + +#include "quant/qdq_util.cuh" + +class MatrixView_half +{ +public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } + + __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const + { + half2* ptr = (half2*) item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __low2half(i01); + items[1] = __high2half(i01); + items[2] = __low2half(i23); + items[3] = __high2half(i23); + } + __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const + { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2float(__low2half(i01)); + items[1] = __half2float(__high2half(i01)); + items[2] = __half2float(__low2half(i23)); + items[3] = __half2float(__high2half(i23)); + } + + __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const + { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2half2(__low2half(i01)); + items[1] = __half2half2(__high2half(i01)); + items[2] = __half2half2(__low2half(i23)); + items[3] = __half2half2(__high2half(i23)); + } +}; + +class MatrixView_half_rw +{ +public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } + __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } + + __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) + { + half2 v01 = __halves2half2(v0, v1); + half2 v23 = __halves2half2(v2, v3); + half2* ptr = (half2*) item_ptr(row, column); + ptr[0] = v01; + ptr[1] = v23; + } +}; + +class MatrixView_q4_row +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const + { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const + { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + items[2] = (d >> 8) & 0x0f; + items[3] = (d >> 12) & 0x0f; + } +}; + +#endif \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/q_gemm_exl2.cu b/mistralrs-quant/kernels/exl2/q_gemm_exl2.cu new file mode 100644 index 000000000..10ed88271 --- /dev/null +++ b/mistralrs-quant/kernels/exl2/q_gemm_exl2.cu @@ -0,0 +1,93 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#include + +#include + +#include "q_matrix.cuh" +#include "matrix_view.cuh" +#include "quant/qdq_2.cuh" +#include "quant/qdq_3.cuh" +#include "quant/qdq_4.cuh" +#include "quant/qdq_5.cuh" +#include "quant/qdq_6.cuh" +#include "quant/qdq_8.cuh" +#include "q_gemm_kernel.cuh" + +#define MAX_Q_GEMM_ROWS 32 +#define EXL2_BLOCK_KN_SIZE 64 +#define EXL2_BLOCK_M_SIZE_MAX 8 +#define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32) +#if defined(USE_ROCM) +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm( + hipblasHandle_t handle, hipblasOperation_t transA, + hipblasOperation_t transB, int m, int n, int k, const half* alpha, + const half* AP, int lda, const half* BP, int ldb, const half* beta, + half* CP, int ldc) { + return hipblasHgemm(handle, transA, transB, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); +} + #define hipblasHgemm __compat_hipblasHgemm +#endif +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) + +void gemm_half_q_half_cuda_part(const half* a, QMatrix* b, half* c, int size_m, + int size_n, int size_k, int m_count, + bool clear) { + { + dim3 blockDim, gridDim; + blockDim.x = EXL2_BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, EXL2_BLOCK_KN_SIZE * 4); + gridDim.y = DIVIDE(size_m, m_count); + gridDim.z = DIVIDE(b->height, EXL2_BLOCK_KN_SIZE); + + fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count); + + kernel<<>>( + a, b->cuda_q_weight, b->cuda_q_scale, b->cuda_q_scale_max, c, size_m, + size_n, size_k, b->height, b->groups, b->cuda_q_group_map, + b->cuda_q_perm, b->rows_8, b->rows_6, b->rows_5, b->rows_4, b->rows_3, + b->rows_2, clear); + } +} + +void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, + QMatrix* b, half* c, int size_m, int size_n, + int size_k, bool clear, half* temp_dq) { + if (size_m > MAX_Q_GEMM_ROWS) { + // Reconstruct FP16 matrix, then cuBLAS + b->reconstruct(temp_dq); + + // cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH); + + const half alpha = __float2half(1.0f); + const half beta = clear ? __float2half(0.0f) : __float2half(1.0f); + cublasHgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, size_n, size_m, size_k, + &alpha, temp_dq, size_n, a, size_k, &beta, c, size_n); + } else { + // Quantized matmul + + int block_m_size_max = EXL2_BLOCK_M_SIZE_MAX; + int max_chunks = size_m / block_m_size_max; + int last_chunk = max_chunks * block_m_size_max; + int last_chunk_size = size_m - last_chunk; + + if (max_chunks) { + gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, + block_m_size_max, clear); + } + + if (last_chunk_size) { + gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, + c + last_chunk * size_n, last_chunk_size, + size_n, size_k, last_chunk_size, clear); + } + } +} \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/q_gemm_kernel.cuh b/mistralrs-quant/kernels/exl2/q_gemm_kernel.cuh new file mode 100644 index 000000000..6612dabd1 --- /dev/null +++ b/mistralrs-quant/kernels/exl2/q_gemm_kernel.cuh @@ -0,0 +1,556 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#include "compat.cuh" + +#define MAX_Q_GEMM_WEIGHTS 4 +#define EXL2_BLOCK_KN_SIZE 64 +#define EXL2_BLOCK_M_SIZE_MAX 8 +#define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32) + +__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h) +{ + // Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127 + + float result = {}; + #pragma unroll + for (int i = 0; i < 4; i++) + { + half2 w01 = dq[i]; + float w0 = __low2float(w01); + float w1 = __high2float(w01); + float x0 = __half2float(*a_ptr++); + float x1 = __half2float(*a_ptr++); + result = fma(w0, x0, result); + result = fma(w1, x1, result); + } + float qs = __half2float(qs_h); + result *= qs; + half result_h = __float2half_rn(result); + return __hadd(result_h, g_result); +} + +__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} + +__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} + + +typedef void (*fp_gemm_half_q_half_kernel) +( + const half*, + const uint32_t*, + const uint32_t*, + const half*, + half*, + const int, + const int, + const int, + const int, + const int, + const uint16_t*, + const uint16_t*, + const int, + const int, + const int, + const int, + const int, + const int, + const bool +); + +template +__global__ void gemm_half_q_half_kernel +( + const half* __restrict__ a, + const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_q_scale, + const half* __restrict__ b_q_scale_max, + half* __restrict__ c, + const int size_m, + const int size_n, + const int size_k, + const int height, + const int groups, + const uint16_t* __restrict__ b_q_group_map, + const uint16_t* __restrict__ b_q_perm, + const int rows_8, + const int rows_6, + const int rows_5, + const int rows_4, + const int rows_3, + const int rows_2, + const bool clear +) +{ + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n); + + int t = threadIdx.x; + + // Block + + int offset_n = blockIdx.x * EXL2_BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * EXL2_BLOCK_KN_SIZE; + + int end_n = min(offset_n + EXL2_BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + EXL2_BLOCK_KN_SIZE, height); + int n = offset_n + t * 4; + + // Read weights + + half_uint16 weights[MAX_Q_GEMM_WEIGHTS]; + + // Preload block_a + + __shared__ half block_a[m_count][EXL2_BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) + { + for (int m = 0; m < m_count; ++m) + { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + half a0 = a_ptr[b_q_perm[offset_k + t]]; +// half a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Clear + + if (n >= size_n) return; + + if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0) + { + for (int m = 0; m < m_count; m++) + *((uint64_t*) c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + + //int group = offset_k / groupsize; + int group = b_q_group_map[offset_k * 2]; + +// if (offset_m == 0 && t == 0) +// DBGI2(offset_k, group); + + // Preload scales + + half scales[EXL2_MAX_GROUPS_IN_BLOCK][4]; + + //int groups_in_block = DIVIDE((end_k - offset_k), groupsize); + int temp_k = offset_k; + for (int g = 0; temp_k < end_k; g++) + { + int qscales[4]; + b_q_scale_.item4(qscales, group + g, n); + qscales[0]++; + qscales[1]++; + qscales[2]++; + qscales[3]++; + half maxscale = b_q_scale_max[group + g]; + scales[g][0] = __hmul(__int2half_rn(qscales[0] * qscales[0]), maxscale); + scales[g][1] = __hmul(__int2half_rn(qscales[1] * qscales[1]), maxscale); + scales[g][2] = __hmul(__int2half_rn(qscales[2] * qscales[2]), maxscale); + scales[g][3] = __hmul(__int2half_rn(qscales[3] * qscales[3]), maxscale); + temp_k += b_q_group_map[temp_k * 2 + 1]; + } + + // a, b offset + + int pre_rows_8 = min(rows_8, offset_k); + int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0; + int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0; + int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0; + int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0; + int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0; + int qk = 0; + qk += pre_rows_8 / 32 * 8; + qk += pre_rows_6 / 32 * 6; + qk += pre_rows_5 / 32 * 5; + qk += pre_rows_4 / 32 * 4; + qk += pre_rows_3 / 32 * 3; + qk += pre_rows_2 / 32 * 2; + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = EXL2_BLOCK_KN_SIZE; + + // Initial group + + int scales_idx = 0; + half qs_h0 = scales[scales_idx][0]; + half qs_h1 = scales[scales_idx][1]; + half qs_h2 = scales[scales_idx][2]; + half qs_h3 = scales[scales_idx][3]; + int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1]; + + // Column result + + half block_c[m_count][4] = {}; + + // Dequantize groups + + int k = offset_k; + + while (k < rows_8 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 4; j++) + { + int4 load_int4[2]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 8; + } + k += 32; + } + + while (k < rows_6 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 2; j++) + { + int4 load_int4[3]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][8]; + dequant_6bit_16(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n); + dequant_6bit_16(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n); + dequant_6bit_16(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n); + dequant_6bit_16(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 16; + } + k += 32; + } + + while (k < rows_5 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 1; j++) + { + int4 load_int4[5]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[3] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[4] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][16]; + dequant_5bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, load_int4[3].x, load_int4[4].x, dq[0], size_n); + dequant_5bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, load_int4[3].y, load_int4[4].y, dq[1], size_n); + dequant_5bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, load_int4[3].z, load_int4[4].z, dq[2], size_n); + dequant_5bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, load_int4[3].w, load_int4[4].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 32; + } + + k += 32; + } + + while (k < rows_4 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 4; j++) + { + int4 load_int4[1]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][4]; + dequant_4bit_8(load_int4[0].x, dq[0], size_n); + dequant_4bit_8(load_int4[0].y, dq[1], size_n); + dequant_4bit_8(load_int4[0].z, dq[2], size_n); + dequant_4bit_8(load_int4[0].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 8; + } + k += 32; + } + + while (k < rows_3 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 1; j++) + { + int4 load_int4[3]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 32; + } + k += 32; + } + + while (k < rows_2 && k < end_k) + { + if (k == nextgroup) + { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + + #pragma unroll + for (int j = 0; j < 1; j++) + { + int4 load_int4[1]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][8]; + dequant_2bit_16(load_int4[0].x, dq[0], size_n); + dequant_2bit_16(load_int4[0].y, dq[1], size_n); + dequant_2bit_16(load_int4[0].z, dq[2], size_n); + dequant_2bit_16(load_int4[0].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + + a_ptr += 16; + } + k += 16; + } + + // Accumulate column sums in c + + for (int m = 0; m < m_count; m++) + { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + + atomicAdd(out , result01); + atomicAdd(out + 1, result23); +// *out = result01; +// *(out + 1) = result23; + } +} + +struct map_m_count_exl2 { + static constexpr fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count) + { + #if EXL2_BLOCK_M_SIZE_MAX >= 1 + if (m_count == 1) return gemm_half_q_half_kernel<1>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 2 + if (m_count == 2) return gemm_half_q_half_kernel<2>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 3 + if (m_count == 3) return gemm_half_q_half_kernel<3>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 4 + if (m_count == 4) return gemm_half_q_half_kernel<4>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 5 + if (m_count == 5) return gemm_half_q_half_kernel<5>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 6 + if (m_count == 6) return gemm_half_q_half_kernel<6>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 7 + if (m_count == 7) return gemm_half_q_half_kernel<7>; + #endif + #if EXL2_BLOCK_M_SIZE_MAX >= 8 + if (m_count == 8) return gemm_half_q_half_kernel<8>; + #endif + return NULL; + } +}; + +fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count) +{ + return map_m_count_exl2::pick_gemm_half_q_half_kernel(m_count); +} \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/q_matrix.cu b/mistralrs-quant/kernels/exl2/q_matrix.cu new file mode 100644 index 000000000..cab969a8e --- /dev/null +++ b/mistralrs-quant/kernels/exl2/q_matrix.cu @@ -0,0 +1,435 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ + +#include + +#include "q_matrix.cuh" +#include "matrix_view.cuh" + +#include "quant/qdq_2.cuh" +#include "quant/qdq_3.cuh" +#include "quant/qdq_4.cuh" +#include "quant/qdq_5.cuh" +#include "quant/qdq_6.cuh" +#include "quant/qdq_8.cuh" + +#define BLOCK_KN_SIZE 128 + +#define THREADS_X 32 +#define THREADS_Y 32 + +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) + +// Shuffle quantized data on load + +__global__ void shuffle_kernel( + uint32_t *__restrict__ b_q_weight, + const int size_k, + const int size_n, + const int rows_8, + const int rows_6, + const int rows_5, + const int rows_4, + const int rows_3, + const int rows_2) +{ + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) + return; + int k = 0; + uint32_t *b_ptr = b_q_weight + n; + while (k < rows_8) + { + shuffle_8bit_4(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 4; + } + while (k < rows_6) + { + shuffle_6bit_16(b_ptr, size_n); + b_ptr += 3 * size_n; + k += 16; + } + while (k < rows_5) + { + shuffle_5bit_32(b_ptr, size_n); + b_ptr += 5 * size_n; + k += 32; + } + while (k < rows_4) + { + shuffle_4bit_8(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 8; + } + while (k < rows_3) + { + shuffle_3bit_32(b_ptr, size_n); + b_ptr += 3 * size_n; + k += 32; + } + while (k < rows_2) + { + shuffle_2bit_16(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 16; + } +} + +// QMatrix constructor + +QMatrix::QMatrix( + const int _device, + const int _height, + const int _width, + const int _groups, + + uint32_t *_q_weight, + uint16_t *_q_perm, + uint16_t *_q_invperm, + uint32_t *_q_scale, + half *_q_scale_max, + uint16_t *_q_groups, + uint16_t *_q_group_map) : device(_device), + height(_height), + width(_width), + groups(_groups) +{ + cudaSetDevice(device); + + failed = false; + + cuda_q_weight = _q_weight; + cuda_q_perm = _q_perm; + cuda_q_invperm = _q_invperm; + cuda_q_scale = _q_scale; + cuda_q_scale_max = _q_scale_max; + cuda_q_groups = _q_groups; + cuda_q_group_map = _q_group_map; + + // Create group map + + rows_8 = 0; + rows_6 = 0; + rows_5 = 0; + rows_4 = 0; + rows_3 = 0; + rows_2 = 0; + + { + uint16_t *cpu_q_groups = (uint16_t *)calloc(groups * 2, sizeof(uint16_t)); + cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost); + + int row = 0; + for (int i = 0; i < groups; i++) + { + int bits = cpu_q_groups[i * 2]; + + int rows; + if (i < groups - 1) + { + int qrows = cpu_q_groups[i * 2 + 3] - cpu_q_groups[i * 2 + 1]; + rows = qrows * 32 / bits; + } + else + rows = height - row; + + if (bits == 8) + rows_8 += rows; + if (bits == 6) + rows_6 += rows; + if (bits == 5) + rows_5 += rows; + if (bits == 4) + rows_4 += rows; + if (bits == 3) + rows_3 += rows; + if (bits == 2) + rows_2 += rows; + row += rows; + } + + free(cpu_q_groups); + + rows_6 += rows_8; + rows_5 += rows_6; + rows_4 += rows_5; + rows_3 += rows_4; + rows_2 += rows_3; + } + + // Shuffle quantized data + + dim3 blockDim, gridDim; + blockDim.x = THREADS_X; + blockDim.y = 1; + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = 1; + + shuffle_kernel<<>>(cuda_q_weight, height, width,rows_8, rows_6, rows_5,rows_4, rows_3, rows_2); +} + +QMatrix::~QMatrix() {} + +// Reconstruct b[k,n] + +__global__ void reconstruct_kernel( + const uint32_t *__restrict__ b_q_weight, + const uint16_t *__restrict__ b_q_perm, + const uint32_t *__restrict__ b_q_scale, + const half *__restrict__ b_q_scale_max, + const uint16_t *__restrict__ b_q_group_map, + const int size_k, + const int size_n, + // const int groupsize, + const int groups, + half *__restrict__ b, + const int rows_8, + const int rows_6, + const int rows_5, + const int rows_4, + const int rows_3, + const int rows_2) +{ + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x; + + // Preload remapping table + + int t = threadIdx.x; + __shared__ uint16_t perm[BLOCK_KN_SIZE]; + if (offset_k + t < size_k) + perm[t] = b_q_perm[offset_k + t]; + + // Column + + int n = offset_n + t; + if (n >= size_n) + return; + + // Find initial group + + // int group = offset_k / groupsize; + int group = b_q_group_map[offset_k * 2]; + + int pre_rows_8 = min(rows_8, offset_k); + int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0; + int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0; + int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0; + int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0; + int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0; + int qk = 0; + qk += pre_rows_8 / 32 * 8; + qk += pre_rows_6 / 32 * 6; + qk += pre_rows_5 / 32 * 5; + qk += pre_rows_4 / 32 * 4; + qk += pre_rows_3 / 32 * 3; + qk += pre_rows_2 / 32 * 2; + + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; + + half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + half2 qs_h2 = __halves2half2(qs_h, qs_h); + int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1]; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + int k = offset_k; + int lk = 0; + + __syncthreads(); + + while (k < rows_8 && k < end_k) + { + if (k == nextgroup) + { + group++; + qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + nextgroup += b_q_group_map[k * 2 + 1]; + qs_h2 = __halves2half2(qs_h, qs_h); + } + for (int p = 0; p < 4; p++) + { + half2 dq[4]; + uint32_t q_0 = *b_ptr; + b_ptr += size_n; + uint32_t q_1 = *b_ptr; + b_ptr += size_n; + dequant_8bit_8(q_0, q_1, dq, size_n); + for (int j = 0; j < 4; j++) + dq[j] = __hmul2(dq[j], qs_h2); + half *dqh = (half *)dq; + for (int j = 0; j < 8; j++) + b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_6 && k < end_k) + { + if (k == nextgroup) + { + group++; + qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + nextgroup += b_q_group_map[k * 2 + 1]; + qs_h2 = __halves2half2(qs_h, qs_h); + } + for (int p = 0; p < 2; p++) + { + half2 dq[8]; + uint32_t q_0 = *b_ptr; + b_ptr += size_n; + uint32_t q_1 = *b_ptr; + b_ptr += size_n; + uint32_t q_2 = *b_ptr; + b_ptr += size_n; + dequant_6bit_16(q_0, q_1, q_2, dq, size_n); + for (int j = 0; j < 8; j++) + dq[j] = __hmul2(dq[j], qs_h2); + half *dqh = (half *)dq; + for (int j = 0; j < 16; j++) + b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_5 && k < end_k) + { + if (k == nextgroup) + { + group++; + qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + nextgroup += b_q_group_map[k * 2 + 1]; + qs_h2 = __halves2half2(qs_h, qs_h); + } + for (int p = 0; p < 1; p++) + { + half2 dq[16]; + uint32_t q_0 = *b_ptr; + b_ptr += size_n; + uint32_t q_1 = *b_ptr; + b_ptr += size_n; + uint32_t q_2 = *b_ptr; + b_ptr += size_n; + uint32_t q_3 = *b_ptr; + b_ptr += size_n; + uint32_t q_4 = *b_ptr; + b_ptr += size_n; + dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n); + for (int j = 0; j < 16; j++) + dq[j] = __hmul2(dq[j], qs_h2); + half *dqh = (half *)dq; + for (int j = 0; j < 32; j++) + b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_4 && k < end_k) + { + if (k == nextgroup) + { + group++; + qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + nextgroup += b_q_group_map[k * 2 + 1]; + qs_h2 = __halves2half2(qs_h, qs_h); + } + for (int p = 0; p < 4; p++) + { + half2 dq[4]; + uint32_t q_0 = *b_ptr; + b_ptr += size_n; + dequant_4bit_8(q_0, dq, size_n); + for (int j = 0; j < 4; j++) + dq[j] = __hmul2(dq[j], qs_h2); + half *dqh = (half *)dq; + for (int j = 0; j < 8; j++) + b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_3 && k < end_k) + { + if (k == nextgroup) + { + group++; + qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + nextgroup += b_q_group_map[k * 2 + 1]; + qs_h2 = __halves2half2(qs_h, qs_h); + } + for (int p = 0; p < 1; p++) + { + half2 dq[16]; + uint32_t q_0 = *b_ptr; + b_ptr += size_n; + uint32_t q_1 = *b_ptr; + b_ptr += size_n; + uint32_t q_2 = *b_ptr; + b_ptr += size_n; + dequant_3bit_32(q_0, q_1, q_2, dq, size_n); + for (int j = 0; j < 16; j++) + dq[j] = __hmul2(dq[j], qs_h2); + half *dqh = (half *)dq; + for (int j = 0; j < 32; j++) + b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_2 && k < end_k) + { + if (k == nextgroup) + { + group++; + qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + nextgroup += b_q_group_map[k * 2 + 1]; + qs_h2 = __halves2half2(qs_h, qs_h); + } + for (int p = 0; p < 1; p++) + { + half2 dq[8]; + uint32_t q_0 = *b_ptr; + b_ptr += size_n; + dequant_2bit_16(q_0, dq, size_n); + for (int j = 0; j < 8; j++) + dq[j] = __hmul2(dq[j], qs_h2); + half *dqh = (half *)dq; + for (int j = 0; j < 16; j++) + b_.set(perm[lk++], n, dqh[j]); + } + k += 16; + } +} + +void QMatrix::reconstruct(half *out) +{ + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); + + { + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + reconstruct_kernel<<>>( + cuda_q_weight, + cuda_q_perm, + cuda_q_scale, + cuda_q_scale_max, + cuda_q_group_map, + height, + width, + // groupsize, + groups, + out, + rows_8, + rows_6, + rows_5, + rows_4, + rows_3, + rows_2); + } +} \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/q_matrix.cuh b/mistralrs-quant/kernels/exl2/q_matrix.cuh new file mode 100644 index 000000000..6eba6284e --- /dev/null +++ b/mistralrs-quant/kernels/exl2/q_matrix.cuh @@ -0,0 +1,72 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#ifndef _q_matrix_cuh +#define _q_matrix_cuh + +#include +#include +#include +#include + +#define MAX_SUPERGROUPS 16 + +class QMatrix +{ +public: + + int device; + bool is_gptq; + + int height; + int width; + int groups; + int gptq_groupsize; + + int rows_8; + int rows_6; + int rows_5; + int rows_4; + int rows_3; + int rows_2; + + uint32_t* cuda_q_weight = NULL; + uint16_t* cuda_q_perm = NULL; + uint16_t* cuda_q_invperm = NULL; + uint32_t* cuda_q_scale = NULL; + half* cuda_q_scale_max = NULL; + uint16_t* cuda_q_groups = NULL; + uint16_t* cuda_q_group_map = NULL; + uint32_t* cuda_gptq_qzeros = NULL; + half* cuda_gptq_scales = NULL; + + half* temp_dq; + + bool failed; + + QMatrix + ( + const int _device, + const int _height, + const int _width, + const int _groups, + + uint32_t* _q_weight, + uint16_t* _q_perm, + uint16_t* _q_invperm, + uint32_t* _q_scale, + half* _q_scale_max, + uint16_t* _q_groups, + uint16_t* _q_group_map + ); + + ~QMatrix(); + + void reconstruct(half* out); + bool make_sequential(const uint32_t* cpu_g_idx); + +private: + +}; + +#endif \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/quant/qdq_2.cuh b/mistralrs-quant/kernels/exl2/quant/qdq_2.cuh new file mode 100644 index 000000000..d4fdc337a --- /dev/null +++ b/mistralrs-quant/kernels/exl2/quant/qdq_2.cuh @@ -0,0 +1,78 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#ifndef _qdq_2_cuh +#define _qdq_2_cuh + +#include "qdq_util.cuh" + +// Permutation: +// +// ffddbb99 77553311 eeccaa88 66442200 + +__forceinline__ __device__ void shuffle_2bit_16 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0]; + uint32_t qb = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + uint32_t qa0 = qa & 0x03; + uint32_t qa1 = (qa & 0x0c) >> 2; + qa >>= 4; + qb |= (qa1 << (i * 2 + 16)); + qb |= (qa0 << (i * 2)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_2bit_16 +( + const uint32_t q_0, + half2 (&dq)[8], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y4_ = __float2half_rn(1.0f / 4.0f); + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y4 = __halves2half2(y4_, y4_); + const half2 y16 = __halves2half2(y16_, y16_); + const half2 y64 = __halves2half2(y64_, y64_); + const half z1_ = __float2half_rn(-1024.0f - 2.0f); + const half z4_ = __float2half_rn(-1024.0f / 4.0f - 2.0f); + const half z16_ = __float2half_rn(-1024.0f / 16.0f - 2.0f); + const half z64_ = __float2half_rn(-1024.0f / 64.0f - 2.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z4 = __halves2half2(z4_, z4_); + const half2 z16 = __halves2half2(z16_, z16_); + const half2 z64 = __halves2half2(z64_, z64_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 + half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 + half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 + qa >>= 8; + half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 + half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 + half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 + half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y4, z4); + dq[2] = __hfma2(q2.as_half2, y16, z16); + dq[3] = __hfma2(q3.as_half2, y64, z64); + dq[4] = __hadd2(q4.as_half2, z1); + dq[5] = __hfma2(q5.as_half2, y4, z4); + dq[6] = __hfma2(q6.as_half2, y16, z16); + dq[7] = __hfma2(q7.as_half2, y64, z64); +} + +#endif \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/quant/qdq_3.cuh b/mistralrs-quant/kernels/exl2/quant/qdq_3.cuh new file mode 100644 index 000000000..b357e020a --- /dev/null +++ b/mistralrs-quant/kernels/exl2/quant/qdq_3.cuh @@ -0,0 +1,138 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#ifndef _qdq_3_cuh +#define _qdq_3_cuh + +#include "qdq_util.cuh" + +// Permutation: +// +// v9997775 55333111 u8886664 44222000 (u, v lsb) +// vjjjhhhf ffdddbbb uiiiggge eecccaaa +// vtttrrrp ppnnnlll usssqqqo oommmkkk + +__forceinline__ __device__ void shuffle_3bit_32 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0 * stride]; + uint32_t qb = q[1 * stride]; + uint32_t qc = q[2 * stride]; + + // qa: aa999888 77766655 54443332 22111000 + // qb: lkkkjjji iihhhggg fffeeedd dcccbbba + // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll + + uint32_t qd = qc >> 26; + qc <<= 4; + qc |= qb >> 28; + qb <<= 2; + qb |= qa >> 30; + + // qa: ..999888 77766655 54443332 22111000 + // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa + // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk + // qd: vvvuuu + + uint32_t za = 0; + uint32_t zb = 0; + uint32_t zc = 0; + + for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); } + for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); } + for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); } + + // za: 9997775 55333111 8886664 44222000 + // zb: jjjhhhf ffdddbbb iiiggge eecccaaa + // zc: tttrrrp ppnnnlll sssqqqo oommmkkk + // qd: vvvuuu + + za |= ((qd & 0x01) >> 0) << 15; + zb |= ((qd & 0x02) >> 1) << 15; + zc |= ((qd & 0x04) >> 2) << 15; + za |= ((qd & 0x08) >> 3) << 31; + zb |= ((qd & 0x10) >> 4) << 31; + zc |= ((qd & 0x20) >> 5) << 31; + + // za: v9997775 55333111 u8886664 44222000 (u, v lsb) + // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa + // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk + + q[0 * stride] = za; + q[1 * stride] = zb; + q[2 * stride] = zc; +} + +__forceinline__ __device__ void dequant_3bit_32 +( + const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[16], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y8_ = __float2half_rn(1.0f / 8.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y8 = __halves2half2(y8_, y8_); + const half2 y64 = __halves2half2(y64_, y64_); + const half z1_ = __float2half_rn(-1024.0f - 4.0f); + const half z8_ = __float2half_rn(-1024.0f / 8.0f - 4.0f); + const half z64_ = __float2half_rn(-1024.0f / 64.0f - 4.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z8 = __halves2half2(z8_, z8_); + const half2 z64 = __halves2half2(z64_, z64_); + + uint32_t qa = q_0; + uint32_t qb = q_1; + uint32_t qc = q_2; + + half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 + qa >>= 6; + half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 + half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 + qa >>= 9; + qa &= 0x00010001; + half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 + half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 + qb >>= 6; + half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 + half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 + half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 + qb >>= 8; + qb &= 0x00020002; + half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 + half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 + qc >>= 6; + half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 + half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 + half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 + qc >>= 7; + qc &= 0x00040004; + half2_uint32 q15((qa | qb | qc) | c0); + + dq[ 0] = __hadd2( q0.as_half2, z1); + dq[ 1] = __hfma2( q1.as_half2, y8, z8); + dq[ 2] = __hadd2( q2.as_half2, z1); + dq[ 3] = __hfma2( q3.as_half2, y8, z8); + dq[ 4] = __hfma2( q4.as_half2, y64, z64); + dq[ 5] = __hadd2( q5.as_half2, z1); + dq[ 6] = __hfma2( q6.as_half2, y8, z8); + dq[ 7] = __hadd2( q7.as_half2, z1); + dq[ 8] = __hfma2( q8.as_half2, y8, z8); + dq[ 9] = __hfma2( q9.as_half2, y64, z64); + dq[10] = __hadd2(q10.as_half2, z1); + dq[11] = __hfma2(q11.as_half2, y8, z8); + dq[12] = __hadd2(q12.as_half2, z1); + dq[13] = __hfma2(q13.as_half2, y8, z8); + dq[14] = __hfma2(q14.as_half2, y64, z64); + dq[15] = __hadd2(q15.as_half2, z1); +} + +#endif \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/quant/qdq_4.cuh b/mistralrs-quant/kernels/exl2/quant/qdq_4.cuh new file mode 100644 index 000000000..cf1d52d60 --- /dev/null +++ b/mistralrs-quant/kernels/exl2/quant/qdq_4.cuh @@ -0,0 +1,141 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#ifndef _qdq_4_cuh +#define _qdq_4_cuh + +#include "qdq_util.cuh" + +// Permutation: +// +// 77775555 33331111 66664444 22220000 + +__forceinline__ __device__ void shuffle_4bit_8 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0]; + uint32_t qb = 0; + + #pragma unroll + for (int i = 0; i < 4; i++) + { + uint32_t qa0 = qa & 0x0f; + uint32_t qa1 = (qa & 0xf0) >> 4; + qa >>= 8; + qb |= (qa1 << (i * 4 + 16)); + qb |= (qa0 << (i * 4)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_4bit_8 +( + const uint32_t q_0, + half2 (&dq)[4], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half2 y16 = __halves2half2(y16_, y16_); + const half z1_ = __float2half_rn(-1024.0f - 8.0f); + const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z16 = __halves2half2(z16_, z16_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y16, z16); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y16, z16); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale +( + const uint32_t zero, + const half scale, + half2 (&z1z16)[2], + half2 (&y1y16)[2] +) +{ + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + half2 scale2 = __half2half2(scale); + + z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); + z1z16[1] = __hmul2(scale2, __half2half2(z16)); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __hmul2(scale2, __half2half2(y1)); + y1y16[1] = __hmul2(scale2, __half2half2(y16)); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero +( + const uint32_t zero, + half2(&z1z16)[2], + half2(&y1y16)[2] +) +{ + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + z1z16[0] = __half2half2(z1.as_half); + z1z16[1] = __half2half2(z16); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __half2half2(y1); + y1y16[1] = __half2half2(y16); +} + + +__forceinline__ __device__ void dequant_4bit_8_gptq +( + const uint32_t q_0, + half2 (&dq)[4], + half2 (&z1z16)[2], + half2 (&y1y16)[2], + int stride, + bool scaled +) +{ + const uint32_t c0 = 0x64006400; + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 ) + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 ) + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) + + if (scaled) + { + dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) + dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) + dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); + } + else + { + dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) + dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z ) + dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z ) + } +} + +#endif \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/quant/qdq_5.cuh b/mistralrs-quant/kernels/exl2/quant/qdq_5.cuh new file mode 100644 index 000000000..9866fc9b9 --- /dev/null +++ b/mistralrs-quant/kernels/exl2/quant/qdq_5.cuh @@ -0,0 +1,170 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#ifndef _qdq_5_cuh +#define _qdq_5_cuh + +#include "qdq_util.cuh" + +// Permutation: +// +// v5555533 33311111 u4444422 22200000 (u, v lsb) +// vbbbbb99 99977777 uaaaaa88 88866666 +// vhhhhhff fffddddd ugggggee eeeccccc +// vnnnnnll llljjjjj ummmmmkk kkkiiiii +// vtttttrr rrrppppp usssssqq qqqooooo + +__forceinline__ __device__ void shuffle_5bit_32 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0 * stride]; + uint32_t qb = q[1 * stride]; + uint32_t qc = q[2 * stride]; + uint32_t qd = q[3 * stride]; + uint32_t qe = q[4 * stride]; + + // qa: 66555554 44443333 32222211 11100000 + // qb: ccccbbbb baaaaa99 99988888 77777666 + // qc: jiiiiihh hhhggggg fffffeee eedddddc + // qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj + // qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp + + uint32_t qf = qe >> 22; + qe <<= 8; + qe |= qd >> 24; + qd <<= 6; + qd |= qc >> 26; + qc <<= 4; + qc |= qb >> 28; + qb <<= 2; + qb |= qa >> 30; + + // qa: 555554 44443333 32222211 11100000 + // qb: bbbbba aaaa9999 98888877 77766666 + // qc: hhhhhg ggggffff feeeeedd dddccccc + // qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii + // qe: ttttts ssssrrrr rqqqqqpp pppooooo + // qf: vv vvvuuuuu + + uint32_t za = 0; + uint32_t zb = 0; + uint32_t zc = 0; + uint32_t zd = 0; + uint32_t ze = 0; + + for (int i = 0; i < 3; i++) { uint32_t t0 = qa & 0x1f; uint32_t t1 = (qa & 0x3e0) >> 5; qa >>= 10; za |= (t0 << (i * 5)); za |= (t1 << (i * 5 + 16)); } + for (int i = 0; i < 3; i++) { uint32_t t0 = qb & 0x1f; uint32_t t1 = (qb & 0x3e0) >> 5; qb >>= 10; zb |= (t0 << (i * 5)); zb |= (t1 << (i * 5 + 16)); } + for (int i = 0; i < 3; i++) { uint32_t t0 = qc & 0x1f; uint32_t t1 = (qc & 0x3e0) >> 5; qc >>= 10; zc |= (t0 << (i * 5)); zc |= (t1 << (i * 5 + 16)); } + for (int i = 0; i < 3; i++) { uint32_t t0 = qd & 0x1f; uint32_t t1 = (qd & 0x3e0) >> 5; qd >>= 10; zd |= (t0 << (i * 5)); zd |= (t1 << (i * 5 + 16)); } + for (int i = 0; i < 3; i++) { uint32_t t0 = qe & 0x1f; uint32_t t1 = (qe & 0x3e0) >> 5; qe >>= 10; ze |= (t0 << (i * 5)); ze |= (t1 << (i * 5 + 16)); } + + // za: 5555533 33311111 4444422 22200000 + // zb: bbbbb99 99977777 aaaaa88 88866666 + // zc: hhhhhff fffddddd gggggee eeeccccc + // zd: nnnnnll llljjjjj mmmmmkk kkkiiiii + // ze: tttttrr rrrppppp sssssqq qqqooooo + // qf: vv vvvuuuuu + + za |= ((qf & 0x001) >> 0) << 15; + zb |= ((qf & 0x002) >> 1) << 15; + zc |= ((qf & 0x004) >> 2) << 15; + zd |= ((qf & 0x008) >> 3) << 15; + ze |= ((qf & 0x010) >> 4) << 15; + za |= ((qf & 0x020) >> 5) << 31; + zb |= ((qf & 0x040) >> 6) << 31; + zc |= ((qf & 0x080) >> 7) << 31; + zd |= ((qf & 0x100) >> 8) << 31; + ze |= ((qf & 0x200) >> 9) << 31; + + // za: v5555533 33311111 u4444422 22200000 (u, v lsb) + // zb: vbbbbb99 99977777 uaaaaa88 88866666 + // zc: vhhhhhff fffddddd ugggggee eeeccccc + // zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii + // ze: vtttttrr rrrppppp usssssqq qqqooooo + + q[0 * stride] = za; + q[1 * stride] = zb; + q[2 * stride] = zc; + q[3 * stride] = zd; + q[4 * stride] = ze; +} + +__forceinline__ __device__ void dequant_5bit_32 +( + const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + const uint32_t q_3, + const uint32_t q_4, + half2 (&dq)[16], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y32_ = __float2half_rn(1.0f / 32.0f); + const half2 y32 = __halves2half2(y32_, y32_); + const half z1_ = __float2half_rn(-1024.0f - 16.0f); + const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z32 = __halves2half2(z32_, z32_); + + uint32_t qa = q_0; + uint32_t qb = q_1; + uint32_t qc = q_2; + uint32_t qd = q_3; + uint32_t qe = q_4; + + half2_uint32 q0 ((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1 ((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024 + qa >>= 10; + half2_uint32 q2 ((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5]) + 1024 + qa >>= 5; + qa &= 0x00010001; + half2_uint32 q3 ((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7]) + 1024 + half2_uint32 q4 ((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024 + qb >>= 10; + half2_uint32 q5 ((qb & 0x001f001f) | c0); // half2(q[10], q[11]) + 1024 + qb >>= 4; + qb &= 0x00020002; + half2_uint32 q6 ((qc & 0x001f001f) | c0); // half2(q[12], q[13]) + 1024 + half2_uint32 q7 ((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024 + qc >>= 10; + half2_uint32 q8 ((qc & 0x001f001f) | c0); // half2(q[16], q[17]) + 1024 + qc >>= 3; + qc &= 0x00040004; + half2_uint32 q9 ((qd & 0x001f001f) | c0); // half2(q[18], q[19]) + 1024 + half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024 + qd >>= 10; + half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23]) + 1024 + qd >>= 2; + qd &= 0x00080008; + half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25]) + 1024 + half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024 + qe >>= 10; + half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29]) + 1024 + qe >>= 1; + qe &= 0x00100010; + half2_uint32 q15((qa | qb | qc | qd | qe) | c0); + + dq[ 0] = __hadd2( q0.as_half2, z1); + dq[ 1] = __hfma2( q1.as_half2, y32, z32); + dq[ 2] = __hadd2( q2.as_half2, z1); + dq[ 3] = __hadd2( q3.as_half2, z1); + dq[ 4] = __hfma2( q4.as_half2, y32, z32); + dq[ 5] = __hadd2( q5.as_half2, z1); + dq[ 6] = __hadd2( q6.as_half2, z1); + dq[ 7] = __hfma2( q7.as_half2, y32, z32); + dq[ 8] = __hadd2( q8.as_half2, z1); + dq[ 9] = __hadd2( q9.as_half2, z1); + dq[10] = __hfma2(q10.as_half2, y32, z32); + dq[11] = __hadd2(q11.as_half2, z1); + dq[12] = __hadd2(q12.as_half2, z1); + dq[13] = __hfma2(q13.as_half2, y32, z32); + dq[14] = __hadd2(q14.as_half2, z1); + dq[15] = __hadd2(q15.as_half2, z1); +} + +#endif \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/quant/qdq_6.cuh b/mistralrs-quant/kernels/exl2/quant/qdq_6.cuh new file mode 100644 index 000000000..43b2659a1 --- /dev/null +++ b/mistralrs-quant/kernels/exl2/quant/qdq_6.cuh @@ -0,0 +1,36 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#ifndef _qdq_6_cuh +#define _qdq_6_cuh + +#include "qdq_util.cuh" + +__forceinline__ __device__ void shuffle_6bit_16 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_6bit_16 +( + const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[8], + int stride +) +{ + half dqh[16]; + for (int i = 0; i < 5; i++) dqh[ i] = dq_ns(exb( q_0, i * 6 , 0x3f), 32); + dqh[ 5 ] = dq_ns(exb(q_1, q_0, 30, 0x3f), 32); + for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb( q_1, i * 6 + 4, 0x3f), 32); + dqh[10 ] = dq_ns(exb(q_2, q_1, 28, 0x3f), 32); + for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb( q_2, i * 6 + 2, 0x3f), 32); + + for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +#endif \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/quant/qdq_8.cuh b/mistralrs-quant/kernels/exl2/quant/qdq_8.cuh new file mode 100644 index 000000000..807f7fb96 --- /dev/null +++ b/mistralrs-quant/kernels/exl2/quant/qdq_8.cuh @@ -0,0 +1,32 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#ifndef _qdq_8_cuh +#define _qdq_8_cuh + +#include "qdq_util.cuh" + +__forceinline__ __device__ void shuffle_8bit_4 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_8bit_8 +( + const uint32_t q_0, + const uint32_t q_1, + half2 (&dq)[4], + int stride +) +{ + half dqh[8]; + for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), 128); + for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 128); + + for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +#endif \ No newline at end of file diff --git a/mistralrs-quant/kernels/exl2/quant/qdq_util.cuh b/mistralrs-quant/kernels/exl2/quant/qdq_util.cuh new file mode 100644 index 000000000..79a4bf365 --- /dev/null +++ b/mistralrs-quant/kernels/exl2/quant/qdq_util.cuh @@ -0,0 +1,56 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 +*/ +#ifndef _qdq_util_cuh +#define _qdq_util_cuh + +union half2_uint32 +{ + uint32_t as_uint32; + half2 as_half2; + __device__ half2_uint32(uint32_t val) : as_uint32(val) {} + __device__ half2_uint32(half2 val) : as_half2(val) {} + __device__ half2_uint32() : as_uint32(0) {} +}; + +union half_uint16 +{ + uint16_t as_uint16; + half as_half; + __device__ half_uint16(uint16_t val) : as_uint16(val) {} + __device__ half_uint16(half val) : as_half(val) {} + __device__ half_uint16() : as_uint16(0) {} +}; + +// Max_scale premultiplied by 1/256 + +__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) +{ + int qs_i = qs + 1; + half qs_h = __int2half_rn(qs_i * qs_i); + qs_h = __hmul(qs_h, max_scale); + return qs_h; +} + +__forceinline__ __device__ half dq(const int q, const int qzero, const half scale) +{ + return __hmul(__int2half_rn(q - qzero), scale); +} + +__forceinline__ __device__ half dq_ns(const int q, const int qzero) +{ + //return __hsub(__int2half_rn(q), __int2half_rn(qzero)); + return __int2half_rn(q - qzero); +} + +__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) +{ + return (int)((q >> shift) & mask); +} + +__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) +{ + return (int)(__funnelshift_rc(q0, q1, shift) & mask); +} + +#endif \ No newline at end of file From 03a2b72afd3c4c91b977f1d97154051951b68c03 Mon Sep 17 00:00:00 2001 From: ro99 Date: Wed, 11 Sep 2024 09:26:22 -0300 Subject: [PATCH 02/15] initial structure for exl2 --- mistralrs-quant/build.rs | 1 + mistralrs-quant/kernels/exl2/q_gemm_exl2.cu | 50 ++----- mistralrs-quant/src/exl2/exl2_cuda.rs | 139 ++++++++++++++++++++ mistralrs-quant/src/exl2/ffi.rs | 23 ++++ mistralrs-quant/src/exl2/mod.rs | 2 + mistralrs-quant/src/lib.rs | 11 ++ 6 files changed, 186 insertions(+), 40 deletions(-) create mode 100644 mistralrs-quant/src/exl2/exl2_cuda.rs create mode 100644 mistralrs-quant/src/exl2/ffi.rs create mode 100644 mistralrs-quant/src/exl2/mod.rs diff --git a/mistralrs-quant/build.rs b/mistralrs-quant/build.rs index d9e09f1c6..9b52da293 100644 --- a/mistralrs-quant/build.rs +++ b/mistralrs-quant/build.rs @@ -8,6 +8,7 @@ fn main() { println!("cargo:rerun-if-changed=build.rs"); let build_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap()); let lib_files = vec![ + "kernels/exl2/q_gemm_exl2.cu", "kernels/gptq/q_gemm.cu", "kernels/hqq/hqq.cu", "kernels/ops/ops.cu", diff --git a/mistralrs-quant/kernels/exl2/q_gemm_exl2.cu b/mistralrs-quant/kernels/exl2/q_gemm_exl2.cu index 10ed88271..21a8360b8 100644 --- a/mistralrs-quant/kernels/exl2/q_gemm_exl2.cu +++ b/mistralrs-quant/kernels/exl2/q_gemm_exl2.cu @@ -1,7 +1,6 @@ /* Adapted from https://github.com/turboderp/exllamav2 */ -#include #include @@ -36,10 +35,16 @@ __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm( #endif #define DIVIDE(x, size) (((x) + (size) - 1) / (size)) -void gemm_half_q_half_cuda_part(const half* a, QMatrix* b, half* c, int size_m, - int size_n, int size_k, int m_count, - bool clear) { - { +extern "C" void gemm_half_q_half_cuda_part_exl2( + const half* a, + QMatrix* b, + half* c, + int size_m, + int size_n, + int size_k, + int m_count, + bool clear +) { dim3 blockDim, gridDim; blockDim.x = EXL2_BLOCK_KN_SIZE; blockDim.y = 1; @@ -55,39 +60,4 @@ void gemm_half_q_half_cuda_part(const half* a, QMatrix* b, half* c, int size_m, size_n, size_k, b->height, b->groups, b->cuda_q_group_map, b->cuda_q_perm, b->rows_8, b->rows_6, b->rows_5, b->rows_4, b->rows_3, b->rows_2, clear); - } -} - -void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, - QMatrix* b, half* c, int size_m, int size_n, - int size_k, bool clear, half* temp_dq) { - if (size_m > MAX_Q_GEMM_ROWS) { - // Reconstruct FP16 matrix, then cuBLAS - b->reconstruct(temp_dq); - - // cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH); - - const half alpha = __float2half(1.0f); - const half beta = clear ? __float2half(0.0f) : __float2half(1.0f); - cublasHgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, size_n, size_m, size_k, - &alpha, temp_dq, size_n, a, size_k, &beta, c, size_n); - } else { - // Quantized matmul - - int block_m_size_max = EXL2_BLOCK_M_SIZE_MAX; - int max_chunks = size_m / block_m_size_max; - int last_chunk = max_chunks * block_m_size_max; - int last_chunk_size = size_m - last_chunk; - - if (max_chunks) { - gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, - block_m_size_max, clear); - } - - if (last_chunk_size) { - gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, - c + last_chunk * size_n, last_chunk_size, - size_n, size_k, last_chunk_size, clear); - } - } } \ No newline at end of file diff --git a/mistralrs-quant/src/exl2/exl2_cuda.rs b/mistralrs-quant/src/exl2/exl2_cuda.rs new file mode 100644 index 000000000..f2a9aee8e --- /dev/null +++ b/mistralrs-quant/src/exl2/exl2_cuda.rs @@ -0,0 +1,139 @@ +use candle_core::{DType, Device, Result, Shape, Tensor}; +use std::sync::Arc; + +pub struct Exl2Layer { + q_weight: Tensor, + q_scale: Tensor, + q_scale_max: Tensor, + q_groups: Tensor, + q_perm: Tensor, + q_invperm: Tensor, + q_group_map: Tensor, + bias: Option, + bits: i32, + exllama_state: i32, + q_matrix: *mut std::ffi::c_void, +} + +impl QuantMethod for Exl2Layer { + fn new(method: QuantMethodConfig) -> Result { + match method { + QuantMethodConfig::Exl2 { + bits, + q_weight, + q_scale, + q_scale_max, + q_groups, + q_invperm, + bias, + } => { + let q_perm = q_invperm.argsort()?.to_dtype(DType::U16)?; + let q_group_map = make_group_map(&q_groups, q_weight.dim(0)?)?; + Ok(Self { + q_weight, + q_scale, + q_scale_max, + q_groups, + q_perm, + q_invperm, + q_group_map, + bias, + bits, + exllama_state: 0, + q_matrix: std::ptr::null_mut(), + }) + } + _ => candle_core::bail!("Expected Exl2 config"), + } + } + + fn forward(&self, x: &Tensor) -> Result { + let out_shape = Shape::from_dims( + &[ + &x.dims()[..x.dims().len() - 1], + &[self.q_weight.dim(candle_core::D::Minus1)?], + ] + .concat(), + ); + let reshaped_x = x.reshape(((), x.dim(candle_core::D::Minus1)?))?; + + if self.exllama_state == 0 { + self.prepare_weights()?; + } + + let mut output = self.exl2_gemm(reshaped_x)?; + if let Some(bias) = &self.bias { + output = output.broadcast_add(bias)?; + } + output.reshape(out_shape) + } + + // Implement other required methods... +} + +impl Exl2Layer { + fn prepare_weights(&mut self) -> Result<()> { + self.q_scale_max = &self.q_scale_max / 256.0; + self.q_invperm = self.q_invperm.to_dtype(DType::U16)?; + self.q_matrix = unsafe { + exl2_make_q_matrix( + self.q_weight.as_ptr()?, + self.q_perm.as_ptr()?, + self.q_invperm.as_ptr()?, + self.q_scale.as_ptr()?, + self.q_scale_max.as_ptr()?, + self.q_groups.as_ptr()?, + self.q_group_map.as_ptr()?, + ) + }; + self.exllama_state = 1; + Ok(()) + } + + fn exl2_gemm(&self, x: Tensor) -> Result { + let (m, k) = (x.dim(0)?, x.dim(1)?); + let n = self.q_weight.dim(1)?; + let c_shape = Shape::from_dims(&[m, n]); + + let c = unsafe { + let dev = get_cuda_device(&x)?; + let c = dev.alloc::(c_shape.elem_count())?; + exl2_gemm( + x.as_ptr()?, + self.q_matrix, + c.device_ptr() as *mut f16, + m as i32, + n as i32, + k as i32, + ); + c + }; + + Ok(Tensor::from_cuda_slice(&c, c_shape, x.device())?) + } +} + +fn make_group_map(q_groups: &Tensor, num_qrows: usize) -> Result { + let gr = q_groups.to_vec1::()?; + let mut group_map = Vec::new(); + let num_groups = gr.len() / 2; + + let mut row = 0; + for i in 0..num_groups { + let bits = gr[i * 2] as usize; + let rows = if i < num_groups - 1 { + let qrows = gr[i * 2 + 3] as usize - gr[i * 2 + 1] as usize; + qrows * 32 / bits + } else { + num_qrows - gr[i * 2 + 1] as usize + }; + + for _ in 0..rows { + group_map.push(i as u16); + group_map.push(rows as u16); + } + row += rows; + } + + Tensor::from_vec(group_map, (group_map.len(),), q_groups.device()) +} \ No newline at end of file diff --git a/mistralrs-quant/src/exl2/ffi.rs b/mistralrs-quant/src/exl2/ffi.rs new file mode 100644 index 000000000..1c3233f82 --- /dev/null +++ b/mistralrs-quant/src/exl2/ffi.rs @@ -0,0 +1,23 @@ +use half::f16; + +#[allow(dead_code)] +extern "C" { + pub fn exl2_make_q_matrix( + q_weight: *const u32, + q_perm: *const u16, + q_invperm: *const u16, + q_scale: *const u32, + q_scale_max: *const f16, + q_groups: *const u16, + q_group_map: *const u16, + ) -> *mut std::ffi::c_void; + + pub fn exl2_gemm( + a: *const f16, + b: *const std::ffi::c_void, + c: *mut f16, + m: i32, + n: i32, + k: i32, + ); +} \ No newline at end of file diff --git a/mistralrs-quant/src/exl2/mod.rs b/mistralrs-quant/src/exl2/mod.rs new file mode 100644 index 000000000..c381dbcaa --- /dev/null +++ b/mistralrs-quant/src/exl2/mod.rs @@ -0,0 +1,2 @@ +mod ffi; +mod exl2_cuda; \ No newline at end of file diff --git a/mistralrs-quant/src/lib.rs b/mistralrs-quant/src/lib.rs index 684039a5a..24d01e85b 100644 --- a/mistralrs-quant/src/lib.rs +++ b/mistralrs-quant/src/lib.rs @@ -9,6 +9,7 @@ use candle_core::{ DType, Device, Result, Tensor, }; +mod exl2; mod gguf; mod gptq; mod hqq; @@ -47,6 +48,16 @@ pub struct QuantizedConfig { #[derive(Debug, Clone)] pub enum QuantMethodConfig { + Exl2 { + bits: i32, + q_weight: Tensor, + q_scale: Tensor, + q_scale_max: Tensor, + q_groups: Tensor, + q_perm: Tensor, + q_invperm: Tensor, + bias: Option, + }, Gptq { bits: i32, use_exllama: bool, From 4e485a07159dab085d1dd2148a60697c88fa0260 Mon Sep 17 00:00:00 2001 From: ro99 Date: Wed, 11 Sep 2024 17:14:10 -0300 Subject: [PATCH 03/15] mess --- mistralrs-quant/kernels/exl2/q_gemm_exl2.cu | 40 +++++++ mistralrs-quant/src/exl2/exl2_cuda.rs | 124 ++++++++++++++++---- mistralrs-quant/src/exl2/ffi.rs | 18 ++- 3 files changed, 154 insertions(+), 28 deletions(-) diff --git a/mistralrs-quant/kernels/exl2/q_gemm_exl2.cu b/mistralrs-quant/kernels/exl2/q_gemm_exl2.cu index 21a8360b8..406310a5c 100644 --- a/mistralrs-quant/kernels/exl2/q_gemm_exl2.cu +++ b/mistralrs-quant/kernels/exl2/q_gemm_exl2.cu @@ -60,4 +60,44 @@ extern "C" void gemm_half_q_half_cuda_part_exl2( size_n, size_k, b->height, b->groups, b->cuda_q_group_map, b->cuda_q_perm, b->rows_8, b->rows_6, b->rows_5, b->rows_4, b->rows_3, b->rows_2, clear); +} + +extern "C" uintptr_t exl2_make_q_matrix( + const int device, + const int height, + const int width, + const int groups, + uint32_t q_weight, + uint16_t q_perm, + uint16_t q_invperm, + uint32_t q_scale, + half q_scale_max, + uint16_t q_groups, + uint16_t q_group_map +) { + QMatrix* m = new QMatrix + ( + device, + height, + width, + groups, + (uint32_t*)q_weight.data_ptr(), + (uint16_t*)q_perm.data_ptr(), + (uint16_t*)q_invperm.data_ptr(), + (uint32_t*)q_scale.data_ptr(), + (half*)q_scale_max.data_ptr(), + (uint16_t*)q_groups.data_ptr(), + (uint16_t*)q_group_map.data_ptr() + ); + return reinterpret_cast(m); +} + +extern "C" void exl2_reconstruct_q_matrix(uintptr_t q_matrix) { + QMatrix* m = reinterpret_cast(q_matrix); + m->reconstruct(); +} + +extern "C" void exl2_destroy_q_matrix(uintptr_t q_matrix) { + QMatrix* m = reinterpret_cast(q_matrix); + delete m; } \ No newline at end of file diff --git a/mistralrs-quant/src/exl2/exl2_cuda.rs b/mistralrs-quant/src/exl2/exl2_cuda.rs index f2a9aee8e..0dbdaebc2 100644 --- a/mistralrs-quant/src/exl2/exl2_cuda.rs +++ b/mistralrs-quant/src/exl2/exl2_cuda.rs @@ -1,6 +1,15 @@ use candle_core::{DType, Device, Result, Shape, Tensor}; +use candle_core::cudarc::{ + cublas::{result::hgemm, sys::cublasOperation_t}, + driver::{CudaSlice, DevicePtr}, +}; use std::sync::Arc; +const MAX_Q_GEMM_ROWS: i32 = 32; +const BLOCK_M_SIZE_MAX: i32 = 8; + + + pub struct Exl2Layer { q_weight: Tensor, q_scale: Tensor, @@ -19,16 +28,17 @@ impl QuantMethod for Exl2Layer { fn new(method: QuantMethodConfig) -> Result { match method { QuantMethodConfig::Exl2 { - bits, q_weight, q_scale, q_scale_max, q_groups, + q_perm, q_invperm, + q_group_map, bias, + bits, } => { - let q_perm = q_invperm.argsort()?.to_dtype(DType::U16)?; - let q_group_map = make_group_map(&q_groups, q_weight.dim(0)?)?; + Ok(Self { q_weight, q_scale, @@ -58,7 +68,8 @@ impl QuantMethod for Exl2Layer { let reshaped_x = x.reshape(((), x.dim(candle_core::D::Minus1)?))?; if self.exllama_state == 0 { - self.prepare_weights()?; + let dev = get_cuda_device(&x)?; + self.prepare_weights(dev.id())?; } let mut output = self.exl2_gemm(reshaped_x)?; @@ -72,11 +83,21 @@ impl QuantMethod for Exl2Layer { } impl Exl2Layer { - fn prepare_weights(&mut self) -> Result<()> { + fn prepare_weights(&mut self, device_id: i32) -> Result<()> { self.q_scale_max = &self.q_scale_max / 256.0; self.q_invperm = self.q_invperm.to_dtype(DType::U16)?; + + let q_perm = self.q_invperm.argsort()?.to_dtype(DType::U16)?; + let q_group_map = make_group_map(&q_groups, q_weight.dim(0)?)?; + self.q_matrix = unsafe { - exl2_make_q_matrix( + exl2_create_q_matrix( + device_id, + + self.q_perm.dims(0)? as i32, + self.q_weight.dim(1)? as i32, + self.q_scale.dim(0)? as i32, + self.q_weight.as_ptr()?, self.q_perm.as_ptr()?, self.q_invperm.as_ptr()?, @@ -90,26 +111,78 @@ impl Exl2Layer { Ok(()) } - fn exl2_gemm(&self, x: Tensor) -> Result { - let (m, k) = (x.dim(0)?, x.dim(1)?); - let n = self.q_weight.dim(1)?; - let c_shape = Shape::from_dims(&[m, n]); - - let c = unsafe { - let dev = get_cuda_device(&x)?; - let c = dev.alloc::(c_shape.elem_count())?; - exl2_gemm( - x.as_ptr()?, - self.q_matrix, - c.device_ptr() as *mut f16, - m as i32, - n as i32, - k as i32, - ); - c + fn exl2_gemm(&self, a: Tensor) -> Result { + + let dev = get_cuda_device(&a)?; + let qm_width = self.q_weight.dims()[1]?; + let c_shape = Shape::from_dims(&[a.dims()[0], qm_width]); + + let (m, n, k) = ( + c_shape.dims()[0] as i32, + c_shape.dims()[1] as i32, + a.dims()[1] as i32, + ); + + let c = unsafe { dev.alloc::(c_shape.elem_count()).w()? }; + let c_ptr = *c.device_ptr() as *mut f16; + + // Create temp_dq as a Tensor, using a zero-sized tensor when not needed + // (TODO: review if this is the best solution here) + let temp_dq = if c_shape.dims()[0] > MAX_Q_GEMM_ROWS as usize { + Tensor::zeros(&[a.dims()[1], qm_width], DType::F16, &dev)? + } else { + Tensor::zeros(&[0, 0], DType::F16, &dev)? }; + + let a_ptr = get_cuda_slice::(a)?; + let temp_dq_ptr = temp_dq.device_ptr() as *const f16; + + if m > MAX_Q_GEMM_ROWS { + // Reconstruct FP16 matrix, then cuBLAS + unsafe { + super::ffi::exl2_reconstruct_q_matrix(self.q_matrix); + } + + let alpha = f16::from_f32(1.0); + let beta = if clear { f16::from_f32(0.0) } else { f16::from_f32(1.0) }; + + unsafe { + hgemm( + *cublas_handle.handle(), + cublasOperation_t::CUBLAS_OP_N, + cublasOperation_t::CUBLAS_OP_N, + n, + m, + k, + &alpha, + temp_dq_ptr as *const _, + n, + a_ptr as *const _, + k, + &beta, + c_ptr, + n, + ) + .w()? + }; + - Ok(Tensor::from_cuda_slice(&c, c_shape, x.device())?) + + + } else { + // Quantized matmul + } + + } +} + +impl Drop for Exl2Layer { + fn drop(&mut self) { + if !self.q_matrix.is_null() { + unsafe { + exl2_destroy_q_matrix(self.q_matrix); + } + } } } @@ -136,4 +209,5 @@ fn make_group_map(q_groups: &Tensor, num_qrows: usize) -> Result { } Tensor::from_vec(group_map, (group_map.len(),), q_groups.device()) -} \ No newline at end of file +} + diff --git a/mistralrs-quant/src/exl2/ffi.rs b/mistralrs-quant/src/exl2/ffi.rs index 1c3233f82..5afbd7306 100644 --- a/mistralrs-quant/src/exl2/ffi.rs +++ b/mistralrs-quant/src/exl2/ffi.rs @@ -1,8 +1,16 @@ use half::f16; +use std::ffi::c_void; + +// Opaque pointer type for QMatrix +type QMatrixPtr = *mut c_void; #[allow(dead_code)] extern "C" { - pub fn exl2_make_q_matrix( + pub fn exl2_create_q_matrix( + device: i32, + height: i32, // q_perm.size(0); + width: i32, // q_weight.size(1); + groups: i32, // q_scale.size(0); q_weight: *const u32, q_perm: *const u16, q_invperm: *const u16, @@ -10,11 +18,15 @@ extern "C" { q_scale_max: *const f16, q_groups: *const u16, q_group_map: *const u16, - ) -> *mut std::ffi::c_void; + ) -> QMatrixPtr; + + pub fn exl2_destroy_q_matrix(q_matrix: QMatrixPtr); + + pub fn exl2_reconstruct_q_matrix(q_matrix: QMatrixPtr); pub fn exl2_gemm( a: *const f16, - b: *const std::ffi::c_void, + b: *const c_void, c: *mut f16, m: i32, n: i32, From 811dae35a298bb5dcd8e32288215870f720e7e4c Mon Sep 17 00:00:00 2001 From: ro99 Date: Thu, 12 Sep 2024 15:03:33 -0300 Subject: [PATCH 04/15] . --- Cargo.lock | 58 +++-- Cargo.toml | 4 +- mistralrs-paged-attn/src/backend/mod.rs | 1 + mistralrs-quant/build.rs | 2 +- mistralrs-quant/src/exl2/exl2_cuda.rs | 302 ++++++++++++++---------- mistralrs-quant/src/exl2/ffi.rs | 2 +- mistralrs-quant/src/gguf/mod.rs | 3 +- mistralrs-quant/src/gptq/gptq_cpu.rs | 3 +- mistralrs-quant/src/gptq/gptq_cuda.rs | 6 +- mistralrs-quant/src/hqq/mod.rs | 3 +- mistralrs-quant/src/lib.rs | 1 + mistralrs-quant/src/unquantized/mod.rs | 3 +- mistralrs-quant/src/utils/ops.rs | 8 + 13 files changed, 251 insertions(+), 145 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c7b214530..ebd034970 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -387,11 +387,11 @@ checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50" [[package]] name = "candle-core" version = "0.6.1" -source = "git+https://github.com/EricLBuehler/candle.git?rev=7f5a470#7f5a47040e798f0076014c9d9e82cc6cb25708a0" +source = "git+https://github.com/ro99/candle.git?rev=2ecc6cc#2ecc6cc071b6b6c68062f327aa8343ad08dbde83" dependencies = [ "accelerate-src", "byteorder", - "candle-kernels", + "candle-kernels 0.6.1 (git+https://github.com/ro99/candle.git?rev=2ecc6cc)", "candle-metal-kernels", "cudarc", "gemm", @@ -411,6 +411,28 @@ dependencies = [ "zip", ] +[[package]] +name = "candle-core" +version = "0.6.1" +source = "git+https://github.com/EricLBuehler/candle.git?rev=7f5a470#7f5a47040e798f0076014c9d9e82cc6cb25708a0" +dependencies = [ + "byteorder", + "candle-kernels 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=7f5a470)", + "cudarc", + "gemm", + "half", + "memmap2", + "num-traits", + "num_cpus", + "rand", + "rand_distr", + "rayon", + "safetensors", + "thiserror", + "yoke", + "zip", +] + [[package]] name = "candle-flash-attn" version = "0.6.1" @@ -418,10 +440,18 @@ source = "git+https://github.com/EricLBuehler/candle.git?rev=7f5a470#7f5a47040e7 dependencies = [ "anyhow", "bindgen_cuda 0.1.5", - "candle-core", + "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=7f5a470)", "half", ] +[[package]] +name = "candle-kernels" +version = "0.6.1" +source = "git+https://github.com/ro99/candle.git?rev=2ecc6cc#2ecc6cc071b6b6c68062f327aa8343ad08dbde83" +dependencies = [ + "bindgen_cuda 0.1.5", +] + [[package]] name = "candle-kernels" version = "0.6.1" @@ -433,7 +463,7 @@ dependencies = [ [[package]] name = "candle-metal-kernels" version = "0.6.1" -source = "git+https://github.com/EricLBuehler/candle.git?rev=7f5a470#7f5a47040e798f0076014c9d9e82cc6cb25708a0" +source = "git+https://github.com/ro99/candle.git?rev=2ecc6cc#2ecc6cc071b6b6c68062f327aa8343ad08dbde83" dependencies = [ "metal", "once_cell", @@ -444,10 +474,10 @@ dependencies = [ [[package]] name = "candle-nn" version = "0.6.1" -source = "git+https://github.com/EricLBuehler/candle.git?rev=7f5a470#7f5a47040e798f0076014c9d9e82cc6cb25708a0" +source = "git+https://github.com/ro99/candle.git?rev=2ecc6cc#2ecc6cc071b6b6c68062f327aa8343ad08dbde83" dependencies = [ "accelerate-src", - "candle-core", + "candle-core 0.6.1 (git+https://github.com/ro99/candle.git?rev=2ecc6cc)", "candle-metal-kernels", "half", "intel-mkl-src", @@ -2097,7 +2127,7 @@ name = "mistralrs" version = "0.3.0" dependencies = [ "anyhow", - "candle-core", + "candle-core 0.6.1 (git+https://github.com/ro99/candle.git?rev=2ecc6cc)", "either", "futures", "image", @@ -2115,7 +2145,7 @@ name = "mistralrs-bench" version = "0.3.0" dependencies = [ "anyhow", - "candle-core", + "candle-core 0.6.1 (git+https://github.com/ro99/candle.git?rev=2ecc6cc)", "clap", "cli-table", "mistralrs-core", @@ -2138,7 +2168,7 @@ dependencies = [ "buildstructor", "bytemuck", "bytemuck_derive", - "candle-core", + "candle-core 0.6.1 (git+https://github.com/ro99/candle.git?rev=2ecc6cc)", "candle-flash-attn", "candle-nn", "cfgrammar", @@ -2199,7 +2229,7 @@ version = "0.3.0" dependencies = [ "anyhow", "bindgen_cuda 0.1.6", - "candle-core", + "candle-core 0.6.1 (git+https://github.com/ro99/candle.git?rev=2ecc6cc)", "half", ] @@ -2210,7 +2240,7 @@ dependencies = [ "accelerate-src", "anyhow", "base64 0.22.1", - "candle-core", + "candle-core 0.6.1 (git+https://github.com/ro99/candle.git?rev=2ecc6cc)", "data-url", "either", "futures", @@ -2232,7 +2262,7 @@ name = "mistralrs-quant" version = "0.3.0" dependencies = [ "bindgen_cuda 0.1.5", - "candle-core", + "candle-core 0.6.1 (git+https://github.com/ro99/candle.git?rev=2ecc6cc)", "candle-nn", "half", "lazy_static", @@ -2249,7 +2279,7 @@ dependencies = [ "accelerate-src", "anyhow", "axum", - "candle-core", + "candle-core 0.6.1 (git+https://github.com/ro99/candle.git?rev=2ecc6cc)", "clap", "ctrlc", "data-url", @@ -2275,7 +2305,7 @@ dependencies = [ name = "mistralrs-vision" version = "0.3.0" dependencies = [ - "candle-core", + "candle-core 0.6.1 (git+https://github.com/ro99/candle.git?rev=2ecc6cc)", "image", ] diff --git a/Cargo.toml b/Cargo.toml index 4343649c4..45e000244 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,8 +25,8 @@ license = "MIT" [workspace.dependencies] anyhow = "1.0.80" -candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "7f5a470" } -candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "7f5a470" } +candle-core = { git = "https://github.com/ro99/candle.git", version = "0.6.0", rev = "2ecc6cc" } +candle-nn = { git = "https://github.com/ro99/candle.git", version = "0.6.0", rev = "2ecc6cc" } serde = "1.0.197" serde_json = "1.0.114" indexmap = { version = "2.2.5", features = ["serde"] } diff --git a/mistralrs-paged-attn/src/backend/mod.rs b/mistralrs-paged-attn/src/backend/mod.rs index ffedc0122..ad40a237c 100644 --- a/mistralrs-paged-attn/src/backend/mod.rs +++ b/mistralrs-paged-attn/src/backend/mod.rs @@ -26,6 +26,7 @@ pub fn get_or_load_func( let spec = match dtype { DType::U8 => "_u8", DType::U32 => "_u32", + DType::I16 => "_i16", DType::I32 => "_i32", DType::I64 => "_i64", DType::BF16 => "_bf16", diff --git a/mistralrs-quant/build.rs b/mistralrs-quant/build.rs index 9b52da293..107bf7bc0 100644 --- a/mistralrs-quant/build.rs +++ b/mistralrs-quant/build.rs @@ -8,7 +8,7 @@ fn main() { println!("cargo:rerun-if-changed=build.rs"); let build_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap()); let lib_files = vec![ - "kernels/exl2/q_gemm_exl2.cu", + //"kernels/exl2/q_gemm_exl2.cu", "kernels/gptq/q_gemm.cu", "kernels/hqq/hqq.cu", "kernels/ops/ops.cu", diff --git a/mistralrs-quant/src/exl2/exl2_cuda.rs b/mistralrs-quant/src/exl2/exl2_cuda.rs index 0dbdaebc2..744ce62d0 100644 --- a/mistralrs-quant/src/exl2/exl2_cuda.rs +++ b/mistralrs-quant/src/exl2/exl2_cuda.rs @@ -1,15 +1,37 @@ -use candle_core::{DType, Device, Result, Shape, Tensor}; -use candle_core::cudarc::{ - cublas::{result::hgemm, sys::cublasOperation_t}, - driver::{CudaSlice, DevicePtr}, +use std::{ + collections::HashMap, + num::NonZeroUsize, + sync::{atomic::AtomicUsize, Arc, Mutex}, +}; + +use candle_core::{ + cuda::{ + cudarc::{ + cublas::{result::hgemm, sys::cublasOperation_t}, + driver::{CudaSlice, DevicePtr}, + }, + CudaStorageSlice, WrapErr, + }, + from_storage_no_op, CudaStorage, DType, Device, Result, Shape, Storage, Tensor, D, +}; +use half::f16; + +use crate::{ + utils::{get_cuda_device, get_cuda_slice}, + IsqType, QuantMethod, QuantMethodConfig, +}; + +use super::ffi::{ + exl2_reconstruct_q_matrix, + exl2_create_q_matrix, + exl2_destroy_q_matrix }; -use std::sync::Arc; const MAX_Q_GEMM_ROWS: i32 = 32; const BLOCK_M_SIZE_MAX: i32 = 8; - +#[derive(Debug)] pub struct Exl2Layer { q_weight: Tensor, q_scale: Tensor, @@ -24,97 +46,52 @@ pub struct Exl2Layer { q_matrix: *mut std::ffi::c_void, } -impl QuantMethod for Exl2Layer { - fn new(method: QuantMethodConfig) -> Result { - match method { - QuantMethodConfig::Exl2 { - q_weight, - q_scale, - q_scale_max, - q_groups, - q_perm, - q_invperm, - q_group_map, - bias, - bits, - } => { +impl Exl2Layer { + fn exl2_gemm(&self, a: Tensor) -> Result { + let dev = get_cuda_device(&a)?; + let a_ptr = get_cuda_slice::(&a)?; - Ok(Self { - q_weight, - q_scale, - q_scale_max, - q_groups, - q_perm, - q_invperm, - q_group_map, - bias, - bits, - exllama_state: 0, - q_matrix: std::ptr::null_mut(), - }) - } - _ => candle_core::bail!("Expected Exl2 config"), - } - } + if self.exllama_state == 0 { - fn forward(&self, x: &Tensor) -> Result { - let out_shape = Shape::from_dims( - &[ - &x.dims()[..x.dims().len() - 1], - &[self.q_weight.dim(candle_core::D::Minus1)?], - ] - .concat(), - ); - let reshaped_x = x.reshape(((), x.dim(candle_core::D::Minus1)?))?; + self.q_scale_max = (self.q_scale_max / 256.0)?; + self.q_invperm = self.q_invperm.to_dtype(DType::I16)?; + self.q_perm = self.q_invperm.arg_sort_last_dim(false)?.to_dtype(DType::I16)?; + self.q_group_map = make_group_map(&self.q_groups, self.q_weight.dim(0)?)?; - if self.exllama_state == 0 { - let dev = get_cuda_device(&x)?; - self.prepare_weights(dev.id())?; - } + // QMatrix entries + let dev_ord = dev.ordinal() as i32; + let b_width = self.q_weight.dims()[1] as i32; + let b_height = self.q_perm.dims()[0] as i32; + let b_groups = self.q_scale.dims()[0] as i32; + let b_q_weight = get_cuda_slice::(&self.q_weight)? as *const u32; + let b_q_perm = get_cuda_slice::(&self.q_perm)? as *const u16; + let b_q_invperm = get_cuda_slice::(&self.q_invperm)? as *const u16; + let b_q_scale = get_cuda_slice::(&self.q_scale)? as *const u32; + let b_q_scale_max = get_cuda_slice::(&self.q_scale_max)?; + let b_q_groups = get_cuda_slice::(&self.q_groups)? as *const u16; + let b_q_group_map = get_cuda_slice::(&self.q_group_map)? as *const u16; - let mut output = self.exl2_gemm(reshaped_x)?; - if let Some(bias) = &self.bias { - output = output.broadcast_add(bias)?; + self.q_matrix = unsafe { + exl2_create_q_matrix( + dev_ord, + b_height, + b_width, + b_groups, + b_q_weight, + b_q_perm, + b_q_invperm, + b_q_scale, + b_q_scale_max, + b_q_groups, + b_q_group_map, + ) + }; + self.exllama_state = 1; } - output.reshape(out_shape) - } - - // Implement other required methods... -} -impl Exl2Layer { - fn prepare_weights(&mut self, device_id: i32) -> Result<()> { - self.q_scale_max = &self.q_scale_max / 256.0; - self.q_invperm = self.q_invperm.to_dtype(DType::U16)?; - - let q_perm = self.q_invperm.argsort()?.to_dtype(DType::U16)?; - let q_group_map = make_group_map(&q_groups, q_weight.dim(0)?)?; - - self.q_matrix = unsafe { - exl2_create_q_matrix( - device_id, - - self.q_perm.dims(0)? as i32, - self.q_weight.dim(1)? as i32, - self.q_scale.dim(0)? as i32, - - self.q_weight.as_ptr()?, - self.q_perm.as_ptr()?, - self.q_invperm.as_ptr()?, - self.q_scale.as_ptr()?, - self.q_scale_max.as_ptr()?, - self.q_groups.as_ptr()?, - self.q_group_map.as_ptr()?, - ) - }; - self.exllama_state = 1; - Ok(()) - } - fn exl2_gemm(&self, a: Tensor) -> Result { - let dev = get_cuda_device(&a)?; - let qm_width = self.q_weight.dims()[1]?; + let qm_width = self.q_weight.dim(1)?; let c_shape = Shape::from_dims(&[a.dims()[0], qm_width]); let (m, n, k) = ( @@ -129,22 +106,26 @@ impl Exl2Layer { // Create temp_dq as a Tensor, using a zero-sized tensor when not needed // (TODO: review if this is the best solution here) let temp_dq = if c_shape.dims()[0] > MAX_Q_GEMM_ROWS as usize { - Tensor::zeros(&[a.dims()[1], qm_width], DType::F16, &dev)? + Tensor::zeros(&[a.dims()[1], qm_width], DType::F16, a.device())? } else { - Tensor::zeros(&[0, 0], DType::F16, &dev)? + Tensor::zeros(&[0, 0], DType::F16, a.device())? }; - let a_ptr = get_cuda_slice::(a)?; - let temp_dq_ptr = temp_dq.device_ptr() as *const f16; + + let temp_dq_ptr = get_cuda_slice::(&temp_dq)?; if m > MAX_Q_GEMM_ROWS { // Reconstruct FP16 matrix, then cuBLAS unsafe { - super::ffi::exl2_reconstruct_q_matrix(self.q_matrix); + exl2_reconstruct_q_matrix(self.q_matrix); } let alpha = f16::from_f32(1.0); - let beta = if clear { f16::from_f32(0.0) } else { f16::from_f32(1.0) }; + let beta = f16::from_f32(0.0); + let cublas_handle = match a.device() { + Device::Cuda(dev) => dev.cublas_handle(), + _ => unreachable!(), // invariant enforced earlier + }; unsafe { hgemm( @@ -165,49 +146,130 @@ impl Exl2Layer { ) .w()? }; - - - } else { - // Quantized matmul + todo!() } - + todo!() } } -impl Drop for Exl2Layer { - fn drop(&mut self) { - if !self.q_matrix.is_null() { - unsafe { - exl2_destroy_q_matrix(self.q_matrix); - } - } - } -} + fn make_group_map(q_groups: &Tensor, num_qrows: usize) -> Result { - let gr = q_groups.to_vec1::()?; + let gr = q_groups.to_vec1::()?; let mut group_map = Vec::new(); let num_groups = gr.len() / 2; - let mut row = 0; for i in 0..num_groups { let bits = gr[i * 2] as usize; - let rows = if i < num_groups - 1 { - let qrows = gr[i * 2 + 3] as usize - gr[i * 2 + 1] as usize; - qrows * 32 / bits + let qrows = if i < num_groups - 1 { + gr[i * 2 + 3] as usize - gr[i * 2 + 1] as usize } else { num_qrows - gr[i * 2 + 1] as usize }; - - for _ in 0..rows { - group_map.push(i as u16); - group_map.push(rows as u16); + let rows = qrows * 32 / bits; + for j in 0..rows { + group_map.push(i as i16); + group_map.push((rows - j) as i16); } - row += rows; } - Tensor::from_vec(group_map, (group_map.len(),), q_groups.device()) + Tensor::from_vec(group_map.clone(), (group_map.len(),), q_groups.device()) } + +impl QuantMethod for Exl2Layer { + fn new(method: QuantMethodConfig) -> Result { + match method { + QuantMethodConfig::Exl2 { + q_weight, + q_scale, + q_scale_max, + q_groups, + q_perm, + q_invperm, + q_group_map, + bias, + bits, + } => { + + Ok(Self { + q_weight, + q_scale, + q_scale_max, + q_groups, + q_perm, + q_invperm, + q_group_map, + bias, + bits, + exllama_state: 0, + q_matrix: std::ptr::null_mut(), + }) + } + QuantMethodConfig::Gptq { .. } + | QuantMethodConfig::Gguf { .. } + | QuantMethodConfig::Unquantized(_) + | QuantMethodConfig::Hqq { .. } => { + unreachable!() + } + } + } + + fn forward(&self, x: &Tensor) -> Result { + let out_shape = Shape::from_dims( + &[ + &x.dims()[..x.dims().len() - 1], + &[self.q_weight.dim(candle_core::D::Minus1)?], + ] + .concat(), + ); + let reshaped_x = x.reshape(((), x.dim(candle_core::D::Minus1)?))?; + let mut output = self.exl2_gemm(reshaped_x)?; + if let Some(bias) = &self.bias { + output = output.broadcast_add(bias)?; + } + output.reshape(out_shape) + } + + fn quantized_act_type(&self) -> Option { + Some(DType::F16) + } + + fn add_delta_w(&self, _delta: &Tensor) -> Result> { + candle_core::bail!("EXL2 quantization does not support adding weight delta.") + } + + fn dtype_and_device(&self) -> (DType, Device) { + todo!() + } + + fn get_bias_mut(&mut self) -> Option<&mut Tensor> { + None + } + + fn apply_isq( + self: Arc, + _dtype: Option, + _device: Device, + _n_quantized: &AtomicUsize, + ) -> Result> { + candle_core::bail!("EXL2 quantization does not support ISQ.") + } + + fn get_max_isq_cpu_threads(&self, _dtype: IsqType) -> Option { + None + } +} + + +impl Drop for Exl2Layer { + fn drop(&mut self) { + if !self.q_matrix.is_null() { + unsafe { + exl2_destroy_q_matrix(self.q_matrix); + } + } + } +} \ No newline at end of file diff --git a/mistralrs-quant/src/exl2/ffi.rs b/mistralrs-quant/src/exl2/ffi.rs index 5afbd7306..76e852c94 100644 --- a/mistralrs-quant/src/exl2/ffi.rs +++ b/mistralrs-quant/src/exl2/ffi.rs @@ -24,7 +24,7 @@ extern "C" { pub fn exl2_reconstruct_q_matrix(q_matrix: QMatrixPtr); - pub fn exl2_gemm( + pub fn exl2_gemm_cuda( a: *const f16, b: *const c_void, c: *mut f16, diff --git a/mistralrs-quant/src/gguf/mod.rs b/mistralrs-quant/src/gguf/mod.rs index 6cc411095..9f05087ab 100644 --- a/mistralrs-quant/src/gguf/mod.rs +++ b/mistralrs-quant/src/gguf/mod.rs @@ -27,7 +27,8 @@ impl QuantMethod for GgufMatMul { w: QMatMul::from_arc(q_weight)?, b, }), - QuantMethodConfig::Gptq { .. } + QuantMethodConfig::Exl2 { .. } + |QuantMethodConfig::Gptq { .. } | QuantMethodConfig::Unquantized(_) | QuantMethodConfig::Hqq { .. } => unreachable!(), } diff --git a/mistralrs-quant/src/gptq/gptq_cpu.rs b/mistralrs-quant/src/gptq/gptq_cpu.rs index cd2ff61a8..f2b16765e 100644 --- a/mistralrs-quant/src/gptq/gptq_cpu.rs +++ b/mistralrs-quant/src/gptq/gptq_cpu.rs @@ -23,7 +23,8 @@ impl QuantMethod for GptqLayer { g_idx: _, bias: _, } => candle_core::bail!("GPTQ is only supported on CUDA."), - QuantMethodConfig::Gguf { .. } + QuantMethodConfig::Exl2 { .. } + | QuantMethodConfig::Gguf { .. } | QuantMethodConfig::Unquantized(_) | QuantMethodConfig::Hqq { .. } => { unreachable!() diff --git a/mistralrs-quant/src/gptq/gptq_cuda.rs b/mistralrs-quant/src/gptq/gptq_cuda.rs index c9a06a360..45417c557 100644 --- a/mistralrs-quant/src/gptq/gptq_cuda.rs +++ b/mistralrs-quant/src/gptq/gptq_cuda.rs @@ -55,7 +55,7 @@ impl GptqLayer { "Expected `a` to be contiguous, got strides {:?}", a.layout().stride() ) - } + } let a_ptr = get_cuda_slice::(&a)?; let b_q_weight = get_cuda_slice::(&self.q_weight)? as *const u32; let b_gptq_qzeros = get_cuda_slice::(&self.gptq_qzeros)? as *const u32; @@ -236,8 +236,8 @@ impl QuantMethod for GptqLayer { use_exllama, bias, }) - } - QuantMethodConfig::Gguf { .. } + } QuantMethodConfig::Exl2 { .. } + | QuantMethodConfig::Gguf { .. } | QuantMethodConfig::Unquantized(_) | QuantMethodConfig::Hqq { .. } => { unreachable!() diff --git a/mistralrs-quant/src/hqq/mod.rs b/mistralrs-quant/src/hqq/mod.rs index 72220bddc..1a8bec23a 100644 --- a/mistralrs-quant/src/hqq/mod.rs +++ b/mistralrs-quant/src/hqq/mod.rs @@ -494,7 +494,8 @@ impl QuantMethod for HqqLayer { Self: Sized, { match method { - QuantMethodConfig::Gguf { .. } + QuantMethodConfig::Exl2 { .. } + | QuantMethodConfig::Gguf { .. } | QuantMethodConfig::Unquantized(_) | QuantMethodConfig::Gptq { .. } => { unreachable!() diff --git a/mistralrs-quant/src/lib.rs b/mistralrs-quant/src/lib.rs index 24d01e85b..fb8b8adc9 100644 --- a/mistralrs-quant/src/lib.rs +++ b/mistralrs-quant/src/lib.rs @@ -56,6 +56,7 @@ pub enum QuantMethodConfig { q_groups: Tensor, q_perm: Tensor, q_invperm: Tensor, + q_group_map: Tensor, bias: Option, }, Gptq { diff --git a/mistralrs-quant/src/unquantized/mod.rs b/mistralrs-quant/src/unquantized/mod.rs index 980918479..2c58b4536 100644 --- a/mistralrs-quant/src/unquantized/mod.rs +++ b/mistralrs-quant/src/unquantized/mod.rs @@ -21,7 +21,8 @@ impl QuantMethod for UnquantLinear { Self: Sized, { match method { - QuantMethodConfig::Gguf { .. } + QuantMethodConfig::Exl2 { .. } + | QuantMethodConfig::Gguf { .. } | QuantMethodConfig::Gptq { .. } | QuantMethodConfig::Hqq { .. } => unreachable!(), QuantMethodConfig::Unquantized(l) => Ok(Self(l)), diff --git a/mistralrs-quant/src/utils/ops.rs b/mistralrs-quant/src/utils/ops.rs index 7021f6e93..8a80df3d5 100644 --- a/mistralrs-quant/src/utils/ops.rs +++ b/mistralrs-quant/src/utils/ops.rs @@ -65,6 +65,7 @@ impl CustomOp2 for BitWiseOr { let result = CpuStorage::I32(result); Ok((result, l1.shape().clone())) } + CpuStorage::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "bitwise-or")), CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise-or")), CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise-or")), CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise-or")), @@ -125,6 +126,9 @@ impl CustomOp2 for BitWiseOr { let elem_count = l1.shape().elem_count(); (d_in1_ptr, d_in2_ptr, elem_count) } + DType::I16 => { + return Err(Error::UnsupportedDTypeForOp(DType::I16, "bitwise-or")); + } DType::BF16 => { return Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise-or")); } @@ -217,6 +221,7 @@ impl CustomOp1 for Leftshift { let result = CpuStorage::I32(result); Ok((result, l1.shape().clone())) } + CpuStorage::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "leftshifr")), CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "leftshifr")), CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "leftshifr")), CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "leftshifr")), @@ -249,6 +254,9 @@ impl CustomOp1 for Leftshift { let elem_count = l1.shape().elem_count(); (d_in1_ptr, elem_count) } + DType::I16 => { + return Err(Error::UnsupportedDTypeForOp(DType::I16, "leftshift")); + } DType::BF16 => { return Err(Error::UnsupportedDTypeForOp(DType::BF16, "leftshift")); } From 40ed615dda344aa2ebc19ad39145ecdeba724b9f Mon Sep 17 00:00:00 2001 From: ro99 Date: Thu, 12 Sep 2024 18:36:52 -0300 Subject: [PATCH 05/15] some fixes --- mistralrs-core/src/cuda/ffi.rs | 28 ++++ mistralrs-core/src/ops.rs | 18 +++ mistralrs-quant/src/exl2/exl2_cuda.rs | 182 +++++++++++++++++--------- mistralrs-quant/src/exl2/mod.rs | 7 +- 4 files changed, 170 insertions(+), 65 deletions(-) diff --git a/mistralrs-core/src/cuda/ffi.rs b/mistralrs-core/src/cuda/ffi.rs index 1a0ad9bca..21a26f20e 100644 --- a/mistralrs-core/src/cuda/ffi.rs +++ b/mistralrs-core/src/cuda/ffi.rs @@ -10,6 +10,7 @@ extern "C" { pub(crate) fn count_nonzero_u32(d_in: *const c_void, N: u32) -> u32; pub(crate) fn count_nonzero_i64(d_in: *const c_void, N: u32) -> u32; pub(crate) fn count_nonzero_i32(d_in: *const c_void, N: u32) -> u32; + pub(crate) fn count_nonzero_i16(d_in: *const c_void, N: u32) -> u32; pub(crate) fn nonzero_bf16( d_in: *const c_void, N: u32, @@ -74,6 +75,14 @@ extern "C" { num_dims: u32, d_out: *mut c_void, ); + pub(crate) fn nonzero_i16( + d_in: *const c_void, + N: u32, + num_nonzero: u32, + dims: *const c_void, + num_dims: u32, + d_out: *mut c_void, + ); pub(crate) fn bitwise_and_u8( d_in1: *const c_void, @@ -99,6 +108,12 @@ extern "C" { d_out: *mut c_void, N: u32, ); + pub(crate) fn bitwise_and_i16( + d_in1: *const c_void, + d_in2: *const c_void, + d_out: *mut c_void, + N: u32, + ); pub(crate) fn bitwise_or_u8( d_in1: *const c_void, d_in2: *const c_void, @@ -123,6 +138,12 @@ extern "C" { d_out: *mut c_void, N: u32, ); + pub(crate) fn bitwise_or_i16( + d_in1: *const c_void, + d_in2: *const c_void, + d_out: *mut c_void, + N: u32, + ); pub(crate) fn bitwise_xor_u8( d_in1: *const c_void, d_in2: *const c_void, @@ -147,9 +168,16 @@ extern "C" { d_out: *mut c_void, N: u32, ); + pub(crate) fn bitwise_xor_i16( + d_in1: *const c_void, + d_in2: *const c_void, + d_out: *mut c_void, + N: u32, + ); // Linked to in mistralrs-quant pub(crate) fn leftshift_u8(d_in1: *const c_void, d_out: *mut c_void, N: u32, k: i32); pub(crate) fn leftshift_u32(d_in1: *const c_void, d_out: *mut c_void, N: u32, k: i32); pub(crate) fn leftshift_i64(d_in1: *const c_void, d_out: *mut c_void, N: u32, k: i32); pub(crate) fn leftshift_i32(d_in1: *const c_void, d_out: *mut c_void, N: u32, k: i32); + pub(crate) fn leftshift_i16(d_in1: *const c_void, d_out: *mut c_void, N: u32, k: i32); } diff --git a/mistralrs-core/src/ops.rs b/mistralrs-core/src/ops.rs index 3d4633265..c477fb3f2 100644 --- a/mistralrs-core/src/ops.rs +++ b/mistralrs-core/src/ops.rs @@ -113,6 +113,12 @@ impl CustomOp2 for BitWise { let result = CpuStorage::I32(result); Ok((result, l1.shape().clone())) } + CpuStorage::I16(vs1) => { + let vs2 = s2.as_slice::().unwrap(); + let result = self.bitwise(vs1, vs2); + let result = CpuStorage::I16(result); + Ok((result, l1.shape().clone())) + } CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise")), CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise")), CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise")), @@ -167,6 +173,12 @@ impl CustomOp2 for BitWise { let elem_count = l1.shape().elem_count(); (d_in1_ptr, d_in2_ptr, elem_count) } + DType::I16 => { + let d_in1_ptr = *s1.as_cuda_slice::()?.device_ptr() as *const c_void; + let d_in2_ptr = *s2.as_cuda_slice::()?.device_ptr() as *const c_void; + let elem_count = l1.shape().elem_count(); + (d_in1_ptr, d_in2_ptr, elem_count) + } DType::BF16 => { return Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise")); } @@ -380,6 +392,7 @@ fn count_nonzero_cuda(dtype: candle_core::DType, d_in: *const c_void, n: u32) -> candle_core::DType::U32 => ffi::count_nonzero_u32(d_in, n), candle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n), candle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n), + candle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n), candle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n), candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n), candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n), @@ -410,6 +423,9 @@ fn nonzero_cuda( candle_core::DType::I32 => { ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out) } + candle_core::DType::I16 => { + ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out) + } candle_core::DType::BF16 => { ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out) } @@ -438,6 +454,7 @@ impl CustomOp1 for NonZero { let result = match storage { candle_core::CpuStorage::U8(vs) => self.nonzero(vs, layout), candle_core::CpuStorage::U32(vs) => self.nonzero(vs, layout), + candle_core::CpuStorage::I16(vs) => self.nonzero(vs, layout), candle_core::CpuStorage::I32(vs) => self.nonzero(vs, layout), candle_core::CpuStorage::I64(vs) => self.nonzero(vs, layout), candle_core::CpuStorage::BF16(vs) => self.nonzero(vs, layout), @@ -464,6 +481,7 @@ impl CustomOp1 for NonZero { let d_in = match storage.dtype() { candle_core::DType::U8 => *storage.as_cuda_slice::()?.device_ptr(), candle_core::DType::U32 => *storage.as_cuda_slice::()?.device_ptr(), + candle_core::DType::I16 => *storage.as_cuda_slice::()?.device_ptr(), candle_core::DType::I32 => *storage.as_cuda_slice::()?.device_ptr(), candle_core::DType::I64 => *storage.as_cuda_slice::()?.device_ptr(), candle_core::DType::BF16 => *storage.as_cuda_slice::()?.device_ptr(), diff --git a/mistralrs-quant/src/exl2/exl2_cuda.rs b/mistralrs-quant/src/exl2/exl2_cuda.rs index 744ce62d0..366601889 100644 --- a/mistralrs-quant/src/exl2/exl2_cuda.rs +++ b/mistralrs-quant/src/exl2/exl2_cuda.rs @@ -1,5 +1,4 @@ use std::{ - collections::HashMap, num::NonZeroUsize, sync::{atomic::AtomicUsize, Arc, Mutex}, }; @@ -8,11 +7,11 @@ use candle_core::{ cuda::{ cudarc::{ cublas::{result::hgemm, sys::cublasOperation_t}, - driver::{CudaSlice, DevicePtr}, + driver::DevicePtr, }, - CudaStorageSlice, WrapErr, + WrapErr, }, - from_storage_no_op, CudaStorage, DType, Device, Result, Shape, Storage, Tensor, D, + DType, Device, Result, Shape, Tensor, D, }; use half::f16; @@ -35,62 +34,114 @@ const BLOCK_M_SIZE_MAX: i32 = 8; pub struct Exl2Layer { q_weight: Tensor, q_scale: Tensor, - q_scale_max: Tensor, q_groups: Tensor, - q_perm: Tensor, q_invperm: Tensor, - q_group_map: Tensor, bias: Option, bits: i32, - exllama_state: i32, + exllama_state: Arc>, +} + +#[derive(Debug)] +struct ExllamaState { + initialized: bool, + q_scale_max: Tensor, + q_perm: Tensor, + q_invperm_short: Tensor, + q_group_map: Tensor, q_matrix: *mut std::ffi::c_void, } +unsafe impl Send for ExllamaState {} +unsafe impl Sync for ExllamaState {} + impl Exl2Layer { - fn exl2_gemm(&self, a: Tensor) -> Result { - let dev = get_cuda_device(&a)?; - let a_ptr = get_cuda_slice::(&a)?; + fn new( + q_weight: Tensor, + q_scale: Tensor, + q_scale_max: Tensor, + q_groups: Tensor, + q_perm: Tensor, + q_group_map: Tensor, + q_invperm: Tensor, + bias: Option, + bits: i32, + ) -> Result { + let exllama_state = Arc::new(Mutex::new(ExllamaState { + initialized: false, + q_scale_max, + q_perm, + q_group_map, + q_invperm_short: Tensor::zeros(q_invperm.shape(), DType::I16, q_invperm.device())?, + q_matrix: std::ptr::null_mut(), + })); + + Ok(Self { + q_weight, + q_scale, + q_groups, + q_invperm, + bias, + bits, + exllama_state, + }) + } - if self.exllama_state == 0 { - - self.q_scale_max = (self.q_scale_max / 256.0)?; - self.q_invperm = self.q_invperm.to_dtype(DType::I16)?; - self.q_perm = self.q_invperm.arg_sort_last_dim(false)?.to_dtype(DType::I16)?; - self.q_group_map = make_group_map(&self.q_groups, self.q_weight.dim(0)?)?; - - // QMatrix entries - let dev_ord = dev.ordinal() as i32; - let b_width = self.q_weight.dims()[1] as i32; - let b_height = self.q_perm.dims()[0] as i32; - let b_groups = self.q_scale.dims()[0] as i32; - let b_q_weight = get_cuda_slice::(&self.q_weight)? as *const u32; - let b_q_perm = get_cuda_slice::(&self.q_perm)? as *const u16; - let b_q_invperm = get_cuda_slice::(&self.q_invperm)? as *const u16; - let b_q_scale = get_cuda_slice::(&self.q_scale)? as *const u32; - let b_q_scale_max = get_cuda_slice::(&self.q_scale_max)?; - let b_q_groups = get_cuda_slice::(&self.q_groups)? as *const u16; - let b_q_group_map = get_cuda_slice::(&self.q_group_map)? as *const u16; - - self.q_matrix = unsafe { - exl2_create_q_matrix( - dev_ord, - b_height, - b_width, - b_groups, - b_q_weight, - b_q_perm, - b_q_invperm, - b_q_scale, - b_q_scale_max, - b_q_groups, - b_q_group_map, - ) - }; - self.exllama_state = 1; + pub fn post_init(&self) -> Result<()> { + self.initialize_exllama() + } + + fn initialize_exllama(&self) -> Result<()> { + let mut state = self.exllama_state.lock().unwrap(); + if state.initialized { + return Ok(()); } + let dev = get_cuda_device(&self.q_weight)?; + + state.q_scale_max = (state.q_scale_max.clone() / 256.0)?; + state.q_invperm_short = self.q_invperm.to_dtype(DType::I16)?; + state.q_perm = state.q_invperm_short.arg_sort_last_dim(false)?.to_dtype(DType::I16)?; + state.q_group_map = make_group_map(&self.q_groups, self.q_weight.dim(0)?)?; + + let dev_ord = dev.ordinal() as i32; + let b_width = self.q_weight.dims()[1] as i32; + let b_height = state.q_perm.dims()[0] as i32; + let b_groups = self.q_scale.dims()[0] as i32; + let b_q_weight = get_cuda_slice::(&self.q_weight)? as *const u32; + let b_q_perm = get_cuda_slice::(&state.q_perm)? as *const u16; + let b_q_invperm = get_cuda_slice::(&self.q_invperm)? as *const u16; + let b_q_scale = get_cuda_slice::(&self.q_scale)? as *const u32; + let b_q_scale_max = get_cuda_slice::(&state.q_scale_max)?; + let b_q_groups = get_cuda_slice::(&self.q_groups)? as *const u16; + let b_q_group_map = get_cuda_slice::(&state.q_group_map)? as *const u16; + + state.q_matrix = unsafe { + exl2_create_q_matrix( + dev_ord, + b_height, + b_width, + b_groups, + b_q_weight, + b_q_perm, + b_q_invperm, + b_q_scale, + b_q_scale_max, + b_q_groups, + b_q_group_map, + ) + }; + + state.initialized = true; + Ok(()) + } + fn exl2_gemm(&self, a: Tensor) -> Result { + self.initialize_exllama()?; + + let dev = get_cuda_device(&a)?; + let a_ptr = get_cuda_slice::(&a)?; + let qm_width = self.q_weight.dim(1)?; let c_shape = Shape::from_dims(&[a.dims()[0], qm_width]); @@ -103,21 +154,17 @@ impl Exl2Layer { let c = unsafe { dev.alloc::(c_shape.elem_count()).w()? }; let c_ptr = *c.device_ptr() as *mut f16; - // Create temp_dq as a Tensor, using a zero-sized tensor when not needed - // (TODO: review if this is the best solution here) - let temp_dq = if c_shape.dims()[0] > MAX_Q_GEMM_ROWS as usize { - Tensor::zeros(&[a.dims()[1], qm_width], DType::F16, a.device())? + let temp_dq = if m > MAX_Q_GEMM_ROWS { + Tensor::zeros(&[k as usize, n as usize], DType::F16, a.device())? } else { Tensor::zeros(&[0, 0], DType::F16, a.device())? }; - - let temp_dq_ptr = get_cuda_slice::(&temp_dq)?; if m > MAX_Q_GEMM_ROWS { // Reconstruct FP16 matrix, then cuBLAS unsafe { - exl2_reconstruct_q_matrix(self.q_matrix); + exl2_reconstruct_q_matrix(self.exllama_state.lock().unwrap().q_matrix); } let alpha = f16::from_f32(1.0); @@ -193,19 +240,23 @@ impl QuantMethod for Exl2Layer { bias, bits, } => { + let exllama_state = Arc::new(Mutex::new(ExllamaState { + initialized: false, + q_scale_max, + q_perm, + q_group_map, + q_invperm_short: Tensor::zeros(q_invperm.shape(), DType::I16, q_invperm.device())?, + q_matrix: std::ptr::null_mut(), + })); Ok(Self { q_weight, q_scale, - q_scale_max, q_groups, - q_perm, q_invperm, - q_group_map, bias, bits, - exllama_state: 0, - q_matrix: std::ptr::null_mut(), + exllama_state, }) } QuantMethodConfig::Gptq { .. } @@ -221,11 +272,11 @@ impl QuantMethod for Exl2Layer { let out_shape = Shape::from_dims( &[ &x.dims()[..x.dims().len() - 1], - &[self.q_weight.dim(candle_core::D::Minus1)?], + &[self.q_weight.dim(D::Minus1)?], ] .concat(), ); - let reshaped_x = x.reshape(((), x.dim(candle_core::D::Minus1)?))?; + let reshaped_x = x.reshape(((), x.dim(D::Minus1)?))?; let mut output = self.exl2_gemm(reshaped_x)?; if let Some(bias) = &self.bias { output = output.broadcast_add(bias)?; @@ -266,9 +317,12 @@ impl QuantMethod for Exl2Layer { impl Drop for Exl2Layer { fn drop(&mut self) { - if !self.q_matrix.is_null() { - unsafe { - exl2_destroy_q_matrix(self.q_matrix); + if let Ok(mut state) = self.exllama_state.lock() { + if !state.q_matrix.is_null() { + unsafe { + exl2_destroy_q_matrix(state.q_matrix); + } + state.q_matrix = std::ptr::null_mut(); } } } diff --git a/mistralrs-quant/src/exl2/mod.rs b/mistralrs-quant/src/exl2/mod.rs index c381dbcaa..9b19e2502 100644 --- a/mistralrs-quant/src/exl2/mod.rs +++ b/mistralrs-quant/src/exl2/mod.rs @@ -1,2 +1,7 @@ +#[cfg(feature = "cuda")] mod ffi; -mod exl2_cuda; \ No newline at end of file +#[cfg(feature = "cuda")] +mod exl2_cuda; + +#[cfg(feature = "cuda")] +pub use exl2_cuda::Exl2Layer; \ No newline at end of file From 02caefa89320209122c069c6885348df48715932 Mon Sep 17 00:00:00 2001 From: ro99 Date: Fri, 13 Sep 2024 12:11:30 -0300 Subject: [PATCH 06/15] Small fix on nonzero logic --- Cargo.lock | 47 ---------------------- Cargo.toml | 3 +- mistralrs-core/build.rs | 1 + mistralrs-core/src/cuda/nonzero_bitwise.cu | 1 + mistralrs-core/src/ops.rs | 2 +- mistralrs-paged-attn/build.rs | 1 + mistralrs-quant/build.rs | 1 + 7 files changed, 7 insertions(+), 49 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ebd034970..701109202 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -600,29 +600,6 @@ version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" -[[package]] -name = "cli-table" -version = "0.4.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b53f9241f288a7b12c56565f04aaeaeeab6b8923d42d99255d4ca428b4d97f89" -dependencies = [ - "cli-table-derive", - "csv", - "termcolor", - "unicode-width", -] - -[[package]] -name = "cli-table-derive" -version = "0.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e83a93253aaae7c74eb7428ce4faa6e219ba94886908048888701819f82fb94" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "color_quant" version = "1.1.0" @@ -2140,21 +2117,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "mistralrs-bench" -version = "0.3.0" -dependencies = [ - "anyhow", - "candle-core 0.6.1 (git+https://github.com/ro99/candle.git?rev=2ecc6cc)", - "clap", - "cli-table", - "mistralrs-core", - "serde", - "serde_json", - "tokio", - "tracing", -] - [[package]] name = "mistralrs-core" version = "0.3.0" @@ -3855,15 +3817,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "termcolor" -version = "1.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" -dependencies = [ - "winapi-util", -] - [[package]] name = "thiserror" version = "1.0.63" diff --git a/Cargo.toml b/Cargo.toml index 45e000244..0c91f0828 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,12 +4,13 @@ members = [ "mistralrs-core", "mistralrs-pyo3", "mistralrs", - "mistralrs-bench", + #"mistralrs-bench", "mistralrs-vision", "mistralrs-quant", ] exclude = [ "mistralrs-paged_attn", + "mistralrs-bench", ] resolver = "2" diff --git a/mistralrs-core/build.rs b/mistralrs-core/build.rs index 1fae6e92a..5dc2e6ae8 100644 --- a/mistralrs-core/build.rs +++ b/mistralrs-core/build.rs @@ -28,6 +28,7 @@ fn main() { // https://github.com/EricLBuehler/mistral.rs/issues/286 if let Some(cuda_nvcc_flags_env) = CUDA_NVCC_FLAGS { builder = builder.arg("--compiler-options"); + //builder = builder.arg("-fPIC -fPIE"); builder = builder.arg(cuda_nvcc_flags_env); } diff --git a/mistralrs-core/src/cuda/nonzero_bitwise.cu b/mistralrs-core/src/cuda/nonzero_bitwise.cu index 5a012dfb7..bd10bea8a 100644 --- a/mistralrs-core/src/cuda/nonzero_bitwise.cu +++ b/mistralrs-core/src/cuda/nonzero_bitwise.cu @@ -57,6 +57,7 @@ COUNT_NONZERO_OP(uint8_t, u8) COUNT_NONZERO_OP(uint32_t, u32) COUNT_NONZERO_OP(int64_t, i64) COUNT_NONZERO_OP(int32_t, i32) +COUNT_NONZERO_OP(int16_t, i16) __global__ void transform_indices(const uint32_t *temp_indices, const uint32_t num_nonzero, diff --git a/mistralrs-core/src/ops.rs b/mistralrs-core/src/ops.rs index c477fb3f2..38688be47 100644 --- a/mistralrs-core/src/ops.rs +++ b/mistralrs-core/src/ops.rs @@ -424,7 +424,7 @@ fn nonzero_cuda( ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out) } candle_core::DType::I16 => { - ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out) + ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out) } candle_core::DType::BF16 => { ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out) diff --git a/mistralrs-paged-attn/build.rs b/mistralrs-paged-attn/build.rs index 1f640bcdd..02f839e58 100644 --- a/mistralrs-paged-attn/build.rs +++ b/mistralrs-paged-attn/build.rs @@ -36,6 +36,7 @@ pub use backend::{copy_blocks, paged_attention, reshape_and_cache, swap_blocks}; // https://github.com/EricLBuehler/mistral.rs/issues/286 if let Some(cuda_nvcc_flags_env) = CUDA_NVCC_FLAGS { builder = builder.arg("--compiler-options"); + //builder = builder.arg("-fPIC -fPIE"); builder = builder.arg(cuda_nvcc_flags_env); } println!("cargo:info={builder:?}"); diff --git a/mistralrs-quant/build.rs b/mistralrs-quant/build.rs index 107bf7bc0..8dac2a614 100644 --- a/mistralrs-quant/build.rs +++ b/mistralrs-quant/build.rs @@ -33,6 +33,7 @@ fn main() { // https://github.com/EricLBuehler/mistral.rs/issues/286 if let Some(cuda_nvcc_flags_env) = CUDA_NVCC_FLAGS { builder = builder.arg("--compiler-options"); + //builder = builder.arg("-fPIC -fPIE"); builder = builder.arg(cuda_nvcc_flags_env); } From 5dff2447f0669fdc06aa338fd0f43f32c65c967a Mon Sep 17 00:00:00 2001 From: ro99 Date: Sat, 14 Sep 2024 14:04:16 -0300 Subject: [PATCH 07/15] EXL2 loader --- .gitignore | 2 + .idea/mistral.rs.iml | 19 +++++ .idea/modules.xml | 8 ++ .idea/vcs.xml | 6 ++ mistralrs-core/src/cuda/ffi.rs | 4 +- mistralrs-core/src/exl2/mod.rs | 87 ++++++++++++++++++++++ mistralrs-core/src/lib.rs | 1 + mistralrs-core/src/model_loader.rs | 50 +++++++++++++ mistralrs-core/src/model_selected.rs | 76 +++++++++++++++++++ mistralrs-core/src/pipeline/exl2.rs | 86 +++++++++++++++++++++ mistralrs-core/src/pipeline/loaders/mod.rs | 2 + mistralrs-core/src/pipeline/mod.rs | 2 + mistralrs-quant/src/exl2/exl2_cuda.rs | 36 ++++----- mistralrs-quant/src/exl2/ffi.rs | 17 ++--- mistralrs-quant/src/exl2/mod.rs | 6 +- mistralrs-quant/src/gguf/mod.rs | 2 +- mistralrs-quant/src/gptq/gptq_cuda.rs | 5 +- mistralrs-quant/src/utils/ops.rs | 2 +- 18 files changed, 370 insertions(+), 41 deletions(-) create mode 100644 .idea/mistral.rs.iml create mode 100644 .idea/modules.xml create mode 100644 .idea/vcs.xml create mode 100644 mistralrs-core/src/exl2/mod.rs create mode 100644 mistralrs-core/src/pipeline/exl2.rs diff --git a/.gitignore b/.gitignore index 83b45cf26..8b9bf1609 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ .vscode *.a .DS_Store +architecture.md +.gitignore diff --git a/.idea/mistral.rs.iml b/.idea/mistral.rs.iml new file mode 100644 index 000000000..29c033249 --- /dev/null +++ b/.idea/mistral.rs.iml @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 000000000..20815d2dd --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 000000000..35eb1ddfb --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/mistralrs-core/src/cuda/ffi.rs b/mistralrs-core/src/cuda/ffi.rs index 21a26f20e..7167dc672 100644 --- a/mistralrs-core/src/cuda/ffi.rs +++ b/mistralrs-core/src/cuda/ffi.rs @@ -10,7 +10,7 @@ extern "C" { pub(crate) fn count_nonzero_u32(d_in: *const c_void, N: u32) -> u32; pub(crate) fn count_nonzero_i64(d_in: *const c_void, N: u32) -> u32; pub(crate) fn count_nonzero_i32(d_in: *const c_void, N: u32) -> u32; - pub(crate) fn count_nonzero_i16(d_in: *const c_void, N: u32) -> u32; + pub(crate) fn count_nonzero_i16(d_in: *const c_void, N: u32) -> u32; pub(crate) fn nonzero_bf16( d_in: *const c_void, N: u32, @@ -82,7 +82,7 @@ extern "C" { dims: *const c_void, num_dims: u32, d_out: *mut c_void, - ); + ); pub(crate) fn bitwise_and_u8( d_in1: *const c_void, diff --git a/mistralrs-core/src/exl2/mod.rs b/mistralrs-core/src/exl2/mod.rs new file mode 100644 index 000000000..f1b165ed5 --- /dev/null +++ b/mistralrs-core/src/exl2/mod.rs @@ -0,0 +1,87 @@ +use anyhow::{Context, Result}; +use std::{num::NonZeroUsize, str::FromStr}; +use strum::EnumString; + +use crate::{pipeline::QuantizationKind, Loader, ModelDType, ModelKind, Topology}; + +pub const EXL2_MULTI_FILE_DELIMITER: &str = " "; + +#[derive(Debug, EnumString, Clone, Copy)] +#[strum(serialize_all = "kebab-case")] +pub enum EXL2Architecture { + Llama, + Mpt, + Gptneox, + Gptj, + Gpt2, + Bloom, + Falcon, + Mamba, + Rwkv, + Phi2, + Phi3, + Starcoder2, +} + +// Wraps from_str() for some convenience: +// - Case-insensitive variant matching (TODO: is this desirable?) +// - Customized error until potential upstream support: https://github.com/Peternator7/strum/issues/332 +impl EXL2Architecture { + pub fn from_value + std::fmt::Display>(value: T) -> Result { + Self::from_str(&value.as_ref().to_ascii_lowercase()) + .with_context(|| format!("Unknown EXL2 architecture `{value}`")) + .map_err(anyhow::Error::msg) + } +} + +pub struct EXL2LoaderBuilder { + model_id: Option, + quantized_model_id: String, + quantized_filenames: Vec, + kind: ModelKind, + config: EXL2SpecificConfig, +} + +pub struct EXL2SpecificConfig { + pub topology: Option, + pub gpu_split: Option, + pub length: Option, + pub rope_scale: Option, + pub rope_alpha: Option, + pub no_flash_attn: bool, + pub no_xformers: bool, + pub no_sdpa: bool, + pub low_mem: bool, + pub experts_per_token: Option, + pub load_q4: bool, + pub fast_safetensors: bool, + pub ignore_compatibility: bool, + pub chunk_size: Option, +} + +impl EXL2LoaderBuilder { + pub fn new( + chat_template: Option, + tok_model_id: Option, + quantized_model_id: String, + quantized_filenames: Vec, + config: EXL2SpecificConfig, + ) -> Self { + let kind = ModelKind::Quantized { + quant: QuantizationKind::Exl2, + }; + + Self { + model_id: tok_model_id, + quantized_model_id, + quantized_filenames, + kind, + config, + } + } + + pub fn build(self) -> Result> { + // Implement the loading logic for EXL2 models here + todo!("Implement EXL2 model loading") + } +} diff --git a/mistralrs-core/src/lib.rs b/mistralrs-core/src/lib.rs index 3a2553a3c..eda167708 100644 --- a/mistralrs-core/src/lib.rs +++ b/mistralrs-core/src/lib.rs @@ -39,6 +39,7 @@ mod amoe; mod cublaslt; #[cfg(not(all(feature = "cuda", target_family = "unix")))] mod dummy_paged_attention; +mod exl2; mod gguf; pub mod layers; mod layers_masker; diff --git a/mistralrs-core/src/model_loader.rs b/mistralrs-core/src/model_loader.rs index e4bfcc025..8808d9eff 100644 --- a/mistralrs-core/src/model_loader.rs +++ b/mistralrs-core/src/model_loader.rs @@ -3,7 +3,9 @@ use std::{ num::NonZeroUsize, }; +use crate::exl2::{EXL2LoaderBuilder, EXL2SpecificConfig}; use crate::{ + exl2::EXL2_MULTI_FILE_DELIMITER, get_toml_selected_model_dtype, pipeline::{GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, NormalSpecificConfig}, GGUFSpecificConfig, Loader, ModelDType, ModelSelected, NormalLoaderBuilder, TomlLoaderArgs, @@ -56,6 +58,7 @@ pub fn get_tgt_non_granular_index(model: &ModelSelected) -> Option { match model { ModelSelected::Plain { .. } | ModelSelected::Lora { .. } + | ModelSelected::EXL2 { .. } | ModelSelected::GGUF { .. } | ModelSelected::LoraGGUF { .. } | ModelSelected::GGML { .. } @@ -84,6 +87,7 @@ pub fn get_model_dtype(model: &ModelSelected) -> anyhow::Result { | ModelSelected::XLora { dtype, .. } | ModelSelected::VisionPlain { dtype, .. } => Ok(*dtype), ModelSelected::GGUF { .. } + | ModelSelected::EXL2 { .. } | ModelSelected::LoraGGUF { .. } | ModelSelected::GGML { .. } | ModelSelected::LoraGGML { .. } @@ -210,6 +214,52 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result EXL2LoaderBuilder::new( + args.chat_template, + tok_model_id, + quantized_model_id, + quantized_filename + .split(EXL2_MULTI_FILE_DELIMITER) + .map(ToOwned::to_owned) + .collect::>(), + EXL2SpecificConfig { + topology: Topology::from_option_path(topology)?, + gpu_split, + length, + rope_scale, + rope_alpha, + no_flash_attn, + no_xformers, + no_sdpa, + low_mem, + experts_per_token, + load_q4, + fast_safetensors, + ignore_compatibility, + chunk_size, + }, + ) + .build()?, ModelSelected::XLoraGGUF { tok_model_id, quantized_model_id, diff --git a/mistralrs-core/src/model_selected.rs b/mistralrs-core/src/model_selected.rs index d3cf3f806..10a5b9c3a 100644 --- a/mistralrs-core/src/model_selected.rs +++ b/mistralrs-core/src/model_selected.rs @@ -337,4 +337,80 @@ pub enum ModelSelected { #[arg(long)] topology: Option, }, + + /// Select an EXL2 model + EXL2 { + /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file. + /// If the `chat_template` is specified, then it will be treated as a path and used over remote files, + /// removing all remote accesses. + #[arg(short, long)] + tok_model_id: Option, + + /// Quantized model ID to find the `quantized_filename`. + /// This may be a HF hub repo or a local path. + #[arg(short = 'm', long)] + quantized_model_id: String, + + /// Quantized filename(s). + /// May be a single filename, or use a delimiter of " " (a single space) for multiple files. + #[arg(short = 'f', long)] + quantized_filename: String, + + /// Path to a topology YAML file. + #[arg(long)] + topology: Option, + + // Specific EXL2 args + /// "auto", or VRAM allocation per GPU in GB + #[arg(short, long)] + gpu_split: Option, + + /// Maximum sequence length + #[arg(short, long)] + length: Option, + + /// RoPE scaling factor + #[arg(short, long)] + rope_scale: Option, + + /// RoPE alpha value (NTK) + #[arg(short, long)] + rope_alpha: Option, + + /// Disable Flash Attention + #[arg(long, action)] + no_flash_attn: bool, + + /// Disable xformers, an alternative plan of flash attn for older devices + #[arg(long, action)] + no_xformers: bool, + + /// Disable Torch SDPA + #[arg(long, action)] + no_sdpa: bool, + + /// Enable VRAM optimizations, potentially trading off speed + #[arg(short, long, action)] + low_mem: bool, + + /// Override MoE model's default number of experts per token + #[arg(short, long)] + experts_per_token: Option, + + /// Load weights in Q4 mode + #[arg(long, action)] + load_q4: bool, + + /// Use alternative safetensors loader (with direct I/O when available) + #[arg(long, action)] + fast_safetensors: bool, + + /// Do not override model config options in case of compatibility issues + #[arg(short, long, action)] + ignore_compatibility: bool, + + /// Chunk size ('input length') + #[arg(long)] + chunk_size: Option, + }, } diff --git a/mistralrs-core/src/pipeline/exl2.rs b/mistralrs-core/src/pipeline/exl2.rs new file mode 100644 index 000000000..2cef958e7 --- /dev/null +++ b/mistralrs-core/src/pipeline/exl2.rs @@ -0,0 +1,86 @@ +use crate::{ChatTemplate, Loader}; +use anyhow::Result; +use std::sync::Arc; +use tokenizers::Tokenizer; + +use crate::{ + models::quantized_llama::ModelWeights as QLlama, + models::quantized_phi2::ModelWeights as QPhi, + models::quantized_phi3::ModelWeights as QPhi3, + models::quantized_starcoder2::ModelWeights as QStarcoder2, + xlora_models::{XLoraQLlama, XLoraQPhi3}, +}; + +use super::GeneralMetadata; + +enum Model { + Llama(QLlama), + Phi2(QPhi), + XLoraLlama(XLoraQLlama), + XLoraPhi3(XLoraQPhi3), + Phi3(QPhi3), + Starcoder2(QStarcoder2), +} + +pub struct EXL2Pipeline { + model: Model, + tokenizer: Arc, + chat_template: Arc, + model_id: String, + metadata: Arc, +} + +pub struct EXL2Loader { + tok_model_id: Option, + quantized_model_id: String, + quantized_filename: String, + config: EXL2SpecificConfig, +} + +pub struct EXL2SpecificConfig { + pub gpu_split: Option, + pub length: Option, + pub rope_scale: Option, + pub rope_alpha: Option, + pub no_flash_attn: bool, + pub no_xformers: bool, + pub no_sdpa: bool, + pub low_mem: bool, + pub experts_per_token: Option, + pub load_q4: bool, + pub fast_safetensors: bool, + pub ignore_compatibility: bool, + pub chunk_size: Option, +} + +pub struct EXL2LoaderBuilder { + tok_model_id: Option, + quantized_model_id: String, + quantized_filename: String, + topology: Option, + config: EXL2SpecificConfig, +} + +impl EXL2LoaderBuilder { + pub fn new( + tok_model_id: Option, + quantized_model_id: String, + quantized_filename: String, + topology: Option, + config: EXL2SpecificConfig, + ) -> Self { + Self { + tok_model_id, + quantized_model_id, + quantized_filename, + topology, + config: EXL2SpecificConfig { ..config }, + } + } + + pub fn build(self) -> Result> { + // Implementation details for building the EXL2 loader would go here + // This is a placeholder and would need to be filled in with the actual implementation + todo!("Implement EXL2 loader building") + } +} diff --git a/mistralrs-core/src/pipeline/loaders/mod.rs b/mistralrs-core/src/pipeline/loaders/mod.rs index 11b69aaf5..3fec299bd 100644 --- a/mistralrs-core/src/pipeline/loaders/mod.rs +++ b/mistralrs-core/src/pipeline/loaders/mod.rs @@ -254,6 +254,8 @@ pub enum QuantizationKind { Gguf, /// GPTQ Gptq, + /// EXL2 + Exl2, } #[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)] diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index b4ee8c39c..764e7d00b 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -1,6 +1,7 @@ mod amoe; mod cache_manager; pub mod chat_template; +mod exl2; mod ggml; mod gguf; mod inputs_processor; @@ -13,6 +14,7 @@ mod processing; mod sampling; mod speculative; mod vision; + use crate::aici::toktree::TokTrie; use crate::amoe::{AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainingInputs, AnyMoeTrainingResult}; use crate::paged_attention::{CacheConfig, CacheEngine}; diff --git a/mistralrs-quant/src/exl2/exl2_cuda.rs b/mistralrs-quant/src/exl2/exl2_cuda.rs index 366601889..9bff23c41 100644 --- a/mistralrs-quant/src/exl2/exl2_cuda.rs +++ b/mistralrs-quant/src/exl2/exl2_cuda.rs @@ -20,16 +20,11 @@ use crate::{ IsqType, QuantMethod, QuantMethodConfig, }; -use super::ffi::{ - exl2_reconstruct_q_matrix, - exl2_create_q_matrix, - exl2_destroy_q_matrix -}; +use super::ffi::{exl2_create_q_matrix, exl2_destroy_q_matrix, exl2_reconstruct_q_matrix}; const MAX_Q_GEMM_ROWS: i32 = 32; const BLOCK_M_SIZE_MAX: i32 = 8; - #[derive(Debug)] pub struct Exl2Layer { q_weight: Tensor, @@ -98,9 +93,12 @@ impl Exl2Layer { let dev = get_cuda_device(&self.q_weight)?; - state.q_scale_max = (state.q_scale_max.clone() / 256.0)?; + state.q_scale_max = (state.q_scale_max.clone() / 256.0)?; state.q_invperm_short = self.q_invperm.to_dtype(DType::I16)?; - state.q_perm = state.q_invperm_short.arg_sort_last_dim(false)?.to_dtype(DType::I16)?; + state.q_perm = state + .q_invperm_short + .arg_sort_last_dim(false)? + .to_dtype(DType::I16)?; state.q_group_map = make_group_map(&self.q_groups, self.q_weight.dim(0)?)?; let dev_ord = dev.ordinal() as i32; @@ -135,14 +133,13 @@ impl Exl2Layer { Ok(()) } - fn exl2_gemm(&self, a: Tensor) -> Result { self.initialize_exllama()?; - + let dev = get_cuda_device(&a)?; let a_ptr = get_cuda_slice::(&a)?; - let qm_width = self.q_weight.dim(1)?; + let qm_width = self.q_weight.dim(1)?; let c_shape = Shape::from_dims(&[a.dims()[0], qm_width]); let (m, n, k) = ( @@ -166,7 +163,7 @@ impl Exl2Layer { unsafe { exl2_reconstruct_q_matrix(self.exllama_state.lock().unwrap().q_matrix); } - + let alpha = f16::from_f32(1.0); let beta = f16::from_f32(0.0); let cublas_handle = match a.device() { @@ -193,7 +190,6 @@ impl Exl2Layer { ) .w()? }; - } else { todo!() } @@ -201,8 +197,6 @@ impl Exl2Layer { } } - - fn make_group_map(q_groups: &Tensor, num_qrows: usize) -> Result { let gr = q_groups.to_vec1::()?; let mut group_map = Vec::new(); @@ -225,7 +219,6 @@ fn make_group_map(q_groups: &Tensor, num_qrows: usize) -> Result { Tensor::from_vec(group_map.clone(), (group_map.len(),), q_groups.device()) } - impl QuantMethod for Exl2Layer { fn new(method: QuantMethodConfig) -> Result { match method { @@ -245,7 +238,11 @@ impl QuantMethod for Exl2Layer { q_scale_max, q_perm, q_group_map, - q_invperm_short: Tensor::zeros(q_invperm.shape(), DType::I16, q_invperm.device())?, + q_invperm_short: Tensor::zeros( + q_invperm.shape(), + DType::I16, + q_invperm.device(), + )?, q_matrix: std::ptr::null_mut(), })); @@ -283,7 +280,7 @@ impl QuantMethod for Exl2Layer { } output.reshape(out_shape) } - + fn quantized_act_type(&self) -> Option { Some(DType::F16) } @@ -314,7 +311,6 @@ impl QuantMethod for Exl2Layer { } } - impl Drop for Exl2Layer { fn drop(&mut self) { if let Ok(mut state) = self.exllama_state.lock() { @@ -326,4 +322,4 @@ impl Drop for Exl2Layer { } } } -} \ No newline at end of file +} diff --git a/mistralrs-quant/src/exl2/ffi.rs b/mistralrs-quant/src/exl2/ffi.rs index 76e852c94..6ee7e59a6 100644 --- a/mistralrs-quant/src/exl2/ffi.rs +++ b/mistralrs-quant/src/exl2/ffi.rs @@ -8,9 +8,9 @@ type QMatrixPtr = *mut c_void; extern "C" { pub fn exl2_create_q_matrix( device: i32, - height: i32, // q_perm.size(0); - width: i32, // q_weight.size(1); - groups: i32, // q_scale.size(0); + height: i32, // q_perm.size(0); + width: i32, // q_weight.size(1); + groups: i32, // q_scale.size(0); q_weight: *const u32, q_perm: *const u16, q_invperm: *const u16, @@ -24,12 +24,5 @@ extern "C" { pub fn exl2_reconstruct_q_matrix(q_matrix: QMatrixPtr); - pub fn exl2_gemm_cuda( - a: *const f16, - b: *const c_void, - c: *mut f16, - m: i32, - n: i32, - k: i32, - ); -} \ No newline at end of file + pub fn exl2_gemm_cuda(a: *const f16, b: *const c_void, c: *mut f16, m: i32, n: i32, k: i32); +} diff --git a/mistralrs-quant/src/exl2/mod.rs b/mistralrs-quant/src/exl2/mod.rs index 9b19e2502..309bcebf2 100644 --- a/mistralrs-quant/src/exl2/mod.rs +++ b/mistralrs-quant/src/exl2/mod.rs @@ -1,7 +1,7 @@ #[cfg(feature = "cuda")] -mod ffi; -#[cfg(feature = "cuda")] mod exl2_cuda; +#[cfg(feature = "cuda")] +mod ffi; #[cfg(feature = "cuda")] -pub use exl2_cuda::Exl2Layer; \ No newline at end of file +pub use exl2_cuda::Exl2Layer; diff --git a/mistralrs-quant/src/gguf/mod.rs b/mistralrs-quant/src/gguf/mod.rs index 9f05087ab..de08b9677 100644 --- a/mistralrs-quant/src/gguf/mod.rs +++ b/mistralrs-quant/src/gguf/mod.rs @@ -28,7 +28,7 @@ impl QuantMethod for GgufMatMul { b, }), QuantMethodConfig::Exl2 { .. } - |QuantMethodConfig::Gptq { .. } + | QuantMethodConfig::Gptq { .. } | QuantMethodConfig::Unquantized(_) | QuantMethodConfig::Hqq { .. } => unreachable!(), } diff --git a/mistralrs-quant/src/gptq/gptq_cuda.rs b/mistralrs-quant/src/gptq/gptq_cuda.rs index 45417c557..03baa664c 100644 --- a/mistralrs-quant/src/gptq/gptq_cuda.rs +++ b/mistralrs-quant/src/gptq/gptq_cuda.rs @@ -55,7 +55,7 @@ impl GptqLayer { "Expected `a` to be contiguous, got strides {:?}", a.layout().stride() ) - } + } let a_ptr = get_cuda_slice::(&a)?; let b_q_weight = get_cuda_slice::(&self.q_weight)? as *const u32; let b_gptq_qzeros = get_cuda_slice::(&self.gptq_qzeros)? as *const u32; @@ -236,7 +236,8 @@ impl QuantMethod for GptqLayer { use_exllama, bias, }) - } QuantMethodConfig::Exl2 { .. } + } + QuantMethodConfig::Exl2 { .. } | QuantMethodConfig::Gguf { .. } | QuantMethodConfig::Unquantized(_) | QuantMethodConfig::Hqq { .. } => { diff --git a/mistralrs-quant/src/utils/ops.rs b/mistralrs-quant/src/utils/ops.rs index 8a80df3d5..64aaf7a00 100644 --- a/mistralrs-quant/src/utils/ops.rs +++ b/mistralrs-quant/src/utils/ops.rs @@ -256,7 +256,7 @@ impl CustomOp1 for Leftshift { } DType::I16 => { return Err(Error::UnsupportedDTypeForOp(DType::I16, "leftshift")); - } + } DType::BF16 => { return Err(Error::UnsupportedDTypeForOp(DType::BF16, "leftshift")); } From aa2a308eaf3a181bb4e7d0311731c407240019d6 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 19 Sep 2024 09:55:46 -0400 Subject: [PATCH 08/15] Cleanup --- .idea/mistral.rs.iml | 19 -------- .idea/modules.xml | 8 --- .idea/vcs.xml | 6 --- Cargo.lock | 114 ++++++++++++++++++++++++------------------- 4 files changed, 65 insertions(+), 82 deletions(-) delete mode 100644 .idea/mistral.rs.iml delete mode 100644 .idea/modules.xml delete mode 100644 .idea/vcs.xml diff --git a/.idea/mistral.rs.iml b/.idea/mistral.rs.iml deleted file mode 100644 index 29c033249..000000000 --- a/.idea/mistral.rs.iml +++ /dev/null @@ -1,19 +0,0 @@ - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml deleted file mode 100644 index 20815d2dd..000000000 --- a/.idea/modules.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml deleted file mode 100644 index 35eb1ddfb..000000000 --- a/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 701109202..8b1c75404 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,9 +20,9 @@ checksum = "415ed64958754dbe991900f3940677e6a7eefb4d7367afd70d642677b0c7d19d" [[package]] name = "addr2line" -version = "0.22.0" +version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678" +checksum = "f5fb1d8e4442bd405fdfd1dacb42792696b0cf9cb15882e5d097b742a676d375" dependencies = [ "gimli", ] @@ -138,9 +138,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.86" +version = "1.0.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +checksum = "4e1496f8fb1fbf272686b8d37f523dab3e4a7443300055e74cdaa449f3114356" dependencies = [ "backtrace", ] @@ -240,17 +240,17 @@ dependencies = [ [[package]] name = "backtrace" -version = "0.3.73" +version = "0.3.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cc23269a4f8976d0a4d2e7109211a419fe30e8d88d677cd60b6bc79c5732e0a" +checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" dependencies = [ "addr2line", - "cc", "cfg-if", "libc", - "miniz_oxide 0.7.4", + "miniz_oxide 0.8.0", "object", "rustc-demangle", + "windows-targets 0.52.6", ] [[package]] @@ -491,9 +491,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.16" +version = "1.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9d013ecb737093c0e86b151a7b837993cf9ec6c502946cfb44bedc392421e0b" +checksum = "b62ac837cdb5cb22e10a256099b4fc502b1dfe560cb282963a974d7abd80e476" dependencies = [ "shlex", ] @@ -588,7 +588,7 @@ version = "4.5.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "501d359d5f3dcaf6ecdeee48833ae73ec6e42723a1e52419c79abf9507eec0a0" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "syn 2.0.77", @@ -1022,11 +1022,11 @@ checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" [[package]] name = "enum-as-inner" -version = "0.6.0" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ffccbb6966c05b32ef8fbac435df276c4ae4d3dc55a8cd0eb9745e6c12f546a" +checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" dependencies = [ - "heck 0.4.1", + "heck", "proc-macro2", "quote", "syn 2.0.77", @@ -1426,14 +1426,14 @@ dependencies = [ [[package]] name = "getset" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e45727250e75cc04ff2846a66397da8ef2b3db8e40e0cef4df67950a07621eb9" +checksum = "f636605b743120a8d32ed92fc27b6cde1a769f8f936c065151eb66f88ded513c" dependencies = [ - "proc-macro-error", + "proc-macro-error2", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.77", ] [[package]] @@ -1448,9 +1448,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.29.0" +version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" +checksum = "32085ea23f3234fc7846555e85283ba4de91e21016dc0455a16286d87a292d64" [[package]] name = "glob" @@ -1501,12 +1501,6 @@ dependencies = [ "allocator-api2", ] -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" - [[package]] name = "heck" version = "0.5.0" @@ -1639,9 +1633,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.7" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9" +checksum = "da62f120a8a37763efb0cf8fdf264b884c7b8b9ac8660b900c8661030c00e6ba" dependencies = [ "bytes", "futures-channel", @@ -1796,9 +1790,9 @@ dependencies = [ [[package]] name = "ipnet" -version = "2.9.0" +version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" +checksum = "187674a687eed5fe42285b40c6291f9a01517d415fad1c3cbc6a9f778af7fcd4" [[package]] name = "is_terminal_polyfill" @@ -1982,9 +1976,9 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "memmap2" -version = "0.9.4" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe751422e4a8caa417e13c3ea66452215d7d63e19e604f4980461212f3ae1322" +checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f" dependencies = [ "libc", "stable_deref_trait", @@ -2793,6 +2787,28 @@ dependencies = [ "version_check", ] +[[package]] +name = "proc-macro-error-attr2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" +dependencies = [ + "proc-macro2", + "quote", +] + +[[package]] +name = "proc-macro-error2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" +dependencies = [ + "proc-macro-error-attr2", + "proc-macro2", + "quote", + "syn 2.0.77", +] + [[package]] name = "proc-macro2" version = "1.0.86" @@ -2883,7 +2899,7 @@ version = "0.22.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "pyo3-build-config", "quote", @@ -3069,9 +3085,9 @@ checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" [[package]] name = "redox_syscall" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4" +checksum = "0884ad60e090bf1345b93da0a5de8923c93884cd03f40dfcfddd3b4bee661853" dependencies = [ "bitflags 2.6.0", ] @@ -3262,9 +3278,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.36" +version = "0.38.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f55e80d50763938498dd5ebb18647174e0c76dc38c5505294bb224624f30f36" +checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" dependencies = [ "bitflags 2.6.0", "errno", @@ -3275,9 +3291,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.12" +version = "0.23.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c58f8c84392efc0a126acce10fa59ff7b3d2ac06ab451a33f2741989b806b044" +checksum = "f2dabaac7466917e566adb06783a81ca48944c6898a1b08b9374106dd671f4c8" dependencies = [ "log", "once_cell", @@ -3306,9 +3322,9 @@ checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" [[package]] name = "rustls-webpki" -version = "0.102.7" +version = "0.102.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84678086bd54edf2b415183ed7a94d0efb049f1b646a33e22a36f3794be6ae56" +checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" dependencies = [ "ring", "rustls-pki-types", @@ -3348,11 +3364,11 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.23" +version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" +checksum = "e9aaafd5a2b6e3d657ff009d82fbd630b6bd54dd4eb06f21693925cdf80f9b8b" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -3676,7 +3692,7 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "rustversion", @@ -4184,9 +4200,9 @@ checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" [[package]] name = "unicode-ident" -version = "1.0.12" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" [[package]] name = "unicode-normalization" @@ -4208,9 +4224,9 @@ dependencies = [ [[package]] name = "unicode-segmentation" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" [[package]] name = "unicode-width" From 13ee35ef52d30b6bdbd470f570b8bd7bc24a9999 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 19 Sep 2024 10:06:53 -0400 Subject: [PATCH 09/15] Use my candle --- Cargo.lock | 98 +++++++++++++++++++++++++++--------------------------- Cargo.toml | 4 +-- 2 files changed, 51 insertions(+), 51 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8b1c75404..65848fccf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -138,9 +138,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.88" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e1496f8fb1fbf272686b8d37f523dab3e4a7443300055e74cdaa449f3114356" +checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" dependencies = [ "backtrace", ] @@ -380,26 +380,21 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.7.1" +version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50" +checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" [[package]] name = "candle-core" version = "0.6.1" -source = "git+https://github.com/ro99/candle.git?rev=2ecc6cc#2ecc6cc071b6b6c68062f327aa8343ad08dbde83" +source = "git+https://github.com/EricLBuehler/candle.git?rev=7f5a470#7f5a47040e798f0076014c9d9e82cc6cb25708a0" dependencies = [ - "accelerate-src", "byteorder", - "candle-kernels 0.6.1 (git+https://github.com/ro99/candle.git?rev=2ecc6cc)", - "candle-metal-kernels", + "candle-kernels 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=7f5a470)", "cudarc", "gemm", "half", - "intel-mkl-src", - "libc", "memmap2", - "metal", "num-traits", "num_cpus", "rand", @@ -414,14 +409,19 @@ dependencies = [ [[package]] name = "candle-core" version = "0.6.1" -source = "git+https://github.com/EricLBuehler/candle.git?rev=7f5a470#7f5a47040e798f0076014c9d9e82cc6cb25708a0" +source = "git+https://github.com/EricLBuehler/candle.git?rev=9e31a19#9e31a192642b4048e3df75173efddebaf663fef2" dependencies = [ + "accelerate-src", "byteorder", - "candle-kernels 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=7f5a470)", + "candle-kernels 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", + "candle-metal-kernels", "cudarc", "gemm", "half", + "intel-mkl-src", + "libc", "memmap2", + "metal", "num-traits", "num_cpus", "rand", @@ -447,7 +447,7 @@ dependencies = [ [[package]] name = "candle-kernels" version = "0.6.1" -source = "git+https://github.com/ro99/candle.git?rev=2ecc6cc#2ecc6cc071b6b6c68062f327aa8343ad08dbde83" +source = "git+https://github.com/EricLBuehler/candle.git?rev=7f5a470#7f5a47040e798f0076014c9d9e82cc6cb25708a0" dependencies = [ "bindgen_cuda 0.1.5", ] @@ -455,7 +455,7 @@ dependencies = [ [[package]] name = "candle-kernels" version = "0.6.1" -source = "git+https://github.com/EricLBuehler/candle.git?rev=7f5a470#7f5a47040e798f0076014c9d9e82cc6cb25708a0" +source = "git+https://github.com/EricLBuehler/candle.git?rev=9e31a19#9e31a192642b4048e3df75173efddebaf663fef2" dependencies = [ "bindgen_cuda 0.1.5", ] @@ -463,7 +463,7 @@ dependencies = [ [[package]] name = "candle-metal-kernels" version = "0.6.1" -source = "git+https://github.com/ro99/candle.git?rev=2ecc6cc#2ecc6cc071b6b6c68062f327aa8343ad08dbde83" +source = "git+https://github.com/EricLBuehler/candle.git?rev=9e31a19#9e31a192642b4048e3df75173efddebaf663fef2" dependencies = [ "metal", "once_cell", @@ -474,10 +474,10 @@ dependencies = [ [[package]] name = "candle-nn" version = "0.6.1" -source = "git+https://github.com/ro99/candle.git?rev=2ecc6cc#2ecc6cc071b6b6c68062f327aa8343ad08dbde83" +source = "git+https://github.com/EricLBuehler/candle.git?rev=9e31a19#9e31a192642b4048e3df75173efddebaf663fef2" dependencies = [ "accelerate-src", - "candle-core 0.6.1 (git+https://github.com/ro99/candle.git?rev=2ecc6cc)", + "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", "candle-metal-kernels", "half", "intel-mkl-src", @@ -491,9 +491,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.18" +version = "1.1.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b62ac837cdb5cb22e10a256099b4fc502b1dfe560cb282963a974d7abd80e476" +checksum = "07b1695e2c7e8fc85310cde85aeaab7e3097f593c91d209d3f9df76c928100f0" dependencies = [ "shlex", ] @@ -1653,9 +1653,9 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.60" +version = "0.1.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -2026,9 +2026,9 @@ dependencies = [ [[package]] name = "minijinja" -version = "2.2.0" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d7d3e3a3eece1fa4618237ad41e1de855ced47eab705cec1c9a920e1d1c5aad" +checksum = "1028b628753a7e1a88fc59c9ba4b02ecc3bc0bd3c7af23df667bc28df9b3310e" dependencies = [ "serde", "serde_json", @@ -2036,9 +2036,9 @@ dependencies = [ [[package]] name = "minijinja-contrib" -version = "2.2.0" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "744a2b84dbd22398e347594ed2aef9d3f1b948934e3e6e94ef69ecd39d597f4b" +checksum = "39ffd46ee854be23604a20efd6c9655374fefbe4d44b949dc0f907305d92873a" dependencies = [ "minijinja", "serde", @@ -2098,7 +2098,7 @@ name = "mistralrs" version = "0.3.0" dependencies = [ "anyhow", - "candle-core 0.6.1 (git+https://github.com/ro99/candle.git?rev=2ecc6cc)", + "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", "either", "futures", "image", @@ -2124,7 +2124,7 @@ dependencies = [ "buildstructor", "bytemuck", "bytemuck_derive", - "candle-core 0.6.1 (git+https://github.com/ro99/candle.git?rev=2ecc6cc)", + "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", "candle-flash-attn", "candle-nn", "cfgrammar", @@ -2185,7 +2185,7 @@ version = "0.3.0" dependencies = [ "anyhow", "bindgen_cuda 0.1.6", - "candle-core 0.6.1 (git+https://github.com/ro99/candle.git?rev=2ecc6cc)", + "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", "half", ] @@ -2196,7 +2196,7 @@ dependencies = [ "accelerate-src", "anyhow", "base64 0.22.1", - "candle-core 0.6.1 (git+https://github.com/ro99/candle.git?rev=2ecc6cc)", + "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", "data-url", "either", "futures", @@ -2218,7 +2218,7 @@ name = "mistralrs-quant" version = "0.3.0" dependencies = [ "bindgen_cuda 0.1.5", - "candle-core 0.6.1 (git+https://github.com/ro99/candle.git?rev=2ecc6cc)", + "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", "candle-nn", "half", "lazy_static", @@ -2235,7 +2235,7 @@ dependencies = [ "accelerate-src", "anyhow", "axum", - "candle-core 0.6.1 (git+https://github.com/ro99/candle.git?rev=2ecc6cc)", + "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", "clap", "ctrlc", "data-url", @@ -2261,7 +2261,7 @@ dependencies = [ name = "mistralrs-vision" version = "0.3.0" dependencies = [ - "candle-core 0.6.1 (git+https://github.com/ro99/candle.git?rev=2ecc6cc)", + "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", "image", ] @@ -2832,9 +2832,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.22.2" +version = "0.22.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433" +checksum = "15ee168e30649f7f234c3d49ef5a7a6cbf5134289bc46c29ff3155fa3221c225" dependencies = [ "anyhow", "cfg-if", @@ -2863,9 +2863,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.22.2" +version = "0.22.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8" +checksum = "e61cef80755fe9e46bb8a0b8f20752ca7676dcc07a5277d8b7768c6172e529b3" dependencies = [ "once_cell", "target-lexicon", @@ -2873,9 +2873,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.22.2" +version = "0.22.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6" +checksum = "67ce096073ec5405f5ee2b8b31f03a68e02aa10d5d4f565eca04acc41931fa1c" dependencies = [ "libc", "pyo3-build-config", @@ -2883,9 +2883,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.22.2" +version = "0.22.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206" +checksum = "2440c6d12bc8f3ae39f1e775266fa5122fd0c8891ce7520fa6048e683ad3de28" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -2895,9 +2895,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.22.2" +version = "0.22.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372" +checksum = "1be962f0e06da8f8465729ea2cb71a416d2257dff56cbe40a70d3e62a93ae5d1" dependencies = [ "heck", "proc-macro2", @@ -4017,9 +4017,9 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.22.20" +version = "0.22.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "583c44c02ad26b0c3f3066fe629275e50627026c51ac2e595cca4c230ce1ce1d" +checksum = "3b072cee73c449a636ffd6f32bd8de3a9f7119139aff882f44943ce2986dc5cf" dependencies = [ "indexmap", "serde", @@ -4206,9 +4206,9 @@ checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" [[package]] name = "unicode-normalization" -version = "0.1.23" +version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" +checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" dependencies = [ "tinyvec", ] @@ -4502,9 +4502,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.26.5" +version = "0.26.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bd24728e5af82c6c4ec1b66ac4844bdf8156257fccda846ec58b42cd0cdbe6a" +checksum = "841c67bff177718f1d4dfefde8d8f0e78f9b6589319ba88312f567fc5841a958" dependencies = [ "rustls-pki-types", ] diff --git a/Cargo.toml b/Cargo.toml index 0c91f0828..ffc34bba0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,8 +26,8 @@ license = "MIT" [workspace.dependencies] anyhow = "1.0.80" -candle-core = { git = "https://github.com/ro99/candle.git", version = "0.6.0", rev = "2ecc6cc" } -candle-nn = { git = "https://github.com/ro99/candle.git", version = "0.6.0", rev = "2ecc6cc" } +candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "9e31a19" } +candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "9e31a19" } serde = "1.0.197" serde_json = "1.0.114" indexmap = { version = "2.2.5", features = ["serde"] } From e11eed63f8d03dfb7f5ce50dbdaada9b4bae61a2 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 19 Sep 2024 10:08:26 -0400 Subject: [PATCH 10/15] Complete merge --- Cargo.lock | 4878 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 4878 insertions(+) create mode 100644 Cargo.lock diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 000000000..20ad3e48e --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,4878 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "Inflector" +version = "0.11.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe438c63458706e03479442743baae6c88256498e6431708f6dfc520a26515d3" +dependencies = [ + "lazy_static", + "regex", +] + +[[package]] +name = "accelerate-src" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "415ed64958754dbe991900f3940677e6a7eefb4d7367afd70d642677b0c7d19d" + +[[package]] +name = "addr2line" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5fb1d8e4442bd405fdfd1dacb42792696b0cf9cb15882e5d097b742a676d375" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "akin" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1763692fc1416554cf051efc56a3de5595eca47299d731cc5c2b583adf8b4d2f" + +[[package]] +name = "allocator-api2" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" + +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "anstream" +version = "0.6.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" + +[[package]] +name = "anstyle-parse" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8" +dependencies = [ + "anstyle", + "windows-sys 0.52.0", +] + +[[package]] +name = "anyhow" +version = "1.0.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" +dependencies = [ + "backtrace", +] + +[[package]] +name = "arbitrary" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d5a26814d8dcb93b0e5a0ff3c6d80a8843bafb21b39e8e18a6f05471870e110" +dependencies = [ + "derive_arbitrary", +] + +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + +[[package]] +name = "async-trait" +version = "0.1.82" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a27b8a3a6e1a44fa4c8baf1f653e4172e81486d4941f2237e20dc2d0cf4ddff1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + +[[package]] +name = "autocfg" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" + +[[package]] +name = "axum" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf" +dependencies = [ + "async-trait", + "axum-core", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper 1.0.1", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a15c63fd72d41492dc4f497196f5da1fb04fb7529e631d73630d1b491e47a2e3" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper 0.1.2", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "backtrace" +version = "0.3.74" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" +dependencies = [ + "addr2line", + "cfg-if", + "libc", + "miniz_oxide 0.8.0", + "object", + "rustc-demangle", + "windows-targets 0.52.6", +] + +[[package]] +name = "base16ct" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" + +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bindgen_cuda" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f8489af5b7d17a81bffe37e0f4d6e1e4de87c87329d05447f22c35d95a1227d" +dependencies = [ + "glob", + "num_cpus", + "rayon", +] + +[[package]] +name = "bindgen_cuda" +version = "0.1.6" +source = "git+https://github.com/guoqingbao/bindgen_cuda.git#fb7ed75f3901b146aa1ba460baaeed5b494f2e0d" +dependencies = [ + "glob", + "num_cpus", + "rayon", +] + +[[package]] +name = "bit_field" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc827186963e592360843fb5ba4b973e145841266c1357f7180c43526f2e5b61" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bitflags" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" + +[[package]] +name = "block" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a" + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "buildstructor" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3907aac66c65520545ae3cb3c195306e20d5ed5c90bfbb992e061cf12a104d0" +dependencies = [ + "lazy_static", + "proc-macro2", + "quote", + "str_inflector", + "syn 2.0.77", + "thiserror", + "try_match", +] + +[[package]] +name = "bumpalo" +version = "3.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" + +[[package]] +name = "bytemuck" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94bbb0ad554ad961ddc5da507a12a29b14e4ae5bda06b19f575a3e6079d2e2ae" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc8b54b395f2fcfbb3d90c47b01c7f444d94d05bdeb775811dec868ac3bbc26" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "byteorder-lite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" + +[[package]] +name = "bytes" +version = "1.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" + +[[package]] +name = "candle-core" +version = "0.6.1" +source = "git+https://github.com/EricLBuehler/candle.git?rev=8a99f7c#8a99f7cf31a1d8f175281492eaa7026730067d08" +dependencies = [ + "byteorder", + "candle-kernels 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=8a99f7c)", + "cudarc", + "gemm", + "half", + "memmap2", + "num-traits", + "num_cpus", + "rand", + "rand_distr", + "rayon", + "safetensors", + "thiserror", + "yoke", + "zip", +] + +[[package]] +name = "candle-core" +version = "0.6.1" +source = "git+https://github.com/EricLBuehler/candle.git?rev=9e31a19#9e31a192642b4048e3df75173efddebaf663fef2" +dependencies = [ + "accelerate-src", + "byteorder", + "candle-kernels 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", + "candle-metal-kernels", + "cudarc", + "gemm", + "half", + "intel-mkl-src", + "libc", + "memmap2", + "metal", + "num-traits", + "num_cpus", + "rand", + "rand_distr", + "rayon", + "safetensors", + "thiserror", + "yoke", + "zip", +] + +[[package]] +name = "candle-flash-attn" +version = "0.6.1" +source = "git+https://github.com/EricLBuehler/candle.git?rev=8a99f7c#8a99f7cf31a1d8f175281492eaa7026730067d08" +dependencies = [ + "anyhow", + "bindgen_cuda 0.1.5", + "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=8a99f7c)", + "half", +] + +[[package]] +name = "candle-kernels" +version = "0.6.1" +source = "git+https://github.com/EricLBuehler/candle.git?rev=8a99f7c#8a99f7cf31a1d8f175281492eaa7026730067d08" +dependencies = [ + "bindgen_cuda 0.1.5", +] + +[[package]] +name = "candle-kernels" +version = "0.6.1" +source = "git+https://github.com/EricLBuehler/candle.git?rev=9e31a19#9e31a192642b4048e3df75173efddebaf663fef2" +dependencies = [ + "bindgen_cuda 0.1.5", +] + +[[package]] +name = "candle-metal-kernels" +version = "0.6.1" +source = "git+https://github.com/EricLBuehler/candle.git?rev=9e31a19#9e31a192642b4048e3df75173efddebaf663fef2" +dependencies = [ + "metal", + "once_cell", + "thiserror", + "tracing", +] + +[[package]] +name = "candle-nn" +version = "0.6.1" +source = "git+https://github.com/EricLBuehler/candle.git?rev=9e31a19#9e31a192642b4048e3df75173efddebaf663fef2" +dependencies = [ + "accelerate-src", + "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", + "candle-metal-kernels", + "half", + "intel-mkl-src", + "metal", + "num-traits", + "rayon", + "safetensors", + "serde", + "thiserror", +] + +[[package]] +name = "cc" +version = "1.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07b1695e2c7e8fc85310cde85aeaab7e3097f593c91d209d3f9df76c928100f0" +dependencies = [ + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + +[[package]] +name = "cfgrammar" +version = "0.13.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6026d8cd82ada8bbcfe337805dd1eb6afdc9e80fa4d57e977b3a36315e0c5525" +dependencies = [ + "indexmap", + "lazy_static", + "num-traits", + "regex", + "serde", + "vob", +] + +[[package]] +name = "chrono" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", + "num-traits", + "wasm-bindgen", + "windows-targets 0.52.6", +] + +[[package]] +name = "chrono-tz" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93698b29de5e97ad0ae26447b344c482a7284c737d9ddc5f9e52b74a336671bb" +dependencies = [ + "chrono", + "chrono-tz-build", + "phf", +] + +[[package]] +name = "chrono-tz-build" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c088aee841df9c3041febbb73934cfc39708749bf96dc827e3359cd39ef11b1" +dependencies = [ + "parse-zoneinfo", + "phf", + "phf_codegen", +] + +[[package]] +name = "clap" +version = "4.5.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e5a21b8495e732f1b3c364c9949b201ca7bae518c502c80256c96ad79eaf6ac" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cf2dd12af7a047ad9d6da2b6b249759a22a7abc0f474c1dae1777afa4b21a73" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim 0.11.1", +] + +[[package]] +name = "clap_derive" +version = "4.5.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "501d359d5f3dcaf6ecdeee48833ae73ec6e42723a1e52419c79abf9507eec0a0" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "clap_lex" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" + +[[package]] +name = "color_quant" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" + +[[package]] +name = "colorchoice" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" + +[[package]] +name = "console" +version = "0.15.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" +dependencies = [ + "encode_unicode", + "lazy_static", + "libc", + "unicode-width", + "windows-sys 0.52.0", +] + +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "core-graphics-types" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45390e6114f68f718cc7a830514a96f903cccd70d02a8f6d9f643ac4ba45afaf" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "libc", +] + +[[package]] +name = "cpufeatures" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0" +dependencies = [ + "libc", +] + +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + +[[package]] +name = "crossterm" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e64e6c0fbe2c17357405f7c758c1ef960fce08bdfb2c03d88d2a18d7e09c4b67" +dependencies = [ + "bitflags 1.3.2", + "crossterm_winapi", + "libc", + "mio 0.8.11", + "parking_lot", + "signal-hook", + "signal-hook-mio", + "winapi", +] + +[[package]] +name = "crossterm_winapi" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" +dependencies = [ + "winapi", +] + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "csv" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +dependencies = [ + "memchr", +] + +[[package]] +name = "ctrlc" +version = "3.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90eeab0aa92f3f9b4e87f258c72b139c207d251f9cbc1080a0086b86a8870dd3" +dependencies = [ + "nix", + "windows-sys 0.59.0", +] + +[[package]] +name = "cudarc" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38cd60a9a42ec83a2ed7effb0b1f073270264ea99da7acfc44f7e8d74dee0384" +dependencies = [ + "half", + "libloading", +] + +[[package]] +name = "darling" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbffa8f8e38810422f320ca457a93cf1cd0056dc9c06c556b867558e0d471463" +dependencies = [ + "darling_core 0.11.0", + "darling_macro 0.11.0", +] + +[[package]] +name = "darling" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" +dependencies = [ + "darling_core 0.20.10", + "darling_macro 0.20.10", +] + +[[package]] +name = "darling_core" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06e172685d94b7b83800e3256a63261537b9d6129e10f21c8e13ddf9dba8c64d" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.10.0", + "syn 1.0.109", +] + +[[package]] +name = "darling_core" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.11.1", + "syn 2.0.77", +] + +[[package]] +name = "darling_macro" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0618ac802792cebd1918ac6042a6ea1eeab92db34b35656afaa577929820788" +dependencies = [ + "darling_core 0.11.0", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "darling_macro" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +dependencies = [ + "darling_core 0.20.10", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "data-url" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c297a1c74b71ae29df00c3e22dd9534821d60eb9af5a0192823fa2acea70c2a" + +[[package]] +name = "defmac" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aafbece59594ed57696a1a69e8bb3ca1683fbc9cdb41d5c02726070b2cd8f19d" + +[[package]] +name = "derive-new" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d150dea618e920167e5973d70ae6ece4385b7164e0d799fe7c122dd0a5d912ad" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "derive_arbitrary" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67e77553c4162a157adbf834ebae5b415acbecbeafc7a74b0e886657506a7611" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "derive_builder" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd33f37ee6a119146a1781d3356a7c26028f83d779b2e04ecd45fdc75c76877b" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7431fa049613920234f22c47fdc33e6cf3ee83067091ea4277a3f8c4587aae38" +dependencies = [ + "darling 0.20.10", + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4abae7035bf79b9877b779505d8cf3749285b80c43941eda66604841889451dc" +dependencies = [ + "derive_builder_core", + "syn 2.0.77", +] + +[[package]] +name = "derive_more" +version = "0.99.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f33878137e4dafd7fa914ad4e259e18a4e8e532b9617a2d0150262bf53abfce" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + +[[package]] +name = "directories" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a49173b84e034382284f27f1af4dcbbd231ffa358c0fe316541a7337f376a35" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c" +dependencies = [ + "libc", + "option-ext", + "redox_users", + "windows-sys 0.48.0", +] + +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "dyn-clone" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" + +[[package]] +name = "dyn-stack" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e53799688f5632f364f8fb387488dd05db9fe45db7011be066fc20e7027f8b" +dependencies = [ + "bytemuck", + "reborrow", +] + +[[package]] +name = "either" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +dependencies = [ + "serde", +] + +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + +[[package]] +name = "encoding_rs" +version = "0.8.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "endian-type" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" + +[[package]] +name = "enum-as-inner" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + +[[package]] +name = "errno" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "esaxx-rs" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6" +dependencies = [ + "cc", +] + +[[package]] +name = "exr" +version = "1.72.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "887d93f60543e9a9362ef8a21beedd0a833c5d9610e18c67abe15a5963dcb1a4" +dependencies = [ + "bit_field", + "flume", + "half", + "lebe", + "miniz_oxide 0.7.4", + "rayon-core", + "smallvec", + "zune-inflate", +] + +[[package]] +name = "eyre" +version = "0.6.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cd915d99f24784cdc19fd37ef22b97e3ff0ae756c7e492e9fbfe897d61e2aec" +dependencies = [ + "indenter", + "once_cell", +] + +[[package]] +name = "fastrand" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" + +[[package]] +name = "fdeflate" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f9bfee30e4dedf0ab8b422f03af778d9612b63f502710fc500a334ebe2de645" +dependencies = [ + "simd-adler32", +] + +[[package]] +name = "filetime" +version = "0.2.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35c0522e981e68cbfa8c3f978441a5f34b30b96e146b33cd3359176b50fe8586" +dependencies = [ + "cfg-if", + "libc", + "libredox", + "windows-sys 0.59.0", +] + +[[package]] +name = "flate2" +version = "1.0.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "324a1be68054ef05ad64b861cc9eaf1d623d2d8cb25b4bf2cb9cdd902b4bf253" +dependencies = [ + "crc32fast", + "miniz_oxide 0.8.0", +] + +[[package]] +name = "flume" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" +dependencies = [ + "spin", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared 0.1.1", +] + +[[package]] +name = "foreign-types" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" +dependencies = [ + "foreign-types-macros", + "foreign-types-shared 0.3.1", +] + +[[package]] +name = "foreign-types-macros" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "foreign-types-shared" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa9a19cbb55df58761df49b23516a86d432839add4af60fc256da840f66ed35b" + +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" + +[[package]] +name = "futures-executor" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "futures-sink" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" + +[[package]] +name = "futures-task" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" + +[[package]] +name = "futures-util" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "galil-seiferas" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "794ac25cfda3fa11d2b07ff8c65889c6c03411646df54e59e606878d899e1d5a" +dependencies = [ + "defmac", + "unchecked-index", +] + +[[package]] +name = "gemm" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ab24cc62135b40090e31a76a9b2766a501979f3070fa27f689c27ec04377d32" +dependencies = [ + "dyn-stack", + "gemm-c32", + "gemm-c64", + "gemm-common", + "gemm-f16", + "gemm-f32", + "gemm-f64", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c32" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9c030d0b983d1e34a546b86e08f600c11696fde16199f971cd46c12e67512c0" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c64" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbb5f2e79fefb9693d18e1066a557b4546cd334b226beadc68b11a8f9431852a" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-common" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2e7ea062c987abcd8db95db917b4ffb4ecdfd0668471d8dc54734fdff2354e8" +dependencies = [ + "bytemuck", + "dyn-stack", + "half", + "num-complex", + "num-traits", + "once_cell", + "paste", + "pulp", + "raw-cpuid", + "rayon", + "seq-macro", + "sysctl", +] + +[[package]] +name = "gemm-f16" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ca4c06b9b11952071d317604acb332e924e817bd891bec8dfb494168c7cedd4" +dependencies = [ + "dyn-stack", + "gemm-common", + "gemm-f32", + "half", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "gemm-f32" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9a69f51aaefbd9cf12d18faf273d3e982d9d711f60775645ed5c8047b4ae113" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-f64" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa397a48544fadf0b81ec8741e5c0fba0043008113f71f2034def1935645d2b0" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "getset" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f636605b743120a8d32ed92fc27b6cde1a769f8f936c065151eb66f88ded513c" +dependencies = [ + "proc-macro-error2", + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "gif" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fb2d69b19215e18bb912fa30f7ce15846e301408695e44e0ef719f1da9e19f2" +dependencies = [ + "color_quant", + "weezl", +] + +[[package]] +name = "gimli" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32085ea23f3234fc7846555e85283ba4de91e21016dc0455a16286d87a292d64" + +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + +[[package]] +name = "h2" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "half" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +dependencies = [ + "bytemuck", + "cfg-if", + "crunchy", + "num-traits", + "rand", + "rand_distr", +] + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", + "allocator-api2", +] + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + +[[package]] +name = "hf-hub" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732" +dependencies = [ + "dirs", + "indicatif", + "log", + "native-tls", + "rand", + "serde", + "serde_json", + "thiserror", + "ureq", +] + +[[package]] +name = "http" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" +dependencies = [ + "bytes", + "futures-util", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fcc0b4a115bf80b728eb8ea024ad5bd707b615bfed49e0665b6e0f86fd082d9" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "hyper" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "h2", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" +dependencies = [ + "futures-util", + "http", + "hyper", + "hyper-util", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", + "webpki-roots", +] + +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + +[[package]] +name = "hyper-util" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da62f120a8a37763efb0cf8fdf264b884c7b8b9ac8660b900c8661030c00e6ba" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http", + "http-body", + "hyper", + "pin-project-lite", + "socket2", + "tokio", + "tower", + "tower-service", + "tracing", +] + +[[package]] +name = "iana-time-zone" +version = "0.1.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + +[[package]] +name = "idna" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + +[[package]] +name = "image" +version = "0.25.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99314c8a2152b8ddb211f924cdae532d8c5e4c8bb54728e12fff1b0cd5963a10" +dependencies = [ + "bytemuck", + "byteorder-lite", + "color_quant", + "exr", + "gif", + "image-webp", + "num-traits", + "png", + "qoi", + "tiff", + "zune-core", + "zune-jpeg", +] + +[[package]] +name = "image-webp" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f79afb8cbee2ef20f59ccd477a218c12a93943d075b492015ecb1bb81f8ee904" +dependencies = [ + "byteorder-lite", + "quick-error", +] + +[[package]] +name = "indenter" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" + +[[package]] +name = "indexmap" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5" +dependencies = [ + "equivalent", + "hashbrown", + "serde", +] + +[[package]] +name = "indicatif" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3" +dependencies = [ + "console", + "instant", + "number_prefix", + "portable-atomic", + "rayon", + "unicode-width", +] + +[[package]] +name = "indoc" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" + +[[package]] +name = "instant" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "intel-mkl-src" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ee70586cd5b3e772a8739a1bd43eaa90d4f4bf0fb2a4edc202e979937ee7f5e" +dependencies = [ + "anyhow", + "intel-mkl-tool", + "ocipkg", +] + +[[package]] +name = "intel-mkl-tool" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "887a16b4537d82227af54d3372971cfa5e0cde53322e60f57584056c16ada1b4" +dependencies = [ + "anyhow", + "log", + "walkdir", +] + +[[package]] +name = "ipnet" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "187674a687eed5fe42285b40c6291f9a01517d415fad1c3cbc6a9f778af7fcd4" + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" + +[[package]] +name = "itertools" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + +[[package]] +name = "jpeg-decoder" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0" + +[[package]] +name = "js-sys" +version = "0.3.70" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "lebe" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" + +[[package]] +name = "libc" +version = "0.2.158" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" + +[[package]] +name = "libloading" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" +dependencies = [ + "cfg-if", + "windows-targets 0.52.6", +] + +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + +[[package]] +name = "libredox" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" +dependencies = [ + "bitflags 2.6.0", + "libc", + "redox_syscall", +] + +[[package]] +name = "linux-raw-sys" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" + +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" + +[[package]] +name = "lrtable" +version = "0.13.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d42d2752cb50a171efadda0cb6fa97432e8bf05accfff3eed320b87e80a2f69e" +dependencies = [ + "cfgrammar", + "fnv", + "num-traits", + "sparsevec", + "vob", +] + +[[package]] +name = "macro_rules_attribute" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a82271f7bc033d84bbca59a3ce3e4159938cb08a9c3aebbe54d215131518a13" +dependencies = [ + "macro_rules_attribute-proc_macro", + "paste", +] + +[[package]] +name = "macro_rules_attribute-proc_macro" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8dd856d451cc0da70e2ef2ce95a18e39a93b7558bedf10201ad28503f918568" + +[[package]] +name = "malloc_buf" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62bb907fe88d54d8d9ce32a3cceab4218ed2f6b7d35617cafe9adf84e43919cb" +dependencies = [ + "libc", +] + +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", +] + +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + +[[package]] +name = "memmap2" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f" +dependencies = [ + "libc", + "stable_deref_trait", +] + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "metal" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c43f73953f8cbe511f021b58f18c3ce1c3d1ae13fe953293e13345bf83217f25" +dependencies = [ + "bitflags 2.6.0", + "block", + "core-graphics-types", + "foreign-types 0.5.0", + "log", + "objc", + "paste", +] + +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + +[[package]] +name = "minijinja" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1028b628753a7e1a88fc59c9ba4b02ecc3bc0bd3c7af23df667bc28df9b3310e" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "minijinja-contrib" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ffd46ee854be23604a20efd6c9655374fefbe4d44b949dc0f907305d92873a" +dependencies = [ + "minijinja", + "serde", +] + +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + +[[package]] +name = "miniz_oxide" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" +dependencies = [ + "adler", + "simd-adler32", +] + +[[package]] +name = "miniz_oxide" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +dependencies = [ + "adler2", +] + +[[package]] +name = "mio" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +dependencies = [ + "libc", + "log", + "wasi", + "windows-sys 0.48.0", +] + +[[package]] +name = "mio" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" +dependencies = [ + "hermit-abi", + "libc", + "wasi", + "windows-sys 0.52.0", +] + +[[package]] +name = "mistralrs" +version = "0.3.0" +dependencies = [ + "anyhow", + "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", + "either", + "futures", + "image", + "indexmap", + "mistralrs-core", + "rand", + "reqwest", + "serde", + "serde_json", + "tokio", +] + +[[package]] +name = "mistralrs-core" +version = "0.3.0" +dependencies = [ + "accelerate-src", + "akin", + "anyhow", + "async-trait", + "base64 0.22.1", + "bindgen_cuda 0.1.5", + "buildstructor", + "bytemuck", + "bytemuck_derive", + "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", + "candle-flash-attn", + "candle-nn", + "cfgrammar", + "chrono", + "clap", + "csv", + "derive-new", + "derive_more", + "dirs", + "either", + "futures", + "galil-seiferas", + "half", + "hf-hub", + "image", + "indexmap", + "indicatif", + "intel-mkl-src", + "itertools 0.13.0", + "lrtable", + "minijinja", + "minijinja-contrib", + "mistralrs-paged-attn", + "mistralrs-quant", + "mistralrs-vision", + "once_cell", + "pyo3", + "radix_trie", + "rand", + "rand_isaac", + "rayon", + "regex", + "regex-automata 0.4.7", + "reqwest", + "rustc-hash", + "safetensors", + "schemars", + "serde", + "serde_json", + "serde_yaml", + "strum", + "sysinfo", + "thiserror", + "tokenizers", + "tokio", + "tokio-rayon", + "toml", + "tqdm", + "tracing", + "tracing-subscriber", + "uuid 1.10.0", + "variantly", + "vob", +] + +[[package]] +name = "mistralrs-paged-attn" +version = "0.3.0" +dependencies = [ + "anyhow", + "bindgen_cuda 0.1.6", + "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", + "half", +] + +[[package]] +name = "mistralrs-pyo3" +version = "0.3.0" +dependencies = [ + "accelerate-src", + "anyhow", + "base64 0.22.1", + "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", + "data-url", + "either", + "futures", + "image", + "indexmap", + "intel-mkl-src", + "mistralrs-core", + "pyo3", + "pyo3-build-config", + "reqwest", + "serde", + "serde_json", + "tokio", + "url", +] + +[[package]] +name = "mistralrs-quant" +version = "0.3.0" +dependencies = [ + "bindgen_cuda 0.1.5", + "byteorder", + "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", + "candle-nn", + "half", + "lazy_static", + "paste", + "rayon", + "serde", + "tracing", +] + +[[package]] +name = "mistralrs-server" +version = "0.3.0" +dependencies = [ + "accelerate-src", + "anyhow", + "axum", + "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", + "clap", + "ctrlc", + "data-url", + "either", + "futures", + "image", + "indexmap", + "intel-mkl-src", + "mistralrs-core", + "once_cell", + "reqwest", + "serde", + "serde_json", + "tokio", + "tower-http", + "tracing", + "url", + "utoipa", + "utoipa-swagger-ui", +] + +[[package]] +name = "mistralrs-vision" +version = "0.3.0" +dependencies = [ + "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", + "image", +] + +[[package]] +name = "monostate" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d208407d7552cd041d8cdb69a1bc3303e029c598738177a3d87082004dc0e1e" +dependencies = [ + "monostate-impl", + "serde", +] + +[[package]] +name = "monostate-impl" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "native-tls" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + +[[package]] +name = "nibble_vec" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a5d83df9f36fe23f0c3648c6bbb8b0298bb5f1939c8f2704431371f4b84d43" +dependencies = [ + "smallvec", +] + +[[package]] +name = "nix" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +dependencies = [ + "bitflags 2.6.0", + "cfg-if", + "cfg_aliases", + "libc", +] + +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + +[[package]] +name = "ntapi" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8a3895c6391c39d7fe7ebc444a87eb2991b2a0bc718fdabd071eec617fc68e4" +dependencies = [ + "winapi", +] + +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "bytemuck", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "num_enum" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e613fc340b2220f734a8595782c551f1250e969d87d3be1ae0579e8d4065179" +dependencies = [ + "num_enum_derive", +] + +[[package]] +name = "num_enum_derive" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + +[[package]] +name = "objc" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "915b1b472bc21c53464d6c8461c9d3af805ba1ef837e1cac254428f4a77177b1" +dependencies = [ + "malloc_buf", + "objc_exception", +] + +[[package]] +name = "objc_exception" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad970fb455818ad6cba4c122ad012fae53ae8b4795f86378bce65e4f6bab2ca4" +dependencies = [ + "cc", +] + +[[package]] +name = "object" +version = "0.36.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "084f1a5821ac4c651660a94a7153d27ac9d8a53736203f58b31945ded098070a" +dependencies = [ + "memchr", +] + +[[package]] +name = "oci-spec" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f5a3fe998d50101ae009351fec56d88a69f4ed182e11000e711068c2f5abf72" +dependencies = [ + "derive_builder", + "getset", + "once_cell", + "regex", + "serde", + "serde_json", + "strum", + "strum_macros", + "thiserror", +] + +[[package]] +name = "ocipkg" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9bb3293021f06540803301af45e7ab81693d50e89a7398a3420bdab139e7ba5e" +dependencies = [ + "base16ct", + "base64 0.22.1", + "chrono", + "directories", + "flate2", + "lazy_static", + "log", + "oci-spec", + "regex", + "serde", + "serde_json", + "sha2", + "tar", + "thiserror", + "toml", + "ureq", + "url", + "uuid 1.10.0", + "walkdir", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "onig" +version = "6.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c4b31c8722ad9171c6d77d3557db078cab2bd50afcc9d09c8b315c59df8ca4f" +dependencies = [ + "bitflags 1.3.2", + "libc", + "once_cell", + "onig_sys", +] + +[[package]] +name = "onig_sys" +version = "69.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b829e3d7e9cc74c7e315ee8edb185bf4190da5acde74afd7fc59c35b1f086e7" +dependencies = [ + "cc", + "pkg-config", +] + +[[package]] +name = "openssl" +version = "0.10.66" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9529f4786b70a3e8c61e11179af17ab6188ad8d0ded78c5529441ed39d4bd9c1" +dependencies = [ + "bitflags 2.6.0", + "cfg-if", + "foreign-types 0.3.2", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + +[[package]] +name = "openssl-sys" +version = "0.9.103" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + +[[package]] +name = "packedvec" +version = "1.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bde3c690ec20e4a2b4fb46f0289a451181eb50011a1e2acc8d85e2fde9062a45" +dependencies = [ + "num-traits", + "serde", +] + +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets 0.52.6", +] + +[[package]] +name = "parse-zoneinfo" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f2a05b18d44e2957b88f96ba460715e295bc1d7510468a2f3d3b44535d26c24" +dependencies = [ + "regex", +] + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "percent-encoding" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" + +[[package]] +name = "phf" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_codegen" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a" +dependencies = [ + "phf_generator", + "phf_shared", +] + +[[package]] +name = "phf_generator" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +dependencies = [ + "phf_shared", + "rand", +] + +[[package]] +name = "phf_shared" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +dependencies = [ + "siphasher", +] + +[[package]] +name = "pin-project" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "pkg-config" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" + +[[package]] +name = "png" +version = "0.17.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06e4b0d3d1312775e782c86c91a111aa1f910cbb65e1337f9975b5f9a554b5e1" +dependencies = [ + "bitflags 1.3.2", + "crc32fast", + "fdeflate", + "flate2", + "miniz_oxide 0.7.4", +] + +[[package]] +name = "portable-atomic" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" + +[[package]] +name = "ppv-lite86" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "proc-macro-crate" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b" +dependencies = [ + "toml_edit", +] + +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" +dependencies = [ + "proc-macro2", + "quote", +] + +[[package]] +name = "proc-macro-error2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" +dependencies = [ + "proc-macro-error-attr2", + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "proc-macro2" +version = "1.0.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pulp" +version = "0.18.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0a01a0dc67cf4558d279f0c25b0962bd08fc6dec0137699eae304103e882fe6" +dependencies = [ + "bytemuck", + "libm", + "num-complex", + "reborrow", +] + +[[package]] +name = "pyo3" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15ee168e30649f7f234c3d49ef5a7a6cbf5134289bc46c29ff3155fa3221c225" +dependencies = [ + "anyhow", + "cfg-if", + "chrono", + "chrono-tz", + "either", + "eyre", + "hashbrown", + "indexmap", + "indoc", + "libc", + "memoffset", + "num-bigint", + "num-complex", + "num-rational", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "rust_decimal", + "serde", + "smallvec", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e61cef80755fe9e46bb8a0b8f20752ca7676dcc07a5277d8b7768c6172e529b3" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67ce096073ec5405f5ee2b8b31f03a68e02aa10d5d4f565eca04acc41931fa1c" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2440c6d12bc8f3ae39f1e775266fa5122fd0c8891ce7520fa6048e683ad3de28" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1be962f0e06da8f8465729ea2cb71a416d2257dff56cbe40a70d3e62a93ae5d1" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "qoi" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f6d64c71eb498fe9eae14ce4ec935c555749aef511cca85b5568910d6e48001" +dependencies = [ + "bytemuck", +] + +[[package]] +name = "quick-error" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" + +[[package]] +name = "quinn" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c7c5fdde3cdae7203427dc4f0a68fe0ed09833edc525a03456b153b79828684" +dependencies = [ + "bytes", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls", + "socket2", + "thiserror", + "tokio", + "tracing", +] + +[[package]] +name = "quinn-proto" +version = "0.11.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fadfaed2cd7f389d0161bb73eeb07b7b78f8691047a6f3e73caaeae55310a4a6" +dependencies = [ + "bytes", + "rand", + "ring", + "rustc-hash", + "rustls", + "slab", + "thiserror", + "tinyvec", + "tracing", +] + +[[package]] +name = "quinn-udp" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fe68c2e9e1a1234e218683dbdf9f9dfcb094113c5ac2b938dfcb9bab4c4140b" +dependencies = [ + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.59.0", +] + +[[package]] +name = "quote" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "radix_trie" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c069c179fcdc6a2fe24d8d18305cf085fdbd4f922c041943e203685d6a1c58fd" +dependencies = [ + "endian-type", + "nibble_vec", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "rand_isaac" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fac4373cd91b4f55722c553fb0f286edbb81ef3ff6eec7b99d1898a4110a0b28" +dependencies = [ + "rand_core", +] + +[[package]] +name = "raw-cpuid" +version = "10.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" +dependencies = [ + "bitflags 1.3.2", +] + +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-cond" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "059f538b55efd2309c9794130bc149c6a553db90e9d99c2030785c82f0bd7df9" +dependencies = [ + "either", + "itertools 0.11.0", + "rayon", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "reborrow" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" + +[[package]] +name = "redox_syscall" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0884ad60e090bf1345b93da0a5de8923c93884cd03f40dfcfddd3b4bee661853" +dependencies = [ + "bitflags 2.6.0", +] + +[[package]] +name = "redox_users" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" +dependencies = [ + "getrandom", + "libredox", + "thiserror", +] + +[[package]] +name = "regex" +version = "1.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata 0.4.7", + "regex-syntax 0.8.4", +] + +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", +] + +[[package]] +name = "regex-automata" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax 0.8.4", +] + +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + +[[package]] +name = "regex-syntax" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" + +[[package]] +name = "reqwest" +version = "0.12.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8f4955649ef5c38cc7f9e8aa41761d48fb9677197daea9984dc54f56aad5e63" +dependencies = [ + "base64 0.22.1", + "bytes", + "encoding_rs", + "futures-channel", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-tls", + "hyper-util", + "ipnet", + "js-sys", + "log", + "mime", + "native-tls", + "once_cell", + "percent-encoding", + "pin-project-lite", + "quinn", + "rustls", + "rustls-pemfile", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper 1.0.1", + "system-configuration", + "tokio", + "tokio-native-tls", + "tokio-rustls", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "webpki-roots", + "windows-registry", +] + +[[package]] +name = "ring" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" +dependencies = [ + "cc", + "cfg-if", + "getrandom", + "libc", + "spin", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] +name = "rust-embed" +version = "8.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa66af4a4fdd5e7ebc276f115e895611a34739a9c1c01028383d612d550953c0" +dependencies = [ + "rust-embed-impl", + "rust-embed-utils", + "walkdir", +] + +[[package]] +name = "rust-embed-impl" +version = "8.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6125dbc8867951125eec87294137f4e9c2c96566e61bf72c45095a7c77761478" +dependencies = [ + "proc-macro2", + "quote", + "rust-embed-utils", + "syn 2.0.77", + "walkdir", +] + +[[package]] +name = "rust-embed-utils" +version = "8.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e5347777e9aacb56039b0e1f28785929a8a3b709e87482e7442c72e7c12529d" +dependencies = [ + "sha2", + "walkdir", +] + +[[package]] +name = "rust_decimal" +version = "1.36.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b082d80e3e3cc52b2ed634388d436fe1f4de6af5786cc2de9ba9737527bdf555" +dependencies = [ + "arrayvec", + "num-traits", +] + +[[package]] +name = "rustc-demangle" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" + +[[package]] +name = "rustc-hash" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152" + +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + +[[package]] +name = "rustix" +version = "0.38.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" +dependencies = [ + "bitflags 2.6.0", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustls" +version = "0.23.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2dabaac7466917e566adb06783a81ca48944c6898a1b08b9374106dd671f4c8" +dependencies = [ + "log", + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pemfile" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425" +dependencies = [ + "base64 0.22.1", + "rustls-pki-types", +] + +[[package]] +name = "rustls-pki-types" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" + +[[package]] +name = "rustls-webpki" +version = "0.102.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + +[[package]] +name = "rustversion" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" + +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + +[[package]] +name = "safetensors" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44560c11236a6130a46ce36c836a62936dc81ebf8c36a37947423571be0e55b6" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "schannel" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9aaafd5a2b6e3d657ff009d82fbd630b6bd54dd4eb06f21693925cdf80f9b8b" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "schemars" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09c024468a378b7e36765cd36702b7a90cc3cba11654f6685c8f233408e89e92" +dependencies = [ + "dyn-clone", + "schemars_derive", + "serde", + "serde_json", +] + +[[package]] +name = "schemars_derive" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1eee588578aff73f856ab961cd2f79e36bc45d7ded33a7562adba4667aecc0e" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn 2.0.77", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags 2.6.0", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75da29fe9b9b08fe9d6b22b5b4bcbc75d8db3aa31e639aa56bb62e9d46bfceaf" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "semver" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" + +[[package]] +name = "seq-macro" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" + +[[package]] +name = "serde" +version = "1.0.210" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.210" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "serde_derive_internals" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "serde_json" +version = "1.0.128" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + +[[package]] +name = "serde_path_to_error" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af99884400da37c88f5e9146b7f1fd0fbcae8f6eec4e9da38b67d05486f814a6" +dependencies = [ + "itoa", + "serde", +] + +[[package]] +name = "serde_spanned" +version = "0.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb5b1b31579f3811bf615c144393417496f152e12ac8b7663bf664f4a815306d" +dependencies = [ + "serde", +] + +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "serde_yaml" +version = "0.9.34+deprecated" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" +dependencies = [ + "indexmap", + "itoa", + "ryu", + "serde", + "unsafe-libyaml", +] + +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "signal-hook" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801" +dependencies = [ + "libc", + "signal-hook-registry", +] + +[[package]] +name = "signal-hook-mio" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd" +dependencies = [ + "libc", + "mio 0.8.11", + "signal-hook", +] + +[[package]] +name = "signal-hook-registry" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +dependencies = [ + "libc", +] + +[[package]] +name = "simd-adler32" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" + +[[package]] +name = "siphasher" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" + +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + +[[package]] +name = "smallvec" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" + +[[package]] +name = "socket2" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "sparsevec" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35df5d2e580b29f3f7ec5b4ed49b0ab3acf7f3624122b3e823cafb9630f293b8" +dependencies = [ + "num-traits", + "packedvec", + "serde", + "vob", +] + +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + +[[package]] +name = "spm_precompiled" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326" +dependencies = [ + "base64 0.13.1", + "nom", + "serde", + "unicode-segmentation", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + +[[package]] +name = "str_inflector" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0b848d5a7695b33ad1be00f84a3c079fe85c9278a325ff9159e6c99cef4ef7" +dependencies = [ + "lazy_static", + "regex", +] + +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "strum" +version = "0.26.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.77", +] + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.77" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + +[[package]] +name = "sync_wrapper" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +dependencies = [ + "futures-core", +] + +[[package]] +name = "synstructure" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "sysctl" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec7dddc5f0fee506baf8b9fdb989e242f17e4b11c61dfbb0635b705217199eea" +dependencies = [ + "bitflags 2.6.0", + "byteorder", + "enum-as-inner", + "libc", + "thiserror", + "walkdir", +] + +[[package]] +name = "sysinfo" +version = "0.30.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a5b4ddaee55fb2bea2bf0e5000747e5f5c0de765e5a5ff87f4cd106439f4bb3" +dependencies = [ + "cfg-if", + "core-foundation-sys", + "libc", + "ntapi", + "once_cell", + "rayon", + "windows", +] + +[[package]] +name = "system-configuration" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" +dependencies = [ + "bitflags 2.6.0", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "tar" +version = "0.4.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb797dad5fb5b76fcf519e702f4a589483b5ef06567f160c392832c1f5e44909" +dependencies = [ + "filetime", + "libc", + "xattr", +] + +[[package]] +name = "target-lexicon" +version = "0.12.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" + +[[package]] +name = "tempfile" +version = "3.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64" +dependencies = [ + "cfg-if", + "fastrand", + "once_cell", + "rustix", + "windows-sys 0.59.0", +] + +[[package]] +name = "thiserror" +version = "1.0.63" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.63" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "thread_local" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +dependencies = [ + "cfg-if", + "once_cell", +] + +[[package]] +name = "tiff" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba1310fcea54c6a9a4fd1aad794ecc02c31682f6bfbecdf460bf19533eed1e3e" +dependencies = [ + "flate2", + "jpeg-decoder", + "weezl", +] + +[[package]] +name = "tinyvec" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "tokenizers" +version = "0.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e500fad1dd3af3d626327e6a3fe5050e664a6eaa4708b8ca92f1794aaf73e6fd" +dependencies = [ + "aho-corasick", + "derive_builder", + "esaxx-rs", + "getrandom", + "indicatif", + "itertools 0.12.1", + "lazy_static", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand", + "rayon", + "rayon-cond", + "regex", + "regex-syntax 0.8.4", + "serde", + "serde_json", + "spm_precompiled", + "thiserror", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + +[[package]] +name = "tokio" +version = "1.40.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" +dependencies = [ + "backtrace", + "bytes", + "libc", + "mio 1.0.2", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys 0.52.0", +] + +[[package]] +name = "tokio-macros" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + +[[package]] +name = "tokio-rayon" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cf33a76e0b1dd03b778f83244137bd59887abf25c0e87bc3e7071105f457693" +dependencies = [ + "rayon", + "tokio", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +dependencies = [ + "rustls", + "rustls-pki-types", + "tokio", +] + +[[package]] +name = "tokio-util" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "toml" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b072cee73c449a636ffd6f32bd8de3a9f7119139aff882f44943ce2986dc5cf" +dependencies = [ + "indexmap", + "serde", + "serde_spanned", + "toml_datetime", + "winnow", +] + +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "futures-core", + "futures-util", + "pin-project", + "pin-project-lite", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" +dependencies = [ + "bitflags 2.6.0", + "bytes", + "http", + "http-body", + "http-body-util", + "pin-project-lite", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + +[[package]] +name = "tqdm" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2d2932240205a99b65f15d9861992c95fbb8c9fb280b3a1f17a92db6dc611f" +dependencies = [ + "anyhow", + "crossterm", + "once_cell", +] + +[[package]] +name = "tracing" +version = "0.1.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +dependencies = [ + "log", + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "tracing-core" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", +] + +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + +[[package]] +name = "try_match" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b065c869a3f832418e279aa4c1d7088f9d5d323bde15a60a08e20c2cd4549082" +dependencies = [ + "try_match_inner", +] + +[[package]] +name = "try_match_inner" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9c81686f7ab4065ccac3df7a910c4249f8c0f3fb70421d6ddec19b9311f63f9" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + +[[package]] +name = "unchecked-index" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eeba86d422ce181a719445e51872fa30f1f7413b62becb52e95ec91aa262d85c" + +[[package]] +name = "unicase" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7d2d4dafb69621809a81864c9c1b864479e1235c0dd4e199924b9742439ed89" +dependencies = [ + "version_check", +] + +[[package]] +name = "unicode-bidi" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" + +[[package]] +name = "unicode-ident" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" + +[[package]] +name = "unicode-normalization" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "unicode-normalization-alignments" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" +dependencies = [ + "smallvec", +] + +[[package]] +name = "unicode-segmentation" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" + +[[package]] +name = "unicode-width" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d" + +[[package]] +name = "unicode_categories" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" + +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" + +[[package]] +name = "unsafe-libyaml" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "ureq" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b74fc6b57825be3373f7054754755f03ac3a8f5d70015ccad699ba2029956f4a" +dependencies = [ + "base64 0.22.1", + "flate2", + "log", + "native-tls", + "once_cell", + "rustls", + "rustls-pki-types", + "serde", + "serde_json", + "url", + "webpki-roots", +] + +[[package]] +name = "url" +version = "2.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + +[[package]] +name = "utoipa" +version = "4.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23" +dependencies = [ + "indexmap", + "serde", + "serde_json", + "utoipa-gen", +] + +[[package]] +name = "utoipa-gen" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bf0e16c02bc4bf5322ab65f10ab1149bdbcaa782cba66dc7057370a3f8190be" +dependencies = [ + "proc-macro-error", + "proc-macro2", + "quote", + "regex", + "syn 2.0.77", +] + +[[package]] +name = "utoipa-swagger-ui" +version = "7.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "943e0ff606c6d57d410fd5663a4d7c074ab2c5f14ab903b9514565e59fa1189e" +dependencies = [ + "axum", + "mime_guess", + "regex", + "reqwest", + "rust-embed", + "serde", + "serde_json", + "url", + "utoipa", + "zip", +] + +[[package]] +name = "uuid" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7" +dependencies = [ + "getrandom", +] + +[[package]] +name = "uuid" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +dependencies = [ + "getrandom", +] + +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + +[[package]] +name = "variantly" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72a332341ba79a179d9e9b33c0d72fbf3dc2c80e1be79416401a08d2b820ef56" +dependencies = [ + "Inflector", + "darling 0.11.0", + "proc-macro2", + "quote", + "syn 1.0.109", + "uuid 0.8.2", +] + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "vob" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c058f4c41e71a043c67744cb76dcc1ae63ece328c1732a72489ccccc2dec23e6" +dependencies = [ + "num-traits", + "rustc_version", + "serde", +] + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" +dependencies = [ + "cfg-if", + "once_cell", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.77", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" +dependencies = [ + "cfg-if", + "js-sys", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" + +[[package]] +name = "web-sys" +version = "0.3.70" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "webpki-roots" +version = "0.26.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841c67bff177718f1d4dfefde8d8f0e78f9b6589319ba88312f567fc5841a958" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "weezl" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-util" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" +dependencies = [ + "windows-core", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-registry" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" +dependencies = [ + "windows-result", + "windows-strings", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-result" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-strings" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +dependencies = [ + "windows-result", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "winnow" +version = "0.6.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68a9bda4691f099d435ad181000724da8e5899daa10713c2d432552b9ccd3a6f" +dependencies = [ + "memchr", +] + +[[package]] +name = "xattr" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8da84f1a25939b27f6820d92aed108f83ff920fdf11a7b19366c27c4cda81d4f" +dependencies = [ + "libc", + "linux-raw-sys", + "rustix", +] + +[[package]] +name = "yoke" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c5b1314b079b0930c31e3af543d8ee1757b1951ae1e1565ec704403a7240ca5" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", + "synstructure", +] + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "byteorder", + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "zerofrom" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91ec111ce797d0e0784a1116d0ddcdbea84322cd79e5d5ad173daeba4f93ab55" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", + "synstructure", +] + +[[package]] +name = "zeroize" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" + +[[package]] +name = "zip" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cc23c04387f4da0374be4533ad1208cbb091d5c11d070dfef13676ad6497164" +dependencies = [ + "arbitrary", + "crc32fast", + "crossbeam-utils", + "displaydoc", + "flate2", + "indexmap", + "num_enum", + "thiserror", +] + +[[package]] +name = "zune-core" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f423a2c17029964870cfaabb1f13dfab7d092a62a29a89264f4d36990ca414a" + +[[package]] +name = "zune-inflate" +version = "0.2.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73ab332fe2f6680068f3582b16a24f90ad7096d5d39b974d1c0aff0125116f02" +dependencies = [ + "simd-adler32", +] + +[[package]] +name = "zune-jpeg" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16099418600b4d8f028622f73ff6e3deaabdff330fb9a2a131dea781ee8b0768" +dependencies = [ + "zune-core", +] From 5d2b9de752e3aec11acd65ff65cbd67fcd87f6f5 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 19 Sep 2024 10:12:29 -0400 Subject: [PATCH 11/15] Compiles now --- mistralrs-quant/src/exl2/exl2_cuda.rs | 11 ++++++++++- mistralrs-quant/src/utils/uqff.rs | 6 +++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/mistralrs-quant/src/exl2/exl2_cuda.rs b/mistralrs-quant/src/exl2/exl2_cuda.rs index 9bff23c41..7b373faf2 100644 --- a/mistralrs-quant/src/exl2/exl2_cuda.rs +++ b/mistralrs-quant/src/exl2/exl2_cuda.rs @@ -17,7 +17,7 @@ use half::f16; use crate::{ utils::{get_cuda_device, get_cuda_slice}, - IsqType, QuantMethod, QuantMethodConfig, + IsqType, QuantMethod, QuantMethodConfig, QuantizedSerde, }; use super::ffi::{exl2_create_q_matrix, exl2_destroy_q_matrix, exl2_reconstruct_q_matrix}; @@ -323,3 +323,12 @@ impl Drop for Exl2Layer { } } } + +impl QuantizedSerde for Exl2Layer { + fn isq_serde_supported(&self) -> bool { + false + } + fn name(&self) -> &'static str { + "exl2" + } +} diff --git a/mistralrs-quant/src/utils/uqff.rs b/mistralrs-quant/src/utils/uqff.rs index 4365596c3..e9ca1e573 100644 --- a/mistralrs-quant/src/utils/uqff.rs +++ b/mistralrs-quant/src/utils/uqff.rs @@ -5,7 +5,7 @@ use half::{bf16, f16}; const HQFF_VERSION_MAJOR: u32 = 0; const HQFF_VERSION_MINOR: u32 = 1; -const HQFF_VERSION_PATCH: u32 = 0; +const HQFF_VERSION_PATCH: u32 = 1; /// Format 4 bytes, little endian: [ UNSPECIFIED ] [ MAJOR ] [ MINOR ] [ PATCH ] pub(crate) const HQFF_VERSION: u32 = @@ -47,6 +47,7 @@ pub(crate) fn serialize_tensor(buffer: &mut Vec, tensor: &Tensor) -> Result< let bias = match tensor.dtype() { DType::U8 => data_to_bytes::(tensor.to_vec1()?), DType::U32 => data_to_bytes::(tensor.to_vec1()?), + DType::I16 => data_to_bytes::(tensor.to_vec1()?), DType::I32 => data_to_bytes::(tensor.to_vec1()?), DType::I64 => data_to_bytes::(tensor.to_vec1()?), DType::F16 => data_to_bytes::(tensor.to_vec1()?), @@ -65,6 +66,7 @@ pub(crate) fn serialize_tensor(buffer: &mut Vec, tensor: &Tensor) -> Result< DType::BF16 => 5, DType::F32 => 6, DType::F64 => 7, + DType::I16 => 8, }; buffer.extend(&dtype.to_le_bytes()); @@ -95,6 +97,7 @@ pub(crate) fn deserialize_tensor( 5 => DType::BF16, 6 => DType::F32, 7 => DType::F64, + 8 => DType::I16, _ => candle_core::bail!("unknown dtype for quantized bias tensor {dtype}"), }; @@ -113,6 +116,7 @@ pub(crate) fn deserialize_tensor( DType::BF16 => bytes_to_data::(&tensor_data, &dims, device), DType::F32 => bytes_to_data::(&tensor_data, &dims, device), DType::F64 => bytes_to_data::(&tensor_data, &dims, device), + DType::I16 => bytes_to_data::(&tensor_data, &dims, device), DType::I32 => bytes_to_data::(&tensor_data, &dims, device), DType::I64 => bytes_to_data::(&tensor_data, &dims, device), DType::U32 => bytes_to_data::(&tensor_data, &dims, device), From 591d55f1eeda5de7a6afb03640194c367413dea4 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 19 Sep 2024 10:34:52 -0400 Subject: [PATCH 12/15] Builds and links now --- mistralrs-quant/build.rs | 3 +- mistralrs-quant/kernels/exl2/q_gemm_exl2.cu | 32 ++++++++++----------- mistralrs-quant/src/exl2/exl2_cuda.rs | 16 +++++------ mistralrs-quant/src/exl2/ffi.rs | 2 +- 4 files changed, 27 insertions(+), 26 deletions(-) diff --git a/mistralrs-quant/build.rs b/mistralrs-quant/build.rs index 8dac2a614..9af6b4832 100644 --- a/mistralrs-quant/build.rs +++ b/mistralrs-quant/build.rs @@ -8,7 +8,8 @@ fn main() { println!("cargo:rerun-if-changed=build.rs"); let build_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap()); let lib_files = vec![ - //"kernels/exl2/q_gemm_exl2.cu", + "kernels/exl2/q_gemm_exl2.cu", + "kernels/exl2/q_matrix.cu", "kernels/gptq/q_gemm.cu", "kernels/hqq/hqq.cu", "kernels/ops/ops.cu", diff --git a/mistralrs-quant/kernels/exl2/q_gemm_exl2.cu b/mistralrs-quant/kernels/exl2/q_gemm_exl2.cu index 406310a5c..a867ac017 100644 --- a/mistralrs-quant/kernels/exl2/q_gemm_exl2.cu +++ b/mistralrs-quant/kernels/exl2/q_gemm_exl2.cu @@ -67,13 +67,13 @@ extern "C" uintptr_t exl2_make_q_matrix( const int height, const int width, const int groups, - uint32_t q_weight, - uint16_t q_perm, - uint16_t q_invperm, - uint32_t q_scale, - half q_scale_max, - uint16_t q_groups, - uint16_t q_group_map + uint32_t* q_weight, + uint16_t* q_perm, + uint16_t* q_invperm, + uint32_t* q_scale, + half* q_scale_max, + uint16_t* q_groups, + uint16_t* q_group_map ) { QMatrix* m = new QMatrix ( @@ -81,20 +81,20 @@ extern "C" uintptr_t exl2_make_q_matrix( height, width, groups, - (uint32_t*)q_weight.data_ptr(), - (uint16_t*)q_perm.data_ptr(), - (uint16_t*)q_invperm.data_ptr(), - (uint32_t*)q_scale.data_ptr(), - (half*)q_scale_max.data_ptr(), - (uint16_t*)q_groups.data_ptr(), - (uint16_t*)q_group_map.data_ptr() + q_weight, + q_perm, + q_invperm, + q_scale, + q_scale_max, + q_groups, + q_group_map ); return reinterpret_cast(m); } -extern "C" void exl2_reconstruct_q_matrix(uintptr_t q_matrix) { +extern "C" void exl2_reconstruct_q_matrix(uintptr_t q_matrix, half* out) { QMatrix* m = reinterpret_cast(q_matrix); - m->reconstruct(); + m->reconstruct(out); } extern "C" void exl2_destroy_q_matrix(uintptr_t q_matrix) { diff --git a/mistralrs-quant/src/exl2/exl2_cuda.rs b/mistralrs-quant/src/exl2/exl2_cuda.rs index 7b373faf2..20cdc1b48 100644 --- a/mistralrs-quant/src/exl2/exl2_cuda.rs +++ b/mistralrs-quant/src/exl2/exl2_cuda.rs @@ -151,17 +151,17 @@ impl Exl2Layer { let c = unsafe { dev.alloc::(c_shape.elem_count()).w()? }; let c_ptr = *c.device_ptr() as *mut f16; - let temp_dq = if m > MAX_Q_GEMM_ROWS { - Tensor::zeros(&[k as usize, n as usize], DType::F16, a.device())? - } else { - Tensor::zeros(&[0, 0], DType::F16, a.device())? - }; - let temp_dq_ptr = get_cuda_slice::(&temp_dq)?; - if m > MAX_Q_GEMM_ROWS { + let temp_dq = if m > MAX_Q_GEMM_ROWS { + Tensor::zeros(&[k as usize, n as usize], DType::F16, a.device())? + } else { + Tensor::zeros(&[0, 0], DType::F16, a.device())? + }; + let temp_dq_ptr = get_cuda_slice::(&temp_dq)? as *mut f16; + // Reconstruct FP16 matrix, then cuBLAS unsafe { - exl2_reconstruct_q_matrix(self.exllama_state.lock().unwrap().q_matrix); + exl2_reconstruct_q_matrix(self.exllama_state.lock().unwrap().q_matrix, temp_dq_ptr); } let alpha = f16::from_f32(1.0); diff --git a/mistralrs-quant/src/exl2/ffi.rs b/mistralrs-quant/src/exl2/ffi.rs index 6ee7e59a6..3f9f874eb 100644 --- a/mistralrs-quant/src/exl2/ffi.rs +++ b/mistralrs-quant/src/exl2/ffi.rs @@ -22,7 +22,7 @@ extern "C" { pub fn exl2_destroy_q_matrix(q_matrix: QMatrixPtr); - pub fn exl2_reconstruct_q_matrix(q_matrix: QMatrixPtr); + pub fn exl2_reconstruct_q_matrix(q_matrix: QMatrixPtr, out: *mut f16); pub fn exl2_gemm_cuda(a: *const f16, b: *const c_void, c: *mut f16, m: i32, n: i32, k: i32); } From 56aca12cda29a105a750bfec515c54a72ef08851 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 19 Sep 2024 14:28:20 -0400 Subject: [PATCH 13/15] More progress --- mistralrs-quant/src/exl2/exl2_cuda.rs | 50 +++++++++------------------ mistralrs-quant/src/exl2/ffi.rs | 2 +- mistralrs-quant/src/lib.rs | 10 ++++-- 3 files changed, 25 insertions(+), 37 deletions(-) diff --git a/mistralrs-quant/src/exl2/exl2_cuda.rs b/mistralrs-quant/src/exl2/exl2_cuda.rs index 20cdc1b48..2bf312cbd 100644 --- a/mistralrs-quant/src/exl2/exl2_cuda.rs +++ b/mistralrs-quant/src/exl2/exl2_cuda.rs @@ -20,10 +20,9 @@ use crate::{ IsqType, QuantMethod, QuantMethodConfig, QuantizedSerde, }; -use super::ffi::{exl2_create_q_matrix, exl2_destroy_q_matrix, exl2_reconstruct_q_matrix}; +use super::ffi::{exl2_destroy_q_matrix, exl2_make_q_matrix, exl2_reconstruct_q_matrix}; const MAX_Q_GEMM_ROWS: i32 = 32; -const BLOCK_M_SIZE_MAX: i32 = 8; #[derive(Debug)] pub struct Exl2Layer { @@ -70,7 +69,7 @@ impl Exl2Layer { q_matrix: std::ptr::null_mut(), })); - Ok(Self { + let this = Self { q_weight, q_scale, q_groups, @@ -78,11 +77,9 @@ impl Exl2Layer { bias, bits, exllama_state, - }) - } - - pub fn post_init(&self) -> Result<()> { - self.initialize_exllama() + }; + this.initialize_exllama()?; + Ok(this) } fn initialize_exllama(&self) -> Result<()> { @@ -114,7 +111,7 @@ impl Exl2Layer { let b_q_group_map = get_cuda_slice::(&state.q_group_map)? as *const u16; state.q_matrix = unsafe { - exl2_create_q_matrix( + exl2_make_q_matrix( dev_ord, b_height, b_width, @@ -232,30 +229,17 @@ impl QuantMethod for Exl2Layer { q_group_map, bias, bits, - } => { - let exllama_state = Arc::new(Mutex::new(ExllamaState { - initialized: false, - q_scale_max, - q_perm, - q_group_map, - q_invperm_short: Tensor::zeros( - q_invperm.shape(), - DType::I16, - q_invperm.device(), - )?, - q_matrix: std::ptr::null_mut(), - })); - - Ok(Self { - q_weight, - q_scale, - q_groups, - q_invperm, - bias, - bits, - exllama_state, - }) - } + } => Self::new( + q_weight, + q_scale, + q_scale_max, + q_groups, + q_perm, + q_group_map, + q_invperm, + bias, + bits, + ), QuantMethodConfig::Gptq { .. } | QuantMethodConfig::Gguf { .. } | QuantMethodConfig::Unquantized(_) diff --git a/mistralrs-quant/src/exl2/ffi.rs b/mistralrs-quant/src/exl2/ffi.rs index 3f9f874eb..ec431d930 100644 --- a/mistralrs-quant/src/exl2/ffi.rs +++ b/mistralrs-quant/src/exl2/ffi.rs @@ -6,7 +6,7 @@ type QMatrixPtr = *mut c_void; #[allow(dead_code)] extern "C" { - pub fn exl2_create_q_matrix( + pub fn exl2_make_q_matrix( device: i32, height: i32, // q_perm.size(0); width: i32, // q_weight.size(1); diff --git a/mistralrs-quant/src/lib.rs b/mistralrs-quant/src/lib.rs index a3efc4f54..39c84433f 100644 --- a/mistralrs-quant/src/lib.rs +++ b/mistralrs-quant/src/lib.rs @@ -25,22 +25,24 @@ pub use unquantized::UnquantLinear; use candle_nn::{Linear, VarBuilder}; use serde::Deserialize; -#[derive(Debug, Clone, Deserialize, Default)] +#[derive(Debug, Clone, Deserialize)] pub enum QuantMethodType { - #[default] #[serde(rename = "gptq")] Gptq, + #[serde(rename = "exl2")] + Exl2, } impl Display for QuantMethodType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Gptq => write!(f, "GPTQ"), + Self::Exl2 => write!(f, "EXL2"), } } } -#[derive(Debug, Clone, Deserialize, Default)] +#[derive(Debug, Clone, Deserialize)] pub struct QuantizedConfig { pub bits: usize, pub quant_method: QuantMethodType, @@ -235,6 +237,7 @@ pub fn linear_no_bias( let layer = if let Some(quant_conf) = &config { match quant_conf.quant_method { QuantMethodType::Gptq => gptq_linear(in_dim, out_dim, quant_conf, vb)?, + QuantMethodType::Exl2 => todo!(), } } else { let layer = candle_nn::linear_no_bias(in_dim, out_dim, vb)?; @@ -254,6 +257,7 @@ pub fn linear( let layer = if let Some(quant_conf) = &config { match quant_conf.quant_method { QuantMethodType::Gptq => gptq_linear(in_dim, out_dim, quant_conf, vb)?, + QuantMethodType::Exl2 => todo!(), } } else { let layer = candle_nn::linear(in_dim, out_dim, vb)?; From f59a4f4350df075ebac5a84d2358d47eb3b75525 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 19 Sep 2024 20:43:11 -0400 Subject: [PATCH 14/15] Support loading --- Cargo.lock | 64 +++++++-------------------- Cargo.toml | 4 +- mistralrs-core/Cargo.toml | 2 +- mistralrs-pyo3/Cargo_template.toml | 2 +- mistralrs-quant/src/exl2/exl2_cuda.rs | 24 ++-------- mistralrs-quant/src/lib.rs | 34 +++++++++++--- 6 files changed, 52 insertions(+), 78 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 20ad3e48e..374a199b8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -387,33 +387,11 @@ checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" [[package]] name = "candle-core" version = "0.6.1" -source = "git+https://github.com/EricLBuehler/candle.git?rev=8a99f7c#8a99f7cf31a1d8f175281492eaa7026730067d08" -dependencies = [ - "byteorder", - "candle-kernels 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=8a99f7c)", - "cudarc", - "gemm", - "half", - "memmap2", - "num-traits", - "num_cpus", - "rand", - "rand_distr", - "rayon", - "safetensors", - "thiserror", - "yoke", - "zip", -] - -[[package]] -name = "candle-core" -version = "0.6.1" -source = "git+https://github.com/EricLBuehler/candle.git?rev=9e31a19#9e31a192642b4048e3df75173efddebaf663fef2" +source = "git+https://github.com/EricLBuehler/candle.git?rev=61d19cc#61d19cc4e8930c6e25b281fe48615f4c1d8479a7" dependencies = [ "accelerate-src", "byteorder", - "candle-kernels 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", + "candle-kernels", "candle-metal-kernels", "cudarc", "gemm", @@ -436,26 +414,18 @@ dependencies = [ [[package]] name = "candle-flash-attn" version = "0.6.1" -source = "git+https://github.com/EricLBuehler/candle.git?rev=8a99f7c#8a99f7cf31a1d8f175281492eaa7026730067d08" +source = "git+https://github.com/EricLBuehler/candle.git?rev=61d19cc#61d19cc4e8930c6e25b281fe48615f4c1d8479a7" dependencies = [ "anyhow", "bindgen_cuda 0.1.5", - "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=8a99f7c)", + "candle-core", "half", ] [[package]] name = "candle-kernels" version = "0.6.1" -source = "git+https://github.com/EricLBuehler/candle.git?rev=8a99f7c#8a99f7cf31a1d8f175281492eaa7026730067d08" -dependencies = [ - "bindgen_cuda 0.1.5", -] - -[[package]] -name = "candle-kernels" -version = "0.6.1" -source = "git+https://github.com/EricLBuehler/candle.git?rev=9e31a19#9e31a192642b4048e3df75173efddebaf663fef2" +source = "git+https://github.com/EricLBuehler/candle.git?rev=61d19cc#61d19cc4e8930c6e25b281fe48615f4c1d8479a7" dependencies = [ "bindgen_cuda 0.1.5", ] @@ -463,7 +433,7 @@ dependencies = [ [[package]] name = "candle-metal-kernels" version = "0.6.1" -source = "git+https://github.com/EricLBuehler/candle.git?rev=9e31a19#9e31a192642b4048e3df75173efddebaf663fef2" +source = "git+https://github.com/EricLBuehler/candle.git?rev=61d19cc#61d19cc4e8930c6e25b281fe48615f4c1d8479a7" dependencies = [ "metal", "once_cell", @@ -474,10 +444,10 @@ dependencies = [ [[package]] name = "candle-nn" version = "0.6.1" -source = "git+https://github.com/EricLBuehler/candle.git?rev=9e31a19#9e31a192642b4048e3df75173efddebaf663fef2" +source = "git+https://github.com/EricLBuehler/candle.git?rev=61d19cc#61d19cc4e8930c6e25b281fe48615f4c1d8479a7" dependencies = [ "accelerate-src", - "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", + "candle-core", "candle-metal-kernels", "half", "intel-mkl-src", @@ -2098,7 +2068,7 @@ name = "mistralrs" version = "0.3.0" dependencies = [ "anyhow", - "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", + "candle-core", "either", "futures", "image", @@ -2124,7 +2094,7 @@ dependencies = [ "buildstructor", "bytemuck", "bytemuck_derive", - "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", + "candle-core", "candle-flash-attn", "candle-nn", "cfgrammar", @@ -2186,7 +2156,7 @@ version = "0.3.0" dependencies = [ "anyhow", "bindgen_cuda 0.1.6", - "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", + "candle-core", "half", ] @@ -2197,7 +2167,7 @@ dependencies = [ "accelerate-src", "anyhow", "base64 0.22.1", - "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", + "candle-core", "data-url", "either", "futures", @@ -2220,7 +2190,7 @@ version = "0.3.0" dependencies = [ "bindgen_cuda 0.1.5", "byteorder", - "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", + "candle-core", "candle-nn", "half", "lazy_static", @@ -2237,7 +2207,7 @@ dependencies = [ "accelerate-src", "anyhow", "axum", - "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", + "candle-core", "clap", "ctrlc", "data-url", @@ -2263,7 +2233,7 @@ dependencies = [ name = "mistralrs-vision" version = "0.3.0" dependencies = [ - "candle-core 0.6.1 (git+https://github.com/EricLBuehler/candle.git?rev=9e31a19)", + "candle-core", "image", ] @@ -4232,9 +4202,9 @@ checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" [[package]] name = "unicode-width" -version = "0.1.13" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" [[package]] name = "unicode_categories" diff --git a/Cargo.toml b/Cargo.toml index ffc34bba0..6a0d0a233 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,8 +26,8 @@ license = "MIT" [workspace.dependencies] anyhow = "1.0.80" -candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "9e31a19" } -candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "9e31a19" } +candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "61d19cc" } +candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "61d19cc" } serde = "1.0.197" serde_json = "1.0.114" indexmap = { version = "2.2.5", features = ["serde"] } diff --git a/mistralrs-core/Cargo.toml b/mistralrs-core/Cargo.toml index a48232996..dde0dd30a 100644 --- a/mistralrs-core/Cargo.toml +++ b/mistralrs-core/Cargo.toml @@ -17,7 +17,7 @@ candle-core.workspace = true candle-nn.workspace = true serde.workspace = true serde_json.workspace = true -candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "8a99f7c", optional = true } +candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "61d19cc", optional = true } dirs = "5.0.1" hf-hub = "0.3.2" thiserror = "1.0.57" diff --git a/mistralrs-pyo3/Cargo_template.toml b/mistralrs-pyo3/Cargo_template.toml index 7946eee0d..4111dbdb2 100644 --- a/mistralrs-pyo3/Cargo_template.toml +++ b/mistralrs-pyo3/Cargo_template.toml @@ -20,7 +20,7 @@ pyo3.workspace = true mistralrs-core = { version = "0.3.0", path = "../mistralrs-core", features=["pyo3_macros","$feature_name"] } serde.workspace = true serde_json.workspace = true -candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "8a99f7c", features=["$feature_name"] } +candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "61d19cc", features=["$feature_name"] } indexmap.workspace = true accelerate-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true } diff --git a/mistralrs-quant/src/exl2/exl2_cuda.rs b/mistralrs-quant/src/exl2/exl2_cuda.rs index 2bf312cbd..6611a27f8 100644 --- a/mistralrs-quant/src/exl2/exl2_cuda.rs +++ b/mistralrs-quant/src/exl2/exl2_cuda.rs @@ -31,7 +31,6 @@ pub struct Exl2Layer { q_groups: Tensor, q_invperm: Tensor, bias: Option, - bits: i32, exllama_state: Arc>, } @@ -54,17 +53,14 @@ impl Exl2Layer { q_scale: Tensor, q_scale_max: Tensor, q_groups: Tensor, - q_perm: Tensor, - q_group_map: Tensor, q_invperm: Tensor, bias: Option, - bits: i32, ) -> Result { let exllama_state = Arc::new(Mutex::new(ExllamaState { initialized: false, q_scale_max, - q_perm, - q_group_map, + q_perm: Tensor::zeros((1,), DType::I16, q_invperm.device())?, + q_group_map: Tensor::zeros((1,), DType::I16, q_invperm.device())?, q_invperm_short: Tensor::zeros(q_invperm.shape(), DType::I16, q_invperm.device())?, q_matrix: std::ptr::null_mut(), })); @@ -75,7 +71,6 @@ impl Exl2Layer { q_groups, q_invperm, bias, - bits, exllama_state, }; this.initialize_exllama()?; @@ -224,22 +219,9 @@ impl QuantMethod for Exl2Layer { q_scale, q_scale_max, q_groups, - q_perm, q_invperm, - q_group_map, bias, - bits, - } => Self::new( - q_weight, - q_scale, - q_scale_max, - q_groups, - q_perm, - q_group_map, - q_invperm, - bias, - bits, - ), + } => Self::new(q_weight, q_scale, q_scale_max, q_groups, q_invperm, bias), QuantMethodConfig::Gptq { .. } | QuantMethodConfig::Gguf { .. } | QuantMethodConfig::Unquantized(_) diff --git a/mistralrs-quant/src/lib.rs b/mistralrs-quant/src/lib.rs index 39c84433f..5544975dc 100644 --- a/mistralrs-quant/src/lib.rs +++ b/mistralrs-quant/src/lib.rs @@ -7,7 +7,7 @@ use std::{ use candle_core::{ quantized::{GgmlDType, QTensor}, - DType, Device, Result, Tensor, + Context, DType, Device, Result, Tensor, }; mod exl2; @@ -17,6 +17,7 @@ mod hqq; mod unquantized; mod utils; +use exl2::Exl2Layer; pub use gguf::GgufMatMul; pub use gptq::GptqLayer; pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer}; @@ -46,20 +47,17 @@ impl Display for QuantMethodType { pub struct QuantizedConfig { pub bits: usize, pub quant_method: QuantMethodType, - pub group_size: usize, + pub group_size: Option, } #[derive(Debug, Clone)] pub enum QuantMethodConfig { Exl2 { - bits: i32, q_weight: Tensor, q_scale: Tensor, q_scale_max: Tensor, q_groups: Tensor, - q_perm: Tensor, q_invperm: Tensor, - q_group_map: Tensor, bias: Option, }, Gptq { @@ -294,7 +292,10 @@ pub fn gptq_linear( Default::default(), DType::I32, )?; - let scale_and_zero_size = in_dim / config.group_size; + let scale_and_zero_size = in_dim + / config + .group_size + .context("GPTQ requires group size in QuantizedConfig")?; let qzeros = vb.get_with_hints_dtype( (scale_and_zero_size, out_dim / pack_factor!(config.bits)), "qzeros", @@ -321,3 +322,24 @@ pub fn gptq_linear( }; Ok(Arc::new(GptqLayer::new(config)?)) } + +pub fn exl2_linear( + _in_dim: usize, + _out_dim: usize, + _config: &QuantizedConfig, + vb: VarBuilder, +) -> Result> { + let q_weight = vb.get_unchecked_dtype("q_weight", DType::I32)?; + let q_scale_max = vb.get_unchecked_dtype("q_scale_max", DType::F16)?; + let q_scale = vb.get_unchecked_dtype("q_scale", DType::I32)?; + let q_invperm = vb.get_unchecked_dtype("q_invperm", DType::I32)?; + let q_groups = vb.get_unchecked_dtype("q_groups", DType::I16)?; + Ok(Arc::new(Exl2Layer::new(QuantMethodConfig::Exl2 { + q_weight, + q_scale, + q_scale_max, + q_groups, + q_invperm, + bias: None, + })?)) +} From 81f3bdc24f037f15fb6fff61b9d49fd517d44fdd Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Wed, 25 Sep 2024 16:44:42 -0400 Subject: [PATCH 15/15] Cleanup --- Cargo.lock | 10 +++++----- Cargo.toml | 4 ++-- mistralrs-core/Cargo.toml | 2 +- mistralrs-pyo3/Cargo_template.toml | 2 +- mistralrs-quant/src/utils/ops.rs | 8 -------- mistralrs-quant/src/utils/uqff.rs | 1 - 6 files changed, 9 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 00cc991fe..dc83283c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -393,7 +393,7 @@ checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" [[package]] name = "candle-core" version = "0.6.1" -source = "git+https://github.com/EricLBuehler/candle.git?rev=9c62368#9c62368211a29f5c2cf94deb811ecca7a9475c38" +source = "git+https://github.com/EricLBuehler/candle.git?rev=56f00b8#56f00b828d48b0a6313745b6c794def93cee73ba" dependencies = [ "accelerate-src", "byteorder", @@ -420,7 +420,7 @@ dependencies = [ [[package]] name = "candle-flash-attn" version = "0.6.1" -source = "git+https://github.com/EricLBuehler/candle.git?rev=9c62368#9c62368211a29f5c2cf94deb811ecca7a9475c38" +source = "git+https://github.com/EricLBuehler/candle.git?rev=56f00b8#56f00b828d48b0a6313745b6c794def93cee73ba" dependencies = [ "anyhow", "bindgen_cuda 0.1.5", @@ -431,7 +431,7 @@ dependencies = [ [[package]] name = "candle-kernels" version = "0.6.1" -source = "git+https://github.com/EricLBuehler/candle.git?rev=9c62368#9c62368211a29f5c2cf94deb811ecca7a9475c38" +source = "git+https://github.com/EricLBuehler/candle.git?rev=56f00b8#56f00b828d48b0a6313745b6c794def93cee73ba" dependencies = [ "bindgen_cuda 0.1.5", ] @@ -439,7 +439,7 @@ dependencies = [ [[package]] name = "candle-metal-kernels" version = "0.6.1" -source = "git+https://github.com/EricLBuehler/candle.git?rev=9c62368#9c62368211a29f5c2cf94deb811ecca7a9475c38" +source = "git+https://github.com/EricLBuehler/candle.git?rev=56f00b8#56f00b828d48b0a6313745b6c794def93cee73ba" dependencies = [ "metal", "once_cell", @@ -450,7 +450,7 @@ dependencies = [ [[package]] name = "candle-nn" version = "0.6.1" -source = "git+https://github.com/EricLBuehler/candle.git?rev=9c62368#9c62368211a29f5c2cf94deb811ecca7a9475c38" +source = "git+https://github.com/EricLBuehler/candle.git?rev=56f00b8#56f00b828d48b0a6313745b6c794def93cee73ba" dependencies = [ "accelerate-src", "candle-core", diff --git a/Cargo.toml b/Cargo.toml index 12b11cc3e..b1b70326e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,8 +26,8 @@ license = "MIT" [workspace.dependencies] anyhow = "1.0.80" -candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "9c62368" } -candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "9c62368" } +candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "56f00b8" } +candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "56f00b8" } serde = "1.0.197" serde_json = "1.0.114" indexmap = { version = "2.2.5", features = ["serde"] } diff --git a/mistralrs-core/Cargo.toml b/mistralrs-core/Cargo.toml index 321c1d1fa..1a9fdddc3 100644 --- a/mistralrs-core/Cargo.toml +++ b/mistralrs-core/Cargo.toml @@ -17,7 +17,7 @@ candle-core.workspace = true candle-nn.workspace = true serde.workspace = true serde_json.workspace = true -candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "9c62368", optional = true } +candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "56f00b8", optional = true } dirs = "5.0.1" hf-hub = "0.3.2" thiserror = "1.0.57" diff --git a/mistralrs-pyo3/Cargo_template.toml b/mistralrs-pyo3/Cargo_template.toml index 8890cbfca..3059d62d0 100644 --- a/mistralrs-pyo3/Cargo_template.toml +++ b/mistralrs-pyo3/Cargo_template.toml @@ -20,7 +20,7 @@ pyo3.workspace = true mistralrs-core = { version = "0.3.0", path = "../mistralrs-core", features=["pyo3_macros","$feature_name"] } serde.workspace = true serde_json.workspace = true -candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "9c62368", features=["$feature_name"] } +candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "56f00b8", features=["$feature_name"] } indexmap.workspace = true accelerate-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true } diff --git a/mistralrs-quant/src/utils/ops.rs b/mistralrs-quant/src/utils/ops.rs index 53327f839..de59be5d0 100644 --- a/mistralrs-quant/src/utils/ops.rs +++ b/mistralrs-quant/src/utils/ops.rs @@ -66,7 +66,6 @@ impl CustomOp2 for BitWiseOr { let result = CpuStorage::I32(result); Ok((result, l1.shape().clone())) } - CpuStorage::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "bitwise-or")), CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise-or")), CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise-or")), CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise-or")), @@ -130,9 +129,6 @@ impl CustomOp2 for BitWiseOr { let elem_count = l1.shape().elem_count(); (d_in1_ptr, d_in2_ptr, elem_count) } - DType::I16 => { - return Err(Error::UnsupportedDTypeForOp(DType::I16, "bitwise-or")); - } DType::BF16 => { return Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise-or")); } @@ -226,7 +222,6 @@ impl CustomOp1 for Leftshift { let result = CpuStorage::I32(result); Ok((result, l1.shape().clone())) } - CpuStorage::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "leftshifr")), CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "leftshifr")), CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "leftshifr")), CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "leftshifr")), @@ -262,9 +257,6 @@ impl CustomOp1 for Leftshift { let elem_count = l1.shape().elem_count(); (d_in1_ptr, elem_count) } - DType::I16 => { - return Err(Error::UnsupportedDTypeForOp(DType::I16, "leftshift")); - } DType::BF16 => { return Err(Error::UnsupportedDTypeForOp(DType::BF16, "leftshift")); } diff --git a/mistralrs-quant/src/utils/uqff.rs b/mistralrs-quant/src/utils/uqff.rs index 1520f8e67..e9ca1e573 100644 --- a/mistralrs-quant/src/utils/uqff.rs +++ b/mistralrs-quant/src/utils/uqff.rs @@ -119,7 +119,6 @@ pub(crate) fn deserialize_tensor( DType::I16 => bytes_to_data::(&tensor_data, &dims, device), DType::I32 => bytes_to_data::(&tensor_data, &dims, device), DType::I64 => bytes_to_data::(&tensor_data, &dims, device), - DType::I16 => bytes_to_data::(&tensor_data, &dims, device), DType::U32 => bytes_to_data::(&tensor_data, &dims, device), DType::U8 => bytes_to_data::(&tensor_data, &dims, device), }