Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean: modularize TCP connection handling #173

Merged
merged 1 commit into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions service/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,13 @@ func (c *measuredConn) Write(b []byte) (int, error) {
return n, err
}

func (c *measuredConn) ReadFrom(r io.Reader) (int64, error) {
n, err := io.Copy(c.StreamConn, r)
func (c *measuredConn) ReadFrom(r io.Reader) (n int64, err error) {
if rf, ok := c.StreamConn.(io.ReaderFrom); ok {
// Prefer ReadFrom if we are calling ReadFrom. Otherwise io.Copy will try WriteTo first.
n, err = rf.ReadFrom(r)
} else {
n, err = io.Copy(c.StreamConn, r)
}
*c.writeCount += n
return n, err
}
Expand Down
97 changes: 67 additions & 30 deletions service/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,25 +239,22 @@ func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn
logger.Debugf("Done with status %v, duration %v", status, connDuration)
}

func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, clientConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) {
// Set a deadline to receive the address to the target.
clientConn.SetReadDeadline(time.Now().Add(h.readTimeout))

// 1. Find the cipher and acess key id.
func (h *tcpHandler) authenticate(clientConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, transport.StreamConn, *onet.ConnectionError) {
// TODO(fortuna): Offer alternative transports.
// Find the cipher and acess key id.
cipherEntry, clientReader, clientSalt, timeToCipher, keyErr := findAccessKey(clientConn, remoteIP(clientConn), h.ciphers)
h.m.AddTCPCipherSearch(keyErr == nil, timeToCipher)
if keyErr != nil {
logger.Debugf("Failed to find a valid cipher after reading %v bytes: %v", proxyMetrics.ClientProxy, keyErr)
const status = "ERR_CIPHER"
h.absorbProbe(listenerPort, clientConn, status, proxyMetrics)
return "", onet.NewConnectionError(status, "Failed to find a valid cipher", keyErr)
return "", nil, onet.NewConnectionError(status, "Failed to find a valid cipher", keyErr)
}
var id string
if cipherEntry != nil {
id = cipherEntry.ID
}

// 2. Check if the connection is a replay.
// Check if the connection is a replay.
isServerSalt := cipherEntry.SaltGenerator.IsServerSalt(clientSalt)
// Only check the cache if findAccessKey succeeded and the salt is unrecognized.
if isServerSalt || !h.replayCache.Add(cipherEntry.ID, clientSalt) {
Expand All @@ -267,38 +264,39 @@ func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, cli
} else {
status = "ERR_REPLAY_CLIENT"
}
h.absorbProbe(listenerPort, clientConn, status, proxyMetrics)
logger.Debugf(status+": %v sent %d bytes", clientConn.RemoteAddr(), proxyMetrics.ClientProxy)
return id, onet.NewConnectionError(status, "Replay detected", nil)
return id, nil, onet.NewConnectionError(status, "Replay detected", nil)
}

// 3. Read target address and dial it.
ssr := shadowsocks.NewReader(clientReader, cipherEntry.CryptoKey)
tgtAddr, err := socks.ReadAddr(ssr)
ssw := shadowsocks.NewWriter(clientConn, cipherEntry.CryptoKey)
ssw.SetSaltGenerator(cipherEntry.SaltGenerator)
return id, transport.WrapConn(clientConn, ssr, ssw), nil
}

// Clear the deadline for the target address
clientConn.SetReadDeadline(time.Time{})
func getProxyRequest(clientConn transport.StreamConn) (string, error) {
// TODO(fortuna): Use Shadowsocks proxy, HTTP CONNECT or SOCKS5 based on first byte:
// case 1, 3 or 4: Shadowsocks (address type)
// case 5: SOCKS5 (protocol version)
// case "C": HTTP CONNECT (first char of method)
tgtAddr, err := socks.ReadAddr(clientConn)
if err != nil {
// Drain to prevent a close on cipher error.
io.Copy(io.Discard, clientConn)
return id, onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", err)
return "", err
}
tgtConn, dialErr := h.dialer.DialStream(ctx, tgtAddr.String())
return tgtAddr.String(), nil
}

func proxyConnection(ctx context.Context, dialer transport.StreamDialer, tgtAddr string, clientConn transport.StreamConn) *onet.ConnectionError {
tgtConn, dialErr := dialer.DialStream(ctx, tgtAddr)
if dialErr != nil {
// We don't drain so dial errors and invalid addresses are communicated quickly.
return id, ensureConnectionError(dialErr, "ERR_CONNECT", "Failed to connect to target")
return ensureConnectionError(dialErr, "ERR_CONNECT", "Failed to connect to target")
}
tgtConn = metrics.MeasureConn(tgtConn, &proxyMetrics.ProxyTarget, &proxyMetrics.TargetProxy)
defer tgtConn.Close()

// 4. Bridge the client and target connections
logger.Debugf("proxy %s <-> %s", clientConn.RemoteAddr().String(), tgtConn.RemoteAddr().String())
ssw := shadowsocks.NewWriter(clientConn, cipherEntry.CryptoKey)
ssw.SetSaltGenerator(cipherEntry.SaltGenerator)

fromClientErrCh := make(chan error)
go func() {
_, fromClientErr := ssr.WriteTo(tgtConn)
_, fromClientErr := io.Copy(tgtConn, clientConn)
if fromClientErr != nil {
// Drain to prevent a close in the case of a cipher error.
io.Copy(io.Discard, clientConn)
Expand All @@ -310,19 +308,58 @@ func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, cli
tgtConn.CloseWrite()
fromClientErrCh <- fromClientErr
}()
_, fromTargetErr := ssw.ReadFrom(tgtConn)
_, fromTargetErr := io.Copy(clientConn, tgtConn)
// Send FIN to client.
clientConn.CloseWrite()
tgtConn.CloseRead()

fromClientErr := <-fromClientErrCh
if fromClientErr != nil {
return id, onet.NewConnectionError("ERR_RELAY_CLIENT", "Failed to relay traffic from client", fromClientErr)
return onet.NewConnectionError("ERR_RELAY_CLIENT", "Failed to relay traffic from client", fromClientErr)
}
if fromTargetErr != nil {
return id, onet.NewConnectionError("ERR_RELAY_TARGET", "Failed to relay traffic from target", fromTargetErr)
return onet.NewConnectionError("ERR_RELAY_TARGET", "Failed to relay traffic from target", fromTargetErr)
}
return id, nil
return nil
}

func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, outerConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) {
// Set a deadline to receive the address to the target.
readDeadline := time.Now().Add(h.readTimeout)
if deadline, ok := ctx.Deadline(); ok {
outerConn.SetDeadline(deadline)
if deadline.Before(readDeadline) {
readDeadline = deadline
}
}
outerConn.SetReadDeadline(readDeadline)

id, innerConn, authErr := h.authenticate(outerConn, proxyMetrics)
if authErr != nil {
// Drain to protect against probing attacks.
h.absorbProbe(listenerPort, outerConn, authErr.Status, proxyMetrics)
return id, authErr
}

// Read target address and dial it.
tgtAddr, err := getProxyRequest(innerConn)
// Clear the deadline for the target address
outerConn.SetReadDeadline(time.Time{})
if err != nil {
// Drain to prevent a close on cipher error.
io.Copy(io.Discard, outerConn)
return id, onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", err)
}

dialer := transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) {
tgtConn, err := h.dialer.DialStream(ctx, tgtAddr)
if err != nil {
return nil, err
}
tgtConn = metrics.MeasureConn(tgtConn, &proxyMetrics.ProxyTarget, &proxyMetrics.TargetProxy)
return tgtConn, nil
})
return id, proxyConnection(ctx, dialer, tgtAddr, innerConn)
}

// Keep the connection open until we hit the authentication deadline to protect against probing attacks
Expand Down
Loading