-
Notifications
You must be signed in to change notification settings - Fork 2
/
gemm_intrinsic.c
68 lines (60 loc) · 3.13 KB
/
gemm_intrinsic.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include "gemm_intrinsic.hpp"
#include "gemm_intrinsic.h"
extern "C" {
typedef struct {
void * instance_ptr = nullptr;
IceSword::DataType dtype_a;
IceSword::DataType dtype_b;
IceSword::DataType dtype_c;
} InstanceHandle;
void * get_instance(IceSword::DataType type_a, IceSword::DataType type_b, IceSword::DataType type_c) {
InstanceHandle *instance_handle = new InstanceHandle;
if (type_b == IceSword::DT_INT8) {
if (type_a == IceSword::DT_INT8) {
instance_handle->instance_ptr = reinterpret_cast<void *>(new IceSword::IntrinsicGemm<char, char, int>);
} else {
}
} else if (type_b == IceSword::DT_FLOAT) {
instance_handle->instance_ptr = reinterpret_cast<void *>(new IceSword::IntrinsicGemm<float, float, float>);
}
instance_handle->dtype_a = type_a;
instance_handle->dtype_b = type_b;
instance_handle->dtype_c = type_c;
return (void *)instance_handle;
}
IceSword::S_Status instance_init(void* instance_handle, const bool trans_a, const bool trans_b,
const int m, const int n, const int k) {
auto handle = reinterpret_cast<InstanceHandle *>(instance_handle);
auto status = IceSword::S_UnImplError;
if (handle->dtype_b == IceSword::DT_INT8) {
if (handle->dtype_a == IceSword::DT_INT8) {
status = reinterpret_cast<IceSword::IntrinsicGemm<char, char, int> *>(handle->instance_ptr)->init(trans_a, trans_b, m, n, k);
} else {
}
} else if (handle->dtype_b == IceSword::DT_FLOAT) {
status = reinterpret_cast<IceSword::IntrinsicGemm<float, float, float> *>(handle->instance_ptr)->init(trans_a, trans_b, m, n, k);
}
return status;
}
IceSword::S_Status instance_dispatch(void* instance_handle, const float alpha,
const float beta, const void* a,
const void* b, void* c) {
auto handle = reinterpret_cast<InstanceHandle *>(instance_handle);
auto status = IceSword::S_UnImplError;
if (handle->dtype_b == IceSword::DT_INT8) {
if (handle->dtype_a == IceSword::DT_INT8) {
auto a_ptr = reinterpret_cast<const char *>(a);
auto b_ptr = reinterpret_cast<const char *>(b);
auto c_ptr = reinterpret_cast<int *>(c);
status = reinterpret_cast<IceSword::IntrinsicGemm<char, char, int> *>(handle->instance_ptr)->dispatch(alpha, beta, a_ptr, b_ptr, c_ptr);
} else {
}
} else if (handle->dtype_b == IceSword::DT_FLOAT) {
auto a_ptr = reinterpret_cast<const float *>(a);
auto b_ptr = reinterpret_cast<const float *>(b);
auto c_ptr = reinterpret_cast<float *>(c);
status = reinterpret_cast<IceSword::IntrinsicGemm<float, float, float> *>(handle->instance_ptr)->dispatch(alpha, beta, a_ptr, b_ptr, c_ptr);
}
return status;
}
}