Skip to content

Commit

Permalink
Remove non compatible JIT modules
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Feb 1, 2024
1 parent c3faec7 commit be0465f
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/brevitas_examples/common/generative/quant_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor):
return zero_point


class RuntimeDynamicStatsScaling(brevitas.jit.ScriptModule):
# TODO: restore JIT compatibility
class RuntimeDynamicStatsScaling(nn.Module):

def __init__(
self,
Expand All @@ -101,7 +102,6 @@ def __init__(
self.stats_impl = scaling_stats_impl
self.dynamic_scaling_broadcastable_fn = dynamic_scaling_broadcastable_fn

@brevitas.jit.script_method
def forward(self, x) -> Tensor:
shape = x.shape
x = self.scaling_stats_input_view_shape_impl(x)
Expand All @@ -110,7 +110,8 @@ def forward(self, x) -> Tensor:
return x


class RuntimeDynamicStatsZeroPoint(brevitas.jit.ScriptModule):
# TODO: restore JIT compatibility
class RuntimeDynamicStatsZeroPoint(nn.Module):

def __init__(
self,
Expand All @@ -125,7 +126,6 @@ def __init__(
self.dynamic_scaling_broadcastable_fn = dynamic_scaling_broadcastable_fn
self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point)

@brevitas.jit.script_method
def forward(self, x, scale, bit_width) -> Tensor:
shape = x.shape
x = self.zero_point_stats_input_view_shape_impl(x)
Expand Down

0 comments on commit be0465f

Please sign in to comment.