-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[custom ops] Begin the scaffolding for dispatch of PyTorch custom ops. (
#270) This makes it possible for us to directly define regular torch ops in terms of generated MLIR. The resulting ops will be specialized and cached per requirements in their definition and will be compiled for any device that Turbine supports when dispatched against tensors on that device. It is left to a follow-up to also wire this mechanism in on the AOT side so that compiling programs that contain our own custom ops transparently includes them with no further glue. The scaffolding for this is in place, but this patch is big enough without touching AOT. This allows users to say something like: ``` @CustomOp.register class identity(CustomOp): name = "test_identity" signature = "(Tensor self) -> Tensor" def select(self, ksel: KernelSelection): x = ksel.arg_tensor(0) ksel.return_tensor(x.t) def generate(self, ksel: KernelSelection, kb: KernelBuilder): # This just yields the IR value of kernel input as the output. # Effectively in eager mode, this is a `return` from the kernel # function. kb.yield_results(kb.arg_bindings[0]) t = torch.tensor([[1, 2, 3]], dtype=torch.int32) result = identity(t) print("CPU result:", result) torch.testing.assert_close(result, t) ``` There will be dedicated `CustomOp` subclasses for our various DSLs that can be used for such things (for more sugar'd use than just open coding IR).
- Loading branch information
1 parent
5d9d08b
commit 68df316
Showing
28 changed files
with
1,490 additions
and
319 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,7 +46,7 @@ jobs: | |
- name: Run tests | ||
run: | | ||
pytest tests/ | ||
pytest -n 4 tests/ | ||
black: | ||
strategy: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
import functools | ||
import sys | ||
|
||
from ..device import ( | ||
from ...runtime.device import ( | ||
DeviceState, | ||
) | ||
|
||
|
Oops, something went wrong.