diff --git a/src/brevitas/__init__.py b/src/brevitas/__init__.py index eddc35a02..fe46102a7 100644 --- a/src/brevitas/__init__.py +++ b/src/brevitas/__init__.py @@ -23,6 +23,12 @@ else: torch_version = version.parse(torch.__version__) +try: + # Attempt _dynamo import + is_dynamo_compiling = torch._dynamo.is_compiling +except: + is_dynamo_compiling = lambda: False + try: __version__ = get_distribution(__name__).version except DistributionNotFound: diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 59b559787..a5c4407fd 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -16,6 +16,7 @@ from torch.nn.utils.rnn import PackedSequence from brevitas import config +from brevitas import is_dynamo_compiling from brevitas import torch_version from brevitas.common import ExportMixin from brevitas.inject import ExtendedInjector @@ -29,11 +30,6 @@ from .utils import filter_kwargs -if torch_version < packaging.version.parse('2.0'): - is_dynamo_compiling = lambda: False -else: - is_dynamo_compiling = torch._dynamo.is_compiling - class QuantProxyMixin(object): __metaclass__ = ABCMeta diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index dc7c704c3..f28233aed 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -7,23 +7,15 @@ from typing import Any, List, Optional, Tuple, Union from warnings import warn -import packaging.version import torch - -from brevitas import torch_version -from brevitas.core.function_wrapper.misc import Identity - -if torch_version < packaging.version.parse('2.0'): - is_dynamo_compiling = lambda: False -else: - is_dynamo_compiling = torch._dynamo.is_compiling - from torch import Tensor import torch.nn as nn from typing_extensions import Protocol from typing_extensions import runtime_checkable from brevitas import config +from brevitas import is_dynamo_compiling +from brevitas.core.function_wrapper.misc import Identity from brevitas.function import max_int from brevitas.inject import BaseInjector as Injector from brevitas.quant_tensor import _unpack_quant_tensor diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index b2ded7239..9feb593b4 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -5,15 +5,7 @@ from abc import abstractmethod from typing import Any, Optional, Tuple, Union -import packaging.version import torch - -from brevitas import torch_version - -if torch_version < packaging.version.parse('2.0'): - is_dynamo_compiling = lambda: False -else: - is_dynamo_compiling = torch._dynamo.is_compiling from torch import nn from torch import Tensor from torch.nn import Identity @@ -21,6 +13,7 @@ from typing_extensions import runtime_checkable import brevitas +from brevitas import is_dynamo_compiling from brevitas.quant_tensor import IntQuantTensor from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO