Skip to content

Commit

Permalink
only close session if owner
Browse files Browse the repository at this point in the history
  • Loading branch information
ika2kki committed Jan 11, 2024
1 parent 8542332 commit 661f336
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
13 changes: 13 additions & 0 deletions mystbin/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@

if TYPE_CHECKING:
import datetime
from types import TracebackType

from aiohttp import ClientSession
from typing_extensions import Self

__all__ = ("Client",)

Expand All @@ -42,6 +44,17 @@ class Client:
def __init__(self, *, token: str | None = None, session: ClientSession | None = None) -> None:
self.http: HTTPClient = HTTPClient(token=token, session=session)

async def __aenter__(self) -> Self:
return self

async def __aexit__(
self,
exc_cls: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None
) -> None:
await self.close()

async def close(self) -> None:
"""|coro|
Expand Down
5 changes: 4 additions & 1 deletion mystbin/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def __init__(self, verb: SupportedHTTPVerb, path: str, **params: Any) -> None:
class HTTPClient:
__slots__ = (
"_session",
"_owns_session",
"_async",
"_token",
"_locks",
Expand All @@ -135,16 +136,18 @@ class HTTPClient:
def __init__(self, *, token: str | None, session: aiohttp.ClientSession | None = None) -> None:
self._token: str | None = token
self._session: aiohttp.ClientSession | None = session
self._owns_session: bool = False
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
user_agent = "mystbin.py (https://github.com/PythonistaGuild/mystbin.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}"
self.user_agent: str = user_agent.format(__version__, sys.version_info, aiohttp.__version__)

async def close(self) -> None:
if self._session:
if self._session and self._owns_session:
await self._session.close()

async def _generate_session(self) -> aiohttp.ClientSession:
self._session = aiohttp.ClientSession()
self._owns_session = True
return self._session

async def request(self, route: Route, **kwargs: Any) -> Any:
Expand Down

0 comments on commit 661f336

Please sign in to comment.