Skip to content

Commit

Permalink
feat: Add support for require_full_compilation in Dynamo (pytorch#2138
Browse files Browse the repository at this point in the history
)
  • Loading branch information
gs-olive authored Aug 29, 2023
1 parent b9c8578 commit 8efbee9
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 15 deletions.
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
USE_PYTHON_RUNTIME = False
USE_FAST_PARTITIONER = True
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
REQUIRE_FULL_COMPILATION = False


def default_device() -> Device:
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
OPTIMIZATION_LEVEL,
PASS_THROUGH_BUILD_FAILURES,
PRECISION,
REQUIRE_FULL_COMPILATION,
TRUNCATE_LONG_AND_DOUBLE,
USE_FAST_PARTITIONER,
USE_PYTHON_RUNTIME,
Expand Down Expand Up @@ -57,3 +58,4 @@ class CompilationSettings:
use_fast_partitioner: bool = USE_FAST_PARTITIONER
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
device: Device = field(default_factory=default_device)
require_full_compilation: bool = REQUIRE_FULL_COMPILATION
10 changes: 7 additions & 3 deletions py/torch_tensorrt/dynamo/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
OPTIMIZATION_LEVEL,
PASS_THROUGH_BUILD_FAILURES,
PRECISION,
REQUIRE_FULL_COMPILATION,
TRUNCATE_LONG_AND_DOUBLE,
USE_FAST_PARTITIONER,
USE_PYTHON_RUNTIME,
Expand Down Expand Up @@ -57,7 +58,7 @@ def compile(
dla_global_dram_size: int = 536870912,
calibrator: object = None,
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
require_full_compilation: bool = False,
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Optional[List[str]] = None,
torch_executed_modules: Optional[List[str]] = None,
Expand All @@ -80,8 +81,10 @@ def compile(
"The Dynamo backend is an experimental feature, for which only the "
"following arguments are supported: "
"{enabled_precisions, debug, workspace_size, min_block_size, "
"torch_executed_ops, pass_through_build_failures, use_fast_partitioner, "
"enable_experimental_decompositions}"
"max_aux_streams, version_compatible, optimization_level, "
"torch_executed_ops, pass_through_build_failures, "
"use_fast_partitioner, enable_experimental_decompositions, "
"require_full_compilation}"
)

if not isinstance(inputs, collections.abc.Sequence):
Expand Down Expand Up @@ -126,6 +129,7 @@ def compile(
"truncate_long_and_double": truncate_long_and_double,
"use_fast_partitioner": use_fast_partitioner,
"enable_experimental_decompositions": enable_experimental_decompositions,
"require_full_compilation": require_full_compilation,
}

settings = CompilationSettings(**compilation_options)
Expand Down
55 changes: 47 additions & 8 deletions py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
_SplitterSettingBase,
)
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS, NodeSet
from torch_tensorrt.dynamo._defaults import DEBUG, MIN_BLOCK_SIZE
from torch_tensorrt.dynamo._defaults import (
DEBUG,
MIN_BLOCK_SIZE,
REQUIRE_FULL_COMPILATION,
)
from torch_tensorrt.dynamo.conversion.converter_registry import (
DYNAMO_CONVERTERS as CONVERTERS,
)
Expand Down Expand Up @@ -92,6 +96,7 @@ class TRTPartitioner(_SplitterBase): # type: ignore
allowed_single_node_partition_ops: Nodes which can be included in single-node partitons.
Generally useful for module-level exclusion ops which are intensive despite being single functions
min_block_size: Minimum number of computational operators per block
require_full_compilation: Require that all computational operators be run in TRT
Returns:
torch.fx.GraphModule
"""
Expand All @@ -104,6 +109,7 @@ def __init__(
Collection[str]
] = DEFAULT_SINGLE_NODE_PARTITIONS,
min_block_size: int = MIN_BLOCK_SIZE,
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
):
"""
Preprocesses graph before splitting:
Expand Down Expand Up @@ -142,6 +148,7 @@ def __init__(

self.num_trt_accelerated_subgraphs: Optional[int] = None
self.allowed_single_node_partition_ops = allowed_single_node_partition_ops
self.require_full_compilation = require_full_compilation

def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]:
"""
Expand All @@ -151,12 +158,16 @@ def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph
result: List[Subgraph] = []
for subgraph in subgraphs:
if subgraph.is_acc:
if len(subgraph.nodes) >= self.settings.min_acc_module_size or (
self.allowed_single_node_partition_ops is not None
and any(
ConverterRegistry.qualified_name_or_str(node.target)
in self.allowed_single_node_partition_ops
for node in subgraph.nodes
if (
len(subgraph.nodes) >= self.settings.min_acc_module_size
or self.require_full_compilation
or (
self.allowed_single_node_partition_ops is not None
and any(
ConverterRegistry.qualified_name_or_str(node.target)
in self.allowed_single_node_partition_ops
for node in subgraph.nodes
)
)
):
result.append(subgraph)
Expand Down Expand Up @@ -185,6 +196,27 @@ def partition_graph(self) -> torch.fx.GraphModule:
# Delegate nodes based on operator coverage
subgraphs = self.put_nodes_into_subgraphs()

# A graph is fully supported if there is a single partition and all operators are supported/convertible
full_support = len([s for s in subgraphs if s.is_acc]) == 1 and not getattr(
self.operator_support, "unsupported_operators", True
)

if not full_support and self.require_full_compilation:
raise AssertionError(
"require_full_compilation=True was specified, but model is not fully supported"
)

if (
full_support
and self.require_full_compilation
and self.settings.min_acc_module_size != MIN_BLOCK_SIZE
):
logger.warning(
"Detected both require_full_compilation and min_block_size compilation "
"arguments were specified. Disregarding min_block_size argument for "
"fully supported model."
)

# Remove segments smaller than the block size (with exceptions)
subgraphs = self.remove_small_acc_subgraphs(subgraphs)

Expand Down Expand Up @@ -217,6 +249,7 @@ def partition(
verbose: bool = DEBUG,
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Collection[Target] = set(),
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
) -> torch.fx.GraphModule:
"""Partition an FX GraphModule with aten ops into TRT engines
Partitioning is based on converter operator support
Expand All @@ -226,6 +259,7 @@ def partition(
verbose: Bool representing whether to print operator support
min_block_size: Minimum number of operators per TRT-Engine Block
torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
require_full_compilation: Require that all computational operators be run in TRT
Returns:
torch.fx.GraphModule
"""
Expand All @@ -236,7 +270,12 @@ def partition(

# Construct
supported_ops = OpSupportTester(torch_executed_ops=torch_executed_ops)
partitioner = TRTPartitioner(gm, supported_ops, min_block_size=min_block_size)
partitioner = TRTPartitioner(
gm,
supported_ops,
min_block_size=min_block_size,
require_full_compilation=require_full_compilation,
)

partitioned_graph = partitioner.partition_graph()

Expand Down
45 changes: 42 additions & 3 deletions py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from torch.fx.graph_module import GraphModule
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
from torch.fx.passes.operator_support import OperatorSupport, SupportDict
from torch_tensorrt.dynamo._defaults import DEBUG, MIN_BLOCK_SIZE
from torch_tensorrt.dynamo._defaults import (
DEBUG,
MIN_BLOCK_SIZE,
REQUIRE_FULL_COMPILATION,
)
from torch_tensorrt.dynamo.conversion.converter_registry import (
DYNAMO_CONVERTERS as CONVERTERS,
)
Expand All @@ -26,6 +30,7 @@ class TRTPartitioner(CapabilityBasedPartitioner): # type: ignore[misc]
allowed_single_node_partition_ops: Nodes which can be included in single-node partitons.
Generally useful for module-level exclusion ops which are intensive despite being single functions
min_block_size: Minimum number of computational operators per block
require_full_compilation: Require that all computational operators be run in TRT
Returns:
torch.fx.GraphModule
"""
Expand All @@ -40,6 +45,7 @@ def __init__(
Collection[str]
] = DEFAULT_SINGLE_NODE_PARTITIONS,
min_block_size: int = MIN_BLOCK_SIZE,
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
) -> None:
super().__init__(
graph_module,
Expand All @@ -50,12 +56,34 @@ def __init__(
)

self.min_block_size = min_block_size
self.require_full_compilation = require_full_compilation

def propose_partitions(self) -> List[Partition]:
# Propose partitions using the default, then refine the results
initial_proposed_partitions = super().propose_partitions()
partitions = dict(enumerate(initial_proposed_partitions))

# A graph is fully supported if there is a single partition and all operators are supported/convertible
full_support = len(partitions) == 1 and not getattr(
self.operator_support, "unsupported_operators", True
)

if not full_support and self.require_full_compilation:
raise AssertionError(
"require_full_compilation=True was specified, but model is not fully supported"
)

if (
full_support
and self.require_full_compilation
and self.min_block_size != MIN_BLOCK_SIZE
):
logger.warning(
"Detected both require_full_compilation and min_block_size compilation "
"arguments were specified. Disregarding min_block_size argument for "
"fully supported model."
)

# For each partition, determine whether or not the number of computational operators
# exceeds the threshold, and if not, remove that partition
partitions_to_remove = {}
Expand All @@ -81,7 +109,11 @@ def propose_partitions(self) -> List[Partition]:
):
compute_node_count += 1

if compute_node_count < self.min_block_size and not exempted_partition:
if (
compute_node_count < self.min_block_size
and not exempted_partition
and not (full_support and self.require_full_compilation)
):
partitions_to_remove[id] = compute_node_count

# Remove any nodes violating the criteria specified by the user
Expand Down Expand Up @@ -172,6 +204,7 @@ def partition(
verbose: bool = DEBUG,
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Optional[Set[str]] = None,
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
) -> torch.fx.GraphModule:
"""Partition an FX GraphModule with aten ops into TRT engines
Partitioning is based on converter operator support
Expand All @@ -181,6 +214,7 @@ def partition(
verbose: Bool representing whether to print operator support
min_block_size: Minimum number of operators per TRT-Engine Block
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
require_full_compilation: Whether to require that all operators be run in TRT
Returns:
torch.fx.GraphModule
"""
Expand All @@ -189,7 +223,12 @@ def partition(
if torch_executed_ops is not None
else set()
)
partitioner = TRTPartitioner(gm, supported_ops, min_block_size=min_block_size)
partitioner = TRTPartitioner(
gm,
supported_ops,
min_block_size=min_block_size,
require_full_compilation=require_full_compilation,
)

# Determine partitions based on user specifications and operator support
# Then, fuse partitions and display overview of supported/unsupported operators
Expand Down
10 changes: 9 additions & 1 deletion py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,15 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
"If this is incorrect, please specify an input device, via the device keyword."
)

logger.info(f"Compiling with Settings:\n{settings}")
# Ignore and warn about require_full_compilation flag
if settings.require_full_compilation:
logger.warning(
"Detected require_full_compilation=True for a torch.compile run. "
"This option has no effect in torch.compile."
)
settings.require_full_compilation = False

logger.info("Compilation Settings: %s\n", settings)

return settings

Expand Down
24 changes: 24 additions & 0 deletions tests/py/dynamo/partitioning/test_fast_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,30 @@ def forward(self, x, y):
"Single operators should not be segmented",
)

def test_partition_fully_supported_one_op_require_full_compilation(self):
class FullySupportedOneOp(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, y):
return torch.ops.aten.add.Tensor(x, y)

fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
partitioned_graph = partitioning.fast_partition(
deepcopy(fx_graph), require_full_compilation=True
)
self.assertEquals(
len(
[
1
for submod in list(partitioned_graph.named_children())
if "_run_on_acc" in submod[0]
]
),
1,
"Single operators can be segmented if full compilation is required",
)

def test_partition_fully_supported_multi_op(self):
class FullySupportedMultiOp(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
Expand Down
18 changes: 18 additions & 0 deletions tests/py/dynamo/partitioning/test_global_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,24 @@ def forward(self, x, y):
"Single operators should not be segmented",
)

def test_partition_fully_supported_one_op_require_full_compilation(self):
class FullySupportedOneOp(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, y):
return torch.ops.aten.add.Tensor(x, y)

fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
partitioned_graph = partitioning.global_partition(
deepcopy(fx_graph), require_full_compilation=True
)
self.assertEquals(
len(list(partitioned_graph.named_children())),
1,
"Single operators can be segmented if full compilation is required",
)

def test_partition_fully_supported_multi_op(self):
class FullySupportedMultiOp(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
Expand Down

0 comments on commit 8efbee9

Please sign in to comment.