Skip to content

Commit

Permalink
Update fused_rotary_positional_embedding.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
StellaAthena authored Nov 27, 2023
1 parent 417f55c commit 5b1331d
Showing 1 changed file with 9 additions and 14 deletions.
23 changes: 9 additions & 14 deletions megatron/fused_kernels/fused_rotary_positional_embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@
namespace fused_rope {

torch::Tensor fwd_cuda(const torch::Tensor &input, const torch::Tensor &cos,
const torch::Tensor &sin);
const torch::Tensor &sin, const bool transpose_output);

torch::Tensor bwd_cuda(const torch::Tensor &output_grads,
const torch::Tensor &cos, const torch::Tensor &sin);
const torch::Tensor &cos, const torch::Tensor &sin,
const bool transpose_output);

torch::Tensor fwd(const at::Tensor &input_, const at::Tensor &cos_,
const at::Tensor &sin_) {
auto input = input_.contiguous();
auto cos = cos_.contiguous();
auto sin = sin_.contiguous();
torch::Tensor fwd(const at::Tensor &input, const at::Tensor &cos,
const at::Tensor &sin, const bool transpose_output) {
TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
TORCH_CHECK(cos.dim() == 4, "expected 4D tensor");
TORCH_CHECK(sin.dim() == 4, "expected 4D tensor");
Expand All @@ -47,14 +45,11 @@ torch::Tensor fwd(const at::Tensor &input_, const at::Tensor &cos_,
"expected the last dim of the input tensor is greater than the "
"sin tensor");

return fwd_cuda(input, cos, sin);
return fwd_cuda(input, cos, sin, transpose_output);
}

torch::Tensor bwd(const torch::Tensor &output_grads_, const at::Tensor &cos_,
const at::Tensor &sin_) {
auto output_grads = output_grads_.contiguous();
auto cos = cos_.contiguous();
auto sin = sin_.contiguous();
torch::Tensor bwd(const torch::Tensor &output_grads, const at::Tensor &cos,
const at::Tensor &sin, const bool transpose_output) {
TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
TORCH_CHECK(cos.dim() == 4, "expected 4D tensor");
TORCH_CHECK(sin.dim() == 4, "expected 4D tensor");
Expand All @@ -77,7 +72,7 @@ torch::Tensor bwd(const torch::Tensor &output_grads_, const at::Tensor &cos_,
"expected the last dim of the output_grads tensor is greater than the "
"sin tensor");

return bwd_cuda(output_grads, cos, sin);
return bwd_cuda(output_grads, cos, sin, transpose_output);
}

} // end namespace fused_rope
Expand Down

0 comments on commit 5b1331d

Please sign in to comment.