|
12 | 12 | from haystack import component, default_from_dict, default_to_dict, logging |
13 | 13 | from haystack.dataclasses.chat_message import ChatMessage, ChatRole, TextContent |
14 | 14 | from haystack.lazy_imports import LazyImport |
15 | | -from haystack.utils import Jinja2TimeExtension |
| 15 | +from haystack.utils import Jinja2TimeExtension, extract_declared_variables |
16 | 16 | from haystack.utils.jinja2_chat_extension import ChatMessageExtension, templatize_part |
17 | 17 |
|
18 | 18 | logger = logging.getLogger(__name__) |
@@ -171,21 +171,28 @@ def __init__( |
171 | 171 |
|
172 | 172 | extracted_variables = [] |
173 | 173 | if template and not variables: |
| 174 | + |
| 175 | + def _extract_from_text(text: str, role: Optional[str] = None, is_filter_allowed: bool = False) -> list: |
| 176 | + if text is None: |
| 177 | + raise ValueError(NO_TEXT_ERROR_MESSAGE.format(role=role or "unknown", message=text)) |
| 178 | + if is_filter_allowed and "templatize_part" in text: |
| 179 | + raise ValueError(FILTER_NOT_ALLOWED_ERROR_MESSAGE) |
| 180 | + |
| 181 | + ast = self._env.parse(text) |
| 182 | + template_variables = meta.find_undeclared_variables(ast) |
| 183 | + assigned_variables = extract_declared_variables(text, env=self._env) |
| 184 | + return list(template_variables - assigned_variables) |
| 185 | + |
174 | 186 | if isinstance(template, list): |
175 | 187 | for message in template: |
176 | 188 | if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM): |
177 | | - # infer variables from template |
178 | | - if message.text is None: |
179 | | - raise ValueError(NO_TEXT_ERROR_MESSAGE.format(role=message.role.value, message=message)) |
180 | | - if message.text and "templatize_part" in message.text: |
181 | | - raise ValueError(FILTER_NOT_ALLOWED_ERROR_MESSAGE) |
182 | | - ast = self._env.parse(message.text) |
183 | | - template_variables = meta.find_undeclared_variables(ast) |
184 | | - extracted_variables += list(template_variables) |
| 189 | + extracted_variables += _extract_from_text( |
| 190 | + message.text, role=message.role.value, is_filter_allowed=True |
| 191 | + ) |
185 | 192 | elif isinstance(template, str): |
186 | | - ast = self._env.parse(template) |
187 | | - extracted_variables = list(meta.find_undeclared_variables(ast)) |
| 193 | + extracted_variables = _extract_from_text(template, is_filter_allowed=False) |
188 | 194 |
|
| 195 | + extracted_variables = extracted_variables or [] |
189 | 196 | self.variables = variables or extracted_variables |
190 | 197 | self.required_variables = required_variables or [] |
191 | 198 |
|
|
0 commit comments