diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 3a04af0b..23bd143c 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -238,7 +238,7 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { return err } slog.Info("UDP service started.", "address", pc.LocalAddr().String()) - go ssService.HandlePacket(pc) + go service.PacketServe(pc, ssService.HandlePacket) } for _, serviceConfig := range config.Services { @@ -271,7 +271,7 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { return err } slog.Info("UDP service started.", "address", pc.LocalAddr().String()) - go ssService.HandlePacket(pc) + go service.PacketServe(pc, ssService.HandlePacket) } } totalCipherCount += len(serviceConfig.Keys) diff --git a/internal/integration_test/integration_test.go b/internal/integration_test/integration_test.go index 0994b90f..e54e8625 100644 --- a/internal/integration_test/integration_test.go +++ b/internal/integration_test/integration_test.go @@ -17,6 +17,7 @@ package integration_test import ( "bytes" "context" + "errors" "fmt" "io" "net" @@ -104,6 +105,9 @@ func startUDPEchoServer(t testing.TB) (*net.UDPConn, *sync.WaitGroup) { for { n, clientAddr, err := conn.ReadFromUDP(buf) if err != nil { + if errors.Is(err, net.ErrClosed) { + return + } t.Logf("Failed to read from UDP conn: %v", err) return } @@ -268,16 +272,25 @@ type fakeUDPConnMetrics struct { clientAddr net.Addr accessKey string up, down []udpRecord + mu sync.Mutex } var _ service.UDPConnMetrics = (*fakeUDPConnMetrics)(nil) func (m *fakeUDPConnMetrics) AddPacketFromClient(status string, clientProxyBytes, proxyTargetBytes int64) { + fmt.Println("===>AddPacketFromClient") + m.mu.Lock() + defer m.mu.Unlock() m.up = append(m.up, udpRecord{m.clientAddr, m.accessKey, status, clientProxyBytes, proxyTargetBytes}) } + func (m *fakeUDPConnMetrics) AddPacketFromTarget(status string, targetProxyBytes, proxyClientBytes int64) { + fmt.Println("===>AddPacketFromTarget") + m.mu.Lock() + defer m.mu.Unlock() m.down = append(m.down, udpRecord{m.clientAddr, m.accessKey, status, targetProxyBytes, proxyClientBytes}) } + func (m *fakeUDPConnMetrics) RemoveNatEntry() { // Not tested because it requires waiting for a long timeout. } @@ -315,7 +328,7 @@ func TestUDPEcho(t *testing.T) { proxy.SetTargetIPValidator(allowAll) done := make(chan struct{}) go func() { - proxy.Handle(proxyConn) + service.PacketServe(proxyConn, proxy.Handle) done <- struct{}{} }() @@ -361,29 +374,26 @@ func TestUDPEcho(t *testing.T) { snapshot := cipherList.SnapshotForClientIP(netip.Addr{}) keyID := snapshot[0].Value.(*service.CipherEntry).ID - if len(testMetrics.connMetrics) != 1 { - t.Errorf("Wrong NAT count: %d", len(testMetrics.connMetrics)) - } - if len(testMetrics.connMetrics[0].up) != 1 { - t.Errorf("Wrong number of packets sent: %v", testMetrics.connMetrics[0].up) - } else { - record := testMetrics.connMetrics[0].up[0] - require.Equal(t, conn.LocalAddr(), record.clientAddr, "Bad upstream metrics") - require.Equal(t, keyID, record.accessKey, "Bad upstream metrics") - require.Equal(t, "OK", record.status, "Bad upstream metrics") - require.Greater(t, record.in, record.out, "Bad upstream metrics") - require.Equal(t, int64(N), record.out, "Bad upstream metrics") - } - if len(testMetrics.connMetrics[0].down) != 1 { - t.Errorf("Wrong number of packets received: %v", testMetrics.connMetrics[0].down) - } else { - record := testMetrics.connMetrics[0].down[0] - require.Equal(t, conn.LocalAddr(), record.clientAddr, "Bad downstream metrics") - require.Equal(t, keyID, record.accessKey, "Bad downstream metrics") - require.Equal(t, "OK", record.status, "Bad downstream metrics") - require.Greater(t, record.out, record.in, "Bad downstream metrics") - require.Equal(t, int64(N), record.in, "Bad downstream metrics") - } + require.Lenf(t, testMetrics.connMetrics, 1, "Wrong NAT count") + + testMetrics.connMetrics[0].mu.Lock() + defer testMetrics.connMetrics[0].mu.Unlock() + + require.Lenf(t, testMetrics.connMetrics[0].up, 1, "Wrong number of packets sent") + record := testMetrics.connMetrics[0].up[0] + require.Equal(t, conn.LocalAddr(), record.clientAddr, "Bad upstream metrics") + require.Equal(t, keyID, record.accessKey, "Bad upstream metrics") + require.Equal(t, "OK", record.status, "Bad upstream metrics") + require.Greater(t, record.in, record.out, "Bad upstream metrics") + require.Equal(t, int64(N), record.out, "Bad upstream metrics") + + require.Lenf(t, testMetrics.connMetrics[0].down, 1, "Wrong number of packets received") + record = testMetrics.connMetrics[0].down[0] + require.Equal(t, conn.LocalAddr(), record.clientAddr, "Bad downstream metrics") + require.Equal(t, keyID, record.accessKey, "Bad downstream metrics") + require.Equal(t, "OK", record.status, "Bad downstream metrics") + require.Greater(t, record.out, record.in, "Bad downstream metrics") + require.Equal(t, int64(N), record.in, "Bad downstream metrics") } func BenchmarkTCPThroughput(b *testing.B) { @@ -548,7 +558,7 @@ func BenchmarkUDPEcho(b *testing.B) { proxy.SetTargetIPValidator(allowAll) done := make(chan struct{}) go func() { - proxy.Handle(server) + service.PacketServe(server, proxy.Handle) done <- struct{}{} }() @@ -592,7 +602,7 @@ func BenchmarkUDPManyKeys(b *testing.B) { proxy.SetTargetIPValidator(allowAll) done := make(chan struct{}) go func() { - proxy.Handle(proxyConn) + service.PacketServe(proxyConn, proxy.Handle) done <- struct{}{} }() diff --git a/service/shadowsocks.go b/service/shadowsocks.go index 636fa94e..7b8221d3 100644 --- a/service/shadowsocks.go +++ b/service/shadowsocks.go @@ -44,7 +44,7 @@ type ServiceMetrics interface { type Service interface { HandleStream(ctx context.Context, conn transport.StreamConn) - HandlePacket(conn net.PacketConn) + HandlePacket(conn net.Conn, pkt []byte) } // Option is a Shadowsocks service constructor option. @@ -137,8 +137,8 @@ func (s *ssService) HandleStream(ctx context.Context, conn transport.StreamConn) } // HandlePacket handles a Shadowsocks packet connection. -func (s *ssService) HandlePacket(conn net.PacketConn) { - s.ph.Handle(conn) +func (s *ssService) HandlePacket(conn net.Conn, pkt []byte) { + s.ph.Handle(conn, pkt) } type ssConnMetrics struct { diff --git a/service/udp.go b/service/udp.go index 94078029..b19c784b 100644 --- a/service/udp.go +++ b/service/udp.go @@ -43,6 +43,12 @@ type UDPMetrics interface { // Max UDP buffer size for the server code. const serverUDPBufferSize = 64 * 1024 +var bufferPool = sync.Pool{ + New: func() interface{} { + return make([]byte, serverUDPBufferSize) + }, +} + // Wrapper for slog.Debug during UDP proxying. func debugUDP(l *slog.Logger, template string, cipherID string, attr slog.Attr) { // This is an optimization to reduce unnecessary allocations due to an interaction @@ -83,7 +89,7 @@ type packetHandler struct { logger *slog.Logger natTimeout time.Duration ciphers CipherList - m UDPMetrics + nm *natmap ssm ShadowsocksConnMetrics targetIPValidator onet.TargetIPValidator } @@ -96,11 +102,11 @@ func NewPacketHandler(natTimeout time.Duration, cipherList CipherList, m UDPMetr if ssMetrics == nil { ssMetrics = &NoOpShadowsocksConnMetrics{} } + nm := newNATmap(natTimeout, m, noopLogger()) return &packetHandler{ logger: noopLogger(), - natTimeout: natTimeout, ciphers: cipherList, - m: m, + nm: nm, ssm: ssMetrics, targetIPValidator: onet.RequirePublicIP, } @@ -108,12 +114,11 @@ func NewPacketHandler(natTimeout time.Duration, cipherList CipherList, m UDPMetr // PacketHandler is a running UDP shadowsocks proxy that can be stopped. type PacketHandler interface { + Handle(conn net.Conn, pkt []byte) // SetLogger sets the logger used to log messages. Uses a no-op logger if nil. SetLogger(l *slog.Logger) // SetTargetIPValidator sets the function to be used to validate the target IP addresses. SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) - // Handle returns after clientConn closes and all the sub goroutines return. - Handle(clientConn net.PacketConn) } func (h *packetHandler) SetLogger(l *slog.Logger) { @@ -127,97 +132,126 @@ func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPVali h.targetIPValidator = targetIPValidator } -// Listen on addr for encrypted packets and basically do UDP NAT. -// We take the ciphers as a pointer because it gets replaced on config updates. -func (h *packetHandler) Handle(clientConn net.PacketConn) { - nm := newNATmap(h.natTimeout, h.m, h.logger) - defer nm.Close() - cipherBuf := make([]byte, serverUDPBufferSize) - textBuf := make([]byte, serverUDPBufferSize) +type PacketHandleFunc func(conn net.Conn, pkt []byte) + +// PacketServe listens for packets and calls `handle` to handle them until the connection +// returns [ErrClosed]. +func PacketServe(clientConn net.PacketConn, handle PacketHandleFunc) { + buffer := bufferPool.Get().([]byte) + defer bufferPool.Put(buffer) for { - clientProxyBytes, clientAddr, err := clientConn.ReadFrom(cipherBuf) - if errors.Is(err, net.ErrClosed) { - break + n, addr, err := clientConn.ReadFrom(buffer) + if err != nil { + if errors.Is(err, net.ErrClosed) { + break + } + slog.Warn("Failed to read from client. Continuing to listen.", "err", err) + continue } + pkt := buffer[:n] - var proxyTargetBytes int - var targetConn *natconn - - connError := func() (connError *onet.ConnectionError) { + func() { defer func() { if r := recover(); r != nil { - slog.Error("Panic in UDP loop: %v. Continuing to listen.", r) + slog.Error("Panic in UDP loop. Continuing to listen.", "err", r) debug.PrintStack() } }() + handle(&wrappedPacketConn{PacketConn: clientConn, raddr: addr}, pkt) + }() + } +} + +type wrappedPacketConn struct { + net.PacketConn + raddr net.Addr +} + +var _ net.Conn = (*wrappedPacketConn)(nil) + +func (pc *wrappedPacketConn) Read(p []byte) (int, error) { + n, _, err := pc.PacketConn.ReadFrom(p) + return n, err +} + +func (pc *wrappedPacketConn) RemoteAddr() net.Addr { + return pc.raddr +} + +func (pc *wrappedPacketConn) Write(b []byte) (n int, err error) { + return pc.PacketConn.WriteTo(b, pc.raddr) +} + +func (h *packetHandler) Handle(clientConn net.Conn, pkt []byte) { + debugUDPAddr(h.logger, "Outbound packet.", clientConn.RemoteAddr(), slog.Int("bytes", len(pkt))) + + var err error + var proxyTargetBytes int + var targetConn *natconn + + connError := func() (connError *onet.ConnectionError) { + defer slog.LogAttrs(nil, slog.LevelDebug, "UDP: Done", slog.String("address", clientConn.RemoteAddr().String())) + + var payload []byte + var tgtUDPAddr *net.UDPAddr + targetConn = h.nm.Get(clientConn.RemoteAddr().String()) + if targetConn == nil { + ip := clientConn.RemoteAddr().(*net.UDPAddr).AddrPort().Addr() + var textData []byte + var cryptoKey *shadowsocks.EncryptionKey + buffer := bufferPool.Get().([]byte) + defer bufferPool.Put(buffer) + unpackStart := time.Now() + textData, keyID, cryptoKey, err := findAccessKeyUDP(ip, buffer, pkt, h.ciphers, h.logger) + timeToCipher := time.Since(unpackStart) + h.ssm.AddCipherSearch(err == nil, timeToCipher) - // Error from ReadFrom if err != nil { - return onet.NewConnectionError("ERR_READ", "Failed to read from client", err) + return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack initial packet", err) } - defer slog.LogAttrs(nil, slog.LevelDebug, "UDP: Done", slog.String("address", clientAddr.String())) - debugUDPAddr(h.logger, "Outbound packet.", clientAddr, slog.Int("bytes", clientProxyBytes)) - - cipherData := cipherBuf[:clientProxyBytes] - var payload []byte - var tgtUDPAddr *net.UDPAddr - targetConn = nm.Get(clientAddr.String()) - if targetConn == nil { - ip := clientAddr.(*net.UDPAddr).AddrPort().Addr() - var textData []byte - var cryptoKey *shadowsocks.EncryptionKey - unpackStart := time.Now() - textData, keyID, cryptoKey, err := findAccessKeyUDP(ip, textBuf, cipherData, h.ciphers, h.logger) - timeToCipher := time.Since(unpackStart) - h.ssm.AddCipherSearch(err == nil, timeToCipher) - - if err != nil { - return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack initial packet", err) - } - var onetErr *onet.ConnectionError - if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil { - return onetErr - } - - udpConn, err := net.ListenPacket("udp", "") - if err != nil { - return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err) - } - targetConn = nm.Add(clientAddr, clientConn, cryptoKey, udpConn, keyID) - } else { - unpackStart := time.Now() - textData, err := shadowsocks.Unpack(nil, cipherData, targetConn.cryptoKey) - timeToCipher := time.Since(unpackStart) - h.ssm.AddCipherSearch(err == nil, timeToCipher) - - if err != nil { - return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack data from client", err) - } + var onetErr *onet.ConnectionError + if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil { + return onetErr + } - var onetErr *onet.ConnectionError - if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil { - return onetErr - } + udpConn, err := net.ListenPacket("udp", "") + if err != nil { + return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err) } + targetConn = h.nm.Add(clientConn, udpConn, cryptoKey, keyID) + } else { + unpackStart := time.Now() + textData, err := shadowsocks.Unpack(nil, pkt, targetConn.cryptoKey) + timeToCipher := time.Since(unpackStart) + h.ssm.AddCipherSearch(err == nil, timeToCipher) - debugUDPAddr(h.logger, "Proxy exit.", clientAddr, slog.Any("target", targetConn.LocalAddr())) - proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature if err != nil { - return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err) + return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack data from client", err) } - return nil - }() - status := "OK" - if connError != nil { - slog.LogAttrs(nil, slog.LevelDebug, "UDP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause)) - status = connError.Status + var onetErr *onet.ConnectionError + if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil { + return onetErr + } } - if targetConn != nil { - targetConn.metrics.AddPacketFromClient(status, int64(clientProxyBytes), int64(proxyTargetBytes)) + + debugUDPAddr(h.logger, "Proxy exit.", clientConn.RemoteAddr(), slog.Any("target", targetConn.LocalAddr())) + proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature + if err != nil { + return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err) } + return nil + }() + + status := "OK" + if connError != nil { + slog.LogAttrs(nil, slog.LevelDebug, "UDP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause)) + status = connError.Status + } + if targetConn != nil { + targetConn.metrics.AddPacketFromClient(status, int64(len(pkt)), int64(proxyTargetBytes)) } } @@ -355,14 +389,14 @@ func (m *natmap) del(key string) net.PacketConn { return nil } -func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, targetConn net.PacketConn, keyID string) *natconn { - connMetrics := m.metrics.AddUDPNatEntry(clientAddr, keyID) - entry := m.set(clientAddr.String(), targetConn, cryptoKey, connMetrics) +func (m *natmap) Add(clientConn net.Conn, targetConn net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, keyID string) *natconn { + connMetrics := m.metrics.AddUDPNatEntry(clientConn.RemoteAddr(), keyID) + entry := m.set(clientConn.RemoteAddr().String(), targetConn, cryptoKey, connMetrics) go func() { - timedCopy(clientAddr, clientConn, entry, m.logger) + timedCopy(clientConn, entry, m.logger) connMetrics.RemoveNatEntry() - if pc := m.del(clientAddr.String()); pc != nil { + if pc := m.del(clientConn.RemoteAddr().String()); pc != nil { pc.Close() } }() @@ -388,7 +422,7 @@ func (m *natmap) Close() error { var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345")) // copy from target to client until read timeout -func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natconn, l *slog.Logger) { +func timedCopy(clientConn net.Conn, targetConn *natconn, l *slog.Logger) { // pkt is used for in-place encryption of downstream UDP packets, with the layout // [padding?][salt][address][body][tag][extra] // Padding is only used if the address is IPv4. @@ -421,7 +455,7 @@ func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natco return onet.NewConnectionError("ERR_READ", "Failed to read from target", err) } - debugUDPAddr(l, "Got response.", clientAddr, slog.Any("target", raddr)) + debugUDPAddr(l, "Got response.", clientConn.RemoteAddr(), slog.Any("target", raddr)) srcAddr := socks.ParseAddr(raddr.String()) addrStart := bodyStart - len(srcAddr) // `plainTextBuf` concatenates the SOCKS address and body: @@ -442,7 +476,7 @@ func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natco if err != nil { return onet.NewConnectionError("ERR_PACK", "Failed to pack data to client", err) } - proxyClientBytes, err = clientConn.WriteTo(buf, clientAddr) + proxyClientBytes, err = clientConn.Write(buf) if err != nil { return onet.NewConnectionError("ERR_WRITE", "Failed to write to client", err) } diff --git a/service/udp_test.go b/service/udp_test.go index 6f620316..9bf657ce 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -139,7 +139,7 @@ func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator) *natTest handler.SetTargetIPValidator(validator) done := make(chan struct{}) go func() { - handler.Handle(clientConn) + PacketServe(clientConn, handler.Handle) done <- struct{}{} }() @@ -216,7 +216,7 @@ func setupNAT() (*fakePacketConn, *fakePacketConn, *natconn) { nat := newNATmap(timeout, &natTestMetrics{}, noopLogger()) clientConn := makePacketConn() targetConn := makePacketConn() - nat.Add(&clientAddr, clientConn, natCryptoKey, targetConn, "key id") + nat.Add(&wrappedPacketConn{PacketConn: clientConn, raddr: &clientAddr}, targetConn, natCryptoKey, "key id") entry := nat.Get(clientAddr.String()) return clientConn, targetConn, entry } @@ -481,7 +481,7 @@ func TestUDPEarlyClose(t *testing.T) { } testMetrics := &natTestMetrics{} const testTimeout = 200 * time.Millisecond - s := NewPacketHandler(testTimeout, cipherList, testMetrics, &fakeShadowsocksMetrics{}) + ph := NewPacketHandler(testTimeout, cipherList, testMetrics, &fakeShadowsocksMetrics{}) clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) if err != nil { @@ -489,7 +489,7 @@ func TestUDPEarlyClose(t *testing.T) { } require.Nil(t, clientConn.Close()) // This should return quickly without timing out. - s.Handle(clientConn) + PacketServe(clientConn, ph.Handle) } // Makes sure the UDP listener returns [io.ErrClosed] on reads and writes after Close().