Skip to content

Commit

Permalink
Refactor PacketServe to use events (close and read).
Browse files Browse the repository at this point in the history
  • Loading branch information
sbruens committed Dec 13, 2024
1 parent 36d4b27 commit f5d9ac3
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 80 deletions.
158 changes: 104 additions & 54 deletions service/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ type UDPAssocationMetrics interface {
// Max UDP buffer size for the server code.
const serverUDPBufferSize = 64 * 1024

// Buffer pool used for reading UDP packets.
var readBufPool = slicepool.MakePool(serverUDPBufferSize)

// Wrapper for slog.Debug during UDP proxying.
func debugUDP(l *slog.Logger, template string, cipherID string, attr slog.Attr) {
// This is an optimization to reduce unnecessary allocations due to an interaction
Expand Down Expand Up @@ -153,58 +156,97 @@ type AssocationHandleFunc func(assocation net.Conn)
// function for each association. It uses a NAT map to track active associations
// and handles their lifecycle.
func PacketServe(clientConn net.PacketConn, handle AssocationHandleFunc, metrics NATMetrics) {
// This goroutine continuously reads from clientConn and sends the received data
// to readCh. It uses a buffer pool (readBufPool) to efficiently manage buffers
// and minimize allocations. The LazySlice is sent along with the read event
// to allow the receiver to release the buffer back to the pool after processing.
readCh := make(chan readEvent, 10)
go func() {
for {
lazySlice := readBufPool.LazySlice()
buffer := lazySlice.Acquire()
n, addr, err := clientConn.ReadFrom(buffer)
if err != nil {
lazySlice.Release()
if errors.Is(err, net.ErrClosed) {
readCh <- readEvent{err: err}
return
}
slog.Warn("Failed to read from client. Continuing to listen.", "err", err)
continue
}
readCh <- readEvent{
poolSlice: lazySlice,
pkt: buffer[:n],
addr: addr,
}
}
}()

nm := newNATmap()
defer nm.Close()
buffer := make([]byte, serverUDPBufferSize)
// This loop handles events from closeCh (connection closures) and readCh
// (incoming data). It removes NAT entries for closed connections and processes
// incoming data packets. The loop also ensures that buffers acquired from
// the readBufPool are released back to the pool after processing is complete.
closeCh := make(chan net.Addr, 10)
for {
n, addr, err := clientConn.ReadFrom(buffer)
if err != nil {
if errors.Is(err, net.ErrClosed) {
break
select {
case addr := <-closeCh:
metrics.RemoveNATEntry()
nm.Del(addr)
case read := <-readCh:
if read.err != nil {
return
}
slog.Warn("Failed to read from client. Continuing to listen.", "err", err)
continue
}
pkt := buffer[:n]

// TODO: Include server address in the NAT key as well.
conn := nm.Get(addr.String())
if conn == nil {
conn = &natconn{
PacketConn: clientConn,
raddr: addr,
doneCh: make(chan struct{}),
readBufCh: make(chan []byte, 1),
bytesReadCh: make(chan int, 1),

poolSlice := read.poolSlice
pkt := read.pkt
addr := read.addr

// TODO: Include server address in the NAT key as well.
conn := nm.Get(addr.String())
if conn == nil {
conn = &natconn{
PacketConn: clientConn,
raddr: addr,
closeCh: closeCh,
doneCh: make(chan struct{}),
readBufCh: make(chan []byte, 1),
bytesReadCh: make(chan int, 1),
}
metrics.AddNATEntry()
nm.Add(addr, conn)
go handle(conn)
}
metrics.AddNATEntry()
deleteEntry := nm.Add(addr, conn)
go func(conn net.Conn) {
defer func() {
conn.Close()
deleteEntry()
metrics.RemoveNATEntry()
}()
handle(conn)
}(conn)
}
readBuf, ok := <-conn.readBufCh
if !ok {
continue
readBuf, ok := <-conn.readBufCh
if !ok {
poolSlice.Release()
continue
}
copy(readBuf, pkt)
poolSlice.Release()
conn.bytesReadCh <- len(pkt)
}
copy(readBuf, pkt)
conn.bytesReadCh <- n
}
}

type readEvent struct {
poolSlice slicepool.LazySlice
pkt []byte
addr net.Addr
err error
}

// natconn adapts a [net.Conn] to provide a synchronized reading mechanism for NAT traversal.
//
// The application provides the buffer to `Read()` (BYOB: Bring Your Own Buffer!)
// which minimizes buffer allocations and copying.
type natconn struct {
net.PacketConn
raddr net.Addr
doneCh chan struct{}
raddr net.Addr
closeCh chan net.Addr
doneCh chan struct{}

// readBufCh provides a buffer to copy incoming packet data into.
readBufCh chan []byte
Expand All @@ -218,14 +260,16 @@ var _ net.Conn = (*natconn)(nil)

func (c *natconn) Read(p []byte) (int, error) {
select {
case <-c.doneCh:
c.closeCh <- c.raddr
return 0, net.ErrClosed
case c.readBufCh <- p:
n, ok := <-c.bytesReadCh
if !ok {
c.closeCh <- c.raddr
return 0, net.ErrClosed
}
return n, nil
case <-c.doneCh:
return 0, net.ErrClosed
}
}

Expand All @@ -235,9 +279,8 @@ func (c *natconn) Write(b []byte) (n int, err error) {

func (c *natconn) Close() error {
close(c.doneCh)
close(c.readBufCh)
close(c.bytesReadCh)
return c.PacketConn.Close()
return nil
}

func (c *natconn) RemoteAddr() net.Addr {
Expand All @@ -255,9 +298,9 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA
return
}

cipherLazySlice := h.bufPool.LazySlice()
cipherBuf := cipherLazySlice.Acquire()
defer cipherLazySlice.Release()
cipherSlice := h.bufPool.LazySlice()
cipherBuf := cipherSlice.Acquire()
defer cipherSlice.Release()

textLazySlice := h.bufPool.LazySlice()

Expand All @@ -266,7 +309,7 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA
for {
clientProxyBytes, err := clientAssociation.Read(cipherBuf)
if errors.Is(err, net.ErrClosed) {
cipherLazySlice.Release()
cipherSlice.Release()
return
}
debugUDPAddr(h.logger, "Outbound packet.", clientAssociation.RemoteAddr(), slog.Int("bytes", clientProxyBytes))
Expand Down Expand Up @@ -301,9 +344,10 @@ func (h *associationHandler) Handle(clientAssociation net.Conn, connMetrics UDPA

connMetrics.AddAuthenticated(keyID)
go func() {
defer connMetrics.AddClosed()
timedCopy(clientAssociation, targetConn, cryptoKey, connMetrics, h.logger)
connMetrics.AddClosed()
targetConn.Close()
clientAssociation.Close()
}()

} else {
Expand Down Expand Up @@ -438,19 +482,25 @@ func (m *natmap) Get(key string) *natconn {
return m.keyConn[key]
}

func (m *natmap) Del(addr net.Addr) net.PacketConn {
m.Lock()
defer m.Unlock()

entry, ok := m.keyConn[addr.String()]
if ok {
delete(m.keyConn, addr.String())
return entry
}
return nil
}

// Add adds a new UDP NAT entry to the natmap and returns a closure to delete
// the entry.
func (m *natmap) Add(addr net.Addr, pc *natconn) func() {
func (m *natmap) Add(addr net.Addr, pc *natconn) {
m.Lock()
defer m.Unlock()

key := addr.String()
m.keyConn[key] = pc
return func() {
m.Lock()
defer m.Unlock()
delete(m.keyConn, key)
}
m.keyConn[addr.String()] = pc
}

func (m *natmap) Close() error {
Expand Down
52 changes: 26 additions & 26 deletions service/udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -480,56 +480,56 @@ func TestTimedPacketConn(t *testing.T) {

func TestNATMap(t *testing.T) {
t.Run("Empty", func(t *testing.T) {
nat := newNATmap()
if nat.Get("foo") != nil {
nm := newNATmap()
if nm.Get("foo") != nil {
t.Error("Expected nil value from empty NAT map")
}
})

t.Run("Add", func(t *testing.T) {
nat := newNATmap()
addr1 := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}
nm := newNATmap()
addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}
conn1 := &natconn{}

nat.Add(addr1, conn1)
assert.Equal(t, conn1, nat.Get(addr1.String()), "Get should return the correct connection")
nm.Add(addr, conn1)
assert.Equal(t, conn1, nm.Get(addr.String()), "Get should return the correct connection")

conn2 := &natconn{}
nat.Add(addr1, conn2)
assert.Equal(t, conn2, nat.Get(addr1.String()), "Adding with the same address should overwrite the entry")
nm.Add(addr, conn2)
assert.Equal(t, conn2, nm.Get(addr.String()), "Adding with the same address should overwrite the entry")
})

t.Run("Get", func(t *testing.T) {
nat := newNATmap()
addr1 := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}
conn1 := &natconn{}
nat.Add(addr1, conn1)
nm := newNATmap()
addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}
conn := &natconn{}
nm.Add(addr, conn)

assert.Equal(t, conn1, nat.Get(addr1.String()), "Get should return the correct connection for an existing address")
assert.Equal(t, conn, nm.Get(addr.String()), "Get should return the correct connection for an existing address")

addr2 := &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 5678}
assert.Nil(t, nat.Get(addr2.String()), "Get should return nil for a non-existent address")
assert.Nil(t, nm.Get(addr2.String()), "Get should return nil for a non-existent address")
})

t.Run("closure_deletes", func(t *testing.T) {
nat := newNATmap()
addr1 := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}
conn1 := &natconn{}
deleteEntry := nat.Add(addr1, conn1)
t.Run("Del", func(t *testing.T) {
nm := newNATmap()
addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}
conn := &natconn{}
nm.Add(addr, conn)

deleteEntry()
nm.Del(addr)

assert.Nil(t, nat.Get(addr1.String()), "Get should return nil after deleting the entry")
assert.Nil(t, nm.Get(addr.String()), "Get should return nil after deleting the entry")
})

t.Run("Close", func(t *testing.T) {
nat := newNATmap()
addr1 := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}
nm := newNATmap()
addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}
pc := makePacketConn()
conn1 := &natconn{PacketConn: pc, raddr: addr1}
nat.Add(addr1, conn1)
conn := &natconn{PacketConn: pc, raddr: addr}
nm.Add(addr, conn)

err := nat.Close()
err := nm.Close()
assert.NoError(t, err, "Close should not return an error")

// The underlying connection should be scheduled to close immediately.
Expand Down

0 comments on commit f5d9ac3

Please sign in to comment.