From 960dee2f9797d56c7851c434f539db680ebb6a59 Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Thu, 26 Dec 2024 19:25:33 -0800 Subject: [PATCH] torch.compile: fix functionalization (#1045) --- aphrodite/compilation/backends.py | 153 ++++++++++++++++++++++++++++++ aphrodite/worker/model_runner.py | 3 +- tests/compile/test_full_graph.py | 11 ++- 3 files changed, 162 insertions(+), 5 deletions(-) create mode 100644 aphrodite/compilation/backends.py diff --git a/aphrodite/compilation/backends.py b/aphrodite/compilation/backends.py new file mode 100644 index 000000000..bc8bd2dfb --- /dev/null +++ b/aphrodite/compilation/backends.py @@ -0,0 +1,153 @@ +import operator + +import torch +import torch.fx as fx + + +def fix_functionalization(graph: fx.Graph): + """ + Rewrite the graph module to replace the pattern involving + torch._higher_order_ops.auto_functionalize.auto_functionalized + with a direct call to the inplace custom op. + # TODO: check if PyTorch nightly has fixed this issue + """ + # debug code, if we want to see the graph before the transformation + # with open("before.py", "w") as f: + # print(graph.python_code(root_module="self", verbose=True).src, file=f) + nodes_to_remove = [] + for node in graph.nodes: + # Identify the auto_functionalized node + if ( + node.op == "call_function" + and node.target + == torch._higher_order_ops.auto_functionalize.auto_functionalized + ): # noqa + if node.args[0] == torch.ops._C.rotary_embedding.default: + # manual replace for rotary_embedding + # Now, collect the arguments + kwargs = node.kwargs + query = kwargs["query"] + mm_node = query.args[0].args[0] + # Create a new call to torch.ops._C.rotary_embedding.default + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.rotary_embedding.default, kwargs=kwargs + ) + # Remove the auto_functionalized node + # Since the node may have outputs, we need to handle its users + # Replace uses of the outputs (getitem nodes) with mm_node + for user in list(node.users): + if ( + user.op == "call_function" + and user.target == operator.getitem + ): # noqa + # Remove the getitem node + for getitem_user in list(user.users): + if ( + getitem_user.op == "call_function" + and getitem_user.target + == torch.ops.aten.slice_scatter.default + ): + # Replace the uses of slice_scatter node + # with mm_node + getitem_user.replace_all_uses_with(mm_node) + nodes_to_remove.append(getitem_user) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + elif node.args[0] == torch.ops._C.fused_add_rms_norm.default: + # manual replace for fused_add_rms_norm + # this is the most effective optimization for llama + # failing to do this will result in many unnecessary copies + kwargs = node.kwargs + input = kwargs["input"] + residual = kwargs["residual"] + # Create a new call to torch.ops._C.rotary_embedding.default + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs + ) + for user in list(node.users): + if ( + user.op == "call_function" + and user.target == operator.getitem + ): # noqa + # Remove the getitem node + if user.args[1] == 1: + replace_node = input + elif user.args[1] == 2: + replace_node = residual + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + elif node.args[0] == torch.ops._C.rms_norm.default: + # manual replace for rms_norm + kwargs = node.kwargs + input = kwargs["input"] + out = kwargs["out"] + weight = kwargs["weight"] + epsilon = kwargs["epsilon"] + # Create a new call to torch.ops._C.rotary_embedding.default + # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.rms_norm.default, + args=(out, input, weight, epsilon), + ) + replace_node = out + for user in list(node.users): + if ( + user.op == "call_function" + and user.target == operator.getitem + ): # noqa + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + elif node.args[0] == torch.ops._C.silu_and_mul.default: + # manual replace for silu_and_mul + kwargs = node.kwargs + input = kwargs["input"] + out = kwargs["out"] + # Create a new call to torch.ops._C.rotary_embedding.default + # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.silu_and_mul.default, + args=(out, input), + ) + replace_node = out + for user in list(node.users): + if ( + user.op == "call_function" + and user.target == operator.getitem + ): # noqa + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + # Remove the nodes all at once + for node in nodes_to_remove: + graph.erase_node(node) + # debug code, if we want to see the graph after the transformation + # with open("after.py", "w") as f: + # print(graph.python_code(root_module="self", verbose=True).src, file=f) + + +def aphrodite_backend(graph, example_inputs): + from torch._inductor import config + + current_config = config.shallow_copy_dict() + from torch._inductor.compile_fx import compile_fx + + current_config["post_grad_custom_post_pass"] = fix_functionalization + return compile_fx(graph, example_inputs, config_patches=current_config) diff --git a/aphrodite/worker/model_runner.py b/aphrodite/worker/model_runner.py index 142ec5e81..fc65b2126 100644 --- a/aphrodite/worker/model_runner.py +++ b/aphrodite/worker/model_runner.py @@ -1071,8 +1071,9 @@ def load_model(self) -> None: if envs.APHRODITE_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo(): logger.info("Compiling the model using torch.compile...") + from aphrodite.compilation.backends import aphrodite_backend from aphrodite.plugins import get_torch_compile_backend - backend = get_torch_compile_backend() or "eager" + backend = get_torch_compile_backend() or aphrodite_backend start_time = time.time() self.model = torch.compile( self.model, diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index b0dd2b001..9b6a3a693 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -16,7 +16,10 @@ def test_full_graph(model): "The future of AI is", ] sampling_params = SamplingParams(temperature=0) - llm = LLM(model="meta-llama/Meta-Llama-3-8B", - enforce_eager=True, - load_format="dummy") - llm.generate(prompts, sampling_params) + llm = LLM(model=model, enforce_eager=True) + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")