diff --git a/context.go b/context.go index eca887b..b71d77e 100644 --- a/context.go +++ b/context.go @@ -85,8 +85,8 @@ func (ctx *Context) UseFunc(fns ...MiddlewareFunc) { func (ctx *Context) Next(req *http.Request) (*http.Response, error) { var ( - err error total = len(ctx.middlewares) + err error ) ctx.mi++ if ctx.mi >= total { diff --git a/forward_handler.go b/forward_handler.go index 1c97528..177b797 100644 --- a/forward_handler.go +++ b/forward_handler.go @@ -1,26 +1,32 @@ package mps import ( + "bytes" "io" + "io/ioutil" "net/http" ) +// The forward proxy type. Implements http.Handler. type ForwardHandler struct { Ctx *Context } +// Create a ForwardHandler func NewForwardHandler() *ForwardHandler { return &ForwardHandler{ Ctx: NewContext(), } } +// func NewForwardHandlerWithContext(ctx *Context) *ForwardHandler { return &ForwardHandler{ Ctx: ctx, } } +// Standard net/http function. You can use it alone func (forward *ForwardHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // Copying a Context preserves the Transport, Middleware ctx := forward.Ctx.Copy() @@ -34,12 +40,16 @@ func (forward *ForwardHandler) ServeHTTP(rw http.ResponseWriter, req *http.Reque resp, err := ctx.Next(req) if err != nil { - http.Error(rw, err.Error(), 500) + http.Error(rw, err.Error(), 502) return } - origBody := resp.Body - defer origBody.Close() + bodyRes, err := ioutil.ReadAll(resp.Body) + if err != nil { + http.Error(rw, err.Error(), 502) + return + } + resp.Body.Close() // http.ResponseWriter will take care of filling the correct response length // Setting it now, might impose wrong value, contradicting the actual new @@ -47,13 +57,20 @@ func (forward *ForwardHandler) ServeHTTP(rw http.ResponseWriter, req *http.Reque // We keep the original body to remove the header only if things changed. // This will prevent problems with HEAD requests where there's no body, yet, // the Content-Length header should be set. - if origBody != resp.Body { + if resp.ContentLength != int64(len(bodyRes)) { resp.Header.Del("Content-Length") } + copyHeaders(rw.Header(), resp.Header, forward.Ctx.KeepDestinationHeaders) rw.WriteHeader(resp.StatusCode) - io.Copy(rw, resp.Body) - resp.Body.Close() + + body := ioutil.NopCloser(bytes.NewReader(bodyRes)) + _, err = io.Copy(rw, body) + body.Close() + if err != nil { + http.Error(rw, err.Error(), 502) + return + } } func (forward *ForwardHandler) Transport() *http.Transport { diff --git a/http_proxy.go b/http_proxy.go index caf7a45..205ad50 100644 --- a/http_proxy.go +++ b/http_proxy.go @@ -7,6 +7,7 @@ import ( "net/http" ) +// The basic proxy type. Implements http.Handler. type HttpProxy struct { // HTTPS requests use the TunnelHandler proxy by default HttpsHandler http.Handler @@ -14,6 +15,9 @@ type HttpProxy struct { // HTTP requests use the ForwardHandler proxy by default HttpHandler http.Handler + // HTTP requests use the ReverseHandler proxy by default + ReverseHandler http.Handler + Ctx *Context } @@ -27,32 +31,50 @@ func NewHttpProxy() *HttpProxy { HttpHandler: &ForwardHandler{Ctx: ctx}, // default HTTPS proxy HttpsHandler: &TunnelHandler{Ctx: ctx}, + // default Reverse proxy + ReverseHandler: &ReverseHandler{Ctx: ctx}, } } +// Standard net/http function. func (proxy *HttpProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if req.Method == http.MethodConnect { proxy.HttpsHandler.ServeHTTP(rw, req) } - proxy.HttpHandler.ServeHTTP(rw, req) + + if !req.URL.IsAbs() { + proxy.ReverseHandler.ServeHTTP(rw, req) + } else { + proxy.HttpHandler.ServeHTTP(rw, req) + } } +// Use registers an Middleware to proxy func (proxy *HttpProxy) Use(middleware ...Middleware) { proxy.Ctx.Use(middleware...) } +// UseFunc registers an MiddlewareFunc to proxy func (proxy *HttpProxy) UseFunc(fus ...MiddlewareFunc) { proxy.Ctx.UseFunc(fus...) } -func (proxy *HttpProxy) OnRequest(filter ...Filter) *ReqCondition { - return &ReqCondition{proxy: proxy, filters: filter} +// OnRequest filter requests through Filters +func (proxy *HttpProxy) OnRequest(filters ...Filter) *ReqCondition { + return &ReqCondition{ctx: proxy.Ctx, filters: filters} } -func (proxy *HttpProxy) OnResponse(filter ...Filter) *RespCondition { - return &RespCondition{proxy: proxy, filters: filter} +// OnResponse filter response through Filters +func (proxy *HttpProxy) OnResponse(filters ...Filter) *RespCondition { + return &RespCondition{ctx: proxy.Ctx, filters: filters} } +// Transport get http.Transport instance +func (proxy *HttpProxy) Transport() *http.Transport { + return proxy.Ctx.Transport +} + +// hijacker an HTTP handler to take over the connection. func hijacker(rw http.ResponseWriter) (conn net.Conn, err error) { hij, ok := rw.(http.Hijacker) if !ok { @@ -77,9 +99,6 @@ func removeProxyHeaders(r *http.Request) { // If no Accept-Encoding header exists, Transport will add the headers it can accept // and would wrap the response body with the relevant reader. r.Header.Del("Accept-Encoding") - // curl can add that, see - // https://jdebp.eu./FGA/web-proxy-connection-header.html - // RFC 2616 (section 13.5.1) // https://www.ietf.org/rfc/rfc2616.txt r.Header.Del("Proxy-Connection") diff --git a/http_proxy_test.go b/http_proxy_test.go index e8ce314..74f5611 100644 --- a/http_proxy_test.go +++ b/http_proxy_test.go @@ -17,9 +17,15 @@ func NewTestServer() *httptest.Server { })) } -func TestNewHttpProxy(t *testing.T) { - asserts := assert.New(t) +func HttpGet(rawurl string, proxy func(r *http.Request) (*url.URL, error)) (*http.Response, error) { + req, _ := http.NewRequest(http.MethodGet, rawurl, nil) + http.DefaultClient.Transport = &http.Transport{ + Proxy: proxy, + } + return http.DefaultClient.Do(req) +} +func TestNewHttpProxy(t *testing.T) { srv := NewTestServer() defer srv.Close() @@ -27,48 +33,49 @@ func TestNewHttpProxy(t *testing.T) { proxySrv := httptest.NewServer(proxy) defer proxySrv.Close() - req, _ := http.NewRequest(http.MethodGet, srv.URL, nil) - http.DefaultClient.Transport = &http.Transport{ - Proxy: func(r *http.Request) (*url.URL, error) { - return url.Parse(srv.URL) - }, + resp, err := HttpGet(srv.URL, func(r *http.Request) (*url.URL, error) { + return url.Parse(proxySrv.URL) + }) + if err != nil { + t.Fatal(err) } - resp, err := http.DefaultClient.Do(req) - asserts.Equal(err, nil, "err should be equal nil") - body, _ := ioutil.ReadAll(resp.Body) resp.Body.Close() + asserts := assert.New(t) asserts.Equal(resp.StatusCode, 200, "statusCode should be equal 200") asserts.Equal(int64(len(body)), resp.ContentLength) } func TestMiddlewareFunc(t *testing.T) { + // target server + srv := NewTestServer() + defer srv.Close() + + // proxy server proxy := NewHttpProxy() + // use Middleware proxy.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) { log.Println(req.URL.String()) return ctx.Next(req) }) - srv := httptest.NewServer(proxy) - defer srv.Close() - - req, _ := http.NewRequest(http.MethodGet, "https://httpbin.org/get", nil) - http.DefaultClient.Transport = &http.Transport{ - Proxy: func(r *http.Request) (*url.URL, error) { - return url.Parse(srv.URL) - }, - } + proxySrv := httptest.NewServer(proxy) + defer proxySrv.Close() - resp, err := http.DefaultClient.Do(req) + // send request + resp, err := HttpGet(srv.URL, func(r *http.Request) (*url.URL, error) { + return url.Parse(proxySrv.URL) + }) if err != nil { t.Fatal(err) } - defer resp.Body.Close() body, _ := ioutil.ReadAll(resp.Body) + resp.Body.Close() - log.Println(err) - log.Println(resp.Status) + asserts := assert.New(t) + asserts.Equal(resp.StatusCode, 200) + asserts.Equal(int64(len(body)), resp.ContentLength) log.Println(string(body)) } diff --git a/middleware.go b/middleware.go index 8287db5..c39dba5 100644 --- a/middleware.go +++ b/middleware.go @@ -2,12 +2,26 @@ package mps import "net/http" +// Middleware will "tamper" with the request coming to the proxy server type Middleware interface { + // Execute the next middleware as a linked list. "ctx.Next(req)" + // eg: + // func Handle(req *http.Request, ctx *Context) (*http.Response, error) { + // // You can do anything to modify the http.Request ... + // resp, err := ctx.Next(req) + // // You can do anything to modify the http.Response ... + // return resp, err + // } + // + // Alternatively, you can simply return the response without executing `ctx.Next()`, + // which will interrupt subsequent middleware execution. Handle(req *http.Request, ctx *Context) (*http.Response, error) } +// A wrapper that would convert a function to a Middleware interface type type MiddlewareFunc func(req *http.Request, ctx *Context) (*http.Response, error) +// MiddlewareFunc.Handle(req, ctx) <=> MiddlewareFunc(req, ctx) func (f MiddlewareFunc) Handle(req *http.Request, ctx *Context) (*http.Response, error) { return f(req, ctx) } diff --git a/mitm_handler.go b/mitm_handler.go index 30fdb5b..29a694c 100644 --- a/mitm_handler.go +++ b/mitm_handler.go @@ -30,6 +30,7 @@ var ( httpsRegexp = regexp.MustCompile("^https://") ) +// The Man-in-the-middle proxy type. Implements http.Handler. type MitmHandler struct { Ctx *Context Certificate tls.Certificate @@ -37,15 +38,16 @@ type MitmHandler struct { CertContainer cert.Container } +// Create a MitmHandler, use default cert. func NewMitmHandler() *MitmHandler { return &MitmHandler{ - Ctx: NewContext(), - // default MPS Certificate + Ctx: NewContext(), Certificate: cert.DefaultCertificate, CertContainer: cert.NewMemProvider(), } } +// Create a MitmHandler with cert file func NewMitmHandlerWithCert(certFile, keyFile string) (*MitmHandler, error) { certificate, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { @@ -58,6 +60,7 @@ func NewMitmHandlerWithCert(certFile, keyFile string) (*MitmHandler, error) { }, nil } +// Standard net/http function. You can use it alone func (mitm *MitmHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // get hijacker connection proxyClient, err := hijacker(w) @@ -171,7 +174,6 @@ func (mitm *MitmHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } } - }() } @@ -232,18 +234,14 @@ func signHost(ca tls.Certificate, hosts []string) (cert *tls.Certificate, err er // certificate template tpl := x509.Certificate{ - // SerialNumber 是 CA 颁布的唯一序列号,在此使用一个大随机数来代表它 SerialNumber: big.NewInt(rand.Int63()), Issuer: x509ca.Subject, - // pkix.Name代表一个X.509识别名。只包含识别名的公共属性,额外的属性被忽略。 Subject: pkix.Name{ Organization: []string{"MPS untrusted MITM proxy Inc"}, }, - NotBefore: time.Unix(0, 0), - NotAfter: time.Now().AddDate(20, 0, 0), - // KeyUsage 与 ExtKeyUsage 用来表明该证书是用来做服务器认证的 - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - // 密钥扩展用途的序列 + NotBefore: time.Unix(0, 0), + NotAfter: time.Now().AddDate(20, 0, 0), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, EmailAddresses: x509ca.EmailAddresses, diff --git a/req_condition.go b/req_condition.go index dde0b9d..968ea52 100644 --- a/req_condition.go +++ b/req_condition.go @@ -5,7 +5,7 @@ import ( ) type ReqCondition struct { - proxy *HttpProxy + ctx *Context filters []Filter } @@ -14,7 +14,7 @@ func (cond *ReqCondition) DoFunc(fn func(req *http.Request) (*http.Request, *htt } func (cond *ReqCondition) Do(fn RequestHandle) { - cond.proxy.Ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) { + cond.ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) { total := len(cond.filters) for i := 0; i < total; i++ { if !cond.filters[i].Match(req) { diff --git a/resp_condition.go b/resp_condition.go index e91d9c9..4923f2d 100644 --- a/resp_condition.go +++ b/resp_condition.go @@ -5,7 +5,7 @@ import ( ) type RespCondition struct { - proxy *HttpProxy + ctx *Context filters []Filter } @@ -14,7 +14,7 @@ func (cond *RespCondition) DoFunc(fn func(resp *http.Response) (*http.Response, } func (cond *RespCondition) Do(fn ResponseHandle) { - cond.proxy.Ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) { + cond.ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) { total := len(cond.filters) for i := 0; i < total; i++ { if !cond.filters[i].Match(req) { diff --git a/reverse_handler.go b/reverse_handler.go index 2096810..559026d 100644 --- a/reverse_handler.go +++ b/reverse_handler.go @@ -1,6 +1,11 @@ package mps -import "net/http" +import ( + "bytes" + "io" + "io/ioutil" + "net/http" +) // ReverseHandler is a reverse proxy server implementation type ReverseHandler struct { @@ -13,6 +18,63 @@ func NewReverseHandler() *ReverseHandler { } } +// Standard net/http function. You can use it alone func (reverse *ReverseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + // Copying a Context preserves the Transport, Middleware + ctx := reverse.Ctx.Copy() + ctx.Request = req + resp, err := ctx.Next(req) + if err != nil { + http.Error(rw, err.Error(), 502) + return + } + + bodyRes, err := ioutil.ReadAll(resp.Body) + if err != nil { + http.Error(rw, err.Error(), 502) + return + } + resp.Body.Close() + + // http.ResponseWriter will take care of filling the correct response length + // Setting it now, might impose wrong value, contradicting the actual new + // body the user returned. + // We keep the original body to remove the header only if things changed. + // This will prevent problems with HEAD requests where there's no body, yet, + // the Content-Length header should be set. + if resp.ContentLength != int64(len(bodyRes)) { + resp.Header.Del("Content-Length") + } + + copyHeaders(rw.Header(), resp.Header, reverse.Ctx.KeepDestinationHeaders) + rw.WriteHeader(resp.StatusCode) + + body := ioutil.NopCloser(bytes.NewReader(bodyRes)) + _, err = io.Copy(rw, body) + body.Close() + if err != nil { + http.Error(rw, err.Error(), 502) + return + } +} + +// Use registers an Middleware to proxy +func (reverse *ReverseHandler) Use(middleware ...Middleware) { + reverse.Ctx.Use(middleware...) +} + +// UseFunc registers an MiddlewareFunc to proxy +func (reverse *ReverseHandler) UseFunc(fus ...MiddlewareFunc) { + reverse.Ctx.UseFunc(fus...) +} + +// OnRequest filter requests through Filters +func (reverse *ReverseHandler) OnRequest(filters ...Filter) *ReqCondition { + return &ReqCondition{ctx: reverse.Ctx, filters: filters} +} + +// OnResponse filter response through Filters +func (reverse *ReverseHandler) OnResponse(filters ...Filter) *RespCondition { + return &RespCondition{ctx: reverse.Ctx, filters: filters} } diff --git a/tunnel_handler.go b/tunnel_handler.go index 9b5e640..a9829ff 100644 --- a/tunnel_handler.go +++ b/tunnel_handler.go @@ -15,16 +15,19 @@ var ( hasPort = regexp.MustCompile(`:\d+$`) ) +// The tunnel proxy type. Implements http.Handler. type TunnelHandler struct { Ctx *Context } +// Create a tunnel handler func NewTunnelHandler() *TunnelHandler { return &TunnelHandler{ Ctx: NewContext(), } } +// Standard net/http function. You can use it alone func (tunnel *TunnelHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // hijacker connection proxyClient, err := hijacker(rw) @@ -34,10 +37,10 @@ func (tunnel *TunnelHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request } var ( - u *url.URL = nil - targetConn net.Conn = nil - targetAddr = hostAndPort(req.URL.Host) - isProxy = false + u *url.URL = nil + targetConn net.Conn = nil + targetAddr = hostAndPort(req.URL.Host) + isCascadeProxy = false ) if tunnel.Ctx.Transport != nil && tunnel.Ctx.Transport.Proxy != nil { u, err = tunnel.Ctx.Transport.Proxy(req) @@ -48,7 +51,7 @@ func (tunnel *TunnelHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request if u != nil { // connect addr eg. "localhost:80" targetAddr = hostAndPort(u.Host) - isProxy = true + isCascadeProxy = true } } @@ -60,9 +63,11 @@ func (tunnel *TunnelHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request } // The cascade proxy needs to forward the request - if isProxy { + if isCascadeProxy { + // The cascading agent needs to send it as-is _ = req.Write(targetConn) } else { + // Tell the client that the tunnel is ready _, _ = proxyClient.Write(HttpTunnelOk) } diff --git a/tunnel_handler_test.go b/tunnel_handler_test.go index 4f3c915..c947378 100644 --- a/tunnel_handler_test.go +++ b/tunnel_handler_test.go @@ -1,8 +1,7 @@ package mps import ( - "io/ioutil" - "log" + "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" "net/url" @@ -10,15 +9,17 @@ import ( ) func TestNewTunnelHandler(t *testing.T) { + srv := NewTestServer() + defer srv.Close() + tunnel := NewTunnelHandler() //tunnel.Transport().Proxy = func(r *http.Request) (*url.URL, error) { - // //return url.Parse("http://121.56.39.197:4283") // return url.Parse("http://127.0.0.1:7890") //} tunnelSrv := httptest.NewServer(tunnel) defer tunnelSrv.Close() - req, _ := http.NewRequest(http.MethodGet, "https://httpbin.org/get", nil) + req, _ := http.NewRequest(http.MethodGet, srv.URL, nil) http.DefaultClient.Transport = &http.Transport{ Proxy: func(r *http.Request) (*url.URL, error) { return url.Parse(tunnelSrv.URL) @@ -29,11 +30,8 @@ func TestNewTunnelHandler(t *testing.T) { if err != nil { t.Fatal(err) } - defer resp.Body.Close() - - body, _ := ioutil.ReadAll(resp.Body) + resp.Body.Close() - log.Println(err) - log.Println(resp.Status) - log.Println(string(body)) + asserts := assert.New(t) + asserts.Equal(resp.StatusCode, 200) }