From 2e2e15d6fce7c7b9824bf49da2bcec36803dc2a5 Mon Sep 17 00:00:00 2001 From: A A Date: Mon, 11 Apr 2022 17:10:15 +0300 Subject: [PATCH] Add rate limiting capabilities --- config_example.yml | 3 + go.mod | 1 + go.sum | 2 + integration_test/integration_test.go | 205 ++++++++++++++++++++++++++- server.go | 31 ++-- service/limiter.go | 128 +++++++++++++++++ service/limiter_test.go | 121 ++++++++++++++++ service/limiter_testing.go | 29 ++++ service/tcp.go | 27 ++-- service/tcp_test.go | 26 ++-- service/udp.go | 34 ++++- service/udp_test.go | 17 ++- 12 files changed, 576 insertions(+), 48 deletions(-) create mode 100644 service/limiter.go create mode 100644 service/limiter_test.go create mode 100644 service/limiter_testing.go diff --git a/config_example.yml b/config_example.yml index 8895b86d..474730d0 100644 --- a/config_example.yml +++ b/config_example.yml @@ -3,13 +3,16 @@ keys: port: 9000 cipher: chacha20-ietf-poly1305 secret: Secret0 + rate_limit: 128000 - id: user-1 port: 9000 cipher: chacha20-ietf-poly1305 secret: Secret1 + rate_limit: 128000 - id: user-2 port: 9001 cipher: chacha20-ietf-poly1305 secret: Secret2 + rate_limit: 128000 diff --git a/go.mod b/go.mod index fcfa48d5..016c778e 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/prometheus/procfs v0.1.3 // indirect github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8 // indirect + golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 // indirect google.golang.org/protobuf v1.23.0 // indirect gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect ) diff --git a/go.sum b/go.sum index 07d7952a..c1db5c2f 100644 --- a/go.sum +++ b/go.sum @@ -110,6 +110,8 @@ golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8 h1:AvbQYmiaaaza3cW3QXRyPo5kYgpFIzOAfeAAN7m3qQ4= golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 h1:M73Iuj3xbbb9Uk1DYhzydthsj6oOd6l9bpuFcNoUvTs= +golang.org/x/time v0.0.0-20220224211638-0e9765cccd65/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= diff --git a/integration_test/integration_test.go b/integration_test/integration_test.go index de674e7e..3ae2fdf9 100644 --- a/integration_test/integration_test.go +++ b/integration_test/integration_test.go @@ -97,6 +97,11 @@ func startUDPEchoServer(t testing.TB) (*net.UDPConn, *sync.WaitGroup) { return conn, &running } +func makeLimiter(cipherList service.CipherList) service.TrafficLimiter { + c := service.MakeTestTrafficLimiterConfig(cipherList) + return service.NewTrafficLimiter(&c) +} + func TestTCPEcho(t *testing.T) { echoListener, echoRunning := startTCPEchoServer(t) @@ -111,7 +116,7 @@ func TestTCPEcho(t *testing.T) { } replayCache := service.NewReplayCache(5) const testTimeout = 200 * time.Millisecond - proxy := service.NewTCPService(cipherList, &replayCache, &metrics.NoOpMetrics{}, testTimeout) + proxy := service.NewTCPService(cipherList, &replayCache, &metrics.NoOpMetrics{}, testTimeout, makeLimiter(cipherList)) proxy.SetTargetIPValidator(allowAll) go proxy.Serve(proxyListener) @@ -164,6 +169,192 @@ func TestTCPEcho(t *testing.T) { echoRunning.Wait() } +func TestTrafficLimiterTCP(t *testing.T) { + echoListener, echoRunning := startTCPEchoServer(t) + + proxyListener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + if err != nil { + t.Fatalf("ListenTCP failed: %v", err) + } + secrets := ss.MakeTestSecrets(1) + cipherList, err := service.MakeTestCiphers(secrets) + if err != nil { + t.Fatal(err) + } + replayCache := service.NewReplayCache(5) + const testTimeout = 5 * time.Second + + key := cipherList.SnapshotForClientIP(net.IP{})[0].Value.(*service.CipherEntry).ID + trafficLimiter := service.NewTrafficLimiter(&service.TrafficLimiterConfig{ + KeyToRateLimit: map[string]int{ + key: 1000, + }, + }) + + proxy := service.NewTCPService(cipherList, &replayCache, &metrics.NoOpMetrics{}, testTimeout, trafficLimiter) + proxy.SetTargetIPValidator(allowAll) + go proxy.Serve(proxyListener) + + proxyHost, proxyPort, err := net.SplitHostPort(proxyListener.Addr().String()) + if err != nil { + t.Fatal(err) + } + portNum, err := strconv.Atoi(proxyPort) + if err != nil { + t.Fatal(err) + } + client, err := client.NewClient(proxyHost, portNum, secrets[0], ss.TestCipher) + if err != nil { + t.Fatalf("Failed to create ShadowsocksClient: %v", err) + } + + doWriteRead := func(N int, repeats int) time.Duration { + up := ss.MakeTestPayload(N) + conn, err := client.DialTCP(nil, echoListener.Addr().String()) + defer conn.Close() + if err != nil { + t.Fatalf("ShadowsocksClient.DialTCP failed: %v", err) + } + start := time.Now() + down := make([]byte, N) + + for i := 0; i < repeats; i++ { + n, err := conn.Write(up) + if err != nil { + t.Fatal(err) + } + if n != N { + t.Fatalf("Tried to upload %d bytes, but only sent %d", N, n) + } + + n, err = io.ReadFull(conn, down) + if err != nil && err != io.EOF { + t.Fatal(err) + } + if n != N { + t.Fatalf("Expected to download %d bytes, but only received %d", N, n) + } + + if !bytes.Equal(up, down) { + t.Fatal("Echo mismatch") + } + } + + return time.Now().Sub(start) + } + + period1 := doWriteRead(200, 4) + if period1 < 500*time.Millisecond { + t.Fatalf("Write-read loop is too fast") + } + + time.Sleep(1 * time.Second) + + period2 := doWriteRead(200, 2) + if period2 > 100*time.Millisecond { + t.Fatalf("Write-read loop is too slow") + } + + period3 := doWriteRead(500, 2) + if period3 < 500*time.Millisecond { + t.Fatalf("Write-read loop is too fast") + } + + proxy.Stop() + echoListener.Close() + echoRunning.Wait() +} + +func TestTrafficLimiterUDP(t *testing.T) { + echoConn, echoRunning := startUDPEchoServer(t) + + proxyConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + if err != nil { + t.Fatalf("ListenTCP failed: %v", err) + } + secrets := ss.MakeTestSecrets(1) + cipherList, err := service.MakeTestCiphers(secrets) + if err != nil { + t.Fatal(err) + } + testMetrics := &fakeUDPMetrics{fakeLocation: "QQ"} + + key := cipherList.SnapshotForClientIP(net.IP{})[0].Value.(*service.CipherEntry).ID + trafficLimiter := service.NewTrafficLimiter(&service.TrafficLimiterConfig{ + KeyToRateLimit: map[string]int{ + key: 1000, + }, + }) + + proxy := service.NewUDPService(time.Hour, cipherList, testMetrics, trafficLimiter) + proxy.SetTargetIPValidator(allowAll) + go proxy.Serve(proxyConn) + + proxyHost, proxyPort, err := net.SplitHostPort(proxyConn.LocalAddr().String()) + if err != nil { + t.Fatal(err) + } + portNum, err := strconv.Atoi(proxyPort) + if err != nil { + t.Fatal(err) + } + client, err := client.NewClient(proxyHost, portNum, secrets[0], ss.TestCipher) + if err != nil { + t.Fatalf("Failed to create ShadowsocksClient: %v", err) + } + conn, err := client.ListenUDP(nil) + if err != nil { + t.Fatalf("ShadowsocksClient.ListenUDP failed: %v", err) + } + + run := func(N int, expectReadError bool) { + up := ss.MakeTestPayload(N) + n, err := conn.WriteTo(up, echoConn.LocalAddr()) + if err != nil { + t.Fatal(err) + } + if n != N { + t.Fatalf("Tried to upload %d bytes, but only sent %d", N, n) + } + + conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond)) + + down := make([]byte, N) + n, addr, err := conn.ReadFrom(down) + if err != nil { + if !expectReadError { + t.Fatalf("Unexpected read error: %v", err) + } + return + } else { + if expectReadError { + t.Fatalf("Expected read error") + } + } + if n != N { + t.Fatalf("Tried to download %d bytes, but only sent %d", N, n) + } + if addr.String() != echoConn.LocalAddr().String() { + t.Errorf("Reported address mismatch: %s != %s", addr.String(), echoConn.LocalAddr().String()) + } + + if !bytes.Equal(up, down) { + t.Fatal("Echo mismatch") + } + } + + for i := 0; i < 3; i++ { + run(300, false) + run(300, true) + time.Sleep(time.Second) + } + + conn.Close() + echoConn.Close() + echoRunning.Wait() + proxy.GracefulStop() +} + type statusMetrics struct { metrics.NoOpMetrics sync.Mutex @@ -184,7 +375,7 @@ func TestRestrictedAddresses(t *testing.T) { require.NoError(t, err) const testTimeout = 200 * time.Millisecond testMetrics := &statusMetrics{} - proxy := service.NewTCPService(cipherList, nil, testMetrics, testTimeout) + proxy := service.NewTCPService(cipherList, nil, testMetrics, testTimeout, makeLimiter(cipherList)) go proxy.Serve(proxyListener) proxyHost, proxyPort, err := net.SplitHostPort(proxyListener.Addr().String()) @@ -266,7 +457,7 @@ func TestUDPEcho(t *testing.T) { t.Fatal(err) } testMetrics := &fakeUDPMetrics{fakeLocation: "QQ"} - proxy := service.NewUDPService(time.Hour, cipherList, testMetrics) + proxy := service.NewUDPService(time.Hour, cipherList, testMetrics, makeLimiter(cipherList)) proxy.SetTargetIPValidator(allowAll) go proxy.Serve(proxyConn) @@ -363,7 +554,7 @@ func BenchmarkTCPThroughput(b *testing.B) { b.Fatal(err) } const testTimeout = 200 * time.Millisecond - proxy := service.NewTCPService(cipherList, nil, &metrics.NoOpMetrics{}, testTimeout) + proxy := service.NewTCPService(cipherList, nil, &metrics.NoOpMetrics{}, testTimeout, makeLimiter(cipherList)) proxy.SetTargetIPValidator(allowAll) go proxy.Serve(proxyListener) @@ -430,7 +621,7 @@ func BenchmarkTCPMultiplexing(b *testing.B) { } replayCache := service.NewReplayCache(service.MaxCapacity) const testTimeout = 200 * time.Millisecond - proxy := service.NewTCPService(cipherList, &replayCache, &metrics.NoOpMetrics{}, testTimeout) + proxy := service.NewTCPService(cipherList, &replayCache, &metrics.NoOpMetrics{}, testTimeout, makeLimiter(cipherList)) proxy.SetTargetIPValidator(allowAll) go proxy.Serve(proxyListener) @@ -505,7 +696,7 @@ func BenchmarkUDPEcho(b *testing.B) { if err != nil { b.Fatal(err) } - proxy := service.NewUDPService(time.Hour, cipherList, &metrics.NoOpMetrics{}) + proxy := service.NewUDPService(time.Hour, cipherList, &metrics.NoOpMetrics{}, makeLimiter(cipherList)) proxy.SetTargetIPValidator(allowAll) go proxy.Serve(proxyConn) @@ -554,7 +745,7 @@ func BenchmarkUDPManyKeys(b *testing.B) { if err != nil { b.Fatal(err) } - proxy := service.NewUDPService(time.Hour, cipherList, &metrics.NoOpMetrics{}) + proxy := service.NewUDPService(time.Hour, cipherList, &metrics.NoOpMetrics{}, makeLimiter(cipherList)) proxy.SetTargetIPValidator(allowAll) go proxy.Serve(proxyConn) diff --git a/server.go b/server.go index c3c04121..e8b28b2b 100644 --- a/server.go +++ b/server.go @@ -74,7 +74,7 @@ type SSServer struct { ports map[int]*ssPort } -func (s *SSServer) startPort(portNum int) error { +func (s *SSServer) startPort(portNum int, trafficLimiterConfig *service.TrafficLimiterConfig) error { listener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: portNum}) if err != nil { return fmt.Errorf("Failed to start TCP on port %v: %v", portNum, err) @@ -85,9 +85,11 @@ func (s *SSServer) startPort(portNum int) error { } logger.Infof("Listening TCP and UDP on port %v", portNum) port := &ssPort{cipherList: service.NewCipherList()} + + limiter := service.NewTrafficLimiter(trafficLimiterConfig) // TODO: Register initial data metrics at zero. - port.tcpService = service.NewTCPService(port.cipherList, &s.replayCache, s.m, tcpReadTimeout) - port.udpService = service.NewUDPService(s.natTimeout, port.cipherList, s.m) + port.tcpService = service.NewTCPService(port.cipherList, &s.replayCache, s.m, tcpReadTimeout, limiter) + port.udpService = service.NewUDPService(s.natTimeout, port.cipherList, s.m, limiter) s.ports[portNum] = port go port.tcpService.Serve(listener) go port.udpService.Serve(packetConn) @@ -120,6 +122,7 @@ func (s *SSServer) loadConfig(filename string) error { portChanges := make(map[int]int) portCiphers := make(map[int]*list.List) // Values are *List of *CipherEntry. + portKeyLimits := make(map[int]map[string]int) for _, keyConfig := range config.Keys { portChanges[keyConfig.Port] = 1 cipherList, ok := portCiphers[keyConfig.Port] @@ -133,6 +136,13 @@ func (s *SSServer) loadConfig(filename string) error { } entry := service.MakeCipherEntry(keyConfig.ID, cipher, keyConfig.Secret) cipherList.PushBack(&entry) + var keyLimits map[string]int + keyLimits, ok = portKeyLimits[keyConfig.Port] + if !ok { + keyLimits = make(map[string]int) + portKeyLimits[keyConfig.Port] = keyLimits + } + keyLimits[keyConfig.ID] = keyConfig.RateLimit } for port := range s.ports { portChanges[port] = portChanges[port] - 1 @@ -143,7 +153,8 @@ func (s *SSServer) loadConfig(filename string) error { return fmt.Errorf("Failed to remove port %v: %v", portNum, err) } } else if count == +1 { - if err := s.startPort(portNum); err != nil { + trafficLimiterConfig := &service.TrafficLimiterConfig{KeyToRateLimit: portKeyLimits[portNum]} + if err := s.startPort(portNum, trafficLimiterConfig); err != nil { return fmt.Errorf("Failed to start port %v: %v", portNum, err) } } @@ -193,10 +204,11 @@ func RunSSServer(filename string, natTimeout time.Duration, sm metrics.Shadowsoc type Config struct { Keys []struct { - ID string - Port int - Cipher string - Secret string + ID string + Port int + Cipher string + Secret string + RateLimit int } } @@ -207,6 +219,9 @@ func readConfig(filename string) (*Config, error) { return nil, err } err = yaml.Unmarshal(configData, &config) + if err != nil { + return nil, err + } return &config, err } diff --git a/service/limiter.go b/service/limiter.go new file mode 100644 index 00000000..c8a14086 --- /dev/null +++ b/service/limiter.go @@ -0,0 +1,128 @@ +// Copyright 2018 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + "io" + "time" + + "golang.org/x/time/rate" +) + +type TrafficLimiterConfig struct { + // Rate limit in bytes per second. + // If the corresponding limit is zero, it means no limits + KeyToRateLimit map[string]int +} + +type TrafficLimiter interface { + WrapReaderWriter(accessKey string, reader io.Reader, writer io.Writer) (io.Reader, io.Writer) + Allow(accessKey string, n int) bool +} + +func NewTrafficLimiter(config *TrafficLimiterConfig) TrafficLimiter { + keyToLimiter := make(map[string]*rate.Limiter, len(config.KeyToRateLimit)) + for accessKey, limit := range config.KeyToRateLimit { + var limiter *rate.Limiter + if limit == 0 { + limiter = nil + } else { + limiter = createLimiter(limit) + } + keyToLimiter[accessKey] = limiter + } + return &trafficLimiter{keyToLimiter: keyToLimiter} +} + +type trafficLimiter struct { + keyToLimiter map[string]*rate.Limiter +} + +// doWait does the waiting for any n, even exceeding the burst limit. +func doWait(l *rate.Limiter, n int) error { + ctx := context.TODO() + b := l.Burst() + for b < n { + err := l.WaitN(ctx, b) + if err != nil { + return err + } + n -= b + } + return l.WaitN(ctx, n) +} + +type limitedReader struct { + reader io.Reader + limiter *rate.Limiter +} + +func (r *limitedReader) Read(b []byte) (int, error) { + n, err := r.reader.Read(b) + if n <= 0 { + return n, err + } + waitErr := doWait(r.limiter, n) + if waitErr != nil { + return 0, waitErr + } + return n, err +} + +type limitedWriter struct { + writer io.Writer + limiter *rate.Limiter +} + +func (w *limitedWriter) Write(b []byte) (int, error) { + n, err := w.writer.Write(b) + if n <= 0 { + return n, err + } + waitErr := doWait(w.limiter, n) + if waitErr != nil { + return 0, waitErr + } + return n, err +} + +func createLimiter(limit int) *rate.Limiter { + burst := limit + theRate := rate.Every(time.Second) * rate.Limit(burst) + return rate.NewLimiter(theRate, burst) +} + +func (l *trafficLimiter) WrapReaderWriter(accessKey string, reader io.Reader, writer io.Writer) (io.Reader, io.Writer) { + limiter, ok := l.keyToLimiter[accessKey] + if !ok { + logger.Panicf("Access key %v not found", accessKey) + } + if limiter == nil { + return reader, writer + } + return &limitedReader{reader: reader, limiter: limiter}, &limitedWriter{writer: writer, limiter: limiter} +} + +func (l *trafficLimiter) Allow(accessKey string, n int) bool { + limiter, ok := l.keyToLimiter[accessKey] + if !ok { + logger.Panicf("Access key %v not found", accessKey) + } + if limiter == nil { + return true + } + return limiter.AllowN(time.Now(), n) +} diff --git a/service/limiter_test.go b/service/limiter_test.go new file mode 100644 index 00000000..e27eaa81 --- /dev/null +++ b/service/limiter_test.go @@ -0,0 +1,121 @@ +// Copyright 2020 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "bytes" + "io" + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func makeRandBuffer(n int64) *bytes.Buffer { + arr := make([]byte, n) + rand.Read(arr) + return bytes.NewBuffer(arr) +} + +func TestTrafficLimiter(t *testing.T) { + key1 := "key1" + key2 := "key2" + key3 := "key3" + + config := TrafficLimiterConfig{ + KeyToRateLimit: map[string]int{ + key1: 20, + key2: 30, + key3: 0, + }, + } + + limiter := NewTrafficLimiter(&config) + + src1Buf := makeRandBuffer(100) + src1 := src1Buf.Bytes() + dst1 := &bytes.Buffer{} + + src2Buf := makeRandBuffer(100) + src2 := src2Buf.Bytes() + dst2 := &bytes.Buffer{} + + src3Buf := makeRandBuffer(100) + src3 := src3Buf.Bytes() + dst3 := &bytes.Buffer{} + + r1, w1 := limiter.WrapReaderWriter(key1, src1Buf, dst1) + r2, w2 := limiter.WrapReaderWriter(key2, src2Buf, dst2) + r3, w3 := limiter.WrapReaderWriter(key3, src3Buf, dst3) + + doRead := func(buf []byte, r io.Reader, n int) time.Duration { + b := make([]byte, n) + start := time.Now() + _, err := io.ReadFull(r, b) + require.NoError(t, err) + require.Equal(t, b, buf[:len(b)]) + return time.Now().Sub(start) + } + + period1 := doRead(src1, r1, 20) + if period1 > 10*time.Millisecond { + t.Errorf("read took too long") + } + + period2 := doRead(src2, r2, 30) + if period2 > 10*time.Millisecond { + t.Errorf("read took too long") + } + + period3 := doRead(src3, r3, 100) + if period3 > 10*time.Millisecond { + t.Errorf("read took too long") + } + + doWrite := func(buf []byte, dst *bytes.Buffer, w io.Writer) time.Duration { + start := time.Now() + size := len(buf) + _, err := w.Write(buf) + require.NoError(t, err) + require.Equal(t, buf, dst.Bytes()[:size]) + return time.Now().Sub(start) + } + + // Waiting works even for payload exceeding one second burst. + period4 := doWrite(src1[:30], dst1, w1) + if period4 < 500*time.Millisecond { + t.Fatalf("write took too short") + } + + allowed := limiter.Allow(key2, 30) + require.True(t, allowed) + + allowed = limiter.Allow(key2, 10) + require.False(t, allowed) + + period5 := doWrite(src2[:30], dst2, w2) + if period5 < 500*time.Millisecond { + t.Fatalf("write took too short") + } + + allowed = limiter.Allow(key3, 1000) + require.True(t, allowed) + + period6 := doWrite(src3[:100], dst3, w3) + if period6 > 10*time.Millisecond { + t.Fatalf("write took too long") + } +} diff --git a/service/limiter_testing.go b/service/limiter_testing.go new file mode 100644 index 00000000..081d7a26 --- /dev/null +++ b/service/limiter_testing.go @@ -0,0 +1,29 @@ +// Copyright 2018 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "net" +) + +func MakeTestTrafficLimiterConfig(ciphers CipherList) TrafficLimiterConfig { + elts := ciphers.SnapshotForClientIP(net.IP{}) + keyToLimit := make(map[string]int) + for _, elt := range elts { + entry := elt.Value.(*CipherEntry) + keyToLimit[entry.ID] = 0 + } + return TrafficLimiterConfig{KeyToRateLimit: keyToLimit} +} diff --git a/service/tcp.go b/service/tcp.go index 69d71a05..f7abfaac 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -102,7 +102,6 @@ func findEntry(firstBytes []byte, ciphers []*list.Element) (*CipherEntry, *list. continue } debugTCP(id, "Found cipher at index %d", ci) - // Move the active cipher to the front, so that the search is quicker next time. return entry, elt } return nil, nil @@ -114,6 +113,7 @@ type tcpService struct { stopped bool ciphers CipherList m metrics.ShadowsocksMetrics + limiter TrafficLimiter running sync.WaitGroup readTimeout time.Duration // `replayCache` is a pointer to SSServer.replayCache, to share the cache among all ports. @@ -123,11 +123,13 @@ type tcpService struct { // NewTCPService creates a TCPService // `replayCache` is a pointer to SSServer.replayCache, to share the cache among all ports. -func NewTCPService(ciphers CipherList, replayCache *ReplayCache, m metrics.ShadowsocksMetrics, timeout time.Duration) TCPService { +func NewTCPService(ciphers CipherList, replayCache *ReplayCache, m metrics.ShadowsocksMetrics, + timeout time.Duration, limiter TrafficLimiter) TCPService { return &tcpService{ ciphers: ciphers, m: m, readTimeout: timeout, + limiter: limiter, replayCache: replayCache, targetIPValidator: onet.RequirePublicIP, } @@ -227,6 +229,13 @@ func (s *tcpService) handleConnection(listenerPort int, clientTCPConn *net.TCPCo var proxyMetrics metrics.ProxyMetrics clientConn := metrics.MeasureConn(clientTCPConn, &proxyMetrics.ProxyClient, &proxyMetrics.ClientProxy) cipherEntry, clientReader, clientSalt, timeToCipher, keyErr := findAccessKey(clientConn, remoteIP(clientTCPConn), s.ciphers) + var clientWriter io.Writer = clientConn + + var accessKey string + if cipherEntry != nil { + accessKey = cipherEntry.ID + clientReader, clientWriter = s.limiter.WrapReaderWriter(accessKey, clientReader, clientWriter) + } connError := func() *onet.ConnectionError { if keyErr != nil { @@ -256,9 +265,10 @@ func (s *tcpService) handleConnection(listenerPort int, clientTCPConn *net.TCPCo clientTCPConn.SetReadDeadline(time.Time{}) if err != nil { // Drain to prevent a close on cipher error. - io.Copy(ioutil.Discard, clientConn) + io.Copy(ioutil.Discard, clientReader) return onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", err) } + logger.Debugf("address %s", clientTCPConn.RemoteAddr().String()) tgtConn, dialErr := dialTarget(tgtAddr, &proxyMetrics, s.targetIPValidator) if dialErr != nil { @@ -268,15 +278,16 @@ func (s *tcpService) handleConnection(listenerPort int, clientTCPConn *net.TCPCo defer tgtConn.Close() logger.Debugf("proxy %s <-> %s", clientTCPConn.RemoteAddr().String(), tgtConn.RemoteAddr().String()) - ssw := ss.NewShadowsocksWriter(clientConn, cipherEntry.Cipher) + ssw := ss.NewShadowsocksWriter(clientWriter, cipherEntry.Cipher) ssw.SetSaltGenerator(cipherEntry.SaltGenerator) fromClientErrCh := make(chan error) go func() { _, fromClientErr := ssr.WriteTo(tgtConn) + logger.Debugf("fromClientErr: %v", fromClientErr) if fromClientErr != nil { // Drain to prevent a close in the case of a cipher error. - io.Copy(ioutil.Discard, clientConn) + io.Copy(ioutil.Discard, clientReader) } clientConn.CloseRead() // Send FIN to target. @@ -306,11 +317,7 @@ func (s *tcpService) handleConnection(listenerPort int, clientTCPConn *net.TCPCo logger.Debugf("TCP Error: %v: %v", connError.Message, connError.Cause) status = connError.Status } - var id string - if cipherEntry != nil { - id = cipherEntry.ID - } - s.m.AddClosedTCPConnection(clientLocation, id, status, proxyMetrics, timeToCipher, connDuration) + s.m.AddClosedTCPConnection(clientLocation, accessKey, status, proxyMetrics, timeToCipher, connDuration) clientConn.Close() // Closing after the metrics are added aids integration testing. logger.Debugf("Done with status %v, duration %v", status, connDuration) } diff --git a/service/tcp_test.go b/service/tcp_test.go index a6080be6..38004105 100644 --- a/service/tcp_test.go +++ b/service/tcp_test.go @@ -277,12 +277,18 @@ func probe(serverAddr *net.TCPAddr, bytesToSend []byte) error { return nil } +func makeLimiter(cipherList CipherList) TrafficLimiter { + c := MakeTestTrafficLimiterConfig(cipherList) + return NewTrafficLimiter(&c) +} + func TestProbeRandom(t *testing.T) { listener := makeLocalhostListener(t) cipherList, err := MakeTestCiphers(ss.MakeTestSecrets(1)) + require.Nil(t, err, "MakeTestCiphers failed: %v", err) testMetrics := &probeTestMetrics{} - s := NewTCPService(cipherList, nil, testMetrics, 200*time.Millisecond) + s := NewTCPService(cipherList, nil, testMetrics, 200*time.Millisecond, makeLimiter(cipherList)) go s.Serve(listener) // 221 is the largest random probe reported by https://gfw.report/blog/gfw_shadowsocks/ @@ -349,7 +355,7 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) { require.Nil(t, err, "MakeTestCiphers failed: %v", err) cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} - s := NewTCPService(cipherList, nil, testMetrics, 200*time.Millisecond) + s := NewTCPService(cipherList, nil, testMetrics, 200*time.Millisecond, makeLimiter(cipherList)) s.SetTargetIPValidator(allowAll) go s.Serve(listener) @@ -379,7 +385,7 @@ func TestProbeClientBytesBasicModified(t *testing.T) { require.Nil(t, err, "MakeTestCiphers failed: %v", err) cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} - s := NewTCPService(cipherList, nil, testMetrics, 200*time.Millisecond) + s := NewTCPService(cipherList, nil, testMetrics, 200*time.Millisecond, makeLimiter(cipherList)) s.SetTargetIPValidator(allowAll) go s.Serve(listener) @@ -410,7 +416,7 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) { require.Nil(t, err, "MakeTestCiphers failed: %v", err) cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} - s := NewTCPService(cipherList, nil, testMetrics, 200*time.Millisecond) + s := NewTCPService(cipherList, nil, testMetrics, 200*time.Millisecond, makeLimiter(cipherList)) s.SetTargetIPValidator(allowAll) go s.Serve(listener) @@ -448,7 +454,7 @@ func TestProbeServerBytesModified(t *testing.T) { require.Nil(t, err, "MakeTestCiphers failed: %v", err) cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} - s := NewTCPService(cipherList, nil, testMetrics, 200*time.Millisecond) + s := NewTCPService(cipherList, nil, testMetrics, 200*time.Millisecond, makeLimiter(cipherList)) go s.Serve(listener) initialBytes := makeServerBytes(t, cipher) @@ -473,7 +479,7 @@ func TestReplayDefense(t *testing.T) { replayCache := NewReplayCache(5) testMetrics := &probeTestMetrics{} const testTimeout = 200 * time.Millisecond - s := NewTCPService(cipherList, &replayCache, testMetrics, testTimeout) + s := NewTCPService(cipherList, &replayCache, testMetrics, testTimeout, makeLimiter(cipherList)) snapshot := cipherList.SnapshotForClientIP(nil) cipherEntry := snapshot[0].Value.(*CipherEntry) cipher := cipherEntry.Cipher @@ -546,7 +552,7 @@ func TestReverseReplayDefense(t *testing.T) { replayCache := NewReplayCache(5) testMetrics := &probeTestMetrics{} const testTimeout = 200 * time.Millisecond - s := NewTCPService(cipherList, &replayCache, testMetrics, testTimeout) + s := NewTCPService(cipherList, &replayCache, testMetrics, testTimeout, makeLimiter(cipherList)) snapshot := cipherList.SnapshotForClientIP(nil) cipherEntry := snapshot[0].Value.(*CipherEntry) cipher := cipherEntry.Cipher @@ -611,7 +617,7 @@ func probeExpectTimeout(t *testing.T, payloadSize int) { cipherList, err := MakeTestCiphers(ss.MakeTestSecrets(5)) require.Nil(t, err, "MakeTestCiphers failed: %v", err) testMetrics := &probeTestMetrics{} - s := NewTCPService(cipherList, nil, testMetrics, testTimeout) + s := NewTCPService(cipherList, nil, testMetrics, testTimeout, makeLimiter(cipherList)) testPayload := ss.MakeTestPayload(payloadSize) done := make(chan bool) @@ -674,7 +680,7 @@ func TestTCPDoubleServe(t *testing.T) { replayCache := NewReplayCache(5) testMetrics := &probeTestMetrics{} const testTimeout = 200 * time.Millisecond - s := NewTCPService(cipherList, &replayCache, testMetrics, testTimeout) + s := NewTCPService(cipherList, &replayCache, testMetrics, testTimeout, makeLimiter(cipherList)) c := make(chan error) for i := 0; i < 2; i++ { @@ -704,7 +710,7 @@ func TestTCPEarlyStop(t *testing.T) { replayCache := NewReplayCache(5) testMetrics := &probeTestMetrics{} const testTimeout = 200 * time.Millisecond - s := NewTCPService(cipherList, &replayCache, testMetrics, testTimeout) + s := NewTCPService(cipherList, &replayCache, testMetrics, testTimeout, makeLimiter(cipherList)) if err := s.Stop(); err != nil { t.Error(err) diff --git a/service/udp.go b/service/udp.go index 42160497..008ea77f 100644 --- a/service/udp.go +++ b/service/udp.go @@ -78,11 +78,18 @@ type udpService struct { m metrics.ShadowsocksMetrics running sync.WaitGroup targetIPValidator onet.TargetIPValidator + limiter TrafficLimiter } // NewUDPService creates a UDPService -func NewUDPService(natTimeout time.Duration, cipherList CipherList, m metrics.ShadowsocksMetrics) UDPService { - return &udpService{natTimeout: natTimeout, ciphers: cipherList, m: m, targetIPValidator: onet.RequirePublicIP} +func NewUDPService(natTimeout time.Duration, cipherList CipherList, m metrics.ShadowsocksMetrics, limiter TrafficLimiter) UDPService { + return &udpService{ + natTimeout: natTimeout, + ciphers: cipherList, + m: m, + targetIPValidator: onet.RequirePublicIP, + limiter: limiter, + } } // UDPService is a running UDP shadowsocks proxy that can be stopped. @@ -119,7 +126,7 @@ func (s *udpService) Serve(clientConn net.PacketConn) error { s.mu.Unlock() defer s.running.Done() - nm := newNATmap(s.natTimeout, s.m, &s.running) + nm := newNATmap(s.natTimeout, s.m, &s.running, s.limiter) defer nm.Close() cipherBuf := make([]byte, serverUDPBufferSize) textBuf := make([]byte, serverUDPBufferSize) @@ -221,6 +228,11 @@ func (s *udpService) Serve(clientConn net.PacketConn) error { } debugUDPAddr(clientAddr, "Proxy exit %v", targetConn.LocalAddr()) + allowed := s.limiter.Allow(keyID, len(payload)) + if !allowed { + debugUDPAddr(clientAddr, "Rate limite exceeded, dropping packet from client (len %v)", len(payload)) + return onet.NewConnectionError("ERR_LIMIT", "Rate limit exceeded", nil) + } proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature if err != nil { return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err) @@ -342,10 +354,11 @@ type natmap struct { timeout time.Duration metrics metrics.ShadowsocksMetrics running *sync.WaitGroup + limiter TrafficLimiter } -func newNATmap(timeout time.Duration, sm metrics.ShadowsocksMetrics, running *sync.WaitGroup) *natmap { - m := &natmap{metrics: sm, running: running} +func newNATmap(timeout time.Duration, sm metrics.ShadowsocksMetrics, running *sync.WaitGroup, limiter TrafficLimiter) *natmap { + m := &natmap{metrics: sm, running: running, limiter: limiter} m.keyConn = make(map[string]*natconn) m.timeout = timeout return m @@ -391,7 +404,7 @@ func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cipher *ss. m.metrics.AddUDPNatEntry() m.running.Add(1) go func() { - timedCopy(clientAddr, clientConn, entry, keyID, m.metrics) + timedCopy(clientAddr, clientConn, entry, keyID, m.metrics, m.limiter) m.metrics.RemoveUDPNatEntry() if pc := m.del(clientAddr.String()); pc != nil { pc.Close() @@ -421,7 +434,7 @@ var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345")) // copy from target to client until read timeout func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natconn, - keyID string, sm metrics.ShadowsocksMetrics) { + keyID string, sm metrics.ShadowsocksMetrics, limiter TrafficLimiter) { // pkt is used for in-place encryption of downstream UDP packets, with the layout // [padding?][salt][address][body][tag][extra] // Padding is only used if the address is IPv4. @@ -455,6 +468,13 @@ func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natco } debugUDPAddr(clientAddr, "Got response from %v", raddr) + + allowed := limiter.Allow(keyID, bodyLen) + if !allowed { + debugUDPAddr(clientAddr, "Rate limite exceeded, dropping packet to client (len %v)", bodyLen) + return onet.NewConnectionError("ERR_LIMIT", "Rate limit exceeded", nil) + } + srcAddr := socks.ParseAddr(raddr.String()) addrStart := bodyStart - len(srcAddr) // `plainTextBuf` concatenates the SOCKS address and body: diff --git a/service/udp_test.go b/service/udp_test.go index 12139fdc..6a7b6569 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -130,7 +130,7 @@ func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator) *natTest cipher := ciphers.SnapshotForClientIP(nil)[0].Value.(*CipherEntry).Cipher clientConn := makePacketConn() metrics := &natTestMetrics{} - service := NewUDPService(timeout, ciphers, metrics) + service := NewUDPService(timeout, ciphers, metrics, makeLimiter(ciphers)) service.SetTargetIPValidator(validator) go service.Serve(clientConn) @@ -202,17 +202,22 @@ func assertAlmostEqual(t *testing.T, a, b time.Time) { } func TestNATEmpty(t *testing.T) { - nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}) + nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}, makeLimiter(NewCipherList())) if nat.Get("foo") != nil { t.Error("Expected nil value from empty NAT map") } } func setupNAT() (*fakePacketConn, *fakePacketConn, *natconn) { - nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}) + cipherList, err := MakeTestCiphers(ss.MakeTestSecrets(1)) + if err != nil { + logger.Fatal(err) + } + nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}, makeLimiter(cipherList)) clientConn := makePacketConn() targetConn := makePacketConn() - nat.Add(&clientAddr, clientConn, natCipher, targetConn, "ZZ", "key id") + key := cipherList.SnapshotForClientIP(net.IP{})[0].Value.(*CipherEntry).ID + nat.Add(&clientAddr, clientConn, natCipher, targetConn, "ZZ", key) entry := nat.Get(clientAddr.String()) return clientConn, targetConn, entry } @@ -474,7 +479,7 @@ func TestUDPDoubleServe(t *testing.T) { } testMetrics := &natTestMetrics{} const testTimeout = 200 * time.Millisecond - s := NewUDPService(testTimeout, cipherList, testMetrics) + s := NewUDPService(testTimeout, cipherList, testMetrics, makeLimiter(cipherList)) c := make(chan error) for i := 0; i < 2; i++ { @@ -508,7 +513,7 @@ func TestUDPEarlyStop(t *testing.T) { } testMetrics := &natTestMetrics{} const testTimeout = 200 * time.Millisecond - s := NewUDPService(testTimeout, cipherList, testMetrics) + s := NewUDPService(testTimeout, cipherList, testMetrics, makeLimiter(cipherList)) if err := s.Stop(); err != nil { t.Error(err)