Skip to content

Commit

Permalink
Fix (jit): remove patcher
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 15, 2023
1 parent 966de12 commit 66d4679
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 36 deletions.
4 changes: 2 additions & 2 deletions src/brevitas/export/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from brevitas.proxy.quant_proxy import QuantProxyProtocol
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.jit_utils import clear_class_registry
from brevitas.utils.jit_utils import jit_patches_generator
# from brevitas.utils.jit_utils import jit_patches_generator
from brevitas.utils.python_utils import patch


Expand Down Expand Up @@ -162,7 +162,7 @@ class BaseManager(ABC):

target_name = None
handlers = []
_base_trace_patches_generator = jit_patches_generator
_base_trace_patches_generator = None # jit_patches_generator
_fn_to_cache = []
_fn_cache = []
_cached_io_handler_map = {}
Expand Down
10 changes: 5 additions & 5 deletions src/brevitas/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from brevitas.config import JIT_ENABLED

IS_ABOVE_110 = version.parse(torch.__version__) > version.parse('1.1.0')
# IS_ABOVE_110 = version.parse(torch.__version__) > version.parse('1.1.0')


def _disabled(fn):
Expand All @@ -20,10 +20,10 @@ def _disabled(fn):
ScriptModule = torch.jit.ScriptModule
Attribute = torch.jit.Attribute

if not IS_ABOVE_110:
script_method_110_disabled = _disabled
else:
script_method_110_disabled = script_method
script_method_110_disabled = script_method
# script_method_110_disabled = _disabled
# if not IS_ABOVE_110:
# else:

else:

Expand Down
42 changes: 20 additions & 22 deletions src/brevitas/utils/jit_utils.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,36 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import inspect
# import inspect

import torch

try:
from torch._jit_internal import get_torchscript_modifier
except:
get_torchscript_modifier = None

from dependencies import Injector
# from dependencies import Injector
from packaging import version
import torch

from brevitas import torch_version
from brevitas.inject import ExtendedInjector
from brevitas.jit import IS_ABOVE_110

from .python_utils import patch
# try:
# from torch._jit_internal import get_torchscript_modifier
# except:
# get_torchscript_modifier = None

# from brevitas.inject import ExtendedInjector
# from brevitas.jit import IS_ABOVE_110

def _get_modifier_wrapper(fn):
if inspect.isclass(fn) and issubclass(fn, (Injector, ExtendedInjector)):
return None
else:
return get_torchscript_modifier(fn)
# from .python_utils import patch

# def _get_modifier_wrapper(fn):
# if inspect.isclass(fn) and issubclass(fn, (Injector, ExtendedInjector)):
# return None
# else:
# return get_torchscript_modifier(fn)

if IS_ABOVE_110:
# if IS_ABOVE_110:

def jit_patches_generator():
return [patch(torch._jit_internal, 'get_torchscript_modifier', _get_modifier_wrapper)]
else:
jit_patches_generator = None
# def jit_patches_generator():
# return [patch(torch._jit_internal, 'get_torchscript_modifier', _get_modifier_wrapper)]
# else:
# jit_patches_generator = None


def clear_class_registry():
Expand Down
7 changes: 0 additions & 7 deletions tests/brevitas_examples/test_jit_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pytest
import torch

from brevitas.utils.jit_utils import jit_patches_generator
from brevitas_examples.bnn_pynq.models import model_with_cfg

FC_INPUT_SIZE = (1, 1, 28, 28)
Expand All @@ -28,9 +27,6 @@ def test_brevitas_fc_jit_trace(size, wbits, abits):
fc, _ = model_with_cfg(nname.lower(), pretrained=False)
fc.train(False)
input_tensor = torch.randn(FC_INPUT_SIZE)
with ExitStack() as stack:
for mgr in jit_patches_generator():
stack.enter_context(mgr)
traced_model = torch.jit.trace(fc, input_tensor)
out_traced = traced_model(input_tensor)
out = fc(input_tensor)
Expand All @@ -46,9 +42,6 @@ def test_brevitas_cnv_jit_trace(wbits, abits):
cnv, _ = model_with_cfg(nname.lower(), pretrained=False)
cnv.train(False)
input_tensor = torch.randn(CNV_INPUT_SIZE)
with ExitStack() as stack:
for mgr in jit_patches_generator():
stack.enter_context(mgr)
traced_model = torch.jit.trace(cnv, input_tensor)
out_traced = traced_model(input_tensor)
out = cnv(input_tensor)
Expand Down

0 comments on commit 66d4679

Please sign in to comment.