diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index 05f0a4558c52..b56d6bf0bf58 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index e712ebe2e72a..5528f884c1d4 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1864,14 +1864,30 @@ def aten〇_weight_norm_interface〡shape(v: List[int], g: List[int], dim: int = def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]: start_val = start if start is not None else 0 end_val = end if end is not None else upstream_shape_functions.max_int() - if (start_val < 0): - start_val += self[dim] - if (end_val < 0): - end_val += self[dim] - nstart = start_val if (step > 0) else end_val + 1 - nend = end_val if (step > 0) else start_val + 1 - nstep = step if (step > 0) else -step - return upstream_shape_functions.slice(self,dim,nstart,nend,nstep) + if (step < 0): + # Convert to equivalent postive-step parameters, which will require swapping start and end. + # If the parameters are in the normal range (0 <= start < d and -1 <= end <= start), then + # swapped_end = start + 1 and swapped_begin = end + 1. + # The shift of inclusion can cause issues if these parameters are not already resolved on the left. + # e.g. start = -1, end = -3 . So valid start is actually d-1, and valid end is d-3. Therefore, we + # should have swapped_end = d, but adding 1 to start before making it valid would result in an + # incorrect, but "valid", swapped_end = 0 for forward slicing. + # Additionally, if adding d doesn't make these values positive, but adding twice would, we need + # to clamp after resolving, otherwise the upstream function will try to resolve a second time. + if start_val < 0: + start_val += self[dim] + if start_val < 0: + start_val = 0 + if end_val < 0: + end_val += self[dim] + if end_val < 0: + end_val = -1 + + tmp = end_val + 1 + end_val = start_val + 1 + start_val = tmp + step = -step + return upstream_shape_functions.slice(self,dim,start_val,end_val,step) def aten〇as_strided〡shape(self: List[int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> List[int]: