From 42821b3e500dba79f7db20065607aae911842fc2 Mon Sep 17 00:00:00 2001 From: dan Date: Tue, 6 Feb 2024 00:26:50 -0600 Subject: [PATCH] add tribal knowledge decomps for newer torch versions Once we migrate default torch version these will be hit by existing tests. --- core/shark_turbine/dynamo/passes.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/shark_turbine/dynamo/passes.py b/core/shark_turbine/dynamo/passes.py index 88c08f6ad..80be06ae8 100644 --- a/core/shark_turbine/dynamo/passes.py +++ b/core/shark_turbine/dynamo/passes.py @@ -52,6 +52,11 @@ torch.ops.aten._scaled_dot_product_flash_attention.default, ] +# These decompositions either didnt exist or weren't required for 2.1.0 +if torch.__version__ > "2.1.0": + DEFAULT_DECOMPOSITIONS.append(torch.ops.aten._scaled_dot_product_flash_attention_for_cpu) + DEFAULT_DECOMPOSITIONS.append(torch.ops.aten.unbind_int) + def apply_decompositions( gm: torch.fx.GraphModule,