Skip to content

Commit

Permalink
update conv1d to new layout specifier gen, axis mapping, and use non-…
Browse files Browse the repository at this point in the history
…singlethreaded local workgroup (pytorch#5504)

Summary:
Pull Request resolved: pytorch#5504

Using new load_texel_lpos for simpler updating

Reviewed By: SS-JIA

Differential Revision: D62990822

fbshipit-source-id: 9163b807d9095ebdb089f08aa6ea20fbbb563d02
  • Loading branch information
nathanaelsee authored and facebook-github-bot committed Sep 21, 2024
1 parent 0eee42a commit d5fdbd4
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 46 deletions.
76 changes: 33 additions & 43 deletions backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,22 @@

layout(std430) buffer;

layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
layout(set = 0, binding = 2) uniform PRECISION sampler3D kernel_in;
layout(set = 0, binding = 3) uniform PRECISION sampler3D bias_in;

layout(set = 0, binding = 4) uniform PRECISION restrict OutLimits {
ivec3 out_limits;
};

layout(set = 0, binding = 5) uniform PRECISION restrict InSizes {
ivec4 in_sizes;
};

layout(set = 0, binding = 6) uniform PRECISION restrict Params {
int kernel_size;
int stride;
int padding;
int dilation;
int in_group_size;
int out_group_size;
};

layout(set = 0, binding = 7) uniform PRECISION restrict OutputParams {
float out_min;
float out_max;
};
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "kernel_in", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "bias_in", DTYPE, STORAGE)}

${layout_declare_ubo(B, "ivec3", "out_limits")}
${layout_declare_ubo(B, "ivec4", "in_sizes")}

${layout_declare_ubo(B, "ivec4", "out_axis_map")}
${layout_declare_ubo(B, "ivec4", "in_axis_map")}
${layout_declare_ubo(B, "ivec4", "kernel_axis_map")}
${layout_declare_ubo(B, "ivec4", "bias_axis_map")}

${layout_declare_ubo(B,"int", "kernel_size", "int", "stride", "int", "padding", "int", "dilation", "int", "in_group_size", "int", "out_group_size")}

${layout_declare_ubo(B, "float", "out_min", "float", "out_max")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

Expand All @@ -67,9 +57,9 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
// shader invocations, where each invocation computes 1 result. But that
// performs worse.
void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const ivec3 lpos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, out_limits))) {
if (any(greaterThanEqual(lpos, out_limits))) {
return;
}

Expand All @@ -78,8 +68,8 @@ void main() {

// "out_c" is the output's channel index where we write our result.
// Across shader invocations, this is the only value that varies.
int out_c = pos.y;
vec4 bias = texelFetch(bias_in, ivec3(out_c, 0, 0), 0);
int out_c = lpos.y;
VEC4_T bias = load_texel_lpos(bias_in, ivec3(out_c, 0, 0), bias_axis_map);

// "in_c" tracks the input's channel start index.
// We iterate over the input group that corresponds to the output group.
Expand All @@ -98,7 +88,7 @@ void main() {
int out_l = 0;

for (int in_l = l_start; in_l < l_end; in_l += stride, ++out_l) {
vec4 sum = vec4(0);
VEC4_T sum = VEC4_T(0);

for (int in_c = c_start; in_c < c_end; ++in_c) {
// "k" tracks the kernel's index for our input-kernel computation.
Expand All @@ -107,25 +97,25 @@ void main() {
for (int k = 0; k < kernel_size; k += 4) {
// Since the weight tensor is width-packed, which is along the length
// dimension, we can batch-read four elements at a time.
const ivec3 w_pos = ivec3(k / 4, in_c % in_group_size, out_c);
const vec4 weight = texelFetch(kernel_in, w_pos, 0);
const ivec3 w_lpos = ivec3(k / 4, in_c % in_group_size, out_c);
const VEC4_T weight = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map);

const ivec3 in_pos_0 = ivec3(in_l + k * dilation, in_c, n / 4);
sum = fma(weight.xxxx, texelFetch(image_in, in_pos_0, 0), sum);
ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, n / 4), in_axis_map);
sum = fma(weight.xxxx, load_texel(t_in, in_pos), sum);

const ivec3 in_pos_1 = ivec3(in_l + (k+1) * dilation, in_c, n / 4);
sum = fma(weight.yyyy, texelFetch(image_in, in_pos_1, 0), sum);
in_pos[in_axis_map.x] += dilation;
sum = fma(weight.yyyy, load_texel(t_in, in_pos), sum);

const ivec3 in_pos_2 = ivec3(in_l + (k+2) * dilation, in_c, n / 4);
sum = fma(weight.zzzz, texelFetch(image_in, in_pos_2, 0), sum);
in_pos[in_axis_map.x] += dilation;
sum = fma(weight.zzzz, load_texel(t_in, in_pos), sum);

const ivec3 in_pos_3 = ivec3(in_l + (k+3) * dilation, in_c, n / 4);
sum = fma(weight.wwww, texelFetch(image_in, in_pos_3, 0), sum);
in_pos[in_axis_map.x] += dilation;
sum = fma(weight.wwww, load_texel(t_in, in_pos), sum);
}
}

ivec3 out_pos = ivec3(out_l, out_c, n / 4);
imageStore(image_out, out_pos, op(sum + bias.x, out_min, out_max));
const ivec3 out_lpos = ivec3(out_l, out_c, n / 4);
write_texel_lpos(t_out, out_lpos, op(sum + bias.x, out_min, out_max), out_axis_map);
}
}
}
3 changes: 1 addition & 2 deletions backends/vulkan/runtime/graph/ops/glsl/conv1d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
conv1d:
parameter_names_with_default_values:
OPERATOR: X
NDIM: 3
DTYPE: float
PACKING: C_packed
STORAGE: texture3d
generate_variant_forall:
DTYPE:
- VALUE: half
Expand Down
6 changes: 5 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ void add_conv1d_node(
int32_t out_group_size = static_cast<int64_t>(out_channels / groups_val);

utils::uvec3 global_size = {1, static_cast<uint32_t>(out_channels), 1};
utils::uvec3 local_size = {1, 1, 1};
utils::uvec3 local_size = {1, 64, 1};

Kernel1dParams kernel_params = {
kernel_size,
Expand Down Expand Up @@ -476,6 +476,10 @@ void add_conv1d_node(
{
t_out->logical_limits_ubo(),
t_in->sizes_ubo(),
t_out->axis_map_ubo(),
t_in->axis_map_ubo(),
t_weight->axis_map_ubo(),
t_bias->axis_map_ubo(),
graph.create_params_buffer(kernel_params),
graph.create_params_buffer(out_params),
},
Expand Down

0 comments on commit d5fdbd4

Please sign in to comment.