diff --git a/Gopkg.lock b/Gopkg.lock index 43f5c570..d43f1792 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -143,22 +143,12 @@ version = "v1.0.0" [[projects]] - digest = "1:a9fe0f8ff72c388d0128e88ce5f3c27d37dcd0950acd7cdb8323555f12463396" - name = "github.com/pressly/chi" - packages = [ - ".", - "middleware", - ] - pruneopts = "UT" - revision = "b5294d10673813fac8558e7f47242bc9e61b4c25" - version = "v3.3.3" - -[[projects]] - digest = "1:75d51eeab0df85a3cea9e1297ccd3183b20a10cb4b48c753d8ec8d113cc14250" + digest = "1:93a746f1060a8acbcf69344862b2ceced80f854170e1caae089b2834c5fbf7f4" name = "github.com/prometheus/client_golang" packages = [ "prometheus", "prometheus/internal", + "prometheus/promhttp", ] pruneopts = "UT" revision = "505eaef017263e299324067d40ca2c48f6a2cf50" @@ -369,11 +359,11 @@ "github.com/coreos/go-oidc/oidc", "github.com/elazarl/goproxy", "github.com/fsnotify/fsnotify", + "github.com/go-chi/chi", "github.com/go-chi/chi/middleware", "github.com/go-resty/resty", - "github.com/pressly/chi", - "github.com/pressly/chi/middleware", "github.com/prometheus/client_golang/prometheus", + "github.com/prometheus/client_golang/prometheus/promhttp", "github.com/rs/cors", "github.com/satori/go.uuid", "github.com/stretchr/testify/assert", diff --git a/admin_test.go b/admin_test.go new file mode 100644 index 00000000..9aea41b8 --- /dev/null +++ b/admin_test.go @@ -0,0 +1,349 @@ +/* +Copyright 2015 All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "fmt" + "io" + "io/ioutil" + "log" + "net/http" + "net/url" + "path" + "testing" + "time" + + "github.com/go-chi/chi" + "github.com/stretchr/testify/assert" +) + +const ( + e2eAdminProxyListener = "127.0.0.1:24329" + e2eAdminEndpointListener = "127.0.0.1:24330" + + e2eAdminProxyListener2 = "127.0.0.1:44329" + + e2eAdminOauthListener = "127.0.0.1:23457" + e2eAdminUpstreamListener = "127.0.0.1:28512" + e2eAdminAppListener = "127.0.0.1:33996" + e2eAdminOauthURL = "/auth/realms/hod-test/.well-known/openid-configuration" + e2eAdminOauthAuthorizeURL = "/auth/realms/hod-test/protocol/openid-connect/auth" + // #nosec + e2eAdminOauthTokenURL = "/auth/realms/hod-test/protocol/openid-connect/token" + e2eAdminOauthJWKSURL = "/auth/realms/hod-test/protocol/openid-connect/certs" + e2eAdminAppURL = "/ok" +) + +// checkListenOrBail waits on a endpoint listener to respond. +// This avoids race conditions with test listieners as go routines +func checkListenOrBail(endpoint string) bool { + const ( + maxWaitCycles = 10 + waitTime = 100 * time.Millisecond + ) + checkListen := http.Client{} + _, err := checkListen.Get(endpoint) + limit := 0 + for err != nil && limit < maxWaitCycles { + time.Sleep(waitTime) + _, err = checkListen.Get(endpoint) + limit++ + } + return limit < maxWaitCycles +} + +func runAdminTestAuth(t *testing.T) error { + // a stub OIDC provider + fake := newFakeAuthServer() + fake.location, _ = url.Parse("http://" + e2eAdminOauthListener) + go func() { + mux := http.NewServeMux() + configurationHandler := func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{ + "issuer": "http://`+e2eAdminOauthListener+`/auth/realms/hod-test", + "subject_types_supported":["public","pairwise"], + "id_token_signing_alg_values_supported":["ES384","RS384","HS256","HS512","ES256","RS256","HS384","ES512","RS512"], + "userinfo_signing_alg_values_supported":["ES384","RS384","HS256","HS512","ES256","RS256","HS384","ES512","RS512","none"], + "authorization_endpoint":"http://`+e2eAdminOauthListener+e2eAdminOauthAuthorizeURL+`", + "token_endpoint":"http://`+e2eAdminOauthListener+e2eAdminOauthTokenURL+`", + "jwks_uri":"http://`+e2eAdminOauthListener+e2eAdminOauthJWKSURL+`" + }`) + } + + authorizeHandler := func(w http.ResponseWriter, req *http.Request) { + redirect := req.FormValue("redirect_uri") + state := req.FormValue("state") + code := "zyx" + location, _ := url.PathUnescape(redirect) + u, _ := url.Parse(location) + v := u.Query() + v.Set("code", code) + v.Set("state", state) + u.RawQuery = v.Encode() + http.Redirect(w, req, u.String(), http.StatusFound) + } + + tokenHandler := func(w http.ResponseWriter, req *http.Request) { + fake.tokenHandler(w, req) + } + + keysHandler := func(w http.ResponseWriter, req *http.Request) { + fake.keysHandler(w, req) + } + mux.HandleFunc(e2eAdminOauthURL, configurationHandler) + mux.HandleFunc(e2eAdminOauthAuthorizeURL, authorizeHandler) + mux.HandleFunc(e2eAdminOauthTokenURL, tokenHandler) + mux.HandleFunc(e2eAdminOauthJWKSURL, keysHandler) + _ = http.ListenAndServe(e2eAdminOauthListener, mux) + }() + if !assert.True(t, checkListenOrBail("http://"+path.Join(e2eAdminOauthListener, e2eAdminOauthURL))) { + err := fmt.Errorf("cannot connect to test http listener on: %s", "http://"+path.Join(e2eAdminOauthListener, e2eAdminOauthURL)) + t.Logf("%v", err) + t.FailNow() + return err + } + return nil +} + +func runAdminTestApp(t *testing.T) error { + go func() { + mux := http.NewServeMux() + appHandler := func(w http.ResponseWriter, req *http.Request) { + _, _ = io.WriteString(w, `{"message": "ok"}`) + w.Header().Set("Content-Type", "application/json") + } + mux.HandleFunc(e2eAdminAppURL, appHandler) + _ = http.ListenAndServe(e2eAdminAppListener, mux) + }() + if !assert.True(t, checkListenOrBail("http://"+path.Join(e2eAdminAppListener, e2eAdminAppURL))) { + err := fmt.Errorf("cannot connect to test http listener on: %s", "http://"+path.Join(e2eAdminAppListener, e2eAdminAppURL)) + t.Logf("%v", err) + t.FailNow() + return err + } + return nil +} + +func runAdminTestGatekeeper(t *testing.T, config *Config) error { + proxy, err := newProxy(config) + if err != nil { + return err + } + _ = proxy.Run() + if !assert.True(t, checkListenOrBail("http://"+config.Listen+"/oauth/login")) { + err := fmt.Errorf("cannot connect to test http listener on: %s", "http://"+config.Listen+"/oauth/login") + t.Logf("%v", err) + t.FailNow() + return err + } + return nil +} + +func runAdminTestUpstream(t *testing.T) error { + // a stub upstream API server + go func() { + getUpstream := func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Upstream-Response-Header", "test") + _, _ = io.WriteString(w, `{"message": "test"}`) + } + + upstream := chi.NewRouter() + upstream.Route("/", func(r chi.Router) { + r.Get("/fake", getUpstream) + }) + + _ = http.ListenAndServe(e2eAdminUpstreamListener, upstream) + }() + if !assert.True(t, checkListenOrBail("http://"+path.Join(e2eAdminUpstreamListener, "/fake"))) { + err := fmt.Errorf("cannot connect to test http listener on: %s", "http://"+path.Join(e2eAdminUpstreamListener, "/fake")) + t.Logf("%v", err) + t.FailNow() + return err + } + return nil +} + +func TestAdmin(t *testing.T) { + log.SetOutput(ioutil.Discard) + + config := newDefaultConfig() + config.Verbose = false + config.DisableAllLogging = false + config.EnableLogging = false + + config.Listen = e2eAdminProxyListener + config.ListenAdmin = e2eAdminEndpointListener + config.EnableMetrics = true + config.EnableProfiling = true + config.DiscoveryURL = "http://" + e2eAdminOauthListener + e2eAdminOauthURL + config.Upstream = "http://" + e2eAdminUpstreamListener + + config.CorsOrigins = []string{"*"} + config.HTTPOnlyCookie = false // since we want to inspect the cookie for testing + config.SecureCookie = false // since we are testing over HTTP + config.AccessTokenDuration = 30 * time.Minute + config.EnableEncryptedToken = false + config.EnableSessionCookies = true + config.EnableAuthorizationCookies = false + config.EnableTokenHeader = false + config.EnableAuthorizationHeader = true + config.ClientID = fakeClientID + config.ClientSecret = fakeSecret + config.Resources = []*Resource{ + { + URL: "/fake", + Methods: []string{"GET", "POST", "DELETE"}, + WhiteListed: false, + }, + } + config.Resources = append(config.Resources, &Resource{ + URL: "/another-fake", + Methods: []string{"GET", "POST", "DELETE"}, + WhiteListed: false, + }) + config.EncryptionKey = "A123456789B123456789C123456789D1" + if !assert.NoError(t, config.isValid()) { + t.FailNow() + } + + // launch fake oauth OIDC server + err := runAdminTestAuth(t) + if !assert.NoError(t, err) { + t.FailNow() + } + + // launch fake upstream resource server + err = runAdminTestUpstream(t) + if !assert.NoError(t, err) { + t.FailNow() + } + + // launch fake app server where to land after authentication + err = runAdminTestApp(t) + if !assert.NoError(t, err) { + t.FailNow() + } + + // launch keycloak-gatekeeper proxy + err = runAdminTestGatekeeper(t, config) + if !assert.NoError(t, err) { + t.FailNow() + } + + // scenario 1: dedicated admin port + + // test health status endpoint + client := http.Client{} + u, _ := url.Parse("http://" + e2eAdminEndpointListener + "/oauth/health") + h := make(http.Header, 10) + h.Set("Content-Type", "application/json") + h.Add("Accept", "application/json") + req := &http.Request{ + Method: "GET", + URL: u, + Header: h, + } + + resp, err := client.Do(req) + if !assert.NoError(t, err) { + t.FailNow() + } + + assert.Equal(t, http.StatusOK, resp.StatusCode) + buf, erb := ioutil.ReadAll(resp.Body) + assert.NoError(t, erb) + assert.Equal(t, "OK\n", string(buf)) // check this is our test resource being called + + // test prometheus metrics endpoint + u, _ = url.Parse("http://" + e2eAdminEndpointListener + "/oauth/metrics") + req = &http.Request{ + Method: "GET", + URL: u, + Header: h, + } + + resp, err = client.Do(req) + if !assert.NoError(t, err) { + t.FailNow() + } + + assert.Equal(t, http.StatusOK, resp.StatusCode) + buf, erb = ioutil.ReadAll(resp.Body) + assert.NoError(t, erb) + assert.Contains(t, string(buf), `proxy_request_duration_sec`) + + // test profiling/debug endpoint + u, _ = url.Parse("http://" + e2eAdminEndpointListener + debugURL + "/symbol") + req = &http.Request{ + Method: "GET", + URL: u, + Header: h, + } + + resp, err = client.Do(req) + if !assert.NoError(t, err) { + t.FailNow() + } + + assert.Equal(t, http.StatusOK, resp.StatusCode) + buf, erb = ioutil.ReadAll(resp.Body) + assert.NoError(t, erb) + assert.Contains(t, string(buf), "num_symbols: 1\n") + + // scenario 2: admin endpoints beside other routes + config.Listen = e2eAdminProxyListener2 + config.ListenAdmin = "" + config.LocalhostMetrics = true + + // launch a new keycloak-gatekeeper proxy + err = runAdminTestGatekeeper(t, config) + if !assert.NoError(t, err) { + t.FailNow() + } + + // test health status endpoint, unauthenticated + u, _ = url.Parse("http://" + e2eAdminProxyListener2 + "/oauth/health") + req = &http.Request{ + Method: "GET", + URL: u, + Header: h, + } + resp, err = client.Do(req) + if !assert.NoError(t, err) { + t.FailNow() + } + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // test metrics + u, _ = url.Parse("http://" + e2eAdminProxyListener2 + "/oauth/metrics") + req = &http.Request{ + Method: "GET", + URL: u, + Header: h, + } + + resp, err = client.Do(req) + if !assert.NoError(t, err) { + t.FailNow() + } + + assert.Equal(t, http.StatusOK, resp.StatusCode) + buf, erb = ioutil.ReadAll(resp.Body) + assert.NoError(t, erb) + assert.Contains(t, string(buf), `proxy_request_duration_sec`) +} diff --git a/config.go b/config.go index cd4410af..56c49d3a 100644 --- a/config.go +++ b/config.go @@ -31,7 +31,7 @@ func newDefaultConfig() *Config { if name, err := os.Hostname(); err == nil { hostnames = append(hostnames, name) } - hostnames = append(hostnames, []string{"localhost", "127.0.0.1"}...) + hostnames = append(hostnames, []string{"localhost", "127.0.0.1", "::1"}...) return &Config{ AccessTokenDuration: time.Duration(720) * time.Hour, @@ -86,6 +86,15 @@ func (r *Config) isValid() error { if r.Listen == "" { return errors.New("you have not specified the listening interface") } + if r.ListenAdmin == r.Listen { + r.ListenAdmin = "" + } + if r.ListenAdminScheme == "" { + r.ListenAdminScheme = secureScheme + } + if r.ListenAdminScheme != secureScheme && r.ListenAdminScheme != unsecureScheme { + return errors.New("scheme for admin listener must be one of [http, https]") + } if r.MaxIdleConns <= 0 { return errors.New("max-idle-connections must be a number > 0") } @@ -95,21 +104,39 @@ func (r *Config) isValid() error { if r.TLSCertificate != "" && r.TLSPrivateKey == "" { return errors.New("you have not provided a private key") } + if r.TLSAdminCertificate != "" && r.TLSAdminPrivateKey == "" { + return errors.New("you have not provided a private key for admin endpoint") + } if r.TLSPrivateKey != "" && r.TLSCertificate == "" { return errors.New("you have not provided a certificate file") } + if r.TLSAdminPrivateKey != "" && r.TLSAdminCertificate == "" { + return errors.New("you have not provided a certificate file for admin endpoint") + } if r.TLSCertificate != "" && !fileExists(r.TLSCertificate) { return fmt.Errorf("the tls certificate %s does not exist", r.TLSCertificate) } + if r.TLSAdminCertificate != "" && !fileExists(r.TLSAdminCertificate) { + return fmt.Errorf("the tls certificate %s does not exist for admin endpoint", r.TLSAdminCertificate) + } if r.TLSPrivateKey != "" && !fileExists(r.TLSPrivateKey) { return fmt.Errorf("the tls private key %s does not exist", r.TLSPrivateKey) } + if r.TLSAdminPrivateKey != "" && !fileExists(r.TLSAdminPrivateKey) { + return fmt.Errorf("the tls private key %s does not exist for admin endpoint", r.TLSAdminPrivateKey) + } if r.TLSCaCertificate != "" && !fileExists(r.TLSCaCertificate) { return fmt.Errorf("the tls ca certificate file %s does not exist", r.TLSCaCertificate) } + if r.TLSAdminCaCertificate != "" && !fileExists(r.TLSAdminCaCertificate) { + return fmt.Errorf("the tls ca certificate file %s does not exist for admin endpoint", r.TLSAdminCaCertificate) + } if r.TLSClientCertificate != "" && !fileExists(r.TLSClientCertificate) { return fmt.Errorf("the tls client certificate %s does not exist", r.TLSClientCertificate) } + if r.TLSAdminClientCertificate != "" && !fileExists(r.TLSAdminClientCertificate) { + return fmt.Errorf("the tls client certificate %s does not exist for admin endpoint", r.TLSAdminClientCertificate) + } if r.UseLetsEncrypt && r.LetsEncryptCacheDir == "" { return fmt.Errorf("the letsencrypt cache dir has not been set") } diff --git a/doc.go b/doc.go index 0b11c4b8..90906d4f 100644 --- a/doc.go +++ b/doc.go @@ -153,10 +153,14 @@ type Resource struct { type Config struct { // ConfigFile is the binding interface ConfigFile string `json:"config" yaml:"config" usage:"path the a configuration file" env:"CONFIG_FILE"` - // Listen is the binding interface - Listen string `json:"listen" yaml:"listen" usage:"the interface the service should be listening on" env:"LISTEN"` + // Listen defines the binding interface for main listener, e.g. {address}:{port}. This is required and there is no default value. + Listen string `json:"listen" yaml:"listen" usage:"Defines the binding interface for main listener, e.g. {address}:{port}. This is required and there is no default value" env:"LISTEN"` // ListenHTTP is the interface to bind the http only service on - ListenHTTP string `json:"listen-http" yaml:"listen-http" usage:"interface we should be listening" env:"LISTEN_HTTP"` + ListenHTTP string `json:"listen-http" yaml:"listen-http" usage:"interface we should be listening to for HTTP traffic" env:"LISTEN_HTTP"` + // ListenAdmin defines the interface to bind admin-only endpoint (live-status, debug, prometheus...). If not defined, this defaults to the main listener defined by Listen. + ListenAdmin string `json:"listen-admin" yaml:"listen-admin" usage:"defines the interface to bind admin-only endpoint (live-status, debug, prometheus...). If not defined, this defaults to the main listener defined by Listen" env:"LISTEN_ADMIN"` + // ListenAdminScheme defines the scheme admin endpoints are served with. If not defined, same as main listener. + ListenAdminScheme string `json:"listen-admin-scheme" yaml:"listen-admin-scheme" usage:"scheme to serve admin-only endpoint (http or https)." env:"LISTEN_ADMIN_SCHEME"` // DiscoveryURL is the url for the keycloak server DiscoveryURL string `json:"discovery-url" yaml:"discovery-url" usage:"discovery url to retrieve the openid configuration" env:"DISCOVERY_URL"` // ClientID is the client id @@ -243,7 +247,7 @@ type Config struct { EnableFrameDeny bool `json:"filter-frame-deny" yaml:"filter-frame-deny" usage:"enable to the frame deny header"` // ContentSecurityPolicy allows the Content-Security-Policy header value to be set with a custom value ContentSecurityPolicy string `json:"content-security-policy" yaml:"content-security-policy" usage:"specify the content security policy"` - // LocalhostMetrics indicated the metrics can only be consume via localhost + // LocalhostMetrics indicates that metrics can only be consumed from localhost LocalhostMetrics bool `json:"localhost-metrics" yaml:"localhost-metrics" usage:"enforces the metrics page can only been requested from 127.0.0.1"` // AccessTokenDuration is default duration applied to the access token cookie @@ -277,6 +281,15 @@ type Config struct { // SkipUpstreamTLSVerify skips the verification of any upstream tls SkipUpstreamTLSVerify bool `json:"skip-upstream-tls-verify" yaml:"skip-upstream-tls-verify" usage:"skip the verification of any upstream TLS"` + // TLSAdminCertificate is the location for a tls certificate for admin https endpoint. Defaults to TLSCertificate. + TLSAdminCertificate string `json:"tls-admin-cert" yaml:"tls-admin-cert" usage:"path to ths TLS certificate" env:"TLS_ADMIN_CERTIFICATE"` + // TLSAdminPrivateKey is the location of a tls private key for admin https endpoint. Default to TLSPrivateKey + TLSAdminPrivateKey string `json:"tls-admin-private-key" yaml:"tls-admin-private-key" usage:"path to the private key for TLS" env:"TLS_ADMIN_PRIVATE_KEY"` + // TLSCaCertificate is the CA certificate which the client cert must be signed + TLSAdminCaCertificate string `json:"tls-admin-ca-certificate" yaml:"tls-admin-ca-certificate" usage:"path to the ca certificate used for signing requests" env:"TLS_ADMIN_CA_CERTIFICATE"` + // TLSAdinClientCertificate is path to a client certificate to use for outbound connections + TLSAdminClientCertificate string `json:"tls-admin-client-certificate" yaml:"tls-admin-client-certificate" usage:"path to the client certificate for outbound connections in reverse and forwarding proxy modes" env:"TLS_ADMIN_CLIENT_CERTIFICATE"` + // CorsOrigins is a list of origins permitted CorsOrigins []string `json:"cors-origins" yaml:"cors-origins" usage:"origins to add to the CORE origins control (Access-Control-Allow-Origin)"` // CorsMethods is a set of access control methods @@ -302,7 +315,7 @@ type Config struct { InvalidAuthRedirectsWith303 bool `json:"invalid-auth-redirects-with-303" yaml:"invalid-auth-redirects-with-303" usage:"use HTTP 303 redirects instead of 307 for invalid auth tokens"` // NoRedirects informs we should hand back a 401 not a redirect NoRedirects bool `json:"no-redirects" yaml:"no-redirects" usage:"do not have back redirects when no authentication is present, 401 them"` - // SkipTokenVerification tells the service to skipp verifying the access token - for testing purposes + // SkipTokenVerification tells the service to skip verifying the access token - for testing purposes SkipTokenVerification bool `json:"skip-token-verification" yaml:"skip-token-verification" usage:"TESTING ONLY; bypass token verification, only expiration and roles enforced"` // UpstreamKeepalives specifies whether we use keepalives on the upstream UpstreamKeepalives bool `json:"upstream-keepalives" yaml:"upstream-keepalives" usage:"enables or disables the keepalive connections for upstream endpoint"` diff --git a/handlers.go b/handlers.go index 5245900e..d0d86c46 100644 --- a/handlers.go +++ b/handlers.go @@ -33,7 +33,7 @@ import ( "github.com/coreos/go-oidc/oauth2" - "github.com/pressly/chi" + "github.com/go-chi/chi" "go.uber.org/zap" ) diff --git a/middleware.go b/middleware.go index 38c48353..e4cf42cd 100644 --- a/middleware.go +++ b/middleware.go @@ -285,7 +285,7 @@ func (r *oauthProxy) checkClaim(user *userContext, claimName string, match *rege return false } -// admissionMiddleware is responsible checking the access token against the protected resource +// admissionMiddleware is responsible for checking the access token against the protected resource func (r *oauthProxy) admissionMiddleware(resource *Resource) func(http.Handler) http.Handler { claimMatches := make(map[string]*regexp.Regexp) for k, v := range r.config.MatchClaims { @@ -359,7 +359,7 @@ func (r *oauthProxy) responseHeaderMiddleware(headers map[string]string) func(ht } } -// identityHeadersMiddleware is responsible for add the authentication headers for the upstream +// identityHeadersMiddleware is responsible for adding the authentication headers to upstream func (r *oauthProxy) identityHeadersMiddleware(custom []string) func(http.Handler) http.Handler { customClaims := make(map[string]string) for _, x := range custom { diff --git a/oauth_test.go b/oauth_test.go index 70786be8..ef623479 100644 --- a/oauth_test.go +++ b/oauth_test.go @@ -30,8 +30,8 @@ import ( "github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/oauth2" - "github.com/pressly/chi" - "github.com/pressly/chi/middleware" + "github.com/go-chi/chi" + "github.com/go-chi/chi/middleware" "github.com/stretchr/testify/assert" ) diff --git a/server.go b/server.go index fc308e64..afe56a3b 100644 --- a/server.go +++ b/server.go @@ -28,6 +28,7 @@ import ( "net/http" "net/url" "os" + "path" "runtime" "strings" "time" @@ -39,8 +40,8 @@ import ( proxyproto "github.com/armon/go-proxyproto" "github.com/coreos/go-oidc/oidc" "github.com/elazarl/goproxy" - "github.com/pressly/chi" - "github.com/pressly/chi/middleware" + "github.com/go-chi/chi" + "github.com/go-chi/chi/middleware" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rs/cors" @@ -57,6 +58,7 @@ type oauthProxy struct { log *zap.Logger metricsHandler http.Handler router http.Handler + adminRouter http.Handler server *http.Server store storage templates *template.Template @@ -106,7 +108,7 @@ func newProxy(config *Config) (*oauthProxy, error) { return nil, err } } else { - log.Warn("TESTING ONLY CONFIG - the verification of the token have been disabled") + log.Warn("TESTING ONLY CONFIG - access token verification has been disabled") } if config.ClientID == "" && config.ClientSecret == "" { @@ -152,19 +154,15 @@ func createLogger(config *Config) (*zap.Logger, error) { return c.Build() } -// createReverseProxy creates a reverse proxy -func (r *oauthProxy) createReverseProxy() error { - r.log.Info("enabled reverse proxy mode, upstream url", zap.String("url", r.config.Upstream)) - if err := r.createUpstreamProxy(r.endpoint); err != nil { - return err - } - engine := chi.NewRouter() +// useDefaultStack sets the default middleware stack for router +func (r *oauthProxy) useDefaultStack(engine chi.Router) { engine.MethodNotAllowed(emptyHandler) engine.NotFound(emptyHandler) engine.Use(middleware.Recoverer) + // @check if the request tracking id middleware is enabled if r.config.EnableRequestID { - r.log.Info("enabled the correlation request id middlware") + r.log.Info("enabled the correlation request id middleware") engine.Use(r.requestIDMiddleware(r.config.RequestIDHeader)) } // @step: enable the entrypoint middleware @@ -173,10 +171,22 @@ func (r *oauthProxy) createReverseProxy() error { if r.config.EnableLogging { engine.Use(r.loggingMiddleware) } + if r.config.EnableSecurityFilter { engine.Use(r.securityMiddleware) } +} + +// createReverseProxy creates a reverse proxy +func (r *oauthProxy) createReverseProxy() error { + r.log.Info("enabled reverse proxy mode, upstream url", zap.String("url", r.config.Upstream)) + if err := r.createUpstreamProxy(r.endpoint); err != nil { + return err + } + engine := chi.NewRouter() + r.useDefaultStack(engine) + // @step: configure CORS middleware if len(r.config.CorsOrigins) > 0 { c := cors.New(cors.Options{ AllowedOrigins: r.config.CorsOrigins, @@ -196,32 +206,62 @@ func (r *oauthProxy) createReverseProxy() error { engine.Use(r.responseHeaderMiddleware(r.config.ResponseHeaders)) } + // step: define admin subrouter: health and metrics + adminEngine := chi.NewRouter() + r.log.Info("enabled health service", zap.String("path", path.Clean(r.config.WithOAuthURI(healthURL)))) + adminEngine.Get(healthURL, r.healthHandler) + if r.config.EnableMetrics { + r.log.Info("enabled the service metrics middleware", zap.String("path", path.Clean(r.config.WithOAuthURI(metricsURL)))) + adminEngine.Get(metricsURL, r.proxyMetricsHandler) + } + // step: add the routing for oauth engine.With(proxyDenyMiddleware).Route(r.config.OAuthURI, func(e chi.Router) { e.MethodNotAllowed(methodNotAllowHandlder) e.HandleFunc(authorizationURL, r.oauthAuthorizationHandler) e.Get(callbackURL, r.oauthCallbackHandler) e.Get(expiredURL, r.expirationHandler) - e.Get(healthURL, r.healthHandler) e.Get(logoutURL, r.logoutHandler) e.Get(tokenURL, r.tokenHandler) e.Post(loginURL, r.loginHandler) - if r.config.EnableMetrics { - r.log.Info("enabled the service metrics middleware", zap.String("path", r.config.WithOAuthURI(metricsURL))) - e.Get(metricsURL, r.proxyMetricsHandler) + + if r.config.ListenAdmin == "" { + e.Mount("/", adminEngine) } }) + // step: define profiling subrouter + var debugEngine chi.Router if r.config.EnableProfiling { - engine.With(proxyDenyMiddleware).Route(debugURL, func(e chi.Router) { - r.log.Warn("enabling the debug profiling on /debug/pprof") - e.Get("/{name}", r.debugHandler) - e.Post("/{name}", r.debugHandler) - }) + r.log.Warn("enabling the debug profiling on " + debugURL) + debugEngine = chi.NewRouter() + debugEngine.Get("/{name}", r.debugHandler) + debugEngine.Post("/{name}", r.debugHandler) + // @check if the server write-timeout is still set and throw a warning if r.config.ServerWriteTimeout > 0 { - r.log.Warn("you must disable the server write timeout (--server-write-timeout) when using pprof profiling") + r.log.Warn("you should disable the server write timeout (--server-write-timeout) when using pprof profiling") } + if r.config.ListenAdmin == "" { + engine.With(proxyDenyMiddleware).Mount(debugURL, debugEngine) + } + } + + if r.config.ListenAdmin != "" { + // mount admin and debug engines separately + r.log.Info("mounting admin endpoints on separate listener") + admin := chi.NewRouter() + admin.MethodNotAllowed(emptyHandler) + admin.NotFound(emptyHandler) + admin.Use(middleware.Recoverer) + admin.Use(proxyDenyMiddleware) + admin.Route("/", func(e chi.Router) { + e.Mount(r.config.OAuthURI, adminEngine) + if debugEngine != nil { + e.Mount(debugURL, debugEngine) + } + }) + r.adminRouter = admin } if r.config.EnableSessionCookies { @@ -351,25 +391,12 @@ func (r *oauthProxy) createForwardingProxy() error { // Run starts the proxy service func (r *oauthProxy) Run() error { - listener, err := r.createHTTPListener(listenerConfig{ - ca: r.config.TLSCaCertificate, - certificate: r.config.TLSCertificate, - clientCert: r.config.TLSClientCertificate, - hostnames: r.config.Hostnames, - letsEncryptCacheDir: r.config.LetsEncryptCacheDir, - listen: r.config.Listen, - privateKey: r.config.TLSPrivateKey, - proxyProtocol: r.config.EnableProxyProtocol, - redirectionURL: r.config.RedirectionURL, - useFileTLS: r.config.TLSPrivateKey != "" && r.config.TLSCertificate != "", - useLetsEncryptTLS: r.config.UseLetsEncrypt, - useSelfSignedTLS: r.config.EnabledSelfSignedTLS, - }) - + listener, err := r.createHTTPListener(makeListenerConfig(r.config)) if err != nil { return err } - // step: create the http server + + // step: create the main http(s) server server := &http.Server{ Addr: r.config.Listen, Handler: r.router, @@ -413,6 +440,62 @@ func (r *oauthProxy) Run() error { }() } + // step: are we running specific admin service as well? + // if not, admin endpoints are added as routes in the main service + if r.config.ListenAdmin != "" { + r.log.Info("keycloak proxy admin service starting", zap.String("interface", r.config.ListenAdmin)) + var ( + adminListener net.Listener + err error + ) + + if r.config.ListenAdminScheme == unsecureScheme { + // run the admin endpoint (metrics, health) with http + adminListener, err = r.createHTTPListener(listenerConfig{ + listen: r.config.ListenAdmin, + proxyProtocol: r.config.EnableProxyProtocol, + }) + if err != nil { + return err + } + } else { + adminListenerConfig := makeListenerConfig(r.config) + + // admin specific overides + adminListenerConfig.listen = r.config.ListenAdmin + + // TLS configuration defaults to the one for the main service, + // and may be overidden + if r.config.TLSAdminPrivateKey != "" && r.config.TLSAdminCertificate != "" { + adminListenerConfig.useFileTLS = true + adminListenerConfig.certificate = r.config.TLSAdminCertificate + adminListenerConfig.privateKey = r.config.TLSAdminPrivateKey + } + if r.config.TLSAdminCaCertificate != "" { + adminListenerConfig.ca = r.config.TLSAdminCaCertificate + } + if r.config.TLSAdminClientCertificate != "" { + adminListenerConfig.clientCert = r.config.TLSAdminClientCertificate + } + adminListener, err = r.createHTTPListener(adminListenerConfig) + if err != nil { + return err + } + } + adminsvc := &http.Server{ + Addr: r.config.ListenAdmin, + Handler: r.adminRouter, + ReadTimeout: r.config.ServerReadTimeout, + WriteTimeout: r.config.ServerWriteTimeout, + IdleTimeout: r.config.ServerIdleTimeout, + } + + go func() { + if ers := adminsvc.Serve(adminListener); err != nil { + r.log.Fatal("failed to start the admin service", zap.Error(ers)) + } + }() + } return nil } @@ -432,6 +515,26 @@ type listenerConfig struct { useSelfSignedTLS bool // indicates we are using the self-signed tls } +// makeListenerConfig extracts a listener configuration from a proxy Config +func makeListenerConfig(config *Config) listenerConfig { + return listenerConfig{ + hostnames: config.Hostnames, + letsEncryptCacheDir: config.LetsEncryptCacheDir, + listen: config.Listen, + proxyProtocol: config.EnableProxyProtocol, + redirectionURL: config.RedirectionURL, + + // TLS settings + useFileTLS: config.TLSPrivateKey != "" && config.TLSCertificate != "", + privateKey: config.TLSPrivateKey, + ca: config.TLSCaCertificate, + certificate: config.TLSCertificate, + clientCert: config.TLSClientCertificate, + useLetsEncryptTLS: config.UseLetsEncrypt, + useSelfSignedTLS: config.EnabledSelfSignedTLS, + } +} + // ErrHostNotConfigured indicates the hostname was not configured var ErrHostNotConfigured = errors.New("acme/autocert: host not configured") @@ -528,7 +631,9 @@ func (r *oauthProxy) createHTTPListener(config listenerConfig) (net.Listener, er } tlsConfig := &tls.Config{ - GetCertificate: getCertificate, + GetCertificate: getCertificate, + // Causes servers to use Go's default ciphersuite preferences, + // which are tuned to avoid attacks. Does nothing on clients. PreferServerCipherSuites: true, NextProtos: []string{"h2", "http/1.1"}, }