Skip to content

Commit 6a5705e

Browse files
committed
feat: refactor ChatPromptBuilder to make use of extract_declared_variables
1 parent 3db96ab commit 6a5705e

File tree

2 files changed

+46
-11
lines changed

2 files changed

+46
-11
lines changed

haystack/components/builders/chat_prompt_builder.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from haystack import component, default_from_dict, default_to_dict, logging
1313
from haystack.dataclasses.chat_message import ChatMessage, ChatRole, TextContent
1414
from haystack.lazy_imports import LazyImport
15-
from haystack.utils import Jinja2TimeExtension
15+
from haystack.utils import Jinja2TimeExtension, extract_declared_variables
1616
from haystack.utils.jinja2_chat_extension import ChatMessageExtension, templatize_part
1717

1818
logger = logging.getLogger(__name__)
@@ -171,21 +171,28 @@ def __init__(
171171

172172
extracted_variables = []
173173
if template and not variables:
174+
175+
def _extract_from_text(text: str, role: Optional[str] = None, is_filter_allowed: bool = False) -> list:
176+
if text is None:
177+
raise ValueError(NO_TEXT_ERROR_MESSAGE.format(role=role or "unknown", message=text))
178+
if is_filter_allowed and "templatize_part" in text:
179+
raise ValueError(FILTER_NOT_ALLOWED_ERROR_MESSAGE)
180+
181+
ast = self._env.parse(text)
182+
template_variables = meta.find_undeclared_variables(ast)
183+
assigned_variables = extract_declared_variables(text, env=self._env)
184+
return list(template_variables - assigned_variables)
185+
174186
if isinstance(template, list):
175187
for message in template:
176188
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
177-
# infer variables from template
178-
if message.text is None:
179-
raise ValueError(NO_TEXT_ERROR_MESSAGE.format(role=message.role.value, message=message))
180-
if message.text and "templatize_part" in message.text:
181-
raise ValueError(FILTER_NOT_ALLOWED_ERROR_MESSAGE)
182-
ast = self._env.parse(message.text)
183-
template_variables = meta.find_undeclared_variables(ast)
184-
extracted_variables += list(template_variables)
189+
extracted_variables += _extract_from_text(
190+
message.text, role=message.role.value, is_filter_allowed=True
191+
)
185192
elif isinstance(template, str):
186-
ast = self._env.parse(template)
187-
extracted_variables = list(meta.find_undeclared_variables(ast))
193+
extracted_variables = _extract_from_text(template, is_filter_allowed=False)
188194

195+
extracted_variables = extracted_variables or []
189196
self.variables = variables or extracted_variables
190197
self.required_variables = required_variables or []
191198

test/components/builders/test_chat_prompt_builder.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,3 +957,31 @@ def test_from_dict(self):
957957
assert builder.template == template
958958
assert builder.variables == ["name", "assistant_name"]
959959
assert builder.required_variables == ["name"]
960+
961+
def test_variables_correct_with_tuple_assignment(self):
962+
template = """{% if existing_documents is not none %}
963+
{% set x, y = (existing_documents|length, 1) %}
964+
{% else %}
965+
{% set x, y = (0, 1) %}
966+
{% endif %}
967+
{% message role="user" %}x={{ x }}, y={{ y }}{% endmessage %}
968+
"""
969+
builder = ChatPromptBuilder(template=template, required_variables="*")
970+
assert set(builder.variables) == {"existing_documents"}
971+
res = builder.run(existing_documents=None)
972+
prompt = res["prompt"]
973+
assert any("x=0, y=1" in msg.text for msg in prompt)
974+
975+
def test_variables_correct_with_list_assignment(self):
976+
template = """{% if existing_documents is not none %}
977+
{% set x, y = [existing_documents|length, 1] %}
978+
{% else %}
979+
{% set x, y = [0, 1] %}
980+
{% endif %}
981+
{% message role="user" %}x={{ x }}, y={{ y }}{% endmessage %}
982+
"""
983+
builder = ChatPromptBuilder(template=template, required_variables="*")
984+
assert set(builder.variables) == {"existing_documents"}
985+
res = builder.run(existing_documents=None)
986+
prompt = res["prompt"]
987+
assert any("x=0, y=1" in msg.text for msg in prompt)

0 commit comments

Comments
 (0)