Skip to content

Commit

Permalink
Feat: Po2 scale support
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 2, 2024
1 parent 0bcea97 commit ee157bc
Show file tree
Hide file tree
Showing 28 changed files with 140 additions and 64 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/base.yml.template
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/base_reduced.yml.template
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/develop_install.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/end_to_end.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/examples_llm_pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/examples_pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/finn_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/notebook.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ort_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/reduced_develop_install.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/reduced_end_to_end.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/reduced_examples_llm_pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/reduced_examples_pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/reduced_finn_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/reduced_notebook.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/reduced_ort_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/reduced_pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
steps:

- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
Expand Down
1 change: 1 addition & 0 deletions src/brevitas/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ def env_to_bool(name, default):
_FULL_STATE_DICT = False
_IS_INSIDE_QUANT_LAYER = None
_ONGOING_EXPORT = None
_RETROCOMPATIBLE_SCALING = False
5 changes: 3 additions & 2 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,13 @@ def __init__(

@brevitas.jit.script_method
def quantize(self, x: torch.Tensor):
scale = self.scaling_impl(x)

if self.float_scaling_impl is not None:
float_scaling_impl_value = self.float_scaling_impl(
self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
scale = scale / float_scaling_impl_value
else:
float_scaling_impl_value = None
scale = self.scaling_impl(x, float_scaling_impl_value)
x = self.input_view_impl(x)
scaled_x = x / scale
internal_scale = float_internal_scale(
Expand Down
9 changes: 3 additions & 6 deletions src/brevitas/core/quant/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,8 @@ def __init__(
@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
bit_width = self.msb_clamp_bit_width_impl()
threshold = self.scaling_impl(x)
int_threshold = self.int_scaling_impl(bit_width)
scale = threshold / int_threshold
scale = self.scaling_impl(x, int_threshold)
zero_point = self.zero_point_impl(x, scale, bit_width)
y = self.int_quant(scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width
Expand Down Expand Up @@ -184,8 +183,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Te
pre_threshold = self.pre_scaling_impl(x)
pre_scale = pre_threshold / int_threshold
pre_zero_point = self.pre_zero_point_impl(x, pre_scale, bit_width)
threshold = self.scaling_impl(x)
scale = threshold / int_threshold
scale = self.scaling_impl(x, int_threshold)
zero_point = self.zero_point_impl(x, scale, bit_width)
y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width, pre_scale, pre_zero_point
Expand Down Expand Up @@ -250,8 +248,7 @@ def forward(self, x: Tensor, input_bit_width: Tensor,
pre_threshold = self.pre_scaling_impl(x, input_bit_width, input_is_signed)
pre_scale = pre_threshold / int_threshold
pre_zero_point = self.pre_zero_point_impl(x, pre_scale, bit_width)
threshold = self.scaling_impl(x)
scale = threshold / int_threshold
scale = self.scaling_impl(x, int_threshold)
zero_point = self.zero_point_impl(x, scale, bit_width)
y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width, pre_scale, pre_zero_point
12 changes: 12 additions & 0 deletions src/brevitas/core/restrict_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return Identity()

def retrocompatibility_op(self, x):
return x

@brevitas.jit.script_method
def forward(self, x: torch.Tensor) -> Tensor:
return x
Expand All @@ -113,6 +116,9 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return InplaceLogTwo()

def retrocompatibility_op(self, x):
return self.power_of_two(x)

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
x = self.power_of_two(x)
Expand All @@ -137,6 +143,9 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return Identity()

def retrocompatibility_op(self, x):
return x

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
x = self.float_to_int_impl(x)
Expand All @@ -162,6 +171,9 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return InplaceLogTwo()

def retrocompatibility_op(self, x):
return self.power_of_two(x)

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
x = self.float_to_int_impl(x)
Expand Down
28 changes: 19 additions & 9 deletions src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,12 @@ def __init__(
dtype,
device)

@brevitas.jit.script_method
def forward(self, x: Optional[torch.Tensor]) -> torch.Tensor:
def forward(
self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
stats = self.parameter_list_stats(x)
return self.stats_scaling_impl(stats)
if threshold is None:
threshold = torch.ones(1).type_as(stats)
return self.stats_scaling_impl(stats, threshold)


class _StatsScaling(brevitas.jit.ScriptModule):
Expand All @@ -80,8 +82,11 @@ def __init__(
self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module()

@brevitas.jit.script_method
def forward(self, stats: torch.Tensor) -> torch.Tensor:
stats = self.restrict_scaling_pre(stats)
def forward(
self, stats: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(stats)
stats = self.restrict_scaling_pre(stats / threshold)
stats = self.affine_rescaling(stats)
stats = self.restrict_clamp_scaling(stats)
return stats
Expand Down Expand Up @@ -120,9 +125,9 @@ def __init__(
device)

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
def forward(self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
stats = self.runtime_stats(x)
return self.stats_scaling_impl(stats)
return self.stats_scaling_impl(stats, threshold)


class _AffineRescaling(brevitas.jit.ScriptModule):
Expand Down Expand Up @@ -179,9 +184,14 @@ def __init__(
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)

@brevitas.jit.script_method
def forward(self, stats_input) -> torch.Tensor:
def forward(
self,
stats_input: torch.Tensor,
threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(stats_input)
stats_input_reshaped = self.input_view_impl(stats_input)
out = self.scaling_stats_impl(stats_input_reshaped)
out = self.scaling_stats_impl(stats_input_reshaped) / threshold
# Scaling min val
out = self.restrict_clamp_scaling(out)
return out
Loading

0 comments on commit ee157bc

Please sign in to comment.