From 50f4dea183ecebb0f79a61f43d2891926a5bdb9c Mon Sep 17 00:00:00 2001 From: lhl Date: Tue, 26 Mar 2024 09:27:28 +0900 Subject: [PATCH] We restructure inference.py to have a class `LipSyncInference` so we can separate loading and inference. This didn't actually make the inference faster so we added `timebudget` decorators to spot the bottleneck (in `preprocess`. From there we shaved off some ~100ms using `soundfile` to load the file and `scipy.signal` to resample, but actually it was switching to `python_speech_features` for the MFCC that shaved off about 600ms of time vs using `librosa` (in the end, we ended removing all librosa dependencies from inference). Before: ``` load_model took 16.14ms preprocess took 719.24ms postprocess took 0.086ms infer took 813.41ms ``` After: ``` load_model took 16.98ms preprocess took 11.03ms postprocess took 0.154ms infer took 84.22ms ``` We also now use cuda if available, it is slightly faster for inference. Added new libs to requirements, removed version numbers since they're outdated. --- inference.py | 330 +++++++++++++++++++++-------------------------- requirements.txt | 19 +-- 2 files changed, 160 insertions(+), 189 deletions(-) diff --git a/inference.py b/inference.py index 5611179..56f7adc 100644 --- a/inference.py +++ b/inference.py @@ -22,188 +22,151 @@ python inference.py --wav_file_name my_audio.wav --model_name my_model.pth """ -import time import argparse +from data_classes import LipSyncNet +from python_speech_features import mfcc +from scipy.signal import resample +import soundfile +from timebudget import timebudget import torch -import librosa -from data_classes import LipSyncNet - -TARGET_SR = 44100 -HOP_LENGTH = 512 - - -def preprocess(filename): - """ - Load an audio file and extract MFCC features for processing. - - Parameters: - filename (str): The path to the audio file. - - Returns: - numpy.ndarray: The MFCC features of the audio file. - """ - y, sr = librosa.load(filename, sr=None) - if sr != TARGET_SR: - y = librosa.resample(y=y, orig_sr=sr, target_sr=TARGET_SR) - mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13, hop_length=HOP_LENGTH) - mfcc = mfcc.T - return mfcc - - -def postprocess(output_data): - """ - Get the most likely mouth shape from the tensor. - - Parameters: - output_data (tensor): The data output by the model - (In the form of likelihoods of each possible mouth shape for each window.) - - Returns: - max_likelihood_list: A list with the most likely mouth shape for each window. - """ - max_likelihood = torch.argmax( - output_data, dim=2) # Find the maximum likelihood for each prediction - max_likelihood_list = torch.flatten( - max_likelihood).tolist() # Convert the tensor to a list - return max_likelihood_list - - -def convert_list_to_RLE(input_list): - """ - Converts the list to a format similar to RunLengthEncoding, - where instead of a list that gives the mouth shape for every window, - it only stores where the mouth shape changes - and what it changes to, which is what we need - This also converts the integers generated by the model - into the standard Hannah Barbera mouth shapes. - - - Parameters: - input_list (list) : A list of windows and mouthshapes for each window. - - Returns: - timestamps_value_pairs (list) : A list of tuples with timestamps - of when the mouth shape changes, - and the corresponding new mouth shape. - """ - - #The length of a single window is 1 second / sample_rate * hop_length - window_length = 1 / TARGET_SR * HOP_LENGTH - timestamps = [ - round(index * window_length, 2) for index in range(len(input_list)) - ] - - timestamps_value_pairs = zip(timestamps, input_list) - timestamps_value_pairs = [ - (t, v) for i, (t, v) in enumerate(timestamps_value_pairs) - if i == 0 or v != input_list[i - 1] - ] - - for _, mouth_shape in timestamps_value_pairs: - if mouth_shape == 7: - mouth_shape = 'X' - else: - mouth_shape = chr(ord('A') + mouth_shape) - - #Now we filter the timestamps - timestamps_value_pairs = [ - (t, v) for i, (t, v) in enumerate(timestamps_value_pairs) - if i == 0 or t - timestamps_value_pairs[i - 1][0] > 0.05 - ] - return timestamps_value_pairs - - -def convert_to_aegisub_format(timestamps_value_pairs): - """ - Converts from timestamp_value_pairs format to aegisub format, which is useful for debugging. - - Parameters: - timestamps_value_pairs : The times and mouth shapes in timestamp/value pair format. - - Returns: - A string with all the lines in aegisub .ass format. - """ - aegisub_lines = [] - for i in range(len(timestamps_value_pairs)): - timestamp, text = timestamps_value_pairs[i] - timestamp_str = '{:.2f}'.format(timestamp) - if i < len(timestamps_value_pairs) - 1: - end_timestamp = '{:.2f}'.format(timestamps_value_pairs[i + 1][0]) - else: - end_timestamp = '{:.2f}'.format( - timestamp + 1.0) # Extend the last line for 1 second - aegisub_line = f'Dialogue: 0,{timestamp_str},{end_timestamp},Default,,0,0,0,,{text}' - aegisub_lines.append(aegisub_line) - return '\n'.join(aegisub_lines) - - -def infer(wav_file_name, model_name): - ''' - Given a wav_file and a model, loads them both and computes the mouth shapes for the wav. - - Parameters: - wav_file_name : Name of a wav file that should be loaded. It should be 41khz. - model_name : Name of the model you're using. - - Returns: - A list of windows and their corresponding mouth shapes. - ''' - - #Due to the small size of the network and input data, - #this gets much better performance on straight CPU. - #If your application requires GPU for some reason, simply uncomment the below code. - ''' - if torch.cuda.is_available(): - device = 'cuda' - else: - device = 'cpu' - print("Warning: CUDA not found, using CPU mode.") - ''' - device = "cpu" - # Load the trained model - model_with_params = torch.load(model_name) - - # Extract the parameters - input_size = model_with_params.input_size - hidden_size = model_with_params.hidden_size - output_size = model_with_params.output_size - num_layers = model_with_params.num_layers - - # Create the model - model = LipSyncNet(input_size, hidden_size, output_size, num_layers) - - # Load the model weights - model.load_state_dict(model_with_params.state_dict()) - # Preprocess the input data - input_data = preprocess(wav_file_name) - - input_data = torch.from_numpy(input_data) - input_data = input_data.unsqueeze(0) - #Pass both model and input to device - model = model.to(device) - input_data = input_data.to(device) - - - - # Feed the input data through the model - start_time = time.time() - - with torch.no_grad(): - output_data = model(input_data) - - elapsed_time = time.time() - start_time - print(f"Time for inference: {elapsed_time}") - # Postprocess the output data - prediction = postprocess(output_data) - return prediction - - -def main(wav_file_name, model_name): - output = infer(wav_file_name=wav_file_name, model_name=model_name) - rle = convert_list_to_RLE(output) - int_to_letter = {i: chr(i + 65) for i in range(7)} - for item1, item2 in rle: - print(f"{item1}:{int_to_letter[item2]}") + + +class LipSyncInference: + def __init__(self, model_name): + self.target_sr = 44100 + self.hop_length = 512 + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = self.load_model(model_name) + + # @timebudget + def load_model(self, model_name): + model_with_params = torch.load(model_name) + input_size = model_with_params.input_size + hidden_size = model_with_params.hidden_size + output_size = model_with_params.output_size + num_layers = model_with_params.num_layers + model = LipSyncNet(input_size, hidden_size, output_size, num_layers) + model.load_state_dict(model_with_params.state_dict()) + model = model.to(self.device) + + # Apply dynamic quantization - this doesn't make things much faster + # model = torch.quantization.quantize_dynamic( + # model, {torch.nn.Linear}, dtype=torch.qint8 + # ) + + return model + + # @timebudget + def preprocess(self, filename): + """ + Load an audio file and extract MFCC features for processing. + + Parameters: + filename (str): The path to the audio file. + + Returns: + numpy.ndarray: The MFCC features of the audio file. + """ + y, sr = soundfile.read(filename) + if sr != self.target_sr: + y = resample(y, int(len(y) * self.target_sr / sr)) + winstep = self.hop_length / sr + mfcc_features = mfcc(y, samplerate=sr, numcep=13, nfft=2048, winlen=0.025, winstep=winstep) + return mfcc_features + + # @timebudget + def postprocess(self, output_data): + """ + Get the most likely mouth shape from the tensor. + + Parameters: + output_data (tensor): The data output by the model + (In the form of likelihoods of each possible mouth shape for each window.) + + Returns: + max_likelihood_list: A list with the most likely mouth shape for each window. + """ + max_likelihood = torch.argmax(output_data, dim=2) + max_likelihood_list = torch.flatten(max_likelihood).tolist() + return max_likelihood_list + + # @timebudget + def convert_list_to_RLE(self, input_list): + """ + Converts the list to a format similar to RunLengthEncoding, + where instead of a list that gives the mouth shape for every window, + it only stores where the mouth shape changes + and what it changes to, which is what we need + This also converts the integers generated by the model + into the standard Hannah Barbera mouth shapes. + + + Parameters: + input_list (list) : A list of windows and mouthshapes for each window. + + Returns: + timestamps_value_pairs (list) : A list of tuples with timestamps + of when the mouth shape changes, + and the corresponding new mouth shape. + """ + window_length = 1 / self.target_sr * self.hop_length + timestamps = [round(index * window_length, 2) for index in range(len(input_list))] + timestamps_value_pairs = zip(timestamps, input_list) + timestamps_value_pairs = [(t, v) for i, (t, v) in enumerate(timestamps_value_pairs) + if i == 0 or v != input_list[i - 1]] + for _, mouth_shape in timestamps_value_pairs: + if mouth_shape == 7: + mouth_shape = 'X' + else: + mouth_shape = chr(ord('A') + mouth_shape) + timestamps_value_pairs = [(t, v) for i, (t, v) in enumerate(timestamps_value_pairs) + if i == 0 or t - timestamps_value_pairs[i - 1][0] > 0.05] + return timestamps_value_pairs + + # @timebudget + def convert_to_aegisub_format(self, timestamps_value_pairs): + """ + Converts from timestamp_value_pairs format to aegisub format, which is useful for debugging. + + Parameters: + timestamps_value_pairs : The times and mouth shapes in timestamp/value pair format. + + Returns: + A string with all the lines in aegisub .ass format. + """ + aegisub_lines = [] + for i in range(len(timestamps_value_pairs)): + timestamp, text = timestamps_value_pairs[i] + timestamp_str = '{:.2f}'.format(timestamp) + if i < len(timestamps_value_pairs) - 1: + end_timestamp = '{:.2f}'.format(timestamps_value_pairs[i + 1][0]) + else: + end_timestamp = '{:.2f}'.format(timestamp + 1.0) + aegisub_line = f'Dialogue: 0,{timestamp_str},{end_timestamp},Default,,0,0,0,,{text}' + aegisub_lines.append(aegisub_line) + return '\n'.join(aegisub_lines) + + # @timebudget + def infer(self, wav_file_name): + """ + Given a wav_file and a model, loads them both and computes the mouth shapes for the wav. + + Parameters: + wav_file_name : Name of a wav file that should be loaded. It should be 41khz. + model_name : Name of the model you're using. + + Returns: + A list of windows and their corresponding mouth shapes. + """ + input_data = self.preprocess(wav_file_name) + input_data = torch.from_numpy(input_data) + input_data = input_data.unsqueeze(0) + input_data = input_data.float() # Cast input data to float32 + input_data = input_data.to(self.device) + with torch.no_grad(): + output_data = self.model(input_data) + prediction = self.postprocess(output_data) + return prediction if __name__ == '__main__': @@ -220,4 +183,9 @@ def main(wav_file_name, model_name): help='The name of the model file.') args = parser.parse_args() - main(wav_file_name=args.wav_file_name, model_name=args.model_name) + inference = LipSyncInference(model_name=args.model_name) + output = inference.infer(wav_file_name=args.wav_file_name) + rle = inference.convert_list_to_RLE(output) + int_to_letter = {i: chr(i + 65) for i in range(7)} + for item1, item2 in rle: + print(f"{item1}:{int_to_letter[item2]}") diff --git a/requirements.txt b/requirements.txt index 83d4415..ceda192 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,11 @@ -librosa==0.10.0.post2 -matplotlib==3.7.2 -numpy==1.24.4 -scikit_learn==1.3.0 -scipy==1.11.1 -seaborn==0.12.2 -torch==2.0.1+cu118 -torchsummary==1.5.1 +librosa +matplotlib +numpy +python_speech_features +scikit_learn +scipy +seaborn +soundfile +timebudget +torch +torchsummary