Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections.abc import AsyncIterator, Iterator, Sequence
from typing import Any, TypeAlias, cast

from langchain_core.runnables import RunnableConfig
from langchain_core.runnables import RunnableConfig, run_in_executor
from langgraph.checkpoint.base import (
BaseCheckpointSaver,
ChannelVersions,
Expand Down Expand Up @@ -273,9 +273,9 @@ def delete_thread(self, thread_id: str, actor_id: str = "") -> None:
"""Delete all checkpoints and writes associated with a thread."""
self.checkpoint_event_client.delete_events(thread_id, actor_id)

# ===== Async methods ( TODO: NOT IMPLEMENTED YET ) =====
# ===== Async methods ( Running sync methods inside executor ) =====
async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
return self.get_tuple(config)
return await run_in_executor(None, self.get_tuple, config)

async def alist(
self,
Expand All @@ -285,7 +285,9 @@ async def alist(
before: RunnableConfig | None = None,
limit: int | None = None,
) -> AsyncIterator[CheckpointTuple]:
for item in self.list(config, filter=filter, before=before, limit=limit):
for item in await run_in_executor(
None, self.list, config, filter=filter, before=before, limit=limit
):
yield item

async def aput(
Expand All @@ -295,7 +297,9 @@ async def aput(
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
return self.put(config, checkpoint, metadata, new_versions)
return await run_in_executor(
None, self.put, config, checkpoint, metadata, new_versions
)

async def aput_writes(
self,
Expand All @@ -304,10 +308,12 @@ async def aput_writes(
task_id: str,
task_path: str = "",
) -> None:
return self.put_writes(config, writes, task_id, task_path)
return await run_in_executor(
None, self.put_writes, config, writes, task_id, task_path
)

async def adelete_thread(self, thread_id: str, actor_id: str = "") -> None:
self.delete_thread(thread_id, actor_id)
await run_in_executor(None, self.delete_thread, thread_id, actor_id)
return None

def get_next_version(
Expand Down
272 changes: 272 additions & 0 deletions libs/langgraph-checkpoint-aws/tests/unit_tests/agentcore/test_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
Unit tests for AgentCore Memory Checkpoint Saver.
"""

import asyncio
import json
import time
from unittest.mock import ANY, MagicMock, Mock, patch

import pytest
Expand Down Expand Up @@ -30,6 +32,30 @@
)
from langgraph_checkpoint_aws.agentcore.saver import AgentCoreMemorySaver

# Configure pytest to use anyio for async tests
pytestmark = pytest.mark.anyio

# Test constants for async testing
N_ASYNC_CALLS = 5
MOCK_SLEEP_DURATION = 0.1 / N_ASYNC_CALLS
OVERHEAD_DURATION = 0.01
TOTAL_EXPECTED_TIME = MOCK_SLEEP_DURATION + OVERHEAD_DURATION


@pytest.fixture
def sample_checkpoint_tuple():
return CheckpointTuple(
config={
"configurable": {
"thread_id": "test_thread_id",
"actor_id": "test_actor",
"checkpoint_id": "test_checkpoint_id",
}
},
checkpoint={"id": "test_checkpoint_id"},
metadata={"source": "input", "step": 0},
)


@pytest.fixture
def sample_checkpoint_event():
Expand Down Expand Up @@ -145,6 +171,56 @@ def sample_checkpoint_metadata(self):
},
)

@pytest.fixture
def mock_slow_get_tuple(self, sample_checkpoint_tuple):
"""Mock get_tuple with artificial delay for testing async concurrency."""

def _mock_slow_get_tuple(config): # noqa: ARG001
time.sleep(MOCK_SLEEP_DURATION)
return sample_checkpoint_tuple

return _mock_slow_get_tuple

@pytest.fixture
def mock_slow_list(self, sample_checkpoint_tuple):
"""Mock list with artificial delay for testing async concurrency."""

def _mock_slow_list(config, *, filter=None, before=None, limit=None): # noqa: ARG001 A002
time.sleep(MOCK_SLEEP_DURATION)
return [sample_checkpoint_tuple]

return _mock_slow_list

@pytest.fixture
def mock_slow_put(self):
"""Mock put with artificial delay for testing async concurrency."""

def _mock_slow_put(config, checkpoint, metadata, new_versions): # noqa: ARG001
time.sleep(MOCK_SLEEP_DURATION)
return config

return _mock_slow_put

@pytest.fixture
def mock_slow_put_writes(self):
"""Mock put_writes with artificial delay for testing async concurrency."""

def _mock_slow_put_writes(config, writes, task_id, task_path=""): # noqa: ARG001
time.sleep(MOCK_SLEEP_DURATION)
return

return _mock_slow_put_writes

@pytest.fixture
def mock_slow_delete_thread(self):
"""Mock delete_thread with artificial delay for testing async concurrency."""

def _mock_slow_delete_thread(thread_id, actor_id=""): # noqa: ARG001
time.sleep(MOCK_SLEEP_DURATION)
return

return _mock_slow_delete_thread

def test_init_with_default_client(self, memory_id):
with patch("boto3.client") as mock_boto3_client:
mock_client = Mock()
Expand Down Expand Up @@ -587,6 +663,202 @@ def test_get_next_version(self, saver):
)
assert version.startswith("00000000000000000000000000000011.")

async def test_aget_tuple_calls_sync_method_with_correct_args(
self, saver, runnable_config, mock_slow_get_tuple
):
with patch.object(
saver, "get_tuple", side_effect=mock_slow_get_tuple
) as mock_get:
result = await saver.aget_tuple(runnable_config)

# Verify sync method was called with correct arguments
mock_get.assert_called_once_with(runnable_config)

assert result is not None

async def test_alist_calls_sync_method_with_correct_args(
self, saver, runnable_config, mock_slow_list
):
filter_dict = {"test": "filter"}
before_config = {"before": "config"}
limit_value = 10

with patch.object(saver, "list", side_effect=mock_slow_list) as mock_list:
# Collect all items from async iterator
items = []
async for item in saver.alist(
runnable_config,
filter=filter_dict,
before=before_config,
limit=limit_value,
):
items.append(item)

# Verify sync method was called with correct arguments
mock_list.assert_called_once_with(
runnable_config,
filter=filter_dict,
before=before_config,
limit=limit_value,
)
assert len(items) == 1
assert isinstance(items[0], CheckpointTuple)

async def test_aput_calls_sync_method_with_correct_args(
self,
saver,
runnable_config,
sample_checkpoint,
sample_checkpoint_metadata,
mock_slow_put,
):
new_versions = {"default": "v2"}

with patch.object(saver, "put", side_effect=mock_slow_put) as mock_put:
result = await saver.aput(
runnable_config,
sample_checkpoint,
sample_checkpoint_metadata,
new_versions,
)

# Verify sync method was called with correct arguments
mock_put.assert_called_once_with(
runnable_config,
sample_checkpoint,
sample_checkpoint_metadata,
new_versions,
)

assert result == runnable_config

async def test_aput_writes_calls_sync_method_with_correct_args(
self, saver, runnable_config, mock_slow_put_writes
):
writes = [("channel", "value")]
task_id = "test-task"
task_path = "test-path"

with patch.object(
saver, "put_writes", side_effect=mock_slow_put_writes
) as mock_put_writes:
result = await saver.aput_writes(
runnable_config, writes, task_id, task_path
)

# Verify sync method was called with correct arguments
mock_put_writes.assert_called_once_with(
runnable_config, writes, task_id, task_path
)
assert result is None

async def test_adelete_thread_calls_sync_method_with_correct_args(
self, saver, runnable_config, mock_slow_delete_thread
):
thread_id = runnable_config["configurable"]["thread_id"]
actor_id = runnable_config["configurable"]["actor_id"]

with patch.object(
saver, "delete_thread", side_effect=mock_slow_delete_thread
) as mock_delete:
result = await saver.adelete_thread(thread_id, actor_id)

# Verify sync method was called with correct arguments
mock_delete.assert_called_once_with(thread_id, actor_id)
assert result is None

async def test_concurrent_calls_aget_tuple(
self, saver, runnable_config, mock_slow_get_tuple
):
with patch.object(saver, "get_tuple", side_effect=mock_slow_get_tuple):
await self.assert_concurrent_calls_are_faster_than_sequential(
N_ASYNC_CALLS, saver.aget_tuple, runnable_config
)

async def test_concurrent_calls_adelete_thread(
self, saver, runnable_config, mock_slow_delete_thread
):
thread_id = runnable_config["configurable"]["thread_id"]
actor_id = runnable_config["configurable"]["actor_id"]

with patch.object(saver, "delete_thread", side_effect=mock_slow_delete_thread):
await self.assert_concurrent_calls_are_faster_than_sequential(
N_ASYNC_CALLS, saver.adelete_thread, thread_id, actor_id
)

async def test_concurrent_calls_aput_writes(
self, saver, runnable_config, mock_slow_put_writes
):
writes = [("channel", "value")]
task_id = "test-task"
task_path = "test-path"

with patch.object(saver, "put_writes", side_effect=mock_slow_put_writes):
await self.assert_concurrent_calls_are_faster_than_sequential(
N_ASYNC_CALLS,
saver.aput_writes,
runnable_config,
writes,
task_id,
task_path,
)

async def test_concurrent_calls_aput(
self,
saver,
runnable_config,
sample_checkpoint,
sample_checkpoint_metadata,
mock_slow_put,
):
new_versions = {"default": "v2"}

with patch.object(saver, "put", side_effect=mock_slow_put):
await self.assert_concurrent_calls_are_faster_than_sequential(
N_ASYNC_CALLS,
saver.aput,
runnable_config,
sample_checkpoint,
sample_checkpoint_metadata,
new_versions,
)

async def test_concurrent_calls_alist(self, saver, runnable_config, mock_slow_list):
filter_dict = {"test": "filter"}
before_config = {"before": "config"}
limit_value = 10

with patch.object(saver, "list", side_effect=mock_slow_list):

async def consume_alist() -> list:
"""Helper coroutine to consume the async iterator."""
items = []
async for item in saver.alist(
runnable_config,
filter=filter_dict,
before=before_config,
limit=limit_value,
):
items.append(item)
return items

await self.assert_concurrent_calls_are_faster_than_sequential(
N_ASYNC_CALLS, consume_alist
)

async def assert_concurrent_calls_are_faster_than_sequential(
self, n_async_calls: int, func, *args, **kwargs
) -> None:
"""Helper to run n async tasks concurrently."""
tasks = [func(*args, **kwargs) for _ in range(n_async_calls)]
start_time = time.time()
await asyncio.gather(*tasks)
concurrent_time = time.time() - start_time
assert concurrent_time < TOTAL_EXPECTED_TIME, (
f"Concurrent execution took {concurrent_time:.2f}s, "
f"expected < {TOTAL_EXPECTED_TIME}s"
)


class TestCheckpointerConfig:
"""Test suite for CheckpointerConfig."""
Expand Down