-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
50 lines (36 loc) · 1.49 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import unittest
from similarity_retrieval.model import LatentModel, get_pretrained_model
from similarity_retrieval.database import download_fashion_mnist
class TestDataset(unittest.TestCase):
def test_download_fashion_mnist(self):
(x_train, y_train), (x_test, y_test) = download_fashion_mnist()
assert x_train.shape == (10000, 28, 28, 3)
assert x_test.shape == (10000, 28, 28, 3)
assert y_train.shape == (10000, 10)
assert y_test.shape == (10000, 10)
class TestModel(unittest.TestCase):
def test_get_pretrained_model(self):
model = get_pretrained_model(
pretrained_model_name="Vgg16",
model_name="Vgg16",
IMAGE_SIZE=28,
colorspace=3,
use_pretrained=True,
)
assert model.__class__.__name__ == "Sequential"
def test_model_training(self):
training_files = download_fashion_mnist()
embedding_model = get_pretrained_model()
latent_model = LatentModel(embedding_model)
latent_model.train(training_files)
def test_model_query(self):
(x_train, y_train), (x_test, y_test) = download_fashion_mnist()
training_files = (x_train, y_train)
embedding_model = get_pretrained_model()
latent_model = LatentModel(embedding_model)
latent_model.train(training_files)
image = x_test[0]
index = latent_model.query(image)
assert type(index) == list
if __name__ == "__main__":
unittest.main()