Skip to content

Commit 5c71fa7

Browse files
committed
Cleanup tests
1 parent aacfb57 commit 5c71fa7

File tree

2 files changed

+66
-58
lines changed
  • libs/langgraph-checkpoint-aws

2 files changed

+66
-58
lines changed

libs/langgraph-checkpoint-aws/langgraph_checkpoint_aws/agentcore/saver.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,6 @@ async def aput(
297297
metadata: CheckpointMetadata,
298298
new_versions: ChannelVersions,
299299
) -> RunnableConfig:
300-
# return self.put(config, checkpoint, metadata, new_versions)
301300
return await run_in_executor(
302301
None, self.put, config, checkpoint, metadata, new_versions
303302
)

libs/langgraph-checkpoint-aws/tests/unit_tests/agentcore/test_saver.py

Lines changed: 66 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -41,53 +41,26 @@
4141
OVERHEAD_RUNNER_TIME = 0.05
4242
TOTAL_EXPECTED_TIME = MOCK_SLEEP_DURATION + OVERHEAD_RUNNER_TIME
4343

44-
45-
# Mock helper functions for async testing
46-
def _create_mock_checkpoint_tuple(
47-
thread_id="test-thread", checkpoint_id="test-checkpoint"
48-
):
44+
@pytest.fixture
45+
def sample_checkpoint_tuple():
4946
"""Helper to create a mock checkpoint tuple with configurable IDs."""
50-
mock_tuple = MagicMock()
51-
mock_tuple.config = {
47+
config = {
5248
"configurable": {
53-
"thread_id": thread_id,
54-
"actor_id": "test-actor",
55-
"checkpoint_id": checkpoint_id,
49+
"thread_id": "test_thread_id",
50+
"actor_id": "test_actor",
51+
"checkpoint_id": "test_checkpoint_id",
5652
}
5753
}
58-
mock_tuple.checkpoint = {"id": checkpoint_id}
59-
mock_tuple.metadata = {"source": "input", "step": 0}
60-
return mock_tuple
61-
62-
63-
def slow_get_tuple(config): # noqa: ARG001
64-
"""Mock get_tuple with artificial delay for testing async concurrency."""
65-
time.sleep(MOCK_SLEEP_DURATION)
66-
return _create_mock_checkpoint_tuple()
67-
68-
69-
def slow_list(config, *, filter=None, before=None, limit=None): # noqa: ARG001 A002
70-
"""Mock list with artificial delay for testing async concurrency."""
71-
time.sleep(MOCK_SLEEP_DURATION)
72-
return [_create_mock_checkpoint_tuple()]
73-
74-
75-
def slow_put(config, checkpoint, metadata, new_versions): # noqa: ARG001
76-
"""Mock put with artificial delay for testing async concurrency."""
77-
time.sleep(MOCK_SLEEP_DURATION)
78-
return config
79-
54+
checkpoint = {"id": "test_checkpoint_id"}
55+
metadata = {"source": "input", "step": 0}
56+
return CheckpointTuple(
57+
config=config,
58+
checkpoint=checkpoint,
59+
metadata=metadata,
60+
)
8061

81-
def slow_put_writes(config, writes, task_id, task_path=""): # noqa: ARG001
82-
"""Mock put_writes with artificial delay for testing async concurrency."""
83-
time.sleep(MOCK_SLEEP_DURATION)
84-
return
8562

8663

87-
def slow_delete_thread(thread_id, actor_id=""): # noqa: ARG001
88-
"""Mock delete_thread with artificial delay for testing async concurrency."""
89-
time.sleep(MOCK_SLEEP_DURATION)
90-
return
9164

9265

9366
@pytest.fixture
@@ -203,6 +176,46 @@ def sample_checkpoint_metadata(self):
203176
"namespace2": "parent_checkpoint_2",
204177
},
205178
)
179+
180+
@pytest.fixture
181+
def slow_get_tuple(self, sample_checkpoint_tuple):
182+
"""Mock get_tuple with artificial delay for testing async concurrency."""
183+
def _slow_get_tuple(config): # noqa: ARG001
184+
time.sleep(MOCK_SLEEP_DURATION)
185+
return sample_checkpoint_tuple
186+
return _slow_get_tuple
187+
188+
@pytest.fixture
189+
def slow_list(self, sample_checkpoint_tuple):
190+
"""Mock list with artificial delay for testing async concurrency."""
191+
def _slow_list(config, *, filter=None, before=None, limit=None): # noqa: ARG001 A002
192+
time.sleep(MOCK_SLEEP_DURATION)
193+
return [sample_checkpoint_tuple]
194+
return _slow_list
195+
196+
@pytest.fixture
197+
def slow_put(self):
198+
"""Mock put with artificial delay for testing async concurrency."""
199+
def _slow_put(config, checkpoint, metadata, new_versions): # noqa: ARG001
200+
time.sleep(MOCK_SLEEP_DURATION)
201+
return config
202+
return _slow_put
203+
204+
@pytest.fixture
205+
def slow_put_writes(self):
206+
"""Mock put_writes with artificial delay for testing async concurrency."""
207+
def _slow_put_writes(config, writes, task_id, task_path=""): # noqa: ARG001
208+
time.sleep(MOCK_SLEEP_DURATION)
209+
return
210+
return _slow_put_writes
211+
212+
@pytest.fixture
213+
def slow_delete_thread(self):
214+
"""Mock delete_thread with artificial delay for testing async concurrency."""
215+
def _slow_delete_thread(thread_id, actor_id=""): # noqa: ARG001
216+
time.sleep(MOCK_SLEEP_DURATION)
217+
return
218+
return _slow_delete_thread
206219

207220
def test_init_with_default_client(self, memory_id):
208221
with patch("boto3.client") as mock_boto3_client:
@@ -647,7 +660,7 @@ def test_get_next_version(self, saver):
647660
assert version.startswith("00000000000000000000000000000011.")
648661

649662
async def test_aget_tuple_calls_sync_method_with_correct_args(
650-
self, saver, runnable_config
663+
self, saver, runnable_config, slow_get_tuple
651664
):
652665
"""
653666
Test that aget_tuple calls the sync get_tuple method with correct arguments.
@@ -659,11 +672,10 @@ async def test_aget_tuple_calls_sync_method_with_correct_args(
659672
# Verify sync method was called with correct arguments
660673
mock_get.assert_called_once_with(runnable_config)
661674

662-
# Verify result is returned correctly
663675
assert result is not None
664676

665677
async def test_alist_calls_sync_method_with_correct_args(
666-
self, saver, runnable_config
678+
self, saver, runnable_config, slow_list
667679
):
668680
"""Test that alist calls the sync list method with correct arguments."""
669681

@@ -689,9 +701,11 @@ async def test_alist_calls_sync_method_with_correct_args(
689701
before=before_config,
690702
limit=limit_value,
691703
)
704+
assert len(items) == 1
705+
assert isinstance(items[0], CheckpointTuple)
692706

693707
async def test_aput_calls_sync_method_with_correct_args(
694-
self, saver, runnable_config, sample_checkpoint, sample_checkpoint_metadata
708+
self, saver, runnable_config, sample_checkpoint, sample_checkpoint_metadata, slow_put
695709
):
696710
"""Test that aput calls the sync put method with correct arguments."""
697711

@@ -713,11 +727,10 @@ async def test_aput_calls_sync_method_with_correct_args(
713727
new_versions,
714728
)
715729

716-
# Verify result is returned correctly
717730
assert result == runnable_config
718731

719732
async def test_aput_writes_calls_sync_method_with_correct_args(
720-
self, saver, runnable_config
733+
self, saver, runnable_config, slow_put_writes
721734
):
722735
"""
723736
Test that aput_writes calls the sync put_writes method with correct arguments.
@@ -738,12 +751,10 @@ async def test_aput_writes_calls_sync_method_with_correct_args(
738751
mock_put_writes.assert_called_once_with(
739752
runnable_config, writes, task_id, task_path
740753
)
741-
742-
# Verify result (should be None for put_writes)
743754
assert result is None
744755

745756
async def test_adelete_thread_calls_sync_method_with_correct_args(
746-
self, saver, runnable_config
757+
self, saver, runnable_config, slow_delete_thread
747758
):
748759
"""
749760
Test that adelete_thread calls the sync delete_thread method
@@ -760,18 +771,16 @@ async def test_adelete_thread_calls_sync_method_with_correct_args(
760771

761772
# Verify sync method was called with correct arguments
762773
mock_delete.assert_called_once_with(thread_id, actor_id)
763-
764-
# Verify result (should be None for delete_thread)
765774
assert result is None
766775

767-
async def test_concurrent_calls_aget_tuple(self, saver, runnable_config):
776+
async def test_concurrent_calls_aget_tuple(self, saver, runnable_config, slow_get_tuple):
768777
"""Test that concurrent calls are faster than sequential calls."""
769778
with patch.object(saver, "get_tuple", side_effect=slow_get_tuple):
770779
await self.assert_concurrent_calls_are_faster_than_sequential(
771780
N_ASYNC_CALLS, saver.aget_tuple, runnable_config
772781
)
773782

774-
async def test_concurrent_calls_adelete_thread(self, saver, runnable_config):
783+
async def test_concurrent_calls_adelete_thread(self, saver, runnable_config, slow_delete_thread):
775784
"""Test that concurrent calls are faster than sequential calls."""
776785
thread_id = runnable_config["configurable"]["thread_id"]
777786
actor_id = runnable_config["configurable"]["actor_id"]
@@ -781,7 +790,7 @@ async def test_concurrent_calls_adelete_thread(self, saver, runnable_config):
781790
N_ASYNC_CALLS, saver.adelete_thread, thread_id, actor_id
782791
)
783792

784-
async def test_concurrent_calls_aput_writes(self, saver, runnable_config):
793+
async def test_concurrent_calls_aput_writes(self, saver, runnable_config, slow_put_writes):
785794
"""Test that concurrent calls are faster than sequential calls."""
786795
writes = [("channel", "value")]
787796
task_id = "test-task"
@@ -798,12 +807,12 @@ async def test_concurrent_calls_aput_writes(self, saver, runnable_config):
798807
)
799808

800809
async def test_concurrent_calls_aput(
801-
self, saver, runnable_config, sample_checkpoint, sample_checkpoint_metadata
810+
self, saver, runnable_config, sample_checkpoint, sample_checkpoint_metadata, slow_put
802811
):
803812
"""Test that concurrent calls are faster than sequential calls."""
804813
new_versions = {"default": "v2"}
805814

806-
with patch.object(saver, "put", side_effect=slow_put_writes):
815+
with patch.object(saver, "put", side_effect=slow_put):
807816
await self.assert_concurrent_calls_are_faster_than_sequential(
808817
N_ASYNC_CALLS,
809818
saver.aput,
@@ -813,7 +822,7 @@ async def test_concurrent_calls_aput(
813822
new_versions,
814823
)
815824

816-
async def test_concurrent_calls_alist(self, saver, runnable_config):
825+
async def test_concurrent_calls_alist(self, saver, runnable_config, slow_list):
817826
"""Test that concurrent calls are faster than sequential calls."""
818827
filter_dict = {"test": "filter"}
819828
before_config = {"before": "config"}

0 commit comments

Comments
 (0)