Skip to content

Commit

Permalink
fix for complex conditional links in query pipeline (#12805)
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-markewich authored Apr 14, 2024
1 parent b4271ec commit b430d17
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 4 deletions.
60 changes: 56 additions & 4 deletions llama-index-core/llama_index/core/query_pipeline/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,11 +477,16 @@ def _process_component_output(
) -> List[str]:
"""Process component output."""
new_queue = queue.copy()
# if there's no more edges, add result to output

nodes_to_keep = set()
nodes_to_remove = set()

# if there's no more edges, clear queue
if module_key in self._get_leaf_keys():
result_outputs[module_key] = output_dict
new_queue = []
else:
edge_list = list(self.dag.edges(module_key, data=True))

# everything not in conditional_edge_list is regular
for _, dest, attr in edge_list:
output = get_output(attr.get("src_key"), output_dict)
Expand All @@ -505,9 +510,56 @@ def _process_component_output(
self.module_dict[dest],
all_module_inputs[dest],
)
nodes_to_keep.add(dest)
else:
# remove dest from queue
new_queue.remove(dest)
nodes_to_remove.add(dest)

# remove nodes from the queue, as well as any nodes that depend on dest
# be sure to not remove any remaining dependencies of the current path
available_paths = []
for node in nodes_to_keep:
for leaf_node in self._get_leaf_keys():
if leaf_node == node:
available_paths.append([node])
else:
available_paths.extend(
list(
networkx.all_simple_paths(
self.dag, source=node, target=leaf_node
)
)
)

# this is a list of all nodes between the current node(s) and the leaf nodes
nodes_to_never_remove = set(x for path in available_paths for x in path) # noqa

removal_paths = []
for node in nodes_to_remove:
for leaf_node in self._get_leaf_keys():
if leaf_node == node:
removal_paths.append([node])
else:
removal_paths.extend(
list(
networkx.all_simple_paths(
self.dag, source=node, target=leaf_node
)
)
)

# this is a list of all nodes between the current node(s) to remove and the leaf nodes
nodes_to_probably_remove = set( # noqa
x for path in removal_paths for x in path
)

# remove nodes that are not in the current path
for node in nodes_to_probably_remove:
if node not in nodes_to_never_remove:
new_queue.remove(node)

# did we remove all remaining edges? then we have our result
if len(new_queue) == 0:
result_outputs[module_key] = output_dict

return new_queue

Expand Down
88 changes: 88 additions & 0 deletions llama-index-core/tests/query_pipeline/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,3 +408,91 @@ def choose_fn(input: int) -> Dict:
output = p.run(inp1=2, inp2=3)
# should go to b
assert output == "3:2"


def test_query_pipeline_super_conditional() -> None:
"""This tests that paths are properly pruned and maintained for many conditional edges."""

def simple_fn(val: int):
print("Running simple_fn", flush=True)
return val

def over_twenty_fn(val: int):
print("Running over_twenty_fn", flush=True)
return val + 100

def final_fn(x: int, y: int, z: int):
print("Running final_fn", flush=True)
return {
"x": x,
"y": y,
"z": z,
}

simple_function_component = FnComponent(fn=simple_fn, output_key="output")
over_twenty_function_2 = FnComponent(fn=over_twenty_fn, output_key="output")
final_fn = FnComponent(fn=final_fn, output_key="output")

qp = QueryPipeline(
modules={
"first_decision": simple_function_component,
"second_decision": simple_function_component,
"under_ten": simple_function_component,
"over_twenty": simple_function_component,
"over_twenty_2": over_twenty_function_2,
"final": final_fn,
},
verbose=True,
)

qp.add_link(
"first_decision",
"under_ten",
condition_fn=lambda x: x < 10,
)
qp.add_link("under_ten", "final", dest_key="x")
qp.add_link("under_ten", "final", dest_key="y")
qp.add_link("under_ten", "final", dest_key="z")

qp.add_link(
"first_decision",
"second_decision",
condition_fn=lambda x: x >= 10,
)

qp.add_link(
"second_decision",
"over_twenty",
condition_fn=lambda x: x > 20,
)
qp.add_link(
"second_decision",
"over_twenty_2",
condition_fn=lambda x: x > 20,
)

qp.add_link(
"second_decision",
"final",
dest_key="z",
condition_fn=lambda x: x > 20,
)
qp.add_link(
"over_twenty",
"final",
dest_key="x",
)
qp.add_link(
"over_twenty_2",
"final",
dest_key="y",
)

response = qp.run(val=9)
assert response == {"x": 9, "y": 9, "z": 9}

response = qp.run(val=11)
assert response == 11

response = qp.run(val=21)
assert response == {"x": 21, "y": 121, "z": 21}

0 comments on commit b430d17

Please sign in to comment.