-
Notifications
You must be signed in to change notification settings - Fork 505
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add More Scalarize Shapes Patterns (#3810)
### new patterns: 1. Propagates `aten.broadcast_to` ops of a single value to an `aten.full` op 2. Propagates arithmetic operations through a templated class which associates some tensor arithmetic ops to their integer-scalar counterparts. These are a major blocker right now, since some models have a bunch of rank 0 arithmetic being done with tensor ops. See the lit test for an interesting example that pads an input to the smallest shape which will become divisible by twelve in `dim0`. If you think this is convoluted, you haven't been staring at ONNX generated IR long enough. 3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to `false`. See the comment in that conversion pattern for more justification as to why it is acceptable to make this assumption here. This is another major blocker for models, since this lack of folding propagates to lack of folding for subsequent `where.self` operations. 4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern` ### other changes: 1. Add two new anchor ops: `AtenArangeStartStepOp` and `Torch::RuntimeAssertOp`. I've checked all possible sources of the runtime assert ops and it is always shape related. The Arange op only takes int inputs, and these are all shape related. Adds a size check to getting a list from literal ops. 2. Improved folders for int arithmetic ops to fold some common patterns. 3. adds the ability to get some values from scalar-tensor ops to getListFromTensor. 4. further cleans up getListFromTensor for readability. ### points to scrutinize: 1. I made the choice to scalarize `div.Tensor` (int dtype result) to `floordiv.int`. This is because our shape computations involving this kind of arithmetic are never negative in practice, and we don't have a "round towards zero" scalar int divide counterpart. 2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if someone happens to add a runtime assert in the future that doesn't boil down to shapes, then it would add to the worklist considerably. We might be able to get around this by adding "NoMemoryEffect" to ops which are "ReadOnly" so that the inputs for the runtime asserts get cse'd with existing elements of the worklist before we even get to this pass.
- Loading branch information
Showing
4 changed files
with
330 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.