-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbase.py
81 lines (59 loc) · 1.86 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# -*- coding: utf-8 -*-
# @Time : 2022/10/21 19:49
# @Author : 之落花--falling_flowers
# @File : base.py
# @Software: PyCharm
from typing import Callable
import torch
from torchviz import make_dot
def timer(func):
import time
def timerfunc(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
print(f'{func.__name__} used time: {end - start}s')
return result
return timerfunc
def ringer(frequency=500, duration=500):
def wrapper(func: Callable):
def ringfunc(*args, **kwargs):
import winsound
result = func(*args, **kwargs)
winsound.Beep(frequency, duration)
return result
return ringfunc
return wrapper
def imgshow(img):
import numpy as np
from matplotlib import pyplot as plt
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.ioff()
plt.show()
def test(net, path, dataloader):
net.train(False)
try:
net.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
except FileNotFoundError as e:
print(e)
exit(-1)
i = 0
for data, target in dataloader:
outcome = net(data)
if torch.argmax(outcome) == target[0]:
i += 1
print(f'Correct rate: {i}/{len(dataloader)}')
# def summary(input_size, model, _print=True, border=False):
# import torchsummary
# torchsummary.summary(input_size, model, _print, border)
def imshow(net: torch.nn.Module, input_, format_: str, name: str, directory: str = './image'):
img = make_dot(net(input_), params=dict(net.named_parameters()), show_attrs=True, show_saved=True)
img.format = format_
img.view(cleanup=True, filename=name, directory=directory)
@timer
def main():
pass
if __name__ == '__main__':
main()