diff --git a/client/websocket.go b/client/websocket.go new file mode 100644 index 00000000..4241728a --- /dev/null +++ b/client/websocket.go @@ -0,0 +1,109 @@ +package client + +import ( + "crypto/tls" + "errors" + "fmt" + "net" + "net/http" + "strings" + "time" + + onet "github.com/Jigsaw-Code/outline-ss-server/net" + ss "github.com/Jigsaw-Code/outline-ss-server/shadowsocks" + "github.com/Jigsaw-Code/outline-ss-server/websocket" + "github.com/shadowsocks/go-shadowsocks2/socks" +) + +type WebsocketOptions struct { + // Addr is the address of the websocket server. It can either an IP address or a domain name. + Addr string + // Port is the destination port of the websocket connection. + Port int + // Host is the hostname to use in the Host header of HTTP request made to the websocket server. + // If empty, the header will be set to `Addr` if it is a domain name. + Host string + // SNI is the hostname to use in the server name extension of the TLS handshake. If empty, it will be set to `Host`. + SNI string + // Path is the HTTP path to use when connecting to the websocket server. + Path string + // Password is the password to use for the shadowsocks connection tunnelled inside the websocket connection. + Password string + // Ciphter is the cipher to use for the shadowsocks connection tunnelled inside the websocket connection. + Cipher string +} + +// NewWebsocketClient creates a client that routes connections to a Shadowsocks proxy +// tunneled inside a websocket connection. +func NewWebsocketClient(opts WebsocketOptions) (Client, error) { + proxy := opts.Addr + if proxy == "" { + proxy = opts.Host + } + if proxy == "" { + return nil, fmt.Errorf("neither Addr or Host are defined") + } + + ss, err := NewClient(proxy, opts.Port, opts.Password, opts.Cipher) + if err != nil { + return nil, err + } + + if strings.HasPrefix(opts.Path, "/") { + opts.Path = opts.Path[1:] + } + + addrIP := net.ParseIP(opts.Addr) + if opts.Host == "" && addrIP == nil { + opts.Host = opts.Addr + } + + if opts.SNI == "" { + opts.SNI = opts.Host + } + + return &wsClient{ + ssClient: ss.(*ssClient), + opts: opts, + }, nil +} + +type wsClient struct { + *ssClient + opts WebsocketOptions +} + +func (c *wsClient) DialTCP(laddr *net.TCPAddr, raddr string) (onet.DuplexConn, error) { + socksTargetAddr := socks.ParseAddr(raddr) + if socksTargetAddr == nil { + return nil, errors.New("Failed to parse target address") + } + + h := make(http.Header) + if c.opts.Host != "" { + h.Set("Host", c.opts.Host) + } + d := websocket.Dialer{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + ServerName: c.opts.SNI, + }, + HandshakeTimeout: websocket.DefaultHandshakeTimeout, + } + proxyConn, err := d.Dial(fmt.Sprintf("wss://%s:%d/%s", c.proxyIP, c.opts.Port, c.opts.Path), h) + if err != nil { + return nil, err + } + + ssw := ss.NewShadowsocksWriter(proxyConn, c.cipher) + _, err = ssw.LazyWrite(socksTargetAddr) + if err != nil { + proxyConn.Close() + return nil, errors.New("Failed to write target address") + } + time.AfterFunc(helloWait, func() { + ssw.Flush() + }) + ssr := ss.NewShadowsocksReader(proxyConn, c.cipher) + return onet.WrapConn(proxyConn, ssr, ssw), nil +} diff --git a/client/websocket_test.go b/client/websocket_test.go new file mode 100644 index 00000000..dc2af551 --- /dev/null +++ b/client/websocket_test.go @@ -0,0 +1,151 @@ +package client + +import ( + "crypto/tls" + "net" + "net/http" + "testing" + "time" + + onet "github.com/Jigsaw-Code/outline-ss-server/net" + ss "github.com/Jigsaw-Code/outline-ss-server/shadowsocks" + "github.com/Jigsaw-Code/outline-ss-server/websocket" +) + +const ( + testWSPath = "/test" +) + +func TestWebsocketClient(t *testing.T) { + testCases := []struct { + name string + opts WebsocketOptions + wantHost string + wantSNI string + }{ + { + name: "with_ip_host", + opts: WebsocketOptions{Addr: "127.0.0.1", Host: "example.com"}, + wantHost: "example.com", + wantSNI: "example.com", + }, + { + name: "with_ip_host_sni", + opts: WebsocketOptions{Addr: "127.0.0.1", Host: "example.com", SNI: "sni.com"}, + wantHost: "example.com", + wantSNI: "sni.com", + }, + { + name: "with_domain", + opts: WebsocketOptions{Addr: "localhost"}, + wantHost: "localhost", + wantSNI: "localhost", + }, + { + name: "with_domain_host", + opts: WebsocketOptions{Addr: "localhost", Host: "example.com"}, + wantHost: "example.com", + wantSNI: "example.com", + }, + { + name: "with_domain_host_sni", + opts: WebsocketOptions{Addr: "localhost", Host: "example.com", SNI: "sni.com"}, + wantHost: "example.com", + wantSNI: "sni.com", + }, + } + + proxy, hostCh, sniCh := startWebsocketShadowsocksEchoProxy(t) + defer close(hostCh) + defer close(sniCh) + defer proxy.Close() + _, proxyPort, err := splitHostPortNumber(proxy.Addr().String()) + if err != nil { + t.Fatalf("Failed to parse proxy address: %v", err) + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.opts.Password = testPassword + tc.opts.Cipher = ss.TestCipher + tc.opts.Port = proxyPort + tc.opts.Path = testWSPath + + d, err := NewWebsocketClient(tc.opts) + if err != nil { + t.Fatalf("Failed to create WebsocketClient: %v", err) + } + conn, err := d.DialTCP(nil, testTargetAddr) + if err != nil { + t.Fatalf("WebsocketClient.DialTCP failed: %v", err) + } + + select { + case sni := <-sniCh: + if sni != tc.wantSNI { + t.Fatalf("Wrong server name in TLS handshake server. got='%v' want='%v'", sni, tc.wantSNI) + } + case <-time.After(50 * time.Millisecond): + t.Fatal("TLS connection state not recevied") + } + select { + case host := <-hostCh: + if host != tc.wantHost { + t.Fatalf("Wrong host header. got='%v' want='%v'", host, tc.wantHost) + } + case <-time.After(50 * time.Millisecond): + t.Fatal("HTTP request not recevied") + } + + conn.SetReadDeadline(time.Now().Add(time.Second * 5)) + expectEchoPayload(conn, ss.MakeTestPayload(1024), make([]byte, 1024), t) + conn.Close() + }) + } +} + +func startWebsocketShadowsocksEchoProxy(t *testing.T) (net.Listener, chan string, chan string) { + proxy, _ := startShadowsocksTCPEchoProxy(testTargetAddr, t) + + hostCh := make(chan string, 1) + sniCh := make(chan string, 1) + + handler := func(w http.ResponseWriter, r *http.Request) { + u := websocket.Upgrader{HandshakeTimeout: 50 * time.Millisecond} + c, err := u.Upgrade(w, r, nil) + defer c.Close() + + hostCh <- r.Host + + if r.URL.Path != testWSPath { + t.Logf("Wrong Path received on request. got='%v' want='%v'", testWSPath, r.URL.Path) + return + } + + targetC, err := net.Dial("tcp", proxy.Addr().String()) + if err != nil { + t.Logf("Failed to connect to TCP echo server: %v", err) + return + } + + onet.Relay(c, targetC.(*net.TCPConn)) + } + + l, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + if err != nil { + t.Fatalf("Starting websocket listener failed: %v", err) + } + + go func() { + srv := &http.Server{Handler: http.HandlerFunc(handler)} + srv.TLSConfig = &tls.Config{ + VerifyConnection: func(cs tls.ConnectionState) error { + sniCh <- cs.ServerName + return nil + }, + } + srv.ServeTLS(l, websocket.TestCert, websocket.TestKey) + }() + + return l, hostCh, sniCh +} diff --git a/go.mod b/go.mod index 3918924d..b090947d 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,7 @@ module github.com/Jigsaw-Code/outline-ss-server require ( github.com/goreleaser/goreleaser v1.12.3 + github.com/gorilla/websocket v1.5.0 github.com/op/go-logging v0.0.0-20160315200505-970db520ece7 github.com/oschwald/geoip2-golang v1.8.0 github.com/prometheus/client_golang v1.13.0 @@ -104,7 +105,6 @@ require ( github.com/goreleaser/chglog v0.2.2 // indirect github.com/goreleaser/fileglob v1.3.0 // indirect github.com/goreleaser/nfpm/v2 v2.20.0 // indirect - github.com/gorilla/websocket v1.5.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-retryablehttp v0.7.1 // indirect github.com/hashicorp/go-version v1.2.1 // indirect diff --git a/server.go b/server.go index c3c04121..d7709a3c 100644 --- a/server.go +++ b/server.go @@ -212,19 +212,25 @@ func readConfig(filename string) (*Config, error) { func main() { var flags struct { - ConfigFile string - MetricsAddr string - IPCountryDB string - natTimeout time.Duration - replayHistory int - Verbose bool - Version bool + ConfigFile string + MetricsAddr string + IPCountryDB string + natTimeout time.Duration + replayHistory int + Verbose bool + Version bool + WebsocketServer bool + TLSCert string + TLSKey string } flag.StringVar(&flags.ConfigFile, "config", "", "Configuration filename") flag.StringVar(&flags.MetricsAddr, "metrics", "", "Address for the Prometheus metrics") flag.StringVar(&flags.IPCountryDB, "ip_country_db", "", "Path to the ip-to-country mmdb file") flag.DurationVar(&flags.natTimeout, "udptimeout", defaultNatTimeout, "UDP tunnel timeout") flag.IntVar(&flags.replayHistory, "replay_history", 0, "Replay buffer size (# of handshakes)") + flag.BoolVar(&flags.WebsocketServer, "websocket", false, "Enables the websocket serve") + flag.StringVar(&flags.TLSCert, "tls-cert", "ssl.crt", "Path to tls certificate to use for the websocket server") + flag.StringVar(&flags.TLSKey, "tls-key", "ssl.key", "Path to tls key to use for the websocket server") flag.BoolVar(&flags.Verbose, "verbose", false, "Enables verbose logging output") flag.BoolVar(&flags.Version, "version", false, "The version of the server") @@ -266,11 +272,19 @@ func main() { } m := metrics.NewPrometheusShadowsocksMetrics(ipCountryDB, prometheus.DefaultRegisterer) m.SetBuildInfo(version) - _, err = RunSSServer(flags.ConfigFile, flags.natTimeout, m, flags.replayHistory) + ssServer, err := RunSSServer(flags.ConfigFile, flags.natTimeout, m, flags.replayHistory) if err != nil { logger.Fatal(err) } + if flags.WebsocketServer { + if flags.TLSCert == "" || flags.TLSKey == "" { + log.Fatalln("TLS cert and key not specified") + flag.Usage() + } + RunWebsocketServer(ssServer, 443, flags.TLSCert, flags.TLSKey) + } + sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) <-sigCh diff --git a/service/tcp.go b/service/tcp.go index bbee770b..e1d939f1 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -143,6 +143,8 @@ type TCPService interface { Stop() error // GracefulStop calls Stop(), and then blocks until all resources have been cleaned up. GracefulStop() error + // HandleConnection takes a shadowsocks client connection and starts a relay to the destination address. + HandleConnection(listenerPort int, clientTCPConn onet.DuplexConn) } func (s *tcpService) SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) { @@ -207,12 +209,12 @@ func (s *tcpService) Serve(listener *net.TCPListener) error { logger.Errorf("Panic in TCP handler: %v", r) } }() - s.handleConnection(listener.Addr().(*net.TCPAddr).Port, clientTCPConn) + s.HandleConnection(listener.Addr().(*net.TCPAddr).Port, clientTCPConn) }() } } -func (s *tcpService) handleConnection(listenerPort int, clientTCPConn *net.TCPConn) { +func (s *tcpService) HandleConnection(listenerPort int, clientTCPConn onet.DuplexConn) { clientLocation, err := s.m.GetLocation(clientTCPConn.RemoteAddr()) if err != nil { logger.Warningf("Failed location lookup: %v", err) @@ -221,7 +223,9 @@ func (s *tcpService) handleConnection(listenerPort int, clientTCPConn *net.TCPCo s.m.AddOpenTCPConnection(clientLocation) connStart := time.Now() - clientTCPConn.SetKeepAlive(true) + if tcp, ok := clientTCPConn.(*net.TCPConn); ok { + tcp.SetKeepAlive(true) + } // Set a deadline to receive the address to the target. clientTCPConn.SetReadDeadline(connStart.Add(s.readTimeout)) var proxyMetrics metrics.ProxyMetrics diff --git a/ssl.crt b/ssl.crt new file mode 100644 index 00000000..3973a86f --- /dev/null +++ b/ssl.crt @@ -0,0 +1,27 @@ +-----BEGIN CERTIFICATE----- +MIIElDCCAnwCCQDlCwAerg6zUzANBgkqhkiG9w0BAQsFADAMMQowCAYDVQQDDAEq +MB4XDTIyMTAzMTAyMDAxMloXDTI1MDcyNzAyMDAxMlowDDEKMAgGA1UEAwwBKjCC +AiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAOV0rN2wcoWuwHtvYyjKO3NL +WQSCB1u/bDrJYXEFCjCaQFuEVWjW44vHwBnTanWPt2TcEN8ZZ7Nd3fku8ZN7I7i9 +3CPRUDatBD/IE6s90JY0dR8i4ZavOorTD8D8Xg5ioh9KdyEnEmgZpDjzENBEJQZa +jUMytEwM5BJYO1QY6TInRbC0XCigouPBdhHR88kc5e/+tAqY7UC5x48CkLIcDAxQ +IzPBdGWheSPjLYLn6XM0QkatZK6/7q40f03BMt29AVvwN3XgU9xBwr4Iij5nC3AV +RfwBbySx+d+5rIxfsm6L5xOP0Zm5IueWNyCa8JDYrxdOf35bcm9VdPvay0IDp4Dr +gTqWRvLWzBCmsXKdOIDo3O0W3UcSfzjyiue+VKNJFOFEpS02RfqaQZ4k+YpS1QIz +3CyzDb7GVq4gNTO7P4hcMXvWxJPqoA6QYAeboHMQr695ucdjF4hebbVMNajXndhf +BPmvAGKZivUn/PZND1vTj331mqWUMQTJOyvUVv+ZpEobhr3GQvScbFw5zLOaR5ud +LAxgmtcyspCC4MQhu7bNPtnP5jhBXmdiNf7bbUQFgnaGDrNKELMIUeY3/TLAs6cv +U+ZbrbofL3eGMZtCeH/7izZxX2cS1OxcHNesqOU7+ZLznbu7AXGrDo6lH/4VbC44 +YaKbn4oxfplUpQcmWWvJAgMBAAEwDQYJKoZIhvcNAQELBQADggIBABdLOenVrN4/ +D2NZJiD027LVcxnWTEjCCMgoaZ1eeQdaHlpQueL1hpZDTnFCZEsbnp77GFqXhNt4 +lnUgF4n8JmFoqR39MCu76k7VlndLTG2aMPOrc6zfe2JUeaC4Q6/BwothMu1xuz6P +kbWResrXVIbdVH+NPqN3SFzye8MQzBkNcgiNRY9syzaPXUD3OfOYT6xnUU6orqsX +LWnCuakRK9YlaK9X6BdPen12wyAbKg+M4eUMTnC/VTRFjHz8H8CnYjk/bzoxUQZG +QV9XK+2dw0A0MLNXiWtoUmemS/ty+tSMnvEMdfXyskJcP3qVbywp4G4E1cUiHw9z +e+0xYZQnaX0I6ztZWliK8ELHEucS+M20VODVUEryKl/zpqq7Rpx16T4VRrnqtYj2 +MqGLWhjUvDRFsuDaVCTbP8oZS3CrR6p0pfHqatYXcRY8E1YeMaGBMz5DKpgEhbjo +2x6md9E7LWVRa4FQu0N/V5xlTUSGUzzgLAXIXhLiKT0B+FutG914zN8z6AOQ7DMY +IhhosVG535/mL8AnPAoDorOSz5Vk1OxPWYCLFiUZK8rqcP3+z0xe4S/bsoGaz7Oc +J/LzGxaanEKi53lS3zg/N4QYWmWe++l59fZVPa0HiAPsMCoHnOGPuhnULzGeGuPJ ++bLSo9WYo8HwVFk6RHeKkJcDoYRQIyO+ +-----END CERTIFICATE----- diff --git a/ssl.key b/ssl.key new file mode 100644 index 00000000..857afbfb --- /dev/null +++ b/ssl.key @@ -0,0 +1,52 @@ +-----BEGIN PRIVATE KEY----- +MIIJRAIBADANBgkqhkiG9w0BAQEFAASCCS4wggkqAgEAAoICAQDldKzdsHKFrsB7 +b2MoyjtzS1kEggdbv2w6yWFxBQowmkBbhFVo1uOLx8AZ02p1j7dk3BDfGWezXd35 +LvGTeyO4vdwj0VA2rQQ/yBOrPdCWNHUfIuGWrzqK0w/A/F4OYqIfSnchJxJoGaQ4 +8xDQRCUGWo1DMrRMDOQSWDtUGOkyJ0WwtFwooKLjwXYR0fPJHOXv/rQKmO1AuceP +ApCyHAwMUCMzwXRloXkj4y2C5+lzNEJGrWSuv+6uNH9NwTLdvQFb8Dd14FPcQcK+ +CIo+ZwtwFUX8AW8ksfnfuayMX7Jui+cTj9GZuSLnljcgmvCQ2K8XTn9+W3JvVXT7 +2stCA6eA64E6lkby1swQprFynTiA6NztFt1HEn848ornvlSjSRThRKUtNkX6mkGe +JPmKUtUCM9wssw2+xlauIDUzuz+IXDF71sST6qAOkGAHm6BzEK+vebnHYxeIXm21 +TDWo153YXwT5rwBimYr1J/z2TQ9b04999ZqllDEEyTsr1Fb/maRKG4a9xkL0nGxc +OcyzmkebnSwMYJrXMrKQguDEIbu2zT7Zz+Y4QV5nYjX+221EBYJ2hg6zShCzCFHm +N/0ywLOnL1PmW626Hy93hjGbQnh/+4s2cV9nEtTsXBzXrKjlO/mS8527uwFxqw6O +pR/+FWwuOGGim5+KMX6ZVKUHJllryQIDAQABAoICAAOl8UGtFoUNnD3aLYduf7d7 +kTTDJH7O8leU8Bmt7NWM/kz2M61xDTkhueovNFgeKtpNrW7+pmlxqp/VoT2pDY5Y +ZnGjWFUmNxUUh0uHthNNTjdqhI+yxYmDhZKZ8Jzl8JHyyyYZyu8gyT2mj7PgAX6y +XeCdo8Q5yD6KbJcPtlV3zmHa3ERBGZXpc4kg/3FJJlbEg/RPLiaDTar2bXqHe6GO +fKDMCJ+9C4IIkKauLUYJpKwfAaTNpGvcpdGEqtxfru/ZR+h14p9z5DbFR/1qAgKM +NAqnsy6wLbri5t1sgBfF3ayv8rMxAF8SQlogXIbRCyehteE6bv1aLHv8pJKuIDGi +12iNvD+0JC9wJM2Lw4fudH0i10KOIaJmkMEJ9/HfRrDU++8v2PXzo2JMwFAnsD2u +OZGM+q8FEGVvIBmF+AT4/tifd6V3fRRLbvMMy0BWcFyI2+/FiF3SRd5j3MT6Vb7e +Jqd4tsJnNZHW3NBKf/szoCv7QcP8QDNvWDDmUbPzRzPgtIk9XIx8zKYRbAvv0ZI1 ++XNfnIKXhyMXT9YpLYTs608EzhemKsIgvfASwtc1X/BCqOH7tu9pjPKFTxcm0pjT +mjaZN/zgsDwX866FT4ZcqSAxASlhovjJnIFtLs5Ju8DaolQ2jaWzt6sRzWVBR1zL +3VUKhqCt1FvYYJx/0vhpAoIBAQD4EHcDU//6NAG1XjyCOZ8wBuo60WwB2X/oyNN9 +5fgyDLdHuH+ANHqoOIlHFK4RMg1gPCfRN12Hk7JIbnwjR9Bas28fXBlGR/kyUWhf +rvpFMjNCOJNhgbj5kQZOJhEc0Ncq6BWLhE3qi3mMw5resr0lXGy7Rby9P2KmOAq4 +3md3c7YFw6QElnALT9cQGdXdMusAfuFGUDnbiv+4DaPGp19hFM+YR2BmGHUEyGcc +g8iv/J0r2UYRRsq4GE2YQ+w2ANyZ+IEMevSP1ZOX1EXrvXwFf3FHzF5gVeoOxQ7Y +10/w0U9l5NQ4a8LHixG9HIeF1ySYorj3k6Lcf+00cs/+Wqf/AoIBAQDsy9CPT/Q9 +a3KbxMrZYTqpgB+F9eSM6mIEj5DhW4tXwOPNkMBvxGe4tZCoL1MqcfmgN2n35CNl +Oz7JJl+fc0+Wa3i9PQJ4eog4aGwyCr6k/fur14ctSFVhAd0/IJV+L3kfGgklZLZZ +ge1/DlM/H8OhFixkpqlzJscZxKqQRHGQ9Fkk5IbdPf+RUtkrJQUUqg6S7q4JH1Mi +pTPQfkALjqpa6LQzaqWsfmGJJKu4TL6ZDJJy/MMQRXQ7f9fyXkKh2q2LRjfAjx/N ++OZXkkjPmDWLDgSEJN+30HiTyA00EZcgZggA5uJhNNJsRGZvBibLYHF9z2V/xFwp +/cUBIM1oAKw3AoIBAQCj6P06zb5ObR7T4LjKs5hj+625v7dGYZkLD+fvQI2HRK+2 +TEqzQ/noPbM3rIp4AkKkXBtTOuoqM4WSJq8QANvDktzSM+Dfd59JiFEXKF9maY1F +LGz1+OlovlMUQEL+b2A9kazqyzlQyWg/guBKVoB0t2WBOMtFoSRmAJHVJd/oJiUY +GfW+skjGsLLCiM+voX12jl/8PfZ9ApOF4j1dfiqf00h4rnEcBP0Nc/3t8YYiAyE3 +YBHUSJqamjRrcDYcWOVrN7DNtlDy2YT0xeaNpl7UoykO8BNMRHir2bm9vkesMCHu +ig1QWqQRherqsnc6ELa1xI/Dx2HNoRnzlgmpX+2xAoIBAQCCBgBFUS/ZsfBCnDKO +Xpcpj5K/qh+PSPv9aR+yvuOqkd4EeGFSfdQ+VmRSFXpjKiZZ1VO4rGrLIVb+eLW1 +BkpDXEv2DVQX96Bo6N3QNJouWtAgsb4mHTvUgoOMMEYl/cdSTqeLAtwmFfPk+ma3 +mKeBAn3p3qHY+wgEnDrT8OEzKRjx6xIq1epJT+azjCZYDHDoOWsS00KBGZlz+H8O +WY4tUO3x9bN3HgZMmfg4wNs/ium3fhdWDe0e5roa+as42KzGdw4SDAT4wp0opMia +RQfRjSbpsJ2vfydWbljhqG8FeUEXza+slKaekIh2mjgfIJvw6zreh2HcJN5SGkLv +wr7NAoIBAQDTuakq+C5olCpr637kV5wljn1rllmkiOB+PeWJT9HhxzmGrlMdWTZp +EDU2+isRIX1oUp/8VX5/P2unacHKpHqAr8Umx67bz502AM9w5WJqRwZYXThscW6B +1UaQUicpvhmvyBbFJKmo2RLOef3AmybZ5wUobjlcDIadzpc47kB5epWwSeqTs7Uj +74LFhxgxsZk+/QsiyWevXNvUy95aBNu5bAIibcm8rrVOe+yhIJvQjPW4FTbe8HE9 +U6VylhNh4e7yuTOfwpqaYyv5zYNUW5PodTadqqNinRvZjzgkJJ95DP0cquZMPS79 +zBuok4kSfhHNHGrpe845HaOzNPv/SCcu +-----END PRIVATE KEY----- diff --git a/websocket.go b/websocket.go new file mode 100644 index 00000000..c3b2f98e --- /dev/null +++ b/websocket.go @@ -0,0 +1,59 @@ +package main + +import ( + "net" + "net/http" + "strconv" + "strings" + + "github.com/Jigsaw-Code/outline-ss-server/websocket" +) + +func RunWebsocketServer(ssServer *SSServer, port int, certPath, keyPath string) (*wsServer, error) { + ws := &wsServer{ss: ssServer} + l, err := net.ListenTCP("tcp", &net.TCPAddr{Port: port}) + if err != nil { + return nil, err + } + logger.Infof("Websocket server listening on %s", l.Addr()) + ws.listener = l + go func() { + err = http.ServeTLS(l, http.HandlerFunc(ws.handleRequest), certPath, keyPath) + if err != nil { + logger.Errorf("Websocket server closed: %v", err) + } + }() + return ws, nil +} + +type wsServer struct { + ss *SSServer + listener net.Listener +} + +func (s *wsServer) Stop() error { + return s.listener.Close() +} + +func (s *wsServer) handleRequest(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Upgrade(w, r, nil) + if err != nil { + logger.Errorf(err.Error()) + return + } + defer c.Close() + + p, err := strconv.Atoi(strings.Trim(r.URL.Path, "/")) + if err != nil { + logger.Errorf("Invalid path %s", r.URL.Path) + return + } + + port, ok := s.ss.ports[p] + if !ok { + logger.Errorf("Port %d does not exist", p) + return + } + + port.tcpService.HandleConnection(p, c) +} diff --git a/websocket/websocket.go b/websocket/websocket.go new file mode 100644 index 00000000..788b19be --- /dev/null +++ b/websocket/websocket.go @@ -0,0 +1,115 @@ +package websocket + +import ( + "fmt" + "io" + "net/http" + "time" + + onet "github.com/Jigsaw-Code/outline-ss-server/net" + "github.com/gorilla/websocket" +) + +var ( + DefaultHandshakeTimeout = 5 * time.Second + defaultDialer = Dialer{ + HandshakeTimeout: DefaultHandshakeTimeout, + } + defaultUpgrade = Upgrader{ + HandshakeTimeout: DefaultHandshakeTimeout, + } +) + +type Dialer websocket.Dialer + +func Dial(u string, h http.Header) (onet.DuplexConn, error) { + return defaultDialer.Dial(u, h) +} + +func (d *Dialer) Dial(u string, h http.Header) (onet.DuplexConn, error) { + wd := websocket.Dialer(*d) + ws, _, err := wd.Dial(u, h) + if err != nil { + return nil, err + } + return wrapWS(ws), err +} + +type Upgrader websocket.Upgrader + +func Upgrade(w http.ResponseWriter, r *http.Request, h http.Header) (onet.DuplexConn, error) { + return defaultUpgrade.Upgrade(w, r, h) +} + +func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, h http.Header) (onet.DuplexConn, error) { + wu := websocket.Upgrader(*u) + c, err := wu.Upgrade(w, r, h) + if err != nil { + return nil, fmt.Errorf("upgrading websocket connection failed: %v", err) + } + return wrapWS(c), nil +} + +type wsWrapper struct { + *websocket.Conn + r io.Reader + readClosed bool + writeClosed bool +} + +func wrapWS(c *websocket.Conn) *wsWrapper { + ws := &wsWrapper{Conn: c} + c.SetCloseHandler(ws.closeHandler) + return ws +} + +func (c *wsWrapper) Write(p []byte) (int, error) { + return len(p), c.WriteMessage(websocket.BinaryMessage, p) +} + +func (c *wsWrapper) Read(p []byte) (n int, err error) { + defer func() { + if websocket.IsCloseError(err, websocket.CloseNormalClosure) { + err = io.EOF + } + }() + if c.r == nil { + if c.readClosed { + return 0, io.EOF + } + var err error + _, c.r, err = c.Conn.NextReader() + if err != nil { + return 0, err + } + } + n, err = c.r.Read(p) + if err == io.EOF && !c.readClosed { + c.r = nil + return c.Read(p) + } + return n, err +} + +func (c *wsWrapper) CloseRead() error { + c.readClosed = true + return nil +} + +func (c *wsWrapper) closeHandler(code int, text string) error { + return c.CloseRead() +} + +func (c *wsWrapper) CloseWrite() error { + c.writeClosed = true + message := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") + return c.WriteControl(websocket.CloseMessage, message, time.Now().Add(time.Second)) +} + +func (c *wsWrapper) Close() error { + return c.Conn.Close() +} + +func (c *wsWrapper) SetDeadline(t time.Time) error { + return c.Conn.UnderlyingConn().SetDeadline(t) +} diff --git a/websocket/websocket_test.go b/websocket/websocket_test.go new file mode 100644 index 00000000..29b624b6 --- /dev/null +++ b/websocket/websocket_test.go @@ -0,0 +1,107 @@ +package websocket + +import ( + "bytes" + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "strconv" + "testing" + "time" + + onet "github.com/Jigsaw-Code/outline-ss-server/net" + ss "github.com/Jigsaw-Code/outline-ss-server/shadowsocks" +) + +func TestWebsocket(t *testing.T) { + l, err := net.ListenTCP("tcp", nil) + if err != nil { + t.Fatalf("Starting listener failed: %v", err) + } + defer l.Close() + + connCh := make(chan onet.DuplexConn) + defer close(connCh) + handler := func(w http.ResponseWriter, r *http.Request) { + u := Upgrader{HandshakeTimeout: 50 * time.Millisecond} + c, err := u.Upgrade(w, r, nil) + if err != nil { + t.Logf("Upgrading websocket failed: %v", err) + } + connCh <- c + } + go func() { + http.ServeTLS(l, http.HandlerFunc(handler), TestCert, TestKey) + }() + + d := Dialer{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + HandshakeTimeout: 50 * time.Millisecond, + } + clientConn, err := d.Dial(fmt.Sprintf("wss://127.0.0.1:%d/", addrPort(t, l.Addr())), nil) + if err != nil { + t.Fatalf("Connecting to websocket server failed: %v", err) + } + + var serverConn onet.DuplexConn + select { + case <-time.After(50 * time.Millisecond): + t.Fatal("Websocket connection not accepted") + case serverConn = <-connCh: + } + + testOneWay := func(left, right onet.DuplexConn) { + payload := ss.MakeTestPayload(1200) + n, err := left.Write(payload) + if err != nil { + t.Fatalf("Writing payload failed: %v", err) + } else if n != len(payload) { + t.Fatalf("Write(), want=%d got=%d", len(payload), n) + } + + right.SetReadDeadline(time.Now().Add(50 * time.Millisecond)) + b := make([]byte, 500) + for i := 0; i < 3; i++ { + n, err = right.Read(b) + if err != nil { + t.Fatalf("Reading payload failed: %v", err) + } + if bytes.Compare(b[:n], payload[len(b)*i:len(b)*i+n]) != 0 { + t.Fatal("Read payload does not match write payload") + } + + // Close write to show connection can still be drained afterwards + if i == 0 { + left.CloseWrite() + } + } + + n, err = right.Read(b) + if err != io.EOF { + t.Fatalf("Read after finish has no error: want=EOF got=%v", err) + } + right.CloseRead() + + _, err = left.Write(payload) + if err == nil { + t.Fatalf("Write after close: want=err got=%v", err) + } + } + + testOneWay(clientConn, serverConn) + testOneWay(serverConn, clientConn) +} + +func addrPort(t *testing.T, a net.Addr) int { + _, p, err := net.SplitHostPort(a.String()) + if err != nil { + t.Fatalf(err.Error()) + } + port, err := strconv.Atoi(p) + if err != nil { + t.Fatalf(err.Error()) + } + return port +} diff --git a/websocket/websocket_testing.go b/websocket/websocket_testing.go new file mode 100644 index 00000000..f542cbc9 --- /dev/null +++ b/websocket/websocket_testing.go @@ -0,0 +1,19 @@ +package websocket + +import ( + "path" + "path/filepath" + "runtime" +) + +var ( + TestCert string + TestKey string +) + +func init() { + _, filename, _, _ := runtime.Caller(0) + cwd := filepath.Dir(filepath.Dir(filename)) + TestCert = path.Join(cwd, "ssl.crt") + TestKey = path.Join(cwd, "ssl.key") +} diff --git a/websocket_test.go b/websocket_test.go new file mode 100644 index 00000000..5c469e88 --- /dev/null +++ b/websocket_test.go @@ -0,0 +1,131 @@ +package main + +import ( + "bytes" + "crypto/tls" + "fmt" + "net" + "strconv" + "sync" + "testing" + "time" + + onet "github.com/Jigsaw-Code/outline-ss-server/net" + "github.com/Jigsaw-Code/outline-ss-server/service" + "github.com/Jigsaw-Code/outline-ss-server/shadowsocks" + "github.com/Jigsaw-Code/outline-ss-server/websocket" +) + +func TestRunWebsocketServer(t *testing.T) { + ss := &SSServer{ports: make(map[int]*ssPort)} + ws, err := RunWebsocketServer(ss, 0, websocket.TestCert, websocket.TestKey) + if err != nil { + t.Fatalf("Failed running websocket server: %v", err) + } + + testPort := 2000 + port := &ssPort{cipherList: service.NewCipherList()} + ss.ports[testPort] = port + tcpService := &fakeTCPService{connCh: make(chan onet.DuplexConn)} + port.tcpService = tcpService + + t.Run("with registered port", func(t *testing.T) { + d := websocket.Dialer{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + HandshakeTimeout: 50 * time.Millisecond, + } + u := fmt.Sprintf("wss://127.0.0.1:%d/%d", addrPort(t, ws.listener.Addr()), testPort) + clientConn, err := d.Dial(u, nil) + if err != nil { + t.Errorf("Failed to connect to websocket server: %v", err) + } + + var serverConn onet.DuplexConn + select { + case <-time.After(50 * time.Millisecond): + t.Fatal("Failed to receive connection on server") + case serverConn = <-tcpService.connCh: + } + defer tcpService.running.Done() + + payload := shadowsocks.MakeTestPayload(1000) + b := make([]byte, 1024) + + _, err = clientConn.Write(payload) + if err != nil { + t.Errorf("Writing payload failed: %v", err) + } + n, err := serverConn.Read(b) + if err != nil { + t.Fatalf("Reading payload failed: %v", err) + } else if bytes.Compare(payload, b[:n]) != 0 { + t.Fatal("Read payload does not match write payload") + } + + _, err = serverConn.Write(payload) + if err != nil { + t.Errorf("Writing payload failed: %v", err) + } + n, err = clientConn.Read(b) + if err != nil { + t.Fatalf("Reading payload failed: %v", err) + } else if bytes.Compare(payload, b[:n]) != 0 { + t.Fatal("Read payload does not match write payload") + } + }) + + t.Run("with non registered port", func(t *testing.T) { + d := websocket.Dialer{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + HandshakeTimeout: 50 * time.Millisecond, + } + u := fmt.Sprintf("wss://127.0.0.1:%d/%d", addrPort(t, ws.listener.Addr()), 3456) + _, err := d.Dial(u, nil) + if err != nil { + t.Fatalf("Failed to connect to websocket server: %v", err) + } + + select { + case <-tcpService.connCh: + t.Fatalf("Expected not to receive connection on non existing port, but received one") + case <-time.After(50 * time.Millisecond): + } + }) +} + +type fakeTCPService struct { + connCh chan onet.DuplexConn + running sync.WaitGroup +} + +func (f *fakeTCPService) HandleConnection(listenerPort int, clientTCPConn onet.DuplexConn) { + f.running.Add(1) + f.connCh <- clientTCPConn + f.running.Wait() +} + +func (f *fakeTCPService) SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) {} + +func (f *fakeTCPService) Serve(listener *net.TCPListener) error { + return nil +} + +func (f *fakeTCPService) Stop() error { + return nil +} + +func (f *fakeTCPService) GracefulStop() error { + return nil +} + +func addrPort(t *testing.T, a net.Addr) int { + _, p, err := net.SplitHostPort(a.String()) + if err != nil { + t.Fatalf(err.Error()) + } + port, err := strconv.Atoi(p) + if err != nil { + t.Fatalf(err.Error()) + } + return port +}