Skip to content

Commit

Permalink
[TK] Subgraph Tracing to support control flow (#350)
Browse files Browse the repository at this point in the history
This PR supports tracing a function within a separate tracer as a
"subgraph". This allows us to trace loop bodies instead of unrolling
them.

Along with subgraph tracing, this patch adds several features to enable
tracing a gemm kernel:

- Instead of doing slicing, add explicit tkl.load/tkl.store ops.
Personally, I feel slicing may not be the way to go forward as we always
have a constant sized output from the slice. If we go with slicing, we
have to analyze if it's constant sized which is not worth it.
- Add support for tkl.constant, tkl.dot, tkl.for_loop.
- Add a new "Vector" class, which is a tensor like class supporting
computations over it. I'm not a fan of using pytorch ops directly since
I don't get control over the op signature.

I did not add support for eager executing these operations, only compile
mode. All of these newly added ops can be eagerly executed, just haven't
added the support in this patch.

---------

Co-authored-by: Stella Laurenzo <[email protected]>
  • Loading branch information
Groverkss and stellaraccident authored Jan 20, 2024
1 parent 6b21267 commit 86653a4
Show file tree
Hide file tree
Showing 18 changed files with 896 additions and 65 deletions.
2 changes: 1 addition & 1 deletion python/shark_turbine/kernel/_support/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class NotSetType:
class ElementType(ABC):
@staticmethod
def cast(something) -> "ElementType":
if isinstance(something, torch.dtyp):
if isinstance(something, torch.dtype):
return TorchElementType(something)
else:
raise TypeError(
Expand Down
161 changes: 161 additions & 0 deletions python/shark_turbine/kernel/_support/regions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from typing import (
Optional,
TypeVar,
Callable,
Type,
assert_type,
cast,
List,
Dict,
Tuple,
)
import random
import contextlib

import torch.fx as fx
import torch.utils._pytree as pytree


class RegionGraph:
def __init__(self):
self.tracers: List["SubgraphTracer"] = []
self.subgraphs: Dict[str, fx.Graph] = dict()
self.inner_freevars: Dict[fx.Graph, List[fx.Proxy]] = dict()

@property
def root_tracer(self) -> "SubgraphTracer":
return self.tracers[0]

@property
def current_tracer(self) -> "SubgraphTracer":
return self.tracers[-1]

def create_proxy(self, *args, **kwargs):
return self.current_tracer.create_proxy(*args, **kwargs)

def create_node(self, *args, **kwargs):
return self.current_tracer.create_node(*args, **kwargs)

def create_arg(self, *args, **kwargs):
return self.current_tracer.create_arg(*args, **kwargs)

def new_subtracer(
self, region_graph: "RegionGraph", parent: Optional["SubgraphTracer"] = None
) -> "SubgraphTracer":
...

### ========================================================================
### Subgraph Tracing
### ========================================================================
def add_subgraph(
self, name: str, graph: fx.Graph, inner_freevars: List[fx.Proxy]
) -> str:
i = 0
while True:
candidate_name = f"{name}_{i}"
i += 1
if candidate_name not in self.subgraphs:
self.subgraphs[candidate_name] = graph
self.inner_freevars[graph] = inner_freevars
return candidate_name

@contextlib.contextmanager
def subtracer(self):
if self.tracers:
new_tracer = self.new_subtracer(self, self.current_tracer)
else:
new_tracer = self.new_subtracer(self)
self.tracers.append(new_tracer)
yield new_tracer
self.tracers.pop()

def __str__(self):
out = ""
for name, subgraph in self.subgraphs.items():
out += f"{name}:"
out += str(subgraph)
out += "\n"
return out


class SubgraphTracer(fx.Tracer):
def __init__(
self, region_graph: RegionGraph, parent: Optional["SubgraphTracer"] = None
):
super().__init__()
self.graph = fx.Graph()
self.region_graph = region_graph
self.parent = parent
self.lifted_freevars: Dict[fx.Proxy, fx.Proxy] = {}

def trace(self, *args, **kwargs) -> Tuple[str, List[fx.Proxy]]:
traced = super().trace(*args, **kwargs)
inner_freevars = list(self.lifted_freevars.values())
implicit_capture = list(self.lifted_freevars.keys())
subgraph_name = self.region_graph.add_subgraph("region", traced, inner_freevars)
return subgraph_name, implicit_capture

def _create_graph_input(self, name: str, type_expr=None) -> fx.Proxy:
proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
# Can use this to check where the freevar has been lifted from.
proxy.node.meta["lifted"] = None
return proxy

def _lift_tracked_freevar_to_input(self, proxy: fx.Proxy):
# It makes no sense for the root graph to have free variables
assert self.parent is not None, "Cannot lift freevars to input in root tracer"

# If the freevar has already been lifted, return the lifted version.
if proxy in self.lifted_freevars:
return self.lifted_freevars[proxy]

# Otherwise, create a new input and store it.
new_proxy = self._create_graph_input(proxy.node.name, proxy.node.type)
self.lifted_freevars[proxy] = new_proxy

# Propagate freevar usage upwards.
if self.parent is not None and proxy.tracer != self.parent:
self.parent._lift_tracked_freevar_to_input(proxy)
return new_proxy

def _maybe_lift_tracked_freevar_to_input(self, arg):
"""
If arg is a free variable, then lift it to be an input.
Returns the new lifted arg (if lifted), else the original arg.
"""
if not isinstance(arg, fx.Proxy):
return arg
elif arg.tracer == self:
return arg
else:
return self._lift_tracked_freevar_to_input(arg)

def create_proxy(
self,
kind,
target,
args,
kwargs,
name=None,
type_expr=None,
proxy_factor_fn=None,
):
if self.parent is not None:
flat_args, tree_spec = pytree.tree_flatten((args, kwargs))
new_flat_args = []
for arg in flat_args:
maybe_new_arg = self._maybe_lift_tracked_freevar_to_input(arg)
new_flat_args.append(maybe_new_arg)
args, kwargs = pytree.tree_unflatten(new_flat_args, tree_spec)

rv = super().create_proxy(
kind,
target,
args,
kwargs,
name,
type_expr,
proxy_factor_fn,
)

return rv
141 changes: 131 additions & 10 deletions python/shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
from abc import ABC, abstractmethod
from typing import Optional, TypeVar, Callable, Type, assert_type, cast
from typing import (
Optional,
TypeVar,
Callable,
Type,
assert_type,
cast,
List,
Dict,
Tuple,
Any,
)

import functools
import warnings
import contextlib
import torch.utils._pytree as pytree
import random

import torch.fx as fx

Expand All @@ -15,8 +29,12 @@

from ..lang.types import (
Index,
Vector,
)

from .regions import RegionGraph, SubgraphTracer


from .. import ops
from ..ops.base import (
OpDispatcher,
Expand All @@ -26,6 +44,20 @@

TCallable = TypeVar("TCallable", bound=Callable)

###############################################################################
# Kernel Region Graph
###############################################################################


class KernelRegionGraph(RegionGraph):
def new_subtracer(
self,
region_graph: "RegionGraph",
parent: Optional["SubgraphTracer"] = None,
) -> "KernelTracer":
return KernelTracer(region_graph, parent=parent)


###############################################################################
# Tracing machinery
###############################################################################
Expand All @@ -35,7 +67,10 @@ class KernelBufferProxy(fx.Proxy):
"""Custom proxy for KernelBuffer so that we can override special methods."""

def __init__(
self, node: fx.Node, tracer: "KernelTracer", orig_type: Type[KernelBuffer]
self,
node: fx.Node,
tracer: "KernelTracer",
orig_type: Type[KernelBuffer],
):
super().__init__(node, tracer)
self._orig_type = orig_type
Expand All @@ -50,9 +85,10 @@ def __setitem__(self, key, item):
ops.kernel_buffer_setitem(self, key, item)


class KernelTracer(fx.Tracer):
class KernelTracer(SubgraphTracer):
"""Custom Tracer for generating a trace of a kernel computation."""

# Register our custom proxies.
def proxy(self, node: fx.Node) -> fx.Proxy:
t = node.type
if t is not None and issubclass(t, KernelBuffer):
Expand All @@ -61,8 +97,15 @@ def proxy(self, node: fx.Node) -> fx.Proxy:


class CapturedTrace:
def __init__(self, gm: fx.GraphModule):
self.gm = gm
def __init__(self, region_graph: RegionGraph, root_graph: str):
self.region_graph = region_graph
self.root_graph = root_graph

def get_subgraph(self, name: str) -> fx.Graph:
return self.region_graph.subgraphs[name]

def get_root_graph(self) -> fx.Graph:
return self.get_subgraph(self.root_graph)


###############################################################################
Expand Down Expand Up @@ -109,18 +152,22 @@ def handle_kernel_buffer_setitem(self, op, kernel_buffer: KernelBuffer, key, ite


class CompiledContext(BaseContext):
def __init__(self, tracer: KernelTracer, *, grid_type: Type[Grid]):
def __init__(self, region_graph: RegionGraph, *, grid_type: Type[Grid]):
super().__init__(eager=False)
self.tracer = tracer
self.region_graph = region_graph
self.grid_type = grid_type

### ========================================================================
### Core Operations
### ========================================================================

def handle_thread_program_id(self, op, axis: int) -> Index:
grid_shape = self.grid_type.symbolic_shape
if axis < 0 or axis >= len(grid_shape):
raise IndexError(
f"Illegal index into grid of rank {len(grid_shape)}: {axis}"
)
proxy = self.tracer.create_proxy(
proxy = self.region_graph.create_proxy(
"call_function",
op,
args=(axis,),
Expand All @@ -130,21 +177,95 @@ def handle_thread_program_id(self, op, axis: int) -> Index:
return proxy

def handle_kernel_buffer_getitem(self, op, kernel_buffer: KernelBuffer, key):
return self.tracer.create_proxy(
return self.region_graph.create_proxy(
"call_function",
op,
args=(kernel_buffer, key),
kwargs={},
)

def handle_kernel_buffer_setitem(self, op, kernel_buffer: KernelBuffer, key, item):
self.tracer.create_proxy(
self.region_graph.create_proxy(
"call_function",
target=op,
args=(kernel_buffer, key, item),
kwargs={},
)

### ========================================================================
### Memory Operations
### ========================================================================
def handle_kernel_buffer_load(self, op, kernel_buffer, multi_index, shape):
return self.region_graph.create_proxy(
"call_function",
target=op,
args=(kernel_buffer, multi_index, shape),
kwargs={},
)

def handle_kernel_buffer_store(self, op, kernel_buffer, multi_index, item):
self.region_graph.create_proxy(
"call_function",
target=op,
args=(kernel_buffer, multi_index, item),
kwargs={},
)

### ========================================================================
### Control Flow Operations
### ========================================================================

def handle_for_loop(self, op, start, stop=None, step=None, init_args=[]):
if stop is None:
stop = start
start = 0
if step is None:
step = 1

def wrapper(f):
with self.region_graph.subtracer() as subtracer:
subgraph_name, implicit_capture = subtracer.trace(f)
# Create a call to this subgraph
ret = self.region_graph.create_proxy(
"call_function",
target=op,
name="for_loop",
args=(start, stop, step, init_args),
kwargs={
"subgraph": subgraph_name,
"implicit_capture": implicit_capture,
},
)
return ret

return wrapper

### ========================================================================
### Math Operations
### ========================================================================

def handle_vector_constant(
self, op, shape: Tuple[int, ...], dtype, value: int | float
):
return self.region_graph.create_proxy(
"call_function",
target=op,
args=(shape, dtype, value),
kwargs={},
)

### ========================================================================
### Reduction Operations
### ========================================================================

def handle_vector_dot(self, op, lhs, rhs, acc):
return self.region_graph.create_proxy(
"call_function",
target=op,
args=(lhs, rhs, acc),
kwargs={},
)


###############################################################################
# Launch context
Expand Down
Loading

0 comments on commit 86653a4

Please sign in to comment.