Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkrkotzeml authored Jan 10, 2025
2 parents c8033c2 + e8f3be8 commit 73fee85
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 38 deletions.
56 changes: 32 additions & 24 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def render_templates(

def xcom_pull(
self,
task_ids: str | Iterable[str] | None = None, # TODO: Simplify to a single task_id? (breaking change)
task_ids: str | Iterable[str] | None = None,
dag_id: str | None = None,
key: str = "return_value", # TODO: Make this a constant (``XCOM_RETURN_KEY``)
include_prior_dates: bool = False, # TODO: Add support for this
Expand Down Expand Up @@ -213,40 +213,48 @@ def xcom_pull(
run_id = self.run_id

if task_ids is None:
# default to the current task if not provided
task_ids = self.task_id
elif not isinstance(task_ids, str) and isinstance(task_ids, Iterable):
# TODO: Handle multiple task_ids or remove support
raise NotImplementedError("Multiple task_ids are not supported yet")

elif isinstance(task_ids, str):
task_ids = [task_ids]
if map_indexes is None:
map_indexes = self.map_index
elif isinstance(map_indexes, Iterable):
# TODO: Handle multiple map_indexes or remove support
raise NotImplementedError("Multiple map_indexes are not supported yet")

log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(
log=log,
msg=GetXCom(
key=key,
dag_id=dag_id,
task_id=task_ids,
run_id=run_id,
map_index=map_indexes,
),
)

msg = SUPERVISOR_COMMS.get_message()
if TYPE_CHECKING:
assert isinstance(msg, XComResult)
xcoms = []
for t in task_ids:
SUPERVISOR_COMMS.send_request(
log=log,
msg=GetXCom(
key=key,
dag_id=dag_id,
task_id=t,
run_id=run_id,
map_index=map_indexes,
),
)

msg = SUPERVISOR_COMMS.get_message()
if not isinstance(msg, XComResult):
raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}")

if msg.value is not None:
from airflow.models.xcom import XCom

if msg.value is not None:
from airflow.models.xcom import XCom
# TODO: Move XCom serialization & deserialization to Task SDK
# https://github.com/apache/airflow/issues/45231
xcom = XCom.deserialize_value(msg) # type: ignore[arg-type]
xcoms.append(xcom)
else:
xcoms.append(default)

# TODO: Move XCom serialization & deserialization to Task SDK
# https://github.com/apache/airflow/issues/45231
return XCom.deserialize_value(msg) # type: ignore[arg-type]
return default
if len(xcoms) == 1:
return xcoms[0]
return xcoms

def xcom_push(self, key: str, value: Any):
"""
Expand Down
38 changes: 24 additions & 14 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,14 +735,20 @@ def test_get_variable_from_context(

assert var_from_context == Variable(key="test_key", value=expected_value)

def test_xcom_pull(self, create_runtime_ti, mock_supervisor_comms, spy_agency):
@pytest.mark.parametrize(
"task_ids",
[
"push_task",
["push_task1", "push_task2"],
{"push_task1", "push_task2"},
],
)
def test_xcom_pull(self, create_runtime_ti, mock_supervisor_comms, spy_agency, task_ids):
"""Test that a task pulls the expected XCom value if it exists."""

task_id = "push_task"

class CustomOperator(BaseOperator):
def execute(self, context):
value = context["ti"].xcom_pull(task_ids=task_id, key="key")
value = context["ti"].xcom_pull(task_ids=task_ids, key="key")
print(f"Pulled XCom Value: {value}")

task = CustomOperator(task_id="pull_task")
Expand All @@ -755,16 +761,20 @@ def execute(self, context):

run(runtime_ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_any_call(
log=mock.ANY,
msg=GetXCom(
key="key",
dag_id="test_dag",
run_id="test_run",
task_id=task_id,
map_index=None,
),
)
if isinstance(task_ids, str):
task_ids = [task_ids]

for task_id in task_ids:
mock_supervisor_comms.send_request.assert_any_call(
log=mock.ANY,
msg=GetXCom(
key="key",
dag_id="test_dag",
run_id="test_run",
task_id=task_id,
map_index=None,
),
)


class TestXComAfterTaskExecution:
Expand Down

0 comments on commit 73fee85

Please sign in to comment.