-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtests.py
311 lines (234 loc) · 9.68 KB
/
tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
from itertools import product
import haiku as hk
import jax
import jax.numpy as jnp
import jmp
import numpy as np
import pytest
from jax_gptq import use_quantized
from jax_gptq.quantize_interpreter import quantize
from jax_gptq.gptq import (
pack_colwise,
unpack_colwise,
pack_matrix,
unpack_matrix,
get_col_quantize_params,
gptq,
QuantizedMatrix,
accumulate_H
)
@pytest.fixture
def simple_model():
def f(w, x):
return x @ w
w = jax.random.normal(jax.random.PRNGKey(0), (256, 1024))
orig_w = np.asarray(w)
xs = [jax.random.normal(jax.random.PRNGKey(1), (32, 256), dtype=jnp.float16) for _ in range(4)]
quantized = jax.device_put(quantize(f, w, xs, block_size=16), jax.devices('gpu')[0])
return f, w, quantized, xs
def test_get_quantize_params():
w = jnp.asarray([
[0.1, 0.2, 0.3, 0.4],
[-0.1, -0.1, -0.1, -0.1],
], dtype=jnp.float32)
q_params = get_col_quantize_params(w)
scale = q_params['scale']
zero = q_params['zero']
assert zero.shape == scale.shape == (4,)
assert np.allclose(scale, (w[0, :] - w[1, :]) / 15)
def test_single():
def f(w, x):
return x @ w
w = jax.random.normal(jax.random.PRNGKey(0), (256, 64))
orig_w = np.asarray(w)
xs = [jax.random.normal(jax.random.PRNGKey(1), (32, 256)) for _ in range(4)]
fn = jax.vmap(f, (None, 0))
orig_output = fn(w, xs[0])
quantized = quantize(fn, w, xs, block_size=16)
de_quantized = unpack_matrix(quantized)
new_output = fn(de_quantized, xs[0])
print(f'Scale: {quantized.scale} Zero: {quantized.zero}')
diff = jnp.max(jnp.abs(orig_w - de_quantized))
print(f'Max quantization err: {diff}')
abs_error = jnp.abs(orig_output - new_output)
relative_error = jnp.abs(abs_error / orig_output)
print(f'Max relative error: {jnp.max(relative_error)}')
print(f'Max absolute error: {jnp.max(jnp.abs(orig_output - new_output))}')
print(f'Max error relative to max: {jnp.max(relative_error) / jnp.mean(jnp.abs(orig_output))}')
print(f'Max error relative to max: {jnp.max(relative_error) / jnp.max(jnp.abs(orig_output))}')
def test_hk():
hk.mixed_precision.set_policy(
hk.Linear,
jmp.Policy(
compute_dtype=jnp.float16,
output_dtype=jnp.float16,
param_dtype=jnp.float16
)
)
def f(x):
for _ in range(3):
x = hk.Linear(1024, with_bias=False)(x)
return x
in_dim = 256
model = hk.without_apply_rng(hk.transform(f))
params = model.init(jax.random.PRNGKey(0), jnp.ones(in_dim))
xs = [jax.random.normal(jax.random.PRNGKey(i), (32, in_dim)) for i in range(64)]
fn = jax.vmap(model.apply, (None, 0))
orig_output = fn(params, xs[0])
quantized_params = quantize(fn, params, xs)
manual_result = xs[0]
for i in range(3):
layer_params = quantized_params[f'linear_{i}' if i > 0 else 'linear']
manual_result = manual_result @ unpack_matrix(layer_params['w'])
manual_result = jax.device_put(manual_result, jax.devices('gpu')[0])
gpu_args = jax.device_put((quantized_params, xs[0]), jax.devices('gpu')[0])
new_output = use_quantized(fn)(*gpu_args)
#new_output = use_quantized(fn)(quantized_params, xs[0])
abs_error = jnp.abs(orig_output - new_output)
relative_error = jnp.abs(abs_error / orig_output)
print(f'Max absolute error from manual calculation: {jnp.max(jnp.abs(manual_result - new_output))}')
print(f'Max relative error: {jnp.max(relative_error)}')
print(f'Max absolute error: {jnp.max(jnp.abs(orig_output - new_output))}')
assert np.allclose(manual_result, new_output, atol=3e-3, rtol=0)
def test_transform():
def f(w, x):
return x @ w
w = jax.random.normal(jax.random.PRNGKey(0), (2048, 2048))
xs = [jax.random.normal(jax.random.PRNGKey(i), (32, 2048), dtype=jnp.float32) for i in range(64)]
fn = jax.vmap(f, (None, 0))
orig_result = jax.device_put(fn(w, xs[0]), jax.devices('gpu')[0])
quantized_params = quantize(fn, w, xs, block_size=16)
unpacked_matrix = unpack_matrix(quantized_params)
gpu_args = jax.device_put((quantized_params, xs[0]), jax.devices('gpu')[0])
manual_result = jax.device_put(xs[0] @ unpacked_matrix, jax.devices('gpu')[0])
transform_result = use_quantized(fn)(*gpu_args)
print(f'Gap to manual calculation: {np.max(np.abs(manual_result - transform_result))}')
print(f'Gap to original: {np.max(np.abs(orig_result - transform_result))}')
assert np.allclose(manual_result, transform_result, atol=1e-4, rtol=0)
def test_pack():
w = jax.random.randint(jax.random.PRNGKey(0), (256, 64), 0, 16)
packed = pack_colwise(w)
print(packed.shape, packed.dtype)
unpacked = unpack_colwise(packed)
assert jnp.all(unpacked == w)
def test_pack_matrix():
w = jax.random.normal(
jax.random.PRNGKey(0),
(256, 64),
)
xs = [jax.random.randint(jax.random.PRNGKey(i), (32, 256), 0, 16) for i in range(4)]
quantized, qparams = gptq(w, xs, block_size=4)
packed_w = pack_matrix(quantized, qparams)
unpacked_w = unpack_matrix(packed_w)
assert jnp.all(unpacked_w == quantized)
def test_remat(simple_model):
f, _, w_q, xs = simple_model
x = xs[0]
expected = x @ unpack_matrix(w_q)
fn = jax.jit(use_quantized(jax.checkpoint(f)))
result = fn(w_q, x)
diff = jnp.max(jnp.abs(result - expected))
print(f'Max error: {diff}')
assert np.allclose(result, expected, atol=0.03, rtol=0)
def test_grad(simple_model):
_, _, w_q, xs = simple_model
x = xs[0]
def f(x, w):
return jnp.sum((x @ w)[3])
unpacked = unpack_matrix(w_q)
grad_fn = jax.jit(jax.grad(use_quantized(jax.checkpoint(f))))
grad = grad_fn(x, w_q)
expected_grad = jax.grad(f)(x, unpacked)
print(f'Grad: {grad}')
print(f'Expected grad: {expected_grad}')
print(f'Max error: {jnp.max(jnp.abs(grad - expected_grad))}')
assert np.allclose(grad, expected_grad, atol=1e-1, rtol=0)
@pytest.mark.parametrize(
['w_contract', 'x_contract'],
list(product(range(3), range(2)))
)
def test_weird_dots(w_contract, x_contract):
w_shape = [47, 768]
x_shape = [32]
contract_dim_size = 256
w_shape.insert(w_contract, contract_dim_size)
x_shape.insert(x_contract, contract_dim_size)
w = jax.random.normal(jax.random.PRNGKey(9999), w_shape)
xs = [jax.random.normal(jax.random.PRNGKey(10 * i), x_shape) for i in range(10)]
def f(w, x):
return jax.lax.dot_general(
w, x,
(((w_contract,), (x_contract,)), ((), ())),
)
expected_output = f(w, xs[0])
quantized_params = quantize(f, w, xs, block_size=16)
quantized_output = use_quantized(f)(quantized_params, xs[0])
print(f'Error: {jnp.max(jnp.abs(expected_output - quantized_output)):.3e}')
mean_error = jnp.mean(jnp.abs(expected_output - quantized_output))
print(f'Mean error: {mean_error:.3e}')
assert (mean_error / jnp.mean(jnp.abs(expected_output))) < 0.1
def test_conv():
in_channels = 32
out_channels = 79
batch_size = 11
spatial_size = 13
w = jax.random.normal(jax.random.PRNGKey(0), (
out_channels,
in_channels,
))
xs = [jax.random.normal(jax.random.PRNGKey(i), (batch_size, in_channels, spatial_size)) for i in range(1, 20)]
def matmul_f(w, x):
return jax.lax.dot_general(w, x, (((1,), (1,)), ((), ())))
def conv_f(w, x):
return jax.lax.conv(x, w, (1,), 'VALID')
quantized_params = quantize(matmul_f, w, xs, block_size=16)
quantized_output = use_quantized(matmul_f)(quantized_params, xs[0])
conv_quantized = quantize(conv_f, w[:, :, None], xs, block_size=16)
unpacked_w = unpack_matrix(quantized_params)
unpacked_kernel = unpack_matrix(conv_quantized)
num_diff = jnp.sum(unpacked_w != unpacked_kernel[:, :, 0])
assert num_diff < 10 # Can have some differences from non-determinism in quantization
quanitzed_conv_output = use_quantized(conv_f)(conv_quantized, xs[0]).transpose(1, 0, 2)
max_diff = jnp.max(jnp.abs(quantized_output - quanitzed_conv_output))
print(f'Conv vs Matmul quantization diff: {max_diff:.3e}')
assert max_diff < 1e-5
def test_transpose():
w = jax.random.normal(jax.random.PRNGKey(9999), (256, 1024))
w_copy = np.asarray(w)
xs = [jax.random.normal(jax.random.PRNGKey(i), (256, 32)) for i in range(10)]
def f(w, x):
return w.T @ x
orig_output = f(w, xs[0])
quantized_params = quantize(f, w, xs, block_size=16)
assert isinstance(quantized_params, QuantizedMatrix)
quantized_output = use_quantized(f)(quantized_params, xs[0])
unpacked = unpack_matrix(quantized_params)
manual_result = unpacked.T @ xs[0]
param_diff = jnp.mean(jnp.abs(unpack_matrix(quantized_params) - w_copy))
assert param_diff < 0.15
assert np.allclose(manual_result, quantized_output, atol=1e-4)
def test_reuse(simple_model):
_, w, _, xs = simple_model
def f(w, x):
intermediate = w.T
answer = w.T @ x.T
answer += jnp.sum(w)
return answer
quantized_params = quantize(f, w, xs, block_size=16)
assert isinstance(quantized_params, QuantizedMatrix)
quantized_output = use_quantized(f)(quantized_params, xs[0])
def test_accumulate_H():
bsize = 37
n = 256
with jax.experimental.enable_x64():
xs = []
key = jax.random.PRNGKey(9876)
for _ in range(20):
xs.append(jax.random.normal(key, (bsize, n), jnp.float64))
n_ex = sum(x.shape[0] for x in xs)
manual_calculation = jnp.sum(
jnp.stack([x.T @ x for x in xs]),
axis=0
) / n_ex
gptq_result = accumulate_H(xs, use_fp64=True)
assert np.allclose(manual_calculation, gptq_result, atol=1e-15)