Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent SEGFAULTs on consecutive exec_command() invocations #658

Open
wants to merge 12 commits into
base: devel
Choose a base branch
from
1 change: 1 addition & 0 deletions docs/changelog-fragments/658.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Repetitive calls to ``exec_channel()`` no longer crash and return reliable output -- by :user:`Jakuje`.
4 changes: 4 additions & 0 deletions src/pylibsshext/channel.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@ cdef class Channel:
cdef _session
cdef libssh.ssh_channel _libssh_channel
cdef libssh.ssh_session _libssh_session

cdef class ChannelCallback:
cdef callbacks.ssh_channel_callbacks_struct callback
cdef _userdata
28 changes: 21 additions & 7 deletions src/pylibsshext/channel.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ cdef int _process_outputs(libssh.ssh_session session,
result.stdout += data_b
return len

cdef class ChannelCallback:
def __cinit__(self):
memset(&self.callback, 0, sizeof(self.callback))
callbacks.ssh_callbacks_init(&self.callback)
self.callback.channel_data_function = <callbacks.ssh_channel_data_callback>&_process_outputs

def set_user_data(self, userdata):
self._userdata = userdata
self.callback.userdata = <void *>self._userdata

cdef class Channel:
def __cinit__(self, session):
self._session = session
Expand Down Expand Up @@ -159,19 +169,23 @@ cdef class Channel:
libssh.ssh_channel_free(channel)
raise LibsshChannelException("Failed to open_session: [{0}]".format(rc))

result = CompletedProcess(args=command, returncode=-1, stdout=b'', stderr=b'')

cb = ChannelCallback()
cb.set_user_data(result)
callbacks.ssh_set_channel_callbacks(channel, &cb.callback)
# keep the callback around in the session object to avoid use after free
self._session.push_callback(cb)

rc = libssh.ssh_channel_request_exec(channel, command.encode("utf-8"))
if rc != libssh.SSH_OK:
libssh.ssh_channel_close(channel)
libssh.ssh_channel_free(channel)
raise LibsshChannelException("Failed to execute command [{0}]: [{1}]".format(command, rc))
result = CompletedProcess(args=command, returncode=-1, stdout=b'', stderr=b'')

cdef callbacks.ssh_channel_callbacks_struct cb
memset(&cb, 0, sizeof(cb))
cb.channel_data_function = <callbacks.ssh_channel_data_callback>&_process_outputs
cb.userdata = <void *>result
callbacks.ssh_callbacks_init(&cb)
callbacks.ssh_set_channel_callbacks(channel, &cb)
# wait before remote writes all data before closing the channel
while not libssh.ssh_channel_is_eof(channel):
libssh.ssh_channel_poll(channel, 0)

libssh.ssh_channel_send_eof(channel)
result.returncode = libssh.ssh_channel_get_exit_status(channel)
Expand Down
1 change: 1 addition & 0 deletions src/pylibsshext/session.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@ cdef class Session:
cdef _hash_py
cdef _fingerprint_py
cdef _keytype_py
cdef _channel_callbacks

cdef libssh.ssh_session get_libssh_session(Session session)
7 changes: 7 additions & 0 deletions src/pylibsshext/session.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ cdef class Session(object):
self._hash_py = None
self._fingerprint_py = None
self._keytype_py = None
# Due to delayed freeing of channels, some older libssh versions might expect
# the callbacks to be around even after we free the underlying channels so
# we should free them only when we terminate the session.
self._channel_callbacks = []

def __cinit__(self, host=None, **kwargs):
self._libssh_session = libssh.ssh_new()
Expand All @@ -123,6 +127,9 @@ cdef class Session(object):
libssh.ssh_free(self._libssh_session)
self._libssh_session = NULL

def push_callback(self, callback):
self._channel_callbacks.append(callback)

@property
def port(self):
cdef unsigned int port_i
Expand Down
9 changes: 6 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Pytest plugins and fixtures configuration."""

import logging
import shutil
import socket
import subprocess
Expand Down Expand Up @@ -115,6 +116,8 @@ def ssh_client_session(ssh_session_connect):
# noqa: DAR101
"""
ssh_session = Session()
# TODO Adjust when #597 will be merged
ssh_session.set_log_level(logging.CRITICAL)
ssh_session_connect(ssh_session)
try: # noqa: WPS501
yield ssh_session
Expand Down Expand Up @@ -173,21 +176,21 @@ def sshd_addr(free_port_num, ssh_authorized_keys_path, sshd_hostkey_path, sshd_p
'/usr/sbin/sshd',
'-D',
'-f', '/dev/null',
'-E', '/dev/stderr',
opt, 'LogLevel=DEBUG3',
opt, 'HostKey={key!s}'.format(key=sshd_hostkey_path),
opt, 'PidFile={pid!s}'.format(pid=sshd_path / 'sshd.pid'),

# NOTE: 'UsePAM no' is not supported on Fedora.
# Ref: https://bugzilla.redhat.com/show_bug.cgi?id=770756#c1
opt, 'UsePAM=yes',
# But it is ok for testing as it simplifies everything
opt, 'UsePAM=no',
opt, 'PasswordAuthentication=no',
opt, 'ChallengeResponseAuthentication=no',
opt, 'GSSAPIAuthentication=no',

opt, 'StrictModes=no',
opt, 'PermitEmptyPasswords=yes',
opt, 'PermitRootLogin=yes',
opt, 'Protocol=2',
opt, 'HostbasedAuthentication=no',
opt, 'IgnoreUserKnownHosts=yes',
opt, 'Port={port:d}'.format(port=free_port_num), # port before addr
Expand Down
36 changes: 27 additions & 9 deletions tests/unit/channel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,37 @@ def ssh_channel(ssh_client_session):
chan.close()


@pytest.mark.xfail(
reason='This test causes SEGFAULT, flakily. '
'Ref: https://github.com/ansible/pylibssh/issues/57',
strict=False,
)
@pytest.mark.forked
Jakuje marked this conversation as resolved.
Show resolved Hide resolved
def exec_second_command(ssh_channel):
"""Check the standard output of ``exec_command()`` as a string."""
u_cmd = ssh_channel.exec_command('echo -n Hello Again')
assert u_cmd.returncode == 0
assert u_cmd.stderr.decode() == '' # noqa: WPS302
assert u_cmd.stdout.decode() == u'Hello Again' # noqa: WPS302


def test_exec_command(ssh_channel):
"""Test getting the output of a remotely executed command."""
u_cmd_out = ssh_channel.exec_command('echo -n Hello World').stdout.decode()
assert u_cmd_out == u'Hello World' # noqa: WPS302
u_cmd = ssh_channel.exec_command('echo -n Hello World')
assert u_cmd.returncode == 0
assert u_cmd.stderr.decode() == ''
assert u_cmd.stdout.decode() == u'Hello World' # noqa: WPS302
# Test that repeated calls to exec_command do not segfault.
u_cmd_out = ssh_channel.exec_command('echo -n Hello Again').stdout.decode()
assert u_cmd_out == u'Hello Again' # noqa: WPS302

# NOTE: Call `exec_command()` once again from another function to
# NOTE: force it to happen in another place of the call stack,
# NOTE: making sure that the context is different from one in this
# NOTE: this test function. The resulting call stack will end up
# NOTE: being more random.
exec_second_command(ssh_channel)


def test_exec_command_stderr(ssh_channel):
"""Test getting the stderr of a remotely executed command."""
u_cmd = ssh_channel.exec_command('echo -n Hello World 1>&2')
assert u_cmd.returncode == 0
assert u_cmd.stderr.decode() == u'Hello World' # noqa: WPS302
assert u_cmd.stdout.decode() == ''


def test_double_close(ssh_channel):
Expand Down
Loading