Skip to content

Commit 14e4d09

Browse files
[INTPYTHON-690] Adds async methods to mongo db saver (#174)
1 parent 3044e4a commit 14e4d09

File tree

6 files changed

+250
-95
lines changed

6 files changed

+250
-95
lines changed

libs/langgraph-checkpoint-mongodb/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
---
44

5+
## Changes in version 0.2.0 (TBD)
6+
7+
- Implements async methods of MongoDBSaver.
8+
- Deprecates ASyncMongoDBSaver, to be removed in 0.3.0
9+
510
## Changes in version 0.1.4 (2025/06/13)
611

712
- Add TTL (time-to-live) indexes for automatic deletion of old checkpoints and writes

libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/aio.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import asyncio
44
import builtins
55
import sys
6+
import warnings
67
from collections.abc import AsyncIterator, Iterator, Sequence
78
from contextlib import asynccontextmanager
89
from datetime import datetime
@@ -87,6 +88,12 @@ def __init__(
8788
ttl: Optional[int] = None,
8889
**kwargs: Any,
8990
) -> None:
91+
warnings.warn(
92+
f"{self.__class__.__name__} is deprecated and will be removed in 0.3.0 release. "
93+
"Please use the async methods of MongoDBSaver instead.",
94+
DeprecationWarning,
95+
stacklevel=2,
96+
)
9097
super().__init__()
9198
self.client = client
9299
self.db = self.client[db_name]

libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
from collections.abc import Iterator, Sequence
1+
import asyncio
2+
from collections.abc import AsyncIterator, Iterator, Sequence
23
from contextlib import contextmanager
34
from datetime import datetime
45
from typing import (
56
Any,
67
Optional,
78
)
89

9-
from langchain_core.runnables import RunnableConfig
10+
from langchain_core.runnables import RunnableConfig, run_in_executor
1011
from pymongo import ASCENDING, MongoClient, UpdateOne
1112
from pymongo.database import Database as MongoDatabase
1213

@@ -464,3 +465,120 @@ def delete_thread(
464465

465466
# Delete all writes associated with the thread ID
466467
self.writes_collection.delete_many({"thread_id": thread_id})
468+
469+
async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
470+
"""Asynchronously fetch a checkpoint tuple using the given configuration.
471+
472+
Asynchronously wraps the blocking `self.get_tuple` method.
473+
474+
Args:
475+
config: Configuration specifying which checkpoint to retrieve.
476+
477+
Returns:
478+
Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found.
479+
480+
"""
481+
return await run_in_executor(None, self.get_tuple, config)
482+
483+
async def alist(
484+
self,
485+
config: Optional[RunnableConfig],
486+
*,
487+
filter: Optional[dict[str, Any]] = None,
488+
before: Optional[RunnableConfig] = None,
489+
limit: Optional[int] = None,
490+
) -> AsyncIterator[CheckpointTuple]:
491+
"""Asynchronously list checkpoints that match the given criteria.
492+
493+
Asynchronously wraps the blocking `self.list` generator.
494+
495+
Runs `self.list(...)` in a background thread and yields its items
496+
asynchronously from an asyncio.Queue. This allows integration of
497+
synchronous iterators into async code.
498+
499+
Args:
500+
config: Configuration object passed to `self.list`.
501+
filter: Optional filter dictionary.
502+
before: Optional parameter to limit results before a given checkpoint.
503+
limit: Optional maximum number of results to yield.
504+
505+
Yields:
506+
AsyncIterator[CheckpointTuple]: An iterator of checkpoint tuples.
507+
"""
508+
loop = asyncio.get_running_loop()
509+
queue: asyncio.Queue[CheckpointTuple] = asyncio.Queue()
510+
sentinel = object()
511+
512+
def run() -> None:
513+
try:
514+
for item in self.list(
515+
config, filter=filter, before=before, limit=limit
516+
):
517+
loop.call_soon_threadsafe(queue.put_nowait, item)
518+
finally:
519+
loop.call_soon_threadsafe(queue.put_nowait, sentinel) # type: ignore
520+
521+
await run_in_executor(None, run)
522+
while True:
523+
item = await queue.get()
524+
if item is sentinel:
525+
break
526+
yield item
527+
528+
async def aput(
529+
self,
530+
config: RunnableConfig,
531+
checkpoint: Checkpoint,
532+
metadata: CheckpointMetadata,
533+
new_versions: ChannelVersions,
534+
) -> RunnableConfig:
535+
"""Asynchronously store a checkpoint with its configuration and metadata.
536+
537+
Asynchronously wraps the blocking `self.put` method.
538+
539+
Args:
540+
config: Configuration for the checkpoint.
541+
checkpoint: The checkpoint to store.
542+
metadata: Additional metadata for the checkpoint.
543+
new_versions: New channel versions as of this write.
544+
545+
Returns:
546+
RunnableConfig: Updated configuration after storing the checkpoint.
547+
"""
548+
return await run_in_executor(
549+
None, self.put, config, checkpoint, metadata, new_versions
550+
)
551+
552+
async def aput_writes(
553+
self,
554+
config: RunnableConfig,
555+
writes: Sequence[tuple[str, Any]],
556+
task_id: str,
557+
task_path: str = "",
558+
) -> None:
559+
"""Asynchronously store intermediate writes linked to a checkpoint.
560+
561+
Asynchronously wraps the blocking `self.put_writes` method.
562+
563+
Args:
564+
config: Configuration of the related checkpoint.
565+
writes: List of writes to store.
566+
task_id: Identifier for the task creating the writes.
567+
task_path: Path of the task creating the writes.
568+
"""
569+
return await run_in_executor(
570+
None, self.put_writes, config, writes, task_id, task_path
571+
)
572+
573+
async def adelete_thread(
574+
self,
575+
thread_id: str,
576+
) -> None:
577+
"""Delete all checkpoints and writes associated with a specific thread ID.
578+
579+
Asynchronously wraps the blocking `self.delete_thread` method.
580+
581+
Args:
582+
thread_id: The thread ID whose checkpoints should be deleted.
583+
"""
584+
return await run_in_executor(None, self.delete_thread, thread_id)

libs/langgraph-checkpoint-mongodb/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ addopts = "--strict-markers --strict-config --durations=5 -vv"
4040
markers = [
4141
"requires: mark tests as requiring a specific library",
4242
"compile: mark placeholder test used to compile integration tests without running them",
43+
"asyncio: mark a test as asyncio",
4344
]
4445
asyncio_mode = "auto"
4546

libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_async.py

Lines changed: 114 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,107 +1,131 @@
11
import os
2-
from typing import Any
2+
from collections.abc import AsyncGenerator
3+
from typing import Any, Union
34

45
import pytest
6+
import pytest_asyncio
57
from bson.errors import InvalidDocument
6-
from pymongo import AsyncMongoClient
8+
from pymongo import AsyncMongoClient, MongoClient
79

8-
from langgraph.checkpoint.mongodb.aio import AsyncMongoDBSaver
10+
from langgraph.checkpoint.mongodb import AsyncMongoDBSaver, MongoDBSaver
911

10-
MONGODB_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017")
12+
MONGODB_URI = os.environ.get(
13+
"MONGODB_URI", "mongodb://localhost:27017/?directConnection=true"
14+
)
1115
DB_NAME = os.environ.get("DB_NAME", "langgraph-test")
1216
COLLECTION_NAME = "sync_checkpoints_aio"
1317

1418

15-
async def test_asearch(input_data: dict[str, Any]) -> None:
16-
# Clear collections if they exist
17-
client: AsyncMongoClient = AsyncMongoClient(MONGODB_URI)
18-
db = client[DB_NAME]
19-
20-
for clxn in await db.list_collection_names():
21-
await db.drop_collection(clxn)
22-
23-
async with AsyncMongoDBSaver.from_conn_string(
24-
MONGODB_URI, DB_NAME, COLLECTION_NAME
25-
) as saver:
26-
# save checkpoints
27-
await saver.aput(
28-
input_data["config_1"],
29-
input_data["chkpnt_1"],
30-
input_data["metadata_1"],
31-
{},
32-
)
33-
await saver.aput(
34-
input_data["config_2"],
35-
input_data["chkpnt_2"],
36-
input_data["metadata_2"],
37-
{},
38-
)
39-
await saver.aput(
40-
input_data["config_3"],
41-
input_data["chkpnt_3"],
42-
input_data["metadata_3"],
43-
{},
44-
)
45-
46-
# call method / assertions
47-
query_1 = {"source": "input"} # search by 1 key
48-
query_2 = {
49-
"step": 1,
50-
"writes": {"foo": "bar"},
51-
} # search by multiple keys
52-
query_3: dict[str, Any] = {} # search by no keys, return all checkpoints
53-
query_4 = {"source": "update", "step": 1} # no match
54-
55-
search_results_1 = [c async for c in saver.alist(None, filter=query_1)]
56-
assert len(search_results_1) == 1
57-
assert search_results_1[0].metadata == input_data["metadata_1"]
58-
59-
search_results_2 = [c async for c in saver.alist(None, filter=query_2)]
60-
assert len(search_results_2) == 1
61-
assert search_results_2[0].metadata == input_data["metadata_2"]
62-
63-
search_results_3 = [c async for c in saver.alist(None, filter=query_3)]
64-
assert len(search_results_3) == 3
65-
66-
search_results_4 = [c async for c in saver.alist(None, filter=query_4)]
67-
assert len(search_results_4) == 0
68-
69-
# search by config (defaults to checkpoints across all namespaces)
70-
search_results_5 = [
71-
c async for c in saver.alist({"configurable": {"thread_id": "thread-2"}})
72-
]
73-
assert len(search_results_5) == 2
74-
assert {
75-
search_results_5[0].config["configurable"]["checkpoint_ns"],
76-
search_results_5[1].config["configurable"]["checkpoint_ns"],
77-
} == {"", "inner"}
78-
79-
80-
async def test_null_chars(input_data: dict[str, Any]) -> None:
19+
@pytest_asyncio.fixture(params=["run_in_executor", "aio"])
20+
async def async_saver(request: pytest.FixtureRequest) -> AsyncGenerator:
21+
if request.param == "aio":
22+
# Use async client and checkpointer
23+
aclient: AsyncMongoClient = AsyncMongoClient(MONGODB_URI)
24+
adb = aclient[DB_NAME]
25+
for clxn in await adb.list_collection_names():
26+
await adb.drop_collection(clxn)
27+
async with AsyncMongoDBSaver.from_conn_string(
28+
MONGODB_URI, DB_NAME, COLLECTION_NAME
29+
) as checkpointer:
30+
yield checkpointer
31+
await aclient.close()
32+
else:
33+
# Use sync client and checkpointer with async methods run in executor
34+
client: MongoClient = MongoClient(MONGODB_URI)
35+
db = client[DB_NAME]
36+
for clxn in db.list_collection_names():
37+
db.drop_collection(clxn)
38+
with MongoDBSaver.from_conn_string(
39+
MONGODB_URI, DB_NAME, COLLECTION_NAME
40+
) as checkpointer:
41+
yield checkpointer
42+
client.close()
43+
44+
45+
@pytest.mark.asyncio
46+
async def test_asearch(
47+
input_data: dict[str, Any], async_saver: Union[AsyncMongoDBSaver, MongoDBSaver]
48+
) -> None:
49+
# save checkpoints
50+
await async_saver.aput(
51+
input_data["config_1"],
52+
input_data["chkpnt_1"],
53+
input_data["metadata_1"],
54+
{},
55+
)
56+
await async_saver.aput(
57+
input_data["config_2"],
58+
input_data["chkpnt_2"],
59+
input_data["metadata_2"],
60+
{},
61+
)
62+
await async_saver.aput(
63+
input_data["config_3"],
64+
input_data["chkpnt_3"],
65+
input_data["metadata_3"],
66+
{},
67+
)
68+
69+
# call method / assertions
70+
query_1 = {"source": "input"} # search by 1 key
71+
query_2 = {
72+
"step": 1,
73+
"writes": {"foo": "bar"},
74+
} # search by multiple keys
75+
query_3: dict[str, Any] = {} # search by no keys, return all checkpoints
76+
query_4 = {"source": "update", "step": 1} # no match
77+
78+
search_results_1 = [c async for c in async_saver.alist(None, filter=query_1)]
79+
assert len(search_results_1) == 1
80+
assert search_results_1[0].metadata == input_data["metadata_1"]
81+
82+
search_results_2 = [c async for c in async_saver.alist(None, filter=query_2)]
83+
assert len(search_results_2) == 1
84+
assert search_results_2[0].metadata == input_data["metadata_2"]
85+
86+
search_results_3 = [c async for c in async_saver.alist(None, filter=query_3)]
87+
assert len(search_results_3) == 3
88+
89+
search_results_4 = [c async for c in async_saver.alist(None, filter=query_4)]
90+
assert len(search_results_4) == 0
91+
92+
# search by config (defaults to checkpoints across all namespaces)
93+
search_results_5 = [
94+
c async for c in async_saver.alist({"configurable": {"thread_id": "thread-2"}})
95+
]
96+
assert len(search_results_5) == 2
97+
assert {
98+
search_results_5[0].config["configurable"]["checkpoint_ns"],
99+
search_results_5[1].config["configurable"]["checkpoint_ns"],
100+
} == {"", "inner"}
101+
102+
103+
@pytest.mark.asyncio
104+
async def test_null_chars(
105+
input_data: dict[str, Any], async_saver: Union[AsyncMongoDBSaver, MongoDBSaver]
106+
) -> None:
81107
"""In MongoDB string *values* can be any valid UTF-8 including nulls.
82108
*Field names*, however, cannot contain nulls characters."""
83-
async with AsyncMongoDBSaver.from_conn_string(
84-
MONGODB_URI, DB_NAME, COLLECTION_NAME
85-
) as saver:
86-
null_str = "\x00abc" # string containing null character
87109

88-
# 1. null string in field *value*
89-
null_value_cfg = await saver.aput(
110+
null_str = "\x00abc" # string containing null character
111+
112+
# 1. null string in field *value*
113+
null_value_cfg = await async_saver.aput(
114+
input_data["config_1"],
115+
input_data["chkpnt_1"],
116+
{"my_key": null_str},
117+
{},
118+
)
119+
null_tuple = await async_saver.aget_tuple(null_value_cfg)
120+
assert null_tuple.metadata["my_key"] == null_str # type: ignore
121+
cps = [c async for c in async_saver.alist(None, filter={"my_key": null_str})]
122+
assert cps[0].metadata["my_key"] == null_str
123+
124+
# 2. null string in field *name*
125+
with pytest.raises(InvalidDocument):
126+
await async_saver.aput(
90127
input_data["config_1"],
91128
input_data["chkpnt_1"],
92-
{"my_key": null_str},
129+
{null_str: "my_value"}, # type: ignore
93130
{},
94131
)
95-
null_tuple = await saver.aget_tuple(null_value_cfg)
96-
assert null_tuple.metadata["my_key"] == null_str # type: ignore
97-
cps = [c async for c in saver.alist(None, filter={"my_key": null_str})]
98-
assert cps[0].metadata["my_key"] == null_str
99-
100-
# 2. null string in field *name*
101-
with pytest.raises(InvalidDocument):
102-
await saver.aput(
103-
input_data["config_1"],
104-
input_data["chkpnt_1"],
105-
{null_str: "my_value"}, # type: ignore
106-
{},
107-
)

0 commit comments

Comments
 (0)