Skip to content

Commit

Permalink
Break TaskRunEngine into Sync and Async classes
Browse files Browse the repository at this point in the history
  • Loading branch information
bunchesofdonald committed Jul 16, 2024
1 parent eb738c3 commit 84126dc
Show file tree
Hide file tree
Showing 4 changed files with 874 additions and 362 deletions.
29 changes: 28 additions & 1 deletion src/prefect/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
import os
import sys
import warnings
from contextlib import ExitStack, contextmanager
from contextlib import ExitStack, asynccontextmanager, contextmanager
from contextvars import ContextVar, Token
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Dict,
Generator,
Optional,
Expand Down Expand Up @@ -209,6 +210,22 @@ def __init__(self, httpx_settings: Optional[dict[str, Any]] = None):
self._httpx_settings = httpx_settings
self._context_stack = 0

async def __aenter__(self):
self._context_stack += 1
if self._context_stack == 1:
self.sync_client.__enter__()
await self.async_client.__aenter__()
return super().__enter__()
else:
return self

async def __aexit__(self, *exc_info):
self._context_stack -= 1
if self._context_stack == 0:
self.sync_client.__exit__(*exc_info)
await self.async_client.__aexit__(*exc_info)
return super().__exit__(*exc_info)

def __enter__(self):
self._context_stack += 1
if self._context_stack == 1:
Expand All @@ -235,6 +252,16 @@ def get_or_create(cls) -> Generator["ClientContext", None, None]:
with ClientContext() as ctx:
yield ctx

@classmethod
@asynccontextmanager
async def async_get_or_create(cls) -> AsyncGenerator["ClientContext", None]:
ctx = ClientContext.get()
if ctx:
yield ctx
else:
async with ClientContext() as ctx:
yield ctx


class RunContext(ContextModel):
"""
Expand Down
Loading

0 comments on commit 84126dc

Please sign in to comment.