diff --git a/airflow-core/newsfragments/54505.significant.rst b/airflow-core/newsfragments/54505.significant.rst new file mode 100644 index 0000000000000..b7454f6acc0e6 --- /dev/null +++ b/airflow-core/newsfragments/54505.significant.rst @@ -0,0 +1,60 @@ +Move task-level exception imports into the Task SDK + +Airflow now sources task-facing exceptions (``AirflowSkipException``, ``TaskDeferred``, etc.) from +``airflow.sdk.exceptions``. ``airflow.exceptions`` still exposes the same exceptions, but they are +proxies that emit ``DeprecatedImportWarning`` so Dag authors can migrate before the shim is removed. + +**What changed:** + +- Runtime code now consistently raises the SDK versions of task-level exceptions. +- The Task SDK redefines these classes so workers no longer depend on ``airflow-core`` at runtime. +- ``airflow.providers.common.compat.sdk`` centralizes compatibility imports for providers. + +**Behaviour changes:** + +- Sensors and other helpers that validate user input now raise ``ValueError`` (instead of + ``AirflowException``) when ``poke_interval``/ ``timeout`` arguments are invalid. +- Importing deprecated exception names from ``airflow.exceptions`` logs a warning directing users to + the SDK import path. + +**Exceptions now provided by ``airflow.sdk.exceptions``:** + +- ``AirflowException`` and ``AirflowNotFoundException`` +- ``AirflowRescheduleException`` and ``AirflowSensorTimeout`` +- ``AirflowSkipException``, ``AirflowFailException``, ``AirflowTaskTimeout``, ``AirflowTaskTerminated`` +- ``TaskDeferred``, ``TaskDeferralTimeout``, ``TaskDeferralError`` +- ``DagRunTriggerException`` and ``DownstreamTasksSkipped`` +- ``AirflowDagCycleException`` and ``AirflowInactiveAssetInInletOrOutletException`` +- ``ParamValidationError``, ``DuplicateTaskIdFound``, ``TaskAlreadyInTaskGroup``, ``TaskNotFound``, ``XComNotFound`` + +**Backward compatibility:** + +- Existing Dags/operators that still import from ``airflow.exceptions`` continue to work, though + they log warnings. +- Providers can rely on ``airflow.providers.common.compat.sdk`` to keep one import path that works + across supported Airflow versions. + +**Migration:** + +- Update custom operators, sensors, and extensions to import exception classes from + ``airflow.sdk.exceptions`` (or from the provider compat shim). +- Adjust custom validation code to expect ``ValueError`` for invalid sensor arguments if it + previously caught ``AirflowException``. + +* Types of change + + * [ ] Dag changes + * [ ] Config changes + * [ ] API changes + * [ ] CLI changes + * [x] Behaviour changes + * [ ] Plugin changes + * [ ] Dependency changes + * [x] Code interface changes + +* Migration rules needed + + * Import task-level exceptions such as ``AirflowSkipException``, ``TaskDeferred``, + ``AirflowFailException``, etc. from ``airflow.sdk.exceptions`` instead of ``airflow.exceptions``. + * Update custom sensors/operators that validated arguments by catching ``AirflowException`` to + expect ``ValueError`` for invalid ``poke_interval`` / ``timeout`` inputs. diff --git a/airflow-core/src/airflow/dag_processing/dagbag.py b/airflow-core/src/airflow/dag_processing/dagbag.py index 0259992dc78f7..c4aecc651fb41 100644 --- a/airflow-core/src/airflow/dag_processing/dagbag.py +++ b/airflow-core/src/airflow/dag_processing/dagbag.py @@ -41,10 +41,8 @@ AirflowClusterPolicyError, AirflowClusterPolicySkipDag, AirflowClusterPolicyViolation, - AirflowDagCycleException, AirflowDagDuplicatedIdException, AirflowException, - AirflowTaskTimeout, UnknownExecutorException, ) from airflow.executors.executor_loader import ExecutorLoader @@ -119,6 +117,8 @@ def timeout(seconds=1, error_message="Timeout"): def handle_timeout(signum, frame): """Log information and raises AirflowTaskTimeout.""" log.error("Process timed out, PID: %s", str(os.getpid())) + from airflow.sdk.exceptions import AirflowTaskTimeout + raise AirflowTaskTimeout(error_message) try: @@ -588,6 +588,7 @@ def bag_dag(self, dag: DAG): except Exception as e: self.log.exception(e) raise AirflowClusterPolicyError(e) + from airflow.sdk.exceptions import AirflowDagCycleException try: prev_dag = self.dags.get(dag.dag_id) diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index 0dd08e3a6b7b2..506cf515d39d8 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -36,7 +36,7 @@ ) from airflow.configuration import conf from airflow.dag_processing.dagbag import DagBag -from airflow.exceptions import TaskNotFound +from airflow.sdk.exceptions import TaskNotFound from airflow.sdk.execution_time.comms import ( ConnectionResult, DeleteVariable, diff --git a/airflow-core/src/airflow/example_dags/example_skip_dag.py b/airflow-core/src/airflow/example_dags/example_skip_dag.py index 7b87dd732c1f0..a2fe914c0356d 100644 --- a/airflow-core/src/airflow/example_dags/example_skip_dag.py +++ b/airflow-core/src/airflow/example_dags/example_skip_dag.py @@ -24,9 +24,9 @@ import pendulum -from airflow.exceptions import AirflowSkipException from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk import DAG, BaseOperator, TriggerRule +from airflow.sdk.exceptions import AirflowSkipException if TYPE_CHECKING: from airflow.sdk import Context diff --git a/airflow-core/src/airflow/exceptions.py b/airflow-core/src/airflow/exceptions.py index 17d85e8605004..3790941229101 100644 --- a/airflow-core/src/airflow/exceptions.py +++ b/airflow-core/src/airflow/exceptions.py @@ -21,121 +21,66 @@ from __future__ import annotations -from collections.abc import Collection, Sequence -from datetime import datetime, timedelta from http import HTTPStatus -from typing import TYPE_CHECKING, Any, NamedTuple +from typing import TYPE_CHECKING, NamedTuple if TYPE_CHECKING: from airflow.models import DagRun - from airflow.sdk.definitions.asset import AssetNameRef, AssetUniqueKey, AssetUriRef - from airflow.utils.state import DagRunState # Re exporting AirflowConfigException from shared configuration from airflow._shared.configuration.exceptions import AirflowConfigException as AirflowConfigException +try: + from airflow.sdk.exceptions import ( + AirflowException, + AirflowNotFoundException, + AirflowRescheduleException, + TaskNotFound, + ) +except ModuleNotFoundError: + # When _AIRFLOW__AS_LIBRARY is set, airflow.sdk may not be installed. + # In that case, we define fallback exception classes that mirror the SDK ones. + class AirflowException(Exception): # type: ignore[no-redef] + """Base exception for Airflow errors.""" -class AirflowException(Exception): - """ - Base class for all Airflow's errors. - - Each custom exception should be derived from this class. - """ - - status_code = HTTPStatus.INTERNAL_SERVER_ERROR - - def serialize(self): - cls = self.__class__ - return f"{cls.__module__}.{cls.__name__}", (str(self),), {} - + pass -class AirflowBadRequest(AirflowException): - """Raise when the application or server cannot handle the request.""" + class AirflowNotFoundException(AirflowException): # type: ignore[no-redef] + """Raise when a requested object is not found.""" - status_code = HTTPStatus.BAD_REQUEST + pass + class TaskNotFound(AirflowException): # type: ignore[no-redef] + """Raise when a Task is not available in the system.""" -class AirflowNotFoundException(AirflowException): - """Raise when the requested object/resource is not available in the system.""" + pass - status_code = HTTPStatus.NOT_FOUND + class AirflowRescheduleException(AirflowException): # type: ignore[no-redef] + """ + Raise when the task should be re-scheduled at a later time. + :param reschedule_date: The date when the task should be rescheduled + """ -class AirflowSensorTimeout(AirflowException): - """Raise when there is a timeout on sensor polling.""" + def __init__(self, reschedule_date): + super().__init__() + self.reschedule_date = reschedule_date + def serialize(self): + cls = self.__class__ + return f"{cls.__module__}.{cls.__name__}", (), {"reschedule_date": self.reschedule_date} -class AirflowRescheduleException(AirflowException): - """ - Raise when the task should be re-scheduled at a later time. - :param reschedule_date: The date when the task should be rescheduled - """ - - def __init__(self, reschedule_date): - super().__init__() - self.reschedule_date = reschedule_date +class AirflowBadRequest(AirflowException): + """Raise when the application or server cannot handle the request.""" - def serialize(self): - cls = self.__class__ - return f"{cls.__module__}.{cls.__name__}", (), {"reschedule_date": self.reschedule_date} + status_code = HTTPStatus.BAD_REQUEST class InvalidStatsNameException(AirflowException): """Raise when name of the stats is invalid.""" -# Important to inherit BaseException instead of AirflowException->Exception, since this Exception is used -# to explicitly interrupt ongoing task. Code that does normal error-handling should not treat -# such interrupt as an error that can be handled normally. (Compare with KeyboardInterrupt) -class AirflowTaskTimeout(BaseException): - """Raise when the task execution times-out.""" - - -class AirflowTaskTerminated(BaseException): - """Raise when the task execution is terminated.""" - - -class AirflowSkipException(AirflowException): - """Raise when the task should be skipped.""" - - -class AirflowFailException(AirflowException): - """Raise when the task should be failed without retrying.""" - - -class _AirflowExecuteWithInactiveAssetExecption(AirflowFailException): - main_message: str - - def __init__(self, inactive_asset_keys: Collection[AssetUniqueKey | AssetNameRef | AssetUriRef]) -> None: - self.inactive_asset_keys = inactive_asset_keys - - @staticmethod - def _render_asset_key(key: AssetUniqueKey | AssetNameRef | AssetUriRef) -> str: - from airflow.sdk.definitions.asset import AssetNameRef, AssetUniqueKey, AssetUriRef - - if isinstance(key, AssetUniqueKey): - return f"Asset(name={key.name!r}, uri={key.uri!r})" - if isinstance(key, AssetNameRef): - return f"Asset.ref(name={key.name!r})" - if isinstance(key, AssetUriRef): - return f"Asset.ref(uri={key.uri!r})" - return repr(key) # Should not happen, but let's fails more gracefully in an exception. - - def __str__(self) -> str: - return f"{self.main_message}: {self.inactive_assets_message}" - - @property - def inactive_assets_message(self) -> str: - return ", ".join(self._render_asset_key(key) for key in self.inactive_asset_keys) - - -class AirflowInactiveAssetInInletOrOutletException(_AirflowExecuteWithInactiveAssetExecption): - """Raise when the task is executed with inactive assets in its inlet or outlet.""" - - main_message = "Task has the following inactive assets in its inlets or outlets" - - class AirflowOptionalProviderFeatureException(AirflowException): """Raise by providers when imports are missing for optional provider features.""" @@ -150,27 +95,6 @@ class AirflowInternalRuntimeError(BaseException): """ -class XComNotFound(AirflowException): - """Raise when an XCom reference is being resolved against a non-existent XCom.""" - - def __init__(self, dag_id: str, task_id: str, key: str) -> None: - super().__init__() - self.dag_id = dag_id - self.task_id = task_id - self.key = key - - def __str__(self) -> str: - return f'XComArg result from {self.task_id} at {self.dag_id} with key="{self.key}" is not found!' - - def serialize(self): - cls = self.__class__ - return ( - f"{cls.__module__}.{cls.__name__}", - (), - {"dag_id": self.dag_id, "task_id": self.task_id, "key": self.key}, - ) - - class AirflowDagDuplicatedIdException(AirflowException): """Raise when a DAG's ID is already used by another DAG.""" @@ -238,39 +162,10 @@ def serialize(self): ) -class DuplicateTaskIdFound(AirflowException): - """Raise when a Task with duplicate task_id is defined in the same DAG.""" - - -class TaskAlreadyInTaskGroup(AirflowException): - """Raise when a Task cannot be added to a TaskGroup since it already belongs to another TaskGroup.""" - - def __init__(self, task_id: str, existing_group_id: str | None, new_group_id: str) -> None: - super().__init__(task_id, new_group_id) - self.task_id = task_id - self.existing_group_id = existing_group_id - self.new_group_id = new_group_id - - def __str__(self) -> str: - if self.existing_group_id is None: - existing_group = "the DAG's root group" - else: - existing_group = f"group {self.existing_group_id!r}" - return f"cannot add {self.task_id!r} to {self.new_group_id!r} (already in {existing_group})" - - class SerializationError(AirflowException): """A problem occurred when trying to serialize something.""" -class ParamValidationError(AirflowException): - """Raise when DAG params is invalid.""" - - -class TaskNotFound(AirflowNotFoundException): - """Raise when a Task is not available in the system.""" - - class TaskInstanceNotFound(AirflowNotFoundException): """Raise when a task instance is not available in the system.""" @@ -337,116 +232,6 @@ class VariableNotUnique(AirflowException): """Raise when multiple values are found for the same variable name.""" -class DownstreamTasksSkipped(AirflowException): - """ - Signal by an operator to skip its downstream tasks. - - Special exception raised to signal that the operator it was raised from wishes to skip - downstream tasks. This is used in the ShortCircuitOperator. - - :param tasks: List of task_ids to skip or a list of tuples with task_id and map_index to skip. - """ - - def __init__(self, *, tasks: Sequence[str | tuple[str, int]]): - super().__init__() - self.tasks = tasks - - -# TODO: workout this to correct place https://github.com/apache/airflow/issues/44353 -class DagRunTriggerException(AirflowException): - """ - Signal by an operator to trigger a specific Dag Run of a dag. - - Special exception raised to signal that the operator it was raised from wishes to trigger - a specific Dag Run of a dag. This is used in the ``TriggerDagRunOperator``. - """ - - def __init__( - self, - *, - trigger_dag_id: str, - dag_run_id: str, - conf: dict | None, - logical_date: datetime | None, - reset_dag_run: bool, - skip_when_already_exists: bool, - wait_for_completion: bool, - allowed_states: list[str | DagRunState], - failed_states: list[str | DagRunState], - poke_interval: int, - deferrable: bool, - ): - super().__init__() - self.trigger_dag_id = trigger_dag_id - self.dag_run_id = dag_run_id - self.conf = conf - self.logical_date = logical_date - self.reset_dag_run = reset_dag_run - self.skip_when_already_exists = skip_when_already_exists - self.wait_for_completion = wait_for_completion - self.allowed_states = allowed_states - self.failed_states = failed_states - self.poke_interval = poke_interval - self.deferrable = deferrable - - -class TaskDeferred(BaseException): - """ - Signal an operator moving to deferred state. - - Special exception raised to signal that the operator it was raised from - wishes to defer until a trigger fires. Triggers can send execution back to task or end the task instance - directly. If the trigger should end the task instance itself, ``method_name`` does not matter, - and can be None; otherwise, provide the name of the method that should be used when - resuming execution in the task. - """ - - def __init__( - self, - *, - trigger, - method_name: str, - kwargs: dict[str, Any] | None = None, - timeout: timedelta | int | float | None = None, - ): - super().__init__() - self.trigger = trigger - self.method_name = method_name - self.kwargs = kwargs - self.timeout: timedelta | None - # Check timeout type at runtime - if isinstance(timeout, (int, float)): - self.timeout = timedelta(seconds=timeout) - else: - self.timeout = timeout - if self.timeout is not None and not hasattr(self.timeout, "total_seconds"): - raise ValueError("Timeout value must be a timedelta") - - def serialize(self): - cls = self.__class__ - return ( - f"{cls.__module__}.{cls.__name__}", - (), - { - "trigger": self.trigger, - "method_name": self.method_name, - "kwargs": self.kwargs, - "timeout": self.timeout, - }, - ) - - def __repr__(self) -> str: - return f"" - - -class TaskDeferralError(AirflowException): - """Raised when a task failed during deferral for some reason.""" - - -class TaskDeferralTimeout(AirflowException): - """Raise when there is a timeout on the deferral.""" - - # The try/except handling is needed after we moved all k8s classes to cncf.kubernetes provider # These two exceptions are used internally by Kubernetes Executor but also by PodGenerator, so we need # to leave them here in case older version of cncf.kubernetes provider is used to run KubernetesPodOperator @@ -518,23 +303,44 @@ def __init__(self, dag_id: str | None = None, message: str | None = None): super().__init__(f"An unexpected error occurred while trying to deserialize Dag '{dag_id}'") +class AirflowClearRunningTaskException(AirflowException): + """Raise when the user attempts to clear currently running tasks.""" + + +_DEPRECATED_EXCEPTIONS = { + "AirflowTaskTerminated": "airflow.sdk.exceptions.AirflowTaskTerminated", + "DuplicateTaskIdFound": "airflow.sdk.exceptions.DuplicateTaskIdFound", + "FailFastDagInvalidTriggerRule": "airflow.sdk.exceptions.FailFastDagInvalidTriggerRule", + "TaskAlreadyInTaskGroup": "airflow.sdk.exceptions.TaskAlreadyInTaskGroup", + "TaskDeferralTimeout": "airflow.sdk.exceptions.TaskDeferralTimeout", + "XComNotFound": "airflow.sdk.exceptions.XComNotFound", + "DownstreamTasksSkipped": "airflow.sdk.exceptions.DownstreamTasksSkipped", + "AirflowSensorTimeout": "airflow.sdk.exceptions.AirflowSensorTimeout", + "DagRunTriggerException": "airflow.sdk.exceptions.DagRunTriggerException", + "TaskDeferralError": "airflow.sdk.exceptions.TaskDeferralError", + "AirflowDagCycleException": "airflow.sdk.exceptions.AirflowDagCycleException", + "AirflowInactiveAssetInInletOrOutletException": "airflow.sdk.exceptions.AirflowInactiveAssetInInletOrOutletException", + "AirflowSkipException": "airflow.sdk.exceptions.AirflowSkipException", + "AirflowTaskTimeout": "airflow.sdk.exceptions.AirflowTaskTimeout", + "AirflowFailException": "airflow.sdk.exceptions.AirflowFailException", + "ParamValidationError": "airflow.sdk.exceptions.ParamValidationError", + "TaskDeferred": "airflow.sdk.exceptions.TaskDeferred", +} + + def __getattr__(name: str): """Provide backward compatibility for moved exceptions.""" - if name == "AirflowDagCycleException": + if name in _DEPRECATED_EXCEPTIONS: import warnings - from airflow.sdk.exceptions import AirflowDagCycleException + from airflow import DeprecatedImportWarning + from airflow.utils.module_loading import import_string + target_path = _DEPRECATED_EXCEPTIONS[name] warnings.warn( - "airflow.exceptions.AirflowDagCycleException is deprecated. " - "Use airflow.sdk.exceptions.AirflowDagCycleException instead.", - DeprecationWarning, + f"airflow.exceptions.{name} is deprecated and will be removed in a future version. Use {target_path} instead.", + DeprecatedImportWarning, stacklevel=2, ) - return AirflowDagCycleException - + return import_string(target_path) raise AttributeError(f"module '{__name__}' has no attribute '{name}'") - - -class AirflowClearRunningTaskException(AirflowException): - """Raise when the user attempts to clear currently running tasks.""" diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index 8c1446fb919db..c0f2b488322fe 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -63,7 +63,6 @@ AirflowException, DeserializationError, SerializationError, - TaskDeferred, TaskNotFound, ) from airflow.models.connection import Connection @@ -739,6 +738,7 @@ def serialize( :meta private: """ from airflow.sdk.definitions._internal.types import is_arg_set + from airflow.sdk.exceptions import TaskDeferred if not is_arg_set(var): return cls._encode(None, type_=DAT.ARG_NOT_SET) diff --git a/airflow-core/tests/unit/dags/test_assets.py b/airflow-core/tests/unit/dags/test_assets.py index 6a0b08f9ba6a1..8d34a2cfa30fc 100644 --- a/airflow-core/tests/unit/dags/test_assets.py +++ b/airflow-core/tests/unit/dags/test_assets.py @@ -19,11 +19,11 @@ from datetime import datetime -from airflow.exceptions import AirflowFailException, AirflowSkipException from airflow.models.dag import DAG from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.python import PythonOperator from airflow.sdk.definitions.asset import Asset +from airflow.sdk.exceptions import AirflowFailException, AirflowSkipException skip_task_dag_asset = Asset(uri="s3://dag_with_skip_task/output_1.txt", name="skip", extra={"hi": "bye"}) fail_task_dag_asset = Asset(uri="s3://dag_with_fail_task/output_1.txt", name="fail", extra={"hi": "bye"}) diff --git a/airflow-core/tests/unit/dags/test_on_failure_callback.py b/airflow-core/tests/unit/dags/test_on_failure_callback.py index 1e0f276e4aee4..78104f06ab17e 100644 --- a/airflow-core/tests/unit/dags/test_on_failure_callback.py +++ b/airflow-core/tests/unit/dags/test_on_failure_callback.py @@ -19,10 +19,10 @@ import os from datetime import datetime -from airflow.exceptions import AirflowFailException from airflow.models.dag import DAG from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.python import PythonOperator +from airflow.sdk.exceptions import AirflowFailException DEFAULT_DATE = datetime(2016, 1, 1) diff --git a/airflow-core/tests/unit/serialization/test_serialized_objects.py b/airflow-core/tests/unit/serialization/test_serialized_objects.py index 327c505497a8a..0b069c8ce1ac7 100644 --- a/airflow-core/tests/unit/serialization/test_serialized_objects.py +++ b/airflow-core/tests/unit/serialization/test_serialized_objects.py @@ -715,8 +715,8 @@ def test_logging_propogated_by_default(self, caplog): assert caplog.messages == ["test"] def test_resume_execution(self): - from airflow.exceptions import TaskDeferralTimeout from airflow.models.trigger import TriggerFailureReason + from airflow.sdk.exceptions import TaskDeferralTimeout op = BaseOperator(task_id="hi") with pytest.raises(TaskDeferralTimeout): diff --git a/airflow-core/tests/unit/ti_deps/deps/test_mapped_task_upstream_dep.py b/airflow-core/tests/unit/ti_deps/deps/test_mapped_task_upstream_dep.py index 71355ed790adb..9905b7905e7e8 100644 --- a/airflow-core/tests/unit/ti_deps/deps/test_mapped_task_upstream_dep.py +++ b/airflow-core/tests/unit/ti_deps/deps/test_mapped_task_upstream_dep.py @@ -21,9 +21,9 @@ import pytest -from airflow.exceptions import AirflowFailException, AirflowSkipException from airflow.models.xcom import XCOM_RETURN_KEY from airflow.providers.standard.operators.empty import EmptyOperator +from airflow.sdk.exceptions import AirflowFailException, AirflowSkipException from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.deps.base_ti_dep import TIDepStatus from airflow.ti_deps.deps.mapped_task_upstream_dep import MappedTaskUpstreamDep diff --git a/devel-common/src/tests_common/test_utils/version_compat.py b/devel-common/src/tests_common/test_utils/version_compat.py index e30c692278fe8..3b3aedaf2d7ab 100644 --- a/devel-common/src/tests_common/test_utils/version_compat.py +++ b/devel-common/src/tests_common/test_utils/version_compat.py @@ -65,7 +65,6 @@ def get_sqlalchemy_version_tuple() -> tuple[int, int, int]: SQLALCHEMY_V_1_4 = (1, 4, 0) <= get_sqlalchemy_version_tuple() < (2, 0, 0) SQLALCHEMY_V_2_0 = (2, 0, 0) <= get_sqlalchemy_version_tuple() < (2, 1, 0) - __all__ = [ "AIRFLOW_V_3_0_PLUS", "AIRFLOW_V_3_0_1", diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/datasync.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/datasync.py index 8074978bbb1a1..38ecef1bad3b8 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/datasync.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/datasync.py @@ -21,8 +21,9 @@ import time from urllib.parse import urlsplit -from airflow.exceptions import AirflowBadRequest, AirflowException, AirflowTaskTimeout +from airflow.exceptions import AirflowBadRequest, AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.common.compat.sdk import AirflowTaskTimeout class DataSyncHook(AwsBaseHook): diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/step_function.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/step_function.py index 48da7cb1150a8..e4c99fb01a1dc 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/step_function.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/step_function.py @@ -18,8 +18,8 @@ import json -from airflow.exceptions import AirflowFailException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.common.compat.sdk import AirflowFailException class StepFunctionHook(AwsBaseHook): diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/datasync.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/datasync.py index 7b2b7282efca7..47146a19144a3 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/datasync.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/datasync.py @@ -23,11 +23,12 @@ from collections.abc import Sequence from typing import TYPE_CHECKING, Any -from airflow.exceptions import AirflowException, AirflowTaskTimeout +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.datasync import DataSyncHook from airflow.providers.amazon.aws.links.datasync import DataSyncTaskExecutionLink, DataSyncTaskLink from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator from airflow.providers.amazon.aws.utils.mixins import aws_template_fields +from airflow.providers.common.compat.sdk import AirflowTaskTimeout if TYPE_CHECKING: from airflow.utils.context import Context diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py index 58e83ff66d24a..2fa79da95ec47 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py @@ -51,9 +51,10 @@ lazy_load_command, ) from airflow.configuration import conf -from airflow.exceptions import AirflowProviderDeprecationWarning, AirflowTaskTimeout +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.executors.base_executor import BaseExecutor from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.providers.common.compat.sdk import AirflowTaskTimeout from airflow.stats import Stats from airflow.utils.state import TaskInstanceState diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index ef57e1310436a..9680d7807edeb 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -41,10 +41,10 @@ import airflow.settings as settings from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowTaskTimeout +from airflow.exceptions import AirflowException from airflow.executors.base_executor import BaseExecutor from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS -from airflow.providers.common.compat.sdk import timeout +from airflow.providers.common.compat.sdk import AirflowTaskTimeout, timeout from airflow.stats import Stats from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/exceptions.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/exceptions.py index 5503c743797ad..208ce8f13971c 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/exceptions.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/exceptions.py @@ -16,9 +16,7 @@ # under the License. from __future__ import annotations -from airflow.exceptions import ( - AirflowException, -) +from airflow.exceptions import AirflowException class PodMutationHookException(AirflowException): diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py index cf8362413111d..a9d47dc866311 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py @@ -41,11 +41,6 @@ from urllib3.exceptions import HTTPError from airflow.configuration import conf -from airflow.exceptions import ( - AirflowException, - AirflowSkipException, - TaskDeferred, -) from airflow.providers.cncf.kubernetes import pod_generator from airflow.providers.cncf.kubernetes.backcompat.backwards_compat_converters import ( convert_affinity, @@ -82,12 +77,13 @@ PodPhase, ) from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_1_PLUS -from airflow.providers.common.compat.sdk import XCOM_RETURN_KEY +from airflow.providers.common.compat.sdk import XCOM_RETURN_KEY, AirflowSkipException, TaskDeferred if AIRFLOW_V_3_1_PLUS: from airflow.sdk import BaseOperator else: from airflow.models import BaseOperator +from airflow.exceptions import AirflowException from airflow.settings import pod_mutation_hook from airflow.utils import yaml from airflow.utils.helpers import prune_dict, validate_key diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/version_compat.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/version_compat.py index b7df8fa3e5183..2fb2ac93a122c 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/version_compat.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/version_compat.py @@ -35,4 +35,8 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0) -__all__ = ["AIRFLOW_V_3_0_PLUS", "AIRFLOW_V_3_1_PLUS"] + +__all__ = [ + "AIRFLOW_V_3_0_PLUS", + "AIRFLOW_V_3_1_PLUS", +] diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_cmd.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_cmd.py index 59c973004044b..8cc1af370d7ed 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_cmd.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_cmd.py @@ -20,7 +20,7 @@ import pytest -from airflow.exceptions import AirflowSkipException +from airflow.providers.common.compat.sdk import AirflowSkipException from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS diff --git a/providers/common/compat/src/airflow/providers/common/compat/sdk.py b/providers/common/compat/src/airflow/providers/common/compat/sdk.py index 080c2658bd5b2..7be62f61b7904 100644 --- a/providers/common/compat/src/airflow/providers/common/compat/sdk.py +++ b/providers/common/compat/src/airflow/providers/common/compat/sdk.py @@ -25,6 +25,8 @@ from typing import TYPE_CHECKING +from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_0_PLUS + if TYPE_CHECKING: import airflow.sdk.io as io # noqa: F401 import airflow.sdk.timezone as timezone # noqa: F401 @@ -77,6 +79,20 @@ from airflow.sdk.definitions.context import context_merge as context_merge from airflow.sdk.definitions.mappedoperator import MappedOperator as MappedOperator from airflow.sdk.definitions.template import literal as literal + from airflow.sdk.exceptions import ( + AirflowFailException as AirflowFailException, + AirflowSkipException as AirflowSkipException, + AirflowTaskTimeout as AirflowTaskTimeout, + ParamValidationError as ParamValidationError, + TaskDeferred as TaskDeferred, + ) + + # Airflow 3-only exceptions (conditionally imported) + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.exceptions import ( + DagRunTriggerException as DagRunTriggerException, + DownstreamTasksSkipped as DownstreamTasksSkipped, + ) from airflow.sdk.execution_time.context import ( AIRFLOW_VAR_NAME_FORMAT_MAPPING as AIRFLOW_VAR_NAME_FORMAT_MAPPING, context_to_airflow_vars as context_to_airflow_vars, @@ -199,8 +215,27 @@ # XCom & Task Communication # ============================================================================ "XCOM_RETURN_KEY": "airflow.models.xcom", + # ============================================================================ + # Exceptions (deprecated in airflow.exceptions, prefer SDK) + # ============================================================================ + # Exceptions available in both Airflow 2 and 3 + "AirflowSkipException": ("airflow.sdk.exceptions", "airflow.exceptions"), + "AirflowTaskTimeout": ("airflow.sdk.exceptions", "airflow.exceptions"), + "AirflowFailException": ("airflow.sdk.exceptions", "airflow.exceptions"), + "ParamValidationError": ("airflow.sdk.exceptions", "airflow.exceptions"), + "TaskDeferred": ("airflow.sdk.exceptions", "airflow.exceptions"), +} + +# Airflow 3-only exceptions (not available in Airflow 2) +_AIRFLOW_3_ONLY_EXCEPTIONS: dict[str, tuple[str, ...]] = { + "DownstreamTasksSkipped": ("airflow.sdk.exceptions", "airflow.exceptions"), + "DagRunTriggerException": ("airflow.sdk.exceptions", "airflow.exceptions"), } +# Add Airflow 3-only exceptions to _IMPORT_MAP if running Airflow 3+ +if AIRFLOW_V_3_0_PLUS: + _IMPORT_MAP.update(_AIRFLOW_3_ONLY_EXCEPTIONS) + # Module map: module_name -> module_path(s) # For entire modules that have been moved (e.g., timezone) # Usage: from airflow.providers.common.compat.lazy_compat import timezone diff --git a/providers/common/compat/src/airflow/providers/common/compat/version_compat.py b/providers/common/compat/src/airflow/providers/common/compat/version_compat.py index 02d0f1ac162b0..6fed38474aafd 100644 --- a/providers/common/compat/src/airflow/providers/common/compat/version_compat.py +++ b/providers/common/compat/src/airflow/providers/common/compat/version_compat.py @@ -32,8 +32,8 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: return airflow_version.major, airflow_version.minor, airflow_version.micro -AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) -AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0) +AIRFLOW_V_3_0_PLUS: bool = get_base_airflow_version_tuple() >= (3, 0, 0) +AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0) if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperator diff --git a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py index 8095373e7d72b..1e82a01bd1ea3 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py +++ b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py @@ -24,9 +24,14 @@ from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, SupportsAbs from airflow import XComArg -from airflow.exceptions import AirflowException, AirflowFailException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.models import SkipMixin -from airflow.providers.common.compat.sdk import BaseHook, BaseOperator +from airflow.providers.common.compat.sdk import ( + AirflowFailException, + AirflowSkipException, + BaseHook, + BaseOperator, +) from airflow.providers.common.sql.hooks.handlers import fetch_all_handler, return_single_query_results from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.utils.helpers import merge_dicts diff --git a/providers/docker/src/airflow/providers/docker/exceptions.py b/providers/docker/src/airflow/providers/docker/exceptions.py index b5eb0e6c11985..e6def099febd8 100644 --- a/providers/docker/src/airflow/providers/docker/exceptions.py +++ b/providers/docker/src/airflow/providers/docker/exceptions.py @@ -19,7 +19,8 @@ from __future__ import annotations -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException +from airflow.providers.common.compat.sdk import AirflowSkipException class DockerContainerFailedException(AirflowException): diff --git a/providers/edge3/src/airflow/providers/edge3/example_dags/win_test.py b/providers/edge3/src/airflow/providers/edge3/example_dags/win_test.py index 7aa2822f219dc..4b33ab76c0146 100644 --- a/providers/edge3/src/airflow/providers/edge3/example_dags/win_test.py +++ b/providers/edge3/src/airflow/providers/edge3/example_dags/win_test.py @@ -32,10 +32,11 @@ from time import sleep from typing import TYPE_CHECKING, Any -from airflow.exceptions import AirflowException, AirflowNotFoundException, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.models import BaseOperator from airflow.models.dag import DAG from airflow.models.variable import Variable +from airflow.providers.common.compat.sdk import AirflowSkipException from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk.execution_time.context import context_to_airflow_vars diff --git a/providers/google/src/airflow/providers/google/cloud/operators/bigquery.py b/providers/google/src/airflow/providers/google/cloud/operators/bigquery.py index 454127154054c..90a0175e678d2 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/bigquery.py @@ -33,7 +33,8 @@ from google.cloud.bigquery.table import RowIterator, Table, TableListItem, TableReference from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.providers.common.compat.sdk import AirflowSkipException from airflow.providers.common.sql.operators.sql import ( # for _parse_boolean SQLCheckOperator, SQLColumnCheckOperator, diff --git a/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py b/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py index 5db00f60ce33f..84e88a0462b57 100644 --- a/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py +++ b/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py @@ -30,8 +30,8 @@ from google.cloud.orchestration.airflow.service_v1.types import Environment, ExecuteAirflowCommandResponse from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowSkipException -from airflow.providers.common.compat.sdk import BaseSensorOperator +from airflow.exceptions import AirflowException +from airflow.providers.common.compat.sdk import AirflowSkipException, BaseSensorOperator from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerHook from airflow.providers.google.cloud.triggers.cloud_composer import ( CloudComposerDAGRunTrigger, diff --git a/providers/google/src/airflow/providers/google/suite/transfers/local_to_drive.py b/providers/google/src/airflow/providers/google/suite/transfers/local_to_drive.py index e5b34d727df4b..4c3b2c5fd0cc9 100644 --- a/providers/google/src/airflow/providers/google/suite/transfers/local_to_drive.py +++ b/providers/google/src/airflow/providers/google/suite/transfers/local_to_drive.py @@ -23,7 +23,7 @@ from pathlib import Path from typing import TYPE_CHECKING -from airflow.exceptions import AirflowFailException +from airflow.providers.common.compat.sdk import AirflowFailException from airflow.providers.google.suite.hooks.drive import GoogleDriveHook from airflow.providers.google.version_compat import BaseOperator diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/synapse.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/synapse.py index 2157583520e35..22acec334d9a8 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/synapse.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/synapse.py @@ -24,8 +24,8 @@ from azure.synapse.artifacts import ArtifactsClient from azure.synapse.spark import SparkClient -from airflow.exceptions import AirflowException, AirflowTaskTimeout -from airflow.providers.common.compat.sdk import BaseHook +from airflow.exceptions import AirflowException +from airflow.providers.common.compat.sdk import AirflowTaskTimeout, BaseHook from airflow.providers.microsoft.azure.utils import ( add_managed_identity_connection_widgets, get_field, diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py index 0be1ccc484dc7..cbf570b43c938 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py @@ -42,8 +42,8 @@ ) from msrestazure.azure_exceptions import CloudError -from airflow.exceptions import AirflowException, AirflowTaskTimeout -from airflow.providers.common.compat.sdk import BaseOperator +from airflow.exceptions import AirflowException +from airflow.providers.common.compat.sdk import AirflowTaskTimeout, BaseOperator from airflow.providers.microsoft.azure.hooks.container_instance import AzureContainerInstanceHook from airflow.providers.microsoft.azure.hooks.container_registry import AzureContainerRegistryHook from airflow.providers.microsoft.azure.hooks.container_volume import AzureContainerVolumeHook diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py index 1d86311284406..04bc0331b2262 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py @@ -26,8 +26,8 @@ Any, ) -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred -from airflow.providers.common.compat.sdk import XCOM_RETURN_KEY, BaseOperator +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.providers.common.compat.sdk import XCOM_RETURN_KEY, BaseOperator, TaskDeferred from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook from airflow.providers.microsoft.azure.triggers.msgraph import ( MSGraphTrigger, diff --git a/providers/slack/src/airflow/providers/slack/transfers/sql_to_slack.py b/providers/slack/src/airflow/providers/slack/transfers/sql_to_slack.py index 4b3cba934c4b0..d66c0e95ef976 100644 --- a/providers/slack/src/airflow/providers/slack/transfers/sql_to_slack.py +++ b/providers/slack/src/airflow/providers/slack/transfers/sql_to_slack.py @@ -22,7 +22,8 @@ from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, Any, Literal -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.providers.common.compat.sdk import AirflowSkipException from airflow.providers.slack.hooks.slack import SlackHook from airflow.providers.slack.transfers.base_sql_to_slack import BaseSqlToSlackOperator from airflow.providers.slack.utils import parse_filename diff --git a/providers/ssh/src/airflow/providers/ssh/operators/ssh.py b/providers/ssh/src/airflow/providers/ssh/operators/ssh.py index 3aef97df0c410..208eba6318c37 100644 --- a/providers/ssh/src/airflow/providers/ssh/operators/ssh.py +++ b/providers/ssh/src/airflow/providers/ssh/operators/ssh.py @@ -23,8 +23,8 @@ from typing import TYPE_CHECKING from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowSkipException -from airflow.providers.common.compat.sdk import BaseOperator +from airflow.exceptions import AirflowException +from airflow.providers.common.compat.sdk import AirflowSkipException, BaseOperator from airflow.providers.ssh.hooks.ssh import SSHHook try: diff --git a/providers/standard/src/airflow/providers/standard/example_dags/example_bash_decorator.py b/providers/standard/src/airflow/providers/standard/example_dags/example_bash_decorator.py index 20225be4b39fd..e5b005ce40c48 100644 --- a/providers/standard/src/airflow/providers/standard/example_dags/example_bash_decorator.py +++ b/providers/standard/src/airflow/providers/standard/example_dags/example_bash_decorator.py @@ -19,11 +19,11 @@ import pendulum -from airflow.exceptions import AirflowSkipException from airflow.providers.common.compat.sdk import TriggerRule from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.utils.weekday import WeekDay from airflow.sdk import chain, dag, task +from airflow.sdk.exceptions import AirflowSkipException @dag(schedule=None, start_date=pendulum.datetime(2023, 1, 1, tz="UTC"), catchup=False) diff --git a/providers/standard/src/airflow/providers/standard/operators/bash.py b/providers/standard/src/airflow/providers/standard/operators/bash.py index 7c5f0269a967e..59bedafbd11db 100644 --- a/providers/standard/src/airflow/providers/standard/operators/bash.py +++ b/providers/standard/src/airflow/providers/standard/operators/bash.py @@ -24,8 +24,8 @@ from functools import cached_property from typing import TYPE_CHECKING, Any, cast -from airflow.exceptions import AirflowException, AirflowSkipException -from airflow.providers.common.compat.sdk import context_to_airflow_vars +from airflow.exceptions import AirflowException +from airflow.providers.common.compat.sdk import AirflowSkipException, context_to_airflow_vars from airflow.providers.standard.hooks.subprocess import SubprocessHook, SubprocessResult, working_directory from airflow.providers.standard.version_compat import BaseOperator diff --git a/providers/standard/src/airflow/providers/standard/operators/python.py b/providers/standard/src/airflow/providers/standard/operators/python.py index b597ef38cc5f1..70ee108851c0e 100644 --- a/providers/standard/src/airflow/providers/standard/operators/python.py +++ b/providers/standard/src/airflow/providers/standard/operators/python.py @@ -45,11 +45,10 @@ AirflowConfigException, AirflowException, AirflowProviderDeprecationWarning, - AirflowSkipException, DeserializingResultError, ) from airflow.models.variable import Variable -from airflow.providers.common.compat.sdk import context_merge +from airflow.providers.common.compat.sdk import AirflowSkipException, context_merge from airflow.providers.standard.hooks.package_index import PackageIndexHook from airflow.providers.standard.utils.python_virtualenv import ( _execute_in_subprocess, diff --git a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py index c0f8709fa87f9..7f2ae765d145f 100644 --- a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py +++ b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py @@ -242,7 +242,7 @@ def execute(self, context: Context): self._trigger_dag_af_2(context=context, run_id=run_id, parsed_logical_date=parsed_logical_date) def _trigger_dag_af_3(self, context, run_id, parsed_logical_date): - from airflow.exceptions import DagRunTriggerException + from airflow.providers.common.compat.sdk import DagRunTriggerException raise DagRunTriggerException( trigger_dag_id=self.trigger_dag_id, diff --git a/providers/standard/src/airflow/providers/standard/sensors/bash.py b/providers/standard/src/airflow/providers/standard/sensors/bash.py index 6283b990dae5f..5ddb03b10b183 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/bash.py +++ b/providers/standard/src/airflow/providers/standard/sensors/bash.py @@ -22,8 +22,7 @@ from tempfile import NamedTemporaryFile, TemporaryDirectory, gettempdir from typing import TYPE_CHECKING -from airflow.exceptions import AirflowFailException -from airflow.providers.common.compat.sdk import BaseSensorOperator +from airflow.providers.common.compat.sdk import AirflowFailException, BaseSensorOperator if TYPE_CHECKING: from airflow.providers.common.compat.sdk import Context diff --git a/providers/standard/src/airflow/providers/standard/sensors/external_task.py b/providers/standard/src/airflow/providers/standard/sensors/external_task.py index 5bfe97b4143e2..28c4492ef71ae 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/external_task.py +++ b/providers/standard/src/airflow/providers/standard/sensors/external_task.py @@ -24,9 +24,8 @@ from typing import TYPE_CHECKING, ClassVar from airflow.configuration import conf -from airflow.exceptions import AirflowSkipException from airflow.models.dag import DagModel -from airflow.providers.common.compat.sdk import BaseOperatorLink, BaseSensorOperator +from airflow.providers.common.compat.sdk import AirflowSkipException, BaseOperatorLink, BaseSensorOperator from airflow.providers.standard.exceptions import ( DuplicateStateError, ExternalDagDeletedError, diff --git a/providers/standard/src/airflow/providers/standard/sensors/time_delta.py b/providers/standard/src/airflow/providers/standard/sensors/time_delta.py index bc27afd531823..a1c1d21e7b884 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/time_delta.py +++ b/providers/standard/src/airflow/providers/standard/sensors/time_delta.py @@ -26,8 +26,8 @@ from packaging.version import Version from airflow.configuration import conf -from airflow.exceptions import AirflowProviderDeprecationWarning, AirflowSkipException -from airflow.providers.common.compat.sdk import BaseSensorOperator, timezone +from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.providers.common.compat.sdk import AirflowSkipException, BaseSensorOperator, timezone from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS diff --git a/providers/standard/src/airflow/providers/standard/triggers/hitl.py b/providers/standard/src/airflow/providers/standard/triggers/hitl.py index b36a3413dcd8e..3579c3595de88 100644 --- a/providers/standard/src/airflow/providers/standard/triggers/hitl.py +++ b/providers/standard/src/airflow/providers/standard/triggers/hitl.py @@ -30,7 +30,7 @@ from asgiref.sync import sync_to_async -from airflow.exceptions import ParamValidationError +from airflow.providers.common.compat.sdk import ParamValidationError from airflow.sdk import Param from airflow.sdk.definitions.param import ParamsDict from airflow.sdk.execution_time.hitl import ( diff --git a/providers/standard/src/airflow/providers/standard/utils/skipmixin.py b/providers/standard/src/airflow/providers/standard/utils/skipmixin.py index 2a00551f34f84..62e3b90fcf3b4 100644 --- a/providers/standard/src/airflow/providers/standard/utils/skipmixin.py +++ b/providers/standard/src/airflow/providers/standard/utils/skipmixin.py @@ -63,7 +63,7 @@ def _set_state_to_skipped( """ # Import is internal for backward compatibility when importing PythonOperator # from airflow.providers.common.compat.standard.operators - from airflow.exceptions import DownstreamTasksSkipped + from airflow.providers.common.compat.sdk import DownstreamTasksSkipped # The following could be applied only for non-mapped tasks, # as future mapped tasks have not been expanded yet. Such tasks diff --git a/task-sdk/src/airflow/sdk/bases/operator.py b/task-sdk/src/airflow/sdk/bases/operator.py index 0292d9abef6a6..f040d99bc2f85 100644 --- a/task-sdk/src/airflow/sdk/bases/operator.py +++ b/task-sdk/src/airflow/sdk/bases/operator.py @@ -35,7 +35,6 @@ import attrs -from airflow.exceptions import RemovedInAirflow4Warning from airflow.sdk import TriggerRule, timezone from airflow.sdk._shared.secrets_masker import redact from airflow.sdk.definitions._internal.abstractoperator import ( @@ -62,6 +61,7 @@ from airflow.sdk.definitions.edges import EdgeModifier from airflow.sdk.definitions.mappedoperator import OperatorPartial, validate_mapping_kwargs from airflow.sdk.definitions.param import ParamsDict +from airflow.sdk.exceptions import RemovedInAirflow4Warning from airflow.task.priority_strategy import ( PriorityWeightStrategy, airflow_priority_weight_strategies, @@ -110,9 +110,6 @@ def db_safe_priority(priority_weight: int) -> int: "cross_downstream", ] -# TODO: Task-SDK -AirflowException = RuntimeError - class TriggerFailureReason(str, Enum): """ @@ -176,7 +173,7 @@ def parse_retries(retries: Any) -> int | None: try: parsed_retries = int(retries) except (TypeError, ValueError): - raise AirflowException(f"'retries' type must be int, not {type(retries).__name__}") + raise RuntimeError(f"'retries' type must be int, not {type(retries).__name__}") return parsed_retries @@ -407,7 +404,7 @@ def wrapper(self, *args, **kwargs): if not cls.test_mode and sentinel is not self: message = f"{self.__class__.__name__}.{func.__name__} cannot be called outside of the Task Runner!" if not self.allow_nested_operators: - raise AirflowException(message) + raise RuntimeError(message) self.log.warning(message) # Now that we've logged, set sentinel so that `super()` calls don't log again @@ -1607,13 +1604,13 @@ def defer( be None; otherwise, provide the name of the method that should be used when resuming execution in the task. """ - from airflow.exceptions import TaskDeferred + from airflow.sdk.exceptions import TaskDeferred raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context): """Entrypoint method called by the Task Runner (instead of execute) when this task is resumed.""" - from airflow.exceptions import TaskDeferralError, TaskDeferralTimeout + from airflow.sdk.exceptions import TaskDeferralError, TaskDeferralTimeout if next_kwargs is None: next_kwargs = {} diff --git a/task-sdk/src/airflow/sdk/bases/sensor.py b/task-sdk/src/airflow/sdk/bases/sensor.py index 7865ed84f873a..8fa87b26b1a53 100644 --- a/task-sdk/src/airflow/sdk/bases/sensor.py +++ b/task-sdk/src/airflow/sdk/bases/sensor.py @@ -25,7 +25,10 @@ from datetime import timedelta from typing import TYPE_CHECKING, Any -from airflow.exceptions import ( +from airflow.sdk import timezone +from airflow.sdk.bases.operator import BaseOperator +from airflow.sdk.configuration import conf +from airflow.sdk.exceptions import ( AirflowException, AirflowFailException, AirflowRescheduleException, @@ -35,9 +38,6 @@ TaskDeferralError, TaskDeferralTimeout, ) -from airflow.sdk import timezone -from airflow.sdk.bases.operator import BaseOperator -from airflow.sdk.configuration import conf if TYPE_CHECKING: from airflow.sdk.definitions.context import Context @@ -145,9 +145,7 @@ def _coerce_poke_interval(poke_interval: float | timedelta) -> timedelta: return poke_interval if isinstance(poke_interval, (int, float)) and poke_interval >= 0: return timedelta(seconds=poke_interval) - raise AirflowException( - "Operator arg `poke_interval` must be timedelta object or a non-negative number" - ) + raise ValueError("Operator arg `poke_interval` must be timedelta object or a non-negative number") @staticmethod def _coerce_timeout(timeout: float | timedelta) -> timedelta: @@ -155,7 +153,7 @@ def _coerce_timeout(timeout: float | timedelta) -> timedelta: return timeout if isinstance(timeout, (int, float)) and timeout >= 0: return timedelta(seconds=timeout) - raise AirflowException("Operator arg `timeout` must be timedelta object or a non-negative number") + raise ValueError("Operator arg `timeout` must be timedelta object or a non-negative number") @staticmethod def _coerce_max_wait(max_wait: float | timedelta | None) -> timedelta | None: @@ -163,15 +161,15 @@ def _coerce_max_wait(max_wait: float | timedelta | None) -> timedelta | None: return max_wait if isinstance(max_wait, (int, float)) and max_wait >= 0: return timedelta(seconds=max_wait) - raise AirflowException("Operator arg `max_wait` must be timedelta object or a non-negative number") + raise ValueError("Operator arg `max_wait` must be timedelta object or a non-negative number") def _validate_input_values(self) -> None: if not isinstance(self.poke_interval, (int, float)) or self.poke_interval < 0: - raise AirflowException("The poke_interval must be a non-negative number") + raise ValueError("The poke_interval must be a non-negative number") if not isinstance(self.timeout, (int, float)) or self.timeout < 0: - raise AirflowException("The timeout must be a non-negative number") + raise ValueError("The timeout must be a non-negative number") if self.mode not in self.valid_modes: - raise AirflowException( + raise ValueError( f"The mode must be one of {self.valid_modes},'{self.dag.dag_id if self.has_dag() else ''} " f".{self.task_id}'; received '{self.mode}'." ) diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/setup_teardown.py b/task-sdk/src/airflow/sdk/definitions/_internal/setup_teardown.py index 961099c6f9765..55b59d296f79a 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/setup_teardown.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/setup_teardown.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, cast -from airflow.exceptions import AirflowException +from airflow.sdk.exceptions import AirflowException if TYPE_CHECKING: from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator diff --git a/task-sdk/src/airflow/sdk/definitions/connection.py b/task-sdk/src/airflow/sdk/definitions/connection.py index 89727b11c1286..7e0ca4f44dd36 100644 --- a/task-sdk/src/airflow/sdk/definitions/connection.py +++ b/task-sdk/src/airflow/sdk/definitions/connection.py @@ -25,8 +25,7 @@ import attrs -from airflow.exceptions import AirflowException, AirflowNotFoundException -from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType +from airflow.sdk.exceptions import AirflowException, AirflowNotFoundException, AirflowRuntimeError, ErrorType log = logging.getLogger(__name__) diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index 598b995b4cbf5..3abd205fc3b65 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -39,12 +39,6 @@ from dateutil.relativedelta import relativedelta from airflow import settings -from airflow.exceptions import ( - DuplicateTaskIdFound, - ParamValidationError, - RemovedInAirflow4Warning, - TaskNotFound, -) from airflow.sdk import TaskInstanceState, TriggerRule from airflow.sdk.bases.operator import BaseOperator from airflow.sdk.definitions._internal.node import validate_key @@ -53,7 +47,14 @@ from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.deadline import DeadlineAlert from airflow.sdk.definitions.param import DagParam, ParamsDict -from airflow.sdk.exceptions import AirflowDagCycleException, FailFastDagInvalidTriggerRule +from airflow.sdk.exceptions import ( + AirflowDagCycleException, + DuplicateTaskIdFound, + FailFastDagInvalidTriggerRule, + ParamValidationError, + RemovedInAirflow4Warning, + TaskNotFound, +) from airflow.timetables.base import Timetable from airflow.timetables.simple import ( AssetTriggeredTimetable, diff --git a/task-sdk/src/airflow/sdk/definitions/decorators/condition.py b/task-sdk/src/airflow/sdk/definitions/decorators/condition.py index 50503e0ddcd8e..9c96b7389e9c5 100644 --- a/task-sdk/src/airflow/sdk/definitions/decorators/condition.py +++ b/task-sdk/src/airflow/sdk/definitions/decorators/condition.py @@ -20,8 +20,8 @@ from functools import wraps from typing import TYPE_CHECKING, Any, TypeVar -from airflow.exceptions import AirflowSkipException from airflow.sdk.bases.decorator import Task, _TaskDecorator +from airflow.sdk.exceptions import AirflowSkipException if TYPE_CHECKING: from typing import TypeAlias diff --git a/task-sdk/src/airflow/sdk/definitions/decorators/setup_teardown.py b/task-sdk/src/airflow/sdk/definitions/decorators/setup_teardown.py index e8059956e670d..0cdd8b52a5f37 100644 --- a/task-sdk/src/airflow/sdk/definitions/decorators/setup_teardown.py +++ b/task-sdk/src/airflow/sdk/definitions/decorators/setup_teardown.py @@ -20,10 +20,10 @@ from collections.abc import Callable from typing import TYPE_CHECKING, cast -from airflow.exceptions import AirflowException from airflow.sdk.bases.operator import BaseOperator from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext from airflow.sdk.definitions.decorators.task_group import _TaskGroupFactory +from airflow.sdk.exceptions import AirflowException if TYPE_CHECKING: from airflow.sdk.bases.decorator import _TaskDecorator diff --git a/task-sdk/src/airflow/sdk/definitions/operator_resources.py b/task-sdk/src/airflow/sdk/definitions/operator_resources.py index 4073af137443e..d6cbf10039d00 100644 --- a/task-sdk/src/airflow/sdk/definitions/operator_resources.py +++ b/task-sdk/src/airflow/sdk/definitions/operator_resources.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -from airflow.exceptions import AirflowException from airflow.sdk.configuration import conf # Constants for resources (megabytes are the base unit) @@ -41,7 +40,7 @@ class Resource: def __init__(self, name, units_str, qty): if qty < 0: - raise AirflowException( + raise ValueError( f"Received resource quantity {qty} for resource {name}, " f"but resource quantity must be non-negative." ) diff --git a/task-sdk/src/airflow/sdk/definitions/param.py b/task-sdk/src/airflow/sdk/definitions/param.py index 9d7722257a8b0..410da71fde7ea 100644 --- a/task-sdk/src/airflow/sdk/definitions/param.py +++ b/task-sdk/src/airflow/sdk/definitions/param.py @@ -23,9 +23,9 @@ from collections.abc import ItemsView, Iterable, Mapping, MutableMapping, ValuesView from typing import TYPE_CHECKING, Any, ClassVar -from airflow.exceptions import AirflowException, ParamValidationError from airflow.sdk.definitions._internal.mixins import ResolveMixin from airflow.sdk.definitions._internal.types import NOTSET, is_arg_set +from airflow.sdk.exceptions import ParamValidationError if TYPE_CHECKING: from airflow.sdk.definitions.context import Context @@ -297,7 +297,7 @@ def resolve(self, context: Context) -> Any: return self._default with contextlib.suppress(KeyError): return context["params"][self._name] - raise AirflowException(f"No value could be resolved for parameter {self._name}") + raise RuntimeError(f"No value could be resolved for parameter {self._name}") def serialize(self) -> dict: """Serialize the DagParam object into a dictionary.""" diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py b/task-sdk/src/airflow/sdk/definitions/taskgroup.py index b2f9aa6b9097e..102cd541ba7f2 100644 --- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py @@ -30,14 +30,13 @@ import attrs import methodtools -from airflow.exceptions import ( - AirflowException, +from airflow.sdk import TriggerRule +from airflow.sdk.definitions._internal.node import DAGNode, validate_group_key +from airflow.sdk.exceptions import ( + AirflowDagCycleException, DuplicateTaskIdFound, TaskAlreadyInTaskGroup, ) -from airflow.sdk import TriggerRule -from airflow.sdk.definitions._internal.node import DAGNode, validate_group_key -from airflow.sdk.exceptions import AirflowDagCycleException if TYPE_CHECKING: from airflow.sdk.bases.operator import BaseOperator @@ -247,14 +246,14 @@ def add(self, task: DAGNode) -> DAGNode: if isinstance(task, TaskGroup): if self.dag: if task.dag is not None and self.dag is not task.dag: - raise RuntimeError( + raise ValueError( "Cannot mix TaskGroups from different Dags: %s and %s", self.dag, task.dag, ) task.dag = self.dag if task.children: - raise AirflowException("Cannot add a non-empty TaskGroup") + raise ValueError("Cannot add a non-empty TaskGroup") self.children[key] = task return task @@ -315,7 +314,7 @@ def update_relative( # Handles setting relationship between a TaskGroup and a task for task in other.roots: if not isinstance(task, DAGNode): - raise AirflowException( + raise RuntimeError( "Relationships can only be set between TaskGroup " f"or operators; received {task.__class__.__name__}" ) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index a674a47e19bde..ef316574df550 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -26,12 +26,12 @@ import attrs -from airflow.exceptions import AirflowException, XComNotFound from airflow.sdk import TriggerRule from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator from airflow.sdk.definitions._internal.mixins import DependencyMixin, ResolveMixin from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext from airflow.sdk.definitions._internal.types import NOTSET, is_arg_set +from airflow.sdk.exceptions import AirflowException, XComNotFound from airflow.sdk.execution_time.lazy_sequence import LazyXComSequence from airflow.sdk.execution_time.xcom import BaseXCom diff --git a/task-sdk/src/airflow/sdk/exceptions.py b/task-sdk/src/airflow/sdk/exceptions.py index 96df62184644b..f17297f2d923a 100644 --- a/task-sdk/src/airflow/sdk/exceptions.py +++ b/task-sdk/src/airflow/sdk/exceptions.py @@ -18,15 +18,38 @@ from __future__ import annotations import enum +from http import HTTPStatus from typing import TYPE_CHECKING, Any -from airflow.exceptions import AirflowException from airflow.sdk import TriggerRule if TYPE_CHECKING: + from collections.abc import Collection + + from airflow.sdk.definitions.asset import AssetNameRef, AssetUniqueKey, AssetUriRef from airflow.sdk.execution_time.comms import ErrorResponse +class AirflowException(Exception): + """ + Base class for all Airflow's errors. + + Each custom exception should be derived from this class. + """ + + status_code = HTTPStatus.INTERNAL_SERVER_ERROR + + def serialize(self): + cls = self.__class__ + return f"{cls.__module__}.{cls.__name__}", (str(self),), {} + + +class AirflowNotFoundException(AirflowException): + """Raise when the requested object/resource is not available in the system.""" + + status_code = HTTPStatus.NOT_FOUND + + class AirflowDagCycleException(AirflowException): """Raise when there is a cycle in Dag definition.""" @@ -71,6 +94,229 @@ def __str__(self) -> str: return f"unmappable return type {typename!r}" +class AirflowFailException(AirflowException): + """Raise when the task should be failed without retrying.""" + + +class _AirflowExecuteWithInactiveAssetExecption(AirflowFailException): + main_message: str + + def __init__(self, inactive_asset_keys: Collection[AssetUniqueKey | AssetNameRef | AssetUriRef]) -> None: + self.inactive_asset_keys = inactive_asset_keys + + @staticmethod + def _render_asset_key(key: AssetUniqueKey | AssetNameRef | AssetUriRef) -> str: + from airflow.sdk.definitions.asset import AssetNameRef, AssetUniqueKey, AssetUriRef + + if isinstance(key, AssetUniqueKey): + return f"Asset(name={key.name!r}, uri={key.uri!r})" + if isinstance(key, AssetNameRef): + return f"Asset.ref(name={key.name!r})" + if isinstance(key, AssetUriRef): + return f"Asset.ref(uri={key.uri!r})" + return repr(key) # Should not happen, but let's fails more gracefully in an exception. + + def __str__(self) -> str: + return f"{self.main_message}: {self.inactive_assets_message}" + + @property + def inactive_assets_message(self) -> str: + return ", ".join(self._render_asset_key(key) for key in self.inactive_asset_keys) + + +class AirflowInactiveAssetInInletOrOutletException(_AirflowExecuteWithInactiveAssetExecption): + """Raise when the task is executed with inactive assets in its inlet or outlet.""" + + main_message = "Task has the following inactive assets in its inlets or outlets" + + +class AirflowRescheduleException(AirflowException): + """ + Raise when the task should be re-scheduled at a later time. + + :param reschedule_date: The date when the task should be rescheduled + """ + + def __init__(self, reschedule_date): + super().__init__() + self.reschedule_date = reschedule_date + + def serialize(self): + cls = self.__class__ + return f"{cls.__module__}.{cls.__name__}", (), {"reschedule_date": self.reschedule_date} + + +class AirflowSensorTimeout(AirflowException): + """Raise when there is a timeout on sensor polling.""" + + +class AirflowSkipException(AirflowException): + """Raise when the task should be skipped.""" + + +class AirflowTaskTerminated(BaseException): + """Raise when the task execution is terminated.""" + + +# Important to inherit BaseException instead of AirflowException->Exception, since this Exception is used +# to explicitly interrupt ongoing task. Code that does normal error-handling should not treat +# such interrupt as an error that can be handled normally. (Compare with KeyboardInterrupt) +class AirflowTaskTimeout(BaseException): + """Raise when the task execution times-out.""" + + +class TaskDeferred(BaseException): + """ + Signal an operator moving to deferred state. + + Special exception raised to signal that the operator it was raised from + wishes to defer until a trigger fires. Triggers can send execution back to task or end the task instance + directly. If the trigger should end the task instance itself, ``method_name`` does not matter, + and can be None; otherwise, provide the name of the method that should be used when + resuming execution in the task. + """ + + def __init__( + self, + *, + trigger, + method_name: str, + kwargs: dict[str, Any] | None = None, + timeout=None, + ): + super().__init__() + self.trigger = trigger + self.method_name = method_name + self.kwargs = kwargs + self.timeout = timeout + + def serialize(self): + cls = self.__class__ + return ( + f"{cls.__module__}.{cls.__name__}", + (), + { + "trigger": self.trigger, + "method_name": self.method_name, + "kwargs": self.kwargs, + "timeout": self.timeout, + }, + ) + + def __repr__(self) -> str: + return f"" + + +class TaskDeferralError(AirflowException): + """Raised when a task failed during deferral for some reason.""" + + +class TaskDeferralTimeout(AirflowException): + """Raise when there is a timeout on the deferral.""" + + +class DagRunTriggerException(AirflowException): + """ + Signal by an operator to trigger a specific Dag Run of a dag. + + Special exception raised to signal that the operator it was raised from wishes to trigger + a specific Dag Run of a dag. This is used in the ``TriggerDagRunOperator``. + """ + + def __init__( + self, + *, + trigger_dag_id: str, + dag_run_id: str, + conf: dict | None, + logical_date=None, + reset_dag_run: bool, + skip_when_already_exists: bool, + wait_for_completion: bool, + allowed_states: list[str], + failed_states: list[str], + poke_interval: int, + deferrable: bool, + ): + super().__init__() + self.trigger_dag_id = trigger_dag_id + self.dag_run_id = dag_run_id + self.conf = conf + self.logical_date = logical_date + self.reset_dag_run = reset_dag_run + self.skip_when_already_exists = skip_when_already_exists + self.wait_for_completion = wait_for_completion + self.allowed_states = allowed_states + self.failed_states = failed_states + self.poke_interval = poke_interval + self.deferrable = deferrable + + +class DownstreamTasksSkipped(AirflowException): + """ + Signal by an operator to skip its downstream tasks. + + Special exception raised to signal that the operator it was raised from wishes to skip + downstream tasks. This is used in the ShortCircuitOperator. + + :param tasks: List of task_ids to skip or a list of tuples with task_id and map_index to skip. + """ + + def __init__(self, *, tasks): + super().__init__() + self.tasks = tasks + + +class XComNotFound(AirflowException): + """Raise when an XCom reference is being resolved against a non-existent XCom.""" + + def __init__(self, dag_id: str, task_id: str, key: str) -> None: + super().__init__() + self.dag_id = dag_id + self.task_id = task_id + self.key = key + + def __str__(self) -> str: + return f'XComArg result from {self.task_id} at {self.dag_id} with key="{self.key}" is not found!' + + def serialize(self): + cls = self.__class__ + return ( + f"{cls.__module__}.{cls.__name__}", + (), + {"dag_id": self.dag_id, "task_id": self.task_id, "key": self.key}, + ) + + +class ParamValidationError(AirflowException): + """Raise when DAG params is invalid.""" + + +class DuplicateTaskIdFound(AirflowException): + """Raise when a Task with duplicate task_id is defined in the same DAG.""" + + +class TaskAlreadyInTaskGroup(AirflowException): + """Raise when a Task cannot be added to a TaskGroup since it already belongs to another TaskGroup.""" + + def __init__(self, task_id: str, existing_group_id: str | None, new_group_id: str): + super().__init__(task_id, new_group_id) + self.task_id = task_id + self.existing_group_id = existing_group_id + self.new_group_id = new_group_id + + def __str__(self) -> str: + if self.existing_group_id is None: + existing_group = "the DAG's root group" + else: + existing_group = f"group {self.existing_group_id!r}" + return f"cannot add {self.task_id!r} to {self.new_group_id!r} (already in {existing_group})" + + +class TaskNotFound(AirflowException): + """Raise when a Task is not available in the system.""" + + class FailFastDagInvalidTriggerRule(AirflowException): """Raise when a dag has 'fail_fast' enabled yet has a non-default trigger rule.""" @@ -88,3 +334,10 @@ def check(cls, *, fail_fast: bool, trigger_rule: TriggerRule): def __str__(self) -> str: return f"A 'fail_fast' dag can only have {TriggerRule.ALL_SUCCESS} trigger rule" + + +class RemovedInAirflow4Warning(DeprecationWarning): + """Issued for usage of deprecated features that will be removed in Airflow4.""" + + deprecated_since: str | None = None + "Indicates the airflow version that started raising this deprecation warning" diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index cb774ee2b4e75..8ed06333ce8da 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -40,7 +40,7 @@ AssetUriRef, BaseAssetUniqueKey, ) -from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType +from airflow.sdk.exceptions import AirflowNotFoundException, AirflowRuntimeError, ErrorType from airflow.sdk.log import mask_secret if TYPE_CHECKING: @@ -172,7 +172,6 @@ def _get_connection(conn_id: str) -> Connection: ) # If no backend found the connection, raise an error - from airflow.exceptions import AirflowNotFoundException raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined") @@ -218,7 +217,6 @@ async def _async_get_connection(conn_id: str) -> Connection: ) # If no backend found the connection, raise an error - from airflow.exceptions import AirflowNotFoundException raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined") @@ -360,8 +358,6 @@ def __hash__(self): return hash(self.__class__.__name__) def get(self, conn_id: str, default_conn: Any = None) -> Any: - from airflow.exceptions import AirflowNotFoundException - try: return _get_connection(conn_id) except AirflowRuntimeError as e: diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 82807bba902b9..3748607b50e3c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -40,7 +40,6 @@ from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock from airflow.dag_processing.bundles.manager import DagBundlesManager -from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException, AirflowTaskTimeout from airflow.listeners.listener import get_listener_manager from airflow.sdk.api.client import get_hostname, getuser from airflow.sdk.api.datamodels._generated import ( @@ -58,7 +57,14 @@ from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.param import process_params -from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType +from airflow.sdk.exceptions import ( + AirflowException, + AirflowInactiveAssetInInletOrOutletException, + AirflowRuntimeError, + AirflowTaskTimeout, + ErrorType, + TaskDeferred, +) from airflow.sdk.execution_time.callback_runner import create_executable_runner from airflow.sdk.execution_time.comms import ( AssetEventDagRunReferenceResult, @@ -117,9 +123,9 @@ from pendulum.datetime import DateTime from structlog.typing import FilteringBoundLogger as Logger - from airflow.exceptions import DagRunTriggerException, TaskDeferred from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator from airflow.sdk.definitions.context import Context + from airflow.sdk.exceptions import DagRunTriggerException from airflow.sdk.types import OutletEventAccessorsProtocol @@ -897,8 +903,7 @@ def run( """Run the task in this process.""" import signal - from airflow.exceptions import ( - AirflowException, + from airflow.sdk.exceptions import ( AirflowFailException, AirflowRescheduleException, AirflowSensorTimeout, @@ -1133,7 +1138,6 @@ def _handle_trigger_dag_run( ti.xcom_push(key="trigger_run_id", value=drte.dag_run_id) if drte.deferrable: - from airflow.exceptions import TaskDeferred from airflow.providers.standard.triggers.external_task import DagStateTrigger defer = TaskDeferred( diff --git a/task-sdk/src/airflow/sdk/execution_time/timeout.py b/task-sdk/src/airflow/sdk/execution_time/timeout.py index fe4a0e8bd8c52..b1ccfb2045606 100644 --- a/task-sdk/src/airflow/sdk/execution_time/timeout.py +++ b/task-sdk/src/airflow/sdk/execution_time/timeout.py @@ -20,7 +20,7 @@ import structlog -from airflow.exceptions import AirflowTaskTimeout +from airflow.sdk.exceptions import AirflowTaskTimeout class TimeoutPosix: diff --git a/task-sdk/tests/task_sdk/bases/test_hook.py b/task-sdk/tests/task_sdk/bases/test_hook.py index eb0cefd19a71a..dd839a425ee5c 100644 --- a/task-sdk/tests/task_sdk/bases/test_hook.py +++ b/task-sdk/tests/task_sdk/bases/test_hook.py @@ -19,8 +19,8 @@ import pytest -from airflow.exceptions import AirflowNotFoundException from airflow.sdk import BaseHook +from airflow.sdk.exceptions import AirflowNotFoundException from airflow.sdk.execution_time.comms import ConnectionResult, GetConnection from tests_common.test_utils.config import conf_vars diff --git a/task-sdk/tests/task_sdk/bases/test_operator.py b/task-sdk/tests/task_sdk/bases/test_operator.py index c45f2ee23abd6..535ba511d0991 100644 --- a/task-sdk/tests/task_sdk/bases/test_operator.py +++ b/task-sdk/tests/task_sdk/bases/test_operator.py @@ -807,7 +807,7 @@ class Branch(Mixin, sql.BranchSQLOperator): pass # The following throws an exception if metaclass breaks MRO: - # airflow.exceptions.AirflowException: Invalid arguments were passed to Branch (task_id: test). Invalid arguments were: + # airflow.sdk.exceptions.AirflowException: Invalid arguments were passed to Branch (task_id: test). Invalid arguments were: # **kwargs: {'sql': 'sql', 'follow_task_ids_if_true': ['x'], 'follow_task_ids_if_false': ['y']} op = Branch( task_id="test", diff --git a/task-sdk/tests/task_sdk/bases/test_sensor.py b/task-sdk/tests/task_sdk/bases/test_sensor.py index ddb911d05fc91..14d4b7a4cf720 100644 --- a/task-sdk/tests/task_sdk/bases/test_sensor.py +++ b/task-sdk/tests/task_sdk/bases/test_sensor.py @@ -24,7 +24,12 @@ import pytest import time_machine -from airflow.exceptions import ( +from airflow.models.trigger import TriggerFailureReason +from airflow.providers.standard.operators.empty import EmptyOperator +from airflow.sdk import TaskInstanceState, timezone +from airflow.sdk.bases.sensor import BaseSensorOperator, PokeReturnValue, poke_mode_only +from airflow.sdk.definitions.dag import DAG +from airflow.sdk.exceptions import ( AirflowException, AirflowFailException, AirflowRescheduleException, @@ -32,11 +37,6 @@ AirflowSkipException, AirflowTaskTimeout, ) -from airflow.models.trigger import TriggerFailureReason -from airflow.providers.standard.operators.empty import EmptyOperator -from airflow.sdk import TaskInstanceState, timezone -from airflow.sdk.bases.sensor import BaseSensorOperator, PokeReturnValue, poke_mode_only -from airflow.sdk.definitions.dag import DAG from airflow.sdk.execution_time.comms import RescheduleTask, TaskRescheduleStartDate from airflow.sdk.timezone import datetime @@ -274,7 +274,7 @@ def run_duration(): assert state == TaskInstanceState.SUCCESS def test_invalid_mode(self): - with pytest.raises(AirflowException): + with pytest.raises(ValueError, match="The mode must be one of"): DummySensor(task_id="a", mode="foo") def test_ok_with_custom_reschedule_exception(self, make_sensor, run_task): @@ -311,7 +311,9 @@ def test_sensor_with_invalid_poke_interval(self): negative_poke_interval = -10 non_number_poke_interval = "abcd" positive_poke_interval = 10 - with pytest.raises(AirflowException): + with pytest.raises( + ValueError, match="Operator arg `poke_interval` must be timedelta object or a non-negative number" + ): DummySensor( task_id="test_sensor_task_1", return_value=None, @@ -319,7 +321,9 @@ def test_sensor_with_invalid_poke_interval(self): timeout=25, ) - with pytest.raises(AirflowException): + with pytest.raises( + ValueError, match="Operator arg `poke_interval` must be timedelta object or a non-negative number" + ): DummySensor( task_id="test_sensor_task_2", return_value=None, @@ -335,12 +339,16 @@ def test_sensor_with_invalid_timeout(self): negative_timeout = -25 non_number_timeout = "abcd" positive_timeout = 25 - with pytest.raises(AirflowException): + with pytest.raises( + ValueError, match="Operator arg `timeout` must be timedelta object or a non-negative number" + ): DummySensor( task_id="test_sensor_task_1", return_value=None, poke_interval=10, timeout=negative_timeout ) - with pytest.raises(AirflowException): + with pytest.raises( + ValueError, match="Operator arg `timeout` must be timedelta object or a non-negative number" + ): DummySensor( task_id="test_sensor_task_2", return_value=None, poke_interval=10, timeout=non_number_timeout ) diff --git a/task-sdk/tests/task_sdk/definitions/decorators/test_setup_teardown.py b/task-sdk/tests/task_sdk/definitions/decorators/test_setup_teardown.py index fbadaa84ac4c6..3539bd5a03fa8 100644 --- a/task-sdk/tests/task_sdk/definitions/decorators/test_setup_teardown.py +++ b/task-sdk/tests/task_sdk/definitions/decorators/test_setup_teardown.py @@ -19,10 +19,10 @@ import pytest -from airflow.exceptions import AirflowException from airflow.providers.standard.operators.bash import BashOperator from airflow.sdk import DAG, setup, task, task_group, teardown from airflow.sdk.definitions.decorators.setup_teardown import context_wrapper +from airflow.sdk.exceptions import AirflowException def make_task(name, type_, setup_=False, teardown_=False): diff --git a/task-sdk/tests/task_sdk/definitions/test_connection.py b/task-sdk/tests/task_sdk/definitions/test_connection.py index 96f32cc97934f..7267eeace90af 100644 --- a/task-sdk/tests/task_sdk/definitions/test_connection.py +++ b/task-sdk/tests/task_sdk/definitions/test_connection.py @@ -23,10 +23,9 @@ import pytest -from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.sdk import Connection from airflow.sdk.configuration import initialize_secrets_backends -from airflow.sdk.exceptions import ErrorType +from airflow.sdk.exceptions import AirflowException, AirflowNotFoundException, ErrorType from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse from airflow.sdk.execution_time.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS diff --git a/task-sdk/tests/task_sdk/definitions/test_dag.py b/task-sdk/tests/task_sdk/definitions/test_dag.py index 99d21ecb58966..b781bcc9e1490 100644 --- a/task-sdk/tests/task_sdk/definitions/test_dag.py +++ b/task-sdk/tests/task_sdk/definitions/test_dag.py @@ -24,12 +24,11 @@ import pytest -from airflow.exceptions import DuplicateTaskIdFound, RemovedInAirflow4Warning from airflow.sdk import Context, Label, TaskGroup from airflow.sdk.bases.operator import BaseOperator from airflow.sdk.definitions.dag import DAG, dag as dag_decorator from airflow.sdk.definitions.param import DagParam, Param, ParamsDict -from airflow.sdk.exceptions import AirflowDagCycleException +from airflow.sdk.exceptions import AirflowDagCycleException, DuplicateTaskIdFound, RemovedInAirflow4Warning DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc) diff --git a/task-sdk/tests/task_sdk/definitions/test_param.py b/task-sdk/tests/task_sdk/definitions/test_param.py index 9350fc1da119b..aa2aa71c51a03 100644 --- a/task-sdk/tests/task_sdk/definitions/test_param.py +++ b/task-sdk/tests/task_sdk/definitions/test_param.py @@ -20,8 +20,8 @@ import pytest -from airflow.exceptions import ParamValidationError from airflow.sdk.definitions.param import Param, ParamsDict +from airflow.sdk.exceptions import ParamValidationError from airflow.serialization.definitions.param import SerializedParam from airflow.serialization.serialized_objects import BaseSerialization diff --git a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py index 0bef4c59ad43f..af487851b07cb 100644 --- a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py +++ b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py @@ -23,10 +23,10 @@ import pytest import structlog -from airflow.exceptions import AirflowSkipException from airflow.sdk import TaskInstanceState from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.definitions.dag import DAG +from airflow.sdk.exceptions import AirflowSkipException from airflow.sdk.execution_time.comms import GetXCom, XComResult log = structlog.get_logger(__name__) diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index 36f1a3eaca860..26cc6720b8029 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -34,7 +34,7 @@ ) from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.variable import Variable -from airflow.sdk.exceptions import ErrorType +from airflow.sdk.exceptions import AirflowNotFoundException, ErrorType from airflow.sdk.execution_time.comms import ( AssetEventDagRunReferenceResult, AssetEventResult, @@ -956,7 +956,6 @@ def get_connection(self, conn_id): def test_get_connection_not_found_raises_error(self, mock_supervisor_comms): """Test that _get_connection raises error when no backend finds connection.""" - from airflow.exceptions import AirflowNotFoundException # Backend returns None (not found) class EmptyBackend: diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 7942bfdf57ccf..e094da8d79161 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -35,15 +35,6 @@ from task_sdk import FAKE_BUNDLE from uuid6 import uuid7 -from airflow.exceptions import ( - AirflowException, - AirflowFailException, - AirflowSensorTimeout, - AirflowSkipException, - AirflowTaskTerminated, - AirflowTaskTimeout, - DownstreamTasksSkipped, -) from airflow.listeners import hookimpl from airflow.listeners.listener import get_listener_manager from airflow.providers.standard.operators.python import PythonOperator @@ -70,7 +61,16 @@ from airflow.sdk.definitions._internal.types import NOTSET, SET_DURING_EXECUTION, is_arg_set from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, Dataset, Model from airflow.sdk.definitions.param import DagParam -from airflow.sdk.exceptions import ErrorType +from airflow.sdk.exceptions import ( + AirflowException, + AirflowFailException, + AirflowSensorTimeout, + AirflowSkipException, + AirflowTaskTerminated, + AirflowTaskTimeout, + DownstreamTasksSkipped, + ErrorType, +) from airflow.sdk.execution_time.comms import ( AssetEventResult, AssetEventsResult, @@ -3137,7 +3137,7 @@ def _execute_success(self, context): self.results.append("execute success") def _execute_skipped(self, context): - from airflow.exceptions import AirflowSkipException + from airflow.sdk.exceptions import AirflowSkipException self.results.append("execute skipped") raise AirflowSkipException @@ -3244,7 +3244,6 @@ def test_runtime_task_instance_log_url_property(self, create_runtime_ti, base_ur def test_task_runner_on_failure_callback_context(self, create_runtime_ti): """Test that on_failure_callback context has end_date and duration.""" - from airflow.exceptions import AirflowException def failure_callback(context): ti = context["task_instance"] @@ -3304,8 +3303,6 @@ def test_task_runner_both_callbacks_have_timing_info(self, create_runtime_ti): """Test that both success and failure callbacks receive accurate timing information.""" import time - from airflow.exceptions import AirflowException - success_data = {} failure_data = {}