Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove connection protocol #3184

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 0 additions & 23 deletions aries_cloudagent/config/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,15 +297,6 @@ def add_arguments(self, parser: ArgumentParser):
"invitation URL. Default: false."
),
)
parser.add_argument(
"--connections-invite",
action="store_true",
env_var="ACAPY_CONNECTIONS_INVITE",
help=(
"After startup, generate and print a new connections protocol "
"style invitation URL. Default: false."
),
)
parser.add_argument(
"--invite-label",
dest="invite_label",
Expand Down Expand Up @@ -445,8 +436,6 @@ def get_settings(self, args: Namespace) -> dict:
settings["debug.seed"] = args.debug_seed
if args.invite:
settings["debug.print_invitation"] = True
if args.connections_invite:
settings["debug.print_connections_invitation"] = True
if args.invite_label:
settings["debug.invite_label"] = args.invite_label
if args.invite_multi_use:
Expand Down Expand Up @@ -1467,24 +1456,12 @@ def add_arguments(self, parser: ArgumentParser):
"and send mediation request and set as default mediator."
),
)
parser.add_argument(
"--mediator-connections-invite",
action="store_true",
env_var="ACAPY_MEDIATION_CONNECTIONS_INVITE",
help=(
"Connect to mediator through a connection invitation. "
"If not specified, connect using an OOB invitation. "
"Default: false."
),
)

def get_settings(self, args: Namespace):
"""Extract mediation invitation settings."""
settings = {}
if args.mediator_invitation:
settings["mediation.invite"] = args.mediator_invitation
if args.mediator_connections_invite:
settings["mediation.connections_invite"] = True

return settings

Expand Down
12 changes: 11 additions & 1 deletion aries_cloudagent/config/default_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from ..anoncreds.registry import AnonCredsRegistry
from ..cache.base import BaseCache
from ..cache.in_memory import InMemoryCache
from ..connections.base_manager import BaseConnectionManager
from ..core.event_bus import EventBus
from ..core.goal_code_registry import GoalCodeRegistry
from ..core.plugin_registry import PluginRegistry
from ..core.profile import ProfileManager, ProfileManagerProvider
from ..core.profile import Profile, ProfileManager, ProfileManagerProvider
from ..core.protocol_registry import ProtocolRegistry
from ..protocols.actionmenu.v1_0.base_service import BaseMenuService
from ..protocols.actionmenu.v1_0.driver_service import DriverMenuService
Expand Down Expand Up @@ -117,6 +118,12 @@ async def bind_providers(self, context: InjectionContext):
context.injector.bind_instance(BaseMenuService, DriverMenuService(context))
context.injector.bind_instance(BaseIntroductionService, DemoIntroductionService())

# Allow BaseConnectionManager to be overridden
context.injector.bind_provider(
BaseConnectionManager,
ClassProvider(BaseConnectionManager, ClassProvider.Inject(Profile)),
)

async def load_plugins(self, context: InjectionContext):
"""Set up plugin registry and load plugins."""

Expand All @@ -126,6 +133,9 @@ async def load_plugins(self, context: InjectionContext):
wallet_type = self.settings.get("wallet.type")
context.injector.bind_instance(PluginRegistry, plugin_registry)

# Connection management endpoints
plugin_registry.register_plugin("aries_cloudagent.connections")

# Register standard protocol plugins
if not self.settings.get("transport.disabled"):
plugin_registry.register_package("aries_cloudagent.protocols")
Expand Down
60 changes: 22 additions & 38 deletions aries_cloudagent/connections/base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import logging
from typing import Dict, List, Optional, Sequence, Text, Tuple, Union
import warnings

import pydid
from base58 import b58decode
Expand All @@ -28,10 +29,7 @@
from ..core.profile import Profile
from ..did.did_key import DIDKey
from ..multitenant.base import BaseMultitenantManager
from ..protocols.connections.v1_0.message_types import ARIES_PROTOCOL as CONN_PROTO
from ..protocols.connections.v1_0.messages.connection_invitation import (
ConnectionInvitation,
)
from ..protocols.didexchange.v1_0.message_types import ARIES_PROTOCOL as CONN_PROTO
from ..protocols.coordinate_mediation.v1_0.models.mediation_record import (
MediationRecord,
)
Expand Down Expand Up @@ -301,6 +299,8 @@ async def create_did_document(
) -> DIDDoc:
"""Create our DID doc for a given DID.

This method is deprecated and will be removed.

Args:
did_info (DIDInfo): The DID information (DID and verkey) used in the
connection.
Expand All @@ -313,6 +313,7 @@ async def create_did_document(
DIDDoc: The prepared `DIDDoc` instance.

"""
warnings.warn("create_did_document is deprecated and will be removed soon")
did_doc = DIDDoc(did=did_info.did)
did_controller = did_info.did
did_key = did_info.verkey
Expand Down Expand Up @@ -615,7 +616,7 @@ def _extract_key_material_in_base58_format(method: VerificationMethod) -> str:
async def _fetch_connection_targets_for_invitation(
self,
connection: ConnRecord,
invitation: Union[ConnectionInvitation, InvitationMessage],
invitation: InvitationMessage,
sender_verkey: str,
) -> Sequence[ConnectionTarget]:
"""Get a list of connection targets for an invitation.
Expand All @@ -625,48 +626,31 @@ async def _fetch_connection_targets_for_invitation(

Args:
connection (ConnRecord): The connection record associated with the invitation.
invitation (Union[ConnectionInvitation, InvitationMessage]): The connection
invitation (InvitationMessage): The connection
or OOB invitation retrieved from the connection record.
sender_verkey (str): The sender's verification key.

Returns:
Sequence[ConnectionTarget]: A list of `ConnectionTarget` objects
representing the connection targets for the invitation.
"""
if isinstance(invitation, ConnectionInvitation):
# conn protocol invitation
if invitation.did:
did = invitation.did
(
endpoint,
recipient_keys,
routing_keys,
) = await self.resolve_invitation(did)
# out-of-band invitation
oob_service_item = invitation.services[0]
if isinstance(oob_service_item, str):
(
endpoint,
recipient_keys,
routing_keys,
) = await self.resolve_invitation(oob_service_item)

else:
endpoint = invitation.endpoint
recipient_keys = invitation.recipient_keys
routing_keys = invitation.routing_keys
else:
# out-of-band invitation
oob_service_item = invitation.services[0]
if isinstance(oob_service_item, str):
(
endpoint,
recipient_keys,
routing_keys,
) = await self.resolve_invitation(oob_service_item)

else:
endpoint = oob_service_item.service_endpoint
recipient_keys = [
DIDKey.from_did(k).public_key_b58
for k in oob_service_item.recipient_keys
]
routing_keys = [
DIDKey.from_did(k).public_key_b58
for k in oob_service_item.routing_keys
]
endpoint = oob_service_item.service_endpoint
recipient_keys = [
DIDKey.from_did(k).public_key_b58 for k in oob_service_item.recipient_keys
]
routing_keys = [
DIDKey.from_did(k).public_key_b58 for k in oob_service_item.routing_keys
]

return [
ConnectionTarget(
Expand Down
36 changes: 8 additions & 28 deletions aries_cloudagent/connections/models/conn_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,6 @@
INDY_RAW_PUBLIC_KEY_VALIDATE,
UUID4_EXAMPLE,
)
from ...protocols.connections.v1_0.message_types import ARIES_PROTOCOL as CONN_PROTO
from ...protocols.connections.v1_0.message_types import (
CONNECTION_INVITATION,
CONNECTION_REQUEST,
)
from ...protocols.connections.v1_0.messages.connection_invitation import (
ConnectionInvitation,
)
from ...protocols.connections.v1_0.messages.connection_request import ConnectionRequest
from ...protocols.didcomm_prefix import DIDCommPrefix
from ...protocols.didexchange.v1_0.message_types import ARIES_PROTOCOL as DIDEX_1_1
from ...protocols.didexchange.v1_0.message_types import DIDEX_1_0
from ...protocols.didexchange.v1_0.messages.request import DIDXRequest
Expand All @@ -44,7 +34,7 @@ class Meta:

schema_class = "MaybeStoredConnRecordSchema"

SUPPORTED_PROTOCOLS = (CONN_PROTO, DIDEX_1_0, DIDEX_1_1)
SUPPORTED_PROTOCOLS = (DIDEX_1_0, DIDEX_1_1)

class Role(Enum):
"""RFC 160 (inviter, invitee) = RFC 23 (responder, requester)."""
Expand Down Expand Up @@ -430,7 +420,7 @@ async def retrieve_by_alias(cls, session: ProfileSession, alias: str) -> "ConnRe
async def attach_invitation(
self,
session: ProfileSession,
invitation: Union[ConnectionInvitation, OOBInvitation],
invitation: OOBInvitation,
):
"""Persist the related connection invitation to storage.

Expand All @@ -447,9 +437,7 @@ async def attach_invitation(
storage = session.inject(BaseStorage)
await storage.add_record(record)

async def retrieve_invitation(
self, session: ProfileSession
) -> Union[ConnectionInvitation, OOBInvitation]:
async def retrieve_invitation(self, session: ProfileSession) -> OOBInvitation:
"""Retrieve the related connection invitation.

Args:
Expand All @@ -462,16 +450,12 @@ async def retrieve_invitation(
{"connection_id": self.connection_id},
)
ser = json.loads(result.value)
return (
ConnectionInvitation
if DIDCommPrefix.unqualify(ser["@type"]) == CONNECTION_INVITATION
else OOBInvitation
).deserialize(ser)
return OOBInvitation.deserialize(ser)

async def attach_request(
self,
session: ProfileSession,
request: Union[ConnectionRequest, DIDXRequest],
request: DIDXRequest,
):
"""Persist the related connection request to storage.

Expand All @@ -491,7 +475,7 @@ async def attach_request(
async def retrieve_request(
self,
session: ProfileSession,
) -> Union[ConnectionRequest, DIDXRequest]:
) -> DIDXRequest:
"""Retrieve the related connection invitation.

Args:
Expand All @@ -503,11 +487,7 @@ async def retrieve_request(
self.RECORD_TYPE_REQUEST, {"connection_id": self.connection_id}
)
ser = json.loads(result.value)
return (
ConnectionRequest
if DIDCommPrefix.unqualify(ser["@type"]) == CONNECTION_REQUEST
else DIDXRequest
).deserialize(ser)
return DIDXRequest.deserialize(ser)

@property
def is_ready(self) -> str:
Expand Down Expand Up @@ -709,7 +689,7 @@ class Meta:
validate=validate.OneOf(ConnRecord.SUPPORTED_PROTOCOLS),
metadata={
"description": "Connection protocol used",
"example": "connections/1.0",
"example": "didexchange/1.1",
},
)
rfc23_state = fields.Str(
Expand Down
Loading
Loading