From e7d30a055ec5c221b669f4dcd06426f0084d52fe Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Fri, 15 Mar 2024 10:21:09 -0400 Subject: [PATCH] refactor: pass a dialer to TCP serving (#150) --- internal/integration_test/integration_test.go | 9 ++- net/error.go | 20 ++++++ net/error_test.go | 49 +++++++++++++ net/private_net.go | 4 +- net/private_net_test.go | 54 +++++++++------ service/tcp.go | 68 ++++++++++--------- service/tcp_test.go | 9 ++- service/udp.go | 2 +- 8 files changed, 149 insertions(+), 66 deletions(-) create mode 100644 net/error_test.go diff --git a/internal/integration_test/integration_test.go b/internal/integration_test/integration_test.go index ecb5ee2b..66ae5cce 100644 --- a/internal/integration_test/integration_test.go +++ b/internal/integration_test/integration_test.go @@ -26,7 +26,6 @@ import ( "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" "github.com/Jigsaw-Code/outline-ss-server/ipinfo" - onet "github.com/Jigsaw-Code/outline-ss-server/net" "github.com/Jigsaw-Code/outline-ss-server/service" "github.com/Jigsaw-Code/outline-ss-server/service/metrics" sstest "github.com/Jigsaw-Code/outline-ss-server/shadowsocks" @@ -41,7 +40,7 @@ func init() { logging.SetLevel(logging.INFO, "") } -func allowAll(ip net.IP) *onet.ConnectionError { +func allowAll(ip net.IP) error { // Allow access to localhost so that we can run integration tests with // an actual destination server. return nil @@ -114,7 +113,7 @@ func TestTCPEcho(t *testing.T) { replayCache := service.NewReplayCache(5) const testTimeout = 200 * time.Millisecond handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, cipherList, &replayCache, &service.NoOpTCPMetrics{}, testTimeout) - handler.SetTargetIPValidator(allowAll) + handler.SetTargetDialer(&transport.TCPStreamDialer{}) done := make(chan struct{}) go func() { service.StreamServe(func() (transport.StreamConn, error) { return proxyListener.AcceptTCP() }, handler.Handle) @@ -362,7 +361,7 @@ func BenchmarkTCPThroughput(b *testing.B) { } const testTimeout = 200 * time.Millisecond handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, cipherList, nil, &service.NoOpTCPMetrics{}, testTimeout) - handler.SetTargetIPValidator(allowAll) + handler.SetTargetDialer(&transport.TCPStreamDialer{}) done := make(chan struct{}) go func() { service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle) @@ -424,7 +423,7 @@ func BenchmarkTCPMultiplexing(b *testing.B) { replayCache := service.NewReplayCache(service.MaxCapacity) const testTimeout = 200 * time.Millisecond handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, cipherList, &replayCache, &service.NoOpTCPMetrics{}, testTimeout) - handler.SetTargetIPValidator(allowAll) + handler.SetTargetDialer(&transport.TCPStreamDialer{}) done := make(chan struct{}) go func() { service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle) diff --git a/net/error.go b/net/error.go index 5a160262..7d888fa4 100644 --- a/net/error.go +++ b/net/error.go @@ -24,3 +24,23 @@ type ConnectionError struct { func NewConnectionError(status, message string, cause error) *ConnectionError { return &ConnectionError{Status: status, Message: message, Cause: cause} } + +func (e *ConnectionError) Error() string { + if e == nil { + return "" + } + msg := e.Message + if len(e.Status) > 0 { + msg += " [" + e.Status + "]" + } + if e.Cause != nil { + msg += ": " + e.Cause.Error() + } + return msg +} + +func (e *ConnectionError) Unwrap() error { + return e.Cause +} + +var _ error = (*ConnectionError)(nil) diff --git a/net/error_test.go b/net/error_test.go new file mode 100644 index 00000000..80c6b7e0 --- /dev/null +++ b/net/error_test.go @@ -0,0 +1,49 @@ +// Copyright 2019 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package net + +import ( + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestConnectionErrorUnwrapCause(t *testing.T) { + cause := errors.New("cause") + err := &ConnectionError{Cause: cause} + require.Equal(t, cause, err.Unwrap()) + require.ErrorIs(t, err, cause) +} + +func TestConnectionErrorString(t *testing.T) { + require.Equal(t, "example message", (&ConnectionError{Message: "example message"}).Error()) + require.Equal(t, "example message [ERR_EXAMPLE]", (&ConnectionError{Message: "example message", Status: "ERR_EXAMPLE"}).Error()) + + cause := errors.New("cause") + err := &ConnectionError{Status: "ERR_EXAMPLE", Message: "example message", Cause: cause} + require.Equal(t, "example message [ERR_EXAMPLE]: cause", err.Error()) +} + +func TestConnectionErrorFromUnwrap(t *testing.T) { + connErr := &ConnectionError{Message: "connection error"} + topErr := fmt.Errorf("top error: %w", connErr) + require.NotEqual(t, topErr, connErr) + require.ErrorIs(t, topErr, connErr) + var unwrapped *ConnectionError + require.True(t, errors.As(topErr, &unwrapped)) + require.Equal(t, connErr, unwrapped) +} diff --git a/net/private_net.go b/net/private_net.go index 3e10bf56..eb69cbe2 100644 --- a/net/private_net.go +++ b/net/private_net.go @@ -48,11 +48,11 @@ func IsPrivateAddress(ip net.IP) bool { } // TargetIPValidator is a type alias for checking if an IP is allowed. -type TargetIPValidator = func(net.IP) *ConnectionError +type TargetIPValidator = func(net.IP) error // RequirePublicIP returns an error if the destination IP is not a // standard public IP. -func RequirePublicIP(ip net.IP) *ConnectionError { +func RequirePublicIP(ip net.IP) error { if !ip.IsGlobalUnicast() { return NewConnectionError("ERR_ADDRESS_INVALID", fmt.Sprintf("Address is not global unicast: %s", ip.String()), nil) } diff --git a/net/private_net_test.go b/net/private_net_test.go index 9156e9b4..518192ce 100644 --- a/net/private_net_test.go +++ b/net/private_net_test.go @@ -15,8 +15,11 @@ package net import ( + "errors" "net" "testing" + + "github.com/stretchr/testify/assert" ) var privateAddressTests = []struct { @@ -46,39 +49,50 @@ func TestIsLanAddress(t *testing.T) { } func TestRequirePublicIP(t *testing.T) { - if err := RequirePublicIP(net.ParseIP("8.8.8.8")); err != nil { - t.Error(err) - } + var err error + + assert.Nil(t, RequirePublicIP(net.ParseIP("8.8.8.8"))) if err := RequirePublicIP(net.ParseIP("2001:4860:4860::8888")); err != nil { t.Error(err) } - err := RequirePublicIP(net.ParseIP("192.168.0.23")) - if err == nil { - t.Error("Expected error") - } else if err.Status != "ERR_ADDRESS_PRIVATE" { - t.Errorf("Wrong status %s", err.Status) + err = RequirePublicIP(net.ParseIP("192.168.0.23")) + if assert.NotNil(t, err) { + var connErr *ConnectionError + if assert.IsType(t, connErr, err) && assert.True(t, errors.As(err, &connErr)) { + assert.Equal(t, "ERR_ADDRESS_PRIVATE", connErr.Status) + } } err = RequirePublicIP(net.ParseIP("::1")) - if err == nil { - t.Error("Expected error") - } else if err.Status != "ERR_ADDRESS_INVALID" { - t.Errorf("Wrong status %s", err.Status) + if assert.NotNil(t, err) { + var connErr *ConnectionError + if assert.IsType(t, connErr, err) && assert.True(t, errors.As(err, &connErr)) { + assert.Equal(t, "ERR_ADDRESS_INVALID", connErr.Status) + } } err = RequirePublicIP(net.ParseIP("224.0.0.251")) - if err == nil { - t.Error("Expected error") - } else if err.Status != "ERR_ADDRESS_INVALID" { - t.Errorf("Wrong status %s", err.Status) + if assert.NotNil(t, err) { + var connErr *ConnectionError + if assert.IsType(t, connErr, err) && assert.True(t, errors.As(err, &connErr)) { + assert.Equal(t, "ERR_ADDRESS_INVALID", connErr.Status) + } } err = RequirePublicIP(net.ParseIP("ff02::fb")) - if err == nil { - t.Error("Expected error") - } else if err.Status != "ERR_ADDRESS_INVALID" { - t.Errorf("Wrong status %s", err.Status) + if assert.NotNil(t, err) { + var connErr *ConnectionError + if assert.IsType(t, connErr, err) && assert.True(t, errors.As(err, &connErr)) { + assert.Equal(t, "ERR_ADDRESS_INVALID", connErr.Status) + } } } + +func TestRequirePublicIPInterface(t *testing.T) { + var err error + err = RequirePublicIP(net.ParseIP("8.8.8.8")) + assert.True(t, err == nil) + assert.Equal(t, nil, err) +} diff --git a/service/tcp.go b/service/tcp.go index d295931b..70fd8bbc 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -124,53 +124,53 @@ type tcpHandler struct { m TCPMetrics readTimeout time.Duration // `replayCache` is a pointer to SSServer.replayCache, to share the cache among all ports. - replayCache *ReplayCache - targetIPValidator onet.TargetIPValidator + replayCache *ReplayCache + dialer transport.StreamDialer } // NewTCPService creates a TCPService // `replayCache` is a pointer to SSServer.replayCache, to share the cache among all ports. func NewTCPHandler(port int, ciphers CipherList, replayCache *ReplayCache, m TCPMetrics, timeout time.Duration) TCPHandler { return &tcpHandler{ - port: port, - ciphers: ciphers, - m: m, - readTimeout: timeout, - replayCache: replayCache, - targetIPValidator: onet.RequirePublicIP, + port: port, + ciphers: ciphers, + m: m, + readTimeout: timeout, + replayCache: replayCache, + dialer: defaultDialer, } } +var defaultDialer = makeValidatingTCPStreamDialer(onet.RequirePublicIP) + +func makeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator) transport.StreamDialer { + return &transport.TCPStreamDialer{Dialer: net.Dialer{Control: func(network, address string, c syscall.RawConn) error { + ip, _, _ := net.SplitHostPort(address) + return targetIPValidator(net.ParseIP(ip)) + }}} +} + // TCPService is a Shadowsocks TCP service that can be started and stopped. type TCPHandler interface { Handle(ctx context.Context, conn transport.StreamConn) - // SetTargetIPValidator sets the function to be used to validate the target IP addresses. - SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) + // SetTargetDialer sets the [transport.StreamDialer] to be used to connect to target addresses. + SetTargetDialer(dialer transport.StreamDialer) } -func (s *tcpHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) { - s.targetIPValidator = targetIPValidator +func (s *tcpHandler) SetTargetDialer(dialer transport.StreamDialer) { + s.dialer = dialer } -func dialTarget(tgtAddr socks.Addr, proxyMetrics *metrics.ProxyMetrics, targetIPValidator onet.TargetIPValidator) (transport.StreamConn, *onet.ConnectionError) { - var ipError *onet.ConnectionError - dialer := net.Dialer{Control: func(network, address string, c syscall.RawConn) error { - ip, _, _ := net.SplitHostPort(address) - ipError = targetIPValidator(net.ParseIP(ip)) - if ipError != nil { - return errors.New(ipError.Message) - } +func ensureConnectionError(err error, fallbackStatus string, fallbackMsg string) *onet.ConnectionError { + if err == nil { return nil - }} - tgtConn, err := dialer.Dial("tcp", tgtAddr.String()) - if ipError != nil { - return nil, ipError - } else if err != nil { - return nil, onet.NewConnectionError("ERR_CONNECT", "Failed to connect to target", err) } - tgtTCPConn := tgtConn.(*net.TCPConn) - tgtTCPConn.SetKeepAlive(true) - return metrics.MeasureConn(tgtTCPConn, &proxyMetrics.ProxyTarget, &proxyMetrics.TargetProxy), nil + var connErr *onet.ConnectionError + if errors.As(err, &connErr) { + return connErr + } else { + return onet.NewConnectionError(fallbackStatus, fallbackMsg, err) + } } type StreamListener func() (transport.StreamConn, error) @@ -226,7 +226,7 @@ func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn measuredClientConn := metrics.MeasureConn(clientConn, &proxyMetrics.ProxyClient, &proxyMetrics.ClientProxy) connStart := time.Now() - id, connError := h.handleConnection(h.port, measuredClientConn, &proxyMetrics) + id, connError := h.handleConnection(ctx, h.port, measuredClientConn, &proxyMetrics) connDuration := time.Since(connStart) status := "OK" @@ -239,7 +239,7 @@ func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn logger.Debugf("Done with status %v, duration %v", status, connDuration) } -func (h *tcpHandler) handleConnection(listenerPort int, clientConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) { +func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, clientConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) { // Set a deadline to receive the address to the target. clientConn.SetReadDeadline(time.Now().Add(h.readTimeout)) @@ -275,6 +275,7 @@ func (h *tcpHandler) handleConnection(listenerPort int, clientConn transport.Str // 3. Read target address and dial it. ssr := shadowsocks.NewReader(clientReader, cipherEntry.CryptoKey) tgtAddr, err := socks.ReadAddr(ssr) + // Clear the deadline for the target address clientConn.SetReadDeadline(time.Time{}) if err != nil { @@ -282,11 +283,12 @@ func (h *tcpHandler) handleConnection(listenerPort int, clientConn transport.Str io.Copy(io.Discard, clientConn) return id, onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", err) } - tgtConn, dialErr := dialTarget(tgtAddr, proxyMetrics, h.targetIPValidator) + tgtConn, dialErr := h.dialer.Dial(ctx, tgtAddr.String()) if dialErr != nil { // We don't drain so dial errors and invalid addresses are communicated quickly. - return id, dialErr + return id, ensureConnectionError(dialErr, "ERR_CONNECT", "Failed to connect to target") } + tgtConn = metrics.MeasureConn(tgtConn, &proxyMetrics.ProxyTarget, &proxyMetrics.TargetProxy) defer tgtConn.Close() // 4. Bridge the client and target connections diff --git a/service/tcp_test.go b/service/tcp_test.go index 1f381111..32ff20f0 100644 --- a/service/tcp_test.go +++ b/service/tcp_test.go @@ -28,7 +28,6 @@ import ( "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" "github.com/Jigsaw-Code/outline-ss-server/ipinfo" - onet "github.com/Jigsaw-Code/outline-ss-server/net" "github.com/Jigsaw-Code/outline-ss-server/service/metrics" logging "github.com/op/go-logging" "github.com/shadowsocks/go-shadowsocks2/socks" @@ -39,7 +38,7 @@ func init() { logging.SetLevel(logging.INFO, "") } -func allowAll(ip net.IP) *onet.ConnectionError { +func allowAll(ip net.IP) error { // Allow access to localhost so that we can run integration tests with // an actual destination server. return nil @@ -353,7 +352,7 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, 200*time.Millisecond) - handler.SetTargetIPValidator(allowAll) + handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) @@ -388,7 +387,7 @@ func TestProbeClientBytesBasicModified(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, 200*time.Millisecond) - handler.SetTargetIPValidator(allowAll) + handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) @@ -424,7 +423,7 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, 200*time.Millisecond) - handler.SetTargetIPValidator(allowAll) + handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) diff --git a/service/udp.go b/service/udp.go index 4138ed81..5ac3880b 100644 --- a/service/udp.go +++ b/service/udp.go @@ -230,7 +230,7 @@ func (h *packetHandler) validatePacket(textData []byte) ([]byte, *net.UDPAddr, * return nil, nil, onet.NewConnectionError("ERR_RESOLVE_ADDRESS", fmt.Sprintf("Failed to resolve target address %v", tgtAddr), err) } if err := h.targetIPValidator(tgtUDPAddr.IP); err != nil { - return nil, nil, err + return nil, nil, ensureConnectionError(err, "ERR_ADDRESS_INVALID", "invalid address") } payload := textData[len(tgtAddr):]