Skip to content

Commit

Permalink
Add an include and modify shape inference
Browse files Browse the repository at this point in the history
  • Loading branch information
zjgarvey committed Nov 1, 2024
1 parent 2291f25 commit bd7cfa1
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1864,14 +1864,30 @@ def aten〇_weight_norm_interface〡shape(v: List[int], g: List[int], dim: int =
def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
start_val = start if start is not None else 0
end_val = end if end is not None else upstream_shape_functions.max_int()
if (start_val < 0):
start_val += self[dim]
if (end_val < 0):
end_val += self[dim]
nstart = start_val if (step > 0) else end_val + 1
nend = end_val if (step > 0) else start_val + 1
nstep = step if (step > 0) else -step
return upstream_shape_functions.slice(self,dim,nstart,nend,nstep)
if (step < 0):
# Convert to equivalent postive-step parameters, which will require swapping start and end.
# If the parameters are in the normal range (0 <= start < d and -1 <= end <= start), then
# swapped_end = start + 1 and swapped_begin = end + 1.
# The shift of inclusion can cause issues if these parameters are not already resolved on the left.
# e.g. start = -1, end = -3 . So valid start is actually d-1, and valid end is d-3. Therefore, we
# should have swapped_end = d, but adding 1 to start before making it valid would result in an
# incorrect, but "valid", swapped_end = 0 for forward slicing.
# Additionally, if adding d doesn't make these values positive, but adding twice would, we need
# to clamp after resolving, otherwise the upstream function will try to resolve a second time.
if start_val < 0:
start_val += self[dim]
if start_val < 0:
start_val = 0
if end_val < 0:
end_val += self[dim]
if end_val < 0:
end_val = -1

tmp = end_val + 1
end_val = start_val + 1
start_val = tmp
step = -step
return upstream_shape_functions.slice(self,dim,start_val,end_val,step)


def aten〇as_strided〡shape(self: List[int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> List[int]:
Expand Down

0 comments on commit bd7cfa1

Please sign in to comment.