diff --git a/pypeman/channels.py b/pypeman/channels.py index 35ac3e3..1eda7bd 100644 --- a/pypeman/channels.py +++ b/pypeman/channels.py @@ -79,6 +79,7 @@ def __init__(self, name=None, parent_channel=None, loop=None, message_store_fact self.drop_nodes = None self.reject_nodes = None self.final_nodes = None + self.init_nodes = None self.wait_subchans = wait_subchans self.raise_dropped = False @@ -375,6 +376,8 @@ async def handle(self, msg): async with self.lock: self.status = BaseChannel.PROCESSING try: + if self.init_nodes: + msg = await self.init_nodes[0].handle(msg.copy()) result = await self.subhandle(msg.copy()) await self.message_store.change_message_state(msg_store_id, message.Message.PROCESSED) msg.chan_rslt = result @@ -593,6 +596,15 @@ def _init_end_nodes(self, *end_nodes): previous_node = node return end_nodes + def add_init_nodes(self, *nodes): + """ + Add nodes that will be launched at the start of the channel before all + processing nodes + """ + if self.init_nodes: + nodes = self.init_nodes.extend(nodes) + self.init_nodes = self._init_end_nodes(*nodes) + def add_join_nodes(self, *end_nodes): """ Add nodes that will be launched only after a successful channel process diff --git a/pypeman/nodes.py b/pypeman/nodes.py index f7240d3..aad7e35 100644 --- a/pypeman/nodes.py +++ b/pypeman/nodes.py @@ -168,7 +168,6 @@ async def handle(self, msg): result = await self.async_run(msg) else: result = self.run(msg) - self.processed += 1 if isinstance(result, asyncio.Future): diff --git a/pypeman/tests/test_channel.py b/pypeman/tests/test_channel.py index eb659ea..16da3f9 100644 --- a/pypeman/tests/test_channel.py +++ b/pypeman/tests/test_channel.py @@ -4,6 +4,8 @@ import tempfile import time +from functools import partial + from hl7.client import MLLPClient from pathlib import Path @@ -39,6 +41,11 @@ def raise_exc(msg): raise Exception() +def return_text(msg, text): + msg.payload = text + return msg + + class ChannelsTests(TestCase): def clean_loop(self): # Useful to execute future callbacks # TODO: remove ? @@ -234,6 +241,28 @@ def test_final_nodes(self): vars(self.clean_msg(msg1)), vars(self.clean_msg(endnode_input)), "Channel final_nodes don't takes event msg in input") + def test_init_nodes(self): + """ Whether BaseChannel init_nodes is working """ + chan1 = BaseChannel(name="test_channel_init_clbk", loop=self.loop) + initouttext = "inittxt" + + n1 = TstNode(name="n1") + initnode = TstNode(name="initnode") + n1._reset_test() + initnode._reset_test() + initnode.mock(output=partial(return_text, text=initouttext)) + chan1.add_init_nodes(initnode) + chan1.add(n1) + msg1 = generate_msg(message_content="startmsg") + self.start_channels() + + self.loop.run_until_complete(chan1.handle(msg1)) + + n1_input = n1.last_input() + self.assertEqual( + n1_input.payload, initouttext, + "Channel init_nodes doesn't seem to work") + def test_multiple_callbacks(self): """ Whether BaseChannel all endnodes are working at same time