diff --git a/analyzer/auth.py b/analyzer/auth.py index a632791..9d32993 100644 --- a/analyzer/auth.py +++ b/analyzer/auth.py @@ -12,10 +12,6 @@ # Password hashing pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") -# Token model -class TokenData(BaseModel): - username: str | None = None - # Users and hashed passwords fake_users_db = { "testuser": { @@ -24,23 +20,33 @@ class TokenData(BaseModel): } } + +# Token model +class TokenData(BaseModel): + username: str | None = None + + oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + def verify_password(plain_password, hashed_password): return pwd_context.verify(plain_password, hashed_password) + def get_user(db, username: str): if username in db: user_dict = db[username] return user_dict return None + def authenticate_user(username: str, password: str): user = get_user(fake_users_db, username) if not user or not verify_password(password, user["hashed_password"]): return False return user + def create_access_token(data: dict, expires_delta: timedelta | None = None): to_encode = data.copy() if expires_delta: @@ -51,6 +57,7 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None): encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt + def get_current_user(token: str = Depends(oauth2_scheme)): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/analyzer/main.py b/analyzer/main.py index 945ee5c..b5784a9 100644 --- a/analyzer/main.py +++ b/analyzer/main.py @@ -2,9 +2,13 @@ from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from datetime import timedelta -from model import make_prediction -from mongodb.save_results import save_prediction, get_all_predictions -from auth import create_access_token, authenticate_user, get_current_user # Authentication functions +from model import make_prediction +from mongodb.save_results import save_prediction, get_all_predictions +from auth import ( + create_access_token, + authenticate_user, + get_current_user, +) app = FastAPI() @@ -25,14 +29,17 @@ "last_trained": "2024-09-28", } + # Define the request models class TextInput(BaseModel): text: str + class UserLogin(BaseModel): username: str password: str + @app.post("/token") def login_for_access_token(form_data: UserLogin): # Endpoint to authenticate a user and generate a JWT token. @@ -48,6 +55,7 @@ def login_for_access_token(form_data: UserLogin): ) return {"access_token": access_token, "token_type": "bearer"} + @app.post("/predict", dependencies=[Depends(get_current_user)]) def predict_text(input: TextInput): # Endpoint to make a single prediction. @@ -55,6 +63,7 @@ def predict_text(input: TextInput): save_prediction(input.text, prediction) # Save the prediction to MongoDB return {"prediction": prediction} + @app.get("/predictions", dependencies=[Depends(get_current_user)]) def get_predictions(): # Endpoint to retrieve all past predictions from the database. @@ -63,6 +72,7 @@ def get_predictions(): raise HTTPException(status_code=404, detail="No predictions found.") return predictions + @app.get("/model_metadata") def model_metadata(): # Endpoint to retrieve metadata about the model. diff --git a/analyzer/model.py b/analyzer/model.py index ba87cd2..532f7c7 100644 --- a/analyzer/model.py +++ b/analyzer/model.py @@ -1,6 +1,6 @@ import joblib import os -from preprocess import preprocess_text # Assuming you have preprocessing functions in preprocess.py +from preprocess import preprocess_text MODEL_PATH = os.path.join("results", "text_classifier.joblib") VECTORIZER_PATH = os.path.join("results", "tfidf_vectorizer.joblib") diff --git a/analyzer/preprocess.py b/analyzer/preprocess.py index 10bb83f..fa144da 100644 --- a/analyzer/preprocess.py +++ b/analyzer/preprocess.py @@ -3,15 +3,8 @@ nlp = spacy.load("en_core_web_sm") def preprocess_text(text): - """ - Function to preprocess the input text by tokenizing, - removing stop words, and lemmatizing. - """ - # Tokenize and process the text us spacy + # Tokenize and process the text using spacy doc = nlp(text) - # Filter tokens: Remove stopwords and puctiation, then lemmatize clean_tokens = [token.lemma_ for token in doc if not token.is_stop and not token.is_punct] - return ' '.join(clean_tokens) - diff --git a/analyzer/requirements.txt b/analyzer/requirements.txt index d708883..aaa3de9 100644 --- a/analyzer/requirements.txt +++ b/analyzer/requirements.txt @@ -8,6 +8,5 @@ pymongo pytest httpx python-jose -jose passlib bandit diff --git a/analyzer/results/tfidf_vectorizer.joblib b/analyzer/results/tfidf_vectorizer.joblib index a244e49..5c5250a 100644 Binary files a/analyzer/results/tfidf_vectorizer.joblib and b/analyzer/results/tfidf_vectorizer.joblib differ diff --git a/analyzer/test_main.py b/analyzer/test_main.py index e247cdf..292eafa 100644 --- a/analyzer/test_main.py +++ b/analyzer/test_main.py @@ -1,5 +1,3 @@ -# tests/test_main.py -import pytest from fastapi.testclient import TestClient from main import app @@ -15,3 +13,4 @@ def test_read_model_metadata(): def test_predict_unauthorized(): response = client.post("/predict", json={"text": "sample text"}) assert response.status_code == 401 +