-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
zhoujingya
committed
Jan 15, 2025
1 parent
d0c54f6
commit 86bcf52
Showing
3 changed files
with
85 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
numba |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#include <iostream> | ||
|
||
// 这个算法实现了矩阵乘法。它接受三个矩阵A、B和C,以及它们的维度m、k和n。 | ||
// A是一个m行k列的矩阵,B是一个k行n列的矩阵,C是一个m行n列的矩阵。 | ||
// 算法通过三个嵌套的for循环来计算矩阵A和矩阵B的乘积,并将结果存储在矩阵C中。 | ||
// 外层循环遍历矩阵A的行,中间循环遍历矩阵B的列,内层循环计算矩阵A的行和矩阵B的列的点积。 | ||
// A[0][0] * B[0][0] + A[0][1] * B[1][0] + A[0][2] * B[2][0] + A[0][3] * B[3][0] | ||
// = C[0][0] A[0][0] * B[0][1] + A[0][1] * B[1][1] + A[0][2] * B[2][1] + A[0][3] | ||
// * B[3][1] = C[0][1] A[1][0] * B[0][0] + A[1][1] * B[1][0] + A[1][2] * B[2][0] | ||
// + A[1][3] * B[3][0] = C[1][0] 以此类推,计算出C矩阵的所有元素。 | ||
// 这个算法的时间复杂度是O(m*k*n),空间复杂度是O(m*n)。 | ||
// 这个算法是矩阵乘法的标准实现,可以用于任何矩阵乘法的实现。 | ||
|
||
void matMul(float *A, float *B, float *C, int m, int k, int n) { | ||
for (int i = 0; i < m; i++) { | ||
for (int j = 0; j < k; j++) { | ||
for (int l = 0; l < n; l++) { | ||
C[i * n + l] += A[i * k + j] * B[j * n + l]; | ||
} | ||
} | ||
} | ||
} | ||
|
||
int main() { | ||
// m = 3, k = 4, n = 3 | ||
float a[12] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; // 3 * 4 | ||
float b[12] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; // 4 * 3 | ||
float dst[9] = {70, 80, 90, 158, 184, 210, 246, 288, 330}; // 3 * 3 | ||
float *c = new float[9]; | ||
matMul(a, b, c, 3, 4, 3); | ||
int fail = 0; | ||
for (int i = 0; i < 9; i++) { | ||
if (c[i] != dst[i]) | ||
fail++; | ||
} | ||
if (fail) | ||
std::cout << "Matrix multiply error\n"; | ||
else | ||
std::cout << "Matrix multiply succeed\n"; | ||
delete[] c; | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from numba import cuda | ||
import numpy as np | ||
|
||
|
||
@cuda.jit | ||
def matrix_multiply(A, B, C): | ||
# 获取当前线程的索引 | ||
row = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x | ||
col = cuda.blockIdx.y * cuda.blockDim.y + cuda.threadIdx.y | ||
|
||
# 检查索引是否在矩阵范围内 | ||
if row < C.shape[0] and col < C.shape[1]: | ||
tmp = 0 | ||
# 计算矩阵乘法的一个元素 | ||
for k in range(A.shape[1]): | ||
tmp += A[row, k] * B[k, col] | ||
C[row, col] = tmp | ||
|
||
|
||
# 创建示例矩阵 | ||
A = np.array([[1, 2], [3, 4]], dtype=np.float32) | ||
B = np.array([[5, 6], [7, 8]], dtype=np.float32) | ||
C = np.zeros((2, 2), dtype=np.float32) | ||
|
||
# 将数据复制到GPU | ||
d_A = cuda.to_device(A) | ||
d_B = cuda.to_device(B) | ||
d_C = cuda.to_device(C) | ||
|
||
# 定义线程块和网格大小 | ||
threadsperblock = (2, 2) | ||
blockspergrid_x = int(np.ceil(A.shape[0] / threadsperblock[0])) | ||
blockspergrid_y = int(np.ceil(B.shape[1] / threadsperblock[1])) | ||
blockspergrid = (blockspergrid_x, blockspergrid_y) | ||
|
||
# 执行核函数 | ||
matrix_multiply[blockspergrid, threadsperblock](d_A, d_B, d_C) | ||
|
||
# 将结果复制回主机 | ||
result = d_C.copy_to_host() | ||
|
||
print(result) |