forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
batch_moments_op.cc
122 lines (107 loc) · 3.52 KB
/
batch_moments_op.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#include "caffe2/operators/batch_moments_op.h"
#include <string>
#include <vector>
#include "caffe2/utils/eigen_utils.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <>
bool BatchMomentsOp<float, CPUContext>::ComputeBatchMomentsNCHW(
const int N,
const int C,
const int HxW,
const float* X,
float* mu,
float* var) {
math::Set<float, CPUContext>(C, 0.0f, mu, &context_);
math::Set<float, CPUContext>(C, 0.0f, var, &context_);
EigenVectorArrayMap<float> mu_arr(mu, C);
EigenVectorArrayMap<float> var_arr(var, C);
const float* X_ptr = X;
const int stride = C * HxW;
for (int i = 0; i < N; ++i) {
ConstEigenArrayMap<float> X_arr(X_ptr, HxW, C);
mu_arr += X_arr.colwise().sum();
var_arr += X_arr.square().colwise().sum();
X_ptr += stride;
}
const float scale = 1.0f / static_cast<float>(N * HxW);
math::Scale<float, float, CPUContext>(C, scale, mu, mu, &context_);
math::Scale<float, float, CPUContext>(C, scale, var, var, &context_);
return true;
}
template <>
bool BatchMomentsOp<float, CPUContext>::ComputeBatchMomentsNHWC(
const int N,
const int C,
const int HxW,
const float* X,
float* mu,
float* var) {
ConstEigenArrayMap<float> X_arr(X, C, N * HxW);
EigenVectorMap<float>(mu, C) = X_arr.rowwise().mean();
EigenVectorMap<float>(var, C) = X_arr.square().rowwise().mean();
return true;
}
template <>
bool BatchMomentsGradientOp<float, CPUContext>::ComputeBatchMomentsGradientNCHW(
const int N,
const int C,
const int HxW,
const float* dmu,
const float* dvar,
const float* X,
float* dX) {
ConstEigenVectorArrayMap<float> dmu_arr(dmu, C);
ConstEigenVectorArrayMap<float> dvar_arr(dvar, C);
const float* X_ptr = X;
float* dX_ptr = dX;
const int stride = C * HxW;
for (int i = 0; i < N; ++i) {
EigenArrayMap<float> dX_arr(dX_ptr, HxW, C);
dX_arr = ConstEigenArrayMap<float>(X_ptr, HxW, C).rowwise() *
dvar_arr.transpose() * 2.0f;
dX_arr.rowwise() += dmu_arr.transpose();
X_ptr += stride;
dX_ptr += stride;
}
const float scale = 1.0f / static_cast<float>(N * HxW);
math::Scale<float, float, CPUContext>(N * C * HxW, scale, dX, dX, &context_);
return true;
}
template <>
bool BatchMomentsGradientOp<float, CPUContext>::ComputeBatchMomentsGradientNHWC(
const int N,
const int C,
const int HxW,
const float* dmu,
const float* dvar,
const float* X,
float* dX) {
const float scale = 1.0f / static_cast<float>(N * HxW);
EigenArrayMap<float> dX_arr(dX, C, N * HxW);
dX_arr = ConstEigenArrayMap<float>(X, C, N * HxW).colwise() *
ConstEigenVectorArrayMap<float>(dvar, C) * 2.0f;
dX_arr.colwise() += ConstEigenVectorArrayMap<float>(dmu, C);
math::Scale<float, float, CPUContext>(N * C * HxW, scale, dX, dX, &context_);
return true;
}
REGISTER_CPU_OPERATOR(BatchMoments, BatchMomentsOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(
BatchMomentsGradient,
BatchMomentsGradientOp<float, CPUContext>);
OPERATOR_SCHEMA(BatchMoments).NumInputs(1).NumOutputs(2);
OPERATOR_SCHEMA(BatchMomentsGradient).NumInputs(3).NumOutputs(1);
namespace {
class GetBatchMomentsGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
std::vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"BatchMomentsGradient",
"",
std::vector<std::string>{GO(0), GO(1), I(0)},
std::vector<std::string>{GI(0)});
}
};
} // namespace
REGISTER_GRADIENT(BatchMoments, GetBatchMomentsGradient);
} // namespace caffe2