From cb03b91df6fa84c36571729fe09d561317c9f0eb Mon Sep 17 00:00:00 2001 From: Frank J <53087374+crazy-JiangDongHua@users.noreply.github.com> Date: Thu, 29 Feb 2024 15:48:49 +0800 Subject: [PATCH] [bugfix]fix bug of oneflow backend be stuck (#10435) --- .../framework/infer_compiler/__init__.py | 39 ++------------ .../framework/infer_compiler/with_fx_graph.py | 2 +- .../infer_compiler/with_oneflow_backend.py | 53 +++++++++++++++++++ 3 files changed, 59 insertions(+), 35 deletions(-) create mode 100644 python/oneflow/framework/infer_compiler/with_oneflow_backend.py diff --git a/python/oneflow/framework/infer_compiler/__init__.py b/python/oneflow/framework/infer_compiler/__init__.py index 524c33b2854..b1e1bbfca88 100644 --- a/python/oneflow/framework/infer_compiler/__init__.py +++ b/python/oneflow/framework/infer_compiler/__init__.py @@ -14,43 +14,14 @@ limitations under the License. """ -import os - -import oneflow as flow -from oneflow.framework.args_tree import ArgsTree +try: + import torch +except ImportError: + print("You should install torch also when use `oneflow.framework.infer_compiler`.") from .transform.custom_transform import register from .utils.patch_for_compiler import * from .with_fx_graph import fx_node_tranform from .with_fx_interpreter import OneFlowInterpreter from .with_oneflow_compile import compile_from_torch - - -def oneflow_backend(gm, example_inputs, *args, **kwargs): - with_interp = os.getenv( - "ONEDIFF_INFER_COMPILER_USE_INTERPRETER", "False" - ).lower() in ("true", "1", "t",) - if not with_interp: - transformed_fn = fx_node_tranform(gm) - - def wrapped_forward(*args, **kwargs): - def input_fn(value): - if isinstance(value, torch.Tensor): - return flow.utils.tensor.from_torch(value.contiguous()) - else: - return value - - args_tree = ArgsTree((args, kwargs), False, tensor_type=torch.Tensor) - out = args_tree.map_leaf(input_fn) - args = out[0] - if with_interp: - output = OneFlowInterpreter(gm, garbage_collect_values=False).run( - *args, **kwargs - ) - else: - output = transformed_fn(*args, **kwargs) - if isinstance(output, tuple): - return tuple(flow.utils.tensor.to_torch(i) for i in output) - return flow.utils.tensor.to_torch(output) - - return wrapped_forward +from .with_oneflow_backend import oneflow_backend diff --git a/python/oneflow/framework/infer_compiler/with_fx_graph.py b/python/oneflow/framework/infer_compiler/with_fx_graph.py index 36f92cb23eb..881b720793c 100644 --- a/python/oneflow/framework/infer_compiler/with_fx_graph.py +++ b/python/oneflow/framework/infer_compiler/with_fx_graph.py @@ -46,7 +46,7 @@ def fx_node_tranform(gm): os.environ.setdefault("ONEFLOW_MLIR_FUSE_FORWARD_OPS", "1") os.environ.setdefault("ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL", "1") os.environ.setdefault("ONEFLOW_MLIR_GROUP_MATMUL", "1") - os.environ.setdefault("ONEFLOW_MLIR_PREFER_NHWC", "1") + os.environ.setdefault("ONEFLOW_MLIR_PREFER_NHWC", "0") os.environ.setdefault("ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS", "1") os.environ.setdefault("ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR", "1") os.environ.setdefault( diff --git a/python/oneflow/framework/infer_compiler/with_oneflow_backend.py b/python/oneflow/framework/infer_compiler/with_oneflow_backend.py new file mode 100644 index 00000000000..23fcb5aa684 --- /dev/null +++ b/python/oneflow/framework/infer_compiler/with_oneflow_backend.py @@ -0,0 +1,53 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import torch + +import oneflow as flow +from oneflow.framework.args_tree import ArgsTree +from .with_fx_graph import fx_node_tranform +from .with_fx_interpreter import OneFlowInterpreter + + +def oneflow_backend(gm, example_inputs, *args, **kwargs): + with_interp = os.getenv( + "ONEDIFF_INFER_COMPILER_USE_INTERPRETER", "False" + ).lower() in ("true", "1", "t",) + if not with_interp: + transformed_fn = fx_node_tranform(gm) + + def wrapped_forward(*args, **kwargs): + def input_fn(value): + if isinstance(value, torch.Tensor): + return flow.utils.tensor.from_torch(value.contiguous()) + else: + return value + + args_tree = ArgsTree((args, kwargs), False, tensor_type=torch.Tensor) + out = args_tree.map_leaf(input_fn) + args = out[0] + if with_interp: + output = OneFlowInterpreter(gm, garbage_collect_values=False).run( + *args, **kwargs + ) + else: + output = transformed_fn(*args, **kwargs) + if isinstance(output, tuple): + return tuple(flow.utils.tensor.to_torch(i) for i in output) + return flow.utils.tensor.to_torch(output) + + return wrapped_forward