|
27 | 27 | from pathlib import Path |
28 | 28 | from typing import TYPE_CHECKING |
29 | 29 | from unittest import mock |
30 | | -from unittest.mock import patch |
| 30 | +from unittest.mock import call, patch |
31 | 31 |
|
32 | 32 | import pandas as pd |
33 | 33 | import pytest |
|
48 | 48 | from airflow.sdk import ( |
49 | 49 | DAG, |
50 | 50 | BaseOperator, |
| 51 | + BaseOperatorLink, |
51 | 52 | Connection, |
52 | 53 | dag as dag_decorator, |
53 | 54 | get_current_context, |
@@ -1723,6 +1724,93 @@ def execute(self, context): |
1723 | 1724 | map_index=runtime_ti.map_index, |
1724 | 1725 | ) |
1725 | 1726 |
|
| 1727 | + def test_task_failed_with_operator_extra_links( |
| 1728 | + self, create_runtime_ti, mock_supervisor_comms, time_machine |
| 1729 | + ): |
| 1730 | + """Test that operator extra links are pushed to xcoms even when task fails.""" |
| 1731 | + instant = timezone.datetime(2024, 12, 3, 10, 0) |
| 1732 | + time_machine.move_to(instant, tick=False) |
| 1733 | + |
| 1734 | + class DummyTestOperator(BaseOperator): |
| 1735 | + operator_extra_links = (AirflowLink(),) |
| 1736 | + |
| 1737 | + def execute(self, context): |
| 1738 | + raise ValueError("Task failed intentionally") |
| 1739 | + |
| 1740 | + task = DummyTestOperator(task_id="task_with_operator_extra_links") |
| 1741 | + runtime_ti = create_runtime_ti(task=task) |
| 1742 | + context = runtime_ti.get_template_context() |
| 1743 | + runtime_ti.start_date = instant |
| 1744 | + runtime_ti.end_date = instant |
| 1745 | + |
| 1746 | + state, _, error = run(runtime_ti, context=context, log=mock.MagicMock()) |
| 1747 | + assert state == TaskInstanceState.FAILED |
| 1748 | + assert error is not None |
| 1749 | + |
| 1750 | + with mock.patch.object(XCom, "_set_xcom_in_db") as mock_xcom_set: |
| 1751 | + finalize( |
| 1752 | + runtime_ti, |
| 1753 | + log=mock.MagicMock(), |
| 1754 | + state=TaskInstanceState.FAILED, |
| 1755 | + context=context, |
| 1756 | + error=error, |
| 1757 | + ) |
| 1758 | + assert mock_xcom_set.mock_calls == [ |
| 1759 | + call( |
| 1760 | + key="_link_AirflowLink", |
| 1761 | + value="https://airflow.apache.org", |
| 1762 | + dag_id=runtime_ti.dag_id, |
| 1763 | + task_id=runtime_ti.task_id, |
| 1764 | + run_id=runtime_ti.run_id, |
| 1765 | + map_index=runtime_ti.map_index, |
| 1766 | + ) |
| 1767 | + ] |
| 1768 | + |
| 1769 | + def test_operator_extra_links_exception_handling( |
| 1770 | + self, create_runtime_ti, mock_supervisor_comms, time_machine |
| 1771 | + ): |
| 1772 | + """Test that exceptions in get_link() don't prevent other links from being pushed.""" |
| 1773 | + instant = timezone.datetime(2024, 12, 3, 10, 0) |
| 1774 | + time_machine.move_to(instant, tick=False) |
| 1775 | + |
| 1776 | + class FailingLink(BaseOperatorLink): |
| 1777 | + """A link that raises an exception when get_link is called.""" |
| 1778 | + |
| 1779 | + name = "failing_link" |
| 1780 | + |
| 1781 | + def get_link(self, operator, *, ti_key): |
| 1782 | + raise ValueError("Link generation failed") |
| 1783 | + |
| 1784 | + class DummyTestOperator(BaseOperator): |
| 1785 | + operator_extra_links = (FailingLink(), AirflowLink()) |
| 1786 | + |
| 1787 | + def execute(self, context): |
| 1788 | + pass |
| 1789 | + |
| 1790 | + task = DummyTestOperator(task_id="task_with_multiple_links") |
| 1791 | + runtime_ti = create_runtime_ti(task=task) |
| 1792 | + context = runtime_ti.get_template_context() |
| 1793 | + runtime_ti.start_date = instant |
| 1794 | + runtime_ti.end_date = instant |
| 1795 | + |
| 1796 | + with mock.patch.object(XCom, "_set_xcom_in_db") as mock_xcom_set: |
| 1797 | + finalize( |
| 1798 | + runtime_ti, |
| 1799 | + log=mock.MagicMock(), |
| 1800 | + state=TaskInstanceState.SUCCESS, |
| 1801 | + context=context, |
| 1802 | + ) |
| 1803 | + assert mock_xcom_set.mock_calls == [ |
| 1804 | + call( |
| 1805 | + key="_link_AirflowLink", |
| 1806 | + value="https://airflow.apache.org", |
| 1807 | + dag_id=runtime_ti.dag_id, |
| 1808 | + task_id=runtime_ti.task_id, |
| 1809 | + run_id=runtime_ti.run_id, |
| 1810 | + map_index=runtime_ti.map_index, |
| 1811 | + ) |
| 1812 | + ] |
| 1813 | + |
1726 | 1814 | @pytest.mark.parametrize( |
1727 | 1815 | ["cmd", "rendered_cmd"], |
1728 | 1816 | [ |
|
0 commit comments