Skip to content

Commit

Permalink
Merge pull request pulp-platform#40 from pulp-platform/pr/Norms
Browse files Browse the repository at this point in the history
Fix performances in FP32 and FP16 InstanceNorm input grad primitives.
  • Loading branch information
dnadalini authored Apr 19, 2024
2 parents 2161873 + 73b1a61 commit 759b56a
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 17 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ PULP-TrainLib's repository is organized with these branches:
- Manuele Rusci ([email protected])
- Francesco Conti ([email protected])
- Cristian Cioflan ([email protected])
- Luca Bompani ([email protected])

## Past Contributors

Expand Down
26 changes: 18 additions & 8 deletions lib/sources/pulp_instnorm_fp16.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/

/**
* Authors: Giacomo Saporetti, Davide Nadalini
* Authors: Giacomo Saporetti, Davide Nadalini, Luca Bompani
*/

#include "pmsis.h"
Expand Down Expand Up @@ -125,6 +125,10 @@ void pulp_instnorm_parallelized_fp16_bw_input_grads_cl( void * InstNorm_args_fp1
fp16 * running_stdev = args->running_stdev;
int freeze_running_params = args->freeze_running_params;

// Stabilize numerically
fp16 grad_scaling = 1e6;
fp16 grad_scaling_inv = 1 / grad_scaling;

int N = in->dim;
int C = in->C;
int H = in->H;
Expand All @@ -149,17 +153,23 @@ void pulp_instnorm_parallelized_fp16_bw_input_grads_cl( void * InstNorm_args_fp1
std = running_stdev[c];
var = running_var[c];

for (int d=0; d<D; d++)
fp16 grad_i_sum = 0;
fp16 grad_i_prod = 0;
for(int i=0; i<D; i++)
{
fp16 grad = 0;
grad_i_sum -= out_diff[i];
grad_i_prod -= (in_data[i] - mean) * out_diff[i];
}

for(int d=0; d<D; d++)
{
fp16 grad = grad_i_sum;
fp16 mean_d = (in_data[d] - mean) / var;

for (int i=0; i<D; i++)
{
grad -= out_diff[i] * (1 + (in_data[i] - mean) * mean_d);
}
grad += grad_i_prod*mean_d;

grad += D*out_diff[d];
grad = grad*gamma/(D*std);
grad = grad*gamma/(D*std);

in_diff[d] = grad;
}
Expand Down
21 changes: 14 additions & 7 deletions lib/sources/pulp_instnorm_fp32.c
Original file line number Diff line number Diff line change
Expand Up @@ -150,20 +150,27 @@ void pulp_instnorm_parallelized_fp32_bw_input_grads_cl( void * InstNorm_args )
std = running_stdev[c];
var = running_var[c];

for (int d=0; d<D; d++)
float grad_i_sum = 0;
float grad_i_prod = 0;
for(int i=0; i<D; i++)
{
float grad = 0;
grad_i_sum -= out_diff[i];
grad_i_prod -= (in_data[i] - mean) * out_diff[i];
}

for(int d=0; d<D; d++)
{
float grad = grad_i_sum;
float mean_d = (in_data[d] - mean) / var;

for (int i=0; i<D; i++)
{
grad -= out_diff[i] * (1 + (in_data[i] - mean) * mean_d);
}
grad += grad_i_prod*mean_d;

grad += D*out_diff[d];
grad = grad*gamma/(D*std);
grad = grad*gamma/(D*std);

in_diff[d] = grad;
}

}
}

Expand Down
5 changes: 3 additions & 2 deletions tests/test_instnorm_fp16/utils/GM.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
STEP = args.STEP
NUM_CORES = args.NUM_CORES

test_data = 100*torch.rand(CI, HI, WI)
#test_data = 100*torch.rand(CI, HI, WI)
test_data = torch.rand(CI, HI, WI)
test_data.requires_grad = True
test_labels = torch.rand(CI, HI, WI)

Expand Down Expand Up @@ -71,7 +72,7 @@


# Simple input data
inp = torch.torch.div(torch.randint(1000, [1, l1_in_ch, l1_hin, l1_win]), 1000).half().to(device)
inp = torch.torch.div(torch.randint(1000, [1, l1_in_ch, l1_hin, l1_win]), 1e6).half().to(device)
inp.requires_grad = True

class DNN(nn.Module):
Expand Down

0 comments on commit 759b56a

Please sign in to comment.