diff --git a/app.py b/app.py index fc5cbfc..a555177 100644 --- a/app.py +++ b/app.py @@ -2,11 +2,14 @@ from fastapi import FastAPI from routes import router -app = FastAPI() +app = FastAPI( + docs_url='/api/docs', + openapi_url='/api/openapi.json' +) app.include_router(router) if __name__ == "__main__": import uvicorn num_workers = os.cpu_count() or 1 - uvicorn.run("app:app", host="0.0.0.0", port=5001, debug=False, workers=num_workers) + uvicorn.run("app:app", host="0.0.0.0", port=5001, debug=False, workers=num_workers) \ No newline at end of file diff --git a/routes.py b/routes.py index 2af272c..87504d0 100644 --- a/routes.py +++ b/routes.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, HTTPException, Depends from pydantic import BaseModel from utils import denoise_with_rnnoise, get_error_arrays, get_pause_count, split_into_phonemes, processLP -from schemas import TextData, audioData, PhonemesRequest, PhonemesResponse, ErrorArraysResponse +from schemas import TextData, audioData, PhonemesRequest, PhonemesResponse, ErrorArraysResponse, AudioProcessingResponse from typing import List import jiwer import eng_to_ipa as p @@ -15,7 +15,40 @@ router = APIRouter() -@router.post('/getTextMatrices') +@router.post('/getTextMatrices', response_model=ErrorArraysResponse, summary="Compute Text Matrices", description="Computes WER, CER, insertion, deletion, substitution, confidence char list, missing char list, construct text", responses={ + 400: { + "description": "Bad Request", + "content": { + "application/json": { + "example": {"detail": "Reference text must be provided."} + } + } + }, + 422: { + "description": "Unprocessable Entity", + "content": { + "application/json": { + "example": { + "detail": [ + { + "loc": ["body", "text"], + "msg": "field required", + "type": "value_error.missing" + } + ] + } + } + } + }, + 500: { + "description": "Internal Server Error", + "content": { + "application/json": { + "example": {"detail": "Unexpected error: Error processing characters: "} + } + } + } +}) async def compute_errors(data: TextData): try: # Validate input data @@ -81,17 +114,88 @@ async def compute_errors(data: TextData): except Exception as e: logger.error(f"Unexpected error: {str(e)}") raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") - -@router.post("/getPhonemes", response_model=dict) + +@router.post("/getPhonemes", response_model=PhonemesResponse, summary="Get Phonemes", description="Converts text into phonemes.", responses={ + 400: { + "description": "Bad Request", + "content": { + "application/json": { + "example": {"detail": "Input text cannot be empty."} + } + } + }, + 422: { + "description": "Unprocessable Entity", + "content": { + "application/json": { + "example": { + "detail": [ + { + "loc": ["body", "text"], + "msg": "field required", + "type": "value_error.missing" + } + ] + } + } + } + }, + 500: { + "description": "Internal Server Error", + "content": { + "application/json": { + "example": {"detail": "Unexpected error: Error getting phonemes: "} + } + } + } +}) async def get_phonemes(data: PhonemesRequest): try: + if not data.text.strip(): + raise HTTPException(status_code=400, detail="Input text cannot be empty.") + phonemesList = split_into_phonemes(p.convert(data.text)) return {"phonemes": phonemesList} + except HTTPException as e: + raise e except Exception as e: logger.error(f"Error getting phonemes: {str(e)}") raise HTTPException(status_code=500, detail=f"Error getting phonemes: {str(e)}") - -@router.post('/audio_processing') + +@router.post('/audio_processing', response_model=AudioProcessingResponse, summary="Process Audio", description="Processes audio by denoising and detecting pauses.", responses={ + 400: { + "description": "Bad Request", + "content": { + "application/json": { + "example": {"detail": "Base64 string of audio must be provided."} + } + } + }, + 422: { + "description": "Unprocessable Entity", + "content": { + "application/json": { + "example": { + "detail": [ + { + "loc": ["body", "text"], + "msg": "field required", + "type": "value_error.missing" + } + ] + } + } + } + }, + 500: { + "description": "Internal Server Error", + "content": { + "application/json": { + "example": {"detail": "Unexpected error: "} + } + } + } +}) async def audio_processing(data: audioData): try: # Validate input data diff --git a/schemas.py b/schemas.py index e67a621..7ef6543 100644 --- a/schemas.py +++ b/schemas.py @@ -1,33 +1,37 @@ -from pydantic import BaseModel -from typing import List, Optional +from pydantic import BaseModel,Field +from typing import List, Optional, Dict class TextData(BaseModel): - reference: str - hypothesis: str - language: str + reference: str = Field(..., example="frog jumps", description="The reference text to compare against.") + hypothesis: Optional[str] = Field(None, example="dog jumps", description="The hypothesis text to be compared.") + language: str = Field(..., example="en", description="The language of the text.") class audioData(BaseModel): - base64_string: str - enablePauseCount:bool - enableDenoiser:bool - contentType: str + base64_string: str = Field(..., example="UklGRiQAAABXQVZFZm10IBAAAAABAAEARKwAABCxAgAEABAAZGF0YUAA", description="Base64 encoded audio string.") + enablePauseCount: bool = Field(..., example=True, description="Flag to enable pause count detection.") + enableDenoiser: bool = Field(..., example=True, description="Flag to enable audio denoising.") + contentType: str = Field(..., example="Word", description="The type of content in the audio.") class PhonemesRequest(BaseModel): - text: str + text: str = Field(..., example="dog jumps", description="The text to convert into phonemes.") class PhonemesResponse(BaseModel): - phonemes: List[str] + phonemes: List[str] = Field(..., example=["d", "ɔ", "g", "ʤ", "ə", "m", "p", "s"], description="List of phonemes extracted from the text.") class ErrorArraysResponse(BaseModel): - wer: float - cer: float - insertion: List[str] - insertion_count: int - deletion: List[str] - deletion_count: int - substitution: List[dict] - substitution_count: int - pause_count: int - confidence_char_list: Optional[List[str]] - missing_char_list: Optional[List[str]] - construct_text: Optional[str] + wer: float = Field(..., example=0.5, description="Word Error Rate.") + cer: float = Field(..., example=0.2, description="Character Error Rate.") + insertion: List[str] = Field(..., example=[], description="List of insertions.") + insertion_count: int = Field(..., example=0, description="Count of insertions.") + deletion: List[str] = Field(..., example=["r"], description="List of deletions.") + deletion_count: int = Field(..., example=1, description="Count of deletions.") + substitution: List[Dict[str, str]] = Field(..., example=[{"removed": "d", "replaced": "f"}], description="List of substitutions.") + substitution_count: int = Field(..., example=1, description="Count of substitutions.") + pause_count: Optional[int] = Field(None, example=None, description="Count of pauses detected.") + confidence_char_list: Optional[List[str]] = Field(None, example=["p", "ʤ", "s", "ə", "m"], description="List of characters with confidence levels.") + missing_char_list: Optional[List[str]] = Field(None, example=["f", "g", "r", "ɑ"], description="List of missing characters.") + construct_text: Optional[str] = Field(None, example="jumps", description="Constructed text based on the hypothesis.") + +class AudioProcessingResponse(BaseModel): + denoised_audio_base64: str = Field(..., example="UkiGRV////wqgwbwrbw////AAAA", description="Base64 encoded denoised audio.") + pause_count: Optional[int] = Field(..., example=2, description="Count of pauses detected.") \ No newline at end of file