Skip to content

Commit

Permalink
[fixed] ReverseHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
telanflow committed Aug 10, 2020
1 parent ac053af commit 6c554d8
Show file tree
Hide file tree
Showing 11 changed files with 189 additions and 69 deletions.
2 changes: 1 addition & 1 deletion context.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ func (ctx *Context) UseFunc(fns ...MiddlewareFunc) {

func (ctx *Context) Next(req *http.Request) (*http.Response, error) {
var (
err error
total = len(ctx.middlewares)
err error
)
ctx.mi++
if ctx.mi >= total {
Expand Down
29 changes: 23 additions & 6 deletions forward_handler.go
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
package mps

import (
"bytes"
"io"
"io/ioutil"
"net/http"
)

// The forward proxy type. Implements http.Handler.
type ForwardHandler struct {
Ctx *Context
}

// Create a ForwardHandler
func NewForwardHandler() *ForwardHandler {
return &ForwardHandler{
Ctx: NewContext(),
}
}

//
func NewForwardHandlerWithContext(ctx *Context) *ForwardHandler {
return &ForwardHandler{
Ctx: ctx,
}
}

// Standard net/http function. You can use it alone
func (forward *ForwardHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// Copying a Context preserves the Transport, Middleware
ctx := forward.Ctx.Copy()
Expand All @@ -34,26 +40,37 @@ func (forward *ForwardHandler) ServeHTTP(rw http.ResponseWriter, req *http.Reque

resp, err := ctx.Next(req)
if err != nil {
http.Error(rw, err.Error(), 500)
http.Error(rw, err.Error(), 502)
return
}

origBody := resp.Body
defer origBody.Close()
bodyRes, err := ioutil.ReadAll(resp.Body)
if err != nil {
http.Error(rw, err.Error(), 502)
return
}
resp.Body.Close()

// http.ResponseWriter will take care of filling the correct response length
// Setting it now, might impose wrong value, contradicting the actual new
// body the user returned.
// We keep the original body to remove the header only if things changed.
// This will prevent problems with HEAD requests where there's no body, yet,
// the Content-Length header should be set.
if origBody != resp.Body {
if resp.ContentLength != int64(len(bodyRes)) {
resp.Header.Del("Content-Length")
}

copyHeaders(rw.Header(), resp.Header, forward.Ctx.KeepDestinationHeaders)
rw.WriteHeader(resp.StatusCode)
io.Copy(rw, resp.Body)
resp.Body.Close()

body := ioutil.NopCloser(bytes.NewReader(bodyRes))
_, err = io.Copy(rw, body)
body.Close()
if err != nil {
http.Error(rw, err.Error(), 502)
return
}
}

func (forward *ForwardHandler) Transport() *http.Transport {
Expand Down
35 changes: 27 additions & 8 deletions http_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@ import (
"net/http"
)

// The basic proxy type. Implements http.Handler.
type HttpProxy struct {
// HTTPS requests use the TunnelHandler proxy by default
HttpsHandler http.Handler

// HTTP requests use the ForwardHandler proxy by default
HttpHandler http.Handler

// HTTP requests use the ReverseHandler proxy by default
ReverseHandler http.Handler

Ctx *Context
}

Expand All @@ -27,32 +31,50 @@ func NewHttpProxy() *HttpProxy {
HttpHandler: &ForwardHandler{Ctx: ctx},
// default HTTPS proxy
HttpsHandler: &TunnelHandler{Ctx: ctx},
// default Reverse proxy
ReverseHandler: &ReverseHandler{Ctx: ctx},
}
}

// Standard net/http function.
func (proxy *HttpProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if req.Method == http.MethodConnect {
proxy.HttpsHandler.ServeHTTP(rw, req)
}
proxy.HttpHandler.ServeHTTP(rw, req)

if !req.URL.IsAbs() {
proxy.ReverseHandler.ServeHTTP(rw, req)
} else {
proxy.HttpHandler.ServeHTTP(rw, req)
}
}

// Use registers an Middleware to proxy
func (proxy *HttpProxy) Use(middleware ...Middleware) {
proxy.Ctx.Use(middleware...)
}

// UseFunc registers an MiddlewareFunc to proxy
func (proxy *HttpProxy) UseFunc(fus ...MiddlewareFunc) {
proxy.Ctx.UseFunc(fus...)
}

func (proxy *HttpProxy) OnRequest(filter ...Filter) *ReqCondition {
return &ReqCondition{proxy: proxy, filters: filter}
// OnRequest filter requests through Filters
func (proxy *HttpProxy) OnRequest(filters ...Filter) *ReqCondition {
return &ReqCondition{ctx: proxy.Ctx, filters: filters}
}

func (proxy *HttpProxy) OnResponse(filter ...Filter) *RespCondition {
return &RespCondition{proxy: proxy, filters: filter}
// OnResponse filter response through Filters
func (proxy *HttpProxy) OnResponse(filters ...Filter) *RespCondition {
return &RespCondition{ctx: proxy.Ctx, filters: filters}
}

// Transport get http.Transport instance
func (proxy *HttpProxy) Transport() *http.Transport {
return proxy.Ctx.Transport
}

// hijacker an HTTP handler to take over the connection.
func hijacker(rw http.ResponseWriter) (conn net.Conn, err error) {
hij, ok := rw.(http.Hijacker)
if !ok {
Expand All @@ -77,9 +99,6 @@ func removeProxyHeaders(r *http.Request) {
// If no Accept-Encoding header exists, Transport will add the headers it can accept
// and would wrap the response body with the relevant reader.
r.Header.Del("Accept-Encoding")
// curl can add that, see
// https://jdebp.eu./FGA/web-proxy-connection-header.html

// RFC 2616 (section 13.5.1)
// https://www.ietf.org/rfc/rfc2616.txt
r.Header.Del("Proxy-Connection")
Expand Down
53 changes: 30 additions & 23 deletions http_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,58 +17,65 @@ func NewTestServer() *httptest.Server {
}))
}

func TestNewHttpProxy(t *testing.T) {
asserts := assert.New(t)
func HttpGet(rawurl string, proxy func(r *http.Request) (*url.URL, error)) (*http.Response, error) {
req, _ := http.NewRequest(http.MethodGet, rawurl, nil)
http.DefaultClient.Transport = &http.Transport{
Proxy: proxy,
}
return http.DefaultClient.Do(req)
}

func TestNewHttpProxy(t *testing.T) {
srv := NewTestServer()
defer srv.Close()

proxy := NewHttpProxy()
proxySrv := httptest.NewServer(proxy)
defer proxySrv.Close()

req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
http.DefaultClient.Transport = &http.Transport{
Proxy: func(r *http.Request) (*url.URL, error) {
return url.Parse(srv.URL)
},
resp, err := HttpGet(srv.URL, func(r *http.Request) (*url.URL, error) {
return url.Parse(proxySrv.URL)
})
if err != nil {
t.Fatal(err)
}

resp, err := http.DefaultClient.Do(req)
asserts.Equal(err, nil, "err should be equal nil")

body, _ := ioutil.ReadAll(resp.Body)
resp.Body.Close()

asserts := assert.New(t)
asserts.Equal(resp.StatusCode, 200, "statusCode should be equal 200")
asserts.Equal(int64(len(body)), resp.ContentLength)
}

func TestMiddlewareFunc(t *testing.T) {
// target server
srv := NewTestServer()
defer srv.Close()

// proxy server
proxy := NewHttpProxy()
// use Middleware
proxy.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) {
log.Println(req.URL.String())
return ctx.Next(req)
})
srv := httptest.NewServer(proxy)
defer srv.Close()

req, _ := http.NewRequest(http.MethodGet, "https://httpbin.org/get", nil)
http.DefaultClient.Transport = &http.Transport{
Proxy: func(r *http.Request) (*url.URL, error) {
return url.Parse(srv.URL)
},
}
proxySrv := httptest.NewServer(proxy)
defer proxySrv.Close()

resp, err := http.DefaultClient.Do(req)
// send request
resp, err := HttpGet(srv.URL, func(r *http.Request) (*url.URL, error) {
return url.Parse(proxySrv.URL)
})
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()

body, _ := ioutil.ReadAll(resp.Body)
resp.Body.Close()

log.Println(err)
log.Println(resp.Status)
asserts := assert.New(t)
asserts.Equal(resp.StatusCode, 200)
asserts.Equal(int64(len(body)), resp.ContentLength)
log.Println(string(body))
}
14 changes: 14 additions & 0 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,26 @@ package mps

import "net/http"

// Middleware will "tamper" with the request coming to the proxy server
type Middleware interface {
// Execute the next middleware as a linked list. "ctx.Next(req)"
// eg:
// func Handle(req *http.Request, ctx *Context) (*http.Response, error) {
// // You can do anything to modify the http.Request ...
// resp, err := ctx.Next(req)
// // You can do anything to modify the http.Response ...
// return resp, err
// }
//
// Alternatively, you can simply return the response without executing `ctx.Next()`,
// which will interrupt subsequent middleware execution.
Handle(req *http.Request, ctx *Context) (*http.Response, error)
}

// A wrapper that would convert a function to a Middleware interface type
type MiddlewareFunc func(req *http.Request, ctx *Context) (*http.Response, error)

// MiddlewareFunc.Handle(req, ctx) <=> MiddlewareFunc(req, ctx)
func (f MiddlewareFunc) Handle(req *http.Request, ctx *Context) (*http.Response, error) {
return f(req, ctx)
}
18 changes: 8 additions & 10 deletions mitm_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,24 @@ var (
httpsRegexp = regexp.MustCompile("^https://")
)

// The Man-in-the-middle proxy type. Implements http.Handler.
type MitmHandler struct {
Ctx *Context
Certificate tls.Certificate
// CertContainer is certificate storage container
CertContainer cert.Container
}

// Create a MitmHandler, use default cert.
func NewMitmHandler() *MitmHandler {
return &MitmHandler{
Ctx: NewContext(),
// default MPS Certificate
Ctx: NewContext(),
Certificate: cert.DefaultCertificate,
CertContainer: cert.NewMemProvider(),
}
}

// Create a MitmHandler with cert file
func NewMitmHandlerWithCert(certFile, keyFile string) (*MitmHandler, error) {
certificate, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
Expand All @@ -58,6 +60,7 @@ func NewMitmHandlerWithCert(certFile, keyFile string) (*MitmHandler, error) {
}, nil
}

// Standard net/http function. You can use it alone
func (mitm *MitmHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// get hijacker connection
proxyClient, err := hijacker(w)
Expand Down Expand Up @@ -171,7 +174,6 @@ func (mitm *MitmHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
}

}()
}

Expand Down Expand Up @@ -232,18 +234,14 @@ func signHost(ca tls.Certificate, hosts []string) (cert *tls.Certificate, err er

// certificate template
tpl := x509.Certificate{
// SerialNumber 是 CA 颁布的唯一序列号,在此使用一个大随机数来代表它
SerialNumber: big.NewInt(rand.Int63()),
Issuer: x509ca.Subject,
// pkix.Name代表一个X.509识别名。只包含识别名的公共属性,额外的属性被忽略。
Subject: pkix.Name{
Organization: []string{"MPS untrusted MITM proxy Inc"},
},
NotBefore: time.Unix(0, 0),
NotAfter: time.Now().AddDate(20, 0, 0),
// KeyUsage 与 ExtKeyUsage 用来表明该证书是用来做服务器认证的
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
// 密钥扩展用途的序列
NotBefore: time.Unix(0, 0),
NotAfter: time.Now().AddDate(20, 0, 0),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
EmailAddresses: x509ca.EmailAddresses,
Expand Down
4 changes: 2 additions & 2 deletions req_condition.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
)

type ReqCondition struct {
proxy *HttpProxy
ctx *Context
filters []Filter
}

Expand All @@ -14,7 +14,7 @@ func (cond *ReqCondition) DoFunc(fn func(req *http.Request) (*http.Request, *htt
}

func (cond *ReqCondition) Do(fn RequestHandle) {
cond.proxy.Ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) {
cond.ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) {
total := len(cond.filters)
for i := 0; i < total; i++ {
if !cond.filters[i].Match(req) {
Expand Down
4 changes: 2 additions & 2 deletions resp_condition.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
)

type RespCondition struct {
proxy *HttpProxy
ctx *Context
filters []Filter
}

Expand All @@ -14,7 +14,7 @@ func (cond *RespCondition) DoFunc(fn func(resp *http.Response) (*http.Response,
}

func (cond *RespCondition) Do(fn ResponseHandle) {
cond.proxy.Ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) {
cond.ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) {
total := len(cond.filters)
for i := 0; i < total; i++ {
if !cond.filters[i].Match(req) {
Expand Down
Loading

0 comments on commit 6c554d8

Please sign in to comment.