Skip to content
This repository has been archived by the owner on Dec 10, 2018. It is now read-only.

Raise TTransportException in operations on closed socket #278

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,29 @@ def test_client_socket_close():
server_socket.close()


def test_client_socket_closed():
server_socket = TServerSocket(host="localhost", port=12345)
server_socket.listen()

client_socket = TSocket(host="localhost", port=12345)
client_socket.open()

conn = server_socket.accept()
client_socket.close()
assert not client_socket.is_open()

with pytest.raises(TTransportException) as e:
client_socket.read(1024)
assert "Could not read from closed socket" in e.value.message

with pytest.raises(TTransportException) as e:
client_socket.write(b"world")
assert "Could not write into closed socket" in e.value.message

conn.close()
server_socket.close()


def test_server_socket_close():
server_socket = TServerSocket(host="localhost", port=12345)
server_socket.listen()
Expand All @@ -124,6 +147,17 @@ def test_server_socket_close():
server_socket.close()


def test_server_socket_closed():
server_socket = TServerSocket(host="localhost", port=12345)
server_socket.listen()

server_socket.close()

with pytest.raises(TTransportException) as e:
server_socket.accept()
assert "Could not accept on closed socket" in e.value.message


def test_client_socket_set_timeout():
server_socket = TServerSocket(host="localhost", port=12345,
client_timeout=100)
Expand Down
14 changes: 14 additions & 0 deletions tests/test_sslsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,17 @@ def test_persist_ssl_context():
ssl_context=client_ssl_context)

_test_socket(server_socket, client_socket)


def test_server_socket_closed():
server_ssl_context = create_thriftpy_context(server_side=True)
server_ssl_context.load_cert_chain(certfile="ssl/server.pem")
server_socket = TSSLServerSocket(host="localhost", port=12345,
ssl_context=server_ssl_context)
server_socket.listen()

server_socket.close()

with pytest.raises(TTransportException) as e:
server_socket.accept()
assert "Could not accept on closed socket" in e.value.message
16 changes: 16 additions & 0 deletions thriftpy/transport/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ def open(self):
message="Could not connect to %s" % str(addr))

def read(self, sz):
if self.sock is None:
raise TTransportException(
type=TTransportException.NOT_OPEN,
message="Could not read from closed socket")

try:
buff = self.sock.recv(sz)
except socket.error as e:
Expand All @@ -126,6 +131,11 @@ def read(self, sz):
return buff

def write(self, buff):
if self.sock is None:
raise TTransportException(
type=TTransportException.NOT_OPEN,
message="Could not write into closed socket")

self.sock.sendall(buff)

def flush(self):
Expand Down Expand Up @@ -209,6 +219,11 @@ def listen(self):
self.sock.listen(self.backlog)

def accept(self):
if self.sock is None:
raise TTransportException(
type=TTransportException.NOT_OPEN,
message="Could not accept on closed socket")

client, _ = self.sock.accept()
if self.client_timeout:
client.settimeout(self.client_timeout)
Expand All @@ -221,5 +236,6 @@ def close(self):
try:
self.sock.shutdown(socket.SHUT_RDWR)
self.sock.close()
self.sock = None
except (socket.error, OSError):
pass
6 changes: 6 additions & 0 deletions thriftpy/transport/sslsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ssl
import struct

from . import TTransportException
from ._ssl import (
create_thriftpy_context,
RESTRICTED_SERVER_CIPHERS,
Expand Down Expand Up @@ -109,6 +110,11 @@ def __init__(self, host, port, socket_family=socket.AF_INET,
self.ssl_context.load_cert_chain(certfile=certfile)

def accept(self):
if self.sock is None:
raise TTransportException(
type=TTransportException.NOT_OPEN,
message="Could not accept on closed socket")

sock, _ = self.sock.accept()
try:
ssl_sock = self.ssl_context.wrap_socket(sock, server_side=True)
Expand Down