diff --git a/cmd/yggstack/main.go b/cmd/yggstack/main.go index ef9d02a..2a68496 100644 --- a/cmd/yggstack/main.go +++ b/cmd/yggstack/main.go @@ -40,13 +40,8 @@ type node struct { socks5Listener net.Listener } -type UDPConnSession struct { - conn *net.UDPConn - remoteAddr net.Addr -} - -type UDPPacketConnSession struct { - conn *net.PacketConn +type UDPSession struct { + conn interface{} remoteAddr net.Addr } @@ -396,22 +391,23 @@ func main() { logger.Errorf("Failed to connect to %s: %s", mapping.Mapped, err) continue } - udpSession := &UDPPacketConnSession{ - conn: &udpFwdConn, + udpSession := &UDPSession{ + conn: udpFwdConn, remoteAddr: remoteUdpAddr, } localUdpConnections.Store(remoteUdpAddrStr, udpSession) - go types.ReverseProxyUDPPacketConn(mtu, udpListenConn, remoteUdpAddr, udpFwdConn) + go types.ReverseProxyUDP(mtu, udpListenConn, remoteUdpAddr, *udpFwdConn) } - udpSession, ok := connVal.(*UDPPacketConnSession) + udpSession, ok := connVal.(*UDPSession) if !ok { continue } - udpFwdConn := *udpSession.conn + udpFwdConnPtr := udpSession.conn.(*net.Conn) + udpFwdConn := *udpFwdConnPtr - _, err = udpFwdConn.WriteTo(udpBuffer[:bytesRead], mapping.Mapped) + _, err = udpFwdConn.Write(udpBuffer[:bytesRead]) if err != nil { logger.Debugf("Cannot write from yggdrasil to udp listener: %q", err) udpFwdConn.Close() @@ -482,23 +478,26 @@ func main() { logger.Errorf("Failed to connect to %s: %s", mapping.Mapped, err) continue } - udpSession := &UDPConnSession{ + udpSession := &UDPSession{ conn: udpFwdConn, remoteAddr: remoteUdpAddr, } remoteUdpConnections.Store(remoteUdpAddrStr, udpSession) - go types.ReverseProxyUDPConn(mtu, udpListenConn, remoteUdpAddr, *udpFwdConn) + go types.ReverseProxyUDP(mtu, udpListenConn, remoteUdpAddr, *udpFwdConn) } - udpSession, ok := connVal.(*UDPConnSession) + udpSession, ok := connVal.(*UDPSession) if !ok { continue } - _, err = udpSession.conn.Write(udpBuffer[:bytesRead]) + udpFwdConnPtr := udpSession.conn.(*net.Conn) + udpFwdConn := *udpFwdConnPtr + + _, err = udpFwdConn.Write(udpBuffer[:bytesRead]) if err != nil { logger.Debugf("Cannot write from yggdrasil to udp listener: %q", err) - udpSession.conn.Close() + udpFwdConn.Close() remoteUdpConnections.Delete(remoteUdpAddrStr) continue } diff --git a/src/netstack/netstack.go b/src/netstack/netstack.go index 42fb056..9b723d9 100644 --- a/src/netstack/netstack.go +++ b/src/netstack/netstack.go @@ -86,12 +86,12 @@ func (s *YggdrasilNetstack) DialContext(ctx context.Context, network, address st } } -func (s *YggdrasilNetstack) DialTCP(addr *net.TCPAddr) (net.Conn, error) { +func (s *YggdrasilNetstack) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { fa, pn, _ := convertToFullAddr(addr.IP, addr.Port) return gonet.DialTCP(s.stack, fa, pn) } -func (s *YggdrasilNetstack) DialUDP(addr *net.UDPAddr) (net.PacketConn, error) { +func (s *YggdrasilNetstack) DialUDP(addr *net.UDPAddr) (*gonet.UDPConn, error) { fa, pn, _ := convertToFullAddr(addr.IP, addr.Port) return gonet.DialUDP(s.stack, nil, &fa, pn) } @@ -101,7 +101,7 @@ func (s *YggdrasilNetstack) ListenTCP(addr *net.TCPAddr) (net.Listener, error) { return gonet.ListenTCP(s.stack, fa, pn) } -func (s *YggdrasilNetstack) ListenUDP(addr *net.UDPAddr) (net.PacketConn, error) { +func (s *YggdrasilNetstack) ListenUDP(addr *net.UDPAddr) (*gonet.UDPConn, error) { fa, pn, _ := convertToFullAddr(addr.IP, addr.Port) return gonet.DialUDP(s.stack, &fa, nil, pn) } diff --git a/src/types/udpproxy.go b/src/types/udpproxy.go index 005cbac..999f931 100644 --- a/src/types/udpproxy.go +++ b/src/types/udpproxy.go @@ -4,27 +4,11 @@ import ( "net" ) -func ReverseProxyUDPConn(mtu uint64, dst net.PacketConn, dstAddr net.Addr, src net.UDPConn) error { +func ReverseProxyUDP(mtu uint64, dst net.PacketConn, dstAddr net.Addr, src *net.Conn) error { buf := make([]byte, mtu) for { - n, err := src.Read(buf[:]) - if err != nil { - return err - } - if n > 0 { - n, err = dst.WriteTo(buf[:n], dstAddr) - if err != nil { - return err - } - } - } - return nil -} - -func ReverseProxyUDPPacketConn(mtu uint64, dst net.PacketConn, dstAddr net.Addr, src net.PacketConn) error { - buf := make([]byte, mtu) - for { - n, _, err := src.ReadFrom(buf[:]) + srcConn := *src + n, err := srcConn.Read(buf[:]) if err != nil { return err }