Skip to content

Commit

Permalink
Fix for JIT, typing, default arg
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 8, 2024
1 parent d3b520f commit f54192c
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 51 deletions.
28 changes: 14 additions & 14 deletions src/brevitas/core/restrict_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, scaling_min_val: Optional[float], restrict_value_impl: Option
self.restrict_value_impl = Identity()

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
def forward(self, x: Tensor):
x = self.restrict_value_impl(x)
x = self.clamp_min_ste(x)
return x
Expand All @@ -52,7 +52,7 @@ def __init__(self, restrict_value_impl: Optional[Module]):
self.restrict_value_impl = Identity()

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
def forward(self, x: Tensor):
x = self.restrict_value_impl(x)
return x

Expand All @@ -68,7 +68,7 @@ def __init__(self, scaling_min_val: Optional[float]):
self.min_val = scaling_min_val

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
def forward(self, x: Tensor):
x = self.clamp_min_ste(x)
return x

Expand All @@ -90,11 +90,11 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return Identity()

def combine_stats_threshold(self, x, threshold):
def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x / threshold

@brevitas.jit.script_method
def forward(self, x: torch.Tensor) -> Tensor:
def forward(self, x: Tensor) -> Tensor:
return x


Expand All @@ -107,7 +107,7 @@ def __init__(self):
def restrict_init_float(self, x: float):
return math.log2(x)

def restrict_init_tensor(self, x: torch.Tensor):
def restrict_init_tensor(self, x: Tensor):
return torch.log2(x)

def restrict_init_module(self):
Expand All @@ -116,11 +116,11 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return InplaceLogTwo()

def combine_stats_threshold(self, x, threshold):
def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x - threshold

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
def forward(self, x: Tensor):
x = self.power_of_two(x)
return x

Expand All @@ -134,7 +134,7 @@ def __init__(self, restrict_value_float_to_int_impl: Module = RoundSte()):
def restrict_init_float(self, x: float):
return x

def restrict_init_tensor(self, x: torch.Tensor):
def restrict_init_tensor(self, x: Tensor):
return x

def restrict_init_module(self):
Expand All @@ -143,11 +143,11 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return Identity()

def combine_stats_threshold(self, x, threshold):
def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x / threshold

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
def forward(self, x: Tensor):
x = self.float_to_int_impl(x)
return x

Expand All @@ -162,7 +162,7 @@ def __init__(self, restrict_value_float_to_int_impl: Module = RoundSte()):
def restrict_init_float(self, x: float):
return math.log2(x)

def restrict_init_tensor(self, x: torch.Tensor):
def restrict_init_tensor(self, x: Tensor):
return torch.log2(x)

def restrict_init_module(self):
Expand All @@ -171,11 +171,11 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return InplaceLogTwo()

def combine_stats_threshold(self, x, threshold):
def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x - threshold

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
def forward(self, x: Tensor):
x = self.float_to_int_impl(x)
x = self.power_of_two(x)
return x
21 changes: 11 additions & 10 deletions src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import brevitas.config as config
from brevitas.core.function_wrapper import Identity
from brevitas.core.restrict_val import _RestrictClampValue
from brevitas.core.restrict_val import FloatRestrictValue
from brevitas.core.stats import _ParameterListStats
from brevitas.core.stats import _RuntimeStats
from brevitas.core.stats import DEFAULT_MOMENTUM
Expand All @@ -27,8 +28,8 @@ def __init__(
scaling_stats_input_view_shape_impl: Module,
scaling_stats_input_concat_dim: int,
tracked_parameter_list: List[torch.nn.Parameter],
restrict_scaling_impl: Module,
scaling_shape: Tuple[int, ...],
restrict_scaling_impl: Module = FloatRestrictValue(),
affine_rescaling: bool = False,
affine_shift_scale: bool = False,
scaling_min_val: Optional[float] = None,
Expand Down Expand Up @@ -90,7 +91,7 @@ def forward(
threshold = torch.ones(1).type_as(stats)
threshold = self.restrict_scaling_pre(threshold)
stats = self.restrict_scaling_pre(stats)
stats = self.restrict_scaling_impl.combine_stats_threshold(stats, threshold)
stats = self.restrict_scaling_impl.combine_scale_threshold(stats, threshold)
stats = self.affine_rescaling(stats)
stats = self.restrict_clamp_scaling(stats)
return stats
Expand All @@ -102,10 +103,10 @@ def __init__(
self,
scaling_stats_impl: Module,
scaling_stats_input_view_shape_impl: Module,
restrict_scaling_impl: Module,
scaling_shape: Tuple[int, ...],
affine_rescaling: bool = False,
affine_shift_scale: bool = False,
restrict_scaling_impl: Module = FloatRestrictValue(),
scaling_stats_momentum: float = DEFAULT_MOMENTUM,
scaling_min_val: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
Expand Down Expand Up @@ -172,13 +173,13 @@ def _load_from_state_dict(
class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule):

def __init__(
self,
group_size: int,
group_dim: int,
input_view_impl: torch.nn.Module,
scaling_stats_impl: torch.nn.Module,
scaling_min_val: Optional[float],
restrict_scaling_impl: Optional[torch.nn.Module]) -> None:
self,
group_size: int,
group_dim: int,
input_view_impl: Module,
scaling_stats_impl: Module,
scaling_min_val: Optional[float],
restrict_scaling_impl: Module = FloatRestrictValue()) -> None:
super(RuntimeDynamicGroupStatsScaling, self).__init__()
self.group_size = group_size
self.group_dim = group_dim
Expand Down
45 changes: 18 additions & 27 deletions src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from brevitas.core.restrict_val import _ClampValue
from brevitas.core.restrict_val import _RestrictClampValue
from brevitas.core.restrict_val import _RestrictValue
from brevitas.core.restrict_val import FloatRestrictValue
from brevitas.core.scaling.runtime import _StatsScaling
from brevitas.core.stats import _ParameterListStats
from brevitas.core.stats import _Stats
Expand Down Expand Up @@ -83,7 +84,7 @@ def __init__(
self.value = StatelessBuffer(torch.tensor(scaling_init, dtype=dtype, device=device))

@brevitas.jit.script_method
def forward(self, placeholder: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor:
def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(placeholder)
# We first apply any restriction to scaling
Expand Down Expand Up @@ -163,7 +164,7 @@ def __init__(
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)

@brevitas.jit.script_method
def forward(self, placeholder: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor:
def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(placeholder)
# We first apply any restriction to scaling
Expand Down Expand Up @@ -197,8 +198,8 @@ def __init__(
scaling_stats_input_view_shape_impl: Module,
scaling_stats_input_concat_dim: int,
tracked_parameter_list: List[torch.nn.Parameter],
restrict_scaling_impl: Module,
scaling_shape: Tuple[int, ...],
restrict_scaling_impl: Module = FloatRestrictValue(),
scaling_min_val: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None) -> None:
Expand All @@ -214,26 +215,20 @@ def __init__(
restrict_scaling_impl, scaling_shape, scaling_min_val, False, False, dtype, device)
self.init_done: bool = brevitas.jit.Attribute(False, bool)
self.local_loss_mode: bool = brevitas.jit.Attribute(False, bool)
if restrict_scaling_impl is not None:
self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module()
self.restrict_preprocess = restrict_scaling_impl.restrict_init_module()
else:
self.restrict_inplace_preprocess = Identity()
self.restrict_preprocess = Identity()

self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module()
self.restrict_preprocess = restrict_scaling_impl.restrict_init_module()
self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device))

@brevitas.jit.script_method
def forward(
self, ignored: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(ignored)
# Threshold division must happen after we update self.value, but before we apply restrict_preproces
# This is because we don't want to store a parameter dependant on a runtime value (threshold)
# And because restrict needs to happen after we divide by threshold
if self.init_done:
value = self.restrict_scaling_impl.combine_stats_threshold(
self.value, self.restrict_inplace_preprocess(threshold))
threshold = self.restrict_inplace_preprocess(threshold)
value = self.restrict_scaling_impl.combine_scale_threshold(value, threshold)
value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value))
return value
else:
Expand All @@ -245,7 +240,7 @@ def forward(
stats = self.restrict_inplace_preprocess(stats)
threshold = self.restrict_inplace_preprocess(threshold)
inplace_tensor_mul(self.value.detach(), stats)
value = self.restrict_scaling_impl.combine_stats_threshold(stats, threshold)
value = self.restrict_scaling_impl.combine_scale_threshold(value, threshold)
value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value))
self.init_done = True
return value
Expand Down Expand Up @@ -323,7 +318,7 @@ def __init__(
scaling_stats_impl: Module,
scaling_stats_input_view_shape_impl: Module = OverBatchOverTensorView(),
scaling_shape: Tuple[int, ...] = SCALAR_SHAPE,
restrict_scaling_impl: Optional[Module] = None,
restrict_scaling_impl: Module = FloatRestrictValue(),
scaling_stats_momentum: Optional[float] = DEFAULT_MOMENTUM,
scaling_min_val: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
Expand All @@ -343,15 +338,11 @@ def __init__(
self.clamp_scaling = _ClampValue(scaling_min_val)
self.local_loss_mode: bool = brevitas.jit.Attribute(
False, bool) # required to support MSE eval or variants
if restrict_scaling_impl is not None:
self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module()
self.restrict_preprocess = restrict_scaling_impl.restrict_init_module()
else:
self.restrict_inplace_preprocess = Identity()
self.restrict_preprocess = Identity()
self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module()
self.restrict_preprocess = restrict_scaling_impl.restrict_init_module()

@brevitas.jit.script_method
def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tensor:
def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor:
# Threshold division must happen after we update self.value, but before we apply restrict_preproces
# This is because we don't want to store a parameter dependent on a runtime value (threshold)
# And because restrict needs to happen after we divide by threshold
Expand All @@ -377,16 +368,16 @@ def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tens
self.restrict_inplace_preprocess(self.buffer)
inplace_tensor_mul(self.value.detach(), self.buffer)
threshold = self.restrict_preprocess(threshold)
value = self.restrict_scaling_impl.combine_stats_threshold(self.value, threshold)
value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold)
self.counter = self.counter + 1
return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value)))
else:
threshold = self.restrict_preprocess(threshold)
value = self.restrict_scaling_impl.combine_stats_threshold(self.value, threshold)
value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold)
return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value)))

@brevitas.jit.script_method
def forward(self, stats_input: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor:
def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(stats_input)
if self.training:
Expand All @@ -398,7 +389,7 @@ def forward(self, stats_input: Tensor, threshold: Optional[torch.Tensor] = None)
out = self.restrict_preprocess(out)
else:
threshold = self.restrict_preprocess(threshold)
out = self.restrict_scaling_impl.combine_stats_threshold(self.value, threshold)
out = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold)
out = abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(out)))
return out

Expand Down

0 comments on commit f54192c

Please sign in to comment.