Skip to content
This repository has been archived by the owner on Nov 22, 2024. It is now read-only.

Commit

Permalink
Sending to outbox working again
Browse files Browse the repository at this point in the history
Signed-off-by: John Andersen <[email protected]>
  • Loading branch information
pdxjohnny committed Oct 23, 2023
1 parent 035beb7 commit 82cb35c
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 107 deletions.
11 changes: 0 additions & 11 deletions scitt_emulator/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,9 @@ def __init__(self, app, signals: SCITTSignals, config_path: Path):
self.app = app
self.asgi_app = app.asgi_app
self.signals = signals
self.connect_signals()
self.config = {}
if config_path and config_path.exists():
self.config = json.loads(config_path.read_text())

async def __call__(self, scope, receive, send):
return await self.asgi_app(scope, receive, send)

def connect_signals(self):
self.created_entry = self.signals.federation.created_entry.connect(self.created_entry)

@abstractmethod
def created_entry(
self,
created_entry: SCITTSignalsFederationCreatedEntry,
):
raise NotImplementedError
118 changes: 34 additions & 84 deletions scitt_emulator/federation_activitypub_bovine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import json
import types
import atexit
import base64
import socket
Expand All @@ -8,6 +9,7 @@
import asyncio
import pathlib
import tempfile
import functools
import traceback
import contextlib
import subprocess
Expand All @@ -30,7 +32,7 @@
from scitt_emulator.scitt import SCITTServiceEmulator
from scitt_emulator.federation import SCITTFederation
from scitt_emulator.tree_algs import TREE_ALGS
from scitt_emulator.signals import SCITTSignalsFederationCreatedEntry
from scitt_emulator.signals import SCITTSignals, SCITTSignalsFederationCreatedEntry

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -80,17 +82,14 @@ def __init__(self, app, signals, config_path):
BovinePubSub(app)
BovineHerd(app)

@app.before_serving
async def initialize_service():
await self.initialize_service()
# asyncio.create_task(self.initialize_service())
# app.add_background_task(self.initialize_service)
app.before_serving(self.initialize_service)

async def initialize_service(self):
# TODO Better domain / fqdn building
self.domain = f'http://127.0.0.1:{self.app.config["port"]}'

config_toml_path = pathlib.Path(self.workspace, "config.toml")
config_toml_path.unlink()
if not config_toml_path.exists():
logger.info("Actor client config does not exist, creating...")
cmd = [
Expand All @@ -112,8 +111,6 @@ async def initialize_service(self):
config_toml_obj[self.handle_name]["handlers"][
inspect.getmodule(sys.modules[__name__]).__spec__.name
] = {
# TODO Sending signal to submit federated claim
# signals.federation.submit_claim.send(self, claim=created_entry.claim)
"signals": self.signals,
"following": self.config.get("following", {}),
}
Expand All @@ -139,94 +136,50 @@ async def initialize_service(self):
logger.info("Actor key added in database")

# Run client handlers
"""
cmd = [
sys.executable,
"-um",
"mechanical_bull.run",
]
self.mechanical_bull_proc = subprocess.Popen(
cmd,
cwd=self.workspace,
)
atexit.register(self.mechanical_bull_proc.terminate)
"""

def build_handler(handler, value):
import importlib
from functools import partial

func = importlib.import_module(handler).handle

if isinstance(value, dict):
return partial(func, **value)
return func

def load_handlers(handlers):
return [build_handler(handler, value) for handler, value in handlers.items()]

async def mechanical_bull_loop(config):
from mechanical_bull.event_loop import loop
from mechanical_bull.handlers import load_handlers, build_handler

async with asyncio.TaskGroup() as taskgroup:
for client_name, value in config.items():
if isinstance(value, dict):
handlers = load_handlers(value["handlers"])
taskgroup.create_task(loop(client_name, value, handlers))

# self.app.add_background_task(mechanical_bull_loop, config_toml_obj)

def created_entry(
self,
scitt_service: SCITTServiceEmulator,
created_entry: SCITTSignalsFederationCreatedEntry,
):
return
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as client:
client.connect(str(self.federate_created_entries_socket_path.resolve()))
client.send(
json.dumps(
{
"treeAlgorithm": created_entry.tree_alg,
"service_parameters": base64.b64encode(
created_entry.public_service_parameters
).decode(),
"entry_id": created_entry.entry_id,
"receipt": base64.b64encode(created_entry.receipt).decode(),
"claim": base64.b64encode(created_entry.claim).decode(),
}
).encode()
)
client.close()
self.app.add_background_task(mechanical_bull_loop, config_toml_obj)


async def handle(
client: bovine.BovineClient,
data: dict,
# config.toml arguments
signals: SCITTSignals = None,
following: dict[str, Follow] = None,
federate_created_entries_socket_path: Path = None,
raise_on_follow_failure: bool = False,
# handler arguments
handler_event: HandlerEvent = None,
handler_api_version: HandlerAPIVersion = HandlerAPIVersion.unstable,
):
try:
logging.info(f"{__file__}:handle(handler_event={handler_event})")
logger.info(f"{__file__}:handle(handler_event={handler_event})")
match handler_event:
case HandlerEvent.OPENED:
# Listen for events from SCITT
asyncio.create_task(
federate_created_entries(
client, federate_created_entries_socket_path
)
)
# TODO Do this without using a client, server side
async def federate_created_entries_pass_client(
sender: SCITTServiceEmulator,
created_entry: SCITTSignalsFederationCreatedEntry = None,
):
nonlocal client
await federate_created_entries(client, sender, created_entry)
client.federate_created_entries = types.MethodType(signals.federation.created_entry.connect(federate_created_entries_pass_client), client)
# print(signals.federation.created_entry.connect(federate_created_entries))
# Preform ActivityPub related init
if following:
try:
async with asyncio.TaskGroup() as tg:
for key, value in following.items():
logging.info("Following... %r", value)
logger.info("Following... %r", value)
tg.create_task(init_follow(client, **value))
except (ExceptionGroup, BaseExceptionGroup) as error:
if raise_on_follow_failure:
Expand Down Expand Up @@ -275,19 +228,10 @@ async def handle(

logger.info("Receipt verified")

return
# Send signal to submit federated claim
# TODO Announce that this entry ID was created via
# federation to avoid an infinate loop
scitt_emulator.client.submit_claim(
home_scitt_url,
claim,
str(Path(tempdir, "home_receipt").resolve()),
str(Path(tempdir, "home_entry_id").resolve()),
scitt_emulator.client.HttpClient(
home_scitt_token,
home_scitt_cacert,
),
)
await signals.federation.submit_claim.send_async(client, claim=claim)
except Exception as ex:
logger.error(ex)
logger.exception(ex)
Expand Down Expand Up @@ -325,14 +269,23 @@ async def init_follow(client, retry: int = 5, **kwargs):
async def federate_created_entries(
client: bovine.BovineClient,
sender: SCITTServiceEmulator,
created_entry: SCITTSignalsFederationCreatedEntry = None,
):
try:
logger.info("federate_created_entry() Reading... %r", reader)
content_bytes = await reader.read()
logger.info("federate_created_entry() Read: %r", content_bytes)
logger.info("federate_created_entry() created_entry: %r", created_entry)
note = (
client.object_factory.note(
content=content_bytes.decode(),
content=json.dumps(
{
"treeAlgorithm": created_entry.tree_alg,
"service_parameters": base64.b64encode(
created_entry.public_service_parameters
).decode(),
"entry_id": created_entry.entry_id,
"receipt": base64.b64encode(created_entry.receipt).decode(),
"claim": base64.b64encode(created_entry.claim).decode(),
}
)
)
.as_public()
.build()
Expand All @@ -341,9 +294,6 @@ async def federate_created_entries(
logger.info("Sending... %r", activity)
await client.send_to_outbox(activity)

writer.close()
await writer.wait_closed()

# DEBUG NOTE Dumping outbox
print("client:", client)
outbox = client.outbox()
Expand Down
21 changes: 11 additions & 10 deletions scitt_emulator/scitt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import ABC, abstractmethod
from pathlib import Path
import contextlib
import asyncio
import time
import json
import uuid
Expand Down Expand Up @@ -75,8 +76,8 @@ def connect_signals(self):
self.signal_receiver_submit_claim,
)

def signal_receiver_submit_claim(self, _sender, claim: bytes) -> None:
self.submit_claim(claim, long_running=True)
async def signal_receiver_submit_claim(self, _sender, claim: bytes) -> None:
await self.submit_claim(claim, long_running=True)

@abstractmethod
def initialize_service(self):
Expand All @@ -90,7 +91,7 @@ def create_receipt_contents(self, countersign_tbi: bytes, entry_id: str):
def verify_receipt_contents(receipt_contents: list, countersign_tbi: bytes):
raise NotImplementedError

def get_operation(self, operation_id: str) -> dict:
async def get_operation(self, operation_id: str) -> dict:
operation_path = self.operations_path / f"{operation_id}.json"
try:
with open(operation_path, "r") as f:
Expand All @@ -101,7 +102,7 @@ def get_operation(self, operation_id: str) -> dict:
if operation["status"] == "running":
# Pretend that the service finishes the operation after
# the client having checked the operation status once.
operation = self._finish_operation(operation)
operation = await self._finish_operation(operation)
return operation

def get_entry(self, entry_id: str) -> dict:
Expand All @@ -121,7 +122,7 @@ def get_claim(self, entry_id: str) -> bytes:
raise EntryNotFoundError(f"Entry {entry_id} not found")
return claim

def submit_claim(self, claim: bytes, long_running=True) -> dict:
async def submit_claim(self, claim: bytes, long_running=True) -> dict:
insert_policy = self.service_parameters.get("insertPolicy", DEFAULT_INSERT_POLICY)

try:
Expand All @@ -137,7 +138,7 @@ def submit_claim(self, claim: bytes, long_running=True) -> dict:
f"non-* insertPolicy only works with long_running=True: {insert_policy!r}"
)
else:
return self._create_entry(claim)
return await self._create_entry(claim)

def public_service_parameters(self) -> bytes:
# TODO Only export public portion of cert
Expand All @@ -153,7 +154,7 @@ def get_entry_id(self, claim: bytes) -> str:
entry_id = f"{entry_id_hash_alg}:{entry_id_hash.hexdigest()}"
return entry_id

def _create_entry(self, claim: bytes) -> dict:
async def _create_entry(self, claim: bytes) -> dict:
entry_id = self.get_entry_id(claim)

receipt = self._create_receipt(claim, entry_id)
Expand All @@ -165,7 +166,7 @@ def _create_entry(self, claim: bytes) -> dict:

entry = {"entryId": entry_id}

self.signals.federation.created_entry.send(
await self.signals.federation.created_entry.send_async(
self,
created_entry=SCITTSignalsFederationCreatedEntry(
tree_alg=self.tree_alg,
Expand Down Expand Up @@ -234,7 +235,7 @@ def _sync_policy_result(self, operation: dict):

return policy_result

def _finish_operation(self, operation: dict):
async def _finish_operation(self, operation: dict):
operation_id = operation["operationId"]
operation_path = self.operations_path / f"{operation_id}.json"
claim_src_path = self.operations_path / f"{operation_id}.cose"
Expand All @@ -251,7 +252,7 @@ def _finish_operation(self, operation: dict):
return operation

claim = claim_src_path.read_bytes()
entry = self._create_entry(claim)
entry = await self._create_entry(claim)
claim_src_path.unlink()

operation["status"] = "succeeded"
Expand Down
4 changes: 2 additions & 2 deletions scitt_emulator/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ async def submit_claim():
return await make_unavailable_error()
try:
if use_lro:
result = app.scitt_service.submit_claim(await request.get_data(), long_running=True)
result = await app.scitt_service.submit_claim(await request.get_data(), long_running=True)
headers = {
"Location": f"{request.host_url}/operations/{result['operationId']}",
"Retry-After": "1"
}
status_code = 202
else:
result = app.scitt_service.submit_claim(await request.get_data(), long_running=False)
result = await app.scitt_service.submit_claim(await request.get_data(), long_running=False)
headers = {
"Location": f"{request.host_url}/entries/{result['entryId']}",
}
Expand Down

0 comments on commit 82cb35c

Please sign in to comment.