Skip to content

Commit

Permalink
Merge pull request #465 from djarecka/fix/states_connection
Browse files Browse the repository at this point in the history
[wip] fixing connections for multiple inputs
  • Loading branch information
djarecka authored May 6, 2021
2 parents a9224dd + c2d63d7 commit bc71076
Show file tree
Hide file tree
Showing 6 changed files with 279 additions and 32 deletions.
17 changes: 11 additions & 6 deletions pydra/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,16 +952,21 @@ def create_connections(self, task, detailed=False):
(task.name, field.name, val.name, val.field)
)
logger.debug("Connecting %s to %s", val.name, task.name)

# adding a state from the previous task to other_states
if (
getattr(self, val.name).state
and getattr(self, val.name).state.splitter_rpn_final
):
# adding a state from the previous task to other_states
other_states[val.name] = (
getattr(self, val.name).state,
field.name,
)
# adding task_name: (task.state, [a field from the connection]
if val.name not in other_states:
other_states[val.name] = (
getattr(self, val.name).state,
[field.name],
)
else:
# if the task already exist in other_state,
# additional field name should be added to the list of fields
other_states[val.name][1].append(field.name)
else: # LazyField with the wf input
# connections with wf input should be added to the detailed graph description
if detailed:
Expand Down
121 changes: 107 additions & 14 deletions pydra/engine/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,10 @@ def inner_inputs(self):
"""
if self.other_states:
_inner_inputs = {}
for name, (st, inp) in self.other_states.items():
for name, (st, inp_l) in self.other_states.items():
if f"_{st.name}" in self.splitter_rpn_compact:
_inner_inputs[f"{self.name}.{inp}"] = st
for inp in inp_l:
_inner_inputs[f"{self.name}.{inp}"] = st
return _inner_inputs
else:
return {}
Expand All @@ -323,6 +324,10 @@ def update_connections(self, new_other_states=None, new_combiner=None):
"""
if new_other_states:
self.other_states = new_other_states
# ensuring that the connected fields are set as a list
self.other_states = {
nm: (st, ensure_list(flds)) for nm, (st, flds) in self.other_states.items()
}
self._connect_splitters()
if new_combiner:
self.combiner = new_combiner
Expand Down Expand Up @@ -388,8 +393,93 @@ def _complete_prev_state(self, prev_state=None):
prev_state = [f"_{name}" for name in self.other_states]
if len(prev_state) == 1:
prev_state = prev_state[0]

if isinstance(prev_state, list):
prev_state = self._removed_repeated(prev_state)
return prev_state

def _removed_repeated(self, previous_splitters):
"""removing states from previous tasks that are repeated either directly or indirectly"""
for el in previous_splitters:
if el[1:] not in self.other_states:
raise hlpst.PydraStateError(
f"can't ask for splitter from {el[1:]}, other nodes that are connected: {self.other_states}"
)

repeated = set(
[
(el, previous_splitters.count(el))
for el in previous_splitters
if previous_splitters.count(el) > 1
]
)
if repeated:
# assuming that I want to remove fro right
previous_splitters.reverse()
for el, cnt in repeated:
for ii in range(cnt):
previous_splitters.remove(el)
previous_splitters.reverse()

el_state = []
el_connect = []
el_state_connect = []
for el in previous_splitters:
nm = el[1:]
st = self.other_states[nm][0]
if not st.other_states:
# states that has no other connections
el_state.append(el)
else: # element has previous_connection
if st.current_splitter: # final?
# states that has previous connections and it's own splitter
el_state_connect.append((el, st.prev_state_splitter))
else:
# states with previous connections but no additional splitter
el_connect.append((el, st.prev_state_splitter))

for el in el_connect:
nm = el[0][1:]
repeated_prev = set(ensure_list(el[1])).intersection(el_state)
if repeated_prev:
for r_el in repeated_prev:
r_nm = r_el[1:]
self.other_states[r_nm] = (
self.other_states[r_nm][0],
self.other_states[r_nm][1] + self.other_states[nm][1],
)
new_st = set(ensure_list(el[1])) - set(el_state)
if not new_st:
previous_splitters.remove(el[0])
else:
for n_el in new_st:
n_nm = n_el[1:]
self.other_states[n_nm] = (
self.other_states[nm][0].other_states[n_nm][0],
self.other_states[nm][1],
)
# removing el of the splitter and adding new_st instead
ind = previous_splitters.index(el[0])
if ind == len(previous_splitters) - 1:
previous_splitters = previous_splitters[:-1] + list(new_st)
else:
previous_splitters = (
previous_splitters[:ind]
+ list(new_st)
+ previous_splitters[ind + 1 :]
)
# TODO: this part is not tested, needs more work
for el in el_state_connect:
repeated_prev = set(ensure_list(el[1])).intersection(el_state)
if repeated_prev:
for r_el in repeated_prev:
previous_splitters.remove(r_el)

if len(previous_splitters) == 1:
return previous_splitters[0]
else:
return previous_splitters

def _prevst_current_check(self, splitter_part, check_nested=True):
"""
Check if splitter_part is purely prev-state part, the current part,
Expand Down Expand Up @@ -640,14 +730,15 @@ def prepare_states_ind(self):
# TODO: need tests in test_Workflow.py
elements_to_remove = []
elements_to_remove_comb = []
for name, (st, inp) in self.other_states.items():
if (
f"{self.name}.{inp}" in self.splitter_rpn
and f"_{name}" in self.splitter_rpn_compact
):
elements_to_remove.append(f"_{name}")
if f"{self.name}.{inp}" not in self.combiner:
elements_to_remove_comb.append(f"_{name}")
for name, (st, inp_l) in self.other_states.items():
for inp in inp_l:
if (
f"{self.name}.{inp}" in self.splitter_rpn
and f"_{name}" in self.splitter_rpn_compact
):
elements_to_remove.append(f"_{name}")
if f"{self.name}.{inp}" not in self.combiner:
elements_to_remove_comb.append(f"_{name}")

partial_rpn = hlpst.remove_inp_from_splitter_rpn(
deepcopy(self.splitter_rpn_compact), elements_to_remove
Expand Down Expand Up @@ -770,8 +861,9 @@ def prepare_inputs(self):
for ii, el in enumerate(self.prev_state_splitter_rpn_compact):
if el in ["*", "."]:
continue
st, inp = self.other_states[el[1:]]
if f"{self.name}.{inp}" in self.splitter_rpn: # inner splitter
st, inp_l = self.other_states[el[1:]]
inp_l = [f"{self.name}.{inp}" for inp in inp_l]
if set(inp_l).intersection(self.splitter_rpn): # inner splitter
connected_to_inner += [
el for el in st.splitter_rpn_final if el not in [".", "*"]
]
Expand All @@ -784,8 +876,9 @@ def prepare_inputs(self):
else:
inputs_ind_prev = hlpst.op["*"](inputs_ind_prev, st_ind)
else:
inputs_ind_prev = hlpst.op["*"](st_ind)
keys_inp_prev += [f"{self.name}.{inp}"]
# TODO: more tests needed
inputs_ind_prev = hlpst.op["."](*[st_ind] * len(inp_l))
keys_inp_prev += inp_l
keys_inp = keys_inp_prev + keys_inp

if inputs_ind and inputs_ind_prev:
Expand Down
115 changes: 104 additions & 11 deletions pydra/engine/tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,98 @@ def test_state_connect_6a():
]


def test_state_connect_7():
"""two 'connected' states with multiple fields that are connected
no explicit splitter for the second state
"""
st1 = State(name="NA", splitter="a")
st2 = State(name="NB", other_states={"NA": (st1, ["x", "y"])})
# should take into account that x, y come from the same task
assert st2.splitter == "_NA"
assert st2.splitter_rpn == ["NA.a"]
assert st2.prev_state_splitter == st2.splitter
assert st2.prev_state_splitter_rpn == st2.splitter_rpn
assert st2.current_splitter is None
assert st2.current_splitter_rpn == []

st2.prepare_states(inputs={"NA.a": [3, 5]})
assert st2.group_for_inputs_final == {"NA.a": 0}
assert st2.groups_stack_final == [[0]]
assert st2.states_ind == [{"NA.a": 0}, {"NA.a": 1}]
assert st2.states_val == [{"NA.a": 3}, {"NA.a": 5}]

st2.prepare_inputs()
# since x,y come from the same state, they should have the same index
assert st2.inputs_ind == [{"NB.x": 0, "NB.y": 0}, {"NB.x": 1, "NB.y": 1}]


def test_state_connect_8():
"""three 'connected' states: NA -> NB -> NC; NA -> NC (only NA has its own splitter)
pydra should recognize, that there is only one splitter - NA
and it should give the same as the previous test
"""
st1 = State(name="NA", splitter="a")
st2 = State(name="NB", other_states={"NA": (st1, "b")})
st3 = State(name="NC", other_states={"NA": (st1, "x"), "NB": (st2, "y")})
# x comes from NA and y comes from NB, but NB has only NA's splitter,
# so it should be treated as both inputs are from NA state
assert st3.splitter == "_NA"
assert st3.splitter_rpn == ["NA.a"]
assert st3.prev_state_splitter == st3.splitter
assert st3.prev_state_splitter_rpn == st3.splitter_rpn
assert st3.current_splitter is None
assert st3.current_splitter_rpn == []

st3.prepare_states(inputs={"NA.a": [3, 5]})
assert st3.group_for_inputs_final == {"NA.a": 0}
assert st3.groups_stack_final == [[0]]
assert st3.states_ind == [{"NA.a": 0}, {"NA.a": 1}]
assert st3.states_val == [{"NA.a": 3}, {"NA.a": 5}]

st3.prepare_inputs()
# since x,y come from the same state (although y indirectly), they should have the same index
assert st3.inputs_ind == [{"NC.x": 0, "NC.y": 0}, {"NC.x": 1, "NC.y": 1}]


@pytest.mark.xfail(
reason="doesn't recognize that NC.y has 4 elements (not independend on NC.x)"
)
def test_state_connect_9():
"""four 'connected' states: NA1 -> NB; NA2 -> NB, NA1 -> NC; NB -> NC
pydra should recognize, that there is only one splitter - NA_1 and NA_2
"""
st1 = State(name="NA_1", splitter="a")
st1a = State(name="NA_2", splitter="a")
st2 = State(name="NB", other_states={"NA_1": (st1, "b"), "NA_2": (st1a, "c")})
st3 = State(name="NC", other_states={"NA_1": (st1, "x"), "NB": (st2, "y")})
# x comes from NA_1 and y comes from NB, but NB has only NA_1/2's splitters,
assert st3.splitter == ["_NA_1", "_NA_2"]
assert st3.splitter_rpn == ["NA_1.a", "NA_2.a", "*"]
assert st3.prev_state_splitter == st3.splitter
assert st3.prev_state_splitter_rpn == st3.splitter_rpn
assert st3.current_splitter is None
assert st3.current_splitter_rpn == []

st3.prepare_states(inputs={"NA_1.a": [3, 5], "NA_2.a": [11, 12]})
assert st3.group_for_inputs_final == {"NA_1.a": 0, "NA_2.a": 1}
assert st3.groups_stack_final == [[0, 1]]
assert st3.states_ind == [
{"NA_1.a": 0, "NA_2.a": 0},
{"NA_1.a": 0, "NA_2.a": 1},
{"NA_1.a": 1, "NA_2.a": 0},
{"NA_1.a": 1, "NA_2.a": 1},
]

st3.prepare_inputs()
assert st3.inputs_ind == [
{"NC.x": 0, "NC.y": 0},
{"NC.x": 0, "NC.y": 1},
{"NC.x": 1, "NC.y": 2},
{"NC.x": 1, "NC.y": 3},
]


def test_state_connect_innerspl_1():
"""two 'connected' states: testing groups, prepare_states and prepare_inputs,
the second state has an inner splitter, full splitter provided
Expand All @@ -605,7 +697,7 @@ def test_state_connect_innerspl_1():
inputs={"NA.a": [3, 5], "NB.b": [[1, 10, 100], [2, 20, 200]]},
cont_dim={"NB.b": 2}, # will be treated as 2d container
)
assert st2.other_states["NA"][1] == "b"
assert st2.other_states["NA"][1] == ["b"]
assert st2.group_for_inputs_final == {"NA.a": 0, "NB.b": 1}
assert st2.groups_stack_final == [[0], [1]]

Expand Down Expand Up @@ -653,7 +745,7 @@ def test_state_connect_innerspl_1a():
assert st2.current_splitter == "NB.b"
assert st2.current_splitter_rpn == ["NB.b"]

assert st2.other_states["NA"][1] == "b"
assert st2.other_states["NA"][1] == ["b"]

st2.prepare_states(
inputs={"NA.a": [3, 5], "NB.b": [[1, 10, 100], [2, 20, 200]]},
Expand Down Expand Up @@ -717,7 +809,7 @@ def test_state_connect_innerspl_2():
inputs={"NA.a": [3, 5], "NB.b": [[1, 10, 100], [2, 20, 200]], "NB.c": [13, 17]},
cont_dim={"NB.b": 2}, # will be treated as 2d container
)
assert st2.other_states["NA"][1] == "b"
assert st2.other_states["NA"][1] == ["b"]
assert st2.group_for_inputs_final == {"NA.a": 0, "NB.c": 1, "NB.b": 2}
assert st2.groups_stack_final == [[0], [1, 2]]

Expand Down Expand Up @@ -778,7 +870,7 @@ def test_state_connect_innerspl_2a():

assert st2.splitter == ["_NA", ["NB.b", "NB.c"]]
assert st2.splitter_rpn == ["NA.a", "NB.b", "NB.c", "*", "*"]
assert st2.other_states["NA"][1] == "b"
assert st2.other_states["NA"][1] == ["b"]

st2.prepare_states(
inputs={"NA.a": [3, 5], "NB.b": [[1, 10, 100], [2, 20, 200]], "NB.c": [13, 17]},
Expand Down Expand Up @@ -839,6 +931,7 @@ def test_state_connect_innerspl_3():
the second state has one inner splitter and one 'normal' splitter
the prev-state parts of the splitter have to be added
"""

st1 = State(name="NA", splitter="a")
st2 = State(name="NB", splitter=["c", "b"], other_states={"NA": (st1, "b")})
st3 = State(name="NC", splitter="d", other_states={"NB": (st2, "a")})
Expand Down Expand Up @@ -986,8 +1079,8 @@ def test_state_connect_innerspl_4():

assert st3.splitter == [["_NA", "_NB"], "NC.d"]
assert st3.splitter_rpn == ["NA.a", "NB.b", "NB.c", "*", "*", "NC.d", "*"]
assert st3.other_states["NA"][1] == "e"
assert st3.other_states["NB"][1] == "f"
assert st3.other_states["NA"][1] == ["e"]
assert st3.other_states["NB"][1] == ["f"]

st3.prepare_states(
inputs={
Expand Down Expand Up @@ -1736,12 +1829,12 @@ def test_connect_splitters_exception_1(splitter, other_states):


def test_connect_splitters_exception_2():
st = State(
name="CN",
splitter="_NB",
other_states={"NA": (State(name="NA", splitter="a"), "b")},
)
with pytest.raises(PydraStateError) as excinfo:
st = State(
name="CN",
splitter="_NB",
other_states={"NA": (State(name="NA", splitter="a"), "b")},
)
st.set_input_groups()
assert "can't ask for splitter from NB" in str(excinfo.value)

Expand Down
Loading

0 comments on commit bc71076

Please sign in to comment.