From 6d7f81e3722b2325d7f98783a6e7d6ce39b48b72 Mon Sep 17 00:00:00 2001 From: Geoffrey Ragot Date: Sun, 2 Oct 2022 22:39:14 +0200 Subject: [PATCH] feat: refactor to properly handle urls prefix stripped. --- cmd/serve.go | 12 ++---- pkg/api/authorization/module.go | 4 +- pkg/api/{routing => }/context.go | 2 +- pkg/api/module.go | 40 ++++++++++++++++---- pkg/api/routing/module.go | 48 ------------------------ pkg/api/routing/module_test.go | 64 -------------------------------- pkg/api/{routing => }/server.go | 2 +- pkg/oidc/module.go | 11 +++--- pkg/oidc/oidc_test.go | 5 +-- pkg/oidc/router.go | 7 +--- 10 files changed, 48 insertions(+), 147 deletions(-) rename pkg/api/{routing => }/context.go (96%) delete mode 100644 pkg/api/routing/module.go delete mode 100644 pkg/api/routing/module_test.go rename pkg/api/{routing => }/server.go (96%) diff --git a/cmd/serve.go b/cmd/serve.go index 5b0e8b8..a809170 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -5,7 +5,6 @@ import ( "crypto/x509" "encoding/pem" "fmt" - "net/url" auth "github.com/formancehq/auth/pkg" "github.com/formancehq/auth/pkg/api" @@ -71,11 +70,6 @@ var serveCmd = &cobra.Command{ return errors.New("base url must be defined") } - baseUrl, err := url.Parse(viper.GetString(baseUrlFlag)) - if err != nil { - return errors.Wrap(err, "parsing base url") - } - delegatedClientID := viper.GetString(delegatedClientIDFlag) if delegatedClientID == "" { return errors.New("delegated client id must be defined") @@ -127,10 +121,10 @@ var serveCmd = &cobra.Command{ Issuer: delegatedIssuer, ClientID: delegatedClientID, ClientSecret: delegatedClientSecret, - RedirectURL: fmt.Sprintf("%s/authorize/callback", baseUrl.String()), + RedirectURL: fmt.Sprintf("%s/authorize/callback", viper.GetString(baseUrlFlag)), }), - api.Module(":8080", baseUrl), - oidc.Module(key, baseUrl, o.Clients...), + api.Module(":8080"), + oidc.Module(key, viper.GetString(baseUrlFlag), o.Clients...), authorization.Module(), sqlstorage.Module(sqlstorage.KindPostgres, viper.GetString(postgresUriFlag), viper.GetBool(debugFlag), key, o.Clients...), diff --git a/pkg/api/authorization/module.go b/pkg/api/authorization/module.go index 2f9f1c6..470ac69 100644 --- a/pkg/api/authorization/module.go +++ b/pkg/api/authorization/module.go @@ -11,14 +11,14 @@ import ( func Module() fx.Option { return fx.Options( - fx.Invoke(fx.Annotate(func(router *mux.Router, o op.OpenIDProvider) error { + fx.Invoke(func(router *mux.Router, o op.OpenIDProvider) error { return router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { route.Handler( middleware(o)(route.GetHandler()), ) return nil }) - }, fx.ParamTags(`name:"prefixedRouter"`))), + }), ) } diff --git a/pkg/api/routing/context.go b/pkg/api/context.go similarity index 96% rename from pkg/api/routing/context.go rename to pkg/api/context.go index 8d8473d..9b68831 100644 --- a/pkg/api/routing/context.go +++ b/pkg/api/context.go @@ -1,4 +1,4 @@ -package routing +package api import ( "context" diff --git a/pkg/api/module.go b/pkg/api/module.go index cf4d963..295c34b 100644 --- a/pkg/api/module.go +++ b/pkg/api/module.go @@ -1,21 +1,47 @@ package api import ( - "net/url" + "context" + "net/http" - "github.com/formancehq/auth/pkg/api/routing" + "github.com/gorilla/mux" sharedhealth "github.com/numary/go-libs/sharedhealth/pkg" + "go.opentelemetry.io/contrib/instrumentation/github.com/gorilla/mux/otelmux" "go.uber.org/fx" ) -func Module(addr string, baseUrl *url.URL) fx.Option { +func CreateRootRouter(healthController *sharedhealth.HealthController) *mux.Router { + rootRouter := mux.NewRouter() + rootRouter.Use(otelmux.Middleware("auth")) + rootRouter.Use(func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + handler.ServeHTTP(w, r) + }) + }) + rootRouter.Path("/_healthcheck").HandlerFunc(healthController.Check) + + return rootRouter +} + +func Module(addr string) fx.Option { return fx.Options( sharedhealth.ProvideHealthCheck(delegatedOIDCServerAvailability), - routing.Module(addr, baseUrl), + sharedhealth.Module(), + fx.Provide(func(healthController *sharedhealth.HealthController) *mux.Router { + return CreateRootRouter(healthController) + }), + fx.Invoke(func(lc fx.Lifecycle, router *mux.Router, healthController *sharedhealth.HealthController) { + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + return StartServer(ctx, addr, router) + }, + }) + }), fx.Invoke( - fx.Annotate(addClientRoutes, fx.ParamTags(``, `name:"prefixedRouter"`)), - fx.Annotate(addScopeRoutes, fx.ParamTags(``, `name:"prefixedRouter"`)), - fx.Annotate(addUserRoutes, fx.ParamTags(``, `name:"prefixedRouter"`)), + addClientRoutes, + addScopeRoutes, + addUserRoutes, ), ) } diff --git a/pkg/api/routing/module.go b/pkg/api/routing/module.go deleted file mode 100644 index d8fce02..0000000 --- a/pkg/api/routing/module.go +++ /dev/null @@ -1,48 +0,0 @@ -package routing - -import ( - "context" - "net/http" - "net/url" - - "github.com/gorilla/mux" - sharedhealth "github.com/numary/go-libs/sharedhealth/pkg" - "go.opentelemetry.io/contrib/instrumentation/github.com/gorilla/mux/otelmux" - "go.uber.org/fx" -) - -func CreateRootRouter(baseUrl *url.URL, healthController *sharedhealth.HealthController) (*mux.Router, *mux.Router) { - rootRouter := mux.NewRouter() - rootRouter.Use(otelmux.Middleware("auth")) - rootRouter.Use(func(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - handler.ServeHTTP(w, r) - }) - }) - rootRouter.Path("/_healthcheck").HandlerFunc(healthController.Check) - - subRouter := rootRouter - if baseUrl.Path != "/" { - subRouter = subRouter.PathPrefix(baseUrl.Path).Subrouter() - subRouter.Path("/_healthcheck").HandlerFunc(healthController.Check) - } - - return rootRouter, subRouter -} - -func Module(addr string, baseUrl *url.URL) fx.Option { - return fx.Options( - sharedhealth.Module(), - fx.Provide(fx.Annotate(func(healthController *sharedhealth.HealthController) (*mux.Router, *mux.Router) { - return CreateRootRouter(baseUrl, healthController) - }, fx.ResultTags(`name:"rootRouter"`, `name:"prefixedRouter"`))), - fx.Invoke(fx.Annotate(func(lc fx.Lifecycle, router *mux.Router, healthController *sharedhealth.HealthController) { - lc.Append(fx.Hook{ - OnStart: func(ctx context.Context) error { - return StartServer(ctx, addr, router) - }, - }) - }, fx.ParamTags(``, `name:"rootRouter"`))), - ) -} diff --git a/pkg/api/routing/module_test.go b/pkg/api/routing/module_test.go deleted file mode 100644 index 6be7e12..0000000 --- a/pkg/api/routing/module_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package routing - -import ( - "context" - "fmt" - "net/http" - "net/url" - "testing" - - "github.com/gorilla/mux" - "github.com/stretchr/testify/require" - "go.uber.org/fx" -) - -func testWithUrl(t *testing.T, urlStr string) { - u, err := url.Parse(urlStr) - require.NoError(t, err) - - ctx := NewContext(context.Background()) - - app := fx.New( - fx.Invoke(fx.Annotate(func(router *mux.Router) { - router.Path("/subpath-with-prefix").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNoContent) - }) - }, fx.ParamTags(`name:"prefixedRouter"`))), - fx.Invoke(fx.Annotate(func(router *mux.Router) { - router.Path("/subpath-to-root").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNoContent) - }) - }, fx.ParamTags(`name:"rootRouter"`))), - Module(":0", u), - fx.NopLogger, - ) - - require.NoError(t, app.Start(ctx)) - defer func() { - require.NoError(t, app.Stop(ctx)) - }() - - serverUrl := fmt.Sprintf("http://localhost:%d", ListeningPort(ctx)) - serverUrlWithPath := fmt.Sprintf("%s%s", serverUrl, u.Path) - - rsp, err := http.Get(fmt.Sprintf("%s/_healthcheck", serverUrlWithPath)) - require.NoError(t, err) - require.Equal(t, http.StatusOK, rsp.StatusCode) - - rsp, err = http.Get(fmt.Sprintf("%s/_healthcheck", serverUrl)) - require.NoError(t, err) - require.Equal(t, http.StatusOK, rsp.StatusCode) - - rsp, err = http.Get(fmt.Sprintf("%s/subpath-with-prefix", serverUrlWithPath)) - require.NoError(t, err) - require.Equal(t, http.StatusNoContent, rsp.StatusCode) - - rsp, err = http.Get(fmt.Sprintf("%s/subpath-to-root", serverUrl)) - require.NoError(t, err) - require.Equal(t, http.StatusNoContent, rsp.StatusCode) -} - -func TestModule(t *testing.T) { - testWithUrl(t, "http://localhost") - testWithUrl(t, "http://localhost/any/sub/path") -} diff --git a/pkg/api/routing/server.go b/pkg/api/server.go similarity index 96% rename from pkg/api/routing/server.go rename to pkg/api/server.go index 3034247..cc78d33 100644 --- a/pkg/api/routing/server.go +++ b/pkg/api/server.go @@ -1,4 +1,4 @@ -package routing +package api import ( "context" diff --git a/pkg/oidc/module.go b/pkg/oidc/module.go index fbbaa0c..1278783 100644 --- a/pkg/oidc/module.go +++ b/pkg/oidc/module.go @@ -3,7 +3,6 @@ package oidc import ( "context" "crypto/rsa" - "net/url" auth "github.com/formancehq/auth/pkg" "github.com/formancehq/auth/pkg/delegatedauth" @@ -13,11 +12,11 @@ import ( "go.uber.org/fx" ) -func Module(privateKey *rsa.PrivateKey, baseUrl *url.URL, staticClients ...auth.StaticClient) fx.Option { +func Module(privateKey *rsa.PrivateKey, issuer string, staticClients ...auth.StaticClient) fx.Option { return fx.Options( - fx.Invoke(fx.Annotate(func(router *mux.Router, provider op.OpenIDProvider, storage Storage, relyingParty rp.RelyingParty) { - AddRoutes(router, provider, storage, relyingParty, baseUrl) - }, fx.ParamTags(`name:"rootRouter"`))), + fx.Invoke(func(router *mux.Router, provider op.OpenIDProvider, storage Storage, relyingParty rp.RelyingParty) { + AddRoutes(router, provider, storage, relyingParty) + }), fx.Provide(fx.Annotate(func(storage Storage, relyingParty rp.RelyingParty) *storageFacade { return NewStorageFacade(storage, relyingParty, privateKey, staticClients...) }, fx.As(new(op.Storage)))), @@ -27,7 +26,7 @@ func Module(privateKey *rsa.PrivateKey, baseUrl *url.URL, staticClients ...auth. return nil, err } - return NewOpenIDProvider(context.TODO(), storage, baseUrl.String(), configuration.Issuer, *keySet) + return NewOpenIDProvider(context.TODO(), storage, issuer, configuration.Issuer, *keySet) }), ) } diff --git a/pkg/oidc/oidc_test.go b/pkg/oidc/oidc_test.go index 9bab11f..9410704 100644 --- a/pkg/oidc/oidc_test.go +++ b/pkg/oidc/oidc_test.go @@ -77,12 +77,9 @@ func withServer(t *testing.T, fn func(m *mockoidc.MockOIDC, storage *sqlstorage. provider, err := oidc.NewOpenIDProvider(context.TODO(), storageFacade, serverUrl, mockOIDC.Issuer(), *keySet) require.NoError(t, err) - u, err := url.Parse(serverUrl) - require.NoError(t, err) - // Create the router router := mux.NewRouter() - oidc.AddRoutes(router, provider, storage, serverRelyingParty, u) + oidc.AddRoutes(router, provider, storage, serverRelyingParty) // Create our http server for our oidc provider providerHttpServer := &http.Server{ diff --git a/pkg/oidc/router.go b/pkg/oidc/router.go index a586f50..018701f 100644 --- a/pkg/oidc/router.go +++ b/pkg/oidc/router.go @@ -1,18 +1,15 @@ package oidc import ( - "net/http" - "net/url" - "github.com/gorilla/mux" "github.com/zitadel/oidc/pkg/client/rp" "github.com/zitadel/oidc/pkg/op" ) -func AddRoutes(router *mux.Router, provider op.OpenIDProvider, storage Storage, relyingParty rp.RelyingParty, baseUrl *url.URL) { +func AddRoutes(router *mux.Router, provider op.OpenIDProvider, storage Storage, relyingParty rp.RelyingParty) { router.NewRoute().Path("/authorize/callback").Queries("code", "{code}"). Handler(authorizeCallbackHandler(provider, storage, relyingParty)) router.NewRoute().Path("/authorize/callback").Queries("error", "{error}"). Handler(authorizeErrorHandler()) - router.PathPrefix("/").Handler(http.StripPrefix(baseUrl.Path, provider.HttpHandler())) + router.PathPrefix("/").Handler(provider.HttpHandler()) }