diff --git a/src/aleph_vrf/executor/main.py b/src/aleph_vrf/executor/main.py index 407a648..c45d3dd 100644 --- a/src/aleph_vrf/executor/main.py +++ b/src/aleph_vrf/executor/main.py @@ -1,5 +1,10 @@ import logging -from typing import Dict, Union +from contextlib import asynccontextmanager +from typing import Dict, Union, Set +from uuid import UUID + +import fastapi +from aleph.sdk.exceptions import MessageNotFoundError, MultipleMessagesError from aleph_vrf.settings import settings @@ -10,7 +15,7 @@ from aleph.sdk.chains.ethereum import ETHAccount from aleph.sdk.client import AlephClient, AuthenticatedAlephClient from aleph.sdk.vm.app import AlephApp -from aleph_message.models import ItemHash +from aleph_message.models import ItemHash, PostMessage from aleph_message.status import MessageStatus logger.debug("import fastapi") @@ -28,15 +33,27 @@ logger.debug("imports done") -http_app = FastAPI() -app = AlephApp(http_app=http_app) - GENERATE_MESSAGE_REF_PATH = "hash" # TODO: Use another method to save the data +ANSWERED_REQUESTS: Set[str] = set() SAVED_GENERATED_BYTES: Dict[str, bytes] = {} +@asynccontextmanager +async def lifespan(app: FastAPI): + global ANSWERED_REQUESTS, SAVED_GENERATED_BYTES + + print(f"ANSWERED_REQUESTS: {ANSWERED_REQUESTS}") + ANSWERED_REQUESTS.clear() + SAVED_GENERATED_BYTES.clear() + yield + + +http_app = FastAPI(lifespan=lifespan) +app = AlephApp(http_app=http_app) + + @app.get("/") async def index(): return { @@ -45,22 +62,46 @@ async def index(): } +async def _get_message(client: AlephClient, item_hash: ItemHash) -> PostMessage: + try: + return await client.get_message(item_hash=item_hash, message_type=PostMessage) + except MessageNotFoundError: + raise fastapi.HTTPException( + status_code=404, detail=f"Message {item_hash} not found" + ) + except MultipleMessagesError: + raise fastapi.HTTPException( + status_code=409, + detail=f"Multiple messages have the following hash: {item_hash}", + ) + except TypeError: + raise fastapi.HTTPException( + status_code=409, detail=f"Message {item_hash} is not a POST message" + ) + + @app.post("/generate/{vrf_request}") -async def receive_generate(vrf_request: str) -> APIResponse[VRFResponseHash]: - global SAVED_GENERATED_BYTES +async def receive_generate(vrf_request: ItemHash) -> APIResponse[VRFResponseHash]: + global SAVED_GENERATED_BYTES, ANSWERED_REQUESTS private_key = get_fallback_private_key() account = ETHAccount(private_key=private_key) - print(settings.API_HOST) async with AlephClient(api_server=settings.API_HOST) as client: - message = await client.get_message(item_hash=vrf_request) + message = await _get_message(client=client, item_hash=vrf_request) generation_request = generate_request_from_message(message) + if generation_request.request_id in ANSWERED_REQUESTS: + raise fastapi.HTTPException( + status_code=409, + detail=f"A random number has already been generated for request {vrf_request}", + ) + generated_bytes, hashed_bytes = generate( generation_request.nb_bytes, generation_request.nonce ) SAVED_GENERATED_BYTES[str(generation_request.execution_id)] = generated_bytes + ANSWERED_REQUESTS.add(generation_request.request_id) response_hash = VRFResponseHash( nb_bytes=generation_request.nb_bytes, @@ -86,19 +127,19 @@ async def receive_generate(vrf_request: str) -> APIResponse[VRFResponseHash]: @app.post("/publish/{hash_message}") -async def receive_publish(hash_message: str) -> APIResponse[VRFRandomBytes]: +async def receive_publish(hash_message: ItemHash) -> APIResponse[VRFRandomBytes]: global SAVED_GENERATED_BYTES private_key = get_fallback_private_key() account = ETHAccount(private_key=private_key) async with AlephClient(api_server=settings.API_HOST) as client: - message = await client.get_message(item_hash=hash_message) + message = await _get_message(client=client, item_hash=hash_message) response_hash = generate_response_hash_from_message(message) - if not SAVED_GENERATED_BYTES[str(response_hash.execution_id)]: - raise ValueError( - f"Random bytes not existing for execution {response_hash.execution_id}" + if response_hash.execution_id not in SAVED_GENERATED_BYTES: + raise fastapi.HTTPException( + status_code=404, detail="The random number has already been published" ) random_bytes: bytes = SAVED_GENERATED_BYTES.pop(str(response_hash.execution_id)) diff --git a/src/aleph_vrf/models.py b/src/aleph_vrf/models.py index 2c63006..417e1dc 100644 --- a/src/aleph_vrf/models.py +++ b/src/aleph_vrf/models.py @@ -1,8 +1,9 @@ from typing import List, Optional, TypeVar, Generic from uuid import uuid4 +import fastapi from aleph_message.models import ItemHash, PostMessage -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError, Field from pydantic.generics import GenericModel @@ -25,19 +26,19 @@ class VRFGenerationRequest(BaseModel): nb_bytes: int nonce: int request_id: str - execution_id: str + execution_id: str = Field(default_factory=lambda: str(uuid4())) vrf_function: ItemHash def generate_request_from_message(message: PostMessage) -> VRFGenerationRequest: content = message.content.content - return VRFGenerationRequest( - nb_bytes=content["nb_bytes"], - nonce=content["nonce"], - request_id=content["request_id"], - execution_id=str(uuid4()), - vrf_function=ItemHash(content["vrf_function"]), - ) + try: + return VRFGenerationRequest.parse_obj(content) + except ValidationError as e: + raise fastapi.HTTPException( + status_code=422, + detail=f"Could not parse content of {message.item_hash} as VRF request object: {e.json()}", + ) class VRFResponseHash(BaseModel): @@ -52,15 +53,16 @@ class VRFResponseHash(BaseModel): def generate_response_hash_from_message(message: PostMessage) -> VRFResponseHash: content = message.content.content - return VRFResponseHash( - nb_bytes=content["nb_bytes"], - nonce=content["nonce"], - request_id=content["request_id"], - execution_id=content["execution_id"], - vrf_request=ItemHash(content["vrf_request"]), - random_bytes_hash=content["random_bytes_hash"], - message_hash=message.item_hash, - ) + try: + response_hash = VRFResponseHash.parse_obj(content) + except ValidationError as e: + raise fastapi.HTTPException( + 422, + detail=f"Could not parse content of {message.item_hash} as VRF response hash object: {e.json()}", + ) + + response_hash.message_hash = message.item_hash + return response_hash class VRFRandomBytes(BaseModel): diff --git a/tests/executor/test_integration.py b/tests/executor/test_integration.py index 6c5c785..e37de4f 100644 --- a/tests/executor/test_integration.py +++ b/tests/executor/test_integration.py @@ -221,27 +221,8 @@ async def test_normal_request_flow( assert random_number_message.chain == random_hash_message.chain -@pytest.mark.asyncio -async def test_call_publish_before_generate( - executor_client: aiohttp.ClientSession, - published_vrf_request: Tuple[MessageDict, VRFRequest], -): - """ - Test that calling /publish before /generate does not leak any data. - """ - - # Create the coordinator request - message_dict, vrf_request = published_vrf_request - item_hash = message_dict["item_hash"] - - # Use the item_hash of an existing POST message, just for fun - resp = await executor_client.post(f"/publish/{item_hash}") - assert resp.status == 404 - - @pytest.mark.asyncio async def test_call_publish_twice( - mock_ccn_client: aiohttp.ClientSession, executor_client: aiohttp.ClientSession, published_vrf_request: Tuple[MessageDict, VRFRequest], ): @@ -297,11 +278,14 @@ async def test_call_generate_twice( @pytest.mark.asyncio async def test_call_generate_without_aleph_message( - mock_ccn_client: aiohttp.ClientSession, executor_client: aiohttp.ClientSession, mock_vrf_request: VRFRequest, ): """ Test that calling POST /generate without an aleph message fails. """ - ... + item_hash = "bad0" * 16 + + # Call POST /generate with a nonexistent item hash + resp = await executor_client.post(f"/generate/{item_hash}") + assert resp.status == 404, await resp.text()