From 51d6e38f4968179d70d3a6bf7574d2ac78867730 Mon Sep 17 00:00:00 2001 From: clickyotomy Date: Sat, 14 Dec 2024 02:37:45 +0000 Subject: [PATCH] netfilter: Support multiport matching (-m multiport) This set of changes adds: - support for `xt_multiport_{,v1}` matchers for matching for a range of ports and their inverse, i.e.,: ``` -m multiport [!] --[s|d]ports (PORT,...|PORT:PORT,...) ``` - support for `IP{,6}T_SO_GET_REVISION_MATCH` socket options, which allows `iptables` to query for the highest supported revision for a given matcher --- pkg/abi/linux/netfilter.go | 60 ++++ pkg/abi/linux/netfilter_test.go | 2 + pkg/sentry/socket/netfilter/BUILD | 2 + pkg/sentry/socket/netfilter/extensions.go | 21 +- .../socket/netfilter/multiport_matcher.go | 242 +++++++++++++++ .../socket/netfilter/multiport_matcher_v1.go | 214 ++++++++++++++ pkg/sentry/socket/netfilter/netfilter.go | 73 ++++- pkg/sentry/socket/netstack/netstack.go | 47 +++ test/iptables/filter_input.go | 184 ++++++++++++ test/iptables/filter_output.go | 277 ++++++++++++++++++ test/iptables/iptables_test.go | 19 ++ 11 files changed, 1137 insertions(+), 4 deletions(-) create mode 100644 pkg/sentry/socket/netfilter/multiport_matcher.go create mode 100644 pkg/sentry/socket/netfilter/multiport_matcher_v1.go diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index d5ddf1199b..2b430fe501 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -727,3 +727,63 @@ const ( // packets do not have an associated socket. XT_OWNER_SOCKET = 1 << 2 ) + +// XT_MULTI_PORTS is the maximum number of ports that the +// multiport match can handle. +const XT_MULTI_PORTS = 15 + +// Flags in XTMultiport{,V1}.Flags; values from "enum xt_multiport_flags" +// in "include/uapi/linux/netfilter/xt_multiport.h". +const ( + XT_MULTIPORT_SOURCE uint8 = 0x0 // Match against source ports. + XT_MULTIPORT_DESTINATION uint8 = 0x1 // Match against destination ports. + XT_MULTIPORT_EITHER uint8 = 0x2 // Match against either ports. +) + +// XTMultiport holds data for matching packets against a set +// of ports. It corresponds to "struct xt_multiport" defined +// in "include/uapi/linux/netfilter/xt_multiport.h". +// +// +marshal +type XTMultiport struct { + // Flags indicates whether the match applies to + // source ports, destination ports, or either, as + // defined by "enum xt_multiport_flags". + Flags uint8 + + // Count is the number of ports in the "Ports" + // slice that the match will check. It must be + // between 1 and "XT_MULTI_PORTS" (inclusive). + Count uint8 + + // Ports is the set of ports that will be matched. + // Only the first "Count" entries are considered. + Ports [XT_MULTI_PORTS]uint16 +} + +// XTMultiportV1 holds data for matching packets against a set +// of ports. It corresponds to "struct xt_multiport_v1" defined +// in "include/uapi/linux/netfilter/xt_multiport.h". +// +// +marshal +type XTMultiportV1 struct { + // Fields same as "XTMultiport". + Flags uint8 + Count uint8 + Ports [XT_MULTI_PORTS]uint16 + + // Pflags is an array of port-specific flags. Each entry + // in "Pflags" corresponds to the port at the same index + // in "Ports". + Pflags [XT_MULTI_PORTS]uint8 + + // Invert is a flag that, if nonzero, indicates + // that the match result should be inverted. + Invert uint8 +} + +// SizeOfXTMultiport is the size of XTMultiport (in bytes). +const SizeOfXTMultiport = 2 + (XT_MULTI_PORTS * 2) + +// SizeOfXTMultiportV1 is the size of XTMultiportV1 (in bytes). +const SizeOfXTMultiportV1 = SizeOfXTMultiport + XT_MULTI_PORTS + 1 diff --git a/pkg/abi/linux/netfilter_test.go b/pkg/abi/linux/netfilter_test.go index e854a88f17..90c818c483 100644 --- a/pkg/abi/linux/netfilter_test.go +++ b/pkg/abi/linux/netfilter_test.go @@ -38,6 +38,8 @@ func TestSizes(t *testing.T) { {IP6TReplace{}, SizeOfIP6TReplace}, {IP6TEntry{}, SizeOfIP6TEntry}, {IP6TIP{}, SizeOfIP6TIP}, + {XTMultiport{}, SizeOfXTMultiport}, + {XTMultiportV1{}, SizeOfXTMultiportV1}, } for _, tc := range testCases { diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index 2835c80029..cfe51d607a 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -12,6 +12,8 @@ go_library( "extensions.go", "ipv4.go", "ipv6.go", + "multiport_matcher.go", + "multiport_matcher_v1.go", "netfilter.go", "owner_matcher.go", "owner_matcher_v1.go", diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index d4e632ac58..6207591e5e 100644 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -105,7 +105,7 @@ func marshalEntryMatch(name string, data []byte) []byte { return buf } -func unmarshalMatcher(mapper IDMapper, match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) { +func unmarshalMatcher(mapper IDMapper, match *linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) { key := matchKey{ name: match.Name.String(), revision: match.Revision, @@ -117,6 +117,25 @@ func unmarshalMatcher(mapper IDMapper, match linux.XTEntryMatch, filter stack.IP return matchMaker.unmarshal(mapper, buf, filter) } +// matchMakerRevision returns the maximum supported version of the +// matcher with "name" up to "rev" and whether any such matcher +// with that name exists. +func matchMakerRevision(name string, rev uint8) (uint8, bool) { + var found bool + var ret uint8 + + for matcher := range matchMakers { + if name == matcher.name { + found = true + if matcher.revision > ret { + ret = matcher.revision + } + } + } + + return ret, found +} + // targetMaker knows how to (un)marshal a target. Once registered, // marshalTarget and unmarshalTarget can be used. type targetMaker interface { diff --git a/pkg/sentry/socket/netfilter/multiport_matcher.go b/pkg/sentry/socket/netfilter/multiport_matcher.go new file mode 100644 index 0000000000..0a8e183399 --- /dev/null +++ b/pkg/sentry/socket/netfilter/multiport_matcher.go @@ -0,0 +1,242 @@ +// Copyright 2024 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netfilter + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + matcherNameMultiport string = "multiport" + matcherRevMultiport uint8 = 0 + matcherPfxMultiport string = (matcherNameMultiport + ".0") +) + +// multiportMarshaler handles marshalling and +// unmarshalling of "xt_multiport" matchers. +type multiportMarshaler struct{} + +// multiportMatcher represents a multiport matcher +// with source and/or destination ports. +type multiportMatcher struct { + flags uint8 // Port match flag (source/destination/either). + count uint8 // Number of ports. + ports []uint16 // List of ports to match against. +} + +// init registers the "multiportMarshaler" with the matcher registry. +func init() { + registerMatchMaker(multiportMarshaler{}) +} + +// name returns the name of the marshaler. +func (multiportMarshaler) name() string { + return matcherNameMultiport +} + +// revision returns the revision number of the marshaler. +func (multiportMarshaler) revision() uint8 { + return matcherRevMultiport +} + +// marshal converts a matcher into its binary representation. +func (multiportMarshaler) marshal(mr matcher) []byte { + m := mr.(*multiportMatcher) + var xtmp linux.XTMultiport + + nflog("%s: marshal: XTMultiport: %+v", matcherPfxMultiport, m) + + // Set the match criteria flag. + xtmp.Flags = m.flags + + // Set the count of ports and populate the "Ports" slice. + xtmp.Count = uint8(len(m.ports)) + + // Truncate the "ports" slice to the maximum allowed + // by "XT_MULTI_PORTS" to prevent out-of-bounds writes. + if xtmp.Count > linux.XT_MULTI_PORTS { + xtmp.Count = linux.XT_MULTI_PORTS + } + + // Copy over the ports. + for i := uint8(0); i < xtmp.Count; i++ { + xtmp.Ports[i] = m.ports[i] + } + + // Marshal the XTMultiport structure into binary format. + return marshalEntryMatch(matcherNameMultiport, marshal.Marshal(&xtmp)) +} + +// unmarshal converts binary data into a multiportMatcher instance. +func (multiportMarshaler) unmarshal(_ IDMapper, buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { + var matchData linux.XTMultiport + + nflog("%s: raw: XTMultiport: %+v", matcherPfxMultiport, buf) + + // Check if the buffer has enough data for XTMultiport. + if len(buf) < linux.SizeOfXTMultiport { + return nil, fmt.Errorf( + "%s: insufficient data, got %d, want: >= %d", + matcherPfxMultiport, + len(buf), + linux.SizeOfXTMultiport, + ) + } + + // Unmarshal the buffer into the XTMultiport structure. + matchData.UnmarshalUnsafe(buf) + nflog("%s: parsed XTMultiport: %+v", matcherPfxMultiport, matchData) + + // Validate the port count. + if matchData.Count == 0 || matchData.Count > linux.XT_MULTI_PORTS { + return nil, fmt.Errorf( + "%s: invalid port count, got %d, want: [1, %d]", + matcherPfxMultiport, matchData.Count, linux.XT_MULTI_PORTS, + ) + } + + // Extract the list of ports from the match data. + ports := make([]uint16, matchData.Count) + for i := 0; i < int(matchData.Count); i++ { + ports[i] = matchData.Ports[i] + } + + // Initialize "multiportMatcher" with the extracted ports. + matcher := &multiportMatcher{ + flags: matchData.Flags, + count: matchData.Count, + ports: ports, + } + + return matcher, nil +} + +// name returns the name of the matcher. +func (multiportMatcher) name() string { + return matcherNameMultiport +} + +// revision returns the revision number of the matcher. +func (multiportMatcher) revision() uint8 { + return matcherRevMultiport +} + +// Match determines if the packet matches any of the specified ports +// and returns true if a match is found. The second boolean returned +// indicates whether the packet should be "hot" dropped, or processed +// with other matchers. +func (m *multiportMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) { + // Extract source and destination ports from the packet. + srcPort, dstPort, ok := extractPorts(pkt) + // The packet does not contain valid transport + // headers or uses an unsupported protocol. + if !ok { + return false, true + } + + // Iterate through the list of ports to check for a match based on + // the specified match criteria: source, destination or either. + for i := uint8(0); i < m.count; i++ { + if exactPortMatch(m.flags, srcPort, dstPort, m.ports[i]) { + return true, false + } + } + + // No match. + return false, false +} + +// extractTransportHeaderPorts is a helper routine that extracts +// the source and destination ports from the provided transport +// header based on the specified transport protocol. It supports +// TCP and UDP protocols and returns the source port, destination +// port, and a boolean indicating whether the extraction was +// successful. If the protocol is unsupported or the transport +// header is too short, it returns (0, 0, false). +func extractTransportHeaderPorts(hdr []byte, prot tcpip.TransportProtocolNumber) (uint16, uint16, bool) { + switch prot { + case header.TCPProtocolNumber: + // Ensure the TCP header has the minimum required length. + if len(hdr) < header.TCPMinimumSize { + return 0, 0, false + } + // Extract and return the source and destination ports. + tcpHdr := header.TCP(hdr) + return tcpHdr.SourcePort(), tcpHdr.DestinationPort(), true + + case header.UDPProtocolNumber: + // Similar to TCP. + if len(hdr) < header.UDPMinimumSize { + return 0, 0, false + } + udpHdr := header.UDP(hdr) + return udpHdr.SourcePort(), udpHdr.DestinationPort(), true + + default: + // Unsupported transport protocol; cannot extract ports. + return 0, 0, false + } +} + +// extractPorts extracts the source and destination ports from the given +// packet buffer. It supports both IPv4 and IPv6 packets and handles TCP +// and UDP transport protocols. It returns the source port, destination +// port, and a boolean indicating success. If the packet does not contain +// enough data or uses an unsupported protocol, it returns (0, 0, false). +func extractPorts(pkt *stack.PacketBuffer) (uint16, uint16, bool) { + // Retrieve the transport header (TCP/UDP) from the packet buffer. + transportHdr := pkt.TransportHeader().Slice() + + // Determine the network protocol. + switch pkt.NetworkProtocolNumber { + case header.IPv4ProtocolNumber: + // Extract the IPv4 header from the network header + // slice, then the transport protocol from it. + ipv4 := header.IPv4(pkt.NetworkHeader().Slice()) + prot := ipv4.TransportProtocol() + return extractTransportHeaderPorts(transportHdr, prot) + + case header.IPv6ProtocolNumber: + // Similar to IPv4. + ipv6 := header.IPv6(pkt.NetworkHeader().Slice()) + prot := ipv6.TransportProtocol() + return extractTransportHeaderPorts(transportHdr, prot) + + default: + // Unsupported network protocol; cannot extract ports. + return 0, 0, false + } +} + +// exactPortMatch return true if "srcPort" or "dstPort" are the +// same as "matchPort" depending on the matching criteria specified +// in "flags". +func exactPortMatch(flags uint8, srcPort, dstPort, matchPort uint16) bool { + switch flags { + case linux.XT_MULTIPORT_SOURCE: + return srcPort == matchPort + case linux.XT_MULTIPORT_DESTINATION: + return dstPort == matchPort + case linux.XT_MULTIPORT_EITHER: + return (srcPort == matchPort) || (dstPort == matchPort) + } + return false +} diff --git a/pkg/sentry/socket/netfilter/multiport_matcher_v1.go b/pkg/sentry/socket/netfilter/multiport_matcher_v1.go new file mode 100644 index 0000000000..d4f910ccd3 --- /dev/null +++ b/pkg/sentry/socket/netfilter/multiport_matcher_v1.go @@ -0,0 +1,214 @@ +// Copyright 2024 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netfilter + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + matcherRevMultiportV1 uint8 = 1 + matcherPfxMultiportV1 string = (matcherNameMultiport + ".1") +) + +// multiportMarshalerV1 handles marshalling and +// unmarshalling of "xt_multiport_v1" matchers. +type multiportMarshalerV1 struct{} + +// multiportMatcherV1 represents a multiport matcher with +// source and/or destination ports, per-port flags, and an +// inversion flag. +type multiportMatcherV1 struct { + flags uint8 // Port match flag (source/destination/either). + count uint8 // Number of ports. + ports []uint16 // List of ports to match against. + pflags []uint8 // Per-port flags (for range matches). + invert bool // Invert match result. +} + +// init registers the "multiportMarshalerV1" with the matcher registry. +func init() { + registerMatchMaker(multiportMarshalerV1{}) +} + +// name returns the name of the marshaler. +func (multiportMarshalerV1) name() string { + return matcherNameMultiport +} + +// revision returns the revision number of the marshaler. +func (multiportMarshalerV1) revision() uint8 { + return matcherRevMultiportV1 +} + +// marshal converts a "multiportMatcherV1" into its binary representation. +func (multiportMarshalerV1) marshal(mr matcher) []byte { + m := mr.(*multiportMatcherV1) + var xtmp linux.XTMultiportV1 + + // Set the match criteria flag. + xtmp.Flags = m.flags + + // Set the count of ports and populate the "Ports" slice. + xtmp.Count = uint8(len(m.ports)) + + // Truncate the "ports" slice to the maximum allowed + // by "XT_MULTI_PORTS" to prevent out-of-bounds writes. + if xtmp.Count > linux.XT_MULTI_PORTS { + xtmp.Count = linux.XT_MULTI_PORTS + } + + // Copy over the ports, and per-port flags. + for i := uint8(0); i < xtmp.Count; i++ { + xtmp.Ports[i] = m.ports[i] + xtmp.Pflags[i] = m.pflags[i] + } + + // If the match result is to be inverted. + if m.invert { + xtmp.Invert = uint8(1) + } + + // Marshal the XTMultiportV1 structure into binary format. + return marshalEntryMatch(matcherNameMultiport, marshal.Marshal(&xtmp)) +} + +// unmarshal converts binary data into a multiportMatcherV1 instance. +func (multiportMarshalerV1) unmarshal(_ IDMapper, buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { + var matchData linux.XTMultiportV1 + + nflog("%s: raw XTMultiportV1: %+v", matcherPfxMultiportV1, buf) + + // Check if the buffer has enough data for XTMultiportV1. + if len(buf) < linux.SizeOfXTMultiportV1 { + return nil, fmt.Errorf( + "%s: insufficient data, got %d, want: >= %d", + matcherPfxMultiportV1, + len(buf), + linux.SizeOfXTMultiportV1, + ) + } + + // Unmarshal the buffer into the XTMultiportV1 structure. + matchData.UnmarshalUnsafe(buf) + nflog("%s: parsed XTMultiportV1: %+v", matcherPfxMultiportV1, matchData) + + // Validate the port count. + if matchData.Count == 0 || matchData.Count > linux.XT_MULTI_PORTS { + return nil, fmt.Errorf( + "%s: invalid port count, got %d, want: [1, %d]", + matcherPfxMultiportV1, matchData.Count, linux.XT_MULTI_PORTS, + ) + } + + // Extract the list of ports and their + // corresponding flags from the match data. + ports := make([]uint16, matchData.Count) + pflags := make([]uint8, matchData.Count) + for i := 0; i < int(matchData.Count); i++ { + ports[i] = matchData.Ports[i] + pflags[i] = matchData.Pflags[i] + } + + // Initialize "multiportMatcherV1" with the extracted ports and flags. + matcher := &multiportMatcherV1{ + flags: matchData.Flags, + count: matchData.Count, + ports: ports, + pflags: pflags, + invert: (matchData.Invert != 0), + } + + return matcher, nil +} + +// name returns the name of the matcher. +func (multiportMatcherV1) name() string { + return matcherNameMultiport +} + +// revision returns the revision number of the matcher. +func (multiportMatcherV1) revision() uint8 { + return matcherRevMultiportV1 +} + +// Match determines if the packet matches any of the specified ports +// and returns true if a match is found. The second boolean returned +// indicates whether the packet should be "hot" dropped, or processed +// with other matchers. +func (m *multiportMatcherV1) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) { + // Extract source and destination ports from the packet. + srcPort, dstPort, ok := extractPorts(pkt) + // The packet does not contain valid transport + // headers or uses an unsupported protocol. + if !ok { + return false, true + } + + // Iterate through the list of ports to check for a match based on + // the specified match criteria: source, destination or either. + i := uint8(0) + for i < m.count { + exact := (m.pflags[i] == 0) + + // This is unlikely, but if range match is enabled for the + // last port in the list, treat it as an exact port match. + if i == (m.count - 1) { + exact = true + } + + if exact { + // Exact port match. + if exactPortMatch(m.flags, srcPort, dstPort, m.ports[i]) { + return (true != m.invert), false + } + + i++ + continue + } + + if rangedPortMatch(m.flags, srcPort, dstPort, m.ports[i], m.ports[i+1]) { + return (true != m.invert), false + } + i += 2 + } + + // No match; invert if needed. + return (false != m.invert), false +} + +// rangedPortMatch return true if "srcPort" or "dstPort" are +// the same in the range of "matchPort{Beg,End}" depending on +// the matching criteria specified in "flags". +func rangedPortMatch(flags uint8, srcPort, dstPort, begPort, endPort uint16) bool { + minPort, maxPort := min(begPort, endPort), max(begPort, endPort) + srcPortMatch := (srcPort >= minPort) && (srcPort <= maxPort) + dstPortMatch := (dstPort >= minPort) && (dstPort <= maxPort) + + switch flags { + case linux.XT_MULTIPORT_SOURCE: + return srcPortMatch + case linux.XT_MULTIPORT_DESTINATION: + return dstPortMatch + case linux.XT_MULTIPORT_EITHER: + return srcPortMatch || dstPortMatch + } + + return false +} diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index c7827fb662..31bba396ee 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -310,11 +310,14 @@ func parseMatchers(mapper IDMapper, filter stack.IPHeaderFilter, optVal []byte) return nil, fmt.Errorf("optVal has insufficient size for match: %d", len(optVal)) } - // Parse the specific matcher. - matcher, err := unmarshalMatcher(mapper, match, filter, optVal[linux.SizeOfXTEntryMatch:match.MatchSize]) + // Starting with the highest supported revision, try to unmarshal + // with each revision down to 0; if all revisions fail, give up. + matcher, err := unmarshalMatcherRevs(mapper, &match, filter, optVal) if err != nil { - return nil, fmt.Errorf("failed to create matcher: %v", err) + return nil, fmt.Errorf("failed to create matcher: %v", match) } + + nflog("set entries: found matcher for: %+v", match) matchers = append(matchers, matcher) // TODO(gvisor.dev/issue/6167): Check the revision field. @@ -328,6 +331,45 @@ func parseMatchers(mapper IDMapper, filter stack.IPHeaderFilter, optVal []byte) return matchers, nil } +// unmarshalMatcherRevs tries to unmarshal matchers with the same name, +// starting with the highest revision down to 0. If all revisions fail, +// it returns the most recent (lowest revision's) "unmarshalMatcher" +// error. +func unmarshalMatcherRevs(mapper IDMapper, match *linux.XTEntryMatch, filter stack.IPHeaderFilter, optVal []byte) (stack.Matcher, error) { + var ( + matcher stack.Matcher + err error + ) + + // Get the highest supported revision for the matcher. + maxRev, found := matchMakerRevision(match.Name.String(), 0) + if !found { + return nil, fmt.Errorf( + "unmarshalMatcherRevs: failed to find matcher with name: %s", + match.Name, + ) + } + + for maxRev >= 0 { + match.Revision = maxRev + + nflog("unmarshalMatcherRevs: attempting to find matcher: %+v", match) + matcher, err = unmarshalMatcher( + mapper, match, filter, + optVal[linux.SizeOfXTEntryMatch:match.MatchSize], + ) + + // A match was found. + if err == nil { + break + } + + maxRev-- + } + + return matcher, err +} + func validUnderflow(rule stack.Rule, ipv6 bool) bool { if len(rule.Matchers) != 0 { return false @@ -367,6 +409,31 @@ func hookFromLinux(hook int) stack.Hook { panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook)) } +// MatchRevision returns a "linux.XTGetRevision" for a given +// matcher. It sets "Revision" to the highest supported value, +// unless the provided revision number is higher. +func MatchRevision(t *kernel.Task, revPtr hostarch.Addr) (linux.XTGetRevision, *syserr.Error) { + // Read in the matcher name and version. + var rev linux.XTGetRevision + + if _, err := rev.CopyIn(t, revPtr); err != nil { + return linux.XTGetRevision{}, syserr.FromError(err) + } + + maxSupported, ok := matchMakerRevision(rev.Name.String(), rev.Revision) + if !ok { + // Return ENOENT if there's no matcher with that name. + return linux.XTGetRevision{}, syserr.ErrNoFileOrDir + } + + if maxSupported < rev.Revision { + // Return EPROTONOSUPPORT if we have an insufficient revision. + return linux.XTGetRevision{}, syserr.ErrProtocolNotSupported + } + + return rev, nil +} + // TargetRevision returns a linux.XTGetRevision for a given target. It sets // Revision to the highest supported value, unless the provided revision number // is larger. diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 808fe1d2da..2277b92abe 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1549,6 +1549,29 @@ func getSockOptIPv6(t *kernel.Task, s socket.Socket, ep commonEndpoint, name int } return &entries, nil + case linux.IP6T_SO_GET_REVISION_MATCH: + if outLen < linux.SizeOfXTGetRevision { + return nil, syserr.ErrInvalidArgument + } + + // Only valid for raw IPv6 sockets. + if skType != linux.SOCK_RAW { + return nil, syserr.ErrProtocolNotAvailable + } + + stk := inet.StackFromContext(t) + if stk == nil { + return nil, syserr.ErrNoDevice + } + + // Get the highest support matcher revision. + ret, err := netfilter.MatchRevision(t, outPtr) + if err != nil { + return nil, err + } + + return &ret, nil + case linux.IP6T_SO_GET_REVISION_TARGET: if outLen < linux.SizeOfXTGetRevision { return nil, syserr.ErrInvalidArgument @@ -1754,6 +1777,30 @@ func getSockOptIP(t *kernel.Task, s socket.Socket, ep commonEndpoint, name int, } return &entries, nil + case linux.IPT_SO_GET_REVISION_MATCH: + if outLen < linux.SizeOfXTGetRevision { + return nil, syserr.ErrInvalidArgument + } + + // Only valid for raw IPv4 sockets. + family, skType, _ := s.Type() + if family != linux.AF_INET || skType != linux.SOCK_RAW { + return nil, syserr.ErrProtocolNotAvailable + } + + stk := inet.StackFromContext(t) + if stk == nil { + return nil, syserr.ErrNoDevice + } + + // Get the highest support matcher revision. + ret, err := netfilter.MatchRevision(t, outPtr) + if err != nil { + return nil, err + } + + return &ret, nil + case linux.IPT_SO_GET_REVISION_TARGET: if outLen < linux.SizeOfXTGetRevision { return nil, syserr.ErrInvalidArgument diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go index a0545e2cd3..cd08e042d0 100644 --- a/test/iptables/filter_input.go +++ b/test/iptables/filter_input.go @@ -59,6 +59,8 @@ func init() { RegisterTestCase(&FilterInputInterfaceInvertAccept{}) RegisterTestCase(&FilterInputInvertDportAccept{}) RegisterTestCase(&FilterInputInvertDportDrop{}) + RegisterTestCase(&FilterInputDropAllSrcPorts{}) + RegisterTestCase(&FilterInputDropAllExceptOneDstPort{}) } // FilterInputDropUDP tests that we can drop UDP traffic. @@ -1056,3 +1058,185 @@ func (*FilterInputInvertDportDrop) LocalAction(ctx context.Context, ip net.IP, i return nil } + +// FilterInputDropAllSrcPorts tests that all TCP packets, regardless +// of source port, are dropped. The rule covers all the source ports +// so that no incoming TCP packet on INPUT is accepted. +// +// Rule(s): +// +// -A INPUT -p tcp -m multiport --sports 0,1,2:32000,32001:65535 -j DROP +type FilterInputDropAllSrcPorts struct { + containerCase +} + +var _ TestCase = (*FilterInputDropAllSrcPorts)(nil) + +// Name implements TestCase.Name. +func (*FilterInputDropAllSrcPorts) Name() string { + return "FilterInputDropAllSrcPorts" +} + +// ContainerAction implements TestCase.ContainerAction. +// The container will then attempt to receive a UDP packet, +// which should never arrive due to the DROP rule. +func (*FilterInputDropAllSrcPorts) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Add the multiport rule that drops all TCP packets from any source port. + err := filterTable( + ipv6, + "-A", "INPUT", "-p", "tcp", "-m", "multiport", + "--sports", "0,1,2:32000,32001:65535", "-j", "DROP", + ) + if err != nil { + return err + } + + testPort := 42 + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + + // listenTCP attempts to receive a TCP packet. Since all + // TCP packets are dropped, it should time out and return + // an error (DeadlineExceeded). + err = listenTCP(timedCtx, testPort, ipv6) + if err == nil { + return fmt.Errorf("unexpected receive on port: %d", testPort) + } + + if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("expected timeout error, vut got: %w", err) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +// It tries to connect to the container's test port, but the +// DROP rule ensures the packet never arrives at the port. +func (*FilterInputDropAllSrcPorts) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + testPort := 42 + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + + if err := connectTCP(timedCtx, ip, testPort, ipv6); err == nil { + return fmt.Errorf( + "expected connect failure on port: %d", + testPort, + ) + } + + return nil +} + +// FilterInputDropAllExceptOneDstPort tests that only packets destined +// to a specific port are accepted, while connections to any other port +// are dropped. The rule uses a negated multiport destination port +// specification to allow only one port. +// +// Rule(s): +// +// -P INPUT DROP +// -A INPUT -p tcp -m multiport ! --dports 0:442,444:32000,32001:65535 -j ACCEPT +type FilterInputDropAllExceptOneDstPort struct { + containerCase +} + +var _ TestCase = (*FilterInputDropAllExceptOneDstPort)(nil) + +// Name implements TestCase.Name. +func (*FilterInputDropAllExceptOneDstPort) Name() string { + return "FilterInputDropAllExceptOneDstPort" +} + +// ContainerAction implements TestCase.ContainerAction. +// It installs a catch-all DROP policy for the input chain and a single +// ACCEPT rule for packets destined to the allowed port. The container +// listens on allowed and blocked ports; only the former should receive +// a connection. +func (*FilterInputDropAllExceptOneDstPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Add the multiport rule that allows inbound on 443 only. + rules := [][]string{ + {"-A", "INPUT", "-p", "tcp", "-m", "multiport", + "!", "--dports", "0:442,444:32000,32001:65535", "-j", "ACCEPT"}, + {"-P", "INPUT", "DROP"}, + } + + if err := filterTableRules(ipv6, rules); err != nil { + return err + } + + allowedPort := 443 + blockedPort := 80 + errCh := make(chan error, 2) + + // Listen on port allowed port. + go func() { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + + if err := listenTCP(timedCtx, allowedPort, ipv6); err != nil { + errCh <- fmt.Errorf( + "unexpected error on allowed port: %s, got: %w", + allowedPort, err, + ) + return + } + errCh <- nil + }() + + // Listen on blocked port. + go func() { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + + err := listenTCP(timedCtx, blockedPort, ipv6) + if err == nil { + // Should not receive any traffic. + errCh <- fmt.Errorf("unexpected receive on port: %d", blockedPort) + return + } + + if !errors.Is(err, context.DeadlineExceeded) { + errCh <- fmt.Errorf( + "expected timeout error on port: %d, but got: %w", + blockedPort, err, + ) + return + } + + errCh <- nil + }() + + // Wait for both listeners. + for i := 0; i < 2; i++ { + if err := <-errCh; err != nil { + return err + } + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +// It connects to both the allowed port and the +// blocked port, only the former should succeed. +func (*FilterInputDropAllExceptOneDstPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + allowedPort := 443 + blockedPort := 80 + + // Connect to allowed port. + allowTimedCtx, allowCancel := context.WithTimeout(ctx, NegativeTimeout) + defer allowCancel() + if err := connectTCP(allowTimedCtx, ip, allowedPort, ipv6); err != nil { + return fmt.Errorf("failed to connect on port %d: %w", allowedPort, err) + } + + // Connect to blocked port. + blockTimedCtx, blockCancel := context.WithTimeout(ctx, NegativeTimeout) + defer blockCancel() + if err := connectTCP(blockTimedCtx, ip, blockedPort, ipv6); err == nil { + return fmt.Errorf("expected connect error on port: %d", blockedPort) + } + + return nil +} diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go index b05c8a79bf..674fd8ab9b 100644 --- a/test/iptables/filter_output.go +++ b/test/iptables/filter_output.go @@ -44,8 +44,15 @@ func init() { RegisterTestCase(&FilterOutputInterfaceInvertAccept{}) RegisterTestCase(&FilterOutputInvertSportAccept{}) RegisterTestCase(&FilterOutputInvertSportDrop{}) + RegisterTestCase(&FilterOutputAcceptInvertSrcPorts{}) + RegisterTestCase(&FilterOutputDropSrcPorts{}) + RegisterTestCase(&FilterOutputAcceptInvertPorts{}) } +// multiportPortCountLimit is the maximum number of +// ports that can be specified for a multiport match. +const multiportPortCountLimit = 15 + // FilterOutputDropTCPDestPort tests that connections are not accepted on // specified source ports. type FilterOutputDropTCPDestPort struct{ baseCase } @@ -780,3 +787,273 @@ func (*FilterOutputInvertSportDrop) LocalAction(ctx context.Context, ip net.IP, return nil } + +// FilterOutputAcceptInvertSrcPorts tests that all UDP outbound connections +// are allowed except those going to specific source ports. The rule uses +// a negated multiport match to ACCEPT traffic for any destination port not +// listed. +// +// Rule(s): +// +// -A OUTPUT -p udp -m multiport ! --sports 53,15008,32000 -j ACCEPT +type FilterOutputAcceptInvertSrcPorts struct { + containerCase +} + +var _ TestCase = (*FilterOutputAcceptInvertSrcPorts)(nil) + +// Name implements TestCase.Name. +func (*FilterOutputAcceptInvertSrcPorts) Name() string { + return "FilterOutputAcceptInvertSrcPorts" +} + +// ContainerAction implements TestCase.ContainerAction. +// It installs the single ACCEPT rule with negation and then +// attempts to connect to a local UDP server listening on a +// blocked port. +func (*FilterOutputAcceptInvertSrcPorts) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Add the multiport rule that accepts all but the blocked ports. + err := filterTable( + ipv6, + "-A", "OUTPUT", "-p", "udp", "-m", "multiport", + "!", "--sports", "53,15008", "-j", "ACCEPT", + ) + if err != nil { + return err + } + + testPort := 53 + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + + // No response will be sent. + if err = listenUDP(timedCtx, testPort, ipv6); err != nil { + if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("expected timeout error, vut got: %w", err) + } + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +// It attempts to connect to the container on the specified port. +// Since the container cannot send back responses, the connection +// attempt will fail or time out. +func (*FilterOutputAcceptInvertSrcPorts) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + testPort := 53 + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + + // This should time out. + if err := sendUDPLoop(timedCtx, ip, testPort, ipv6); err == nil { + return fmt.Errorf("expected connect failure on port: %d", testPort) + } + + return nil +} + +// FilterOutputDropSrcPorts tests that any TCP packet leaving the +// container from a source port in set is dropped, preventing the +// container from making outbound responses on these ports. +// +// Rule(s): +// +// -A OUTPUT -p tcp -m multiport --sports 22,53,80:443 -j DROP +type FilterOutputDropSrcPorts struct { + containerCase +} + +var _ TestCase = (*FilterOutputDropSrcPorts)(nil) + +// Name implements TestCase.Name. +func (*FilterOutputDropSrcPorts) Name() string { + return "FilterOutputDropSrcPorts" +} + +// ContainerAction implements TestCase.ContainerAction. +// It installs the DROP rule for outbound packets with the specified +// source ports. The container then listens on those ports, expecting +// connection attempts from the local side. Because responses from +// these ports are dropped, no handshake completes. +func (*FilterOutputDropSrcPorts) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Add the DROP rule for outbound packets from the specified source ports. + err := filterTable( + ipv6, + "-A", "OUTPUT", "-p", "tcp", "-m", "multiport", + "--sports", "22,53,80:443", "-j", "DROP", + ) + if err != nil { + return err + } + + // Listen on a set of ports within the blocked range. + ports := []int{22, 53, 80, 443} + errCh := make(chan error, len(ports)) + + for _, p := range ports { + go func(port int) { + // Attempt to accept connections. Even if an inbound + // connection is created, it won't receive are reply. + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + + if err := listenTCP(timedCtx, port, ipv6); err != nil { + if !errors.Is(err, context.DeadlineExceeded) { + errCh <- fmt.Errorf( + "unexpected error on port %d: %w", + port, err, + ) + return + } + } + // Timing out or no successful connection is expected. + errCh <- nil + }(p) + } + + // Wait for all listeners to report. + for i := 0; i < len(ports); i++ { + if err := <-errCh; err != nil { + return err + } + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +// It attempts to connect to the container on each of the blocked +// source ports. Since the container cannot send back responses, +// the connection attempts will fail or time out. +func (*FilterOutputDropSrcPorts) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + ports := []int{22, 53, 80, 443} + errCh := make(chan error, len(ports)) + + for _, p := range ports { + go func(port int) { + // Attempt to connect, but it will time out. + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + + if err := connectTCP(timedCtx, ip, port, ipv6); err == nil { + errCh <- fmt.Errorf( + "expected timout error on port %d, but got: %w", + port, err, + ) + return + } + errCh <- nil + }(p) + } + + // Wait for all client to report. + for i := 0; i < len(ports); i++ { + if err := <-errCh; err != nil { + return err + } + } + + return nil +} + +// FilterOutputAcceptInvertPorts tests a negation of either ports +// matching on OUTPUT. The rule accepts all UDP packets if either +// their source and destination ports fall into the matched set. +// +// Rule(s): +// +// -A OUTPUT -p tcp -m multiport ! --ports 22,53:80,443 -j ACCEPT +type FilterOutputAcceptInvertPorts struct { + containerCase +} + +var _ TestCase = (*FilterOutputAcceptInvertPorts)(nil) + +// Name implements TestCase.Name. +func (*FilterOutputAcceptInvertPorts) Name() string { + return "FilterOutputAcceptInvertPorts" +} + +// ContainerAction implements TestCase.ContainerAction. +// It installs the single ACCEPT rule with negation. The container then +// listens on those ports, expecting connection attempts from the local +// side, which will all succeed. +func (*FilterOutputAcceptInvertPorts) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + err := filterTable( + ipv6, + "-A", "OUTPUT", "-p", "tcp", "-m", "multiport", + "!", "--ports", "53:80,22,443", "-j", "ACCEPT", + ) + if err != nil { + return err + } + + // Even though some of the ports belong to the inverted set, the + // combination of the source and destination port will not match. + // Since the listener ports "low" ports, the chances of this failing + // is low. + testPorts := []int{22, 27017} + errCh := make(chan error, len(testPorts)) + + for _, p := range testPorts { + go func(port int) { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + + if err := listenTCP(timedCtx, port, ipv6); err != nil { + errCh <- fmt.Errorf( + "unexpected error on allowed port: %s, got: %w", + port, err, + ) + return + } + + errCh <- nil + }(p) + } + + // Wait for listeners. + for i := 0; i < 2; i++ { + if err := <-errCh; err != nil { + return err + } + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +// It attempts to connect to the container on each ports +// being listened on. Since this is an either port match, +// both connections should succeed. +func (*FilterOutputAcceptInvertPorts) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + testPorts := []int{22, 27017} + errCh := make(chan error, len(testPorts)) + + // All connections should succeed. + for _, p := range testPorts { + go func(port int) { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, port, ipv6); err != nil { + errCh <- fmt.Errorf( + "failed to connect on port %d: %w", + port, err, + ) + return + } + + errCh <- nil + }(p) + } + + // Wait for clients. + for i := 0; i < len(testPorts); i++ { + if err := <-errCh; err != nil { + return err + } + } + + return nil +} diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 2d36fc5c01..d5fcccca06 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -498,3 +498,22 @@ func TestNATPostSNATUDP(t *testing.T) { func TestNATPostSNATTCP(t *testing.T) { singleTest(t, &NATPostSNATTCP{}) } + +func TestFilterInputDropAllSrcPorts(t *testing.T) { + singleTest(t, &FilterInputDropAllSrcPorts{}) +} +func TestFilterInputDropAllExceptOneDstPort(t *testing.T) { + singleTest(t, &FilterInputDropAllExceptOneDstPort{}) +} + +func TestFilterOutputAcceptInvertSrcPorts(t *testing.T) { + singleTest(t, &FilterOutputAcceptInvertSrcPorts{}) +} + +func TestFilterOutputDropSrcPorts(t *testing.T) { + singleTest(t, &FilterOutputDropSrcPorts{}) +} + +func TestFilterOutputAcceptInvertPorts(t *testing.T) { + singleTest(t, &FilterOutputAcceptInvertPorts{}) +}