diff --git a/circle.yml b/circle.yml index 696a18f..fce64d7 100644 --- a/circle.yml +++ b/circle.yml @@ -2,7 +2,7 @@ dependencies: override: - pip install -r requirements-tests.txt - pip install tox tox-pyenv - - pyenv local 2.7.11 3.5.2 3.6.0 + - pyenv local 2.7.11 3.5.2 3.6.2 test: override: - tox diff --git a/dataplicity/m2mmanager.py b/dataplicity/m2mmanager.py index 708e866..a9c2218 100644 --- a/dataplicity/m2mmanager.py +++ b/dataplicity/m2mmanager.py @@ -153,7 +153,7 @@ def on_instruction(self, sender, data): elif action == 'open-portredirect': device_port = data['device_port'] m2m_port = data['m2m_port'] - self.client.port_forward.redirect_port(device_port, m2m_port) + self.client.port_forward.redirect_port(m2m_port, device_port) elif action == 'reboot-device': self.reboot() elif action == 'read-file': diff --git a/dataplicity/portforward.py b/dataplicity/portforward.py index ede21e8..71ae528 100644 --- a/dataplicity/portforward.py +++ b/dataplicity/portforward.py @@ -24,12 +24,12 @@ class Connection(threading.Thread): # Max to read at-a-time BUFFER_SIZE = 1024 * 32 - def __init__(self, service, connection_id, channel): + def __init__(self, close_event, channel, host_port): """Initialize the connection, set up callbacks.""" super(Connection, self).__init__() - self._service = weakref.ref(service) - self.connection_id = connection_id + self._close_event = close_event self.channel = channel + self.host_port = host_port self._lock = threading.RLock() self.socket = None @@ -39,18 +39,16 @@ def __init__(self, service, connection_id, channel): self.on_channel_close, self.on_channel_control) - @property - def service(self): - """Get the parent service object (weak reference, may return None).""" - return self._service() - @property def close_event(self): """Get a threading.Event object.""" - return self.service.close_event + return self._close_event def run(self): - """Main loop, connects to local server, reads data, and writes it to an m2m channel.""" + """ + Main loop, connects to local server, reads data, and writes it to an + m2m channel. + """ bytes_written = 0 try: # Connect to remote host @@ -61,9 +59,12 @@ def run(self): # Read all the data we can and write it to the channel # TODO: Rework this loop to not use the timeout while not self.close_event.is_set(): - # Block for a period of time until the socket becomes readable, or there is an error + # Block for a period of time until the socket becomes readable, + # or there is an error try: - readable, _, exceptional = select.select([self.socket], [], [self.socket], 5.0) + readable, _, exceptional = select.select( + [self.socket], [], [self.socket], 5.0 + ) except Exception as e: # For paranoia only. log.warning('error %s in select', e) @@ -88,9 +89,8 @@ def run(self): break finally: log.debug("left recv loop (read %s bytes)", bytes_written) - # Tell service we're done with this connection - self.service.on_connection_complete(self.connection_id) - # These close methods are a null operation if the objects are already closed + # These close methods are a null operation if the objects are + # already closed self.channel.close() self._shutdown_read() @@ -138,9 +138,9 @@ def _connect(self): # Set the timeout for initial connect, as default is too high _socket.settimeout(5.0) - log.debug('connecting to %s', self.service.url) + log.debug('connecting to %s:%d', *self.host_port) try: - _socket.connect(self.service.host_port) + _socket.connect(self.host_port) except socket.timeout: log.error('timed out connecting to server') return False @@ -151,7 +151,7 @@ def _connect(self): log.exception('error connecting') return False else: - log.debug("connected to %s", self.service.url) + log.debug("connected to %s:%d", *self.host_port) self.socket = _socket self._flush_buffer() return True @@ -194,8 +194,8 @@ def __init__(self, manager, name, port, host="127.0.0.1"): self.name = name self.port = port self.host = host + self.m2m_port = None self._connect_index = 0 - self._connections = {} self._lock = threading.RLock() def __repr__(self): @@ -222,30 +222,18 @@ def host_port(self): """A tuple of (host, port) as a convenience for socket.connect.""" return (self.host, self.port) - @property - def url(self): - """URL of server we're connecting to.""" - return "http://{0}:{1}".format(self.host, self.port) - def connect(self, port_no): """Add a new connection.""" + self.m2m_port = port_no log.debug('new %r connection on port %s', self, port_no) with self._lock: - connection_id = self._connect_index = self._connect_index + 1 channel = self.m2m.m2m_client.get_channel(port_no) - connection = Connection(self, connection_id, channel) - self._connections[connection_id] = connection + connection = Connection( + self.close_event, + channel, + self.host_port, + ) connection.start() - return connection_id - - def remove_connection(self, connection_id): - with self._lock: - self._connections.pop(connection_id, None) - - def on_connection_complete(self, connection_id): - """Called by a connection when it is finished.""" - with self._lock: - self.remove_connection(connection_id) class PortForwardManager(object): @@ -317,9 +305,18 @@ def open(self, m2m_port, service=None, port=None): return service.connect(m2m_port) - def redirect_service(self, m2m_port, device_port): - service = Service( - manager=self, name='port-{}'.format(device_port), - port=device_port, host='127.0.0.1' - ) - service.connect(m2m_port) + def redirect_port(self, m2m_port, device_port): + # we need to store the reference to the Service somewhere so that + # when the Connection starts in thread it wouldn't loose the value + # of service variable. However, we have to remember that there may + # be numerous connections to the same local port. + # for instance, one could be ssh'ed into a machine twice, so we + # shan't confuse these two connections. + # therefore, an easy way is to store these in a dict, so that the + # lookup would be quick + # + Connection( + close_event=self.close_event, + channel=self.m2m.m2m_client.get_channel(m2m_port), + host_port=('127.0.0.1', device_port) + ).start() diff --git a/tests/dataplicity/test_portforward.py b/tests/dataplicity/test_portforward.py index cfd46eb..dea9041 100644 --- a/tests/dataplicity/test_portforward.py +++ b/tests/dataplicity/test_portforward.py @@ -1,6 +1,9 @@ import pytest -from mock import patch +from dataplicity.m2mmanager import M2MManager from dataplicity.portforward import PortForwardManager +from mock import call, patch + +_weakref_table = {} class FakeClient(object): @@ -10,7 +13,9 @@ class FakeClient(object): @pytest.fixture def manager(): client = FakeClient() - return PortForwardManager.init(client=client) + _weakref_table['client'] = client + yield PortForwardManager.init(client=client) + del _weakref_table['client'] @pytest.fixture @@ -31,10 +36,28 @@ def test_open_service_which_doesnt_exist_results_in_noop(manager, route): def test_redirect_service(manager, route): - with patch('dataplicity.portforward.Service.connect') as connect: - manager.redirect_service(9999, 22) - - assert connect.called + manager.client.m2m = M2MManager.init('ws://localhost/') + with patch('dataplicity.portforward.Connection.start') as connection_start: + manager.redirect_port(9999, 22) + + assert connection_start.called + + +def test_calling_redirect_service_from_m2mmanager_works(): + with patch( + 'dataplicity.portforward.PortForwardManager.redirect_port' + ) as redirect_port: + client = FakeClient() + client.port_forward = PortForwardManager(client) + m2m_manager = M2MManager(client=client, url='ws://localhost/') + m2m_manager.on_instruction( + 'sender', { + 'action': 'open-portredirect', + 'device_port': 22, + 'm2m_port': 1234 + } + ) + assert redirect_port.call_args == call(1234, 22) def test_can_open_service_by_name(manager): diff --git a/tox.ini b/tox.ini index e2ed27a..13208f9 100644 --- a/tox.ini +++ b/tox.ini @@ -2,10 +2,11 @@ envlist = py{27,35,36} [testenv] +usedevelop = true passenv = CIRCLE_ARTIFACTS setenv = PYTHONPATH={toxinidir}/tests deps = -rrequirements-tests.txt commands = py.test --cov-config {toxinidir}/.coveragerc \ - --cov={envsitepackagesdir}/dataplicity \ + --cov={toxinidir}/dataplicity \ --cov-report html:{env:CIRCLE_ARTIFACTS:reports}/{envname} \ {posargs:tests/}