Skip to content

Commit

Permalink
Corrected spelling
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiow committed Oct 1, 2024
1 parent 9b15104 commit 6626bdd
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 19 deletions.
15 changes: 11 additions & 4 deletions analyzer/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@
# Password hashing
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

# Token model
class TokenData(BaseModel):
username: str | None = None

# Users and hashed passwords
fake_users_db = {
"testuser": {
Expand All @@ -24,23 +20,33 @@ class TokenData(BaseModel):
}
}


# Token model
class TokenData(BaseModel):
username: str | None = None


oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")


def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)


def get_user(db, username: str):
if username in db:
user_dict = db[username]
return user_dict
return None


def authenticate_user(username: str, password: str):
user = get_user(fake_users_db, username)
if not user or not verify_password(password, user["hashed_password"]):
return False
return user


def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
Expand All @@ -51,6 +57,7 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None):
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt


def get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
Expand Down
16 changes: 13 additions & 3 deletions analyzer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from datetime import timedelta
from model import make_prediction
from mongodb.save_results import save_prediction, get_all_predictions
from auth import create_access_token, authenticate_user, get_current_user # Authentication functions
from model import make_prediction
from mongodb.save_results import save_prediction, get_all_predictions
from auth import (
create_access_token,
authenticate_user,
get_current_user,
)

app = FastAPI()

Expand All @@ -25,14 +29,17 @@
"last_trained": "2024-09-28",
}


# Define the request models
class TextInput(BaseModel):
text: str


class UserLogin(BaseModel):
username: str
password: str


@app.post("/token")
def login_for_access_token(form_data: UserLogin):
# Endpoint to authenticate a user and generate a JWT token.
Expand All @@ -48,13 +55,15 @@ def login_for_access_token(form_data: UserLogin):
)
return {"access_token": access_token, "token_type": "bearer"}


@app.post("/predict", dependencies=[Depends(get_current_user)])
def predict_text(input: TextInput):
# Endpoint to make a single prediction.
prediction = make_prediction(input.text)
save_prediction(input.text, prediction) # Save the prediction to MongoDB
return {"prediction": prediction}


@app.get("/predictions", dependencies=[Depends(get_current_user)])
def get_predictions():
# Endpoint to retrieve all past predictions from the database.
Expand All @@ -63,6 +72,7 @@ def get_predictions():
raise HTTPException(status_code=404, detail="No predictions found.")
return predictions


@app.get("/model_metadata")
def model_metadata():
# Endpoint to retrieve metadata about the model.
Expand Down
2 changes: 1 addition & 1 deletion analyzer/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import joblib
import os
from preprocess import preprocess_text # Assuming you have preprocessing functions in preprocess.py
from preprocess import preprocess_text

MODEL_PATH = os.path.join("results", "text_classifier.joblib")
VECTORIZER_PATH = os.path.join("results", "tfidf_vectorizer.joblib")
Expand Down
9 changes: 1 addition & 8 deletions analyzer/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,8 @@
nlp = spacy.load("en_core_web_sm")

def preprocess_text(text):
"""
Function to preprocess the input text by tokenizing,
removing stop words, and lemmatizing.
"""
# Tokenize and process the text us spacy
# Tokenize and process the text using spacy
doc = nlp(text)

# Filter tokens: Remove stopwords and puctiation, then lemmatize
clean_tokens = [token.lemma_ for token in doc if not token.is_stop and not token.is_punct]

return ' '.join(clean_tokens)

1 change: 0 additions & 1 deletion analyzer/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,5 @@ pymongo
pytest
httpx
python-jose
jose
passlib
bandit
Binary file modified analyzer/results/tfidf_vectorizer.joblib
Binary file not shown.
3 changes: 1 addition & 2 deletions analyzer/test_main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# tests/test_main.py
import pytest
from fastapi.testclient import TestClient
from main import app

Expand All @@ -15,3 +13,4 @@ def test_read_model_metadata():
def test_predict_unauthorized():
response = client.post("/predict", json={"text": "sample text"})
assert response.status_code == 401

0 comments on commit 6626bdd

Please sign in to comment.