diff --git a/binderhub/app.py b/binderhub/app.py index 2cf9e2b50..430162314 100644 --- a/binderhub/app.py +++ b/binderhub/app.py @@ -736,19 +736,6 @@ def _cast_ban_networks(self, proposal): return networks - ban_networks_min_prefix_len = Integer( - 1, - help="The shortest prefix in ban_networks", - ) - - @observe("ban_networks") - def _update_prefix_len(self, change): - if not change.new: - min_len = 1 - else: - min_len = min(net.prefixlen for net in change.new) - self.ban_networks_min_prefix_len = min_len or 1 - tornado_settings = Dict( config=True, help=""" @@ -928,7 +915,6 @@ def initialize(self, *args, **kwargs): "debug": self.debug, "launcher": self.launcher, "ban_networks": self.ban_networks, - "ban_networks_min_prefix_len": self.ban_networks_min_prefix_len, "build_pool": self.build_pool, "build_token_check_origin": self.build_token_check_origin, "build_token_secret": self.build_token_secret, diff --git a/binderhub/base.py b/binderhub/base.py index 5f198c401..3695f1cdc 100644 --- a/binderhub/base.py +++ b/binderhub/base.py @@ -38,12 +38,12 @@ def check_request_ip(self): match = ip_in_networks( request_ip, ban_networks, - min_prefix_len=self.settings["ban_networks_min_prefix_len"], ) if match: - network, message = match + network_spec = match + message = ban_networks[network_spec] app_log.warning( - f"Blocking request from {request_ip} matching banned network {network}: {message}" + f"Blocking request from {request_ip} matching banned network {network_spec}: {message}" ) raise web.HTTPError(403, f"Requests from {message} are not allowed") diff --git a/binderhub/tests/test_auth.py b/binderhub/tests/test_auth.py index 2780df0bb..ecb4c400a 100644 --- a/binderhub/tests/test_auth.py +++ b/binderhub/tests/test_auth.py @@ -91,9 +91,14 @@ async def test_ban_networks(request, app, use_session, path, banned, prefixlen, "255.255.255.255/32": "255.x", "1.0.0.0/8": "1.x", } - local_net = str(ipaddress.ip_network("127.0.0.1").supernet(new_prefix=prefixlen)) + local_net = [ + str(ipaddress.ip_network("127.0.0.1").supernet(new_prefix=prefixlen)), + str(ipaddress.ip_network("::1").supernet(new_prefix=prefixlen)), + ] + if banned: - ban_networks[local_net] = "local" + for net in local_net: + ban_networks[net] = "local" # pass through trait validators on app app.ban_networks = ban_networks @@ -106,7 +111,6 @@ def reset(): app.tornado_app.settings, { "ban_networks": app.ban_networks, - "ban_networks_min_prefix_len": app.ban_networks_min_prefix_len, }, ): r = await async_requests.get(url) diff --git a/binderhub/tests/test_utils.py b/binderhub/tests/test_utils.py index d3d242c10..dcca83277 100644 --- a/binderhub/tests/test_utils.py +++ b/binderhub/tests/test_utils.py @@ -116,24 +116,14 @@ def later(): ("192.168.1.2", ["192.168.1.0/24", "255.255.0.0/16"], True), ("192.168.1.2", ["255.255.0.0/16", "192.168.1.0/24"], True), ("192.168.1.2", [], False), + ("2001:db8:0:0:0:0:0:1", ["2001:db8::/32", "192.168.1.1/32"], True), + ("3001:db8:0:0:0:0:0:1", ["2001:db8::/32", "192.168.1.1/32"], False), ], ) def test_ip_in_networks(ip, cidrs, found): - networks = {ipaddress.ip_network(cidr): f"message {cidr}" for cidr in cidrs} - if networks: - min_prefix = min(net.prefixlen for net in networks) - else: - min_prefix = 1 - match = utils.ip_in_networks(ip, networks, min_prefix) + match = utils.ip_in_networks(ip, [ipaddress.ip_network(c) for c in cidrs]) if found: assert match - net, message = match - assert message == f"message {net}" - assert ipaddress.ip_address(ip) in net + assert str(match) in cidrs else: assert match is False - - -def test_ip_in_networks_invalid(): - with pytest.raises(ValueError): - utils.ip_in_networks("1.2.3.4", {}, 0) diff --git a/binderhub/utils.py b/binderhub/utils.py index 400aaa956..45a9f1c75 100644 --- a/binderhub/utils.py +++ b/binderhub/utils.py @@ -4,6 +4,7 @@ import time from collections import OrderedDict from hashlib import blake2b +from typing import Iterable from unittest.mock import Mock from kubernetes.client import api_client @@ -167,32 +168,21 @@ def url_path_join(*pieces): return result -def ip_in_networks(ip, networks, min_prefix_len=1): - """Return whether `ip` is in the dict of networks - - This is O(1) regardless of the size of networks - - Implementation based on netaddr.IPSet.__contains__ - - Repeatedly checks if ip/32; ip/31; ip/30; etc. is in networks - for all netmasks that match the given ip, - for a max of 32 dict key lookups for ipv4. +def ip_in_networks( + ip_addr: str, networks: Iterable[ipaddress.IPv4Network | ipaddress.IPv6Network] +): + """ + Checks if `ip_addr` is contained within any of the networks in `networks` - If all netmasks have a prefix length of e.g. 24 or greater, - min_prefix_len prevents checking wider network masks that can't possibly match. + If ip_addr is in any of the provided networks, return the first network that matches. + If not, return False - Returns `(netmask, networks[netmask])` for matching netmask - in networks, if found; False, otherwise. + Both ipv6 and ipv4 are supported """ - if min_prefix_len < 1: - raise ValueError(f"min_prefix_len must be >= 1, got {min_prefix_len}") - if not networks: - return False - check_net = ipaddress.ip_network(ip) - while check_net.prefixlen >= min_prefix_len: - if check_net in networks: - return check_net, networks[check_net] - check_net = check_net.supernet(1) + ip = ipaddress.ip_address(ip_addr) + for network in networks: + if ip in network: + return network return False