Skip to content

Commit

Permalink
Fix BatchNorm import with the pytorch interoperability.
Browse files Browse the repository at this point in the history
  • Loading branch information
cmoineau committed Jan 9, 2023
2 parents a118a0b + d7a410e commit e1bed92
Showing 1 changed file with 83 additions and 46 deletions.
129 changes: 83 additions & 46 deletions python/pytorch_to_n2d2/pytorch_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,75 @@ def backward(ctx, grad_output):
outputs = outputs.view(self.current_batch_size, -1)
return outputs

class ContextNoBatchNormFuse:
"""
PyTorch fuse batchnorm if train = False.
If train = True batchnorm stats are updated with the dummy tensor and are invalidated.
Solution : Temporary replacement of the forward method for every batchnorm.
"""
def __init__(self, model):
self.model = model
self.fowards = []

def __enter__(self):
"""Change batchnorm forward behavior when entering the block.
"""
for module in self.model.modules():
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
def fake_forward(inputs,
current_bn=module):
"""New batchnorm forward which save statistics and restore them after propagation
"""
if current_bn.momentum is None:
eaf = 0.0
else:
eaf = current_bn.momentum
assert current_bn.num_batches_tracked is not None
current_bn.num_batches_tracked.add_(1)
if current_bn.momentum is None: # use cumulative moving average
eaf = 1.0 / current_bn.num_batches_tracked.item()
else: # use exponential moving average
eaf = current_bn.momentum

# Save copy of values before propagation
saved_run_mean = current_bn.running_mean.detach().clone()
saved_run_var = current_bn.running_var.detach().clone()
saved_bias = current_bn.bias.detach().clone()
saved_weight = current_bn.weight.detach().clone()

# Real batchnorm forward
output_tensor = torch.nn.functional.batch_norm(
inputs,
current_bn.running_mean, # running_mean
current_bn.running_var, # running_var
current_bn.weight, # weight
current_bn.bias, # bias
True, # bn training
eaf, # exponential_average_factor
current_bn.eps, # epsilon
)

current_bn.running_mean.copy_(torch.nn.Parameter(saved_run_mean).requires_grad_(False))
current_bn.running_var.copy_(torch.nn.Parameter(saved_run_var).requires_grad_(False))
current_bn.bias = (torch.nn.Parameter(saved_bias))
current_bn.weight = (torch.nn.Parameter(saved_weight))

return output_tensor
# Update Batchnorm forward with the fake forward !
self.fowards.append(module.forward)
module.forward = fake_forward
return self

def __exit__(self, exc_type, exc_value, traceback):
"""Restore batchnorm forward behavior when exiting the block.
"""
cpt = 0
for module in self.model.modules():
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
# Restore Batchnorm forward
module.forward = self.fowards[cpt] # torch.nn.modules.batchnorm._BatchNorm.forward
cpt += 1

@n2d2.check_types
def wrap(torch_model:torch.nn.Module,
input_size: Union[list, tuple],
Expand All @@ -281,38 +350,22 @@ def wrap(torch_model:torch.nn.Module,
model_path = f'./{torch_model.__class__.__name__}.onnx'
print("Exporting torch module to ONNX ...")

dummy_in = torch.zeros(input_size).to(next(torch_model.parameters()).device)

# Note : To keep batchnorm we export model in train mode.
# However we cannot freeze batchnorm stats in pytorch < 12 (see : https://github.com/pytorch/pytorch/issues/75252).
# To deal with this issue we save stats before export and update the N2D2 BatchNorm.
batchnorm_stats = [] # Queue of batchnorm stats (means, vars, biases, weights)
for module in torch_model.modules():
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
batchnorm_stats.append((
module.running_mean.detach().clone(),
module.running_var.detach().clone(),
module.bias.detach().clone(),
module.weight.detach().clone()))

dummy_in = torch.zeros(input_size).to(next(torch_model.parameters()).device)
# And even in > 12 when stats freezed the ONNX graph drastically changes ...
# To deal with this issue we use a context which change the forward behavior of batchnorm to protect stats.
with ContextNoBatchNormFuse(torch_model) as ctx:
torch.onnx.export(torch_model,
dummy_in,
raw_model_path,
verbose=verbose,
export_params=True,
opset_version=opset_version,
training=torch.onnx.TrainingMode.TRAINING,
do_constant_folding=False)

torch.onnx.export(torch_model,
dummy_in,
raw_model_path,
verbose=verbose,
export_params=True,
opset_version=opset_version,
training=torch.onnx.TrainingMode.TRAINING,
do_constant_folding=False
)
tmp_bn_idx = 0
for module in torch_model.modules():
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
means, variances, biases, weights = batchnorm_stats[tmp_bn_idx]
module.running_mean.copy_(torch.nn.Parameter(means).requires_grad_(False))
module.running_var.copy_(torch.nn.Parameter(variances).requires_grad_(False))
module.bias = (torch.nn.Parameter(biases))
module.weight = (torch.nn.Parameter(weights))
tmp_bn_idx +=1

print("Simplifying the ONNX model ...")
onnx_model = onnx.load(raw_model_path)
Expand All @@ -339,30 +392,14 @@ def wrap(torch_model:torch.nn.Module,
need_to_flatten = False

for cell in deepNet:
if isinstance(cell, n2d2.cells.BatchNorm2d):
# Note : We make the asumption that the pytorch and N2D2 graph are brwosed in the same way.
# If accuracy drastically drop after export this part of the code may be the problem !
# PyTorch and ONNX names are not the same.
means, variances, biases, weights = batchnorm_stats.pop(0)
for idx, mean in enumerate(means):
cell.N2D2().setMean(idx, n2d2.Tensor([1], value=mean.item()).N2D2())
for idx, variance in enumerate(variances):
cell.N2D2().setVariance(idx, n2d2.Tensor([1], value=variance.item()).N2D2())
for idx, bias in enumerate(biases):
cell.N2D2().setBias(idx, n2d2.Tensor([1], value=bias.item()).N2D2())
for idx, weight in enumerate(weights):
cell.N2D2().setScale(idx, n2d2.Tensor([1], value=weight.item()).N2D2())

elif isinstance(cell, n2d2.cells.Softmax):
if isinstance(cell, n2d2.cells.Softmax):
# ONNX import Softmax with_loss = True supposing we are using a CrossEntropy loss.
cell.with_loss = False
elif isinstance(cell, n2d2.cells.Fc):
# We suppose that the Fully connected layer are at the end of the network.
need_to_flatten = True
else:
pass
if len(batchnorm_stats) != 0:
raise RuntimeError("Something went wrong when converting the torch model to N2D2, not the same number of BatchNorm layer !")

deepNet._embedded_deepnet.N2D2().initialize()

Expand Down

0 comments on commit e1bed92

Please sign in to comment.