-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_mixmatch.py
31 lines (27 loc) · 966 Bytes
/
test_mixmatch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
"""Same as notebook"""
import unittest
import numpy as np
from .model import get_small_model
from .mixmatch import ArrayDataset, MixupLoader, MixMatchLoss
import torch.nn.functional as F
from torch import nn
import pickle
def pickle_load(path):
with open(path, 'rb') as f:
return pickle.load(f, encoding='latin1')
class TestMismatch(unittest.TestCase):
def test_loader(self):
(X_labeled, y_labeled, X_unlabeled) = pickle_load('cifar_subset.pkl')
ds = ArrayDataset(X_labeled[:12], y_labeled[:12], X_unlabeled[:12])
BS = 4
model = get_small_model()
loader = MixupLoader(ds, batch_size=BS)
loader.model = model
loss_fn = MixMatchLoss()
for xb, yb in loader:
# print(x.shape,y.shape)
# print(np.round(to_arr(yb), 3))
preds = F.softmax(model.forward(xb), dim=1)
loss = loss_fn(preds, yb, BS // 2)
print(loss)
break