Skip to content

Commit

Permalink
[FxImporter] Fix sympy_int_to_int utility (llvm#3657)
Browse files Browse the repository at this point in the history
New sympy type is introduced to represent integer infinity in upstream
PyTorch repo. Subsequently, sympy.oo is no longer used to represent
infinity upper bound for dynamic dimensions where the upper bound is
unknown. Instead `int_oo` is used to represent integer infinity. This
commit updates the `_sympy_int_to_int` utility in light of this change.
  • Loading branch information
patel-vimal authored Aug 26, 2024
1 parent f9766c8 commit fa39d91
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 13 deletions.
4 changes: 4 additions & 0 deletions python/TorchMLIRModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,8 @@ PYBIND11_MODULE(_torchMlir, m) {
}
},
py::arg("context"), py::arg("load") = true);

m.def("get_int64_max", []() { return INT64_MAX; });

m.def("get_int64_min", []() { return INT64_MIN; });
}
34 changes: 28 additions & 6 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@
# conditional.
ml_dtypes = None

try:
from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity
except ModuleNotFoundError:
# This commit on PyTorch repo introduced IntInfinity and NegativeIntInfinity:
# https://github.com/pytorch/pytorch/commit/2229884102ac95c9dda0aeadbded1b04295d892e
# Required module may not be present in the stable version of PyTorch.
int_oo = None
IntInfinity = None
NegativeIntInfinity = None

from torch.fx.node import (
Argument as NodeArgument,
)
Expand Down Expand Up @@ -125,6 +135,8 @@
func as func_dialect,
)

from .._mlir_libs._torchMlir import get_int64_max, get_int64_min

__all__ = [
"FxImporter",
]
Expand Down Expand Up @@ -1165,22 +1177,32 @@ def set_symbolic_guards(
self, prog: torch.export.ExportedProgram
) -> Dict[str, RangeConstraint]:

# Recent PyTorch versions use `int_oo` to represent integer infinity.
# Older PyTorch versions like PyTorch stable version may not have
# `int_oo` defined just yet.
infs = (sympy.oo, int_oo) if int_oo is not None else (sympy.oo,)

def _sympy_int_to_int(val: sympy.Expr, adjust_func: Callable):
# Convert simple sympy Integers into concrete int
if val == sympy.oo:
return math.inf
if val == -sympy.oo:
return -math.inf
if val in infs:
return get_int64_max()
if val in tuple(-inf for inf in infs):
return get_int64_min()
if isinstance(val, sympy.Integer):
return int(val)
# TODO: Remove this adjustment when fractional ranges are removed
return adjust_func(val)

contains_symbolic_ints = False
sym_int_types = (
(sympy.Integer, IntInfinity, NegativeIntInfinity)
if IntInfinity is not None
else sympy.Integer
)
for val in prog.range_constraints.values():
if (
isinstance(val.lower, sympy.Integer)
and isinstance(val.upper, sympy.Integer)
isinstance(val.lower, sym_int_types)
and isinstance(val.upper, sym_int_types)
and not val.is_bool
):
contains_symbolic_ints = True
Expand Down
16 changes: 9 additions & 7 deletions test/python/fx_importer/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,13 @@ def forward(self, x):

@run
# CHECK-LABEL: test_import_frozen_exported_program_with_dynamic_shapes
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,4],f32>
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,5],f32>) -> !torch.vtensor<[?,?,5],f32>
# CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32>
# CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,4],f32> -> !torch.vtensor<[?,4],f32>
# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32>
# CHECK: return %[[TANH]] : !torch.vtensor<[?,4],f32>
# CHECK: %[[S1:.*]] = torch.symbolic_int "s1" {min_val = 2, max_val = {{[0-9]+}}} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 5)> : !torch.vtensor<[?,?,5],f32>
# CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?,5],f32> -> !torch.vtensor<[?,?,5],f32>
# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 5)> : !torch.vtensor<[?,?,5],f32>
# CHECK: return %[[TANH]] : !torch.vtensor<[?,?,5],f32>
def test_import_frozen_exported_program_with_dynamic_shapes():
class Basic(nn.Module):
def __init__(self):
Expand All @@ -103,10 +104,11 @@ def forward(self, x):
return torch.tanh(x)

batch = Dim("batch", max=10)
dynamic_shapes = {"x": {0: batch}}
channel = Dim("channel", min=2)
dynamic_shapes = {"x": {0: batch, 1: channel}}
m = fx.export_and_import(
Basic(),
torch.randn(3, 4),
torch.randn(3, 4, 5),
dynamic_shapes=dynamic_shapes,
func_name="test_net",
import_symbolic_shape_expressions=True,
Expand Down

0 comments on commit fa39d91

Please sign in to comment.