Skip to content

Commit

Permalink
minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Nov 4, 2024
1 parent 53ec156 commit 6e5c3e4
Show file tree
Hide file tree
Showing 6 changed files with 329 additions and 58 deletions.
2 changes: 1 addition & 1 deletion .github/scripts/ljspeech/TTS/run-matcha.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function infer() {

curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1

./matcha/infer.py \
./matcha/synth.py \
--epoch 1 \
--exp-dir ./matcha/exp \
--tokens data/tokens.txt \
Expand Down
182 changes: 133 additions & 49 deletions egs/ljspeech/TTS/matcha/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@

import soundfile as sf
import torch
from matcha.hifigan.config import v1, v2, v3
from matcha.hifigan.denoiser import Denoiser
from matcha.hifigan.models import Generator as HiFiGAN
import torch.nn as nn
from hifigan.config import v1, v2, v3
from hifigan.denoiser import Denoiser
from hifigan.models import Generator as HiFiGAN
from tokenizer import Tokenizer
from train import get_model, get_params
from tts_datamodule import LJSpeechTtsDataModule

from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict
from icefall.utils import AttributeDict, setup_logger


def get_parser():
Expand Down Expand Up @@ -63,24 +65,10 @@ def get_parser():
help="""Path to vocabulary.""",
)

parser.add_argument(
"--input-text",
type=str,
required=True,
help="The text to generate speech for",
)

parser.add_argument(
"--output-wav",
type=str,
required=True,
help="The filename of the wave to save the generated speech",
)

return parser


def load_vocoder(checkpoint_path):
def load_vocoder(checkpoint_path: Path) -> nn.Module:
checkpoint_path = str(checkpoint_path)
if checkpoint_path.endswith("v1"):
h = AttributeDict(v1)
Expand All @@ -100,22 +88,30 @@ def load_vocoder(checkpoint_path):
return hifigan


def to_waveform(mel, vocoder, denoiser):
def to_waveform(
mel: torch.Tensor, vocoder: nn.Module, denoiser: nn.Module
) -> torch.Tensor:
audio = vocoder(mel).clamp(-1, 1)
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
return audio.cpu().squeeze()


def process_text(text: str, tokenizer):
def process_text(text: str, tokenizer: Tokenizer) -> dict:
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
x = torch.tensor(x, dtype=torch.long)
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
return {"x_orig": text, "x": x, "x_lengths": x_lengths}


def synthesise(
model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None
):
model: nn.Module,
tokenizer: Tokenizer,
n_timesteps: int,
text: str,
length_scale: float,
temperature: float,
spks=None,
) -> dict:
text_processed = process_text(text, tokenizer)
start_t = dt.datetime.now()
output = model.synthesise(
Expand All @@ -131,14 +127,102 @@ def synthesise(
return output


def infer_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
vocoder: nn.Module,
denoiser: nn.Module,
tokenizer: Tokenizer,
) -> None:
"""Decode dataset.
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
tokenizer:
Used to convert text to phonemes.
"""

device = next(model.parameters()).device
num_cuts = 0
log_interval = 5

try:
num_batches = len(dl)
except TypeError:
num_batches = "?"

for batch_idx, batch in enumerate(dl):
batch_size = len(batch["tokens"])

texts = batch["supervisions"]["text"]

audio = batch["audio"]
audio_lens = batch["audio_lens"].tolist()
cut_ids = [cut.id for cut in batch["cut"]]

for i in range(batch_size):
output = synthesise(
model=model,
tokenizer=tokenizer,
n_timesteps=params.n_timesteps,
text=texts[i],
length_scale=params.length_scale,
temperature=params.temperature,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)

sf.write(
file=params.save_wave_dir / f"{cut_ids[i]}_pred.wav",
data=output["waveform"],
samplerate=params.sampling_rate,
subtype="PCM_16"
)
sf.write(
file=params.save_wave_dir / f"{cut_ids[i]}_gt.wav",
data=audio[i].numpy(),
samplerate=params.sampling_rate,
subtype="PCM_16"
)

num_cuts += batch_size

if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"

logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")


@torch.inference_mode()
def main():
parser = get_parser()
LJSpeechTtsDataModule.add_arguments(parser)
args = parser.parse_args()
params = get_params()
args.exp_dir = Path(args.exp_dir)

params = get_params()
params.update(vars(args))

params.suffix = f"epoch-{params.epoch}"

params.res_dir = params.exp_dir / "infer" / params.suffix
params.save_wav_dir = params.res_dir / "wav"
params.save_wav_dir.mkdir(parents=True, exist_ok=True)

setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
logging.info("Infer started")

device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")

tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
Expand All @@ -151,49 +235,49 @@ def main():

params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]

# Number of ODE Solver steps
params.n_timesteps = 2

# Changes to the speaking rate
params.length_scale = 1.0

# Sampling temperature
params.temperature = 0.667
logging.info(params)

logging.info("About to create model")
model = get_model(params)

if not Path(f"{params.exp_dir}/epoch-{params.epoch}.pt").is_file():
raise ValueError("{params.exp_dir}/epoch-{params.epoch}.pt does not exist")

load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
model.eval()

# we need cut ids to organize tts results.
args.return_cuts = True
ljspeech = LJSpeechTtsDataModule(args)

test_cuts = ljspeech.test_cuts()
test_dl = ljspeech.test_dataloaders(test_cuts)

if not Path(params.vocoder).is_file():
raise ValueError(f"{params.vocoder} does not exist")

vocoder = load_vocoder(params.vocoder)
denoiser = Denoiser(vocoder, mode="zeros")

# Number of ODE Solver steps
n_timesteps = 2
vocoder = vocoder.to(device)

# Changes to the speaking rate
length_scale = 1.0

# Sampling temperature
temperature = 0.667
denoiser = Denoiser(vocoder, mode="zeros")
denoiser = denoiser.to(device)

output = synthesise(
infer_dataset(
dl=test_dl,
params=params,
model=model,
vocoder=vocoder,
denoiser=denoiser,
tokenizer=tokenizer,
n_timesteps=n_timesteps,
text=params.input_text,
length_scale=length_scale,
temperature=temperature,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)

sf.write(params.output_wav, output["waveform"], 22050, "PCM_16")


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

logging.basicConfig(format=formatter, level=logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main()
Loading

0 comments on commit 6e5c3e4

Please sign in to comment.