From 1ee57fd3e20fc24e62d9a8ae7c88ab379a827442 Mon Sep 17 00:00:00 2001 From: Ahmed Reza Rafsanzani <31310263+medreza@users.noreply.github.com> Date: Tue, 14 Jun 2022 13:26:00 +0700 Subject: [PATCH] feat: support regex in allowed domains with prefix (#54) --- pkg/cors/cors_filter.go | 23 +++++++++++++++++++++++ pkg/cors/cors_filter_test.go | 11 +++++++++++ 2 files changed, 34 insertions(+) diff --git a/pkg/cors/cors_filter.go b/pkg/cors/cors_filter.go index 1fac724..8f0dc2f 100644 --- a/pkg/cors/cors_filter.go +++ b/pkg/cors/cors_filter.go @@ -15,6 +15,8 @@ package cors import ( + "errors" + "regexp" "strconv" "strings" @@ -39,6 +41,10 @@ type CrossOriginResourceSharing struct { Container *restful.Container } +const ( + AllowedDomainsRegexPrefix = "re:" +) + // Filter is a filter function that implements the CORS flow func (c CrossOriginResourceSharing) Filter(req *restful.Request, resp *restful.Response, chain *restful.FilterChain) { origin := req.Request.Header.Get(restful.HEADER_Origin) @@ -135,6 +141,15 @@ func (c CrossOriginResourceSharing) isOriginAllowed(origin string) bool { if domain == origin || domain == "*" { return true } + if strings.HasPrefix(domain, AllowedDomainsRegexPrefix) { + pattern, err := getPattern(domain) + if err != nil { + return false + } + if pattern.MatchString(origin) { + return true + } + } } return false @@ -159,3 +174,11 @@ func (c CrossOriginResourceSharing) isValidAccessControlRequestHeader(header str } return false } + +func getPattern(str string) (*regexp.Regexp, error) { + split := strings.Split(str, AllowedDomainsRegexPrefix) + if len(split) < 2 { + return nil, errors.New("pattern not found") + } + return regexp.Compile(split[1]) +} diff --git a/pkg/cors/cors_filter_test.go b/pkg/cors/cors_filter_test.go index c012de3..5e55229 100644 --- a/pkg/cors/cors_filter_test.go +++ b/pkg/cors/cors_filter_test.go @@ -91,6 +91,17 @@ func TestIsOriginAllowed(t *testing.T) { assert.True(t, corsWithWildcardAllowedDomain.isOriginAllowed("https://www.example.io.something")) assert.True(t, corsWithWildcardAllowedDomain.isOriginAllowed("https://www.example.io.something.io")) + // TEST 5: Allowed domains with regex + corsWithRegex := CrossOriginResourceSharing{ + AllowedDomains: []string{"re:https://([a-z0-9]+[.])*example.io$", "https://www.example.com"}, + } + assert.True(t, corsWithRegex.isOriginAllowed("https://www.example.io")) + assert.True(t, corsWithRegex.isOriginAllowed("https://subdomain.example.io")) + assert.True(t, corsWithRegex.isOriginAllowed("https://www.example.com")) + assert.False(t, corsWithRegex.isOriginAllowed("https://subdomain.example.com")) + assert.False(t, corsWithRegex.isOriginAllowed("https://www.example.net")) + assert.False(t, corsWithRegex.isOriginAllowed("https://subdomain.example.io.something")) + assert.False(t, corsWithRegex.isOriginAllowed("https://www.example.io.something.io")) } func TestIsValidAccessControlRequestMethod(t *testing.T) {