diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 02065924fde5d..c85eae647f3e2 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -517,8 +517,17 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = synchronize_session=False, ) - for ti in executable_tis: - make_transient(ti) + self.log.info("CallingCB for [%s] TIs", len(executable_tis)) + for ti in executable_tis: + self.log.info("TI_CB: %s", ti.call_state_change_callback) + self.log.info("TI_CB Function name: %s", ti.call_state_change_callback.__name__) + self.log.info("TI_CB Function bound to instance: %s", ti.call_state_change_callback.__self__) + # Handles the following states: + # - QUEUED + ti.call_state_change_callback(state=TaskInstanceState.QUEUED) + self.log.info("Called already call_state_change_callback(state=QUEUED):\n\t") + make_transient(ti) + return executable_tis @provide_session diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index ab9543c127f20..70b3c4cde7ff2 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -216,6 +216,7 @@ def partial( on_failure_callback: Optional[TaskStateChangeCallback] = None, on_success_callback: Optional[TaskStateChangeCallback] = None, on_retry_callback: Optional[TaskStateChangeCallback] = None, + on_state_change_callback: Optional[TaskStateChangeCallback] = None, run_as_user: Optional[str] = None, executor_config: Optional[Dict] = None, inlets: Optional[Any] = None, @@ -276,6 +277,7 @@ def partial( partial_kwargs.setdefault("on_failure_callback", on_failure_callback) partial_kwargs.setdefault("on_retry_callback", on_retry_callback) partial_kwargs.setdefault("on_success_callback", on_success_callback) + partial_kwargs.setdefault("on_state_change_callback", on_state_change_callback) partial_kwargs.setdefault("run_as_user", run_as_user) partial_kwargs.setdefault("executor_config", executor_config) partial_kwargs.setdefault("inlets", inlets) @@ -564,6 +566,8 @@ class derived from this one results in the creation of a task object, that it is executed when retries occur. :param on_success_callback: much like the ``on_failure_callback`` except that it is executed when the task succeeds. + :param on_state_change_callback: much like the ``on_failure_callback`` except + that it is executed when the task state is changed. :param pre_execute: a function to be called immediately before task execution, receiving a context dictionary; raising an exception will prevent the task from being executed. @@ -667,6 +671,7 @@ class derived from this one results in the creation of a task object, 'on_failure_callback', 'on_success_callback', 'on_retry_callback', + 'on_state_change_callback', 'do_xcom_push', } @@ -730,6 +735,7 @@ def __init__( on_failure_callback: Optional[TaskStateChangeCallback] = None, on_success_callback: Optional[TaskStateChangeCallback] = None, on_retry_callback: Optional[TaskStateChangeCallback] = None, + on_state_change_callback: Optional[TaskStateChangeCallback] = None, pre_execute: Optional[TaskPreExecuteHook] = None, post_execute: Optional[TaskPostExecuteHook] = None, trigger_rule: str = DEFAULT_TRIGGER_RULE, @@ -793,6 +799,7 @@ def __init__( self.on_failure_callback = on_failure_callback self.on_success_callback = on_success_callback self.on_retry_callback = on_retry_callback + self.on_state_change_callback = on_state_change_callback self._pre_execute_hook = pre_execute self._post_execute_hook = post_execute @@ -1782,4 +1789,4 @@ def get_link(self, operator: AbstractOperator, *, ti_key: "TaskInstanceKey") -> :param operator: airflow operator :param ti_key: TaskInstance ID to return link for :return: link to external system - """ + """ \ No newline at end of file diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 8b7f3a1c39de4..2613be1e45997 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -954,6 +954,11 @@ def _check_for_removed_or_restored_tasks( Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1) existing_indexes[task].append(ti.map_index) expected_indexes[task] = range(total_length) + + # Handles the following states: + # - REMOVED + # - None + ti.call_state_change_callback() # Check if we have some missing indexes to create ti for missing_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list) for k, v in existing_indexes.items(): @@ -1208,6 +1213,11 @@ def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = NEW_SES ) .update({TI.state: State.SCHEDULED}, synchronize_session=False) ) + + # Handles the following state + # - SCHEDULED + for ti in schedulable_tis: + ti.call_state_change_callback(state=TaskInstanceState.SCHEDULED) # Tasks using EmptyOperator should not be executed, mark them as success if dummy_ti_ids: @@ -1251,4 +1261,4 @@ def get_log_filename_template(self, *, session: Session = NEW_SESSION) -> str: DeprecationWarning, stacklevel=2, ) - return self.get_log_template(session=session).filename + return self.get_log_template(session=session).filename \ No newline at end of file diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 7b8c2f670cc72..f0f6b50b9dab4 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -219,6 +219,9 @@ def clear_task_instances( # the task is terminated and becomes eligible for retry. ti.state = TaskInstanceState.RESTARTING job_ids.append(ti.job_id) + # Handles the following state + # - RESTARTING + ti.call_state_change_callback() else: task_id = ti.task_id if dag and dag.has_task(task_id): @@ -272,6 +275,7 @@ def clear_task_instances( delete_qry = TR.__table__.delete().where(conditions) session.execute(delete_qry) + if job_ids: from airflow.jobs.base_job import BaseJob @@ -987,6 +991,11 @@ def set_state(self, state: Optional[str], session=NEW_SESSION): self.end_date = self.end_date or current_time self.duration = (self.end_date - self.start_date).total_seconds() session.merge(self) + # Handles the following states: + # - UPSTREAM_FAILED + # - SKIPPED + # - FAILED + self.call_state_change_callback(state) @property def is_premature(self): @@ -1350,6 +1359,9 @@ def check_and_change_state_before_execution( task_reschedule: TR = TR.query_for_task_instance(self, session=session).first() if task_reschedule: self.start_date = task_reschedule.start_date + # Handles the following states + # - UP_FOR_RESCHEDULE + self.call_state_change_callback() # Secondly we find non-runnable but requeueable tis. We reset its state. # This is because we might have hit concurrency limits, @@ -1391,6 +1403,9 @@ def check_and_change_state_before_execution( if not test_mode: session.merge(self).task = task session.commit() + # Handles state + # - RUNNING + self.call_state_change_callback() # Closing all pooled connections to prevent # "max number of connections reached" @@ -1635,6 +1650,9 @@ def _update_ti_state_for_sensing(self, session=NEW_SESSION): self.start_date = timezone.utcnow() session.merge(self) session.commit() + # Handles the following states: + # - SENSING + self.call_state_change_callback() # Raise exception for sensing state raise AirflowSmartSensorException("Task successfully registered in smart sensor.") @@ -1730,6 +1748,9 @@ def _defer_task(self, session, defer: TaskDeferred): self.trigger_timeout = min(self.start_date + execution_timeout, self.trigger_timeout) else: self.trigger_timeout = self.start_date + execution_timeout + # Handles the following states: + # - DEFERRED + self.call_state_change_callback() def _run_execute_callback(self, context: Context, task): """Functions that need to be run before a Task is executed""" @@ -1772,6 +1793,11 @@ def _run_finished_callback(self, error: Optional[Union[str, Exception]] = None) task.on_retry_callback(context) except Exception: self.log.exception("Error when executing on_retry_callback") + # Handles the following states: + # - SUCCESS + # - UP_FOR_RETRY + # - FAILED + self.call_state_change_callback() @provide_session def run( @@ -2605,6 +2631,47 @@ def ti_selector_condition(cls, vals: Collection[Union[str, Tuple[str, int]]]) -> return filters[0] return or_(*filters) + # 'state' will be taken from the TaskInstance if not given. + # In some cases the task state is not updated in the TaskInstance, only in the metastore, + # to avoid querying the meta store on state change, we can just send the state the + # task should be on and call the callback with the correct state. + def call_state_change_callback(self, state=None): + try: + print("==================call_state_change_callback[%s, %s]============", state, self.state) + print("State changed for DAG: %s, Task: %s, to state: %s", self.dag_id, self.task_id, self.state) + print("1self:", self.__dict__) + if not hasattr(self, 'task'): + # self: {'_sa_instance_state': , 'dag_id': 'testing_logs', 'hostname': '', 'queued_dttm': datetime.datetime(2024, 8, 15, 22, 5, 9, 73142, tzinfo=Timezone('UTC')), 'next_method': None, 'run_id': 'scheduled__2024-08-15T22:00:00+00:00', 'unixname': 'root', 'queued_by_job_id': None, 'next_kwargs': None, 'map_index': -1, 'job_id': None, 'pid': None, 'start_date': None, 'pool': 'default_pool', 'executor_config': {}, 'end_date': None, 'pool_slots': 1, 'external_executor_id': None, '_try_number': 0, 'duration': None, 'queue': 'default', 'trigger_id': None, 'state': 'scheduled', 'priority_weight': 1, 'trigger_timeout': None, 'task_id': 'logging_from_presto_operator', 'max_tries': 1, 'operator': 'PrestoOperator', 'dag_run': , 'rendered_task_instance_fields': None, '_log': , 'test_mode': False, 'dag_model': } + # there is no self.task_instance + if not hasattr(self, 'task_instance'): + if not hasattr(self, 'get_task_instance'): + # Get task instance from DB + self.task = TaskInstance.get_task_instance(self.dag_id, self.task_id, self.execution_date) + # self.task = self.task_instance.task + + print("2self:", self.__dict__) + if hasattr(self, 'task'): + print("TASK: %s",self.task) + print("TASK_on_state_change_callback: %s",self.task.on_state_change_callback) + # Update state in current instance, for cases where the state is only updated in DB + # This to avoid querying the DB to get the latest state + if state: + self.state = state + + task = self.task + if task.on_state_change_callback is not None: + # Ensure the state and timestamps are up-to-date + # self.refresh_from_db() + context = self.get_template_context() + try: + task.on_state_change_callback(context) + except Exception: + self.log.exception("Error when executing on_state_change_callback") + else: + self.log.info("Couldn't get self.task object!") + except Exception as e: + log.error("Exception in call_state_change_callback: %s", e) + # State of the task instance. # Stores string version of the task state. @@ -2703,4 +2770,4 @@ def from_dict(cls, obj_dict: dict): if STATICA_HACK: # pragma: no cover from airflow.jobs.base_job import BaseJob - TaskInstance.queued_by_job = relationship(BaseJob) + TaskInstance.queued_by_job = relationship(BaseJob) \ No newline at end of file diff --git a/setup.py b/setup.py index 1ba4509b708cd..e5b97019fedcf 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ logger = logging.getLogger(__name__) -version = '2.3.4.post35' +version = '2.3.4.post36.dev0' AIRFLOW_SOURCES_ROOT = Path(__file__).parent.resolve() my_dir = dirname(__file__)