From d22aa3995749e65da07ca1ee5e14d16733b49103 Mon Sep 17 00:00:00 2001 From: Vasyl Gello Date: Thu, 18 Jul 2024 22:24:46 +0300 Subject: [PATCH] [WIP] Introduce TCP/UDP local/remote port forwarding Signed-off-by: Vasyl Gello --- cmd/yggstack/main.go | 126 +++++++++++-- src/types/mapping.go | 379 ++++++++++++++++++++++++++++++++------ src/types/mapping_test.go | 60 ++++++ src/types/udpproxy.go | 19 +- 4 files changed, 508 insertions(+), 76 deletions(-) diff --git a/cmd/yggstack/main.go b/cmd/yggstack/main.go index 7d6ca98..ef9d02a 100644 --- a/cmd/yggstack/main.go +++ b/cmd/yggstack/main.go @@ -40,15 +40,22 @@ type node struct { socks5Listener net.Listener } -type UDPSession struct { - conn *net.UDPConn +type UDPConnSession struct { + conn *net.UDPConn + remoteAddr net.Addr +} + +type UDPPacketConnSession struct { + conn *net.PacketConn remoteAddr net.Addr } // The main function is responsible for configuring and starting Yggdrasil. func main() { - var exposetcp types.TCPMappings - var exposeudp types.UDPMappings + var localtcp types.TCPLocalMappings + var localudp types.UDPLocalMappings + var remotetcp types.TCPRemoteMappings + var remoteudp types.UDPRemoteMappings genconf := flag.Bool("genconf", false, "print a new config to stdout") useconf := flag.Bool("useconf", false, "read HJSON/JSON config from stdin") useconffile := flag.String("useconffile", "", "read HJSON/JSON config from specified file path") @@ -64,8 +71,10 @@ func main() { loglevel := flag.String("loglevel", "info", "loglevel to enable") socks := flag.String("socks", "", "address to listen on for SOCKS, i.e. :1080; or UNIX socket file path, i.e. /tmp/yggstack.sock") nameserver := flag.String("nameserver", "", "the Yggdrasil IPv6 address to use as a DNS server for SOCKS") - flag.Var(&exposetcp, "exposetcp", "TCP ports to expose to the network, e.g. 22, 2022:22, 22:192.168.1.1:2022") - flag.Var(&exposeudp, "exposeudp", "UDP ports to expose to the network, e.g. 22, 2022:22, 22:192.168.1.1:2022") + flag.Var(&localtcp, "local-tcp", "TCP ports to forward to the remote Yggdradil node, e.g. 22:[a:b:c:d]:22, 127.0.0.1:22:[a:b:c:d]:22") + flag.Var(&localudp, "local-udp", "UDP ports to forward to the remote Yggdrasil node, e.g. 22:[a:b:c:d]:2022, 127.0.0.1:[a:b:c:d]:22") + flag.Var(&remotetcp, "remote-tcp", "TCP ports to expose to the network, e.g. 22, 2022:22, 22:192.168.1.1:2022") + flag.Var(&remoteudp, "remote-udp", "UDP ports to expose to the network, e.g. 22, 2022:22, 22:192.168.1.1:2022") flag.Parse() // Catch interrupts from the operating system to exit gracefully. @@ -328,9 +337,96 @@ func main() { } } - // Create TCP mappings + // Create local TCP mappings (forwarding connections from local port + // to remote Yggdrasil node) { - for _, mapping := range exposetcp { + for _, mapping := range localtcp { + go func(mapping types.TCPMapping) { + listener, err := net.ListenTCP("tcp", mapping.Listen) + if err != nil { + panic(err) + } + logger.Infof("Mapping local TCP port %d to Yggdrasil %s", mapping.Listen.Port, mapping.Mapped) + for { + c, err := listener.Accept() + if err != nil { + panic(err) + } + r, err := s.DialTCP(mapping.Mapped) + if err != nil { + logger.Errorf("Failed to connect to %s: %s", mapping.Mapped, err) + _ = c.Close() + continue + } + go types.ProxyTCP(n.core.MTU(), c, r) + } + }(mapping) + } + } + + // Create local UDP mappings (forwarding connections from local port + // to remote Yggdrasil node) + { + for _, mapping := range localudp { + go func(mapping types.UDPMapping) { + mtu := n.core.MTU() + udpListenConn, err := net.ListenUDP("udp", mapping.Listen) + if err != nil { + panic(err) + } + logger.Infof("Mapping local UDP port %d to Yggdrasil %s", mapping.Listen.Port, mapping.Mapped) + localUdpConnections := new(sync.Map) + udpBuffer := make([]byte, mtu) + for { + bytesRead, remoteUdpAddr, err := udpListenConn.ReadFrom(udpBuffer) + if err != nil { + if bytesRead == 0 { + continue + } + } + + remoteUdpAddrStr := remoteUdpAddr.String() + + connVal, ok := localUdpConnections.Load(remoteUdpAddrStr) + + if !ok { + logger.Infof("Creating new session for %s", remoteUdpAddr.String()) + udpFwdConn, err := s.DialUDP(mapping.Mapped) + if err != nil { + logger.Errorf("Failed to connect to %s: %s", mapping.Mapped, err) + continue + } + udpSession := &UDPPacketConnSession{ + conn: &udpFwdConn, + remoteAddr: remoteUdpAddr, + } + localUdpConnections.Store(remoteUdpAddrStr, udpSession) + go types.ReverseProxyUDPPacketConn(mtu, udpListenConn, remoteUdpAddr, udpFwdConn) + } + + udpSession, ok := connVal.(*UDPPacketConnSession) + if !ok { + continue + } + + udpFwdConn := *udpSession.conn + + _, err = udpFwdConn.WriteTo(udpBuffer[:bytesRead], mapping.Mapped) + if err != nil { + logger.Debugf("Cannot write from yggdrasil to udp listener: %q", err) + udpFwdConn.Close() + localUdpConnections.Delete(remoteUdpAddrStr) + continue + } + } + }(mapping) + } + } + + // Create remote TCP mappings (forwarding connections from Yggdrasil + // node to local port) + { + for _, mapping := range remotetcp { go func(mapping types.TCPMapping) { listener, err := s.ListenTCP(mapping.Listen) if err != nil { @@ -354,9 +450,10 @@ func main() { } } - // Create UDP mappings + // Create remote UDP mappings (forwarding connections from Yggdrasil + // node to local port) { - for _, mapping := range exposeudp { + for _, mapping := range remoteudp { go func(mapping types.UDPMapping) { mtu := n.core.MTU() udpListenConn, err := s.ListenUDP(mapping.Listen) @@ -385,16 +482,15 @@ func main() { logger.Errorf("Failed to connect to %s: %s", mapping.Mapped, err) continue } - udpSession := &UDPSession{ - conn: udpFwdConn, + udpSession := &UDPConnSession{ + conn: udpFwdConn, remoteAddr: remoteUdpAddr, } remoteUdpConnections.Store(remoteUdpAddrStr, udpSession) - go types.ReverseProxyUDP(mtu, udpListenConn, remoteUdpAddr, *udpFwdConn) + go types.ReverseProxyUDPConn(mtu, udpListenConn, remoteUdpAddr, *udpFwdConn) } - - udpSession, ok := connVal.(*UDPSession) + udpSession, ok := connVal.(*UDPConnSession) if !ok { continue } diff --git a/src/types/mapping.go b/src/types/mapping.go index 03e99c3..015755f 100644 --- a/src/types/mapping.go +++ b/src/types/mapping.go @@ -7,62 +7,269 @@ import ( "strings" ) +func parseMappingString(value string) (first_address string, first_port int, second_address string, second_port int, err error) { + var first_port_string string = "" + var second_port_string string = "" + + tokens := strings.Split(value, ":") + tokens_len := len(tokens) + + // If token count is 1, then it is first and second port the same + + if tokens_len == 1 { + first_port, err = strconv.Atoi(tokens[0]) + if err != nil { + return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value) + } + second_port = first_port + } + + // If token count is 2, then it is : + + if tokens_len == 2 { + first_port, err = strconv.Atoi(tokens[0]) + if err != nil { + return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value) + } + second_port, err = strconv.Atoi(tokens[1]) + if err != nil { + return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value) + } + } + + // If token count is 3, parse it as + // :: + + if tokens_len == 3 { + first_port, err = strconv.Atoi(tokens[0]) + if err != nil { + return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value) + } + second_address, second_port_string, err = net.SplitHostPort( + tokens[1] + ":" + tokens[2]) + if err != nil { + return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value) + } + second_port, err = strconv.Atoi(second_port_string) + if err != nil { + return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value) + } + } + + // If token count is 4, parse it as + // ::: + + if tokens_len == 4 { + first_address, first_port_string, err = net.SplitHostPort( + tokens[0] + ":" + tokens[1]) + if err != nil { + return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value) + } + second_address, second_port_string, err = net.SplitHostPort( + tokens[0] + ":" + tokens[1]) + if err != nil { + return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value) + } + first_port, err = strconv.Atoi(first_port_string) + if err != nil { + return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value) + } + second_port, err = strconv.Atoi(second_port_string) + if err != nil { + return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value) + } + } + + if tokens_len > 4 { + // Last token needs to be the second_port + + second_port, err = strconv.Atoi(tokens[tokens_len-1]) + if err != nil { + return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value) + } + + // Cut seen tokens + + tokens = tokens[:tokens_len-1] + tokens_len = len(tokens) + + if strings.HasSuffix(tokens[tokens_len-1], "]") { + // Reverse-walk over tokens to find the end of + // numeric ipv6 address + + for i := tokens_len - 1; i >= 0; i-- { + if strings.HasPrefix(tokens[i], "[") { + // Store second address + second_address = strings.Join(tokens[i:], ":") + second_address, _ = strings.CutPrefix(second_address, "[") + second_address, _ = strings.CutSuffix(second_address, "]") + // Cut seen tokens + tokens = tokens[:i] + // break from loop + break + } + } + } else { + // next is second address in non-numerical-ipv6 form + second_address = tokens[tokens_len-1] + tokens = tokens[:tokens_len-1] + } + + tokens_len = len(tokens) + + if tokens_len < 1 { + return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value) + } + + // Last token needs to be the first_port + + first_port, err = strconv.Atoi(tokens[tokens_len-1]) + if err != nil { + return "", 0, "", 0, fmt.Errorf("Malformed mapping spec '%s'", value) + } + + // Cut seen tokens + + tokens = tokens[:tokens_len-1] + tokens_len = len(tokens) + + if tokens_len > 0 { + if strings.HasSuffix(tokens[tokens_len-1], "]") { + // Reverse-walk over tokens to find the end of + // numeric ipv6 address + + for i := tokens_len - 1; i >= 0; i-- { + if strings.HasPrefix(tokens[i], "[") { + // Store first address + first_address = strings.Join(tokens[i:], ":") + first_address, _ = strings.CutPrefix(first_address, "[") + first_address, _ = strings.CutSuffix(first_address, "]") + // break from loop + break + } + } + } else { + // next is first address in non-numerical-ipv6 form + first_address = tokens[tokens_len-1] + } + } + } + + if first_port == 0 || second_port == 0 { + return "", 0, "", 0, fmt.Errorf("Ports must not be zero") + } + + return first_address, first_port, second_address, second_port, nil +} + type TCPMapping struct { Listen *net.TCPAddr Mapped *net.TCPAddr } -type TCPMappings []TCPMapping +type TCPLocalMappings []TCPMapping -func (m *TCPMappings) String() string { +func (m *TCPLocalMappings) String() string { return "" } -func (m *TCPMappings) Set(value string) error { - tokens := strings.Split(value, ":") - if len(tokens) > 2 { - tokens = strings.SplitN(value, ":", 2) - host, port, err := net.SplitHostPort(tokens[1]) - if err != nil { - return fmt.Errorf("failed to split host and port: %w", err) - } - tokens = append(tokens[:1], host, port) - } - listenport, err := strconv.Atoi(tokens[0]) +func (m *TCPLocalMappings) Set(value string) error { + first_address, first_port, second_address, second_port, err := + parseMappingString(value) + if err != nil { - return fmt.Errorf("listen port is invalid: %w", err) + return err } - if listenport == 0 { - return fmt.Errorf("listen port must not be zero") + + // First address can be ipv4/ipv6 + // Second address can be only Yggdrasil ipv6 + + if !strings.Contains(second_address, ":") { + return fmt.Errorf("Yggdrasil listening address can be only IPv6") } + + // Create mapping + mapping := TCPMapping{ Listen: &net.TCPAddr{ - Port: listenport, + Port: first_port, }, Mapped: &net.TCPAddr{ IP: net.IPv6loopback, - Port: listenport, + Port: second_port, }, } - tokens = tokens[1:] - if len(tokens) > 0 { - mappedaddr := net.ParseIP(tokens[0]) + + if first_address != "" { + listenaddr := net.ParseIP(first_address) + if listenaddr == nil { + return fmt.Errorf("invalid listen address %q", first_address) + } + mapping.Listen.IP = listenaddr + } + + if second_address != "" { + mappedaddr := net.ParseIP(second_address) if mappedaddr == nil { - return fmt.Errorf("invalid mapped address %q", tokens[0]) + return fmt.Errorf("invalid mapped address %q", second_address) } + // TODO: Filter Yggdrasil IPs here mapping.Mapped.IP = mappedaddr - tokens = tokens[1:] } - if len(tokens) > 0 { - mappedport, err := strconv.Atoi(tokens[0]) - if err != nil { - return fmt.Errorf("mapped port is invalid: %w", err) + + *m = append(*m, mapping) + return nil +} + +type TCPRemoteMappings []TCPMapping + +func (m *TCPRemoteMappings) String() string { + return "" +} + +func (m *TCPRemoteMappings) Set(value string) error { + first_address, first_port, second_address, second_port, err := + parseMappingString(value) + + if err != nil { + return err + } + + // First address must be empty + // Second address can be ipv4/ipv6 + + if first_address != "" { + return fmt.Errorf("Yggdrasil listening must be empty") + } + + // Create mapping + + mapping := TCPMapping{ + Listen: &net.TCPAddr{ + Port: first_port, + }, + Mapped: &net.TCPAddr{ + IP: net.IPv6loopback, + Port: second_port, + }, + } + + if first_address != "" { + listenaddr := net.ParseIP(first_address) + if listenaddr == nil { + return fmt.Errorf("invalid listen address %q", first_address) } - if mappedport == 0 { - return fmt.Errorf("mapped port must not be zero") + mapping.Listen.IP = listenaddr + } + + if second_address != "" { + mappedaddr := net.ParseIP(second_address) + if mappedaddr == nil { + return fmt.Errorf("invalid mapped address %q", second_address) } - mapping.Mapped.Port = mappedport + mapping.Mapped.IP = mappedaddr } + *m = append(*m, mapping) return nil } @@ -72,57 +279,109 @@ type UDPMapping struct { Mapped *net.UDPAddr } -type UDPMappings []UDPMapping +type UDPLocalMappings []UDPMapping -func (m *UDPMappings) String() string { +func (m *UDPLocalMappings) String() string { return "" } -func (m *UDPMappings) Set(value string) error { - tokens := strings.Split(value, ":") - if len(tokens) > 2 { - tokens = strings.SplitN(value, ":", 2) - host, port, err := net.SplitHostPort(tokens[1]) - if err != nil { - return fmt.Errorf("failed to split host and port: %w", err) - } - tokens = append(tokens[:1], host, port) - } - listenport, err := strconv.Atoi(tokens[0]) +func (m *UDPLocalMappings) Set(value string) error { + first_address, first_port, second_address, second_port, err := + parseMappingString(value) + if err != nil { - return fmt.Errorf("listen port is invalid: %w", err) + return err } - if listenport == 0 { - return fmt.Errorf("listen port must not be zero") + + // First address can be ipv4/ipv6 + // Second address can be only Yggdrasil ipv6 + + if !strings.Contains(second_address, ":") { + return fmt.Errorf("Yggdrasil listening address can be only IPv6") } + + // Create mapping + mapping := UDPMapping{ Listen: &net.UDPAddr{ - Port: listenport, + Port: first_port, }, Mapped: &net.UDPAddr{ IP: net.IPv6loopback, - Port: listenport, + Port: second_port, }, } - tokens = tokens[1:] - if len(tokens) > 0 { - mappedaddr := net.ParseIP(tokens[0]) + + if first_address != "" { + listenaddr := net.ParseIP(first_address) + if listenaddr == nil { + return fmt.Errorf("invalid listen address %q", first_address) + } + mapping.Listen.IP = listenaddr + } + + if second_address != "" { + mappedaddr := net.ParseIP(second_address) if mappedaddr == nil { - return fmt.Errorf("invalid mapped address %q", tokens[0]) + return fmt.Errorf("invalid mapped address %q", second_address) } + // TODO: Filter Yggdrasil IPs here mapping.Mapped.IP = mappedaddr - tokens = tokens[1:] } - if len(tokens) > 0 { - mappedport, err := strconv.Atoi(tokens[0]) - if err != nil { - return fmt.Errorf("mapped port is invalid: %w", err) + + *m = append(*m, mapping) + return nil +} + +type UDPRemoteMappings []UDPMapping + +func (m *UDPRemoteMappings) String() string { + return "" +} + +func (m *UDPRemoteMappings) Set(value string) error { + first_address, first_port, second_address, second_port, err := + parseMappingString(value) + + if err != nil { + return err + } + + // First address must be empty + // Second address can be ipv4/ipv6 + + if first_address != "" { + return fmt.Errorf("Yggdrasil listening must be empty") + } + + // Create mapping + + mapping := UDPMapping{ + Listen: &net.UDPAddr{ + Port: first_port, + }, + Mapped: &net.UDPAddr{ + IP: net.IPv6loopback, + Port: second_port, + }, + } + + if first_address != "" { + listenaddr := net.ParseIP(first_address) + if listenaddr == nil { + return fmt.Errorf("invalid listen address %q", first_address) } - if mappedport == 0 { - return fmt.Errorf("mapped port must not be zero") + mapping.Listen.IP = listenaddr + } + + if second_address != "" { + mappedaddr := net.ParseIP(second_address) + if mappedaddr == nil { + return fmt.Errorf("invalid mapped address %q", second_address) } - mapping.Mapped.Port = mappedport + mapping.Mapped.IP = mappedaddr } + *m = append(*m, mapping) return nil } diff --git a/src/types/mapping_test.go b/src/types/mapping_test.go index 23e68b6..accc273 100644 --- a/src/types/mapping_test.go +++ b/src/types/mapping_test.go @@ -13,18 +13,48 @@ func TestEndpointMappings(t *testing.T) { if err := tcpMappings.Set("1234:192.168.1.1:4321"); err != nil { t.Fatal(err) } + if err := tcpMappings.Set("192.168.1.2:1234:192.168.1.1:4321"); err != nil { + t.Fatal(err) + } if err := tcpMappings.Set("1234:[2000::1]:4321"); err != nil { t.Fatal(err) } + if err := tcpMappings.Set("[2001:1]:1234:[2000::1]:4321"); err != nil { + t.Fatal(err) + } if err := tcpMappings.Set("a"); err == nil { t.Fatal("'a' should be an invalid exposed port") } if err := tcpMappings.Set("1234:localhost"); err == nil { t.Fatal("mapped address must be an IP literal") } + if err := tcpMappings.Set("127.0.0.1:1234:localhost"); err == nil { + t.Fatal("mapped address must be an IP literal") + } + if err := tcpMappings.Set("[2000:1]:1234:localhost"); err == nil { + t.Fatal("mapped address must be an IP literal") + } + if err := tcpMappings.Set("localhost:1234:127.0.0.1"); err == nil { + t.Fatal("listen address must be an IP literal") + } + if err := tcpMappings.Set("localhost:1234:127.0.0.1"); err == nil { + t.Fatal("listen address must be an IP literal") + } + if err := tcpMappings.Set("localhost:1234:[2000:1]"); err == nil { + t.Fatal("listen address must be an IP literal") + } + if err := tcpMappings.Set("localhost:1234:[2000:1]"); err == nil { + t.Fatal("listen address must be an IP literal") + } if err := tcpMappings.Set("1234:localhost:a"); err == nil { t.Fatal("'a' should be an invalid mapped port") } + if err := tcpMappings.Set("127.0.0.1:1234:127.0.0.1:a"); err == nil { + t.Fatal("'a' should be an invalid mapped port") + } + if err := tcpMappings.Set("[2000::1]:1234:[2000::1]:a"); err == nil { + t.Fatal("'a' should be an invalid mapped port") + } var udpMappings UDPMappings if err := udpMappings.Set("1234"); err != nil { t.Fatal(err) @@ -35,16 +65,46 @@ func TestEndpointMappings(t *testing.T) { if err := udpMappings.Set("1234:192.168.1.1:4321"); err != nil { t.Fatal(err) } + if err := udpMappings.Set("192.168.1.2:1234:192.168.1.1:4321"); err != nil { + t.Fatal(err) + } if err := udpMappings.Set("1234:[2000::1]:4321"); err != nil { t.Fatal(err) } + if err := udpMappings.Set("[2001:1]:1234:[2000::1]:4321"); err != nil { + t.Fatal(err) + } if err := udpMappings.Set("a"); err == nil { t.Fatal("'a' should be an invalid exposed port") } if err := udpMappings.Set("1234:localhost"); err == nil { t.Fatal("mapped address must be an IP literal") } + if err := udpMappings.Set("127.0.0.1:1234:localhost"); err == nil { + t.Fatal("mapped address must be an IP literal") + } + if err := udpMappings.Set("[2000:1]:1234:localhost"); err == nil { + t.Fatal("mapped address must be an IP literal") + } + if err := udpMappings.Set("localhost:1234:127.0.0.1"); err == nil { + t.Fatal("listen address must be an IP literal") + } + if err := udpMappings.Set("localhost:1234:127.0.0.1"); err == nil { + t.Fatal("listen address must be an IP literal") + } + if err := udpMappings.Set("localhost:1234:[2000:1]"); err == nil { + t.Fatal("listen address must be an IP literal") + } + if err := udpMappings.Set("localhost:1234:[2000:1]"); err == nil { + t.Fatal("listen address must be an IP literal") + } if err := udpMappings.Set("1234:localhost:a"); err == nil { t.Fatal("'a' should be an invalid mapped port") } + if err := udpMappings.Set("127.0.0.1:1234:127.0.0.1:a"); err == nil { + t.Fatal("'a' should be an invalid mapped port") + } + if err := udpMappings.Set("[2000::1]:1234:[2000::1]:a"); err == nil { + t.Fatal("'a' should be an invalid mapped port") + } } diff --git a/src/types/udpproxy.go b/src/types/udpproxy.go index 7fc4375..005cbac 100644 --- a/src/types/udpproxy.go +++ b/src/types/udpproxy.go @@ -4,7 +4,7 @@ import ( "net" ) -func ReverseProxyUDP(mtu uint64, dst net.PacketConn, dstAddr net.Addr, src net.UDPConn) error { +func ReverseProxyUDPConn(mtu uint64, dst net.PacketConn, dstAddr net.Addr, src net.UDPConn) error { buf := make([]byte, mtu) for { n, err := src.Read(buf[:]) @@ -20,3 +20,20 @@ func ReverseProxyUDP(mtu uint64, dst net.PacketConn, dstAddr net.Addr, src net.U } return nil } + +func ReverseProxyUDPPacketConn(mtu uint64, dst net.PacketConn, dstAddr net.Addr, src net.PacketConn) error { + buf := make([]byte, mtu) + for { + n, _, err := src.ReadFrom(buf[:]) + if err != nil { + return err + } + if n > 0 { + n, err = dst.WriteTo(buf[:n], dstAddr) + if err != nil { + return err + } + } + } + return nil +}