diff --git a/cmd/outline-ss-server/metrics.go b/cmd/outline-ss-server/metrics.go index 780274a4..3b241a26 100644 --- a/cmd/outline-ss-server/metrics.go +++ b/cmd/outline-ss-server/metrics.go @@ -46,17 +46,19 @@ func newPrometheusServerMetrics() *serverMetrics { }), ports: prometheus.NewGauge(prometheus.GaugeOpts{ Name: "ports", - Help: "Count of open Shadowsocks ports", + Help: "Count of open ports", }), } } func (m *serverMetrics) Describe(ch chan<- *prometheus.Desc) { + m.buildInfo.Describe(ch) m.accessKeys.Describe(ch) m.ports.Describe(ch) } func (m *serverMetrics) Collect(ch chan<- prometheus.Metric) { + m.buildInfo.Describe(ch) m.accessKeys.Collect(ch) m.ports.Collect(ch) } diff --git a/cmd/outline-ss-server/server_test.go b/cmd/outline-ss-server/server_test.go index 20729a06..05999486 100644 --- a/cmd/outline-ss-server/server_test.go +++ b/cmd/outline-ss-server/server_test.go @@ -17,14 +17,17 @@ package main import ( "testing" "time" + + "github.com/Jigsaw-Code/outline-ss-server/prometheus" ) func TestRunSSServer(t *testing.T) { - m, err := newPrometheusOutlineMetrics(nil) + serverMetrics := newPrometheusServerMetrics() + serviceMetrics, err := prometheus.NewServiceMetrics(nil) if err != nil { - t.Fatalf("Failed to create Prometheus metrics: %v", err) + t.Fatalf("Failed to create Prometheus service metrics: %v", err) } - server, err := RunSSServer("config_example.yml", 30*time.Second, m, 10000) + server, err := RunSSServer("config_example.yml", 30*time.Second, serverMetrics, serviceMetrics, 10000) if err != nil { t.Fatalf("RunSSServer() error = %v", err) } diff --git a/prometheus/metrics.go b/prometheus/metrics.go new file mode 100644 index 00000000..186cba7c --- /dev/null +++ b/prometheus/metrics.go @@ -0,0 +1,574 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "fmt" + "log/slog" + "net" + "net/netip" + "sync" + "time" + + "github.com/Jigsaw-Code/outline-ss-server/ipinfo" + "github.com/Jigsaw-Code/outline-ss-server/service" + "github.com/Jigsaw-Code/outline-ss-server/service/metrics" + "github.com/prometheus/client_golang/prometheus" +) + +// `now` is stubbable for testing. +var now = time.Now + +func newTimeToCipherVec(proto string) (prometheus.ObserverVec, error) { + vec := prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "time_to_cipher_ms", + Help: "Time needed to find the cipher", + Buckets: []float64{0.1, 1, 10, 100, 1000}, + }, []string{"proto", "found_key"}) + return vec.CurryWith(map[string]string{"proto": proto}) +} + +type proxyCollector struct { + // NOTE: New metrics need to be added to `newProxyCollector()`, `Describe()` and `Collect()`. + dataBytesPerKey *prometheus.CounterVec + dataBytesPerLocation *prometheus.CounterVec +} + +func newProxyCollector(proto string) (*proxyCollector, error) { + dataBytesPerKey, err := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "data_bytes", + Help: "Bytes transferred by the proxy, per access key", + }, []string{"proto", "dir", "access_key"}).CurryWith(map[string]string{"proto": proto}) + if err != nil { + return nil, err + } + dataBytesPerLocation, err := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "data_bytes_per_location", + Help: "Bytes transferred by the proxy, per location", + }, []string{"proto", "dir", "location", "asn", "asorg"}).CurryWith(map[string]string{"proto": proto}) + if err != nil { + return nil, err + } + return &proxyCollector{ + dataBytesPerKey: dataBytesPerKey, + dataBytesPerLocation: dataBytesPerLocation, + }, nil +} + +func (c *proxyCollector) Describe(ch chan<- *prometheus.Desc) { + c.dataBytesPerKey.Describe(ch) + c.dataBytesPerLocation.Describe(ch) +} + +func (c *proxyCollector) Collect(ch chan<- prometheus.Metric) { + c.dataBytesPerKey.Collect(ch) + c.dataBytesPerLocation.Collect(ch) +} + +func (c *proxyCollector) addClientTarget(clientProxyBytes, proxyTargetBytes int64, accessKey string, clientInfo ipinfo.IPInfo) { + addIfNonZero(clientProxyBytes, c.dataBytesPerKey, "c>p", accessKey) + addIfNonZero(clientProxyBytes, c.dataBytesPerLocation, "c>p", clientInfo.CountryCode.String(), asnLabel(clientInfo.ASN.Number), clientInfo.ASN.Organization) + addIfNonZero(proxyTargetBytes, c.dataBytesPerKey, "p>t", accessKey) + addIfNonZero(proxyTargetBytes, c.dataBytesPerLocation, "p>t", clientInfo.CountryCode.String(), asnLabel(clientInfo.ASN.Number), clientInfo.ASN.Organization) +} + +func (c *proxyCollector) addTargetClient(targetProxyBytes, proxyClientBytes int64, accessKey string, clientInfo ipinfo.IPInfo) { + addIfNonZero(targetProxyBytes, c.dataBytesPerKey, "p 0 { + counterVec.WithLabelValues(lvs...).Add(float64(value)) + } +} + +func asnLabel(asn int) string { + if asn == 0 { + return "" + } + return fmt.Sprint(asn) +} + +// Converts a [net.Addr] to an [IPKey]. +func toIPKey(addr net.Addr, accessKey string) (*IPKey, error) { + hostname, _, err := net.SplitHostPort(addr.String()) + if err != nil { + return nil, fmt.Errorf("failed to create IPKey: %w", err) + } + ip, err := netip.ParseAddr(hostname) + if err != nil { + return nil, fmt.Errorf("failed to create IPKey: %w", err) + } + return &IPKey{ip, accessKey}, nil +} diff --git a/prometheus/metrics_test.go b/prometheus/metrics_test.go new file mode 100644 index 00000000..5dfcf05a --- /dev/null +++ b/prometheus/metrics_test.go @@ -0,0 +1,226 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "net" + "strings" + "testing" + "time" + + "github.com/Jigsaw-Code/outline-ss-server/ipinfo" + "github.com/Jigsaw-Code/outline-ss-server/service/metrics" + "github.com/op/go-logging" + "github.com/prometheus/client_golang/prometheus" + promtest "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/require" +) + +type noopMap struct{} + +func (*noopMap) GetIPInfo(ip net.IP) (ipinfo.IPInfo, error) { + return ipinfo.IPInfo{}, nil +} + +type fakeAddr string + +func (a fakeAddr) String() string { return string(a) } +func (a fakeAddr) Network() string { return "" } + +// Sets the processing clock to be t until changed. +func setNow(t time.Time) { + now = func() time.Time { + return t + } +} + +func init() { + logging.SetLevel(logging.INFO, "") +} + +type fakeConn struct { + net.Conn +} + +func (c *fakeConn) LocalAddr() net.Addr { + return fakeAddr("127.0.0.1:9") +} + +func (c *fakeConn) RemoteAddr() net.Addr { + return fakeAddr("127.0.0.1:10") +} + +func TestMethodsDontPanic(t *testing.T) { + ssMetrics, _ := NewServiceMetrics(nil) + proxyMetrics := metrics.ProxyMetrics{ + ClientProxy: 1, + ProxyTarget: 2, + TargetProxy: 3, + ProxyClient: 4, + } + addr := fakeAddr("127.0.0.1:9") + + tcpMetrics := ssMetrics.AddOpenTCPConnection(&fakeConn{}) + tcpMetrics.AddAuthenticated("0") + tcpMetrics.AddClosed("OK", proxyMetrics, 10*time.Millisecond) + tcpMetrics.AddProbe("ERR_CIPHER", "eof", proxyMetrics.ClientProxy) + + udpMetrics := ssMetrics.AddUDPNatEntry(addr, "key-1") + udpMetrics.AddPacketFromClient("OK", 10, 20) + udpMetrics.AddPacketFromTarget("OK", 10, 20) + udpMetrics.RemoveNatEntry() + + ssMetrics.tcpServiceMetrics.AddCipherSearch(true, 10*time.Millisecond) + ssMetrics.udpServiceMetrics.AddCipherSearch(true, 10*time.Millisecond) +} + +func TestASNLabel(t *testing.T) { + require.Equal(t, "", asnLabel(0)) + require.Equal(t, "100", asnLabel(100)) +} + +func TestTunnelTime(t *testing.T) { + t.Run("PerKey", func(t *testing.T) { + setNow(time.Date(2010, 1, 2, 3, 4, 5, .0, time.Local)) + ssMetrics, _ := NewServiceMetrics(nil) + reg := prometheus.NewPedanticRegistry() + reg.MustRegister(ssMetrics) + + connMetrics := ssMetrics.AddOpenTCPConnection(&fakeConn{}) + connMetrics.AddAuthenticated("key-1") + setNow(time.Date(2010, 1, 2, 3, 4, 20, .0, time.Local)) + + expected := strings.NewReader(` + # HELP tunnel_time_seconds Tunnel time, per access key. + # TYPE tunnel_time_seconds counter + tunnel_time_seconds{access_key="key-1"} 15 + `) + err := promtest.GatherAndCompare( + reg, + expected, + "tunnel_time_seconds", + ) + require.NoError(t, err, "unexpected metric value found") + }) + + t.Run("PerLocation", func(t *testing.T) { + setNow(time.Date(2010, 1, 2, 3, 4, 5, .0, time.Local)) + ssMetrics, _ := NewServiceMetrics(&noopMap{}) + reg := prometheus.NewPedanticRegistry() + reg.MustRegister(ssMetrics) + + connMetrics := ssMetrics.AddOpenTCPConnection(&fakeConn{}) + connMetrics.AddAuthenticated("key-1") + setNow(time.Date(2010, 1, 2, 3, 4, 10, .0, time.Local)) + + expected := strings.NewReader(` + # HELP tunnel_time_seconds_per_location Tunnel time, per location. + # TYPE tunnel_time_seconds_per_location counter + tunnel_time_seconds_per_location{asn="",asorg="",location="XL"} 5 + `) + err := promtest.GatherAndCompare( + reg, + expected, + "tunnel_time_seconds_per_location", + ) + require.NoError(t, err, "unexpected metric value found") + }) +} + +func TestTunnelTimePerKeyDoesNotPanicOnUnknownClosedConnection(t *testing.T) { + reg := prometheus.NewPedanticRegistry() + ssMetrics, _ := NewServiceMetrics(nil) + + connMetrics := ssMetrics.AddOpenTCPConnection(&fakeConn{}) + connMetrics.AddClosed("OK", metrics.ProxyMetrics{}, time.Minute) + + err := promtest.GatherAndCompare( + reg, + strings.NewReader(""), + "tunnel_time_seconds", + ) + require.NoError(t, err, "unexpectedly found metric value") +} + +func BenchmarkOpenTCP(b *testing.B) { + ssMetrics, _ := NewServiceMetrics(nil) + conn := &fakeConn{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + ssMetrics.AddOpenTCPConnection(conn) + } +} + +func BenchmarkCloseTCP(b *testing.B) { + ssMetrics, _ := NewServiceMetrics(nil) + connMetrics := ssMetrics.AddOpenTCPConnection(&fakeConn{}) + accessKey := "key 1" + status := "OK" + data := metrics.ProxyMetrics{} + duration := time.Minute + b.ResetTimer() + for i := 0; i < b.N; i++ { + connMetrics.AddAuthenticated(accessKey) + connMetrics.AddClosed(status, data, duration) + } +} + +func BenchmarkProbe(b *testing.B) { + ssMetrics, _ := NewServiceMetrics(nil) + connMetrics := ssMetrics.AddOpenTCPConnection(&fakeConn{}) + status := "ERR_REPLAY" + drainResult := "other" + data := metrics.ProxyMetrics{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + connMetrics.AddProbe(status, drainResult, data.ClientProxy) + } +} + +func BenchmarkClientUDP(b *testing.B) { + ssMetrics, _ := NewServiceMetrics(nil) + addr := fakeAddr("127.0.0.1:9") + accessKey := "key 1" + udpMetrics := ssMetrics.AddUDPNatEntry(addr, accessKey) + status := "OK" + size := int64(1000) + b.ResetTimer() + for i := 0; i < b.N; i++ { + udpMetrics.AddPacketFromClient(status, size, size) + } +} + +func BenchmarkTargetUDP(b *testing.B) { + ssMetrics, _ := NewServiceMetrics(nil) + addr := fakeAddr("127.0.0.1:9") + accessKey := "key 1" + udpMetrics := ssMetrics.AddUDPNatEntry(addr, accessKey) + status := "OK" + size := int64(1000) + b.ResetTimer() + for i := 0; i < b.N; i++ { + udpMetrics.AddPacketFromTarget(status, size, size) + } +} + +func BenchmarkNAT(b *testing.B) { + ssMetrics, _ := NewServiceMetrics(nil) + addr := fakeAddr("127.0.0.1:9") + b.ResetTimer() + for i := 0; i < b.N; i++ { + udpMetrics := ssMetrics.AddUDPNatEntry(addr, "key-0") + udpMetrics.RemoveNatEntry() + } +} diff --git a/service/metrics_test.go b/service/metrics_test.go index 1e2c8b95..288b695c 100644 --- a/service/metrics_test.go +++ b/service/metrics_test.go @@ -21,7 +21,6 @@ import ( "time" "github.com/Jigsaw-Code/outline-ss-server/ipinfo" - "github.com/Jigsaw-Code/outline-ss-server/service/metrics" "github.com/op/go-logging" "github.com/prometheus/client_golang/prometheus" promtest "github.com/prometheus/client_golang/prometheus/testutil" @@ -63,166 +62,49 @@ func (c *fakeConn) RemoteAddr() net.Addr { } func TestMethodsDontPanic(t *testing.T) { - ssMetrics, _ := newPrometheusOutlineMetrics(nil) - proxyMetrics := metrics.ProxyMetrics{ - ClientProxy: 1, - ProxyTarget: 2, - TargetProxy: 3, - ProxyClient: 4, - } - addr := fakeAddr("127.0.0.1:9") - ssMetrics.SetBuildInfo("0.0.0-test") - ssMetrics.SetNumAccessKeys(20, 2) - - tcpMetrics := ssMetrics.AddOpenTCPConnection(&fakeConn{}) - tcpMetrics.AddAuthenticated("0") - tcpMetrics.AddClosed("OK", proxyMetrics, 10*time.Millisecond) - tcpMetrics.AddProbe("ERR_CIPHER", "eof", proxyMetrics.ClientProxy) - - udpMetrics := ssMetrics.AddUDPNatEntry(addr, "key-1") - udpMetrics.AddPacketFromClient("OK", 10, 20) - udpMetrics.AddPacketFromTarget("OK", 10, 20) - udpMetrics.RemoveNatEntry() - - ssMetrics.tcpServiceMetrics.AddCipherSearch(true, 10*time.Millisecond) - ssMetrics.udpServiceMetrics.AddCipherSearch(true, 10*time.Millisecond) -} - -func TestASNLabel(t *testing.T) { - require.Equal(t, "", asnLabel(0)) - require.Equal(t, "100", asnLabel(100)) + m := newPrometheusServerMetrics() + m.SetVersion("0.0.0-test") + m.SetNumAccessKeys(20, 2) } -func TestTunnelTime(t *testing.T) { - t.Run("PerKey", func(t *testing.T) { - setNow(time.Date(2010, 1, 2, 3, 4, 5, .0, time.Local)) - ssMetrics, _ := newPrometheusOutlineMetrics(nil) - reg := prometheus.NewPedanticRegistry() - reg.MustRegister(ssMetrics) - - connMetrics := ssMetrics.AddOpenTCPConnection(&fakeConn{}) - connMetrics.AddAuthenticated("key-1") - setNow(time.Date(2010, 1, 2, 3, 4, 20, .0, time.Local)) - - expected := strings.NewReader(` - # HELP tunnel_time_seconds Tunnel time, per access key. - # TYPE tunnel_time_seconds counter - tunnel_time_seconds{access_key="key-1"} 15 - `) - err := promtest.GatherAndCompare( - reg, - expected, - "tunnel_time_seconds", - ) - require.NoError(t, err, "unexpected metric value found") - }) - - t.Run("PerLocation", func(t *testing.T) { - setNow(time.Date(2010, 1, 2, 3, 4, 5, .0, time.Local)) - ssMetrics, _ := newPrometheusOutlineMetrics(&noopMap{}) - reg := prometheus.NewPedanticRegistry() - reg.MustRegister(ssMetrics) - - connMetrics := ssMetrics.AddOpenTCPConnection(&fakeConn{}) - connMetrics.AddAuthenticated("key-1") - setNow(time.Date(2010, 1, 2, 3, 4, 10, .0, time.Local)) - - expected := strings.NewReader(` - # HELP tunnel_time_seconds_per_location Tunnel time, per location. - # TYPE tunnel_time_seconds_per_location counter - tunnel_time_seconds_per_location{asn="",asorg="",location="XL"} 5 - `) - err := promtest.GatherAndCompare( - reg, - expected, - "tunnel_time_seconds_per_location", - ) - require.NoError(t, err, "unexpected metric value found") - }) -} - -func TestTunnelTimePerKeyDoesNotPanicOnUnknownClosedConnection(t *testing.T) { +func TestSetVersion(t *testing.T) { + m := newPrometheusServerMetrics() reg := prometheus.NewPedanticRegistry() - ssMetrics, _ := newPrometheusOutlineMetrics(nil) + reg.MustRegister(m) - connMetrics := ssMetrics.AddOpenTCPConnection(&fakeConn{}) - connMetrics.AddClosed("OK", metrics.ProxyMetrics{}, time.Minute) + m.SetVersion("0.0.0-test") err := promtest.GatherAndCompare( reg, - strings.NewReader(""), - "tunnel_time_seconds", + strings.NewReader(` + # HELP build_info Information on the outline-ss-server build + # TYPE build_info gauge + build_info{version="0.0.0-test"} 1 + `), + "build_info", ) - require.NoError(t, err, "unexpectedly found metric value") -} - -func BenchmarkOpenTCP(b *testing.B) { - ssMetrics, _ := newPrometheusOutlineMetrics(nil) - conn := &fakeConn{} - b.ResetTimer() - for i := 0; i < b.N; i++ { - ssMetrics.AddOpenTCPConnection(conn) - } -} - -func BenchmarkCloseTCP(b *testing.B) { - ssMetrics, _ := newPrometheusOutlineMetrics(nil) - connMetrics := ssMetrics.AddOpenTCPConnection(&fakeConn{}) - accessKey := "key 1" - status := "OK" - data := metrics.ProxyMetrics{} - duration := time.Minute - b.ResetTimer() - for i := 0; i < b.N; i++ { - connMetrics.AddAuthenticated(accessKey) - connMetrics.AddClosed(status, data, duration) - } + require.NoError(t, err, "unexpected metric value found") } -func BenchmarkProbe(b *testing.B) { - ssMetrics, _ := newPrometheusOutlineMetrics(nil) - connMetrics := ssMetrics.AddOpenTCPConnection(&fakeConn{}) - status := "ERR_REPLAY" - drainResult := "other" - data := metrics.ProxyMetrics{} - b.ResetTimer() - for i := 0; i < b.N; i++ { - connMetrics.AddProbe(status, drainResult, data.ClientProxy) - } -} +func TestSetNumAccessKeys(t *testing.T) { + m := newPrometheusServerMetrics() + reg := prometheus.NewPedanticRegistry() + reg.MustRegister(m) -func BenchmarkClientUDP(b *testing.B) { - ssMetrics, _ := newPrometheusOutlineMetrics(nil) - addr := fakeAddr("127.0.0.1:9") - accessKey := "key 1" - udpMetrics := ssMetrics.AddUDPNatEntry(addr, accessKey) - status := "OK" - size := int64(1000) - b.ResetTimer() - for i := 0; i < b.N; i++ { - udpMetrics.AddPacketFromClient(status, size, size) - } -} + m.SetNumAccessKeys(1, 2) -func BenchmarkTargetUDP(b *testing.B) { - ssMetrics, _ := newPrometheusOutlineMetrics(nil) - addr := fakeAddr("127.0.0.1:9") - accessKey := "key 1" - udpMetrics := ssMetrics.AddUDPNatEntry(addr, accessKey) - status := "OK" - size := int64(1000) - b.ResetTimer() - for i := 0; i < b.N; i++ { - udpMetrics.AddPacketFromTarget(status, size, size) - } -} - -func BenchmarkNAT(b *testing.B) { - ssMetrics, _ := newPrometheusOutlineMetrics(nil) - addr := fakeAddr("127.0.0.1:9") - b.ResetTimer() - for i := 0; i < b.N; i++ { - udpMetrics := ssMetrics.AddUDPNatEntry(addr, "key-0") - udpMetrics.RemoveNatEntry() - } + err := promtest.GatherAndCompare( + reg, + strings.NewReader(` + # HELP keys Count of access keys + # TYPE keys gauge + keys 1 + # HELP ports Count of open ports + # TYPE ports gauge + ports 2 + `), + "keys", + "ports", + ) + require.NoError(t, err, "unexpected metric value found") } diff --git a/service/shadowsocks.go b/service/shadowsocks.go index 97329c3a..87814df8 100644 --- a/service/shadowsocks.go +++ b/service/shadowsocks.go @@ -14,9 +14,119 @@ package service -import "time" +import ( + "context" + "net" + "time" + + "github.com/Jigsaw-Code/outline-sdk/transport" +) + +const ( + // 59 seconds is most common timeout for servers that do not respond to invalid requests + tcpReadTimeout time.Duration = 59 * time.Second + + // A UDP NAT timeout of at least 5 minutes is recommended in RFC 4787 Section 4.3. + defaultNatTimeout time.Duration = 5 * time.Minute +) // ShadowsocksConnMetrics is used to report Shadowsocks related metrics on connections. type ShadowsocksConnMetrics interface { AddCipherSearch(accessKeyFound bool, timeToCipher time.Duration) } + +type ServiceMetrics interface { + UDPMetrics + AddOpenTCPConnection(conn net.Conn) TCPConnMetrics + AddCipherSearch(proto string, accessKeyFound bool, timeToCipher time.Duration) +} + +type Service interface { + HandleStream(ctx context.Context, conn transport.StreamConn) + HandlePacket(conn net.PacketConn) +} + +// Option is a Shadowsocks service constructor option. +type Option func(s *ssService) + +type ssService struct { + m ServiceMetrics + ciphers CipherList + natTimeout time.Duration + replayCache *ReplayCache + + sh StreamHandler + ph PacketHandler +} + +// NewShadowsocksService creates a new service +func NewShadowsocksService(opts ...Option) (Service, error) { + s := &ssService{} + + for _, opt := range opts { + opt(s) + } + + if s.natTimeout == 0 { + s.natTimeout = defaultNatTimeout + } + return s, nil +} + +// WithCiphers option function. +func WithCiphers(ciphers CipherList) Option { + return func(s *ssService) { + s.ciphers = ciphers + } +} + +// WithMetrics option function. +func WithMetrics(metrics ServiceMetrics) Option { + return func(s *ssService) { + s.m = metrics + } +} + +// WithReplayCache option function. +func WithReplayCache(replayCache *ReplayCache) Option { + return func(s *ssService) { + s.replayCache = replayCache + } +} + +// WithNatTimeout option function. +func WithNatTimeout(natTimeout time.Duration) Option { + return func(s *ssService) { + s.natTimeout = natTimeout + } +} + +// HandleStream handles a Shadowsocks stream-based connection. +func (s *ssService) HandleStream(ctx context.Context, conn transport.StreamConn) { + if s.sh == nil { + authFunc := NewShadowsocksStreamAuthenticator(s.ciphers, s.replayCache, &ssConnMetrics{ServiceMetrics: s.m, proto: "tcp"}) + // TODO: Register initial data metrics at zero. + s.sh = NewStreamHandler(authFunc, tcpReadTimeout) + } + connMetrics := s.m.AddOpenTCPConnection(conn) + s.sh.Handle(ctx, conn, connMetrics) +} + +// HandlePacket handles a Shadowsocks packet connection. +func (s *ssService) HandlePacket(conn net.PacketConn) { + if s.ph == nil { + s.ph = NewPacketHandler(s.natTimeout, s.ciphers, s.m, &ssConnMetrics{ServiceMetrics: s.m, proto: "udp"}) + } + s.ph.Handle(conn) +} + +type ssConnMetrics struct { + ServiceMetrics + proto string +} + +var _ ShadowsocksConnMetrics = (*ssConnMetrics)(nil) + +func (cm *ssConnMetrics) AddCipherSearch(accessKeyFound bool, timeToCipher time.Duration) { + cm.ServiceMetrics.AddCipherSearch(cm.proto, accessKeyFound, timeToCipher) +}