@@ -367,6 +367,7 @@ def mmr_traversal_search(
367367 lambda_mult : float = 0.5 ,
368368 score_threshold : float = float ("-inf" ),
369369 metadata_filter : dict [str , Any ] = {}, # noqa: B006
370+ tag_filter : set [tuple [str , str ]],
370371 ) -> Iterable [Node ]:
371372 """Retrieve documents from this graph store using MMR-traversal.
372373
@@ -398,6 +399,7 @@ def mmr_traversal_search(
398399 score_threshold: Only documents with a score greater than or equal
399400 this threshold will be chosen. Defaults to -infinity.
400401 metadata_filter: Optional metadata to filter the results.
402+ tag_filter: Optional tags to filter graph edges to be traversed.
401403 """
402404 query_embedding = self ._embedding .embed_query (query )
403405 helper = MmrHelper (
@@ -444,9 +446,14 @@ def fetch_neighborhood(neighborhood: Sequence[str]) -> None:
444446 new_candidates = {}
445447 for adjacent in adjacents :
446448 if adjacent .target_content_id not in outgoing_tags :
447- outgoing_tags [adjacent .target_content_id ] = (
448- adjacent .target_link_to_tags
449- )
449+ if tag_filter .len () == 0 :
450+ outgoing_tags [adjacent .target_content_id ] = (
451+ adjacent .target_link_to_tags
452+ )
453+ else :
454+ outgoing_tags [adjacent .target_content_id ] = (
455+ tag_filter .intersection (adjacent .target_link_to_tags )
456+ )
450457
451458 new_candidates [adjacent .target_content_id ] = (
452459 adjacent .target_text_embedding
@@ -474,7 +481,10 @@ def fetch_initial_candidates() -> None:
474481 for row in fetched :
475482 if row .content_id not in outgoing_tags :
476483 candidates [row .content_id ] = row .text_embedding
477- outgoing_tags [row .content_id ] = set (row .link_to_tags or [])
484+ if tag_filter .len () == 0 :
485+ outgoing_tags [row .content_id ] = set (row .link_to_tags or [])
486+ else :
487+ outgoing_tags [row .content_id ] = tag_filter .intersection (set (row .link_to_tags or []))
478488 helper .add_candidates (candidates )
479489
480490 if initial_roots :
@@ -522,9 +532,14 @@ def fetch_initial_candidates() -> None:
522532 new_candidates = {}
523533 for adjacent in adjacents :
524534 if adjacent .target_content_id not in outgoing_tags :
525- outgoing_tags [adjacent .target_content_id ] = (
526- adjacent .target_link_to_tags
527- )
535+ if tag_filter .len () == 0 :
536+ outgoing_tags [adjacent .target_content_id ] = (
537+ adjacent .target_link_to_tags
538+ )
539+ else :
540+ outgoing_tags [adjacent .target_content_id ] = (
541+ tag_filter .intersection (adjacent .target_link_to_tags )
542+ )
528543 new_candidates [adjacent .target_content_id ] = (
529544 adjacent .target_text_embedding
530545 )
0 commit comments