forked from hanqingguo/infocom2021
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_user_case_hide.py
99 lines (84 loc) · 4.38 KB
/
generate_user_case_hide.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from datasets.eva_dataloader_hide import create_dataloader
import wavio
import argparse
from utils.hparams import HParam
import os
import glob
import torch
import librosa
import argparse
from utils.audio import Audio
from model.model import VoiceFilter
from model.embedder import SpeechEmbedder
from utils.evaluation import tensor_normalize
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-b', '--base_dir', type=str, default='.',
help="Root directory of run.")
parser.add_argument('-c', '--config', type=str, required=True,
help="yaml file for configuration")
parser.add_argument('-e', '--embedder_path', type=str, required=True,
help="path of embedder model pt file")
parser.add_argument('--checkpoint_path', type=str, default=None,
help="path of checkpoint pt file of focus model")
parser.add_argument('-m', '--model', type=str, required=True,
help="Name of the model. Used for both logging and saving checkpoints.")
parser.add_argument('-g', '--gpu', type=int, required=True, default='1',
help="ID of the selected gpu. Used for gpu selection.")
parser.add_argument('-o', '--out_dir', type=str, required=True,
help="out directory of result.wav")
args = parser.parse_args()
hp = HParam(args.config)
root_dir_test = hp.data.test_dir
alldirs = [x[0] for x in os.walk(root_dir_test)]
dirs = [leaf for leaf in alldirs if len(leaf.split('/'))>5]
speaker_count = 0
# dir = '/data/our_dataset/test/3/joint'
for dir in dirs:
speaker_count = speaker_count + 1
print("Speaker : {}/56\n".format(speaker_count))
tree = dir.split('/')
speaker_id = tree[-2]
hp.data.test_dir = dir
testloader = create_dataloader(hp, args, train=False)
for batch in testloader:
# length of batch is 1, set in dataloader
ref_mel, eliminated_wav, mixed_wav, expected_hidden_wav, eliminated_mag, expected_hidden_mag, mixed_mag, mixed_phase, dvec_path, eliminated_wav_path, mixed_wav_path = \
batch[0]
# print("expected_focused: {}".format(expected_focused_wav_path))
print("Mixed: {}".format(mixed_wav_path))
model = VoiceFilter(hp).cuda()
chkpt_model = torch.load(args.checkpoint_path, map_location='cuda:0')['model']
model.load_state_dict(chkpt_model)
model.eval()
embedder = SpeechEmbedder(hp).cuda()
chkpt_embed = torch.load(args.embedder_path)
embedder.load_state_dict(chkpt_embed)
embedder.eval()
audio = Audio(hp)
dvec_wav, _ = librosa.load(dvec_path, sr=16000)
ref_mel = audio.get_mel(dvec_wav)
ref_mel = torch.from_numpy(ref_mel).float().cuda()
dvec = embedder(ref_mel)
dvec = dvec.unsqueeze(0) # (1, 256)
mixed_wav, _ = librosa.load(mixed_wav_path, sr=16000)
mixed_mag, mixed_phase = audio.wav2spec(mixed_wav)
mixed_mag = torch.from_numpy(mixed_mag).float().cuda()
mixed_mag = mixed_mag.unsqueeze(0)
shadow_mag = model(mixed_mag, dvec)
# shadow_mag.size() = [1, 301, 601]
recorded_mag = tensor_normalize(mixed_mag + shadow_mag)
recorded_mag = recorded_mag[0].cpu().detach().numpy()
mixed_mag = mixed_mag[0].cpu().detach().numpy()
shadow_mag = shadow_mag[0].cpu().detach().numpy()
shadow_wav = audio.spec2wav(shadow_mag, mixed_phase)
# scale is frequency pass to time domain, used on wav signal normalization
recorded_wav1 = audio.spec2wav(recorded_mag, mixed_phase) # path 1
# mixed_Wav_path = '/data/our_dataset/test/13/babble/000001-mixed.wav'
hide1 = mixed_wav_path[:-9] + 'hide1.wav'
hide2 = mixed_wav_path[:-9] + 'hide2.wav'
# purified3 = os.path.join(args.out_dir, 'result3.wav')
# original mixed wav and expected_focused wav are not PCM, cannot be read by google cloud
wavio.write(hide1, recorded_wav1, 16000, sampwidth=2) # frequency +
wavio.write(hide2, shadow_wav, 16000, sampwidth=2) # est noise
# wavio.write(purified3, enhanced_wav, 16000, sampwidth=2) # mix + est noise