-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
44 lines (32 loc) · 1.04 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import re
import numpy as np
import torch
import torch.distributed as dist
import collections
import logging
def info(strings):
print(strings)
class LossMeter(object):
def __init__(self, maxlen=100):
"""Computes and stores the running average"""
self.vals = collections.deque([], maxlen=maxlen)
def __len__(self):
return len(self.vals)
def update(self, new_val):
self.vals.append(new_val)
@property
def val(self):
return sum(self.vals) / len(self.vals)
def __repr__(self):
return str(self.val)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def load_state_dict(state_dict_path, loc='cpu'):
state_dict = torch.load(state_dict_path, map_location=loc)
# Change Multi GPU to single GPU
original_keys = list(state_dict.keys())
for key in original_keys:
if key.startswith("module."):
new_key = key[len("module."):]
state_dict[new_key] = state_dict.pop(key)
return state_dict