|
2 | 2 | Unit tests for AgentCore Memory Checkpoint Saver. |
3 | 3 | """ |
4 | 4 |
|
| 5 | +import asyncio |
5 | 6 | import json |
| 7 | +import time |
6 | 8 | from unittest.mock import ANY, MagicMock, Mock, patch |
7 | 9 |
|
8 | 10 | import pytest |
|
30 | 32 | ) |
31 | 33 | from langgraph_checkpoint_aws.agentcore.saver import AgentCoreMemorySaver |
32 | 34 |
|
| 35 | +# Configure pytest to use anyio for async tests (asyncio backend only) |
| 36 | +pytestmark = pytest.mark.anyio |
| 37 | + |
| 38 | +# Test constants for async testing |
| 39 | +N_ASYNC_CALLS = 10 |
| 40 | +MOCK_SLEEP_DURATION = 0.5 / N_ASYNC_CALLS |
| 41 | +OVERHEAD_RUNNER_TIME = 0.05 |
| 42 | +TOTAL_EXPECTED_TIME = MOCK_SLEEP_DURATION + OVERHEAD_RUNNER_TIME |
| 43 | + |
| 44 | + |
| 45 | +# Mock helper functions for async testing |
| 46 | +def _create_mock_checkpoint_tuple( |
| 47 | + thread_id="test-thread", checkpoint_id="test-checkpoint" |
| 48 | +): |
| 49 | + """Helper to create a mock checkpoint tuple with configurable IDs.""" |
| 50 | + mock_tuple = MagicMock() |
| 51 | + mock_tuple.config = { |
| 52 | + "configurable": { |
| 53 | + "thread_id": thread_id, |
| 54 | + "actor_id": "test-actor", |
| 55 | + "checkpoint_id": checkpoint_id, |
| 56 | + } |
| 57 | + } |
| 58 | + mock_tuple.checkpoint = {"id": checkpoint_id} |
| 59 | + mock_tuple.metadata = {"source": "input", "step": 0} |
| 60 | + return mock_tuple |
| 61 | + |
| 62 | + |
| 63 | +def slow_get_tuple(config): # noqa: ARG001 |
| 64 | + """Mock get_tuple with artificial delay for testing async concurrency.""" |
| 65 | + time.sleep(MOCK_SLEEP_DURATION) |
| 66 | + return _create_mock_checkpoint_tuple() |
| 67 | + |
| 68 | + |
| 69 | +def slow_list(config, *, filter=None, before=None, limit=None): # noqa: ARG001 A002 |
| 70 | + """Mock list with artificial delay for testing async concurrency.""" |
| 71 | + time.sleep(MOCK_SLEEP_DURATION) |
| 72 | + return [_create_mock_checkpoint_tuple()] |
| 73 | + |
| 74 | + |
| 75 | +def slow_put(config, checkpoint, metadata, new_versions): # noqa: ARG001 |
| 76 | + """Mock put with artificial delay for testing async concurrency.""" |
| 77 | + time.sleep(MOCK_SLEEP_DURATION) |
| 78 | + return config |
| 79 | + |
| 80 | + |
| 81 | +def slow_put_writes(config, writes, task_id, task_path=""): # noqa: ARG001 |
| 82 | + """Mock put_writes with artificial delay for testing async concurrency.""" |
| 83 | + time.sleep(MOCK_SLEEP_DURATION) |
| 84 | + return |
| 85 | + |
| 86 | + |
| 87 | +def slow_delete_thread(thread_id, actor_id=""): # noqa: ARG001 |
| 88 | + """Mock delete_thread with artificial delay for testing async concurrency.""" |
| 89 | + time.sleep(MOCK_SLEEP_DURATION) |
| 90 | + return |
| 91 | + |
33 | 92 |
|
34 | 93 | @pytest.fixture |
35 | 94 | def sample_checkpoint_event(): |
@@ -587,6 +646,210 @@ def test_get_next_version(self, saver): |
587 | 646 | ) |
588 | 647 | assert version.startswith("00000000000000000000000000000011.") |
589 | 648 |
|
| 649 | + async def test_aget_tuple_calls_sync_method_with_correct_args( |
| 650 | + self, saver, runnable_config |
| 651 | + ): |
| 652 | + """ |
| 653 | + Test that aget_tuple calls the sync get_tuple method with correct arguments. |
| 654 | + """ |
| 655 | + |
| 656 | + with patch.object(saver, "get_tuple", side_effect=slow_get_tuple) as mock_get: |
| 657 | + result = await saver.aget_tuple(runnable_config) |
| 658 | + |
| 659 | + # Verify sync method was called with correct arguments |
| 660 | + mock_get.assert_called_once_with(runnable_config) |
| 661 | + |
| 662 | + # Verify result is returned correctly |
| 663 | + assert result is not None |
| 664 | + |
| 665 | + async def test_alist_calls_sync_method_with_correct_args( |
| 666 | + self, saver, runnable_config |
| 667 | + ): |
| 668 | + """Test that alist calls the sync list method with correct arguments.""" |
| 669 | + |
| 670 | + filter_dict = {"test": "filter"} |
| 671 | + before_config = {"before": "config"} |
| 672 | + limit_value = 10 |
| 673 | + |
| 674 | + with patch.object(saver, "list", side_effect=slow_list) as mock_list: |
| 675 | + # Collect all items from async iterator |
| 676 | + items = [] |
| 677 | + async for item in saver.alist( |
| 678 | + runnable_config, |
| 679 | + filter=filter_dict, |
| 680 | + before=before_config, |
| 681 | + limit=limit_value, |
| 682 | + ): |
| 683 | + items.append(item) |
| 684 | + |
| 685 | + # Verify sync method was called with correct arguments |
| 686 | + mock_list.assert_called_once_with( |
| 687 | + runnable_config, |
| 688 | + filter=filter_dict, |
| 689 | + before=before_config, |
| 690 | + limit=limit_value, |
| 691 | + ) |
| 692 | + |
| 693 | + async def test_aput_calls_sync_method_with_correct_args( |
| 694 | + self, saver, runnable_config, sample_checkpoint, sample_checkpoint_metadata |
| 695 | + ): |
| 696 | + """Test that aput calls the sync put method with correct arguments.""" |
| 697 | + |
| 698 | + new_versions = {"default": "v2"} |
| 699 | + |
| 700 | + with patch.object(saver, "put", side_effect=slow_put) as mock_put: |
| 701 | + result = await saver.aput( |
| 702 | + runnable_config, |
| 703 | + sample_checkpoint, |
| 704 | + sample_checkpoint_metadata, |
| 705 | + new_versions, |
| 706 | + ) |
| 707 | + |
| 708 | + # Verify sync method was called with correct arguments |
| 709 | + mock_put.assert_called_once_with( |
| 710 | + runnable_config, |
| 711 | + sample_checkpoint, |
| 712 | + sample_checkpoint_metadata, |
| 713 | + new_versions, |
| 714 | + ) |
| 715 | + |
| 716 | + # Verify result is returned correctly |
| 717 | + assert result == runnable_config |
| 718 | + |
| 719 | + async def test_aput_writes_calls_sync_method_with_correct_args( |
| 720 | + self, saver, runnable_config |
| 721 | + ): |
| 722 | + """ |
| 723 | + Test that aput_writes calls the sync put_writes method with correct arguments. |
| 724 | + """ |
| 725 | + |
| 726 | + writes = [("channel", "value")] |
| 727 | + task_id = "test-task" |
| 728 | + task_path = "test-path" |
| 729 | + |
| 730 | + with patch.object( |
| 731 | + saver, "put_writes", side_effect=slow_put_writes |
| 732 | + ) as mock_put_writes: |
| 733 | + result = await saver.aput_writes( |
| 734 | + runnable_config, writes, task_id, task_path |
| 735 | + ) |
| 736 | + |
| 737 | + # Verify sync method was called with correct arguments |
| 738 | + mock_put_writes.assert_called_once_with( |
| 739 | + runnable_config, writes, task_id, task_path |
| 740 | + ) |
| 741 | + |
| 742 | + # Verify result (should be None for put_writes) |
| 743 | + assert result is None |
| 744 | + |
| 745 | + async def test_adelete_thread_calls_sync_method_with_correct_args( |
| 746 | + self, saver, runnable_config |
| 747 | + ): |
| 748 | + """ |
| 749 | + Test that adelete_thread calls the sync delete_thread method |
| 750 | + with correct arguments |
| 751 | + """ |
| 752 | + |
| 753 | + thread_id = runnable_config["configurable"]["thread_id"] |
| 754 | + actor_id = runnable_config["configurable"]["actor_id"] |
| 755 | + |
| 756 | + with patch.object( |
| 757 | + saver, "delete_thread", side_effect=slow_delete_thread |
| 758 | + ) as mock_delete: |
| 759 | + result = await saver.adelete_thread(thread_id, actor_id) |
| 760 | + |
| 761 | + # Verify sync method was called with correct arguments |
| 762 | + mock_delete.assert_called_once_with(thread_id, actor_id) |
| 763 | + |
| 764 | + # Verify result (should be None for delete_thread) |
| 765 | + assert result is None |
| 766 | + |
| 767 | + async def test_concurrent_calls_aget_tuple(self, saver, runnable_config): |
| 768 | + """Test that concurrent calls are faster than sequential calls.""" |
| 769 | + with patch.object(saver, "get_tuple", side_effect=slow_get_tuple): |
| 770 | + await self.assert_concurrent_calls_are_faster_than_sequential( |
| 771 | + N_ASYNC_CALLS, saver.aget_tuple, runnable_config |
| 772 | + ) |
| 773 | + |
| 774 | + async def test_concurrent_calls_adelete_thread(self, saver, runnable_config): |
| 775 | + """Test that concurrent calls are faster than sequential calls.""" |
| 776 | + thread_id = runnable_config["configurable"]["thread_id"] |
| 777 | + actor_id = runnable_config["configurable"]["actor_id"] |
| 778 | + |
| 779 | + with patch.object(saver, "delete_thread", side_effect=slow_delete_thread): |
| 780 | + await self.assert_concurrent_calls_are_faster_than_sequential( |
| 781 | + N_ASYNC_CALLS, saver.adelete_thread, thread_id, actor_id |
| 782 | + ) |
| 783 | + |
| 784 | + async def test_concurrent_calls_aput_writes(self, saver, runnable_config): |
| 785 | + """Test that concurrent calls are faster than sequential calls.""" |
| 786 | + writes = [("channel", "value")] |
| 787 | + task_id = "test-task" |
| 788 | + task_path = "test-path" |
| 789 | + |
| 790 | + with patch.object(saver, "put_writes", side_effect=slow_put_writes): |
| 791 | + await self.assert_concurrent_calls_are_faster_than_sequential( |
| 792 | + N_ASYNC_CALLS, |
| 793 | + saver.aput_writes, |
| 794 | + runnable_config, |
| 795 | + writes, |
| 796 | + task_id, |
| 797 | + task_path, |
| 798 | + ) |
| 799 | + |
| 800 | + async def test_concurrent_calls_aput( |
| 801 | + self, saver, runnable_config, sample_checkpoint, sample_checkpoint_metadata |
| 802 | + ): |
| 803 | + """Test that concurrent calls are faster than sequential calls.""" |
| 804 | + new_versions = {"default": "v2"} |
| 805 | + |
| 806 | + with patch.object(saver, "put", side_effect=slow_put_writes): |
| 807 | + await self.assert_concurrent_calls_are_faster_than_sequential( |
| 808 | + N_ASYNC_CALLS, |
| 809 | + saver.aput, |
| 810 | + runnable_config, |
| 811 | + sample_checkpoint, |
| 812 | + sample_checkpoint_metadata, |
| 813 | + new_versions, |
| 814 | + ) |
| 815 | + |
| 816 | + async def test_concurrent_calls_alist(self, saver, runnable_config): |
| 817 | + """Test that concurrent calls are faster than sequential calls.""" |
| 818 | + filter_dict = {"test": "filter"} |
| 819 | + before_config = {"before": "config"} |
| 820 | + limit_value = 10 |
| 821 | + |
| 822 | + with patch.object(saver, "list", side_effect=slow_list): |
| 823 | + |
| 824 | + async def consume_alist() -> list: |
| 825 | + """Helper coroutine to consume the async iterator.""" |
| 826 | + items = [] |
| 827 | + async for item in saver.alist( |
| 828 | + runnable_config, |
| 829 | + filter=filter_dict, |
| 830 | + before=before_config, |
| 831 | + limit=limit_value, |
| 832 | + ): |
| 833 | + items.append(item) |
| 834 | + return items |
| 835 | + |
| 836 | + await self.assert_concurrent_calls_are_faster_than_sequential( |
| 837 | + N_ASYNC_CALLS, consume_alist |
| 838 | + ) |
| 839 | + |
| 840 | + async def assert_concurrent_calls_are_faster_than_sequential( |
| 841 | + self, n_async_calls: int, func, *args, **kwargs |
| 842 | + ) -> None: |
| 843 | + """Helper to run n async tasks concurrently.""" |
| 844 | + tasks = [func(*args, **kwargs) for _ in range(n_async_calls)] |
| 845 | + start_time = time.time() |
| 846 | + await asyncio.gather(*tasks) |
| 847 | + concurrent_time = time.time() - start_time |
| 848 | + assert concurrent_time < TOTAL_EXPECTED_TIME, ( |
| 849 | + f"Concurrent execution took {concurrent_time:.2f}s, " |
| 850 | + f"expected < {TOTAL_EXPECTED_TIME}s" |
| 851 | + ) |
| 852 | + |
590 | 853 |
|
591 | 854 | class TestCheckpointerConfig: |
592 | 855 | """Test suite for CheckpointerConfig.""" |
|
0 commit comments