Skip to content

Commit

Permalink
Pre-commit fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert committed Sep 13, 2023
1 parent 8dde099 commit e9de10f
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 28 deletions.
1 change: 1 addition & 0 deletions src/brevitas_examples/super_resolution/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from functools import partial
from typing import Union

from torch import hub
import torch.nn as nn

Expand Down
2 changes: 1 addition & 1 deletion src/brevitas_examples/super_resolution/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat):
scaling_per_output_channel = True


class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant):
class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant):
pre_scaling_min_val = 1e-10
scaling_min_val = 1e-10

Expand Down
25 changes: 5 additions & 20 deletions src/brevitas_examples/super_resolution/models/espcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@
from .common import ConstUint8ActQuant

__all__ = [
"float_espcn",
"quant_espcn",
"quant_espcn_a2q",
"quant_espcn_base",
"FloatESPCN",
"QuantESPCN"]
"float_espcn", "quant_espcn", "quant_espcn_a2q", "quant_espcn_base", "FloatESPCN", "QuantESPCN"]

IO_DATA_BIT_WIDTH = 8
IO_ACC_BIT_WIDTH = 32
Expand Down Expand Up @@ -47,19 +42,9 @@ def __init__(self, upscale_factor: int = 3, num_channels: int = 3):
padding=2,
bias=True)
self.conv2 = nn.Conv2d(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
bias=True)
in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)
self.conv3 = nn.Conv2d(
in_channels=64,
out_channels=32,
kernel_size=3,
stride=1,
padding=1,
bias=True)
in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, bias=True)
self.conv4 = nn.Conv2d(
in_channels=32,
out_channels=num_channels * pow(upscale_factor, 2),
Expand All @@ -85,7 +70,7 @@ def forward(self, inp: Tensor):
x = self.relu(self.bn2(self.conv2(x)))
x = self.relu(self.bn3(self.conv3(x)))
x = self.pixel_shuffle(self.conv4(x))
x = self.out(x) # To mirror quant version
x = self.out(x) # To mirror quant version
return x


Expand Down Expand Up @@ -145,7 +130,7 @@ def __init__(
weight_quant=weight_quant)
# We quantize the weights and input activations of the final layer
# to 8-bit integers. We do not apply the accumulator constraint to
# the final convolution layer. FINN does not currently support
# the final convolution layer. FINN does not currently support
# per-tensor quantization or biases for sub-pixel convolution layers.
self.conv4 = qnn.QuantConv2d(
in_channels=32,
Expand Down
10 changes: 4 additions & 6 deletions src/brevitas_examples/super_resolution/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@
import torch.utils.data as data
from torchvision.transforms import CenterCrop
from torchvision.transforms import Compose
from torchvision.transforms import Resize
from torchvision.transforms import ToTensor
from torchvision.transforms import RandomCrop
from torchvision.transforms import RandomVerticalFlip
from torchvision.transforms import RandomHorizontalFlip
from torchvision.transforms import RandomVerticalFlip
from torchvision.transforms import Resize
from torchvision.transforms import ToTensor

__all__ = ["get_bsd300_dataloaders"]

Expand Down Expand Up @@ -127,9 +127,7 @@ def calculate_valid_crop_size(crop_size, upscale_factor):

def train_transforms(crop_size):
return Compose([
RandomCrop(crop_size, pad_if_needed=True),
RandomHorizontalFlip(),
RandomVerticalFlip()])
RandomCrop(crop_size, pad_if_needed=True), RandomHorizontalFlip(), RandomVerticalFlip()])


def test_transforms(crop_size):
Expand Down
3 changes: 2 additions & 1 deletion src/brevitas_examples/super_resolution/utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

import torch
from torch import Tensor
from brevitas.function import abs_binary_sign_grad

from brevitas.core.scaling.pre_scaling import AccumulatorAwareParameterPreScaling
from brevitas.function import abs_binary_sign_grad

device = 'cuda' if torch.cuda.is_available() else 'cpu'

Expand Down

0 comments on commit e9de10f

Please sign in to comment.