Skip to content

Commit 301772c

Browse files
blink1073Abhishekabhishek499NoahStapp
authored
INTPYTHON-678 Allow agent_toolkit parser to handle Python and JS objects (#193)
Co-authored-by: Abhishek <[email protected]> Co-authored-by: abhishek499 <[email protected]> Co-authored-by: Noah Stapp <[email protected]>
1 parent 14e4d09 commit 301772c

File tree

2 files changed

+98
-11
lines changed

2 files changed

+98
-11
lines changed

libs/langchain-mongodb/langchain_mongodb/agent_toolkit/database.py

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import json
66
import re
7-
from datetime import date, datetime
7+
from datetime import date, datetime, timezone
88
from typing import Any, Dict, Iterable, List, Optional, Union
99

1010
from bson import ObjectId
@@ -209,16 +209,34 @@ def _elide_doc(self, doc: dict[str, Any]) -> None:
209209
doc[key] = value[: MAX_STRING_LENGTH_OF_SAMPLE_DOCUMENT_VALUE + 1]
210210

211211
def _parse_command(self, command: str) -> Any:
212-
# Convert a JavaScript command to a python object.
212+
"""
213+
Extracts and parses the aggregation pipeline from a JavaScript-style MongoDB command.
214+
Handles ObjectId(), ISODate(), new Date() and converts them into Python constructs.
215+
"""
213216
command = re.sub(r"\s+", " ", command.strip())
214-
# Handle missing closing parens.
215217
if command.endswith("]"):
216218
command += ")"
217-
agg_command = command[command.index("[") : -1]
219+
220+
try:
221+
agg_str = command.split(".aggregate(", 1)[1].rsplit(")", 1)[0]
222+
except Exception as e:
223+
raise ValueError(f"Could not extract aggregation pipeline: {e}") from e
224+
225+
# Convert JavaScript-style constructs to Python syntax
226+
agg_str = self._convert_mongo_js_to_python(agg_str)
227+
218228
try:
219-
return json.loads(agg_command)
229+
eval_globals = {
230+
"ObjectId": ObjectId,
231+
"datetime": datetime,
232+
"timezone": timezone,
233+
}
234+
agg_pipeline = eval(agg_str, eval_globals)
235+
if not isinstance(agg_pipeline, list):
236+
raise ValueError("Aggregation pipeline must be a list.")
237+
return agg_pipeline
220238
except Exception as e:
221-
raise ValueError(f"Cannot execute command {command}") from e
239+
raise ValueError(f"Failed to parse aggregation pipeline: {e}") from e
222240

223241
def run(self, command: str) -> Union[str, Cursor]:
224242
"""Execute a MongoDB aggregation command and return a string representing the results.
@@ -230,14 +248,29 @@ def run(self, command: str) -> Union[str, Cursor]:
230248
"""
231249
if not command.startswith("db."):
232250
raise ValueError(f"Cannot run command {command}")
233-
col_name = command.split(".")[1]
251+
252+
try:
253+
col_name = command.split(".")[1]
254+
except IndexError as e:
255+
raise ValueError(
256+
"Invalid command format. Could not extract collection name."
257+
) from e
258+
234259
if col_name not in self.get_usable_collection_names():
235260
raise ValueError(f"Collection {col_name} does not exist!")
236-
coll = self._db[col_name]
261+
237262
if ".aggregate(" not in command:
238-
raise ValueError(f"Cannot execute command {command}")
239-
agg = self._parse_command(command)
240-
return dumps(list(coll.aggregate(agg)), indent=2)
263+
raise ValueError("Only aggregate(...) queries are currently supported.")
264+
265+
# Parse pipeline using helper
266+
agg_pipeline = self._parse_command(command)
267+
268+
try:
269+
coll = self._db[col_name]
270+
result = coll.aggregate(agg_pipeline)
271+
return dumps(list(result), indent=2)
272+
except Exception as e:
273+
raise ValueError(f"Error executing aggregation: {e}") from e
241274

242275
def get_collection_info_no_throw(
243276
self, collection_names: Optional[List[str]] = None
@@ -280,3 +313,39 @@ def get_context(self) -> Dict[str, Any]:
280313
"collection_info": collection_info,
281314
"collection_names": ", ".join(collection_names),
282315
}
316+
317+
def _convert_mongo_js_to_python(self, code: str) -> str:
318+
"""Convert JavaScript-style MongoDB syntax into Python-safe code."""
319+
320+
def _handle_iso_date(match: Any) -> str:
321+
date_str = match.group(1)
322+
if not date_str:
323+
raise ValueError("ISODate must contain a date string.")
324+
dt = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
325+
return f"datetime({dt.year}, {dt.month}, {dt.day}, {dt.hour}, {dt.minute}, {dt.second}, tzinfo=timezone.utc)"
326+
327+
def _handle_new_date(match: Any) -> str:
328+
date_str = match.group(1)
329+
if not date_str:
330+
raise ValueError(
331+
"new Date() without arguments is not allowed. Please pass an explicit date string."
332+
)
333+
dt = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
334+
return f"datetime({dt.year}, {dt.month}, {dt.day}, {dt.hour}, {dt.minute}, {dt.second}, tzinfo=timezone.utc)"
335+
336+
def _handle_object_id(match: Any) -> str:
337+
oid_str = match.group(1)
338+
if not oid_str:
339+
raise ValueError("ObjectId must contain a value.")
340+
return f"ObjectId('{oid_str}')"
341+
342+
patterns = [
343+
(r'ISODate\(\s*["\']([^"\']*)["\']\s*\)', _handle_iso_date),
344+
(r'new\s+Date\(\s*["\']([^"\']*)["\']\s*\)', _handle_new_date),
345+
(r'ObjectId\(\s*["\']([^"\']*)["\']\s*\)', _handle_object_id),
346+
]
347+
348+
for pattern, replacer in patterns:
349+
code = re.sub(pattern, replacer, code)
350+
351+
return code

libs/langchain-mongodb/tests/unit_tests/test_tools.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

3+
from datetime import datetime
34
from typing import Type
45

6+
from bson import ObjectId
57
from langchain_tests.unit_tests import ToolsUnitTests
68

79
from langchain_mongodb.agent_toolkit import MongoDBDatabase
@@ -69,3 +71,19 @@ def tool_constructor_params(self) -> dict:
6971
@property
7072
def tool_invoke_params_example(self) -> dict:
7173
return dict(query="db.foo.aggregate()")
74+
75+
76+
def test_database_parse_command() -> None:
77+
db = MongoDBDatabase(MockClient(), "test") # type:ignore[arg-type]
78+
79+
command = """db.user.aggregate([ { "$match": { "_id": ObjectId("123412341234123412341234") } } ])"""
80+
result = db._parse_command(command)
81+
assert isinstance(result[0]["$match"]["_id"], ObjectId)
82+
83+
command = """db.user.aggregate([ { "$match": { "date": ISODate("2017-04-27T04:26:42.709Z") } } ])"""
84+
result = db._parse_command(command)
85+
assert isinstance(result[0]["$match"]["date"], datetime)
86+
87+
command = """db.user.aggregate([ { "$match": { "date": new Date("2017-04-27T04:26:42.709Z") } } ])"""
88+
result = db._parse_command(command)
89+
assert isinstance(result[0]["$match"]["date"], datetime)

0 commit comments

Comments
 (0)