Skip to content

Commit

Permalink
Merge pull request #293 from TUMFARSynchrony/frontend-gk-asymmetric-f…
Browse files Browse the repository at this point in the history
…ilter

Asymmetric Filters Feature
  • Loading branch information
gulbikeimge authored Oct 7, 2024
2 parents da746bc + ba1bc8f commit ac82c9e
Show file tree
Hide file tree
Showing 25 changed files with 1,109 additions and 848 deletions.
179 changes: 124 additions & 55 deletions backend/connection/connection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
"""Provide the `Connection` and `SubConnection` classes."""

from __future__ import annotations

import asyncio
import json
import logging
from typing import Any, Callable, Coroutine, Tuple

import shortuuid
from aiortc import (
RTCPeerConnection,
RTCDataChannel,
Expand All @@ -10,35 +16,30 @@
RTCRtpSender,
RTCIceCandidate,
)
from aiortc.contrib.media import MediaRelay
from aiortc.sdp import candidate_from_sdp

import shortuuid
import asyncio
import logging
import json

from connection.connection_interface import ConnectionInterface
from connection.connection_state import ConnectionState, parse_connection_state
from connection.messages import (
ConnectionAnswerDict,
ConnectionOfferDict,
ConnectionProposalDict,
AddIceCandidateDict,
)
from connection.sub_connection import SubConnection
from hub.track_handler import TrackHandler
from hub.exceptions import ErrorDictException
from connection.connection_interface import ConnectionInterface
from connection.connection_state import ConnectionState, parse_connection_state
from hub.record_handler import RecordHandler

from connection.util import find_participant_asymmetric_filter
from custom_types.error import ErrorDict
from custom_types.message import MessageDict, is_valid_messagedict
from filter_api import FilterAPIInterface
from filters import FilterDict
from filters.filter_data_dict import FilterDataDict
from filter_api import FilterAPIInterface
from custom_types.message import MessageDict, is_valid_messagedict
from hub.exceptions import ErrorDictException
from hub.record_handler import RecordHandler
from hub.track_handler import TrackHandler
from session.data.participant import ParticipantDict
from session.data.participant.participant_summary import ParticipantSummaryDict

from aiortc.contrib.media import MediaRelay


class Connection(ConnectionInterface):
"""Connection with a single client using multiple sub-connections.
Expand Down Expand Up @@ -78,14 +79,20 @@ class Connection(ConnectionInterface):
_audio_record_handler: RecordHandler
_video_record_handler: RecordHandler
_raw_video_record_handler: RecordHandler
_filter_api: FilterAPIInterface
_video_track: MediaStreamTrack
_audio_track: MediaStreamTrack
_relay: MediaRelay
_sub_connection_video_track_handlers: dict[str, TrackHandler]
_sub_connection_audio_track_handlers: dict[str, TrackHandler]

def __init__(
self,
pc: RTCPeerConnection,
message_handler: Callable[[MessageDict], Coroutine[Any, Any, None]],
log_name_suffix: str,
filter_api: FilterAPIInterface,
record_data: tuple,
self,
pc: RTCPeerConnection,
message_handler: Callable[[MessageDict], Coroutine[Any, Any, None]],
log_name_suffix: str,
filter_api: FilterAPIInterface,
record_data: tuple,
) -> None:
"""Create new Connection based on a aiortc.RTCPeerConnection.
Expand Down Expand Up @@ -121,6 +128,9 @@ def __init__(
self._message_handler = message_handler
self._incoming_audio = TrackHandler("audio", self, filter_api)
self._incoming_video = TrackHandler("video", self, filter_api)
self._relay = MediaRelay()
self._sub_connection_audio_track_handlers = dict()
self._sub_connection_video_track_handlers = dict()

(record, record_to) = record_data
self._audio_record_handler = RecordHandler(
Expand All @@ -139,18 +149,24 @@ def __init__(

self._dc = None
self._tasks = []

self._filter_api = filter_api
# Register event handlers
pc.add_listener("datachannel", self._on_datachannel)
pc.add_listener("connectionstatechange", self._on_connection_state_change)
pc.add_listener("track", self._on_track)

def _set_video_track(self, track: MediaStreamTrack):
self._video_track = track

def _set_audio_track(self, track: MediaStreamTrack):
self._audio_track = track

async def complete_setup(
self,
audio_filters: list[FilterDict],
video_filters: list[FilterDict],
audio_group_filters: list[FilterDict],
video_group_filters: list[FilterDict],
self,
audio_filters: list[FilterDict],
video_filters: list[FilterDict],
audio_group_filters: list[FilterDict],
video_group_filters: list[FilterDict],
) -> None:
"""Complete Connection setup.
Expand Down Expand Up @@ -230,21 +246,59 @@ def state(self) -> ConnectionState:
# For docstring see ConnectionInterface or hover over function declaration
return self._state

async def create_subscriber_proposal(
self, participant_summary: ParticipantSummaryDict | str | None
) -> ConnectionProposalDict:
# For docstring see ConnectionInterface or hover over function declaration

subconnection_id = shortuuid.uuid()
async def _create_sub_connection_media_tracks(
self,
sub_connection_id: str,
participant_summary: ParticipantSummaryDict | str | None,
subscriber: ParticipantDict | None = None
) -> tuple[TrackHandler, TrackHandler] | tuple[MediaStreamTrack, MediaStreamTrack]:
if subscriber is not None:
asymmetric_filter = find_participant_asymmetric_filter(participant_summary, subscriber)
if asymmetric_filter is not None:
audio_track_handler = TrackHandler("audio", self, self._filter_api)
video_track_handler = TrackHandler("video", self, self._filter_api)
await asyncio.gather(
audio_track_handler.complete_setup(asymmetric_filter["audio_filters"],
subscriber["audio_group_filters"]),
audio_track_handler.set_track(self._relay.subscribe(self._audio_track, False)),
video_track_handler.complete_setup(asymmetric_filter["video_filters"],
subscriber["video_group_filters"]),
video_track_handler.set_track(self._relay.subscribe(self._video_track, False))

)
self._sub_connection_audio_track_handlers[sub_connection_id] = audio_track_handler
self._sub_connection_video_track_handlers[sub_connection_id] = video_track_handler
return audio_track_handler, video_track_handler

return self._incoming_audio.subscribe(), self._incoming_video.subscribe()

async def _create_sub_connection(
self,
participant_summary: ParticipantSummaryDict | str | None,
subscriber: ParticipantDict | None = None
) -> SubConnection:
sub_connection_id = shortuuid.uuid()
audio_track, video_track = await self._create_sub_connection_media_tracks(
sub_connection_id,
participant_summary,
subscriber
)
sc = SubConnection(
subconnection_id,
self._incoming_video.subscribe(),
self._incoming_audio.subscribe(),
sub_connection_id,
video_track,
audio_track,
participant_summary,
self._log_name_suffix,
)
sc.add_listener("connection_closed", self._handle_closed_subconnection)
self._sub_connections[subconnection_id] = sc
self._sub_connections[sub_connection_id] = sc
return sc

async def create_subscriber_proposal(
self, participant_summary: ParticipantSummaryDict | str | None,
subscriber: ParticipantDict | None = None
) -> ConnectionProposalDict:
sc: SubConnection = await self._create_sub_connection(participant_summary, subscriber)
return sc.proposal

async def handle_add_ice_candidate(self, candidate: RTCIceCandidate):
Expand All @@ -265,7 +319,7 @@ async def handle_add_ice_candidate(self, candidate: RTCIceCandidate):
await self._main_pc.addIceCandidate(rtc_candidate)

async def handle_subscriber_offer(
self, offer: ConnectionOfferDict
self, offer: ConnectionOfferDict
) -> ConnectionAnswerDict:
# For docstring see ConnectionInterface or hover over function declaration
subconnection_id = offer["id"]
Expand All @@ -284,7 +338,7 @@ async def handle_subscriber_offer(
return await sc.handle_offer(offer_description)

async def handle_subscriber_add_ice_candidate(
self, candidate: AddIceCandidateDict
self, candidate: AddIceCandidateDict
):
# For docstring see ConnectionInterface or hover over function declaration
subconnection_id = candidate["id"]
Expand Down Expand Up @@ -328,11 +382,16 @@ async def set_muted(self, video: bool, audio: bool) -> None:
self._incoming_video.muted = video
if self._incoming_audio is not None:
self._incoming_audio.muted = audio
for sub_connection_id in self._sub_connection_video_track_handlers:
self._sub_connection_video_track_handlers[sub_connection_id].muted = video
for sub_connection_id in self._sub_connection_audio_track_handlers:
self._sub_connection_audio_track_handlers[sub_connection_id].muted = audio

async def start_recording(self) -> None:
# For docstring see ConnectionInterface or hover over function declaration
await asyncio.gather(
self._video_record_handler.start(), self._raw_video_record_handler.start(), self._audio_record_handler.start()
self._video_record_handler.start(), self._raw_video_record_handler.start(),
self._audio_record_handler.start()
)

async def stop_recording(self) -> None:
Expand All @@ -350,13 +409,13 @@ async def set_audio_filters(self, filters: list[FilterDict]) -> None:
await self._incoming_audio.set_filters(filters)

async def set_video_group_filters(
self, group_filters: list[FilterDict], ports: list[tuple[int, int]]
self, group_filters: list[FilterDict], ports: list[tuple[int, int]]
) -> None:
# For docstring see ConnectionInterface or hover over function declaration
await self._incoming_video.set_group_filters(group_filters, ports)

async def set_audio_group_filters(
self, group_filters: list[FilterDict], ports: list[tuple[int, int]]
self, group_filters: list[FilterDict], ports: list[tuple[int, int]]
) -> None:
# For docstring see ConnectionInterface or hover over function declaration
await self._incoming_audio.set_group_filters(group_filters, ports)
Expand All @@ -373,6 +432,14 @@ async def _handle_closed_subconnection(self, subconnection_id: str) -> None:
self._logger.debug(f"Remove sub connection {subconnection_id}")
self._sub_connections.pop(subconnection_id)
self._logger.debug(f"SubConnections after removing: {self._sub_connections}")
if subconnection_id in self._sub_connection_video_track_handlers:
self._sub_connection_video_track_handlers.pop(subconnection_id)
self._logger.debug(
f"Sub-connection video track handler after removing: {self._sub_connection_video_track_handlers}")
if subconnection_id in self._sub_connection_audio_track_handlers:
self._sub_connection_audio_track_handlers.pop(subconnection_id)
self._logger.debug(
f"Sub-connection audio track handler after removing: {self._sub_connection_audio_track_handlers}")

def _on_datachannel(self, channel: RTCDataChannel) -> None:
"""Handle new incoming datachannel.
Expand All @@ -396,10 +463,10 @@ async def _handle_datachannel_close(self) -> None:
async def _check_if_closed(self) -> None:
"""Check if this Connection is closed and should stop."""
if (
self._incoming_audio is not None
and self._incoming_video is not None
and self._incoming_audio.readyState == "ended"
and self._incoming_video.readyState == "ended"
self._incoming_audio is not None
and self._incoming_video is not None
and self._incoming_audio.readyState == "ended"
and self._incoming_video.readyState == "ended"
):
self._logger.debug("All incoming tracks are closed -> stop connection")
await self.stop()
Expand Down Expand Up @@ -490,13 +557,15 @@ def _on_track(self, track: MediaStreamTrack):
self._logger.debug(f"{track.kind} track received")
if track.kind == "audio":
task = asyncio.create_task(self._incoming_audio.set_track(track))
self._set_audio_track(track)
sender = self._main_pc.addTrack(self._incoming_audio.subscribe())
self._audio_record_handler.add_track(self._incoming_audio.subscribe())
self._listen_to_track_close(self._incoming_audio, sender)
elif track.kind == "video":
# Use relay to be able to access track multiple times
relay = MediaRelay()
task = asyncio.create_task(self._incoming_video.set_track(relay.subscribe(track, False)))
self._set_video_track(track)
sender = self._main_pc.addTrack(self._incoming_video.subscribe())
self._video_record_handler.add_track(self._incoming_video.subscribe())
self._raw_video_record_handler.add_track(relay.subscribe(track, False))
Expand Down Expand Up @@ -538,15 +607,15 @@ async def _close_transceiver():


async def connection_factory(
offer: RTCSessionDescription,
message_handler: Callable[[MessageDict], Coroutine[Any, Any, None]],
log_name_suffix: str,
audio_filters: list[FilterDict],
video_filters: list[FilterDict],
audio_group_filters: list[FilterDict],
video_group_filters: list[FilterDict],
filter_api: FilterAPIInterface,
record_data: list,
offer: RTCSessionDescription,
message_handler: Callable[[MessageDict], Coroutine[Any, Any, None]],
log_name_suffix: str,
audio_filters: list[FilterDict],
video_filters: list[FilterDict],
audio_group_filters: list[FilterDict],
video_group_filters: list[FilterDict],
filter_api: FilterAPIInterface,
record_data: list,
) -> Tuple[RTCSessionDescription, Connection]:
"""Instantiate Connection.
Expand Down
3 changes: 2 additions & 1 deletion backend/connection/connection_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def state(self) -> ConnectionState:

@abstractmethod
async def create_subscriber_proposal(
self, participant_summary: ParticipantSummaryDict | str | None
self, participant_summary: ParticipantSummaryDict | str | None,
subscriber=None
) -> ConnectionProposalDict:
"""Create a SubConnection proposal.
Expand Down
2 changes: 1 addition & 1 deletion backend/connection/connection_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ async def _handle_message(self, msg: dict):
case "SEND":
await self._connection.send(data)
case "CREATE_PROPOSAL":
proposal = await self._connection.create_subscriber_proposal(data)
proposal = await self._connection.create_subscriber_proposal(data["participant_summary"], data["subscriber"] if "subscriber" in data else None)
self._send_command("CONNECTION_PROPOSAL", proposal, command_nr)
case "HANDLE_OFFER":
try:
Expand Down
Loading

0 comments on commit ac82c9e

Please sign in to comment.