diff --git a/src/aleph_vrf/executor/main.py b/src/aleph_vrf/executor/main.py index a88b2c9..4e9aa39 100644 --- a/src/aleph_vrf/executor/main.py +++ b/src/aleph_vrf/executor/main.py @@ -47,7 +47,7 @@ async def index(): @app.post("/generate/{vrf_request}") -async def receive_generate(vrf_request: str) -> APIResponse: +async def receive_generate(vrf_request: str) -> APIResponse[VRFResponseHash]: global SAVED_GENERATED_BYTES private_key = get_fallback_private_key() @@ -87,7 +87,7 @@ async def receive_generate(vrf_request: str) -> APIResponse: @app.post("/publish/{hash_message}") -async def receive_publish(hash_message: str) -> APIResponse: +async def receive_publish(hash_message: str) -> APIResponse[VRFRandomBytes]: global SAVED_GENERATED_BYTES private_key = get_fallback_private_key() diff --git a/src/aleph_vrf/models.py b/src/aleph_vrf/models.py index 7434de5..e2d57b5 100644 --- a/src/aleph_vrf/models.py +++ b/src/aleph_vrf/models.py @@ -1,8 +1,9 @@ -from typing import Any, List, Optional +from typing import Any, List, Optional, TypeVar, Generic from uuid import UUID, uuid4 from aleph_message.models import ItemHash, PostMessage from pydantic import BaseModel +from pydantic.generics import GenericModel class Node(BaseModel): @@ -93,5 +94,8 @@ class VRFResponse(BaseModel): message_hash: Optional[str] = None -class APIResponse(BaseModel): - data: Any +M = TypeVar("M", bound=BaseModel) + + +class APIResponse(GenericModel, Generic[M]): + data: M diff --git a/tests/executor/test_integration.py b/tests/executor/test_integration.py index cefd136..b5ca4ec 100644 --- a/tests/executor/test_integration.py +++ b/tests/executor/test_integration.py @@ -1,12 +1,12 @@ import datetime as dt from hashlib import sha256 -from typing import Dict, Any +from typing import Dict, Any, Tuple, Union import aiohttp import pytest -from aleph_message.models import ItemType, MessageType, PostContent, Chain +from aleph_message.models import ItemType, MessageType, PostContent, Chain, ItemHash -from aleph_vrf.models import VRFRequest +from aleph_vrf.models import VRFRequest, VRFResponseHash, VRFResponse @pytest.mark.asyncio @@ -20,24 +20,16 @@ async def test_mock(mock_ccn: str): assert messages -@pytest.fixture -def mock_coordinator_message(): - sender = "aleph_vrf_coordinator" - request_id = "513eb52c-cb74-463a-b40e-0e2adedafb8b" +MessageDict = Dict[str, Any] - vrf_request = VRFRequest( - nb_bytes=32, - nonce=42, - vrf_function="deca" * 16, - request_id=request_id, - node_list_hash="1234", - ) + +def make_post_message(vrf_object: Union[VRFRequest, VRFResponse], sender: str) -> MessageDict: content = PostContent( type="vrf_library_post", - ref=f"vrf_{vrf_request.request_id}_request", + ref=f"vrf_{vrf_object.request_id}_request", address=sender, time=dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc).timestamp(), - content=vrf_request, + content=vrf_object, ) item_content = content.json() @@ -51,24 +43,71 @@ def mock_coordinator_message(): "signature": "fake-sig", "item_type": ItemType.inline.value, "item_content": item_content, - "channel": f"vrf_{request_id}", + "channel": f"vrf_{vrf_object.request_id}", "time": content.time, } +@pytest.fixture +def mock_vrf_request() -> VRFRequest: + request_id = "513eb52c-cb74-463a-b40e-0e2adedafb8b" + + vrf_request = VRFRequest( + nb_bytes=32, + nonce=42, + vrf_function="deca" * 16, + request_id=request_id, + node_list_hash="1234", + ) + return vrf_request + + +# @pytest.fixture +# def mock_vrf_response(mock_vrf_request: VRFRequest) -> VRFResponse: +# return VRFResponse(nb_bytes=mock_vrf_request.nb_bytes, nonce=mock_vrf_request.nonce, vrf_function=mock_vrf_request.vrf_function, nodes=[], random) + + @pytest.mark.asyncio -async def test_start_request( - mock_ccn: str, executor_server: str, mock_coordinator_message: Dict[str, Any] +async def test_normal_request_flow( + mock_ccn: str, + executor_server: str, + mock_vrf_request: VRFRequest, + # mock_vrf_response: VRFResponse, ): + """ + Test that the executor works under normal circumstances: + 1. The coordinator publishes a request message + 2. The coordinator calls /generate + 3. The coordinator publishes a message to request the generated random number + 4. The coordinator calls /publish. + """ + + sender = "aleph_vrf_coordinator" + message_dict = make_post_message(mock_vrf_request, sender=sender) + item_hash = ItemHash(message_dict["item_hash"]) + async with aiohttp.ClientSession(mock_ccn) as ccn_client: resp = await ccn_client.post( "/api/v0/messages", - json={"message": mock_coordinator_message, "sync": True}, + json={"message": message_dict, "sync": True}, ) assert resp.status == 200, await resp.text() async with aiohttp.ClientSession(executor_server) as executor_client: - resp = await executor_client.post( - f"/generate/{mock_coordinator_message['item_hash']}" - ) + resp = await executor_client.post(f"/generate/{item_hash}") + assert resp.status == 200, await resp.text() + response_json = await resp.json() + + response_hash = VRFResponseHash.parse_obj(response_json["data"]) + + assert response_hash.nb_bytes == mock_vrf_request.nb_bytes + assert response_hash.nonce == mock_vrf_request.nonce + assert response_hash.request_id == mock_vrf_request.request_id + assert response_hash.execution_id # This should be a UUID4 + assert response_hash.vrf_request == item_hash + assert response_hash.random_bytes_hash + assert response_hash.message_hash + + resp = await executor_client.post(f"/publish/{response_hash.message_hash}") + assert resp.status == 200, await resp.text() print(await resp.text())