Skip to content

Commit 21d66d7

Browse files
committed
fix: include entire state snapshot in values stream chunk alongside interrupt
1 parent ed5d5e5 commit 21d66d7

File tree

3 files changed

+92
-17
lines changed

3 files changed

+92
-17
lines changed

libs/langgraph/langgraph/pregel/_loop.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ def _put_checkpoint(self, metadata: CheckpointMetadata) -> None:
766766
]
767767
self.checkpoint["channel_values"][TASKS] = sanitized_tasks
768768
# bail if no checkpointer
769-
769+
770770
if do_checkpoint and self._checkpointer_put_after_previous is not None:
771771
self.prev_checkpoint_config = (
772772
self.checkpoint_config
@@ -938,9 +938,11 @@ def output_writes(
938938
self._emit("updates", lambda: iter(interrupts))
939939
if "values" in stream_modes:
940940
current_values = read_channels(self.channels, self.output_keys)
941+
# self.output_keys is a sequence, stream chunk conntains entire state and interrupts
941942
if isinstance(current_values, dict):
942943
current_values[INTERRUPT] = interrupts[0][INTERRUPT]
943944
self._emit("values", lambda: iter([current_values]))
945+
# self.output_keys is a string, stream chunk contains only interrupts
944946
else:
945947
self._emit("values", lambda: iter(interrupts))
946948
elif writes[0][0] != ERROR:

libs/langgraph/tests/test_pregel.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -830,10 +830,9 @@ def raise_if_above_10(input: int) -> int:
830830
assert checkpoint["channel_values"].get("total") == 5
831831

832832

833-
def test_pending_writes_resume(
834-
sync_checkpointer: BaseCheckpointSaver
835-
) -> None:
833+
def test_pending_writes_resume(sync_checkpointer: BaseCheckpointSaver) -> None:
836834
durability = "exit"
835+
837836
class State(TypedDict):
838837
value: Annotated[int, operator.add]
839838

@@ -8424,8 +8423,8 @@ def human_node_2(state: State):
84248423
}
84258424

84268425

8427-
def test_interrupt_stream_mode_values():
8428-
"""Test that interrupts are surfaced when steam_mode='values'"""
8426+
def test_interrupt_stream_mode_values(sync_checkpointer: BaseCheckpointSaver):
8427+
"""Test that interrupts are surfaced on 'values' stream mode"""
84298428

84308429
class State(TypedDict):
84318430
robot_input: str
@@ -8443,19 +8442,36 @@ def human_input_node(state: State) -> Command:
84438442
builder.add_node(human_input_node)
84448443
builder.add_edge(START, "robot_input_node")
84458444
builder.add_edge("robot_input_node", "human_input_node")
8446-
app = builder.compile()
8445+
app = builder.compile(checkpointer=sync_checkpointer)
8446+
config = {"configurable": {"thread_id": str(uuid.uuid4())}}
84478447

8448-
result = [*app.stream(State(), stream_mode=["updates", "values"])]
8449-
print("PY_DEBUG: result", result)
8448+
result = [*app.stream(State(), config, stream_mode=["updates", "values"])]
84508449
assert len(result) == 4
8451-
8452-
8453-
assert result = [
8454-
('updates', {'robot_input_node': {'robot_input': 'beep boop i am a robot'}}),
8455-
('values', {'robot_input': 'beep boop i am a robot'}),
8456-
('updates', {'__interrupt__': (Interrupt(value='interrupt', id=AnyStr()),)}),
8457-
('values', {'robot_input': 'beep boop i am a robot', '__interrupt__': (Interrupt(value='interrupt', id=AnyStr()),)})]
8458-
assert result[1] == {"robot_input": "beep boop i am a robot", "__interrupt__": (Interrupt(value="interrupt", id=AnyStr()),)}
8450+
assert result == [
8451+
("updates", {"robot_input_node": {"robot_input": "beep boop i am a robot"}}),
8452+
("values", {"robot_input": "beep boop i am a robot"}),
8453+
("updates", {"__interrupt__": (Interrupt(value="interrupt", id=AnyStr()),)}),
8454+
(
8455+
"values",
8456+
{
8457+
"robot_input": "beep boop i am a robot",
8458+
"__interrupt__": (Interrupt(value="interrupt", id=AnyStr()),),
8459+
},
8460+
),
8461+
]
8462+
resume_result = [
8463+
*app.stream(
8464+
Command(resume="i am a human"), config, stream_mode=["updates", "values"]
8465+
)
8466+
]
8467+
assert resume_result == [
8468+
("values", {"robot_input": "beep boop i am a robot"}),
8469+
("updates", {"human_input_node": {"human_input": "i am a human"}}),
8470+
(
8471+
"values",
8472+
{"robot_input": "beep boop i am a robot", "human_input": "i am a human"},
8473+
),
8474+
]
84598475

84608476

84618477
def test_supersteps_populate_task_results(

libs/langgraph/tests/test_pregel_async.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9196,6 +9196,63 @@ async def consumer() -> None:
91969196
assert all(t.done() for t in recorded_tasks)
91979197

91989198

9199+
async def test_interrupt_stream_mode_values(async_checkpointer: BaseCheckpointSaver):
9200+
"""Test that interrupts are surfaced on 'values' stream mode"""
9201+
9202+
class State(TypedDict):
9203+
robot_input: str
9204+
human_input: str
9205+
9206+
def robot_input_node(state: State) -> State:
9207+
return {"robot_input": "beep boop i am a robot"}
9208+
9209+
def human_input_node(state: State) -> Command:
9210+
human_input = interrupt("interrupt")
9211+
return Command(update={"human_input": human_input})
9212+
9213+
builder = StateGraph(State)
9214+
builder.add_node(robot_input_node)
9215+
builder.add_node(human_input_node)
9216+
builder.add_edge(START, "robot_input_node")
9217+
builder.add_edge("robot_input_node", "human_input_node")
9218+
app = builder.compile(checkpointer=async_checkpointer)
9219+
config = {"configurable": {"thread_id": str(uuid.uuid4())}}
9220+
9221+
result = [
9222+
(mode, e)
9223+
async for mode, e in app.astream(
9224+
State(), config, stream_mode=["updates", "values"]
9225+
)
9226+
]
9227+
assert len(result) == 4
9228+
assert result == [
9229+
("updates", {"robot_input_node": {"robot_input": "beep boop i am a robot"}}),
9230+
("values", {"robot_input": "beep boop i am a robot"}),
9231+
("updates", {"__interrupt__": (Interrupt(value="interrupt", id=AnyStr()),)}),
9232+
(
9233+
"values",
9234+
{
9235+
"robot_input": "beep boop i am a robot",
9236+
"__interrupt__": (Interrupt(value="interrupt", id=AnyStr()),),
9237+
},
9238+
),
9239+
]
9240+
resume_result = [
9241+
(mode, e)
9242+
async for mode, e in app.astream(
9243+
Command(resume="i am a human"), config, stream_mode=["updates", "values"]
9244+
)
9245+
]
9246+
assert resume_result == [
9247+
("values", {"robot_input": "beep boop i am a robot"}),
9248+
("updates", {"human_input_node": {"human_input": "i am a human"}}),
9249+
(
9250+
"values",
9251+
{"robot_input": "beep boop i am a robot", "human_input": "i am a human"},
9252+
),
9253+
]
9254+
9255+
91999256
async def test_supersteps_populate_task_results(
92009257
async_checkpointer: BaseCheckpointSaver,
92019258
) -> None:

0 commit comments

Comments
 (0)