From f8480381c36c74d0a12eed56d3d4b4667f349ed8 Mon Sep 17 00:00:00 2001 From: Luis Gonzalez Date: Tue, 9 Jul 2024 14:11:52 -0600 Subject: [PATCH 1/7] on_state_change_callback() implementation --- airflow/jobs/scheduler_job.py | 3 +++ airflow/models/baseoperator.py | 7 ++++++ airflow/models/dagrun.py | 10 ++++++++ airflow/models/taskinstance.py | 42 +++++++++++++++++++++++++++++++++- 4 files changed, 61 insertions(+), 1 deletion(-) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 02065924fde5..4840f1ff7a35 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -518,6 +518,9 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = ) for ti in executable_tis: + # Handles the following states: + # - QUEUED + ti.call_state_change_callback() make_transient(ti) return executable_tis diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index ab9543c127f2..f7dc10da4133 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -216,6 +216,7 @@ def partial( on_failure_callback: Optional[TaskStateChangeCallback] = None, on_success_callback: Optional[TaskStateChangeCallback] = None, on_retry_callback: Optional[TaskStateChangeCallback] = None, + on_state_change_callback: Optional[TaskStateChangeCallback] = None, run_as_user: Optional[str] = None, executor_config: Optional[Dict] = None, inlets: Optional[Any] = None, @@ -276,6 +277,7 @@ def partial( partial_kwargs.setdefault("on_failure_callback", on_failure_callback) partial_kwargs.setdefault("on_retry_callback", on_retry_callback) partial_kwargs.setdefault("on_success_callback", on_success_callback) + partial_kwargs.setdefault("on_state_change_callback", on_state_change_callback) partial_kwargs.setdefault("run_as_user", run_as_user) partial_kwargs.setdefault("executor_config", executor_config) partial_kwargs.setdefault("inlets", inlets) @@ -564,6 +566,8 @@ class derived from this one results in the creation of a task object, that it is executed when retries occur. :param on_success_callback: much like the ``on_failure_callback`` except that it is executed when the task succeeds. + :param on_state_change_callback: much like the ``on_failure_callback`` except + that it is executed when the task state is changed. :param pre_execute: a function to be called immediately before task execution, receiving a context dictionary; raising an exception will prevent the task from being executed. @@ -667,6 +671,7 @@ class derived from this one results in the creation of a task object, 'on_failure_callback', 'on_success_callback', 'on_retry_callback', + 'on_state_change_callback', 'do_xcom_push', } @@ -730,6 +735,7 @@ def __init__( on_failure_callback: Optional[TaskStateChangeCallback] = None, on_success_callback: Optional[TaskStateChangeCallback] = None, on_retry_callback: Optional[TaskStateChangeCallback] = None, + on_state_change_callback: Optional[TaskStateChangeCallback] = None, pre_execute: Optional[TaskPreExecuteHook] = None, post_execute: Optional[TaskPostExecuteHook] = None, trigger_rule: str = DEFAULT_TRIGGER_RULE, @@ -793,6 +799,7 @@ def __init__( self.on_failure_callback = on_failure_callback self.on_success_callback = on_success_callback self.on_retry_callback = on_retry_callback + self.on_state_change_callback = on_state_change_callback self._pre_execute_hook = pre_execute self._post_execute_hook = post_execute diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 8b7f3a1c39de..197687c7ab1f 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -954,6 +954,11 @@ def _check_for_removed_or_restored_tasks( Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1) existing_indexes[task].append(ti.map_index) expected_indexes[task] = range(total_length) + + # Handles the following states: + # - REMOVED + # - None + ti.call_state_change_callback() # Check if we have some missing indexes to create ti for missing_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list) for k, v in existing_indexes.items(): @@ -1208,6 +1213,11 @@ def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = NEW_SES ) .update({TI.state: State.SCHEDULED}, synchronize_session=False) ) + + # Handles the following state + # - SCHEDULED + for ti in schedulable_tis: + ti.call_state_change_callback() # Tasks using EmptyOperator should not be executed, mark them as success if dummy_ti_ids: diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 7b8c2f670cc7..ae13c8859d53 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -219,6 +219,9 @@ def clear_task_instances( # the task is terminated and becomes eligible for retry. ti.state = TaskInstanceState.RESTARTING job_ids.append(ti.job_id) + # Handles the following state + # - RESTARTING + ti.call_state_change_callback() else: task_id = ti.task_id if dag and dag.has_task(task_id): @@ -272,6 +275,7 @@ def clear_task_instances( delete_qry = TR.__table__.delete().where(conditions) session.execute(delete_qry) + if job_ids: from airflow.jobs.base_job import BaseJob @@ -987,6 +991,11 @@ def set_state(self, state: Optional[str], session=NEW_SESSION): self.end_date = self.end_date or current_time self.duration = (self.end_date - self.start_date).total_seconds() session.merge(self) + # Handles the following states: + # - UPSTREAM_FAILED + # - SKIPPED + # - FAILED + self.call_state_change_callback() @property def is_premature(self): @@ -1350,6 +1359,9 @@ def check_and_change_state_before_execution( task_reschedule: TR = TR.query_for_task_instance(self, session=session).first() if task_reschedule: self.start_date = task_reschedule.start_date + # Handles the following states + # - UP_FOR_RESCHEDULE + self.call_state_change_callback() # Secondly we find non-runnable but requeueable tis. We reset its state. # This is because we might have hit concurrency limits, @@ -1391,6 +1403,9 @@ def check_and_change_state_before_execution( if not test_mode: session.merge(self).task = task session.commit() + # Handles state + # - RUNNING + self.call_state_change_callback() # Closing all pooled connections to prevent # "max number of connections reached" @@ -1635,6 +1650,9 @@ def _update_ti_state_for_sensing(self, session=NEW_SESSION): self.start_date = timezone.utcnow() session.merge(self) session.commit() + # Handles the following states: + # - SENSING + self.call_state_change_callback() # Raise exception for sensing state raise AirflowSmartSensorException("Task successfully registered in smart sensor.") @@ -1730,6 +1748,9 @@ def _defer_task(self, session, defer: TaskDeferred): self.trigger_timeout = min(self.start_date + execution_timeout, self.trigger_timeout) else: self.trigger_timeout = self.start_date + execution_timeout + # Handles the following states: + # - DEFERRED + self.call_state_change_callback() def _run_execute_callback(self, context: Context, task): """Functions that need to be run before a Task is executed""" @@ -1772,6 +1793,11 @@ def _run_finished_callback(self, error: Optional[Union[str, Exception]] = None) task.on_retry_callback(context) except Exception: self.log.exception("Error when executing on_retry_callback") + # Handles the following states: + # - SUCCESS + # - UP_FOR_RETRY + # - FAILED + self.call_state_change_callback() @provide_session def run( @@ -2604,7 +2630,21 @@ def ti_selector_condition(cls, vals: Collection[Union[str, Tuple[str, int]]]) -> if len(filters) == 1: return filters[0] return or_(*filters) - + + + @classmethod + def call_state_change_callback(self): + self.log.info("State changed for DAG: %s, Task: %s, to state: %s", self.dag_id, self.task_id, self.state) + task = self.task + if task.on_state_change_callback is not None: + # Ensure the state and timestamps are up-to-date + self.refresh_from_db() + context = self.get_template_context() + try: + task.on_state_change_callback(context) + except Exception: + self.log.exception("Error when executing on_state_change_callback") + # State of the task instance. # Stores string version of the task state. From 782d05501ec7ffb1854bed916b067cdc58c107ea Mon Sep 17 00:00:00 2001 From: Luis Gonzalez Date: Tue, 9 Jul 2024 14:14:02 -0600 Subject: [PATCH 2/7] Bump version to 2.3.4.post36 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1ba4509b708c..e02c8eb06aa2 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ logger = logging.getLogger(__name__) -version = '2.3.4.post35' +version = '2.3.4.post36' AIRFLOW_SOURCES_ROOT = Path(__file__).parent.resolve() my_dir = dirname(__file__) From ca6d180cb587e79315e866cb7a0d9bf3ebef9b83 Mon Sep 17 00:00:00 2001 From: Luis Gonzalez Date: Wed, 10 Jul 2024 13:23:44 -0600 Subject: [PATCH 3/7] Update version to debug0 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e02c8eb06aa2..6cdd997240f6 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ logger = logging.getLogger(__name__) -version = '2.3.4.post36' +version = '2.3.4.post36-debug0' AIRFLOW_SOURCES_ROOT = Path(__file__).parent.resolve() my_dir = dirname(__file__) From c21d284b20dd74cba634aeee526a187b26443e3e Mon Sep 17 00:00:00 2001 From: Luis Gonzalez Date: Wed, 10 Jul 2024 14:47:43 -0600 Subject: [PATCH 4/7] Update version to 2.3.4.post36.dev0 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6cdd997240f6..ff7dfc50d1f6 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ logger = logging.getLogger(__name__) -version = '2.3.4.post36-debug0' +version = '"2.3.4.post36.dev0' AIRFLOW_SOURCES_ROOT = Path(__file__).parent.resolve() my_dir = dirname(__file__) From 5c9fe3e85ca8ae07df08883222c9e458490c1ad2 Mon Sep 17 00:00:00 2001 From: Luis Gonzalez Date: Wed, 10 Jul 2024 14:51:31 -0600 Subject: [PATCH 5/7] fix version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ff7dfc50d1f6..e5b97019fedc 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ logger = logging.getLogger(__name__) -version = '"2.3.4.post36.dev0' +version = '2.3.4.post36.dev0' AIRFLOW_SOURCES_ROOT = Path(__file__).parent.resolve() my_dir = dirname(__file__) From d3cea9e6df6c02234d56192663bdfe5dfae51547 Mon Sep 17 00:00:00 2001 From: Luis Gonzalez Date: Fri, 12 Jul 2024 17:08:27 -0600 Subject: [PATCH 6/7] WIP Fix callback call --- airflow/jobs/scheduler_job.py | 12 +++++----- airflow/models/dagrun.py | 2 +- airflow/models/taskinstance.py | 41 ++++++++++++++++++++++------------ 3 files changed, 35 insertions(+), 20 deletions(-) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 4840f1ff7a35..1f4f97b0bf89 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -517,11 +517,13 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = synchronize_session=False, ) - for ti in executable_tis: - # Handles the following states: - # - QUEUED - ti.call_state_change_callback() - make_transient(ti) + self.log.info("CallingCB for [%s] TIs", len(executable_tis)) + for ti in executable_tis: + self.log.info("TI_CB: %s", ti.call_state_change_callback) + # Handles the following states: + # - QUEUED + ti.call_state_change_callback(state=TaskInstanceState.QUEUED) + make_transient(ti) return executable_tis @provide_session diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 197687c7ab1f..4df03c0b2a24 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -1217,7 +1217,7 @@ def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = NEW_SES # Handles the following state # - SCHEDULED for ti in schedulable_tis: - ti.call_state_change_callback() + ti.call_state_change_callback(state=TaskInstanceState.SCHEDULED) # Tasks using EmptyOperator should not be executed, mark them as success if dummy_ti_ids: diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index ae13c8859d53..2af784f451a1 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -995,7 +995,7 @@ def set_state(self, state: Optional[str], session=NEW_SESSION): # - UPSTREAM_FAILED # - SKIPPED # - FAILED - self.call_state_change_callback() + self.call_state_change_callback(state) @property def is_premature(self): @@ -2630,20 +2630,33 @@ def ti_selector_condition(cls, vals: Collection[Union[str, Tuple[str, int]]]) -> if len(filters) == 1: return filters[0] return or_(*filters) - - - @classmethod - def call_state_change_callback(self): + + # 'state' will be taken from the TaskInstance if not given. + # In some cases the task state is not updated in the TaskInstance, only in the metastore, + # to avoid querying the meta store on state change, we can just send the state the + # task should be on and call the callback with the correct state. + def call_state_change_callback(self, state=None): + self.log.info("==================call_state_change_callback[%s, %s]============", state, self.state) self.log.info("State changed for DAG: %s, Task: %s, to state: %s", self.dag_id, self.task_id, self.state) - task = self.task - if task.on_state_change_callback is not None: - # Ensure the state and timestamps are up-to-date - self.refresh_from_db() - context = self.get_template_context() - try: - task.on_state_change_callback(context) - except Exception: - self.log.exception("Error when executing on_state_change_callback") + if hasattr(self, 'task'): + self.log.info("TASK: %s",self.task) + self.log.info("TASK_on_state_change_callback: %s",self.task.on_state_change_callback) + # Update state in current instance, for cases where the state is only updated in DB + # This to avoid querying the DB to get the latest state + if state: + self.state = state + + task = self.task + if task.on_state_change_callback is not None: + # Ensure the state and timestamps are up-to-date + # self.refresh_from_db() + context = self.get_template_context() + try: + task.on_state_change_callback(context) + except Exception: + self.log.exception("Error when executing on_state_change_callback") + else: + self.log.info("Couldn't get self.task object!") # State of the task instance. From 77d2c7d58562b14f053b787a86e115132054d44f Mon Sep 17 00:00:00 2001 From: Luis Gonzalez Date: Thu, 15 Aug 2024 16:27:56 -0600 Subject: [PATCH 7/7] Print works, no task yet --- airflow/jobs/scheduler_job.py | 4 +++ airflow/models/baseoperator.py | 2 +- airflow/models/dagrun.py | 2 +- airflow/models/taskinstance.py | 58 +++++++++++++++++++++------------- 4 files changed, 42 insertions(+), 24 deletions(-) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 1f4f97b0bf89..c85eae647f3e 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -520,10 +520,14 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = self.log.info("CallingCB for [%s] TIs", len(executable_tis)) for ti in executable_tis: self.log.info("TI_CB: %s", ti.call_state_change_callback) + self.log.info("TI_CB Function name: %s", ti.call_state_change_callback.__name__) + self.log.info("TI_CB Function bound to instance: %s", ti.call_state_change_callback.__self__) # Handles the following states: # - QUEUED ti.call_state_change_callback(state=TaskInstanceState.QUEUED) + self.log.info("Called already call_state_change_callback(state=QUEUED):\n\t") make_transient(ti) + return executable_tis @provide_session diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index f7dc10da4133..70b3c4cde7ff 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1789,4 +1789,4 @@ def get_link(self, operator: AbstractOperator, *, ti_key: "TaskInstanceKey") -> :param operator: airflow operator :param ti_key: TaskInstance ID to return link for :return: link to external system - """ + """ \ No newline at end of file diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 4df03c0b2a24..2613be1e4599 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -1261,4 +1261,4 @@ def get_log_filename_template(self, *, session: Session = NEW_SESSION) -> str: DeprecationWarning, stacklevel=2, ) - return self.get_log_template(session=session).filename + return self.get_log_template(session=session).filename \ No newline at end of file diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 2af784f451a1..f0f6b50b9dab 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2636,27 +2636,41 @@ def ti_selector_condition(cls, vals: Collection[Union[str, Tuple[str, int]]]) -> # to avoid querying the meta store on state change, we can just send the state the # task should be on and call the callback with the correct state. def call_state_change_callback(self, state=None): - self.log.info("==================call_state_change_callback[%s, %s]============", state, self.state) - self.log.info("State changed for DAG: %s, Task: %s, to state: %s", self.dag_id, self.task_id, self.state) - if hasattr(self, 'task'): - self.log.info("TASK: %s",self.task) - self.log.info("TASK_on_state_change_callback: %s",self.task.on_state_change_callback) - # Update state in current instance, for cases where the state is only updated in DB - # This to avoid querying the DB to get the latest state - if state: - self.state = state - - task = self.task - if task.on_state_change_callback is not None: - # Ensure the state and timestamps are up-to-date - # self.refresh_from_db() - context = self.get_template_context() - try: - task.on_state_change_callback(context) - except Exception: - self.log.exception("Error when executing on_state_change_callback") - else: - self.log.info("Couldn't get self.task object!") + try: + print("==================call_state_change_callback[%s, %s]============", state, self.state) + print("State changed for DAG: %s, Task: %s, to state: %s", self.dag_id, self.task_id, self.state) + print("1self:", self.__dict__) + if not hasattr(self, 'task'): + # self: {'_sa_instance_state': , 'dag_id': 'testing_logs', 'hostname': '', 'queued_dttm': datetime.datetime(2024, 8, 15, 22, 5, 9, 73142, tzinfo=Timezone('UTC')), 'next_method': None, 'run_id': 'scheduled__2024-08-15T22:00:00+00:00', 'unixname': 'root', 'queued_by_job_id': None, 'next_kwargs': None, 'map_index': -1, 'job_id': None, 'pid': None, 'start_date': None, 'pool': 'default_pool', 'executor_config': {}, 'end_date': None, 'pool_slots': 1, 'external_executor_id': None, '_try_number': 0, 'duration': None, 'queue': 'default', 'trigger_id': None, 'state': 'scheduled', 'priority_weight': 1, 'trigger_timeout': None, 'task_id': 'logging_from_presto_operator', 'max_tries': 1, 'operator': 'PrestoOperator', 'dag_run': , 'rendered_task_instance_fields': None, '_log': , 'test_mode': False, 'dag_model': } + # there is no self.task_instance + if not hasattr(self, 'task_instance'): + if not hasattr(self, 'get_task_instance'): + # Get task instance from DB + self.task = TaskInstance.get_task_instance(self.dag_id, self.task_id, self.execution_date) + # self.task = self.task_instance.task + + print("2self:", self.__dict__) + if hasattr(self, 'task'): + print("TASK: %s",self.task) + print("TASK_on_state_change_callback: %s",self.task.on_state_change_callback) + # Update state in current instance, for cases where the state is only updated in DB + # This to avoid querying the DB to get the latest state + if state: + self.state = state + + task = self.task + if task.on_state_change_callback is not None: + # Ensure the state and timestamps are up-to-date + # self.refresh_from_db() + context = self.get_template_context() + try: + task.on_state_change_callback(context) + except Exception: + self.log.exception("Error when executing on_state_change_callback") + else: + self.log.info("Couldn't get self.task object!") + except Exception as e: + log.error("Exception in call_state_change_callback: %s", e) # State of the task instance. @@ -2756,4 +2770,4 @@ def from_dict(cls, obj_dict: dict): if STATICA_HACK: # pragma: no cover from airflow.jobs.base_job import BaseJob - TaskInstance.queued_by_job = relationship(BaseJob) + TaskInstance.queued_by_job = relationship(BaseJob) \ No newline at end of file