-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactored train_model.py and updated readme.me, made requirements fi…
…le and included folders.
- Loading branch information
cryptowooser
committed
Jul 16, 2023
1 parent
0fc0aa0
commit 2f49923
Showing
13 changed files
with
292 additions
and
263 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import os | ||
|
||
def count_letters(directory): | ||
# Dictionary to store count of each letter A through I. | ||
count = {chr(i): 0 for i in range(65, 74)} | ||
count.update({chr(i): 0 for i in range(97, 106)}) # lowercase a-i | ||
|
||
# Text file extensions to be considered. | ||
text_extensions = ['.txt'] | ||
|
||
for root, dirs, files in os.walk(directory): | ||
for file in files: | ||
if any(file.endswith(ext) for ext in text_extensions): | ||
with open(os.path.join(root, file), 'r') as f: | ||
for line in f: | ||
for char in line: | ||
if char in count: | ||
count[char] += 1 | ||
|
||
return count | ||
|
||
def main(): | ||
directory = 'texts' # Insert your directory here. | ||
result = count_letters(directory) | ||
|
||
total_count = sum(result.values()) | ||
|
||
print(f"The letters A through I appear {total_count} times in the text files.") | ||
print("Detailed count per letter:") | ||
for letter, count in result.items(): | ||
print(f"{letter}: {count}") | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import torch | ||
from data_classes import LipSyncNet | ||
|
||
# Load the model | ||
model = LipSyncNet(input_size=13, hidden_size=256, output_size=7, num_layers=2) | ||
model.load_state_dict(torch.load('model_full_dataset_2layers_backup.pth')) | ||
|
||
# Add the parameters as attributes | ||
model.input_size = 13 | ||
model.hidden_size = 256 | ||
model.output_size = 7 | ||
model.num_layers = 2 | ||
|
||
# Save the entire model, not just the state dictionary | ||
torch.save(model, 'model_full_dataset_2layers.pth') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import os | ||
import librosa | ||
from scipy.io.wavfile import write | ||
|
||
input_dir = 'newflacs' | ||
output_dir = 'newwavs' | ||
|
||
os.makedirs(output_dir, exist_ok=True) | ||
|
||
# Walking through the input directory to find all .flac files | ||
for subdir, dirs, files in os.walk(input_dir): | ||
for file in files: | ||
filepath = subdir + os.sep + file | ||
|
||
if filepath.endswith(".flac"): | ||
# Loading the flac file and checking the sample rate | ||
y, sr = librosa.load(filepath, sr=None) | ||
|
||
# If sample rate is not 44.1kHz, resample it | ||
if sr != 44100: | ||
y_resampled = librosa.resample(y=y, orig_sr=sr, target_sr=44100) | ||
|
||
else: | ||
y_resampled = y | ||
|
||
# Write out as a .wav file in the output directory | ||
output_file_path = os.path.join(output_dir, file.replace('.flac', '.wav')) | ||
write(output_file_path, 44100, y_resampled) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from torchsummary import summary | ||
import torch | ||
from torch.utils.data import Dataset, DataLoader | ||
from scipy.io import wavfile | ||
import numpy as np | ||
import os | ||
import librosa | ||
import matplotlib.pyplot as plt | ||
import glob | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from sklearn.metrics import confusion_matrix | ||
import seaborn as sns | ||
import matplotlib.pyplot as plt | ||
|
||
input_size = 13 | ||
hidden_size = 256 | ||
output_size = 7 | ||
print(f"Output Size: {output_size}") | ||
num_layers = 3 | ||
device = 'cuda' | ||
|
||
class LipSyncNet(nn.Module): | ||
def __init__(self, input_size, hidden_size, output_size, num_layers): | ||
super(LipSyncNet, self).__init__() | ||
self.conv1 = nn.Conv1d(input_size, hidden_size, kernel_size=3, padding=1) | ||
self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True, bidirectional=True, dropout=0.2) | ||
# self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True, dropout=0.2) | ||
self.fc = nn.Linear(hidden_size*2, output_size) | ||
self.num_layers = num_layers | ||
self.hidden_size = hidden_size | ||
|
||
def forward(self, x): | ||
|
||
x = x.permute(0, 2, 1) # permute the dimensions | ||
x = self.conv1(x) | ||
x = F.relu(x) | ||
|
||
# Initialize hidden state and cell state | ||
h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(x.device) | ||
c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(x.device) | ||
|
||
# Pass through LSTM | ||
x = x.permute(0, 2, 1) # permute the dimensions | ||
|
||
out, _ = self.lstm(x, (h0, c0)) | ||
|
||
# Pass through FC layer to get predictions for each time step | ||
out = self.fc(out) | ||
|
||
return out | ||
|
||
# Assuming your model is named "model" and is moved to the correct device | ||
model = LipSyncNet(input_size, hidden_size, output_size, num_layers).to(device) | ||
|
||
# You need to specify the size of the input in the call to summary | ||
# Assuming your input is 1-dimensional MFCC with a sequence length of 1000 | ||
summary(model, input_size=13) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import inference | ||
|
||
import inference # Import your inference.py module | ||
import cProfile | ||
|
||
cProfile.runctx('inference.main()', globals(), locals(), filename='profile_data.prof') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import torch | ||
from data_classes import LipSyncNet | ||
|
||
# Load the model | ||
model = torch.load('model_full_dataset_2layers.pth') | ||
|
||
# Print the attributes | ||
print(f'input_size: {model.input_size}') | ||
print(f'hidden_size: {model.hidden_size}') | ||
print(f'output_size: {model.output_size}') | ||
print(f'num_layers: {model.num_layers}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import wave | ||
|
||
def get_sample_rate(filename): | ||
with wave.open(filename, 'rb') as wav_file: | ||
return wav_file.getframerate() | ||
|
||
# Test it on an audio file | ||
sample_rate = get_sample_rate('wavs/003.wav') | ||
print(f'Sample rate: {sample_rate} Hz') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import os | ||
import random | ||
from datetime import datetime | ||
|
||
def create_run_name(nouns_file, adjectives_file): | ||
# Read in the nouns and adjectives | ||
with open(nouns_file, 'r') as f: | ||
nouns = [line.strip() for line in f.readlines()] | ||
with open(adjectives_file, 'r') as f: | ||
adjectives = [line.strip() for line in f.readlines()] | ||
|
||
# Randomly select a noun and adjective | ||
noun = random.choice(nouns) | ||
adjective = random.choice(adjectives) | ||
|
||
# Get current date and time | ||
now = datetime.now() | ||
|
||
# Format as a string | ||
now_str = now.strftime("%Y%m%d-%H%M%S") | ||
|
||
# Construct run name | ||
run_name = f"lipsync_{adjective}_{noun}_{now_str}" | ||
|
||
return run_name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
librosa==0.10.0.post2 | ||
matplotlib==3.7.2 | ||
numpy==1.24.4 | ||
scikit_learn==1.3.0 | ||
scipy==1.11.1 | ||
seaborn==0.12.2 | ||
torch==2.0.1+cu118 | ||
torchsummary==1.5.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Ignore everything in this directory | ||
* | ||
# Except this file | ||
!.gitignore |
Oops, something went wrong.