Skip to content

Commit

Permalink
refactor: pass a dialer to TCP serving (#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna authored Mar 15, 2024
1 parent 6fc944e commit e7d30a0
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 66 deletions.
9 changes: 4 additions & 5 deletions internal/integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
"github.com/Jigsaw-Code/outline-sdk/transport"
"github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks"
"github.com/Jigsaw-Code/outline-ss-server/ipinfo"
onet "github.com/Jigsaw-Code/outline-ss-server/net"
"github.com/Jigsaw-Code/outline-ss-server/service"
"github.com/Jigsaw-Code/outline-ss-server/service/metrics"
sstest "github.com/Jigsaw-Code/outline-ss-server/shadowsocks"
Expand All @@ -41,7 +40,7 @@ func init() {
logging.SetLevel(logging.INFO, "")
}

func allowAll(ip net.IP) *onet.ConnectionError {
func allowAll(ip net.IP) error {
// Allow access to localhost so that we can run integration tests with
// an actual destination server.
return nil
Expand Down Expand Up @@ -114,7 +113,7 @@ func TestTCPEcho(t *testing.T) {
replayCache := service.NewReplayCache(5)
const testTimeout = 200 * time.Millisecond
handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, cipherList, &replayCache, &service.NoOpTCPMetrics{}, testTimeout)
handler.SetTargetIPValidator(allowAll)
handler.SetTargetDialer(&transport.TCPStreamDialer{})
done := make(chan struct{})
go func() {
service.StreamServe(func() (transport.StreamConn, error) { return proxyListener.AcceptTCP() }, handler.Handle)
Expand Down Expand Up @@ -362,7 +361,7 @@ func BenchmarkTCPThroughput(b *testing.B) {
}
const testTimeout = 200 * time.Millisecond
handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, cipherList, nil, &service.NoOpTCPMetrics{}, testTimeout)
handler.SetTargetIPValidator(allowAll)
handler.SetTargetDialer(&transport.TCPStreamDialer{})
done := make(chan struct{})
go func() {
service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle)
Expand Down Expand Up @@ -424,7 +423,7 @@ func BenchmarkTCPMultiplexing(b *testing.B) {
replayCache := service.NewReplayCache(service.MaxCapacity)
const testTimeout = 200 * time.Millisecond
handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, cipherList, &replayCache, &service.NoOpTCPMetrics{}, testTimeout)
handler.SetTargetIPValidator(allowAll)
handler.SetTargetDialer(&transport.TCPStreamDialer{})
done := make(chan struct{})
go func() {
service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle)
Expand Down
20 changes: 20 additions & 0 deletions net/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,23 @@ type ConnectionError struct {
func NewConnectionError(status, message string, cause error) *ConnectionError {
return &ConnectionError{Status: status, Message: message, Cause: cause}
}

func (e *ConnectionError) Error() string {
if e == nil {
return "<nil>"
}
msg := e.Message
if len(e.Status) > 0 {
msg += " [" + e.Status + "]"
}
if e.Cause != nil {
msg += ": " + e.Cause.Error()
}
return msg
}

func (e *ConnectionError) Unwrap() error {
return e.Cause
}

var _ error = (*ConnectionError)(nil)
49 changes: 49 additions & 0 deletions net/error_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright 2019 Jigsaw Operations LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package net

import (
"errors"
"fmt"
"testing"

"github.com/stretchr/testify/require"
)

func TestConnectionErrorUnwrapCause(t *testing.T) {
cause := errors.New("cause")
err := &ConnectionError{Cause: cause}
require.Equal(t, cause, err.Unwrap())
require.ErrorIs(t, err, cause)
}

func TestConnectionErrorString(t *testing.T) {
require.Equal(t, "example message", (&ConnectionError{Message: "example message"}).Error())
require.Equal(t, "example message [ERR_EXAMPLE]", (&ConnectionError{Message: "example message", Status: "ERR_EXAMPLE"}).Error())

cause := errors.New("cause")
err := &ConnectionError{Status: "ERR_EXAMPLE", Message: "example message", Cause: cause}
require.Equal(t, "example message [ERR_EXAMPLE]: cause", err.Error())
}

func TestConnectionErrorFromUnwrap(t *testing.T) {
connErr := &ConnectionError{Message: "connection error"}
topErr := fmt.Errorf("top error: %w", connErr)
require.NotEqual(t, topErr, connErr)
require.ErrorIs(t, topErr, connErr)
var unwrapped *ConnectionError
require.True(t, errors.As(topErr, &unwrapped))
require.Equal(t, connErr, unwrapped)
}
4 changes: 2 additions & 2 deletions net/private_net.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ func IsPrivateAddress(ip net.IP) bool {
}

// TargetIPValidator is a type alias for checking if an IP is allowed.
type TargetIPValidator = func(net.IP) *ConnectionError
type TargetIPValidator = func(net.IP) error

// RequirePublicIP returns an error if the destination IP is not a
// standard public IP.
func RequirePublicIP(ip net.IP) *ConnectionError {
func RequirePublicIP(ip net.IP) error {
if !ip.IsGlobalUnicast() {
return NewConnectionError("ERR_ADDRESS_INVALID", fmt.Sprintf("Address is not global unicast: %s", ip.String()), nil)
}
Expand Down
54 changes: 34 additions & 20 deletions net/private_net_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
package net

import (
"errors"
"net"
"testing"

"github.com/stretchr/testify/assert"
)

var privateAddressTests = []struct {
Expand Down Expand Up @@ -46,39 +49,50 @@ func TestIsLanAddress(t *testing.T) {
}

func TestRequirePublicIP(t *testing.T) {
if err := RequirePublicIP(net.ParseIP("8.8.8.8")); err != nil {
t.Error(err)
}
var err error

assert.Nil(t, RequirePublicIP(net.ParseIP("8.8.8.8")))

if err := RequirePublicIP(net.ParseIP("2001:4860:4860::8888")); err != nil {
t.Error(err)
}

err := RequirePublicIP(net.ParseIP("192.168.0.23"))
if err == nil {
t.Error("Expected error")
} else if err.Status != "ERR_ADDRESS_PRIVATE" {
t.Errorf("Wrong status %s", err.Status)
err = RequirePublicIP(net.ParseIP("192.168.0.23"))
if assert.NotNil(t, err) {
var connErr *ConnectionError
if assert.IsType(t, connErr, err) && assert.True(t, errors.As(err, &connErr)) {
assert.Equal(t, "ERR_ADDRESS_PRIVATE", connErr.Status)
}
}

err = RequirePublicIP(net.ParseIP("::1"))
if err == nil {
t.Error("Expected error")
} else if err.Status != "ERR_ADDRESS_INVALID" {
t.Errorf("Wrong status %s", err.Status)
if assert.NotNil(t, err) {
var connErr *ConnectionError
if assert.IsType(t, connErr, err) && assert.True(t, errors.As(err, &connErr)) {
assert.Equal(t, "ERR_ADDRESS_INVALID", connErr.Status)
}
}

err = RequirePublicIP(net.ParseIP("224.0.0.251"))
if err == nil {
t.Error("Expected error")
} else if err.Status != "ERR_ADDRESS_INVALID" {
t.Errorf("Wrong status %s", err.Status)
if assert.NotNil(t, err) {
var connErr *ConnectionError
if assert.IsType(t, connErr, err) && assert.True(t, errors.As(err, &connErr)) {
assert.Equal(t, "ERR_ADDRESS_INVALID", connErr.Status)
}
}

err = RequirePublicIP(net.ParseIP("ff02::fb"))
if err == nil {
t.Error("Expected error")
} else if err.Status != "ERR_ADDRESS_INVALID" {
t.Errorf("Wrong status %s", err.Status)
if assert.NotNil(t, err) {
var connErr *ConnectionError
if assert.IsType(t, connErr, err) && assert.True(t, errors.As(err, &connErr)) {
assert.Equal(t, "ERR_ADDRESS_INVALID", connErr.Status)
}
}
}

func TestRequirePublicIPInterface(t *testing.T) {
var err error
err = RequirePublicIP(net.ParseIP("8.8.8.8"))
assert.True(t, err == nil)
assert.Equal(t, nil, err)
}
68 changes: 35 additions & 33 deletions service/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,53 +124,53 @@ type tcpHandler struct {
m TCPMetrics
readTimeout time.Duration
// `replayCache` is a pointer to SSServer.replayCache, to share the cache among all ports.
replayCache *ReplayCache
targetIPValidator onet.TargetIPValidator
replayCache *ReplayCache
dialer transport.StreamDialer
}

// NewTCPService creates a TCPService
// `replayCache` is a pointer to SSServer.replayCache, to share the cache among all ports.
func NewTCPHandler(port int, ciphers CipherList, replayCache *ReplayCache, m TCPMetrics, timeout time.Duration) TCPHandler {
return &tcpHandler{
port: port,
ciphers: ciphers,
m: m,
readTimeout: timeout,
replayCache: replayCache,
targetIPValidator: onet.RequirePublicIP,
port: port,
ciphers: ciphers,
m: m,
readTimeout: timeout,
replayCache: replayCache,
dialer: defaultDialer,
}
}

var defaultDialer = makeValidatingTCPStreamDialer(onet.RequirePublicIP)

func makeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator) transport.StreamDialer {
return &transport.TCPStreamDialer{Dialer: net.Dialer{Control: func(network, address string, c syscall.RawConn) error {
ip, _, _ := net.SplitHostPort(address)
return targetIPValidator(net.ParseIP(ip))
}}}
}

// TCPService is a Shadowsocks TCP service that can be started and stopped.
type TCPHandler interface {
Handle(ctx context.Context, conn transport.StreamConn)
// SetTargetIPValidator sets the function to be used to validate the target IP addresses.
SetTargetIPValidator(targetIPValidator onet.TargetIPValidator)
// SetTargetDialer sets the [transport.StreamDialer] to be used to connect to target addresses.
SetTargetDialer(dialer transport.StreamDialer)
}

func (s *tcpHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) {
s.targetIPValidator = targetIPValidator
func (s *tcpHandler) SetTargetDialer(dialer transport.StreamDialer) {
s.dialer = dialer
}

func dialTarget(tgtAddr socks.Addr, proxyMetrics *metrics.ProxyMetrics, targetIPValidator onet.TargetIPValidator) (transport.StreamConn, *onet.ConnectionError) {
var ipError *onet.ConnectionError
dialer := net.Dialer{Control: func(network, address string, c syscall.RawConn) error {
ip, _, _ := net.SplitHostPort(address)
ipError = targetIPValidator(net.ParseIP(ip))
if ipError != nil {
return errors.New(ipError.Message)
}
func ensureConnectionError(err error, fallbackStatus string, fallbackMsg string) *onet.ConnectionError {
if err == nil {
return nil
}}
tgtConn, err := dialer.Dial("tcp", tgtAddr.String())
if ipError != nil {
return nil, ipError
} else if err != nil {
return nil, onet.NewConnectionError("ERR_CONNECT", "Failed to connect to target", err)
}
tgtTCPConn := tgtConn.(*net.TCPConn)
tgtTCPConn.SetKeepAlive(true)
return metrics.MeasureConn(tgtTCPConn, &proxyMetrics.ProxyTarget, &proxyMetrics.TargetProxy), nil
var connErr *onet.ConnectionError
if errors.As(err, &connErr) {
return connErr
} else {
return onet.NewConnectionError(fallbackStatus, fallbackMsg, err)
}
}

type StreamListener func() (transport.StreamConn, error)
Expand Down Expand Up @@ -226,7 +226,7 @@ func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn
measuredClientConn := metrics.MeasureConn(clientConn, &proxyMetrics.ProxyClient, &proxyMetrics.ClientProxy)
connStart := time.Now()

id, connError := h.handleConnection(h.port, measuredClientConn, &proxyMetrics)
id, connError := h.handleConnection(ctx, h.port, measuredClientConn, &proxyMetrics)

connDuration := time.Since(connStart)
status := "OK"
Expand All @@ -239,7 +239,7 @@ func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn
logger.Debugf("Done with status %v, duration %v", status, connDuration)
}

func (h *tcpHandler) handleConnection(listenerPort int, clientConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) {
func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, clientConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) {
// Set a deadline to receive the address to the target.
clientConn.SetReadDeadline(time.Now().Add(h.readTimeout))

Expand Down Expand Up @@ -275,18 +275,20 @@ func (h *tcpHandler) handleConnection(listenerPort int, clientConn transport.Str
// 3. Read target address and dial it.
ssr := shadowsocks.NewReader(clientReader, cipherEntry.CryptoKey)
tgtAddr, err := socks.ReadAddr(ssr)

// Clear the deadline for the target address
clientConn.SetReadDeadline(time.Time{})
if err != nil {
// Drain to prevent a close on cipher error.
io.Copy(io.Discard, clientConn)
return id, onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", err)
}
tgtConn, dialErr := dialTarget(tgtAddr, proxyMetrics, h.targetIPValidator)
tgtConn, dialErr := h.dialer.Dial(ctx, tgtAddr.String())
if dialErr != nil {
// We don't drain so dial errors and invalid addresses are communicated quickly.
return id, dialErr
return id, ensureConnectionError(dialErr, "ERR_CONNECT", "Failed to connect to target")
}
tgtConn = metrics.MeasureConn(tgtConn, &proxyMetrics.ProxyTarget, &proxyMetrics.TargetProxy)
defer tgtConn.Close()

// 4. Bridge the client and target connections
Expand Down
9 changes: 4 additions & 5 deletions service/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"github.com/Jigsaw-Code/outline-sdk/transport"
"github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks"
"github.com/Jigsaw-Code/outline-ss-server/ipinfo"
onet "github.com/Jigsaw-Code/outline-ss-server/net"
"github.com/Jigsaw-Code/outline-ss-server/service/metrics"
logging "github.com/op/go-logging"
"github.com/shadowsocks/go-shadowsocks2/socks"
Expand All @@ -39,7 +38,7 @@ func init() {
logging.SetLevel(logging.INFO, "")
}

func allowAll(ip net.IP) *onet.ConnectionError {
func allowAll(ip net.IP) error {
// Allow access to localhost so that we can run integration tests with
// an actual destination server.
return nil
Expand Down Expand Up @@ -353,7 +352,7 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) {
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, 200*time.Millisecond)
handler.SetTargetIPValidator(allowAll)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll))
done := make(chan struct{})
go func() {
StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle)
Expand Down Expand Up @@ -388,7 +387,7 @@ func TestProbeClientBytesBasicModified(t *testing.T) {
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, 200*time.Millisecond)
handler.SetTargetIPValidator(allowAll)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll))
done := make(chan struct{})
go func() {
StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle)
Expand Down Expand Up @@ -424,7 +423,7 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) {
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, cipherList, nil, testMetrics, 200*time.Millisecond)
handler.SetTargetIPValidator(allowAll)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll))
done := make(chan struct{})
go func() {
StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle)
Expand Down
2 changes: 1 addition & 1 deletion service/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func (h *packetHandler) validatePacket(textData []byte) ([]byte, *net.UDPAddr, *
return nil, nil, onet.NewConnectionError("ERR_RESOLVE_ADDRESS", fmt.Sprintf("Failed to resolve target address %v", tgtAddr), err)
}
if err := h.targetIPValidator(tgtUDPAddr.IP); err != nil {
return nil, nil, err
return nil, nil, ensureConnectionError(err, "ERR_ADDRESS_INVALID", "invalid address")
}

payload := textData[len(tgtAddr):]
Expand Down

0 comments on commit e7d30a0

Please sign in to comment.