diff --git a/haystack/components/builders/chat_prompt_builder.py b/haystack/components/builders/chat_prompt_builder.py index aba2780fa7..24a5df9efa 100644 --- a/haystack/components/builders/chat_prompt_builder.py +++ b/haystack/components/builders/chat_prompt_builder.py @@ -12,7 +12,7 @@ from haystack import component, default_from_dict, default_to_dict, logging from haystack.dataclasses.chat_message import ChatMessage, ChatRole, TextContent from haystack.lazy_imports import LazyImport -from haystack.utils import Jinja2TimeExtension +from haystack.utils import Jinja2TemplateVariableExtractor, Jinja2TimeExtension from haystack.utils.jinja2_chat_extension import ChatMessageExtension, templatize_part logger = logging.getLogger(__name__) @@ -171,21 +171,31 @@ def __init__( extracted_variables = [] if template and not variables: + + def _extract_from_text( + text: Optional[str], role: Optional[str] = None, is_filter_allowed: bool = False + ) -> list: + if text is None: + raise ValueError(NO_TEXT_ERROR_MESSAGE.format(role=role or "unknown", message=text)) + if is_filter_allowed and "templatize_part" in text: + raise ValueError(FILTER_NOT_ALLOWED_ERROR_MESSAGE) + + ast = self._env.parse(text) + template_variables = meta.find_undeclared_variables(ast) + jinja_var_extractor = Jinja2TemplateVariableExtractor(env=self._env) + assigned_variables = jinja_var_extractor._extract_from_text(template_str=text) + return list(template_variables - assigned_variables) + if isinstance(template, list): for message in template: if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM): - # infer variables from template - if message.text is None: - raise ValueError(NO_TEXT_ERROR_MESSAGE.format(role=message.role.value, message=message)) - if message.text and "templatize_part" in message.text: - raise ValueError(FILTER_NOT_ALLOWED_ERROR_MESSAGE) - ast = self._env.parse(message.text) - template_variables = meta.find_undeclared_variables(ast) - extracted_variables += list(template_variables) + extracted_variables += _extract_from_text( + message.text, role=message.role.value, is_filter_allowed=True + ) elif isinstance(template, str): - ast = self._env.parse(template) - extracted_variables = list(meta.find_undeclared_variables(ast)) + extracted_variables = _extract_from_text(template, is_filter_allowed=False) + extracted_variables = extracted_variables or [] self.variables = variables or extracted_variables self.required_variables = required_variables or [] diff --git a/haystack/components/builders/prompt_builder.py b/haystack/components/builders/prompt_builder.py index 3e1d3603ca..ab2caea1f3 100644 --- a/haystack/components/builders/prompt_builder.py +++ b/haystack/components/builders/prompt_builder.py @@ -8,7 +8,7 @@ from jinja2.sandbox import SandboxedEnvironment from haystack import component, default_to_dict, logging -from haystack.utils import Jinja2TimeExtension +from haystack.utils import Jinja2TemplateVariableExtractor, Jinja2TimeExtension logger = logging.getLogger(__name__) @@ -174,11 +174,15 @@ def __init__( self._env = SandboxedEnvironment() self.template = self._env.from_string(template) + if not variables: - # infer variables from template ast = self._env.parse(template) template_variables = meta.find_undeclared_variables(ast) - variables = list(template_variables) + jinja_var_extractor = Jinja2TemplateVariableExtractor(env=self._env) + assigned_variables = jinja_var_extractor._extract_from_text(template_str=template) + + variables = list(template_variables - assigned_variables) + variables = variables or [] self.variables = variables diff --git a/haystack/utils/__init__.py b/haystack/utils/__init__.py index 7439e97658..3b8edaa6bc 100644 --- a/haystack/utils/__init__.py +++ b/haystack/utils/__init__.py @@ -15,6 +15,7 @@ "device": ["ComponentDevice", "Device", "DeviceMap", "DeviceType"], "deserialization": ["deserialize_document_store_in_init_params_inplace", "deserialize_chatgenerator_inplace"], "filters": ["document_matches_filter", "raise_on_invalid_filter_syntax"], + "jinja2": ["Jinja2TemplateVariableExtractor"], "jinja2_extensions": ["Jinja2TimeExtension"], "jupyter": ["is_in_jupyter"], "misc": ["expit", "expand_page_range"], @@ -40,6 +41,7 @@ from .device import DeviceType as DeviceType from .filters import document_matches_filter as document_matches_filter from .filters import raise_on_invalid_filter_syntax as raise_on_invalid_filter_syntax + from .jinja2 import Jinja2TemplateVariableExtractor as Jinja2TemplateVariableExtractor from .jinja2_extensions import Jinja2TimeExtension as Jinja2TimeExtension from .jupyter import is_in_jupyter as is_in_jupyter from .misc import expand_page_range as expand_page_range diff --git a/haystack/utils/jinja2.py b/haystack/utils/jinja2.py new file mode 100644 index 0000000000..60d9838896 --- /dev/null +++ b/haystack/utils/jinja2.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +from jinja2 import Environment, nodes + + +class Jinja2TemplateVariableExtractor: + """ + A utility class for extracting declared variables from Jinja2 templates. + """ + + def __init__(self, env: Optional[Environment] = None): + self.env = env or Environment() + + def _extract_from_text(self, template_str: str, role: Optional[str] = None) -> set[str]: + """ + Extract declared variables from a Jinja2 template string. + + :param template_str: The Jinja2 template string to analyze. + :param env: The Jinja2 Environment. Defaults to None. + + :returns: + A set of variable names used in the template. + """ + try: + ast = self.env.parse(template_str) + except Exception as e: + raise RuntimeError(f"Failed to parse Jinja2 template: {e}") + + # Collect all variables assigned inside the template via {% set %} + assigned_variables = set() + + for node in ast.find_all(nodes.Assign): + if isinstance(node.target, nodes.Name): + assigned_variables.add(node.target.name) + elif isinstance(node.target, (nodes.List, nodes.Tuple)): + for name_node in node.target.items: + if isinstance(name_node, nodes.Name): + assigned_variables.add(name_node.name) + + return assigned_variables diff --git a/releasenotes/notes/jinja2-set-vars-parsing-919ae03d3c8a1465.yaml b/releasenotes/notes/jinja2-set-vars-parsing-919ae03d3c8a1465.yaml new file mode 100644 index 0000000000..a617094f1f --- /dev/null +++ b/releasenotes/notes/jinja2-set-vars-parsing-919ae03d3c8a1465.yaml @@ -0,0 +1,7 @@ +fixes: + - | + Fixed an issue where Jinja2 variable assignments using the `set` directive + were not being parsed correctly in certain contexts. This fix ensures that + variables assigned with `{% set var = value %}` are now properly recognized + and can be used as expected within templates inside `PromptBuilder` and + `ChatPromptBuilder`. diff --git a/test/components/builders/test_chat_prompt_builder.py b/test/components/builders/test_chat_prompt_builder.py index fa15dcbc56..99efedff92 100644 --- a/test/components/builders/test_chat_prompt_builder.py +++ b/test/components/builders/test_chat_prompt_builder.py @@ -957,3 +957,31 @@ def test_from_dict(self): assert builder.template == template assert builder.variables == ["name", "assistant_name"] assert builder.required_variables == ["name"] + + def test_variables_correct_with_tuple_assignment(self): + template = """{% if existing_documents is not none %} + {% set x, y = (existing_documents|length, 1) %} + {% else %} + {% set x, y = (0, 1) %} + {% endif %} + {% message role="user" %}x={{ x }}, y={{ y }}{% endmessage %} + """ + builder = ChatPromptBuilder(template=template, required_variables="*") + assert set(builder.variables) == {"existing_documents"} + res = builder.run(existing_documents=None) + prompt = res["prompt"] + assert any("x=0, y=1" in msg.text for msg in prompt) + + def test_variables_correct_with_list_assignment(self): + template = """{% if existing_documents is not none %} + {% set x, y = [existing_documents|length, 1] %} + {% else %} + {% set x, y = [0, 1] %} + {% endif %} + {% message role="user" %}x={{ x }}, y={{ y }}{% endmessage %} + """ + builder = ChatPromptBuilder(template=template, required_variables="*") + assert set(builder.variables) == {"existing_documents"} + res = builder.run(existing_documents=None) + prompt = res["prompt"] + assert any("x=0, y=1" in msg.text for msg in prompt) diff --git a/test/components/builders/test_prompt_builder.py b/test/components/builders/test_prompt_builder.py index 8a5504c5fc..288d0f2554 100644 --- a/test/components/builders/test_prompt_builder.py +++ b/test/components/builders/test_prompt_builder.py @@ -337,3 +337,47 @@ def test_warning_no_required_variables(self, caplog): with caplog.at_level(logging.WARNING): _ = PromptBuilder(template="This is a {{ variable }}") assert "but `required_variables` is not set." in caplog.text + + def test_template_assigned_variables_from_required_inputs(self) -> None: + template = """{% if existing_documents is not none %} + {% set existing_doc_len = existing_documents|length %} + {% else %} + {% set existing_doc_len = 0 %} + {% endif %} + {% for doc in docs %} + + {{ doc.content }} + + {% endfor %} + """ + + builder = PromptBuilder(template=template, required_variables="*") + + builder = PromptBuilder(template=template, required_variables="*") + assert set(builder.variables) == {"docs", "existing_documents"} + + def test_variables_correct_with_tuple_assignment(self): + template = """{% if existing_documents is not none %} +{% set x, y = (existing_documents|length, 1) %} +{% else %} +{% set x, y = (0, 1) %} +{% endif %} +x={{ x }}, y={{ y }} +""" + builder = PromptBuilder(template=template, required_variables="*") + assert set(builder.variables) == {"existing_documents"} + res = builder.run(existing_documents=None) + assert "x=0, y=1" in res["prompt"] + + def test_variables_correct_with_list_assignment(self): + template = """{% if existing_documents is not none %} +{% set x, y = [existing_documents|length, 1] %} +{% else %} +{% set x, y = [0, 1] %} +{% endif %} +x={{ x }}, y={{ y }} +""" + builder = PromptBuilder(template=template, required_variables="*") + assert set(builder.variables) == {"existing_documents"} + res = builder.run(existing_documents=None) + assert "x=0, y=1" in res["prompt"]