Skip to content

Commit

Permalink
Address assertion problem (check hasDtype and hasSizes) and handle di…
Browse files Browse the repository at this point in the history
…fferent types for p and maxnorm args
  • Loading branch information
Branko Trifkovic committed May 31, 2024
1 parent 2084c46 commit 4f55ebb
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 24 deletions.
34 changes: 22 additions & 12 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4656,14 +4656,15 @@ LogicalResult AtenNormScalarOp::verify() {
LogicalResult AtenRenormOp::verify() {

auto selfType = cast<BaseTensorType>(getSelf().getType());
auto selfDtype = selfType.getDtype();
auto selfRankedType = getSelf().getType().cast<RankedTensorType>();
int64_t selfRank = selfRankedType.getRank();

if (!selfType.hasDtype()) {
if (!selfType.hasDtype() || !selfType.hasSizes()) {
return success();
}

auto inShape = selfType.getSizes();
int64_t selfRank = inShape.size();
auto selfDtype = selfType.getDtype();

// Check if dtype is one of those supported by renorm operation.
// ComplexType will match any torch complex types, but each float must be
// checked individually.
Expand All @@ -4687,34 +4688,43 @@ LogicalResult AtenRenormOp::verify() {

// Check if argument p is valid
auto pType = getP().getType();
double_t p;

if (pType.isa<mlir::ComplexType>()) {
return emitOpError("renorm: p must be real-valued");
}

if (!matchPattern(getP(), m_TorchConstantFloat(&p))) {
// The argument 'p' can be either an integer or a floating-point number,
// so we need to consider both options and check if 'p' is within the correct
// range
int64_t pInt = 1;
double_t pDouble = 1;
if (!matchPattern(getP(), m_TorchConstantInt(&pInt)) &&
!matchPattern(getP(), m_TorchConstantFloat(&pDouble))) {
return emitOpError("renorm: p must be real-valued");
}

if (p < 0) {
if (pInt <= 0 || pDouble <= 0) {
return emitOpError("renorm: non-positive norm not supported");
}

// Check if argument maxnorm is valid
auto maxnormType = getMaxnorm().getType();
double_t maxnorm;
if (maxnormType.isa<mlir::ComplexType>()) {
return emitOpError("renorm: maxnorm must be real-valued");
}

if (!matchPattern(getP(), m_TorchConstantFloat(&maxnorm))) {
// The argument 'maxnorm' can be either an integer or a floating-point number,
// so we need to consider both options and check if 'maxnorm' is within the
// correct range
int64_t maxnormInt = 0;
double_t maxnormDouble = 0;
if (!matchPattern(getMaxnorm(), m_TorchConstantInt(&maxnormInt)) &&
!matchPattern(getMaxnorm(), m_TorchConstantFloat(&maxnormDouble))) {
return emitOpError("renorm: maxnorm must be real-valued");
}

if (maxnorm < 0) {
return emitOpError("renorm: expected maxnorm to be >= 0 but got ")
<< maxnorm;
if (maxnormInt < 0 || maxnormDouble < 0) {
return emitOpError("renorm: expected maxnorm to be >= 0");
}

// Get the dimension
Expand Down
11 changes: 11 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2083,9 +2083,20 @@ class DecomposeAtenRenormOp : public OpRewritePattern<AtenRenormOp> {
Value p = op.getP();
Value maxnorm = op.getMaxnorm();

// Prepare all necessary variables
auto ndim = getTensorRank(self);
auto resType = cast<BaseTensorType>(self.getType());

if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}

Type dtype = resType.getDtype();
if (isa<mlir::ComplexType>(dtype)) {
return rewriter.notifyMatchFailure(
op, "lowering of aten.renorm for complex inputs dtype is "
"currently unimplemented");
}

ArrayRef<int64_t> inputSizeArrayRef = resType.getSizes();
SmallVector<int64_t> inputSize;
Expand Down
24 changes: 12 additions & 12 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,15 +639,15 @@ def AtenInstanceNormModule_basic(module, tu: TestUtils):
class RenormModuleFloat32(torch.nn.Module):
def __init__(self):
super().__init__()
self.p = 2.0
self.dim = 3
self.p = 2
self.dim = 1
self.maxnorm = 10

@export
@annotate_args(
[
None,
([3, 3, 4, 5], torch.float32, True),
([3, 3], torch.float32, True),
]
)
def forward(self, x):
Expand All @@ -656,21 +656,21 @@ def forward(self, x):

@register_test_case(module_factory=lambda: RenormModuleFloat32())
def RenormModuleFloat32_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 3, 4, 5))
module.forward(tu.rand(3, 3))


class RenormModuleFloat16(torch.nn.Module):
def __init__(self):
super().__init__()
self.p = 2.0
self.p = 2.1
self.dim = 1
self.maxnorm = 10.5
self.maxnorm = 10

@export
@annotate_args(
[
None,
([3, 3, 4, 5], torch.float16, True),
([3, 4, 5], torch.float16, True),
]
)
def forward(self, x):
Expand All @@ -679,21 +679,21 @@ def forward(self, x):

@register_test_case(module_factory=lambda: RenormModuleFloat16())
def RenormModuleFloat16_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 3, 4, 5).to(torch.float16))
module.forward(tu.rand(3, 4, 5).to(torch.float16))


class RenormModuleFloat32NegativeDim(torch.nn.Module):
def __init__(self):
super().__init__()
self.p = 2.0
self.p = 2.3
self.dim = -1
self.maxnorm = 10
self.maxnorm = 5.2

@export
@annotate_args(
[
None,
([3, 3, 4, 5], torch.float32, True),
([1, 4, 5, 2], torch.float32, True),
]
)
def forward(self, x):
Expand All @@ -702,4 +702,4 @@ def forward(self, x):

@register_test_case(module_factory=lambda: RenormModuleFloat32NegativeDim())
def RenormModuleFloat32NegativeDim_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 3, 4, 5).to(torch.float32))
module.forward(tu.rand(1, 4, 5, 2).to(torch.float32))

0 comments on commit 4f55ebb

Please sign in to comment.