From cb87de967a053a48abe5051f243acbc7ce67da34 Mon Sep 17 00:00:00 2001 From: Ron Frederick Date: Tue, 10 Sep 2024 06:15:34 -0700 Subject: [PATCH] Use an executor to prevent GSSAPI calls from blocking the event loop Some operations such as GSSAPI calls can sometimes block the event loop if not run in an executor. However, doing that requires packet handlers to be asynchronous. This commit adds support for async packet handlers for key exchange and auth, and changes the GSSAPI handlers to run the step() call in an executor. --- asyncssh/auth.py | 33 ++++---- asyncssh/connection.py | 57 +++++++++---- asyncssh/kex.py | 4 +- asyncssh/kex_dh.py | 44 +++++----- asyncssh/kex_rsa.py | 4 +- asyncssh/misc.py | 9 +++ asyncssh/packet.py | 15 ++-- tests/test_auth.py | 17 ++-- tests/test_kex.py | 180 ++++++++++++++++++++++------------------- tests/util.py | 2 +- 10 files changed, 210 insertions(+), 155 deletions(-) diff --git a/asyncssh/auth.py b/asyncssh/auth.py index 7437b0ac..377b2559 100644 --- a/asyncssh/auth.py +++ b/asyncssh/auth.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2022 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -27,6 +27,7 @@ from .gss import GSSBase, GSSError from .logging import SSHLogger from .misc import ProtocolError, PasswordChangeRequired, get_symbol_names +from .misc import run_in_executor from .packet import Boolean, String, UInt32, SSHPacket, SSHPacketHandler from .public_key import SigningKey from .saslprep import saslprep, SASLPrepError @@ -199,8 +200,8 @@ def _finish(self) -> None: else: self.send_packet(MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE) - def _process_response(self, _pkttype: int, _pktid: int, - packet: SSHPacket) -> None: + async def _process_response(self, _pkttype: int, _pktid: int, + packet: SSHPacket) -> None: """Process a GSS response from the server""" mech = packet.get_string() @@ -212,7 +213,7 @@ def _process_response(self, _pkttype: int, _pktid: int, raise ProtocolError('Mechanism mismatch') try: - token = self._gss.step() + token = await run_in_executor(self._gss.step) assert token is not None self.send_packet(MSG_USERAUTH_GSSAPI_TOKEN, String(token)) @@ -225,8 +226,8 @@ def _process_response(self, _pkttype: int, _pktid: int, self._conn.try_next_auth() - def _process_token(self, _pkttype: int, _pktid: int, - packet: SSHPacket) -> None: + async def _process_token(self, _pkttype: int, _pktid: int, + packet: SSHPacket) -> None: """Process a GSS token from the server""" token: Optional[bytes] = packet.get_string() @@ -235,7 +236,7 @@ def _process_token(self, _pkttype: int, _pktid: int, assert self._gss is not None try: - token = self._gss.step(token) + token = await run_in_executor(self._gss.step, token) if token: self.send_packet(MSG_USERAUTH_GSSAPI_TOKEN, String(token)) @@ -261,8 +262,8 @@ def _process_error(self, _pkttype: int, _pktid: int, self.logger.debug1('GSS error from server: %s', msg) self._got_error = True - def _process_error_token(self, _pkttype: int, _pktid: int, - packet: SSHPacket) -> None: + async def _process_error_token(self, _pkttype: int, _pktid: int, + packet: SSHPacket) -> None: """Process a GSS error token from the server""" token = packet.get_string() @@ -271,7 +272,7 @@ def _process_error_token(self, _pkttype: int, _pktid: int, assert self._gss is not None try: - self._gss.step(token) + await run_in_executor(self._gss.step, token) except GSSError as exc: if not self._got_error: # pragma: no cover self.logger.debug1('GSS error from server: %s', str(exc)) @@ -649,15 +650,15 @@ async def _finish(self) -> None: else: self.send_failure() - def _process_token(self, _pkttype: int, _pktid: int, - packet: SSHPacket) -> None: + async def _process_token(self, _pkttype: int, _pktid: int, + packet: SSHPacket) -> None: """Process a GSS token from the client""" token: Optional[bytes] = packet.get_string() packet.check_end() try: - token = self._gss.step(token) + token = await run_in_executor(self._gss.step, token) if token: self.send_packet(MSG_USERAUTH_GSSAPI_TOKEN, String(token)) @@ -682,15 +683,15 @@ def _process_exchange_complete(self, _pkttype: int, _pktid: int, else: self.send_failure() - def _process_error_token(self, _pkttype: int, _pktid: int, - packet: SSHPacket) -> None: + async def _process_error_token(self, _pkttype: int, _pktid: int, + packet: SSHPacket) -> None: """Process a GSS error token from the client""" token = packet.get_string() packet.check_end() try: - self._gss.step(token) + await run_in_executor(self._gss.step, token) except GSSError as exc: self.logger.debug1('GSS error from client: %s', str(exc)) diff --git a/asyncssh/connection.py b/asyncssh/connection.py index cc6a3222..2d0424c9 100644 --- a/asyncssh/connection.py +++ b/asyncssh/connection.py @@ -1326,17 +1326,7 @@ def data_received(self, data: bytes, datatype: DataType = None) -> None: self._inpbuf += data - self._reset_keepalive_timer() - - # pylint: disable=broad-except - try: - while self._inpbuf and self._recv_handler(): - pass - except DisconnectError as exc: - self._send_disconnect(exc.code, exc.reason, exc.lang) - self._force_close(exc) - except Exception: - self.internal_error() + self._recv_data() # pylint: enable=arguments-differ def eof_received(self) -> None: @@ -1442,6 +1432,21 @@ def _send_version(self) -> None: self._send(version + b'\r\n') + def _recv_data(self) -> None: + """Parse received data""" + + self._reset_keepalive_timer() + + # pylint: disable=broad-except + try: + while self._inpbuf and self._recv_handler(): + pass + except DisconnectError as exc: + self._send_disconnect(exc.code, exc.reason, exc.lang) + self._force_close(exc) + except Exception: + self.internal_error() + def _recv_version(self) -> bool: """Receive and parse the remote SSH version""" @@ -1595,11 +1600,20 @@ def _recv_packet(self) -> bool: if not skip_reason: try: - processed = handler.process_packet(pkttype, seq, packet) + result = handler.process_packet(pkttype, seq, packet) except PacketDecodeError as exc: raise ProtocolError(str(exc)) from None - if not processed: + if inspect.isawaitable(result): + # Buffer received data until current packet is processed + self._recv_handler = lambda: False + + task = self.create_task(result) + task.add_done_callback(functools.partial( + self._finish_recv_packet, pkttype, seq, is_async=True)) + + return False + elif not result: if self._strict_kex and not self._recv_encryption: exc_reason = 'Strict key exchange violation: ' \ 'unexpected packet type %d received' % pkttype @@ -1611,6 +1625,14 @@ def _recv_packet(self) -> bool: if exc_reason: raise ProtocolError(exc_reason) + self._finish_recv_packet(pkttype, seq) + return True + + def _finish_recv_packet(self, pkttype: int, seq: int, + _task: Optional[asyncio.Task] = None, + is_async: bool = False) -> None: + """Finish processing a packet""" + if pkttype > MSG_USERAUTH_LAST: self._auth_final = True @@ -1625,7 +1647,8 @@ def _recv_packet(self) -> bool: else: self._recv_seq = (seq + 1) & 0xffffffff - return True + if is_async and self._inpbuf: + self._recv_data() def send_packet(self, pkttype: int, *args: bytes, handler: Optional[SSHPacketLogger] = None) -> None: @@ -2218,8 +2241,8 @@ def _process_ext_info(self, _pkttype: int, _pktid: int, self._server_sig_algs = \ set(extensions.get(b'server-sig-algs', b'').split(b',')) - def _process_kexinit(self, _pkttype: int, _pktid: int, - packet: SSHPacket) -> None: + async def _process_kexinit(self, _pkttype: int, _pktid: int, + packet: SSHPacket) -> None: """Process a key exchange request""" if self._kex: @@ -2323,7 +2346,7 @@ def _process_kexinit(self, _pkttype: int, _pktid: int, self.logger.debug1('Beginning key exchange') self.logger.debug2(' Key exchange alg: %s', self._kex.algorithm) - self._kex.start() + await self._kex.start() def _process_newkeys(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: diff --git a/asyncssh/kex.py b/asyncssh/kex.py index c6dff7fa..1458540c 100644 --- a/asyncssh/kex.py +++ b/asyncssh/kex.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2022 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -58,7 +58,7 @@ def __init__(self, alg: bytes, conn: 'SSHConnection', hash_alg: HashType): self._hash_alg = hash_alg - def start(self) -> None: + async def start(self) -> None: """Start key exchange""" raise NotImplementedError diff --git a/asyncssh/kex_dh.py b/asyncssh/kex_dh.py index 71493fad..fa1625f7 100644 --- a/asyncssh/kex_dh.py +++ b/asyncssh/kex_dh.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2022 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -33,7 +33,7 @@ from .gss import GSSError from .kex import Kex, register_kex_alg, register_gss_kex_alg from .misc import HashType, KeyExchangeFailed, ProtocolError -from .misc import get_symbol_names +from .misc import get_symbol_names, run_in_executor from .packet import Boolean, MPInt, String, UInt32, SSHPacket from .public_key import SigningKey, VerifyingKey @@ -274,7 +274,7 @@ def _process_reply(self, _pkttype: int, _pktid: int, host_key = client_conn.validate_server_host_key(host_key_data) self._verify_reply(host_key, host_key_data, sig) - def start(self) -> None: + async def start(self) -> None: """Start DH key exchange""" if self._conn.is_client(): @@ -384,7 +384,7 @@ def _process_group(self, _pkttype: int, _pktid: int, self._gex_data += MPInt(p) + MPInt(g) self._perform_init() - def start(self) -> None: + async def start(self) -> None: """Start DH group exchange""" if self._conn.is_client(): @@ -455,7 +455,7 @@ def _compute_server_shared(self) -> bytes: except ValueError: raise ProtocolError('Invalid ECDH client public key') from None - def start(self) -> None: + async def start(self) -> None: """Start ECDH key exchange""" if self._conn.is_client(): @@ -567,11 +567,11 @@ def _send_continue(self) -> None: self.send_packet(MSG_KEXGSS_CONTINUE, String(self._token)) - def _process_token(self, token: Optional[bytes] = None) -> None: + async def _process_token(self, token: Optional[bytes] = None) -> None: """Process a GSS token""" try: - self._token = self._gss.step(token) + self._token = await run_in_executor(self._gss.step, token) except GSSError as exc: if self._conn.is_server(): self.send_packet(MSG_KEXGSS_ERROR, UInt32(exc.maj_code), @@ -583,8 +583,8 @@ def _process_token(self, token: Optional[bytes] = None) -> None: raise KeyExchangeFailed(str(exc)) from None - def _process_init(self, _pkttype: int, _pktid: int, - packet: SSHPacket) -> None: + async def _process_gss_init(self, _pkttype: int, _pktid: int, + packet: SSHPacket) -> None: """Process a GSS init message""" if self._conn.is_client(): @@ -603,7 +603,7 @@ def _process_init(self, _pkttype: int, _pktid: int, else: self._host_key_data = b'' - self._process_token(token) + await self._process_token(token) if self._gss.complete: self._check_secure() @@ -612,8 +612,8 @@ def _process_init(self, _pkttype: int, _pktid: int, else: self._send_continue() - def _process_continue(self, _pkttype: int, _pktid: int, - packet: SSHPacket) -> None: + async def _process_continue(self, _pkttype: int, _pktid: int, + packet: SSHPacket) -> None: """Process a GSS continue message""" token = packet.get_string() @@ -622,7 +622,7 @@ def _process_continue(self, _pkttype: int, _pktid: int, if self._conn.is_client() and self._gss.complete: raise ProtocolError('Unexpected kexgss continue msg') - self._process_token(token) + await self._process_token(token) if self._conn.is_server() and self._gss.complete: self._check_secure() @@ -630,8 +630,8 @@ def _process_continue(self, _pkttype: int, _pktid: int, else: self._send_continue() - def _process_complete(self, _pkttype: int, _pktid: int, - packet: SSHPacket) -> None: + async def _process_complete(self, _pkttype: int, _pktid: int, + packet: SSHPacket) -> None: """Process a GSS complete message""" if self._conn.is_server(): @@ -647,7 +647,7 @@ def _process_complete(self, _pkttype: int, _pktid: int, if self._gss.complete: raise ProtocolError('Non-empty token after complete') - self._process_token(token) + await self._process_token(token) if self._token: raise ProtocolError('Non-empty token after complete') @@ -682,12 +682,12 @@ def _process_error(self, _pkttype: int, _pktid: int, self._conn.logger.debug1('GSS error: %s', msg.decode('utf-8', errors='ignore')) - def start(self) -> None: + async def start(self) -> None: """Start GSS key exchange""" if self._conn.is_client(): - self._process_token() - super().start() + await self._process_token() + await super().start() class _KexGSS(_KexGSSBase, _KexDH): @@ -696,7 +696,7 @@ class _KexGSS(_KexGSSBase, _KexDH): _handler_names = get_symbol_names(globals(), 'MSG_KEXGSS_') _packet_handlers = { - MSG_KEXGSS_INIT: _KexGSSBase._process_init, + MSG_KEXGSS_INIT: _KexGSSBase._process_gss_init, MSG_KEXGSS_CONTINUE: _KexGSSBase._process_continue, MSG_KEXGSS_COMPLETE: _KexGSSBase._process_complete, MSG_KEXGSS_HOSTKEY: _KexGSSBase._process_hostkey, @@ -713,7 +713,7 @@ class _KexGSSGex(_KexGSSBase, _KexDHGex): _group_type = MSG_KEXGSS_GROUP _packet_handlers = { - MSG_KEXGSS_INIT: _KexGSSBase._process_init, + MSG_KEXGSS_INIT: _KexGSSBase._process_gss_init, MSG_KEXGSS_CONTINUE: _KexGSSBase._process_continue, MSG_KEXGSS_COMPLETE: _KexGSSBase._process_complete, MSG_KEXGSS_HOSTKEY: _KexGSSBase._process_hostkey, @@ -729,7 +729,7 @@ class _KexGSSECDH(_KexGSSBase, _KexECDH): _handler_names = get_symbol_names(globals(), 'MSG_KEXGSS_') _packet_handlers = { - MSG_KEXGSS_INIT: _KexGSSBase._process_init, + MSG_KEXGSS_INIT: _KexGSSBase._process_gss_init, MSG_KEXGSS_CONTINUE: _KexGSSBase._process_continue, MSG_KEXGSS_COMPLETE: _KexGSSBase._process_complete, MSG_KEXGSS_HOSTKEY: _KexGSSBase._process_hostkey, diff --git a/asyncssh/kex_rsa.py b/asyncssh/kex_rsa.py index 8e66a241..6f5ea465 100644 --- a/asyncssh/kex_rsa.py +++ b/asyncssh/kex_rsa.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2022 by Ron Frederick and others. +# Copyright (c) 2018-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -64,7 +64,7 @@ def __init__(self, alg: bytes, conn: 'SSHConnection', hash_alg: HashType, self._k = 0 self._encrypted_k = b'' - def start(self) -> None: + async def start(self) -> None: """Start RSA key exchange""" if self._conn.is_server(): diff --git a/asyncssh/misc.py b/asyncssh/misc.py index c45a642e..b00cddc1 100644 --- a/asyncssh/misc.py +++ b/asyncssh/misc.py @@ -20,6 +20,7 @@ """Miscellaneous utility classes and functions""" +import asyncio import functools import ipaddress import re @@ -356,6 +357,14 @@ async def maybe_wait_closed(writer: '_SupportsWaitClosed') -> None: pass +async def run_in_executor(func: Callable[..., _T], *args: object) -> _T: + """Run a function in an asyncio executor""" + + loop = asyncio.get_event_loop() + + return await loop.run_in_executor(None, func, *args) + + def set_terminal_size(tty: IO, width: int, height: int, pixwidth: int, pixheight: int) -> None: """Set the terminal size of a TTY""" diff --git a/asyncssh/packet.py b/asyncssh/packet.py index 916348d0..fcd04d77 100644 --- a/asyncssh/packet.py +++ b/asyncssh/packet.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -20,14 +20,15 @@ """SSH packet encoding and decoding functions""" -from typing import Any, Callable, Iterable, Mapping, Optional, Sequence, Union +from typing import Any, Awaitable, Callable, Iterable, Mapping, Optional +from typing import Sequence, Union from .logging import SSHLogger -from .misc import plural +from .misc import MaybeAwait, plural _LoggedPacket = Union[bytes, 'SSHPacket'] -_PacketHandler = Callable[[Any, int, int, 'SSHPacket'], None] +_PacketHandler = Callable[[Any, int, int, 'SSHPacket'], MaybeAwait[None]] class PacketDecodeError(ValueError): @@ -230,11 +231,11 @@ def logger(self) -> SSHLogger: raise NotImplementedError def process_packet(self, pkttype: int, pktid: int, - packet: SSHPacket) -> bool: + packet: SSHPacket) -> Union[bool, Awaitable[None]]: """Log and process a received packet""" if pkttype in self._packet_handlers: - self._packet_handlers[pkttype](self, pkttype, pktid, packet) - return True + return self._packet_handlers[pkttype](self, pkttype, + pktid, packet) or True else: return False diff --git a/tests/test_auth.py b/tests/test_auth.py index ecbdd5a5..86125818 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -21,6 +21,7 @@ """Unit tests for authentication""" import asyncio +import inspect import unittest import asyncssh @@ -45,7 +46,7 @@ def connection_lost(self, exc): raise NotImplementedError - def process_packet(self, data): + async def process_packet(self, data): """Process an incoming packet""" raise NotImplementedError @@ -126,7 +127,7 @@ def connection_lost(self, exc=None): self.close() - def process_packet(self, data): + async def process_packet(self, data): """Process an incoming packet""" packet = SSHPacket(data) @@ -154,7 +155,10 @@ def process_packet(self, data): self._auth = None self._auth_waiter = None else: - self._auth.process_packet(pkttype, None, packet) + result = self._auth.process_packet(pkttype, None, packet) + + if inspect.isawaitable(result): + await result async def get_auth_result(self): """Return the result of the authentication""" @@ -285,7 +289,7 @@ def connection_lost(self, exc=None): self.close() - def process_packet(self, data): + async def process_packet(self, data): """Process an incoming packet""" packet = SSHPacket(data) @@ -308,7 +312,10 @@ def process_packet(self, data): else: self._auth = lookup_server_auth(self, 'user', method, packet) else: - self._auth.process_packet(pkttype, None, packet) + result = self._auth.process_packet(pkttype, None, packet) + + if inspect.isawaitable(result): + await result def send_userauth_failure(self, partial_success): """Send a user authentication failure response""" diff --git a/tests/test_kex.py b/tests/test_kex.py index 1269b611..e81344c0 100644 --- a/tests/test_kex.py +++ b/tests/test_kex.py @@ -21,6 +21,7 @@ """Unit tests for key exchange""" import asyncio +import inspect import unittest from hashlib import sha1 @@ -59,10 +60,10 @@ def __init__(self, alg, gss, peer, server=False): self._kex = get_kex(self, alg) - def start(self): + async def start(self): """Start key exchange""" - self._kex.start() + await self._kex.start() def connection_lost(self, exc): """Handle the closing of a connection""" @@ -72,12 +73,15 @@ def connection_lost(self, exc): def enable_gss_kex_auth(self): """Ignore request to enable GSS key exchange authentication""" - def process_packet(self, data): + async def process_packet(self, data): """Process an incoming packet""" packet = SSHPacket(data) pkttype = packet.get_byte() - self._kex.process_packet(pkttype, None, packet) + result = self._kex.process_packet(pkttype, None, packet) + + if inspect.isawaitable(result): + await result def get_hash_prefix(self): """Return the bytes used in calculating unique connection hashes""" @@ -101,68 +105,72 @@ def get_gss_context(self): return self._gss - def simulate_dh_init(self, e): + async def simulate_dh_init(self, e): """Simulate receiving a DH init packet""" - self.process_packet(Byte(MSG_KEXDH_INIT) + MPInt(e)) + await self.process_packet(Byte(MSG_KEXDH_INIT) + MPInt(e)) - def simulate_dh_reply(self, host_key_data, f, sig): + async def simulate_dh_reply(self, host_key_data, f, sig): """Simulate receiving a DH reply packet""" - self.process_packet(b''.join((Byte(MSG_KEXDH_REPLY), - String(host_key_data), - MPInt(f), String(sig)))) + await self.process_packet(b''.join((Byte(MSG_KEXDH_REPLY), + String(host_key_data), + MPInt(f), String(sig)))) - def simulate_dh_gex_group(self, p, g): + async def simulate_dh_gex_group(self, p, g): """Simulate receiving a DH GEX group packet""" - self.process_packet(Byte(MSG_KEX_DH_GEX_GROUP) + MPInt(p) + MPInt(g)) + await self.process_packet(Byte(MSG_KEX_DH_GEX_GROUP) + + MPInt(p) + MPInt(g)) - def simulate_dh_gex_init(self, e): + async def simulate_dh_gex_init(self, e): """Simulate receiving a DH GEX init packet""" - self.process_packet(Byte(MSG_KEX_DH_GEX_INIT) + MPInt(e)) + await self.process_packet(Byte(MSG_KEX_DH_GEX_INIT) + MPInt(e)) - def simulate_dh_gex_reply(self, host_key_data, f, sig): + async def simulate_dh_gex_reply(self, host_key_data, f, sig): """Simulate receiving a DH GEX reply packet""" - self.process_packet(b''.join((Byte(MSG_KEX_DH_GEX_REPLY), - String(host_key_data), + await self.process_packet(b''.join((Byte(MSG_KEX_DH_GEX_REPLY), + String(host_key_data), MPInt(f), String(sig)))) - def simulate_gss_complete(self, f, sig): + async def simulate_gss_complete(self, f, sig): """Simulate receiving a GSS complete packet""" - self.process_packet(b''.join((Byte(MSG_KEXGSS_COMPLETE), MPInt(f), - String(sig), Boolean(False)))) + await self.process_packet(b''.join((Byte(MSG_KEXGSS_COMPLETE), + MPInt(f), String(sig), + Boolean(False)))) - def simulate_ecdh_init(self, client_pub): + async def simulate_ecdh_init(self, client_pub): """Simulate receiving an ECDH init packet""" - self.process_packet(Byte(MSG_KEX_ECDH_INIT) + String(client_pub)) + await self.process_packet(Byte(MSG_KEX_ECDH_INIT) + String(client_pub)) - def simulate_ecdh_reply(self, host_key_data, server_pub, sig): + async def simulate_ecdh_reply(self, host_key_data, server_pub, sig): """Simulate receiving ab ECDH reply packet""" - self.process_packet(b''.join((Byte(MSG_KEX_ECDH_REPLY), - String(host_key_data), - String(server_pub), String(sig)))) + await self.process_packet(b''.join((Byte(MSG_KEX_ECDH_REPLY), + String(host_key_data), + String(server_pub), String(sig)))) - def simulate_rsa_pubkey(self, host_key_data, trans_key_data): + async def simulate_rsa_pubkey(self, host_key_data, trans_key_data): """Simulate receiving an RSA pubkey packet""" - self.process_packet(Byte(MSG_KEXRSA_PUBKEY) + String(host_key_data) + - String(trans_key_data)) + await self.process_packet(Byte(MSG_KEXRSA_PUBKEY) + + String(host_key_data) + + String(trans_key_data)) - def simulate_rsa_secret(self, encrypted_k): + async def simulate_rsa_secret(self, encrypted_k): """Simulate receiving an RSA secret packet""" - self.process_packet(Byte(MSG_KEXRSA_SECRET) + String(encrypted_k)) + await self.process_packet(Byte(MSG_KEXRSA_SECRET) + + String(encrypted_k)) - def simulate_rsa_done(self, sig): + async def simulate_rsa_done(self, sig): """Simulate receiving an RSA done packet""" - self.process_packet(Byte(MSG_KEXRSA_DONE) + String(sig)) + await self.process_packet(Byte(MSG_KEXRSA_DONE) + String(sig)) class _KexClientStub(_KexConnectionStub): @@ -238,8 +246,8 @@ async def _check_kex(self, alg, gss_host=None): client_conn, server_conn = _KexClientStub.make_pair(alg, gss_host) try: - client_conn.start() - server_conn.start() + await client_conn.start() + await server_conn.start() self.assertEqual((await client_conn.get_key()), (await server_conn.get_key())) @@ -315,25 +323,27 @@ async def test_dh_errors(self): with self.subTest('Init sent to client'): with self.assertRaises(asyncssh.ProtocolError): - client_conn.process_packet(Byte(MSG_KEXDH_INIT)) + await client_conn.process_packet(Byte(MSG_KEXDH_INIT)) with self.subTest('Reply sent to server'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.process_packet(Byte(MSG_KEXDH_REPLY)) + await server_conn.process_packet(Byte(MSG_KEXDH_REPLY)) with self.subTest('Invalid e value'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.simulate_dh_init(0) + await server_conn.simulate_dh_init(0) with self.subTest('Invalid f value'): with self.assertRaises(asyncssh.ProtocolError): - client_conn.start() - client_conn.simulate_dh_reply(host_key.public_data, 0, b'') + await client_conn.start() + await client_conn.simulate_dh_reply(host_key.public_data, + 0, b'') with self.subTest('Invalid signature'): with self.assertRaises(asyncssh.KeyExchangeFailed): - client_conn.start() - client_conn.simulate_dh_reply(host_key.public_data, 2, b'') + await client_conn.start() + await client_conn.simulate_dh_reply(host_key.public_data, + 2, b'') client_conn.close() server_conn.close() @@ -347,27 +357,27 @@ async def test_dh_gex_errors(self): with self.subTest('Request sent to client'): with self.assertRaises(asyncssh.ProtocolError): - client_conn.process_packet(Byte(MSG_KEX_DH_GEX_REQUEST)) + await client_conn.process_packet(Byte(MSG_KEX_DH_GEX_REQUEST)) with self.subTest('Group sent to server'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.simulate_dh_gex_group(1, 2) + await server_conn.simulate_dh_gex_group(1, 2) with self.subTest('Init sent to client'): with self.assertRaises(asyncssh.ProtocolError): - client_conn.simulate_dh_gex_init(1) + await client_conn.simulate_dh_gex_init(1) with self.subTest('Init sent before group'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.simulate_dh_gex_init(1) + await server_conn.simulate_dh_gex_init(1) with self.subTest('Reply sent to server'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.simulate_dh_gex_reply(b'', 1, b'') + await server_conn.simulate_dh_gex_reply(b'', 1, b'') with self.subTest('Reply sent before group'): with self.assertRaises(asyncssh.ProtocolError): - client_conn.simulate_dh_gex_reply(b'', 1, b'') + await client_conn.simulate_dh_gex_reply(b'', 1, b'') client_conn.close() server_conn.close() @@ -382,19 +392,19 @@ async def test_gss_errors(self): with self.subTest('Init sent to client'): with self.assertRaises(asyncssh.ProtocolError): - client_conn.process_packet(Byte(MSG_KEXGSS_INIT)) + await client_conn.process_packet(Byte(MSG_KEXGSS_INIT)) with self.subTest('Complete sent to server'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.process_packet(Byte(MSG_KEXGSS_COMPLETE)) + await server_conn.process_packet(Byte(MSG_KEXGSS_COMPLETE)) with self.subTest('Exchange failed to complete'): with self.assertRaises(asyncssh.ProtocolError): - client_conn.simulate_gss_complete(1, b'succeed') + await client_conn.simulate_gss_complete(1, b'succeed') with self.subTest('Error sent to server'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.process_packet(Byte(MSG_KEXGSS_ERROR)) + await server_conn.process_packet(Byte(MSG_KEXGSS_ERROR)) client_conn.close() server_conn.close() @@ -447,31 +457,32 @@ async def test_ecdh_errors(self): with self.subTest('Init sent to client'): with self.assertRaises(asyncssh.ProtocolError): - client_conn.simulate_ecdh_init(b'') + await client_conn.simulate_ecdh_init(b'') with self.subTest('Invalid client public key'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.simulate_ecdh_init(b'') + await server_conn.simulate_ecdh_init(b'') with self.subTest('Reply sent to server'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.simulate_ecdh_reply(b'', b'', b'') + await server_conn.simulate_ecdh_reply(b'', b'', b'') with self.subTest('Invalid server host key'): with self.assertRaises(asyncssh.KeyImportError): - client_conn.simulate_ecdh_reply(b'', b'', b'') + await client_conn.simulate_ecdh_reply(b'', b'', b'') with self.subTest('Invalid server public key'): with self.assertRaises(asyncssh.ProtocolError): host_key = server_conn.get_server_host_key() - client_conn.simulate_ecdh_reply(host_key.public_data, b'', b'') + await client_conn.simulate_ecdh_reply(host_key.public_data, + b'', b'') with self.subTest('Invalid signature'): with self.assertRaises(asyncssh.KeyExchangeFailed): host_key = server_conn.get_server_host_key() server_pub = ECDH(b'nistp256').get_public() - client_conn.simulate_ecdh_reply(host_key.public_data, - server_pub, b'') + await client_conn.simulate_ecdh_reply(host_key.public_data, + server_pub, b'') client_conn.close() server_conn.close() @@ -486,26 +497,27 @@ async def test_curve25519dh_errors(self): with self.subTest('Invalid client public key'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.simulate_ecdh_init(b'') + await server_conn.simulate_ecdh_init(b'') with self.subTest('Invalid server public key'): with self.assertRaises(asyncssh.ProtocolError): host_key = server_conn.get_server_host_key() - client_conn.simulate_ecdh_reply(host_key.public_data, b'', b'') + await client_conn.simulate_ecdh_reply(host_key.public_data, + b'', b'') with self.subTest('Invalid peer public key'): with self.assertRaises(asyncssh.ProtocolError): host_key = server_conn.get_server_host_key() server_pub = b'\x01' + 31*b'\x00' - client_conn.simulate_ecdh_reply(host_key.public_data, - server_pub, b'') + await client_conn.simulate_ecdh_reply(host_key.public_data, + server_pub, b'') with self.subTest('Invalid signature'): with self.assertRaises(asyncssh.KeyExchangeFailed): host_key = server_conn.get_server_host_key() server_pub = Curve25519DH().get_public() - client_conn.simulate_ecdh_reply(host_key.public_data, - server_pub, b'') + await client_conn.simulate_ecdh_reply(host_key.public_data, + server_pub, b'') client_conn.close() server_conn.close() @@ -520,19 +532,20 @@ async def test_curve448dh_errors(self): with self.subTest('Invalid client public key'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.simulate_ecdh_init(b'') + await server_conn.simulate_ecdh_init(b'') with self.subTest('Invalid server public key'): with self.assertRaises(asyncssh.ProtocolError): host_key = server_conn.get_server_host_key() - client_conn.simulate_ecdh_reply(host_key.public_data, b'', b'') + await client_conn.simulate_ecdh_reply(host_key.public_data, + b'', b'') with self.subTest('Invalid signature'): with self.assertRaises(asyncssh.KeyExchangeFailed): host_key = server_conn.get_server_host_key() server_pub = Curve448DH().get_public() - client_conn.simulate_ecdh_reply(host_key.public_data, - server_pub, b'') + await client_conn.simulate_ecdh_reply(host_key.public_data, + server_pub, b'') client_conn.close() server_conn.close() @@ -547,24 +560,25 @@ async def test_sntrup761dh_errors(self): with self.subTest('Invalid client SNTRUP761 public key'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.simulate_ecdh_init(b'') + await server_conn.simulate_ecdh_init(b'') with self.subTest('Invalid client Curve25519 public key'): with self.assertRaises(asyncssh.ProtocolError): pub = sntrup761_pubkey_bytes * b'\0' - server_conn.simulate_ecdh_init(pub) + await server_conn.simulate_ecdh_init(pub) with self.subTest('Invalid server SNTRUP761 public key'): with self.assertRaises(asyncssh.ProtocolError): host_key = server_conn.get_server_host_key() - client_conn.simulate_ecdh_reply(host_key.public_data, b'', b'') + await client_conn.simulate_ecdh_reply(host_key.public_data, + b'', b'') with self.subTest('Invalid server Curve25519 public key'): with self.assertRaises(asyncssh.ProtocolError): host_key = server_conn.get_server_host_key() ciphertext = sntrup761_ciphertext_bytes * b'\0' - client_conn.simulate_ecdh_reply(host_key.public_data, - ciphertext, b'') + await client_conn.simulate_ecdh_reply(host_key.public_data, + ciphertext, b'') client_conn.close() server_conn.close() @@ -578,32 +592,32 @@ async def test_rsa_errors(self): with self.subTest('Pubkey sent to server'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.simulate_rsa_pubkey(b'', b'') + await server_conn.simulate_rsa_pubkey(b'', b'') with self.subTest('Secret sent to client'): with self.assertRaises(asyncssh.ProtocolError): - client_conn.simulate_rsa_secret(b'') + await client_conn.simulate_rsa_secret(b'') with self.subTest('Done sent to server'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.simulate_rsa_done(b'') + await server_conn.simulate_rsa_done(b'') with self.subTest('Invalid transient public key'): with self.assertRaises(asyncssh.ProtocolError): - client_conn.simulate_rsa_pubkey(b'', b'') + await client_conn.simulate_rsa_pubkey(b'', b'') with self.subTest('Invalid encrypted secret'): with self.assertRaises(asyncssh.KeyExchangeFailed): - server_conn.start() - server_conn.simulate_rsa_secret(b'') + await server_conn.start() + await server_conn.simulate_rsa_secret(b'') with self.subTest('Invalid signature'): with self.assertRaises(asyncssh.KeyExchangeFailed): host_key = server_conn.get_server_host_key() trans_key = get_test_key('ssh-rsa', 2048) - client_conn.simulate_rsa_pubkey(host_key.public_data, - trans_key.public_data) - client_conn.simulate_rsa_done(b'') + await client_conn.simulate_rsa_pubkey(host_key.public_data, + trans_key.public_data) + await client_conn.simulate_rsa_done(b'') client_conn.close() server_conn.close() diff --git a/tests/util.py b/tests/util.py index b76b0c93..2c016e83 100644 --- a/tests/util.py +++ b/tests/util.py @@ -323,7 +323,7 @@ async def _process_packets(self): self.connection_lost(data) break - self.process_packet(data) + await self.process_packet(data) def connection_lost(self, exc): """Handle the closing of a connection"""