diff --git a/context.go b/context.go index a40ca76..eca887b 100644 --- a/context.go +++ b/context.go @@ -3,13 +3,24 @@ package mps import ( "context" "crypto/tls" + "net" "net/http" + "time" ) +// Context for the request +// which contains Middleware, Transport, and other values type Context struct { + // context.Context Context context.Context - Request *http.Request - Response *http.Response + + // Request context-dependent requests + Request *http.Request + + // Response is associated with Request + Response *http.Response + + // Transport is used for global HTTP requests, and it will be reused. Transport *http.Transport // In some cases it is not always necessary to remove the Proxy Header. @@ -20,7 +31,10 @@ type Context struct { // present in the http.Response before proxying KeepDestinationHeaders bool - // requests Middleware + // middlewares ACTS on Request and Response. + // It's going to be reused by the Context + // mi is the index subscript of the middlewares traversal + // the default value for the index is -1 mi int middlewares []Middleware } @@ -29,25 +43,25 @@ func NewContext() *Context { return &Context{ Context: context.Background(), Transport: &http.Transport{ - //DialContext: (&net.Dialer{ - // Timeout: 15 * time.Second, - // KeepAlive: 30 * time.Second, - // DualStack: true, - //}).DialContext, - ////ForceAttemptHTTP2: true, - //MaxIdleConns: 100, - //IdleConnTimeout: 90 * time.Second, - //TLSHandshakeTimeout: 10 * time.Second, - //ExpectContinueTimeout: 1 * time.Second, + DialContext: (&net.Dialer{ + Timeout: 15 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, Proxy: http.ProxyFromEnvironment, }, - Request: nil, - Response: nil, - KeepHeader: false, + Request: nil, + Response: nil, + KeepHeader: false, KeepDestinationHeaders: false, - mi: -1, - middlewares: make([]Middleware, 0), + mi: -1, + middlewares: make([]Middleware, 0), } } @@ -88,7 +102,7 @@ func (ctx *Context) Next(req *http.Request) (*http.Response, error) { func (ctx *Context) Copy() *Context { return &Context{ - Context: context.Background(), + Context: context.Background(), Request: nil, Response: nil, KeepHeader: false, diff --git a/filter.go b/filter.go index c0769ed..aeaa3cb 100644 --- a/filter.go +++ b/filter.go @@ -3,29 +3,77 @@ package mps import ( "net/http" "regexp" + "strings" ) type Filter interface { - Match(expr string) bool + Match(req *http.Request) bool } -type FilterFunc func(expr string) bool +type FilterFunc func(req *http.Request) bool -func (f FilterFunc) Match(expr string) bool { - return f(expr) +func (f FilterFunc) Match(req *http.Request) bool { + return f(req) } -// 匹配域名 -var MatchIsHost = func(expr string, req *http.Request) Filter { - exp, err := regexp.Compile(expr) - if err != nil { - panic(err) +// FilterHostMatches for request.Host +func FilterHostMatches(regexps ...*regexp.Regexp) Filter { + return FilterFunc(func(req *http.Request) bool { + for _, re := range regexps { + if re.MatchString(req.Host) { + return true + } + } + return false + }) +} + +// FilterHostIs returns a Filter, testing whether the host to which the request is directed to equal +// to one of the given strings +func FilterHostIs(hosts ...string) Filter { + hostSet := make(map[string]bool) + for _, h := range hosts { + hostSet[h] = true } - return FilterFunc(func(expr string) bool { - return exp.MatchString(req.Host) + return FilterFunc(func(req *http.Request) bool { + _, ok := hostSet[req.URL.Host] + return ok + }) +} + +// FilterUrlMatches returns a Filter testing whether the destination URL +// of the request matches the given regexp, with or without prefix +func FilterUrlMatches(re *regexp.Regexp) Filter { + return FilterFunc(func(req *http.Request) bool { + return re.MatchString(req.URL.Path) || + re.MatchString(req.URL.Host+req.URL.Path) }) } -type ReqHandler interface { - Handler(ctx *Context) +// FilterUrlHasPrefix returns a Filter checking wether the destination URL the proxy client has requested +// has the given prefix, with or without the host. +// For example FilterUrlHasPrefix("host/x") will match requests of the form 'GET host/x', and will match +// requests to url 'http://host/x' +func FilterUrlHasPrefix(prefix string) Filter { + return FilterFunc(func(req *http.Request) bool { + return strings.HasPrefix(req.URL.Path, prefix) || + strings.HasPrefix(req.URL.Host+req.URL.Path, prefix) || + strings.HasPrefix(req.URL.Scheme+req.URL.Host+req.URL.Path, prefix) + }) +} + +// FilterUrlIs returns a Filter, testing whether or not the request URL is one of the given strings +// with or without the host prefix. +// FilterUrlIs("google.com/","foo") will match requests 'GET /' to 'google.com', requests `'GET google.com/' to +// any host, and requests of the form 'GET foo'. +func FilterUrlIs(urls ...string) Filter { + urlSet := make(map[string]bool) + for _, u := range urls { + urlSet[u] = true + } + return FilterFunc(func(req *http.Request) bool { + _, pathOk := urlSet[req.URL.Path] + _, hostAndOk := urlSet[req.URL.Host+req.URL.Path] + return pathOk || hostAndOk + }) } diff --git a/http_proxy.go b/http_proxy.go index 67cc119..caf7a45 100644 --- a/http_proxy.go +++ b/http_proxy.go @@ -22,9 +22,9 @@ func NewHttpProxy() *HttpProxy { ctx := NewContext() return &HttpProxy{ - Ctx: ctx, + Ctx: ctx, // default HTTP proxy - HttpHandler: &ForwardHandler{Ctx: ctx}, + HttpHandler: &ForwardHandler{Ctx: ctx}, // default HTTPS proxy HttpsHandler: &TunnelHandler{Ctx: ctx}, } @@ -45,6 +45,14 @@ func (proxy *HttpProxy) UseFunc(fus ...MiddlewareFunc) { proxy.Ctx.UseFunc(fus...) } +func (proxy *HttpProxy) OnRequest(filter ...Filter) *ReqCondition { + return &ReqCondition{proxy: proxy, filters: filter} +} + +func (proxy *HttpProxy) OnResponse(filter ...Filter) *RespCondition { + return &RespCondition{proxy: proxy, filters: filter} +} + func hijacker(rw http.ResponseWriter) (conn net.Conn, err error) { hij, ok := rw.(http.Hijacker) if !ok { diff --git a/http_proxy_test.go b/http_proxy_test.go index afe03f9..a0587c5 100644 --- a/http_proxy_test.go +++ b/http_proxy_test.go @@ -61,4 +61,4 @@ func TestMiddlewareFunc(t *testing.T) { log.Println(err) log.Println(resp.Status) log.Println(string(body)) -} \ No newline at end of file +} diff --git a/mitm_handler.go b/mitm_handler.go index 7c12a36..30fdb5b 100644 --- a/mitm_handler.go +++ b/mitm_handler.go @@ -31,10 +31,9 @@ var ( ) type MitmHandler struct { + Ctx *Context Certificate tls.Certificate - Ctx *Context - - TLSConfig *tls.Config + // CertContainer is certificate storage container CertContainer cert.Container } @@ -42,15 +41,23 @@ func NewMitmHandler() *MitmHandler { return &MitmHandler{ Ctx: NewContext(), // default MPS Certificate - Certificate: cert.DefaultCertificate, - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - // Certificate cache storage container + Certificate: cert.DefaultCertificate, CertContainer: cert.NewMemProvider(), } } +func NewMitmHandlerWithCert(certFile, keyFile string) (*MitmHandler, error) { + certificate, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + return &MitmHandler{ + Ctx: NewContext(), + Certificate: certificate, + CertContainer: cert.NewMemProvider(), + }, nil +} + func (mitm *MitmHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // get hijacker connection proxyClient, err := hijacker(w) @@ -173,10 +180,11 @@ func (mitm *MitmHandler) TLSConfigFromCA(host string) (*tls.Config, error) { // Returned existing certificate for the host crt, err := mitm.CertContainer.Get(host) - if err == nil { - config := cloneTLSConfig(mitm.TLSConfig) - config.Certificates = append(config.Certificates, *crt) - return config, nil + if err == nil && crt != nil { + return &tls.Config{ + InsecureSkipVerify: true, + Certificates: []tls.Certificate{*crt}, + }, nil } // Issue a certificate for host @@ -189,10 +197,10 @@ func (mitm *MitmHandler) TLSConfigFromCA(host string) (*tls.Config, error) { // Set certificate to container mitm.CertContainer.Set(host, crt) - config := &tls.Config{ - Certificates: []tls.Certificate{*crt}, - } - return config, nil + return &tls.Config{ + InsecureSkipVerify: true, + Certificates: []tls.Certificate{*crt}, + }, nil } func signHost(ca tls.Certificate, hosts []string) (cert *tls.Certificate, err error) { @@ -286,7 +294,9 @@ func hashHosts(lst []string) []byte { // client or server. func cloneTLSConfig(cfg *tls.Config) *tls.Config { if cfg == nil { - return &tls.Config{} + return &tls.Config{ + InsecureSkipVerify: true, + } } return cfg.Clone() } diff --git a/req_condition.go b/req_condition.go new file mode 100644 index 0000000..dde0b9d --- /dev/null +++ b/req_condition.go @@ -0,0 +1,32 @@ +package mps + +import ( + "net/http" +) + +type ReqCondition struct { + proxy *HttpProxy + filters []Filter +} + +func (cond *ReqCondition) DoFunc(fn func(req *http.Request) (*http.Request, *http.Response)) { + cond.Do(RequestHandleFunc(fn)) +} + +func (cond *ReqCondition) Do(fn RequestHandle) { + cond.proxy.Ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) { + total := len(cond.filters) + for i := 0; i < total; i++ { + if !cond.filters[i].Match(req) { + return ctx.Next(req) + } + } + + req, resp := fn.Handle(req) + if resp != nil { + return resp, nil + } + + return ctx.Next(req) + }) +} diff --git a/resp_condition.go b/resp_condition.go new file mode 100644 index 0000000..e91d9c9 --- /dev/null +++ b/resp_condition.go @@ -0,0 +1,32 @@ +package mps + +import ( + "net/http" +) + +type RespCondition struct { + proxy *HttpProxy + filters []Filter +} + +func (cond *RespCondition) DoFunc(fn func(resp *http.Response) (*http.Response, error)) { + cond.Do(ResponseHandleFunc(fn)) +} + +func (cond *RespCondition) Do(fn ResponseHandle) { + cond.proxy.Ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) { + total := len(cond.filters) + for i := 0; i < total; i++ { + if !cond.filters[i].Match(req) { + return ctx.Next(req) + } + } + + resp, err := ctx.Next(req) + if err != nil { + return nil, err + } + + return fn.Handle(resp) + }) +} diff --git a/response_handle.go b/response_handle.go index ab2f29e..29e6f72 100644 --- a/response_handle.go +++ b/response_handle.go @@ -3,11 +3,11 @@ package mps import "net/http" type ResponseHandle interface { - Handle(resp *http.Response) *http.Response + Handle(resp *http.Response) (*http.Response, error) } -type ResponseHandleFunc func(resp *http.Response) *http.Response +type ResponseHandleFunc func(resp *http.Response) (*http.Response, error) -func (f ResponseHandleFunc) Handle(resp *http.Response) *http.Response { +func (f ResponseHandleFunc) Handle(resp *http.Response) (*http.Response, error) { return f(resp) } diff --git a/reverse_handler.go b/reverse_handler.go index 39d0769..2096810 100644 --- a/reverse_handler.go +++ b/reverse_handler.go @@ -2,6 +2,7 @@ package mps import "net/http" +// ReverseHandler is a reverse proxy server implementation type ReverseHandler struct { Ctx *Context } diff --git a/tunnel_handler.go b/tunnel_handler.go index 2e9ce85..657f04b 100644 --- a/tunnel_handler.go +++ b/tunnel_handler.go @@ -37,6 +37,7 @@ func (tunnel *TunnelHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request u *url.URL = nil targetConn net.Conn = nil targetAddr = hostAndPort(req.URL.Host) + isProxy = false ) if tunnel.Ctx.Transport != nil && tunnel.Ctx.Transport.Proxy != nil { u, err = tunnel.Ctx.Transport.Proxy(req) @@ -47,6 +48,7 @@ func (tunnel *TunnelHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request if u != nil { // connect addr eg. "localhost:80" targetAddr = hostAndPort(u.Host) + isProxy = true } } else { @@ -60,6 +62,11 @@ func (tunnel *TunnelHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request return } + // The cascade proxy needs to forward the request + if isProxy { + _ = req.Write(targetConn) + } + go func() { buf := make([]byte, 2048) _, _ = io.CopyBuffer(targetConn, proxyClient, buf) diff --git a/tunnel_handler_test.go b/tunnel_handler_test.go index 62372c6..4f3c915 100644 --- a/tunnel_handler_test.go +++ b/tunnel_handler_test.go @@ -12,14 +12,13 @@ import ( func TestNewTunnelHandler(t *testing.T) { tunnel := NewTunnelHandler() //tunnel.Transport().Proxy = func(r *http.Request) (*url.URL, error) { - // //return url.Parse("http://59.58.58.92:4235") + // //return url.Parse("http://121.56.39.197:4283") // return url.Parse("http://127.0.0.1:7890") //} - tunnel.Transport().Dial = nil tunnelSrv := httptest.NewServer(tunnel) defer tunnelSrv.Close() - req, _ := http.NewRequest(http.MethodGet, "http://httpbin.org/get", nil) + req, _ := http.NewRequest(http.MethodGet, "https://httpbin.org/get", nil) http.DefaultClient.Transport = &http.Transport{ Proxy: func(r *http.Request) (*url.URL, error) { return url.Parse(tunnelSrv.URL)