Skip to content

Commit

Permalink
Merge pull request #1085 from EleutherAI/avoid-.contiguous()
Browse files Browse the repository at this point in the history
Avoid .contiguous()
  • Loading branch information
StellaAthena authored Nov 27, 2023
2 parents d86c399 + f20f1f5 commit af387b0
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 110 deletions.
146 changes: 83 additions & 63 deletions megatron/fused_kernels/fused_rotary_positional_embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,77 +14,97 @@
* limitations under the License.
*/

#include <torch/extension.h>
#include <ATen/ATen.h>

#include "fused_rotary_positional_embedding.h"
#include "type_shim.h"

namespace fused_rope {

torch::Tensor fwd_cuda(const torch::Tensor &input, const torch::Tensor &cos,
const torch::Tensor &sin);

torch::Tensor bwd_cuda(const torch::Tensor &output_grads,
const torch::Tensor &cos, const torch::Tensor &sin);
const torch::Tensor &sin, const bool transpose_output) {
// input sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = input.size(0);
const int b = input.size(1);
const int h = input.size(2);
const int d = input.size(3);
// input strides
const int stride_s = input.stride(0);
const int stride_b = input.stride(1);
const int stride_h = input.stride(2);
const int stride_d = input.stride(3);
// cos/sin's shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = cos.size(3);

torch::Tensor fwd(const at::Tensor &input_, const at::Tensor &cos_,
const at::Tensor &sin_) {
auto input = input_.contiguous();
auto cos = cos_.contiguous();
auto sin = sin_.contiguous();
TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
TORCH_CHECK(cos.dim() == 4, "expected 4D tensor");
TORCH_CHECK(sin.dim() == 4, "expected 4D tensor");
TORCH_CHECK(input.size(0) == cos.size(0),
"expected input and cos tensor have the same sequence length");
TORCH_CHECK(input.size(0) == sin.size(0),
"expected input and sin tensor have the same sequence length");
TORCH_CHECK(cos.size(1) == 1 && cos.size(2) == 1,
"expected the second and third dims of the cos tensor equal 1");
TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1,
"expected the second and third dims of the sin tensor equal 1");
TORCH_CHECK(input.size(3) >= cos.size(3),
"expected the last dim of the input tensor is greater than the "
"cos tensor");
TORCH_CHECK(input.size(3) >= sin.size(3),
"expected the last dim of the input tensor is greater than the "
"sin tensor");
// output
auto act_options = input.options().requires_grad(false);
torch::Tensor output;
if (transpose_output) {
output = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
output = torch::empty({s, b, h, d}, act_options);
}
// output strides
const int o_stride_s = output.stride(0);
const int o_stride_b = output.stride(1);
const int o_stride_h = output.stride(2);
const int o_stride_d = output.stride(3);

return fwd_cuda(input, cos, sin);
DISPATCH_FLOAT_HALF_AND_BFLOAT(
input.scalar_type(), 0, "dispatch_fused_rope_forward",
dispatch_fused_rope_forward(
s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s,
o_stride_b, o_stride_h, o_stride_d, input.data_ptr<scalar_t_0>(),
cos.data_ptr<scalar_t_0>(), sin.data_ptr<scalar_t_0>(),
output.data_ptr<scalar_t_0>()););
return output;
}

torch::Tensor bwd(const torch::Tensor &output_grads_, const at::Tensor &cos_,
const at::Tensor &sin_) {
auto output_grads = output_grads_.contiguous();
auto cos = cos_.contiguous();
auto sin = sin_.contiguous();
TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
TORCH_CHECK(cos.dim() == 4, "expected 4D tensor");
TORCH_CHECK(sin.dim() == 4, "expected 4D tensor");
TORCH_CHECK(
output_grads.size(0) == cos.size(0),
"expected output_grads and cos tensor have the same sequence length");
TORCH_CHECK(
output_grads.size(0) == sin.size(0),
"expected output_grads and sin tensor have the same sequence length");
TORCH_CHECK(cos.size(1) == 1 && cos.size(2) == 1,
"expected the second and third dims of the cos tensor equal 1");
TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1,
"expected the second and third dims of the sin tensor equal 1");
TORCH_CHECK(
output_grads.size(3) >= cos.size(3),
"expected the last dim of the output_grads tensor is greater than the "
"cos tensor");
TORCH_CHECK(
output_grads.size(3) >= sin.size(3),
"expected the last dim of the output_grads tensor is greater than the "
"sin tensor");

return bwd_cuda(output_grads, cos, sin);
}
torch::Tensor bwd_cuda(const torch::Tensor &output_grads,
const torch::Tensor &cos, const torch::Tensor &sin,
const bool transpose_output) {
// output_grads sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = output_grads.size(0);
const int b = output_grads.size(1);
const int h = output_grads.size(2);
const int d = output_grads.size(3);
// output_grads strides
const int stride_s = output_grads.stride(0);
const int stride_b = output_grads.stride(1);
const int stride_h = output_grads.stride(2);
const int stride_d = output_grads.stride(3);
// cos/sin's shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = cos.size(3);

} // end namespace fused_rope
auto act_options = output_grads.options().requires_grad(false);
torch::Tensor input_grads;
if (transpose_output) {
input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
input_grads = torch::empty({s, b, h, d}, act_options);
}
const int o_stride_s = input_grads.stride(0);
const int o_stride_b = input_grads.stride(1);
const int o_stride_h = input_grads.stride(2);
const int o_stride_d = input_grads.stride(3);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &fused_rope::fwd,
"Fused Rotary Positional Embedding -- Forward.");
m.def("backward", &fused_rope::bwd,
"Fused Rotary Positional Embedding -- Backward.");
DISPATCH_FLOAT_HALF_AND_BFLOAT(
output_grads.scalar_type(), 0, "dispatch_fused_rope_backward",
dispatch_fused_rope_backward(
s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s,
o_stride_b, o_stride_h, o_stride_d,
output_grads.data_ptr<scalar_t_0>(), cos.data_ptr<scalar_t_0>(),
sin.data_ptr<scalar_t_0>(), input_grads.data_ptr<scalar_t_0>());)
return input_grads;
}
} // end namespace fused_rope
115 changes: 68 additions & 47 deletions megatron/fused_kernels/fused_rotary_positional_embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,70 +25,83 @@
namespace {

template <typename scalar_t>
__global__ void fused_rope_forward(int sq, int b, int np, int hn, int hn2,
__global__ void fused_rope_forward(int h, int d, int d2, int stride_s,
int stride_b, int stride_h, int stride_d,
int o_stride_s, int o_stride_b,
int o_stride_h, int o_stride_d,
const scalar_t* src, const scalar_t* cos,
const scalar_t* sin, scalar_t* dst) {
int sq_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = sq_id * b * np * hn + b_id * np * hn;
int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = s_id * stride_s + b_id * stride_b;
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
#pragma unroll
for (int hn_id = threadIdx.x; hn_id < hn2; hn_id += blockDim.x) {
scalar_t v_cos = cos[sq_id * hn2 + hn_id];
scalar_t v_sin = sin[sq_id * hn2 + hn_id];
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
scalar_t v_cos = cos[s_id * d2 + d_id];
scalar_t v_sin = sin[s_id * d2 + d_id];
#pragma unroll
for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) {
int offset_src_dst = offset_block + head_id * hn + hn_id;
scalar_t v_src = src[offset_src_dst];
scalar_t v_src_rotate = (hn_id + hn2 / 2 < hn2)
? -src[offset_src_dst + hn2 / 2]
: src[offset_src_dst + hn2 / 2 - hn2];
dst[offset_src_dst] = v_src * v_cos + v_src_rotate * v_sin;
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
scalar_t v_src = src[offset_src];
scalar_t v_src_rotate = (d_id + d2 / 2 < d2)
? -src[offset_src + (d2 / 2) * stride_d]
: src[offset_src + (d2 / 2 - d2) * stride_d];
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
}
}

// copy the rest
if (hn > hn2) {
if (d > d2) {
#pragma unroll
for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) {
int offset_head = offset_block + head_id * hn;
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_head = offset_block + h_id * stride_h;
int offset_head_dst = offset_block_dst + h_id * o_stride_h;
#pragma unroll
for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) {
dst[offset_head + hn_id] = src[offset_head + hn_id];
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
dst[offset_head_dst + d_id * o_stride_d] =
src[offset_head + d_id * stride_d];
}
}
}
}

template <typename scalar_t>
__global__ void fused_rope_backward(int sq, int b, int np, int hn, int hn2,
__global__ void fused_rope_backward(int h, int d, int d2, int stride_s,
int stride_b, int stride_h, int stride_d,
int o_stride_s, int o_stride_b,
int o_stride_h, int o_stride_d,
const scalar_t* src, const scalar_t* cos,
const scalar_t* sin, scalar_t* dst) {
int sq_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = sq_id * b * np * hn + b_id * np * hn;
int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = s_id * stride_s + b_id * stride_b;
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
#pragma unroll
for (int hn_id = threadIdx.x; hn_id < hn2; hn_id += blockDim.x) {
scalar_t v_cos = cos[sq_id * hn2 + hn_id];
scalar_t v_sin = (hn_id + hn2 / 2 < hn2)
? sin[sq_id * hn2 + hn_id + hn2 / 2]
: -sin[sq_id * hn2 + hn_id + hn2 / 2 - hn2];
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
scalar_t v_cos = cos[s_id * d2 + d_id];
scalar_t v_sin = (d_id + d2 / 2 < d2)
? sin[s_id * d2 + d_id + d2 / 2]
: -sin[s_id * d2 + d_id + d2 / 2 - d2];
#pragma unroll
for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) {
int offset_src_dst = offset_block + head_id * hn + hn_id;
scalar_t v_src = src[offset_src_dst];
scalar_t v_src_rotate = (hn_id + hn2 / 2 < hn2)
? src[offset_src_dst + hn2 / 2]
: src[offset_src_dst + hn2 / 2 - hn2];
dst[offset_src_dst] = v_src * v_cos + v_src_rotate * v_sin;
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
scalar_t v_src = src[offset_src];
scalar_t v_src_rotate = (d_id + d2 / 2 < d2)
? src[offset_src + (d2 / 2) * stride_d]
: src[offset_src + (d2 / 2 - d2) * stride_d];
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
}
}

// handle the tail
if (hn > hn2) {
if (d > d2) {
#pragma unroll
for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) {
int offset_head = offset_block + head_id * hn;
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_head = offset_block + h_id * stride_h;
int offset_head_dst = offset_block_dst + h_id * o_stride_h;
#pragma unroll
for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) {
dst[offset_head + hn_id] = src[offset_head + hn_id];
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d];
}
}
}
Expand All @@ -97,32 +110,40 @@ __global__ void fused_rope_backward(int sq, int b, int np, int hn, int hn2,
} // end of anonymous namespace

template <typename scalar_t>
void dispatch_fused_rope_forward(int sq, int b, int np, int hn, int hn2,
void dispatch_fused_rope_forward(int s, int b, int h, int d, int d2,
int stride_s, int stride_b, int stride_h,
int stride_d, int o_stride_s, int o_stride_b,
int o_stride_h, int o_stride_d,
const scalar_t* input, const scalar_t* cos,
const scalar_t* sin, scalar_t* output) {
auto stream = at::cuda::getCurrentCUDAStream();

int warps_per_block = np < 16 ? 4 : 8;
dim3 blocks(sq, b);
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b);
dim3 threads(C10_WARP_SIZE, warps_per_block);

fused_rope_forward<<<blocks, threads, 0, stream>>>(sq, b, np, hn, hn2, input,
cos, sin, output);
fused_rope_forward<<<blocks, threads, 0, stream>>>(
h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, input, cos, sin, output);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template <typename scalar_t>
void dispatch_fused_rope_backward(int sq, int b, int np, int hn, int hn2,
void dispatch_fused_rope_backward(int s, int b, int h, int d, int d2,
int stride_s, int stride_b, int stride_h,
int stride_d, int o_stride_s, int o_stride_b,
int o_stride_h, int o_stride_d,
const scalar_t* output_grads,
const scalar_t* cos, const scalar_t* sin,
scalar_t* input_grads) {
auto stream = at::cuda::getCurrentCUDAStream();

int warps_per_block = np < 16 ? 4 : 8;
dim3 blocks(sq, b);
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b);
dim3 threads(C10_WARP_SIZE, warps_per_block);

fused_rope_backward<<<blocks, threads, 0, stream>>>(
sq, b, np, hn, hn2, output_grads, cos, sin, input_grads);
h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, output_grads, cos, sin, input_grads);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

0 comments on commit af387b0

Please sign in to comment.