Skip to content

Commit

Permalink
Address change requests
Browse files Browse the repository at this point in the history
  • Loading branch information
Branko Trifkovic committed May 31, 2024
1 parent 4f55ebb commit 2838f84
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 46 deletions.
46 changes: 11 additions & 35 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4657,92 +4657,68 @@ LogicalResult AtenRenormOp::verify() {

auto selfType = cast<BaseTensorType>(getSelf().getType());

if (!selfType.hasDtype() || !selfType.hasSizes()) {
if (!selfType.hasDtype() || !selfType.hasSizes())
return success();
}

auto inShape = selfType.getSizes();
int64_t selfRank = inShape.size();
auto selfDtype = selfType.getDtype();

// Check if dtype is one of those supported by renorm operation.
// ComplexType will match any torch complex types, but each float must be
// checked individually.
if (selfDtype.isa<mlir::ComplexType>()) {
return emitOpError(
"lowering for complex type input tensor is currently unsuporrted");
}

if (!selfDtype.isa<mlir::Float16Type, mlir::BFloat16Type, mlir::Float32Type,
mlir::Float64Type, mlir::ComplexType>()) {
mlir::Float64Type, mlir::ComplexType>())
return emitOpError(
"expected a float or complex type for input tensor, but got ")
<< selfDtype;
}

// According to the Pytoch documentation tensor need to be at least rank 2
if (selfRank <= 1) {
if (selfRank <= 1)
return emitOpError("renorm: input needs at least 2 dimensions, got ")
<< selfRank << " dimensions";
}

// Check if argument p is valid
auto pType = getP().getType();

if (pType.isa<mlir::ComplexType>()) {
if (pType.isa<mlir::ComplexType>())
return emitOpError("renorm: p must be real-valued");
}

// The argument 'p' can be either an integer or a floating-point number,
// so we need to consider both options and check if 'p' is within the correct
// range
int64_t pInt = 1;
double_t pDouble = 1;
if (!matchPattern(getP(), m_TorchConstantInt(&pInt)) &&
!matchPattern(getP(), m_TorchConstantFloat(&pDouble))) {
!matchPattern(getP(), m_TorchConstantFloat(&pDouble)))
return emitOpError("renorm: p must be real-valued");
}

if (pInt <= 0 || pDouble <= 0) {
if (pInt <= 0 || pDouble <= 0)
return emitOpError("renorm: non-positive norm not supported");
}

// Check if argument maxnorm is valid
auto maxnormType = getMaxnorm().getType();
if (maxnormType.isa<mlir::ComplexType>()) {
if (maxnormType.isa<mlir::ComplexType>())
return emitOpError("renorm: maxnorm must be real-valued");
}

// The argument 'maxnorm' can be either an integer or a floating-point number,
// so we need to consider both options and check if 'maxnorm' is within the
// correct range
int64_t maxnormInt = 0;
double_t maxnormDouble = 0;
if (!matchPattern(getMaxnorm(), m_TorchConstantInt(&maxnormInt)) &&
!matchPattern(getMaxnorm(), m_TorchConstantFloat(&maxnormDouble))) {
!matchPattern(getMaxnorm(), m_TorchConstantFloat(&maxnormDouble)))
return emitOpError("renorm: maxnorm must be real-valued");
}

if (maxnormInt < 0 || maxnormDouble < 0) {
if (maxnormInt < 0 || maxnormDouble < 0)
return emitOpError("renorm: expected maxnorm to be >= 0");
}

// Get the dimension
int64_t dim;
if (!matchPattern(getDim(), m_TorchConstantInt(&dim))) {
if (!matchPattern(getDim(), m_TorchConstantInt(&dim)))
return emitOpError("dim must be a constant int");
}

// check if is dim is in the correct range
if (dim >= selfRank || dim < -selfRank) {
if (dim >= selfRank || dim < -selfRank)
return emitOpError("Dimension out of range (expected to be in range of [")
<< -selfRank << ", " << selfRank - 1 << "], but got " << dim;
}

// Canonicalize dimension if necessary
if (dim < 0) {
dim += selfRank;
}

return success();
}
Expand Down
23 changes: 12 additions & 11 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2087,8 +2087,9 @@ class DecomposeAtenRenormOp : public OpRewritePattern<AtenRenormOp> {
auto ndim = getTensorRank(self);
auto resType = cast<BaseTensorType>(self.getType());

if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
if (!resType.hasDtype() || !resType.hasSizes()) {
return rewriter.notifyMatchFailure(op,
"result should have dtype and sizes");
}

Type dtype = resType.getDtype();
Expand All @@ -2098,9 +2099,10 @@ class DecomposeAtenRenormOp : public OpRewritePattern<AtenRenormOp> {
"currently unimplemented");
}

ArrayRef<int64_t> inputSizeArrayRef = resType.getSizes();
SmallVector<int64_t> inputSize;
inputSize.assign(inputSizeArrayRef.begin(), inputSizeArrayRef.end());
SmallVector<int64_t> inputSize(resType.getSizes());

if (inputSize[0] == Torch::kUnknownSize)
return rewriter.notifyMatchFailure(op, "size of input tensor is unknown");

// Convert dim from Value to int
int64_t dimInt;
Expand All @@ -2118,13 +2120,12 @@ class DecomposeAtenRenormOp : public OpRewritePattern<AtenRenormOp> {
// ndim-1]
llvm::SmallVector<Value> reduceDimsVector;
for (u_int64_t i = 0; i < ndim; i++) {
if (i == (u_int64_t)dimInt)
continue;

Value constI = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));

if (i == (u_int64_t)dimInt) {
continue;
}

reduceDimsVector.push_back(constI);
}

Expand All @@ -2136,9 +2137,9 @@ class DecomposeAtenRenormOp : public OpRewritePattern<AtenRenormOp> {
// Make output shape for linalg.vector_norm operation
SmallVector<Value> inputSizeValue;
for (u_int64_t i = 0; i < inputSize.size(); i++) {
if (i != (u_int64_t)dimInt) {
if (i != (u_int64_t)dimInt)
inputSize[i] = 1;
}

inputSizeValue.push_back(
rewriter.create<Torch::ConstantIntOp>(loc, inputSize[i]));
}
Expand Down

0 comments on commit 2838f84

Please sign in to comment.