From dc8363d22cfa2863fe6e38c9c0e85c25122bf0a5 Mon Sep 17 00:00:00 2001 From: uoosef Date: Sat, 9 Sep 2023 00:17:55 +0330 Subject: [PATCH] fragmentation rework --- cmd/cli/main.go | 26 +- cmd/core/core.go | 163 ------- cmd/gui/main.go | 12 +- cmd/mobile/tun2socks.go | 7 +- config/config.go | 37 ++ dialer/dialer.go | 20 +- dialer/dialer_tcp_test.go | 4 +- dialer/http.go | 8 +- dialer/tcp.go | 8 +- dialer/tls.go | 4 +- logger/logger.go | 5 - net/adapter/fragment/conn.go | 200 +++++++++ net/adapter/http/conn.go | 102 +++++ .../adapter/ws/conn.go | 4 +- server/handle.go | 307 +++++++++++++ server/server.go | 421 ++++-------------- transport/transport.go | 22 +- transport/ws.go | 23 +- 18 files changed, 815 insertions(+), 558 deletions(-) delete mode 100644 cmd/core/core.go create mode 100644 config/config.go create mode 100644 net/adapter/fragment/conn.go create mode 100644 net/adapter/http/conn.go rename wsconnadapter/wsconnadapter.go => net/adapter/ws/conn.go (95%) create mode 100644 server/handle.go diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 35b8d5a..49fd305 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -2,8 +2,9 @@ package main import ( - "bepass/cmd/core" + "bepass/config" "bepass/logger" + "bepass/server" "encoding/json" "errors" "fmt" @@ -25,47 +26,46 @@ func main() { err := ff.Parse(fs, os.Args[1:]) switch { case errors.Is(err, ff.ErrHelp): - fmt.Fprintf(os.Stderr, "%s\n", ffhelp.Flags(fs)) + logger.Errorf("%s\n", ffhelp.Flags(fs)) os.Exit(0) case err != nil: - fmt.Fprintf(os.Stderr, "error: %v\n", err) + logger.Errorf("error: %v\n", err) os.Exit(1) } // Load and validate configuration from JSON file - config, err := loadConfig(configPath) + err = loadConfig(configPath) if err != nil { logger.Fatal("", err) } // Run the server with the loaded configuration - err = core.RunServer(config, true) + err = server.Run(true) if err != nil { logger.Fatal("", err) } - // Handle graceful shutdown + // HandleTCPTunnel graceful shutdown handleShutdown() } -func loadConfig(configPath string) (*core.Config, error) { +func loadConfig(configPath string) error { file, err := os.Open(configPath) if err != nil { - return nil, err + return err } defer file.Close() - config := &core.Config{} decoder := json.NewDecoder(file) - err = decoder.Decode(config) + err = decoder.Decode(config.G) if err != nil { if strings.Contains(err.Error(), "invalid character") { - return nil, fmt.Errorf("configuration file is not valid JSON") + return fmt.Errorf("configuration file is not valid JSON") } - return nil, err + return err } - return config, nil + return nil } func handleShutdown() { diff --git a/cmd/core/core.go b/cmd/core/core.go deleted file mode 100644 index a715369..0000000 --- a/cmd/core/core.go +++ /dev/null @@ -1,163 +0,0 @@ -package core - -import ( - "bepass/bufferpool" - "bepass/dialer" - "bepass/doh" - "bepass/resolve" - "bepass/server" - "bepass/socks5" - "bepass/transport" - "bepass/utils" - "context" - "fmt" - "io" - "os" - "os/signal" - "strings" - "syscall" - "time" -) - -type Config struct { - TLSHeaderLength int `mapstructure:"TLSHeaderLength"` - TLSPaddingEnabled bool `mapstructure:"TLSPaddingEnabled"` - TLSPaddingSize [2]int `mapstructure:"TLSPaddingSize"` - DnsCacheTTL int `mapstructure:"DnsCacheTTL"` - DnsRequestTimeout int `mapstructure:"DnsRequestTimeout"` - WorkerAddress string `mapstructure:"WorkerAddress"` - WorkerIPPortAddress string `mapstructure:"WorkerIPPortAddress"` - WorkerEnabled bool `mapstructure:"WorkerEnabled"` - WorkerDNSOnly bool `mapstructure:"WorkerDNSOnly"` - EnableLowLevelSockets bool `mapstructure:"EnableLowLevelSockets"` - EnableDNSFragmentation bool `mapstructure:"EnableDNSFragmentation"` - RemoteDNSAddr string `mapstructure:"RemoteDNSAddr"` - BindAddress string `mapstructure:"BindAddress"` - UDPBindAddress string `mapstructure:"UDPBindAddress"` - ChunksLengthBeforeSni [2]int `mapstructure:"ChunksLengthBeforeSni"` - UDPReadTimeout int `mapstructure:"UDPReadTimeout"` - UDPWriteTimeout int `mapstructure:"UDPWriteTimeout"` - UDPLinkIdleTimeout int64 `mapstructure:"UDPLinkIdleTimeout"` - SniChunksLength [2]int `mapstructure:"SniChunksLength"` - ChunksLengthAfterSni [2]int `mapstructure:"ChunksLengthAfterSni"` - DelayBetweenChunks [2]int `mapstructure:"DelayBetweenChunks"` - Hosts []resolve.Hosts `mapstructure:"Hosts"` - ResolveSystem string `mapstructure:"-"` - DoHClient *doh.Client `mapstructure:"-"` -} - -var s5 *socks5.Server - -func RunServer(config *Config, captureCTRLC bool) error { - appCache := utils.NewCache(time.Duration(config.DnsCacheTTL) * time.Second) - - var resolveSystem string - var dohClient *doh.Client - - localResolver := &resolve.LocalResolver{ - Hosts: config.Hosts, - } - - dialer_ := &dialer.Dialer{ - EnableLowLevelSockets: config.EnableLowLevelSockets, - TLSPaddingEnabled: config.TLSPaddingEnabled, - TLSPaddingSize: config.TLSPaddingSize, - ProxyAddress: fmt.Sprintf("socks5://%s", config.BindAddress), - } - - wsTunnel := &transport.WSTunnel{ - BindAddress: config.BindAddress, - Dialer: dialer_, - ReadTimeout: config.UDPReadTimeout, - WriteTimeout: config.UDPWriteTimeout, - LinkIdleTimeout: config.UDPLinkIdleTimeout, - EstablishedTunnels: make(map[string]*transport.EstablishedTunnel), - ShortClientID: utils.ShortID(6), - } - - transport_ := &transport.Transport{ - WorkerAddress: config.WorkerAddress, - BindAddress: config.BindAddress, - Dialer: dialer_, - BufferPool: bufferpool.NewPool(32 * 1024), - UDPBind: config.UDPBindAddress, - Tunnel: wsTunnel, - } - - if strings.HasPrefix(config.RemoteDNSAddr, "https://") { - resolveSystem = "doh" - dohClient = doh.NewClient( - doh.WithDNSFragmentation((config.WorkerEnabled && config.WorkerDNSOnly) || config.EnableDNSFragmentation), - doh.WithDialer(dialer_), - doh.WithLocalResolver(localResolver), - ) - } else { - resolveSystem = "DNSCrypt" - } - - chunkConfig := server.ChunkConfig{ - BeforeSniLength: config.SniChunksLength, - AfterSniLength: config.ChunksLengthAfterSni, - Delay: config.DelayBetweenChunks, - TLSHeaderLength: config.TLSHeaderLength, - } - - workerConfig := server.WorkerConfig{ - WorkerAddress: config.WorkerAddress, - WorkerIPPortAddress: config.WorkerIPPortAddress, - WorkerEnabled: config.WorkerEnabled, - WorkerDNSOnly: config.WorkerDNSOnly, - } - - serverHandler := &server.Server{ - RemoteDNSAddr: config.RemoteDNSAddr, - Cache: appCache, - ResolveSystem: resolveSystem, - DoHClient: dohClient, - ChunkConfig: chunkConfig, - WorkerConfig: workerConfig, - BindAddress: config.BindAddress, - EnableLowLevelSockets: config.EnableLowLevelSockets, - Dialer: dialer_, - LocalResolver: localResolver, - Transport: transport_, - } - - if captureCTRLC { - c := make(chan os.Signal) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - go func() { - <-c - _ = ShutDown() - os.Exit(0) - }() - } - - if workerConfig.WorkerEnabled && !workerConfig.WorkerDNSOnly { - s5 = socks5.NewServer( - socks5.WithConnectHandle(func(ctx context.Context, w io.Writer, req *socks5.Request) error { - return serverHandler.Handle(ctx, w, req, "tcp") - }), - socks5.WithAssociateHandle(func(ctx context.Context, w io.Writer, req *socks5.Request) error { - return serverHandler.Handle(ctx, w, req, "udp") - }), - ) - } else { - s5 = socks5.NewServer( - socks5.WithConnectHandle(func(ctx context.Context, w io.Writer, req *socks5.Request) error { - return serverHandler.Handle(ctx, w, req, "tcp") - }), - ) - } - - fmt.Println("Starting socks, http server:", config.BindAddress) - if err := s5.ListenAndServe("tcp", config.BindAddress); err != nil { - return err - } - - return nil -} - -func ShutDown() error { - return s5.Shutdown() -} diff --git a/cmd/gui/main.go b/cmd/gui/main.go index 3233d55..525176e 100644 --- a/cmd/gui/main.go +++ b/cmd/gui/main.go @@ -1,7 +1,8 @@ package main import ( - "bepass/cmd/core" + "bepass/config" + "bepass/server" "encoding/json" "fmt" "io" @@ -48,7 +49,7 @@ type UIComponents struct { openFileButton *widget.Button connectButton *widget.Button isConnected bool - coreConfig *core.Config + coreConfig *config.Config } func createUIComponents(myWindow *fyne.Window) *UIComponents { @@ -132,7 +133,7 @@ func (ui *UIComponents) Connect(myWindow *fyne.Window) { firstValue := ui.dohInput.Text secondValue := ui.listenInput.Text - ui.coreConfig = &core.Config{ + ui.coreConfig = &config.Config{ TLSHeaderLength: 5, DnsCacheTTL: 3600, WorkerAddress: "worker.example.com", @@ -146,7 +147,6 @@ func (ui *UIComponents) Connect(myWindow *fyne.Window) { ChunksLengthAfterSni: [2]int{50, 60}, DelayBetweenChunks: [2]int{70, 80}, ResolveSystem: "doh", - DoHClient: nil, // Initialize appropriately } } @@ -156,7 +156,7 @@ func (ui *UIComponents) Connect(myWindow *fyne.Window) { } go func() { - err := core.RunServer(ui.coreConfig, true) + err := server.Run(true) if err != nil { dialog.ShowError(err, *myWindow) ui.isConnected = false @@ -175,7 +175,7 @@ func (ui *UIComponents) Connect(myWindow *fyne.Window) { func (ui *UIComponents) Disconnect(myWindow *fyne.Window) { go func() { if ui.coreConfig != nil { - err := core.ShutDown() + err := server.ShutDown() if err != nil { dialog.ShowError(err, *myWindow) diff --git a/cmd/mobile/tun2socks.go b/cmd/mobile/tun2socks.go index 70e444b..f483ee4 100644 --- a/cmd/mobile/tun2socks.go +++ b/cmd/mobile/tun2socks.go @@ -1,6 +1,7 @@ package tun2socks import ( + "bepass/config" "encoding/json" "errors" "io" @@ -18,18 +19,18 @@ import ( "github.com/eycorsican/go-tun2socks/core" "github.com/eycorsican/go-tun2socks/proxy/socks" - bepassCore "bepass/cmd/core" + bepassCore "bepass/server" "github.com/songgao/water" ) func StartClient(cfg string) bool { - config := &bepassCore.Config{} + config := &config.Config{} err := json.Unmarshal([]byte(cfg), config) if err != nil { return false } - err = bepassCore.RunServer(config, false) + err = bepassCore.Run(false) if err != nil { return false } diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..12985b9 --- /dev/null +++ b/config/config.go @@ -0,0 +1,37 @@ +package config + +import ( + "bepass/resolve" +) + +type Config struct { + TLSHeaderLength int `mapstructure:"TLSHeaderLength"` + TLSPaddingEnabled bool `mapstructure:"TLSPaddingEnabled"` + TLSPaddingSize [2]int `mapstructure:"TLSPaddingSize"` + DnsCacheTTL int `mapstructure:"DnsCacheTTL"` + DnsRequestTimeout int `mapstructure:"DnsRequestTimeout"` + WorkerAddress string `mapstructure:"WorkerAddress"` + WorkerIPPortAddress string `mapstructure:"WorkerIPPortAddress"` + WorkerEnabled bool `mapstructure:"WorkerEnabled"` + WorkerDNSOnly bool `mapstructure:"WorkerDNSOnly"` + EnableLowLevelSockets bool `mapstructure:"EnableLowLevelSockets"` + EnableDNSFragmentation bool `mapstructure:"EnableDNSFragmentation"` + RemoteDNSAddr string `mapstructure:"RemoteDNSAddr"` + BindAddress string `mapstructure:"BindAddress"` + UDPBindAddress string `mapstructure:"UDPBindAddress"` + ChunksLengthBeforeSni [2]int `mapstructure:"ChunksLengthBeforeSni"` + SniChunksLength [2]int `mapstructure:"SniChunksLength"` + ChunksLengthAfterSni [2]int `mapstructure:"ChunksLengthAfterSni"` + UDPReadTimeout int `mapstructure:"UDPReadTimeout"` + UDPWriteTimeout int `mapstructure:"UDPWriteTimeout"` + UDPLinkIdleTimeout int64 `mapstructure:"UDPLinkIdleTimeout"` + DelayBetweenChunks [2]int `mapstructure:"DelayBetweenChunks"` + Hosts []resolve.Hosts `mapstructure:"Hosts"` + ResolveSystem string `mapstructure:"-"` +} + +var G *Config + +func init() { + G = &Config{} +} diff --git a/dialer/dialer.go b/dialer/dialer.go index 0b4de0a..70dbdb1 100644 --- a/dialer/dialer.go +++ b/dialer/dialer.go @@ -2,11 +2,13 @@ package dialer import ( + "bepass/net/adapter/fragment" + "bepass/net/adapter/http" "net" ) // PlainTCPDial is a type representing a function for plain TCP dialing. -type PlainTCPDial func(network, addr, hostPort string) (net.Conn, error) +type PlainTCPDial func(network, addr string) (net.Conn, error) // Dialer is a struct that holds various options for custom dialing. type Dialer struct { @@ -15,3 +17,19 @@ type Dialer struct { TLSPaddingSize [2]int // Size of TLS padding. ProxyAddress string // Address of the proxy server. } + +func (d *Dialer) FragmentDial(network, addr string) (net.Conn, error) { + tcpConn, err := d.TCPDial(network, addr) + if err != nil { + return nil, err + } + return fragment.New(tcpConn), nil +} + +func (d *Dialer) HttpDial(network, addr string) (net.Conn, error) { + tcpConn, err := d.TCPDial(network, addr) + if err != nil { + return nil, err + } + return http.New(tcpConn), nil +} diff --git a/dialer/dialer_tcp_test.go b/dialer/dialer_tcp_test.go index 32f9eaa..613b5c4 100644 --- a/dialer/dialer_tcp_test.go +++ b/dialer/dialer_tcp_test.go @@ -17,13 +17,11 @@ func TestDialerAndTCPDial(t *testing.T) { testCases := []struct { network string addr string - hostPort string expected error }{ { network: "tcp", addr: "example.com:80", - hostPort: "", expected: nil, // Modify this based on your expected outcome }, // Add more test cases here @@ -32,7 +30,7 @@ func TestDialerAndTCPDial(t *testing.T) { // Run the test cases for _, tc := range testCases { t.Run(tc.addr, func(t *testing.T) { - conn, err := d.TCPDial(tc.network, tc.addr, tc.hostPort) + conn, err := d.TCPDial(tc.network, tc.addr) if err != nil { t.Fatalf("TCPDial failed: %v", err) } diff --git a/dialer/http.go b/dialer/http.go index 020c168..d17f867 100644 --- a/dialer/http.go +++ b/dialer/http.go @@ -14,12 +14,12 @@ func (d *Dialer) MakeHTTPClient(hostPort string, enableProxy bool) *http.Client transport := &http.Transport{ ForceAttemptHTTP2: false, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return d.TCPDial(network, addr, hostPort) + return d.TCPDial(network, addr) }, DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return d.TLSDial(func(network, addr, hostPort string) (net.Conn, error) { - return d.TCPDial(network, addr, hostPort) - }, network, addr, hostPort) + return d.TLSDial(func(network, addr string) (net.Conn, error) { + return d.TCPDial(network, addr) + }, network, addr) }, } if enableProxy { diff --git a/dialer/tcp.go b/dialer/tcp.go index 0e96c5f..8887f4a 100644 --- a/dialer/tcp.go +++ b/dialer/tcp.go @@ -11,16 +11,12 @@ import ( ) // TCPDial connects to the destination address. -func (d *Dialer) TCPDial(network, addr, hostPort string) (*net.TCPConn, error) { +func (d *Dialer) TCPDial(network, addr string) (*net.TCPConn, error) { var ( tcpAddr *net.TCPAddr err error ) - if hostPort != "" { - tcpAddr, err = net.ResolveTCPAddr(network, hostPort) - } else { - tcpAddr, err = net.ResolveTCPAddr(network, addr) - } + tcpAddr, err = net.ResolveTCPAddr(network, addr) if err != nil { return nil, err } diff --git a/dialer/tls.go b/dialer/tls.go index 6b8c055..c2f0491 100644 --- a/dialer/tls.go +++ b/dialer/tls.go @@ -210,12 +210,12 @@ func removeProtocolFromALPN(spec *tls.ClientHelloSpec, protocol string) *tls.Cli } // TLSDial dials a TLS connection. -func (d *Dialer) TLSDial(plainDialer PlainTCPDial, network, addr, hostPort string) (net.Conn, error) { +func (d *Dialer) TLSDial(plainDialer PlainTCPDial, network, addr string) (net.Conn, error) { sni, _, err := net.SplitHostPort(addr) if err != nil { return nil, err } - plainConn, err := plainDialer(network, addr, hostPort) + plainConn, err := plainDialer(network, addr) if err != nil { return nil, err } diff --git a/logger/logger.go b/logger/logger.go index a8820e8..e75ce42 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -67,11 +67,6 @@ func init() { } } -// GetLogger returns the global logger instance. -func GetLogger() *slog.Logger { - return logger -} - func log(ctx context.Context, level slog.Level, msg string, args ...interface{}) { if !logger.Enabled(ctx, level) { return diff --git a/net/adapter/fragment/conn.go b/net/adapter/fragment/conn.go new file mode 100644 index 0000000..25dc48f --- /dev/null +++ b/net/adapter/fragment/conn.go @@ -0,0 +1,200 @@ +package fragment + +import ( + "bepass/config" + "bepass/sni" + "bytes" + "math/rand" + "net" + "sync" + "time" +) + +// Adapter represents an adapter for implementing fragmentation as net.Conn interface +type Adapter struct { + conn net.Conn + readMutex sync.Mutex + writeMutex sync.Mutex + isFirstWrite bool + // search for sni and if sni was found, initially split client hello packet to 3 packets + // first chunk is contents of original tls hello packet before reaching sni + // second packet is sni itself + // and third package is contents of original tls hello packet after sni + // we fragment each part separately BSL indicates each fragment's size(a range) for + // original packet contents before reaching the sni + // SL indicates each fragment's size(a range) for the sni itself + // ASL indicates each fragment's size(a range) for remaining contents of original packet that comes after sni + // and delay indicates how much delay system should take before sending next fragment as a separate packet + BSL [2]int + SL [2]int + ASL [2]int + Delay [2]int +} + +// New creates a new Adapter from a net.Conn connection. +func New(conn net.Conn) *Adapter { + return &Adapter{ + conn: conn, + isFirstWrite: true, + BSL: config.G.ChunksLengthBeforeSni, + SL: config.G.SniChunksLength, + ASL: config.G.ChunksLengthBeforeSni, + } +} + +// it will search for sni or host in package and if found then chunks Write writes data to the net.Conn connection. +func (a *Adapter) writeFragments(b []byte, index int) (int, error) { + nw := 0 + position := 0 + lengthMin, lengthMax := 0, 0 + if index == 0 { + lengthMin, lengthMax = a.BSL[0], a.BSL[1] + } else if index == 1 { // if its sni + lengthMin, lengthMax = a.SL[0], a.SL[1] + } else { // if its after sni + lengthMin, lengthMax = a.ASL[0], a.ASL[1] + } + for position < len(b) { + var fragmentLength int + if lengthMax-lengthMin > 0 { + fragmentLength = rand.Intn(lengthMax-lengthMin) + lengthMin + } else { + fragmentLength = lengthMin + } + + if fragmentLength > len(b)-position { + fragmentLength = len(b) - position + } + + var delay int + if a.Delay[1]-a.Delay[0] > 0 { + delay = rand.Intn(a.Delay[1]-a.Delay[0]) + a.Delay[0] + } else { + delay = a.Delay[0] + } + + tnw, ew := a.conn.Write(b[position : position+fragmentLength]) + if ew != nil { + return 0, ew + } + + nw += tnw + + position += fragmentLength + time.Sleep(time.Duration(delay) * time.Millisecond) + } + + return nw, nil +} + +// it will search for sni or host in package and if found then chunks Write writes data to the net.Conn connection. +func (a *Adapter) fragmentAndWriteFirstPacket(b []byte) (int, error) { + hello, err := sni.ReadClientHello(bytes.NewReader(b)) + if err != nil { + return a.conn.Write(b) + } + helloPacketSni := []byte(hello.ServerName) + chunks := make(map[int][]byte) + + /* + splitting original hello packet to BeforeSNI, SNI, AfterSNI chunks + */ + // search for sni through original tls client hello + index := bytes.Index(b, helloPacketSni) + if index == -1 { + return a.conn.Write(b) + } + // before helloPacketSni + chunks[0] = make([]byte, index) + copy(chunks[0], b[:index]) + // helloPacketSni + chunks[1] = make([]byte, len(helloPacketSni)) + copy(chunks[1], b[index:index+len(helloPacketSni)]) + // after helloPacketSni + chunks[2] = make([]byte, len(b)-index-len(helloPacketSni)) + copy(chunks[2], b[index+len(helloPacketSni):]) + + /* + sending fragments + */ + // number of written packets + nw := 0 + var ew error = nil + + for i := 0; i < 3; i++ { + tnw, ew := a.writeFragments(chunks[i], i) + nw += tnw + if ew != nil { + return 0, ew + } + } + + return nw, ew +} + +// Write writes data to the net.Conn connection. +func (a *Adapter) Write(b []byte) (int, error) { + a.writeMutex.Lock() + defer a.writeMutex.Unlock() + + var ( + bytesWritten int + err error + ) + + if a.isFirstWrite { + a.isFirstWrite = false + return a.fragmentAndWriteFirstPacket(b) + } else { + bytesWritten, err = a.conn.Write(b) + } + + return bytesWritten, err +} + +// Read reads data from the net.Conn connection. +func (a *Adapter) Read(b []byte) (int, error) { + // Read() can be called concurrently, and we mutate some internal state here + a.readMutex.Lock() + defer a.readMutex.Unlock() + + bytesRead, err := a.conn.Read(b) + if err != nil { + return 0, err + } + return bytesRead, err +} + +// Close closes the net.Conn connection. +func (a *Adapter) Close() error { + return a.conn.Close() +} + +// LocalAddr returns the local network address. +func (a *Adapter) LocalAddr() net.Addr { + return a.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (a *Adapter) RemoteAddr() net.Addr { + return a.conn.RemoteAddr() +} + +// SetDeadline sets the read and write deadlines for the connection. +func (a *Adapter) SetDeadline(t time.Time) error { + if err := a.SetReadDeadline(t); err != nil { + return err + } + + return a.SetWriteDeadline(t) +} + +// SetReadDeadline sets the read deadline for the connection. +func (a *Adapter) SetReadDeadline(t time.Time) error { + return a.conn.SetReadDeadline(t) +} + +// SetWriteDeadline sets the write deadline for the connection. +func (a *Adapter) SetWriteDeadline(t time.Time) error { + return a.conn.SetWriteDeadline(t) +} diff --git a/net/adapter/http/conn.go b/net/adapter/http/conn.go new file mode 100644 index 0000000..5008900 --- /dev/null +++ b/net/adapter/http/conn.go @@ -0,0 +1,102 @@ +package http + +import ( + "bepass/logger" + "bepass/sni" + "bytes" + "net" + "sync" + "time" +) + +// Adapter represents an adapter for implementing fragmentation as net.Conn interface +type Adapter struct { + conn net.Conn + readMutex sync.Mutex + writeMutex sync.Mutex + isFirstWrite bool +} + +// New creates a new Adapter from a net.Conn connection. +func New(conn net.Conn) *Adapter { + return &Adapter{ + conn: conn, + isFirstWrite: true, + } +} + +// Read reads data from the net.Conn connection. +func (a *Adapter) Read(b []byte) (int, error) { + // Read() can be called concurrently, and we mutate some internal state here + a.readMutex.Lock() + defer a.readMutex.Unlock() + + bytesRead, err := a.conn.Read(b) + if err != nil { + return 0, err + } + return bytesRead, err +} + +// Write writes data to the net.Conn connection. +func (a *Adapter) Write(b []byte) (int, error) { + a.writeMutex.Lock() + defer a.writeMutex.Unlock() + + var ( + bytesWritten int + err error + ) + + if a.isFirstWrite { + a.isFirstWrite = false + host, httpPacketData, err := sni.ParseHTTPHost(bytes.NewReader(b)) + if err != nil { + return a.conn.Write(b) + } + logger.Info("found http packet host: %s", host) + _, err = a.conn.Write(httpPacketData) + if err != nil { + return 0, err + } + return len(b), nil + } else { + bytesWritten, err = a.conn.Write(b) + } + + return bytesWritten, err +} + +// Close closes the net.Conn connection. +func (a *Adapter) Close() error { + return a.conn.Close() +} + +// LocalAddr returns the local network address. +func (a *Adapter) LocalAddr() net.Addr { + return a.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (a *Adapter) RemoteAddr() net.Addr { + return a.conn.RemoteAddr() +} + +// SetDeadline sets the read and write deadlines for the connection. +func (a *Adapter) SetDeadline(t time.Time) error { + if err := a.SetReadDeadline(t); err != nil { + return err + } + + return a.SetWriteDeadline(t) +} + +// SetReadDeadline sets the read deadline for the connection. +func (a *Adapter) SetReadDeadline(t time.Time) error { + return a.conn.SetReadDeadline(t) +} + +// SetWriteDeadline sets the write deadline for the connection. +func (a *Adapter) SetWriteDeadline(t time.Time) error { + return a.conn.SetWriteDeadline(t) +} diff --git a/wsconnadapter/wsconnadapter.go b/net/adapter/ws/conn.go similarity index 95% rename from wsconnadapter/wsconnadapter.go rename to net/adapter/ws/conn.go index c556384..598bdbe 100644 --- a/wsconnadapter/wsconnadapter.go +++ b/net/adapter/ws/conn.go @@ -1,6 +1,6 @@ -// Package wsconnadapter provides an adapter for representing WebSocket connections as net.Conn objects. +// Package ws provides an adapter for representing WebSocket connections as net.Conn objects. // It allows you to use WebSocket connections as if they were standard network connections. -package wsconnadapter +package ws import ( "errors" diff --git a/server/handle.go b/server/handle.go new file mode 100644 index 0000000..dbae7a3 --- /dev/null +++ b/server/handle.go @@ -0,0 +1,307 @@ +package server + +import ( + "bepass/dialer" + "bepass/doh" + "bepass/logger" + "bepass/resolve" + "bepass/sni" + "bepass/socks5" + "bepass/socks5/statute" + "bepass/transport" + "bepass/utils" + "bytes" + "context" + "fmt" + "io" + "net" + "net/url" + "strconv" + "strings" + "time" + + "github.com/ameshkov/dnscrypt/v2" + + "github.com/miekg/dns" +) + +// FragmentConfig Constants for chunk lengths and delays. +type FragmentConfig struct { + BSL [2]int + ASL [2]int + Delay [2]int +} + +// WorkerConfig Constants for cloudflare worker. +type WorkerConfig struct { + WorkerAddress string + WorkerIPPortAddress string + WorkerEnabled bool + WorkerDNSOnly bool +} + +type Server struct { + RemoteDNSAddr string + Cache *utils.Cache + ResolveSystem string + DoHClient *doh.Client + ChunkConfig FragmentConfig + WorkerConfig WorkerConfig + Dialer *dialer.Dialer + BindAddress string + EnableLowLevelSockets bool + LocalResolver *resolve.LocalResolver + Transport *transport.Transport +} + +// extractHostnameOrChangeHTTPHostHeader This function extracts the tls sni or http +func (s *Server) extractHostnameOrChangeHTTPHostHeader(data []byte) ( + hostname []byte, firstPacketData []byte, isHTTP bool, err error) { + hello, err := sni.ReadClientHello(bytes.NewReader(data)) + if err != nil { + host, httpPacketData, err := sni.ParseHTTPHost(bytes.NewReader(data)) + if err != nil { + return nil, data, false, err + } + return []byte(host), httpPacketData, true, nil + } + return []byte(hello.ServerName), data, false, nil +} + +func (s *Server) processFirstPacket(ctx context.Context, w io.Writer, req *socks5.Request) ( + *socks5.Request, string, bool, error, +) { + if err := socks5.SendReply(w, statute.RepSuccess, nil); err != nil { + logger.Errorf("failed to send reply: %v", err) + return nil, "", false, err + } + + firstPacket := make([]byte, 32*1024) + read, err := req.Reader.Read(firstPacket) + if err != nil { + return nil, "", false, err + } + + hostname, firstPacketData, isHTTP, err := s.extractHostnameOrChangeHTTPHostHeader(firstPacket[:read]) + + if hostname != nil { + logger.Infof("Hostname %s", string(hostname)) + } + + IPPort, err := s.resolveDestination(ctx, req) + if err != nil { + return nil, "", false, err + } + + // if user has a faulty dns, and it returns dpi ip, + // we resolve destination based on extracted tls sni or http hostname + if hostname != nil && strings.Contains(IPPort, "10.10.3") { + logger.Infof("%s is dpi ip extracting destination host from packets...", IPPort) + req.RawDestAddr.FQDN = string(hostname) + IPPort, err = s.resolveDestination(ctx, req) + if err != nil { + // if destination resolved to dpi and we cant resolve to actual destination + // it's pointless to connect to dpi + logger.Infof("system was unable to extract destination host from packets!") + return nil, "", false, err + } + } + + req.Reader = &utils.BufferedReader{ + FirstPacketData: firstPacketData, + BufReader: req.Reader, + FirstTime: true, + } + + return req, IPPort, isHTTP, nil +} + +func (s *Server) HandleTCPTunnel(ctx context.Context, w io.Writer, req *socks5.Request) error { + r, _, _, err := s.processFirstPacket(ctx, w, req) + if err != nil { + return err + } + return s.Transport.TunnelTCP(w, r) +} + +func (s *Server) HandleUDPTunnel(_ context.Context, w io.Writer, req *socks5.Request) error { + return s.Transport.TunnelUDP(w, req) +} + +// HandleTCPFragment handles the SOCKS5 request and forwards traffic to the destination. +func (s *Server) HandleTCPFragment(ctx context.Context, w io.Writer, req *socks5.Request) error { + r, IPPort, isHTTP, err := s.processFirstPacket(ctx, w, req) + if err != nil { + return err + } + + logger.Infof("Dialing %s...", IPPort) + + var conn net.Conn + + if isHTTP { + conn, err = s.Dialer.HttpDial("tcp", IPPort) + } else { + conn, err = s.Dialer.FragmentDial("tcp", IPPort) + } + + if err != nil { + return err + } + defer func() { + _ = conn.Close() + }() + + // Start proxying + errCh := make(chan error, 2) + go func() { errCh <- s.Copy(r.Reader, conn) }() + go func() { errCh <- s.Copy(conn, w) }() + // Wait + for i := 0; i < 2; i++ { + e := <-errCh + if e != nil { + // return from this function closes target (and conn). + return e + } + } + return nil +} + +func (s *Server) Copy(reader io.Reader, writer io.Writer) error { + buf := make([]byte, 32*1024) + + _, err := io.CopyBuffer(writer, reader, buf[:cap(buf)]) + return err +} + +func (s *Server) resolveDestination(_ context.Context, req *socks5.Request) (string, error) { + dest := req.RawDestAddr + + if dest.FQDN != "" { + ip, err := s.Resolve(dest.FQDN) + if err != nil { + return "", err + } + dest.IP = net.ParseIP(ip) + logger.Infof("resolved %s to %s", req.RawDestAddr, dest) + } else { + logger.Infof("skipping resolution for %s", req.RawDestAddr) + } + + addr := net.JoinHostPort(dest.IP.String(), strconv.Itoa(dest.Port)) + return addr, nil +} + +// Resolve resolves the FQDN to an IP address using the specified resolution mechanism. +func (s *Server) Resolve(fqdn string) (string, error) { + if s.WorkerConfig.WorkerEnabled && + strings.Contains(s.WorkerConfig.WorkerAddress, fqdn) { + dh, _, err := net.SplitHostPort(s.WorkerConfig.WorkerIPPortAddress) + if strings.Contains(dh, ":") { + // its ipv6 + dh = "[" + dh + "]" + } + if err != nil { + return "", err + } + return dh, nil + } + + if h := s.LocalResolver.CheckHosts(fqdn); h != "" { + return h, nil + } + + if s.ResolveSystem == "doh" { + u, err := url.Parse(s.RemoteDNSAddr) + if err == nil { + if u.Hostname() == fqdn { + return s.LocalResolver.Resolve(u.Hostname()), nil + } + } + } + + // Ensure fqdn ends with a period + if !strings.HasSuffix(fqdn, ".") { + fqdn += "." + } + + // Check the cache for fqdn + if cachedValue, _ := s.Cache.Get(fqdn); cachedValue != nil { + logger.Infof("using cached value for %s", fqdn) + return cachedValue.(string), nil + } + + // Build request message + req := dns.Msg{} + req.Id = dns.Id() + req.RecursionDesired = true + req.Question = []dns.Question{{ + Name: fqdn, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }} + + // Determine which DNS resolution mechanism to use + var exchange *dns.Msg + var err error + switch s.ResolveSystem { + case "doh": + exchange, err = s.resolveDNSWithDOH(&req) + default: + exchange, err = s.resolveDNSWithDNSCrypt(&req) + } + if err != nil { + return "", err + } + // Parse answer and store in cache + answer := exchange.Answer[0] + logger.Infof("resolved %s to %s", fqdn, strings.Replace(answer.String(), "\t", " ", -1)) + record := strings.Fields(answer.String()) + if record[3] == "CNAME" { + ip, err := s.Resolve(record[4]) + if err != nil { + return "", err + } + s.Cache.Set(fqdn, ip) + return ip, nil + } + ip := record[4] + s.Cache.Set(fqdn, ip) + return ip, nil +} + +// resolveDNSWithDOH resolves DNS using DNS-over-HTTP (DoH) client. +func (s *Server) resolveDNSWithDOH(req *dns.Msg) (*dns.Msg, error) { + dnsAddr := s.RemoteDNSAddr + if s.WorkerConfig.WorkerEnabled && s.WorkerConfig.WorkerDNSOnly { + dnsAddr = s.WorkerConfig.WorkerAddress + } + + exchange, _, err := s.DoHClient.Exchange(req, dnsAddr) + if err != nil { + return nil, err + } + if len(exchange.Answer) == 0 { + return nil, fmt.Errorf("no answer") + } + return exchange, nil +} + +// resolveDNSWithDNSCrypt resolves DNS using DNSCrypt client. +func (s *Server) resolveDNSWithDNSCrypt(req *dns.Msg) (*dns.Msg, error) { + c := dnscrypt.Client{ + Net: "tcp", Timeout: 10 * time.Second, + } + resolverInfo, err := c.Dial(s.RemoteDNSAddr) + if err != nil { + return nil, err + } + exchange, err := c.Exchange(req, resolverInfo) + if err != nil { + return nil, err + } + if len(exchange.Answer) == 0 { + return nil, fmt.Errorf("no answer") + } + return exchange, nil +} diff --git a/server/server.go b/server/server.go index 2147fcb..d7001d4 100644 --- a/server/server.go +++ b/server/server.go @@ -1,360 +1,135 @@ package server import ( + "bepass/bufferpool" + "bepass/config" "bepass/dialer" "bepass/doh" - "bepass/logger" "bepass/resolve" - "bepass/sni" "bepass/socks5" - "bepass/socks5/statute" "bepass/transport" "bepass/utils" - "bytes" "context" "fmt" "io" - "math/rand" - "net" - "net/url" - "strconv" + "os" + "os/signal" "strings" + "syscall" "time" - - "github.com/ameshkov/dnscrypt/v2" - - "github.com/miekg/dns" ) -// ChunkConfig Constants for chunk lengths and delays. -type ChunkConfig struct { - TLSHeaderLength int - BeforeSniLength [2]int - AfterSniLength [2]int - Delay [2]int -} +var s5 *socks5.Server -// WorkerConfig Constants for cloudflare worker. -type WorkerConfig struct { - WorkerAddress string - WorkerIPPortAddress string - WorkerEnabled bool - WorkerDNSOnly bool -} +func Run(captureCTRLC bool) error { + appCache := utils.NewCache(time.Duration(config.G.DnsCacheTTL) * time.Second) -type Server struct { - RemoteDNSAddr string - Cache *utils.Cache - ResolveSystem string - DoHClient *doh.Client - ChunkConfig ChunkConfig - WorkerConfig WorkerConfig - Dialer *dialer.Dialer - BindAddress string - EnableLowLevelSockets bool - LocalResolver *resolve.LocalResolver - Transport *transport.Transport -} + var resolveSystem string + var dohClient *doh.Client -// extractHostnameOrChangeHTTPHostHeader This function extracts the tls sni or http -func (s *Server) extractHostnameOrChangeHTTPHostHeader(data []byte) ( - hostname []byte, firstPacketData []byte, isHTTP bool, err error) { - hello, err := sni.ReadClientHello(bytes.NewReader(data)) - if err != nil { - host, httpPacketData, err := sni.ParseHTTPHost(bytes.NewReader(data)) - if err != nil { - return nil, data, false, err - } - return []byte(host), httpPacketData, true, nil + localResolver := &resolve.LocalResolver{ + Hosts: config.G.Hosts, } - return []byte(hello.ServerName), data, false, nil -} -func (s *Server) getChunkedPackets(data []byte, host []byte) map[int][]byte { - chunks := make(map[int][]byte) - index := bytes.Index(data, host) - if index == -1 { - chunks[0] = data - return chunks + appDialer := &dialer.Dialer{ + EnableLowLevelSockets: config.G.EnableLowLevelSockets, + TLSPaddingEnabled: config.G.TLSPaddingEnabled, + TLSPaddingSize: config.G.TLSPaddingSize, + ProxyAddress: fmt.Sprintf("socks5://%s", config.G.BindAddress), } - // before sni - chunks[0] = make([]byte, index) - copy(chunks[0], data[:index]) - // sni - chunks[1] = make([]byte, len(host)) - copy(chunks[1], data[index:index+len(host)]) - // after sni - chunks[2] = make([]byte, len(data)-index-len(host)) - copy(chunks[2], data[index+len(host):]) - return chunks -} -func (s *Server) sendSplitChunks(dst io.Writer, chunks map[int][]byte) { - chunkLengthMin, chunkLengthMax := s.ChunkConfig.BeforeSniLength[0], s.ChunkConfig.BeforeSniLength[1] - if len(chunks) > 1 { - chunkLengthMin, chunkLengthMax = s.ChunkConfig.AfterSniLength[0], s.ChunkConfig.AfterSniLength[1] + wsTunnel := &transport.WSTunnel{ + BindAddress: config.G.BindAddress, + Dialer: appDialer, + ReadTimeout: config.G.UDPReadTimeout, + WriteTimeout: config.G.UDPWriteTimeout, + LinkIdleTimeout: config.G.UDPLinkIdleTimeout, + EstablishedTunnels: make(map[string]*transport.EstablishedTunnel), + ShortClientID: utils.ShortID(6), } - for _, chunk := range chunks { - position := 0 - - for position < len(chunk) { - var chunkLength int - if chunkLengthMax-chunkLengthMin > 0 { - chunkLength = rand.Intn(chunkLengthMax-chunkLengthMin) + chunkLengthMin - } else { - chunkLength = chunkLengthMin - } - - if chunkLength > len(chunk)-position { - chunkLength = len(chunk) - position - } - - var delay int - if s.ChunkConfig.Delay[1]-s.ChunkConfig.Delay[0] > 0 { - delay = rand.Intn(s.ChunkConfig.Delay[1]-s.ChunkConfig.Delay[0]) + s.ChunkConfig.Delay[0] - } else { - delay = s.ChunkConfig.Delay[0] - } - - _, errWrite := dst.Write(chunk[position : position+chunkLength]) - if errWrite != nil { - return - } - - position += chunkLength - time.Sleep(time.Duration(delay) * time.Millisecond) - } - } -} - -// Handle handles the SOCKS5 request and forwards traffic to the destination. -func (s *Server) Handle(ctx context.Context, w io.Writer, req *socks5.Request, network string) error { - if s.WorkerConfig.WorkerEnabled && !s.WorkerConfig.WorkerDNSOnly && network == "udp" { - return s.Transport.TunnelUDP(w, req) + tunnelTransport := &transport.Transport{ + WorkerAddress: config.G.WorkerAddress, + BindAddress: config.G.BindAddress, + Dialer: appDialer, + BufferPool: bufferpool.NewPool(32 * 1024), + UDPBind: config.G.UDPBindAddress, + Tunnel: wsTunnel, } - if err := socks5.SendReply(w, statute.RepSuccess, nil); err != nil { - logger.Errorf("failed to send reply: %v", err) - return err - } - - firstPacket := make([]byte, 32*1024) - read, err := req.Reader.Read(firstPacket) - if err != nil { - return err - } - - hostname, firstPacketData, isHTTP, err := s.extractHostnameOrChangeHTTPHostHeader(firstPacket[:read]) - - if hostname != nil { - logger.Infof("Hostname %s", string(hostname)) - } - - IPPort, err := s.resolveDestination(ctx, req) - if err != nil { - return err - } - - // if user has a faulty dns, and it returns dpi ip, - // we resolve destination based on extracted tls sni or http hostname - if hostname != nil && strings.Contains(IPPort, "10.10.3") { - logger.Infof("%s is dpi ip extracting destination host from packets...", IPPort) - req.RawDestAddr.FQDN = string(hostname) - IPPort, err = s.resolveDestination(ctx, req) - if err != nil { - // if destination resolved to dpi and we cant resolve to actual destination - // it's pointless to connect to dpi - logger.Infof("system was unable to extract destination host from packets!") - return err - } - } - - if s.WorkerConfig.WorkerEnabled && - !s.WorkerConfig.WorkerDNSOnly && - (!strings.Contains(s.WorkerConfig.WorkerAddress, req.DstAddr.FQDN) || strings.TrimSpace(req.DstAddr.FQDN) == "") { - req.Reader = &utils.BufferedReader{ - FirstPacketData: firstPacketData, - BufReader: req.Reader, - FirstTime: true, - } - return s.Transport.TunnelTCP(w, req) - } - - firstPacketChunks := make(map[int][]byte) - - if isHTTP || err != nil || hostname == nil { - firstPacketChunks[0] = firstPacketData + if strings.HasPrefix(config.G.RemoteDNSAddr, "https://") { + resolveSystem = "doh" + dohClient = doh.NewClient( + doh.WithDNSFragmentation((config.G.WorkerEnabled && config.G.WorkerDNSOnly) || config.G.EnableDNSFragmentation), + doh.WithDialer(appDialer), + doh.WithLocalResolver(localResolver), + ) } else { - firstPacketChunks = s.getChunkedPackets(firstPacketData, hostname) - } - - logger.Infof("Dialing %s...", IPPort) - - conn, err := s.Dialer.TCPDial("tcp", "", IPPort) - if err != nil { - return err + resolveSystem = "DNSCrypt" + } + + chunkConfig := FragmentConfig{ + BSL: config.G.SniChunksLength, + ASL: config.G.ChunksLengthAfterSni, + Delay: config.G.DelayBetweenChunks, + } + + workerConfig := WorkerConfig{ + WorkerAddress: config.G.WorkerAddress, + WorkerIPPortAddress: config.G.WorkerIPPortAddress, + WorkerEnabled: config.G.WorkerEnabled, + WorkerDNSOnly: config.G.WorkerDNSOnly, + } + + serverHandler := &Server{ + RemoteDNSAddr: config.G.RemoteDNSAddr, + Cache: appCache, + ResolveSystem: resolveSystem, + DoHClient: dohClient, + ChunkConfig: chunkConfig, + WorkerConfig: workerConfig, + BindAddress: config.G.BindAddress, + EnableLowLevelSockets: config.G.EnableLowLevelSockets, + Dialer: appDialer, + LocalResolver: localResolver, + Transport: tunnelTransport, + } + + if captureCTRLC { + c := make(chan os.Signal) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + go func() { + <-c + _ = ShutDown() + os.Exit(0) + }() + } + + if workerConfig.WorkerEnabled && !workerConfig.WorkerDNSOnly { + s5 = socks5.NewServer( + socks5.WithConnectHandle(func(ctx context.Context, w io.Writer, req *socks5.Request) error { + return serverHandler.HandleTCPTunnel(ctx, w, req) + }), + socks5.WithAssociateHandle(func(ctx context.Context, w io.Writer, req *socks5.Request) error { + return serverHandler.HandleTCPTunnel(ctx, w, req) + }), + ) + } else { + s5 = socks5.NewServer( + socks5.WithConnectHandle(func(ctx context.Context, w io.Writer, req *socks5.Request) error { + return serverHandler.HandleTCPFragment(ctx, w, req) + }), + ) } - defer conn.Close() - if err := conn.SetNoDelay(true); err != nil { - logger.Errorf("failed to set NODELAY option: %v", err) + fmt.Println("Starting socks, http server:", config.G.BindAddress) + if err := s5.ListenAndServe("tcp", config.G.BindAddress); err != nil { return err } - // writing first packet - s.sendSplitChunks(conn, firstPacketChunks) - - // Start proxying - errCh := make(chan error, 2) - go func() { errCh <- s.Copy(req.Reader, conn) }() - go func() { errCh <- s.Copy(conn, w) }() - // Wait - for i := 0; i < 2; i++ { - e := <-errCh - if e != nil { - // return from this function closes target (and conn). - return e - } - } return nil } -func (s *Server) Copy(reader io.Reader, writer io.Writer) error { - buf := make([]byte, 32*1024) - - _, err := io.CopyBuffer(writer, reader, buf[:cap(buf)]) - return err -} - -func (s *Server) resolveDestination(ctx context.Context, req *socks5.Request) (string, error) { - dest := req.RawDestAddr - - if dest.FQDN != "" { - ip, err := s.Resolve(dest.FQDN) - if err != nil { - return "", err - } - dest.IP = net.ParseIP(ip) - logger.Infof("resolved %s to %s", req.RawDestAddr, dest) - } else { - logger.Infof("skipping resolution for %s", req.RawDestAddr) - } - - addr := net.JoinHostPort(dest.IP.String(), strconv.Itoa(dest.Port)) - return addr, nil -} - -// Resolve resolves the FQDN to an IP address using the specified resolution mechanism. -func (s *Server) Resolve(fqdn string) (string, error) { - if s.WorkerConfig.WorkerEnabled && - strings.Contains(s.WorkerConfig.WorkerAddress, fqdn) { - dh, _, err := net.SplitHostPort(s.WorkerConfig.WorkerIPPortAddress) - if strings.Contains(dh, ":") { - // its ipv6 - dh = "[" + dh + "]" - } - if err != nil { - return "", err - } - return dh, nil - } - - if h := s.LocalResolver.CheckHosts(fqdn); h != "" { - return h, nil - } - - if s.ResolveSystem == "doh" { - u, err := url.Parse(s.RemoteDNSAddr) - if err == nil { - if u.Hostname() == fqdn { - return s.LocalResolver.Resolve(u.Hostname()), nil - } - } - } - - // Ensure fqdn ends with a period - if !strings.HasSuffix(fqdn, ".") { - fqdn += "." - } - - // Check the cache for fqdn - if cachedValue, _ := s.Cache.Get(fqdn); cachedValue != nil { - logger.Infof("using cached value for %s", fqdn) - return cachedValue.(string), nil - } - - // Build request message - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{{ - Name: fqdn, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - }} - - // Determine which DNS resolution mechanism to use - var exchange *dns.Msg - var err error - switch s.ResolveSystem { - case "doh": - exchange, err = s.resolveDNSWithDOH(&req) - default: - exchange, err = s.resolveDNSWithDNSCrypt(&req) - } - if err != nil { - return "", err - } - // Parse answer and store in cache - answer := exchange.Answer[0] - logger.Infof("resolved %s to %s", fqdn, strings.Replace(answer.String(), "\t", " ", -1)) - record := strings.Fields(answer.String()) - if record[3] == "CNAME" { - ip, err := s.Resolve(record[4]) - if err != nil { - return "", err - } - s.Cache.Set(fqdn, ip) - return ip, nil - } - ip := record[4] - s.Cache.Set(fqdn, ip) - return ip, nil -} - -// resolveDNSWithDOH resolves DNS using DNS-over-HTTP (DoH) client. -func (s *Server) resolveDNSWithDOH(req *dns.Msg) (*dns.Msg, error) { - dnsAddr := s.RemoteDNSAddr - if s.WorkerConfig.WorkerEnabled && s.WorkerConfig.WorkerDNSOnly { - dnsAddr = s.WorkerConfig.WorkerAddress - } - - exchange, _, err := s.DoHClient.Exchange(req, dnsAddr) - if err != nil { - return nil, err - } - if len(exchange.Answer) == 0 { - return nil, fmt.Errorf("no answer") - } - return exchange, nil -} - -// resolveDNSWithDNSCrypt resolves DNS using DNSCrypt client. -func (s *Server) resolveDNSWithDNSCrypt(req *dns.Msg) (*dns.Msg, error) { - c := dnscrypt.Client{ - Net: "tcp", Timeout: 10 * time.Second, - } - resolverInfo, err := c.Dial(s.RemoteDNSAddr) - if err != nil { - return nil, err - } - exchange, err := c.Exchange(req, resolverInfo) - if err != nil { - return nil, err - } - if len(exchange.Answer) == 0 { - return nil, fmt.Errorf("no answer") - } - return exchange, nil +func ShutDown() error { + return s5.Shutdown() } diff --git a/transport/transport.go b/transport/transport.go index e141736..c4f5e17 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -5,10 +5,10 @@ import ( "bepass/bufferpool" "bepass/dialer" "bepass/logger" + "bepass/net/adapter/ws" "bepass/socks5" "bepass/socks5/statute" "bepass/utils" - "bepass/wsconnadapter" "fmt" "io" "net" @@ -19,7 +19,7 @@ import ( type UDPBind struct { Source *net.UDPAddr Destination string - TCPTunnel *wsconnadapter.Adapter + TCPTunnel *ws.Adapter TunnelStatus bool SocksWriter io.Writer SocksReq *socks5.Request @@ -70,8 +70,10 @@ func (t *Transport) TunnelTCP(w io.Writer, req *socks5.Request) error { return err } - conn := wsconnadapter.New(wsConn) - defer conn.Close() + conn := ws.New(wsConn) + defer func() { + _ = conn.Close() + }() if err != nil { return err @@ -80,16 +82,14 @@ func (t *Transport) TunnelTCP(w io.Writer, req *socks5.Request) error { // flush ws stream to write conn.Write([]byte{}) - errCh := make(chan error, 2) + errCh := make(chan error) go func() { errCh <- t.Copy(req.Reader, conn) }() go func() { errCh <- t.Copy(conn, w) }() // Wait - for i := 0; i < 2; i++ { - e := <-errCh - if e != nil { - // return from this function closes target (and conn). - return e - } + e := <-errCh + if e != nil { + // return from this function closes target (and conn). + return e } return nil } diff --git a/transport/ws.go b/transport/ws.go index 8c8665f..cccb91f 100644 --- a/transport/ws.go +++ b/transport/ws.go @@ -2,9 +2,10 @@ package transport import ( + "bepass/config" "bepass/dialer" "bepass/logger" - "bepass/wsconnadapter" + "bepass/net/adapter/ws" "context" "encoding/binary" "net" @@ -12,7 +13,6 @@ import ( "time" "github.com/gorilla/websocket" - "golang.org/x/net/proxy" ) // EstablishedTunnel represents an established tunnel. @@ -33,26 +33,17 @@ type WSTunnel struct { ShortClientID string } -// socks5TCPDial dials using SOCKS5 proxy. -func (w *WSTunnel) socks5TCPDial(_ context.Context, network, addr string) (net.Conn, error) { - d, err := proxy.SOCKS5("tcp", w.BindAddress, nil, proxy.Direct) - if err != nil { - return nil, err - } - return d.Dial(network, addr) -} - // Dial establishes a WebSocket connection. func (w *WSTunnel) Dial(endpoint string) (*websocket.Conn, error) { d := websocket.Dialer{ NetDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return w.socks5TCPDial(ctx, network, addr) + return w.Dialer.HttpDial(network, config.G.WorkerIPPortAddress) }, NetDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return w.Dialer.TLSDial(func(network, addr, hostPort string) (net.Conn, error) { - return w.socks5TCPDial(ctx, network, addr) - }, network, addr, "") + return w.Dialer.TLSDial(func(network, addr string) (net.Conn, error) { + return w.Dialer.FragmentDial(network, config.G.WorkerIPPortAddress) + }, network, addr) }, } conn, _, err := d.Dial(endpoint, nil) @@ -91,7 +82,7 @@ func (w *WSTunnel) PersistentDial(tunnelEndpoint string, bindWriteChannel chan U logger.Infof("connecting to %s\r\n", tunnelEndpoint) c, err := w.Dial(tunnelEndpoint) - conn := wsconnadapter.New(c) + conn := ws.New(c) if err != nil { logger.Errorf("error dialing udp over tcp tunnel: %v\r\n", err)