Skip to content

Commit

Permalink
add custom op and support gelu+dropout+linear fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
jlxue committed Jul 7, 2023
1 parent 5730e93 commit 23e3c7a
Show file tree
Hide file tree
Showing 11 changed files with 656 additions and 0 deletions.
Empty file.
135 changes: 135 additions & 0 deletions test/python/dsl_frontend/custom_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import json
import hashlib
import importlib
import torch
import os
from run_tuned_json_graph import get_device_source
from export_json_graph import get_input_dict, construct_json_graph


dtype_mapping = {
'float64': torch.float64,
'float32': torch.float32,
'float16': torch.float16,
'int64': torch.int64,
'int32': torch.int32,
'int16': torch.int16,
'int8': torch.int8,
}
def generate_welder_graph(ir, feed_list, extra_outputs, tags=""):
input_dict, kwargs = {}, {}
for k, i, shape, dtype in feed_list:
input_dict[k] = {
'dtype': str(dtype).split('.')[1],
'shape': list(shape)
}

ir = ir.replace('"', '`').replace('\n', ' ').strip()
input_dict = json.dumps(input_dict)
extra_outputs = ', '.join(['"%s"' % x for x in extra_outputs])
expression = f'- einstein_v2("{ir}", input_dict={input_dict}, extra_outputs=[{extra_outputs}]) ## @: {tags}'

nodes = []
edges = [[id, 0] for id in range(len(feed_list))]
node_id = len(feed_list)
nodes.append([node_id, expression, "fused_op", edges])
nodes.append([node_id + 1, "", "Result", [[node_id, 0]]])

return json.dumps(nodes, indent=2)


def load_kernel(graph_path):
raw_model_path = graph_path
tuned_model_path = graph_path.strip('.json') + ".kernel.json"
with open(raw_model_path) as f:
raw_json_graph = json.load(f)
with open(tuned_model_path) as f:
tuned_json_graph = json.load(f)

inputs_outputs_info = []
device_source = get_device_source(raw_json_graph, tuned_json_graph, inputs_outputs_info)

backend = 'c-cuda'
lib_name = 'antares_custom_torch_v2_%s' % backend.replace('-', '_')
try:
custom_lib = importlib.import_module(lib_name)
except:
print(f'Failed to import {lib_name}.\nPlease install Custom Plugin for backend in advance: BACKEND={backend} antares torch-setup')
custom_key = custom_lib.inject(device_source)
return custom_lib, custom_key, inputs_outputs_info

class CompiledKernel:
def __init__(self, custom_lib, custom_key, inout_info):
self.custom_lib = custom_lib
self.custom_key = custom_key
self.inout_info = inout_info

KERNEL_CACHE = {}

class CustomOp(torch.nn.Module):
def __init__(self, kernel_file, device=None):
super(CustomOp, self).__init__()
if device is None:
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
else:
self.device = device
self.custom_lib, self.custom_key, inout_info = load_kernel(kernel_file)
self.output_list = inout_info[1]


def __init__(self, ir, input_orders, extra_outputs=[], tags="", steps=1, arch='g3090', device=None):
super(CustomOp, self).__init__()
if device is None:
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
else:
self.device = device
ir = ir.replace('"', '`').replace('\n', ' ').strip()
self.hash_key = hashlib.sha256(ir.encode()).hexdigest()
if self.hash_key in KERNEL_CACHE:
cache = KERNEL_CACHE[self.hash_key]
self.custom_lib = cache.custom_lib
self.custom_key = cache.custom_key
self.output_list = cache.inout_info[1]
return

input_list, index = [], 0
for k in input_orders:
if isinstance(input_orders[k], tuple):
input_list += [(k, index, input_orders[k][2], input_orders[k][1])]
else:
input_list += [(k, index, input_orders[k].shape, input_orders[k].dtype)]
index += 1

self.input_orders = sorted(input_list, key=lambda x: x[0])
self.graph = generate_welder_graph(ir, input_list, extra_outputs, tags)

graph_path = f'/home/jxue/.cache/nnfusion/graph/{self.hash_key}.json'
tuned_graph_path = f'/home/jxue/.cache/nnfusion/graph/{self.hash_key}.kernel.json'
if not os.path.exists(tuned_graph_path) or steps > 1:
with open(graph_path, 'w+') as fp:
fp.write(self.graph)

cmd = f'python3 -m run_compiler {graph_path} {tuned_graph_path} --device 0 --topk {steps} --arch {arch}'
print(cmd)
os.system(cmd)
assert os.path.exists(tuned_graph_path)
self.custom_lib, self.custom_key, inout_info = load_kernel(graph_path)
self.output_list = inout_info[1]
KERNEL_CACHE[self.hash_key] = CompiledKernel(self.custom_lib, self.custom_key, inout_info)

def input_info(self):
return self.input_list

def forward(self, inputs):
ordered_inputs = []
for i in range(len(inputs)):
inp = inputs[i]
ordered_inputs.append(inp.contiguous().to(self.device))

outputs = []
for info in self.output_list:
out = torch.empty(info[1]['shape'], device=self.device, dtype=dtype_mapping[info[1]['dtype']])
outputs.append(out)
self.custom_lib.forward(self.custom_key, ordered_inputs + outputs)
outputs = outputs[0] if len(outputs) == 1 else tuple(outputs)
return outputs
120 changes: 120 additions & 0 deletions test/python/dsl_frontend/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from custom_op import CustomOp

class CustomLinear(nn.Module):
def __init__(
self,
embed_dim,
ffn_dim,
activation_dropout,
):
super().__init__()
self.embed_dim = embed_dim
self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
self.fc2 = nn.Linear(ffn_dim, self.embed_dim)

def reset_parameters(self):
self.fc2.reset_parameters()

def forward(self, x):
x = F.gelu(x.float()).type_as(x)
x = self.activation_dropout_module(x)
x = self.fc2(x)
return x

M_SQRT1_2 = 0.70710678118654752440
M_2_SQRTPI = 1.12837916709551257390

class FusedLinearFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias, p):
mask = (torch.rand_like(x) >= p)
# fused_op = CustomOp(ir=f'''
# output0[N0, N2] +=! input0[N0, N1] * input1[N2, N1];
# ''', input_orders={'input0': x, 'input1': weight}, tags="tensorCoreConfig=(0, 1)", device=device)

fused_op = CustomOp(ir=f'''
m0[N0, N1] = input0[N0, N1].cast(`float32`);
m1[N0, N1] = m0[N0, N1] * const(0.5).cast(`float32`) * (const(1.0).cast(`float32`) + (m0[N0, N1] * const({M_SQRT1_2}).cast(`float32`)).call(`erf`));
m2[N0, N1] = m1[N0, N1].cast(`float16`);
m3[N0, N1] = m2[N0, N1] * input3[N0, N1] / const({1-p}).cast(`float16`);
m4[N0, N2] +=! m3[N0, N1] * input1[N2, N1];
output0[N0, N2] = m4[N0, N2] + input2[N0];
''', input_orders={'input0': x, 'input1': weight, 'input2': bias, 'input3': mask}, tags="tensorCoreConfig=(0, 1)", device=device)
y = fused_op([x, weight, bias, mask])
ctx.save_for_backward(x, weight, mask)
ctx.p = p
return y

@staticmethod
def backward(ctx, dy):
x, weight, mask = ctx.saved_tensors
p = ctx.p
dbias = torch.sum(dy, dim=0)
dw_op = CustomOp(ir=f'''
m0[N0, N1] = input0[N0, N1].cast(`float32`);
m1[N0, N1] = m0[N0, N1] * const(0.5).cast(`float32`) * (const(1.0).cast(`float32`) + (m0[N0, N1] * const({M_SQRT1_2}).cast(`float32`)).call(`erf`));
m2[N0, N1] = m1[N0, N1].cast(`float16`);
m3[N0, N1] = m2[N0, N1] * input2[N0, N1] / const({1-p}).cast(`float16`);
output0[N2, N1] +=! input1[N0, N2] * m3[N0, N1];
''', input_orders={'input0': x, 'input1': dy, 'input2': mask}, tags="tensorCoreConfig=(0, 1)", device=device)
dw = dw_op([x, dy, mask])

dx_op = CustomOp(ir=f'''
m0[N0, N1] +=! input3[N0, N2] * input1[N2, N1];
m1[N0, N1] = m0[N0, N1] * input2[N0, N1] * const({1-p}).cast(`float16`);
m2[N0, N1] = m1[N0, N1].cast(`float32`);
m3[N0, N1] = const(0.5).cast(`float32`) * (const(1.0).cast(`float32`) + (input0[N0, N1] * const({M_SQRT1_2}).cast(`float32`)).call(`erf`));
m4[N0, N1] = (const(-0.5).cast(`float32`) * input0[N0, N1] * input0[N0, N1]).call(`exp`) * const({M_2_SQRTPI * M_SQRT1_2 * 0.5}).cast(`float32`);
output0[N0, N1] = m2[N0, N1] * (m3[N0, N1] + input0[N0, N1] * m4[N0, N1]);
''', input_orders={'input0': x, 'input1': weight, 'input2': mask, 'input3': dy}, tags="tensorCoreConfig=(0, 1)", device=device)
dx = dx_op([x, weight, mask, dy])
return dx, dw, dbias, None

class FusedCustomLinear(nn.Module):
def __init__(
self,
embed_dim,
ffn_dim,
activation_dropout,
):
super().__init__()
self.embed_dim = embed_dim
self.activation_dropout = activation_dropout
self.fc2 = nn.Linear(ffn_dim, self.embed_dim, dtype=torch.float16)

def reset_parameters(self):
self.fc2.reset_parameters()

def forward(self, x):
return FusedLinearFunc.apply(x, self.fc2.weight, self.fc2.bias, self.activation_dropout)


if __name__ == '__main__':
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
torch.set_default_dtype(torch.float16)
x = torch.randn(2048, 16384, requires_grad = True, device=device)
ref = CustomLinear(4096, 16384, 0).to(device)
fused = FusedCustomLinear(4096, 16384, 0).to(device)

y_ref = ref(x)
y_fused = fused(x)

y_grad = torch.ones_like(y, device=device)
y.backward(y_grad)

# start = time.time()
# for i in range(100):
# y = layer.forward(x)
# y.backward(y_grad)
# #print(x, x.grad, layer.fc2.weight.grad, layer.fc2.bias.grad)
# end = time.time()
# print(end-start)





48 changes: 48 additions & 0 deletions test/python/dsl_frontend/kernel_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,54 @@
#ifndef __CUDA_COMMON_MACRO__
#define __CUDA_COMMON_MACRO__
__device__ half max(half a, half b)
{
return __hgt(__half(a), __half(b)) ? a : b;
}
__device__ half min(half a, half b)
{
return __hlt(__half(a), __half(b)) ? a : b;
}
typedef long long _ll;
#define int64_t _ll
#define __int8_t_defined
#define CUDA_UNSUPPORTED_HALF_MATH_BINARY(HALF_MATH_NAME, FP32_MATH_NAME) inline __device__ half HALF_MATH_NAME(half x, half y) { float tmp_x = __half2float(x); float tmp_y = __half2float(y); float result = FP32_MATH_NAME(tmp_x, tmp_y); return __float2half(result); }
#define CUDA_UNSUPPORTED_HALF_MATH_UNARY(HALF_MATH_NAME, FP32_MATH_NAME) inline __device__ half HALF_MATH_NAME(half x) { float tmp_x = __half2float(x); float result = FP32_MATH_NAME(tmp_x); return __float2half(result); }
CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf)
CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf)
CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf)
CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf)
CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf)
#undef CUDA_UNSUPPORTED_HALF_MATH_BINARY
#undef CUDA_UNSUPPORTED_HALF_MATH_UNARY
// Pack two half values.
inline __device__ __host__ unsigned
__pack_half2(const half x, const half y) {
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
// There is no make_int8 in cuda, but TVM codegen seem to use it
inline __device__ longlong4 make_int8(int x0, int x1, int x2, int x3, int x4, int x5, int x6, int x7) {
int2 i0 = make_int2(x0, x1);
int2 i1 = make_int2(x2, x3);
int2 i2 = make_int2(x4, x5);
int2 i3 = make_int2(x6, x7);
long long l0 = *(long long*)&i0;
long long l1 = *(long long*)&i1;
long long l2 = *(long long*)&i2;
long long l3 = *(long long*)&i3;
return make_longlong4(l0, l1, l2, l3);
}
#if (__CUDA_ARCH__ >= 600)
__forceinline__ __device__ __half hmax(const __half &a, const __half &b) { return a > b ? a : b; }
Expand Down
Loading

0 comments on commit 23e3c7a

Please sign in to comment.