Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

agent/http: refactor proxy for simpler error handling #271

Merged
merged 1 commit into from
Aug 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 66 additions & 48 deletions pkg/agent/protocol/http/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,6 @@ func NewProxy(c ProxyConfig, d Disruption) (protocol.Proxy, error) {
}, nil
}

// contains verifies if a list of strings contains the given string
func contains(list []string, target string) bool {
for _, element := range list {
if element == target {
return true
}
}
return false
}

// httpClient defines the method for executing HTTP requests. It is used to allow mocking
// the client in tests
type httpClient interface {
Expand All @@ -97,56 +87,84 @@ type httpHandler struct {
client httpClient
}

func (h *httpHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
var statusCode int
headers := http.Header{}
body := io.NopCloser(strings.NewReader(h.disruption.ErrorBody))

excluded := contains(h.disruption.Excluded, req.URL.Path)

if !excluded && h.disruption.ErrorRate > 0 && rand.Float32() <= h.disruption.ErrorRate {
// force error code
statusCode = int(h.disruption.ErrorCode)
} else {
req.Host = h.upstreamURL.Host
req.URL.Host = h.upstreamURL.Host
req.URL.Scheme = h.upstreamURL.Scheme
req.RequestURI = ""
originServerResponse, srvErr := h.client.Do(req)
if srvErr != nil {
rw.WriteHeader(http.StatusInternalServerError)
_, _ = fmt.Fprint(rw, srvErr)
return
// isExcluded checks whether a request should be proxied through without any kind of modification whatsoever.
func (h *httpHandler) isExcluded(r *http.Request) bool {
for _, excluded := range h.disruption.Excluded {
if strings.EqualFold(r.URL.Path, excluded) {
return true
}
}

headers = originServerResponse.Header
statusCode = originServerResponse.StatusCode
body = originServerResponse.Body
return false
}

defer func() {
_ = originServerResponse.Body.Close()
}()
}
// forward forwards a request to the upstream URL.
// Request is performed immediately, but response won't be sent before the duration specified in delay.
func (h *httpHandler) forward(rw http.ResponseWriter, req *http.Request, delay time.Duration) {
timer := time.After(delay)

if !excluded && h.disruption.AverageDelay > 0 {
delay := int64(h.disruption.AverageDelay)
if h.disruption.DelayVariation > 0 {
variation := int64(h.disruption.DelayVariation)
delay = delay + variation - 2*rand.Int63n(variation)
}
time.Sleep(time.Duration(delay))
upstreamReq := req.Clone(context.Background())
upstreamReq.Host = h.upstreamURL.Host
upstreamReq.URL.Host = h.upstreamURL.Host
upstreamReq.URL.Scheme = h.upstreamURL.Scheme
upstreamReq.RequestURI = "" // It is an error to set this field in an HTTP client request.

response, err := h.client.Do(req)
<-timer
if err != nil {
rw.WriteHeader(http.StatusBadGateway)
_, _ = fmt.Fprint(rw, err)
return
}

// return response to the client
for key, values := range headers {
defer func() {
// Fully consume and then close upstream response body.
_, _ = io.Copy(io.Discard, response.Body)
_ = response.Body.Close()
}()

// Mirror headers.
for key, values := range response.Header {
for _, value := range values {
rw.Header().Add(key, value)
}
}
rw.WriteHeader(statusCode)

// Mirror status code.
rw.WriteHeader(response.StatusCode)

// ignore errors writing body, nothing to do.
_, _ = io.Copy(rw, body)
_, _ = io.Copy(rw, response.Body)
}

// injectError waits sleeps the duration specified in delay and then writes the configured error downstream.
func (h *httpHandler) injectError(rw http.ResponseWriter, delay time.Duration) {
time.Sleep(delay)

rw.WriteHeader(int(h.disruption.ErrorCode))
_, _ = rw.Write([]byte(h.disruption.ErrorBody))
}

func (h *httpHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if h.isExcluded(req) {
//nolint:contextcheck // Unclear which context the linter requires us to propagate here.
h.forward(rw, req, 0)
return
}

delay := h.disruption.AverageDelay
if h.disruption.DelayVariation > 0 {
variation := int64(h.disruption.DelayVariation)
delay += time.Duration(variation - 2*rand.Int63n(variation))
}

if h.disruption.ErrorRate > 0 && rand.Float32() <= h.disruption.ErrorRate {
h.injectError(rw, delay)
return
}

//nolint:contextcheck // Unclear which context the linter requires us to propagate here.
h.forward(rw, req, delay)
}

// Start starts the execution of the proxy
Expand Down
Loading