Skip to content

Commit

Permalink
Merge pull request #2 from lhl/main
Browse files Browse the repository at this point in the history
Inference code should be 10X faster now
  • Loading branch information
cryptowooser authored Mar 26, 2024
2 parents 8d97c84 + 50f4dea commit 42a60d8
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 189 deletions.
330 changes: 149 additions & 181 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,188 +22,151 @@
python inference.py --wav_file_name my_audio.wav --model_name my_model.pth
"""

import time
import argparse
from data_classes import LipSyncNet
from python_speech_features import mfcc
from scipy.signal import resample
import soundfile
from timebudget import timebudget
import torch
import librosa
from data_classes import LipSyncNet

TARGET_SR = 44100
HOP_LENGTH = 512


def preprocess(filename):
"""
Load an audio file and extract MFCC features for processing.
Parameters:
filename (str): The path to the audio file.
Returns:
numpy.ndarray: The MFCC features of the audio file.
"""
y, sr = librosa.load(filename, sr=None)
if sr != TARGET_SR:
y = librosa.resample(y=y, orig_sr=sr, target_sr=TARGET_SR)
mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13, hop_length=HOP_LENGTH)
mfcc = mfcc.T
return mfcc


def postprocess(output_data):
"""
Get the most likely mouth shape from the tensor.
Parameters:
output_data (tensor): The data output by the model
(In the form of likelihoods of each possible mouth shape for each window.)
Returns:
max_likelihood_list: A list with the most likely mouth shape for each window.
"""
max_likelihood = torch.argmax(
output_data, dim=2) # Find the maximum likelihood for each prediction
max_likelihood_list = torch.flatten(
max_likelihood).tolist() # Convert the tensor to a list
return max_likelihood_list


def convert_list_to_RLE(input_list):
"""
Converts the list to a format similar to RunLengthEncoding,
where instead of a list that gives the mouth shape for every window,
it only stores where the mouth shape changes
and what it changes to, which is what we need
This also converts the integers generated by the model
into the standard Hannah Barbera mouth shapes.
Parameters:
input_list (list) : A list of windows and mouthshapes for each window.
Returns:
timestamps_value_pairs (list) : A list of tuples with timestamps
of when the mouth shape changes,
and the corresponding new mouth shape.
"""

#The length of a single window is 1 second / sample_rate * hop_length
window_length = 1 / TARGET_SR * HOP_LENGTH
timestamps = [
round(index * window_length, 2) for index in range(len(input_list))
]

timestamps_value_pairs = zip(timestamps, input_list)
timestamps_value_pairs = [
(t, v) for i, (t, v) in enumerate(timestamps_value_pairs)
if i == 0 or v != input_list[i - 1]
]

for _, mouth_shape in timestamps_value_pairs:
if mouth_shape == 7:
mouth_shape = 'X'
else:
mouth_shape = chr(ord('A') + mouth_shape)

#Now we filter the timestamps
timestamps_value_pairs = [
(t, v) for i, (t, v) in enumerate(timestamps_value_pairs)
if i == 0 or t - timestamps_value_pairs[i - 1][0] > 0.05
]
return timestamps_value_pairs


def convert_to_aegisub_format(timestamps_value_pairs):
"""
Converts from timestamp_value_pairs format to aegisub format, which is useful for debugging.
Parameters:
timestamps_value_pairs : The times and mouth shapes in timestamp/value pair format.
Returns:
A string with all the lines in aegisub .ass format.
"""
aegisub_lines = []
for i in range(len(timestamps_value_pairs)):
timestamp, text = timestamps_value_pairs[i]
timestamp_str = '{:.2f}'.format(timestamp)
if i < len(timestamps_value_pairs) - 1:
end_timestamp = '{:.2f}'.format(timestamps_value_pairs[i + 1][0])
else:
end_timestamp = '{:.2f}'.format(
timestamp + 1.0) # Extend the last line for 1 second
aegisub_line = f'Dialogue: 0,{timestamp_str},{end_timestamp},Default,,0,0,0,,{text}'
aegisub_lines.append(aegisub_line)
return '\n'.join(aegisub_lines)


def infer(wav_file_name, model_name):
'''
Given a wav_file and a model, loads them both and computes the mouth shapes for the wav.
Parameters:
wav_file_name : Name of a wav file that should be loaded. It should be 41khz.
model_name : Name of the model you're using.
Returns:
A list of windows and their corresponding mouth shapes.
'''

#Due to the small size of the network and input data,
#this gets much better performance on straight CPU.
#If your application requires GPU for some reason, simply uncomment the below code.
'''
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
print("Warning: CUDA not found, using CPU mode.")
'''
device = "cpu"
# Load the trained model
model_with_params = torch.load(model_name)

# Extract the parameters
input_size = model_with_params.input_size
hidden_size = model_with_params.hidden_size
output_size = model_with_params.output_size
num_layers = model_with_params.num_layers

# Create the model
model = LipSyncNet(input_size, hidden_size, output_size, num_layers)

# Load the model weights
model.load_state_dict(model_with_params.state_dict())
# Preprocess the input data
input_data = preprocess(wav_file_name)

input_data = torch.from_numpy(input_data)
input_data = input_data.unsqueeze(0)
#Pass both model and input to device
model = model.to(device)
input_data = input_data.to(device)



# Feed the input data through the model
start_time = time.time()

with torch.no_grad():
output_data = model(input_data)

elapsed_time = time.time() - start_time
print(f"Time for inference: {elapsed_time}")
# Postprocess the output data
prediction = postprocess(output_data)
return prediction


def main(wav_file_name, model_name):
output = infer(wav_file_name=wav_file_name, model_name=model_name)
rle = convert_list_to_RLE(output)
int_to_letter = {i: chr(i + 65) for i in range(7)}
for item1, item2 in rle:
print(f"{item1}:{int_to_letter[item2]}")


class LipSyncInference:
def __init__(self, model_name):
self.target_sr = 44100
self.hop_length = 512
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.load_model(model_name)

# @timebudget
def load_model(self, model_name):
model_with_params = torch.load(model_name)
input_size = model_with_params.input_size
hidden_size = model_with_params.hidden_size
output_size = model_with_params.output_size
num_layers = model_with_params.num_layers
model = LipSyncNet(input_size, hidden_size, output_size, num_layers)
model.load_state_dict(model_with_params.state_dict())
model = model.to(self.device)

# Apply dynamic quantization - this doesn't make things much faster
# model = torch.quantization.quantize_dynamic(
# model, {torch.nn.Linear}, dtype=torch.qint8
# )

return model

# @timebudget
def preprocess(self, filename):
"""
Load an audio file and extract MFCC features for processing.
Parameters:
filename (str): The path to the audio file.
Returns:
numpy.ndarray: The MFCC features of the audio file.
"""
y, sr = soundfile.read(filename)
if sr != self.target_sr:
y = resample(y, int(len(y) * self.target_sr / sr))
winstep = self.hop_length / sr
mfcc_features = mfcc(y, samplerate=sr, numcep=13, nfft=2048, winlen=0.025, winstep=winstep)
return mfcc_features

# @timebudget
def postprocess(self, output_data):
"""
Get the most likely mouth shape from the tensor.
Parameters:
output_data (tensor): The data output by the model
(In the form of likelihoods of each possible mouth shape for each window.)
Returns:
max_likelihood_list: A list with the most likely mouth shape for each window.
"""
max_likelihood = torch.argmax(output_data, dim=2)
max_likelihood_list = torch.flatten(max_likelihood).tolist()
return max_likelihood_list

# @timebudget
def convert_list_to_RLE(self, input_list):
"""
Converts the list to a format similar to RunLengthEncoding,
where instead of a list that gives the mouth shape for every window,
it only stores where the mouth shape changes
and what it changes to, which is what we need
This also converts the integers generated by the model
into the standard Hannah Barbera mouth shapes.
Parameters:
input_list (list) : A list of windows and mouthshapes for each window.
Returns:
timestamps_value_pairs (list) : A list of tuples with timestamps
of when the mouth shape changes,
and the corresponding new mouth shape.
"""
window_length = 1 / self.target_sr * self.hop_length
timestamps = [round(index * window_length, 2) for index in range(len(input_list))]
timestamps_value_pairs = zip(timestamps, input_list)
timestamps_value_pairs = [(t, v) for i, (t, v) in enumerate(timestamps_value_pairs)
if i == 0 or v != input_list[i - 1]]
for _, mouth_shape in timestamps_value_pairs:
if mouth_shape == 7:
mouth_shape = 'X'
else:
mouth_shape = chr(ord('A') + mouth_shape)
timestamps_value_pairs = [(t, v) for i, (t, v) in enumerate(timestamps_value_pairs)
if i == 0 or t - timestamps_value_pairs[i - 1][0] > 0.05]
return timestamps_value_pairs

# @timebudget
def convert_to_aegisub_format(self, timestamps_value_pairs):
"""
Converts from timestamp_value_pairs format to aegisub format, which is useful for debugging.
Parameters:
timestamps_value_pairs : The times and mouth shapes in timestamp/value pair format.
Returns:
A string with all the lines in aegisub .ass format.
"""
aegisub_lines = []
for i in range(len(timestamps_value_pairs)):
timestamp, text = timestamps_value_pairs[i]
timestamp_str = '{:.2f}'.format(timestamp)
if i < len(timestamps_value_pairs) - 1:
end_timestamp = '{:.2f}'.format(timestamps_value_pairs[i + 1][0])
else:
end_timestamp = '{:.2f}'.format(timestamp + 1.0)
aegisub_line = f'Dialogue: 0,{timestamp_str},{end_timestamp},Default,,0,0,0,,{text}'
aegisub_lines.append(aegisub_line)
return '\n'.join(aegisub_lines)

# @timebudget
def infer(self, wav_file_name):
"""
Given a wav_file and a model, loads them both and computes the mouth shapes for the wav.
Parameters:
wav_file_name : Name of a wav file that should be loaded. It should be 41khz.
model_name : Name of the model you're using.
Returns:
A list of windows and their corresponding mouth shapes.
"""
input_data = self.preprocess(wav_file_name)
input_data = torch.from_numpy(input_data)
input_data = input_data.unsqueeze(0)
input_data = input_data.float() # Cast input data to float32
input_data = input_data.to(self.device)
with torch.no_grad():
output_data = self.model(input_data)
prediction = self.postprocess(output_data)
return prediction


if __name__ == '__main__':
Expand All @@ -220,4 +183,9 @@ def main(wav_file_name, model_name):
help='The name of the model file.')
args = parser.parse_args()

main(wav_file_name=args.wav_file_name, model_name=args.model_name)
inference = LipSyncInference(model_name=args.model_name)
output = inference.infer(wav_file_name=args.wav_file_name)
rle = inference.convert_list_to_RLE(output)
int_to_letter = {i: chr(i + 65) for i in range(7)}
for item1, item2 in rle:
print(f"{item1}:{int_to_letter[item2]}")
19 changes: 11 additions & 8 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
librosa==0.10.0.post2
matplotlib==3.7.2
numpy==1.24.4
scikit_learn==1.3.0
scipy==1.11.1
seaborn==0.12.2
torch==2.0.1+cu118
torchsummary==1.5.1
librosa
matplotlib
numpy
python_speech_features
scikit_learn
scipy
seaborn
soundfile
timebudget
torch
torchsummary

0 comments on commit 42a60d8

Please sign in to comment.