Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix onnx.Pad constant #172

Merged
merged 1 commit into from
Jun 10, 2024
Merged

Conversation

josel-amd
Copy link
Collaborator

Looking at the GeneratedTorchOps.td. This leaves two options, either the operator specification is wrong or the code generating it.

I check a bit and it seems that indeed the op spec is right. References:

https://github.com/pytorch/pytorch/blob/543a870943120484db547382ed9ca9538a40f284/torch/csrc/api/include/torch/nn/functional/padding.h#L12

https://pytorch.org/docs/main/generated/torch.nn.functional.pad.html#torch.nn.functional.pad

def Torch_AtenPadOp : Torch_Op<"aten.pad", [
    AllowsTypeRefinement,
    HasValueSemantics,
    ReadOnly
  ]> {
  let summary = "Generated op for `aten::pad : (Tensor, int[], str, float?) -> (Tensor)`";
  let arguments = (ins
    AnyTorchTensorType:$self,
    AnyTorchListOfTorchIntType:$pad,
    Torch_StringType:$mode,
    AnyTorchOptionalFloatType:$value
  );
  let results = (outs
    AnyTorchTensorType:$result
  );
  let hasCustomAssemblyFormat = 1;
  let extraClassDefinition = [{
    ParseResult AtenPadOp::parse(OpAsmParser &parser, OperationState &result) {
      return parseDefaultTorchOp(parser, result, 4, 1);
    }
    void AtenPadOp::print(OpAsmPrinter &printer) {
      printDefaultTorchOp(printer, *this, 4, 1);
    }
  }];
}

@cferry-AMD
Copy link

To be sure I understood correctly, is the support of other data types than f32 what is wrong in TorchOnnxToTorch? Is there any test case where this appears?

@josel-amd
Copy link
Collaborator Author

To be sure I understood correctly, is the support of other data types than f32 what is wrong in TorchOnnxToTorch? Is there any test case where this appears?

Yes, only f32 is allowed in that specific argument. I don't think there is a test case for it. At least nothing started failing. I can add one, if there is general agreement about the fix.

Copy link

@cferry-AMD cferry-AMD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good, indeed if we look closely at what happens above, any constant scalar passed is turned into a float. Thanks!

@mgehre-amd mgehre-amd merged commit c4ccde3 into feature/backport_ea1_ops Jun 10, 2024
2 checks passed
@mgehre-amd mgehre-amd deleted the jose.fix_onnx_pad branch June 10, 2024 09:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants