Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
13MrBlackCat13 authored Nov 9, 2023
1 parent 9fbc720 commit 9738c72
Showing 1 changed file with 18 additions and 27 deletions.
45 changes: 18 additions & 27 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from sklearn.model_selection import train_test_split
import joblib
import numpy as np
import logging
from joblib import Parallel, delayed

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def load_data(data_dir):
categories = os.listdir(data_dir)
Expand All @@ -29,7 +32,7 @@ def load_data(data_dir):
messages.append(f.read())
categories_list.append(category)
except UnicodeError:
print(f"Error decoding file: {os.path.join(data_dir, category, message)}")
logging.error(f"Error decoding file: {os.path.join(data_dir, category, message)}")

data = pd.DataFrame({'text': messages, 'category': categories_list})
return data
Expand Down Expand Up @@ -79,40 +82,28 @@ def main():
X_train_vectorized_emails = vectorizer_emails.fit_transform(X_train_emails)
X_test_vectorized_emails = vectorizer_emails.transform(X_test_emails)

models_to_train = []
models_to_train = ['lr', 'svm', 'rf', 'gb'] if args.train == 'all' else [args.train]

if args.train == 'all':
models_to_train = ['lr', 'svm', 'rf', 'gb']
else:
models_to_train = [args.train]

for model_name in models_to_train:
if model_name == 'lr':
model_text = train_model('lr', X_train_vectorized_text, y_train_text)
model_emails = train_model('lr', X_train_vectorized_emails, y_train_emails)
elif model_name == 'svm':
model_text = train_model('svm', X_train_vectorized_text, y_train_text)
model_emails = train_model('svm', X_train_vectorized_emails, y_train_emails)
elif model_name == 'rf':
model_text = train_model('rf', X_train_vectorized_text, y_train_text)
model_emails = train_model('rf', X_train_vectorized_emails, y_train_emails)
elif model_name == 'gb':
model_text = train_model('gb', X_train_vectorized_text, y_train_text)
model_emails = train_model('gb', X_train_vectorized_emails, y_train_emails)
else:
print(f"Invalid model name: {model_name}")
continue
if not os.path.exists('models'):
os.makedirs('models')

def train_and_save(model_name):
model_text = train_model(model_name, X_train_vectorized_text, y_train_text)
model_emails = train_model(model_name, X_train_vectorized_emails, y_train_emails)

accuracy_text = test_model(model_text, X_test_vectorized_text, y_test_text)
accuracy_emails = test_model(model_emails, X_test_vectorized_emails, y_test_emails)

print(f'{model_name.upper()} Accuracy for Text: {accuracy_text:.2f}')
print(f'{model_name.upper()} Accuracy for Emails: {accuracy_emails:.2f}')
logging.info(f'{model_name.upper()} Accuracy for Text: {accuracy_text:.2f}')
logging.info(f'{model_name.upper()} Accuracy for Emails: {accuracy_emails:.2f}')

joblib.dump(model_text, f'models/text-model-{model_name}.pkl')
joblib.dump(model_emails, f'models/emails-model-{model_name}.pkl')
joblib.dump(vectorizer_text, f'models/text-vectorizer-{model_name}.pkl')
joblib.dump(vectorizer_emails, f'models/emails-vectorizer-{model_name}.pkl')

print("Models trained and saved successfully")
Parallel(n_jobs=-1)(delayed(train_and_save)(model_name) for model_name in models_to_train)
logging.info("Models trained and saved successfully")

if __name__ == '__main__':
main()
main()

0 comments on commit 9738c72

Please sign in to comment.