diff --git a/megatron/fused_kernels/fused_rotary_positional_embedding.cpp b/megatron/fused_kernels/fused_rotary_positional_embedding.cpp index c00ad8ead..ad6b26da0 100644 --- a/megatron/fused_kernels/fused_rotary_positional_embedding.cpp +++ b/megatron/fused_kernels/fused_rotary_positional_embedding.cpp @@ -14,72 +14,97 @@ * limitations under the License. */ -#include +#include + +#include "fused_rotary_positional_embedding.h" +#include "type_shim.h" namespace fused_rope { torch::Tensor fwd_cuda(const torch::Tensor &input, const torch::Tensor &cos, - const torch::Tensor &sin, const bool transpose_output); + const torch::Tensor &sin, const bool transpose_output) { + // input sizes: (s, b, h, d) + // s: sequence length + // b: batch size + // h: head num + // d: dim of each head + const int s = input.size(0); + const int b = input.size(1); + const int h = input.size(2); + const int d = input.size(3); + // input strides + const int stride_s = input.stride(0); + const int stride_b = input.stride(1); + const int stride_h = input.stride(2); + const int stride_d = input.stride(3); + // cos/sin's shape is always (s, 1, 1, d2), so the strides are same under + // different memory formats + const int d2 = cos.size(3); -torch::Tensor bwd_cuda(const torch::Tensor &output_grads, - const torch::Tensor &cos, const torch::Tensor &sin, - const bool transpose_output); + // output + auto act_options = input.options().requires_grad(false); + torch::Tensor output; + if (transpose_output) { + output = torch::empty({b, s, h, d}, act_options).transpose(0, 1); + } else { + output = torch::empty({s, b, h, d}, act_options); + } + // output strides + const int o_stride_s = output.stride(0); + const int o_stride_b = output.stride(1); + const int o_stride_h = output.stride(2); + const int o_stride_d = output.stride(3); -torch::Tensor fwd(const at::Tensor &input, const at::Tensor &cos, - const at::Tensor &sin, const bool transpose_output) { - TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); - TORCH_CHECK(cos.dim() == 4, "expected 4D tensor"); - TORCH_CHECK(sin.dim() == 4, "expected 4D tensor"); - TORCH_CHECK(input.size(0) == cos.size(0), - "expected input and cos tensor have the same sequence length"); - TORCH_CHECK(input.size(0) == sin.size(0), - "expected input and sin tensor have the same sequence length"); - TORCH_CHECK(cos.size(1) == 1 && cos.size(2) == 1, - "expected the second and third dims of the cos tensor equal 1"); - TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1, - "expected the second and third dims of the sin tensor equal 1"); - TORCH_CHECK(input.size(3) >= cos.size(3), - "expected the last dim of the input tensor is greater than the " - "cos tensor"); - TORCH_CHECK(input.size(3) >= sin.size(3), - "expected the last dim of the input tensor is greater than the " - "sin tensor"); - - return fwd_cuda(input, cos, sin, transpose_output); + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), 0, "dispatch_fused_rope_forward", + dispatch_fused_rope_forward( + s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, + o_stride_b, o_stride_h, o_stride_d, input.data_ptr(), + cos.data_ptr(), sin.data_ptr(), + output.data_ptr());); + return output; } -torch::Tensor bwd(const torch::Tensor &output_grads, const at::Tensor &cos, - const at::Tensor &sin, const bool transpose_output) { - TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); - TORCH_CHECK(cos.dim() == 4, "expected 4D tensor"); - TORCH_CHECK(sin.dim() == 4, "expected 4D tensor"); - TORCH_CHECK( - output_grads.size(0) == cos.size(0), - "expected output_grads and cos tensor have the same sequence length"); - TORCH_CHECK( - output_grads.size(0) == sin.size(0), - "expected output_grads and sin tensor have the same sequence length"); - TORCH_CHECK(cos.size(1) == 1 && cos.size(2) == 1, - "expected the second and third dims of the cos tensor equal 1"); - TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1, - "expected the second and third dims of the sin tensor equal 1"); - TORCH_CHECK( - output_grads.size(3) >= cos.size(3), - "expected the last dim of the output_grads tensor is greater than the " - "cos tensor"); - TORCH_CHECK( - output_grads.size(3) >= sin.size(3), - "expected the last dim of the output_grads tensor is greater than the " - "sin tensor"); - - return bwd_cuda(output_grads, cos, sin, transpose_output); -} +torch::Tensor bwd_cuda(const torch::Tensor &output_grads, + const torch::Tensor &cos, const torch::Tensor &sin, + const bool transpose_output) { + // output_grads sizes: (s, b, h, d) + // s: sequence length + // b: batch size + // h: head num + // d: dim of each head + const int s = output_grads.size(0); + const int b = output_grads.size(1); + const int h = output_grads.size(2); + const int d = output_grads.size(3); + // output_grads strides + const int stride_s = output_grads.stride(0); + const int stride_b = output_grads.stride(1); + const int stride_h = output_grads.stride(2); + const int stride_d = output_grads.stride(3); + // cos/sin's shape is always (s, 1, 1, d2), so the strides are same under + // different memory formats + const int d2 = cos.size(3); -} // end namespace fused_rope + auto act_options = output_grads.options().requires_grad(false); + torch::Tensor input_grads; + if (transpose_output) { + input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1); + } else { + input_grads = torch::empty({s, b, h, d}, act_options); + } + const int o_stride_s = input_grads.stride(0); + const int o_stride_b = input_grads.stride(1); + const int o_stride_h = input_grads.stride(2); + const int o_stride_d = input_grads.stride(3); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &fused_rope::fwd, - "Fused Rotary Positional Embedding -- Forward."); - m.def("backward", &fused_rope::bwd, - "Fused Rotary Positional Embedding -- Backward."); + DISPATCH_FLOAT_HALF_AND_BFLOAT( + output_grads.scalar_type(), 0, "dispatch_fused_rope_backward", + dispatch_fused_rope_backward( + s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, + o_stride_b, o_stride_h, o_stride_d, + output_grads.data_ptr(), cos.data_ptr(), + sin.data_ptr(), input_grads.data_ptr());) + return input_grads; } +} // end namespace fused_rope