Skip to content

Commit

Permalink
Perfect middleware, request interception
Browse files Browse the repository at this point in the history
  • Loading branch information
telanflow committed Aug 9, 2020
1 parent 05f580c commit 235a412
Show file tree
Hide file tree
Showing 11 changed files with 209 additions and 58 deletions.
52 changes: 33 additions & 19 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,24 @@ package mps
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
)

// Context for the request
// which contains Middleware, Transport, and other values
type Context struct {
// context.Context
Context context.Context
Request *http.Request
Response *http.Response

// Request context-dependent requests
Request *http.Request

// Response is associated with Request
Response *http.Response

// Transport is used for global HTTP requests, and it will be reused.
Transport *http.Transport

// In some cases it is not always necessary to remove the Proxy Header.
Expand All @@ -20,7 +31,10 @@ type Context struct {
// present in the http.Response before proxying
KeepDestinationHeaders bool

// requests Middleware
// middlewares ACTS on Request and Response.
// It's going to be reused by the Context
// mi is the index subscript of the middlewares traversal
// the default value for the index is -1
mi int
middlewares []Middleware
}
Expand All @@ -29,25 +43,25 @@ func NewContext() *Context {
return &Context{
Context: context.Background(),
Transport: &http.Transport{
//DialContext: (&net.Dialer{
// Timeout: 15 * time.Second,
// KeepAlive: 30 * time.Second,
// DualStack: true,
//}).DialContext,
////ForceAttemptHTTP2: true,
//MaxIdleConns: 100,
//IdleConnTimeout: 90 * time.Second,
//TLSHandshakeTimeout: 10 * time.Second,
//ExpectContinueTimeout: 1 * time.Second,
DialContext: (&net.Dialer{
Timeout: 15 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
Proxy: http.ProxyFromEnvironment,
},
Request: nil,
Response: nil,
KeepHeader: false,
Request: nil,
Response: nil,
KeepHeader: false,
KeepDestinationHeaders: false,
mi: -1,
middlewares: make([]Middleware, 0),
mi: -1,
middlewares: make([]Middleware, 0),
}
}

Expand Down Expand Up @@ -88,7 +102,7 @@ func (ctx *Context) Next(req *http.Request) (*http.Response, error) {

func (ctx *Context) Copy() *Context {
return &Context{
Context: context.Background(),
Context: context.Background(),
Request: nil,
Response: nil,
KeepHeader: false,
Expand Down
74 changes: 61 additions & 13 deletions filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,77 @@ package mps
import (
"net/http"
"regexp"
"strings"
)

type Filter interface {
Match(expr string) bool
Match(req *http.Request) bool
}

type FilterFunc func(expr string) bool
type FilterFunc func(req *http.Request) bool

func (f FilterFunc) Match(expr string) bool {
return f(expr)
func (f FilterFunc) Match(req *http.Request) bool {
return f(req)
}

// 匹配域名
var MatchIsHost = func(expr string, req *http.Request) Filter {
exp, err := regexp.Compile(expr)
if err != nil {
panic(err)
// FilterHostMatches for request.Host
func FilterHostMatches(regexps ...*regexp.Regexp) Filter {
return FilterFunc(func(req *http.Request) bool {
for _, re := range regexps {
if re.MatchString(req.Host) {
return true
}
}
return false
})
}

// FilterHostIs returns a Filter, testing whether the host to which the request is directed to equal
// to one of the given strings
func FilterHostIs(hosts ...string) Filter {
hostSet := make(map[string]bool)
for _, h := range hosts {
hostSet[h] = true
}
return FilterFunc(func(expr string) bool {
return exp.MatchString(req.Host)
return FilterFunc(func(req *http.Request) bool {
_, ok := hostSet[req.URL.Host]
return ok
})
}

// FilterUrlMatches returns a Filter testing whether the destination URL
// of the request matches the given regexp, with or without prefix
func FilterUrlMatches(re *regexp.Regexp) Filter {
return FilterFunc(func(req *http.Request) bool {
return re.MatchString(req.URL.Path) ||
re.MatchString(req.URL.Host+req.URL.Path)
})
}

type ReqHandler interface {
Handler(ctx *Context)
// FilterUrlHasPrefix returns a Filter checking wether the destination URL the proxy client has requested
// has the given prefix, with or without the host.
// For example FilterUrlHasPrefix("host/x") will match requests of the form 'GET host/x', and will match
// requests to url 'http://host/x'
func FilterUrlHasPrefix(prefix string) Filter {
return FilterFunc(func(req *http.Request) bool {
return strings.HasPrefix(req.URL.Path, prefix) ||
strings.HasPrefix(req.URL.Host+req.URL.Path, prefix) ||
strings.HasPrefix(req.URL.Scheme+req.URL.Host+req.URL.Path, prefix)
})
}

// FilterUrlIs returns a Filter, testing whether or not the request URL is one of the given strings
// with or without the host prefix.
// FilterUrlIs("google.com/","foo") will match requests 'GET /' to 'google.com', requests `'GET google.com/' to
// any host, and requests of the form 'GET foo'.
func FilterUrlIs(urls ...string) Filter {
urlSet := make(map[string]bool)
for _, u := range urls {
urlSet[u] = true
}
return FilterFunc(func(req *http.Request) bool {
_, pathOk := urlSet[req.URL.Path]
_, hostAndOk := urlSet[req.URL.Host+req.URL.Path]
return pathOk || hostAndOk
})
}
12 changes: 10 additions & 2 deletions http_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ func NewHttpProxy() *HttpProxy {
ctx := NewContext()

return &HttpProxy{
Ctx: ctx,
Ctx: ctx,
// default HTTP proxy
HttpHandler: &ForwardHandler{Ctx: ctx},
HttpHandler: &ForwardHandler{Ctx: ctx},
// default HTTPS proxy
HttpsHandler: &TunnelHandler{Ctx: ctx},
}
Expand All @@ -45,6 +45,14 @@ func (proxy *HttpProxy) UseFunc(fus ...MiddlewareFunc) {
proxy.Ctx.UseFunc(fus...)
}

func (proxy *HttpProxy) OnRequest(filter ...Filter) *ReqCondition {
return &ReqCondition{proxy: proxy, filters: filter}
}

func (proxy *HttpProxy) OnResponse(filter ...Filter) *RespCondition {
return &RespCondition{proxy: proxy, filters: filter}
}

func hijacker(rw http.ResponseWriter) (conn net.Conn, err error) {
hij, ok := rw.(http.Hijacker)
if !ok {
Expand Down
2 changes: 1 addition & 1 deletion http_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ func TestMiddlewareFunc(t *testing.T) {
log.Println(err)
log.Println(resp.Status)
log.Println(string(body))
}
}
44 changes: 27 additions & 17 deletions mitm_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,33 @@ var (
)

type MitmHandler struct {
Ctx *Context
Certificate tls.Certificate
Ctx *Context

TLSConfig *tls.Config
// CertContainer is certificate storage container
CertContainer cert.Container
}

func NewMitmHandler() *MitmHandler {
return &MitmHandler{
Ctx: NewContext(),
// default MPS Certificate
Certificate: cert.DefaultCertificate,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
// Certificate cache storage container
Certificate: cert.DefaultCertificate,
CertContainer: cert.NewMemProvider(),
}
}

func NewMitmHandlerWithCert(certFile, keyFile string) (*MitmHandler, error) {
certificate, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
return &MitmHandler{
Ctx: NewContext(),
Certificate: certificate,
CertContainer: cert.NewMemProvider(),
}, nil
}

func (mitm *MitmHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// get hijacker connection
proxyClient, err := hijacker(w)
Expand Down Expand Up @@ -173,10 +180,11 @@ func (mitm *MitmHandler) TLSConfigFromCA(host string) (*tls.Config, error) {

// Returned existing certificate for the host
crt, err := mitm.CertContainer.Get(host)
if err == nil {
config := cloneTLSConfig(mitm.TLSConfig)
config.Certificates = append(config.Certificates, *crt)
return config, nil
if err == nil && crt != nil {
return &tls.Config{
InsecureSkipVerify: true,
Certificates: []tls.Certificate{*crt},
}, nil
}

// Issue a certificate for host
Expand All @@ -189,10 +197,10 @@ func (mitm *MitmHandler) TLSConfigFromCA(host string) (*tls.Config, error) {
// Set certificate to container
mitm.CertContainer.Set(host, crt)

config := &tls.Config{
Certificates: []tls.Certificate{*crt},
}
return config, nil
return &tls.Config{
InsecureSkipVerify: true,
Certificates: []tls.Certificate{*crt},
}, nil
}

func signHost(ca tls.Certificate, hosts []string) (cert *tls.Certificate, err error) {
Expand Down Expand Up @@ -286,7 +294,9 @@ func hashHosts(lst []string) []byte {
// client or server.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
return &tls.Config{
InsecureSkipVerify: true,
}
}
return cfg.Clone()
}
Expand Down
32 changes: 32 additions & 0 deletions req_condition.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package mps

import (
"net/http"
)

type ReqCondition struct {
proxy *HttpProxy
filters []Filter
}

func (cond *ReqCondition) DoFunc(fn func(req *http.Request) (*http.Request, *http.Response)) {
cond.Do(RequestHandleFunc(fn))
}

func (cond *ReqCondition) Do(fn RequestHandle) {
cond.proxy.Ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) {
total := len(cond.filters)
for i := 0; i < total; i++ {
if !cond.filters[i].Match(req) {
return ctx.Next(req)
}
}

req, resp := fn.Handle(req)
if resp != nil {
return resp, nil
}

return ctx.Next(req)
})
}
32 changes: 32 additions & 0 deletions resp_condition.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package mps

import (
"net/http"
)

type RespCondition struct {
proxy *HttpProxy
filters []Filter
}

func (cond *RespCondition) DoFunc(fn func(resp *http.Response) (*http.Response, error)) {
cond.Do(ResponseHandleFunc(fn))
}

func (cond *RespCondition) Do(fn ResponseHandle) {
cond.proxy.Ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) {
total := len(cond.filters)
for i := 0; i < total; i++ {
if !cond.filters[i].Match(req) {
return ctx.Next(req)
}
}

resp, err := ctx.Next(req)
if err != nil {
return nil, err
}

return fn.Handle(resp)
})
}
6 changes: 3 additions & 3 deletions response_handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ package mps
import "net/http"

type ResponseHandle interface {
Handle(resp *http.Response) *http.Response
Handle(resp *http.Response) (*http.Response, error)
}

type ResponseHandleFunc func(resp *http.Response) *http.Response
type ResponseHandleFunc func(resp *http.Response) (*http.Response, error)

func (f ResponseHandleFunc) Handle(resp *http.Response) *http.Response {
func (f ResponseHandleFunc) Handle(resp *http.Response) (*http.Response, error) {
return f(resp)
}
1 change: 1 addition & 0 deletions reverse_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package mps

import "net/http"

// ReverseHandler is a reverse proxy server implementation
type ReverseHandler struct {
Ctx *Context
}
Expand Down
Loading

0 comments on commit 235a412

Please sign in to comment.