Skip to content

Commit

Permalink
Update fused_rotary_positional_embedding.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
StellaAthena authored Nov 27, 2023
1 parent 5b1331d commit f20f1f5
Showing 1 changed file with 83 additions and 58 deletions.
141 changes: 83 additions & 58 deletions megatron/fused_kernels/fused_rotary_positional_embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,72 +14,97 @@
* limitations under the License.
*/

#include <torch/extension.h>
#include <ATen/ATen.h>

#include "fused_rotary_positional_embedding.h"
#include "type_shim.h"

namespace fused_rope {

torch::Tensor fwd_cuda(const torch::Tensor &input, const torch::Tensor &cos,
const torch::Tensor &sin, const bool transpose_output);
const torch::Tensor &sin, const bool transpose_output) {
// input sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = input.size(0);
const int b = input.size(1);
const int h = input.size(2);
const int d = input.size(3);
// input strides
const int stride_s = input.stride(0);
const int stride_b = input.stride(1);
const int stride_h = input.stride(2);
const int stride_d = input.stride(3);
// cos/sin's shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = cos.size(3);

torch::Tensor bwd_cuda(const torch::Tensor &output_grads,
const torch::Tensor &cos, const torch::Tensor &sin,
const bool transpose_output);
// output
auto act_options = input.options().requires_grad(false);
torch::Tensor output;
if (transpose_output) {
output = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
output = torch::empty({s, b, h, d}, act_options);
}
// output strides
const int o_stride_s = output.stride(0);
const int o_stride_b = output.stride(1);
const int o_stride_h = output.stride(2);
const int o_stride_d = output.stride(3);

torch::Tensor fwd(const at::Tensor &input, const at::Tensor &cos,
const at::Tensor &sin, const bool transpose_output) {
TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
TORCH_CHECK(cos.dim() == 4, "expected 4D tensor");
TORCH_CHECK(sin.dim() == 4, "expected 4D tensor");
TORCH_CHECK(input.size(0) == cos.size(0),
"expected input and cos tensor have the same sequence length");
TORCH_CHECK(input.size(0) == sin.size(0),
"expected input and sin tensor have the same sequence length");
TORCH_CHECK(cos.size(1) == 1 && cos.size(2) == 1,
"expected the second and third dims of the cos tensor equal 1");
TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1,
"expected the second and third dims of the sin tensor equal 1");
TORCH_CHECK(input.size(3) >= cos.size(3),
"expected the last dim of the input tensor is greater than the "
"cos tensor");
TORCH_CHECK(input.size(3) >= sin.size(3),
"expected the last dim of the input tensor is greater than the "
"sin tensor");

return fwd_cuda(input, cos, sin, transpose_output);
DISPATCH_FLOAT_HALF_AND_BFLOAT(
input.scalar_type(), 0, "dispatch_fused_rope_forward",
dispatch_fused_rope_forward(
s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s,
o_stride_b, o_stride_h, o_stride_d, input.data_ptr<scalar_t_0>(),
cos.data_ptr<scalar_t_0>(), sin.data_ptr<scalar_t_0>(),
output.data_ptr<scalar_t_0>()););
return output;
}

torch::Tensor bwd(const torch::Tensor &output_grads, const at::Tensor &cos,
const at::Tensor &sin, const bool transpose_output) {
TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
TORCH_CHECK(cos.dim() == 4, "expected 4D tensor");
TORCH_CHECK(sin.dim() == 4, "expected 4D tensor");
TORCH_CHECK(
output_grads.size(0) == cos.size(0),
"expected output_grads and cos tensor have the same sequence length");
TORCH_CHECK(
output_grads.size(0) == sin.size(0),
"expected output_grads and sin tensor have the same sequence length");
TORCH_CHECK(cos.size(1) == 1 && cos.size(2) == 1,
"expected the second and third dims of the cos tensor equal 1");
TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1,
"expected the second and third dims of the sin tensor equal 1");
TORCH_CHECK(
output_grads.size(3) >= cos.size(3),
"expected the last dim of the output_grads tensor is greater than the "
"cos tensor");
TORCH_CHECK(
output_grads.size(3) >= sin.size(3),
"expected the last dim of the output_grads tensor is greater than the "
"sin tensor");

return bwd_cuda(output_grads, cos, sin, transpose_output);
}
torch::Tensor bwd_cuda(const torch::Tensor &output_grads,
const torch::Tensor &cos, const torch::Tensor &sin,
const bool transpose_output) {
// output_grads sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = output_grads.size(0);
const int b = output_grads.size(1);
const int h = output_grads.size(2);
const int d = output_grads.size(3);
// output_grads strides
const int stride_s = output_grads.stride(0);
const int stride_b = output_grads.stride(1);
const int stride_h = output_grads.stride(2);
const int stride_d = output_grads.stride(3);
// cos/sin's shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = cos.size(3);

} // end namespace fused_rope
auto act_options = output_grads.options().requires_grad(false);
torch::Tensor input_grads;
if (transpose_output) {
input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
input_grads = torch::empty({s, b, h, d}, act_options);
}
const int o_stride_s = input_grads.stride(0);
const int o_stride_b = input_grads.stride(1);
const int o_stride_h = input_grads.stride(2);
const int o_stride_d = input_grads.stride(3);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &fused_rope::fwd,
"Fused Rotary Positional Embedding -- Forward.");
m.def("backward", &fused_rope::bwd,
"Fused Rotary Positional Embedding -- Backward.");
DISPATCH_FLOAT_HALF_AND_BFLOAT(
output_grads.scalar_type(), 0, "dispatch_fused_rope_backward",
dispatch_fused_rope_backward(
s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s,
o_stride_b, o_stride_h, o_stride_d,
output_grads.data_ptr<scalar_t_0>(), cos.data_ptr<scalar_t_0>(),
sin.data_ptr<scalar_t_0>(), input_grads.data_ptr<scalar_t_0>());)
return input_grads;
}
} // end namespace fused_rope

0 comments on commit f20f1f5

Please sign in to comment.