diff --git a/firecrest/AsyncClient.py b/firecrest/AsyncClient.py index 2d63b27..f5c2615 100644 --- a/firecrest/AsyncClient.py +++ b/firecrest/AsyncClient.py @@ -27,6 +27,7 @@ import firecrest.FirecrestException as fe import firecrest.types as t from firecrest.AsyncExternalStorage import AsyncExternalUpload, AsyncExternalDownload +from firecrest.utilities import time_block if sys.version_info >= (3, 8): @@ -194,6 +195,10 @@ def __init__( "tasks": 0.1, "utilities": 0.1, } + #: Merge GET requests to the same endpoint, when possible. This will + #: take effect only when the time_between_calls of the microservice + #: is greater than 0. + self.merge_get_requests: bool = False self._next_request_ts: dict[str, float] = { "compute": 0, "reservations": 0, @@ -203,21 +208,26 @@ def __init__( "utilities": 0, } self._locks = { - "compute": asyncio.Lock(), - "reservations": asyncio.Lock(), - "status": asyncio.Lock(), - "storage": asyncio.Lock(), - "tasks": asyncio.Lock(), - "utilities": asyncio.Lock(), + "/compute/jobs": asyncio.Lock(), + "/compute/acct": asyncio.Lock(), + "/tasks": asyncio.Lock(), + } + # The following objects are used to "merge" requests in the same + # endpoints, for example requests to tasks or polling for jobs + self._polling_ids: dict[str, set] = { + "/compute/jobs": set(), + "/compute/acct": set(), + "/tasks": set() + } + self._polling_results: dict[str, List] = { + "/compute/jobs": [], + "/compute/acct": [], + "/tasks": [] } - - # The following objects are used to "merge" requests in the same endpoints, - # for example requests to tasks or polling for jobs - self._polling_ids: dict[str, set] = {"compute": set(), "tasks": set()} - self._polling_results: dict[str, List] = {"compute": [], "tasks": []} self._polling_events: dict[str, Optional[asyncio.Event]] = { - "compute": None, - "tasks": None, + "/compute/jobs": None, + "/compute/acct": None, + "/tasks": None, } def set_api_version(self, api_version: str) -> None: @@ -242,8 +252,7 @@ def is_session_closed(self) -> bool: """Check if the httpx session is closed""" return self._session.is_closed - @_retry_requests # type: ignore - async def _get_request( + async def _get_merge_request( self, endpoint, additional_headers=None, params=None ) -> httpx.Response: microservice = endpoint.split("/")[1] @@ -251,11 +260,11 @@ async def _get_request( async def _merged_get(event): await self._stall_request(microservice) - async with self._locks[microservice]: - results = self._polling_results[microservice] - ids = self._polling_ids[microservice].copy() - self._polling_events[microservice] = None - self._polling_ids[microservice] = set() + async with self._locks[endpoint]: + results = self._polling_results[endpoint] + ids = self._polling_ids[endpoint].copy() + self._polling_events[endpoint] = None + self._polling_ids[endpoint] = set() comma_sep_par = "tasks" if microservice == "tasks" else "jobs" if ids == {"*"}: if comma_sep_par in params: @@ -270,9 +279,12 @@ async def _merged_get(event): headers.update(additional_headers) logger.info(f"Making GET request to {endpoint}") - resp = await self._session.get( - url=url, headers=headers, params=params, timeout=self.timeout - ) + with time_block(f"GET request to {endpoint}", logger): + resp = await self._session.get( + url=url, headers=headers, + params=params, + timeout=self.timeout + ) self._next_request_ts[microservice] = ( time.time() + self.time_between_calls[microservice] @@ -283,75 +295,127 @@ async def _merged_get(event): return - if microservice == "tasks" or endpoint in ("/compute/jobs", "/compute/acct"): - async with self._locks[microservice]: - if self._polling_ids[microservice] != {"*"}: - comma_sep_par = "tasks" if microservice == "tasks" else "jobs" - if comma_sep_par not in params: - self._polling_ids[microservice] = {"*"} - else: - task_ids = params[comma_sep_par].split(",") - self._polling_ids[microservice].update(task_ids) - - if self._polling_events[microservice] is None: - self._polling_events[microservice] = asyncio.Event() - my_event = self._polling_events[microservice] - self._polling_results[microservice] = [] - my_result = self._polling_results[microservice] - waiter = True - task = asyncio.create_task(_merged_get(my_event)) + async with self._locks[endpoint]: + if self._polling_ids[endpoint] != {"*"}: + comma_sep_par = "tasks" if endpoint == "/tasks" else "jobs" + if comma_sep_par not in params: + self._polling_ids[endpoint] = {"*"} else: - waiter = False - my_event = self._polling_events[microservice] - my_result = self._polling_results[microservice] + new_ids = params[comma_sep_par].split(",") + self._polling_ids[endpoint].update(new_ids) + + if self._polling_events[endpoint] is None: + self._polling_events[endpoint] = asyncio.Event() + my_event = self._polling_events[endpoint] + self._polling_results[endpoint] = [] + my_result = self._polling_results[endpoint] + waiter = True + task = asyncio.create_task(_merged_get(my_event)) + else: + waiter = False + my_event = self._polling_events[endpoint] + my_result = self._polling_results[endpoint] - if waiter: - await task + if waiter: + await task - await my_event.wait() # type: ignore - resp = my_result[0] - return resp + await my_event.wait() # type: ignore + resp = my_result[0] + return resp - # Otherwise just do what you usually do - async with self._locks[microservice]: - await self._stall_request(microservice) - headers = { - "Authorization": f"Bearer {self._authorization.get_access_token()}" - } - if additional_headers: - headers.update(additional_headers) + async def _get_simple_request( + self, endpoint, additional_headers=None, params=None + ) -> httpx.Response: + microservice = endpoint.split("/")[1] + url = f"{self._firecrest_url}{endpoint}" + await self._stall_request(microservice) + self._next_request_ts[microservice] = ( + time.time() + self.time_between_calls[microservice] + ) + headers = { + "Authorization": f"Bearer {self._authorization.get_access_token()}" + } + if additional_headers: + headers.update(additional_headers) - logger.info(f"Making GET request to {endpoint}") + logger.info(f"Making GET request to {endpoint}") + with time_block(f"GET request to {endpoint}", logger): resp = await self._session.get( url=url, headers=headers, params=params, timeout=self.timeout ) - self._next_request_ts[microservice] = ( - time.time() + self.time_between_calls[microservice] - ) return resp + @_retry_requests # type: ignore + async def _get_request( + self, endpoint, additional_headers=None, params=None + ) -> httpx.Response: + microservice = endpoint.split("/")[1] + if ( + self.merge_get_requests and + self.time_between_calls[microservice] > 0 and + endpoint in ("/compute/jobs", "/compute/acct", "/tasks") + ): + # We can only merge requests with the additional restrictions: + # - For `/compute/acct` we can merge only if the start_time, + # end_time, and pagination parameters are not set. + # Moreover we cannot merge if the `*` is used as a task id, + # because the default `sacct` command will only return the + # jobs of the last day. + # - For `/compute/jobs` we can merge only if the pagination + # parameters are not set. + if ( + endpoint == "/compute/acct" + and ( + "starttime" not in params + or "endtime" not in params + or "pageSize" not in params + or "pageNumber" not in params + or params.get("jobs") + ) + ) or ( + endpoint == "/compute/jobs" + and ( + "pageSize" not in params + or "pageNumber" not in params + or params.get("jobs") + ) + ) or ( + endpoint == "/tasks" + ): + return await self._get_merge_request( + endpoint=endpoint, + additional_headers=additional_headers, + params=params + ) + + return await self._get_simple_request( + endpoint=endpoint, + additional_headers=additional_headers, + params=params + ) + @_retry_requests # type: ignore async def _post_request( self, endpoint, additional_headers=None, data=None, files=None ) -> httpx.Response: microservice = endpoint.split("/")[1] url = f"{self._firecrest_url}{endpoint}" - async with self._locks[microservice]: - await self._stall_request(microservice) - headers = { - "Authorization": f"Bearer {self._authorization.get_access_token()}" - } - if additional_headers: - headers.update(additional_headers) + await self._stall_request(microservice) + self._next_request_ts[microservice] = ( + time.time() + self.time_between_calls[microservice] + ) + headers = { + "Authorization": f"Bearer {self._authorization.get_access_token()}" + } + if additional_headers: + headers.update(additional_headers) - logger.info(f"Making POST request to {endpoint}") + logger.info(f"Making POST request to {endpoint}") + with time_block(f"POST request to {endpoint}", logger): resp = await self._session.post( url=url, headers=headers, data=data, files=files, timeout=self.timeout ) - self._next_request_ts[microservice] = ( - time.time() + self.time_between_calls[microservice] - ) return resp @@ -361,21 +425,21 @@ async def _put_request( ) -> httpx.Response: microservice = endpoint.split("/")[1] url = f"{self._firecrest_url}{endpoint}" - async with self._locks[microservice]: - await self._stall_request(microservice) - headers = { - "Authorization": f"Bearer {self._authorization.get_access_token()}" - } - if additional_headers: - headers.update(additional_headers) + self._next_request_ts[microservice] = ( + time.time() + self.time_between_calls[microservice] + ) + await self._stall_request(microservice) + headers = { + "Authorization": f"Bearer {self._authorization.get_access_token()}" + } + if additional_headers: + headers.update(additional_headers) - logger.info(f"Making PUT request to {endpoint}") + logger.info(f"Making PUT request to {endpoint}") + with time_block(f"PUT request to {endpoint}", logger): resp = await self._session.put( url=url, headers=headers, data=data, timeout=self.timeout ) - self._next_request_ts[microservice] = ( - time.time() + self.time_between_calls[microservice] - ) return resp @@ -385,17 +449,20 @@ async def _delete_request( ) -> httpx.Response: microservice = endpoint.split("/")[1] url = f"{self._firecrest_url}{endpoint}" - async with self._locks[microservice]: - await self._stall_request(microservice) - headers = { - "Authorization": f"Bearer {self._authorization.get_access_token()}" - } - if additional_headers: - headers.update(additional_headers) + await self._stall_request(microservice) + self._next_request_ts[microservice] = ( + time.time() + self.time_between_calls[microservice] + ) + headers = { + "Authorization": f"Bearer {self._authorization.get_access_token()}" + } + if additional_headers: + headers.update(additional_headers) - logger.info(f"Making DELETE request to {endpoint}") - # httpx doesn't support data in the `delete` method so we will have to - # use the generic `request` method + logger.info(f"Making DELETE request to {endpoint}") + with time_block(f"DELETE request to {endpoint}", logger): + # httpx doesn't support data in the `delete` method so we will + # have to use the generic `request` method # https://www.python-httpx.org/compatibility/#request-body-on-http-methods resp = await self._session.request( method="DELETE", @@ -404,9 +471,6 @@ async def _delete_request( data=data, timeout=self.timeout, ) - self._next_request_ts[microservice] = ( - time.time() + self.time_between_calls[microservice] - ) return resp @@ -414,11 +478,13 @@ async def _stall_request(self, microservice: str) -> None: if self._next_request_ts[microservice] is not None: while time.time() <= self._next_request_ts[microservice]: logger.debug( - f"`{microservice}` microservice has received too many requests. " - f"Going to sleep for " + f"`{microservice}` microservice has received too many " + f"requests. Going to sleep for " f"~{self._next_request_ts[microservice] - time.time()} sec" ) - await asyncio.sleep(self._next_request_ts[microservice] - time.time()) + await asyncio.sleep( + self._next_request_ts[microservice] - time.time() + ) @overload def _json_response(