forked from tensorpack/tensorpack
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dorefa.py
102 lines (76 loc) · 2.76 KB
/
dorefa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# -*- coding: utf-8 -*-
# File: dorefa.py
# Author: Yuxin Wu
import tensorflow as tf
from tensorpack.utils.argtools import graph_memoized
@graph_memoized
def get_dorefa(bitW, bitA, bitG):
"""
Return the three quantization functions fw, fa, fg, for weights, activations and gradients respectively
It's unsafe to call this function multiple times with different parameters
"""
def quantize(x, k):
n = float(2 ** k - 1)
@tf.custom_gradient
def _quantize(x):
return tf.round(x * n) / n, lambda dy: dy
return _quantize(x)
def fw(x):
if bitW == 32:
return x
if bitW == 1: # BWN
E = tf.stop_gradient(tf.reduce_mean(tf.abs(x)))
@tf.custom_gradient
def _sign(x):
return tf.sign(x / E) * E, lambda dy: dy
return _sign(x)
x = tf.tanh(x)
x = x / tf.reduce_max(tf.abs(x)) * 0.5 + 0.5
return 2 * quantize(x, bitW) - 1
def fa(x):
if bitA == 32:
return x
return quantize(x, bitA)
def fg(x):
if bitG == 32:
return x
@tf.custom_gradient
def _identity(input):
def grad_fg(x):
rank = x.get_shape().ndims
assert rank is not None
maxx = tf.reduce_max(tf.abs(x), list(range(1, rank)), keep_dims=True)
x = x / maxx
n = float(2**bitG - 1)
x = x * 0.5 + 0.5 + tf.random_uniform(
tf.shape(x), minval=-0.5 / n, maxval=0.5 / n)
x = tf.clip_by_value(x, 0.0, 1.0)
x = quantize(x, bitG) - 0.5
return x * maxx * 2
return input, grad_fg
return _identity(x)
return fw, fa, fg
def ternarize(x, thresh=0.05):
"""
Implemented Trained Ternary Quantization:
https://arxiv.org/abs/1612.01064
Code modified from the authors' at:
https://github.com/czhu95/ternarynet/blob/master/examples/Ternary-Net/ternary.py
"""
shape = x.get_shape()
thre_x = tf.stop_gradient(tf.reduce_max(tf.abs(x)) * thresh)
w_p = tf.get_variable('Wp', initializer=1.0, dtype=tf.float32)
w_n = tf.get_variable('Wn', initializer=1.0, dtype=tf.float32)
tf.summary.scalar(w_p.op.name + '-summary', w_p)
tf.summary.scalar(w_n.op.name + '-summary', w_n)
mask = tf.ones(shape)
mask_p = tf.where(x > thre_x, tf.ones(shape) * w_p, mask)
mask_np = tf.where(x < -thre_x, tf.ones(shape) * w_n, mask_p)
mask_z = tf.where((x < thre_x) & (x > - thre_x), tf.zeros(shape), mask)
@tf.custom_gradient
def _sign_mask(x):
return tf.sign(x) * mask_z, lambda dy: dy
w = _sign_mask(x)
w = w * mask_np
tf.summary.histogram(w.name, w)
return w