From cb8868929957f8d1f28b21a2a42b53f898de2c1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E6=B3=BD=E8=BD=A9?= Date: Wed, 21 Feb 2024 11:46:40 +0800 Subject: [PATCH] add oidc plugin (#301) Signed-off-by: spacewander --- Makefile | 4 +- go.mod | 8 +- go.sum | 10 + pkg/filtermanager/filtermanager.go | 36 +-- pkg/filtermanager/filtermanager_test.go | 2 +- pkg/request/request.go | 71 +++++- plugins/oidc/config.go | 93 ++++++++ plugins/oidc/config.pb.go | 219 +++++++++++++++++ plugins/oidc/config.pb.validate.go | 200 ++++++++++++++++ plugins/oidc/config.proto | 36 +++ plugins/oidc/config_test.go | 61 +++++ plugins/oidc/filter.go | 223 ++++++++++++++++++ plugins/oidc/filter_test.go | 202 ++++++++++++++++ plugins/plugins.go | 1 + .../integration/data_plane/data_plane.go | 7 +- plugins/tests/integration/oidc_test.go | 113 +++++++++ .../testdata/services/docker-compose.yml | 52 ++++ .../testdata/services/hydra/hydra.yml | 22 ++ plugins/tests/pkg/envoy/capi.go | 6 +- 19 files changed, 1339 insertions(+), 27 deletions(-) create mode 100644 plugins/oidc/config.go create mode 100644 plugins/oidc/config.pb.go create mode 100644 plugins/oidc/config.pb.validate.go create mode 100644 plugins/oidc/config.proto create mode 100644 plugins/oidc/config_test.go create mode 100644 plugins/oidc/filter.go create mode 100644 plugins/oidc/filter_test.go create mode 100644 plugins/tests/integration/oidc_test.go create mode 100644 plugins/tests/integration/testdata/services/hydra/hydra.yml diff --git a/Makefile b/Makefile index 7e083bd6..884e04f7 100644 --- a/Makefile +++ b/Makefile @@ -134,7 +134,7 @@ run-demo: -v $(PWD)/libgolang.so:/etc/libgolang.so \ -p 10000:10000 \ ${PROXY_IMAGE} \ - envoy -c /etc/demo.yaml --log-level debug + envoy -c /etc/demo.yaml --log-level info .PHONY: dev-tools dev-tools: @@ -242,7 +242,7 @@ verify-example: .PHONY: start-service start-service: - cd ./plugins/tests/integration/testdata/services && docker-compose up -d + cd ./plugins/tests/integration/testdata/services && docker-compose up -d --build # E2E KUBECTL ?= $(LOCALBIN)/kubectl diff --git a/go.mod b/go.mod index 1715e361..51d0dd64 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/agiledragon/gomonkey/v2 v2.11.0 github.com/casbin/casbin/v2 v2.82.0 github.com/cncf/xds/go v0.0.0-20231109132714-523115ebc101 + github.com/coreos/go-oidc/v3 v3.9.0 github.com/envoyproxy/envoy v1.29.1-0.20240208055117-b788e1a92347 github.com/envoyproxy/go-control-plane v0.11.2-0.20231019082134-6e4589f570e1 // version used by istio 1.20 github.com/envoyproxy/protoc-gen-validate v1.0.2 @@ -19,6 +20,7 @@ require ( github.com/golang/protobuf v1.5.3 github.com/google/cel-go v0.20.0 github.com/google/uuid v1.6.0 + github.com/gorilla/securecookie v1.1.2 github.com/jellydator/ttlcache/v3 v3.1.1 github.com/nacos-group/nacos-sdk-go v1.1.4 github.com/onsi/ginkgo/v2 v2.15.0 @@ -29,6 +31,8 @@ require ( github.com/spf13/viper v1.18.2 github.com/stretchr/testify v1.8.4 go.uber.org/zap v1.26.0 + golang.org/x/net v0.19.0 + golang.org/x/oauth2 v0.15.0 golang.org/x/text v0.14.0 golang.org/x/time v0.5.0 google.golang.org/genproto/googleapis/api v0.0.0-20231106174013-bbf56f31fb17 @@ -63,6 +67,7 @@ require ( github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/go-errors/errors v1.0.1 // indirect github.com/go-ini/ini v1.67.0 // indirect + github.com/go-jose/go-jose/v3 v3.0.1 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/jsonpointer v0.20.0 // indirect github.com/go-openapi/jsonreference v0.20.2 // indirect @@ -114,9 +119,8 @@ require ( go.opentelemetry.io/otel/sdk v1.21.0 // indirect go.opentelemetry.io/otel/trace v1.21.0 // indirect go.uber.org/multierr v1.11.0 // indirect + golang.org/x/crypto v0.16.0 // indirect golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 // indirect - golang.org/x/net v0.19.0 // indirect - golang.org/x/oauth2 v0.15.0 // indirect golang.org/x/sync v0.5.0 // indirect golang.org/x/sys v0.16.0 // indirect golang.org/x/term v0.15.0 // indirect diff --git a/go.sum b/go.sum index 30b5ca06..e42ef180 100644 --- a/go.sum +++ b/go.sum @@ -41,6 +41,8 @@ github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMn github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/xds/go v0.0.0-20231109132714-523115ebc101 h1:7To3pQ+pZo0i3dsWEbinPNFs5gPSBOsJtx3wTT94VBY= github.com/cncf/xds/go v0.0.0-20231109132714-523115ebc101/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= +github.com/coreos/go-oidc/v3 v3.9.0 h1:0J/ogVOd4y8P0f0xUh8l9t07xRP/d8tccvjHl2dcsSo= +github.com/coreos/go-oidc/v3 v3.9.0/go.mod h1:rTKz2PYwftcrtoCzV5g5kvfJoWcm0Mk8AF8y1iAQro4= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -84,6 +86,8 @@ github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6 github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= +github.com/go-jose/go-jose/v3 v3.0.1 h1:pWmKFVtt+Jl0vBZTIpz/eAKwsm6LkIxDVVbFHKkchhA= +github.com/go-jose/go-jose/v3 v3.0.1/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= @@ -132,6 +136,7 @@ github.com/google/flatbuffers v1.12.1/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv github.com/google/gnostic-models v0.6.8 h1:yo/ABAfM5IMRsS1VnXjTBvUb61tFIHozhlYvRgGre9I= github.com/google/gnostic-models v0.6.8/go.mod h1:5n7qKqH0f5wFt+aWF8CW6pZLLNOfYuF5OpfBSENuI8U= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= @@ -147,6 +152,8 @@ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= +github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= +github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo= github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 h1:YBftPWNWd4WwGqtY2yeZL2ef8rHAxPBD8KFhJpmcqms= github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0/go.mod h1:YN5jB8ie0yfIUg6VvR9Kz84aCaG7AsGZnLjhHbUqwPg= @@ -306,9 +313,12 @@ go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo= go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= +golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 h1:+iq7lrkxmFNBM7xx+Rae2W6uyPfhPeDWD+n+JgppptE= golang.org/x/exp v0.0.0-20231219180239-dc181d75b848/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= diff --git a/pkg/filtermanager/filtermanager.go b/pkg/filtermanager/filtermanager.go index f8802fc6..58af4adb 100644 --- a/pkg/filtermanager/filtermanager.go +++ b/pkg/filtermanager/filtermanager.go @@ -51,7 +51,7 @@ type FilterManagerConfig struct { } type filterManagerConfig struct { - authnFiltersEndAt int + consumerFiltersEndAt int current []*model.ParsedFilterConfig pool *sync.Pool @@ -110,7 +110,7 @@ func (p *FilterManagerConfigParser) Parse(any *anypb.Any, callbacks capi.ConfigC conf := initFilterManagerConfig(fmConfig.Namespace) conf.current = make([]*model.ParsedFilterConfig, 0, len(plugins)) - authnFiltersEndAt := 0 + consumerFiltersEndAt := 0 i := 0 for _, proto := range plugins { @@ -119,7 +119,7 @@ func (p *FilterManagerConfigParser) Parse(any *anypb.Any, callbacks capi.ConfigC // For now, we have nothing to provide as config callbacks config, err := plugin.ConfigParser.Parse(proto.Config, nil) if err != nil { - api.LogErrorf("%w during parsing plugin %s in filtermanager", err, name) + api.LogErrorf("%s during parsing plugin %s in filtermanager", err, name) // Return an error from the Parse method will cause assertion failure. // See https://github.com/envoyproxy/envoy/blob/f301eebf7acc680e27e03396a1be6be77e1ae3a5/contrib/golang/filters/http/source/golang_filter.cc#L1736-L1737 @@ -137,9 +137,9 @@ func (p *FilterManagerConfigParser) Parse(any *anypb.Any, callbacks capi.ConfigC Factory: plugin.Factory, }) - p := pkgPlugins.LoadHttpPlugin(name) - if p.Order().Position == pkgPlugins.OrderPositionAuthn { - authnFiltersEndAt = i + 1 + _, ok := pkgPlugins.LoadHttpPlugin(name).(pkgPlugins.ConsumerPlugin) + if ok { + consumerFiltersEndAt = i + 1 } } i++ @@ -148,7 +148,7 @@ func (p *FilterManagerConfigParser) Parse(any *anypb.Any, callbacks capi.ConfigC api.LogErrorf("plugin %s not found, ignored", name) } } - conf.authnFiltersEndAt = authnFiltersEndAt + conf.consumerFiltersEndAt = consumerFiltersEndAt return conf, nil } @@ -181,8 +181,8 @@ func newFilterWrapper(name string, f api.Filter) *filterWrapper { } type filterManager struct { - filters []*filterWrapper - authnFilters []*filterWrapper + filters []*filterWrapper + consumerFilters []*filterWrapper decodeRequestNeeded bool decodeIdx int @@ -208,7 +208,7 @@ type filterManager struct { func (m *filterManager) Reset() { m.filters = nil - m.authnFilters = nil + m.consumerFilters = nil m.decodeRequestNeeded = false m.decodeIdx = -1 @@ -373,11 +373,11 @@ func FilterManagerFactory(c interface{}, cb capi.FilterCallbackHandler) capi.Str fm.filters = filters - if conf.authnFiltersEndAt != 0 { - authnFiltersEndAt := conf.authnFiltersEndAt - authnFilters := filters[:authnFiltersEndAt] - fm.authnFilters = authnFilters - fm.filters = filters[authnFiltersEndAt:] + if conf.consumerFiltersEndAt != 0 { + consumerFiltersEndAt := conf.consumerFiltersEndAt + consumerFilters := filters[:consumerFiltersEndAt] + fm.consumerFilters = consumerFilters + fm.filters = filters[consumerFiltersEndAt:] } // The skip check is based on the compiled code. So if the DecodeRequest is defined, @@ -474,9 +474,9 @@ func (m *filterManager) DecodeHeaders(headers api.RequestHeaderMap, endStream bo var res api.ResultAction m.reqHdr = headers - if len(m.authnFilters) > 0 { - for _, f := range m.authnFilters { - // Authn plugins only use DecodeHeaders for now + if len(m.consumerFilters) > 0 { + for _, f := range m.consumerFilters { + // Consumer plugins only use DecodeHeaders for now res = f.DecodeHeaders(headers, endStream) if m.handleAction(res, phaseDecodeHeaders) { return diff --git a/pkg/filtermanager/filtermanager_test.go b/pkg/filtermanager/filtermanager_test.go index 06a0761a..1f0298bd 100644 --- a/pkg/filtermanager/filtermanager_test.go +++ b/pkg/filtermanager/filtermanager_test.go @@ -346,7 +346,7 @@ func (f *addReqFilter) DecodeHeaders(headers api.RequestHeaderMap, endStream boo func TestFiltersFromConsumer(t *testing.T) { cb := envoy.NewCAPIFilterCallbackHandler() config := initFilterManagerConfig("ns") - config.authnFiltersEndAt = 1 + config.consumerFiltersEndAt = 1 config.current = []*model.ParsedFilterConfig{ { Name: "set_consumer", diff --git a/pkg/request/request.go b/pkg/request/request.go index 05f5755e..7a6636ac 100644 --- a/pkg/request/request.go +++ b/pkg/request/request.go @@ -16,13 +16,17 @@ package request import ( "fmt" + "net/http" + "net/textproto" "net/url" + "strings" "github.com/envoyproxy/envoy/contrib/golang/common/go/api" + "golang.org/x/net/http/httpguts" ) -func GetUrl(header api.RequestHeaderMap) *url.URL { - path := header.Path() +func GetUrl(headers api.RequestHeaderMap) *url.URL { + path := headers.Path() // TODO: cache it uri, err := url.ParseRequestURI(path) if err != nil { @@ -31,6 +35,69 @@ func GetUrl(header api.RequestHeaderMap) *url.URL { return uri } +// The cookie parser is from Go's http/cookie.go, which are not exported + +func isNotToken(r rune) bool { + return !httpguts.IsTokenRune(r) +} + +func isCookieNameValid(raw string) bool { + if raw == "" { + return false + } + return strings.IndexFunc(raw, isNotToken) < 0 +} + +func validCookieValueByte(b byte) bool { + return 0x20 <= b && b < 0x7f && b != '"' && b != ';' && b != '\\' +} + +func parseCookieValue(raw string, allowDoubleQuote bool) (string, bool) { + // Strip the quotes, if present. + if allowDoubleQuote && len(raw) > 1 && raw[0] == '"' && raw[len(raw)-1] == '"' { + raw = raw[1 : len(raw)-1] + } + for i := 0; i < len(raw); i++ { + if !validCookieValueByte(raw[i]) { + return "", false + } + } + return raw, true +} + +// If multiple cookies match the given name, only one cookie will be returned. +func GetCookies(headers api.RequestHeaderMap) map[string]*http.Cookie { + lines := headers.Values("Cookie") + if len(lines) == 0 { + return map[string]*http.Cookie{} + } + + cookies := make(map[string]*http.Cookie, len(lines)+strings.Count(lines[0], ";")) + for _, line := range lines { + line = textproto.TrimString(line) + + var part string + for len(line) > 0 { // continue since we have rest + part, line, _ = strings.Cut(line, ";") + part = textproto.TrimString(part) + if part == "" { + continue + } + name, val, _ := strings.Cut(part, "=") + name = textproto.TrimString(name) + if !isCookieNameValid(name) { + continue + } + val, ok := parseCookieValue(val, true) + if !ok { + continue + } + cookies[name] = &http.Cookie{Name: name, Value: val} + } + } + return cookies +} + // GetHeaders returns a plain map represents the headers. The returned headers won't // contain any pseudo header like `:authority`. func GetHeaders(header api.RequestHeaderMap) map[string][]string { diff --git a/plugins/oidc/config.go b/plugins/oidc/config.go new file mode 100644 index 00000000..c23df2c9 --- /dev/null +++ b/plugins/oidc/config.go @@ -0,0 +1,93 @@ +// Copyright The HTNN Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package oidc + +import ( + "context" + "net/http" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/gorilla/securecookie" + "golang.org/x/oauth2" + + "mosn.io/htnn/pkg/filtermanager/api" + "mosn.io/htnn/pkg/plugins" +) + +const ( + Name = "oidc" +) + +func init() { + plugins.RegisterHttpPlugin(Name, &plugin{}) +} + +type plugin struct { + plugins.PluginMethodDefaultImpl +} + +func (p *plugin) Type() plugins.PluginType { + return plugins.TypeAuthn +} + +func (p *plugin) Order() plugins.PluginOrder { + return plugins.PluginOrder{ + Position: plugins.OrderPositionAuthn, + } +} + +func (p *plugin) Factory() api.FilterFactory { + return factory +} + +func (p *plugin) Config() api.PluginConfig { + return &config{} +} + +type config struct { + Config + + oauth2Config *oauth2.Config + verifier *oidc.IDTokenVerifier + cookieEncoding *securecookie.SecureCookie +} + +func ctxWithClient(ctx context.Context) context.Context { + httpClient := &http.Client{Timeout: 3 * time.Second} + return context.WithValue(ctx, oauth2.HTTPClient, httpClient) +} + +func (conf *config) Init(cb api.ConfigCallbackHandler) error { + ctx := ctxWithClient(context.Background()) + provider, err := oidc.NewProvider(ctx, conf.Issuer) + if err != nil { + return err + } + + conf.oauth2Config = &oauth2.Config{ + ClientID: conf.ClientId, + ClientSecret: conf.ClientSecret, + // ScopeOpenID is the mandatory scope for all OpenID Connect OAuth2 requests. + Scopes: append([]string{oidc.ScopeOpenID}, conf.Scopes...), + RedirectURL: conf.RedirectUrl, + + // Discovery returns the OAuth2 endpoints. + Endpoint: provider.Endpoint(), + } + conf.verifier = provider.Verifier(&oidc.Config{ClientID: conf.ClientId}) + conf.cookieEncoding = securecookie.New([]byte(conf.ClientSecret), nil) + return nil +} diff --git a/plugins/oidc/config.pb.go b/plugins/oidc/config.pb.go new file mode 100644 index 00000000..6091ec7f --- /dev/null +++ b/plugins/oidc/config.pb.go @@ -0,0 +1,219 @@ +// Copyright The HTNN Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.31.0 +// protoc v4.24.4 +// source: plugins/oidc/config.proto + +package oidc + +import ( + reflect "reflect" + sync "sync" + + _ "github.com/envoyproxy/protoc-gen-validate/validate" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Config struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ClientId string `protobuf:"bytes,1,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"` + ClientSecret string `protobuf:"bytes,2,opt,name=client_secret,json=clientSecret,proto3" json:"client_secret,omitempty"` + // The issuer is the URL identifier for the service. For example: "https://accounts.google.com" + // or "https://login.salesforce.com". + Issuer string `protobuf:"bytes,3,opt,name=issuer,proto3" json:"issuer,omitempty"` + // The configured URL MUST exactly match one of the Redirection URI values + // for the Client pre-registered at the OpenID Provider + RedirectUrl string `protobuf:"bytes,4,opt,name=redirect_url,json=redirectUrl,proto3" json:"redirect_url,omitempty"` + Scopes []string `protobuf:"bytes,5,rep,name=scopes,proto3" json:"scopes,omitempty"` + // This option is provided to skip the nonce verification. It is designed for local development. + SkipNonceVerify bool `protobuf:"varint,6,opt,name=skip_nonce_verify,json=skipNonceVerify,proto3" json:"skip_nonce_verify,omitempty"` +} + +func (x *Config) Reset() { + *x = Config{} + if protoimpl.UnsafeEnabled { + mi := &file_plugins_oidc_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_plugins_oidc_config_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Config.ProtoReflect.Descriptor instead. +func (*Config) Descriptor() ([]byte, []int) { + return file_plugins_oidc_config_proto_rawDescGZIP(), []int{0} +} + +func (x *Config) GetClientId() string { + if x != nil { + return x.ClientId + } + return "" +} + +func (x *Config) GetClientSecret() string { + if x != nil { + return x.ClientSecret + } + return "" +} + +func (x *Config) GetIssuer() string { + if x != nil { + return x.Issuer + } + return "" +} + +func (x *Config) GetRedirectUrl() string { + if x != nil { + return x.RedirectUrl + } + return "" +} + +func (x *Config) GetScopes() []string { + if x != nil { + return x.Scopes + } + return nil +} + +func (x *Config) GetSkipNonceVerify() bool { + if x != nil { + return x.SkipNonceVerify + } + return false +} + +var File_plugins_oidc_config_proto protoreflect.FileDescriptor + +var file_plugins_oidc_config_proto_rawDesc = []byte{ + 0x0a, 0x19, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2f, 0x6f, 0x69, 0x64, 0x63, 0x2f, 0x63, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0c, 0x70, 0x6c, 0x75, + 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x6f, 0x69, 0x64, 0x63, 0x1a, 0x17, 0x76, 0x61, 0x6c, 0x69, 0x64, + 0x61, 0x74, 0x65, 0x2f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x22, 0xef, 0x01, 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, + 0x09, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x42, 0x07, 0xfa, 0x42, 0x04, 0x72, 0x02, 0x10, 0x01, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, + 0x74, 0x49, 0x64, 0x12, 0x2c, 0x0a, 0x0d, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x73, 0x65, + 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x07, 0xfa, 0x42, 0x04, 0x72, + 0x02, 0x10, 0x01, 0x52, 0x0c, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, + 0x74, 0x12, 0x20, 0x0a, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x09, 0x42, 0x08, 0xfa, 0x42, 0x05, 0x72, 0x03, 0x88, 0x01, 0x01, 0x52, 0x06, 0x69, 0x73, 0x73, + 0x75, 0x65, 0x72, 0x12, 0x2b, 0x0a, 0x0c, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x5f, + 0x75, 0x72, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x42, 0x08, 0xfa, 0x42, 0x05, 0x72, 0x03, + 0x88, 0x01, 0x01, 0x52, 0x0b, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x72, 0x6c, + 0x12, 0x16, 0x0a, 0x06, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x06, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x73, 0x12, 0x2a, 0x0a, 0x11, 0x73, 0x6b, 0x69, 0x70, + 0x5f, 0x6e, 0x6f, 0x6e, 0x63, 0x65, 0x5f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x79, 0x18, 0x06, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x0f, 0x73, 0x6b, 0x69, 0x70, 0x4e, 0x6f, 0x6e, 0x63, 0x65, 0x56, 0x65, + 0x72, 0x69, 0x66, 0x79, 0x42, 0x1b, 0x5a, 0x19, 0x6d, 0x6f, 0x73, 0x6e, 0x2e, 0x69, 0x6f, 0x2f, + 0x68, 0x74, 0x6e, 0x6e, 0x2f, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2f, 0x6f, 0x69, 0x64, + 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_plugins_oidc_config_proto_rawDescOnce sync.Once + file_plugins_oidc_config_proto_rawDescData = file_plugins_oidc_config_proto_rawDesc +) + +func file_plugins_oidc_config_proto_rawDescGZIP() []byte { + file_plugins_oidc_config_proto_rawDescOnce.Do(func() { + file_plugins_oidc_config_proto_rawDescData = protoimpl.X.CompressGZIP(file_plugins_oidc_config_proto_rawDescData) + }) + return file_plugins_oidc_config_proto_rawDescData +} + +var file_plugins_oidc_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_plugins_oidc_config_proto_goTypes = []interface{}{ + (*Config)(nil), // 0: plugins.oidc.Config +} +var file_plugins_oidc_config_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_plugins_oidc_config_proto_init() } +func file_plugins_oidc_config_proto_init() { + if File_plugins_oidc_config_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_plugins_oidc_config_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Config); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_plugins_oidc_config_proto_rawDesc, + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_plugins_oidc_config_proto_goTypes, + DependencyIndexes: file_plugins_oidc_config_proto_depIdxs, + MessageInfos: file_plugins_oidc_config_proto_msgTypes, + }.Build() + File_plugins_oidc_config_proto = out.File + file_plugins_oidc_config_proto_rawDesc = nil + file_plugins_oidc_config_proto_goTypes = nil + file_plugins_oidc_config_proto_depIdxs = nil +} diff --git a/plugins/oidc/config.pb.validate.go b/plugins/oidc/config.pb.validate.go new file mode 100644 index 00000000..86e3a514 --- /dev/null +++ b/plugins/oidc/config.pb.validate.go @@ -0,0 +1,200 @@ +// Code generated by protoc-gen-validate. DO NOT EDIT. +// source: plugins/oidc/config.proto + +package oidc + +import ( + "bytes" + "errors" + "fmt" + "net" + "net/mail" + "net/url" + "regexp" + "sort" + "strings" + "time" + "unicode/utf8" + + "google.golang.org/protobuf/types/known/anypb" +) + +// ensure the imports are used +var ( + _ = bytes.MinRead + _ = errors.New("") + _ = fmt.Print + _ = utf8.UTFMax + _ = (*regexp.Regexp)(nil) + _ = (*strings.Reader)(nil) + _ = net.IPv4len + _ = time.Duration(0) + _ = (*url.URL)(nil) + _ = (*mail.Address)(nil) + _ = anypb.Any{} + _ = sort.Sort +) + +// Validate checks the field values on Config with the rules defined in the +// proto definition for this message. If any rules are violated, the first +// error encountered is returned, or nil if there are no violations. +func (m *Config) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on Config with the rules defined in the +// proto definition for this message. If any rules are violated, the result is +// a list of violation errors wrapped in ConfigMultiError, or nil if none found. +func (m *Config) ValidateAll() error { + return m.validate(true) +} + +func (m *Config) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + if utf8.RuneCountInString(m.GetClientId()) < 1 { + err := ConfigValidationError{ + field: "ClientId", + reason: "value length must be at least 1 runes", + } + if !all { + return err + } + errors = append(errors, err) + } + + if utf8.RuneCountInString(m.GetClientSecret()) < 1 { + err := ConfigValidationError{ + field: "ClientSecret", + reason: "value length must be at least 1 runes", + } + if !all { + return err + } + errors = append(errors, err) + } + + if uri, err := url.Parse(m.GetIssuer()); err != nil { + err = ConfigValidationError{ + field: "Issuer", + reason: "value must be a valid URI", + cause: err, + } + if !all { + return err + } + errors = append(errors, err) + } else if !uri.IsAbs() { + err := ConfigValidationError{ + field: "Issuer", + reason: "value must be absolute", + } + if !all { + return err + } + errors = append(errors, err) + } + + if uri, err := url.Parse(m.GetRedirectUrl()); err != nil { + err = ConfigValidationError{ + field: "RedirectUrl", + reason: "value must be a valid URI", + cause: err, + } + if !all { + return err + } + errors = append(errors, err) + } else if !uri.IsAbs() { + err := ConfigValidationError{ + field: "RedirectUrl", + reason: "value must be absolute", + } + if !all { + return err + } + errors = append(errors, err) + } + + // no validation rules for SkipNonceVerify + + if len(errors) > 0 { + return ConfigMultiError(errors) + } + + return nil +} + +// ConfigMultiError is an error wrapping multiple validation errors returned by +// Config.ValidateAll() if the designated constraints aren't met. +type ConfigMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m ConfigMultiError) Error() string { + var msgs []string + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m ConfigMultiError) AllErrors() []error { return m } + +// ConfigValidationError is the validation error returned by Config.Validate if +// the designated constraints aren't met. +type ConfigValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e ConfigValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e ConfigValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e ConfigValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e ConfigValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e ConfigValidationError) ErrorName() string { return "ConfigValidationError" } + +// Error satisfies the builtin error interface +func (e ConfigValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sConfig.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = ConfigValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = ConfigValidationError{} diff --git a/plugins/oidc/config.proto b/plugins/oidc/config.proto new file mode 100644 index 00000000..54f0bd5f --- /dev/null +++ b/plugins/oidc/config.proto @@ -0,0 +1,36 @@ +// Copyright The HTNN Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package plugins.oidc; + +import "validate/validate.proto"; + +option go_package = "mosn.io/htnn/plugins/oidc"; + +message Config { + string client_id = 1 [(validate.rules).string = {min_len: 1}]; + string client_secret = 2 [(validate.rules).string = {min_len: 1}]; + // The issuer is the URL identifier for the service. For example: "https://accounts.google.com" + // or "https://login.salesforce.com". + string issuer = 3 [(validate.rules).string = {uri: true}]; + // The configured URL MUST exactly match one of the Redirection URI values + // for the Client pre-registered at the OpenID Provider + string redirect_url = 4 [(validate.rules).string = {uri: true}]; + repeated string scopes = 5; + + // This option is provided to skip the nonce verification. It is designed for local development. + bool skip_nonce_verify = 6; +} diff --git a/plugins/oidc/config_test.go b/plugins/oidc/config_test.go new file mode 100644 index 00000000..918fd328 --- /dev/null +++ b/plugins/oidc/config_test.go @@ -0,0 +1,61 @@ +// Copyright The HTNN Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package oidc + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/encoding/protojson" +) + +func TestBadIssuer(t *testing.T) { + c := config{ + Config: Config{ + Issuer: "http://github.com", + }, + } + err := c.Init(nil) + assert.Error(t, err) +} + +func TestConfig(t *testing.T) { + tests := []struct { + name string + input string + err string + }{ + { + name: "bad issuer url", + input: `{"clientId":"a", "clientSecret":"b", "issuer":"google.com"}`, + err: "invalid Config.Issuer:", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conf := &config{} + err := protojson.Unmarshal([]byte(tt.input), conf) + if err == nil { + err = conf.Validate() + } + if tt.err == "" { + assert.Nil(t, err) + } else { + assert.ErrorContains(t, err, tt.err) + } + }) + } +} diff --git a/plugins/oidc/filter.go b/plugins/oidc/filter.go new file mode 100644 index 00000000..ea9272de --- /dev/null +++ b/plugins/oidc/filter.go @@ -0,0 +1,223 @@ +// Copyright The HTNN Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package oidc + +import ( + "context" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "golang.org/x/oauth2" + + "mosn.io/htnn/pkg/filtermanager/api" + "mosn.io/htnn/pkg/request" +) + +func factory(c interface{}, callbacks api.FilterCallbackHandler) api.Filter { + return &filter{ + callbacks: callbacks, + config: c.(*config), + } +} + +type filter struct { + api.PassThroughFilter + + callbacks api.FilterCallbackHandler + config *config +} + +type Tokens struct { + IDToken string `json:"id_token"` +} + +func generateState(verifier string, secret string, url string) string { + encodedRedirectUrl := base64.URLEncoding.EncodeToString([]byte(url)) + state := fmt.Sprintf("%s.%s", verifier, encodedRedirectUrl) + signature := signState(state, secret) + // fmt: verifier.originUrl.signature + return fmt.Sprintf("%s.%s", state, signature) +} + +func verifyState(state string, secret string) bool { + pieces := strings.Split(state, ".") + if len(pieces) != 3 { + return false + } + data := fmt.Sprintf("%s.%s", pieces[0], pieces[1]) + signature := signState(data, secret) + return pieces[2] == signature +} + +func signState(state string, secret string) string { + mac := hmac.New(sha256.New, []byte(secret)) + mac.Write([]byte(state)) + return base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) +} + +func (f *filter) handleInitRequest(headers api.RequestHeaderMap) api.ResultAction { + config := f.config + o2conf := config.oauth2Config + + b := make([]byte, 8) + _, _ = rand.Read(b) + nonce := base64.RawURLEncoding.EncodeToString(b) + verifier := oauth2.GenerateVerifier() + originUrl := fmt.Sprintf("%s://%s%s", headers.Scheme(), headers.Host(), headers.Path()) + s := generateState(verifier, config.ClientSecret, originUrl) + url := o2conf.AuthCodeURL(s, + // use PKCE to protect against CSRF attacks if possible + // https://www.ietf.org/archive/id/draft-ietf-oauth-security-topics-22.html#name-countermeasures-6 + oauth2.S256ChallengeOption(verifier), + oauth2.SetAuthURLParam("nonce", nonce)) + + n, err := config.cookieEncoding.Encode("htnn_oidc_nonce", nonce) + if err != nil { + api.LogErrorf("failed to encode cookie: %v", err) + return &api.LocalResponse{Code: 503, Msg: "failed to encode cookie"} + } + cookieNonce := &http.Cookie{ + Name: "htnn_oidc_nonce", + Value: n, + MaxAge: int(time.Hour.Seconds()), + HttpOnly: true, + // TODO: allow configuring the cookie attributes + } + + return &api.LocalResponse{ + Code: http.StatusFound, + Header: http.Header{ + "Location": []string{url}, + "Set-Cookie": []string{cookieNonce.String()}, + }, + } +} + +func (f *filter) handleCallback(headers api.RequestHeaderMap, query url.Values) api.ResultAction { + config := f.config + o2conf := config.oauth2Config + ctx := context.Background() + code := query.Get("code") + state := query.Get("state") + + // Here we provide the mechanism below to ensure the id token is client's: + // 1. sign the state to avoid being forged by the attacker + // 2. use PKCE to ensure the code is bound with the state, which is trusted after being verified + // 3. use nonce to ensure the id token is coming from the authorization request we initiated + if !verifyState(state, config.ClientSecret) { + api.LogInfof("bad state: %s", state) + return &api.LocalResponse{Code: 403, Msg: "bad state"} + } + verifier, encodedUrl, _ := strings.Cut(state, ".") + b, _ := base64.URLEncoding.DecodeString(encodedUrl) + originUrl := string(b) + + ctx = ctxWithClient(ctx) + oauth2Token, err := o2conf.Exchange(ctx, code, oauth2.VerifierOption(verifier)) + if err != nil { + api.LogErrorf("failed to exchange code to the token: %v", err) + return &api.LocalResponse{Code: 503, Msg: "failed to exchange code to the token"} + } + + // TODO: handle refresh_token + rawIDToken, ok := oauth2Token.Extra("id_token").(string) + if !ok { + api.LogErrorf("failed to lookup id token: %v", err) + return &api.LocalResponse{Code: 503, Msg: "failed to lookup id token"} + } + + idToken, err := config.verifier.Verify(ctx, rawIDToken) + if err != nil { + api.LogInfof("bad token: %s", err) + return &api.LocalResponse{Code: 403, Msg: "bad token"} + } + + if !config.SkipNonceVerify { + nonce, ok := request.GetCookies(headers)["htnn_oidc_nonce"] + if !ok { + api.LogInfof("bad nonce, expected %s", idToken.Nonce) + return &api.LocalResponse{Code: 403, Msg: "bad nonce"} + } + + var p string + err := config.cookieEncoding.Decode("htnn_oidc_nonce", nonce.Value, &p) + if err != nil || p != idToken.Nonce { + if err != nil { + api.LogInfof("bad nonce: %s, expected %s", err, idToken.Nonce) + } else { + api.LogInfof("bad nonce: %s, expected %s", p, idToken.Nonce) + } + return &api.LocalResponse{Code: 403, Msg: "bad nonce"} + } + } + + value := Tokens{ + IDToken: rawIDToken, + } + token, err := config.cookieEncoding.Encode("htnn_oidc_token", &value) + if err != nil { + api.LogErrorf("failed to encode cookie: %v", err) + return &api.LocalResponse{Code: 503, Msg: "failed to encode cookie"} + } + + cookie := &http.Cookie{ + Name: "htnn_oidc_token", + Value: token, + MaxAge: int(time.Until(idToken.Expiry).Seconds()), + HttpOnly: true, + } + return &api.LocalResponse{ + Code: http.StatusFound, + Header: http.Header{ + "Location": []string{originUrl}, + "Set-Cookie": []string{cookie.String()}, + }, + } +} + +func (f *filter) attachInfo(headers api.RequestHeaderMap, encodedToken string) api.ResultAction { + config := f.config + + value := Tokens{} + err := config.cookieEncoding.Decode("htnn_oidc_token", encodedToken, &value) + if err != nil { + api.LogInfof("bad oidc cookie: %s, err: %s", encodedToken, err.Error()) + return &api.LocalResponse{Code: 403, Msg: "bad oidc cookie"} + } + headers.Set("authorization", fmt.Sprintf("Bearer %s", value.IDToken)) + return api.Continue +} + +func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api.ResultAction { + token, ok := request.GetCookies(headers)["htnn_oidc_token"] + if ok { + return f.attachInfo(headers, token.Value) + } + + query := request.GetUrl(headers).Query() + code := query.Get("code") + if code == "" { + return f.handleInitRequest(headers) + } + + return f.handleCallback(headers, query) +} diff --git a/plugins/oidc/filter_test.go b/plugins/oidc/filter_test.go new file mode 100644 index 00000000..ddecc908 --- /dev/null +++ b/plugins/oidc/filter_test.go @@ -0,0 +1,202 @@ +// Copyright The HTNN Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package oidc + +import ( + "errors" + "net/http" + "strings" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/gorilla/securecookie" + "github.com/stretchr/testify/assert" + "golang.org/x/oauth2" + + "mosn.io/htnn/pkg/filtermanager/api" + "mosn.io/htnn/plugins/tests/pkg/envoy" +) + +func getCfg() *config { + return &config{ + Config: Config{ + ClientId: "9119df09-b20b-4c08-ba08-72472dda2cd2", + ClientSecret: "dSYo5hBwjX_DC57_tfZHlfrDel", + RedirectUrl: "http://127.0.0.1:10000", + }, + oauth2Config: &oauth2.Config{}, + verifier: &oidc.IDTokenVerifier{}, + cookieEncoding: securecookie.New([]byte("dSYo5hBwjX_DC57_tfZHlfrDel"), nil), + } +} + +func TestInitRequest(t *testing.T) { + conf := getCfg() + url := "http://host.docker.internal:4444/oauth2/auth?client_id=ef34cf65-016c-4b17-9864-8bd04dc22555&code_challenge=i3aZkytxb-6b4zvopxeT8AY21kon7EnJ7TlumdMlVuU&code_challenge_method=S256&nonce=yFyviTyEYAw&redirect_uri=http%3A%2F%2F127.0.0.1%3A10000%2Fecho&response_type=code&scope=openid&state=hqV183kqqtJxk_10F_5Y9" + patches := gomonkey.ApplyMethodReturn(conf.oauth2Config, "AuthCodeURL", url) + defer patches.Reset() + + cb := envoy.NewFilterCallbackHandler() + f := factory(getCfg(), cb).(*filter) + h := http.Header{} + hdr := envoy.NewRequestHeaderMap(h) + res := f.DecodeHeaders(hdr, true) + resp := res.(*api.LocalResponse) + assert.Equal(t, url, resp.Header.Get("Location")) + // other fields are checked in the integration test +} + +func TestCallback(t *testing.T) { + conf := getCfg() + verifier := oauth2.GenerateVerifier() + state := generateState(verifier, conf.ClientSecret, "https://127.0.0.1:2379/x?y=1") + rawIDToken := "rawIDToken" + token := (&oauth2.Token{}).WithExtra(map[string]interface{}{ + "id_token": rawIDToken, + }) + nonce, _ := conf.cookieEncoding.Encode("htnn_oidc_nonce", "xxx") + + tests := []struct { + name string + state string + cookie string + mock func() *gomonkey.Patches + res api.ResultAction + checkRedirectClientBack func(f *filter, headers http.Header) + }{ + { + name: "sanity", + state: state, + cookie: "htnn_oidc_nonce=" + nonce, + mock: func() *gomonkey.Patches { + patches := gomonkey.ApplyMethodReturn(conf.oauth2Config, "Exchange", token, nil) + patches.ApplyMethodReturn(conf.verifier, "Verify", &oidc.IDToken{ + Nonce: "xxx", Expiry: time.Now().Add(2 * time.Hour), + }, nil) + return patches + }, + checkRedirectClientBack: func(f *filter, headers http.Header) { + s := headers.Get("Location") + assert.Equal(t, "https://127.0.0.1:2379/x?y=1", s) + cookie := headers.Get("Set-Cookie") + assert.Contains(t, cookie, "Max-Age=7199;") + + // verify the cookie value + v := strings.Split(strings.Split(cookie, ";")[0], "=")[1] + h := http.Header{} + hdr := envoy.NewRequestHeaderMap(h) + assert.Equal(t, api.Continue, f.attachInfo(hdr, v)) + bearer, _ := hdr.Get("authorization") + assert.Equal(t, "Bearer rawIDToken", bearer) + }, + }, + { + name: "sanity", + state: state + "x", + cookie: "htnn_oidc_nonce=" + nonce, + res: &api.LocalResponse{Code: 403, Msg: "bad state"}, + }, + { + name: "failed to exchange", + state: state, + cookie: "htnn_oidc_nonce=" + nonce, + mock: func() *gomonkey.Patches { + patches := gomonkey.ApplyMethodReturn(conf.oauth2Config, "Exchange", nil, errors.New("timed out")) + return patches + }, + res: &api.LocalResponse{Code: 503, Msg: "failed to exchange code to the token"}, + }, + { + name: "failed to lookup token", + state: state, + cookie: "htnn_oidc_nonce=" + nonce, + mock: func() *gomonkey.Patches { + patches := gomonkey.ApplyMethodReturn(conf.oauth2Config, "Exchange", &oauth2.Token{}, nil) + return patches + }, + res: &api.LocalResponse{Code: 503, Msg: "failed to lookup id token"}, + }, + { + name: "bad token", + state: state, + cookie: "htnn_oidc_nonce=" + nonce, + mock: func() *gomonkey.Patches { + patches := gomonkey.ApplyMethodReturn(conf.oauth2Config, "Exchange", token, nil) + patches.ApplyMethodReturn(conf.verifier, "Verify", nil, errors.New("ouch")) + return patches + }, + res: &api.LocalResponse{Code: 403, Msg: "bad token"}, + }, + { + name: "bad nonce", + state: state, + cookie: "htnn_oidc_nonce=xxy", + mock: func() *gomonkey.Patches { + patches := gomonkey.ApplyMethodReturn(conf.oauth2Config, "Exchange", token, nil) + patches.ApplyMethodReturn(conf.verifier, "Verify", &oidc.IDToken{Nonce: "xxx"}, nil) + return patches + }, + res: &api.LocalResponse{Code: 403, Msg: "bad nonce"}, + }, + { + name: "bad nonce, no cookie", + state: state, + mock: func() *gomonkey.Patches { + patches := gomonkey.ApplyMethodReturn(conf.oauth2Config, "Exchange", token, nil) + patches.ApplyMethodReturn(conf.verifier, "Verify", &oidc.IDToken{Nonce: "xxx"}, nil) + return patches + }, + res: &api.LocalResponse{Code: 403, Msg: "bad nonce"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.mock != nil { + patches := tt.mock() + defer patches.Reset() + } + + cb := envoy.NewFilterCallbackHandler() + f := factory(getCfg(), cb).(*filter) + h := http.Header{} + h.Set(":path", "/echo?code=123&state="+tt.state) + h.Set("cookie", tt.cookie) + hdr := envoy.NewRequestHeaderMap(h) + res := f.DecodeHeaders(hdr, true) + if tt.res != nil { + assert.Equal(t, tt.res, res) + } + + if tt.checkRedirectClientBack != nil { + resp := res.(*api.LocalResponse) + tt.checkRedirectClientBack(f, resp.Header) + } + }) + } +} + +func TestAttachInfo(t *testing.T) { + cb := envoy.NewFilterCallbackHandler() + f := factory(getCfg(), cb).(*filter) + h := http.Header{} + h.Set("Cookie", "htnn_oidc_token=xxx") + hdr := envoy.NewRequestHeaderMap(h) + res := f.DecodeHeaders(hdr, true) + resp := res.(*api.LocalResponse) + assert.Equal(t, 403, resp.Code) + assert.Equal(t, "bad oidc cookie", resp.Msg) +} diff --git a/plugins/plugins.go b/plugins/plugins.go index 0a313527..3ceb4d49 100644 --- a/plugins/plugins.go +++ b/plugins/plugins.go @@ -24,5 +24,6 @@ import ( _ "mosn.io/htnn/plugins/key_auth" _ "mosn.io/htnn/plugins/limit_count_redis" _ "mosn.io/htnn/plugins/limit_req" + _ "mosn.io/htnn/plugins/oidc" _ "mosn.io/htnn/plugins/opa" ) diff --git a/plugins/tests/integration/data_plane/data_plane.go b/plugins/tests/integration/data_plane/data_plane.go index 7aa192dd..1e7e072d 100644 --- a/plugins/tests/integration/data_plane/data_plane.go +++ b/plugins/tests/integration/data_plane/data_plane.go @@ -294,7 +294,12 @@ func (dp *DataPlane) do(method string, path string, header http.Header, body io. return net.DialTimeout("tcp", ":10000", 1*time.Second) }} - client := &http.Client{Transport: tr, Timeout: 10 * time.Second} + client := &http.Client{Transport: tr, + Timeout: 10 * time.Second, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } resp, err := client.Do(req) return resp, err } diff --git a/plugins/tests/integration/oidc_test.go b/plugins/tests/integration/oidc_test.go new file mode 100644 index 00000000..b0dbc119 --- /dev/null +++ b/plugins/tests/integration/oidc_test.go @@ -0,0 +1,113 @@ +// Copyright The HTNN Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "encoding/base64" + "encoding/json" + "net/http" + "net/url" + "os/exec" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "mosn.io/htnn/plugins/tests/integration/control_plane" + "mosn.io/htnn/plugins/tests/integration/data_plane" + "mosn.io/htnn/plugins/tests/integration/helper" +) + +func TestOIDC(t *testing.T) { + dp, err := data_plane.StartDataPlane(t, &data_plane.Option{}) + if err != nil { + t.Fatalf("failed to start data plane: %v", err) + return + } + defer dp.Stop() + + helper.WaitServiceUp(t, ":4444", "hydra") + + redirectUrl := "http://127.0.0.1:10000/echo" + hydraCmd := "hydra create client --response-type code,id_token " + + "--grant-type authorization_code,refresh_token -e http://127.0.0.1:4445 " + + "--redirect-uri " + redirectUrl + " --format json" + cmdline := "docker compose -f ./testdata/services/docker-compose.yml " + + "exec --no-TTY hydra " + hydraCmd + cmds := strings.Fields(cmdline) + cmd := exec.Command(cmds[0], cmds[1:]...) + stdout, err := cmd.Output() + if err != nil { + reason := string(err.(*exec.ExitError).Stderr) + require.NoError(t, err, reason) + } + t.Logf("hydra output: %s", stdout) + + type hydraOutput struct { + ClientId string `json:"client_id"` + ClientSecret string `json:"client_secret"` + } + + var hydra hydraOutput + json.Unmarshal(stdout, &hydra) + + config := control_plane.NewSinglePluinConfig("oidc", map[string]interface{}{ + "clientId": hydra.ClientId, + "clientSecret": hydra.ClientSecret, + "redirectUrl": redirectUrl, + "issuer": "http://hydra:4444", + }) + controlPlane.UseGoPluginConfig(config, dp) + + uri := "" + var resp *http.Response + require.Eventually(t, func() bool { + resp, err = dp.Get("/echo?a=1", nil) + require.Nil(t, err) + uri = resp.Header.Get("Location") + return uri != "" + }, 5*time.Second, 1*time.Second) + + u, err := url.ParseRequestURI(uri) + require.NoError(t, err) + require.Equal(t, "hydra:4444", u.Host) + require.Equal(t, hydra.ClientId, u.Query().Get("client_id")) + require.Equal(t, redirectUrl, u.Query().Get("redirect_uri")) + encodedUrl := strings.Split(u.Query().Get("state"), ".")[1] + b, _ := base64.URLEncoding.DecodeString(encodedUrl) + originUrl := string(b) + require.Equal(t, "http://localhost:10000/echo?a=1", originUrl) + require.NotEmpty(t, u.Query().Get("nonce")) + require.NotEmpty(t, u.Query().Get("code_challenge")) + cookie := resp.Header.Get("Set-Cookie") + require.Regexp(t, `^htnn_oidc_nonce=[^;]+; Max-Age=3600; HttpOnly$`, cookie) + + // the request is sent from the host + uri = strings.Replace(uri, "http://hydra:4444", "http://127.0.0.1:4444", 1) + req, err := http.NewRequest("GET", uri, nil) + require.NoError(t, err) + + client := &http.Client{ + Timeout: 10 * time.Second, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + resp, err = client.Do(req) + require.NoError(t, err) + require.Equal(t, 302, resp.StatusCode) + require.Contains(t, resp.Header.Get("Location"), "http://127.0.0.1:3000/login") +} diff --git a/plugins/tests/integration/testdata/services/docker-compose.yml b/plugins/tests/integration/testdata/services/docker-compose.yml index ef69d50c..da0e4422 100644 --- a/plugins/tests/integration/testdata/services/docker-compose.yml +++ b/plugins/tests/integration/testdata/services/docker-compose.yml @@ -2,6 +2,54 @@ version: "3.8" services: # names in alphabetical order + hydra: + image: oryd/hydra:v2.2.0-rc.3 + ports: + - "4444:4444" # Public port + - "4445:4445" # Admin port + - "5555:5555" # Port for hydra token user + command: serve -c /etc/config/hydra/hydra.yml all --dev + volumes: + - type: volume + source: hydra-sqlite + target: /var/lib/sqlite + read_only: false + - type: bind + source: ./hydra + target: /etc/config/hydra + environment: + - DSN=sqlite:///var/lib/sqlite/db.sqlite?_fk=true + restart: unless-stopped + depends_on: + - hydra-migrate + networks: + service: + hydra-consent: + environment: + - HYDRA_ADMIN_URL=http://hydra:4445 + image: oryd/hydra-login-consent-node:v2.2.0-rc.3 + ports: + - "3000:3000" + restart: unless-stopped + networks: + service: + hydra-migrate: + image: oryd/hydra:v2.2.0-rc.3 + environment: + - DSN=sqlite:///var/lib/sqlite/db.sqlite?_fk=true + command: migrate -c /etc/config/hydra/hydra.yml sql -e --yes + volumes: + - type: volume + source: hydra-sqlite + target: /var/lib/sqlite + read_only: false + - type: bind + source: ./hydra + target: /etc/config/hydra + restart: on-failure + networks: + service: + opa: image: openpolicyagent/opa:0.58.0 restart: unless-stopped @@ -14,6 +62,7 @@ services: target: /test.rego networks: service: + redis: image: redis:latest restart: unless-stopped @@ -35,3 +84,6 @@ services: networks: service: + +volumes: + hydra-sqlite: diff --git a/plugins/tests/integration/testdata/services/hydra/hydra.yml b/plugins/tests/integration/testdata/services/hydra/hydra.yml new file mode 100644 index 00000000..8da3ff91 --- /dev/null +++ b/plugins/tests/integration/testdata/services/hydra/hydra.yml @@ -0,0 +1,22 @@ +serve: + cookies: + same_site_mode: Lax + +urls: + self: + issuer: http://hydra:4444 + consent: http://127.0.0.1:3000/consent + login: http://127.0.0.1:3000/login + logout: http://127.0.0.1:3000/logout + +secrets: + system: + - youReallyNeedToChangeThis + +oidc: + subject_identifiers: + supported_types: + - pairwise + - public + pairwise: + salt: youReallyNeedToChangeThis diff --git a/plugins/tests/pkg/envoy/capi.go b/plugins/tests/pkg/envoy/capi.go index e44ad904..9e40144f 100644 --- a/plugins/tests/pkg/envoy/capi.go +++ b/plugins/tests/pkg/envoy/capi.go @@ -106,7 +106,11 @@ func NewRequestHeaderMap(hdr http.Header) *RequestHeaderMap { } func (i *RequestHeaderMap) Scheme() string { - return "http" + scheme, ok := i.Get(":scheme") + if !ok { + return "http" + } + return scheme } func (i *RequestHeaderMap) Method() string {