Skip to content

Commit

Permalink
Fix body not sent upstream with opa authorization #432 (#433)
Browse files Browse the repository at this point in the history
  • Loading branch information
p53 authored Mar 16, 2024
1 parent c3be865 commit eee4ed6
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 35 deletions.
7 changes: 1 addition & 6 deletions pkg/authorization/external_opa.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ func (p *OpaAuthorizationProvider) Authorize() (AuthzDecision, error) {
defer cancel()

reqBody, err := io.ReadAll(p.req.Body)

if err != nil {
return DeniedAuthz, err
}
p.req.Body.Close()

opaReq := &OpaAuthzRequest{Input: &OpaInput{}}
opaReq.Input.Body = string(reqBody)
Expand All @@ -77,7 +77,6 @@ func (p *OpaAuthorizationProvider) Authorize() (AuthzDecision, error) {
opaReq.Input.UserAgent = p.req.UserAgent()

opaReqBody, err := json.Marshal(opaReq)

if err != nil {
return DeniedAuthz, err
}
Expand All @@ -87,7 +86,6 @@ func (p *OpaAuthorizationProvider) Authorize() (AuthzDecision, error) {
p.authzURL.String(),
bytes.NewReader(opaReqBody),
)

if err != nil {
return DeniedAuthz, err
}
Expand All @@ -99,19 +97,16 @@ func (p *OpaAuthorizationProvider) Authorize() (AuthzDecision, error) {
opaResp := &OpaAuthzResponse{}

resp, err := client.Do(httpReq)

if err != nil {
return DeniedAuthz, err
}

body, err := io.ReadAll(resp.Body)

if err != nil {
return DeniedAuthz, err
}

err = json.Unmarshal(body, opaResp)

if err != nil {
return DeniedAuthz, err
}
Expand Down
12 changes: 0 additions & 12 deletions pkg/authorization/external_opa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ func TestExternalOpa(t *testing.T) {
Name: "Test",
}
reqBody, err := json.Marshal(testInput)

if err != nil {
return nil, err
}
Expand All @@ -47,7 +46,6 @@ func TestExternalOpa(t *testing.T) {
"dummy",
bytes.NewReader(reqBody),
)

if err != nil {
return nil, err
}
Expand All @@ -74,7 +72,6 @@ func TestExternalOpa(t *testing.T) {
Name: "Test",
}
reqBody, err := json.Marshal(testInput)

if err != nil {
return nil, err
}
Expand All @@ -84,7 +81,6 @@ func TestExternalOpa(t *testing.T) {
"dummy",
bytes.NewReader(reqBody),
)

if err != nil {
return nil, err
}
Expand All @@ -111,7 +107,6 @@ func TestExternalOpa(t *testing.T) {
Name: "Test",
}
reqBody, err := json.Marshal(testInput)

if err != nil {
return nil, err
}
Expand All @@ -121,7 +116,6 @@ func TestExternalOpa(t *testing.T) {
"dummy",
bytes.NewReader(reqBody),
)

if err != nil {
return nil, err
}
Expand All @@ -140,7 +134,6 @@ func TestExternalOpa(t *testing.T) {
Name: "Test",
}
reqBody, err := json.Marshal(testInput)

if err != nil {
return nil, err
}
Expand All @@ -150,7 +143,6 @@ func TestExternalOpa(t *testing.T) {
"dummy",
bytes.NewReader(reqBody),
)

if err != nil {
return nil, err
}
Expand All @@ -169,7 +161,6 @@ func TestExternalOpa(t *testing.T) {
Name: "Test",
}
reqBody, err := yaml.Marshal(testInput)

if err != nil {
return nil, err
}
Expand All @@ -179,7 +170,6 @@ func TestExternalOpa(t *testing.T) {
"dummy",
bytes.NewReader(reqBody),
)

if err != nil {
return nil, err
}
Expand All @@ -206,7 +196,6 @@ func TestExternalOpa(t *testing.T) {
Name: "Test",
}
reqBody, err := yaml.Marshal(testInput)

if err != nil {
return nil, err
}
Expand All @@ -216,7 +205,6 @@ func TestExternalOpa(t *testing.T) {
"dummy",
bytes.NewReader(reqBody),
)

if err != nil {
return nil, err
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/authorization/resource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/go-chi/chi/v5"
"github.com/gogatekeeper/gatekeeper/pkg/utils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestDecodeResourceBad(t *testing.T) {
Expand Down Expand Up @@ -161,10 +162,10 @@ func TestResourceParseOk(t *testing.T) {
r, err := NewResource().Parse(testCase.Option)

if testCase.Ok {
assert.NoError(t, err, "case %d should not have errored with: %s", i, err)
require.NoError(t, err, "case %d should not have errored with: %s", i, err)
assert.Equal(t, r, testCase.Resource, "case %d, expected: %#v, got: %#v", i, testCase.Resource, r)
} else {
assert.Error(t, err)
require.Error(t, err)
}
}
}
Expand Down
35 changes: 27 additions & 8 deletions pkg/keycloak/proxy/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ limitations under the License.
package proxy

import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
Expand Down Expand Up @@ -470,7 +472,8 @@ func authenticationMiddleware(
}
}

next.ServeHTTP(wrt, req.WithContext(ctx))
*req = *(req.WithContext(ctx))
next.ServeHTTP(wrt, req)
})
}
}
Expand Down Expand Up @@ -617,12 +620,29 @@ func authorizationMiddleware(
}
}
} else if enableOpa {
provider = authorization.NewOpaAuthorizationProvider(
opaTimeout,
*opaAuthzURL,
req,
)
decision, err = provider.Authorize()
// initially request Body is stream read from network connection,
// when read once, it is closed, so second time we would not be able to
// read it, so what we will do here is that we will read body,
// create copy of original request and pass body which we already read
// to original req and to new copy of request,
// new copy will be passed to authorizer, which also needs to read body
reqBody, varErr := io.ReadAll(req.Body)
if varErr != nil {
decision = authorization.DeniedAuthz
err = varErr
} else {
req.Body.Close()
passReq := *req
passReq.Body = io.NopCloser(bytes.NewReader(reqBody))
req.Body = io.NopCloser(bytes.NewReader(reqBody))

provider = authorization.NewOpaAuthorizationProvider(
opaTimeout,
*opaAuthzURL,
&passReq,
)
decision, err = provider.Authorize()
}
}

switch err {
Expand Down Expand Up @@ -673,7 +693,6 @@ func authorizationMiddleware(
next.ServeHTTP(wrt, req.WithContext(accessForbidden(wrt, req)))
return
}

next.ServeHTTP(wrt, req)
})
}
Expand Down
5 changes: 4 additions & 1 deletion pkg/keycloak/proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,10 @@ func NewProxy(config *config.Config, log *zap.Logger, upstream reverseProxy) (*O
)
}

svc.Upstream = upstream
if upstream != nil {
svc.Upstream = upstream
}

// are we running in forwarding mode?
if config.EnableForwarding {
if err := svc.createForwardingProxy(); err != nil {
Expand Down
12 changes: 10 additions & 2 deletions pkg/testsuite/fake_upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package testsuite

import (
"encoding/json"
"io"
"net/http"
"strings"

Expand All @@ -14,16 +15,16 @@ type fakeUpstreamResponse struct {
Method string `json:"method"`
Address string `json:"address"`
Headers http.Header `json:"headers"`
Body string `json:"body"`
}

// FakeUpstreamService acts as a fake upstream service, returns the headers and request
type FakeUpstreamService struct{}

func (f *FakeUpstreamService) ServeHTTP(wrt http.ResponseWriter, req *http.Request) {
wrt.Header().Set(TestProxyAccepted, "true")

upgrade := strings.ToLower(req.Header.Get("Upgrade"))
if upgrade == "websocket" {
wrt.Header().Set(TestProxyAccepted, "true")
websocket.Handler(func(wsock *websocket.Conn) {
defer wsock.Close()
var data []byte
Expand All @@ -41,6 +42,12 @@ func (f *FakeUpstreamService) ServeHTTP(wrt http.ResponseWriter, req *http.Reque
_ = websocket.Message.Send(wsock, content)
}).ServeHTTP(wrt, req)
} else {
reqBody, err := io.ReadAll(req.Body)
if err != nil {
wrt.WriteHeader(http.StatusInternalServerError)
}

wrt.Header().Set(TestProxyAccepted, "true")
wrt.Header().Set("Content-Type", "application/json")
content, err := json.Marshal(&fakeUpstreamResponse{
// r.RequestURI is what was received by the proxy.
Expand All @@ -50,6 +57,7 @@ func (f *FakeUpstreamService) ServeHTTP(wrt http.ResponseWriter, req *http.Reque
Method: req.Method,
Address: req.RemoteAddr,
Headers: req.Header,
Body: string(reqBody),
})

if err != nil {
Expand Down
18 changes: 14 additions & 4 deletions pkg/testsuite/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2531,6 +2531,9 @@ func TestLogRealIP(t *testing.T) {

//nolint:funlen
func TestEnableOpa(t *testing.T) {
upstreamService := httptest.NewServer(&FakeUpstreamService{})
upstreamURL := upstreamService.URL

requests := []struct {
Name string
ProxySettings func(c *config.Config)
Expand All @@ -2546,17 +2549,23 @@ func TestEnableOpa(t *testing.T) {
conf.OpaTimeout = 60 * time.Second
conf.ClientID = ValidUsername
conf.ClientSecret = ValidPassword
conf.Upstream = upstreamURL
},
ExecutionSettings: []fakeRequest{
{
URI: FakeTestURL,
ExpectedProxy: true,
HasToken: true,
Redirects: false,
ExpectedCode: http.StatusOK,
Method: "POST",
FormValues: map[string]string{
"Name": "Whatever",
},
HasToken: true,
Redirects: false,
ExpectedCode: http.StatusOK,
ExpectedContent: func(body string, testNum int) {
assert.Contains(t, body, "test")
assert.Contains(t, body, "method")
assert.Contains(t, body, "Whatever")
},
},
},
Expand All @@ -2566,8 +2575,9 @@ func TestEnableOpa(t *testing.T) {
default allow := false
allow {
input.method = "GET"
input.method = "POST"
input.path = FakeTestURL
contains(input.body, "Whatever")
}
`,
StartOpa: true,
Expand Down

0 comments on commit eee4ed6

Please sign in to comment.