Skip to content

Commit

Permalink
🩹 Fix error B904 of pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Felipedino committed Jan 11, 2025
1 parent facf576 commit 1b9a39e
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 28 deletions.
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"name":"Local: predict","url":"c:\\Users\\felip\\OneDrive\\Escritorio\\proyecto\\python-envs\\DashAI\\DashAI\\back\\api\\api_v1\\endpoints\\predict.py","tests":[{"id":1733963920211,"input":"","output":""}],"interactive":false,"memoryLimit":1024,"timeLimit":3000,"srcPath":"c:\\Users\\felip\\OneDrive\\Escritorio\\proyecto\\python-envs\\DashAI\\DashAI\\back\\api\\api_v1\\endpoints\\predict.py","group":"local","local":true}
{"name":"Local: predict","url":"c:\\Users\\felip\\OneDrive\\Escritorio\\proyecto\\python-envs\\DashAI\\DashAI\\back\\api\\api_v1\\endpoints\\predict.py","tests":[{"id":1733963920211,"input":"","output":""}],"interactive":false,"memoryLimit":1024,"timeLimit":3000,"srcPath":"c:\\Users\\felip\\OneDrive\\Escritorio\\proyecto\\python-envs\\DashAI\\DashAI\\back\\api\\api_v1\\endpoints\\predict.py","group":"local","local":true}
31 changes: 16 additions & 15 deletions DashAI/back/api/api_v1/endpoints/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@

from fastapi import APIRouter, Depends, status
from fastapi.exceptions import HTTPException
from fastapi.responses import JSONResponse
from kink import di, inject
from sqlalchemy.orm import sessionmaker

from DashAI.back.api.api_v1.schemas.predict_params import (
FilterDatasetParams,
RenameRequest,
filterDatasetParams,
)
from DashAI.back.dataloaders.classes.dashai_dataset import get_columns_spec
from DashAI.back.dependencies.database.models import Dataset, Experiment, Run
Expand Down Expand Up @@ -51,7 +50,7 @@ async def get_metadata_prediction_json(
path.mkdir(parents=True, exist_ok=True)
files = os.listdir(path)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e)) from e

json_files = [f for f in files if f.endswith(".json")]
if not json_files:
Expand Down Expand Up @@ -193,8 +192,10 @@ async def get_predict_summary(
with open(path, "r") as f:
try:
data = json.load(f)["prediction"]
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid JSON format")
except json.JSONDecodeError as e:
raise HTTPException(
status_code=400, detail="Invalid JSON format"
) from e
summary["total_data_points"] = len(data)
class_set = set(data)
classes = [str(item) for item in class_set]
Expand All @@ -204,10 +205,10 @@ async def get_predict_summary(
for class_name in classes:
try:
occurrences = data.count(int(class_name))
except ValueError:
except ValueError as e:
raise HTTPException(
status_code=400, detail=f"Invalid class value: {class_name}"
)
) from e
distribution = {
"id": id,
"Class": class_name,
Expand All @@ -221,16 +222,16 @@ async def get_predict_summary(
{"id": idx, "value": value} for idx, value in enumerate(data[:50], 1)
]
summary["sample_data"] = sample_data
except FileNotFoundError:
raise HTTPException(status_code=404, detail="Prediction not found")
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail="Prediction not found") from e
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e)) from e
return summary


@router.post("/filter_datasets")
async def filter_datasets_endpoint(
params: filterDatasetParams,
params: FilterDatasetParams,
session_factory: sessionmaker = Depends(lambda: di["session_factory"]),
):
"""
Expand Down Expand Up @@ -275,7 +276,7 @@ async def filter_datasets_endpoint(
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred while filtering datasets",
)
) from e


@router.get("/download/{predict_name}")
Expand Down Expand Up @@ -314,7 +315,7 @@ async def download_prediction(
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred while downloading the prediction file",
)
) from e


@router.delete("/{predict_name}")
Expand Down Expand Up @@ -354,7 +355,7 @@ async def delete_prediction(
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred while deleting the prediction file",
)
) from e


@router.patch("/{predict_name}")
Expand Down Expand Up @@ -409,4 +410,4 @@ async def rename_prediction(
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred while renaming the prediction file",
)
) from e
2 changes: 1 addition & 1 deletion DashAI/back/api/api_v1/schemas/predict_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ class RenameRequest(BaseModel):
new_name: str


class filterDatasetParams(BaseModel):
class FilterDatasetParams(BaseModel):
train_dataset_id: int
datasets: List[str]
2 changes: 1 addition & 1 deletion DashAI/back/dependencies/database/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import datetime
from typing import List

from sqlalchemy import JSON, Boolean, DateTime, Enum, ForeignKey, String
from sqlalchemy import JSON, DateTime, Enum, ForeignKey, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Mapped, mapped_column, relationship

Expand Down
3 changes: 1 addition & 2 deletions DashAI/back/job/predict_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def run(
session_factory: sessionmaker = Depends(lambda: di["session_factory"]),
config=lambda di: di["config"],
) -> List[Any]:

run_id: int = self.kwargs["run_id"]
id: int = self.kwargs["id"]
db: Session = self.kwargs["db"]
Expand Down Expand Up @@ -112,7 +111,7 @@ def run(
raise HTTPException(
status_code=400,
detail=f"Invalid columns selected: {str(ve)}",
)
) from ve
except Exception as e:
log.error(e)
raise JobError(
Expand Down
8 changes: 0 additions & 8 deletions tests/back/api/test_predict_api.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,11 @@
import csv
import datetime
import json
import os
from pathlib import Path

import joblib
import numpy as np
import pytest
from fastapi.testclient import TestClient

from DashAI.back.api.api_v1.schemas.job_params import JobParams
from DashAI.back.dataloaders.classes.dashai_dataset import (
DashAIDataset,
get_dataset_info,
)
from DashAI.back.dataloaders.classes.json_dataloader import JSONDataLoader
from DashAI.back.dependencies.database.models import Dataset, Experiment, Run
from DashAI.back.dependencies.registry import ComponentRegistry
Expand Down

0 comments on commit 1b9a39e

Please sign in to comment.