From 98213ab5c7854cfa6b9ec9479776e498fec3bf1b Mon Sep 17 00:00:00 2001 From: Alessandro Pappalardo <1934033+volcacius@users.noreply.github.com> Date: Fri, 10 Nov 2023 15:38:38 +0000 Subject: [PATCH] Feat (QuantTensor): QuantTensor x Tensor elementary ops dequantize to Tensor (#668) --- src/brevitas/quant_tensor/__init__.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index 92bfbe39e..05593f5a3 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -295,7 +295,7 @@ def cat(tensors, dim, out=None): else: tensors = [qt.value if isinstance(qt, QuantTensor) else qt for qt in tensors] output_value = torch.cat(tensors, dim=dim) - return QuantTensor(output_value) + return output_value # Reference: https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types @@ -366,9 +366,9 @@ def __add__(self, other): signed=output_signed, training=output_training) elif isinstance(other, QuantTensor): - output = QuantTensor(self.value + other.value) + output = self.value + other.value else: - output = QuantTensor(self.value + other) + output = self.value + other return output def __radd__(self, other): @@ -396,9 +396,9 @@ def __mul__(self, other): signed=output_signed, training=output_training) elif isinstance(other, QuantTensor): - output = QuantTensor(self.value * other.value) + output = self.value * other.value else: - output = QuantTensor(self.value * other) + output = self.value * other return output def __sub__(self, other): @@ -423,9 +423,9 @@ def __truediv__(self, other): signed=output_signed, training=output_training) elif isinstance(other, QuantTensor): - output = QuantTensor(self.value / other.value) + output = self.value / other.value else: - output = QuantTensor(self.value / other) + output = self.value / other return output def __abs__(self):