diff --git a/trio/__init__.py b/trio/__init__.py index d66ffceea9..8c56b8b5b8 100644 --- a/trio/__init__.py +++ b/trio/__init__.py @@ -60,6 +60,7 @@ open_memory_channel, MemorySendChannel, MemoryReceiveChannel, + StapledMemoryChannel, ) from ._signals import open_signal_receiver diff --git a/trio/_channel.py b/trio/_channel.py index dac7935c0c..26e0ac3483 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -5,7 +5,7 @@ from outcome import Error, Value from .abc import SendChannel, ReceiveChannel, Channel -from ._util import generic_function, NoPublicConstructor +from ._util import generic_function, NoPublicConstructor, Final import trio from ._core import enable_ki_protection @@ -344,3 +344,65 @@ async def aclose(self): self._state.send_tasks.clear() self._state.data.clear() await trio.lowlevel.checkpoint() + + +@attr.s(auto_attribs=True, eq=False, hash=False) +class StapledMemoryChannel(Channel, metaclass=Final): + """This class `staples `__ + together two memory channel halves to make a bidirectional channel. + + Args: + send_channel (~trio.MemorySendChannel): The channel to use for sending. + receive_channel (~trio.MemoryReceiveChannel): The channel to use for + receiving. + + Example: + + The channel halves from :function:~trio.open_memory_channel can + be bound together and accessed by a simple API:: + + channel = StapledMemoryChannel(*open_memory_channel(1)) + await channel.send("x") + assert await channel.receive() == "x" + + :class:`StapledMemoryChannel` objects implement the methods in the + :class:`~trio.abc.Channel` interface, as well as the "nowait" variants + of send and receive. They also have two additional public attributes: + + .. attribute:: send_channel + + The underlying :class:`~trio.MemorySendChannel`. :meth:`send` and + :meth:`send_nowait` are delegated to this object. + + .. attribute:: receive_channel + + The underlying :class:`~trio.MemoryReceiveChannel`. :meth:`receive()` + and :meth:`receive_nowait` are delegated to this object. + + """ + + send_channel: MemorySendChannel + receive_channel: MemoryReceiveChannel + + async def send(self, value): + """Calls ``self.send_channel.send``.""" + await self.send_channel.send(value) + + def send_nowait(self, value): + """Calls ``self.send_channel.send_nowait``.""" + self.send_channel.send_nowait(value) + + async def receive(self): + """Calls ``self.receive_channel.receive``.""" + return await self.receive_channel.receive() + + def receive_nowait(self): + """Calls ``self.receive_channel.receive_nowait``.""" + return self.receive_channel.receive_nowait() + + async def aclose(self): + """Calls ``aclose`` on both underlying channels.""" + try: + await self.send_channel.aclose() + finally: + await self.receive_channel.aclose() diff --git a/trio/tests/test_channel.py b/trio/tests/test_channel.py index b43466dd7d..01f8366cbf 100644 --- a/trio/tests/test_channel.py +++ b/trio/tests/test_channel.py @@ -1,8 +1,9 @@ import pytest +from .._abc import Channel from ..testing import wait_all_tasks_blocked, assert_checkpoints import trio -from trio import open_memory_channel, EndOfChannel +from trio import open_memory_channel, EndOfChannel, StapledMemoryChannel async def test_channel(): @@ -350,3 +351,26 @@ async def do_send(s, v): assert await r.receive() == 1 with pytest.raises(trio.WouldBlock): r.receive_nowait() + + +async def test_stapled_memory_channel(): + assert issubclass(StapledMemoryChannel, Channel) + stapled = StapledMemoryChannel(*open_memory_channel(0)) + + async with trio.open_nursery() as nursery: + + @nursery.start_soon + async def _listener(): + assert await stapled.receive() == 10 + + await wait_all_tasks_blocked() + await stapled.send(10) + + with pytest.raises(trio.WouldBlock): + stapled.send_nowait(10) + with pytest.raises(trio.WouldBlock): + stapled.receive_nowait() + + assert not (stapled.send_channel._closed or stapled.receive_channel._closed) + await stapled.aclose() + assert stapled.send_channel._closed and stapled.receive_channel._closed