diff --git a/README.md b/README.md index 740a8c5..f497e72 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ In this diagram, the client has generated and installed WireGuard configuration 1. Download binaries from the [releases](https://github.com/sandialabs/wiretap/releases) page, one for your client machine and one for your server (if different os/arch) 2. Run `./wiretap configure --port --endpoint --routes ` with the appropriate arguments -3. Import the resulting `wiretap_relay.conf` and `wiretap_e2ee.conf` files into WireGuard on the client machine +3. Import the resulting `wiretap.conf` and `wiretap_relay.conf` files into WireGuard on the client machine 4. Copy and paste the server command output that best suits your target system into Wiretap on the server machine 5. Add more servers and clients as needed with the `add` subcommand @@ -79,7 +79,7 @@ PublicKey = kMj7HwfYYFO/XEHNFK2kz9cBd7vTHk63fhygyuYLMzI= AllowedIPs = 172.17.0.0/24,fd:17::/48 ──────────────────────────────── -config: wiretap_e2ee.conf +config: wiretap.conf ──────────────────────────────── [Interface] PrivateKey = YCTRVwB4xOEcBtifVmhjMhRYL7+DOlDP5VdHZGclZGg= @@ -105,12 +105,12 @@ Config File: ./wiretap serve -f wiretap_server.conf > **Note** > Wiretap uses 2 WireGuard interfaces per node in order to safely and scalably chain together servers. This means your client will bind to more than one port, but only the Relay Interface port needs to be accessible by the Server. See the [How It Works](#how-it-works) section for details. Use `--simple` if your setup requires a single interface on the client -Install the resulting config either by copying and pasting the output or by importing the new `wiretap_relay.conf` and `wiretap_e2ee.conf` files into WireGuard: +Install the resulting config either by copying and pasting the output or by importing the new `wiretap.conf` and `wiretap_relay.conf` files into WireGuard: * If using a GUI, select the menu option similar to *Import Tunnel(s) From File* -* If you have `wg-quick` installed, `sudo wg-quick up ./wiretap_relay.conf` and `sudo wg-quick up ./wiretap_e2ee.conf` +* If you have `wg-quick` installed, `sudo wg-quick up ./wiretap.conf` and `sudo wg-quick up ./wiretap_relay.conf` -Don't forget to disable or remove the tunnels when you're done (e.g., `sudo wg-quick down ./wiretap_relay.conf` and `sudo wg-quick down ./wiretap_e2ee.conf`) +Don't forget to disable or remove the tunnels when you're done (e.g., `sudo wg-quick down ./wiretap.conf` and `sudo wg-quick down ./wiretap_relay.conf`) ### Deploy @@ -176,7 +176,7 @@ If you plan to attach a server directly to the client, the status command just c Configurations successfully generated. Import the updated config(s) into WireGuard locally and pass the arguments below to Wiretap on the new remote server. -config: wiretap_e2ee.conf +config: wiretap.conf ──────────────────────────────── [Interface] PrivateKey = YCTRVwB4xOEcBtifVmhjMhRYL7+DOlDP5VdHZGclZGg= @@ -203,7 +203,7 @@ POSIX Shell: WIRETAP_RELAY_INTERFACE_PRIVATEKEY=sLERnxT2+VdwwcJOTUHK5fa5sIN7oJ1 Config File: ./wiretap serve -f wiretap_server_1.conf ``` -The client's E2EE configuration will be modified, so you need to reimport it. For example, `wg-quick down ./wiretap_e2ee.conf` and `wg-quick up ./wiretap_e2ee.conf`. If you are attaching a server directly to the client, the Relay interface will also need to be refreshed. +The client's E2EE configuration will be modified, so you need to reimport it. For example, `wg-quick down ./wiretap.conf` and `wg-quick up ./wiretap.conf`. If you are attaching a server directly to the client, the Relay interface will also need to be refreshed. Now you can use any of the server command options to deploy Wiretap to the new server. It will then connect to the already existing server. @@ -258,7 +258,7 @@ The `add client` subcommand can be used to share access to the Wiretap network w > **Note** > All servers must be deployed *before* adding additional clients -Adding a client is very similar to the other commands. It will generate a `wiretap_relay.conf` and `wiretap_e2ee.conf` for sharing. Make sure that all of the first-hop servers (any server directly attached to the original client) can reach or be reached by the new client. Once you get the endpoint information from whoever will be running the new client run: +Adding a client is very similar to the other commands. It will generate a `wiretap.conf` and `wiretap_relay.conf` for sharing. Make sure that all of the first-hop servers (any server directly attached to the original client) can reach or be reached by the new client. Once you get the endpoint information from whoever will be running the new client run: ```bash ./wiretap add client --port 1337 --endpoint 1.3.3.8:1337 @@ -280,7 +280,7 @@ PublicKey = kMj7HwfYYFO/XEHNFK2kz9cBd7vTHk63fhygyuYLMzI= AllowedIPs = 172.17.0.0/24,fd:17::/48 ──────────────────────────────── -config: wiretap_e2ee_1.conf +config: wiretap_1.conf ──────────────────────────────── [Interface] PrivateKey = 8AhL1kDjwBn/IoY4KLd5mMP4GQsyMYNsqYm3aM/bHnE= @@ -303,6 +303,46 @@ Endpoint = 172.17.0.3:51821 Send these files and have the recipient import them into WireGuard to have access to everything in the Wiretap network! By default the routes (AllowedIPs) are copied over, but can be modified by the recipient as needed. +### Port Forwarding + +> **Warning** +> Port forwarding exposes services on your local machine to the remote network, use with caution + +You can expose a service on the client by using the `expose` subcommand. For example, to allow remote systems to access port 80/tcp on your local machine, you could run: + +``` +./wiretap expose --local 80 --remote 8080 +``` + +Now all Wiretap servers will be bound to port 8080/tcp and proxy connections to your services on port 80/tcp. By default this uses IPv6, so make sure any listening services support IPv6 as well. +To configure Wiretap to only use IPv4, use the `configure` subcommand's `--disable-ipv6` option. + +To dynamically forward all ports using SOCKS5: + +``` +./wiretap expose --dynamic --remote 8080 +``` + +All servers will spin up a SOCKS5 server on port 8080 and proxy traffic to your local machine and can be used like this: + +``` +curl -x socks5://:8080 http://:1337 +``` + +The destination IP will be rewritten by the server so you can put any address. + +#### List + +Use `./wiretap expose list` to see all forwarding rules currently configured. + +#### Remove + +Use `./wiretap remove` with the same arguments used in `expose` to delete a rule. For example, to remove the SOCKS5 example above: + +``` +./wiretap expose remove --dynamic --remote 8080 +``` + ## How It Works A traditional VPN can't be installed by unprivileged users because VPNs rely on dangerous operations like changing network routes and working with raw packets. @@ -329,6 +369,7 @@ Usage: Available Commands: add Add peer to wiretap configure Build wireguard config + expose Expose local services to servers help Help about any command ping Ping wiretap server API serve Listen and proxy traffic into target network @@ -353,9 +394,12 @@ Use "wiretap [command] --help" for more information about a command. - TCP - Transparent connections - RST response when port is unreachable + - Reverse Port Forward + - Reverse Socks5 Support - UDP - Transparent "connections" - ICMP Destination Unreachable when port is unreachable + - Reverse Port Forward * Application - API internal to Wiretap for dynamic configuration - Chain servers together to tunnel traffic through an arbitrary number of machines @@ -460,7 +504,7 @@ Install the newly created WireGuard configs with: ```bash wg-quick up ./wiretap_relay.conf -wg-quick up ./wiretap_e2ee.conf +wg-quick up ./wiretap.conf ``` Copy and paste the Wiretap arguments printed by the configure command into the server machine prompt. It should look like this: @@ -540,7 +584,7 @@ To bring down the WireGuard interfaces on the client machine, run: ```bash wg-quick down ./wiretap_relay.conf -wg-quick down ./wiretap_e2ee.conf +wg-quick down ./wiretap.conf ``` ## Experimental diff --git a/demo.tape b/demo.tape index fccc2c0..40e0e53 100644 --- a/demo.tape +++ b/demo.tape @@ -43,6 +43,8 @@ # # Run `socat TCP-LISTEN:6000,reuseaddr,fork UNIX-CLIENT:\"$DISPLAY\"` before recording to enable clipboard operations # If using XQuartz, also run `xhost + localhost` +# +# Postprocess with `ffmpeg -an -i wiretap_demo.mp4 -vf "scale=1600:-1,fps=30" -c:v libx264 -preset slow -crf 28 output.mp4` Output media/wiretap_demo.mp4 @@ -103,7 +105,7 @@ Sleep 2s Type "curl http://10.2.0.4 --connect-timeout 3" Sleep 1s Enter Sleep 6s Type "./wiretap configure --endpoint 10.1.0.2:51820 --routes 10.2.0.0/16,fd:2::/64 -c" Sleep 1s Enter Sleep 4s Type "wg-quick up ./wiretap_relay.conf" Sleep 1s Enter Sleep 2s -Type "wg-quick up ./wiretap_e2ee.conf" Sleep 1s Enter Sleep 2s +Type "wg-quick up ./wiretap.conf" Sleep 1s Enter Sleep 2s Ctrl+b Left diff --git a/docker-compose.yml b/docker-compose.yml index 8eeaee4..817c8f6 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -55,6 +55,8 @@ services: depends_on: - client image: wiretap:latest + ports: + - "6060:6060" networks: exposed: ipv4_address: 10.1.0.3 diff --git a/media/wiretap_demo.mp4 b/media/wiretap_demo.mp4 index 8def80d..65e905c 100644 Binary files a/media/wiretap_demo.mp4 and b/media/wiretap_demo.mp4 differ diff --git a/src/api/api.go b/src/api/api.go index 9490dac..b7adadc 100644 --- a/src/api/api.go +++ b/src/api/api.go @@ -155,3 +155,74 @@ func AddAllowedIPs(apiAddr netip.AddrPort, pubKey wgtypes.Key, allowedIPs []net. return err } + +func Expose(apiAddr netip.AddrPort, localPort uint, remotePort uint, protocol string, dynamic bool) error { + req := serverapi.ExposeRequest{ + Action: serverapi.ExposeActionExpose, + LocalPort: localPort, + RemotePort: remotePort, + Protocol: protocol, + Dynamic: dynamic, + } + + body, err := json.Marshal(req) + if err != nil { + return err + } + + _, err = makeRequest(request{ + URL: makeUrl(apiAddr, "expose", []string{}), + Method: "POST", + Body: body, + }) + + return err +} + +func ExposeList(apiAddr netip.AddrPort) ([]serverapi.ExposeTuple, error) { + req := serverapi.ExposeRequest{ + Action: serverapi.ExposeActionList, + } + + body, err := json.Marshal(req) + if err != nil { + return nil, err + } + + body, err = makeRequest(request{ + URL: makeUrl(apiAddr, "expose", []string{}), + Method: "POST", + Body: body, + }) + if err != nil { + return nil, err + } + + var tuples []serverapi.ExposeTuple + err = json.Unmarshal(body, &tuples) + + return tuples, err +} + +func ExposeDelete(apiAddr netip.AddrPort, localPort uint, remotePort uint, protocol string, dynamic bool) error { + req := serverapi.ExposeRequest{ + Action: serverapi.ExposeActionDelete, + LocalPort: localPort, + RemotePort: remotePort, + Protocol: protocol, + Dynamic: dynamic, + } + + body, err := json.Marshal(req) + if err != nil { + return err + } + + _, err = makeRequest(request{ + URL: makeUrl(apiAddr, "expose", []string{}), + Method: "POST", + Body: body, + }) + + return err +} diff --git a/src/cmd/add_client.go b/src/cmd/add_client.go index cdb9566..b91c2c9 100644 --- a/src/cmd/add_client.go +++ b/src/cmd/add_client.go @@ -70,15 +70,29 @@ func (c addClientCmdConfig) Run() { addresses, err := api.AllocateClientNode(apiAddrPort) check("failed to retrieve address allocation from server", err) + disableV6 := false + if len(baseConfigE2EE.GetPeers()[0].GetAllowedIPs()) < 3 { + disableV6 = true + } + // Make new configs for client. + relayAddrs := []string{addresses.NextClientRelayAddr4.String() + "/32"} + if !disableV6 { + relayAddrs = append(relayAddrs, addresses.NextClientRelayAddr6.String()+"/128") + } clientConfigRelay, err := peer.GetConfig(peer.ConfigArgs{ ListenPort: addCmdArgs.port, - Addresses: []string{addresses.NextClientRelayAddr4.String() + "/32", addresses.NextClientRelayAddr6.String() + "/128"}, + Addresses: relayAddrs, }) check("failed to generate client relay config", err) + + e2eeAddrs := []string{addresses.NextClientE2EEAddr4.String() + "/32"} + if !disableV6 { + e2eeAddrs = append(e2eeAddrs, addresses.NextClientE2EEAddr6.String()+"/128") + } clientConfigE2EE, err := peer.GetConfig(peer.ConfigArgs{ ListenPort: E2EEPort, - Addresses: []string{addresses.NextClientE2EEAddr4.String() + "/32", addresses.NextClientE2EEAddr6.String() + "/128"}, + Addresses: e2eeAddrs, MTU: c.mtu - 80, }) check("failed to generate relay e2ee config", err) diff --git a/src/cmd/add_server.go b/src/cmd/add_server.go index b830db8..0695ee1 100644 --- a/src/cmd/add_server.go +++ b/src/cmd/add_server.go @@ -76,6 +76,14 @@ func (c addServerCmdConfig) Run() { serverConfigE2EE, err := peer.GetConfig(peer.ConfigArgs{}) check("failed to generate server e2ee config", err) + // Set APIBits based on v4/v6 address. + disableV6 := false + apiAddr := clientConfigE2EE.GetPeers()[0].GetApiAddr() + if apiAddr.Is4() { + disableV6 = true + APIBits = APIV4Bits + } + // Connect new server directly to client if no server address was provided. // This branch sets up the server and client configs like `configure` does. if len(c.serverAddress) == 0 { @@ -86,18 +94,12 @@ func (c addServerCmdConfig) Run() { } newRelayPrefixes := peer.GetNextPrefixesForPeers(relayPeers) - if len(newRelayPrefixes) != 2 { - check("failed to get next relay prefixes", errors.New("need two relay prefixes")) - } - - // Find next API subnet + basePrefix := netip.PrefixFrom(apiAddr, APIBits).Masked() e2eePeers := clientConfigE2EE.GetPeers() - baseAllowedIPs := e2eePeers[0].GetAllowedIPs() - basePrefix := netip.PrefixFrom(netip.MustParsePrefix(baseAllowedIPs[len(baseAllowedIPs)-1].String()).Addr(), APIBits).Masked() for _, p := range e2eePeers { - prefixes := p.GetAllowedIPs() + apiAddr := p.GetApiAddr() - apiPrefix := netip.PrefixFrom(netip.MustParsePrefix(prefixes[len(prefixes)-1].String()).Addr(), APIBits).Masked() + apiPrefix := netip.PrefixFrom(apiAddr, APIBits).Masked() if basePrefix.Addr().Less(apiPrefix.Addr()) { basePrefix = apiPrefix } @@ -133,7 +135,7 @@ func (c addServerCmdConfig) Run() { clientConfigRelay.AddPeer(serverRelayPeer) // Add new server as E2EE peer. - c.allowedIPs = append(c.allowedIPs, apiPrefix.Addr().Next().Next().String()+"/128") + c.allowedIPs = append(c.allowedIPs, fmt.Sprintf("%s/%d", apiPrefix.Addr().Next().Next().String(), apiPrefix.Addr().BitLen())) serverE2EEPeer, err := peer.GetPeerConfig(peer.PeerConfigArgs{ PublicKey: serverConfigE2EE.GetPublicKey(), AllowedIPs: c.allowedIPs, @@ -160,9 +162,14 @@ func (c addServerCmdConfig) Run() { serverConfigRelay.AddPeer(clientPeerConfigRelay) serverConfigE2EE.AddPeer(clientPeerConfigE2EE) - err = serverConfigRelay.SetAddresses([]string{newRelayPrefixes[0].Addr().Next().Next().String() + "/32", newRelayPrefixes[1].Addr().Next().Next().String() + "/128"}) + relayAddrs := []string{newRelayPrefixes[0].Addr().Next().Next().String() + "/32"} + if !disableV6 { + relayAddrs = append(relayAddrs, newRelayPrefixes[1].Addr().Next().Next().String()+"/128") + } + err = serverConfigRelay.SetAddresses(relayAddrs) check("failed to set addresses", err) - err = serverConfigE2EE.SetAddresses([]string{apiPrefix.Addr().Next().Next().String() + "/128"}) + + err = serverConfigE2EE.SetAddresses([]string{fmt.Sprintf("%s/%d", apiPrefix.Addr().Next().Next().String(), apiPrefix.Addr().BitLen())}) check("failed to set addresses", err) } else { // Get leaf server info @@ -178,9 +185,7 @@ func (c addServerCmdConfig) Run() { leafApiPrefix := netip.PrefixFrom(leafApiAddr, APIBits) apiAddr := leafApiAddr for _, p := range clientConfigE2EE.GetPeers() { - aps := p.GetAllowedIPs() - aa := netip.MustParsePrefix(aps[len(aps)-1].String()).Addr() - + aa := p.GetApiAddr() if leafApiPrefix.Contains(aa) && aa.Less(apiAddr) { apiAddr = aa } @@ -205,13 +210,17 @@ func (c addServerCmdConfig) Run() { check("failed to set endpoint", err) } } - err = leafServerPeerConfigRelay.SetAllowedIPs([]string{ClientRelaySubnet4.String(), ClientRelaySubnet6.String()}) + relayAddrs := []string{ClientRelaySubnet4.String()} + if !disableV6 { + relayAddrs = append(relayAddrs, ClientRelaySubnet6.String()) + } + err = leafServerPeerConfigRelay.SetAllowedIPs(relayAddrs) check("failed to set allowedIPs", err) serverConfigRelay.AddPeer(leafServerPeerConfigRelay) serverConfigE2EE.AddPeer(clientPeerConfigE2EE) // Make E2EE peer for local config. - c.allowedIPs = append(c.allowedIPs, addresses.ApiAddr.String()+"/128") + c.allowedIPs = append(c.allowedIPs, fmt.Sprintf("%s/%d", addresses.ApiAddr.String(), addresses.ApiAddr.BitLen())) serverPeerConfigE2EE, err := peer.GetPeerConfig(peer.PeerConfigArgs{ PublicKey: serverConfigE2EE.GetPublicKey(), AllowedIPs: c.allowedIPs, @@ -221,9 +230,13 @@ func (c addServerCmdConfig) Run() { clientConfigE2EE.AddPeer(serverPeerConfigE2EE) // Make peer config for the server that this new server will connect to. + addrs := []string{addresses.NextServerRelayAddr4.String() + "/32"} + if !disableV6 { + addrs = append(addrs, addresses.NextServerRelayAddr6.String()+"/128") + } serverPeerConfigRelay, err := peer.GetPeerConfig(peer.PeerConfigArgs{ PublicKey: serverConfigRelay.GetPublicKey(), - AllowedIPs: []string{addresses.NextServerRelayAddr4.String() + "/32", addresses.NextServerRelayAddr6.String() + "/128"}, + AllowedIPs: addrs, Endpoint: func() string { if addArgs.outbound { return addArgs.endpoint @@ -245,9 +258,9 @@ func (c addServerCmdConfig) Run() { err = api.AddRelayPeer(leafApiAddrPort, serverPeerConfigRelay) check("failed to add peer to leaf server", err) - err = serverConfigRelay.SetAddresses([]string{addresses.NextServerRelayAddr4.String() + "/32", addresses.NextServerRelayAddr6.String() + "/128"}) + err = serverConfigRelay.SetAddresses(addrs) check("failed to set addresses", err) - err = serverConfigE2EE.SetAddresses([]string{addresses.ApiAddr.String() + "/128"}) + err = serverConfigE2EE.SetAddresses([]string{fmt.Sprintf("%s/%d", addresses.ApiAddr.String(), addresses.ApiAddr.BitLen())}) check("failed to set addresses", err) // Update routes for every node in path to new server (after getting addresses) @@ -265,8 +278,8 @@ func (c addServerCmdConfig) Run() { // Find which of our E2EE peers has an endpoint that matches the first Allowed IP of this peer: for _, e2ee_p := range clientConfigE2EE.GetPeers() { if p.GetAllowedIPs()[0].Contains(e2ee_p.GetEndpoint().IP) { - aps := e2ee_p.GetAllowedIPs() - serverApi = netip.MustParseAddrPort(net.JoinHostPort(aps[len(aps)-1].IP.String(), fmt.Sprint(ApiPort))) + aa := e2ee_p.GetApiAddr() + serverApi = netip.MustParseAddrPort(net.JoinHostPort(aa.String(), fmt.Sprint(ApiPort))) continue outer } } @@ -320,7 +333,7 @@ func (c addServerCmdConfig) Run() { // Copy to clipboard if requested. var clipboardStatus string if c.writeToClipboard { - err = clipboard.WriteAll(peer.CreateServerCommand(serverConfigRelay, serverConfigE2EE, peer.POSIX, false)) + err = clipboard.WriteAll(peer.CreateServerCommand(serverConfigRelay, serverConfigE2EE, peer.POSIX, false, disableV6)) if err != nil { clipboardStatus = fmt.Sprintf("%s %s", RedBold("clipboard:"), Red(fmt.Sprintf("error copying to clipboard: %v", err))) } else { @@ -347,8 +360,8 @@ func (c addServerCmdConfig) Run() { fmt.Fprintln(color.Output) fmt.Fprintln(color.Output, fileStatusServer) fmt.Fprintln(color.Output) - fmt.Fprintln(color.Output, Cyan("POSIX Shell: "), Green(peer.CreateServerCommand(serverConfigRelay, serverConfigE2EE, peer.POSIX, false))) - fmt.Fprintln(color.Output, Cyan(" PowerShell: "), Green(peer.CreateServerCommand(serverConfigRelay, serverConfigE2EE, peer.PowerShell, false))) + fmt.Fprintln(color.Output, Cyan("POSIX Shell: "), Green(peer.CreateServerCommand(serverConfigRelay, serverConfigE2EE, peer.POSIX, false, disableV6))) + fmt.Fprintln(color.Output, Cyan(" PowerShell: "), Green(peer.CreateServerCommand(serverConfigRelay, serverConfigE2EE, peer.PowerShell, false, disableV6))) fmt.Fprintln(color.Output, Cyan("Config File: "), Green("./wiretap serve -f "+c.configFileServer)) fmt.Fprintln(color.Output) if c.writeToClipboard { diff --git a/src/cmd/configure.go b/src/cmd/configure.go index d809cdb..c36e132 100644 --- a/src/cmd/configure.go +++ b/src/cmd/configure.go @@ -30,8 +30,10 @@ type configureCmdConfig struct { serverAddr4Relay string serverAddr6Relay string apiAddr string + apiv4Addr string keepalive int mtu int + disableV6 bool } // Defaults for configure command. @@ -53,8 +55,10 @@ var configureCmdArgs = configureCmdConfig{ serverAddr4Relay: RelaySubnets4.Addr().Next().Next().String() + "/32", serverAddr6Relay: RelaySubnets6.Addr().Next().Next().String() + "/128", apiAddr: ApiSubnets.Addr().Next().Next().String() + "/128", + apiv4Addr: ApiV4Subnets.Addr().Next().Next().String() + "/32", keepalive: Keepalive, mtu: MTU, + disableV6: false, } // configureCmd represents the configure command. @@ -91,6 +95,7 @@ func init() { configureCmd.Flags().IntVarP(&configureCmdArgs.keepalive, "keepalive", "k", configureCmdArgs.keepalive, "tunnel keepalive in seconds, only applies to outbound handshakes") configureCmd.Flags().IntVarP(&configureCmdArgs.mtu, "mtu", "m", configureCmdArgs.mtu, "tunnel MTU") + configureCmd.Flags().BoolVarP(&configureCmdArgs.disableV6, "disable-ipv6", "", configureCmdArgs.disableV6, "disables IPv6") err := configureCmd.MarkFlagRequired("routes") check("failed to mark flag required", err) @@ -102,7 +107,7 @@ func init() { helpFunc := configureCmd.HelpFunc() configureCmd.SetHelpFunc(func(cmd *cobra.Command, args []string) { if !ShowHidden { - for _, f := range []string{"api", "ipv4-relay", "ipv6-relay", "ipv4-e2ee", "ipv6-e2ee", "ipv4-relay-server", "ipv6-relay-server", "keepalive", "mtu"} { + for _, f := range []string{"api", "ipv4-relay", "ipv6-relay", "ipv4-e2ee", "ipv6-e2ee", "ipv4-relay-server", "ipv6-relay-server", "keepalive", "mtu", "disable-ipv6"} { err := cmd.Flags().MarkHidden(f) if err != nil { fmt.Printf("Failed to hide flag %v: %v\n", f, err) @@ -118,6 +123,9 @@ func init() { func (c configureCmdConfig) Run() { var err error + if c.disableV6 && netip.MustParsePrefix(c.apiAddr).Addr().Is6() { + c.apiAddr = c.apiv4Addr + } c.allowedIPs = append(c.allowedIPs, c.apiAddr) // Generate client and server configs. @@ -137,6 +145,21 @@ func (c configureCmdConfig) Run() { relaySubnet4 = netip.PrefixFrom(relaySubnet4.Addr(), SubnetV4Bits).Masked() relaySubnet6 = netip.PrefixFrom(relaySubnet6.Addr(), SubnetV6Bits).Masked() + relaySubnets := []netip.Prefix{relaySubnet4} + if !c.disableV6 { + relaySubnets = append(relaySubnets, relaySubnet6) + } + + clientRelayAddrs := []string{c.clientAddr4Relay} + if !c.disableV6 { + clientRelayAddrs = append(clientRelayAddrs, c.clientAddr6Relay) + } + + clientE2EEAddrs := []string{c.clientAddr4E2EE} + if !c.disableV6 { + clientE2EEAddrs = append(clientE2EEAddrs, c.clientAddr6E2EE) + } + clientConfigRelayArgs := peer.ConfigArgs{ ListenPort: c.port, Peers: []peer.PeerConfigArgs{ @@ -146,7 +169,13 @@ func (c configureCmdConfig) Run() { if c.simple { return c.allowedIPs } else { - return []string{relaySubnet4.String(), relaySubnet6.String()} + return func() []string { + var s []string + for _, r := range relaySubnets { + s = append(s, r.String()) + } + return s + }() } }(), Endpoint: func() string { @@ -165,7 +194,7 @@ func (c configureCmdConfig) Run() { }(), }, }, - Addresses: []string{c.clientAddr4Relay, c.clientAddr6Relay}, + Addresses: clientRelayAddrs, } clientConfigE2EEArgs := peer.ConfigArgs{ @@ -177,7 +206,7 @@ func (c configureCmdConfig) Run() { Endpoint: net.JoinHostPort(relaySubnet4.Addr().Next().Next().String(), fmt.Sprint(E2EEPort)), }, }, - Addresses: []string{c.clientAddr4E2EE, c.clientAddr6E2EE}, + Addresses: clientE2EEAddrs, MTU: c.mtu - 80, } @@ -215,6 +244,10 @@ func (c configureCmdConfig) Run() { c.configFileE2EE = peer.FindAvailableFilename(c.configFileE2EE) c.configFileServer = peer.FindAvailableFilename(c.configFileServer) + if c.simple { + c.configFileRelay = c.configFileE2EE + } + // Write config file and get status string. var fileStatusRelay string err = os.WriteFile(c.configFileRelay, []byte(clientConfigRelay.AsFile()), 0600) @@ -249,11 +282,14 @@ func (c configureCmdConfig) Run() { if c.simple { serverConfigFile = fmt.Sprintf("%s --simple", serverConfigFile) } + if c.disableV6 { + serverConfigFile = fmt.Sprintf("%s --disable-ipv6", serverConfigFile) + } // Copy to clipboard if requested. var clipboardStatus string if c.writeToClipboard { - err = clipboard.WriteAll(peer.CreateServerCommand(serverConfigRelay, serverConfigE2EE, peer.POSIX, c.simple)) + err = clipboard.WriteAll(peer.CreateServerCommand(serverConfigRelay, serverConfigE2EE, peer.POSIX, c.simple, c.disableV6)) if err != nil { clipboardStatus = fmt.Sprintf("%s %s", RedBold("clipboard:"), Red(fmt.Sprintf("error copying to clipboard: %v", err))) } else { @@ -281,8 +317,8 @@ func (c configureCmdConfig) Run() { fmt.Fprintln(color.Output, fileStatusServer) fmt.Fprintln(color.Output) fmt.Fprintln(color.Output, GreenBold("server command:")) - fmt.Fprintln(color.Output, Cyan("POSIX Shell: "), Green(peer.CreateServerCommand(serverConfigRelay, serverConfigE2EE, peer.POSIX, c.simple))) - fmt.Fprintln(color.Output, Cyan(" PowerShell: "), Green(peer.CreateServerCommand(serverConfigRelay, serverConfigE2EE, peer.PowerShell, c.simple))) + fmt.Fprintln(color.Output, Cyan("POSIX Shell: "), Green(peer.CreateServerCommand(serverConfigRelay, serverConfigE2EE, peer.POSIX, c.simple, c.disableV6))) + fmt.Fprintln(color.Output, Cyan(" PowerShell: "), Green(peer.CreateServerCommand(serverConfigRelay, serverConfigE2EE, peer.PowerShell, c.simple, c.disableV6))) fmt.Fprintln(color.Output, Cyan("Config File: "), Green(serverConfigFile)) fmt.Fprintln(color.Output) if c.writeToClipboard { diff --git a/src/cmd/expose.go b/src/cmd/expose.go new file mode 100644 index 0000000..71e91dc --- /dev/null +++ b/src/cmd/expose.go @@ -0,0 +1,246 @@ +package cmd + +import ( + "fmt" + "log" + "net/netip" + + "github.com/fatih/color" + "github.com/spf13/cobra" + + "wiretap/api" + "wiretap/peer" +) + +type exposeCmdConfig struct { + serverAddr string + localPort uint + remotePort uint + protocol string + dynamic bool + configFile string +} + +// Defaults for expose command. +// See root command for shared defaults. +var exposeCmd = exposeCmdConfig{ + serverAddr: "", + localPort: 0, + remotePort: 0, + protocol: "tcp", + dynamic: false, + configFile: ConfigE2EE, +} + +// Add command and set flags. +func init() { + // Base command. + cmd := &cobra.Command{ + Use: "expose", + Short: "Expose local services to servers", + Long: `Expose a port statically or allow dynamic forwarding through a remote server to the local network`, + ValidArgs: []string{"remove", "list"}, + Args: cobra.OnlyValidArgs, + Run: func(cmd *cobra.Command, args []string) { + exposeCmd.Run() + }, + } + + rootCmd.AddCommand(cmd) + + cmd.Flags().UintVarP(&exposeCmd.localPort, "local", "l", exposeCmd.localPort, "Local port to expose") + cmd.Flags().UintVarP(&exposeCmd.remotePort, "remote", "r", exposeCmd.remotePort, "Remote port to forward if different from local port") + cmd.Flags().StringVarP(&exposeCmd.protocol, "protocol", "p", exposeCmd.protocol, "Port protocol, tcp/udp") + cmd.Flags().BoolVarP(&exposeCmd.dynamic, "dynamic", "d", exposeCmd.dynamic, "Dynamic port forwarding, SOCKS proxy service opens on remote port") + cmd.PersistentFlags().StringVarP(&exposeCmd.serverAddr, "server-address", "s", exposeCmd.serverAddr, "API address of server that ports should be forwarded from, exposes service to all servers by default") + cmd.PersistentFlags().StringVarP(&exposeCmd.configFile, "config", "c", exposeCmd.configFile, "Config file needed when talking to all serves (the default)") + + cmd.MarkFlagsMutuallyExclusive("dynamic", "local") + + cmd.Flags().SortFlags = false + + listCmd := &cobra.Command{ + Use: "list", + Short: "List exposed ports", + Long: `List all static and dynamically forwarded ports`, + Run: func(cmd *cobra.Command, args []string) { + exposeCmd.List() + }, + } + + cmd.AddCommand(listCmd) + + deleteCmd := &cobra.Command{ + Use: "remove", + Short: "Remove exposed ports", + Long: `Remove exposed ports`, + Run: func(cmd *cobra.Command, args []string) { + exposeCmd.Delete() + }, + } + + deleteCmd.Flags().UintVarP(&exposeCmd.localPort, "local", "l", exposeCmd.localPort, "Local port") + deleteCmd.Flags().UintVarP(&exposeCmd.remotePort, "remote", "r", exposeCmd.remotePort, "Remote port") + deleteCmd.Flags().StringVarP(&exposeCmd.protocol, "protocol", "p", exposeCmd.protocol, "Port protocol, tcp/udp") + deleteCmd.Flags().BoolVarP(&exposeCmd.dynamic, "dynamic", "d", exposeCmd.dynamic, "Dynamic port forwarding") + + cmd.AddCommand(deleteCmd) +} + +// Run attempts to ping server API and prints response. +func (c exposeCmdConfig) Run() { + var apiAddrs []netip.Addr + + // Get list of all API addrs *or* just use provided addr + if c.serverAddr != "" { + apiAddr, err := netip.ParseAddr(c.serverAddr) + check("failed to parse server address", err) + + apiAddrs = append(apiAddrs, apiAddr) + } else { + config, err := peer.ParseConfig(c.configFile) + check("failed to parse config file", err) + + for _, p := range config.GetPeers() { + apiAddrs = append(apiAddrs, p.GetApiAddr()) + } + } + + if c.dynamic { + // Validate options required for dynamic forwarding. + if c.remotePort < 1 || c.remotePort > 65535 { + log.Fatalln("invalid remote port:", c.remotePort) + } + } else { + // Validate options required for static forwarding. + if c.localPort < 1 || c.localPort > 65535 { + log.Fatalln("invalid local port:", c.localPort) + } + if c.remotePort == 0 { + c.remotePort = c.localPort + } else { + if c.remotePort > 65535 { + log.Fatalln("invalid remote port:", c.remotePort) + } + } + + if c.protocol != "tcp" && c.protocol != "udp" { + log.Fatalln("invalid protocol:", c.protocol) + } + } + + // Make API requests to the list of API addresses with the parameters: localPort, remotePort, protocol, dynamic + fmt.Fprintf(color.Output, "%s: local %s <- remote %d\n", GreenBold("expose"), func() string { + if c.dynamic { + return "*" + } else { + return fmt.Sprint(c.localPort) + } + }(), c.remotePort) + for _, a := range apiAddrs { + err := api.Expose(netip.AddrPortFrom(a, uint16(ApiPort)), c.localPort, c.remotePort, c.protocol, c.dynamic) + if err != nil { + fmt.Fprintf(color.Output, "\t[%v] %s: %s\n", RedBold(a), RedBold("error"), Red(err)) + } else { + fmt.Fprintf(color.Output, "\t[%v] %s\n", GreenBold(a), Green("OK")) + } + } +} + +// List lists the exposed port configuration for server(s). +func (c exposeCmdConfig) List() { + var apiAddrs []netip.Addr + + // Get list of all API addrs *or* just use provided addr + if c.serverAddr != "" { + apiAddr, err := netip.ParseAddr(c.serverAddr) + check("failed to parse server address", err) + + apiAddrs = append(apiAddrs, apiAddr) + } else { + config, err := peer.ParseConfig(c.configFile) + check("failed to parse config file", err) + + for _, p := range config.GetPeers() { + apiAddrs = append(apiAddrs, p.GetApiAddr()) + } + } + + for _, a := range apiAddrs { + tuples, err := api.ExposeList(netip.AddrPortFrom(a, uint16(ApiPort))) + if err != nil { + fmt.Fprintf(color.Output, "[%v] %s: %s\n", RedBold(a), RedBold("error"), Red(err)) + } else { + fmt.Fprintf(color.Output, "[%v]: %s\n", GreenBold(a), Cyan(len(tuples))) + for _, t := range tuples { + fmt.Fprintf(color.Output, "\tlocal %s <- remote %d/%s\n", func() string { + if t.LocalPort == 0 { + return "*" + } else { + return fmt.Sprintf("%d/%s", t.LocalPort, t.Protocol) + } + }(), t.RemotePort, t.Protocol) + } + } + } +} + +// Delete removes +func (c exposeCmdConfig) Delete() { + var apiAddrs []netip.Addr + + // Get list of all API addrs *or* just use provided addr + if c.serverAddr != "" { + apiAddr, err := netip.ParseAddr(c.serverAddr) + check("failed to parse server address", err) + + apiAddrs = append(apiAddrs, apiAddr) + } else { + config, err := peer.ParseConfig(c.configFile) + check("failed to parse config file", err) + + for _, p := range config.GetPeers() { + apiAddrs = append(apiAddrs, p.GetApiAddr()) + } + } + + if c.dynamic { + // Validate options required for dynamic forwarding. + if c.remotePort < 1 || c.remotePort > 65535 { + log.Fatalln("invalid remote port:", c.remotePort) + } + } else { + // Validate options required for static forwarding. + if c.localPort < 1 || c.localPort > 65535 { + log.Fatalln("invalid local port:", c.localPort) + } + if c.remotePort == 0 { + c.remotePort = c.localPort + } else { + if c.remotePort > 65535 { + log.Fatalln("invalid remote port:", c.remotePort) + } + } + + if c.protocol != "tcp" && c.protocol != "udp" { + log.Fatalln("invalid protocol:", c.protocol) + } + } + + // Make API requests to the list of API addresses with the parameters: localPort, remotePort, protocol, dynamic + fmt.Fprintf(color.Output, "%s: local %s <- remote %d\n", GreenBold("delete"), func() string { + if c.dynamic { + return "*" + } else { + return fmt.Sprint(c.localPort) + } + }(), c.remotePort) + for _, a := range apiAddrs { + err := api.ExposeDelete(netip.AddrPortFrom(a, uint16(ApiPort)), c.localPort, c.remotePort, c.protocol, c.dynamic) + if err != nil { + fmt.Fprintf(color.Output, "\t[%v] %s: %s\n", RedBold(a), RedBold("error"), Red(err)) + } else { + fmt.Fprintf(color.Output, "\t[%v] %s\n", GreenBold(a), Green("Removed")) + } + } +} diff --git a/src/cmd/root.go b/src/cmd/root.go index 9cddc54..6cde237 100644 --- a/src/cmd/root.go +++ b/src/cmd/root.go @@ -18,12 +18,13 @@ var ( Port = 51820 E2EEPort = 51821 ConfigRelay = "wiretap_relay.conf" - ConfigE2EE = "wiretap_e2ee.conf" + ConfigE2EE = "wiretap.conf" ConfigServer = "wiretap_server.conf" Keepalive = 25 MTU = 1420 ShowHidden = false ApiSubnets = netip.MustParsePrefix("::/8") + ApiV4Subnets = netip.MustParsePrefix("192.0.2.0/24") ApiPort = 80 ClientRelaySubnet4 = netip.MustParsePrefix("172.16.0.0/16") ClientRelaySubnet6 = netip.MustParsePrefix("fd:16::/40") @@ -36,6 +37,7 @@ var ( SubnetV4Bits = 24 SubnetV6Bits = 48 APIBits = 16 + APIV4Bits = 24 ) // Define colors. diff --git a/src/cmd/serve.go b/src/cmd/serve.go index 08d821d..98e9a78 100644 --- a/src/cmd/serve.go +++ b/src/cmd/serve.go @@ -19,6 +19,8 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + gtcp "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + gudp "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "wiretap/peer" "wiretap/transport/api" @@ -44,6 +46,7 @@ type serveCmdConfig struct { keepaliveIdle uint keepaliveCount uint keepaliveInterval uint + disableV6 bool } type wiretapDefaultConfig struct { @@ -56,6 +59,7 @@ type wiretapDefaultConfig struct { serverAddr4E2EE string serverAddr6E2EE string apiAddr string + apiV4Addr string keepalive int mtu int } @@ -77,6 +81,7 @@ var serveCmd = serveCmdConfig{ keepaliveIdle: 60, keepaliveCount: 3, keepaliveInterval: 60, + disableV6: false, } var wiretapDefault = wiretapDefaultConfig{ @@ -89,6 +94,7 @@ var wiretapDefault = wiretapDefaultConfig{ serverAddr4E2EE: E2EESubnets4.Addr().Next().Next().String(), serverAddr6E2EE: E2EESubnets6.Addr().Next().Next().String(), apiAddr: ApiSubnets.Addr().Next().Next().String(), + apiV4Addr: ApiV4Subnets.Addr().Next().Next().String(), keepalive: Keepalive, mtu: MTU, } @@ -118,9 +124,10 @@ func init() { cmd.Flags().StringVarP(&serveCmd.logFile, "log-file", "o", serveCmd.logFile, "write log to this filename") cmd.Flags().UintVarP(&serveCmd.catchTimeout, "completion-timeout", "", serveCmd.catchTimeout, "time in ms for client to complete TCP connection to server") cmd.Flags().UintVarP(&serveCmd.connTimeout, "conn-timeout", "", serveCmd.connTimeout, "time in ms for server to wait for outgoing TCP handshakes to complete") - cmd.Flags().UintVarP(&serveCmd.connTimeout, "keepalive-idle", "", serveCmd.keepaliveIdle, "time in seconds before TCP keepalives are sent to client") - cmd.Flags().UintVarP(&serveCmd.connTimeout, "keepalive-interval", "", serveCmd.keepaliveInterval, "time in seconds between TCP keepalives") - cmd.Flags().UintVarP(&serveCmd.connTimeout, "keepalive-count", "", serveCmd.keepaliveCount, "number of unacknowledged TCP keepalives before closing connection") + cmd.Flags().UintVarP(&serveCmd.keepaliveIdle, "keepalive-idle", "", serveCmd.keepaliveIdle, "time in seconds before TCP keepalives are sent to client") + cmd.Flags().UintVarP(&serveCmd.keepaliveInterval, "keepalive-interval", "", serveCmd.keepaliveInterval, "time in seconds between TCP keepalives") + cmd.Flags().UintVarP(&serveCmd.keepaliveCount, "keepalive-count", "", serveCmd.keepaliveCount, "number of unacknowledged TCP keepalives before closing connection") + cmd.Flags().BoolVarP(&serveCmd.disableV6, "disable-ipv6", "", serveCmd.disableV6, "disable ipv6") cmd.Flags().StringVarP(&serveCmd.clientAddr4Relay, "ipv4-relay-client", "", serveCmd.clientAddr4Relay, "ipv4 relay address of client") cmd.Flags().StringVarP(&serveCmd.clientAddr6Relay, "ipv6-relay-client", "", serveCmd.clientAddr6Relay, "ipv6 relay address of client") @@ -131,6 +138,9 @@ func init() { err = viper.BindPFlag("simple", cmd.Flags().Lookup("simple")) check("error binding flag to viper", err) + err = viper.BindPFlag("disableipv6", cmd.Flags().Lookup("disable-ipv6")) + check("error binding flag to viper", err) + // Quiet and debug flags must be used independently. cmd.MarkFlagsMutuallyExclusive("debug", "quiet") @@ -233,6 +243,7 @@ func init() { "keepalive-interval", "keepalive-count", "keepalive-idle", + "disable-ipv6", } { err := cmd.Flags().MarkHidden(f) if err != nil { @@ -289,6 +300,22 @@ func (c serveCmdConfig) Run() { check("config error", errors.New("public key of peer is required")) } + if viper.IsSet("disableipv6") && netip.MustParseAddr(viper.GetString("E2EE.Interface.api")).Is6() { + viper.Set("E2EE.Interface.api", wiretapDefault.apiV4Addr) + } + + relayAddresses := []string{viper.GetString("Relay.Interface.ipv4") + "/32"} + if !viper.IsSet("disableipv6") { + relayAddresses = append(relayAddresses, viper.GetString("Relay.Interface.ipv6")+"/128") + } + aips := []string{} + for _, ip := range strings.Split(viper.GetString("Relay.Peer.allowed"), ",") { + if viper.IsSet("disableipv6") && netip.MustParsePrefix(ip).Addr().Is6() { + continue + } + + aips = append(aips, ip) + } configRelayArgs := peer.ConfigArgs{ PrivateKey: viper.GetString("Relay.Interface.privatekey"), ListenPort: viper.GetInt("Relay.Interface.port"), @@ -303,15 +330,23 @@ func (c serveCmdConfig) Run() { return 0 } }(), - AllowedIPs: strings.Split(viper.GetString("Relay.Peer.allowed"), ","), + AllowedIPs: aips, }, }, - Addresses: []string{viper.GetString("Relay.Interface.ipv4") + "/32", viper.GetString("Relay.Interface.ipv6") + "/128"}, + Addresses: relayAddresses, } configRelay, err := peer.GetConfig(configRelayArgs) check("failed to make relay configuration", err) + allowedIPs := []string{c.clientAddr4E2EE + "/32"} + if !viper.IsSet("disableipv6") { + allowedIPs = append(allowedIPs, c.clientAddr6E2EE+"/128") + } + e2eeAddresses := []string{viper.GetString("E2EE.Interface.ipv4") + "/32"} + if !viper.IsSet("disableipv6") { + e2eeAddresses = append(e2eeAddresses, viper.GetString("E2EE.Interface.ipv6")+"/128") + } var configE2EE peer.Config if !viper.GetBool("simple") { configE2EEArgs := peer.ConfigArgs{ @@ -321,11 +356,11 @@ func (c serveCmdConfig) Run() { { PublicKey: viper.GetString("E2EE.Peer.publickey"), Endpoint: viper.GetString("E2EE.Peer.endpoint"), - AllowedIPs: []string{c.clientAddr4E2EE + "/32", c.clientAddr6E2EE + "/128"}, + AllowedIPs: allowedIPs, PersistentKeepaliveInterval: viper.GetInt("Relay.Peer.keepalive"), }, }, - Addresses: []string{viper.GetString("E2EE.Interface.ipv4") + "/32", viper.GetString("E2EE.Interface.ipv6") + "/128", viper.GetString("E2EE.Interface.api") + "/128"}, + Addresses: e2eeAddresses, } configE2EE, err = peer.GetConfig(configE2EEArgs) check("failed to make e2ee configuration", err) @@ -353,10 +388,14 @@ func (c serveCmdConfig) Run() { ipv4Addr, err := netip.ParseAddr(viper.GetString("Relay.Interface.ipv4")) check("failed to parse ipv4 address", err) - ipv6Addr, err := netip.ParseAddr(viper.GetString("Relay.Interface.ipv6")) - check("failed to parse ipv6 address", err) + relayAddrs := []netip.Addr{ipv4Addr} + + if !viper.IsSet("disableipv6") { + ipv6Addr, err := netip.ParseAddr(viper.GetString("Relay.Interface.ipv6")) + check("failed to parse ipv6 address", err) + relayAddrs = append(relayAddrs, ipv6Addr) + } - relayAddrs := []netip.Addr{ipv4Addr, ipv6Addr} if viper.GetBool("simple") { relayAddrs = append(relayAddrs, apiAddr) } @@ -377,21 +416,28 @@ func (c serveCmdConfig) Run() { if tcpipErr != nil { check("failed to enable forwarding", errors.New(tcpipErr.String())) } - tcpipErr = s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true) - if tcpipErr != nil { - check("failed to enable forwarding", errors.New(tcpipErr.String())) + if !viper.IsSet("disableipv6") { + tcpipErr = s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true) + if tcpipErr != nil { + check("failed to enable forwarding", errors.New(tcpipErr.String())) + } } // Create virtual e2ee interface with this address and MTU - 80. ipv4Addr, err = netip.ParseAddr(viper.GetString("E2EE.Interface.ipv4")) check("failed to parse ipv4 address", err) - ipv6Addr, err = netip.ParseAddr(viper.GetString("E2EE.Interface.ipv6")) - check("failed to parse ipv6 address", err) + e2eeAddrs := []netip.Addr{ipv4Addr, apiAddr} + + if !viper.IsSet("disableipv6") { + ipv6Addr, err := netip.ParseAddr(viper.GetString("E2EE.Interface.ipv6")) + check("failed to parse ipv6 address", err) + e2eeAddrs = append(e2eeAddrs, ipv6Addr) + } if !viper.GetBool("simple") { tunE2EE, tnetE2EE, err = netstack.CreateNetTUN( - []netip.Addr{ipv4Addr, ipv6Addr, apiAddr}, + e2eeAddrs, []netip.Addr{}, viper.GetInt("Relay.Interface.mtu")-80, ) @@ -399,6 +445,14 @@ func (c serveCmdConfig) Run() { } } + transportHandler := func() *netstack.Net { + if viper.GetBool("simple") { + return tnetRelay + } else { + return tnetE2EE + } + }() + var logger int if c.debug { logger = device.LogLevelVerbose @@ -408,6 +462,29 @@ func (c serveCmdConfig) Run() { logger = device.LogLevelError } + s := transportHandler.Stack() + s.SetPromiscuousMode(1, true) + + // TCP Forwarding mechanism. + tcpConfig := tcp.Config{ + CatchTimeout: time.Duration(c.catchTimeout) * time.Millisecond, + ConnTimeout: time.Duration(c.connTimeout) * time.Millisecond, + KeepaliveIdle: time.Duration(c.keepaliveIdle) * time.Second, + KeepaliveInterval: time.Duration(c.keepaliveInterval) * time.Second, + KeepaliveCount: int(c.keepaliveCount), + Tnet: transportHandler, + StackLock: &lock, + } + tcpForwarder := gtcp.NewForwarder(s, 0, 65535, tcp.Handler(tcpConfig)) + s.SetTransportProtocolHandler(gtcp.ProtocolNumber, tcpForwarder.HandlePacket) + + // UDP Forwarding mechanism. + udpConfig := udp.Config{ + Tnet: transportHandler, + StackLock: &lock, + } + s.SetTransportProtocolHandler(gudp.ProtocolNumber, udp.Handler(udpConfig)) + // Make new relay device. devRelay := device.NewDevice(tunRelay, conn.NewDefaultBind(), device.NewLogger(logger, "")) // Configure wireguard. @@ -430,61 +507,29 @@ func (c serveCmdConfig) Run() { check("failed to bring up e2ee device", err) } - transportHandler := func() *netstack.Net { - if viper.GetBool("simple") { - return tnetRelay - } else { - return tnetE2EE - } - }() + // Handlers that require long-running routines: - // Start transport layer handlers under the e2ee device. - wg.Add(1) - lock.Lock() - go func() { - config := tcp.TcpConfig{ - CatchTimeout: time.Duration(c.catchTimeout) * time.Millisecond, - ConnTimeout: time.Duration(c.connTimeout) * time.Millisecond, - KeepaliveIdle: time.Duration(c.keepaliveIdle) * time.Second, - KeepaliveInterval: time.Duration(c.keepaliveInterval) * time.Second, - KeepaliveCount: int(c.keepaliveCount), - Ipv4Addr: ipv4Addr, - Ipv6Addr: ipv6Addr, - Port: 1337, - } - tcp.Handle(transportHandler, config, &lock) - wg.Done() - }() - - lock.Lock() - wg.Add(1) - go func() { - udp.Handle(transportHandler, ipv4Addr, ipv6Addr, 1337, &lock) - wg.Done() - }() - - lock.Lock() + // Start ICMP Handler. wg.Add(1) go func() { icmp.Handle(transportHandler, &lock) wg.Done() }() - // Start API handler. Starting last because firewall rule needs to be first. - ns := api.NetworkState{ - NextClientRelayAddr4: netip.MustParseAddr(c.clientAddr4Relay), - NextClientRelayAddr6: netip.MustParseAddr(c.clientAddr6Relay), - NextServerRelayAddr4: netip.MustParseAddr(viper.GetString("Relay.Interface.ipv4")), - NextServerRelayAddr6: netip.MustParseAddr(viper.GetString("Relay.Interface.ipv6")), - NextClientE2EEAddr4: netip.MustParseAddr(c.clientAddr4E2EE), - NextClientE2EEAddr6: netip.MustParseAddr(c.clientAddr6E2EE), - NextServerE2EEAddr4: netip.MustParseAddr(viper.GetString("E2EE.Interface.ipv4")), - NextServerE2EEAddr6: netip.MustParseAddr(viper.GetString("E2EE.Interface.ipv6")), - ApiAddr: netip.MustParseAddr(viper.GetString("E2EE.Interface.api")), - } - lock.Lock() + // Start API handler. wg.Add(1) go func() { + ns := api.NetworkState{ + NextClientRelayAddr4: netip.MustParseAddr(c.clientAddr4Relay), + NextClientRelayAddr6: netip.MustParseAddr(c.clientAddr6Relay), + NextServerRelayAddr4: netip.MustParseAddr(viper.GetString("Relay.Interface.ipv4")), + NextServerRelayAddr6: netip.MustParseAddr(viper.GetString("Relay.Interface.ipv6")), + NextClientE2EEAddr4: netip.MustParseAddr(c.clientAddr4E2EE), + NextClientE2EEAddr6: netip.MustParseAddr(c.clientAddr6E2EE), + NextServerE2EEAddr4: netip.MustParseAddr(viper.GetString("E2EE.Interface.ipv4")), + NextServerE2EEAddr6: netip.MustParseAddr(viper.GetString("E2EE.Interface.ipv6")), + ApiAddr: netip.MustParseAddr(viper.GetString("E2EE.Interface.api")), + } api.Handle(transportHandler, devRelay, devE2EE, &configRelay, &configE2EE, apiAddr, uint16(ApiPort), &lock, &ns) wg.Done() }() diff --git a/src/cmd/status.go b/src/cmd/status.go index f26fca0..e510471 100644 --- a/src/cmd/status.go +++ b/src/cmd/status.go @@ -93,8 +93,6 @@ func (c statusCmdConfig) Run() { for _, rp := range current.relayConfig.GetPeers() { // Skip client-facing peers. for _, ip := range rp.GetAllowedIPs() { - ClientRelaySubnet4 = netip.MustParsePrefix("172.16.0.0/16") - ClientRelaySubnet6 = netip.MustParsePrefix("fd:16::/40") if ClientRelaySubnet4.Contains(netip.MustParseAddr(ip.IP.String())) || ClientRelaySubnet6.Contains(netip.MustParseAddr(ip.IP.String())) { continue outer } diff --git a/src/go.mod b/src/go.mod index 9c1a5c5..9165ce2 100644 --- a/src/go.mod +++ b/src/go.mod @@ -2,9 +2,7 @@ module wiretap go 1.20 -replace golang.zx2c4.com/wireguard => github.com/luker983/wireguard-go v0.0.0-20230405143335-420d4b1d8857 - -//replace golang.zx2c4.com/wireguard => ../custom-wireguard-go +replace golang.zx2c4.com/wireguard => github.com/luker983/wireguard-go v0.0.0-20230628150900-2e22d4a23db1 require ( github.com/atotto/clipboard v0.1.4 @@ -22,6 +20,7 @@ require ( ) require ( + github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/google/btree v1.1.2 // indirect github.com/google/uuid v1.3.0 // indirect diff --git a/src/go.sum b/src/go.sum index 609b7b7..5cfc8dd 100644 --- a/src/go.sum +++ b/src/go.sum @@ -38,6 +38,8 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= +github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= @@ -149,8 +151,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/libp2p/go-reuseport v0.2.0 h1:18PRvIMlpY6ZK85nIAicSBuXXvrYoSw3dsBAR7zc560= github.com/libp2p/go-reuseport v0.2.0/go.mod h1:bvVho6eLMm6Bz5hmU0LYN3ixd3nPPvtIlaURZZgOY4k= -github.com/luker983/wireguard-go v0.0.0-20230405143335-420d4b1d8857 h1:PCgvXGTboyf0jWw1sCGS70OfEo9/LRGBaLXpzt+DyV8= -github.com/luker983/wireguard-go v0.0.0-20230405143335-420d4b1d8857/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4= +github.com/luker983/wireguard-go v0.0.0-20230628150900-2e22d4a23db1 h1:+Tr/ZJIEOLVxTNqt2yv0b8l2CONqPfot4511omlEgks= +github.com/luker983/wireguard-go v0.0.0-20230628150900-2e22d4a23db1/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4= github.com/m1gwings/treedrawer v0.3.3-beta h1:VeeQ4I90+NL0G2Tga3H4EY4hbOyVP3ID4T93r21oLbQ= github.com/m1gwings/treedrawer v0.3.3-beta/go.mod h1:Sebh5tCtjQWAG/B9xWct163vB9pCbBcA1ykaUErDUTY= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= diff --git a/src/peer/config.go b/src/peer/config.go index 20ca0e9..15fcd57 100644 --- a/src/peer/config.go +++ b/src/peer/config.go @@ -419,7 +419,7 @@ func (c *Config) AsIPC() string { return s.String() } -func CreateServerCommand(relayConfig Config, e2eeConfig Config, shell Shell, simple bool) string { +func CreateServerCommand(relayConfig Config, e2eeConfig Config, shell Shell, simple bool, disableV6 bool) string { var s strings.Builder var keys []string var vals []string @@ -428,9 +428,11 @@ func CreateServerCommand(relayConfig Config, e2eeConfig Config, shell Shell, sim keys = append(keys, "WIRETAP_RELAY_INTERFACE_PRIVATEKEY") vals = append(vals, relayConfig.GetPrivateKey()) - if len(relayConfig.addresses) == 2 { + if len(relayConfig.addresses) >= 1 { keys = append(keys, "WIRETAP_RELAY_INTERFACE_IPV4") vals = append(vals, relayConfig.addresses[0].IP.String()) + } + if len(relayConfig.addresses) >= 2 { keys = append(keys, "WIRETAP_RELAY_INTERFACE_IPV6") vals = append(vals, relayConfig.addresses[1].IP.String()) } @@ -488,6 +490,11 @@ func CreateServerCommand(relayConfig Config, e2eeConfig Config, shell Shell, sim vals = append(vals, "true") } + if disableV6 { + keys = append(keys, "WIRETAP_DISABLEIPV6") + vals = append(vals, "true") + } + switch shell { case POSIX: for i := 0; i < len(keys); i++ { @@ -511,8 +518,10 @@ func CreateServerFile(relayConfig Config, e2eeConfig Config) string { s.WriteString("[Relay.Interface]\n") s.WriteString(fmt.Sprintf("PrivateKey = %s\n", relayConfig.GetPrivateKey())) - if len(relayConfig.addresses) == 2 { + if len(relayConfig.addresses) >= 1 { s.WriteString(fmt.Sprintf("IPv4 = %s\n", relayConfig.addresses[0].IP.String())) + } + if len(relayConfig.addresses) >= 2 { s.WriteString(fmt.Sprintf("IPv6 = %s\n", relayConfig.addresses[1].IP.String())) } diff --git a/src/transport/api/api.go b/src/transport/api/api.go index e9233cc..3621d90 100644 --- a/src/transport/api/api.go +++ b/src/transport/api/api.go @@ -2,7 +2,6 @@ package api import ( - "bytes" "encoding/json" "errors" "fmt" @@ -20,8 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "wiretap/peer" "wiretap/transport" @@ -66,41 +63,46 @@ type AddAllowedIPsRequest struct { AllowedIPs []net.IPNet } -var clientAddresses map[uint64]NetworkState -var serverAddresses map[uint64]NetworkState +type ExposeTuple struct { + RemoteAddr netip.Addr + LocalPort uint + RemotePort uint + Protocol string +} +type ExposeAction int -// Handle adds rule to top of firewall rules that accepts direct connections to API. -func Handle(tnet *netstack.Net, devRelay *device.Device, devE2EE *device.Device, relayConfig *peer.Config, e2eeConfig *peer.Config, addr netip.Addr, port uint16, lock *sync.Mutex, ns *NetworkState) { - s := tnet.Stack() +const ( + ExposeActionExpose ExposeAction = iota + ExposeActionList + ExposeActionDelete +) - headerFilter := stack.IPHeaderFilter{ - Protocol: tcp.ProtocolNumber, - CheckProtocol: true, - Dst: tcpip.Address(addr.AsSlice()), - DstMask: tcpip.Address(bytes.Repeat([]byte("\xff"), addr.BitLen()/8)), - } +type ExposeRequest struct { + Action ExposeAction + LocalPort uint + RemotePort uint + Protocol string + Dynamic bool +} - rule := stack.Rule{ - Filter: headerFilter, - Target: &stack.AcceptTarget{ - NetworkProtocol: func() tcpip.NetworkProtocolNumber { - if addr.Is4() { - return ipv4.ProtocolNumber - } - return ipv6.ProtocolNumber - }(), - }, - } +type ExposeConn struct { + TcpListener *net.Listener + UdpConn *net.UDPConn +} - tid := stack.NATID - transport.PushRule(s, rule, tid, addr.Is6()) - lock.Unlock() +var clientAddresses map[uint64]NetworkState +var serverAddresses map[uint64]NetworkState +// Handle adds rule to top of firewall rules that accepts direct connections to API. +func Handle(tnet *netstack.Net, devRelay *device.Device, devE2EE *device.Device, relayConfig *peer.Config, e2eeConfig *peer.Config, addr netip.Addr, port uint16, lock *sync.Mutex, ns *NetworkState) { configs := ServerConfigs{ RelayConfig: relayConfig, E2EEConfig: e2eeConfig, } + exposeMap := make(map[ExposeTuple]ExposeConn) + var exposeLock sync.RWMutex + serverAddresses = make(map[uint64]NetworkState) clientAddresses = make(map[uint64]NetworkState) @@ -120,11 +122,17 @@ func Handle(tnet *netstack.Net, devRelay *device.Device, devE2EE *device.Device, log.Panic(err) } + localAddr := tcpip.FullAddress{ + NIC: 1, + Addr: tcpip.Address(addr.AsSlice()), + } + http.HandleFunc("/ping", wrapApi(handlePing())) http.HandleFunc("/serverinfo", wrapApi(handleServerInfo(configs))) http.HandleFunc("/addpeer", wrapApi(handleAddPeer(devRelay, devE2EE, configs))) http.HandleFunc("/allocate", wrapApi(handleAllocate(ns))) http.HandleFunc("/addallowedips", wrapApi(handleAddAllowedIPs(devRelay, configs))) + http.HandleFunc("/expose", wrapApi(handleExpose(tnet, &exposeMap, &exposeLock, localAddr))) log.Println("API: API listener up") err = http.Serve(listener, nil) @@ -211,6 +219,7 @@ func handleAddPeer(devRelay *device.Device, devE2EE *device.Device, config Serve err = p.UnmarshalJSON(body) if err != nil { writeErr(w, err) + return } // If addresses not assigned, error out, should have been determined from a previous API call. @@ -355,3 +364,152 @@ func handleAddAllowedIPs(devRelay *device.Device, config ServerConfigs) http.Han w.WriteHeader(http.StatusOK) } } + +func handleExpose(tnet *netstack.Net, exposeMap *map[ExposeTuple]ExposeConn, exposeLock *sync.RWMutex, localAddr tcpip.FullAddress) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + // Parse query parameters. + decoder := json.NewDecoder(r.Body) + var requestArgs ExposeRequest + err := decoder.Decode(&requestArgs) + if err != nil { + writeErr(w, err) + return + } + + remoteAddr, _, _ := net.SplitHostPort(r.RemoteAddr) + + et := ExposeTuple{ + netip.MustParseAddr(remoteAddr), + requestArgs.LocalPort, + requestArgs.RemotePort, + requestArgs.Protocol, + } + + switch requestArgs.Action { + // Return list of all exposed ports. + case ExposeActionList: + exposeLock.RLock() + defer exposeLock.RUnlock() + + ett := []ExposeTuple{} + + for k := range *exposeMap { + ett = append(ett, k) + } + + body, err := json.Marshal(ett) + if err != nil { + writeErr(w, err) + return + } + + _, err = w.Write(body) + if err != nil { + log.Printf("API Error: %v", err) + } + return + // Start exposing port if not already done. + case ExposeActionExpose: + exposeLock.Lock() + defer exposeLock.Unlock() + + _, ok := (*exposeMap)[et] + if ok { + // Already exists, cancel. + writeErr(w, errors.New("port already exposed")) + return + } + + proto := ipv4.ProtocolNumber + if et.RemoteAddr.Is6() { + proto = ipv6.ProtocolNumber + } + + if requestArgs.Dynamic { + // Handle Dynamic. + l, err := net.Listen("tcp", fmt.Sprintf(":%d", requestArgs.RemotePort)) + if err != nil { + writeErr(w, err) + return + } + + // Bind successful, perform dynamic forwarding. + go transport.ForwardDynamic( + tnet.Stack(), + &l, + localAddr, + tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(et.RemoteAddr.AsSlice())}, + proto, + ) + + (*exposeMap)[et] = ExposeConn{TcpListener: &l} + } else if requestArgs.Protocol == "tcp" { + // Handle TCP. + l, err := net.Listen(requestArgs.Protocol, fmt.Sprintf(":%d", requestArgs.RemotePort)) + if err != nil { + writeErr(w, err) + return + } + + // Bind successful, expose port. + go transport.ForwardTcpPort( + tnet.Stack(), + l, + localAddr, + tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(et.RemoteAddr.AsSlice()), Port: uint16(et.LocalPort)}, + proto, + ) + + (*exposeMap)[et] = ExposeConn{TcpListener: &l} + } else { + // Handle UDP. + addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", requestArgs.RemotePort)) + if err != nil { + writeErr(w, err) + return + } + c, err := net.ListenUDP("udp", addr) + if err != nil { + writeErr(w, err) + return + } + + // Bind successful, expose port. + go transport.ForwardUdpPort( + tnet.Stack(), + c, + localAddr, + tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(et.RemoteAddr.AsSlice()), Port: uint16(et.LocalPort)}, + proto, + ) + + (*exposeMap)[et] = ExposeConn{UdpConn: c} + } + + // Stop listener and delete from map. + case ExposeActionDelete: + exposeLock.Lock() + defer exposeLock.Unlock() + + c, ok := (*exposeMap)[et] + if ok { + if et.Protocol == "tcp" && c.TcpListener != nil { + (*c.TcpListener).Close() + } else if c.UdpConn != nil { + c.UdpConn.Close() + } + delete(*exposeMap, et) + } else { + writeErr(w, errors.New("not found")) + return + } + } + + w.WriteHeader(http.StatusOK) + } +} diff --git a/src/transport/icmp/icmp.go b/src/transport/icmp/icmp.go index b058df2..e06f75c 100644 --- a/src/transport/icmp/icmp.go +++ b/src/transport/icmp/icmp.go @@ -2,6 +2,7 @@ package icmp import ( + "bytes" "log" "net" "net/netip" @@ -18,100 +19,101 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/waiter" "wiretap/transport" ) -// preroutingMatch matches packets in the prerouting stage and clones: -// packet into channel for processing. -type preroutingMatch struct { - pktChan chan stack.PacketBufferPtr -} - var pinger Ping = nil -// When a new ICMP message hits the prerouting stage, the packet is cloned -// to the ICMP handler and dropped here. -func (m preroutingMatch) Match(hook stack.Hook, packet stack.PacketBufferPtr, inputInterfaceName, outputInterfaceName string) (matches bool, hotdrop bool) { - if hook == stack.Prerouting { - m.pktChan <- packet.Clone() - return false, true - } - - return false, false -} - -// handleICMP proxies ICMP messages using whatever means it can with the permissions this binary -// has on the system. func Handle(tnet *netstack.Net, lock *sync.Mutex) { - s := tnet.Stack() - - // create iptables rule that drops icmp, but clones packet and sends it to this handler. - headerFilter4 := stack.IPHeaderFilter{ - Protocol: icmp.ProtocolNumber4, - CheckProtocol: true, - } - - headerFilter6 := stack.IPHeaderFilter{ - Protocol: icmp.ProtocolNumber6, - CheckProtocol: true, - } - - match := preroutingMatch{ - pktChan: make(chan stack.PacketBufferPtr), - } + handler := func(t tcpip.TransportProtocolNumber, n tcpip.NetworkProtocolNumber) { + var wq waiter.Queue + lock.Lock() + ep, err := tnet.Stack().NewRawEndpoint(t, n, &wq, true) + lock.Unlock() + if err != nil { + log.Panic("icmp handler error:", err) + } - rule4 := stack.Rule{ - Filter: headerFilter4, - Matchers: []stack.Matcher{match}, - Target: &stack.DropTarget{ - NetworkProtocol: ipv4.ProtocolNumber, - }, - } + // Need this to get destination address. + ep.SocketOptions().SetIPv6ReceivePacketInfo(true) + + waitEntry, notifyChan := waiter.NewChannelEntry(waiter.ReadableEvents) + wq.EventRegister(&waitEntry) + defer wq.EventUnregister(&waitEntry) + + for { + var buf bytes.Buffer + res, err := ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true}) + if err != nil { + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + log.Panic("icmp handler error:", err) + } + } else { + var netHeader header.Network + if n == ipv4.ProtocolNumber { + netHeader = header.IPv4(buf.Bytes()) + } else { + // TODO: Come up with a better way to do this than manually building ipv6 header. + version := 6 //IPv6 + nextHeader := 58 // ICMPv6 + payloadLength := len(buf.Bytes()) + hopLimit := 64 // TTL + src := netip.MustParseAddr(res.RemoteAddr.Addr.String()).AsSlice() + dst := netip.MustParseAddr(res.ControlMessages.IPv6PacketInfo.Addr.String()).AsSlice() + + ipv6Header := []byte{ + uint8(version << 4), // Version / Traffic Class + 0, 0, 0, // Traffic Class / Flow Label + uint8((payloadLength >> 8) & 0xFF), // Payload Length MSN + uint8(payloadLength & 0xFF), // Payload Length LSN + uint8(nextHeader), + uint8(hopLimit), + } + + ipv6Header = append(ipv6Header, src...) + ipv6Header = append(ipv6Header, dst...) + packet := append(ipv6Header, buf.Bytes()...) + netHeader = header.IPv6(packet) + } + + go func() { + handleMessage(tnet.Stack(), netHeader) + }() + + continue + } + + <-notifyChan + } - rule6 := stack.Rule{ - Filter: headerFilter6, - Matchers: []stack.Matcher{match}, - Target: &stack.DropTarget{ - NetworkProtocol: ipv6.ProtocolNumber, - }, } - tid := stack.NATID - transport.PushRule(s, rule4, tid, false) - transport.PushRule(s, rule6, tid, true) - lock.Unlock() - - log.Println("Transport: ICMP listener up") - for { - clonedPacket := <-match.pktChan - go func() { - handleMessage(s, clonedPacket) - clonedPacket.DecRef() - }() - } + // Start handler for ipv4 and ipv6. + go handler(icmp.ProtocolNumber4, ipv4.ProtocolNumber) + handler(icmp.ProtocolNumber6, ipv6.ProtocolNumber) } // handleICMPMessage parses ICMP packets and proxies them if possible. -func handleMessage(s *stack.Stack, packet stack.PacketBufferPtr) { +func handleMessage(s *stack.Stack, pkt header.Network) { // Parse ICMP packet type. - netHeader := packet.Network() - log.Printf("(client %v) - Transport: ICMP -> %v", netHeader.SourceAddress(), netHeader.DestinationAddress()) + log.Printf("(client %v) - Transport: ICMP -> %v", pkt.SourceAddress(), pkt.DestinationAddress()) - isIpv6 := !netip.MustParseAddr(netHeader.SourceAddress().String()).Is4() + isIpv6 := !netip.MustParseAddr(pkt.SourceAddress().String()).Is4() if isIpv6 { - transHeader := header.ICMPv6(netHeader.Payload()) + transHeader := header.ICMPv6(pkt.Payload()) switch transHeader.Type() { case header.ICMPv6EchoRequest: - handleEcho(s, packet) + handleEcho(s, pkt) default: log.Println("ICMPv6 type not implemented:", transHeader.Type()) } } else { - transHeader := header.ICMPv4(netHeader.Payload()) + transHeader := header.ICMPv4(pkt.Payload()) switch transHeader.Type() { case header.ICMPv4Echo: - handleEcho(s, packet) + handleEcho(s, pkt) default: log.Println("ICMPv4 type not implemented:", transHeader.Type()) } @@ -121,12 +123,12 @@ func handleMessage(s *stack.Stack, packet stack.PacketBufferPtr) { // handleICMPEcho tries to send ICMP echo requests to the true destination however it can. // If successful, it sends an echo response to the peer. -func handleEcho(s *stack.Stack, packet stack.PacketBufferPtr) { +func handleEcho(s *stack.Stack, pkt header.Network) { var success bool var err error // Parse network header for destination address. - dest := packet.Network().DestinationAddress().String() + dest := pkt.DestinationAddress().String() if pinger == nil { pinger, success, err = getPing(dest) @@ -136,7 +138,7 @@ func handleEcho(s *stack.Stack, packet stack.PacketBufferPtr) { if err == nil { if success { - sendEchoResponse(s, packet) + sendEchoResponse(s, pkt) } return @@ -146,19 +148,17 @@ func handleEcho(s *stack.Stack, packet stack.PacketBufferPtr) { } // sendICMPEchoResponse sends an echo response to the peer with a spoofed source address. -func sendEchoResponse(s *stack.Stack, packet stack.PacketBufferPtr) { +func sendEchoResponse(s *stack.Stack, pkt header.Network) { var response []byte var ipHeader []byte var err error - netHeader := packet.Network() - - isIpv6 := netHeader.DestinationAddress().To4() == "" + isIpv6 := pkt.DestinationAddress().To4() == "" netProto := ipv4.ProtocolNumber if isIpv6 { netProto = ipv6.ProtocolNumber - transHeader := header.ICMPv6(netHeader.Payload()) + transHeader := header.ICMPv6(pkt.Payload()) // Create ICMP response and marshal it. response, err = (&neticmp.Message{ Type: netipv6.ICMPTypeEchoReply, @@ -168,25 +168,25 @@ func sendEchoResponse(s *stack.Stack, packet stack.PacketBufferPtr) { Seq: int(transHeader.Sequence()), Data: transHeader.Payload(), }, - }).Marshal(neticmp.IPv6PseudoHeader(net.ParseIP(netHeader.DestinationAddress().String()), net.ParseIP(netHeader.SourceAddress().String()))) + }).Marshal(neticmp.IPv6PseudoHeader(net.ParseIP(pkt.DestinationAddress().String()), net.ParseIP(pkt.SourceAddress().String()))) if err != nil { - log.Println("Failed to marshal response:", err) + log.Println("failed to marshal response:", err) return } // Assert type to get network header bytes. - ipv6Header, ok := netHeader.(header.IPv6) + ipv6Header, ok := pkt.(header.IPv6) if !ok { - log.Println("Could not assert network header as IPv6 header") + log.Println("could not assert network header as IPv6 header") return } // Swap source and destination addresses from original request. tmp := ipv6Header.DestinationAddress() ipv6Header.SetDestinationAddress(ipv6Header.SourceAddress()) ipv6Header.SetSourceAddress(tmp) - ipHeader = ipv6Header + ipHeader = ipv6Header[:40] } else { - transHeader := header.ICMPv4(netHeader.Payload()) + transHeader := header.ICMPv4(pkt.Payload()) // Create ICMP response and marshal it. response, err = (&neticmp.Message{ Type: netipv4.ICMPTypeEchoReply, @@ -198,26 +198,26 @@ func sendEchoResponse(s *stack.Stack, packet stack.PacketBufferPtr) { }, }).Marshal(nil) if err != nil { - log.Println("Failed to marshal response:", err) + log.Println("failed to marshal response:", err) return } // Assert type to get network header bytes. - ipv4Header, ok := netHeader.(header.IPv4) + ipv4Header, ok := pkt.(header.IPv4) if !ok { - log.Println("Could not assert network header as IPv4 header") + log.Println("could not assert network header as IPv4 header") return } // Swap source and destination addresses from original request. tmp := ipv4Header.DestinationAddress() ipv4Header.SetDestinationAddress(ipv4Header.SourceAddress()) ipv4Header.SetSourceAddress(tmp) - ipHeader = ipv4Header + ipHeader = ipv4Header[:ipv4Header.HeaderLength()] } - tcpipErr := transport.SendPacket(s, append(ipHeader, response...), &tcpip.FullAddress{NIC: 1, Addr: netHeader.SourceAddress()}, netProto) + tcpipErr := transport.SendPacket(s, append(ipHeader, response...), &tcpip.FullAddress{NIC: 1, Addr: pkt.SourceAddress()}, netProto) if tcpipErr != nil { - log.Println("Failed to write:", tcpipErr) + log.Println("failed to write:", tcpipErr) return } } diff --git a/src/transport/tcp/adapter.go b/src/transport/tcp/adapter.go deleted file mode 100644 index 56df927..0000000 --- a/src/transport/tcp/adapter.go +++ /dev/null @@ -1,643 +0,0 @@ -// Copy of gvisor gonet TCPListener so we can implement new method that gets correct remote address even after connection has closed. -// Modifications have been made to the original file. -// TODO: Raise issue with gvisor about RemoteAddress() behavior so this can be removed. -package tcp - -import ( - "bytes" - "context" - "errors" - "fmt" - "io" - "net" - "net/netip" - "time" - - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/waiter" -) - -// Address conversion adapted from https://git.zx2c4.com/wireguard-go/tree/tun/netstack/tun.go. -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. - */ - -// Reimplementation of the private function netstack.convertToFullAddr. -func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { - var protoNumber tcpip.NetworkProtocolNumber - if endpoint.Addr().Is4() { - protoNumber = ipv4.ProtocolNumber - } else { - protoNumber = ipv6.ProtocolNumber - } - return tcpip.FullAddress{ - NIC: 1, - Addr: tcpip.Address(endpoint.Addr().AsSlice()), - Port: endpoint.Port(), - }, protoNumber -} - -// Reimplementation of netstack.Net.ListenTCP to call custom ListenTCP. -func listenTCP(s *stack.Stack, laddr *net.TCPAddr) (*TCPListener, error) { - ip, _ := netip.AddrFromSlice(laddr.IP) - addrPort := netip.AddrPortFrom(ip, uint16(laddr.Port)) - addr, network := convertToFullAddr(addrPort) - return ListenTCP(s, addr, network) -} - -// netstack Net adapation ends here. - -// Adaptation of gonet TCP adapter from https://github.com/google/gvisor/blob/df1f4cbd9fcbf56fbfc6fab82c4f3930f0343026/pkg/tcpip/adapters/gonet/gonet.go. - -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -var ( - errCanceled = errors.New("operation canceled") -) - -// timeoutError is how the net package reports timeouts. -type timeoutError struct{} - -func (e *timeoutError) Error() string { return "i/o timeout" } -func (e *timeoutError) Timeout() bool { return true } -func (e *timeoutError) Temporary() bool { return true } - -// A TCPListener is a wrapper around a TCP tcpip.Endpoint that implements -// net.Listener. -type TCPListener struct { - stack *stack.Stack - ep tcpip.Endpoint - wq *waiter.Queue - cancelOnce sync.Once - cancel chan struct{} -} - -// NewTCPListener creates a new TCPListener from a listening tcpip.Endpoint. -func NewTCPListener(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *TCPListener { - return &TCPListener{ - stack: s, - ep: ep, - wq: wq, - cancel: make(chan struct{}), - } -} - -// maxListenBacklog is set to be reasonably high for most uses of gonet. Go net -// package uses the value in /proc/sys/net/core/somaxconn file in Linux as the -// default listen backlog. The value below matches the default in common linux -// distros. -// -// See: https://cs.opensource.google/go/go/+/refs/tags/go1.18.1:src/net/sock_linux.go;drc=refs%2Ftags%2Fgo1.18.1;l=66 -const maxListenBacklog = 4096 - -// ListenTCP creates a new TCPListener. -func ListenTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPListener, error) { - // Create a TCP endpoint, bind it, then start listening. - var wq waiter.Queue - ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) - if err != nil { - return nil, errors.New(err.String()) - } - - if err := ep.Bind(addr); err != nil { - ep.Close() - return nil, &net.OpError{ - Op: "bind", - Net: "tcp", - Addr: fullToTCPAddr(addr), - Err: errors.New(err.String()), - } - } - - if err := ep.Listen(maxListenBacklog); err != nil { - ep.Close() - return nil, &net.OpError{ - Op: "listen", - Net: "tcp", - Addr: fullToTCPAddr(addr), - Err: errors.New(err.String()), - } - } - - return NewTCPListener(s, &wq, ep), nil -} - -// Close implements net.Listener.Close. -func (l *TCPListener) Close() error { - l.ep.Close() - return nil -} - -// Shutdown stops the HTTP server. -func (l *TCPListener) Shutdown() { - l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) - l.cancelOnce.Do(func() { - close(l.cancel) // broadcast cancellation - }) -} - -// Addr implements net.Listener.Addr. -func (l *TCPListener) Addr() net.Addr { - a, err := l.ep.GetLocalAddress() - if err != nil { - return nil - } - return fullToTCPAddr(a) -} - -type deadlineTimer struct { - // mu protects the fields below. - mu sync.Mutex - - readTimer *time.Timer - readCancelCh chan struct{} - writeTimer *time.Timer - writeCancelCh chan struct{} -} - -func (d *deadlineTimer) init() { - d.readCancelCh = make(chan struct{}) - d.writeCancelCh = make(chan struct{}) -} - -func (d *deadlineTimer) readCancel() <-chan struct{} { - d.mu.Lock() - c := d.readCancelCh - d.mu.Unlock() - return c -} -func (d *deadlineTimer) writeCancel() <-chan struct{} { - d.mu.Lock() - c := d.writeCancelCh - d.mu.Unlock() - return c -} - -// setDeadline contains the shared logic for setting a deadline. -// -// cancelCh and timer must be pointers to deadlineTimer.readCancelCh and -// deadlineTimer.readTimer or deadlineTimer.writeCancelCh and -// deadlineTimer.writeTimer. -// -// setDeadline must only be called while holding d.mu. -func (d *deadlineTimer) setDeadline(cancelCh *chan struct{}, timer **time.Timer, t time.Time) { - if *timer != nil && !(*timer).Stop() { - *cancelCh = make(chan struct{}) - } - - // Create a new channel if we already closed it due to setting an already - // expired time. We won't race with the timer because we already handled - // that above. - select { - case <-*cancelCh: - *cancelCh = make(chan struct{}) - default: - } - - // "A zero value for t means I/O operations will not time out." - // - net.Conn.SetDeadline - if t.IsZero() { - return - } - - timeout := time.Until(t) - if timeout <= 0 { - close(*cancelCh) - return - } - - // Timer.Stop returns whether or not the AfterFunc has started, but - // does not indicate whether or not it has completed. Make a copy of - // the cancel channel to prevent this code from racing with the next - // call of setDeadline replacing *cancelCh. - ch := *cancelCh - *timer = time.AfterFunc(timeout, func() { - close(ch) - }) -} - -// SetReadDeadline implements net.Conn.SetReadDeadline and -// net.PacketConn.SetReadDeadline. -func (d *deadlineTimer) SetReadDeadline(t time.Time) error { - d.mu.Lock() - d.setDeadline(&d.readCancelCh, &d.readTimer, t) - d.mu.Unlock() - return nil -} - -// SetWriteDeadline implements net.Conn.SetWriteDeadline and -// net.PacketConn.SetWriteDeadline. -func (d *deadlineTimer) SetWriteDeadline(t time.Time) error { - d.mu.Lock() - d.setDeadline(&d.writeCancelCh, &d.writeTimer, t) - d.mu.Unlock() - return nil -} - -// SetDeadline implements net.Conn.SetDeadline and net.PacketConn.SetDeadline. -func (d *deadlineTimer) SetDeadline(t time.Time) error { - d.mu.Lock() - d.setDeadline(&d.readCancelCh, &d.readTimer, t) - d.setDeadline(&d.writeCancelCh, &d.writeTimer, t) - d.mu.Unlock() - return nil -} - -// A TCPConn is a wrapper around a TCP tcpip.Endpoint that implements the net.Conn -// interface. -type TCPConn struct { - deadlineTimer - - wq *waiter.Queue - ep tcpip.Endpoint - - // readMu serializes reads and implicitly protects read. - // - // Lock ordering: - // If both readMu and deadlineTimer.mu are to be used in a single - // request, readMu must be acquired before deadlineTimer.mu. - readMu sync.Mutex -} - -// NewTCPConn creates a new TCPConn. -func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn { - c := &TCPConn{ - wq: wq, - ep: ep, - } - c.deadlineTimer.init() - return c -} - -// Changed from original: -// AcceptFrom is identical to Accept except that it also returns the Remote Address as seen by the endpoint. -func (l *TCPListener) AcceptFrom(c *TcpConfig) (net.Conn, net.Addr, error) { - remoteAddr := tcpip.FullAddress{} - n, wq, err := l.ep.Accept(&remoteAddr) - - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Create wait queue entry that notifies a channel. - waitEntry, notifyCh := waiter.NewChannelEntry(waiter.ReadableEvents) - l.wq.EventRegister(&waitEntry) - defer l.wq.EventUnregister(&waitEntry) - - for { - n, wq, err = l.ep.Accept(&remoteAddr) - - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - break - } - - select { - case <-l.cancel: - return nil, nil, errCanceled - case <-notifyCh: - } - } - } - - if err != nil { - return nil, nil, &net.OpError{ - Op: "accept", - Net: "tcp", - Addr: l.Addr(), - Err: errors.New(err.String()), - } - } - - // Enable keepalive and set defaults so that after (idle + (count * interval)) connection will be dropped if unresponsive. - n.SocketOptions().SetKeepAlive(true) - keepaliveIdle := tcpip.KeepaliveIdleOption(c.KeepaliveIdle) - err = n.SetSockOpt(&keepaliveIdle) - if err != nil { - return nil, nil, &net.OpError{ - Op: "accept", - Net: "tcp", - Addr: l.Addr(), - Err: errors.New(err.String()), - } - } - keepaliveInterval := tcpip.KeepaliveIntervalOption(c.KeepaliveInterval) - err = n.SetSockOpt(&keepaliveInterval) - if err != nil { - return nil, nil, &net.OpError{ - Op: "accept", - Net: "tcp", - Addr: l.Addr(), - Err: errors.New(err.String()), - } - } - err = n.SetSockOptInt(tcpip.KeepaliveCountOption, c.KeepaliveCount) - if err != nil { - return nil, nil, &net.OpError{ - Op: "accept", - Net: "tcp", - Addr: l.Addr(), - Err: errors.New(err.String()), - } - } - - return NewTCPConn(wq, n), fullToTCPAddr(remoteAddr), nil -} - -// Accept implements net.Conn.Accept. -func (l *TCPListener) Accept() (net.Conn, error) { - remoteAddr := tcpip.FullAddress{} - n, wq, err := l.ep.Accept(&remoteAddr) - - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Create wait queue entry that notifies a channel. - waitEntry, notifyCh := waiter.NewChannelEntry(waiter.ReadableEvents) - l.wq.EventRegister(&waitEntry) - defer l.wq.EventUnregister(&waitEntry) - - for { - n, wq, err = l.ep.Accept(&remoteAddr) - - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - break - } - - select { - case <-l.cancel: - return nil, errCanceled - case <-notifyCh: - } - } - } - - if err != nil { - return nil, &net.OpError{ - Op: "accept", - Net: "tcp", - Addr: l.Addr(), - Err: errors.New(err.String()), - } - } - - return NewTCPConn(wq, n), nil -} - -type opErrorer interface { - newOpError(op string, err error) *net.OpError -} - -// commonRead implements the common logic between net.Conn.Read and -// net.PacketConn.ReadFrom. -func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer) (int, error) { - select { - case <-deadline: - return 0, errorer.newOpError("read", &timeoutError{}) - default: - } - - w := tcpip.SliceWriter(b) - opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil} - res, err := ep.Read(&w, opts) - - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Create wait queue entry that notifies a channel. - waitEntry, notifyCh := waiter.NewChannelEntry(waiter.ReadableEvents) - wq.EventRegister(&waitEntry) - defer wq.EventUnregister(&waitEntry) - for { - res, err = ep.Read(&w, opts) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - break - } - select { - case <-deadline: - return 0, errorer.newOpError("read", &timeoutError{}) - case <-notifyCh: - } - } - } - - if _, ok := err.(*tcpip.ErrClosedForReceive); ok { - return 0, io.EOF - } - - if err != nil { - return 0, errorer.newOpError("read", errors.New(err.String())) - } - - if addr != nil { - *addr = res.RemoteAddr - } - return res.Count, nil -} - -// Read implements net.Conn.Read. -func (c *TCPConn) Read(b []byte) (int, error) { - c.readMu.Lock() - defer c.readMu.Unlock() - - deadline := c.readCancel() - - n, err := commonRead(b, c.ep, c.wq, deadline, nil, c) - if n != 0 { - c.ep.ModerateRecvBuf(n) - } - return n, err -} - -// Write implements net.Conn.Write. -func (c *TCPConn) Write(b []byte) (int, error) { - deadline := c.writeCancel() - - // Check if deadlineTimer has already expired. - select { - case <-deadline: - return 0, c.newOpError("write", &timeoutError{}) - default: - } - - // We must handle two soft failure conditions simultaneously: - // 1. Write may write nothing and return *tcpip.ErrWouldBlock. - // If this happens, we need to register for notifications if we have - // not already and wait to try again. - // 2. Write may write fewer than the full number of bytes and return - // without error. In this case we need to try writing the remaining - // bytes again. I do not need to register for notifications. - // - // What is more, these two soft failure conditions can be interspersed. - // There is no guarantee that all of the condition #1s will occur before - // all of the condition #2s or visa-versa. - var ( - r bytes.Reader - nbytes int - entry waiter.Entry - ch <-chan struct{} - ) - for nbytes != len(b) { - r.Reset(b[nbytes:]) - n, err := c.ep.Write(&r, tcpip.WriteOptions{}) - nbytes += int(n) - switch err.(type) { - case nil: - case *tcpip.ErrWouldBlock: - if ch == nil { - entry, ch = waiter.NewChannelEntry(waiter.WritableEvents) - c.wq.EventRegister(&entry) - defer c.wq.EventUnregister(&entry) - } else { - // Don't wait immediately after registration in case more data - // became available between when we last checked and when we setup - // the notification. - select { - case <-deadline: - return nbytes, c.newOpError("write", &timeoutError{}) - case <-ch: - continue - } - } - default: - return nbytes, c.newOpError("write", errors.New(err.String())) - } - } - return nbytes, nil -} - -// Close implements net.Conn.Close. -func (c *TCPConn) Close() error { - c.ep.Close() - return nil -} - -// CloseRead shuts down the reading side of the TCP connection. Most callers -// should just use Close. -// -// A TCP Half-Close is performed the same as CloseRead for *net.TCPConn. -func (c *TCPConn) CloseRead() error { - if terr := c.ep.Shutdown(tcpip.ShutdownRead); terr != nil { - return c.newOpError("close", errors.New(terr.String())) - } - return nil -} - -// CloseWrite shuts down the writing side of the TCP connection. Most callers -// should just use Close. -// -// A TCP Half-Close is performed the same as CloseWrite for *net.TCPConn. -func (c *TCPConn) CloseWrite() error { - if terr := c.ep.Shutdown(tcpip.ShutdownWrite); terr != nil { - return c.newOpError("close", errors.New(terr.String())) - } - return nil -} - -// LocalAddr implements net.Conn.LocalAddr. -func (c *TCPConn) LocalAddr() net.Addr { - a, err := c.ep.GetLocalAddress() - if err != nil { - return nil - } - return fullToTCPAddr(a) -} - -// RemoteAddr implements net.Conn.RemoteAddr. -func (c *TCPConn) RemoteAddr() net.Addr { - a, err := c.ep.GetRemoteAddress() - if err != nil { - return nil - } - return fullToTCPAddr(a) -} - -func (c *TCPConn) newOpError(op string, err error) *net.OpError { - return &net.OpError{ - Op: op, - Net: "tcp", - Source: c.LocalAddr(), - Addr: c.RemoteAddr(), - Err: err, - } -} - -func fullToTCPAddr(addr tcpip.FullAddress) *net.TCPAddr { - return &net.TCPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)} -} - -// DialTCP creates a new TCPConn connected to the specified address. -func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { - return DialContextTCP(context.Background(), s, addr, network) -} - -// DialTCPWithBind creates a new TCPConn connected to the specified -// remoteAddress with its local address bound to localAddr. -func DialTCPWithBind(ctx context.Context, s *stack.Stack, localAddr, remoteAddr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { - // Create TCP endpoint, then connect. - var wq waiter.Queue - ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) - if err != nil { - return nil, errors.New(err.String()) - } - - // Create wait queue entry that notifies a channel. - // - // We do this unconditionally as Connect will always return an error. - waitEntry, notifyCh := waiter.NewChannelEntry(waiter.WritableEvents) - wq.EventRegister(&waitEntry) - defer wq.EventUnregister(&waitEntry) - - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - // Bind before connect if requested. - if localAddr != (tcpip.FullAddress{}) { - if err = ep.Bind(localAddr); err != nil { - return nil, fmt.Errorf("ep.Bind(%+v) = %s", localAddr, err) - } - } - - err = ep.Connect(remoteAddr) - if _, ok := err.(*tcpip.ErrConnectStarted); ok { - select { - case <-ctx.Done(): - ep.Close() - return nil, ctx.Err() - case <-notifyCh: - } - - err = ep.LastError() - } - if err != nil { - ep.Close() - return nil, &net.OpError{ - Op: "connect", - Net: "tcp", - Addr: fullToTCPAddr(remoteAddr), - Err: errors.New(err.String()), - } - } - - return NewTCPConn(&wq, ep), nil -} - -// DialContextTCP creates a new TCPConn connected to the specified address -// with the option of adding cancellation and timeouts. -func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { - return DialTCPWithBind(ctx, s, tcpip.FullAddress{} /* localAddr */, addr /* remoteAddr */, network) -} - -// gonet adaptation ends here. diff --git a/src/transport/tcp/tcp.go b/src/transport/tcp/tcp.go index 395b347..8a95268 100644 --- a/src/transport/tcp/tcp.go +++ b/src/transport/tcp/tcp.go @@ -1,419 +1,147 @@ // Package tcp proxies TCP connections between a WireGuard peer and a destination // accessible by the machine where Wiretap is running. +// +// Adapted from https://github.com/tailscale/tailscale/blob/2cf6e127907641bdb9eb5cd8e8cf14e968b571d7/wgengine/netstack/netstack.go +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause package tcp import ( + "errors" "fmt" - "io" "log" "net" "os" "sync" "syscall" "time" + "wiretap/transport" "net/netip" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" - "golang.zx2c4.com/wireguard/tun/netstack" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - - "wiretap/transport" + "gvisor.dev/gvisor/pkg/waiter" ) // Configure TCP handler. -type TcpConfig struct { - Ipv4Addr netip.Addr - Ipv6Addr netip.Addr - Port uint16 +type Config struct { CatchTimeout time.Duration ConnTimeout time.Duration KeepaliveIdle time.Duration KeepaliveInterval time.Duration KeepaliveCount int + Tnet *netstack.Net + StackLock *sync.Mutex } -// tcpConn tracks a connection, source and destination IP and Port. -type tcpConn struct { - Source string - Dest string -} - -// connTrack holds the net.Conn to a final destination -// and the status of that connection. -type connTrack struct { - Connecting bool - Conn net.Conn - Caught chan bool -} - -// Keep track of connections so we don't duplicate work. -var isOpen = make(map[tcpConn]connTrack) -var isOpenLock = sync.RWMutex{} - -// preroutingMatch matches packets in the prerouting stage. -type preroutingMatch struct { - pktChan chan stack.PacketBufferPtr - endpoint *channel.Endpoint - config *TcpConfig -} - -// Match looks for SYN packets (start of a tcp conn). Before proxying connection, we need to check -// if intendend destination is up. Drop the packet to prevent blocking, but start goroutine that -// connects to destination. If destination is up, reinject packet and allow it through. -func (m preroutingMatch) Match(hook stack.Hook, packet stack.PacketBufferPtr, inputInterfaceName, outputInterfaceName string) (matches bool, hotdrop bool) { - if hook == stack.Prerouting { - // If SYN flag set, see if connection possible. - netHeader := packet.Network() - transHeader := header.TCP(netHeader.Payload()) - - flags := transHeader.Flags() - if flags.Contains(header.TCPFlagSyn) && !flags.Contains(header.TCPFlagAck) { - dest := net.JoinHostPort(netHeader.DestinationAddress().String(), fmt.Sprint(transHeader.DestinationPort())) - source := net.JoinHostPort(netHeader.SourceAddress().String(), fmt.Sprint(transHeader.SourcePort())) - c := tcpConn{source, dest} +// Handler manages a single TCP flow. +func Handler(c Config) func(*tcp.ForwarderRequest) { + return func(req *tcp.ForwarderRequest) { + // Received TCP flow, add address so we can work with it. + s := req.ID() + log.Printf("(client %s) - Transport: TCP -> %s", net.JoinHostPort(s.RemoteAddress.String(), fmt.Sprint(s.RemotePort)), net.JoinHostPort(s.LocalAddress.String(), fmt.Sprint(s.LocalPort))) - isOpenLock.RLock() - ctrack, ok := isOpen[c] - isOpenLock.RUnlock() + // Add address to stack. + addr, _ := netip.AddrFromSlice(net.IP(s.LocalAddress)) + err := transport.GetConnCounts().AddAddress(addr, c.Tnet.Stack(), c.StackLock) + if err != nil { + log.Println("failed to add address:", err) + req.Complete(false) + return + } + defer func() { + err := transport.GetConnCounts().RemoveAddress(addr, c.Tnet.Stack(), c.StackLock) + if err != nil { + log.Println("failed to remove address:", err) + } + }() - if !ok { - // If not in conn map, drop this packet for now, but clone so it can - // be reinjected if connections are successful. - isOpenLock.Lock() - // In progress, but not ready to forward SYN packets yet. - isOpen[c] = connTrack{ - Connecting: true, - } - isOpenLock.Unlock() + // Address is added, now test if remote endpoint is available. + dstConn, caughtChan, rst := checkDst(&c, s) + if dstConn == nil { + req.Complete(rst) + return + } - packetClone := packet.Clone() - go func() { - checkIfOpen(c, m, packetClone) - packetClone.DecRef() - }() + // Accept conn. + srcConn, err := accept(&c, req) + if err != nil { + dstConn.Close() + log.Println("failed to create endpoint:", err) + return + } - // Hotdrop because we're taking control of the packet. - return false, true - } else if ctrack.Connecting { - // Already checking if port is open. Do nothing. - return false, false - } else { - // Connection is verified to be open. Allow this connection and reset conn map. + // Tell checker that this connection was caught, timer can shutdown. + caughtChan <- true - return true, false - } - } - // ACK here means ACK without prior connection, drop. - if transHeader.Flags() == header.TCPFlagAck { - return false, false - } + transport.Proxy(srcConn, dstConn) } - - return false, false } -// If destination is open, whitelist and reinject. Otherwise send reset. -func checkIfOpen(conn tcpConn, m preroutingMatch, packet stack.PacketBufferPtr) { - log.Printf("(client %v) - Transport: TCP -> %v", conn.Source, conn.Dest) - c, err := net.DialTimeout("tcp", conn.Dest, m.config.ConnTimeout) +// checkDst determines if a tcp connection can be made to a destination. +// Returns the connection on success, +// a channel for the caller to populate when the connection is used, +// and whether or not to send RST to source. +func checkDst(config *Config, s stack.TransportEndpointID) (net.Conn, chan bool, bool) { + c, err := net.DialTimeout("tcp", net.JoinHostPort(s.LocalAddress.String(), fmt.Sprint(s.LocalPort)), config.ConnTimeout) if err != nil { - //log.Printf("Error connecting to %s: %s\n", conn.Dest, err) - // If connection refused, we can send a reset to let peer know. if oerr, ok := err.(*net.OpError); ok { if syserr, ok := oerr.Err.(*os.SyscallError); ok { if syserr.Err == syscall.ECONNREFUSED { - //log.Println("Connection refused, sending reset") - m.pktChan <- packet.Clone() + return nil, nil, true } } } - // Error, reset connection progress. - isOpenLock.Lock() - delete(isOpen, conn) - isOpenLock.Unlock() - return - } - caughtChan := make(chan bool) - // No error, mark successful and reinject packet. - isOpenLock.Lock() - isOpen[conn] = connTrack{ - Connecting: false, - Conn: c, - Caught: caughtChan, + // Different error, don't send reset. + return nil, nil, false } - isOpenLock.Unlock() // Start "catch" timer to make sure connection is actually used. + caughtChan := make(chan bool) go func() { select { - case <-time.After(m.config.CatchTimeout): + case <-time.After(config.CatchTimeout): c.Close() - isOpenLock.Lock() - delete(isOpen, conn) - isOpenLock.Unlock() case <-caughtChan: } }() - isIpv6 := !netip.MustParseAddrPort(c.RemoteAddr().String()).Addr().Is4() - netProto := ipv4.ProtocolNumber - if isIpv6 { - netProto = ipv6.ProtocolNumber - } - new_packet := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: packet.ToBuffer(), - }) - m.endpoint.InjectInbound(netProto, new_packet) -} - -// Handle creates a DNAT rule that forwards destination packets to a tcp listener. -// Once a connection is accepted, it gets handed off to handleConn(). -func Handle(tnet *netstack.Net, config TcpConfig, lock *sync.Mutex) { - s := tnet.Stack() - - // Create iptables rule. - // iptables -t nat -A PREROUTING -p tcp -j DNAT --to-destination 192.168.0.1:80 - headerFilter := stack.IPHeaderFilter{Protocol: tcp.ProtocolNumber, - CheckProtocol: true, - } - - match := preroutingMatch{ - pktChan: make(chan stack.PacketBufferPtr, 1), - endpoint: tnet.Endpoint(), - config: &config, - } - - rule4 := stack.Rule{ - Filter: headerFilter, - Matchers: []stack.Matcher{match}, - Target: &stack.DNATTarget{ - Addr: tcpip.Address(config.Ipv4Addr.AsSlice()), - Port: config.Port, - NetworkProtocol: ipv4.ProtocolNumber, - }, - } - - rule6 := stack.Rule{ - Filter: headerFilter, - Matchers: []stack.Matcher{match}, - Target: &stack.DNATTarget{ - Addr: tcpip.Address(config.Ipv6Addr.AsSlice()), - Port: config.Port, - NetworkProtocol: ipv6.ProtocolNumber, - }, - } - - tid := stack.NATID - transport.PushRule(s, rule4, tid, false) - transport.PushRule(s, rule6, tid, true) - lock.Unlock() - - // RST handler - go func() { - for { - packetClone := <-match.pktChan - go func() { - sendRST(s, packetClone) - packetClone.DecRef() - }() - } - }() - - go startListener(tnet, s.IPTables(), &net.TCPAddr{Port: int(config.Port)}, config.Ipv4Addr, config.Ipv6Addr, s, &config) + return c, caughtChan, false } -// startListener accepts connections from WireGuard peer. -func startListener(tnet *netstack.Net, tables *stack.IPTables, listenAddr *net.TCPAddr, localAddr4 netip.Addr, localAddr6 netip.Addr, s *stack.Stack, c *TcpConfig) { - // Workaround to get true remote address even when connection closes prematurely. - l, err := listenTCP(s, listenAddr) +// accept converts a forwarder request to an endpoint, sets sockopts, then converts to conn. +// "Completes" forwarding request without RST. +func accept(c *Config, req *tcp.ForwarderRequest) (net.Conn, error) { + // We want to accept this flow, setup endpoint to complete handshake. + var wq waiter.Queue + ep, err := req.CreateEndpoint(&wq) + req.Complete(false) if err != nil { - log.Panic(err) + return nil, errors.New(err.String()) } - defer l.Close() - - log.Println("Transport: TCP listener up") - for { - // Every TCP connection gets accepted here, modified Accept function sets correct remote address. - c, remoteAddr, err := l.AcceptFrom(c) - if err != nil || remoteAddr == nil { - log.Println("Failed to accept connection:", err) - continue - } - - go func() { - isIpv6 := !netip.MustParseAddrPort(remoteAddr.String()).Addr().Is4() - netProto := ipv4.ProtocolNumber - localAddr := localAddr4 - if isIpv6 { - netProto = ipv6.ProtocolNumber - localAddr = localAddr6 - } - - handleConn(c, localAddr, remoteAddr, netProto, tables) - }() - } -} - -// handleConn finds the intended target of a peer's connection, -// connects to that target, then copies data between the two. -func handleConn(c net.Conn, ipAddr netip.Addr, remoteAddr net.Addr, netProto tcpip.NetworkProtocolNumber, tables *stack.IPTables) { - var wg sync.WaitGroup - defer c.Close() - - // Lookup original destination of this connection. - addr, port, tcpipErr := tables.OriginalDst(stack.TransportEndpointID{ - LocalPort: 1337, - LocalAddress: tcpip.Address(ipAddr.AsSlice()), RemotePort: netip.MustParseAddrPort(remoteAddr.String()).Port(), - RemoteAddress: tcpip.Address(netip.MustParseAddrPort(remoteAddr.String()).Addr().AsSlice()), - }, netProto, tcp.ProtocolNumber) - if tcpipErr != nil { - log.Println("Error reading original destination:", tcpipErr) - return - } - - dest := net.JoinHostPort(addr.String(), fmt.Sprint(port)) - source := remoteAddr.String() - cString := tcpConn{source, dest} - - // Original destination should be dialed already for when we checked if it was open: - isOpenLock.Lock() - ctrack, ok := isOpen[cString] - isOpenLock.Unlock() - if !ok { - log.Printf("Error looking up conn to destination: %v\n", net.JoinHostPort(addr.String(), fmt.Sprint(port))) - return - } - - // Delete original destination from map so it can be remade. - newConn := ctrack.Conn - isOpenLock.Lock() - // Notify catch timer that this conn is being used. - ctrack.Caught <- true - delete(isOpen, cString) - isOpenLock.Unlock() - - // Copy from new connection to peer - wg.Add(1) - go func() { - _, err := io.Copy(c, newConn) - if err != nil { - log.Printf("Error copying between connections: %v\n", err) - } - wg.Done() - c.Close() - }() - - // Copy from peer to new connection. - _, err := io.Copy(newConn, c) + // Enable keepalive and set defaults so that after (idle + (count * interval)) connection will be dropped if unresponsive. + ep.SocketOptions().SetKeepAlive(true) + keepaliveIdle := tcpip.KeepaliveIdleOption(c.KeepaliveIdle) + err = ep.SetSockOpt(&keepaliveIdle) if err != nil { - log.Printf("Error copying between connections: %v\n", err) + return nil, errors.New(err.String()) } - newConn.Close() - - // Wait for both copies to finish. - wg.Wait() -} - -// sendRST sends an RST back to the original source of a packet. -func sendRST(s *stack.Stack, packet stack.PacketBufferPtr) { - var err error - var ipv4Layer *layers.IPv4 - var ipv6Layer *layers.IPv6 - - netHeader := packet.Network() - transHeader := header.TCP(netHeader.Payload()) - - isIpv6 := netHeader.DestinationAddress().To4() == "" - - if isIpv6 { - ipv6Layer = &layers.IPv6{} - ipv6Layer, err = transport.GetNetworkLayer[header.IPv6](netHeader, ipv6Layer) - ipv6Layer.SrcIP, ipv6Layer.DstIP = ipv6Layer.DstIP, ipv6Layer.SrcIP - } else { - ipv4Layer = &layers.IPv4{} - ipv4Layer, err = transport.GetNetworkLayer[header.IPv4](netHeader, ipv4Layer) - ipv4Layer.SrcIP, ipv4Layer.DstIP = ipv4Layer.DstIP, ipv4Layer.SrcIP - } - - if err != nil { - log.Println("Could not decode Network header:", err) - return - } - - // Create transport layer and swap ports, fix flags. - tcpLayer := &layers.TCP{} - err = tcpLayer.DecodeFromBytes(transHeader, gopacket.NilDecodeFeedback) - if err != nil { - log.Println("Could not decode TCP header:", err) - return - } - - tcpLayer.SrcPort, tcpLayer.DstPort = tcpLayer.DstPort, tcpLayer.SrcPort - tcpLayer.Ack = tcpLayer.Seq + 1 - tcpLayer.Seq = 0 - tcpLayer.DataOffset = 5 - tcpLayer.SYN = false - tcpLayer.RST = true - tcpLayer.ACK = true - tcpLayer.Window = 0 - tcpLayer.Options = nil - tcpLayer.Padding = nil - - if isIpv6 { - err = tcpLayer.SetNetworkLayerForChecksum(ipv6Layer) - } else { - err = tcpLayer.SetNetworkLayerForChecksum(ipv4Layer) - } - + keepaliveInterval := tcpip.KeepaliveIntervalOption(c.KeepaliveInterval) + err = ep.SetSockOpt(&keepaliveInterval) if err != nil { - log.Println("Could not set layer for checksum:", err) - return - } - - buf := gopacket.NewSerializeBuffer() - options := gopacket.SerializeOptions{ - ComputeChecksums: true, - FixLengths: true, - } - if isIpv6 { - err = gopacket.SerializeLayers(buf, options, - ipv6Layer, - tcpLayer, - ) - } else { - err = gopacket.SerializeLayers(buf, options, - ipv4Layer, - tcpLayer, - ) + return nil, errors.New(err.String()) } + err = ep.SetSockOptInt(tcpip.KeepaliveCountOption, c.KeepaliveCount) if err != nil { - log.Println("Failed to serialize layers:", err) - return + return nil, errors.New(err.String()) } - response := buf.Bytes() - - // Create network layer endpoint for spoofing source address. - proto := ipv4.ProtocolNumber - if isIpv6 { - proto = ipv6.ProtocolNumber - } - - tcpipErr := transport.SendPacket(s, response, &tcpip.FullAddress{NIC: 1, Addr: netHeader.SourceAddress()}, proto) - if tcpipErr != nil { - log.Println("Failed to send reset:", tcpipErr) - return - } + return gonet.NewTCPConn(&wq, ep), nil } diff --git a/src/transport/transport.go b/src/transport/transport.go index 012d0f1..4420245 100644 --- a/src/transport/transport.go +++ b/src/transport/transport.go @@ -3,22 +3,26 @@ package transport import ( "bytes" + "context" "errors" + "io" + "log" + "net" + "net/netip" + "strconv" + "sync" + "github.com/armon/go-socks5" "github.com/google/gopacket" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) -// PushRule pushes a rule onto a firewall table. -func PushRule(s *stack.Stack, rule stack.Rule, tableId stack.TableID, ipv6 bool) { - table := s.IPTables().GetTable(tableId, ipv6) - table.Rules = append([]stack.Rule{rule}, table.Rules...) - s.IPTables().ReplaceTable(tableId, table, ipv6) -} - // IPHeader is a type interface used by GetNetworkLayer. type IPHeader interface { header.IPv4 | header.IPv6 @@ -29,13 +33,57 @@ type IPLayer interface { DecodeFromBytes(data []byte, df gopacket.DecodeFeedback) error } +type ConnCounts struct { + counts map[netip.Addr]int + lock sync.Mutex +} + +var connCounts ConnCounts + +func init() { + connCounts.counts = make(map[netip.Addr]int) +} + +func GetConnCounts() *ConnCounts { + return &connCounts +} + +func (c *ConnCounts) AddAddress(addr netip.Addr, s *stack.Stack, stackLock *sync.Mutex) error { + c.lock.Lock() + defer c.lock.Unlock() + + c.counts[addr]++ + + if c.counts[addr] > 1 { + return nil + } + + var protoNumber tcpip.NetworkProtocolNumber + if addr.Is4() { + protoNumber = ipv4.ProtocolNumber + } else if addr.Is6() { + protoNumber = ipv6.ProtocolNumber + } + protoAddr := tcpip.ProtocolAddress{ + Protocol: protoNumber, + AddressWithPrefix: tcpip.Address(addr.AsSlice()).WithPrefix(), + } + + stackLock.Lock() + err := s.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) + stackLock.Unlock() + if err != nil { + return errors.New(err.String()) + } + return nil +} + // GetNetworkLayer parses a network header, then converts it to bytes. func GetNetworkLayer[H IPHeader, L IPLayer](netHeader header.Network, ipLayer L) (L, error) { h, ok := netHeader.(H) if !ok { - return ipLayer, errors.New("Could not assert network header as provided type") + return ipLayer, errors.New("could not assert network header as provided type") } - err := ipLayer.DecodeFromBytes(h, gopacket.NilDecodeFeedback) if err != nil { return ipLayer, err @@ -44,10 +92,34 @@ func GetNetworkLayer[H IPHeader, L IPLayer](netHeader header.Network, ipLayer L) return ipLayer, nil } +// RemoveAddress removes an address from the stack once it is no longer needed. +func (c *ConnCounts) RemoveAddress(addr netip.Addr, s *stack.Stack, stackLock *sync.Mutex) error { + c.lock.Lock() + defer c.lock.Unlock() + + c.counts[addr]-- + + if c.counts[addr] > 0 { + return nil + } + + delete(c.counts, addr) + + stackLock.Lock() + err := s.RemoveAddress(1, tcpip.Address(addr.AsSlice())) + stackLock.Unlock() + + if err != nil { + return errors.New(err.String()) + } + return nil +} + // SendPacket sends a network-layer packet. func SendPacket(s *stack.Stack, packet []byte, addr *tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) tcpip.Error { // Create network layer endpoint for spoofing source address. var wq waiter.Queue + ep, tcpipErr := s.NewPacketEndpoint(true, netProto, &wq) if tcpipErr != nil { return tcpipErr @@ -65,3 +137,162 @@ func SendPacket(s *stack.Stack, packet []byte, addr *tcpip.FullAddress, netProto return nil } + +func Proxy(src net.Conn, dst net.Conn) { + var wg sync.WaitGroup + + wg.Add(1) + go func() { + _, err := io.Copy(src, dst) + if err != nil { + log.Printf("error copying between connections: %v\n", err) + } + src.Close() + wg.Done() + }() + + // Copy from peer to new connection. + _, nerr := io.Copy(dst, src) + if nerr != nil { + log.Printf("error copying between connections: %v\n", nerr) + } + dst.Close() + + // Wait for both copies to finish. + wg.Wait() +} + +// ForwardTcpPort proxies TCP connections by accepting connections and proxying them back to the client. +func ForwardTcpPort(s *stack.Stack, l net.Listener, localAddr tcpip.FullAddress, remoteAddr tcpip.FullAddress, np tcpip.NetworkProtocolNumber) { + ctx, cancel := context.WithCancel(context.Background()) + for { + conn, err := l.Accept() + if err != nil { + cancel() + return + } + + // Proxy between conns. + go func() { + var nc net.Conn + nc, err = gonet.DialTCPWithBind( + ctx, + s, + localAddr, + remoteAddr, + np, + ) + if err != nil { + log.Println("failed to proxy conn:", err) + conn.Close() + return + } + + Proxy(conn, nc) + }() + } +} + +// ForwardUdpPort proxies UDP datagrams by forwarding datagrams to a peer, and then returns responses to the last remote address to talk to this endpoint. +// No connection tracking is in place at this time. +func ForwardUdpPort(s *stack.Stack, conn *net.UDPConn, localAddr tcpip.FullAddress, remoteAddr tcpip.FullAddress, np tcpip.NetworkProtocolNumber) { + var wg sync.WaitGroup + var clientAddr *netip.AddrPort + var lock sync.Mutex + + const bufSize = 65535 + + // Connect to forwarded port. + nc, err := gonet.DialUDP( + s, + &localAddr, + &remoteAddr, + np, + ) + if err != nil { + log.Println("failed to proxy conn:", err) + conn.Close() + return + } + + // Accept packets and forward to peer. + wg.Add(1) + go func() { + buf := make([]byte, bufSize) + for { + n, addr, err := conn.ReadFromUDPAddrPort(buf) + if err != nil { + nc.Close() + log.Println("conn closed:", err) + break + } + lock.Lock() + clientAddr = &addr + lock.Unlock() + _, err = nc.Write(buf[:n]) + if err != nil { + log.Println("failed to send:", err) + continue + } + } + wg.Done() + }() + + for { + buf := make([]byte, bufSize) + n, err := nc.Read(buf) + if err != nil { + log.Println("conn closed:", err) + conn.Close() + break + } + lock.Lock() + if clientAddr == nil { + lock.Unlock() + continue + } + addr := *clientAddr + lock.Unlock() + _, err = conn.WriteToUDPAddrPort(buf[:n], addr) + if err != nil { + log.Println("failed to send:", err) + continue + } + } + + wg.Wait() +} + +// ForwardTcpPort proxies TCP connections by accepting connections and proxying them back to the client. +func ForwardDynamic(s *stack.Stack, l *net.Listener, localAddr tcpip.FullAddress, remoteAddr tcpip.FullAddress, np tcpip.NetworkProtocolNumber) { + dialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + _, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + dport, err := strconv.Atoi(port) + if err != nil { + return nil, err + } + + return gonet.DialTCPWithBind( + ctx, + s, + localAddr, + tcpip.FullAddress{NIC: remoteAddr.NIC, Addr: remoteAddr.Addr, Port: uint16(dport)}, + np, + ) + } + + conf := &socks5.Config{Dial: dialer} + server, err := socks5.New(conf) + if err != nil { + log.Println("failed to make socks server:", err) + return + } + + if err := server.Serve(*l); err != nil { + log.Println("socks server stopped:", err) + } +} diff --git a/src/transport/udp/udp.go b/src/transport/udp/udp.go index be2a008..8d2356c 100644 --- a/src/transport/udp/udp.go +++ b/src/transport/udp/udp.go @@ -1,3 +1,5 @@ +// Package udp proxies UDP messages between a WireGuard peer and a destination accessible +// by the machine where Wiretap is running. package udp import ( @@ -24,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "wiretap/transport" ) @@ -48,25 +49,25 @@ var sourceMapLock = sync.RWMutex{} var connMap = make(map[udpConn](chan stack.PacketBufferPtr)) var connMapLock = sync.RWMutex{} -// preroutingMatch matches packets in the prerouting stage. -type preroutingMatch struct{} +type Config struct { + Tnet *netstack.Net + StackLock *sync.Mutex +} -var s *stack.Stack +// Handler handles UDP packets. Returns function that returns true if packet is handled, or false if ICMP Destination Unreachable should be sent. +// TODO: Clean this up. Can't use UDPForwarder because it doesn't offer a way to return false, which is required to send Unreachables. +func Handler(c Config) func(stack.TransportEndpointID, stack.PacketBufferPtr) bool { + return func(teid stack.TransportEndpointID, pkb stack.PacketBufferPtr) bool { + log.Printf("(client %s) - Transport: UDP -> %s", net.JoinHostPort(teid.RemoteAddress.String(), fmt.Sprint(teid.RemotePort)), net.JoinHostPort(teid.LocalAddress.String(), fmt.Sprint(teid.LocalPort))) -// Match rejects all packets, but clones every prerouting packet to the packet handler. -func (m preroutingMatch) Match(hook stack.Hook, packet stack.PacketBufferPtr, inputInterfaceName, outputInterfaceName string) (matches bool, hotdrop bool) { - if hook == stack.Prerouting { - packetClone := packet.Clone() + packetClone := pkb.Clone() go func() { - newPacket(packetClone) + newPacket(packetClone, c.Tnet.Stack()) packetClone.DecRef() }() - // Taking control of packet, hotdrop. - return false, true + return true } - - return false, false } func sourceMapLookup(n netip.AddrPort) (dialerCount, bool) { @@ -130,15 +131,13 @@ func getDataFromPacket(packet stack.PacketBufferPtr) []byte { } // NewPacket handles every new packet and sending it to the proper UDP dialer. -func newPacket(packet stack.PacketBufferPtr) { +func newPacket(packet stack.PacketBufferPtr, s *stack.Stack) { netHeader := packet.Network() transHeader := header.UDP(netHeader.Payload()) source := netip.MustParseAddrPort(net.JoinHostPort(netHeader.SourceAddress().String(), fmt.Sprint(transHeader.SourcePort()))) dest := netip.MustParseAddrPort(net.JoinHostPort(netHeader.DestinationAddress().String(), fmt.Sprint(transHeader.DestinationPort()))) - log.Printf("(client %v) - Transport: UDP -> %v", source, dest) - var pktChan chan stack.PacketBufferPtr var ok bool @@ -162,56 +161,13 @@ func newPacket(packet stack.PacketBufferPtr) { pktChan = make(chan stack.PacketBufferPtr, 1) connMapWrite(conn, pktChan) - go handleConn(conn, port) + go handleConn(conn, port, s) pktChan <- packet.Clone() } -// Handle creates a DNAT rule that forwards destination packets to a udp listener. -// Once a connection is accepted, it gets handed off to handleConn(). -func Handle(tnet *netstack.Net, ipv4Addr netip.Addr, ipv6Addr netip.Addr, port uint16, lock *sync.Mutex) { - s = tnet.Stack() - - // Create NATing firewall rule. - // iptables -t nat -A PREROUTING -p udp -j DNAT --to-destination : - headerFilter := stack.IPHeaderFilter{ - Protocol: udp.ProtocolNumber, - CheckProtocol: true, - } - - match := preroutingMatch{} - - rule4 := stack.Rule{ - Filter: headerFilter, - Matchers: []stack.Matcher{match}, - Target: &stack.DNATTarget{ - Addr: tcpip.Address(ipv4Addr.AsSlice()), - Port: port, - NetworkProtocol: ipv4.ProtocolNumber, - }, - } - - rule6 := stack.Rule{ - Filter: headerFilter, - Matchers: []stack.Matcher{match}, - Target: &stack.DNATTarget{ - Addr: tcpip.Address(ipv6Addr.AsSlice()), - Port: port, - NetworkProtocol: ipv6.ProtocolNumber, - }, - } - - tid := stack.NATID - transport.PushRule(s, rule4, tid, false) - transport.PushRule(s, rule6, tid, true) - lock.Unlock() - - // UDP listener is handled in the prerouting rule, we can return. - log.Println("Transport: UDP listener up") -} - // handleConn proxies traffic between a source and destination. -func handleConn(conn udpConn, port int) { +func handleConn(conn udpConn, port int, s *stack.Stack) { defer func() { connMapDelete(conn) }() @@ -221,12 +177,12 @@ func handleConn(conn udpConn, port int) { // New dialer from source to destination. laddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port)) if err != nil { - log.Println("Failed to parse laddr", err) + log.Println("failed to parse laddr", err) return } raddr, err := net.ResolveUDPAddr("udp", conn.Dest.String()) if err != nil { - log.Println("Failed to parse raddr", err) + log.Println("failed to parse raddr", err) return } @@ -234,7 +190,7 @@ func handleConn(conn udpConn, port int) { // Would like to use ListenUDP, but we don't get ICMP unreachable. newConn, err := reuse.Dial("udp", laddr.String(), raddr.String()) if err != nil { - log.Println("Failed new UDP bind", err) + log.Println("failed new UDP bind", err) return } defer newConn.Close() @@ -249,7 +205,7 @@ func handleConn(conn udpConn, port int) { err = newConn.SetDeadline(time.Now().Add(30 * time.Second)) if err != nil { - log.Println("Failed to set deadline", err) + log.Println("failed to set deadline", err) } // Sends packet from peer to destination. @@ -269,7 +225,7 @@ func handleConn(conn udpConn, port int) { _, err := newConn.Write(data) pkt.DecRef() if err != nil { - log.Println("Error sending packet:", err) + log.Println("error sending packet:", err) newConn.Close() return } @@ -277,7 +233,7 @@ func handleConn(conn udpConn, port int) { // Reset timer, we got a packet. err = newConn.SetDeadline(time.Now().Add(30 * time.Second)) if err != nil { - log.Println("Failed to set deadline:", err) + log.Println("failed to set deadline:", err) } } }() @@ -291,7 +247,7 @@ func handleConn(conn udpConn, port int) { if oerr, ok := err.(*net.OpError); ok { if syserr, ok := oerr.Err.(*os.SyscallError); ok { if syserr.Err == syscall.ECONNREFUSED { - go sendUnreachable(mostRecentPacket) + go sendUnreachable(mostRecentPacket, s) } } } @@ -312,17 +268,17 @@ func handleConn(conn udpConn, port int) { // Reset timer, we got a packet. err = newConn.SetDeadline(time.Now().Add(30 * time.Second)) if err != nil { - log.Println("Failed to set deadline:", err) + log.Println("failed to set deadline:", err) } // Write packet back to peer. - sendResponse(conn, newBuf[:n]) + sendResponse(conn, newBuf[:n], s) } } // sendResponse builds a UDP packet to return to the peer. // TCP doesn't need this because the NATing works fine, but with UDP the OriginalDst function fails. -func sendResponse(conn udpConn, data []byte) { +func sendResponse(conn udpConn, data []byte, s *stack.Stack) { var err error var ipv4Layer *layers.IPv4 var ipv6Layer *layers.IPv6 @@ -354,7 +310,7 @@ func sendResponse(conn udpConn, data []byte) { err = udpLayer.SetNetworkLayerForChecksum(ipv4Layer) } if err != nil { - log.Println("Failed to marshal response:", err) + log.Println("failed to marshal response:", err) return } @@ -381,20 +337,20 @@ func sendResponse(conn udpConn, data []byte) { } if err != nil { - log.Println("Failed to serialize layers:", err) + log.Println("failed to serialize layers:", err) return } tcpipErr := transport.SendPacket(s, buf.Bytes(), &tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(conn.Source.Addr().AsSlice())}, proto) if tcpipErr != nil { - log.Println("Failed to write:", tcpipErr) + log.Println("failed to write:", tcpipErr) return } } // sendUnreachable sends an ICMP Port Unreachable packet to peer as if from // the original destination of the packet. -func sendUnreachable(packet stack.PacketBufferPtr) { +func sendUnreachable(packet stack.PacketBufferPtr, s *stack.Stack) { var err error var ipv4Layer *layers.IPv4 var ipv6Layer *layers.IPv6 @@ -412,7 +368,7 @@ func sendUnreachable(packet stack.PacketBufferPtr) { ipv6Layer = &layers.IPv6{} ipv6Layer, err = transport.GetNetworkLayer[header.IPv6](netHeader, ipv6Layer) if err != nil { - log.Println("Could not decode Network header:", err) + log.Println("could not decode Network header:", err) return } ipv6Layer = &layers.IPv6{ @@ -424,7 +380,7 @@ func sendUnreachable(packet stack.PacketBufferPtr) { } ipv6Header, ok := netHeader.(header.IPv6) if !ok { - log.Println("Could not type assert IPv6 Network Header") + log.Println("could not type assert IPv6 Network Header") return } icmpLayer, err = (&neticmp.Message{ @@ -439,7 +395,7 @@ func sendUnreachable(packet stack.PacketBufferPtr) { ipv4Layer = &layers.IPv4{} ipv4Layer, err = transport.GetNetworkLayer[header.IPv4](netHeader, ipv4Layer) if err != nil { - log.Println("Could not decode Network header:", err) + log.Println("could not decode Network header:", err) return } ipv4Layer = &layers.IPv4{ @@ -452,7 +408,7 @@ func sendUnreachable(packet stack.PacketBufferPtr) { } ipv4Header, ok := netHeader.(header.IPv4) if !ok { - log.Println("Could not type assert IPv6 Network Header") + log.Println("could not type assert IPv6 Network Header") return } icmpLayer, err = (&neticmp.Message{ @@ -465,7 +421,7 @@ func sendUnreachable(packet stack.PacketBufferPtr) { ipv4Layer.Length = uint16((int(ipv4Layer.IHL) * 4) + len(icmpLayer)) } if err != nil { - log.Println("Failed to marshal response:", err) + log.Println("failed to marshal response:", err) return } @@ -486,7 +442,7 @@ func sendUnreachable(packet stack.PacketBufferPtr) { ) } if err != nil { - log.Println("Failed to serialize layers:", err) + log.Println("failed to serialize layers:", err) return } @@ -494,7 +450,7 @@ func sendUnreachable(packet stack.PacketBufferPtr) { tcpipErr := transport.SendPacket(s, response, &tcpip.FullAddress{NIC: 1, Addr: netHeader.SourceAddress()}, proto) if tcpipErr != nil { - log.Println("Failed to write:", tcpipErr) + log.Println("failed to write:", tcpipErr) return } } diff --git a/wiretap.Dockerfile b/wiretap.Dockerfile index 746668b..ab632c5 100644 --- a/wiretap.Dockerfile +++ b/wiretap.Dockerfile @@ -5,7 +5,7 @@ ARG https_proxy # Utilities for testing RUN apt-get update -RUN apt-get install net-tools nmap dnsutils tcpdump iproute2 vim netcat iputils-ping wireguard iperf xsel masscan -y +RUN apt-get install net-tools nmap dnsutils tcpdump iproute2 vim netcat-openbsd iputils-ping wireguard iperf xsel masscan -y WORKDIR /wiretap COPY ./src/go.mod ./src/go.sum ./