Skip to content

Commit

Permalink
plan9: basic client implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
svinota committed Jun 8, 2024
1 parent ac69658 commit fd18ba0
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 54 deletions.
10 changes: 4 additions & 6 deletions pyroute2/netlink/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,8 +877,6 @@ def __init__(
):
global cache_jit
dict.__init__(self)
for i in self.fields:
self[i[0]] = 0 # FIXME: only for number values
self._buf = None
self.data = data or bytearray()
self.offset = offset
Expand Down Expand Up @@ -1006,12 +1004,10 @@ def __ops(self, rvalue, op0, op1):
res['attrs'] = []
for attr in lvalue['attrs']:
if isinstance(attr[1], nlmsg_base):
print("recursion")
diff = getattr(attr[1], op0)(rvalue.get_attr(attr[0]))
if diff is not None:
res['attrs'].append([attr[0], diff])
else:
print("fail", type(attr[1]))
if op0 == '__sub__':
# operator -, complement
if rvalue.get_attr(attr[0]) != attr[1]:
Expand All @@ -1024,7 +1020,6 @@ def __ops(self, rvalue, op0, op1):
del res['attrs']
if not res:
return None
print(res)
return res

def __bool__(self):
Expand Down Expand Up @@ -1759,7 +1754,10 @@ def ft_encode(self, offset):
else:
zs = 0
for name, fmt in self.fields:
value = self[name]
default = self.defaults.get(name)
value = self[name] if name in self else default
if value is None:
continue

if isinstance(fmt, str):
offset = self.encode_field(fmt, self.data, offset, value, zs)
Expand Down
44 changes: 23 additions & 21 deletions pyroute2/netlink/nlsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ def parse(self, data, seq=None, callback=None, skip_alien_seq=False):
not support any defragmentation on that level
'''
offset = 0

# there must be at least one header in the buffer,
# 'IHHII' == 16 bytes
while offset <= len(data) - 16:
Expand Down Expand Up @@ -349,6 +348,7 @@ def __init__(
fileno=None,
sndbuf=1048576,
rcvbuf=1048576,
rcvsize=16384,
all_ns=False,
async_qsize=None,
nlm_generator=None,
Expand All @@ -357,6 +357,7 @@ def __init__(
strict_check=False,
groups=0,
nlm_echo=False,
use_socket=None,
):
# 8<-----------------------------------------
self.spec = NetlinkSocketSpec(
Expand All @@ -367,6 +368,7 @@ def __init__(
'fileno': fileno,
'sndbuf': sndbuf,
'rcvbuf': rcvbuf,
'rcvsize': rcvsize,
'all_ns': all_ns,
'async_qsize': async_qsize,
'target': target,
Expand All @@ -379,8 +381,11 @@ def __init__(
'strict_check': strict_check,
'groups': groups,
'nlm_echo': nlm_echo,
'use_socket': use_socket is not None,
'tag_field': 'sequence_number',
}
)
self.use_socket = use_socket
self.status = self.spec.status
self.capabilities = {
'create_bridge': config.kernel > [3, 2, 0],
Expand Down Expand Up @@ -421,6 +426,8 @@ def nlm_request_batch(*argv, **kwarg):

def restart_base_socket(self, sock=None):
"""Re-init a netlink socket."""
if self.status['use_socket']:
return self.use_socket
sock = self.socket if sock is None else sock
if sock is not None:
sock.close()
Expand All @@ -447,11 +454,9 @@ def __getattr__(self, attr):
'settimeout',
'gettimeout',
'shutdown',
'recv',
'recvfrom',
'recvfrom_into',
'fileno',
'send',
'sendto',
'connect',
'listen',
Expand Down Expand Up @@ -571,18 +576,13 @@ def put(
self.put_header(msg, msg_type, msg_flags, msg_seq, msg_pid)
msg.reset()
msg.encode()
return self.sendto(msg.data, addr)
return self.send(msg.data)

def getdata(self, block=False):
if block:
flags = 0
else:
flags = MSG_DONTWAIT
data = bytearray(16384)
log.debug("consume, block=%s", block)
bufsize = self.socket.recv_into(data, 0, flags)
log.debug("consumed bufsize=%s", bufsize)
return data
def recv(self, buffersize, flags=0):
return self.socket.recv(buffersize, flags)

def send(self, data, flags=0):
return self.socket.send(data, flags)

def get(
self,
Expand All @@ -601,22 +601,27 @@ def get(
# step 1. receive as much as we can from the socket
while True:
try:
data = self.getdata(block=False)
data = self.recv(self.status['rcvsize'], MSG_DONTWAIT)
if len(data) == 0 or data[0] == 0:
return
self.buffer.append(data)
except BlockingIOError:
break
if len(self.buffer) == 0:
self.buffer.append(self.getdata(block=True))
self.buffer.append(self.recv(self.status['rcvsize']))
# step 2. fetch one data block from the buffer
data = self.buffer.pop(0)
# step 3. parse the data block
messages = tuple(self.marshal.parse(data, msg_seq, callback))
if len(messages) == 0:
break
for msg in messages:
if msg_seq > 0 and msg['header']['sequence_number'] != msg_seq:
if started and msg['header']['type'] == NLMSG_DONE:
return
if (
msg_seq > 0
and msg['header'][self.status['tag_field']] != msg_seq
):
continue
msg['header']['target'] = self.status['target']
msg['header']['stats'] = Stats(0, 0, 0)
Expand All @@ -625,12 +630,9 @@ def get(
log.debug("message %s", msg)
yield msg

if started and msg['header']['type'] == NLMSG_DONE:
break

if started and (
(msg_seq == 0)
or (not msg['header']['flags'] & NLM_F_MULTI)
or (not msg['header'].get('flags', 0) & NLM_F_MULTI)
or (callable(terminate) and terminate(msg))
):
enough = True
Expand Down
44 changes: 36 additions & 8 deletions pyroute2/plan9/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import builtins
import json
import struct

from pyroute2.netlink import nlmsg
Expand Down Expand Up @@ -34,6 +36,10 @@
Topenfd = 98
Ropenfd = 99

# 9P2000.pr2 extensions
Tcall = 80
Rcall = 81


def array(kind, header='H'):
class CustomArray:
Expand Down Expand Up @@ -74,11 +80,14 @@ def __init__(self, qtype, vers, path):

@staticmethod
def decode_from(data, offset):
return dict(
zip(
('type', 'vers', 'path'),
struct.unpack_from('=BIQ', data, offset),
)
return (
dict(
zip(
('type', 'vers', 'path'),
struct.unpack_from('=BIQ', data, offset),
)
),
offset + struct.calcsize('=BIQ'),
)

@staticmethod
Expand Down Expand Up @@ -322,7 +331,7 @@ class msg_tread(msg_base):


class msg_rread(msg_base):
defaults = {'header': {'type': Rread}}
defaults = {'header': {'type': Rread}, 'data': b''}
fields = (('data', CData),)


Expand All @@ -336,9 +345,19 @@ class msg_rwrite(msg_base):
fields = (('count', 'I'),)


class msg_tcall(msg_base):
defaults = {'header': {'type': Tcall}}
fields = (('fid', 'I'), ('text', String), ('data', CData))


class msg_rcall(msg_base):
defaults = {'header': {'type': Rcall}}
fields = (('err', 'H'), ('text', String), ('data', CData))


class Marshal9P(Marshal):
default_message_class = msg_rerror
error_type = Rerror

msg_map = {
Tversion: msg_tversion,
Rversion: msg_rversion,
Expand All @@ -359,18 +378,27 @@ class Marshal9P(Marshal):
Rstat: msg_rstat,
Twrite: msg_twrite,
Rwrite: msg_rwrite,
Tcall: msg_tcall,
Rcall: msg_rcall,
}

def parse(self, data, seq=None, callback=None, skip_alien_seq=False):
offset = 0
while offset <= len(data) - 5:
(length, key, tag) = struct.unpack_from('IBH', data, offset)
(length, key, tag) = struct.unpack_from('=IBH', data, offset)
if skip_alien_seq and tag != seq:
continue
if not 0 < length <= len(data):
break
parser = self.get_parser(key, 0, tag)
msg = parser(data, offset, length)
if key == Rerror:
spec = json.loads(msg['ename'])
if spec['class'] in dir(builtins):
cls = getattr(builtins, spec['class'])
else:
cls = Exception
raise cls(spec['argv'])
offset += length
if msg is None:
continue
Expand Down
103 changes: 103 additions & 0 deletions pyroute2/plan9/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import json
import os
import pwd
import socket

from pyroute2.common import AddrPool
from pyroute2.plan9 import (
msg_tattach,
msg_tcall,
msg_tread,
msg_tversion,
msg_twalk,
msg_twrite,
)
from pyroute2.plan9.plan9socket import Plan9Socket


class Plan9Client:
def __init__(self, address=None, use_socket=None):
if use_socket is None:
use_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket = Plan9Socket(use_socket=use_socket)
if address is not None:
self.socket.connect(address)
self.wnames = {'': 0}
self.cwd = 0
self.fid_pool = AddrPool(minaddr=0x00000001, maxaddr=0x0000FFFF)

def init(self):
self.version()
self.auth()
self.attach()

def request(self, msg, tag=0):
if tag == 0:
tag = self.socket.addr_pool.alloc()
try:
msg['header']['tag'] = tag
msg.reset()
msg.encode()
self.socket.send(msg.data)
return self.socket.get(msg_seq=tag)[0]
finally:
self.socket.addr_pool.free(tag, ban=0xFF)

def version(self):
m = msg_tversion()
m['header']['tag'] = 0xFFFF
m['msize'] = 8192
m['version'] = '9P2000'
return self.request(m, tag=0xFFFF)

def auth(self):
pass

def attach(self):
m = msg_tattach()
m['fid'] = 0
m['afid'] = 0xFFFFFFFF
m['uname'] = pwd.getpwuid(os.getuid()).pw_name
m['aname'] = ''
return self.request(m)

def walk(self, path, newfid=None, fid=None):
m = msg_twalk()
m['fid'] = self.cwd if fid is None else fid
m['newfid'] = newfid if newfid is not None else self.fid_pool.alloc()
m['wname'] = path.split(os.path.sep)
self.wnames[path] = m['newfid']
return self.request(m)

def fid(self, path):
if path not in self.wnames:
newfid = self.fid_pool.alloc()
self.walk(path, newfid)
self.wnames[path] = newfid
return self.wnames[path]

def read(self, fid):
m = msg_tread()
m['fid'] = fid
m['offset'] = 0
m['count'] = 8192
return self.request(m)

def write(self, fid, data):
m = msg_twrite()
m['fid'] = fid
m['offset'] = 0
m['data'] = data
return self.request(m)

def call(self, fid, fname, argv=None, kwarg=None, data=b''):
spec = {
'call': fname,
'argv': argv if argv is not None else [],
'kwarg': kwarg if kwarg is not None else {},
}
m = msg_tcall()
m['fid'] = fid
m['text'] = json.dumps(spec)
m['data'] = data
return self.request(m)
Loading

0 comments on commit fd18ba0

Please sign in to comment.