Skip to content

Commit

Permalink
clean: decouple TCPHandler from the Shadowsocks transport (#174)
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna authored Mar 15, 2024
1 parent b9cb68e commit 36bf99d
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 66 deletions.
3 changes: 2 additions & 1 deletion cmd/outline-ss-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ func (s *SSServer) startPort(portNum int) error {
}
logger.Infof("Shadowsocks UDP service listening on %v", packetConn.LocalAddr().String())
port := &ssPort{tcpListener: listener, packetConn: packetConn, cipherList: service.NewCipherList()}
authFunc := service.NewShadowsocksStreamAuthenticator(port.cipherList, &s.replayCache, s.m)
// TODO: Register initial data metrics at zero.
tcpHandler := service.NewTCPHandler(portNum, port.cipherList, &s.replayCache, s.m, tcpReadTimeout)
tcpHandler := service.NewTCPHandler(portNum, authFunc, s.m, tcpReadTimeout)
packetHandler := service.NewPacketHandler(s.natTimeout, port.cipherList, s.m)
s.ports[portNum] = port
accept := func() (transport.StreamConn, error) {
Expand Down
15 changes: 11 additions & 4 deletions internal/integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ func TestTCPEcho(t *testing.T) {
}
replayCache := service.NewReplayCache(5)
const testTimeout = 200 * time.Millisecond
handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, cipherList, &replayCache, &service.NoOpTCPMetrics{}, testTimeout)
testMetrics := &service.NoOpTCPMetrics{}
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout)
handler.SetTargetDialer(&transport.TCPDialer{})
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -198,7 +200,8 @@ func TestRestrictedAddresses(t *testing.T) {
require.NoError(t, err)
const testTimeout = 200 * time.Millisecond
testMetrics := &statusMetrics{}
handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, testTimeout)
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout)
done := make(chan struct{})
go func() {
service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle)
Expand Down Expand Up @@ -378,7 +381,9 @@ func BenchmarkTCPThroughput(b *testing.B) {
b.Fatal(err)
}
const testTimeout = 200 * time.Millisecond
handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, cipherList, nil, &service.NoOpTCPMetrics{}, testTimeout)
testMetrics := &service.NoOpTCPMetrics{}
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout)
handler.SetTargetDialer(&transport.TCPDialer{})
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -440,7 +445,9 @@ func BenchmarkTCPMultiplexing(b *testing.B) {
}
replayCache := service.NewReplayCache(service.MaxCapacity)
const testTimeout = 200 * time.Millisecond
handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, cipherList, &replayCache, &service.NoOpTCPMetrics{}, testTimeout)
testMetrics := &service.NoOpTCPMetrics{}
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout)
handler.SetTargetDialer(&transport.TCPDialer{})
done := make(chan struct{})
go func() {
Expand Down
108 changes: 55 additions & 53 deletions service/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,7 @@ type TCPMetrics interface {
// TCP metrics
AddOpenTCPConnection(clientInfo ipinfo.IPInfo)
AddClosedTCPConnection(clientInfo ipinfo.IPInfo, accessKey, status string, data metrics.ProxyMetrics, duration time.Duration)

// Shadowsocks TCP metrics
AddTCPProbe(status, drainResult string, port int, clientProxyBytes int64)
AddTCPCipherSearch(accessKeyFound bool, timeToCipher time.Duration)
}

func remoteIP(conn net.Conn) net.IP {
Expand Down Expand Up @@ -118,26 +115,65 @@ func findEntry(firstBytes []byte, ciphers []*list.Element) (*CipherEntry, *list.
return nil, nil
}

type StreamAuthenticateFunc func(clientConn transport.StreamConn) (string, transport.StreamConn, *onet.ConnectionError)

// ShadowsocksTCPMetrics is used to report Shadowsocks metrics on TCP connections.
type ShadowsocksTCPMetrics interface {
// Shadowsocks TCP metrics
AddTCPCipherSearch(accessKeyFound bool, timeToCipher time.Duration)
}

// NewShadowsocksStreamAuthenticator creates a stream authenticator that uses Shadowsocks.
// TODO(fortuna): Offer alternative transports.
func NewShadowsocksStreamAuthenticator(ciphers CipherList, replayCache *ReplayCache, metrics ShadowsocksTCPMetrics) StreamAuthenticateFunc {
return func(clientConn transport.StreamConn) (string, transport.StreamConn, *onet.ConnectionError) {
// Find the cipher and acess key id.
cipherEntry, clientReader, clientSalt, timeToCipher, keyErr := findAccessKey(clientConn, remoteIP(clientConn), ciphers)
metrics.AddTCPCipherSearch(keyErr == nil, timeToCipher)
if keyErr != nil {
const status = "ERR_CIPHER"
return "", nil, onet.NewConnectionError(status, "Failed to find a valid cipher", keyErr)
}
var id string
if cipherEntry != nil {
id = cipherEntry.ID
}

// Check if the connection is a replay.
isServerSalt := cipherEntry.SaltGenerator.IsServerSalt(clientSalt)
// Only check the cache if findAccessKey succeeded and the salt is unrecognized.
if isServerSalt || !replayCache.Add(cipherEntry.ID, clientSalt) {
var status string
if isServerSalt {
status = "ERR_REPLAY_SERVER"
} else {
status = "ERR_REPLAY_CLIENT"
}
return id, nil, onet.NewConnectionError(status, "Replay detected", nil)
}
ssr := shadowsocks.NewReader(clientReader, cipherEntry.CryptoKey)
ssw := shadowsocks.NewWriter(clientConn, cipherEntry.CryptoKey)
ssw.SetSaltGenerator(cipherEntry.SaltGenerator)
return id, transport.WrapConn(clientConn, ssr, ssw), nil
}
}

type tcpHandler struct {
port int
ciphers CipherList
m TCPMetrics
readTimeout time.Duration
// `replayCache` is a pointer to SSServer.replayCache, to share the cache among all ports.
replayCache *ReplayCache
dialer transport.StreamDialer
port int
m TCPMetrics
readTimeout time.Duration
authenticate StreamAuthenticateFunc
dialer transport.StreamDialer
}

// NewTCPService creates a TCPService
// `replayCache` is a pointer to SSServer.replayCache, to share the cache among all ports.
func NewTCPHandler(port int, ciphers CipherList, replayCache *ReplayCache, m TCPMetrics, timeout time.Duration) TCPHandler {
func NewTCPHandler(port int, authenticate StreamAuthenticateFunc, m TCPMetrics, timeout time.Duration) TCPHandler {
return &tcpHandler{
port: port,
ciphers: ciphers,
m: m,
readTimeout: timeout,
replayCache: replayCache,
dialer: defaultDialer,
port: port,
m: m,
readTimeout: timeout,
authenticate: authenticate,
dialer: defaultDialer,
}
}

Expand Down Expand Up @@ -239,40 +275,6 @@ func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn
logger.Debugf("Done with status %v, duration %v", status, connDuration)
}

func (h *tcpHandler) authenticate(clientConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, transport.StreamConn, *onet.ConnectionError) {
// TODO(fortuna): Offer alternative transports.
// Find the cipher and acess key id.
cipherEntry, clientReader, clientSalt, timeToCipher, keyErr := findAccessKey(clientConn, remoteIP(clientConn), h.ciphers)
h.m.AddTCPCipherSearch(keyErr == nil, timeToCipher)
if keyErr != nil {
logger.Debugf("Failed to find a valid cipher after reading %v bytes: %v", proxyMetrics.ClientProxy, keyErr)
const status = "ERR_CIPHER"
return "", nil, onet.NewConnectionError(status, "Failed to find a valid cipher", keyErr)
}
var id string
if cipherEntry != nil {
id = cipherEntry.ID
}

// Check if the connection is a replay.
isServerSalt := cipherEntry.SaltGenerator.IsServerSalt(clientSalt)
// Only check the cache if findAccessKey succeeded and the salt is unrecognized.
if isServerSalt || !h.replayCache.Add(cipherEntry.ID, clientSalt) {
var status string
if isServerSalt {
status = "ERR_REPLAY_SERVER"
} else {
status = "ERR_REPLAY_CLIENT"
}
logger.Debugf(status+": %v sent %d bytes", clientConn.RemoteAddr(), proxyMetrics.ClientProxy)
return id, nil, onet.NewConnectionError(status, "Replay detected", nil)
}
ssr := shadowsocks.NewReader(clientReader, cipherEntry.CryptoKey)
ssw := shadowsocks.NewWriter(clientConn, cipherEntry.CryptoKey)
ssw.SetSaltGenerator(cipherEntry.SaltGenerator)
return id, transport.WrapConn(clientConn, ssr, ssw), nil
}

func getProxyRequest(clientConn transport.StreamConn) (string, error) {
// TODO(fortuna): Use Shadowsocks proxy, HTTP CONNECT or SOCKS5 based on first byte:
// case 1, 3 or 4: Shadowsocks (address type)
Expand Down Expand Up @@ -334,7 +336,7 @@ func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, out
}
outerConn.SetReadDeadline(readDeadline)

id, innerConn, authErr := h.authenticate(outerConn, proxyMetrics)
id, innerConn, authErr := h.authenticate(outerConn)
if authErr != nil {
// Drain to protect against probing attacks.
h.absorbProbe(listenerPort, outerConn, authErr.Status, proxyMetrics)
Expand Down
24 changes: 16 additions & 8 deletions service/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ func TestProbeRandom(t *testing.T) {
cipherList, err := MakeTestCiphers(makeTestSecrets(1))
require.NoError(t, err, "MakeTestCiphers failed: %v", err)
testMetrics := &probeTestMetrics{}
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, 200*time.Millisecond)
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, 200*time.Millisecond)
done := make(chan struct{})
go func() {
StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle)
Expand Down Expand Up @@ -351,7 +352,8 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) {
require.NoError(t, err, "MakeTestCiphers failed: %v", err)
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, 200*time.Millisecond)
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, 200*time.Millisecond)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll))
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -386,7 +388,8 @@ func TestProbeClientBytesBasicModified(t *testing.T) {
require.NoError(t, err, "MakeTestCiphers failed: %v", err)
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, 200*time.Millisecond)
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, 200*time.Millisecond)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll))
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -422,7 +425,8 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) {
require.NoError(t, err, "MakeTestCiphers failed: %v", err)
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, 200*time.Millisecond)
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, 200*time.Millisecond)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll))
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -465,7 +469,8 @@ func TestProbeServerBytesModified(t *testing.T) {
require.NoError(t, err, "MakeTestCiphers failed: %v", err)
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, 200*time.Millisecond)
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, 200*time.Millisecond)
done := make(chan struct{})
go func() {
StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle)
Expand Down Expand Up @@ -495,7 +500,8 @@ func TestReplayDefense(t *testing.T) {
replayCache := NewReplayCache(5)
testMetrics := &probeTestMetrics{}
const testTimeout = 200 * time.Millisecond
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, &replayCache, testMetrics, testTimeout)
authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout)
snapshot := cipherList.SnapshotForClientIP(nil)
cipherEntry := snapshot[0].Value.(*CipherEntry)
cipher := cipherEntry.CryptoKey
Expand Down Expand Up @@ -573,7 +579,8 @@ func TestReverseReplayDefense(t *testing.T) {
replayCache := NewReplayCache(5)
testMetrics := &probeTestMetrics{}
const testTimeout = 200 * time.Millisecond
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, &replayCache, testMetrics, testTimeout)
authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout)
snapshot := cipherList.SnapshotForClientIP(nil)
cipherEntry := snapshot[0].Value.(*CipherEntry)
cipher := cipherEntry.CryptoKey
Expand Down Expand Up @@ -643,7 +650,8 @@ func probeExpectTimeout(t *testing.T, payloadSize int) {
cipherList, err := MakeTestCiphers(makeTestSecrets(5))
require.NoError(t, err, "MakeTestCiphers failed: %v", err)
testMetrics := &probeTestMetrics{}
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, testTimeout)
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout)

done := make(chan struct{})
go func() {
Expand Down

0 comments on commit 36bf99d

Please sign in to comment.