Skip to content

Commit 0dfa738

Browse files
chore: address PR feedback
1 parent a544766 commit 0dfa738

File tree

2 files changed

+43
-22
lines changed

2 files changed

+43
-22
lines changed

libs/langgraph-checkpoint-aws/langgraph_checkpoint_aws/store/valkey/search_strategies.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,31 @@ def _execute_vector_search(self, op: SearchOp) -> list[SearchItem]:
154154
)
155155

156156
# Generate embedding for the query
157+
# Try sync method first, fall back to async only if no event loop exists
157158
if hasattr(self.store.embeddings, "embed_query"):
158159
query_vector = self.store.embeddings.embed_query(op.query)
160+
elif hasattr(self.store.embeddings, "aembed_query"):
161+
# Check if we're in an async context
162+
try:
163+
asyncio.get_running_loop()
164+
# We're in an async context - cannot use asyncio.run()
165+
raise SearchIndexError(
166+
"Cannot generate embeddings: sync method not available "
167+
"and already in async context. Use AsyncValkeyStore instead.",
168+
index_name=index_name,
169+
index_operation="embedding_generation",
170+
)
171+
except RuntimeError:
172+
# No running event loop, safe to create one
173+
query_vector = asyncio.run(
174+
self.store.embeddings.aembed_query(op.query)
175+
)
159176
else:
160-
# Fallback to async if only async method available
161-
query_vector = asyncio.run(self.store.embeddings.aembed_query(op.query))
177+
raise SearchIndexError(
178+
"No embedding method available (embed_query or aembed_query)",
179+
index_name=index_name,
180+
index_operation="embedding_generation",
181+
)
162182

163183
# Pack vector to binary bytes for FT.SEARCH
164184
vec_bytes = struct.pack(f"{len(query_vector)}f", *query_vector)

libs/langgraph-checkpoint-aws/tests/unit_tests/store/valkey/test_valkey_index_config_unit.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,25 @@ def fake_valkey_client():
2525
return fakeredis.FakeStrictRedis(decode_responses=False)
2626

2727

28+
def convert_hash_fields_to_str(hash_fields: dict) -> dict:
29+
"""Helper to convert hash field bytes to strings for test assertions.
30+
31+
Args:
32+
hash_fields: Raw hash fields from Valkey (may contain bytes)
33+
34+
Returns:
35+
Dictionary with bytes converted to strings (except binary vector field)
36+
"""
37+
result = {}
38+
for k, v in hash_fields.items():
39+
key_str = k.decode("utf-8") if isinstance(k, bytes) else k
40+
# Skip binary vector field - it's not UTF-8 text
41+
if key_str == "vector" and isinstance(v, bytes):
42+
result[key_str] = v # Keep as bytes
43+
else:
44+
result[key_str] = v.decode("utf-8") if isinstance(v, bytes) else v
45+
return result
46+
2847
class TestValkeyIndexConfig:
2948
"""Test suite for ValkeyIndexConfig TypedDict."""
3049

@@ -682,16 +701,7 @@ def test_document_creation_includes_searchable_fields(self, fake_valkey_client):
682701
hash_fields = fake_valkey_client.hgetall(key)
683702

684703
# Convert bytes keys/values to strings for easier checking
685-
hash_fields_str = {}
686-
for k, v in hash_fields.items():
687-
key_str = k.decode("utf-8") if isinstance(k, bytes) else k
688-
# Skip binary vector field - it's not UTF-8 text
689-
if key_str == "vector" and isinstance(v, bytes):
690-
val_str = v # Keep as bytes
691-
else:
692-
val_str = v.decode("utf-8") if isinstance(v, bytes) else v
693-
hash_fields_str[key_str] = val_str
694-
704+
hash_fields_str = convert_hash_fields_to_str(hash_fields)
695705
# Verify that searchable fields are included in hash fields
696706
# (without value_ prefix)
697707
assert "user_id" in hash_fields_str
@@ -744,16 +754,7 @@ def test_list_fields_handled_correctly(self, fake_valkey_client):
744754
hash_fields = fake_valkey_client.hgetall(key)
745755

746756
# Convert bytes keys/values to strings for easier checking
747-
hash_fields_str = {}
748-
for k, v in hash_fields.items():
749-
key_str = k.decode("utf-8") if isinstance(k, bytes) else k
750-
# Skip binary vector field - it's not UTF-8 text
751-
if key_str == "vector" and isinstance(v, bytes):
752-
val_str = v # Keep as bytes
753-
else:
754-
val_str = v.decode("utf-8") if isinstance(v, bytes) else v
755-
hash_fields_str[key_str] = val_str
756-
757+
hash_fields_str = convert_hash_fields_to_str(hash_fields)
757758
# Verify that list fields are converted to comma-separated strings
758759
assert hash_fields_str["tags"] == "machine-learning,ai,python"
759760
assert hash_fields_str["categories"] == "tech,tutorial"

0 commit comments

Comments
 (0)