diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 748003ec..47b686a9 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -86,8 +86,9 @@ func (s *SSServer) startPort(portNum int) error { } logger.Infof("Shadowsocks UDP service listening on %v", packetConn.LocalAddr().String()) port := &ssPort{tcpListener: listener, packetConn: packetConn, cipherList: service.NewCipherList()} + authFunc := service.NewShadowsocksStreamAuthenticator(port.cipherList, &s.replayCache, s.m) // TODO: Register initial data metrics at zero. - tcpHandler := service.NewTCPHandler(portNum, port.cipherList, &s.replayCache, s.m, tcpReadTimeout) + tcpHandler := service.NewTCPHandler(portNum, authFunc, s.m, tcpReadTimeout) packetHandler := service.NewPacketHandler(s.natTimeout, port.cipherList, s.m) s.ports[portNum] = port accept := func() (transport.StreamConn, error) { diff --git a/internal/integration_test/integration_test.go b/internal/integration_test/integration_test.go index 364f9482..bfb03d4a 100644 --- a/internal/integration_test/integration_test.go +++ b/internal/integration_test/integration_test.go @@ -130,7 +130,9 @@ func TestTCPEcho(t *testing.T) { } replayCache := service.NewReplayCache(5) const testTimeout = 200 * time.Millisecond - handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, cipherList, &replayCache, &service.NoOpTCPMetrics{}, testTimeout) + testMetrics := &service.NoOpTCPMetrics{} + authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) + handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { @@ -198,7 +200,8 @@ func TestRestrictedAddresses(t *testing.T) { require.NoError(t, err) const testTimeout = 200 * time.Millisecond testMetrics := &statusMetrics{} - handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, testTimeout) + authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) + handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) done := make(chan struct{}) go func() { service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle) @@ -378,7 +381,9 @@ func BenchmarkTCPThroughput(b *testing.B) { b.Fatal(err) } const testTimeout = 200 * time.Millisecond - handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, cipherList, nil, &service.NoOpTCPMetrics{}, testTimeout) + testMetrics := &service.NoOpTCPMetrics{} + authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) + handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { @@ -440,7 +445,9 @@ func BenchmarkTCPMultiplexing(b *testing.B) { } replayCache := service.NewReplayCache(service.MaxCapacity) const testTimeout = 200 * time.Millisecond - handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, cipherList, &replayCache, &service.NoOpTCPMetrics{}, testTimeout) + testMetrics := &service.NoOpTCPMetrics{} + authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) + handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { diff --git a/service/tcp.go b/service/tcp.go index 59186bfd..71deaf8d 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -42,10 +42,7 @@ type TCPMetrics interface { // TCP metrics AddOpenTCPConnection(clientInfo ipinfo.IPInfo) AddClosedTCPConnection(clientInfo ipinfo.IPInfo, accessKey, status string, data metrics.ProxyMetrics, duration time.Duration) - - // Shadowsocks TCP metrics AddTCPProbe(status, drainResult string, port int, clientProxyBytes int64) - AddTCPCipherSearch(accessKeyFound bool, timeToCipher time.Duration) } func remoteIP(conn net.Conn) net.IP { @@ -118,26 +115,65 @@ func findEntry(firstBytes []byte, ciphers []*list.Element) (*CipherEntry, *list. return nil, nil } +type StreamAuthenticateFunc func(clientConn transport.StreamConn) (string, transport.StreamConn, *onet.ConnectionError) + +// ShadowsocksTCPMetrics is used to report Shadowsocks metrics on TCP connections. +type ShadowsocksTCPMetrics interface { + // Shadowsocks TCP metrics + AddTCPCipherSearch(accessKeyFound bool, timeToCipher time.Duration) +} + +// NewShadowsocksStreamAuthenticator creates a stream authenticator that uses Shadowsocks. +// TODO(fortuna): Offer alternative transports. +func NewShadowsocksStreamAuthenticator(ciphers CipherList, replayCache *ReplayCache, metrics ShadowsocksTCPMetrics) StreamAuthenticateFunc { + return func(clientConn transport.StreamConn) (string, transport.StreamConn, *onet.ConnectionError) { + // Find the cipher and acess key id. + cipherEntry, clientReader, clientSalt, timeToCipher, keyErr := findAccessKey(clientConn, remoteIP(clientConn), ciphers) + metrics.AddTCPCipherSearch(keyErr == nil, timeToCipher) + if keyErr != nil { + const status = "ERR_CIPHER" + return "", nil, onet.NewConnectionError(status, "Failed to find a valid cipher", keyErr) + } + var id string + if cipherEntry != nil { + id = cipherEntry.ID + } + + // Check if the connection is a replay. + isServerSalt := cipherEntry.SaltGenerator.IsServerSalt(clientSalt) + // Only check the cache if findAccessKey succeeded and the salt is unrecognized. + if isServerSalt || !replayCache.Add(cipherEntry.ID, clientSalt) { + var status string + if isServerSalt { + status = "ERR_REPLAY_SERVER" + } else { + status = "ERR_REPLAY_CLIENT" + } + return id, nil, onet.NewConnectionError(status, "Replay detected", nil) + } + ssr := shadowsocks.NewReader(clientReader, cipherEntry.CryptoKey) + ssw := shadowsocks.NewWriter(clientConn, cipherEntry.CryptoKey) + ssw.SetSaltGenerator(cipherEntry.SaltGenerator) + return id, transport.WrapConn(clientConn, ssr, ssw), nil + } +} + type tcpHandler struct { - port int - ciphers CipherList - m TCPMetrics - readTimeout time.Duration - // `replayCache` is a pointer to SSServer.replayCache, to share the cache among all ports. - replayCache *ReplayCache - dialer transport.StreamDialer + port int + m TCPMetrics + readTimeout time.Duration + authenticate StreamAuthenticateFunc + dialer transport.StreamDialer } // NewTCPService creates a TCPService -// `replayCache` is a pointer to SSServer.replayCache, to share the cache among all ports. -func NewTCPHandler(port int, ciphers CipherList, replayCache *ReplayCache, m TCPMetrics, timeout time.Duration) TCPHandler { +func NewTCPHandler(port int, authenticate StreamAuthenticateFunc, m TCPMetrics, timeout time.Duration) TCPHandler { return &tcpHandler{ - port: port, - ciphers: ciphers, - m: m, - readTimeout: timeout, - replayCache: replayCache, - dialer: defaultDialer, + port: port, + m: m, + readTimeout: timeout, + authenticate: authenticate, + dialer: defaultDialer, } } @@ -239,40 +275,6 @@ func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn logger.Debugf("Done with status %v, duration %v", status, connDuration) } -func (h *tcpHandler) authenticate(clientConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, transport.StreamConn, *onet.ConnectionError) { - // TODO(fortuna): Offer alternative transports. - // Find the cipher and acess key id. - cipherEntry, clientReader, clientSalt, timeToCipher, keyErr := findAccessKey(clientConn, remoteIP(clientConn), h.ciphers) - h.m.AddTCPCipherSearch(keyErr == nil, timeToCipher) - if keyErr != nil { - logger.Debugf("Failed to find a valid cipher after reading %v bytes: %v", proxyMetrics.ClientProxy, keyErr) - const status = "ERR_CIPHER" - return "", nil, onet.NewConnectionError(status, "Failed to find a valid cipher", keyErr) - } - var id string - if cipherEntry != nil { - id = cipherEntry.ID - } - - // Check if the connection is a replay. - isServerSalt := cipherEntry.SaltGenerator.IsServerSalt(clientSalt) - // Only check the cache if findAccessKey succeeded and the salt is unrecognized. - if isServerSalt || !h.replayCache.Add(cipherEntry.ID, clientSalt) { - var status string - if isServerSalt { - status = "ERR_REPLAY_SERVER" - } else { - status = "ERR_REPLAY_CLIENT" - } - logger.Debugf(status+": %v sent %d bytes", clientConn.RemoteAddr(), proxyMetrics.ClientProxy) - return id, nil, onet.NewConnectionError(status, "Replay detected", nil) - } - ssr := shadowsocks.NewReader(clientReader, cipherEntry.CryptoKey) - ssw := shadowsocks.NewWriter(clientConn, cipherEntry.CryptoKey) - ssw.SetSaltGenerator(cipherEntry.SaltGenerator) - return id, transport.WrapConn(clientConn, ssr, ssw), nil -} - func getProxyRequest(clientConn transport.StreamConn) (string, error) { // TODO(fortuna): Use Shadowsocks proxy, HTTP CONNECT or SOCKS5 based on first byte: // case 1, 3 or 4: Shadowsocks (address type) @@ -334,7 +336,7 @@ func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, out } outerConn.SetReadDeadline(readDeadline) - id, innerConn, authErr := h.authenticate(outerConn, proxyMetrics) + id, innerConn, authErr := h.authenticate(outerConn) if authErr != nil { // Drain to protect against probing attacks. h.absorbProbe(listenerPort, outerConn, authErr.Status, proxyMetrics) diff --git a/service/tcp_test.go b/service/tcp_test.go index 32ff20f0..b29adcb1 100644 --- a/service/tcp_test.go +++ b/service/tcp_test.go @@ -275,7 +275,8 @@ func TestProbeRandom(t *testing.T) { cipherList, err := MakeTestCiphers(makeTestSecrets(1)) require.NoError(t, err, "MakeTestCiphers failed: %v", err) testMetrics := &probeTestMetrics{} - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, 200*time.Millisecond) + authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) + handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, 200*time.Millisecond) done := make(chan struct{}) go func() { StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) @@ -351,7 +352,8 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) { require.NoError(t, err, "MakeTestCiphers failed: %v", err) cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, 200*time.Millisecond) + authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) + handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, 200*time.Millisecond) handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { @@ -386,7 +388,8 @@ func TestProbeClientBytesBasicModified(t *testing.T) { require.NoError(t, err, "MakeTestCiphers failed: %v", err) cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, 200*time.Millisecond) + authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) + handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, 200*time.Millisecond) handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { @@ -422,7 +425,8 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) { require.NoError(t, err, "MakeTestCiphers failed: %v", err) cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, 200*time.Millisecond) + authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) + handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, 200*time.Millisecond) handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { @@ -465,7 +469,8 @@ func TestProbeServerBytesModified(t *testing.T) { require.NoError(t, err, "MakeTestCiphers failed: %v", err) cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, 200*time.Millisecond) + authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) + handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, 200*time.Millisecond) done := make(chan struct{}) go func() { StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) @@ -495,7 +500,8 @@ func TestReplayDefense(t *testing.T) { replayCache := NewReplayCache(5) testMetrics := &probeTestMetrics{} const testTimeout = 200 * time.Millisecond - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, &replayCache, testMetrics, testTimeout) + authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) + handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) snapshot := cipherList.SnapshotForClientIP(nil) cipherEntry := snapshot[0].Value.(*CipherEntry) cipher := cipherEntry.CryptoKey @@ -573,7 +579,8 @@ func TestReverseReplayDefense(t *testing.T) { replayCache := NewReplayCache(5) testMetrics := &probeTestMetrics{} const testTimeout = 200 * time.Millisecond - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, &replayCache, testMetrics, testTimeout) + authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) + handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) snapshot := cipherList.SnapshotForClientIP(nil) cipherEntry := snapshot[0].Value.(*CipherEntry) cipher := cipherEntry.CryptoKey @@ -643,7 +650,8 @@ func probeExpectTimeout(t *testing.T, payloadSize int) { cipherList, err := MakeTestCiphers(makeTestSecrets(5)) require.NoError(t, err, "MakeTestCiphers failed: %v", err) testMetrics := &probeTestMetrics{} - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, testTimeout) + authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) + handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) done := make(chan struct{}) go func() {