Skip to content

Commit

Permalink
Fix handling of !torch.number in abstract interpretation library (#…
Browse files Browse the repository at this point in the history
…2309)

In PyTorch, the `NumberType` is equal to `Union[int, float,
complex]`. However, the abstract interpretation library was treating
the `NumberType` as `Union[int, float]`, resulting in type mismatches
when reifying certain dtype functions. This commit fixes the type
inconsistency by having the abstract interpretation functions take as
an input a `Union[int, float, complex]` for the ops that take
`!torch.number` inputs.
  • Loading branch information
ramiro050 authored Jul 17, 2023
1 parent 5706697 commit 718f53f
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 137 deletions.
158 changes: 79 additions & 79 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,17 @@ FailureOr<Value> Torch::adjustFunctionArg(
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}

// !torch.union<int, float> or !torch.union<int, float, none> is the type used
// for (optional) `Scalar` inputs. At compile time, such inputs will usually
// be resolved to an `int` or a `float` so we need to derefine to match the
// library function signature.
// The type `!torch.number` can be an `int`, `float`, or `complex`.
// TODO: Add a new type `Torch::ComplexType` to handle the complex case.
if (desiredType.isa<Torch::NumberType>() &&
operandType.isa<Torch::IntType, Torch::FloatType>()) {
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}

// !torch.union<int, float, none> is the type used for optional
// `Scalar` inputs. At compile time, such inputs will usually be
// resolved to an `int`, `float`, or `None` so we need to derefine
// to match the library function signature.
if (auto unionType = desiredType.dyn_cast<Torch::UnionType>()) {
if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) {
return containedType
Expand Down
Loading

0 comments on commit 718f53f

Please sign in to comment.