Skip to content

Commit

Permalink
Merge pull request #1889 from yuvipanda/ip-check
Browse files Browse the repository at this point in the history
Better ipv6 support when checking network bans
  • Loading branch information
minrk authored Nov 27, 2024
2 parents 0284be9 + 694628c commit 0dd37ea
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 57 deletions.
14 changes: 0 additions & 14 deletions binderhub/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,19 +736,6 @@ def _cast_ban_networks(self, proposal):

return networks

ban_networks_min_prefix_len = Integer(
1,
help="The shortest prefix in ban_networks",
)

@observe("ban_networks")
def _update_prefix_len(self, change):
if not change.new:
min_len = 1
else:
min_len = min(net.prefixlen for net in change.new)
self.ban_networks_min_prefix_len = min_len or 1

tornado_settings = Dict(
config=True,
help="""
Expand Down Expand Up @@ -928,7 +915,6 @@ def initialize(self, *args, **kwargs):
"debug": self.debug,
"launcher": self.launcher,
"ban_networks": self.ban_networks,
"ban_networks_min_prefix_len": self.ban_networks_min_prefix_len,
"build_pool": self.build_pool,
"build_token_check_origin": self.build_token_check_origin,
"build_token_secret": self.build_token_secret,
Expand Down
6 changes: 3 additions & 3 deletions binderhub/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ def check_request_ip(self):
match = ip_in_networks(
request_ip,
ban_networks,
min_prefix_len=self.settings["ban_networks_min_prefix_len"],
)
if match:
network, message = match
network_spec = match
message = ban_networks[network_spec]
app_log.warning(
f"Blocking request from {request_ip} matching banned network {network}: {message}"
f"Blocking request from {request_ip} matching banned network {network_spec}: {message}"
)
raise web.HTTPError(403, f"Requests from {message} are not allowed")

Expand Down
10 changes: 7 additions & 3 deletions binderhub/tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,14 @@ async def test_ban_networks(request, app, use_session, path, banned, prefixlen,
"255.255.255.255/32": "255.x",
"1.0.0.0/8": "1.x",
}
local_net = str(ipaddress.ip_network("127.0.0.1").supernet(new_prefix=prefixlen))
local_net = [
str(ipaddress.ip_network("127.0.0.1").supernet(new_prefix=prefixlen)),
str(ipaddress.ip_network("::1").supernet(new_prefix=prefixlen)),
]

if banned:
ban_networks[local_net] = "local"
for net in local_net:
ban_networks[net] = "local"

# pass through trait validators on app
app.ban_networks = ban_networks
Expand All @@ -106,7 +111,6 @@ def reset():
app.tornado_app.settings,
{
"ban_networks": app.ban_networks,
"ban_networks_min_prefix_len": app.ban_networks_min_prefix_len,
},
):
r = await async_requests.get(url)
Expand Down
18 changes: 4 additions & 14 deletions binderhub/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,24 +116,14 @@ def later():
("192.168.1.2", ["192.168.1.0/24", "255.255.0.0/16"], True),
("192.168.1.2", ["255.255.0.0/16", "192.168.1.0/24"], True),
("192.168.1.2", [], False),
("2001:db8:0:0:0:0:0:1", ["2001:db8::/32", "192.168.1.1/32"], True),
("3001:db8:0:0:0:0:0:1", ["2001:db8::/32", "192.168.1.1/32"], False),
],
)
def test_ip_in_networks(ip, cidrs, found):
networks = {ipaddress.ip_network(cidr): f"message {cidr}" for cidr in cidrs}
if networks:
min_prefix = min(net.prefixlen for net in networks)
else:
min_prefix = 1
match = utils.ip_in_networks(ip, networks, min_prefix)
match = utils.ip_in_networks(ip, [ipaddress.ip_network(c) for c in cidrs])
if found:
assert match
net, message = match
assert message == f"message {net}"
assert ipaddress.ip_address(ip) in net
assert str(match) in cidrs
else:
assert match is False


def test_ip_in_networks_invalid():
with pytest.raises(ValueError):
utils.ip_in_networks("1.2.3.4", {}, 0)
36 changes: 13 additions & 23 deletions binderhub/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from collections import OrderedDict
from hashlib import blake2b
from typing import Iterable
from unittest.mock import Mock

from kubernetes.client import api_client
Expand Down Expand Up @@ -167,32 +168,21 @@ def url_path_join(*pieces):
return result


def ip_in_networks(ip, networks, min_prefix_len=1):
"""Return whether `ip` is in the dict of networks
This is O(1) regardless of the size of networks
Implementation based on netaddr.IPSet.__contains__
Repeatedly checks if ip/32; ip/31; ip/30; etc. is in networks
for all netmasks that match the given ip,
for a max of 32 dict key lookups for ipv4.
def ip_in_networks(
ip_addr: str, networks: Iterable[ipaddress.IPv4Network | ipaddress.IPv6Network]
):
"""
Checks if `ip_addr` is contained within any of the networks in `networks`
If all netmasks have a prefix length of e.g. 24 or greater,
min_prefix_len prevents checking wider network masks that can't possibly match.
If ip_addr is in any of the provided networks, return the first network that matches.
If not, return False
Returns `(netmask, networks[netmask])` for matching netmask
in networks, if found; False, otherwise.
Both ipv6 and ipv4 are supported
"""
if min_prefix_len < 1:
raise ValueError(f"min_prefix_len must be >= 1, got {min_prefix_len}")
if not networks:
return False
check_net = ipaddress.ip_network(ip)
while check_net.prefixlen >= min_prefix_len:
if check_net in networks:
return check_net, networks[check_net]
check_net = check_net.supernet(1)
ip = ipaddress.ip_address(ip_addr)
for network in networks:
if ip in network:
return network
return False


Expand Down

0 comments on commit 0dd37ea

Please sign in to comment.