Skip to content

Commit aacfb57

Browse files
committed
Add tests
1 parent 7e51b21 commit aacfb57

File tree

2 files changed

+268
-3
lines changed
  • libs/langgraph-checkpoint-aws

2 files changed

+268
-3
lines changed

libs/langgraph-checkpoint-aws/langgraph_checkpoint_aws/agentcore/saver.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def delete_thread(self, thread_id: str, actor_id: str = "") -> None:
273273
"""Delete all checkpoints and writes associated with a thread."""
274274
self.checkpoint_event_client.delete_events(thread_id, actor_id)
275275

276-
# ===== Async methods ( TODO: Check running sync methods inside executor ) =====
276+
# ===== Async methods ( Running sync methods inside executor ) =====
277277
async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
278278
return await run_in_executor(None, self.get_tuple, config)
279279

@@ -285,7 +285,9 @@ async def alist(
285285
before: RunnableConfig | None = None,
286286
limit: int | None = None,
287287
) -> AsyncIterator[CheckpointTuple]:
288-
for item in await run_in_executor(None, self.list, config, filter=filter, before=before, limit=limit):
288+
for item in await run_in_executor(
289+
None, self.list, config, filter=filter, before=before, limit=limit
290+
):
289291
yield item
290292

291293
async def aput(
@@ -312,7 +314,7 @@ async def aput_writes(
312314
)
313315

314316
async def adelete_thread(self, thread_id: str, actor_id: str = "") -> None:
315-
self.delete_thread(thread_id, actor_id)
317+
await run_in_executor(None, self.delete_thread, thread_id, actor_id)
316318
return None
317319

318320
def get_next_version(

libs/langgraph-checkpoint-aws/tests/unit_tests/agentcore/test_saver.py

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
Unit tests for AgentCore Memory Checkpoint Saver.
33
"""
44

5+
import asyncio
56
import json
7+
import time
68
from unittest.mock import ANY, MagicMock, Mock, patch
79

810
import pytest
@@ -30,6 +32,63 @@
3032
)
3133
from langgraph_checkpoint_aws.agentcore.saver import AgentCoreMemorySaver
3234

35+
# Configure pytest to use anyio for async tests (asyncio backend only)
36+
pytestmark = pytest.mark.anyio
37+
38+
# Test constants for async testing
39+
N_ASYNC_CALLS = 10
40+
MOCK_SLEEP_DURATION = 0.5 / N_ASYNC_CALLS
41+
OVERHEAD_RUNNER_TIME = 0.05
42+
TOTAL_EXPECTED_TIME = MOCK_SLEEP_DURATION + OVERHEAD_RUNNER_TIME
43+
44+
45+
# Mock helper functions for async testing
46+
def _create_mock_checkpoint_tuple(
47+
thread_id="test-thread", checkpoint_id="test-checkpoint"
48+
):
49+
"""Helper to create a mock checkpoint tuple with configurable IDs."""
50+
mock_tuple = MagicMock()
51+
mock_tuple.config = {
52+
"configurable": {
53+
"thread_id": thread_id,
54+
"actor_id": "test-actor",
55+
"checkpoint_id": checkpoint_id,
56+
}
57+
}
58+
mock_tuple.checkpoint = {"id": checkpoint_id}
59+
mock_tuple.metadata = {"source": "input", "step": 0}
60+
return mock_tuple
61+
62+
63+
def slow_get_tuple(config): # noqa: ARG001
64+
"""Mock get_tuple with artificial delay for testing async concurrency."""
65+
time.sleep(MOCK_SLEEP_DURATION)
66+
return _create_mock_checkpoint_tuple()
67+
68+
69+
def slow_list(config, *, filter=None, before=None, limit=None): # noqa: ARG001 A002
70+
"""Mock list with artificial delay for testing async concurrency."""
71+
time.sleep(MOCK_SLEEP_DURATION)
72+
return [_create_mock_checkpoint_tuple()]
73+
74+
75+
def slow_put(config, checkpoint, metadata, new_versions): # noqa: ARG001
76+
"""Mock put with artificial delay for testing async concurrency."""
77+
time.sleep(MOCK_SLEEP_DURATION)
78+
return config
79+
80+
81+
def slow_put_writes(config, writes, task_id, task_path=""): # noqa: ARG001
82+
"""Mock put_writes with artificial delay for testing async concurrency."""
83+
time.sleep(MOCK_SLEEP_DURATION)
84+
return
85+
86+
87+
def slow_delete_thread(thread_id, actor_id=""): # noqa: ARG001
88+
"""Mock delete_thread with artificial delay for testing async concurrency."""
89+
time.sleep(MOCK_SLEEP_DURATION)
90+
return
91+
3392

3493
@pytest.fixture
3594
def sample_checkpoint_event():
@@ -587,6 +646,210 @@ def test_get_next_version(self, saver):
587646
)
588647
assert version.startswith("00000000000000000000000000000011.")
589648

649+
async def test_aget_tuple_calls_sync_method_with_correct_args(
650+
self, saver, runnable_config
651+
):
652+
"""
653+
Test that aget_tuple calls the sync get_tuple method with correct arguments.
654+
"""
655+
656+
with patch.object(saver, "get_tuple", side_effect=slow_get_tuple) as mock_get:
657+
result = await saver.aget_tuple(runnable_config)
658+
659+
# Verify sync method was called with correct arguments
660+
mock_get.assert_called_once_with(runnable_config)
661+
662+
# Verify result is returned correctly
663+
assert result is not None
664+
665+
async def test_alist_calls_sync_method_with_correct_args(
666+
self, saver, runnable_config
667+
):
668+
"""Test that alist calls the sync list method with correct arguments."""
669+
670+
filter_dict = {"test": "filter"}
671+
before_config = {"before": "config"}
672+
limit_value = 10
673+
674+
with patch.object(saver, "list", side_effect=slow_list) as mock_list:
675+
# Collect all items from async iterator
676+
items = []
677+
async for item in saver.alist(
678+
runnable_config,
679+
filter=filter_dict,
680+
before=before_config,
681+
limit=limit_value,
682+
):
683+
items.append(item)
684+
685+
# Verify sync method was called with correct arguments
686+
mock_list.assert_called_once_with(
687+
runnable_config,
688+
filter=filter_dict,
689+
before=before_config,
690+
limit=limit_value,
691+
)
692+
693+
async def test_aput_calls_sync_method_with_correct_args(
694+
self, saver, runnable_config, sample_checkpoint, sample_checkpoint_metadata
695+
):
696+
"""Test that aput calls the sync put method with correct arguments."""
697+
698+
new_versions = {"default": "v2"}
699+
700+
with patch.object(saver, "put", side_effect=slow_put) as mock_put:
701+
result = await saver.aput(
702+
runnable_config,
703+
sample_checkpoint,
704+
sample_checkpoint_metadata,
705+
new_versions,
706+
)
707+
708+
# Verify sync method was called with correct arguments
709+
mock_put.assert_called_once_with(
710+
runnable_config,
711+
sample_checkpoint,
712+
sample_checkpoint_metadata,
713+
new_versions,
714+
)
715+
716+
# Verify result is returned correctly
717+
assert result == runnable_config
718+
719+
async def test_aput_writes_calls_sync_method_with_correct_args(
720+
self, saver, runnable_config
721+
):
722+
"""
723+
Test that aput_writes calls the sync put_writes method with correct arguments.
724+
"""
725+
726+
writes = [("channel", "value")]
727+
task_id = "test-task"
728+
task_path = "test-path"
729+
730+
with patch.object(
731+
saver, "put_writes", side_effect=slow_put_writes
732+
) as mock_put_writes:
733+
result = await saver.aput_writes(
734+
runnable_config, writes, task_id, task_path
735+
)
736+
737+
# Verify sync method was called with correct arguments
738+
mock_put_writes.assert_called_once_with(
739+
runnable_config, writes, task_id, task_path
740+
)
741+
742+
# Verify result (should be None for put_writes)
743+
assert result is None
744+
745+
async def test_adelete_thread_calls_sync_method_with_correct_args(
746+
self, saver, runnable_config
747+
):
748+
"""
749+
Test that adelete_thread calls the sync delete_thread method
750+
with correct arguments
751+
"""
752+
753+
thread_id = runnable_config["configurable"]["thread_id"]
754+
actor_id = runnable_config["configurable"]["actor_id"]
755+
756+
with patch.object(
757+
saver, "delete_thread", side_effect=slow_delete_thread
758+
) as mock_delete:
759+
result = await saver.adelete_thread(thread_id, actor_id)
760+
761+
# Verify sync method was called with correct arguments
762+
mock_delete.assert_called_once_with(thread_id, actor_id)
763+
764+
# Verify result (should be None for delete_thread)
765+
assert result is None
766+
767+
async def test_concurrent_calls_aget_tuple(self, saver, runnable_config):
768+
"""Test that concurrent calls are faster than sequential calls."""
769+
with patch.object(saver, "get_tuple", side_effect=slow_get_tuple):
770+
await self.assert_concurrent_calls_are_faster_than_sequential(
771+
N_ASYNC_CALLS, saver.aget_tuple, runnable_config
772+
)
773+
774+
async def test_concurrent_calls_adelete_thread(self, saver, runnable_config):
775+
"""Test that concurrent calls are faster than sequential calls."""
776+
thread_id = runnable_config["configurable"]["thread_id"]
777+
actor_id = runnable_config["configurable"]["actor_id"]
778+
779+
with patch.object(saver, "delete_thread", side_effect=slow_delete_thread):
780+
await self.assert_concurrent_calls_are_faster_than_sequential(
781+
N_ASYNC_CALLS, saver.adelete_thread, thread_id, actor_id
782+
)
783+
784+
async def test_concurrent_calls_aput_writes(self, saver, runnable_config):
785+
"""Test that concurrent calls are faster than sequential calls."""
786+
writes = [("channel", "value")]
787+
task_id = "test-task"
788+
task_path = "test-path"
789+
790+
with patch.object(saver, "put_writes", side_effect=slow_put_writes):
791+
await self.assert_concurrent_calls_are_faster_than_sequential(
792+
N_ASYNC_CALLS,
793+
saver.aput_writes,
794+
runnable_config,
795+
writes,
796+
task_id,
797+
task_path,
798+
)
799+
800+
async def test_concurrent_calls_aput(
801+
self, saver, runnable_config, sample_checkpoint, sample_checkpoint_metadata
802+
):
803+
"""Test that concurrent calls are faster than sequential calls."""
804+
new_versions = {"default": "v2"}
805+
806+
with patch.object(saver, "put", side_effect=slow_put_writes):
807+
await self.assert_concurrent_calls_are_faster_than_sequential(
808+
N_ASYNC_CALLS,
809+
saver.aput,
810+
runnable_config,
811+
sample_checkpoint,
812+
sample_checkpoint_metadata,
813+
new_versions,
814+
)
815+
816+
async def test_concurrent_calls_alist(self, saver, runnable_config):
817+
"""Test that concurrent calls are faster than sequential calls."""
818+
filter_dict = {"test": "filter"}
819+
before_config = {"before": "config"}
820+
limit_value = 10
821+
822+
with patch.object(saver, "list", side_effect=slow_list):
823+
824+
async def consume_alist() -> list:
825+
"""Helper coroutine to consume the async iterator."""
826+
items = []
827+
async for item in saver.alist(
828+
runnable_config,
829+
filter=filter_dict,
830+
before=before_config,
831+
limit=limit_value,
832+
):
833+
items.append(item)
834+
return items
835+
836+
await self.assert_concurrent_calls_are_faster_than_sequential(
837+
N_ASYNC_CALLS, consume_alist
838+
)
839+
840+
async def assert_concurrent_calls_are_faster_than_sequential(
841+
self, n_async_calls: int, func, *args, **kwargs
842+
) -> None:
843+
"""Helper to run n async tasks concurrently."""
844+
tasks = [func(*args, **kwargs) for _ in range(n_async_calls)]
845+
start_time = time.time()
846+
await asyncio.gather(*tasks)
847+
concurrent_time = time.time() - start_time
848+
assert concurrent_time < TOTAL_EXPECTED_TIME, (
849+
f"Concurrent execution took {concurrent_time:.2f}s, "
850+
f"expected < {TOTAL_EXPECTED_TIME}s"
851+
)
852+
590853

591854
class TestCheckpointerConfig:
592855
"""Test suite for CheckpointerConfig."""

0 commit comments

Comments
 (0)