Skip to content

Commit

Permalink
Fix import
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 13, 2024
1 parent 3dc7dfa commit 04231e4
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 23 deletions.
6 changes: 6 additions & 0 deletions src/brevitas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
else:
torch_version = version.parse(torch.__version__)

try:
# Attempt _dynamo import
is_dynamo_compiling = torch._dynamo.is_compiling
except:
is_dynamo_compiling = lambda: False

try:
__version__ = get_distribution(__name__).version
except DistributionNotFound:
Expand Down
6 changes: 1 addition & 5 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.nn.utils.rnn import PackedSequence

from brevitas import config
from brevitas import is_dynamo_compiling
from brevitas import torch_version
from brevitas.common import ExportMixin
from brevitas.inject import ExtendedInjector
Expand All @@ -29,11 +30,6 @@

from .utils import filter_kwargs

if torch_version < packaging.version.parse('2.0'):
is_dynamo_compiling = lambda: False
else:
is_dynamo_compiling = torch._dynamo.is_compiling


class QuantProxyMixin(object):
__metaclass__ = ABCMeta
Expand Down
12 changes: 2 additions & 10 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,15 @@
from typing import Any, List, Optional, Tuple, Union
from warnings import warn

import packaging.version
import torch

from brevitas import torch_version
from brevitas.core.function_wrapper.misc import Identity

if torch_version < packaging.version.parse('2.0'):
is_dynamo_compiling = lambda: False
else:
is_dynamo_compiling = torch._dynamo.is_compiling

from torch import Tensor
import torch.nn as nn
from typing_extensions import Protocol
from typing_extensions import runtime_checkable

from brevitas import config
from brevitas import is_dynamo_compiling
from brevitas.core.function_wrapper.misc import Identity
from brevitas.function import max_int
from brevitas.inject import BaseInjector as Injector
from brevitas.quant_tensor import _unpack_quant_tensor
Expand Down
9 changes: 1 addition & 8 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,15 @@
from abc import abstractmethod
from typing import Any, Optional, Tuple, Union

import packaging.version
import torch

from brevitas import torch_version

if torch_version < packaging.version.parse('2.0'):
is_dynamo_compiling = lambda: False
else:
is_dynamo_compiling = torch._dynamo.is_compiling
from torch import nn
from torch import Tensor
from torch.nn import Identity
from typing_extensions import Protocol
from typing_extensions import runtime_checkable

import brevitas
from brevitas import is_dynamo_compiling
from brevitas.quant_tensor import IntQuantTensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIO
Expand Down

0 comments on commit 04231e4

Please sign in to comment.