-
from langchain_core.messages import AIMessage
from langgraph.graph import END, START, MessagesState, StateGraph
from rich import get_console
def main_1(_: MessagesState) -> MessagesState:
print("This is main_1")
return {"messages": [AIMessage(content="This is main_1")]}
def sub_1(_: MessagesState) -> MessagesState:
print("This is sub_1")
return {"messages": [AIMessage(content="This is sub_1")]}
def sub_2(_: MessagesState) -> MessagesState:
print("This is sub_2")
return {"messages": [AIMessage(content="This is sub_2")]}
def sub_3(_: MessagesState) -> MessagesState:
print("This is sub_3")
return {"messages": [AIMessage(content="This is sub_3")]}
def main_2(_: MessagesState) -> MessagesState:
print("This is main_2")
return {"messages": [AIMessage(content="This is main_2")]}
print("============ sub ============")
sub_graph = StateGraph(MessagesState)
sub_graph.add_node(sub_1)
sub_graph.add_node(sub_2)
sub_graph.add_node(sub_3)
sub_graph.add_edge(START, sub_1.__name__)
sub_graph.add_edge(sub_1.__name__, sub_2.__name__)
sub_graph.add_edge(sub_2.__name__, sub_3.__name__)
sub_graph.add_edge(sub_3.__name__, END)
sub_app = sub_graph.compile()
for v in sub_app.stream(
MessagesState(messages=[]),
{"configurable": {"thread_id": "sub_thread"}},
stream_mode="updates",
subgraphs=False,
):
get_console().print(v)
print("============ main ============")
main_graph = StateGraph(MessagesState)
main_graph.add_node(main_1)
main_graph.add_node("sub_app", sub_app)
main_graph.add_node(main_2)
main_graph.add_edge(START, main_1.__name__)
main_graph.add_edge(main_1.__name__, "sub_app")
main_graph.add_edge("sub_app", main_2.__name__)
main_graph.add_edge(main_2.__name__, END)
main_app = main_graph.compile()
for v in main_app.stream(
MessagesState(messages=[]),
{"configurable": {"thread_id": "main_thread"}},
stream_mode="updates",
subgraphs=True,
):
get_console().print(v)
async def main() -> None:
print("============ astream_events ============")
async for e in main_app.astream_events(
MessagesState(messages=[]),
{"configurable": {"thread_id": "main_thread"}},
stream_mode="updates",
subgraphs=True,
version="v2",
):
if e["event"] == "on_chain_stream":
get_console().print(e["data"]["chunk"])
if __name__ == "__main__":
import asyncio
asyncio.run(main()) ⬆️ When running the subgraph alone with ⬆️ When the subgraph is included within another graph and run with ⬆️ However, it works normally in Why would I get this kind of output? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
@gbaian10 this is due to difference between the implementations for .stream() vs .astream_events(). for now this is the expected behavior, but we might look into this in the future to improve consistency |
Beta Was this translation helpful? Give feedback.
@gbaian10 this is due to difference between the implementations for .stream() vs .astream_events(). for now this is the expected behavior, but we might look into this in the future to improve consistency