Skip to content

Commit

Permalink
Improvements, Bug Fixes and Year Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
iDigitalFlame committed Dec 11, 2021
1 parent bc5f166 commit e9632d0
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 51 deletions.
23 changes: 11 additions & 12 deletions new.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2021 PurpleSec Team
// Copyright 2021 - 2022 PurpleSec Team
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
Expand Down Expand Up @@ -48,21 +48,22 @@ func (t Timeout) config(p *Proxy) {
p.server.ReadHeaderTimeout = time.Duration(t)
}

// TLS creates a config paramater with the specified Key and Value file
// paths.
// TLS creates a config paramater with the specified Key and Value file paths.
func TLS(cert, key string) Parameter {
return &keys{Cert: cert, Key: key}
}

// New creates a new Proxy instance from the specified listen
// address and optional parameters.
// New creates a new Proxy instance from the specified listen address and
// optional parameters.
func New(listen string, c ...Parameter) *Proxy {
return NewContext(context.Background(), listen, c...)
}

// NewContext creates a new Proxy instance from the specified listen
// address and optional parameters. This function allows the caller to specify
// a context to specify when to shutdown the Proxy.
// NewContext creates a new Proxy instance from the specified listen address and
// optional parameters.
//
// This function allows the caller to specify a context to specify when to shutdown
// the Proxy.
func NewContext(x context.Context, listen string, c ...Parameter) *Proxy {
p := &Proxy{
pool: &sync.Pool{
Expand All @@ -86,10 +87,8 @@ func NewContext(x context.Context, listen string, c ...Parameter) *Proxy {
c[i].config(p)
}
if len(c) == 0 {
p.server.ReadTimeout = DefaultTimeout
p.server.IdleTimeout = DefaultTimeout
p.server.WriteTimeout = DefaultTimeout
p.server.ReadHeaderTimeout = DefaultTimeout
p.server.ReadTimeout, p.server.IdleTimeout = DefaultTimeout, DefaultTimeout
p.server.WriteTimeout, p.server.ReadHeaderTimeout = DefaultTimeout, DefaultTimeout
}
return p
}
58 changes: 39 additions & 19 deletions proxy.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2021 PurpleSec Team
// Copyright 2021 - 2022 PurpleSec Team
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
Expand All @@ -19,21 +19,21 @@ package switchproxy
import (
"bytes"
"context"
"crypto/tls"
"io"
"net"
"net/http"
"sync"
"time"
)

const (
// DefaultTimeout is the default timeout value used when a Timeout is not
// specified in NewProxy.
DefaultTimeout = time.Second * time.Duration(15)
)
// DefaultTimeout is the default timeout value used when a Timeout is not
// specified in NewProxy.
const DefaultTimeout = time.Second * time.Duration(15)

// Proxy is a struct that represents a stacked proxy that allows a forwarding proxy
// with secondary read only Switch connections that allow logging and storing the connection data.
// with secondary read only Switch connections that allow logging and storing
// the connection data.
type Proxy struct {
ctx context.Context
key string
Expand All @@ -51,29 +51,48 @@ type transfer struct {
data []byte
}

// Stop attempts to gracefully close and Stop the proxy and all remaining connextions.
func (p *Proxy) Stop() error {
// Close attempts to gracefully close and stop the proxy and all remaining
// connextions.
func (p *Proxy) Close() error {
p.cancel()
return p.server.Close()
}

// Start starts the Server listening loop and returns an error if the server could not be started.
// Start starts the Server listening loop and returns an error if the server
// could not be started.
//
// Only returns an error if any IO issues occur during operation.
func (p *Proxy) Start() error {
defer p.Stop()
var err error
if len(p.cert) > 0 && len(p.key) > 0 {
return p.server.ListenAndServeTLS(p.cert, p.key)
p.server.TLSConfig = &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
},
CurvePreferences: []tls.CurveID{tls.CurveP256, tls.X25519},
PreferServerCipherSuites: true,
}
err = p.server.ListenAndServeTLS(p.cert, p.key)
} else {
err = p.server.ListenAndServe()
}
return p.server.ListenAndServe()
p.Close()
return err
}

// Primary sets the primary Proxy Switch context.
func (p *Proxy) Primary(s *Switch) {
p.primary = s
}
func (p *Proxy) clear(t *transfer) {
t.in = nil
t.data = nil
t.in, t.data = nil, nil
t.out.Reset()
t.read.Reset()
p.pool.Put(t)
Expand All @@ -90,15 +109,14 @@ func (p *Proxy) context(_ net.Listener) context.Context {
// ServeHTTP satisfies the http.Handler interface.
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
t := p.pool.Get().(*transfer)
defer p.clear(t)
defer r.Body.Close()
if _, err := io.Copy(t.read, r.Body); err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
p.clear(t)
r.Body.Close()
return
}
t.data = t.read.Bytes()
t.in = bytes.NewReader(t.data)
if p.primary != nil {
if t.in = bytes.NewReader(t.data); p.primary != nil {
if s, h, err := p.primary.process(p.ctx, r, t); err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
} else {
Expand All @@ -120,4 +138,6 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
p.secondary[i].process(p.ctx, r, t)
}
}
p.clear(t)
r.Body.Close()
}
49 changes: 29 additions & 20 deletions switch.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2021 PurpleSec Team
// Copyright 2021 - 2022 PurpleSec Team
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
Expand All @@ -18,7 +18,7 @@ package switchproxy

import (
"context"
"fmt"
"errors"
"io"
"net"
"net/http"
Expand Down Expand Up @@ -63,6 +63,7 @@ func (r Result) IsResponse() bool {
}

// Rewrite adds a URL rewrite from the Switch.
//
// If a URL starts with the 'from' paramater, it will be replaced with the 'to'
// paramater, only if starting with on the URL path.
func (s *Switch) Rewrite(from, to string) {
Expand All @@ -80,12 +81,14 @@ func NewSwitch(target string) (*Switch, error) {
return NewSwitchTimeout(target, DefaultTimeout)
}

// NewSwitchTimeout creates a switching context that allows the connection to be proxied
// to the specified server. This function will set the specified timeout.
// NewSwitchTimeout creates a switching context that allows the connection to be
// proxied to the specified server.
//
// This function will set the specified timeout.
func NewSwitchTimeout(target string, t time.Duration) (*Switch, error) {
u, err := url.Parse(target)
if err != nil {
return nil, fmt.Errorf("unable to resolve URL: %w", err)
return nil, errors.New("unable to resolve URL: " + err.Error())
}
if !u.IsAbs() {
u.Scheme = "http"
Expand Down Expand Up @@ -113,31 +116,33 @@ func NewSwitchTimeout(target string, t time.Duration) (*Switch, error) {
return s, nil
}
func (s Switch) process(x context.Context, r *http.Request, t *transfer) (int, http.Header, error) {
s.URL.Path = r.URL.Path
s.URL.User = r.URL.User
s.URL.Opaque = r.URL.Opaque
s.URL.Fragment = r.URL.Fragment
s.URL.RawQuery = r.URL.RawQuery
s.URL.ForceQuery = r.URL.ForceQuery
s.Path = r.URL.Path
s.User = r.URL.User
s.Opaque = r.URL.Opaque
s.Fragment = r.URL.Fragment
s.RawQuery = r.URL.RawQuery
s.ForceQuery = r.URL.ForceQuery
for k, v := range s.rewrite {
if strings.HasPrefix(s.URL.Path, k) {
s.URL.Path = path.Join(v, s.URL.Path[len(k):])
if strings.HasPrefix(s.Path, k) {
s.Path = path.Join(v, s.Path[len(k):])
}
}
var f func()
if s.timeout > 0 {
var f context.CancelFunc
x, f = context.WithTimeout(x, s.timeout)
defer f()
}
u := uuid.New().String()
q, err := http.NewRequestWithContext(x, r.Method, s.URL.String(), t.in)
var (
u = uuid.New().String()
q, err = http.NewRequestWithContext(x, r.Method, s.String(), t.in)
)
if err != nil {
f()
return 0, nil, err
}
if s.Pre != nil {
s.Pre(Result{
IP: r.RemoteAddr,
URL: s.URL.String(),
URL: s.String(),
UUID: u,
Path: s.Path,
Method: r.Method,
Expand All @@ -150,16 +155,18 @@ func (s Switch) process(x context.Context, r *http.Request, t *transfer) (int, h
q.TransferEncoding = r.TransferEncoding
o, err := s.client.Do(q)
if err != nil {
f()
return 0, nil, err
}
defer o.Body.Close()
if _, err := io.Copy(t.out, o.Body); err != nil {
f()
o.Body.Close()
return 0, nil, err
}
if s.Post != nil {
s.Post(Result{
IP: r.RemoteAddr,
URL: s.URL.String(),
URL: s.String(),
Path: s.Path,
UUID: u,
Status: uint16(o.StatusCode),
Expand All @@ -168,5 +175,7 @@ func (s Switch) process(x context.Context, r *http.Request, t *transfer) (int, h
Headers: o.Header,
})
}
f()
o.Body.Close()
return o.StatusCode, o.Header, nil
}

0 comments on commit e9632d0

Please sign in to comment.