Skip to content

Commit

Permalink
Feat (QuantTensor): QuantTensor x Tensor elementary ops dequantize to…
Browse files Browse the repository at this point in the history
… Tensor (#668)
  • Loading branch information
volcacius authored Nov 10, 2023
1 parent 32186be commit 98213ab
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/brevitas/quant_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def cat(tensors, dim, out=None):
else:
tensors = [qt.value if isinstance(qt, QuantTensor) else qt for qt in tensors]
output_value = torch.cat(tensors, dim=dim)
return QuantTensor(output_value)
return output_value

# Reference: https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types

Expand Down Expand Up @@ -366,9 +366,9 @@ def __add__(self, other):
signed=output_signed,
training=output_training)
elif isinstance(other, QuantTensor):
output = QuantTensor(self.value + other.value)
output = self.value + other.value
else:
output = QuantTensor(self.value + other)
output = self.value + other
return output

def __radd__(self, other):
Expand Down Expand Up @@ -396,9 +396,9 @@ def __mul__(self, other):
signed=output_signed,
training=output_training)
elif isinstance(other, QuantTensor):
output = QuantTensor(self.value * other.value)
output = self.value * other.value
else:
output = QuantTensor(self.value * other)
output = self.value * other
return output

def __sub__(self, other):
Expand All @@ -423,9 +423,9 @@ def __truediv__(self, other):
signed=output_signed,
training=output_training)
elif isinstance(other, QuantTensor):
output = QuantTensor(self.value / other.value)
output = self.value / other.value
else:
output = QuantTensor(self.value / other)
output = self.value / other
return output

def __abs__(self):
Expand Down

0 comments on commit 98213ab

Please sign in to comment.