-
Notifications
You must be signed in to change notification settings - Fork 15
/
hubconf.py
37 lines (25 loc) · 1.24 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
"""torch.hub configuration."""
dependencies = ["torch", "torchaudio"]
import torch # pylint: disable=wrong-import-position
from speechmos.utmos22.strong.model import UTMOS22Strong # pylint: disable=wrong-import-position
URLS = {
"utmos22_strong": "https://github.com/tarepan/SpeechMOS/releases/download/v1.0.0/utmos22_strong_step7459_v1.pt",
}
# [Origin]
# "utmos22_strong" is derived from official sarulab-speech/UTMOS22 'UTMOS strong learner' checkpoint, under MIT lisence (Copyright 2022 Saruwatari&Koyama laboratory, The University of Tokyo, https://github.com/sarulab-speech/UTMOS22/blob/master/LICENSE).
# Weight transfer code is in my fork (`/demo/utmos_strong_alt`).
def utmos22_strong(progress: bool = True, pretrained: bool = True) -> UTMOS22Strong:
"""
`UTMOS strong learner` speech naturalness MOS predictor.
Args:
progress - Whether to show model checkpoint load progress
"""
# Init
model = UTMOS22Strong()
# Pretrained weights
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(url=URLS["utmos22_strong"], map_location="cpu", progress=progress)
model.load_state_dict(state_dict)
# Mode
model.eval()
return model