Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor asynchronous client #107

Merged
merged 11 commits into from
Jun 13, 2024
262 changes: 164 additions & 98 deletions firecrest/AsyncClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import firecrest.FirecrestException as fe
import firecrest.types as t
from firecrest.AsyncExternalStorage import AsyncExternalUpload, AsyncExternalDownload
from firecrest.utilities import time_block


if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -194,6 +195,10 @@ def __init__(
"tasks": 0.1,
"utilities": 0.1,
}
#: Merge GET requests to the same endpoint, when possible. This will
#: take effect only when the time_between_calls of the microservice
#: is greater than 0.
self.merge_get_requests: bool = False
self._next_request_ts: dict[str, float] = {
"compute": 0,
"reservations": 0,
Expand All @@ -203,21 +208,26 @@ def __init__(
"utilities": 0,
}
self._locks = {
"compute": asyncio.Lock(),
"reservations": asyncio.Lock(),
"status": asyncio.Lock(),
"storage": asyncio.Lock(),
"tasks": asyncio.Lock(),
"utilities": asyncio.Lock(),
"/compute/jobs": asyncio.Lock(),
"/compute/acct": asyncio.Lock(),
"/tasks": asyncio.Lock(),
}
# The following objects are used to "merge" requests in the same
# endpoints, for example requests to tasks or polling for jobs
self._polling_ids: dict[str, set] = {
"/compute/jobs": set(),
"/compute/acct": set(),
"/tasks": set()
}
self._polling_results: dict[str, List] = {
"/compute/jobs": [],
"/compute/acct": [],
"/tasks": []
}

# The following objects are used to "merge" requests in the same endpoints,
# for example requests to tasks or polling for jobs
self._polling_ids: dict[str, set] = {"compute": set(), "tasks": set()}
self._polling_results: dict[str, List] = {"compute": [], "tasks": []}
self._polling_events: dict[str, Optional[asyncio.Event]] = {
"compute": None,
"tasks": None,
"/compute/jobs": None,
"/compute/acct": None,
"/tasks": None,
}

def set_api_version(self, api_version: str) -> None:
Expand All @@ -242,20 +252,19 @@ def is_session_closed(self) -> bool:
"""Check if the httpx session is closed"""
return self._session.is_closed

@_retry_requests # type: ignore
async def _get_request(
async def _get_merge_request(
self, endpoint, additional_headers=None, params=None
) -> httpx.Response:
microservice = endpoint.split("/")[1]
url = f"{self._firecrest_url}{endpoint}"

async def _merged_get(event):
await self._stall_request(microservice)
async with self._locks[microservice]:
results = self._polling_results[microservice]
ids = self._polling_ids[microservice].copy()
self._polling_events[microservice] = None
self._polling_ids[microservice] = set()
async with self._locks[endpoint]:
results = self._polling_results[endpoint]
ids = self._polling_ids[endpoint].copy()
self._polling_events[endpoint] = None
self._polling_ids[endpoint] = set()
comma_sep_par = "tasks" if microservice == "tasks" else "jobs"
if ids == {"*"}:
if comma_sep_par in params:
Expand All @@ -270,9 +279,12 @@ async def _merged_get(event):
headers.update(additional_headers)

logger.info(f"Making GET request to {endpoint}")
resp = await self._session.get(
url=url, headers=headers, params=params, timeout=self.timeout
)
with time_block(f"GET request to {endpoint}", logger):
resp = await self._session.get(
url=url, headers=headers,
params=params,
timeout=self.timeout
)

self._next_request_ts[microservice] = (
time.time() + self.time_between_calls[microservice]
Expand All @@ -283,75 +295,127 @@ async def _merged_get(event):

return

if microservice == "tasks" or endpoint in ("/compute/jobs", "/compute/acct"):
async with self._locks[microservice]:
if self._polling_ids[microservice] != {"*"}:
comma_sep_par = "tasks" if microservice == "tasks" else "jobs"
if comma_sep_par not in params:
self._polling_ids[microservice] = {"*"}
else:
task_ids = params[comma_sep_par].split(",")
self._polling_ids[microservice].update(task_ids)

if self._polling_events[microservice] is None:
self._polling_events[microservice] = asyncio.Event()
my_event = self._polling_events[microservice]
self._polling_results[microservice] = []
my_result = self._polling_results[microservice]
waiter = True
task = asyncio.create_task(_merged_get(my_event))
async with self._locks[endpoint]:
if self._polling_ids[endpoint] != {"*"}:
comma_sep_par = "tasks" if endpoint == "/tasks" else "jobs"
if comma_sep_par not in params:
self._polling_ids[endpoint] = {"*"}
else:
waiter = False
my_event = self._polling_events[microservice]
my_result = self._polling_results[microservice]
new_ids = params[comma_sep_par].split(",")
self._polling_ids[endpoint].update(new_ids)

if self._polling_events[endpoint] is None:
self._polling_events[endpoint] = asyncio.Event()
my_event = self._polling_events[endpoint]
self._polling_results[endpoint] = []
my_result = self._polling_results[endpoint]
waiter = True
task = asyncio.create_task(_merged_get(my_event))
else:
waiter = False
my_event = self._polling_events[endpoint]
my_result = self._polling_results[endpoint]

if waiter:
await task
if waiter:
await task

await my_event.wait() # type: ignore
resp = my_result[0]
return resp
await my_event.wait() # type: ignore
resp = my_result[0]
return resp

# Otherwise just do what you usually do
async with self._locks[microservice]:
await self._stall_request(microservice)
headers = {
"Authorization": f"Bearer {self._authorization.get_access_token()}"
}
if additional_headers:
headers.update(additional_headers)
async def _get_simple_request(
self, endpoint, additional_headers=None, params=None
) -> httpx.Response:
microservice = endpoint.split("/")[1]
url = f"{self._firecrest_url}{endpoint}"
await self._stall_request(microservice)
self._next_request_ts[microservice] = (
time.time() + self.time_between_calls[microservice]
)
headers = {
"Authorization": f"Bearer {self._authorization.get_access_token()}"
}
if additional_headers:
headers.update(additional_headers)

logger.info(f"Making GET request to {endpoint}")
logger.info(f"Making GET request to {endpoint}")
with time_block(f"GET request to {endpoint}", logger):
resp = await self._session.get(
url=url, headers=headers, params=params, timeout=self.timeout
)
self._next_request_ts[microservice] = (
time.time() + self.time_between_calls[microservice]
)

return resp

@_retry_requests # type: ignore
async def _get_request(
self, endpoint, additional_headers=None, params=None
) -> httpx.Response:
microservice = endpoint.split("/")[1]
if (
self.merge_get_requests and
self.time_between_calls[microservice] > 0 and
endpoint in ("/compute/jobs", "/compute/acct", "/tasks")
):
# We can only merge requests with the additional restrictions:
# - For `/compute/acct` we can merge only if the start_time,
# end_time, and pagination parameters are not set.
# Moreover we cannot merge if the `*` is used as a task id,
# because the default `sacct` command will only return the
# jobs of the last day.
# - For `/compute/jobs` we can merge only if the pagination
# parameters are not set.
if (
endpoint == "/compute/acct"
and (
"starttime" not in params
or "endtime" not in params
or "pageSize" not in params
or "pageNumber" not in params
or params.get("jobs")
)
) or (
endpoint == "/compute/jobs"
and (
"pageSize" not in params
or "pageNumber" not in params
or params.get("jobs")
)
) or (
endpoint == "/tasks"
):
return await self._get_merge_request(
endpoint=endpoint,
additional_headers=additional_headers,
params=params
)

return await self._get_simple_request(
endpoint=endpoint,
additional_headers=additional_headers,
params=params
)

@_retry_requests # type: ignore
async def _post_request(
self, endpoint, additional_headers=None, data=None, files=None
) -> httpx.Response:
microservice = endpoint.split("/")[1]
url = f"{self._firecrest_url}{endpoint}"
async with self._locks[microservice]:
await self._stall_request(microservice)
headers = {
"Authorization": f"Bearer {self._authorization.get_access_token()}"
}
if additional_headers:
headers.update(additional_headers)
await self._stall_request(microservice)
self._next_request_ts[microservice] = (
time.time() + self.time_between_calls[microservice]
)
headers = {
"Authorization": f"Bearer {self._authorization.get_access_token()}"
}
if additional_headers:
headers.update(additional_headers)

logger.info(f"Making POST request to {endpoint}")
logger.info(f"Making POST request to {endpoint}")
with time_block(f"POST request to {endpoint}", logger):
resp = await self._session.post(
url=url, headers=headers, data=data, files=files, timeout=self.timeout
)
self._next_request_ts[microservice] = (
time.time() + self.time_between_calls[microservice]
)

return resp

Expand All @@ -361,21 +425,21 @@ async def _put_request(
) -> httpx.Response:
microservice = endpoint.split("/")[1]
url = f"{self._firecrest_url}{endpoint}"
async with self._locks[microservice]:
await self._stall_request(microservice)
headers = {
"Authorization": f"Bearer {self._authorization.get_access_token()}"
}
if additional_headers:
headers.update(additional_headers)
self._next_request_ts[microservice] = (
time.time() + self.time_between_calls[microservice]
)
await self._stall_request(microservice)
headers = {
"Authorization": f"Bearer {self._authorization.get_access_token()}"
}
if additional_headers:
headers.update(additional_headers)

logger.info(f"Making PUT request to {endpoint}")
logger.info(f"Making PUT request to {endpoint}")
with time_block(f"PUT request to {endpoint}", logger):
resp = await self._session.put(
url=url, headers=headers, data=data, timeout=self.timeout
)
self._next_request_ts[microservice] = (
time.time() + self.time_between_calls[microservice]
)

return resp

Expand All @@ -385,17 +449,20 @@ async def _delete_request(
) -> httpx.Response:
microservice = endpoint.split("/")[1]
url = f"{self._firecrest_url}{endpoint}"
async with self._locks[microservice]:
await self._stall_request(microservice)
headers = {
"Authorization": f"Bearer {self._authorization.get_access_token()}"
}
if additional_headers:
headers.update(additional_headers)
await self._stall_request(microservice)
self._next_request_ts[microservice] = (
time.time() + self.time_between_calls[microservice]
)
headers = {
"Authorization": f"Bearer {self._authorization.get_access_token()}"
}
if additional_headers:
headers.update(additional_headers)

logger.info(f"Making DELETE request to {endpoint}")
# httpx doesn't support data in the `delete` method so we will have to
# use the generic `request` method
logger.info(f"Making DELETE request to {endpoint}")
with time_block(f"DELETE request to {endpoint}", logger):
# httpx doesn't support data in the `delete` method so we will
# have to use the generic `request` method
# https://www.python-httpx.org/compatibility/#request-body-on-http-methods
resp = await self._session.request(
method="DELETE",
Expand All @@ -404,21 +471,20 @@ async def _delete_request(
data=data,
timeout=self.timeout,
)
self._next_request_ts[microservice] = (
time.time() + self.time_between_calls[microservice]
)

return resp

async def _stall_request(self, microservice: str) -> None:
if self._next_request_ts[microservice] is not None:
while time.time() <= self._next_request_ts[microservice]:
logger.debug(
f"`{microservice}` microservice has received too many requests. "
f"Going to sleep for "
f"`{microservice}` microservice has received too many "
f"requests. Going to sleep for "
f"~{self._next_request_ts[microservice] - time.time()} sec"
)
await asyncio.sleep(self._next_request_ts[microservice] - time.time())
await asyncio.sleep(
self._next_request_ts[microservice] - time.time()
)

@overload
def _json_response(
Expand Down
Loading