From 033ba9d32a072df6d4bd12538765636b9c32aa8e Mon Sep 17 00:00:00 2001 From: sbruens Date: Tue, 3 Dec 2024 15:32:02 -0500 Subject: [PATCH 1/8] Fix metrics test. --- prometheus/metrics_test.go | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/prometheus/metrics_test.go b/prometheus/metrics_test.go index 5dfcf05a..73f39fd2 100644 --- a/prometheus/metrics_test.go +++ b/prometheus/metrics_test.go @@ -70,17 +70,17 @@ func TestMethodsDontPanic(t *testing.T) { TargetProxy: 3, ProxyClient: 4, } - addr := fakeAddr("127.0.0.1:9") tcpMetrics := ssMetrics.AddOpenTCPConnection(&fakeConn{}) tcpMetrics.AddAuthenticated("0") tcpMetrics.AddClosed("OK", proxyMetrics, 10*time.Millisecond) tcpMetrics.AddProbe("ERR_CIPHER", "eof", proxyMetrics.ClientProxy) - udpMetrics := ssMetrics.AddUDPNatEntry(addr, "key-1") + udpMetrics := ssMetrics.AddOpenUDPAssociation(&fakeConn{}) + udpMetrics.AddAuthenticated("0") udpMetrics.AddPacketFromClient("OK", 10, 20) udpMetrics.AddPacketFromTarget("OK", 10, 20) - udpMetrics.RemoveNatEntry() + udpMetrics.AddClosed() ssMetrics.tcpServiceMetrics.AddCipherSearch(true, 10*time.Millisecond) ssMetrics.udpServiceMetrics.AddCipherSearch(true, 10*time.Millisecond) @@ -191,9 +191,7 @@ func BenchmarkProbe(b *testing.B) { func BenchmarkClientUDP(b *testing.B) { ssMetrics, _ := NewServiceMetrics(nil) - addr := fakeAddr("127.0.0.1:9") - accessKey := "key 1" - udpMetrics := ssMetrics.AddUDPNatEntry(addr, accessKey) + udpMetrics := ssMetrics.AddOpenUDPAssociation(&fakeConn{}) status := "OK" size := int64(1000) b.ResetTimer() @@ -204,9 +202,7 @@ func BenchmarkClientUDP(b *testing.B) { func BenchmarkTargetUDP(b *testing.B) { ssMetrics, _ := NewServiceMetrics(nil) - addr := fakeAddr("127.0.0.1:9") - accessKey := "key 1" - udpMetrics := ssMetrics.AddUDPNatEntry(addr, accessKey) + udpMetrics := ssMetrics.AddOpenUDPAssociation(&fakeConn{}) status := "OK" size := int64(1000) b.ResetTimer() @@ -215,12 +211,12 @@ func BenchmarkTargetUDP(b *testing.B) { } } -func BenchmarkNAT(b *testing.B) { +func BenchmarkClose(b *testing.B) { ssMetrics, _ := NewServiceMetrics(nil) addr := fakeAddr("127.0.0.1:9") b.ResetTimer() for i := 0; i < b.N; i++ { - udpMetrics := ssMetrics.AddUDPNatEntry(addr, "key-0") - udpMetrics.RemoveNatEntry() + udpMetrics := ssMetrics.AddOpenUDPAssociation(&fakeConn{}) + udpMetrics.AddClosed() } } From 87a9d23fc77c8788db3bcc5547705c3bcdb76ebe Mon Sep 17 00:00:00 2001 From: sbruens Date: Tue, 3 Dec 2024 15:52:48 -0500 Subject: [PATCH 2/8] Use `slicepool`. --- service/udp.go | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/service/udp.go b/service/udp.go index 220b3819..433ac139 100644 --- a/service/udp.go +++ b/service/udp.go @@ -26,8 +26,10 @@ import ( "time" "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" - onet "github.com/Jigsaw-Code/outline-ss-server/net" "github.com/shadowsocks/go-shadowsocks2/socks" + + "github.com/Jigsaw-Code/outline-ss-server/internal/slicepool" + onet "github.com/Jigsaw-Code/outline-ss-server/net" ) // NATMetrics is used to report NAT related metrics. @@ -84,8 +86,9 @@ func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherLis } type associationHandler struct { - logger *slog.Logger - bufferPool sync.Pool + logger *slog.Logger + // bufPool stores the byte slices used for reading and decrypting packets. + bufPool slicepool.Pool ciphers CipherList ssm ShadowsocksConnMetrics targetIPValidator onet.TargetIPValidator @@ -94,19 +97,14 @@ type associationHandler struct { var _ AssociationHandler = (*associationHandler)(nil) -// NewAssociationHandler creates a UDPService +// NewAssociationHandler creates an AssociationHandler func NewAssociationHandler(natTimeout time.Duration, cipherList CipherList, ssMetrics ShadowsocksConnMetrics) AssociationHandler { if ssMetrics == nil { ssMetrics = &NoOpShadowsocksConnMetrics{} } - bufferPool := sync.Pool{ - New: func() interface{} { - return make([]byte, serverUDPBufferSize) - }, - } return &associationHandler{ logger: noopLogger(), - bufferPool: bufferPool, + bufPool: slicepool.MakePool(serverUDPBufferSize), ciphers: cipherList, ssm: ssMetrics, targetIPValidator: onet.RequirePublicIP, @@ -229,18 +227,16 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA return } - cipherBuf := h.bufferPool.Get().([]byte) - textBuf := h.bufferPool.Get().([]byte) - defer func() { - h.bufferPool.Put(cipherBuf) - h.bufferPool.Put(textBuf) - }() + cipherLazySlice := h.bufPool.LazySlice() + textLazySlice := h.bufPool.LazySlice() var cryptoKey *shadowsocks.EncryptionKey var proxyTargetBytes int for { + cipherBuf := cipherLazySlice.Acquire() clientProxyBytes, err := clientAssociation.Read(cipherBuf) if errors.Is(err, net.ErrClosed) { + cipherLazySlice.Release() return } debugUDPAddr(h.logger, "Outbound packet.", clientAssociation.RemoteAddr(), slog.Int("bytes", clientProxyBytes)) @@ -262,9 +258,11 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA ip := clientAssociation.RemoteAddr().(*net.UDPAddr).AddrPort().Addr() var keyID string var cryptoKey *shadowsocks.EncryptionKey + textBuf := textLazySlice.Acquire() unpackStart := time.Now() textData, keyID, cryptoKey, err = findAccessKeyUDP(ip, textBuf, cipherData, h.ciphers, h.logger) timeToCipher := time.Since(unpackStart) + textLazySlice.Release() h.ssm.AddCipherSearch(err == nil, timeToCipher) if err != nil { @@ -302,6 +300,8 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA return nil }() + cipherLazySlice.Release() + status := "OK" if connError != nil { h.logger.LogAttrs(nil, slog.LevelDebug, "UDP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause)) From 50feaa29ff3b9032a2fefbe8ec4327be6925a930 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 4 Dec 2024 12:24:20 -0500 Subject: [PATCH 3/8] Let the assocation handler provide the buffer. --- prometheus/metrics_test.go | 1 - service/udp.go | 50 +++++++++++++++++++++++--------------- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/prometheus/metrics_test.go b/prometheus/metrics_test.go index 73f39fd2..c435ee0c 100644 --- a/prometheus/metrics_test.go +++ b/prometheus/metrics_test.go @@ -213,7 +213,6 @@ func BenchmarkTargetUDP(b *testing.B) { func BenchmarkClose(b *testing.B) { ssMetrics, _ := NewServiceMetrics(nil) - addr := fakeAddr("127.0.0.1:9") b.ResetTimer() for i := 0; i < b.N; i++ { udpMetrics := ssMetrics.AddOpenUDPAssociation(&fakeConn{}) diff --git a/service/udp.go b/service/udp.go index 433ac139..87c5ba4b 100644 --- a/service/udp.go +++ b/service/udp.go @@ -17,7 +17,6 @@ package service import ( "errors" "fmt" - "io" "log/slog" "net" "net/netip" @@ -173,7 +172,8 @@ func PacketServe(clientConn net.PacketConn, handle AssocationHandleFunc, metrics if conn == nil { conn = &natconn{ Conn: &packetConnWrapper{PacketConn: clientConn, raddr: addr}, - readCh: make(chan []byte, 1), + readBufCh: make(chan []byte, 1), + bytesReadCh: make(chan int, 1), } metrics.AddNATEntry() deleteEntry := nm.Add(addr, conn) @@ -186,32 +186,44 @@ func PacketServe(clientConn net.PacketConn, handle AssocationHandleFunc, metrics handle(conn) }(conn) } - conn.readCh <- pkt + readBuf, ok := <-conn.readBufCh + if !ok { + continue + } + copy(readBuf, pkt) + conn.bytesReadCh <- n } } +// natconn adapts a [net.Conn] to provide a synchronized reading mechanism for NAT traversal. +// +// The application provides the buffer to `Read()` (BYOB: Bring Your Own Buffer!) +// which minimizes buffer allocations and copying. type natconn struct { net.Conn - readCh chan []byte + + // readBufCh provides a buffer to copy incoming packet data into. + readBufCh chan []byte + + // bytesReadCh is used to signal the availability of new data and carries + // the length of the received packet. + bytesReadCh chan int } var _ net.Conn = (*natconn)(nil) func (c *natconn) Read(p []byte) (int, error) { - select { - case pkt := <-c.readCh: - if pkt == nil { - break - } - return copy(p, pkt), nil - case <-time.After(30 * time.Second): - break + c.readBufCh <- p + n, ok := <-c.bytesReadCh + if !ok { + return 0, net.ErrClosed } - return 0, io.EOF + return n, nil } func (c *natconn) Close() error { - close(c.readCh) + close(c.readBufCh) + close(c.bytesReadCh) c.Conn.Close() return nil } @@ -228,20 +240,21 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA } cipherLazySlice := h.bufPool.LazySlice() + cipherBuf := cipherLazySlice.Acquire() + defer cipherLazySlice.Release() + textLazySlice := h.bufPool.LazySlice() var cryptoKey *shadowsocks.EncryptionKey var proxyTargetBytes int for { - cipherBuf := cipherLazySlice.Acquire() clientProxyBytes, err := clientAssociation.Read(cipherBuf) if errors.Is(err, net.ErrClosed) { cipherLazySlice.Release() return } debugUDPAddr(h.logger, "Outbound packet.", clientAssociation.RemoteAddr(), slog.Int("bytes", clientProxyBytes)) - cipherData := cipherBuf[:clientProxyBytes] - + connError := func() *onet.ConnectionError { defer func() { if r := recover(); r != nil { @@ -251,6 +264,7 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA slog.LogAttrs(nil, slog.LevelDebug, "UDP: Done", slog.String("address", clientAssociation.RemoteAddr().String())) }() + cipherData := cipherBuf[:clientProxyBytes] var textData []byte var err error @@ -300,8 +314,6 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA return nil }() - cipherLazySlice.Release() - status := "OK" if connError != nil { h.logger.LogAttrs(nil, slog.LevelDebug, "UDP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause)) From b7a4a3cd3618af55e004624f789bdc59d932841d Mon Sep 17 00:00:00 2001 From: sbruens Date: Tue, 10 Dec 2024 14:44:22 -0500 Subject: [PATCH 4/8] Remove the `packetConnWrapper` and move the logic into `natconn` instead. --- service/udp.go | 42 ++++++++++++++---------------------------- service/udp_test.go | 2 +- 2 files changed, 15 insertions(+), 29 deletions(-) diff --git a/service/udp.go b/service/udp.go index 87c5ba4b..5d1583a2 100644 --- a/service/udp.go +++ b/service/udp.go @@ -171,7 +171,8 @@ func PacketServe(clientConn net.PacketConn, handle AssocationHandleFunc, metrics conn := nm.Get(addr.String()) if conn == nil { conn = &natconn{ - Conn: &packetConnWrapper{PacketConn: clientConn, raddr: addr}, + PacketConn: clientConn, + raddr: addr, readBufCh: make(chan []byte, 1), bytesReadCh: make(chan int, 1), } @@ -200,7 +201,8 @@ func PacketServe(clientConn net.PacketConn, handle AssocationHandleFunc, metrics // The application provides the buffer to `Read()` (BYOB: Bring Your Own Buffer!) // which minimizes buffer allocations and copying. type natconn struct { - net.Conn + net.PacketConn + raddr net.Addr // readBufCh provides a buffer to copy incoming packet data into. readBufCh chan []byte @@ -221,13 +223,22 @@ func (c *natconn) Read(p []byte) (int, error) { return n, nil } +func (c *natconn) Write(b []byte) (n int, err error) { + return c.PacketConn.WriteTo(b, c.raddr) +} + func (c *natconn) Close() error { close(c.readBufCh) close(c.bytesReadCh) - c.Conn.Close() + c.PacketConn.Close() return nil } +func (c *natconn) RemoteAddr() net.Addr { + return c.raddr +} + + func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPAssocationMetrics) { if connMetrics == nil { connMetrics = &NoOpUDPAssocationMetrics{} @@ -466,31 +477,6 @@ func (m *natmap) Close() error { return err } -// packetConnWrapper wraps a [net.PacketConn] and provides a [net.Conn] interface -// with a given remote address. -type packetConnWrapper struct { - net.PacketConn - raddr net.Addr -} - -var _ net.Conn = (*packetConnWrapper)(nil) - -// ReadFrom reads data from the connection. -func (pcw *packetConnWrapper) Read(b []byte) (n int, err error) { - n, _, err = pcw.PacketConn.ReadFrom(b) - return -} - -// WriteTo writes data to the connection. -func (pcw *packetConnWrapper) Write(b []byte) (n int, err error) { - return pcw.PacketConn.WriteTo(b, pcw.raddr) -} - -// RemoteAddr returns the remote network address. -func (pcw *packetConnWrapper) RemoteAddr() net.Addr { - return pcw.raddr -} - // Get the maximum length of the shadowsocks address header by parsing // and serializing an IPv6 address from the example range. var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345")) diff --git a/service/udp_test.go b/service/udp_test.go index 160c9697..1ca0ec96 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -508,7 +508,7 @@ func TestNATMap(t *testing.T) { nat := newNATmap() addr1 := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} pc := makePacketConn() - conn1 := &natconn{Conn: &packetConnWrapper{PacketConn: pc, raddr: addr1}} + conn1 := &natconn{PacketConn: pc, raddr: addr1} nat.Add(addr1, conn1) err := nat.Close() From f80294c39c894311b8b9138cb4bea105459a07f2 Mon Sep 17 00:00:00 2001 From: sbruens Date: Tue, 10 Dec 2024 16:35:19 -0500 Subject: [PATCH 5/8] Fix close while reading of `natconn`. --- service/udp.go | 32 ++++++++++++++++++-------------- service/udp_test.go | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/service/udp.go b/service/udp.go index 5d1583a2..4b822ce0 100644 --- a/service/udp.go +++ b/service/udp.go @@ -171,9 +171,10 @@ func PacketServe(clientConn net.PacketConn, handle AssocationHandleFunc, metrics conn := nm.Get(addr.String()) if conn == nil { conn = &natconn{ - PacketConn: clientConn, - raddr: addr, - readBufCh: make(chan []byte, 1), + PacketConn: clientConn, + raddr: addr, + doneCh: make(chan struct{}), + readBufCh: make(chan []byte, 1), bytesReadCh: make(chan int, 1), } metrics.AddNATEntry() @@ -202,10 +203,11 @@ func PacketServe(clientConn net.PacketConn, handle AssocationHandleFunc, metrics // which minimizes buffer allocations and copying. type natconn struct { net.PacketConn - raddr net.Addr + raddr net.Addr + doneCh chan struct{} // readBufCh provides a buffer to copy incoming packet data into. - readBufCh chan []byte + readBufCh chan []byte // bytesReadCh is used to signal the availability of new data and carries // the length of the received packet. @@ -215,12 +217,16 @@ type natconn struct { var _ net.Conn = (*natconn)(nil) func (c *natconn) Read(p []byte) (int, error) { - c.readBufCh <- p - n, ok := <-c.bytesReadCh - if !ok { + select { + case c.readBufCh <- p: + n, ok := <-c.bytesReadCh + if !ok { + return 0, net.ErrClosed + } + return n, nil + case <-c.doneCh: return 0, net.ErrClosed } - return n, nil } func (c *natconn) Write(b []byte) (n int, err error) { @@ -228,17 +234,15 @@ func (c *natconn) Write(b []byte) (n int, err error) { } func (c *natconn) Close() error { + close(c.doneCh) close(c.readBufCh) close(c.bytesReadCh) - c.PacketConn.Close() - return nil + return c.PacketConn.Close() } - func (c *natconn) RemoteAddr() net.Addr { return c.raddr } - func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPAssocationMetrics) { if connMetrics == nil { connMetrics = &NoOpUDPAssocationMetrics{} @@ -265,7 +269,7 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA return } debugUDPAddr(h.logger, "Outbound packet.", clientAssociation.RemoteAddr(), slog.Int("bytes", clientProxyBytes)) - + connError := func() *onet.ConnectionError { defer func() { if r := recover(); r != nil { diff --git a/service/udp_test.go b/service/udp_test.go index 1ca0ec96..6bba5f0b 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -195,6 +195,24 @@ func startTestHandler() (AssociationHandler, func(target net.Addr, payload []byt }, targetConn } +func TestNatconnCloseWhileReading(t *testing.T) { + nc := &natconn{ + PacketConn: makePacketConn(), + raddr: &clientAddr, + doneCh: make(chan struct{}), + readBufCh: make(chan []byte, 1), + bytesReadCh: make(chan int, 1), + } + go func() { + buf := make([]byte, 1024) + nc.Read(buf) + }() + + err := nc.Close() + + assert.NoError(t, err, "Close should not panic or return an error") +} + func TestAssociationHandler_Handle_IPFilter(t *testing.T) { t.Run("RequirePublicIP blocks localhost", func(t *testing.T) { handler, sendPayload, targetConn := startTestHandler() From 36d4b275a0252862ba4593ecc436b626a93bcbb4 Mon Sep 17 00:00:00 2001 From: sbruens Date: Tue, 10 Dec 2024 16:51:46 -0500 Subject: [PATCH 6/8] Simplify the natmap a little. --- service/udp.go | 32 +++++++++----------------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/service/udp.go b/service/udp.go index 4b822ce0..57818a78 100644 --- a/service/udp.go +++ b/service/udp.go @@ -179,7 +179,7 @@ func PacketServe(clientConn net.PacketConn, handle AssocationHandleFunc, metrics } metrics.AddNATEntry() deleteEntry := nm.Add(addr, conn) - go func(conn *natconn) { + go func(conn net.Conn) { defer func() { conn.Close() deleteEntry() @@ -239,6 +239,7 @@ func (c *natconn) Close() error { close(c.bytesReadCh) return c.PacketConn.Close() } + func (c *natconn) RemoteAddr() net.Addr { return c.raddr } @@ -437,33 +438,18 @@ func (m *natmap) Get(key string) *natconn { return m.keyConn[key] } -func (m *natmap) set(key string, pc *natconn) { - m.Lock() - defer m.Unlock() - - m.keyConn[key] = pc - return -} - -func (m *natmap) del(key string) *natconn { - m.Lock() - defer m.Unlock() - - entry, ok := m.keyConn[key] - if ok { - delete(m.keyConn, key) - return entry - } - return nil -} - // Add adds a new UDP NAT entry to the natmap and returns a closure to delete // the entry. func (m *natmap) Add(addr net.Addr, pc *natconn) func() { + m.Lock() + defer m.Unlock() + key := addr.String() - m.set(key, pc) + m.keyConn[key] = pc return func() { - m.del(key) + m.Lock() + defer m.Unlock() + delete(m.keyConn, key) } } From f5d9ac3a73cd97ea00a7c53a54ba68cd202db557 Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 13 Dec 2024 12:56:59 -0500 Subject: [PATCH 7/8] Refactor `PacketServe` to use events (close and read). --- service/udp.go | 158 +++++++++++++++++++++++++++++--------------- service/udp_test.go | 52 +++++++-------- 2 files changed, 130 insertions(+), 80 deletions(-) diff --git a/service/udp.go b/service/udp.go index 57818a78..2431583f 100644 --- a/service/udp.go +++ b/service/udp.go @@ -48,6 +48,9 @@ type UDPAssocationMetrics interface { // Max UDP buffer size for the server code. const serverUDPBufferSize = 64 * 1024 +// Buffer pool used for reading UDP packets. +var readBufPool = slicepool.MakePool(serverUDPBufferSize) + // Wrapper for slog.Debug during UDP proxying. func debugUDP(l *slog.Logger, template string, cipherID string, attr slog.Attr) { // This is an optimization to reduce unnecessary allocations due to an interaction @@ -153,58 +156,97 @@ type AssocationHandleFunc func(assocation net.Conn) // function for each association. It uses a NAT map to track active associations // and handles their lifecycle. func PacketServe(clientConn net.PacketConn, handle AssocationHandleFunc, metrics NATMetrics) { + // This goroutine continuously reads from clientConn and sends the received data + // to readCh. It uses a buffer pool (readBufPool) to efficiently manage buffers + // and minimize allocations. The LazySlice is sent along with the read event + // to allow the receiver to release the buffer back to the pool after processing. + readCh := make(chan readEvent, 10) + go func() { + for { + lazySlice := readBufPool.LazySlice() + buffer := lazySlice.Acquire() + n, addr, err := clientConn.ReadFrom(buffer) + if err != nil { + lazySlice.Release() + if errors.Is(err, net.ErrClosed) { + readCh <- readEvent{err: err} + return + } + slog.Warn("Failed to read from client. Continuing to listen.", "err", err) + continue + } + readCh <- readEvent{ + poolSlice: lazySlice, + pkt: buffer[:n], + addr: addr, + } + } + }() + nm := newNATmap() defer nm.Close() - buffer := make([]byte, serverUDPBufferSize) + // This loop handles events from closeCh (connection closures) and readCh + // (incoming data). It removes NAT entries for closed connections and processes + // incoming data packets. The loop also ensures that buffers acquired from + // the readBufPool are released back to the pool after processing is complete. + closeCh := make(chan net.Addr, 10) for { - n, addr, err := clientConn.ReadFrom(buffer) - if err != nil { - if errors.Is(err, net.ErrClosed) { - break + select { + case addr := <-closeCh: + metrics.RemoveNATEntry() + nm.Del(addr) + case read := <-readCh: + if read.err != nil { + return } - slog.Warn("Failed to read from client. Continuing to listen.", "err", err) - continue - } - pkt := buffer[:n] - - // TODO: Include server address in the NAT key as well. - conn := nm.Get(addr.String()) - if conn == nil { - conn = &natconn{ - PacketConn: clientConn, - raddr: addr, - doneCh: make(chan struct{}), - readBufCh: make(chan []byte, 1), - bytesReadCh: make(chan int, 1), + + poolSlice := read.poolSlice + pkt := read.pkt + addr := read.addr + + // TODO: Include server address in the NAT key as well. + conn := nm.Get(addr.String()) + if conn == nil { + conn = &natconn{ + PacketConn: clientConn, + raddr: addr, + closeCh: closeCh, + doneCh: make(chan struct{}), + readBufCh: make(chan []byte, 1), + bytesReadCh: make(chan int, 1), + } + metrics.AddNATEntry() + nm.Add(addr, conn) + go handle(conn) } - metrics.AddNATEntry() - deleteEntry := nm.Add(addr, conn) - go func(conn net.Conn) { - defer func() { - conn.Close() - deleteEntry() - metrics.RemoveNATEntry() - }() - handle(conn) - }(conn) - } - readBuf, ok := <-conn.readBufCh - if !ok { - continue + readBuf, ok := <-conn.readBufCh + if !ok { + poolSlice.Release() + continue + } + copy(readBuf, pkt) + poolSlice.Release() + conn.bytesReadCh <- len(pkt) } - copy(readBuf, pkt) - conn.bytesReadCh <- n } } +type readEvent struct { + poolSlice slicepool.LazySlice + pkt []byte + addr net.Addr + err error +} + // natconn adapts a [net.Conn] to provide a synchronized reading mechanism for NAT traversal. // // The application provides the buffer to `Read()` (BYOB: Bring Your Own Buffer!) // which minimizes buffer allocations and copying. type natconn struct { net.PacketConn - raddr net.Addr - doneCh chan struct{} + raddr net.Addr + closeCh chan net.Addr + doneCh chan struct{} // readBufCh provides a buffer to copy incoming packet data into. readBufCh chan []byte @@ -218,14 +260,16 @@ var _ net.Conn = (*natconn)(nil) func (c *natconn) Read(p []byte) (int, error) { select { + case <-c.doneCh: + c.closeCh <- c.raddr + return 0, net.ErrClosed case c.readBufCh <- p: n, ok := <-c.bytesReadCh if !ok { + c.closeCh <- c.raddr return 0, net.ErrClosed } return n, nil - case <-c.doneCh: - return 0, net.ErrClosed } } @@ -235,9 +279,8 @@ func (c *natconn) Write(b []byte) (n int, err error) { func (c *natconn) Close() error { close(c.doneCh) - close(c.readBufCh) close(c.bytesReadCh) - return c.PacketConn.Close() + return nil } func (c *natconn) RemoteAddr() net.Addr { @@ -255,9 +298,9 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA return } - cipherLazySlice := h.bufPool.LazySlice() - cipherBuf := cipherLazySlice.Acquire() - defer cipherLazySlice.Release() + cipherSlice := h.bufPool.LazySlice() + cipherBuf := cipherSlice.Acquire() + defer cipherSlice.Release() textLazySlice := h.bufPool.LazySlice() @@ -266,7 +309,7 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA for { clientProxyBytes, err := clientAssociation.Read(cipherBuf) if errors.Is(err, net.ErrClosed) { - cipherLazySlice.Release() + cipherSlice.Release() return } debugUDPAddr(h.logger, "Outbound packet.", clientAssociation.RemoteAddr(), slog.Int("bytes", clientProxyBytes)) @@ -301,9 +344,10 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA connMetrics.AddAuthenticated(keyID) go func() { - defer connMetrics.AddClosed() timedCopy(clientAssociation, targetConn, cryptoKey, connMetrics, h.logger) + connMetrics.AddClosed() targetConn.Close() + clientAssociation.Close() }() } else { @@ -438,19 +482,25 @@ func (m *natmap) Get(key string) *natconn { return m.keyConn[key] } +func (m *natmap) Del(addr net.Addr) net.PacketConn { + m.Lock() + defer m.Unlock() + + entry, ok := m.keyConn[addr.String()] + if ok { + delete(m.keyConn, addr.String()) + return entry + } + return nil +} + // Add adds a new UDP NAT entry to the natmap and returns a closure to delete // the entry. -func (m *natmap) Add(addr net.Addr, pc *natconn) func() { +func (m *natmap) Add(addr net.Addr, pc *natconn) { m.Lock() defer m.Unlock() - key := addr.String() - m.keyConn[key] = pc - return func() { - m.Lock() - defer m.Unlock() - delete(m.keyConn, key) - } + m.keyConn[addr.String()] = pc } func (m *natmap) Close() error { diff --git a/service/udp_test.go b/service/udp_test.go index 6bba5f0b..8a7eb043 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -480,56 +480,56 @@ func TestTimedPacketConn(t *testing.T) { func TestNATMap(t *testing.T) { t.Run("Empty", func(t *testing.T) { - nat := newNATmap() - if nat.Get("foo") != nil { + nm := newNATmap() + if nm.Get("foo") != nil { t.Error("Expected nil value from empty NAT map") } }) t.Run("Add", func(t *testing.T) { - nat := newNATmap() - addr1 := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} + nm := newNATmap() + addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} conn1 := &natconn{} - nat.Add(addr1, conn1) - assert.Equal(t, conn1, nat.Get(addr1.String()), "Get should return the correct connection") + nm.Add(addr, conn1) + assert.Equal(t, conn1, nm.Get(addr.String()), "Get should return the correct connection") conn2 := &natconn{} - nat.Add(addr1, conn2) - assert.Equal(t, conn2, nat.Get(addr1.String()), "Adding with the same address should overwrite the entry") + nm.Add(addr, conn2) + assert.Equal(t, conn2, nm.Get(addr.String()), "Adding with the same address should overwrite the entry") }) t.Run("Get", func(t *testing.T) { - nat := newNATmap() - addr1 := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} - conn1 := &natconn{} - nat.Add(addr1, conn1) + nm := newNATmap() + addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} + conn := &natconn{} + nm.Add(addr, conn) - assert.Equal(t, conn1, nat.Get(addr1.String()), "Get should return the correct connection for an existing address") + assert.Equal(t, conn, nm.Get(addr.String()), "Get should return the correct connection for an existing address") addr2 := &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 5678} - assert.Nil(t, nat.Get(addr2.String()), "Get should return nil for a non-existent address") + assert.Nil(t, nm.Get(addr2.String()), "Get should return nil for a non-existent address") }) - t.Run("closure_deletes", func(t *testing.T) { - nat := newNATmap() - addr1 := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} - conn1 := &natconn{} - deleteEntry := nat.Add(addr1, conn1) + t.Run("Del", func(t *testing.T) { + nm := newNATmap() + addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} + conn := &natconn{} + nm.Add(addr, conn) - deleteEntry() + nm.Del(addr) - assert.Nil(t, nat.Get(addr1.String()), "Get should return nil after deleting the entry") + assert.Nil(t, nm.Get(addr.String()), "Get should return nil after deleting the entry") }) t.Run("Close", func(t *testing.T) { - nat := newNATmap() - addr1 := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} + nm := newNATmap() + addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} pc := makePacketConn() - conn1 := &natconn{PacketConn: pc, raddr: addr1} - nat.Add(addr1, conn1) + conn := &natconn{PacketConn: pc, raddr: addr} + nm.Add(addr, conn) - err := nat.Close() + err := nm.Close() assert.NoError(t, err, "Close should not return an error") // The underlying connection should be scheduled to close immediately. From e0547f28a46f1f1e6fac0cb6d1f5ddb89d1fcf5f Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 13 Dec 2024 13:37:33 -0500 Subject: [PATCH 8/8] Use correct logger. --- service/udp.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/service/udp.go b/service/udp.go index 2431583f..6ca1cd2c 100644 --- a/service/udp.go +++ b/service/udp.go @@ -294,7 +294,7 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA targetConn, err := h.targetConnFactory() if err != nil { - slog.Error("UDP: failed to create target connection", slog.Any("err", err)) + h.logger.Error("UDP: failed to create target connection", slog.Any("err", err)) return } @@ -317,10 +317,10 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA connError := func() *onet.ConnectionError { defer func() { if r := recover(); r != nil { - slog.Error("Panic in UDP loop. Continuing to listen.", "err", r) + h.logger.Error("Panic in UDP loop. Continuing to listen.", "err", r) debug.PrintStack() } - slog.LogAttrs(nil, slog.LevelDebug, "UDP: Done", slog.String("address", clientAssociation.RemoteAddr().String())) + h.logger.LogAttrs(nil, slog.LevelDebug, "UDP: Done", slog.String("address", clientAssociation.RemoteAddr().String())) }() cipherData := cipherBuf[:clientProxyBytes]