diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index d85fd72a..0ed7295d 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -24,11 +24,13 @@ import ( type ServiceConfig struct { Listeners []ListenerConfig Keys []KeyConfig + Dialer DialerConfig } type ListenerType string const listenerTypeTCP ListenerType = "tcp" + const listenerTypeUDP ListenerType = "udp" type ListenerConfig struct { @@ -36,6 +38,10 @@ type ListenerConfig struct { Address string } +type DialerConfig struct { + Fwmark uint +} + type KeyConfig struct { ID string Cipher string diff --git a/cmd/outline-ss-server/config_example.yml b/cmd/outline-ss-server/config_example.yml index 7131c4cb..8dddbb6f 100644 --- a/cmd/outline-ss-server/config_example.yml +++ b/cmd/outline-ss-server/config_example.yml @@ -21,19 +21,22 @@ services: - type: udp address: "[::]:9000" keys: - - id: user-0 - cipher: chacha20-ietf-poly1305 - secret: Secret0 - - id: user-1 - cipher: chacha20-ietf-poly1305 - secret: Secret1 - + - id: user-0 + cipher: chacha20-ietf-poly1305 + secret: Secret0 + - id: user-1 + cipher: chacha20-ietf-poly1305 + secret: Secret1 + dialer: + # fwmark can be used in conjunction with other Linux networking features like cgroups, network namespaces, and TC (Traffic Control) for sophisticated network management. + # Value of 0 disables fwmark (SO_MARK) (Linux Only) + fwmark: 0 - listeners: - type: tcp address: "[::]:9001" - type: udp address: "[::]:9001" keys: - - id: user-2 - cipher: chacha20-ietf-poly1305 - secret: Secret2 + - id: user-2 + cipher: chacha20-ietf-poly1305 + secret: Secret2 diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index b497a0b0..ae20648b 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -28,17 +28,21 @@ import ( "time" "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" - "github.com/Jigsaw-Code/outline-ss-server/ipinfo" - outline_prometheus "github.com/Jigsaw-Code/outline-ss-server/prometheus" - "github.com/Jigsaw-Code/outline-ss-server/service" "github.com/lmittmann/tint" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "golang.org/x/term" + + "github.com/Jigsaw-Code/outline-ss-server/ipinfo" + onet "github.com/Jigsaw-Code/outline-ss-server/net" + outline_prometheus "github.com/Jigsaw-Code/outline-ss-server/prometheus" + "github.com/Jigsaw-Code/outline-ss-server/service" ) -var logLevel = new(slog.LevelVar) // Info by default -var logHandler slog.Handler +var ( + logLevel = new(slog.LevelVar) // Info by default + logHandler slog.Handler +) // Set by goreleaser default ldflags. See https://goreleaser.com/customization/build/ var version = "dev" @@ -221,9 +225,10 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { ssService, err := service.NewShadowsocksService( service.WithCiphers(ciphers), - service.WithNatTimeout(s.natTimeout), service.WithMetrics(s.serviceMetrics), service.WithReplayCache(&s.replayCache), + service.WithStreamDialer(service.MakeValidatingTCPStreamDialer(onet.RequirePublicIP, 0)), + service.WithPacketListener(service.MakeTargetUDPListener(s.natTimeout, 0)), service.WithLogger(slog.Default()), ) ln, err := lnSet.ListenStream(addr) @@ -248,9 +253,10 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { } ssService, err := service.NewShadowsocksService( service.WithCiphers(ciphers), - service.WithNatTimeout(s.natTimeout), service.WithMetrics(s.serviceMetrics), service.WithReplayCache(&s.replayCache), + service.WithStreamDialer(service.MakeValidatingTCPStreamDialer(onet.RequirePublicIP, serviceConfig.Dialer.Fwmark)), + service.WithPacketListener(service.MakeTargetUDPListener(s.natTimeout, serviceConfig.Dialer.Fwmark)), service.WithLogger(slog.Default()), ) if err != nil { @@ -263,14 +269,24 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { if err != nil { return err } - slog.Info("TCP service started.", "address", ln.Addr().String()) + slog.Info("TCP service started.", "address", ln.Addr().String(), "fwmark", func() any { + if serviceConfig.Dialer.Fwmark == 0 { + return "disabled" + } + return serviceConfig.Dialer.Fwmark + }()) go service.StreamServe(ln.AcceptStream, ssService.HandleStream) case listenerTypeUDP: pc, err := lnSet.ListenPacket(lnConfig.Address) if err != nil { return err } - slog.Info("UDP service started.", "address", pc.LocalAddr().String()) + slog.Info("UDP service started.", "address", pc.LocalAddr().String(), "fwmark", func() any { + if serviceConfig.Dialer.Fwmark == 0 { + return "disabled" + } + return serviceConfig.Dialer.Fwmark + }()) go service.PacketServe(pc, ssService.NewAssociation, s.serverMetrics) } } diff --git a/internal/integration_test/integration_test.go b/internal/integration_test/integration_test.go index cfba8b0a..680c7ad7 100644 --- a/internal/integration_test/integration_test.go +++ b/internal/integration_test/integration_test.go @@ -317,7 +317,7 @@ func TestUDPEcho(t *testing.T) { if err != nil { t.Fatal(err) } - proxy := service.NewPacketHandler(time.Hour, cipherList, &fakeShadowsocksMetrics{}) + proxy := service.NewPacketHandler(cipherList, &fakeShadowsocksMetrics{}) proxy.SetTargetIPValidator(allowAll) natMetrics := &natTestMetrics{} @@ -545,7 +545,7 @@ func BenchmarkUDPEcho(b *testing.B) { if err != nil { b.Fatal(err) } - proxy := service.NewPacketHandler(time.Hour, cipherList, &fakeShadowsocksMetrics{}) + proxy := service.NewPacketHandler(cipherList, &fakeShadowsocksMetrics{}) proxy.SetTargetIPValidator(allowAll) done := make(chan struct{}) go func() { @@ -591,7 +591,7 @@ func BenchmarkUDPManyKeys(b *testing.B) { if err != nil { b.Fatal(err) } - proxy := service.NewPacketHandler(time.Hour, cipherList, &fakeShadowsocksMetrics{}) + proxy := service.NewPacketHandler(cipherList, &fakeShadowsocksMetrics{}) proxy.SetTargetIPValidator(allowAll) done := make(chan struct{}) go func() { diff --git a/service/shadowsocks.go b/service/shadowsocks.go index 112fb08b..2a80ea83 100644 --- a/service/shadowsocks.go +++ b/service/shadowsocks.go @@ -21,14 +21,13 @@ import ( "time" "github.com/Jigsaw-Code/outline-sdk/transport" + + onet "github.com/Jigsaw-Code/outline-ss-server/net" ) const ( // 59 seconds is most common timeout for servers that do not respond to invalid requests tcpReadTimeout time.Duration = 59 * time.Second - - // A UDP NAT timeout of at least 5 minutes is recommended in RFC 4787 Section 4.3. - defaultNatTimeout time.Duration = 5 * time.Minute ) // ShadowsocksConnMetrics is used to report Shadowsocks related metrics on connections. @@ -51,14 +50,16 @@ type Service interface { type Option func(s *ssService) type ssService struct { - logger *slog.Logger - metrics ServiceMetrics - ciphers CipherList - natTimeout time.Duration - replayCache *ReplayCache - - sh StreamHandler - ph PacketHandler + logger *slog.Logger + metrics ServiceMetrics + ciphers CipherList + targetIPValidator onet.TargetIPValidator + replayCache *ReplayCache + + streamDialer transport.StreamDialer + sh StreamHandler + packetListener transport.PacketListener + ph PacketHandler } // NewShadowsocksService creates a new Shadowsocks service. @@ -69,10 +70,6 @@ func NewShadowsocksService(opts ...Option) (Service, error) { opt(s) } - // If no NAT timeout is provided via options, use the recommended default. - if s.natTimeout == 0 { - s.natTimeout = defaultNatTimeout - } // If no logger is provided via options, use a noop logger. if s.logger == nil { s.logger = noopLogger() @@ -83,9 +80,15 @@ func NewShadowsocksService(opts ...Option) (Service, error) { NewShadowsocksStreamAuthenticator(s.ciphers, s.replayCache, &ssConnMetrics{ServiceMetrics: s.metrics, proto: "tcp"}, s.logger), tcpReadTimeout, ) + if s.streamDialer != nil { + s.sh.SetTargetDialer(s.streamDialer) + } s.sh.SetLogger(s.logger) - s.ph = NewPacketHandler(s.natTimeout, s.ciphers, &ssConnMetrics{ServiceMetrics: s.metrics, proto: "udp"}) + s.ph = NewPacketHandler(s.ciphers, &ssConnMetrics{ServiceMetrics: s.metrics, proto: "udp"}) + if s.packetListener != nil { + s.ph.SetTargetPacketListener(s.packetListener) + } s.ph.SetLogger(s.logger) return s, nil @@ -120,10 +123,17 @@ func WithReplayCache(replayCache *ReplayCache) Option { } } -// WithNatTimeout option function. -func WithNatTimeout(natTimeout time.Duration) Option { +// WithStreamDialer option function. +func WithStreamDialer(dialer transport.StreamDialer) Option { + return func(s *ssService) { + s.streamDialer = dialer + } +} + +// WithPacketListener option function. +func WithPacketListener(listener transport.PacketListener) Option { return func(s *ssService) { - s.natTimeout = natTimeout + s.packetListener = listener } } diff --git a/service/socketopts_linux.go b/service/socketopts_linux.go new file mode 100644 index 00000000..c6d807ca --- /dev/null +++ b/service/socketopts_linux.go @@ -0,0 +1,33 @@ +// Copyright 2024 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux + +package service + +import ( + "os" + "syscall" +) + +func SetFwmark(rc syscall.RawConn, fwmark uint) error { + var err error + rc.Control(func(fd uintptr) { + err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, int(fwmark)) + }) + if err != nil { + return os.NewSyscallError("failed to set fwmark for socket", err) + } + return nil +} diff --git a/service/tcp.go b/service/tcp.go index 7823e314..5cb1b864 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -25,14 +25,14 @@ import ( "net" "net/netip" "sync" - "syscall" "time" "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" + "github.com/shadowsocks/go-shadowsocks2/socks" + onet "github.com/Jigsaw-Code/outline-ss-server/net" "github.com/Jigsaw-Code/outline-ss-server/service/metrics" - "github.com/shadowsocks/go-shadowsocks2/socks" ) // TCPConnMetrics is used to report metrics on TCP connections. @@ -172,19 +172,10 @@ func NewStreamHandler(authenticate StreamAuthenticateFunc, timeout time.Duration logger: noopLogger(), readTimeout: timeout, authenticate: authenticate, - dialer: defaultDialer, + dialer: MakeValidatingTCPStreamDialer(onet.RequirePublicIP, 0), } } -var defaultDialer = makeValidatingTCPStreamDialer(onet.RequirePublicIP) - -func makeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator) transport.StreamDialer { - return &transport.TCPDialer{Dialer: net.Dialer{Control: func(network, address string, c syscall.RawConn) error { - ip, _, _ := net.SplitHostPort(address) - return targetIPValidator(net.ParseIP(ip)) - }}} -} - // StreamHandler is a handler that handles stream connections. type StreamHandler interface { Handle(ctx context.Context, conn transport.StreamConn, connMetrics TCPConnMetrics) @@ -399,6 +390,8 @@ type NoOpTCPConnMetrics struct{} var _ TCPConnMetrics = (*NoOpTCPConnMetrics)(nil) func (m *NoOpTCPConnMetrics) AddAuthenticated(accessKey string) {} + func (m *NoOpTCPConnMetrics) AddClosed(status string, data metrics.ProxyMetrics, duration time.Duration) { } + func (m *NoOpTCPConnMetrics) AddProbe(status, drainResult string, clientProxyBytes int64) {} diff --git a/service/tcp_linux.go b/service/tcp_linux.go new file mode 100644 index 00000000..82c63870 --- /dev/null +++ b/service/tcp_linux.go @@ -0,0 +1,40 @@ +// Copyright 2024 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux + +package service + +import ( + "net" + "syscall" + + "github.com/Jigsaw-Code/outline-sdk/transport" + + onet "github.com/Jigsaw-Code/outline-ss-server/net" +) + +// fwmark can be used in conjunction with other Linux networking features like cgroups, network namespaces, and TC (Traffic Control) for sophisticated network management. +// Value of 0 disables fwmark (SO_MARK) (Linux Only) +func MakeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator, fwmark uint) transport.StreamDialer { + return &transport.TCPDialer{Dialer: net.Dialer{Control: func(network, address string, c syscall.RawConn) error { + if fwmark > 0 { + if err := SetFwmark(c, fwmark); err != nil { + return err + } + } + ip, _, _ := net.SplitHostPort(address) + return targetIPValidator(net.ParseIP(ip)) + }}} +} diff --git a/service/tcp_other.go b/service/tcp_other.go new file mode 100644 index 00000000..c86bb07c --- /dev/null +++ b/service/tcp_other.go @@ -0,0 +1,38 @@ +// Copyright 2024 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !linux + +package service + +import ( + "net" + "syscall" + + "github.com/Jigsaw-Code/outline-sdk/transport" + + onet "github.com/Jigsaw-Code/outline-ss-server/net" +) + +// fwmark can be used in conjunction with other Linux networking features like cgroups, network namespaces, and TC (Traffic Control) for sophisticated network management. +// Value of 0 disables fwmark (SO_MARK) (Linux Only) +func MakeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator, fwmark uint) transport.StreamDialer { + if fwmark != 0 { + panic("fwmark is linux-specific feature and should be 0") + } + return &transport.TCPDialer{Dialer: net.Dialer{Control: func(network, address string, c syscall.RawConn) error { + ip, _, _ := net.SplitHostPort(address) + return targetIPValidator(net.ParseIP(ip)) + }}} +} diff --git a/service/tcp_test.go b/service/tcp_test.go index e69d1d1e..ab497f91 100644 --- a/service/tcp_test.go +++ b/service/tcp_test.go @@ -29,10 +29,11 @@ import ( "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" - "github.com/Jigsaw-Code/outline-ss-server/service/metrics" logging "github.com/op/go-logging" "github.com/shadowsocks/go-shadowsocks2/socks" "github.com/stretchr/testify/require" + + "github.com/Jigsaw-Code/outline-ss-server/service/metrics" ) func init() { @@ -215,8 +216,7 @@ func BenchmarkTCPFindCipherRepeat(b *testing.B) { } // Stub implementation for shadowsocks authentication metrics. -type fakeShadowsocksMetrics struct { -} +type fakeShadowsocksMetrics struct{} var _ ShadowsocksConnMetrics = (*fakeShadowsocksMetrics)(nil) @@ -232,6 +232,7 @@ type probeTestMetrics struct { } var _ TCPConnMetrics = (*probeTestMetrics)(nil) + var _ ShadowsocksConnMetrics = (*fakeShadowsocksMetrics)(nil) func (m *probeTestMetrics) AddClosed(status string, data metrics.ProxyMetrics, duration time.Duration) { @@ -367,7 +368,7 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) { testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil) handler := NewStreamHandler(authFunc, 200*time.Millisecond) - handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) + handler.SetTargetDialer(MakeValidatingTCPStreamDialer(allowAll, 0)) done := make(chan struct{}) go func() { StreamServe( @@ -405,7 +406,7 @@ func TestProbeClientBytesBasicModified(t *testing.T) { testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil) handler := NewStreamHandler(authFunc, 200*time.Millisecond) - handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) + handler.SetTargetDialer(MakeValidatingTCPStreamDialer(allowAll, 0)) done := make(chan struct{}) go func() { StreamServe( @@ -444,7 +445,7 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) { testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil) handler := NewStreamHandler(authFunc, 200*time.Millisecond) - handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) + handler.SetTargetDialer(MakeValidatingTCPStreamDialer(allowAll, 0)) done := make(chan struct{}) go func() { StreamServe( diff --git a/service/udp.go b/service/udp.go index 2caf9205..3c2f89c6 100644 --- a/service/udp.go +++ b/service/udp.go @@ -15,6 +15,7 @@ package service import ( + "context" "errors" "fmt" "log/slog" @@ -24,6 +25,7 @@ import ( "sync" "time" + "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" "github.com/shadowsocks/go-shadowsocks2/socks" @@ -45,8 +47,13 @@ type UDPAssocationMetrics interface { AddClosed() } -// Max UDP buffer size for the server code. -const serverUDPBufferSize = 64 * 1024 +const ( + // A UDP NAT timeout of at least 5 minutes is recommended in RFC 4787 Section 4.3. + defaultNatTimeout time.Duration = 5 * time.Minute + + // Max UDP buffer size for the server code. + serverUDPBufferSize = 64 * 1024 +) // Buffer pool used for reading UDP packets. var readBufPool = slicepool.MakePool(serverUDPBufferSize) @@ -88,13 +95,13 @@ type packetHandler struct { ciphers CipherList ssm ShadowsocksConnMetrics targetIPValidator onet.TargetIPValidator - targetConnFactory func() (net.PacketConn, error) + targetListener transport.PacketListener } var _ PacketHandler = (*packetHandler)(nil) -// NewPacketHandler creates an PacketHandler -func NewPacketHandler(natTimeout time.Duration, cipherList CipherList, ssMetrics ShadowsocksConnMetrics) PacketHandler { +// NewPacketHandler creates a PacketHandler +func NewPacketHandler(cipherList CipherList, ssMetrics ShadowsocksConnMetrics) PacketHandler { if ssMetrics == nil { ssMetrics = &NoOpShadowsocksConnMetrics{} } @@ -104,16 +111,7 @@ func NewPacketHandler(natTimeout time.Duration, cipherList CipherList, ssMetrics ciphers: cipherList, ssm: ssMetrics, targetIPValidator: onet.RequirePublicIP, - targetConnFactory: func() (net.PacketConn, error) { - pc, err := net.ListenPacket("udp", "") - if err != nil { - return nil, fmt.Errorf("failed to create UDP socket: %v", err) - } - return &timedPacketConn{ - PacketConn: pc, - defaultTimeout: natTimeout, - }, nil - }, + targetListener: MakeTargetUDPListener(defaultNatTimeout, 0), } } @@ -123,8 +121,8 @@ type PacketHandler interface { SetLogger(l *slog.Logger) // SetTargetIPValidator sets the function to be used to validate the target IP addresses. SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) - // SetTargetConnFactory sets the function to be used to create new target connections. - SetTargetConnFactory(factory func() (net.PacketConn, error)) + // SetTargetPacketListener sets the packet listener to use for target connections. + SetTargetPacketListener(targetListener transport.PacketListener) // NewAssociation creates a new Association instance. NewAssociation(conn net.Conn, connMetrics UDPAssocationMetrics) (Association, error) } @@ -140,17 +138,17 @@ func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPVali h.targetIPValidator = targetIPValidator } -func (h *packetHandler) SetTargetConnFactory(factory func() (net.PacketConn, error)) { - h.targetConnFactory = factory +func (h *packetHandler) SetTargetPacketListener(targetListener transport.PacketListener) { + h.targetListener = targetListener } func (h *packetHandler) NewAssociation(conn net.Conn, m UDPAssocationMetrics) (Association, error) { if m == nil { - m = &NoOpUDPAssocationMetrics{} + m = &NoOpUDPAssociationMetrics{} } // Create the target connection - targetConn, err := h.targetConnFactory() + targetConn, err := h.targetListener.ListenPacket(context.Background()) if err != nil { return nil, fmt.Errorf("failed to create target connection: %w", err) } @@ -602,17 +600,17 @@ func (a *association) timedCopy() { } } -// NoOpUDPAssocationMetrics is a [UDPAssocationMetrics] that doesn't do anything. Useful in tests +// NoOpUDPAssociationMetrics is a [UDPAssocationMetrics] that doesn't do anything. Useful in tests // or if you don't want to track metrics. -type NoOpUDPAssocationMetrics struct{} +type NoOpUDPAssociationMetrics struct{} -var _ UDPAssocationMetrics = (*NoOpUDPAssocationMetrics)(nil) +var _ UDPAssocationMetrics = (*NoOpUDPAssociationMetrics)(nil) -func (m *NoOpUDPAssocationMetrics) AddAuthenticated(accessKey string) {} +func (m *NoOpUDPAssociationMetrics) AddAuthenticated(accessKey string) {} -func (m *NoOpUDPAssocationMetrics) AddPacketFromClient(status string, clientProxyBytes, proxyTargetBytes int64) { +func (m *NoOpUDPAssociationMetrics) AddPacketFromClient(status string, clientProxyBytes, proxyTargetBytes int64) { } -func (m *NoOpUDPAssocationMetrics) AddPacketFromTarget(status string, targetProxyBytes, proxyClientBytes int64) { +func (m *NoOpUDPAssociationMetrics) AddPacketFromTarget(status string, targetProxyBytes, proxyClientBytes int64) { } -func (m *NoOpUDPAssocationMetrics) AddClosed() { +func (m *NoOpUDPAssociationMetrics) AddClosed() { } diff --git a/service/udp_linux.go b/service/udp_linux.go new file mode 100644 index 00000000..f0b8bad4 --- /dev/null +++ b/service/udp_linux.go @@ -0,0 +1,63 @@ +// Copyright 2024 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in comlniance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by aplnicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or imlnied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux + +package service + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/Jigsaw-Code/outline-sdk/transport" +) + +type udpListener struct { + natTimeout time.Duration + // fwmark can be used in conjunction with other Linux networking features like cgroups, network + // namespaces, and TC (Traffic Control) for sophisticated network management. + // Value of 0 disables fwmark (SO_MARK) (Linux only) + fwmark uint +} + +// NewPacketListener creates a new PacketListener that listens on UDP +// and optionally sets a firewall mark on the socket (Linux only). +func MakeTargetUDPListener(natTimeout time.Duration, fwmark uint) transport.PacketListener { + return &udpListener{natTimeout: natTimeout, fwmark: fwmark} +} + +func (ln *udpListener) ListenPacket(ctx context.Context) (net.PacketConn, error) { + conn, err := net.ListenUDP("udp", nil) + if err != nil { + return nil, fmt.Errorf("Failed to create UDP socket: %w", err) + } + + if ln.fwmark > 0 { + rawConn, err := conn.SyscallConn() + if err != nil { + conn.Close() + return nil, fmt.Errorf("failed to get UDP raw connection: %w", err) + } + + err = SetFwmark(rawConn, ln.fwmark) + if err != nil { + conn.Close() + return nil, fmt.Errorf("Failed to set `fwmark`: %w", err) + + } + } + return &timedPacketConn{PacketConn: conn, defaultTimeout: ln.natTimeout}, nil +} diff --git a/service/udp_other.go b/service/udp_other.go new file mode 100644 index 00000000..c67e59f3 --- /dev/null +++ b/service/udp_other.go @@ -0,0 +1,47 @@ +// Copyright 2024 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !linux + +package service + +import ( + "context" + "net" + "time" + + "github.com/Jigsaw-Code/outline-sdk/transport" +) + +type udpListener struct { + *transport.UDPListener + natTimeout time.Duration +} + +// fwmark can be used in conjunction with other Linux networking features like cgroups, network namespaces, and TC (Traffic Control) for sophisticated network management. +// Value of 0 disables fwmark (SO_MARK) +func MakeTargetUDPListener(natTimeout time.Duration, fwmark uint) transport.PacketListener { + if fwmark != 0 { + panic("fwmark is linux-specific feature and should be 0") + } + return &udpListener{UDPListener: &transport.UDPListener{Address: ""}} +} + +func (ln *udpListener) ListenPacket(ctx context.Context) (net.PacketConn, error) { + conn, err := ln.UDPListener.ListenPacket(ctx) + if err != nil { + return nil, err + } + return &timedPacketConn{PacketConn: conn, defaultTimeout: ln.natTimeout}, nil +} diff --git a/service/udp_test.go b/service/udp_test.go index a2410074..358dd32e 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -16,6 +16,7 @@ package service import ( "bytes" + "context" "errors" "fmt" "net" @@ -25,19 +26,22 @@ import ( "time" "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" - onet "github.com/Jigsaw-Code/outline-ss-server/net" logging "github.com/op/go-logging" "github.com/shadowsocks/go-shadowsocks2/socks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + onet "github.com/Jigsaw-Code/outline-ss-server/net" ) const timeout = 5 * time.Minute var clientAddr = net.UDPAddr{IP: []byte{192, 0, 2, 1}, Port: 12345} + var targetAddr = net.UDPAddr{IP: []byte{192, 0, 2, 2}, Port: 54321} var localAddr = net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 9} var dnsAddr = net.UDPAddr{IP: []byte{192, 0, 2, 3}, Port: 53} + var natCryptoKey *shadowsocks.EncryptionKey func init() { @@ -51,6 +55,14 @@ type packet struct { err error } +type packetListener struct { + conn net.PacketConn +} + +func (ln *packetListener) ListenPacket(ctx context.Context) (net.PacketConn, error) { + return ln.conn, nil +} + type fakePacketConn struct { net.PacketConn send chan packet @@ -182,12 +194,10 @@ func sendSSPayload(conn *fakePacketConn, addr net.Addr, cipher *shadowsocks.Encr func startTestHandler() (PacketHandler, func(target net.Addr, payload []byte), *fakePacketConn) { ciphers, _ := MakeTestCiphers([]string{"asdf"}) cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey - handler := NewPacketHandler(10*time.Second, ciphers, nil) + handler := NewPacketHandler(ciphers, nil) clientConn := makePacketConn() targetConn := makePacketConn() - handler.SetTargetConnFactory(func() (net.PacketConn, error) { - return targetConn, nil - }) + handler.SetTargetPacketListener(&packetListener{targetConn}) go PacketServe(clientConn, func(conn net.Conn) (Association, error) { return handler.NewAssociation(conn, nil) }, &natTestMetrics{}) @@ -242,12 +252,10 @@ func TestPacketHandler_Handle_IPFilter(t *testing.T) { func TestUpstreamMetrics(t *testing.T) { ciphers, _ := MakeTestCiphers([]string{"asdf"}) cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey - handler := NewPacketHandler(10*time.Second, ciphers, nil) + handler := NewPacketHandler(ciphers, nil) clientConn := makePacketConn() targetConn := makePacketConn() - handler.SetTargetConnFactory(func() (net.PacketConn, error) { - return targetConn, nil - }) + handler.SetTargetPacketListener(&packetListener{targetConn}) metrics := &fakeUDPAssocationMetrics{} go PacketServe(clientConn, func(conn net.Conn) (Association, error) { return handler.NewAssociation(conn, metrics) @@ -302,9 +310,7 @@ func (e *fakeTimeoutError) Temporary() bool { func TestTimedPacketConn(t *testing.T) { t.Run("Write", func(t *testing.T) { handler, sendPayload, targetConn := startTestHandler() - handler.SetTargetConnFactory(func() (net.PacketConn, error) { - return &timedPacketConn{PacketConn: targetConn, defaultTimeout: timeout}, nil - }) + handler.SetTargetPacketListener(&packetListener{targetConn}) buf := []byte{1} sendPayload(&targetAddr, buf) @@ -319,9 +325,7 @@ func TestTimedPacketConn(t *testing.T) { t.Run("WriteDNS", func(t *testing.T) { handler, sendPayload, targetConn := startTestHandler() - handler.SetTargetConnFactory(func() (net.PacketConn, error) { - return &timedPacketConn{PacketConn: targetConn, defaultTimeout: timeout}, nil - }) + handler.SetTargetPacketListener(&packetListener{targetConn}) // Simulate one DNS query being sent. buf := []byte{1} @@ -338,9 +342,7 @@ func TestTimedPacketConn(t *testing.T) { t.Run("WriteDNSMultiple", func(t *testing.T) { handler, sendPayload, targetConn := startTestHandler() - handler.SetTargetConnFactory(func() (net.PacketConn, error) { - return &timedPacketConn{PacketConn: targetConn, defaultTimeout: timeout}, nil - }) + handler.SetTargetPacketListener(&packetListener{targetConn}) // Simulate three DNS queries being sent. buf := []byte{1} @@ -357,9 +359,7 @@ func TestTimedPacketConn(t *testing.T) { t.Run("WriteMixed", func(t *testing.T) { handler, sendPayload, targetConn := startTestHandler() - handler.SetTargetConnFactory(func() (net.PacketConn, error) { - return &timedPacketConn{PacketConn: targetConn, defaultTimeout: timeout}, nil - }) + handler.SetTargetPacketListener(&packetListener{targetConn}) // Simulate both non-DNS and DNS packets being sent. buf := []byte{1} @@ -375,12 +375,10 @@ func TestTimedPacketConn(t *testing.T) { t.Run("FastClose", func(t *testing.T) { ciphers, _ := MakeTestCiphers([]string{"asdf"}) cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey - handler := NewPacketHandler(10*time.Second, ciphers, nil) + handler := NewPacketHandler(ciphers, nil) clientConn := makePacketConn() targetConn := makePacketConn() - handler.SetTargetConnFactory(func() (net.PacketConn, error) { - return &timedPacketConn{PacketConn: targetConn, defaultTimeout: timeout}, nil - }) + handler.SetTargetPacketListener(&packetListener{targetConn}) go PacketServe(clientConn, func(conn net.Conn) (Association, error) { return handler.NewAssociation(conn, nil) }, &natTestMetrics{}) @@ -405,12 +403,10 @@ func TestTimedPacketConn(t *testing.T) { t.Run("NoFastClose_NotDNS", func(t *testing.T) { ciphers, _ := MakeTestCiphers([]string{"asdf"}) cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey - handler := NewPacketHandler(10*time.Second, ciphers, nil) + handler := NewPacketHandler(ciphers, nil) clientConn := makePacketConn() targetConn := makePacketConn() - handler.SetTargetConnFactory(func() (net.PacketConn, error) { - return &timedPacketConn{PacketConn: targetConn, defaultTimeout: timeout}, nil - }) + handler.SetTargetPacketListener(&packetListener{targetConn}) go PacketServe(clientConn, func(conn net.Conn) (Association, error) { return handler.NewAssociation(conn, nil) }, &natTestMetrics{}) @@ -435,12 +431,10 @@ func TestTimedPacketConn(t *testing.T) { t.Run("NoFastClose_MultipleDNS", func(t *testing.T) { ciphers, _ := MakeTestCiphers([]string{"asdf"}) cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey - handler := NewPacketHandler(10*time.Second, ciphers, nil) + handler := NewPacketHandler(ciphers, nil) clientConn := makePacketConn() targetConn := makePacketConn() - handler.SetTargetConnFactory(func() (net.PacketConn, error) { - return &timedPacketConn{PacketConn: targetConn, defaultTimeout: timeout}, nil - }) + handler.SetTargetPacketListener(&packetListener{targetConn}) go PacketServe(clientConn, func(conn net.Conn) (Association, error) { return handler.NewAssociation(conn, nil) }, &natTestMetrics{}) @@ -463,9 +457,7 @@ func TestTimedPacketConn(t *testing.T) { t.Run("Timeout", func(t *testing.T) { handler, sendPayload, targetConn := startTestHandler() - handler.SetTargetConnFactory(func() (net.PacketConn, error) { - return &timedPacketConn{PacketConn: targetConn, defaultTimeout: timeout}, nil - }) + handler.SetTargetPacketListener(&packetListener{targetConn}) // Simulate a non-DNS initial packet. sendPayload(&targetAddr, []byte{1}) @@ -497,11 +489,11 @@ func TestNATMap(t *testing.T) { addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} assoc1 := &association{} - nm.Add(addr, assoc1) + nm.Add(addr.String(), assoc1) assert.Equal(t, assoc1, nm.Get(addr.String()), "Get should return the correct connection") assoc2 := &association{} - nm.Add(addr, assoc2) + nm.Add(addr.String(), assoc2) assert.Equal(t, assoc2, nm.Get(addr.String()), "Adding with the same address should overwrite the entry") }) @@ -509,7 +501,7 @@ func TestNATMap(t *testing.T) { nm := newNATmap() addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} assoc := &association{} - nm.Add(addr, assoc) + nm.Add(addr.String(), assoc) assert.Equal(t, assoc, nm.Get(addr.String()), "Get should return the correct connection for an existing address") @@ -521,9 +513,9 @@ func TestNATMap(t *testing.T) { nm := newNATmap() addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} assoc := &association{} - nm.Add(addr, assoc) + nm.Add(addr.String(), assoc) - nm.Del(addr) + nm.Del(addr.String()) assert.Nil(t, nm.Get(addr.String()), "Get should return nil after deleting the entry") }) @@ -533,7 +525,7 @@ func TestNATMap(t *testing.T) { addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} pc := makePacketConn() assoc := &association{Conn: &natconn{PacketConn: pc, raddr: addr}} - nm.Add(addr, assoc) + nm.Add(addr.String(), assoc) err := nm.Close() assert.NoError(t, err, "Close should not return an error") @@ -626,7 +618,7 @@ func TestUDPEarlyClose(t *testing.T) { t.Fatal(err) } const testTimeout = 200 * time.Millisecond - ph := NewPacketHandler(testTimeout, cipherList, &fakeShadowsocksMetrics{}) + ph := NewPacketHandler(cipherList, &fakeShadowsocksMetrics{}) clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) if err != nil { @@ -635,7 +627,7 @@ func TestUDPEarlyClose(t *testing.T) { require.Nil(t, clientConn.Close()) // This should return quickly without timing out. go PacketServe(clientConn, func(conn net.Conn) (Association, error) { - return ph.NewAssociation(conn, &NoOpUDPAssocationMetrics{}) + return ph.NewAssociation(conn, &NoOpUDPAssociationMetrics{}) }, &natTestMetrics{}) }