Skip to content

Commit

Permalink
ipvs: Merge pull request #1187 from svinota/ipvs-1138
Browse files Browse the repository at this point in the history
basic ipvs support

Bug-Url: #1187
Bug-Url: #1138
  • Loading branch information
svinota authored Mar 14, 2024
2 parents 0f41932 + 8f73638 commit 3ca8a9f
Show file tree
Hide file tree
Showing 8 changed files with 453 additions and 43 deletions.
4 changes: 4 additions & 0 deletions pyroute2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pyroute2.iproute import ChaoticIPRoute, IPBatch, IPRoute, RawIPRoute
from pyroute2.iproute.ipmock import IPRoute as IPMock
from pyroute2.ipset import IPSet
from pyroute2.ipvs import IPVS, IPVSDest, IPVSService
from pyroute2.iwutil import IW
from pyroute2.ndb.main import NDB
from pyroute2.ndb.noipdb import NoIPDB
Expand Down Expand Up @@ -81,6 +82,9 @@
IPRoute,
IPRSocket,
IPSet,
IPVS,
IPVSDest,
IPVSService,
IW,
GenericNetlinkSocket,
L2tp,
Expand Down
33 changes: 0 additions & 33 deletions pyroute2/iproute/linux.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,9 @@
from pyroute2.lab import LAB_API
from pyroute2.netlink import (
NLM_F_ACK,
NLM_F_APPEND,
NLM_F_ATOMIC,
NLM_F_CREATE,
NLM_F_DUMP,
NLM_F_ECHO,
NLM_F_EXCL,
NLM_F_REPLACE,
NLM_F_REQUEST,
NLM_F_ROOT,
NLMSG_ERROR,
Expand Down Expand Up @@ -186,35 +182,6 @@ def filter_messages(*argv, **kwarg):
self._genmatch = self.filter_messages
self.filter_messages = filter_messages

def make_request_type(self, command, command_map):
if isinstance(command, basestring):
return (lambda x: (x[0], self.make_request_flags(x[1])))(
command_map[command]
)
elif isinstance(command, int):
return command, self.make_request_flags('create')
elif isinstance(command, (list, tuple)):
return command
else:
raise TypeError('allowed command types: int, str, list, tuple')

def make_request_flags(self, mode):
flags = {
'dump': NLM_F_REQUEST | NLM_F_DUMP,
'get': NLM_F_REQUEST | NLM_F_ACK,
'req': NLM_F_REQUEST | NLM_F_ACK,
}
flags['create'] = flags['req'] | NLM_F_CREATE | NLM_F_EXCL
flags['append'] = flags['req'] | NLM_F_CREATE | NLM_F_APPEND
flags['change'] = flags['req'] | NLM_F_REPLACE
flags['replace'] = flags['change'] | NLM_F_CREATE

return flags[mode] | (
NLM_F_ECHO
if (self.config['nlm_echo'] and mode not in ('get', 'dump'))
else 0
)

def filter_messages(self, dump_filter, msgs):
'''
Filter messages using `dump_filter`. The filter might be a
Expand Down
195 changes: 195 additions & 0 deletions pyroute2/ipvs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
'''
IPVS -- IP Virtual Server
-------------------------
IPVS configuration is done via generic netlink protocol.
At the low level one can use it with a GenericNetlinkSocket,
binding it to "IPVS" generic netlink family.
But for the convenience the library provides utility classes:
* IPVS -- a socket class to access the API
* IPVSService -- a class to define IPVS service records
* IPVSDest -- a class to define real server records
Dump all the records::
from pyroute2 import IPVS, IPVSDest, IPVSService
# run the socket
ipvs = IPVS()
# iterate all the IPVS services
for s in ipvs.service("dump"):
# create a utility object from a netlink message
service = IPVSService.from_message(s)
print("Service: ", service)
# iterate all the real servers for this service
for d in ipvs.dest("dump", service=service):
# create and print a utility object
dest = IPVSDest.from_message(d)
print(" Real server: ", dest)
Create a service and a real server record::
from socket import IPPROTO_TCP
from pyroute2 import IPVS, IPVSDest, IPVSService
ipvs = IPVS()
service = IPVSService(addr="192.168.122.1", port=80, protocol=IPPROTO_TCP)
real_server = IPVSDest(addr="10.0.2.20", port=80)
ipvs.service("add", service=service)
ipvs.dest("add", service=service, dest=real_server)
Delete a service::
from pyroute2 import IPVS, IPVSService
ipvs = IPVS()
ipvs.service("del",
service=IPVSService(
addr="192.168.122.1",
port=80,
protocol=IPPROTO_TCP
)
)
'''

from socket import AF_INET

from pyroute2.common import get_address_family
from pyroute2.netlink.generic import ipvs
from pyroute2.requests.common import NLAKeyTransform
from pyroute2.requests.main import RequestProcessor


class ServiceFieldFilter(NLAKeyTransform):
_nla_prefix = 'IPVS_SVC_ATTR_'

def set_addr(self, context, value):
ret = {"addr": value}
if "af" in context.keys():
family = context["af"]
else:
family = ret["af"] = get_address_family(value)
if family == AF_INET and "netmask" not in context.keys():
ret["netmask"] = "255.255.255.255"
return ret


class DestFieldFilter(NLAKeyTransform):
_nla_prefix = 'IPVS_DEST_ATTR_'

def set_addr(self, context, value):
ret = {"addr": value}
if "addr_family" not in context.keys():
ret["addr_family"] = get_address_family(value)
return ret


class NLAFilter(RequestProcessor):
msg = None
keys = tuple()
field_filter = None
nla = None
default_values = {}

def __init__(self, **kwarg):
dict.update(self, self.default_values)
super().__init__(prime=kwarg)

@classmethod
def from_message(cls, msg):
obj = cls()
for key, value in msg.get(cls.nla)["attrs"]:
obj[key] = value
obj.pop("stats", None)
obj.pop("stats64", None)
return obj

def dump_nla(self, items=None):
if items is None:
items = self.items()
self.update(self)
self.finalize()
return {
"attrs": list(
map(lambda x: (self.msg.name2nla(x[0]), x[1]), items)
)
}

def dump_key(self):
return self.dump_nla(
items=filter(lambda x: x[0] in self.key_fields, self.items())
)


class IPVSService(NLAFilter):
field_filter = ServiceFieldFilter()
msg = ipvs.ipvsmsg.service
key_fields = ("af", "protocol", "addr", "port")
nla = "IPVS_CMD_ATTR_SERVICE"
default_values = {
"timeout": 0,
"sched_name": "wlc",
"flags": {"flags": 0, "mask": 0xFFFF},
}


class IPVSDest(NLAFilter):
field_filter = DestFieldFilter()
msg = ipvs.ipvsmsg.dest
nla = "IPVS_CMD_ATTR_DEST"
default_values = {
"fwd_method": 3,
"weight": 1,
"tun_type": 0,
"tun_port": 0,
"tun_flags": 0,
"u_thresh": 0,
"l_thresh": 0,
}


class IPVS(ipvs.IPVSSocket):

def service(self, command, service=None):
command_map = {
"add": (ipvs.IPVS_CMD_NEW_SERVICE, "create"),
"set": (ipvs.IPVS_CMD_SET_SERVICE, "change"),
"update": (ipvs.IPVS_CMD_DEL_SERVICE, "change"),
"del": (ipvs.IPVS_CMD_DEL_SERVICE, "req"),
"get": (ipvs.IPVS_CMD_GET_SERVICE, "get"),
"dump": (ipvs.IPVS_CMD_GET_SERVICE, "dump"),
}
cmd, flags = self.make_request_type(command, command_map)
msg = ipvs.ipvsmsg()
msg["cmd"] = cmd
msg["version"] = ipvs.GENL_VERSION
if service is not None:
msg["attrs"] = [("IPVS_CMD_ATTR_SERVICE", service.dump_nla())]
return self.nlm_request(msg, msg_type=self.prid, msg_flags=flags)

def dest(self, command, service, dest=None):
command_map = {
"add": (ipvs.IPVS_CMD_NEW_DEST, "create"),
"set": (ipvs.IPVS_CMD_SET_DEST, "change"),
"update": (ipvs.IPVS_CMD_DEL_DEST, "change"),
"del": (ipvs.IPVS_CMD_DEL_DEST, "req"),
"get": (ipvs.IPVS_CMD_GET_DEST, "get"),
"dump": (ipvs.IPVS_CMD_GET_DEST, "dump"),
}
cmd, flags = self.make_request_type(command, command_map)
msg = ipvs.ipvsmsg()
msg["cmd"] = cmd
msg["version"] = 0x1
msg["attrs"] = [("IPVS_CMD_ATTR_SERVICE", service.dump_key())]
if dest is not None:
msg["attrs"].append(("IPVS_CMD_ATTR_DEST", dest.dump_nla()))
return self.nlm_request(msg, msg_type=self.prid, msg_flags=flags)
42 changes: 34 additions & 8 deletions pyroute2/netlink/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ class my_msg(nlmsg):

import io
import logging
import socket
import struct
import sys
import threading
Expand Down Expand Up @@ -894,7 +895,7 @@ def __init__(
self.value = NotInitialized
# work only on non-empty mappings
if self.nla_map and not self.__class__.__compiled_nla:
self.compile_nla()
self.compile_nla_table()
if self.header:
self['header'] = {}

Expand Down Expand Up @@ -1434,7 +1435,7 @@ def getvalue(self):

return self

def compile_nla(self):
def compile_nla_table(self):
# Bug-Url: https://github.com/svinota/pyroute2/issues/980
# Bug-Url: https://github.com/svinota/pyroute2/pull/981
if isinstance(self.nla_map, NlaMapAdapter):
Expand Down Expand Up @@ -2051,15 +2052,37 @@ class target(nla_base_string):
__slots__ = ()
sql_type = 'TEXT'
family = None
family_attr = None
own_parent = True

def __init__(self, *argv, **kwarg):
init = kwarg.get('init', None)
if init is not None:
key, value = init.split(',')
if key == 'family' and value.startswith('AF_'):
self.family = getattr(socket, value)
elif key == 'nla':
self.family_attr = value
super().__init__(*argv, **kwarg)

def get_family(self):
if self.family is not None:
return self.family
pointer = self
if self.family_attr is not None:
nla = self.family_attr
else:
nla = 'family'
while pointer.parent is not None:
pointer = pointer.parent
return pointer.get('family', AF_UNSPEC)
family = pointer.get(nla)
if family is not None:
return family
return AF_UNSPEC

@staticmethod
def get_addrlen(family):
return {AF_INET: 4, AF_INET6: 16, AF_MPLS: 4}.get(family, 4)

def encode(self):
family = self.get_family()
Expand Down Expand Up @@ -2096,14 +2119,17 @@ def encode(self):
def decode(self):
nla_base_string.decode(self)
family = self.get_family()
data = self['value']
if family in (AF_INET, AF_INET6):
self.value = inet_ntop(family, self['value'])
if family == AF_INET:
data = data[:4]
elif family == AF_INET6:
data = data[:16]
self.value = inet_ntop(family, data)
elif family == AF_MPLS:
self.value = []
for i in range(len(self['value']) // 4):
label = struct.unpack(
'>I', self['value'][i * 4 : i * 4 + 4]
)[0]
for i in range(len(data) // 4):
label = struct.unpack('>I', data[i * 4 : i * 4 + 4])[0]
record = {
'label': (label & 0xFFFFF000) >> 12,
'tc': (label & 0x00000E00) >> 9,
Expand Down
Loading

0 comments on commit 3ca8a9f

Please sign in to comment.