Skip to content
This repository has been archived by the owner on Dec 10, 2018. It is now read-only.

Asyncio Support #299

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions examples/asyncio/ping_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-

import thriftpy
from thriftpy.contrib.async import make_client

import asyncio


pp_thrift = thriftpy.load("pingpong.thrift", module_name="pp_thrift")

@asyncio.coroutine
def main():
c = yield from make_client(pp_thrift.PingService)

pong = yield from c.ping()
print(pong)

c.close()

if __name__ == '__main__':
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
loop.close()
29 changes: 29 additions & 0 deletions examples/asyncio/ping_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-

import thriftpy
import asyncio
from thriftpy.contrib.async import make_server


pp_thrift = thriftpy.load("pingpong.thrift", module_name="pp_thrift")


class Dispatcher(object):
@asyncio.coroutine
def ping(self):
print("ping pong!")
return 'pong'

if __name__ == '__main__':
loop = asyncio.get_event_loop()
server = loop.run_until_complete(
make_server(pp_thrift.PingService, Dispatcher()))

try:
loop.run_forever()
except KeyboardInterrupt:
pass

server.close()
loop.run_until_complete(server.wait_closed())
loop.close()
7 changes: 7 additions & 0 deletions examples/asyncio/pingpong.thrift
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# ping service demo
service PingService {
/*
* Sexy c style comment
*/
string ping(),
}
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import sys

collect_ignore = ["setup.py"]
if sys.version_info < (3, 5):
collect_ignore.append("test_asyncio.py")
135 changes: 135 additions & 0 deletions tests/test_asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# -*- coding: utf-8 -*-

from __future__ import absolute_import

from os import path

import thriftpy
from thriftpy.contrib.async import make_client, make_server
from thriftpy.rpc import make_client as make_sync_client
from thriftpy.transport import TFramedTransportFactory

import pytest
import asyncio
import threading

addressbook = thriftpy.load(path.join(path.dirname(__file__),
"addressbook.thrift"))


class Dispatcher(object):
def __init__(self):
self.registry = {}

@asyncio.coroutine
def add(self, person):
"""
bool add(1: Person person);
"""
if person.name in self.registry:
return False
self.registry[person.name] = person
return True

@asyncio.coroutine
def get(self, name):
"""
Person get(1: string name) throws (1: PersonNotExistsError not_exists);
"""
if name not in self.registry:
raise addressbook.PersonNotExistsError(
'Person "{0}" does not exist!'.format(name))
return self.registry[name]

@asyncio.coroutine
def remove(self, name):
"""
bool remove(1: string name) throws (1: PersonNotExistsError not_exists)
"""
# delay action for later
yield from asyncio.sleep(.1)
if name not in self.registry:
raise addressbook.PersonNotExistsError(
'Person "{0}" does not exist!'.format(name))
del self.registry[name]
return True


class Server(threading.Thread):
def __init__(self):
self.loop = loop = asyncio.new_event_loop()
self.server = loop.run_until_complete(make_server(
service=addressbook.AddressBookService,
handler=Dispatcher(),
loop=loop
))
super().__init__()

def run(self):
loop = self.loop
server = self.server
asyncio.set_event_loop(loop)

loop.run_forever()

server.close()
loop.run_until_complete(server.wait_closed())

loop.close()

def stop(self):
self.loop.call_soon_threadsafe(self.loop.stop)
self.join()


@pytest.fixture
def server():
server = Server()
server.start()
yield server
server.stop()


class TestAsyncClient:
@pytest.fixture
async def client(self, request, server):
client = await make_client(addressbook.AddressBookService)
request.addfinalizer(client.close)
return client

@pytest.mark.asyncio
async def test_result(self, client):
dennis = addressbook.Person(name='Dennis Ritchie')
success = await client.add(dennis)
assert success
success = await client.add(dennis)
assert not success
person = await client.get(dennis.name)
assert person.name == dennis.name

@pytest.mark.asyncio
async def test_exception(self, client):
with pytest.raises(addressbook.PersonNotExistsError):
await client.get('Brian Kernighan')


class TestSyncClient:
@pytest.fixture
async def client(self, request, server):
client = make_sync_client(addressbook.AddressBookService,
trans_factory=TFramedTransportFactory())
request.addfinalizer(client.close)
return client

def test_result(self, client):
dennis = addressbook.Person(name='Dennis Ritchie')
success = client.add(dennis)
assert success
success = client.add(dennis)
assert not success
person = client.get(dennis.name)
assert person.name == dennis.name

def test_exception(self, client):
with pytest.raises(addressbook.PersonNotExistsError):
client.get('Brian Kernighan')
190 changes: 190 additions & 0 deletions thriftpy/contrib/async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from thriftpy.thrift import TType, TMessageType, TApplicationException, TProcessor, TClient, args2kwargs
from thriftpy.transport import TMemoryBuffer
from thriftpy.protocol import TBinaryProtocolFactory

import asyncio
import struct

import logging
LOG = logging.getLogger(__name__)


class TAsyncTransport(TMemoryBuffer):
def __init__(self, trans):
super().__init__()
self._trans = trans
self._io_lock = asyncio.Lock()

def flush(self):
buf = self.getvalue()
self._trans.write(struct.pack("!i", len(buf)) + buf)
self.setvalue(b'')

@asyncio.coroutine
def read_frame(self):
# do not yield the event loop on a single reader
# between reading the frame_size and the buffer
with (yield from self._io_lock):
buff = yield from self._trans.readexactly(4)
sz, = struct.unpack('!i', buff)

frame = yield from self._trans.readexactly(sz)
self.setvalue(frame)

@asyncio.coroutine
def drain(self):
# drain cannot be called concurrently
with (yield from self._io_lock):
yield from self._trans.drain()


class TAsyncReader(TAsyncTransport):
def close(self):
self._trans.feed_eof()
super().close()


class TAsyncWriter(TAsyncTransport):
def close(self):
self._trans.write_eof()
super().close()


class TAsyncProcessor(TProcessor):
def __init__(self, service, handler):
self._service = service
self._handler = handler

@asyncio.coroutine
def process(self, iprot, oprot):
# the standard thrift protocol packs a single request per frame
# note that chunked requests are not supported, and would require
# additional sequence information
yield from iprot.trans.read_frame()
api, seqid, result, call = self.process_in(iprot)

if isinstance(result, TApplicationException):
self.send_exception(oprot, api, result, seqid)
yield from oprot.trans.drain()

try:
result.success = yield from call()
except Exception as e:
# raise if api don't have throws
self.handle_exception(e, result)

if not result.oneway:
self.send_result(oprot, api, result, seqid)
yield from oprot.trans.drain()


class TAsyncServer(object):
def __init__(self, processor,
iprot_factory=None,
oprot_factory=None,
timeout=None):
self.processor = processor
self.iprot_factory = iprot_factory or TBinaryProtocolFactory()
self.oprot_factory = oprot_factory or self.iprot_factory
self.timeout = timeout

@asyncio.coroutine
def __call__(self, reader, writer):
itrans = TAsyncReader(reader)
iproto = self.iprot_factory.get_protocol(itrans)

otrans = TAsyncWriter(writer)
oproto = self.oprot_factory.get_protocol(otrans)

while not reader.at_eof():
try:
fut = self.processor.process(iproto, oproto)
yield from asyncio.wait_for(fut, self.timeout)
except ConnectionError:
LOG.debug('client has closed the connection')
writer.close()
except asyncio.TimeoutError:
LOG.debug('timeout when processing the client request')
writer.close()
except asyncio.IncompleteReadError:
LOG.debug('client has closed the connection')
writer.close()
except Exception:
# app exception
LOG.exception('unhandled app exception')
writer.close()
writer.close()


class TAsyncClient(TClient):
def __init__(self, *args, timeout=None, **kwargs):
super().__init__(*args, **kwargs)
self.timeout = timeout

@asyncio.coroutine
def _req(self, _api, *args, **kwargs):
fut = self._req_impl(_api, *args, **kwargs)
result = yield from asyncio.wait_for(fut, self.timeout)
return result

@asyncio.coroutine
def _req_impl(self, _api, *args, **kwargs):
args_cls = getattr(self._service, _api + "_args")
_kw = args2kwargs(args_cls.thrift_spec, *args)

kwargs.update(_kw)
result_cls = getattr(self._service, _api + "_result")

self._send(_api, **kwargs)
yield from self._oprot.trans.drain()

# wait result only if non-oneway
if not getattr(result_cls, "oneway"):
yield from self._iprot.trans.read_frame()
return self._recv(_api)

def close(self):
self._iprot.trans.close()
self._oprot.trans.close()


@asyncio.coroutine
def make_server(
service,
handler,
host = 'localhost',
port = 9090,
proto_factory = TBinaryProtocolFactory(),
loop = None,
timeout = None
):
"""
create a thrift server running on an asyncio event-loop.
"""
processor = TAsyncProcessor(service, handler)
if loop is None:
loop = asyncio.get_event_loop()
server = yield from asyncio.start_server(
TAsyncServer(processor, proto_factory, timeout=timeout), host, port, loop=loop)
return server


@asyncio.coroutine
def make_client(service,
host = 'localhost',
port = 9090,
proto_factory = TBinaryProtocolFactory(),
timeout = None,
loop = None):
if loop is None:
loop = asyncio.get_event_loop()

reader, writer = yield from asyncio.open_connection(
host, port, loop=loop)

itrans = TAsyncReader(reader)
iproto = proto_factory.get_protocol(itrans)

otrans = TAsyncWriter(writer)
oproto = proto_factory.get_protocol(otrans)
return TAsyncClient(service, iproto, oproto)
Loading