diff --git a/service/udp.go b/service/udp.go index 57818a78..2431583f 100644 --- a/service/udp.go +++ b/service/udp.go @@ -48,6 +48,9 @@ type UDPAssocationMetrics interface { // Max UDP buffer size for the server code. const serverUDPBufferSize = 64 * 1024 +// Buffer pool used for reading UDP packets. +var readBufPool = slicepool.MakePool(serverUDPBufferSize) + // Wrapper for slog.Debug during UDP proxying. func debugUDP(l *slog.Logger, template string, cipherID string, attr slog.Attr) { // This is an optimization to reduce unnecessary allocations due to an interaction @@ -153,58 +156,97 @@ type AssocationHandleFunc func(assocation net.Conn) // function for each association. It uses a NAT map to track active associations // and handles their lifecycle. func PacketServe(clientConn net.PacketConn, handle AssocationHandleFunc, metrics NATMetrics) { + // This goroutine continuously reads from clientConn and sends the received data + // to readCh. It uses a buffer pool (readBufPool) to efficiently manage buffers + // and minimize allocations. The LazySlice is sent along with the read event + // to allow the receiver to release the buffer back to the pool after processing. + readCh := make(chan readEvent, 10) + go func() { + for { + lazySlice := readBufPool.LazySlice() + buffer := lazySlice.Acquire() + n, addr, err := clientConn.ReadFrom(buffer) + if err != nil { + lazySlice.Release() + if errors.Is(err, net.ErrClosed) { + readCh <- readEvent{err: err} + return + } + slog.Warn("Failed to read from client. Continuing to listen.", "err", err) + continue + } + readCh <- readEvent{ + poolSlice: lazySlice, + pkt: buffer[:n], + addr: addr, + } + } + }() + nm := newNATmap() defer nm.Close() - buffer := make([]byte, serverUDPBufferSize) + // This loop handles events from closeCh (connection closures) and readCh + // (incoming data). It removes NAT entries for closed connections and processes + // incoming data packets. The loop also ensures that buffers acquired from + // the readBufPool are released back to the pool after processing is complete. + closeCh := make(chan net.Addr, 10) for { - n, addr, err := clientConn.ReadFrom(buffer) - if err != nil { - if errors.Is(err, net.ErrClosed) { - break + select { + case addr := <-closeCh: + metrics.RemoveNATEntry() + nm.Del(addr) + case read := <-readCh: + if read.err != nil { + return } - slog.Warn("Failed to read from client. Continuing to listen.", "err", err) - continue - } - pkt := buffer[:n] - - // TODO: Include server address in the NAT key as well. - conn := nm.Get(addr.String()) - if conn == nil { - conn = &natconn{ - PacketConn: clientConn, - raddr: addr, - doneCh: make(chan struct{}), - readBufCh: make(chan []byte, 1), - bytesReadCh: make(chan int, 1), + + poolSlice := read.poolSlice + pkt := read.pkt + addr := read.addr + + // TODO: Include server address in the NAT key as well. + conn := nm.Get(addr.String()) + if conn == nil { + conn = &natconn{ + PacketConn: clientConn, + raddr: addr, + closeCh: closeCh, + doneCh: make(chan struct{}), + readBufCh: make(chan []byte, 1), + bytesReadCh: make(chan int, 1), + } + metrics.AddNATEntry() + nm.Add(addr, conn) + go handle(conn) } - metrics.AddNATEntry() - deleteEntry := nm.Add(addr, conn) - go func(conn net.Conn) { - defer func() { - conn.Close() - deleteEntry() - metrics.RemoveNATEntry() - }() - handle(conn) - }(conn) - } - readBuf, ok := <-conn.readBufCh - if !ok { - continue + readBuf, ok := <-conn.readBufCh + if !ok { + poolSlice.Release() + continue + } + copy(readBuf, pkt) + poolSlice.Release() + conn.bytesReadCh <- len(pkt) } - copy(readBuf, pkt) - conn.bytesReadCh <- n } } +type readEvent struct { + poolSlice slicepool.LazySlice + pkt []byte + addr net.Addr + err error +} + // natconn adapts a [net.Conn] to provide a synchronized reading mechanism for NAT traversal. // // The application provides the buffer to `Read()` (BYOB: Bring Your Own Buffer!) // which minimizes buffer allocations and copying. type natconn struct { net.PacketConn - raddr net.Addr - doneCh chan struct{} + raddr net.Addr + closeCh chan net.Addr + doneCh chan struct{} // readBufCh provides a buffer to copy incoming packet data into. readBufCh chan []byte @@ -218,14 +260,16 @@ var _ net.Conn = (*natconn)(nil) func (c *natconn) Read(p []byte) (int, error) { select { + case <-c.doneCh: + c.closeCh <- c.raddr + return 0, net.ErrClosed case c.readBufCh <- p: n, ok := <-c.bytesReadCh if !ok { + c.closeCh <- c.raddr return 0, net.ErrClosed } return n, nil - case <-c.doneCh: - return 0, net.ErrClosed } } @@ -235,9 +279,8 @@ func (c *natconn) Write(b []byte) (n int, err error) { func (c *natconn) Close() error { close(c.doneCh) - close(c.readBufCh) close(c.bytesReadCh) - return c.PacketConn.Close() + return nil } func (c *natconn) RemoteAddr() net.Addr { @@ -255,9 +298,9 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA return } - cipherLazySlice := h.bufPool.LazySlice() - cipherBuf := cipherLazySlice.Acquire() - defer cipherLazySlice.Release() + cipherSlice := h.bufPool.LazySlice() + cipherBuf := cipherSlice.Acquire() + defer cipherSlice.Release() textLazySlice := h.bufPool.LazySlice() @@ -266,7 +309,7 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA for { clientProxyBytes, err := clientAssociation.Read(cipherBuf) if errors.Is(err, net.ErrClosed) { - cipherLazySlice.Release() + cipherSlice.Release() return } debugUDPAddr(h.logger, "Outbound packet.", clientAssociation.RemoteAddr(), slog.Int("bytes", clientProxyBytes)) @@ -301,9 +344,10 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA connMetrics.AddAuthenticated(keyID) go func() { - defer connMetrics.AddClosed() timedCopy(clientAssociation, targetConn, cryptoKey, connMetrics, h.logger) + connMetrics.AddClosed() targetConn.Close() + clientAssociation.Close() }() } else { @@ -438,19 +482,25 @@ func (m *natmap) Get(key string) *natconn { return m.keyConn[key] } +func (m *natmap) Del(addr net.Addr) net.PacketConn { + m.Lock() + defer m.Unlock() + + entry, ok := m.keyConn[addr.String()] + if ok { + delete(m.keyConn, addr.String()) + return entry + } + return nil +} + // Add adds a new UDP NAT entry to the natmap and returns a closure to delete // the entry. -func (m *natmap) Add(addr net.Addr, pc *natconn) func() { +func (m *natmap) Add(addr net.Addr, pc *natconn) { m.Lock() defer m.Unlock() - key := addr.String() - m.keyConn[key] = pc - return func() { - m.Lock() - defer m.Unlock() - delete(m.keyConn, key) - } + m.keyConn[addr.String()] = pc } func (m *natmap) Close() error { diff --git a/service/udp_test.go b/service/udp_test.go index 6bba5f0b..8a7eb043 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -480,56 +480,56 @@ func TestTimedPacketConn(t *testing.T) { func TestNATMap(t *testing.T) { t.Run("Empty", func(t *testing.T) { - nat := newNATmap() - if nat.Get("foo") != nil { + nm := newNATmap() + if nm.Get("foo") != nil { t.Error("Expected nil value from empty NAT map") } }) t.Run("Add", func(t *testing.T) { - nat := newNATmap() - addr1 := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} + nm := newNATmap() + addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} conn1 := &natconn{} - nat.Add(addr1, conn1) - assert.Equal(t, conn1, nat.Get(addr1.String()), "Get should return the correct connection") + nm.Add(addr, conn1) + assert.Equal(t, conn1, nm.Get(addr.String()), "Get should return the correct connection") conn2 := &natconn{} - nat.Add(addr1, conn2) - assert.Equal(t, conn2, nat.Get(addr1.String()), "Adding with the same address should overwrite the entry") + nm.Add(addr, conn2) + assert.Equal(t, conn2, nm.Get(addr.String()), "Adding with the same address should overwrite the entry") }) t.Run("Get", func(t *testing.T) { - nat := newNATmap() - addr1 := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} - conn1 := &natconn{} - nat.Add(addr1, conn1) + nm := newNATmap() + addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} + conn := &natconn{} + nm.Add(addr, conn) - assert.Equal(t, conn1, nat.Get(addr1.String()), "Get should return the correct connection for an existing address") + assert.Equal(t, conn, nm.Get(addr.String()), "Get should return the correct connection for an existing address") addr2 := &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 5678} - assert.Nil(t, nat.Get(addr2.String()), "Get should return nil for a non-existent address") + assert.Nil(t, nm.Get(addr2.String()), "Get should return nil for a non-existent address") }) - t.Run("closure_deletes", func(t *testing.T) { - nat := newNATmap() - addr1 := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} - conn1 := &natconn{} - deleteEntry := nat.Add(addr1, conn1) + t.Run("Del", func(t *testing.T) { + nm := newNATmap() + addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} + conn := &natconn{} + nm.Add(addr, conn) - deleteEntry() + nm.Del(addr) - assert.Nil(t, nat.Get(addr1.String()), "Get should return nil after deleting the entry") + assert.Nil(t, nm.Get(addr.String()), "Get should return nil after deleting the entry") }) t.Run("Close", func(t *testing.T) { - nat := newNATmap() - addr1 := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} + nm := newNATmap() + addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} pc := makePacketConn() - conn1 := &natconn{PacketConn: pc, raddr: addr1} - nat.Add(addr1, conn1) + conn := &natconn{PacketConn: pc, raddr: addr} + nm.Add(addr, conn) - err := nat.Close() + err := nm.Close() assert.NoError(t, err, "Close should not return an error") // The underlying connection should be scheduled to close immediately.