Skip to content

Commit

Permalink
organize remotely executing functions in "cmdeploy.remote" sub package
Browse files Browse the repository at this point in the history
  • Loading branch information
hpk42 committed Aug 1, 2024
1 parent e7a9bf2 commit 14a6927
Show file tree
Hide file tree
Showing 11 changed files with 132 additions and 77 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

## untagged

- refactor ssh-based execution to allow organizing remote functions in
modules.
([#396](https://github.com/deltachat/chatmail/pull/396))



## 1.4.1 2024-07-31

Expand Down
1 change: 1 addition & 0 deletions chatmaild/src/chatmaild/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

6 changes: 3 additions & 3 deletions cmdeploy/src/cmdeploy/cmdeploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from packaging import version
from termcolor import colored

from . import dns, remote_funcs
from . import dns, remote
from .sshexec import SSHExec

#
Expand Down Expand Up @@ -132,7 +132,7 @@ def status_cmd(args, out):
else:
out.red("no privacy settings")

for line in sshexec(remote_funcs.get_systemd_running):
for line in sshexec(remote.rshell.get_systemd_running):
print(line)


Expand Down Expand Up @@ -313,7 +313,7 @@ def main(args=None):

def get_sshexec():
print(f"[ssh] login to {args.config.mail_domain}")
return SSHExec(args.config.mail_domain, remote_funcs, verbose=args.verbose)
return SSHExec(args.config.mail_domain, verbose=args.verbose)

args.get_sshexec = get_sshexec

Expand Down
6 changes: 3 additions & 3 deletions cmdeploy/src/cmdeploy/dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from jinja2 import Template

from . import remote_funcs
from . import remote


def get_initial_remote_data(sshexec, mail_domain):
return sshexec.logged(
call=remote_funcs.perform_initial_checks, kwargs=dict(mail_domain=mail_domain)
call=remote.rdns.perform_initial_checks, kwargs=dict(mail_domain=mail_domain)
)


Expand Down Expand Up @@ -42,7 +42,7 @@ def check_full_zone(sshexec, remote_data, out, zonefile) -> int:
and return (exitcode, remote_data) tuple."""

required_diff, recommended_diff = sshexec.logged(
remote_funcs.check_zonefile, kwargs=dict(zonefile=zonefile)
remote.rdns.check_zonefile, kwargs=dict(zonefile=zonefile)
)

if required_diff:
Expand Down
12 changes: 12 additions & 0 deletions cmdeploy/src/cmdeploy/remote/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
The 'cmdeploy.remote' sub package contains modules with remotely executing functions.
Its "_sshexec_bootstrap" module is executed remotely through `SSHExec`
and its main() loop there stays connected via a command channel,
ready to receive function invocations ("command") and return results.
"""

from . import rdns, rshell

__all__ = ["rdns", "rshell"]
34 changes: 34 additions & 0 deletions cmdeploy/src/cmdeploy/remote/_sshexec_bootstrap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import builtins
import importlib
import traceback

## Function Execution server


def _run_loop(cmd_channel):
while 1:
cmd = cmd_channel.receive()
if cmd is None:
break

cmd_channel.send(_handle_one_request(cmd))


def _handle_one_request(cmd):
pymod_path, func_name, kwargs = cmd
try:
mod = importlib.import_module(pymod_path)
func = getattr(mod, func_name)
res = func(**kwargs)
return ("finish", res)
except:
data = traceback.format_exc()
return ("error", data)


def main(channel):
# enable simple "print" logging

builtins.print = lambda x="": channel.send(("log", x))

_run_loop(channel)
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,8 @@
"""

import re
import traceback
from subprocess import CalledProcessError, check_output


def shell(command, fail_ok=False):
print(f"$ {command}")
try:
return check_output(command, shell=True).decode().rstrip()
except CalledProcessError:
if not fail_ok:
raise
return ""


def get_systemd_running():
lines = shell("systemctl --type=service --state=running").split("\n")
return [line for line in lines if line.startswith(" ")]
from .rshell import ShellError, shell


def perform_initial_checks(mail_domain):
Expand Down Expand Up @@ -59,7 +44,7 @@ def get_dkim_entry(mail_domain, dkim_selector):
f"openssl rsa -in /etc/dkimkeys/{dkim_selector}.private "
"-pubout 2>/dev/null | awk '/-/{next}{printf(\"%s\",$0)}'"
)
except CalledProcessError:
except ShellError:
return
dkim_value_raw = f"v=DKIM1;k=rsa;p={dkim_pubkey};s=email;t=s"
dkim_value = '" "'.join(re.findall(".{1,255}", dkim_value_raw))
Expand Down Expand Up @@ -99,37 +84,3 @@ def check_zonefile(zonefile):
recommended_diff.append(zf_line)

return required_diff, recommended_diff


## Function Execution server


def _run_loop(cmd_channel):
while 1:
cmd = cmd_channel.receive()
if cmd is None:
break

cmd_channel.send(_handle_one_request(cmd))


def _handle_one_request(cmd):
func_name, kwargs = cmd
try:
res = globals()[func_name](**kwargs)
return ("finish", res)
except:
data = traceback.format_exc()
return ("error", data)


# check if this module is executed remotely
# and setup a simple serialized function-execution loop

if __name__ == "__channelexec__":
channel = channel # noqa (channel object gets injected)

# enable simple "print" logging for anyone changing this module
globals()["print"] = lambda x="": channel.send(("log", x))

_run_loop(channel)
17 changes: 17 additions & 0 deletions cmdeploy/src/cmdeploy/remote/rshell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from subprocess import CalledProcessError as ShellError
from subprocess import check_output


def shell(command, fail_ok=False):
print(f"$ {command}")
try:
return check_output(command, shell=True).decode().rstrip()
except ShellError:
if not fail_ok:
raise
return ""


def get_systemd_running():
lines = shell("systemctl --type=service --state=running").split("\n")
return [line for line in lines if line.startswith(" ")]
41 changes: 38 additions & 3 deletions cmdeploy/src/cmdeploy/sshexec.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,45 @@
import inspect
import os
import sys
from queue import Queue

import execnet

from . import remote


class FuncError(Exception):
pass


def bootstrap_remote(gateway, remote=remote):
"""Return a command channel which can execute remote functions."""
source_init_path = inspect.getfile(remote)
basedir = os.path.dirname(source_init_path)
name = os.path.basename(basedir)

# rsync sourcedir to remote host
remote_pkg_path = f"/root/from-cmdeploy/{name}"
q = Queue()
finish = lambda: q.put(None)
rsync = execnet.RSync(sourcedir=basedir, verbose=False)
rsync.add_target(gateway, remote_pkg_path, finishedcallback=finish, delete=True)
rsync.send()
q.get()

# start sshexec bootstrap and return its command channel
remote_sys_path = os.path.dirname(remote_pkg_path)
channel = gateway.remote_exec(
f"""
import sys
sys.path.insert(0, {remote_sys_path!r})
from remote._sshexec_bootstrap import main
main(channel)
"""
)
return channel


def print_stderr(item="", end="\n"):
print(item, file=sys.stderr, end=end)

Expand All @@ -15,16 +48,18 @@ class SSHExec:
RemoteError = execnet.RemoteError
FuncError = FuncError

def __init__(self, host, remote_funcs, verbose=False, python="python3", timeout=60):
def __init__(self, host, verbose=False, python="python3", timeout=60):
self.gateway = execnet.makegateway(f"ssh=root@{host}//python={python}")
self._remote_cmdloop_channel = self.gateway.remote_exec(remote_funcs)
self._remote_cmdloop_channel = bootstrap_remote(self.gateway, remote)
self.timeout = timeout
self.verbose = verbose

def __call__(self, call, kwargs=None, log_callback=None):
if kwargs is None:
kwargs = {}
self._remote_cmdloop_channel.send((call.__name__, kwargs))
assert call.__module__.startswith("cmdeploy.remote")
modname = call.__module__.replace("cmdeploy.", "")
self._remote_cmdloop_channel.send((modname, call.__name__, kwargs))
while 1:
code, data = self._remote_cmdloop_channel.receive(timeout=self.timeout)
if log_callback is not None and code == "log":
Expand Down
20 changes: 10 additions & 10 deletions cmdeploy/src/cmdeploy/tests/online/test_1_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,29 @@

import pytest

from cmdeploy import remote_funcs
from cmdeploy import remote
from cmdeploy.sshexec import SSHExec


class TestSSHExecutor:
@pytest.fixture(scope="class")
def sshexec(self, sshdomain):
return SSHExec(sshdomain, remote_funcs)
return SSHExec(sshdomain)

def test_ls(self, sshexec):
out = sshexec(call=remote_funcs.shell, kwargs=dict(command="ls"))
out2 = sshexec(call=remote_funcs.shell, kwargs=dict(command="ls"))
out = sshexec(call=remote.rdns.shell, kwargs=dict(command="ls"))
out2 = sshexec(call=remote.rdns.shell, kwargs=dict(command="ls"))
assert out == out2

def test_perform_initial(self, sshexec, maildomain):
res = sshexec(
remote_funcs.perform_initial_checks, kwargs=dict(mail_domain=maildomain)
remote.rdns.perform_initial_checks, kwargs=dict(mail_domain=maildomain)
)
assert res["A"] or res["AAAA"]

def test_logged(self, sshexec, maildomain, capsys):
sshexec.logged(
remote_funcs.perform_initial_checks, kwargs=dict(mail_domain=maildomain)
remote.rdns.perform_initial_checks, kwargs=dict(mail_domain=maildomain)
)
out, err = capsys.readouterr()
assert err.startswith("Collecting")
Expand All @@ -33,21 +33,21 @@ def test_logged(self, sshexec, maildomain, capsys):

sshexec.verbose = True
sshexec.logged(
remote_funcs.perform_initial_checks, kwargs=dict(mail_domain=maildomain)
remote.rdns.perform_initial_checks, kwargs=dict(mail_domain=maildomain)
)
out, err = capsys.readouterr()
lines = err.split("\n")
assert len(lines) > 4
assert remote_funcs.perform_initial_checks.__doc__ in lines[0]
assert remote.rdns.perform_initial_checks.__doc__ in lines[0]

def test_exception(self, sshexec, capsys):
try:
sshexec.logged(
remote_funcs.perform_initial_checks,
remote.rdns.perform_initial_checks,
kwargs=dict(mail_domain=None),
)
except sshexec.FuncError as e:
assert "remote_funcs.py" in str(e)
assert "rdns.py" in str(e)
assert "AssertionError" in str(e)
else:
pytest.fail("didn't raise exception")
Expand Down
14 changes: 7 additions & 7 deletions cmdeploy/src/cmdeploy/tests/test_dns.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from cmdeploy import remote_funcs
from cmdeploy import remote
from cmdeploy.dns import check_full_zone, check_initial_remote_data


Expand All @@ -14,7 +14,7 @@ def query_dns(typ, domain):
except KeyError:
return ""

monkeypatch.setattr(remote_funcs, query_dns.__name__, query_dns)
monkeypatch.setattr(remote.rdns, query_dns.__name__, query_dns)
return qdict


Expand All @@ -32,13 +32,13 @@ def mockdns(mockdns_base):

class TestPerformInitialChecks:
def test_perform_initial_checks_ok1(self, mockdns):
remote_data = remote_funcs.perform_initial_checks("some.domain")
remote_data = remote.rdns.perform_initial_checks("some.domain")
assert len(remote_data) == 7

@pytest.mark.parametrize("drop", ["A", "AAAA"])
def test_perform_initial_checks_with_one_of_A_AAAA(self, mockdns, drop):
del mockdns[drop]
remote_data = remote_funcs.perform_initial_checks("some.domain")
remote_data = remote.rdns.perform_initial_checks("some.domain")
assert len(remote_data) == 7
assert not remote_data[drop]

Expand All @@ -49,7 +49,7 @@ def test_perform_initial_checks_with_one_of_A_AAAA(self, mockdns, drop):

def test_perform_initial_checks_no_mta_sts(self, mockdns):
del mockdns["CNAME"]
remote_data = remote_funcs.perform_initial_checks("some.domain")
remote_data = remote.rdns.perform_initial_checks("some.domain")
assert len(remote_data) == 4
assert not remote_data["MTA_STS"]

Expand Down Expand Up @@ -85,14 +85,14 @@ class TestZonefileChecks:
def test_check_zonefile_all_ok(self, cm_data, mockdns_base):
zonefile = cm_data.get("zftest.zone")
parse_zonefile_into_dict(zonefile, mockdns_base)
required_diff, recommended_diff = remote_funcs.check_zonefile(zonefile)
required_diff, recommended_diff = remote.rdns.check_zonefile(zonefile)
assert not required_diff and not recommended_diff

def test_check_zonefile_recommended_not_set(self, cm_data, mockdns_base):
zonefile = cm_data.get("zftest.zone")
zonefile_mocked = zonefile.split("; Recommended")[0]
parse_zonefile_into_dict(zonefile_mocked, mockdns_base)
required_diff, recommended_diff = remote_funcs.check_zonefile(zonefile)
required_diff, recommended_diff = remote.rdns.check_zonefile(zonefile)
assert not required_diff
assert len(recommended_diff) == 8

Expand Down

0 comments on commit 14a6927

Please sign in to comment.