diff --git a/pydra/engine/core.py b/pydra/engine/core.py index 666f2b0255..df969448cd 100644 --- a/pydra/engine/core.py +++ b/pydra/engine/core.py @@ -952,16 +952,21 @@ def create_connections(self, task, detailed=False): (task.name, field.name, val.name, val.field) ) logger.debug("Connecting %s to %s", val.name, task.name) - + # adding a state from the previous task to other_states if ( getattr(self, val.name).state and getattr(self, val.name).state.splitter_rpn_final ): - # adding a state from the previous task to other_states - other_states[val.name] = ( - getattr(self, val.name).state, - field.name, - ) + # adding task_name: (task.state, [a field from the connection] + if val.name not in other_states: + other_states[val.name] = ( + getattr(self, val.name).state, + [field.name], + ) + else: + # if the task already exist in other_state, + # additional field name should be added to the list of fields + other_states[val.name][1].append(field.name) else: # LazyField with the wf input # connections with wf input should be added to the detailed graph description if detailed: diff --git a/pydra/engine/state.py b/pydra/engine/state.py index 94049b88f6..edf8401529 100644 --- a/pydra/engine/state.py +++ b/pydra/engine/state.py @@ -304,9 +304,10 @@ def inner_inputs(self): """ if self.other_states: _inner_inputs = {} - for name, (st, inp) in self.other_states.items(): + for name, (st, inp_l) in self.other_states.items(): if f"_{st.name}" in self.splitter_rpn_compact: - _inner_inputs[f"{self.name}.{inp}"] = st + for inp in inp_l: + _inner_inputs[f"{self.name}.{inp}"] = st return _inner_inputs else: return {} @@ -323,6 +324,10 @@ def update_connections(self, new_other_states=None, new_combiner=None): """ if new_other_states: self.other_states = new_other_states + # ensuring that the connected fields are set as a list + self.other_states = { + nm: (st, ensure_list(flds)) for nm, (st, flds) in self.other_states.items() + } self._connect_splitters() if new_combiner: self.combiner = new_combiner @@ -388,8 +393,93 @@ def _complete_prev_state(self, prev_state=None): prev_state = [f"_{name}" for name in self.other_states] if len(prev_state) == 1: prev_state = prev_state[0] + + if isinstance(prev_state, list): + prev_state = self._removed_repeated(prev_state) return prev_state + def _removed_repeated(self, previous_splitters): + """removing states from previous tasks that are repeated either directly or indirectly""" + for el in previous_splitters: + if el[1:] not in self.other_states: + raise hlpst.PydraStateError( + f"can't ask for splitter from {el[1:]}, other nodes that are connected: {self.other_states}" + ) + + repeated = set( + [ + (el, previous_splitters.count(el)) + for el in previous_splitters + if previous_splitters.count(el) > 1 + ] + ) + if repeated: + # assuming that I want to remove fro right + previous_splitters.reverse() + for el, cnt in repeated: + for ii in range(cnt): + previous_splitters.remove(el) + previous_splitters.reverse() + + el_state = [] + el_connect = [] + el_state_connect = [] + for el in previous_splitters: + nm = el[1:] + st = self.other_states[nm][0] + if not st.other_states: + # states that has no other connections + el_state.append(el) + else: # element has previous_connection + if st.current_splitter: # final? + # states that has previous connections and it's own splitter + el_state_connect.append((el, st.prev_state_splitter)) + else: + # states with previous connections but no additional splitter + el_connect.append((el, st.prev_state_splitter)) + + for el in el_connect: + nm = el[0][1:] + repeated_prev = set(ensure_list(el[1])).intersection(el_state) + if repeated_prev: + for r_el in repeated_prev: + r_nm = r_el[1:] + self.other_states[r_nm] = ( + self.other_states[r_nm][0], + self.other_states[r_nm][1] + self.other_states[nm][1], + ) + new_st = set(ensure_list(el[1])) - set(el_state) + if not new_st: + previous_splitters.remove(el[0]) + else: + for n_el in new_st: + n_nm = n_el[1:] + self.other_states[n_nm] = ( + self.other_states[nm][0].other_states[n_nm][0], + self.other_states[nm][1], + ) + # removing el of the splitter and adding new_st instead + ind = previous_splitters.index(el[0]) + if ind == len(previous_splitters) - 1: + previous_splitters = previous_splitters[:-1] + list(new_st) + else: + previous_splitters = ( + previous_splitters[:ind] + + list(new_st) + + previous_splitters[ind + 1 :] + ) + # TODO: this part is not tested, needs more work + for el in el_state_connect: + repeated_prev = set(ensure_list(el[1])).intersection(el_state) + if repeated_prev: + for r_el in repeated_prev: + previous_splitters.remove(r_el) + + if len(previous_splitters) == 1: + return previous_splitters[0] + else: + return previous_splitters + def _prevst_current_check(self, splitter_part, check_nested=True): """ Check if splitter_part is purely prev-state part, the current part, @@ -640,14 +730,15 @@ def prepare_states_ind(self): # TODO: need tests in test_Workflow.py elements_to_remove = [] elements_to_remove_comb = [] - for name, (st, inp) in self.other_states.items(): - if ( - f"{self.name}.{inp}" in self.splitter_rpn - and f"_{name}" in self.splitter_rpn_compact - ): - elements_to_remove.append(f"_{name}") - if f"{self.name}.{inp}" not in self.combiner: - elements_to_remove_comb.append(f"_{name}") + for name, (st, inp_l) in self.other_states.items(): + for inp in inp_l: + if ( + f"{self.name}.{inp}" in self.splitter_rpn + and f"_{name}" in self.splitter_rpn_compact + ): + elements_to_remove.append(f"_{name}") + if f"{self.name}.{inp}" not in self.combiner: + elements_to_remove_comb.append(f"_{name}") partial_rpn = hlpst.remove_inp_from_splitter_rpn( deepcopy(self.splitter_rpn_compact), elements_to_remove @@ -770,8 +861,9 @@ def prepare_inputs(self): for ii, el in enumerate(self.prev_state_splitter_rpn_compact): if el in ["*", "."]: continue - st, inp = self.other_states[el[1:]] - if f"{self.name}.{inp}" in self.splitter_rpn: # inner splitter + st, inp_l = self.other_states[el[1:]] + inp_l = [f"{self.name}.{inp}" for inp in inp_l] + if set(inp_l).intersection(self.splitter_rpn): # inner splitter connected_to_inner += [ el for el in st.splitter_rpn_final if el not in [".", "*"] ] @@ -784,8 +876,9 @@ def prepare_inputs(self): else: inputs_ind_prev = hlpst.op["*"](inputs_ind_prev, st_ind) else: - inputs_ind_prev = hlpst.op["*"](st_ind) - keys_inp_prev += [f"{self.name}.{inp}"] + # TODO: more tests needed + inputs_ind_prev = hlpst.op["."](*[st_ind] * len(inp_l)) + keys_inp_prev += inp_l keys_inp = keys_inp_prev + keys_inp if inputs_ind and inputs_ind_prev: diff --git a/pydra/engine/tests/test_state.py b/pydra/engine/tests/test_state.py index dadd8c3381..cca4de6a4a 100644 --- a/pydra/engine/tests/test_state.py +++ b/pydra/engine/tests/test_state.py @@ -586,6 +586,98 @@ def test_state_connect_6a(): ] +def test_state_connect_7(): + """two 'connected' states with multiple fields that are connected + no explicit splitter for the second state + """ + st1 = State(name="NA", splitter="a") + st2 = State(name="NB", other_states={"NA": (st1, ["x", "y"])}) + # should take into account that x, y come from the same task + assert st2.splitter == "_NA" + assert st2.splitter_rpn == ["NA.a"] + assert st2.prev_state_splitter == st2.splitter + assert st2.prev_state_splitter_rpn == st2.splitter_rpn + assert st2.current_splitter is None + assert st2.current_splitter_rpn == [] + + st2.prepare_states(inputs={"NA.a": [3, 5]}) + assert st2.group_for_inputs_final == {"NA.a": 0} + assert st2.groups_stack_final == [[0]] + assert st2.states_ind == [{"NA.a": 0}, {"NA.a": 1}] + assert st2.states_val == [{"NA.a": 3}, {"NA.a": 5}] + + st2.prepare_inputs() + # since x,y come from the same state, they should have the same index + assert st2.inputs_ind == [{"NB.x": 0, "NB.y": 0}, {"NB.x": 1, "NB.y": 1}] + + +def test_state_connect_8(): + """three 'connected' states: NA -> NB -> NC; NA -> NC (only NA has its own splitter) + pydra should recognize, that there is only one splitter - NA + and it should give the same as the previous test + """ + st1 = State(name="NA", splitter="a") + st2 = State(name="NB", other_states={"NA": (st1, "b")}) + st3 = State(name="NC", other_states={"NA": (st1, "x"), "NB": (st2, "y")}) + # x comes from NA and y comes from NB, but NB has only NA's splitter, + # so it should be treated as both inputs are from NA state + assert st3.splitter == "_NA" + assert st3.splitter_rpn == ["NA.a"] + assert st3.prev_state_splitter == st3.splitter + assert st3.prev_state_splitter_rpn == st3.splitter_rpn + assert st3.current_splitter is None + assert st3.current_splitter_rpn == [] + + st3.prepare_states(inputs={"NA.a": [3, 5]}) + assert st3.group_for_inputs_final == {"NA.a": 0} + assert st3.groups_stack_final == [[0]] + assert st3.states_ind == [{"NA.a": 0}, {"NA.a": 1}] + assert st3.states_val == [{"NA.a": 3}, {"NA.a": 5}] + + st3.prepare_inputs() + # since x,y come from the same state (although y indirectly), they should have the same index + assert st3.inputs_ind == [{"NC.x": 0, "NC.y": 0}, {"NC.x": 1, "NC.y": 1}] + + +@pytest.mark.xfail( + reason="doesn't recognize that NC.y has 4 elements (not independend on NC.x)" +) +def test_state_connect_9(): + """four 'connected' states: NA1 -> NB; NA2 -> NB, NA1 -> NC; NB -> NC + pydra should recognize, that there is only one splitter - NA_1 and NA_2 + + """ + st1 = State(name="NA_1", splitter="a") + st1a = State(name="NA_2", splitter="a") + st2 = State(name="NB", other_states={"NA_1": (st1, "b"), "NA_2": (st1a, "c")}) + st3 = State(name="NC", other_states={"NA_1": (st1, "x"), "NB": (st2, "y")}) + # x comes from NA_1 and y comes from NB, but NB has only NA_1/2's splitters, + assert st3.splitter == ["_NA_1", "_NA_2"] + assert st3.splitter_rpn == ["NA_1.a", "NA_2.a", "*"] + assert st3.prev_state_splitter == st3.splitter + assert st3.prev_state_splitter_rpn == st3.splitter_rpn + assert st3.current_splitter is None + assert st3.current_splitter_rpn == [] + + st3.prepare_states(inputs={"NA_1.a": [3, 5], "NA_2.a": [11, 12]}) + assert st3.group_for_inputs_final == {"NA_1.a": 0, "NA_2.a": 1} + assert st3.groups_stack_final == [[0, 1]] + assert st3.states_ind == [ + {"NA_1.a": 0, "NA_2.a": 0}, + {"NA_1.a": 0, "NA_2.a": 1}, + {"NA_1.a": 1, "NA_2.a": 0}, + {"NA_1.a": 1, "NA_2.a": 1}, + ] + + st3.prepare_inputs() + assert st3.inputs_ind == [ + {"NC.x": 0, "NC.y": 0}, + {"NC.x": 0, "NC.y": 1}, + {"NC.x": 1, "NC.y": 2}, + {"NC.x": 1, "NC.y": 3}, + ] + + def test_state_connect_innerspl_1(): """two 'connected' states: testing groups, prepare_states and prepare_inputs, the second state has an inner splitter, full splitter provided @@ -605,7 +697,7 @@ def test_state_connect_innerspl_1(): inputs={"NA.a": [3, 5], "NB.b": [[1, 10, 100], [2, 20, 200]]}, cont_dim={"NB.b": 2}, # will be treated as 2d container ) - assert st2.other_states["NA"][1] == "b" + assert st2.other_states["NA"][1] == ["b"] assert st2.group_for_inputs_final == {"NA.a": 0, "NB.b": 1} assert st2.groups_stack_final == [[0], [1]] @@ -653,7 +745,7 @@ def test_state_connect_innerspl_1a(): assert st2.current_splitter == "NB.b" assert st2.current_splitter_rpn == ["NB.b"] - assert st2.other_states["NA"][1] == "b" + assert st2.other_states["NA"][1] == ["b"] st2.prepare_states( inputs={"NA.a": [3, 5], "NB.b": [[1, 10, 100], [2, 20, 200]]}, @@ -717,7 +809,7 @@ def test_state_connect_innerspl_2(): inputs={"NA.a": [3, 5], "NB.b": [[1, 10, 100], [2, 20, 200]], "NB.c": [13, 17]}, cont_dim={"NB.b": 2}, # will be treated as 2d container ) - assert st2.other_states["NA"][1] == "b" + assert st2.other_states["NA"][1] == ["b"] assert st2.group_for_inputs_final == {"NA.a": 0, "NB.c": 1, "NB.b": 2} assert st2.groups_stack_final == [[0], [1, 2]] @@ -778,7 +870,7 @@ def test_state_connect_innerspl_2a(): assert st2.splitter == ["_NA", ["NB.b", "NB.c"]] assert st2.splitter_rpn == ["NA.a", "NB.b", "NB.c", "*", "*"] - assert st2.other_states["NA"][1] == "b" + assert st2.other_states["NA"][1] == ["b"] st2.prepare_states( inputs={"NA.a": [3, 5], "NB.b": [[1, 10, 100], [2, 20, 200]], "NB.c": [13, 17]}, @@ -839,6 +931,7 @@ def test_state_connect_innerspl_3(): the second state has one inner splitter and one 'normal' splitter the prev-state parts of the splitter have to be added """ + st1 = State(name="NA", splitter="a") st2 = State(name="NB", splitter=["c", "b"], other_states={"NA": (st1, "b")}) st3 = State(name="NC", splitter="d", other_states={"NB": (st2, "a")}) @@ -986,8 +1079,8 @@ def test_state_connect_innerspl_4(): assert st3.splitter == [["_NA", "_NB"], "NC.d"] assert st3.splitter_rpn == ["NA.a", "NB.b", "NB.c", "*", "*", "NC.d", "*"] - assert st3.other_states["NA"][1] == "e" - assert st3.other_states["NB"][1] == "f" + assert st3.other_states["NA"][1] == ["e"] + assert st3.other_states["NB"][1] == ["f"] st3.prepare_states( inputs={ @@ -1736,12 +1829,12 @@ def test_connect_splitters_exception_1(splitter, other_states): def test_connect_splitters_exception_2(): - st = State( - name="CN", - splitter="_NB", - other_states={"NA": (State(name="NA", splitter="a"), "b")}, - ) with pytest.raises(PydraStateError) as excinfo: + st = State( + name="CN", + splitter="_NB", + other_states={"NA": (State(name="NA", splitter="a"), "b")}, + ) st.set_input_groups() assert "can't ask for splitter from NB" in str(excinfo.value) diff --git a/pydra/engine/tests/test_workflow.py b/pydra/engine/tests/test_workflow.py index 59358d8949..ac95a0d3f9 100644 --- a/pydra/engine/tests/test_workflow.py +++ b/pydra/engine/tests/test_workflow.py @@ -11,6 +11,7 @@ power, ten, identity, + identity_2flds, list_output, fun_addsubvar, fun_addvar3, @@ -1401,6 +1402,54 @@ def test_wf_3nd_ndst_6(plugin, tmpdir): assert wf.output_dir.exists() +# workflows with structures A -> B -> C with multiple connections + + +def test_wf_3nd_7(tmpdir): + """workflow with three tasks A->B->C vs two tasks A->C with multiple connections""" + wf = Workflow(name="wf", input_spec=["zip"], cache_dir=tmpdir) + wf.inputs.zip = [["test1", "test3", "test5"], ["test2", "test4", "test6"]] + + wf.add(identity_2flds(name="iden2flds_1", x1=wf.lzin.zip, x2="Hoi").split("x1")) + + wf.add(identity(name="identity", x=wf.iden2flds_1.lzout.out1)) + + wf.add( + identity_2flds( + name="iden2flds_2", x1=wf.identity.lzout.out, x2=wf.iden2flds_1.lzout.out2 + ) + ) + + wf.add( + identity_2flds( + name="iden2flds_2a", + x1=wf.iden2flds_1.lzout.out1, + x2=wf.iden2flds_1.lzout.out2, + ) + ) + + wf.set_output( + [ + ("out1", wf.iden2flds_2.lzout.out1), + ("out2", wf.iden2flds_2.lzout.out2), + ("out1a", wf.iden2flds_2a.lzout.out1), + ("out2a", wf.iden2flds_2a.lzout.out2), + ] + ) + + with Submitter(plugin="cf") as sub: + sub(wf) + + res = wf.result() + + assert ( + res.output.out1 + == res.output.out1a + == [["test1", "test3", "test5"], ["test2", "test4", "test6"]] + ) + assert res.output.out2 == res.output.out2a == ["Hoi", "Hoi"] + + # workflows with Left and Right part in splitters A -> B (L&R parts of the splitter) diff --git a/pydra/engine/tests/utils.py b/pydra/engine/tests/utils.py index 857292ba1c..99b429f32c 100644 --- a/pydra/engine/tests/utils.py +++ b/pydra/engine/tests/utils.py @@ -141,6 +141,13 @@ def identity(x): return x +@mark.task +def identity_2flds( + x1, x2 +) -> ty.NamedTuple("Output", [("out1", ty.Any), ("out2", ty.Any)]): + return x1, x2 + + @mark.task def ten(x): return 10 diff --git a/setup.cfg b/setup.cfg index 1c055a0804..09aae1637d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -82,7 +82,7 @@ tests = %(test)s dev = %(test)s - black + black==21.4b2 pre-commit dask = %(test)s