Skip to content

Commit

Permalink
Feat: removed original_cat workaround (#902)
Browse files Browse the repository at this point in the history
  • Loading branch information
costigt-dev authored Mar 14, 2024
1 parent ce2c199 commit 0bcf34d
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 31 deletions.
23 changes: 0 additions & 23 deletions src/brevitas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,6 @@
else:
torch_version = version.parse(torch.__version__)

original_cat = torch.cat
if torch_version < version.parse('1.7.0'):
from torch._overrides import handle_torch_function
from torch._overrides import has_torch_function

@torch.jit.ignore
def unsupported_jit_cat(tensors, dim):
if not isinstance(tensors, (tuple, list)):
tensors = tuple(tensors)
return unsupported_jit_cat(tensors, dim)
if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
return handle_torch_function(
original_cat, relevant_args=tensors, tensors=tensors, dim=dim)
else:
return original_cat(tensors=tensors, dim=dim)

def cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
if not torch.jit.is_scripting():
return unsupported_jit_cat(tensors, dim)
return original_cat(tensors, dim=dim)

torch.cat = cat

try:
__version__ = get_distribution(__name__).version
except DistributionNotFound:
Expand Down
13 changes: 6 additions & 7 deletions src/brevitas/graph/quantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

ADD_METHODS = ['add', 'add_']

CAT = brevitas.original_cat

SIGN_PRESERVING_MODULES = (
nn.Dropout,
nn.Dropout2d,
Expand Down Expand Up @@ -87,7 +85,8 @@ def are_inputs_unsigned(model, node, is_unsigned_list, quant_act_map, unsigned_a
else:
is_unsigned_list.append(False)
elif inp_node.op == 'call_function':
if inp_node.target in [torch.reshape, torch.flatten, torch.transpose, CAT] + ADD_FNS:
if inp_node.target in [torch.reshape, torch.flatten, torch.transpose, torch.cat
] + ADD_FNS:
are_inputs_unsigned(
model, inp_node, is_unsigned_list, quant_act_map, unsigned_act_tuple)
else:
Expand Down Expand Up @@ -141,7 +140,7 @@ def are_inputs_quantized_and_aligned(model, node, quantized_modules_list, quant_
if inp_node.target in [torch.reshape, torch.flatten, torch.transpose]:
are_inputs_quantized_and_aligned(
model, inp_node, quantized_modules_list, quant_act_map, same_sign)
elif inp_node.target is CAT:
elif inp_node.target is torch.cat:
are_inputs_quantized_and_aligned(
model, inp_node, quantized_modules_list, quant_act_map, True)
elif inp_node.target in ADD_FNS:
Expand Down Expand Up @@ -281,7 +280,7 @@ def recursive_input_handler(
quant_identity_map,
align_input_quant_fn,
align_sign)
elif inp_node.op == 'call_function' and inp_node.target is CAT:
elif inp_node.op == 'call_function' and inp_node.target is torch.cat:
recursive_input_handler(
model,
inp_node,
Expand Down Expand Up @@ -329,12 +328,12 @@ def residual_handler(
def is_converged(model):

for node in model.graph.nodes:
if (node.op == 'call_function' and node.target in ADD_FNS + [CAT] or
if (node.op == 'call_function' and node.target in ADD_FNS + [torch.cat] or
node.op == 'call_method' and node.target in ADD_METHODS):
rewriters = []
# If the op is CAT, check that inputs have same sign, and in recursive_input_handler
# force that the sign is aligned
same_sign = node.target is CAT
same_sign = node.target is torch.cat

# If input to the CAT or ADD node are quantized and aligned correctly, continue to
# the next node
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/quant_tensor/torch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def transpose_handler(inp, *args, **kwargs):
return inp.transpose(*args, **kwargs)


@implements(brevitas.original_cat)
@implements(torch.cat)
def cat_handler(*args, **kwargs):
from brevitas.quant_tensor import QuantTensor
return QuantTensor.cat(*args, **kwargs)
Expand Down

0 comments on commit 0bcf34d

Please sign in to comment.