diff --git a/CHANGELOG.md b/CHANGELOG.md index bbee7112..0fcf5089 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,11 @@ ## untagged +- refactor ssh-based execution to allow organizing remote functions in + modules. + ([#396](https://github.com/deltachat/chatmail/pull/396)) + + ## 1.4.1 2024-07-31 diff --git a/chatmaild/src/chatmaild/__init__.py b/chatmaild/src/chatmaild/__init__.py index e69de29b..8b137891 100644 --- a/chatmaild/src/chatmaild/__init__.py +++ b/chatmaild/src/chatmaild/__init__.py @@ -0,0 +1 @@ + diff --git a/cmdeploy/src/cmdeploy/cmdeploy.py b/cmdeploy/src/cmdeploy/cmdeploy.py index 67d038eb..cd992a2b 100644 --- a/cmdeploy/src/cmdeploy/cmdeploy.py +++ b/cmdeploy/src/cmdeploy/cmdeploy.py @@ -18,7 +18,7 @@ from packaging import version from termcolor import colored -from . import dns, remote_funcs +from . import dns, remote from .sshexec import SSHExec # @@ -132,7 +132,7 @@ def status_cmd(args, out): else: out.red("no privacy settings") - for line in sshexec(remote_funcs.get_systemd_running): + for line in sshexec(remote.rshell.get_systemd_running): print(line) @@ -313,7 +313,7 @@ def main(args=None): def get_sshexec(): print(f"[ssh] login to {args.config.mail_domain}") - return SSHExec(args.config.mail_domain, remote_funcs, verbose=args.verbose) + return SSHExec(args.config.mail_domain, verbose=args.verbose) args.get_sshexec = get_sshexec diff --git a/cmdeploy/src/cmdeploy/dns.py b/cmdeploy/src/cmdeploy/dns.py index 7d2e9c43..e4b95f90 100644 --- a/cmdeploy/src/cmdeploy/dns.py +++ b/cmdeploy/src/cmdeploy/dns.py @@ -3,12 +3,12 @@ from jinja2 import Template -from . import remote_funcs +from . import remote def get_initial_remote_data(sshexec, mail_domain): return sshexec.logged( - call=remote_funcs.perform_initial_checks, kwargs=dict(mail_domain=mail_domain) + call=remote.rdns.perform_initial_checks, kwargs=dict(mail_domain=mail_domain) ) @@ -42,7 +42,7 @@ def check_full_zone(sshexec, remote_data, out, zonefile) -> int: and return (exitcode, remote_data) tuple.""" required_diff, recommended_diff = sshexec.logged( - remote_funcs.check_zonefile, kwargs=dict(zonefile=zonefile) + remote.rdns.check_zonefile, kwargs=dict(zonefile=zonefile) ) if required_diff: diff --git a/cmdeploy/src/cmdeploy/remote/__init__.py b/cmdeploy/src/cmdeploy/remote/__init__.py new file mode 100644 index 00000000..c300b44d --- /dev/null +++ b/cmdeploy/src/cmdeploy/remote/__init__.py @@ -0,0 +1,12 @@ +""" + +The 'cmdeploy.remote' sub package contains modules with remotely executing functions. + +Its "_sshexec_bootstrap" module is executed remotely through `SSHExec` +and its main() loop there stays connected via a command channel, +ready to receive function invocations ("command") and return results. +""" + +from . import rdns, rshell + +__all__ = ["rdns", "rshell"] diff --git a/cmdeploy/src/cmdeploy/remote/_sshexec_bootstrap.py b/cmdeploy/src/cmdeploy/remote/_sshexec_bootstrap.py new file mode 100644 index 00000000..f5b4c083 --- /dev/null +++ b/cmdeploy/src/cmdeploy/remote/_sshexec_bootstrap.py @@ -0,0 +1,34 @@ +import builtins +import importlib +import traceback + +## Function Execution server + + +def _run_loop(cmd_channel): + while 1: + cmd = cmd_channel.receive() + if cmd is None: + break + + cmd_channel.send(_handle_one_request(cmd)) + + +def _handle_one_request(cmd): + pymod_path, func_name, kwargs = cmd + try: + mod = importlib.import_module(pymod_path) + func = getattr(mod, func_name) + res = func(**kwargs) + return ("finish", res) + except: + data = traceback.format_exc() + return ("error", data) + + +def main(channel): + # enable simple "print" logging + + builtins.print = lambda x="": channel.send(("log", x)) + + _run_loop(channel) diff --git a/cmdeploy/src/cmdeploy/remote_funcs.py b/cmdeploy/src/cmdeploy/remote/rdns.py similarity index 67% rename from cmdeploy/src/cmdeploy/remote_funcs.py rename to cmdeploy/src/cmdeploy/remote/rdns.py index 12ff3ff4..3402f53f 100644 --- a/cmdeploy/src/cmdeploy/remote_funcs.py +++ b/cmdeploy/src/cmdeploy/remote/rdns.py @@ -11,23 +11,8 @@ """ import re -import traceback -from subprocess import CalledProcessError, check_output - -def shell(command, fail_ok=False): - print(f"$ {command}") - try: - return check_output(command, shell=True).decode().rstrip() - except CalledProcessError: - if not fail_ok: - raise - return "" - - -def get_systemd_running(): - lines = shell("systemctl --type=service --state=running").split("\n") - return [line for line in lines if line.startswith(" ")] +from .rshell import ShellError, shell def perform_initial_checks(mail_domain): @@ -59,7 +44,7 @@ def get_dkim_entry(mail_domain, dkim_selector): f"openssl rsa -in /etc/dkimkeys/{dkim_selector}.private " "-pubout 2>/dev/null | awk '/-/{next}{printf(\"%s\",$0)}'" ) - except CalledProcessError: + except ShellError: return dkim_value_raw = f"v=DKIM1;k=rsa;p={dkim_pubkey};s=email;t=s" dkim_value = '" "'.join(re.findall(".{1,255}", dkim_value_raw)) @@ -99,37 +84,3 @@ def check_zonefile(zonefile): recommended_diff.append(zf_line) return required_diff, recommended_diff - - -## Function Execution server - - -def _run_loop(cmd_channel): - while 1: - cmd = cmd_channel.receive() - if cmd is None: - break - - cmd_channel.send(_handle_one_request(cmd)) - - -def _handle_one_request(cmd): - func_name, kwargs = cmd - try: - res = globals()[func_name](**kwargs) - return ("finish", res) - except: - data = traceback.format_exc() - return ("error", data) - - -# check if this module is executed remotely -# and setup a simple serialized function-execution loop - -if __name__ == "__channelexec__": - channel = channel # noqa (channel object gets injected) - - # enable simple "print" logging for anyone changing this module - globals()["print"] = lambda x="": channel.send(("log", x)) - - _run_loop(channel) diff --git a/cmdeploy/src/cmdeploy/remote/rshell.py b/cmdeploy/src/cmdeploy/remote/rshell.py new file mode 100644 index 00000000..994223bc --- /dev/null +++ b/cmdeploy/src/cmdeploy/remote/rshell.py @@ -0,0 +1,17 @@ +from subprocess import CalledProcessError as ShellError +from subprocess import check_output + + +def shell(command, fail_ok=False): + print(f"$ {command}") + try: + return check_output(command, shell=True).decode().rstrip() + except ShellError: + if not fail_ok: + raise + return "" + + +def get_systemd_running(): + lines = shell("systemctl --type=service --state=running").split("\n") + return [line for line in lines if line.startswith(" ")] diff --git a/cmdeploy/src/cmdeploy/sshexec.py b/cmdeploy/src/cmdeploy/sshexec.py index 474c04db..8a87e781 100644 --- a/cmdeploy/src/cmdeploy/sshexec.py +++ b/cmdeploy/src/cmdeploy/sshexec.py @@ -1,12 +1,45 @@ +import inspect +import os import sys +from queue import Queue import execnet +from . import remote + class FuncError(Exception): pass +def bootstrap_remote(gateway, remote=remote): + """Return a command channel which can execute remote functions.""" + source_init_path = inspect.getfile(remote) + basedir = os.path.dirname(source_init_path) + name = os.path.basename(basedir) + + # rsync sourcedir to remote host + remote_pkg_path = f"/root/from-cmdeploy/{name}" + q = Queue() + finish = lambda: q.put(None) + rsync = execnet.RSync(sourcedir=basedir, verbose=False) + rsync.add_target(gateway, remote_pkg_path, finishedcallback=finish, delete=True) + rsync.send() + q.get() + + # start sshexec bootstrap and return its command channel + remote_sys_path = os.path.dirname(remote_pkg_path) + channel = gateway.remote_exec( + f""" + import sys + sys.path.insert(0, {remote_sys_path!r}) + from remote._sshexec_bootstrap import main + main(channel) + """ + ) + return channel + + def print_stderr(item="", end="\n"): print(item, file=sys.stderr, end=end) @@ -15,16 +48,18 @@ class SSHExec: RemoteError = execnet.RemoteError FuncError = FuncError - def __init__(self, host, remote_funcs, verbose=False, python="python3", timeout=60): + def __init__(self, host, verbose=False, python="python3", timeout=60): self.gateway = execnet.makegateway(f"ssh=root@{host}//python={python}") - self._remote_cmdloop_channel = self.gateway.remote_exec(remote_funcs) + self._remote_cmdloop_channel = bootstrap_remote(self.gateway, remote) self.timeout = timeout self.verbose = verbose def __call__(self, call, kwargs=None, log_callback=None): if kwargs is None: kwargs = {} - self._remote_cmdloop_channel.send((call.__name__, kwargs)) + assert call.__module__.startswith("cmdeploy.remote") + modname = call.__module__.replace("cmdeploy.", "") + self._remote_cmdloop_channel.send((modname, call.__name__, kwargs)) while 1: code, data = self._remote_cmdloop_channel.receive(timeout=self.timeout) if log_callback is not None and code == "log": diff --git a/cmdeploy/src/cmdeploy/tests/online/test_1_basic.py b/cmdeploy/src/cmdeploy/tests/online/test_1_basic.py index 0d3ecb8f..f5f2c023 100644 --- a/cmdeploy/src/cmdeploy/tests/online/test_1_basic.py +++ b/cmdeploy/src/cmdeploy/tests/online/test_1_basic.py @@ -2,29 +2,29 @@ import pytest -from cmdeploy import remote_funcs +from cmdeploy import remote from cmdeploy.sshexec import SSHExec class TestSSHExecutor: @pytest.fixture(scope="class") def sshexec(self, sshdomain): - return SSHExec(sshdomain, remote_funcs) + return SSHExec(sshdomain) def test_ls(self, sshexec): - out = sshexec(call=remote_funcs.shell, kwargs=dict(command="ls")) - out2 = sshexec(call=remote_funcs.shell, kwargs=dict(command="ls")) + out = sshexec(call=remote.rdns.shell, kwargs=dict(command="ls")) + out2 = sshexec(call=remote.rdns.shell, kwargs=dict(command="ls")) assert out == out2 def test_perform_initial(self, sshexec, maildomain): res = sshexec( - remote_funcs.perform_initial_checks, kwargs=dict(mail_domain=maildomain) + remote.rdns.perform_initial_checks, kwargs=dict(mail_domain=maildomain) ) assert res["A"] or res["AAAA"] def test_logged(self, sshexec, maildomain, capsys): sshexec.logged( - remote_funcs.perform_initial_checks, kwargs=dict(mail_domain=maildomain) + remote.rdns.perform_initial_checks, kwargs=dict(mail_domain=maildomain) ) out, err = capsys.readouterr() assert err.startswith("Collecting") @@ -33,21 +33,21 @@ def test_logged(self, sshexec, maildomain, capsys): sshexec.verbose = True sshexec.logged( - remote_funcs.perform_initial_checks, kwargs=dict(mail_domain=maildomain) + remote.rdns.perform_initial_checks, kwargs=dict(mail_domain=maildomain) ) out, err = capsys.readouterr() lines = err.split("\n") assert len(lines) > 4 - assert remote_funcs.perform_initial_checks.__doc__ in lines[0] + assert remote.rdns.perform_initial_checks.__doc__ in lines[0] def test_exception(self, sshexec, capsys): try: sshexec.logged( - remote_funcs.perform_initial_checks, + remote.rdns.perform_initial_checks, kwargs=dict(mail_domain=None), ) except sshexec.FuncError as e: - assert "remote_funcs.py" in str(e) + assert "rdns.py" in str(e) assert "AssertionError" in str(e) else: pytest.fail("didn't raise exception") diff --git a/cmdeploy/src/cmdeploy/tests/test_dns.py b/cmdeploy/src/cmdeploy/tests/test_dns.py index eba0e904..71b1baa5 100644 --- a/cmdeploy/src/cmdeploy/tests/test_dns.py +++ b/cmdeploy/src/cmdeploy/tests/test_dns.py @@ -1,6 +1,6 @@ import pytest -from cmdeploy import remote_funcs +from cmdeploy import remote from cmdeploy.dns import check_full_zone, check_initial_remote_data @@ -14,7 +14,7 @@ def query_dns(typ, domain): except KeyError: return "" - monkeypatch.setattr(remote_funcs, query_dns.__name__, query_dns) + monkeypatch.setattr(remote.rdns, query_dns.__name__, query_dns) return qdict @@ -32,13 +32,13 @@ def mockdns(mockdns_base): class TestPerformInitialChecks: def test_perform_initial_checks_ok1(self, mockdns): - remote_data = remote_funcs.perform_initial_checks("some.domain") + remote_data = remote.rdns.perform_initial_checks("some.domain") assert len(remote_data) == 7 @pytest.mark.parametrize("drop", ["A", "AAAA"]) def test_perform_initial_checks_with_one_of_A_AAAA(self, mockdns, drop): del mockdns[drop] - remote_data = remote_funcs.perform_initial_checks("some.domain") + remote_data = remote.rdns.perform_initial_checks("some.domain") assert len(remote_data) == 7 assert not remote_data[drop] @@ -49,7 +49,7 @@ def test_perform_initial_checks_with_one_of_A_AAAA(self, mockdns, drop): def test_perform_initial_checks_no_mta_sts(self, mockdns): del mockdns["CNAME"] - remote_data = remote_funcs.perform_initial_checks("some.domain") + remote_data = remote.rdns.perform_initial_checks("some.domain") assert len(remote_data) == 4 assert not remote_data["MTA_STS"] @@ -85,14 +85,14 @@ class TestZonefileChecks: def test_check_zonefile_all_ok(self, cm_data, mockdns_base): zonefile = cm_data.get("zftest.zone") parse_zonefile_into_dict(zonefile, mockdns_base) - required_diff, recommended_diff = remote_funcs.check_zonefile(zonefile) + required_diff, recommended_diff = remote.rdns.check_zonefile(zonefile) assert not required_diff and not recommended_diff def test_check_zonefile_recommended_not_set(self, cm_data, mockdns_base): zonefile = cm_data.get("zftest.zone") zonefile_mocked = zonefile.split("; Recommended")[0] parse_zonefile_into_dict(zonefile_mocked, mockdns_base) - required_diff, recommended_diff = remote_funcs.check_zonefile(zonefile) + required_diff, recommended_diff = remote.rdns.check_zonefile(zonefile) assert not required_diff assert len(recommended_diff) == 8