From fd18ba0804c7cd4a80d00e0fd95549bc1d7f9644 Mon Sep 17 00:00:00 2001 From: Peter Saveliev Date: Sat, 8 Jun 2024 15:07:19 +0200 Subject: [PATCH] plan9: basic client implementation --- pyroute2/netlink/__init__.py | 10 ++-- pyroute2/netlink/nlsocket.py | 44 ++++++++------- pyroute2/plan9/__init__.py | 44 ++++++++++++--- pyroute2/plan9/client.py | 103 ++++++++++++++++++++++++++++++++++ pyroute2/plan9/filesystem.py | 6 ++ pyroute2/plan9/plan9socket.py | 15 ++++- pyroute2/plan9/server.py | 59 +++++++++++++------ 7 files changed, 227 insertions(+), 54 deletions(-) create mode 100644 pyroute2/plan9/client.py diff --git a/pyroute2/netlink/__init__.py b/pyroute2/netlink/__init__.py index 78fd155f3..c07c55494 100644 --- a/pyroute2/netlink/__init__.py +++ b/pyroute2/netlink/__init__.py @@ -877,8 +877,6 @@ def __init__( ): global cache_jit dict.__init__(self) - for i in self.fields: - self[i[0]] = 0 # FIXME: only for number values self._buf = None self.data = data or bytearray() self.offset = offset @@ -1006,12 +1004,10 @@ def __ops(self, rvalue, op0, op1): res['attrs'] = [] for attr in lvalue['attrs']: if isinstance(attr[1], nlmsg_base): - print("recursion") diff = getattr(attr[1], op0)(rvalue.get_attr(attr[0])) if diff is not None: res['attrs'].append([attr[0], diff]) else: - print("fail", type(attr[1])) if op0 == '__sub__': # operator -, complement if rvalue.get_attr(attr[0]) != attr[1]: @@ -1024,7 +1020,6 @@ def __ops(self, rvalue, op0, op1): del res['attrs'] if not res: return None - print(res) return res def __bool__(self): @@ -1759,7 +1754,10 @@ def ft_encode(self, offset): else: zs = 0 for name, fmt in self.fields: - value = self[name] + default = self.defaults.get(name) + value = self[name] if name in self else default + if value is None: + continue if isinstance(fmt, str): offset = self.encode_field(fmt, self.data, offset, value, zs) diff --git a/pyroute2/netlink/nlsocket.py b/pyroute2/netlink/nlsocket.py index 4f455db13..31feace8f 100644 --- a/pyroute2/netlink/nlsocket.py +++ b/pyroute2/netlink/nlsocket.py @@ -221,7 +221,6 @@ def parse(self, data, seq=None, callback=None, skip_alien_seq=False): not support any defragmentation on that level ''' offset = 0 - # there must be at least one header in the buffer, # 'IHHII' == 16 bytes while offset <= len(data) - 16: @@ -349,6 +348,7 @@ def __init__( fileno=None, sndbuf=1048576, rcvbuf=1048576, + rcvsize=16384, all_ns=False, async_qsize=None, nlm_generator=None, @@ -357,6 +357,7 @@ def __init__( strict_check=False, groups=0, nlm_echo=False, + use_socket=None, ): # 8<----------------------------------------- self.spec = NetlinkSocketSpec( @@ -367,6 +368,7 @@ def __init__( 'fileno': fileno, 'sndbuf': sndbuf, 'rcvbuf': rcvbuf, + 'rcvsize': rcvsize, 'all_ns': all_ns, 'async_qsize': async_qsize, 'target': target, @@ -379,8 +381,11 @@ def __init__( 'strict_check': strict_check, 'groups': groups, 'nlm_echo': nlm_echo, + 'use_socket': use_socket is not None, + 'tag_field': 'sequence_number', } ) + self.use_socket = use_socket self.status = self.spec.status self.capabilities = { 'create_bridge': config.kernel > [3, 2, 0], @@ -421,6 +426,8 @@ def nlm_request_batch(*argv, **kwarg): def restart_base_socket(self, sock=None): """Re-init a netlink socket.""" + if self.status['use_socket']: + return self.use_socket sock = self.socket if sock is None else sock if sock is not None: sock.close() @@ -447,11 +454,9 @@ def __getattr__(self, attr): 'settimeout', 'gettimeout', 'shutdown', - 'recv', 'recvfrom', 'recvfrom_into', 'fileno', - 'send', 'sendto', 'connect', 'listen', @@ -571,18 +576,13 @@ def put( self.put_header(msg, msg_type, msg_flags, msg_seq, msg_pid) msg.reset() msg.encode() - return self.sendto(msg.data, addr) + return self.send(msg.data) - def getdata(self, block=False): - if block: - flags = 0 - else: - flags = MSG_DONTWAIT - data = bytearray(16384) - log.debug("consume, block=%s", block) - bufsize = self.socket.recv_into(data, 0, flags) - log.debug("consumed bufsize=%s", bufsize) - return data + def recv(self, buffersize, flags=0): + return self.socket.recv(buffersize, flags) + + def send(self, data, flags=0): + return self.socket.send(data, flags) def get( self, @@ -601,14 +601,14 @@ def get( # step 1. receive as much as we can from the socket while True: try: - data = self.getdata(block=False) + data = self.recv(self.status['rcvsize'], MSG_DONTWAIT) if len(data) == 0 or data[0] == 0: return self.buffer.append(data) except BlockingIOError: break if len(self.buffer) == 0: - self.buffer.append(self.getdata(block=True)) + self.buffer.append(self.recv(self.status['rcvsize'])) # step 2. fetch one data block from the buffer data = self.buffer.pop(0) # step 3. parse the data block @@ -616,7 +616,12 @@ def get( if len(messages) == 0: break for msg in messages: - if msg_seq > 0 and msg['header']['sequence_number'] != msg_seq: + if started and msg['header']['type'] == NLMSG_DONE: + return + if ( + msg_seq > 0 + and msg['header'][self.status['tag_field']] != msg_seq + ): continue msg['header']['target'] = self.status['target'] msg['header']['stats'] = Stats(0, 0, 0) @@ -625,12 +630,9 @@ def get( log.debug("message %s", msg) yield msg - if started and msg['header']['type'] == NLMSG_DONE: - break - if started and ( (msg_seq == 0) - or (not msg['header']['flags'] & NLM_F_MULTI) + or (not msg['header'].get('flags', 0) & NLM_F_MULTI) or (callable(terminate) and terminate(msg)) ): enough = True diff --git a/pyroute2/plan9/__init__.py b/pyroute2/plan9/__init__.py index ca17fa3aa..cc4c5b94f 100644 --- a/pyroute2/plan9/__init__.py +++ b/pyroute2/plan9/__init__.py @@ -1,3 +1,5 @@ +import builtins +import json import struct from pyroute2.netlink import nlmsg @@ -34,6 +36,10 @@ Topenfd = 98 Ropenfd = 99 +# 9P2000.pr2 extensions +Tcall = 80 +Rcall = 81 + def array(kind, header='H'): class CustomArray: @@ -74,11 +80,14 @@ def __init__(self, qtype, vers, path): @staticmethod def decode_from(data, offset): - return dict( - zip( - ('type', 'vers', 'path'), - struct.unpack_from('=BIQ', data, offset), - ) + return ( + dict( + zip( + ('type', 'vers', 'path'), + struct.unpack_from('=BIQ', data, offset), + ) + ), + offset + struct.calcsize('=BIQ'), ) @staticmethod @@ -322,7 +331,7 @@ class msg_tread(msg_base): class msg_rread(msg_base): - defaults = {'header': {'type': Rread}} + defaults = {'header': {'type': Rread}, 'data': b''} fields = (('data', CData),) @@ -336,9 +345,19 @@ class msg_rwrite(msg_base): fields = (('count', 'I'),) +class msg_tcall(msg_base): + defaults = {'header': {'type': Tcall}} + fields = (('fid', 'I'), ('text', String), ('data', CData)) + + +class msg_rcall(msg_base): + defaults = {'header': {'type': Rcall}} + fields = (('err', 'H'), ('text', String), ('data', CData)) + + class Marshal9P(Marshal): default_message_class = msg_rerror - error_type = Rerror + msg_map = { Tversion: msg_tversion, Rversion: msg_rversion, @@ -359,18 +378,27 @@ class Marshal9P(Marshal): Rstat: msg_rstat, Twrite: msg_twrite, Rwrite: msg_rwrite, + Tcall: msg_tcall, + Rcall: msg_rcall, } def parse(self, data, seq=None, callback=None, skip_alien_seq=False): offset = 0 while offset <= len(data) - 5: - (length, key, tag) = struct.unpack_from('IBH', data, offset) + (length, key, tag) = struct.unpack_from('=IBH', data, offset) if skip_alien_seq and tag != seq: continue if not 0 < length <= len(data): break parser = self.get_parser(key, 0, tag) msg = parser(data, offset, length) + if key == Rerror: + spec = json.loads(msg['ename']) + if spec['class'] in dir(builtins): + cls = getattr(builtins, spec['class']) + else: + cls = Exception + raise cls(spec['argv']) offset += length if msg is None: continue diff --git a/pyroute2/plan9/client.py b/pyroute2/plan9/client.py new file mode 100644 index 000000000..96b3f78fe --- /dev/null +++ b/pyroute2/plan9/client.py @@ -0,0 +1,103 @@ +import json +import os +import pwd +import socket + +from pyroute2.common import AddrPool +from pyroute2.plan9 import ( + msg_tattach, + msg_tcall, + msg_tread, + msg_tversion, + msg_twalk, + msg_twrite, +) +from pyroute2.plan9.plan9socket import Plan9Socket + + +class Plan9Client: + def __init__(self, address=None, use_socket=None): + if use_socket is None: + use_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket = Plan9Socket(use_socket=use_socket) + if address is not None: + self.socket.connect(address) + self.wnames = {'': 0} + self.cwd = 0 + self.fid_pool = AddrPool(minaddr=0x00000001, maxaddr=0x0000FFFF) + + def init(self): + self.version() + self.auth() + self.attach() + + def request(self, msg, tag=0): + if tag == 0: + tag = self.socket.addr_pool.alloc() + try: + msg['header']['tag'] = tag + msg.reset() + msg.encode() + self.socket.send(msg.data) + return self.socket.get(msg_seq=tag)[0] + finally: + self.socket.addr_pool.free(tag, ban=0xFF) + + def version(self): + m = msg_tversion() + m['header']['tag'] = 0xFFFF + m['msize'] = 8192 + m['version'] = '9P2000' + return self.request(m, tag=0xFFFF) + + def auth(self): + pass + + def attach(self): + m = msg_tattach() + m['fid'] = 0 + m['afid'] = 0xFFFFFFFF + m['uname'] = pwd.getpwuid(os.getuid()).pw_name + m['aname'] = '' + return self.request(m) + + def walk(self, path, newfid=None, fid=None): + m = msg_twalk() + m['fid'] = self.cwd if fid is None else fid + m['newfid'] = newfid if newfid is not None else self.fid_pool.alloc() + m['wname'] = path.split(os.path.sep) + self.wnames[path] = m['newfid'] + return self.request(m) + + def fid(self, path): + if path not in self.wnames: + newfid = self.fid_pool.alloc() + self.walk(path, newfid) + self.wnames[path] = newfid + return self.wnames[path] + + def read(self, fid): + m = msg_tread() + m['fid'] = fid + m['offset'] = 0 + m['count'] = 8192 + return self.request(m) + + def write(self, fid, data): + m = msg_twrite() + m['fid'] = fid + m['offset'] = 0 + m['data'] = data + return self.request(m) + + def call(self, fid, fname, argv=None, kwarg=None, data=b''): + spec = { + 'call': fname, + 'argv': argv if argv is not None else [], + 'kwarg': kwarg if kwarg is not None else {}, + } + m = msg_tcall() + m['fid'] = fid + m['text'] = json.dumps(spec) + m['data'] = data + return self.request(m) diff --git a/pyroute2/plan9/filesystem.py b/pyroute2/plan9/filesystem.py index 70722612e..0d84c769d 100644 --- a/pyroute2/plan9/filesystem.py +++ b/pyroute2/plan9/filesystem.py @@ -13,6 +13,7 @@ class Inode: qid = None stat = None data = None + callbacks = None def __init__( self, @@ -29,6 +30,7 @@ def __init__( self.data = io.BytesIO(data.encode('utf-8')) self.parents = parents if parents is not None else set() self.children = children if children is not None else set() + self.callbacks = {} self.stat = Stat() self.qid = Qid(qtype, 0, path) self.stat['uid'] = ( @@ -59,6 +61,10 @@ def get_child(self, name): return child raise KeyError('file not found') + def add_callback(self, call, f): + self.callbacks[call] = f + return self + def add_parent(self, inode): return self.parents.add(inode) diff --git a/pyroute2/plan9/plan9socket.py b/pyroute2/plan9/plan9socket.py index 3c66df3eb..c7fb3db18 100644 --- a/pyroute2/plan9/plan9socket.py +++ b/pyroute2/plan9/plan9socket.py @@ -6,10 +6,18 @@ class Plan9Socket(NetlinkSocket): def __init__(self, *argv, **kwarg): - super().__init__() + kw = {} + co_varnames = super().__init__.__code__.co_varnames + for key, value in kwarg.items(): + if key in co_varnames: + kw[key] = value + super().__init__(**kw) self.marshal = Marshal9P() + self.spec['tag_field'] = 'tag' def restart_base_socket(self, sock=None): + if self.status['use_socket']: + return self.use_socket sock = self.socket if sock is None else sock if sock is not None: sock.close() @@ -21,7 +29,12 @@ def bind(self, *argv, **kwarg): return self.socket.bind(*argv, **kwarg) def accept(self): + if self.status['use_socket']: + return (self, None) (connection, address) = self.socket.accept() new_socket = self.clone() new_socket.socket = connection return (new_socket, address) + + def connect(self, address): + self.socket.connect(address) diff --git a/pyroute2/plan9/server.py b/pyroute2/plan9/server.py index 839748701..ca4643dd2 100644 --- a/pyroute2/plan9/server.py +++ b/pyroute2/plan9/server.py @@ -1,7 +1,10 @@ +import json + from pyroute2.plan9 import ( Stat, Tattach, Tauth, + Tcall, Tclunk, Topen, Tread, @@ -10,6 +13,7 @@ Twalk, Twrite, msg_rattach, + msg_rcall, msg_rclunk, msg_rerror, msg_ropen, @@ -25,18 +29,18 @@ data = str(dir()) +def get_exception_args(exc): + args = [] + if hasattr(exc, 'errno'): + args.append(exc.errno) + args.append(exc.strerror) + return args + + def route(rtable, request, state): def decorator(f): - def wrapped(*argv, **kwarg): - try: - return f(*argv, **kwarg) - except Exception as e: - m = msg_rerror() - m['ename'] = repr(e) - return m - - rtable[request] = wrapped - return wrapped + rtable[request] = f + return f return decorator @@ -104,10 +108,21 @@ def t_open(self, req): m['iounit'] = 8192 return m + @route(rtable, request=Tcall, state=(Twalk, Topen, Tstat)) + def t_call(self, req): + m = msg_rcall() + inode = self.session.get_fid(req['fid']) + m['err'] = 255 + if Tcall in inode.callbacks: + m = inode.callbacks[Tcall](self.session, inode, req, m) + return m + @route(rtable, request=Twrite, state=(Topen,)) def t_write(self, req): m = msg_rwrite() inode = self.session.get_fid(req['fid']) + if Twrite in inode.callbacks: + return inode.callbacks[Twrite](self.session, inode, req, m) if inode.qid['type'] & 0x80: raise TypeError('can not call write() on dir') inode.data.seek(req['offset']) @@ -118,6 +133,8 @@ def t_write(self, req): def t_read(self, req): m = msg_rread() inode = self.session.get_fid(req['fid']) + if Tread in inode.callbacks: + return inode.callbacks[Tread](self.session, inode, req, m) if inode.qid['type'] & 0x80: data = bytearray() offset = 0 @@ -140,23 +157,29 @@ def serve(self): if len(request) != 1: return t_message = request[0] - r_message = self.rtable[t_message['header']['type']]( - self, t_message - ) try: + r_message = self.rtable[t_message['header']['type']]( + self, t_message + ) + r_message['header']['tag'] = t_message['header']['tag'] r_message.encode() except Exception as e: r_message = msg_rerror() - r_message['ename'] = repr(e) + spec = { + 'class': e.__class__.__name__, + 'argv': get_exception_args(e), + } + r_message['ename'] = json.dumps(spec) r_message.encode() self.socket.send(r_message.data) class Plan9Server: - def __init__(self, address): - self.socket = Plan9Socket() - self.socket.bind(address) - self.socket.listen(1) + def __init__(self, address=None, use_socket=None): + self.socket = Plan9Socket(use_socket=use_socket) + if use_socket is None: + self.socket.bind(address) + self.socket.listen(1) self.filesystem = Filesystem() def accept(self):