4141OVERHEAD_RUNNER_TIME = 0.05
4242TOTAL_EXPECTED_TIME = MOCK_SLEEP_DURATION + OVERHEAD_RUNNER_TIME
4343
44-
45- # Mock helper functions for async testing
46- def _create_mock_checkpoint_tuple (
47- thread_id = "test-thread" , checkpoint_id = "test-checkpoint"
48- ):
44+ @pytest .fixture
45+ def sample_checkpoint_tuple ():
4946 """Helper to create a mock checkpoint tuple with configurable IDs."""
50- mock_tuple = MagicMock ()
51- mock_tuple .config = {
47+ config = {
5248 "configurable" : {
53- "thread_id" : thread_id ,
54- "actor_id" : "test-actor " ,
55- "checkpoint_id" : checkpoint_id ,
49+ "thread_id" : "test_thread_id" ,
50+ "actor_id" : "test_actor " ,
51+ "checkpoint_id" : "test_checkpoint_id" ,
5652 }
5753 }
58- mock_tuple .checkpoint = {"id" : checkpoint_id }
59- mock_tuple .metadata = {"source" : "input" , "step" : 0 }
60- return mock_tuple
61-
62-
63- def slow_get_tuple (config ): # noqa: ARG001
64- """Mock get_tuple with artificial delay for testing async concurrency."""
65- time .sleep (MOCK_SLEEP_DURATION )
66- return _create_mock_checkpoint_tuple ()
67-
68-
69- def slow_list (config , * , filter = None , before = None , limit = None ): # noqa: ARG001 A002
70- """Mock list with artificial delay for testing async concurrency."""
71- time .sleep (MOCK_SLEEP_DURATION )
72- return [_create_mock_checkpoint_tuple ()]
73-
74-
75- def slow_put (config , checkpoint , metadata , new_versions ): # noqa: ARG001
76- """Mock put with artificial delay for testing async concurrency."""
77- time .sleep (MOCK_SLEEP_DURATION )
78- return config
79-
54+ checkpoint = {"id" : "test_checkpoint_id" }
55+ metadata = {"source" : "input" , "step" : 0 }
56+ return CheckpointTuple (
57+ config = config ,
58+ checkpoint = checkpoint ,
59+ metadata = metadata ,
60+ )
8061
81- def slow_put_writes (config , writes , task_id , task_path = "" ): # noqa: ARG001
82- """Mock put_writes with artificial delay for testing async concurrency."""
83- time .sleep (MOCK_SLEEP_DURATION )
84- return
8562
8663
87- def slow_delete_thread (thread_id , actor_id = "" ): # noqa: ARG001
88- """Mock delete_thread with artificial delay for testing async concurrency."""
89- time .sleep (MOCK_SLEEP_DURATION )
90- return
9164
9265
9366@pytest .fixture
@@ -203,6 +176,46 @@ def sample_checkpoint_metadata(self):
203176 "namespace2" : "parent_checkpoint_2" ,
204177 },
205178 )
179+
180+ @pytest .fixture
181+ def slow_get_tuple (self , sample_checkpoint_tuple ):
182+ """Mock get_tuple with artificial delay for testing async concurrency."""
183+ def _slow_get_tuple (config ): # noqa: ARG001
184+ time .sleep (MOCK_SLEEP_DURATION )
185+ return sample_checkpoint_tuple
186+ return _slow_get_tuple
187+
188+ @pytest .fixture
189+ def slow_list (self , sample_checkpoint_tuple ):
190+ """Mock list with artificial delay for testing async concurrency."""
191+ def _slow_list (config , * , filter = None , before = None , limit = None ): # noqa: ARG001 A002
192+ time .sleep (MOCK_SLEEP_DURATION )
193+ return [sample_checkpoint_tuple ]
194+ return _slow_list
195+
196+ @pytest .fixture
197+ def slow_put (self ):
198+ """Mock put with artificial delay for testing async concurrency."""
199+ def _slow_put (config , checkpoint , metadata , new_versions ): # noqa: ARG001
200+ time .sleep (MOCK_SLEEP_DURATION )
201+ return config
202+ return _slow_put
203+
204+ @pytest .fixture
205+ def slow_put_writes (self ):
206+ """Mock put_writes with artificial delay for testing async concurrency."""
207+ def _slow_put_writes (config , writes , task_id , task_path = "" ): # noqa: ARG001
208+ time .sleep (MOCK_SLEEP_DURATION )
209+ return
210+ return _slow_put_writes
211+
212+ @pytest .fixture
213+ def slow_delete_thread (self ):
214+ """Mock delete_thread with artificial delay for testing async concurrency."""
215+ def _slow_delete_thread (thread_id , actor_id = "" ): # noqa: ARG001
216+ time .sleep (MOCK_SLEEP_DURATION )
217+ return
218+ return _slow_delete_thread
206219
207220 def test_init_with_default_client (self , memory_id ):
208221 with patch ("boto3.client" ) as mock_boto3_client :
@@ -647,7 +660,7 @@ def test_get_next_version(self, saver):
647660 assert version .startswith ("00000000000000000000000000000011." )
648661
649662 async def test_aget_tuple_calls_sync_method_with_correct_args (
650- self , saver , runnable_config
663+ self , saver , runnable_config , slow_get_tuple
651664 ):
652665 """
653666 Test that aget_tuple calls the sync get_tuple method with correct arguments.
@@ -659,11 +672,10 @@ async def test_aget_tuple_calls_sync_method_with_correct_args(
659672 # Verify sync method was called with correct arguments
660673 mock_get .assert_called_once_with (runnable_config )
661674
662- # Verify result is returned correctly
663675 assert result is not None
664676
665677 async def test_alist_calls_sync_method_with_correct_args (
666- self , saver , runnable_config
678+ self , saver , runnable_config , slow_list
667679 ):
668680 """Test that alist calls the sync list method with correct arguments."""
669681
@@ -689,9 +701,11 @@ async def test_alist_calls_sync_method_with_correct_args(
689701 before = before_config ,
690702 limit = limit_value ,
691703 )
704+ assert len (items ) == 1
705+ assert isinstance (items [0 ], CheckpointTuple )
692706
693707 async def test_aput_calls_sync_method_with_correct_args (
694- self , saver , runnable_config , sample_checkpoint , sample_checkpoint_metadata
708+ self , saver , runnable_config , sample_checkpoint , sample_checkpoint_metadata , slow_put
695709 ):
696710 """Test that aput calls the sync put method with correct arguments."""
697711
@@ -713,11 +727,10 @@ async def test_aput_calls_sync_method_with_correct_args(
713727 new_versions ,
714728 )
715729
716- # Verify result is returned correctly
717730 assert result == runnable_config
718731
719732 async def test_aput_writes_calls_sync_method_with_correct_args (
720- self , saver , runnable_config
733+ self , saver , runnable_config , slow_put_writes
721734 ):
722735 """
723736 Test that aput_writes calls the sync put_writes method with correct arguments.
@@ -738,12 +751,10 @@ async def test_aput_writes_calls_sync_method_with_correct_args(
738751 mock_put_writes .assert_called_once_with (
739752 runnable_config , writes , task_id , task_path
740753 )
741-
742- # Verify result (should be None for put_writes)
743754 assert result is None
744755
745756 async def test_adelete_thread_calls_sync_method_with_correct_args (
746- self , saver , runnable_config
757+ self , saver , runnable_config , slow_delete_thread
747758 ):
748759 """
749760 Test that adelete_thread calls the sync delete_thread method
@@ -760,18 +771,16 @@ async def test_adelete_thread_calls_sync_method_with_correct_args(
760771
761772 # Verify sync method was called with correct arguments
762773 mock_delete .assert_called_once_with (thread_id , actor_id )
763-
764- # Verify result (should be None for delete_thread)
765774 assert result is None
766775
767- async def test_concurrent_calls_aget_tuple (self , saver , runnable_config ):
776+ async def test_concurrent_calls_aget_tuple (self , saver , runnable_config , slow_get_tuple ):
768777 """Test that concurrent calls are faster than sequential calls."""
769778 with patch .object (saver , "get_tuple" , side_effect = slow_get_tuple ):
770779 await self .assert_concurrent_calls_are_faster_than_sequential (
771780 N_ASYNC_CALLS , saver .aget_tuple , runnable_config
772781 )
773782
774- async def test_concurrent_calls_adelete_thread (self , saver , runnable_config ):
783+ async def test_concurrent_calls_adelete_thread (self , saver , runnable_config , slow_delete_thread ):
775784 """Test that concurrent calls are faster than sequential calls."""
776785 thread_id = runnable_config ["configurable" ]["thread_id" ]
777786 actor_id = runnable_config ["configurable" ]["actor_id" ]
@@ -781,7 +790,7 @@ async def test_concurrent_calls_adelete_thread(self, saver, runnable_config):
781790 N_ASYNC_CALLS , saver .adelete_thread , thread_id , actor_id
782791 )
783792
784- async def test_concurrent_calls_aput_writes (self , saver , runnable_config ):
793+ async def test_concurrent_calls_aput_writes (self , saver , runnable_config , slow_put_writes ):
785794 """Test that concurrent calls are faster than sequential calls."""
786795 writes = [("channel" , "value" )]
787796 task_id = "test-task"
@@ -798,12 +807,12 @@ async def test_concurrent_calls_aput_writes(self, saver, runnable_config):
798807 )
799808
800809 async def test_concurrent_calls_aput (
801- self , saver , runnable_config , sample_checkpoint , sample_checkpoint_metadata
810+ self , saver , runnable_config , sample_checkpoint , sample_checkpoint_metadata , slow_put
802811 ):
803812 """Test that concurrent calls are faster than sequential calls."""
804813 new_versions = {"default" : "v2" }
805814
806- with patch .object (saver , "put" , side_effect = slow_put_writes ):
815+ with patch .object (saver , "put" , side_effect = slow_put ):
807816 await self .assert_concurrent_calls_are_faster_than_sequential (
808817 N_ASYNC_CALLS ,
809818 saver .aput ,
@@ -813,7 +822,7 @@ async def test_concurrent_calls_aput(
813822 new_versions ,
814823 )
815824
816- async def test_concurrent_calls_alist (self , saver , runnable_config ):
825+ async def test_concurrent_calls_alist (self , saver , runnable_config , slow_list ):
817826 """Test that concurrent calls are faster than sequential calls."""
818827 filter_dict = {"test" : "filter" }
819828 before_config = {"before" : "config" }
0 commit comments