diff --git a/test/python/dsl_frontend/__init__.py b/test/python/dsl_frontend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/python/dsl_frontend/custom_op.py b/test/python/dsl_frontend/custom_op.py new file mode 100755 index 000000000..d3af63932 --- /dev/null +++ b/test/python/dsl_frontend/custom_op.py @@ -0,0 +1,135 @@ +import json +import hashlib +import importlib +import torch +import os +from run_tuned_json_graph import get_device_source +from export_json_graph import get_input_dict, construct_json_graph + + +dtype_mapping = { + 'float64': torch.float64, + 'float32': torch.float32, + 'float16': torch.float16, + 'int64': torch.int64, + 'int32': torch.int32, + 'int16': torch.int16, + 'int8': torch.int8, + } +def generate_welder_graph(ir, feed_list, extra_outputs, tags=""): + input_dict, kwargs = {}, {} + for k, i, shape, dtype in feed_list: + input_dict[k] = { + 'dtype': str(dtype).split('.')[1], + 'shape': list(shape) + } + + ir = ir.replace('"', '`').replace('\n', ' ').strip() + input_dict = json.dumps(input_dict) + extra_outputs = ', '.join(['"%s"' % x for x in extra_outputs]) + expression = f'- einstein_v2("{ir}", input_dict={input_dict}, extra_outputs=[{extra_outputs}]) ## @: {tags}' + + nodes = [] + edges = [[id, 0] for id in range(len(feed_list))] + node_id = len(feed_list) + nodes.append([node_id, expression, "fused_op", edges]) + nodes.append([node_id + 1, "", "Result", [[node_id, 0]]]) + + return json.dumps(nodes, indent=2) + + +def load_kernel(graph_path): + raw_model_path = graph_path + tuned_model_path = graph_path.strip('.json') + ".kernel.json" + with open(raw_model_path) as f: + raw_json_graph = json.load(f) + with open(tuned_model_path) as f: + tuned_json_graph = json.load(f) + + inputs_outputs_info = [] + device_source = get_device_source(raw_json_graph, tuned_json_graph, inputs_outputs_info) + + backend = 'c-cuda' + lib_name = 'antares_custom_torch_v2_%s' % backend.replace('-', '_') + try: + custom_lib = importlib.import_module(lib_name) + except: + print(f'Failed to import {lib_name}.\nPlease install Custom Plugin for backend in advance: BACKEND={backend} antares torch-setup') + custom_key = custom_lib.inject(device_source) + return custom_lib, custom_key, inputs_outputs_info + +class CompiledKernel: + def __init__(self, custom_lib, custom_key, inout_info): + self.custom_lib = custom_lib + self.custom_key = custom_key + self.inout_info = inout_info + +KERNEL_CACHE = {} + +class CustomOp(torch.nn.Module): + def __init__(self, kernel_file, device=None): + super(CustomOp, self).__init__() + if device is None: + self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + else: + self.device = device + self.custom_lib, self.custom_key, inout_info = load_kernel(kernel_file) + self.output_list = inout_info[1] + + + def __init__(self, ir, input_orders, extra_outputs=[], tags="", steps=1, arch='g3090', device=None): + super(CustomOp, self).__init__() + if device is None: + self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + else: + self.device = device + ir = ir.replace('"', '`').replace('\n', ' ').strip() + self.hash_key = hashlib.sha256(ir.encode()).hexdigest() + if self.hash_key in KERNEL_CACHE: + cache = KERNEL_CACHE[self.hash_key] + self.custom_lib = cache.custom_lib + self.custom_key = cache.custom_key + self.output_list = cache.inout_info[1] + return + + input_list, index = [], 0 + for k in input_orders: + if isinstance(input_orders[k], tuple): + input_list += [(k, index, input_orders[k][2], input_orders[k][1])] + else: + input_list += [(k, index, input_orders[k].shape, input_orders[k].dtype)] + index += 1 + + self.input_orders = sorted(input_list, key=lambda x: x[0]) + self.graph = generate_welder_graph(ir, input_list, extra_outputs, tags) + + graph_path = f'/home/jxue/.cache/nnfusion/graph/{self.hash_key}.json' + tuned_graph_path = f'/home/jxue/.cache/nnfusion/graph/{self.hash_key}.kernel.json' + if not os.path.exists(tuned_graph_path) or steps > 1: + with open(graph_path, 'w+') as fp: + fp.write(self.graph) + + cmd = f'python3 -m run_compiler {graph_path} {tuned_graph_path} --device 0 --topk {steps} --arch {arch}' + print(cmd) + os.system(cmd) + assert os.path.exists(tuned_graph_path) + self.custom_lib, self.custom_key, inout_info = load_kernel(graph_path) + self.output_list = inout_info[1] + KERNEL_CACHE[self.hash_key] = CompiledKernel(self.custom_lib, self.custom_key, inout_info) + + def input_info(self): + return self.input_list + + def forward(self, inputs): + ordered_inputs = [] + for i in range(len(inputs)): + inp = inputs[i] + ordered_inputs.append(inp.contiguous().to(self.device)) + + outputs = [] + for info in self.output_list: + out = torch.empty(info[1]['shape'], device=self.device, dtype=dtype_mapping[info[1]['dtype']]) + outputs.append(out) + self.custom_lib.forward(self.custom_key, ordered_inputs + outputs) + outputs = outputs[0] if len(outputs) == 1 else tuple(outputs) + return outputs diff --git a/test/python/dsl_frontend/example.py b/test/python/dsl_frontend/example.py new file mode 100755 index 000000000..daac9e131 --- /dev/null +++ b/test/python/dsl_frontend/example.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import time +from custom_op import CustomOp + +class CustomLinear(nn.Module): + def __init__( + self, + embed_dim, + ffn_dim, + activation_dropout, + ): + super().__init__() + self.embed_dim = embed_dim + self.activation_dropout_module = torch.nn.Dropout(activation_dropout) + self.fc2 = nn.Linear(ffn_dim, self.embed_dim) + + def reset_parameters(self): + self.fc2.reset_parameters() + + def forward(self, x): + x = F.gelu(x.float()).type_as(x) + x = self.activation_dropout_module(x) + x = self.fc2(x) + return x + +M_SQRT1_2 = 0.70710678118654752440 +M_2_SQRTPI = 1.12837916709551257390 + +class FusedLinearFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias, p): + mask = (torch.rand_like(x) >= p) +# fused_op = CustomOp(ir=f''' +# output0[N0, N2] +=! input0[N0, N1] * input1[N2, N1]; +# ''', input_orders={'input0': x, 'input1': weight}, tags="tensorCoreConfig=(0, 1)", device=device) + + fused_op = CustomOp(ir=f''' +m0[N0, N1] = input0[N0, N1].cast(`float32`); +m1[N0, N1] = m0[N0, N1] * const(0.5).cast(`float32`) * (const(1.0).cast(`float32`) + (m0[N0, N1] * const({M_SQRT1_2}).cast(`float32`)).call(`erf`)); +m2[N0, N1] = m1[N0, N1].cast(`float16`); +m3[N0, N1] = m2[N0, N1] * input3[N0, N1] / const({1-p}).cast(`float16`); +m4[N0, N2] +=! m3[N0, N1] * input1[N2, N1]; +output0[N0, N2] = m4[N0, N2] + input2[N0]; +''', input_orders={'input0': x, 'input1': weight, 'input2': bias, 'input3': mask}, tags="tensorCoreConfig=(0, 1)", device=device) + y = fused_op([x, weight, bias, mask]) + ctx.save_for_backward(x, weight, mask) + ctx.p = p + return y + + @staticmethod + def backward(ctx, dy): + x, weight, mask = ctx.saved_tensors + p = ctx.p + dbias = torch.sum(dy, dim=0) + dw_op = CustomOp(ir=f''' +m0[N0, N1] = input0[N0, N1].cast(`float32`); +m1[N0, N1] = m0[N0, N1] * const(0.5).cast(`float32`) * (const(1.0).cast(`float32`) + (m0[N0, N1] * const({M_SQRT1_2}).cast(`float32`)).call(`erf`)); +m2[N0, N1] = m1[N0, N1].cast(`float16`); +m3[N0, N1] = m2[N0, N1] * input2[N0, N1] / const({1-p}).cast(`float16`); +output0[N2, N1] +=! input1[N0, N2] * m3[N0, N1]; +''', input_orders={'input0': x, 'input1': dy, 'input2': mask}, tags="tensorCoreConfig=(0, 1)", device=device) + dw = dw_op([x, dy, mask]) + + dx_op = CustomOp(ir=f''' +m0[N0, N1] +=! input3[N0, N2] * input1[N2, N1]; +m1[N0, N1] = m0[N0, N1] * input2[N0, N1] * const({1-p}).cast(`float16`); +m2[N0, N1] = m1[N0, N1].cast(`float32`); +m3[N0, N1] = const(0.5).cast(`float32`) * (const(1.0).cast(`float32`) + (input0[N0, N1] * const({M_SQRT1_2}).cast(`float32`)).call(`erf`)); +m4[N0, N1] = (const(-0.5).cast(`float32`) * input0[N0, N1] * input0[N0, N1]).call(`exp`) * const({M_2_SQRTPI * M_SQRT1_2 * 0.5}).cast(`float32`); +output0[N0, N1] = m2[N0, N1] * (m3[N0, N1] + input0[N0, N1] * m4[N0, N1]); +''', input_orders={'input0': x, 'input1': weight, 'input2': mask, 'input3': dy}, tags="tensorCoreConfig=(0, 1)", device=device) + dx = dx_op([x, weight, mask, dy]) + return dx, dw, dbias, None + +class FusedCustomLinear(nn.Module): + def __init__( + self, + embed_dim, + ffn_dim, + activation_dropout, + ): + super().__init__() + self.embed_dim = embed_dim + self.activation_dropout = activation_dropout + self.fc2 = nn.Linear(ffn_dim, self.embed_dim, dtype=torch.float16) + + def reset_parameters(self): + self.fc2.reset_parameters() + + def forward(self, x): + return FusedLinearFunc.apply(x, self.fc2.weight, self.fc2.bias, self.activation_dropout) + + +if __name__ == '__main__': + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + torch.set_default_dtype(torch.float16) + x = torch.randn(2048, 16384, requires_grad = True, device=device) + ref = CustomLinear(4096, 16384, 0).to(device) + fused = FusedCustomLinear(4096, 16384, 0).to(device) + + y_ref = ref(x) + y_fused = fused(x) + + y_grad = torch.ones_like(y, device=device) + y.backward(y_grad) + + # start = time.time() + # for i in range(100): + # y = layer.forward(x) + # y.backward(y_grad) + # #print(x, x.grad, layer.fc2.weight.grad, layer.fc2.bias.grad) + # end = time.time() + # print(end-start) + + + + + diff --git a/test/python/dsl_frontend/kernel_packer.py b/test/python/dsl_frontend/kernel_packer.py index 72b522c4a..b5cf77f30 100644 --- a/test/python/dsl_frontend/kernel_packer.py +++ b/test/python/dsl_frontend/kernel_packer.py @@ -11,6 +11,54 @@ #ifndef __CUDA_COMMON_MACRO__ #define __CUDA_COMMON_MACRO__ +__device__ half max(half a, half b) +{ + return __hgt(__half(a), __half(b)) ? a : b; +} +__device__ half min(half a, half b) +{ + return __hlt(__half(a), __half(b)) ? a : b; +} + +typedef long long _ll; +#define int64_t _ll +#define __int8_t_defined + +#define CUDA_UNSUPPORTED_HALF_MATH_BINARY(HALF_MATH_NAME, FP32_MATH_NAME) inline __device__ half HALF_MATH_NAME(half x, half y) { float tmp_x = __half2float(x); float tmp_y = __half2float(y); float result = FP32_MATH_NAME(tmp_x, tmp_y); return __float2half(result); } + +#define CUDA_UNSUPPORTED_HALF_MATH_UNARY(HALF_MATH_NAME, FP32_MATH_NAME) inline __device__ half HALF_MATH_NAME(half x) { float tmp_x = __half2float(x); float result = FP32_MATH_NAME(tmp_x); return __float2half(result); } + +CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf) +CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf) +CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf) +CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf) +CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf) + +#undef CUDA_UNSUPPORTED_HALF_MATH_BINARY +#undef CUDA_UNSUPPORTED_HALF_MATH_UNARY + +// Pack two half values. +inline __device__ __host__ unsigned +__pack_half2(const half x, const half y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} + +// There is no make_int8 in cuda, but TVM codegen seem to use it +inline __device__ longlong4 make_int8(int x0, int x1, int x2, int x3, int x4, int x5, int x6, int x7) { + int2 i0 = make_int2(x0, x1); + int2 i1 = make_int2(x2, x3); + int2 i2 = make_int2(x4, x5); + int2 i3 = make_int2(x6, x7); + long long l0 = *(long long*)&i0; + long long l1 = *(long long*)&i1; + long long l2 = *(long long*)&i2; + long long l3 = *(long long*)&i3; + return make_longlong4(l0, l1, l2, l3); +} + + #if (__CUDA_ARCH__ >= 600) __forceinline__ __device__ __half hmax(const __half &a, const __half &b) { return a > b ? a : b; } diff --git a/test/python/dsl_frontend/linear_fusion/README.md b/test/python/dsl_frontend/linear_fusion/README.md new file mode 100644 index 000000000..f9e0d5010 --- /dev/null +++ b/test/python/dsl_frontend/linear_fusion/README.md @@ -0,0 +1,106 @@ +# Pattern: Cast + Gelu + Cast + Dropout + Matmul + BiasAdd + +https://github.com/microsoft/torchscale/blob/main/torchscale/component/feedforward_network.py#L133-L137 + +## Forward: + +``` +[2048, 4096] = f([2048,16384], [4096, 16384], [4096], [2048,16384]) +y = f(x, w, b, mask): + x = x.cast(float32) + x = x * 0.5 * (1 + erf(x * M_SQRT1_2)) # GELU + x = x.cast(float16) + x = x * mask /(1-p) #dropout(x, mask) * s + y = x * w + y = y + b + +``` + +``` +- einstein_v2(\" +m0[N0, N1] = input0[N0, N1].cast(`float32`); +m1[N0, N1] = m0[N0, N1] * const(0.5).cast(`float32`) * (const(1.0).cast(`float32`) + (m0[N0, N1] * const(0.70710678118654752440).cast(`float32`)).call(`erf`)); +m2[N0, N1] = m1[N0, N1].cast(`float16`); +m3[N0, N1] = m2[N0, N1] * input3[N0, N1] / const(0.5).cast(`float16`); +m4[N0, N2] +=! m3[N0, N1] * input1[N2, N1]; +output[N0, N2] = m4[N0, N2] + input2[N0];\", +input_dict={ + \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [2048, 16384]}, + \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [4096, 16384]}, + \"input2\" : { \"dtype\" : \"float16\", \"shape\" : [4096]}, + \"input3\" : { \"dtype\" : \"float16\", \"shape\" : [2048, 16384]}}) +## @: tensorCoreConfig=(0, 1) +``` + + +## Backward: + +### db + +``` +[4096] = reduce_sum(2048, 4096) +db = f0(dy) +db = reduce_sum(dy, 0) +``` + +### dw + +``` + [4096, 16384] = ([2048,16384], [2048, 4096], [2048,16384]) +dw = f1(x, dy, mask) + x = x.cast(float32) + x = x * 0.5 * (1 + erf(x * M_SQRT1_2)) # GELU + x = x.cast(float16) + x = x * mask /(1-p) #dropout(x, mask) * s + dw = dy^T * x # [4096, 16384] = [2048, 4096] * [2048,16384] +``` + +``` +- einstein_v2(\" +m0[N0, N1] = input0[N0, N1].cast(`float32`); +m1[N0, N1] = m0[N0, N1] * const(0.5).cast(`float32`) * (const(1.0).cast(`float32`) + (m0[N0, N1] * const(0.70710678118654752440).cast(`float32`)).call(`erf`)); +m2[N0, N1] = m1[N0, N1].cast(`float16`); +m3[N0, N1] = m2[N0, N1] * input2[N0, N1] / const(0.5).cast(`float16`); +output0[N2, N1] +=! input1[N0, N2] * m3[N0, N1];\", +input_dict={ + \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [2048, 16384]} , + \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [2048, 4096]}, + \"input2\" : { \"dtype\" : \"float16\", \"shape\" : [2048, 16384]}}) +## @: tensorCoreConfig=(0, 1) +``` + +### dx +``` +dx = f2(x, w, mask, dy) + dx = dy * w # [2048,16384] = [2048, 4096] * [4096, 16384] + dx = dx * (1-p) * mask + dx = dx.cast(float32) + #dx = dx * (cdf + x * pdf) # cdf = 0.5 * (1 + erf(x * M_SQRT1_2)) pdf = exp(-0.5 * x * x ) * M_2_SQRTPI * M_SQRT1_2 * 0.5; + dx = dx * (0.5 * (1 + erf(x * M_SQRT1_2)) + x * (exp(-0.5 * x * x ) * M_2_SQRTPI * M_SQRT1_2 * 0.5)) +``` + +``` +- einstein_v2(\" +m0[N0, N1] +=! input3[N0, N2] * input1[N2, N1]; +m1[N0, N1] = m0[N0, N1] * input2[N0, N1] * const(0.5).cast(`float16`); +m2[N0, N1] = m1[N0, N1].cast(`float32`); +m3[N0, N1] = const(0.5).cast(`float32`) * (const(1.0).cast(`float32`) + (input0[N0, N1] * const(0.70710678118654752440).cast(`float32`)).call(`erf`)); +m4[N0, N1] = (const(-0.5).cast(`float32`) * input0[N0, N1] * input0[N0, N1]).call(`exp`) * const(0.3989422804014327).cast(`float32`); +output0[N0, N1] = m2[N0, N1] * (m3[N0, N1] + input0[N0, N1] * m4[N0, N1]); +\", +input_dict={ + \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [2048, 16384]}, + \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [4096, 16384]}, + \"input2\" : { \"dtype\" : \"float16\", \"shape\" : [2048, 16384]}, + \"input3\" : { \"dtype\" : \"float16\", \"shape\" : [2048, 4096]}}) +## @: tensorCoreConfig=(0, 1) +``` + + + +## reference: + +M_SQRT1_2 = 0.70710678118654752440 +M_2_SQRTPI = 1.12837916709551257390 + +https://github.com/pytorch/pytorch/blob/c24b61bc20f76c238e742b765a9efe9ae20c7c03/aten/src/ATen/native/cuda/ActivationGeluKernel.cu \ No newline at end of file diff --git a/test/python/dsl_frontend/linear_fusion/bwd.dw.json b/test/python/dsl_frontend/linear_fusion/bwd.dw.json new file mode 100644 index 000000000..27647ab29 --- /dev/null +++ b/test/python/dsl_frontend/linear_fusion/bwd.dw.json @@ -0,0 +1,32 @@ +[ + [ + 3, + "- einstein_v2(\" m0[N0, N1] = input0[N0, N1].cast(`float32`); m1[N0, N1] = m0[N0, N1] * const(0.5).cast(`float32`) * (const(1.0).cast(`float32`) + (m0[N0, N1] * const(0.70710678118654752440).cast(`float32`)).call(`erf`)); m2[N0, N1] = m1[N0, N1].cast(`float16`); m3[N0, N1] = m2[N0, N1] * input2[N0, N1] / const(0.5).cast(`float16`); output0[N2, N1] +=! input1[N0, N2] * m3[N0, N1];\", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [2048, 16384]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [2048, 4096]}, \"input2\" : { \"dtype\" : \"float16\", \"shape\" : [2048, 16384]}})## @: tensorCoreConfig=(0, 1)", + "FusedLayer", + [ + [ + 0, + 0 + ], + [ + 1, + 0 + ], + [ + 2, + 0 + ] + ] + ], + [ + 4, + "", + "Result", + [ + [ + 3, + 0 + ] + ] + ] +] \ No newline at end of file diff --git a/test/python/dsl_frontend/linear_fusion/bwd.dw.kernel.json b/test/python/dsl_frontend/linear_fusion/bwd.dw.kernel.json new file mode 100644 index 000000000..cccaa7117 --- /dev/null +++ b/test/python/dsl_frontend/linear_fusion/bwd.dw.kernel.json @@ -0,0 +1,45 @@ +[ + { + "nodes": [ + 3 + ], + "node_names": [ + "FusedLayer_3" + ], + "group_id": 0, + "input_desc": [ + [ + 3, + 0 + ], + [ + 3, + 1 + ], + [ + 3, + 2 + ] + ], + "output_desc": [ + [ + 3, + 0 + ] + ], + "code": "__global__ void __launch_bounds__(128) Group0(half* __restrict__ input0, half* __restrict__ input1, half* __restrict__ input2, half* __restrict__ output0) {\n nvcuda::wmma::fragment output0_wmma_accumulator[16];\n __shared__ half input1_shared[4352];\n __shared__ half m3_shared[4352];\n nvcuda::wmma::fragment input1_shared_wmma_matrix_a[4];\n nvcuda::wmma::fragment m3_shared_wmma_matrix_b[4];\n for (int N2_c_outer_init = 0; N2_c_outer_init < 4; ++N2_c_outer_init) {\n for (int N1_c_outer_init = 0; N1_c_outer_init < 4; ++N1_c_outer_init) {\n nvcuda::wmma::fill_fragment(output0_wmma_accumulator[((N2_c_outer_init * 4) + N1_c_outer_init)], 0.000000e+00f);\n }\n }\n for (int N0_outer = 0; N0_outer < 64; ++N0_outer) {\n __syncthreads();\n *(uint4*)(input1_shared + (((((int)threadIdx.y) * 272) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(input1 + (((((N0_outer * 131072) + (((int)threadIdx.y) * 8192)) + ((((int)threadIdx.x) >> 4) * 4096)) + ((((int)blockIdx.x) >> 7) * 128)) + ((((int)threadIdx.x) & 15) * 8)));\n *(uint4*)(input1_shared + ((((((int)threadIdx.y) * 272) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8)) + 1088)) = *(uint4*)(input1 + ((((((N0_outer * 131072) + (((int)threadIdx.y) * 8192)) + ((((int)threadIdx.x) >> 4) * 4096)) + ((((int)blockIdx.x) >> 7) * 128)) + ((((int)threadIdx.x) & 15) * 8)) + 32768));\n *(uint4*)(input1_shared + ((((((int)threadIdx.y) * 272) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8)) + 2176)) = *(uint4*)(input1 + ((((((N0_outer * 131072) + (((int)threadIdx.y) * 8192)) + ((((int)threadIdx.x) >> 4) * 4096)) + ((((int)blockIdx.x) >> 7) * 128)) + ((((int)threadIdx.x) & 15) * 8)) + 65536));\n *(uint4*)(input1_shared + ((((((int)threadIdx.y) * 272) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8)) + 3264)) = *(uint4*)(input1 + ((((((N0_outer * 131072) + (((int)threadIdx.y) * 8192)) + ((((int)threadIdx.x) >> 4) * 4096)) + ((((int)blockIdx.x) >> 7) * 128)) + ((((int)threadIdx.x) & 15) * 8)) + 98304));\n for (int ax0_ax1_fused_inner_s = 0; ax0_ax1_fused_inner_s < 4; ++ax0_ax1_fused_inner_s) {\n m3_shared[(((((int)threadIdx.y) * 136) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s)] = ((((half)((((float)input0[(((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s)]) * 5.000000e-01f) * (1.000000e+00f + erff((((float)input0[(((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s)]) * 7.071068e-01f))))) * input2[(((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s)]) * __float2half_rn(2.000000e+00f));\n }\n for (int ax0_ax1_fused_inner_s_1 = 0; ax0_ax1_fused_inner_s_1 < 4; ++ax0_ax1_fused_inner_s_1) {\n m3_shared[((((((int)threadIdx.y) * 136) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_1) + 544)] = ((((half)((((float)input0[((((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_1) + 65536)]) * 5.000000e-01f) * (1.000000e+00f + erff((((float)input0[((((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_1) + 65536)]) * 7.071068e-01f))))) * input2[((((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_1) + 65536)]) * __float2half_rn(2.000000e+00f));\n }\n for (int ax0_ax1_fused_inner_s_2 = 0; ax0_ax1_fused_inner_s_2 < 4; ++ax0_ax1_fused_inner_s_2) {\n m3_shared[((((((int)threadIdx.y) * 136) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_2) + 1088)] = ((((half)((((float)input0[((((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_2) + 131072)]) * 5.000000e-01f) * (1.000000e+00f + erff((((float)input0[((((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_2) + 131072)]) * 7.071068e-01f))))) * input2[((((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_2) + 131072)]) * __float2half_rn(2.000000e+00f));\n }\n for (int ax0_ax1_fused_inner_s_3 = 0; ax0_ax1_fused_inner_s_3 < 4; ++ax0_ax1_fused_inner_s_3) {\n m3_shared[((((((int)threadIdx.y) * 136) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_3) + 1632)] = ((((half)((((float)input0[((((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_3) + 196608)]) * 5.000000e-01f) * (1.000000e+00f + erff((((float)input0[((((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_3) + 196608)]) * 7.071068e-01f))))) * input2[((((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_3) + 196608)]) * __float2half_rn(2.000000e+00f));\n }\n for (int ax0_ax1_fused_inner_s_4 = 0; ax0_ax1_fused_inner_s_4 < 4; ++ax0_ax1_fused_inner_s_4) {\n m3_shared[((((((int)threadIdx.y) * 136) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_4) + 2176)] = ((((half)((((float)input0[((((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_4) + 262144)]) * 5.000000e-01f) * (1.000000e+00f + erff((((float)input0[((((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_4) + 262144)]) * 7.071068e-01f))))) * input2[((((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_4) + 262144)]) * __float2half_rn(2.000000e+00f));\n }\n for (int ax0_ax1_fused_inner_s_5 = 0; ax0_ax1_fused_inner_s_5 < 4; ++ax0_ax1_fused_inner_s_5) {\n m3_shared[((((((int)threadIdx.y) * 136) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_5) + 2720)] = ((((half)((((float)input0[((((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_5) + 327680)]) * 5.000000e-01f) * (1.000000e+00f + erff((((float)input0[((((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_5) + 327680)]) * 7.071068e-01f))))) * input2[((((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_5) + 327680)]) * __float2half_rn(2.000000e+00f));\n }\n for (int ax0_ax1_fused_inner_s_6 = 0; ax0_ax1_fused_inner_s_6 < 4; ++ax0_ax1_fused_inner_s_6) {\n m3_shared[((((((int)threadIdx.y) * 136) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_6) + 3264)] = ((((half)((((float)input0[((((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_6) + 393216)]) * 5.000000e-01f) * (1.000000e+00f + erff((((float)input0[((((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_6) + 393216)]) * 7.071068e-01f))))) * input2[((((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_6) + 393216)]) * __float2half_rn(2.000000e+00f));\n }\n for (int ax0_ax1_fused_inner_s_7 = 0; ax0_ax1_fused_inner_s_7 < 4; ++ax0_ax1_fused_inner_s_7) {\n m3_shared[((((((int)threadIdx.y) * 136) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_7) + 3808)] = ((((half)((((float)input0[((((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_7) + 458752)]) * 5.000000e-01f) * (1.000000e+00f + erff((((float)input0[((((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_7) + 458752)]) * 7.071068e-01f))))) * input2[((((((N0_outer * 524288) + (((int)threadIdx.y) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + (((int)threadIdx.x) * 4)) + ax0_ax1_fused_inner_s_7) + 458752)]) * __float2half_rn(2.000000e+00f));\n }\n __syncthreads();\n for (int N0_inner_outer = 0; N0_inner_outer < 2; ++N0_inner_outer) {\n for (int ax1_outer = 0; ax1_outer < 4; ++ax1_outer) {\n nvcuda::wmma::load_matrix_sync(input1_shared_wmma_matrix_a[ax1_outer], (&(input1_shared[(((N0_inner_outer * 2176) + ((((int)threadIdx.y) >> 1) * 64)) + (ax1_outer * 16))])), 136);\n }\n for (int ax1_outer_1 = 0; ax1_outer_1 < 4; ++ax1_outer_1) {\n nvcuda::wmma::load_matrix_sync(m3_shared_wmma_matrix_b[ax1_outer_1], (&(m3_shared[(((N0_inner_outer * 2176) + ((((int)threadIdx.y) & 1) * 64)) + (ax1_outer_1 * 16))])), 136);\n }\n for (int N2_c_outer = 0; N2_c_outer < 4; ++N2_c_outer) {\n for (int N1_c_outer = 0; N1_c_outer < 4; ++N1_c_outer) {\n nvcuda::wmma::mma_sync(output0_wmma_accumulator[((N2_c_outer * 4) + N1_c_outer)], input1_shared_wmma_matrix_a[N2_c_outer], m3_shared_wmma_matrix_b[N1_c_outer], output0_wmma_accumulator[((N2_c_outer * 4) + N1_c_outer)]);\n }\n }\n }\n }\n __syncthreads();\n for (int N2_inner_inner_outer = 0; N2_inner_inner_outer < 4; ++N2_inner_inner_outer) {\n for (int N1_inner_inner_outer = 0; N1_inner_inner_outer < 4; ++N1_inner_inner_outer) {\n nvcuda::wmma::store_matrix_sync((&(output0[(((((((((int)blockIdx.x) >> 7) * 2097152) + ((((int)threadIdx.y) >> 1) * 1048576)) + (N2_inner_inner_outer * 262144)) + ((((int)blockIdx.x) & 127) * 128)) + ((((int)threadIdx.y) & 1) * 64)) + (N1_inner_inner_outer * 16))])), output0_wmma_accumulator[((N2_inner_inner_outer * 4) + N1_inner_inner_outer)], 16384, nvcuda::wmma::mem_row_major);\n }\n }\n __syncthreads();\n}\n\n", + "block_size": [ + 32, + 4, + 1 + ], + "grid_size": [ + 4096, + 1, + 1 + ], + "latency": 5.714944362640381, + "name": "Group0", + "gain": 0 + } +] \ No newline at end of file diff --git a/test/python/dsl_frontend/linear_fusion/bwd.dx.json b/test/python/dsl_frontend/linear_fusion/bwd.dx.json new file mode 100644 index 000000000..e8e106a78 --- /dev/null +++ b/test/python/dsl_frontend/linear_fusion/bwd.dx.json @@ -0,0 +1,36 @@ +[ + [ + 4, + "- einstein_v2(\"m0[N0, N1] +=! input3[N0, N2] * input1[N2, N1];m1[N0, N1] = m0[N0, N1] * input2[N0, N1] * const(0.5).cast(`float16`);m2[N0, N1] = m1[N0, N1].cast(`float32`); m3[N0, N1] = const(0.5).cast(`float32`) * (const(1.0).cast(`float32`) + (input0[N0, N1] * const(0.70710678118654752440).cast(`float32`)).call(`erf`));m4[N0, N1] = (const(-0.5).cast(`float32`) * input0[N0, N1] * input0[N0, N1]).call(`exp`) * const(0.3989422804014327).cast(`float32`);output0[N0, N1] = m2[N0, N1] * (m3[N0, N1] + input0[N0, N1] * m4[N0, N1]);\", input_dict={\"input0\" : { \"dtype\" : \"float16\", \"shape\" : [2048, 16384]}, \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [4096, 16384]}, \"input2\" : { \"dtype\" : \"float16\", \"shape\" : [2048, 16384]},\"input3\" : { \"dtype\" : \"float16\", \"shape\" : [2048, 4096]}})## @: tensorCoreConfig=(0, 1)", + "FusedLayer", + [ + [ + 0, + 0 + ], + [ + 1, + 0 + ], + [ + 2, + 0 + ], + [ + 3, + 0 + ] + ] + ], + [ + 5, + "", + "Result", + [ + [ + 4, + 0 + ] + ] + ] +] \ No newline at end of file diff --git a/test/python/dsl_frontend/linear_fusion/bwd.dx.kernel.json b/test/python/dsl_frontend/linear_fusion/bwd.dx.kernel.json new file mode 100644 index 000000000..589fb4886 --- /dev/null +++ b/test/python/dsl_frontend/linear_fusion/bwd.dx.kernel.json @@ -0,0 +1,49 @@ +[ + { + "nodes": [ + 4 + ], + "node_names": [ + "FusedLayer_4" + ], + "group_id": 0, + "input_desc": [ + [ + 4, + 0 + ], + [ + 4, + 1 + ], + [ + 4, + 2 + ], + [ + 4, + 3 + ] + ], + "output_desc": [ + [ + 4, + 0 + ] + ], + "code": "__global__ void __launch_bounds__(128) Group0(half* __restrict__ input0, half* __restrict__ input1, half* __restrict__ input2, half* __restrict__ input3, float* __restrict__ output0) {\n nvcuda::wmma::fragment m0_wmma_accumulator[8];\n __shared__ half input3_shared[4608];\n __shared__ half input1_shared[8704];\n nvcuda::wmma::fragment input3_shared_wmma_matrix_a[4];\n nvcuda::wmma::fragment input1_shared_wmma_matrix_b[2];\n for (int N0_c_outer_init = 0; N0_c_outer_init < 4; ++N0_c_outer_init) {\n for (int N1_c_outer_init = 0; N1_c_outer_init < 2; ++N1_c_outer_init) {\n nvcuda::wmma::fill_fragment(m0_wmma_accumulator[((N0_c_outer_init * 2) + N1_c_outer_init)], 0.000000e+00f);\n }\n }\n for (int N2_outer = 0; N2_outer < 64; ++N2_outer) {\n __syncthreads();\n *(uint4*)(input3_shared + (((((int)threadIdx.y) * 288) + ((((int)threadIdx.x) >> 3) * 72)) + ((((int)threadIdx.x) & 7) * 8))) = *(uint4*)(input3 + ((((((((int)blockIdx.x) >> 7) * 262144) + (((int)threadIdx.y) * 16384)) + ((((int)threadIdx.x) >> 3) * 4096)) + (N2_outer * 64)) + ((((int)threadIdx.x) & 7) * 8)));\n *(uint4*)(input3_shared + ((((((int)threadIdx.y) * 288) + ((((int)threadIdx.x) >> 3) * 72)) + ((((int)threadIdx.x) & 7) * 8)) + 1152)) = *(uint4*)(input3 + (((((((((int)blockIdx.x) >> 7) * 262144) + (((int)threadIdx.y) * 16384)) + ((((int)threadIdx.x) >> 3) * 4096)) + (N2_outer * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 65536));\n *(uint4*)(input3_shared + ((((((int)threadIdx.y) * 288) + ((((int)threadIdx.x) >> 3) * 72)) + ((((int)threadIdx.x) & 7) * 8)) + 2304)) = *(uint4*)(input3 + (((((((((int)blockIdx.x) >> 7) * 262144) + (((int)threadIdx.y) * 16384)) + ((((int)threadIdx.x) >> 3) * 4096)) + (N2_outer * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 131072));\n *(uint4*)(input3_shared + ((((((int)threadIdx.y) * 288) + ((((int)threadIdx.x) >> 3) * 72)) + ((((int)threadIdx.x) & 7) * 8)) + 3456)) = *(uint4*)(input3 + (((((((((int)blockIdx.x) >> 7) * 262144) + (((int)threadIdx.y) * 16384)) + ((((int)threadIdx.x) >> 3) * 4096)) + (N2_outer * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 196608));\n *(uint4*)(input1_shared + (((((int)threadIdx.y) * 272) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(input1 + (((((N2_outer * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + ((((int)threadIdx.x) & 15) * 8)));\n *(uint4*)(input1_shared + ((((((int)threadIdx.y) * 272) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8)) + 1088)) = *(uint4*)(input1 + ((((((N2_outer * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + ((((int)threadIdx.x) & 15) * 8)) + 131072));\n *(uint4*)(input1_shared + ((((((int)threadIdx.y) * 272) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8)) + 2176)) = *(uint4*)(input1 + ((((((N2_outer * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + ((((int)threadIdx.x) & 15) * 8)) + 262144));\n *(uint4*)(input1_shared + ((((((int)threadIdx.y) * 272) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8)) + 3264)) = *(uint4*)(input1 + ((((((N2_outer * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + ((((int)threadIdx.x) & 15) * 8)) + 393216));\n *(uint4*)(input1_shared + ((((((int)threadIdx.y) * 272) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8)) + 4352)) = *(uint4*)(input1 + ((((((N2_outer * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + ((((int)threadIdx.x) & 15) * 8)) + 524288));\n *(uint4*)(input1_shared + ((((((int)threadIdx.y) * 272) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8)) + 5440)) = *(uint4*)(input1 + ((((((N2_outer * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + ((((int)threadIdx.x) & 15) * 8)) + 655360));\n *(uint4*)(input1_shared + ((((((int)threadIdx.y) * 272) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8)) + 6528)) = *(uint4*)(input1 + ((((((N2_outer * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + ((((int)threadIdx.x) & 15) * 8)) + 786432));\n *(uint4*)(input1_shared + ((((((int)threadIdx.y) * 272) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8)) + 7616)) = *(uint4*)(input1 + ((((((N2_outer * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + ((((int)threadIdx.x) & 15) * 8)) + 917504));\n __syncthreads();\n for (int N2_inner_outer = 0; N2_inner_outer < 4; ++N2_inner_outer) {\n for (int ax0_outer = 0; ax0_outer < 4; ++ax0_outer) {\n nvcuda::wmma::load_matrix_sync(input3_shared_wmma_matrix_a[ax0_outer], (&(input3_shared[((((((int)threadIdx.y) >> 1) * 2304) + (ax0_outer * 576)) + (N2_inner_outer * 16))])), 72);\n }\n for (int ax1_outer = 0; ax1_outer < 2; ++ax1_outer) {\n nvcuda::wmma::load_matrix_sync(input1_shared_wmma_matrix_b[ax1_outer], (&(input1_shared[(((N2_inner_outer * 2176) + ((((int)threadIdx.y) & 1) * 64)) + (ax1_outer * 32))])), 136);\n }\n for (int N0_c_outer = 0; N0_c_outer < 4; ++N0_c_outer) {\n for (int N1_c_outer = 0; N1_c_outer < 2; ++N1_c_outer) {\n nvcuda::wmma::mma_sync(m0_wmma_accumulator[((N0_c_outer * 2) + N1_c_outer)], input3_shared_wmma_matrix_a[N0_c_outer], input1_shared_wmma_matrix_b[N1_c_outer], m0_wmma_accumulator[((N0_c_outer * 2) + N1_c_outer)]);\n }\n }\n }\n }\n __syncthreads();\n for (int ax0_inner_outer = 0; ax0_inner_outer < 4; ++ax0_inner_outer) {\n for (int ax1_inner_outer = 0; ax1_inner_outer < 2; ++ax1_inner_outer) {\n nvcuda::wmma::store_matrix_sync((&(input1_shared[(((((((int)threadIdx.y) >> 1) * 4352) + (ax0_inner_outer * 1088)) + ((((int)threadIdx.y) & 1) * 64)) + (ax1_inner_outer * 32))])), m0_wmma_accumulator[((ax0_inner_outer * 2) + ax1_inner_outer)], 136, nvcuda::wmma::mem_row_major);\n }\n }\n __syncthreads();\n for (int N0_inner_N1_inner_fused_outer_outer_outer = 0; N0_inner_N1_inner_fused_outer_outer_outer < 8; ++N0_inner_N1_inner_fused_outer_outer_outer) {\n for (int N0_inner_N1_inner_fused_inner_s = 0; N0_inner_N1_inner_fused_inner_s < 8; ++N0_inner_N1_inner_fused_inner_s) {\n output0[((((((((((int)blockIdx.x) >> 7) * 1048576) + (N0_inner_N1_inner_fused_outer_outer_outer * 131072)) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + ((((int)threadIdx.x) & 15) * 8)) + N0_inner_N1_inner_fused_inner_s)] = (((float)((input1_shared[(((((N0_inner_N1_inner_fused_outer_outer_outer * 1088) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8)) + N0_inner_N1_inner_fused_inner_s)] * input2[((((((((((int)blockIdx.x) >> 7) * 1048576) + (N0_inner_N1_inner_fused_outer_outer_outer * 131072)) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + ((((int)threadIdx.x) & 15) * 8)) + N0_inner_N1_inner_fused_inner_s)]) * __float2half_rn(5.000000e-01f))) * ((5.000000e-01f * (1.000000e+00f + erff((((float)input0[((((((((((int)blockIdx.x) >> 7) * 1048576) + (N0_inner_N1_inner_fused_outer_outer_outer * 131072)) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + ((((int)threadIdx.x) & 15) * 8)) + N0_inner_N1_inner_fused_inner_s)]) * 7.071068e-01f)))) + (((float)input0[((((((((((int)blockIdx.x) >> 7) * 1048576) + (N0_inner_N1_inner_fused_outer_outer_outer * 131072)) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + ((((int)threadIdx.x) & 15) * 8)) + N0_inner_N1_inner_fused_inner_s)]) * (__expf(((-5.000000e-01f * ((float)input0[((((((((((int)blockIdx.x) >> 7) * 1048576) + (N0_inner_N1_inner_fused_outer_outer_outer * 131072)) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + ((((int)threadIdx.x) & 15) * 8)) + N0_inner_N1_inner_fused_inner_s)])) * ((float)input0[((((((((((int)blockIdx.x) >> 7) * 1048576) + (N0_inner_N1_inner_fused_outer_outer_outer * 131072)) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + ((((int)blockIdx.x) & 127) * 128)) + ((((int)threadIdx.x) & 15) * 8)) + N0_inner_N1_inner_fused_inner_s)]))) * 3.989423e-01f))));\n }\n }\n}\n\n", + "block_size": [ + 32, + 4, + 1 + ], + "grid_size": [ + 4096, + 1, + 1 + ], + "latency": 5.310894966125488, + "name": "Group0", + "gain": 0 + } +] \ No newline at end of file diff --git a/test/python/dsl_frontend/linear_fusion/fwd.json b/test/python/dsl_frontend/linear_fusion/fwd.json new file mode 100644 index 000000000..968239e11 --- /dev/null +++ b/test/python/dsl_frontend/linear_fusion/fwd.json @@ -0,0 +1,36 @@ +[ + [ + 4, + "- einstein_v2(\"m0[N0, N1] = input0[N0, N1].cast(`float32`); m1[N0, N1] = m0[N0, N1] * const(0.5).cast(`float32`) * (const(1.0).cast(`float32`) + (m0[N0, N1] * const(0.70710678118654752440).cast(`float32`)).call(`erf`)); m2[N0, N1] = m1[N0, N1].cast(`float16`); m3[N0, N1] = m2[N0, N1] * input3[N0, N1] / const(0.5).cast(`float16`); m4[N0, N2] +=! m3[N0, N1] * input1[N2, N1];output0[N0, N2] = m4[N0, N2] + input2[N0];\", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [2048, 16384]}, \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [4096, 16384]}, \"input2\" : { \"dtype\" : \"float16\", \"shape\" : [4096]}, \"input3\" : { \"dtype\" : \"float16\", \"shape\" : [2048, 16384]}})## @: tensorCoreConfig=(0, 1)", + "FusedLayer", + [ + [ + 0, + 0 + ], + [ + 1, + 0 + ], + [ + 2, + 0 + ], + [ + 3, + 0 + ] + ] + ], + [ + 5, + "", + "Result", + [ + [ + 4, + 0 + ] + ] + ] +] \ No newline at end of file diff --git a/test/python/dsl_frontend/linear_fusion/fwd.kernel.json b/test/python/dsl_frontend/linear_fusion/fwd.kernel.json new file mode 100644 index 000000000..4589f92ee --- /dev/null +++ b/test/python/dsl_frontend/linear_fusion/fwd.kernel.json @@ -0,0 +1,49 @@ +[ + { + "nodes": [ + 4 + ], + "node_names": [ + "FusedLayer_4" + ], + "group_id": 0, + "input_desc": [ + [ + 4, + 0 + ], + [ + 4, + 1 + ], + [ + 4, + 2 + ], + [ + 4, + 3 + ] + ], + "output_desc": [ + [ + 4, + 0 + ] + ], + "code": "__global__ void __launch_bounds__(128) Group0(half* __restrict__ input0, half* __restrict__ input1, half* __restrict__ input2, half* __restrict__ input3, half* __restrict__ output0) {\n nvcuda::wmma::fragment m4_wmma_accumulator[8];\n __shared__ half m3_shared[4608];\n __shared__ half input1_shared[9216];\n nvcuda::wmma::fragment m3_shared_wmma_matrix_a[8];\n nvcuda::wmma::fragment input1_shared_wmma_matrix_b[1];\n for (int N0_c_outer_init = 0; N0_c_outer_init < 8; ++N0_c_outer_init) {\n nvcuda::wmma::fill_fragment(m4_wmma_accumulator[N0_c_outer_init], 0.000000e+00f);\n }\n for (int N1_outer = 0; N1_outer < 256; ++N1_outer) {\n __syncthreads();\n for (int ax0_ax1_fused_inner_s = 0; ax0_ax1_fused_inner_s < 4; ++ax0_ax1_fused_inner_s) {\n m3_shared[((((((int)threadIdx.y) * 144) + ((((int)threadIdx.x) >> 4) * 72)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s)] = ((((half)((((float)input0[(((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s)]) * 5.000000e-01f) * (1.000000e+00f + erff((((float)input0[(((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s)]) * 7.071068e-01f))))) * input3[(((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s)]) * __float2half_rn(2.000000e+00f));\n }\n for (int ax0_ax1_fused_inner_s_1 = 0; ax0_ax1_fused_inner_s_1 < 4; ++ax0_ax1_fused_inner_s_1) {\n m3_shared[(((((((int)threadIdx.y) * 144) + ((((int)threadIdx.x) >> 4) * 72)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_1) + 576)] = ((((half)((((float)input0[((((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_1) + 131072)]) * 5.000000e-01f) * (1.000000e+00f + erff((((float)input0[((((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_1) + 131072)]) * 7.071068e-01f))))) * input3[((((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_1) + 131072)]) * __float2half_rn(2.000000e+00f));\n }\n for (int ax0_ax1_fused_inner_s_2 = 0; ax0_ax1_fused_inner_s_2 < 4; ++ax0_ax1_fused_inner_s_2) {\n m3_shared[(((((((int)threadIdx.y) * 144) + ((((int)threadIdx.x) >> 4) * 72)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_2) + 1152)] = ((((half)((((float)input0[((((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_2) + 262144)]) * 5.000000e-01f) * (1.000000e+00f + erff((((float)input0[((((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_2) + 262144)]) * 7.071068e-01f))))) * input3[((((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_2) + 262144)]) * __float2half_rn(2.000000e+00f));\n }\n for (int ax0_ax1_fused_inner_s_3 = 0; ax0_ax1_fused_inner_s_3 < 4; ++ax0_ax1_fused_inner_s_3) {\n m3_shared[(((((((int)threadIdx.y) * 144) + ((((int)threadIdx.x) >> 4) * 72)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_3) + 1728)] = ((((half)((((float)input0[((((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_3) + 393216)]) * 5.000000e-01f) * (1.000000e+00f + erff((((float)input0[((((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_3) + 393216)]) * 7.071068e-01f))))) * input3[((((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_3) + 393216)]) * __float2half_rn(2.000000e+00f));\n }\n for (int ax0_ax1_fused_inner_s_4 = 0; ax0_ax1_fused_inner_s_4 < 4; ++ax0_ax1_fused_inner_s_4) {\n m3_shared[(((((((int)threadIdx.y) * 144) + ((((int)threadIdx.x) >> 4) * 72)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_4) + 2304)] = ((((half)((((float)input0[((((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_4) + 524288)]) * 5.000000e-01f) * (1.000000e+00f + erff((((float)input0[((((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_4) + 524288)]) * 7.071068e-01f))))) * input3[((((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_4) + 524288)]) * __float2half_rn(2.000000e+00f));\n }\n for (int ax0_ax1_fused_inner_s_5 = 0; ax0_ax1_fused_inner_s_5 < 4; ++ax0_ax1_fused_inner_s_5) {\n m3_shared[(((((((int)threadIdx.y) * 144) + ((((int)threadIdx.x) >> 4) * 72)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_5) + 2880)] = ((((half)((((float)input0[((((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_5) + 655360)]) * 5.000000e-01f) * (1.000000e+00f + erff((((float)input0[((((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_5) + 655360)]) * 7.071068e-01f))))) * input3[((((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_5) + 655360)]) * __float2half_rn(2.000000e+00f));\n }\n for (int ax0_ax1_fused_inner_s_6 = 0; ax0_ax1_fused_inner_s_6 < 4; ++ax0_ax1_fused_inner_s_6) {\n m3_shared[(((((((int)threadIdx.y) * 144) + ((((int)threadIdx.x) >> 4) * 72)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_6) + 3456)] = ((((half)((((float)input0[((((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_6) + 786432)]) * 5.000000e-01f) * (1.000000e+00f + erff((((float)input0[((((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_6) + 786432)]) * 7.071068e-01f))))) * input3[((((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_6) + 786432)]) * __float2half_rn(2.000000e+00f));\n }\n for (int ax0_ax1_fused_inner_s_7 = 0; ax0_ax1_fused_inner_s_7 < 4; ++ax0_ax1_fused_inner_s_7) {\n m3_shared[(((((((int)threadIdx.y) * 144) + ((((int)threadIdx.x) >> 4) * 72)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_7) + 4032)] = ((((half)((((float)input0[((((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_7) + 917504)]) * 5.000000e-01f) * (1.000000e+00f + erff((((float)input0[((((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_7) + 917504)]) * 7.071068e-01f))))) * input3[((((((((((int)blockIdx.x) >> 5) * 1048576) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 4) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 15) * 4)) + ax0_ax1_fused_inner_s_7) + 917504)]) * __float2half_rn(2.000000e+00f));\n }\n *(uint4*)(input1_shared + (((((int)threadIdx.y) * 288) + ((((int)threadIdx.x) >> 3) * 72)) + ((((int)threadIdx.x) & 7) * 8))) = *(uint4*)(input1 + ((((((((int)blockIdx.x) & 31) * 2097152) + (((int)threadIdx.y) * 65536)) + ((((int)threadIdx.x) >> 3) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 7) * 8)));\n *(uint4*)(input1_shared + ((((((int)threadIdx.y) * 288) + ((((int)threadIdx.x) >> 3) * 72)) + ((((int)threadIdx.x) & 7) * 8)) + 1152)) = *(uint4*)(input1 + (((((((((int)blockIdx.x) & 31) * 2097152) + (((int)threadIdx.y) * 65536)) + ((((int)threadIdx.x) >> 3) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 262144));\n *(uint4*)(input1_shared + ((((((int)threadIdx.y) * 288) + ((((int)threadIdx.x) >> 3) * 72)) + ((((int)threadIdx.x) & 7) * 8)) + 2304)) = *(uint4*)(input1 + (((((((((int)blockIdx.x) & 31) * 2097152) + (((int)threadIdx.y) * 65536)) + ((((int)threadIdx.x) >> 3) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 524288));\n *(uint4*)(input1_shared + ((((((int)threadIdx.y) * 288) + ((((int)threadIdx.x) >> 3) * 72)) + ((((int)threadIdx.x) & 7) * 8)) + 3456)) = *(uint4*)(input1 + (((((((((int)blockIdx.x) & 31) * 2097152) + (((int)threadIdx.y) * 65536)) + ((((int)threadIdx.x) >> 3) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 786432));\n *(uint4*)(input1_shared + ((((((int)threadIdx.y) * 288) + ((((int)threadIdx.x) >> 3) * 72)) + ((((int)threadIdx.x) & 7) * 8)) + 4608)) = *(uint4*)(input1 + (((((((((int)blockIdx.x) & 31) * 2097152) + (((int)threadIdx.y) * 65536)) + ((((int)threadIdx.x) >> 3) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 1048576));\n *(uint4*)(input1_shared + ((((((int)threadIdx.y) * 288) + ((((int)threadIdx.x) >> 3) * 72)) + ((((int)threadIdx.x) & 7) * 8)) + 5760)) = *(uint4*)(input1 + (((((((((int)blockIdx.x) & 31) * 2097152) + (((int)threadIdx.y) * 65536)) + ((((int)threadIdx.x) >> 3) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 1310720));\n *(uint4*)(input1_shared + ((((((int)threadIdx.y) * 288) + ((((int)threadIdx.x) >> 3) * 72)) + ((((int)threadIdx.x) & 7) * 8)) + 6912)) = *(uint4*)(input1 + (((((((((int)blockIdx.x) & 31) * 2097152) + (((int)threadIdx.y) * 65536)) + ((((int)threadIdx.x) >> 3) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 1572864));\n *(uint4*)(input1_shared + ((((((int)threadIdx.y) * 288) + ((((int)threadIdx.x) >> 3) * 72)) + ((((int)threadIdx.x) & 7) * 8)) + 8064)) = *(uint4*)(input1 + (((((((((int)blockIdx.x) & 31) * 2097152) + (((int)threadIdx.y) * 65536)) + ((((int)threadIdx.x) >> 3) * 16384)) + (N1_outer * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 1835008));\n __syncthreads();\n for (int N1_inner_outer = 0; N1_inner_outer < 4; ++N1_inner_outer) {\n for (int ax0_outer = 0; ax0_outer < 8; ++ax0_outer) {\n nvcuda::wmma::load_matrix_sync(m3_shared_wmma_matrix_a[ax0_outer], (&(m3_shared[((ax0_outer * 576) + (N1_inner_outer * 16))])), 72);\n }\n nvcuda::wmma::load_matrix_sync(input1_shared_wmma_matrix_b[0], (&(input1_shared[((((int)threadIdx.y) * 2304) + (N1_inner_outer * 16))])), 72);\n for (int N0_c_outer = 0; N0_c_outer < 8; ++N0_c_outer) {\n nvcuda::wmma::mma_sync(m4_wmma_accumulator[N0_c_outer], m3_shared_wmma_matrix_a[N0_c_outer], input1_shared_wmma_matrix_b[0], m4_wmma_accumulator[N0_c_outer]);\n }\n }\n }\n __syncthreads();\n for (int ax0_inner_outer = 0; ax0_inner_outer < 8; ++ax0_inner_outer) {\n nvcuda::wmma::store_matrix_sync((&(input1_shared[((ax0_inner_outer * 1088) + (((int)threadIdx.y) * 32))])), m4_wmma_accumulator[ax0_inner_outer], 136, nvcuda::wmma::mem_row_major);\n }\n __syncthreads();\n for (int N0_inner_N2_inner_fused_outer_outer_outer = 0; N0_inner_N2_inner_fused_outer_outer_outer < 8; ++N0_inner_N2_inner_fused_outer_outer_outer) {\n uint4 __1;\n uint4 __2 = *(uint4*)(input1_shared + ((((N0_inner_N2_inner_fused_outer_outer_outer * 1088) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8)));\n uint4 __3 = make_uint4(__pack_half2(input2[(((((((int)blockIdx.x) >> 5) * 64) + (N0_inner_N2_inner_fused_outer_outer_outer * 8)) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4))], input2[(((((((int)blockIdx.x) >> 5) * 64) + (N0_inner_N2_inner_fused_outer_outer_outer * 8)) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4))]), __pack_half2(input2[(((((((int)blockIdx.x) >> 5) * 64) + (N0_inner_N2_inner_fused_outer_outer_outer * 8)) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4))], input2[(((((((int)blockIdx.x) >> 5) * 64) + (N0_inner_N2_inner_fused_outer_outer_outer * 8)) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4))]), __pack_half2(input2[(((((((int)blockIdx.x) >> 5) * 64) + (N0_inner_N2_inner_fused_outer_outer_outer * 8)) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4))], input2[(((((((int)blockIdx.x) >> 5) * 64) + (N0_inner_N2_inner_fused_outer_outer_outer * 8)) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4))]), __pack_half2(input2[(((((((int)blockIdx.x) >> 5) * 64) + (N0_inner_N2_inner_fused_outer_outer_outer * 8)) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4))], input2[(((((((int)blockIdx.x) >> 5) * 64) + (N0_inner_N2_inner_fused_outer_outer_outer * 8)) + (((int)threadIdx.y) * 2)) + (((int)threadIdx.x) >> 4))]));\n ((half2*)(&(__1.x)))->x = (((half2*)(&(__2.x)))->x+((half2*)(&(__3.x)))->x);\n ((half2*)(&(__1.x)))->y = (((half2*)(&(__2.x)))->y+((half2*)(&(__3.x)))->y);\n ((half2*)(&(__1.y)))->x = (((half2*)(&(__2.y)))->x+((half2*)(&(__3.y)))->x);\n ((half2*)(&(__1.y)))->y = (((half2*)(&(__2.y)))->y+((half2*)(&(__3.y)))->y);\n ((half2*)(&(__1.z)))->x = (((half2*)(&(__2.z)))->x+((half2*)(&(__3.z)))->x);\n ((half2*)(&(__1.z)))->y = (((half2*)(&(__2.z)))->y+((half2*)(&(__3.z)))->y);\n ((half2*)(&(__1.w)))->x = (((half2*)(&(__2.w)))->x+((half2*)(&(__3.w)))->x);\n ((half2*)(&(__1.w)))->y = (((half2*)(&(__2.w)))->y+((half2*)(&(__3.w)))->y);\n *(uint4*)(output0 + (((((((((int)blockIdx.x) >> 5) * 262144) + (N0_inner_N2_inner_fused_outer_outer_outer * 32768)) + (((int)threadIdx.y) * 8192)) + ((((int)threadIdx.x) >> 4) * 4096)) + ((((int)blockIdx.x) & 31) * 128)) + ((((int)threadIdx.x) & 15) * 8))) = __1;\n }\n}\n\n", + "block_size": [ + 32, + 4, + 1 + ], + "grid_size": [ + 1024, + 1, + 1 + ], + "latency": 5.927213191986084, + "name": "Group0", + "gain": 0 + } +] \ No newline at end of file