Skip to content

Commit

Permalink
Keyword argument protocol changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Shrews committed Aug 9, 2024
1 parent b965f9c commit fc1ac8e
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 9 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#
import datetime

from ansible_runner.__main__ import VERSION
from ansible_runner.version import VERSION


# -- Project information -----------------------------------------------------
Expand Down
3 changes: 1 addition & 2 deletions src/ansible_runner/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,9 @@
from ansible_runner import cleanup
from ansible_runner.utils import dump_artifact, Bunch, register_for_cleanup
from ansible_runner.utils.capacity import get_cpu_count, get_mem_in_bytes, ensure_uuid
from ansible_runner.utils.importlib_compat import importlib_metadata
from ansible_runner.runner import Runner
from ansible_runner.version import VERSION

VERSION = importlib_metadata.version("ansible_runner")

DEFAULT_ROLES_PATH = os.getenv('ANSIBLE_ROLES_PATH', None)
DEFAULT_RUNNER_BINARY = os.getenv('RUNNER_BINARY', None)
Expand Down
5 changes: 3 additions & 2 deletions src/ansible_runner/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
sanitize_json_response,
signal_handler,
)
from ansible_runner.version import VERSION

logging.getLogger('ansible-runner').addHandler(logging.NullHandler())

Expand Down Expand Up @@ -98,11 +99,11 @@ def init_runner(**kwargs):
streamer = kwargs.pop('streamer', None)
if streamer:
if streamer == 'transmit':
stream_transmitter = Transmitter(**kwargs)
stream_transmitter = Transmitter(runner_version=VERSION, **kwargs)
return stream_transmitter

if streamer == 'worker':
stream_worker = Worker(**kwargs)
stream_worker = Worker(runner_version=VERSION, **kwargs)
return stream_worker

if streamer == 'process':
Expand Down
24 changes: 21 additions & 3 deletions src/ansible_runner/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from collections.abc import Mapping
from functools import wraps
from packaging.version import Version
from threading import Event, RLock, Thread

import ansible_runner
Expand All @@ -37,7 +38,9 @@ def __init__(self, settings):


class Transmitter:
def __init__(self, _output=None, **kwargs):
def __init__(self, runner_version: str, _output=None, **kwargs):
self.runner_version = runner_version

if _output is None:
_output = sys.stdout.buffer
self._output = _output
Expand All @@ -53,7 +56,12 @@ def __init__(self, _output=None, **kwargs):

def run(self):
self._output.write(
json.dumps({'kwargs': self.kwargs}, cls=UUIDEncoder).encode('utf-8')
json.dumps(
{
'runner_version': self.runner_version,
'kwargs': self.kwargs
},
cls=UUIDEncoder).encode('utf-8')
)
self._output.write(b'\n')
self._output.flush()
Expand All @@ -69,7 +77,9 @@ def run(self):


class Worker:
def __init__(self, _input=None, _output=None, keepalive_seconds: float | None = None, **kwargs):
def __init__(self, runner_version: str, _input=None, _output=None, keepalive_seconds: float | None = None, **kwargs):
self.runner_version = runner_version

if _input is None:
_input = sys.stdin.buffer
if _output is None:
Expand Down Expand Up @@ -187,6 +197,14 @@ def run(self):
self.finished_callback(None) # send eof line
return self.status, self.rc

if 'runner_version' in data:
if Version(self.runner_version) != Version(data['runner_version']):
self.status_handler({
'status': 'error',
'job_explanation': f"Streaming data version mismatch: worker {self.runner_version}, data {data['runner_version']}",
}, None)
self.finished_callback(None) # send eof line
return self.status, self.rc
if 'kwargs' in data:
self.job_kwargs = self.update_paths(data['kwargs'])
elif 'zipfile' in data:
Expand Down
3 changes: 3 additions & 0 deletions src/ansible_runner/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .utils.importlib_compat import importlib_metadata

VERSION = importlib_metadata.version("ansible_runner")
71 changes: 70 additions & 1 deletion test/unit/test_streaming.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import io
import os

from ansible_runner.streaming import Processor
import pytest

from ansible_runner.streaming import Processor, Transmitter, Worker


class TestProcessor:
Expand All @@ -14,3 +17,69 @@ def test_artifact_dir_with_int_ident(self, tmp_path):
assert p.artifact_dir == os.path.join(kwargs['private_data_dir'],
'artifacts',
str(kwargs['ident']))


class TestTransmitter:

def test_job_arguments(self, tmp_path, project_fixtures):
"""
Test format of sending job arguments.
"""
transmit_dir = project_fixtures / 'debug'
outgoing_buffer_file = tmp_path / 'buffer_out'
outgoing_buffer_file.touch()

kwargs = {
'playbook': 'debug.yml',
'only_transmit_kwargs': True
}

with outgoing_buffer_file.open('b+r') as outgoing_buffer:
transmitter = Transmitter(
runner_version="1.0.0",
_output=outgoing_buffer,
private_data_dir=transmit_dir,
**kwargs)
transmitter.run()
outgoing_buffer.seek(0)
sent = outgoing_buffer.read()

expected = b'{"runner_version": "1.0.0", "kwargs": {"playbook": "debug.yml"}}\n{"eof": true}\n'
assert sent == expected

def test_version_mismatch(self, project_fixtures):
transmit_dir = project_fixtures / 'debug'
transmit_buffer = io.BytesIO()
output_buffer = io.BytesIO()

for buffer in (transmit_buffer, output_buffer):
buffer.name = 'foo'

kwargs = {
'playbook': 'debug.yml',
'only_transmit_kwargs': True
}

status, rc = Transmitter(
runner_version="1.0.0",
_output=transmit_buffer,
private_data_dir=transmit_dir,
**kwargs).run()

assert rc in (None, 0)
assert status == 'unstarted'
transmit_buffer.seek(0)

worker = Worker(runner_version="0.1.0",
_input=transmit_buffer,
_output=output_buffer)

status, rc = worker.run()

assert status == 'error'
assert rc in (None, 0)

output_buffer.seek(0)
output = output_buffer.read()

assert output == b'{"status": "error", "job_explanation": "Streaming data version mismatch: worker 0.1.0, data 1.0.0"}\n{"eof": true}\n'

0 comments on commit fc1ac8e

Please sign in to comment.