Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TK] Subgraph Tracing to support control flow #350

Merged
merged 10 commits into from
Jan 20, 2024
2 changes: 1 addition & 1 deletion python/shark_turbine/kernel/_support/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class NotSetType:
class ElementType(ABC):
@staticmethod
def cast(something) -> "ElementType":
if isinstance(something, torch.dtyp):
if isinstance(something, torch.dtype):
return TorchElementType(something)
else:
raise TypeError(
Expand Down
161 changes: 161 additions & 0 deletions python/shark_turbine/kernel/_support/regions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from typing import (
Optional,
TypeVar,
Callable,
Type,
assert_type,
cast,
List,
Dict,
Tuple,
)
import random
import contextlib

import torch.fx as fx
import torch.utils._pytree as pytree


class RegionGraph:
def __init__(self):
self.tracers: List["SubgraphTracer"] = []
self.subgraphs: Dict[str, fx.Graph] = dict()
self.inner_freevars: Dict[fx.Graph, List[fx.Proxy]] = dict()

@property
def root_tracer(self) -> "SubgraphTracer":
return self.tracers[0]

@property
def current_tracer(self) -> "SubgraphTracer":
return self.tracers[-1]

def create_proxy(self, *args, **kwargs):
return self.current_tracer.create_proxy(*args, **kwargs)

def create_node(self, *args, **kwargs):
return self.current_tracer.create_node(*args, **kwargs)

def create_arg(self, *args, **kwargs):
return self.current_tracer.create_arg(*args, **kwargs)

def new_subtracer(
self, region_graph: "RegionGraph", parent: Optional["SubgraphTracer"] = None
) -> "SubgraphTracer":
...

### ========================================================================
### Subgraph Tracing
### ========================================================================
def add_subgraph(
self, name: str, graph: fx.Graph, inner_freevars: List[fx.Proxy]
) -> str:
i = 0
while True:
candidate_name = f"{name}_{i}"
i += 1
if candidate_name not in self.subgraphs:
self.subgraphs[candidate_name] = graph
self.inner_freevars[graph] = inner_freevars
return candidate_name

@contextlib.contextmanager
def subtracer(self):
if self.tracers:
new_tracer = self.new_subtracer(self, self.current_tracer)
else:
new_tracer = self.new_subtracer(self)
self.tracers.append(new_tracer)
yield new_tracer
self.tracers.pop()

def __str__(self):
out = ""
for name, subgraph in self.subgraphs.items():
out += f"{name}:"
out += str(subgraph)
out += "\n"
return out


class SubgraphTracer(fx.Tracer):
def __init__(
self, region_graph: RegionGraph, parent: Optional["SubgraphTracer"] = None
):
super().__init__()
self.graph = fx.Graph()
self.region_graph = region_graph
self.parent = parent
self.lifted_freevars: Dict[fx.Proxy, fx.Proxy] = {}

def trace(self, *args, **kwargs) -> Tuple[str, List[fx.Proxy]]:
traced = super().trace(*args, **kwargs)
inner_freevars = list(self.lifted_freevars.values())
implicit_capture = list(self.lifted_freevars.keys())
subgraph_name = self.region_graph.add_subgraph("region", traced, inner_freevars)
return subgraph_name, implicit_capture

def _create_graph_input(self, name: str, type_expr=None) -> fx.Proxy:
proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
# Can use this to check where the freevar has been lifted from.
proxy.node.meta["lifted"] = None
return proxy

def _lift_tracked_freevar_to_input(self, proxy: fx.Proxy):
# It makes no sense for the root graph to have free variables
assert self.parent is not None, "Cannot lift freevars to input in root tracer"

# If the freevar has already been lifted, return the lifted version.
if proxy in self.lifted_freevars:
return self.lifted_freevars[proxy]

# Otherwise, create a new input and store it.
new_proxy = self._create_graph_input(proxy.node.name, proxy.node.type)
self.lifted_freevars[proxy] = new_proxy

# Propagate freevar usage upwards.
if self.parent is not None and proxy.tracer != self.parent:
self.parent._lift_tracked_freevar_to_input(proxy)
return new_proxy

def _maybe_lift_tracked_freevar_to_input(self, arg):
"""
If arg is a free variable, then lift it to be an input.
Returns the new lifted arg (if lifted), else the original arg.
"""
if not isinstance(arg, fx.Proxy):
return arg
elif arg.tracer == self:
return arg
else:
return self._lift_tracked_freevar_to_input(arg)

def create_proxy(
self,
kind,
target,
args,
kwargs,
name=None,
type_expr=None,
proxy_factor_fn=None,
):
if self.parent is not None:
flat_args, tree_spec = pytree.tree_flatten((args, kwargs))
new_flat_args = []
for arg in flat_args:
maybe_new_arg = self._maybe_lift_tracked_freevar_to_input(arg)
new_flat_args.append(maybe_new_arg)
args, kwargs = pytree.tree_unflatten(new_flat_args, tree_spec)

rv = super().create_proxy(
kind,
target,
args,
kwargs,
name,
type_expr,
proxy_factor_fn,
)

return rv
141 changes: 131 additions & 10 deletions python/shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
from abc import ABC, abstractmethod
from typing import Optional, TypeVar, Callable, Type, assert_type, cast
from typing import (
Optional,
TypeVar,
Callable,
Type,
assert_type,
cast,
List,
Dict,
Tuple,
Any,
)

import functools
import warnings
import contextlib
import torch.utils._pytree as pytree
import random

import torch.fx as fx

Expand All @@ -15,8 +29,12 @@

from ..lang.types import (
Index,
Vector,
)

from .regions import RegionGraph, SubgraphTracer


from .. import ops
from ..ops.base import (
OpDispatcher,
Expand All @@ -26,6 +44,20 @@

TCallable = TypeVar("TCallable", bound=Callable)

###############################################################################
# Kernel Region Graph
###############################################################################


class KernelRegionGraph(RegionGraph):
def new_subtracer(
self,
region_graph: "RegionGraph",
parent: Optional["SubgraphTracer"] = None,
) -> "KernelTracer":
return KernelTracer(region_graph, parent=parent)


###############################################################################
# Tracing machinery
###############################################################################
Expand All @@ -35,7 +67,10 @@ class KernelBufferProxy(fx.Proxy):
"""Custom proxy for KernelBuffer so that we can override special methods."""

def __init__(
self, node: fx.Node, tracer: "KernelTracer", orig_type: Type[KernelBuffer]
self,
node: fx.Node,
tracer: "KernelTracer",
orig_type: Type[KernelBuffer],
):
super().__init__(node, tracer)
self._orig_type = orig_type
Expand All @@ -50,9 +85,10 @@ def __setitem__(self, key, item):
ops.kernel_buffer_setitem(self, key, item)


class KernelTracer(fx.Tracer):
class KernelTracer(SubgraphTracer):
"""Custom Tracer for generating a trace of a kernel computation."""

# Register our custom proxies.
def proxy(self, node: fx.Node) -> fx.Proxy:
t = node.type
if t is not None and issubclass(t, KernelBuffer):
Expand All @@ -61,8 +97,15 @@ def proxy(self, node: fx.Node) -> fx.Proxy:


class CapturedTrace:
def __init__(self, gm: fx.GraphModule):
self.gm = gm
def __init__(self, region_graph: RegionGraph, root_graph: str):
self.region_graph = region_graph
self.root_graph = root_graph

def get_subgraph(self, name: str) -> fx.Graph:
return self.region_graph.subgraphs[name]

def get_root_graph(self) -> fx.Graph:
return self.get_subgraph(self.root_graph)


###############################################################################
Expand Down Expand Up @@ -109,18 +152,22 @@ def handle_kernel_buffer_setitem(self, op, kernel_buffer: KernelBuffer, key, ite


class CompiledContext(BaseContext):
def __init__(self, tracer: KernelTracer, *, grid_type: Type[Grid]):
def __init__(self, region_graph: RegionGraph, *, grid_type: Type[Grid]):
super().__init__(eager=False)
self.tracer = tracer
self.region_graph = region_graph
self.grid_type = grid_type

### ========================================================================
### Core Operations
### ========================================================================

def handle_thread_program_id(self, op, axis: int) -> Index:
grid_shape = self.grid_type.symbolic_shape
if axis < 0 or axis >= len(grid_shape):
raise IndexError(
f"Illegal index into grid of rank {len(grid_shape)}: {axis}"
)
proxy = self.tracer.create_proxy(
proxy = self.region_graph.create_proxy(
"call_function",
op,
args=(axis,),
Expand All @@ -130,21 +177,95 @@ def handle_thread_program_id(self, op, axis: int) -> Index:
return proxy

def handle_kernel_buffer_getitem(self, op, kernel_buffer: KernelBuffer, key):
return self.tracer.create_proxy(
return self.region_graph.create_proxy(
"call_function",
op,
args=(kernel_buffer, key),
kwargs={},
)

def handle_kernel_buffer_setitem(self, op, kernel_buffer: KernelBuffer, key, item):
self.tracer.create_proxy(
self.region_graph.create_proxy(
"call_function",
target=op,
args=(kernel_buffer, key, item),
kwargs={},
)

### ========================================================================
### Memory Operations
### ========================================================================
def handle_kernel_buffer_load(self, op, kernel_buffer, multi_index, shape):
return self.region_graph.create_proxy(
"call_function",
target=op,
args=(kernel_buffer, multi_index, shape),
kwargs={},
)

def handle_kernel_buffer_store(self, op, kernel_buffer, multi_index, item):
self.region_graph.create_proxy(
"call_function",
target=op,
args=(kernel_buffer, multi_index, item),
kwargs={},
)

### ========================================================================
### Control Flow Operations
### ========================================================================

def handle_for_loop(self, op, start, stop=None, step=None, init_args=[]):
if stop is None:
stop = start
start = 0
if step is None:
step = 1

def wrapper(f):
with self.region_graph.subtracer() as subtracer:
subgraph_name, implicit_capture = subtracer.trace(f)
# Create a call to this subgraph
ret = self.region_graph.create_proxy(
"call_function",
target=op,
name="for_loop",
args=(start, stop, step, init_args),
kwargs={
"subgraph": subgraph_name,
"implicit_capture": implicit_capture,
},
)
return ret

return wrapper

### ========================================================================
### Math Operations
### ========================================================================

def handle_vector_constant(
self, op, shape: Tuple[int, ...], dtype, value: int | float
):
return self.region_graph.create_proxy(
"call_function",
target=op,
args=(shape, dtype, value),
kwargs={},
)

### ========================================================================
### Reduction Operations
### ========================================================================

def handle_vector_dot(self, op, lhs, rhs, acc):
return self.region_graph.create_proxy(
"call_function",
target=op,
args=(lhs, rhs, acc),
kwargs={},
)


###############################################################################
# Launch context
Expand Down
Loading
Loading