forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
channel_shuffle_op.cc
169 lines (146 loc) · 4.33 KB
/
channel_shuffle_op.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#include "channel_shuffle_op.h"
#include <array>
#include <string>
#include <vector>
#ifdef CAFFE2_USE_MKL
#include <mkl.h>
#endif // CAFFE2_USE_MKL
#include "caffe2/utils/math.h"
namespace caffe2 {
namespace {
template <typename T>
void RunChannelShuffleNCHW(
const int N,
const int G,
const int K,
const int HxW,
const T* X,
T* Y,
CPUContext* context) {
const int stride = G * K * HxW;
for (int i = 0; i < N; ++i) {
if (G < K) {
for (int j = 0; j < G; ++j) {
math::CopyMatrix<T, CPUContext>(
K, HxW, X + j * K * HxW, HxW, Y + j * HxW, G * HxW, context);
}
} else {
for (int j = 0; j < K; ++j) {
math::CopyMatrix<T, CPUContext>(
G, HxW, X + j * HxW, K * HxW, Y + j * G * HxW, HxW, context);
}
}
X += stride;
Y += stride;
}
}
template <typename T>
void RunChannelShuffleNHWC(
const int N,
const int G,
const int K,
const int HxW,
const T* X,
T* Y,
CPUContext* context) {
const std::array<std::int64_t, 2> dims = {G, K};
const std::array<std::int32_t, 2> axes = {1, 0};
const int M = N * HxW;
const int C = G * K;
for (int i = 0; i < M; ++i) {
math::Transpose<std::int64_t, T, CPUContext>(
2, dims.data(), axes.data(), X, Y, context);
X += C;
Y += C;
}
}
} // namespace
template <>
bool ChannelShuffleOp<float, CPUContext>::RunOnDeviceWithOrderNCHW() {
const auto& X = Input(0);
auto* Y = Output(0, X.sizes(), at::dtype<float>());
const int N = X.dim32(0);
const int C = X.dim32(1);
const int G = group_;
CAFFE_ENFORCE_EQ(C % G, 0);
const int K = C / G;
const int HxW = X.size_from_dim(2);
const float* X_data = X.data<float>();
float* Y_data = Y->mutable_data<float>();
RunChannelShuffleNCHW<float>(N, G, K, HxW, X_data, Y_data, &context_);
return true;
} // namespace caffe2
template <>
bool ChannelShuffleOp<float, CPUContext>::RunOnDeviceWithOrderNHWC() {
const auto& X = Input(0);
auto* Y = Output(0, X.sizes(), at::dtype<float>());
const int ndim = X.dim();
const int N = X.dim32(0);
const int C = X.dim32(ndim - 1);
const int G = group_;
CAFFE_ENFORCE_EQ(C % G, 0);
const int K = C / G;
const int HxW = X.size_between_dim(0, ndim - 1);
const float* X_data = X.data<float>();
float* Y_data = Y->mutable_data<float>();
RunChannelShuffleNHWC<float>(N, G, K, HxW, X_data, Y_data, &context_);
return true;
}
template <>
bool ChannelShuffleGradientOp<float, CPUContext>::RunOnDeviceWithOrderNCHW() {
const auto& dY = Input(0);
auto* dX = Output(0, dY.sizes(), at::dtype<float>());
const int N = dY.dim32(0);
const int C = dY.dim32(1);
const int G = group_;
CAFFE_ENFORCE_EQ(C % G, 0);
const int K = C / G;
const int HxW = dY.size_from_dim(2);
const float* dY_data = dY.data<float>();
float* dX_data = dX->mutable_data<float>();
RunChannelShuffleNCHW<float>(N, K, G, HxW, dY_data, dX_data, &context_);
return true;
}
template <>
bool ChannelShuffleGradientOp<float, CPUContext>::RunOnDeviceWithOrderNHWC() {
const auto& dY = Input(0);
auto* dX = Output(0, dY.sizes(), at::dtype<float>());
const int ndim = dY.dim();
const int N = dY.dim32(0);
const int C = dY.dim32(ndim - 1);
const int G = group_;
CAFFE_ENFORCE_EQ(C % G, 0);
const int K = C / G;
const int HxW = dY.size_between_dim(0, ndim - 1);
const float* dY_data = dY.data<float>();
float* dX_data = dX->mutable_data<float>();
RunChannelShuffleNHWC<float>(N, K, G, HxW, dY_data, dX_data, &context_);
return true;
}
REGISTER_CPU_OPERATOR(ChannelShuffle, ChannelShuffleOp<float, CPUContext>);
REGISTER_CPU_GRADIENT_OPERATOR(
ChannelShuffleGradient,
ChannelShuffleGradientOp<float, CPUContext>);
OPERATOR_SCHEMA(ChannelShuffle)
.IdenticalTypeAndShape()
.NumInputs(1)
.NumOutputs(1)
.InheritOnnxSchema();
GRADIENT_OPERATOR_SCHEMA(ChannelShuffleGradient)
.IdenticalTypeAndShape()
.NumInputs(1)
.NumOutputs(1);
namespace {
class GetChannelShuffleGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
std::vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"ChannelShuffleGradient",
"",
std::vector<std::string>{GO(0)},
std::vector<std::string>{GI(0)});
}
};
} // namespace
REGISTER_GRADIENT(ChannelShuffle, GetChannelShuffleGradient);
} // namespace caffe2