Skip to content

Commit

Permalink
torch.compile: fix functionalization (#1045)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale authored Dec 27, 2024
1 parent ce7b602 commit 960dee2
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 5 deletions.
153 changes: 153 additions & 0 deletions aphrodite/compilation/backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import operator

import torch
import torch.fx as fx


def fix_functionalization(graph: fx.Graph):
"""
Rewrite the graph module to replace the pattern involving
torch._higher_order_ops.auto_functionalize.auto_functionalized
with a direct call to the inplace custom op.
# TODO: check if PyTorch nightly has fixed this issue
"""
# debug code, if we want to see the graph before the transformation
# with open("before.py", "w") as f:
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
nodes_to_remove = []
for node in graph.nodes:
# Identify the auto_functionalized node
if (
node.op == "call_function"
and node.target
== torch._higher_order_ops.auto_functionalize.auto_functionalized
): # noqa
if node.args[0] == torch.ops._C.rotary_embedding.default:
# manual replace for rotary_embedding
# Now, collect the arguments
kwargs = node.kwargs
query = kwargs["query"]
mm_node = query.args[0].args[0]
# Create a new call to torch.ops._C.rotary_embedding.default
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(
torch.ops._C.rotary_embedding.default, kwargs=kwargs
)
# Remove the auto_functionalized node
# Since the node may have outputs, we need to handle its users
# Replace uses of the outputs (getitem nodes) with mm_node
for user in list(node.users):
if (
user.op == "call_function"
and user.target == operator.getitem
): # noqa
# Remove the getitem node
for getitem_user in list(user.users):
if (
getitem_user.op == "call_function"
and getitem_user.target
== torch.ops.aten.slice_scatter.default
):
# Replace the uses of slice_scatter node
# with mm_node
getitem_user.replace_all_uses_with(mm_node)
nodes_to_remove.append(getitem_user)
nodes_to_remove.append(user)
nodes_to_remove.append(node)
elif node.args[0] == torch.ops._C.fused_add_rms_norm.default:
# manual replace for fused_add_rms_norm
# this is the most effective optimization for llama
# failing to do this will result in many unnecessary copies
kwargs = node.kwargs
input = kwargs["input"]
residual = kwargs["residual"]
# Create a new call to torch.ops._C.rotary_embedding.default
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(
torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs
)
for user in list(node.users):
if (
user.op == "call_function"
and user.target == operator.getitem
): # noqa
# Remove the getitem node
if user.args[1] == 1:
replace_node = input
elif user.args[1] == 2:
replace_node = residual
user.replace_all_uses_with(replace_node)
nodes_to_remove.append(user)
nodes_to_remove.append(node)
elif node.args[0] == torch.ops._C.rms_norm.default:
# manual replace for rms_norm
kwargs = node.kwargs
input = kwargs["input"]
out = kwargs["out"]
weight = kwargs["weight"]
epsilon = kwargs["epsilon"]
# Create a new call to torch.ops._C.rotary_embedding.default
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(
torch.ops._C.rms_norm.default,
args=(out, input, weight, epsilon),
)
replace_node = out
for user in list(node.users):
if (
user.op == "call_function"
and user.target == operator.getitem
): # noqa
user.replace_all_uses_with(replace_node)
nodes_to_remove.append(user)
nodes_to_remove.append(node)
elif node.args[0] == torch.ops._C.silu_and_mul.default:
# manual replace for silu_and_mul
kwargs = node.kwargs
input = kwargs["input"]
out = kwargs["out"]
# Create a new call to torch.ops._C.rotary_embedding.default
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(
torch.ops._C.silu_and_mul.default,
args=(out, input),
)
replace_node = out
for user in list(node.users):
if (
user.op == "call_function"
and user.target == operator.getitem
): # noqa
user.replace_all_uses_with(replace_node)
nodes_to_remove.append(user)
nodes_to_remove.append(node)
# Remove the nodes all at once
for node in nodes_to_remove:
graph.erase_node(node)
# debug code, if we want to see the graph after the transformation
# with open("after.py", "w") as f:
# print(graph.python_code(root_module="self", verbose=True).src, file=f)


def aphrodite_backend(graph, example_inputs):
from torch._inductor import config

current_config = config.shallow_copy_dict()
from torch._inductor.compile_fx import compile_fx

current_config["post_grad_custom_post_pass"] = fix_functionalization
return compile_fx(graph, example_inputs, config_patches=current_config)
3 changes: 2 additions & 1 deletion aphrodite/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,8 +1071,9 @@ def load_model(self) -> None:

if envs.APHRODITE_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
logger.info("Compiling the model using torch.compile...")
from aphrodite.compilation.backends import aphrodite_backend
from aphrodite.plugins import get_torch_compile_backend
backend = get_torch_compile_backend() or "eager"
backend = get_torch_compile_backend() or aphrodite_backend
start_time = time.time()
self.model = torch.compile(
self.model,
Expand Down
11 changes: 7 additions & 4 deletions tests/compile/test_full_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ def test_full_graph(model):
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model="meta-llama/Meta-Llama-3-8B",
enforce_eager=True,
load_format="dummy")
llm.generate(prompts, sampling_params)
llm = LLM(model=model, enforce_eager=True)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

0 comments on commit 960dee2

Please sign in to comment.