diff --git a/cmd/yggstack/main.go b/cmd/yggstack/main.go index 5aeacb8..fc17f18 100644 --- a/cmd/yggstack/main.go +++ b/cmd/yggstack/main.go @@ -362,38 +362,23 @@ func main() { remoteUdpConnections := new(sync.Map) udpBuffer := make([]byte, mtu) for { - bytesRead, remoteUdpAddr, err := udpListenConn.ReadFromUDP(udpBuffer) + bytesRead, remoteUdpAddr, err := udpListenConn.ReadFrom(udpBuffer) if err != nil { continue } var udpFwdConn *net.UDPConn = nil - remoteUdpConnections.Range(func(key, value interface{}) bool { - remoteUdpAddrCandidate, ok := key.(*net.UDPAddr) - if !ok { - logger.Errorf("Candidate for UDP address broken!") - return true - } - if remoteUdpAddrCandidate.String() == remoteUdpAddr.String() { - udpFwdConn, ok = *value.(*net.UDPConn) - if !ok { - logger.Errorf("Candidate for UDP connection broken!") - return true - } - return false - } - return true - }) + udpFwdConn, ok := remoteUdpConnections.Load(remoteUdpAddr) - if udpFwdConn == nil { + if !ok { udpFwdConn, err = net.DialUDP("udp", nil, mapping.Mapped) if err != nil { logger.Errorf("Failed to connect to %s: %s", mapping.Mapped, err) continue } - remoteUdpConnections.Store(&remoteUdpAddr, &udpFwdConn) - go types.ReverseProxyUDP(mtu, &udpListenConn, &remoteUdpAddr, &udpFwdConn) + remoteUdpConnections.Store(remoteUdpAddr, &udpFwdConn) + go types.ReverseProxyUDP(mtu, &udpListenConn, remoteUdpAddr, &udpFwdConn) } _, err = udpFwdConn.Write(udpBuffer[:bytesRead]) diff --git a/src/netstack/netstack.go b/src/netstack/netstack.go index d74c2f9..42fb056 100644 --- a/src/netstack/netstack.go +++ b/src/netstack/netstack.go @@ -91,7 +91,7 @@ func (s *YggdrasilNetstack) DialTCP(addr *net.TCPAddr) (net.Conn, error) { return gonet.DialTCP(s.stack, fa, pn) } -func (s *YggdrasilNetstack) DialUDP(addr *net.UDPAddr) (net.UDPConn, error) { +func (s *YggdrasilNetstack) DialUDP(addr *net.UDPAddr) (net.PacketConn, error) { fa, pn, _ := convertToFullAddr(addr.IP, addr.Port) return gonet.DialUDP(s.stack, nil, &fa, pn) } @@ -101,7 +101,7 @@ func (s *YggdrasilNetstack) ListenTCP(addr *net.TCPAddr) (net.Listener, error) { return gonet.ListenTCP(s.stack, fa, pn) } -func (s *YggdrasilNetstack) ListenUDP(addr *net.UDPAddr) (net.UDPConn, error) { +func (s *YggdrasilNetstack) ListenUDP(addr *net.UDPAddr) (net.PacketConn, error) { fa, pn, _ := convertToFullAddr(addr.IP, addr.Port) return gonet.DialUDP(s.stack, &fa, nil, pn) } diff --git a/src/types/udpproxy.go b/src/types/udpproxy.go index 3aa967d..6db5a3d 100644 --- a/src/types/udpproxy.go +++ b/src/types/udpproxy.go @@ -4,7 +4,7 @@ import ( "net" ) -func ReverseProxyUDP(mtu uint64, dst *net.UDPConn, dstAddr *net.UDPAddr, src *net.UDPConn) error { +func ReverseProxyUDP(mtu uint64, dst *net.PacketConn, dstAddr net.Addr, src *net.PacketConn) error { buf := make([]byte, mtu) for { n, err := src.Read(buf[:]) @@ -12,7 +12,7 @@ func ReverseProxyUDP(mtu uint64, dst *net.UDPConn, dstAddr *net.UDPAddr, src *ne return err } if n > 0 { - n, err = dst.WriteToUDP(buf[:n], dstAddr) + n, err = dst.WriteTo(buf[:n], dstAddr) if err != nil { return err }