Skip to content

Commit

Permalink
Create noop metrics if nil.
Browse files Browse the repository at this point in the history
  • Loading branch information
sbruens committed Sep 16, 2024
1 parent 39da61b commit 213903d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 21 deletions.
26 changes: 17 additions & 9 deletions service/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ func StreamServe(accept StreamAcceptFunc, handle StreamHandleFunc) {
}

func (h *streamHandler) Handle(ctx context.Context, clientConn transport.StreamConn, connMetrics TCPConnMetrics) {
if connMetrics == nil {
connMetrics = &NoOpTCPConnMetrics{}
}
var proxyMetrics metrics.ProxyMetrics
measuredClientConn := metrics.MeasureConn(clientConn, &proxyMetrics.ProxyClient, &proxyMetrics.ClientProxy)
connStart := time.Now()
Expand All @@ -253,9 +256,7 @@ func (h *streamHandler) Handle(ctx context.Context, clientConn transport.StreamC
status = connError.Status
slog.LogAttrs(nil, slog.LevelDebug, "TCP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause))
}
if connMetrics != nil {
connMetrics.AddClosed(status, proxyMetrics, connDuration)
}
connMetrics.AddClosed(status, proxyMetrics, connDuration)
measuredClientConn.Close() // Closing after the metrics are added aids integration testing.
slog.LogAttrs(nil, slog.LevelDebug, "TCP: Done.", slog.String("status", status), slog.Duration("duration", connDuration))
}
Expand Down Expand Up @@ -327,9 +328,7 @@ func (h *streamHandler) handleConnection(ctx context.Context, outerConn transpor
h.absorbProbe(outerConn, connMetrics, authErr.Status, proxyMetrics)
return authErr
}
if connMetrics != nil {
connMetrics.AddAuthenticated(id)
}
connMetrics.AddAuthenticated(id)

// Read target address and dial it.
tgtAddr, err := getProxyRequest(innerConn)
Expand Down Expand Up @@ -359,9 +358,7 @@ func (h *streamHandler) absorbProbe(clientConn io.ReadCloser, connMetrics TCPCon
_, drainErr := io.Copy(io.Discard, clientConn) // drain socket
drainResult := drainErrToString(drainErr)
slog.LogAttrs(nil, slog.LevelDebug, "Drain error.", slog.Any("err", drainErr), slog.String("result", drainResult))
if connMetrics != nil {
connMetrics.AddProbe(status, drainResult, proxyMetrics.ClientProxy)
}
connMetrics.AddProbe(status, drainResult, proxyMetrics.ClientProxy)
}

func drainErrToString(drainErr error) string {
Expand All @@ -375,3 +372,14 @@ func drainErrToString(drainErr error) string {
return "other"
}
}

// NoOpTCPConnMetrics is a [TCPConnMetrics] that doesn't do anything. Useful in tests
// or if you don't want to track metrics.
type NoOpTCPConnMetrics struct{}

var _ TCPConnMetrics = (*NoOpTCPConnMetrics)(nil)

func (m *NoOpTCPConnMetrics) AddAuthenticated(accessKey string) {}
func (m *NoOpTCPConnMetrics) AddClosed(status string, data metrics.ProxyMetrics, duration time.Duration) {
}
func (m *NoOpTCPConnMetrics) AddProbe(status, drainResult string, clientProxyBytes int64) {}
26 changes: 14 additions & 12 deletions service/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,16 @@ type packetHandler struct {

// NewPacketHandler creates a UDPService
func NewPacketHandler(natTimeout time.Duration, cipherList CipherList, m UDPMetrics, ssMetrics ShadowsocksConnMetrics) PacketHandler {
return &packetHandler{natTimeout: natTimeout, ciphers: cipherList, m: m, ssm: ssMetrics, targetIPValidator: onet.RequirePublicIP}
if m == nil {
m = &NoOpUDPMetrics{}
}
return &packetHandler{
natTimeout: natTimeout,
ciphers: cipherList,
m: m,
ssm: ssMetrics,
targetIPValidator: onet.RequirePublicIP,
}
}

// PacketHandler is a running UDP shadowsocks proxy that can be stopped.
Expand Down Expand Up @@ -198,7 +207,7 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) {
slog.LogAttrs(nil, slog.LevelDebug, "UDP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause))
status = connError.Status
}
if targetConn != nil && targetConn.metrics != nil {
if targetConn != nil {
targetConn.metrics.AddPacketFromClient(status, int64(clientProxyBytes), int64(proxyTargetBytes))
}
}
Expand Down Expand Up @@ -343,18 +352,13 @@ func (m *natmap) del(key string) net.PacketConn {
}

func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, targetConn net.PacketConn, keyID string) *natconn {
var connMetrics UDPConnMetrics
if m.metrics != nil {
connMetrics = m.metrics.AddUDPNatEntry(clientAddr, keyID)
}
connMetrics := m.metrics.AddUDPNatEntry(clientAddr, keyID)
entry := m.set(clientAddr.String(), targetConn, cryptoKey, keyID, connMetrics)

m.running.Add(1)
go func() {
timedCopy(clientAddr, clientConn, entry, keyID)
if connMetrics != nil {
connMetrics.RemoveNatEntry()
}
connMetrics.RemoveNatEntry()
if pc := m.del(clientAddr.String()); pc != nil {
pc.Close()
}
Expand Down Expand Up @@ -450,9 +454,7 @@ func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natco
if expired {
break
}
if targetConn.metrics != nil {
targetConn.metrics.AddPacketFromTarget(status, int64(bodyLen), int64(proxyClientBytes))
}
targetConn.metrics.AddPacketFromTarget(status, int64(bodyLen), int64(proxyClientBytes))
}
}

Expand Down

0 comments on commit 213903d

Please sign in to comment.