Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(mypy): enforce explicit overrides #2270

Merged
merged 2 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions antarest/core/cache/business/local_chache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import time
from typing import Dict, List, Optional

from typing_extensions import override

from antarest.core.config import CacheConfig
from antarest.core.interfaces.cache import ICache
from antarest.core.model import JSON
Expand All @@ -40,6 +42,7 @@ def __init__(self, config: CacheConfig = CacheConfig()):
daemon=True,
)

@override
def start(self) -> None:
self.checker_thread.start()

Expand All @@ -55,6 +58,7 @@ def checker(self) -> None:
for id in to_delete:
del self.cache[id]

@override
def put(self, id: str, data: JSON, duration: int = 3600) -> None: # Duration in second
with self.lock:
logger.info(f"Adding cache key {id}")
Expand All @@ -64,6 +68,7 @@ def put(self, id: str, data: JSON, duration: int = 3600) -> None: # Duration in
duration=duration,
)

@override
def get(self, id: str, refresh_duration: Optional[int] = None) -> Optional[JSON]:
res = None
with self.lock:
Expand All @@ -76,12 +81,14 @@ def get(self, id: str, refresh_duration: Optional[int] = None) -> Optional[JSON]
res = self.cache[id].data
return res

@override
def invalidate(self, id: str) -> None:
with self.lock:
logger.info(f"Removing cache key {id}")
if id in self.cache:
del self.cache[id]

@override
def invalidate_all(self, ids: List[str]) -> None:
with self.lock:
logger.info(f"Removing cache keys {ids}")
Expand Down
6 changes: 6 additions & 0 deletions antarest/core/cache/business/redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import List, Optional

from redis.client import Redis
from typing_extensions import override

from antarest.core.interfaces.cache import ICache
from antarest.core.model import JSON
Expand All @@ -31,17 +32,20 @@ class RedisCache(ICache):
def __init__(self, redis_client: Redis): # type: ignore
self.redis = redis_client

@override
def start(self) -> None:
# Assuming the Redis service is already running; no need to start it here.
pass

@override
def put(self, id: str, data: JSON, duration: int = 3600) -> None:
redis_element = RedisCacheElement(duration=duration, data=data)
redis_key = f"cache:{id}"
logger.info(f"Adding cache key {id}")
self.redis.set(redis_key, redis_element.model_dump_json())
self.redis.expire(redis_key, duration)

@override
def get(self, id: str, refresh_timeout: Optional[int] = None) -> Optional[JSON]:
redis_key = f"cache:{id}"
result = self.redis.get(redis_key)
Expand All @@ -58,10 +62,12 @@ def get(self, id: str, refresh_timeout: Optional[int] = None) -> Optional[JSON]:
logger.info(f"Cache key {id} not found")
return None

@override
def invalidate(self, id: str) -> None:
logger.info(f"Removing cache key {id}")
self.redis.delete(f"cache:{id}")

@override
def invalidate_all(self, ids: List[str]) -> None:
logger.info(f"Removing cache keys {ids}")
self.redis.delete(*[f"cache:{id}" for id in ids])
3 changes: 3 additions & 0 deletions antarest/core/configdata/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, Optional

from sqlalchemy import Column, Integer, String # type: ignore
from typing_extensions import override

from antarest.core.persistence import Base
from antarest.core.serialization import AntaresBaseModel
Expand All @@ -30,11 +31,13 @@ class ConfigData(Base): # type: ignore
key = Column(String(), primary_key=True)
value = Column(String(), nullable=True)

@override
def __eq__(self, other: Any) -> bool:
if not isinstance(other, ConfigData):
return False
return bool(other.key == self.key and other.value == self.value and other.owner == self.owner)

@override
def __repr__(self) -> str:
return f"key={self.key}, value={self.value}, owner={self.owner}"

Expand Down
11 changes: 11 additions & 0 deletions antarest/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from http import HTTPStatus

from fastapi.exceptions import HTTPException
from typing_extensions import override


class ShouldNotHappenException(Exception):
Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(self, path: str, *area_ids: str):
detail = f"{self.object_name.title()} {detail}"
super().__init__(HTTPStatus.NOT_FOUND, detail)

@override
def __str__(self) -> str:
"""Return a string representation of the exception."""
return self.detail
Expand Down Expand Up @@ -127,6 +129,7 @@ def __init__(self, path: str, section_id: str):
detail = f"{object_name.title()} '{section_id}' not found in '{path}'"
super().__init__(HTTPStatus.NOT_FOUND, detail)

@override
def __str__(self) -> str:
"""Return a string representation of the exception."""
return self.detail
Expand Down Expand Up @@ -172,6 +175,7 @@ def __init__(self, path: str):
detail = f"{self.object_name.title()} {detail}"
super().__init__(HTTPStatus.NOT_FOUND, detail)

@override
def __str__(self) -> str:
return self.detail

Expand Down Expand Up @@ -227,6 +231,7 @@ def __init__(self, area_id: str, *duplicates: str):
detail = f"{self.object_name.title()} {detail}"
super().__init__(HTTPStatus.CONFLICT, detail)

@override
def __str__(self) -> str:
"""Return a string representation of the exception."""
return self.detail
Expand Down Expand Up @@ -397,6 +402,7 @@ def __init__(self, object_id: str, binding_ids: t.Sequence[str], *, object_type:
)
super().__init__(HTTPStatus.FORBIDDEN, message)

@override
def __str__(self) -> str:
"""Return a string representation of the exception."""
return self.detail
Expand Down Expand Up @@ -429,6 +435,7 @@ def __init__(self, output_id: str) -> None:
message = f"Output '{output_id}' not found"
super().__init__(HTTPStatus.NOT_FOUND, message)

@override
def __str__(self) -> str:
"""Return a string representation of the exception."""
return self.detail
Expand Down Expand Up @@ -463,6 +470,7 @@ def __init__(self, output_id: str, mc_root: str) -> None:
message = f"The output '{output_id}' sub-folder '{mc_root}' does not exist"
super().__init__(HTTPStatus.NOT_FOUND, message)

@override
def __str__(self) -> str:
"""Return a string representation of the exception."""
return self.detail
Expand Down Expand Up @@ -552,6 +560,7 @@ def __init__(self, binding_constraint_id: str, *ids: str) -> None:
}[min(count, 2)]
super().__init__(HTTPStatus.NOT_FOUND, message)

@override
def __str__(self) -> str:
"""Return a string representation of the exception."""
return self.detail
Expand All @@ -572,6 +581,7 @@ def __init__(self, binding_constraint_id: str, *ids: str) -> None:
}[min(count, 2)]
super().__init__(HTTPStatus.CONFLICT, message)

@override
def __str__(self) -> str:
"""Return a string representation of the exception."""
return self.detail
Expand All @@ -589,6 +599,7 @@ def __init__(self, binding_constraint_id: str, term_json: str) -> None:
)
super().__init__(HTTPStatus.UNPROCESSABLE_ENTITY, message)

@override
def __str__(self) -> str:
"""Return a string representation of the exception."""
return self.detail
Expand Down
2 changes: 2 additions & 0 deletions antarest/core/filetransfer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Optional

from sqlalchemy import Boolean, Column, DateTime, Integer, String # type: ignore
from typing_extensions import override

from antarest.core.persistence import Base
from antarest.core.serialization import AntaresBaseModel
Expand Down Expand Up @@ -81,6 +82,7 @@ def to_dto(self) -> FileDownloadDTO:
error_message=self.error_message or "",
)

@override
def __repr__(self) -> str:
return (
f"(id={self.id},"
Expand Down
9 changes: 9 additions & 0 deletions antarest/core/interfaces/eventbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from enum import StrEnum
from typing import Any, Awaitable, Callable, List, Optional

from typing_extensions import override

from antarest.core.model import PermissionInfo
from antarest.core.serialization import AntaresBaseModel

Expand Down Expand Up @@ -140,32 +142,39 @@ class DummyEventBusService(IEventBus):
def __init__(self) -> None:
self.events: List[Event] = []

@override
def queue(self, event: Event, queue: str) -> None:
# Noop
pass

@override
def add_queue_consumer(self, listener: Callable[[Event], Awaitable[None]], queue: str) -> str:
return ""

@override
def remove_queue_consumer(self, listener_id: str) -> None:
# Noop
pass

@override
def push(self, event: Event) -> None:
# Noop
self.events.append(event)

@override
def add_listener(
self,
listener: Callable[[Event], Awaitable[None]],
type_filter: Optional[List[EventType]] = None,
) -> str:
return ""

@override
def remove_listener(self, listener_id: str) -> None:
# Noop
pass

@override
def start(self, threaded: bool = True) -> None:
# Noop
pass
4 changes: 4 additions & 0 deletions antarest/core/logging/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
from typing_extensions import override

from antarest.core.config import Config

Expand All @@ -39,6 +40,7 @@ class CustomDefaultFormatter(logging.Formatter):
fields to the log record with a value of `None`.
"""

@override
def format(self, record: logging.LogRecord) -> str:
"""
Formats the specified log record using the custom formatter,
Expand Down Expand Up @@ -169,13 +171,15 @@ def configure_logger(config: Config, handler_cls: str = "logging.FileHandler") -


class LoggingMiddleware(BaseHTTPMiddleware):
@override
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
with RequestContext(request):
response = await call_next(request)
return response


class ContextFilter(logging.Filter):
@override
def filter(self, record: logging.LogRecord) -> bool:
request: Optional[Request] = _request.get()
request_id: Optional[str] = _request_id.get()
Expand Down
8 changes: 8 additions & 0 deletions antarest/core/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from fastapi import HTTPException
from markupsafe import escape
from ratelimit import Rule # type: ignore
from typing_extensions import override

from antarest.core.jwt import JWTUser

Expand All @@ -38,24 +39,30 @@ def __init__(self, data=None, **kwargs) -> None: # type: ignore
data = {}
self.update(data, **kwargs)

@override
def __setitem__(self, key: str, value: t.Any) -> None:
self._store[key.lower()] = (key, value)

@override
def __getitem__(self, key: str) -> t.Any:
return self._store[key.lower()][1]

@override
def __delitem__(self, key: str) -> None:
del self._store[key.lower()]

@override
def __iter__(self) -> t.Any:
return (casedkey for casedkey, mappedvalue in self._store.values())

@override
def __len__(self) -> int:
return len(self._store)

def lower_items(self) -> Generator[Tuple[Any, Any], Any, None]:
return ((lowerkey, keyval[1]) for (lowerkey, keyval) in self._store.items())

@override
def __eq__(self, other: t.Any) -> bool:
if isinstance(other, t.Mapping):
other = CaseInsensitiveDict(other)
Expand All @@ -66,6 +73,7 @@ def __eq__(self, other: t.Any) -> bool:
def copy(self) -> "CaseInsensitiveDict":
return CaseInsensitiveDict(self._store.values())

@override
def __repr__(self) -> str:
return str(dict(self.items()))

Expand Down
5 changes: 5 additions & 0 deletions antarest/core/tasks/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, Sequence, String # type: ignore
from sqlalchemy.engine.base import Engine # type: ignore
from sqlalchemy.orm import relationship, sessionmaker # type: ignore
from typing_extensions import override

from antarest.core.persistence import Base
from antarest.core.serialization import AntaresBaseModel
Expand Down Expand Up @@ -122,11 +123,13 @@ class TaskJobLog(Base): # type: ignore
# If the TaskJob is deleted, all attached logs must also be deleted in cascade.
job: "TaskJob" = relationship("TaskJob", back_populates="logs", uselist=False)

@override
def __eq__(self, other: t.Any) -> bool:
if not isinstance(other, TaskJobLog):
return False
return bool(other.id == self.id and other.message == self.message and other.task_id == self.task_id)

@override
def __repr__(self) -> str:
return f"id={self.id}, message={self.message}, task_id={self.task_id}"

Expand Down Expand Up @@ -198,6 +201,7 @@ def to_dto(self, with_logs: bool = False) -> TaskDTO:
progress=self.progress,
)

@override
def __eq__(self, other: t.Any) -> bool:
if not isinstance(other, TaskJob):
return False
Expand All @@ -213,6 +217,7 @@ def __eq__(self, other: t.Any) -> bool:
and other.logs == self.logs
)

@override
def __repr__(self) -> str:
return (
f"id={self.id},"
Expand Down
Loading
Loading