diff --git a/pkg/auth/iam/iam.go b/pkg/auth/iam/iam.go index aa37d32..7fcc1db 100644 --- a/pkg/auth/iam/iam.go +++ b/pkg/auth/iam/iam.go @@ -300,25 +300,19 @@ func WithValidScope(scope string) FilterOption { // WithMatchedSubdomain filters request to a subdomain to match it with namespace in user's token func WithMatchedSubdomain(excludedNamespaces []string) FilterOption { return func(req *restful.Request, iamClient iam.Client, claims *iam.JWTClaims) error { - parsedURL, err := url.Parse(req.Request.Host) - if err != nil { - // error parsing means the request comes from internal call, for example service to service call - return nil - } - - part := strings.Split(parsedURL.Host, ".") + part := strings.Split(getHost(req.Request), ".") if len(part) < 3 { // url with subdomain should have at least 3 part, e.g. foo.example.com, otherwise we should not check it return nil } for _, excludedNS := range excludedNamespaces { - if excludedNS == claims.Namespace { + if strings.ToLower(excludedNS) == strings.ToLower(claims.Namespace) { return nil } } - if claims.Namespace == part[0] { + if strings.ToLower(claims.Namespace) == strings.ToLower(part[0]) { return nil } @@ -327,6 +321,17 @@ func WithMatchedSubdomain(excludedNamespaces []string) FilterOption { } } +func getHost(req *http.Request) string { + if !req.URL.IsAbs() { + host := req.Host + if i := strings.Index(host, ":"); i != -1 { + host = host[:i] + } + return host + } + return req.URL.Host +} + // parseAccessToken is used to read token from Authorization Header or Cookie. // it will return the token value and token from. func parseAccessToken(request *restful.Request) (string, string, error) { diff --git a/pkg/auth/iam/iam_test.go b/pkg/auth/iam/iam_test.go index 358ef9b..b040384 100644 --- a/pkg/auth/iam/iam_test.go +++ b/pkg/auth/iam/iam_test.go @@ -16,6 +16,7 @@ package iam import ( "net/http" + "net/url" "testing" "github.com/AccelByte/go-restful-plugins/v4/pkg/constant" @@ -161,3 +162,44 @@ func TestValidateRefererHeaderWithSubdomain(t *testing.T) { }) } } + +// nolint:paralleltest +func TestGetHost(t *testing.T) { + testcases := []struct { + name string + expected string + requestHost string + URLHost string + scheme string + }{ + { + name: "not_absolute_url", + requestHost: "host.example.com", + URLHost: "url.example.com", + expected: "host.example.com", + }, + { + name: "not_absolute_url_with_port", + requestHost: "host.example.com:80", + URLHost: "url.example.com", + expected: "host.example.com", + }, + { + name: "absolute_url_without_port", + requestHost: "host.example.com", + URLHost: "url.example.com", + scheme: "http", + expected: "url.example.com", + }, + } + + for _, testcase := range testcases { + t.Run(testcase.name, func(t *testing.T) { + req := &http.Request{ + Host: testcase.requestHost, + URL: &url.URL{Host: testcase.URLHost, Scheme: testcase.scheme}, + } + assert.Equal(t, testcase.expected, getHost(req)) + }) + } +}