diff --git a/prometheus/metrics_test.go b/prometheus/metrics_test.go index 5dfcf05a..c435ee0c 100644 --- a/prometheus/metrics_test.go +++ b/prometheus/metrics_test.go @@ -70,17 +70,17 @@ func TestMethodsDontPanic(t *testing.T) { TargetProxy: 3, ProxyClient: 4, } - addr := fakeAddr("127.0.0.1:9") tcpMetrics := ssMetrics.AddOpenTCPConnection(&fakeConn{}) tcpMetrics.AddAuthenticated("0") tcpMetrics.AddClosed("OK", proxyMetrics, 10*time.Millisecond) tcpMetrics.AddProbe("ERR_CIPHER", "eof", proxyMetrics.ClientProxy) - udpMetrics := ssMetrics.AddUDPNatEntry(addr, "key-1") + udpMetrics := ssMetrics.AddOpenUDPAssociation(&fakeConn{}) + udpMetrics.AddAuthenticated("0") udpMetrics.AddPacketFromClient("OK", 10, 20) udpMetrics.AddPacketFromTarget("OK", 10, 20) - udpMetrics.RemoveNatEntry() + udpMetrics.AddClosed() ssMetrics.tcpServiceMetrics.AddCipherSearch(true, 10*time.Millisecond) ssMetrics.udpServiceMetrics.AddCipherSearch(true, 10*time.Millisecond) @@ -191,9 +191,7 @@ func BenchmarkProbe(b *testing.B) { func BenchmarkClientUDP(b *testing.B) { ssMetrics, _ := NewServiceMetrics(nil) - addr := fakeAddr("127.0.0.1:9") - accessKey := "key 1" - udpMetrics := ssMetrics.AddUDPNatEntry(addr, accessKey) + udpMetrics := ssMetrics.AddOpenUDPAssociation(&fakeConn{}) status := "OK" size := int64(1000) b.ResetTimer() @@ -204,9 +202,7 @@ func BenchmarkClientUDP(b *testing.B) { func BenchmarkTargetUDP(b *testing.B) { ssMetrics, _ := NewServiceMetrics(nil) - addr := fakeAddr("127.0.0.1:9") - accessKey := "key 1" - udpMetrics := ssMetrics.AddUDPNatEntry(addr, accessKey) + udpMetrics := ssMetrics.AddOpenUDPAssociation(&fakeConn{}) status := "OK" size := int64(1000) b.ResetTimer() @@ -215,12 +211,11 @@ func BenchmarkTargetUDP(b *testing.B) { } } -func BenchmarkNAT(b *testing.B) { +func BenchmarkClose(b *testing.B) { ssMetrics, _ := NewServiceMetrics(nil) - addr := fakeAddr("127.0.0.1:9") b.ResetTimer() for i := 0; i < b.N; i++ { - udpMetrics := ssMetrics.AddUDPNatEntry(addr, "key-0") - udpMetrics.RemoveNatEntry() + udpMetrics := ssMetrics.AddOpenUDPAssociation(&fakeConn{}) + udpMetrics.AddClosed() } } diff --git a/service/udp.go b/service/udp.go index 220b3819..6ca1cd2c 100644 --- a/service/udp.go +++ b/service/udp.go @@ -17,7 +17,6 @@ package service import ( "errors" "fmt" - "io" "log/slog" "net" "net/netip" @@ -26,8 +25,10 @@ import ( "time" "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" - onet "github.com/Jigsaw-Code/outline-ss-server/net" "github.com/shadowsocks/go-shadowsocks2/socks" + + "github.com/Jigsaw-Code/outline-ss-server/internal/slicepool" + onet "github.com/Jigsaw-Code/outline-ss-server/net" ) // NATMetrics is used to report NAT related metrics. @@ -47,6 +48,9 @@ type UDPAssocationMetrics interface { // Max UDP buffer size for the server code. const serverUDPBufferSize = 64 * 1024 +// Buffer pool used for reading UDP packets. +var readBufPool = slicepool.MakePool(serverUDPBufferSize) + // Wrapper for slog.Debug during UDP proxying. func debugUDP(l *slog.Logger, template string, cipherID string, attr slog.Attr) { // This is an optimization to reduce unnecessary allocations due to an interaction @@ -84,8 +88,9 @@ func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherLis } type associationHandler struct { - logger *slog.Logger - bufferPool sync.Pool + logger *slog.Logger + // bufPool stores the byte slices used for reading and decrypting packets. + bufPool slicepool.Pool ciphers CipherList ssm ShadowsocksConnMetrics targetIPValidator onet.TargetIPValidator @@ -94,19 +99,14 @@ type associationHandler struct { var _ AssociationHandler = (*associationHandler)(nil) -// NewAssociationHandler creates a UDPService +// NewAssociationHandler creates an AssociationHandler func NewAssociationHandler(natTimeout time.Duration, cipherList CipherList, ssMetrics ShadowsocksConnMetrics) AssociationHandler { if ssMetrics == nil { ssMetrics = &NoOpShadowsocksConnMetrics{} } - bufferPool := sync.Pool{ - New: func() interface{} { - return make([]byte, serverUDPBufferSize) - }, - } return &associationHandler{ logger: noopLogger(), - bufferPool: bufferPool, + bufPool: slicepool.MakePool(serverUDPBufferSize), ciphers: cipherList, ssm: ssMetrics, targetIPValidator: onet.RequirePublicIP, @@ -156,68 +156,137 @@ type AssocationHandleFunc func(assocation net.Conn) // function for each association. It uses a NAT map to track active associations // and handles their lifecycle. func PacketServe(clientConn net.PacketConn, handle AssocationHandleFunc, metrics NATMetrics) { + // This goroutine continuously reads from clientConn and sends the received data + // to readCh. It uses a buffer pool (readBufPool) to efficiently manage buffers + // and minimize allocations. The LazySlice is sent along with the read event + // to allow the receiver to release the buffer back to the pool after processing. + readCh := make(chan readEvent, 10) + go func() { + for { + lazySlice := readBufPool.LazySlice() + buffer := lazySlice.Acquire() + n, addr, err := clientConn.ReadFrom(buffer) + if err != nil { + lazySlice.Release() + if errors.Is(err, net.ErrClosed) { + readCh <- readEvent{err: err} + return + } + slog.Warn("Failed to read from client. Continuing to listen.", "err", err) + continue + } + readCh <- readEvent{ + poolSlice: lazySlice, + pkt: buffer[:n], + addr: addr, + } + } + }() + nm := newNATmap() defer nm.Close() - buffer := make([]byte, serverUDPBufferSize) + // This loop handles events from closeCh (connection closures) and readCh + // (incoming data). It removes NAT entries for closed connections and processes + // incoming data packets. The loop also ensures that buffers acquired from + // the readBufPool are released back to the pool after processing is complete. + closeCh := make(chan net.Addr, 10) for { - n, addr, err := clientConn.ReadFrom(buffer) - if err != nil { - if errors.Is(err, net.ErrClosed) { - break + select { + case addr := <-closeCh: + metrics.RemoveNATEntry() + nm.Del(addr) + case read := <-readCh: + if read.err != nil { + return } - slog.Warn("Failed to read from client. Continuing to listen.", "err", err) - continue - } - pkt := buffer[:n] - - // TODO: Include server address in the NAT key as well. - conn := nm.Get(addr.String()) - if conn == nil { - conn = &natconn{ - Conn: &packetConnWrapper{PacketConn: clientConn, raddr: addr}, - readCh: make(chan []byte, 1), + + poolSlice := read.poolSlice + pkt := read.pkt + addr := read.addr + + // TODO: Include server address in the NAT key as well. + conn := nm.Get(addr.String()) + if conn == nil { + conn = &natconn{ + PacketConn: clientConn, + raddr: addr, + closeCh: closeCh, + doneCh: make(chan struct{}), + readBufCh: make(chan []byte, 1), + bytesReadCh: make(chan int, 1), + } + metrics.AddNATEntry() + nm.Add(addr, conn) + go handle(conn) } - metrics.AddNATEntry() - deleteEntry := nm.Add(addr, conn) - go func(conn *natconn) { - defer func() { - conn.Close() - deleteEntry() - metrics.RemoveNATEntry() - }() - handle(conn) - }(conn) + readBuf, ok := <-conn.readBufCh + if !ok { + poolSlice.Release() + continue + } + copy(readBuf, pkt) + poolSlice.Release() + conn.bytesReadCh <- len(pkt) } - conn.readCh <- pkt } } +type readEvent struct { + poolSlice slicepool.LazySlice + pkt []byte + addr net.Addr + err error +} + +// natconn adapts a [net.Conn] to provide a synchronized reading mechanism for NAT traversal. +// +// The application provides the buffer to `Read()` (BYOB: Bring Your Own Buffer!) +// which minimizes buffer allocations and copying. type natconn struct { - net.Conn - readCh chan []byte + net.PacketConn + raddr net.Addr + closeCh chan net.Addr + doneCh chan struct{} + + // readBufCh provides a buffer to copy incoming packet data into. + readBufCh chan []byte + + // bytesReadCh is used to signal the availability of new data and carries + // the length of the received packet. + bytesReadCh chan int } var _ net.Conn = (*natconn)(nil) func (c *natconn) Read(p []byte) (int, error) { select { - case pkt := <-c.readCh: - if pkt == nil { - break + case <-c.doneCh: + c.closeCh <- c.raddr + return 0, net.ErrClosed + case c.readBufCh <- p: + n, ok := <-c.bytesReadCh + if !ok { + c.closeCh <- c.raddr + return 0, net.ErrClosed } - return copy(p, pkt), nil - case <-time.After(30 * time.Second): - break + return n, nil } - return 0, io.EOF +} + +func (c *natconn) Write(b []byte) (n int, err error) { + return c.PacketConn.WriteTo(b, c.raddr) } func (c *natconn) Close() error { - close(c.readCh) - c.Conn.Close() + close(c.doneCh) + close(c.bytesReadCh) return nil } +func (c *natconn) RemoteAddr() net.Addr { + return c.raddr +} + func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPAssocationMetrics) { if connMetrics == nil { connMetrics = &NoOpUDPAssocationMetrics{} @@ -225,36 +294,36 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA targetConn, err := h.targetConnFactory() if err != nil { - slog.Error("UDP: failed to create target connection", slog.Any("err", err)) + h.logger.Error("UDP: failed to create target connection", slog.Any("err", err)) return } - cipherBuf := h.bufferPool.Get().([]byte) - textBuf := h.bufferPool.Get().([]byte) - defer func() { - h.bufferPool.Put(cipherBuf) - h.bufferPool.Put(textBuf) - }() + cipherSlice := h.bufPool.LazySlice() + cipherBuf := cipherSlice.Acquire() + defer cipherSlice.Release() + + textLazySlice := h.bufPool.LazySlice() var cryptoKey *shadowsocks.EncryptionKey var proxyTargetBytes int for { clientProxyBytes, err := clientAssociation.Read(cipherBuf) if errors.Is(err, net.ErrClosed) { + cipherSlice.Release() return } debugUDPAddr(h.logger, "Outbound packet.", clientAssociation.RemoteAddr(), slog.Int("bytes", clientProxyBytes)) - cipherData := cipherBuf[:clientProxyBytes] connError := func() *onet.ConnectionError { defer func() { if r := recover(); r != nil { - slog.Error("Panic in UDP loop. Continuing to listen.", "err", r) + h.logger.Error("Panic in UDP loop. Continuing to listen.", "err", r) debug.PrintStack() } - slog.LogAttrs(nil, slog.LevelDebug, "UDP: Done", slog.String("address", clientAssociation.RemoteAddr().String())) + h.logger.LogAttrs(nil, slog.LevelDebug, "UDP: Done", slog.String("address", clientAssociation.RemoteAddr().String())) }() + cipherData := cipherBuf[:clientProxyBytes] var textData []byte var err error @@ -262,9 +331,11 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA ip := clientAssociation.RemoteAddr().(*net.UDPAddr).AddrPort().Addr() var keyID string var cryptoKey *shadowsocks.EncryptionKey + textBuf := textLazySlice.Acquire() unpackStart := time.Now() textData, keyID, cryptoKey, err = findAccessKeyUDP(ip, textBuf, cipherData, h.ciphers, h.logger) timeToCipher := time.Since(unpackStart) + textLazySlice.Release() h.ssm.AddCipherSearch(err == nil, timeToCipher) if err != nil { @@ -273,9 +344,10 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA connMetrics.AddAuthenticated(keyID) go func() { - defer connMetrics.AddClosed() timedCopy(clientAssociation, targetConn, cryptoKey, connMetrics, h.logger) + connMetrics.AddClosed() targetConn.Close() + clientAssociation.Close() }() } else { @@ -410,21 +482,13 @@ func (m *natmap) Get(key string) *natconn { return m.keyConn[key] } -func (m *natmap) set(key string, pc *natconn) { +func (m *natmap) Del(addr net.Addr) net.PacketConn { m.Lock() defer m.Unlock() - m.keyConn[key] = pc - return -} - -func (m *natmap) del(key string) *natconn { - m.Lock() - defer m.Unlock() - - entry, ok := m.keyConn[key] + entry, ok := m.keyConn[addr.String()] if ok { - delete(m.keyConn, key) + delete(m.keyConn, addr.String()) return entry } return nil @@ -432,12 +496,11 @@ func (m *natmap) del(key string) *natconn { // Add adds a new UDP NAT entry to the natmap and returns a closure to delete // the entry. -func (m *natmap) Add(addr net.Addr, pc *natconn) func() { - key := addr.String() - m.set(key, pc) - return func() { - m.del(key) - } +func (m *natmap) Add(addr net.Addr, pc *natconn) { + m.Lock() + defer m.Unlock() + + m.keyConn[addr.String()] = pc } func (m *natmap) Close() error { @@ -454,31 +517,6 @@ func (m *natmap) Close() error { return err } -// packetConnWrapper wraps a [net.PacketConn] and provides a [net.Conn] interface -// with a given remote address. -type packetConnWrapper struct { - net.PacketConn - raddr net.Addr -} - -var _ net.Conn = (*packetConnWrapper)(nil) - -// ReadFrom reads data from the connection. -func (pcw *packetConnWrapper) Read(b []byte) (n int, err error) { - n, _, err = pcw.PacketConn.ReadFrom(b) - return -} - -// WriteTo writes data to the connection. -func (pcw *packetConnWrapper) Write(b []byte) (n int, err error) { - return pcw.PacketConn.WriteTo(b, pcw.raddr) -} - -// RemoteAddr returns the remote network address. -func (pcw *packetConnWrapper) RemoteAddr() net.Addr { - return pcw.raddr -} - // Get the maximum length of the shadowsocks address header by parsing // and serializing an IPv6 address from the example range. var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345")) diff --git a/service/udp_test.go b/service/udp_test.go index 160c9697..8a7eb043 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -195,6 +195,24 @@ func startTestHandler() (AssociationHandler, func(target net.Addr, payload []byt }, targetConn } +func TestNatconnCloseWhileReading(t *testing.T) { + nc := &natconn{ + PacketConn: makePacketConn(), + raddr: &clientAddr, + doneCh: make(chan struct{}), + readBufCh: make(chan []byte, 1), + bytesReadCh: make(chan int, 1), + } + go func() { + buf := make([]byte, 1024) + nc.Read(buf) + }() + + err := nc.Close() + + assert.NoError(t, err, "Close should not panic or return an error") +} + func TestAssociationHandler_Handle_IPFilter(t *testing.T) { t.Run("RequirePublicIP blocks localhost", func(t *testing.T) { handler, sendPayload, targetConn := startTestHandler() @@ -462,56 +480,56 @@ func TestTimedPacketConn(t *testing.T) { func TestNATMap(t *testing.T) { t.Run("Empty", func(t *testing.T) { - nat := newNATmap() - if nat.Get("foo") != nil { + nm := newNATmap() + if nm.Get("foo") != nil { t.Error("Expected nil value from empty NAT map") } }) t.Run("Add", func(t *testing.T) { - nat := newNATmap() - addr1 := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} + nm := newNATmap() + addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} conn1 := &natconn{} - nat.Add(addr1, conn1) - assert.Equal(t, conn1, nat.Get(addr1.String()), "Get should return the correct connection") + nm.Add(addr, conn1) + assert.Equal(t, conn1, nm.Get(addr.String()), "Get should return the correct connection") conn2 := &natconn{} - nat.Add(addr1, conn2) - assert.Equal(t, conn2, nat.Get(addr1.String()), "Adding with the same address should overwrite the entry") + nm.Add(addr, conn2) + assert.Equal(t, conn2, nm.Get(addr.String()), "Adding with the same address should overwrite the entry") }) t.Run("Get", func(t *testing.T) { - nat := newNATmap() - addr1 := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} - conn1 := &natconn{} - nat.Add(addr1, conn1) + nm := newNATmap() + addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} + conn := &natconn{} + nm.Add(addr, conn) - assert.Equal(t, conn1, nat.Get(addr1.String()), "Get should return the correct connection for an existing address") + assert.Equal(t, conn, nm.Get(addr.String()), "Get should return the correct connection for an existing address") addr2 := &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 5678} - assert.Nil(t, nat.Get(addr2.String()), "Get should return nil for a non-existent address") + assert.Nil(t, nm.Get(addr2.String()), "Get should return nil for a non-existent address") }) - t.Run("closure_deletes", func(t *testing.T) { - nat := newNATmap() - addr1 := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} - conn1 := &natconn{} - deleteEntry := nat.Add(addr1, conn1) + t.Run("Del", func(t *testing.T) { + nm := newNATmap() + addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} + conn := &natconn{} + nm.Add(addr, conn) - deleteEntry() + nm.Del(addr) - assert.Nil(t, nat.Get(addr1.String()), "Get should return nil after deleting the entry") + assert.Nil(t, nm.Get(addr.String()), "Get should return nil after deleting the entry") }) t.Run("Close", func(t *testing.T) { - nat := newNATmap() - addr1 := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} + nm := newNATmap() + addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} pc := makePacketConn() - conn1 := &natconn{Conn: &packetConnWrapper{PacketConn: pc, raddr: addr1}} - nat.Add(addr1, conn1) + conn := &natconn{PacketConn: pc, raddr: addr} + nm.Add(addr, conn) - err := nat.Close() + err := nm.Close() assert.NoError(t, err, "Close should not return an error") // The underlying connection should be scheduled to close immediately.