diff --git a/flask_cors/core.py b/flask_cors/core.py index 7654cb1..0b8e84b 100644 --- a/flask_cors/core.py +++ b/flask_cors/core.py @@ -111,7 +111,7 @@ def get_regexp_pattern(regexp): def get_cors_origins(options, request_origin): - origins = options.get('origins') + origins = retrieve_origins(options, request_origin) wildcard = r'.*' in origins # If the Origin header is not present terminate this set of steps. @@ -174,7 +174,8 @@ def get_allow_headers(options, acl_request_headers): def get_cors_headers(options, request_headers, request_method): - origins_to_set = get_cors_origins(options, request_headers.get('Origin')) + request_origin = request_headers.get('Origin') + origins_to_set = get_cors_origins(options, request_origin) headers = MultiDict() if not origins_to_set: # CORS is not enabled for this route @@ -211,11 +212,13 @@ def get_cors_headers(options, request_headers, request_method): # Only set header if the origin returned will vary dynamically, # i.e. if we are not returning an asterisk, and there are multiple # origins that can be matched. + origins = retrieve_origins(options, request_origin) + if headers[ACL_ORIGIN] == '*': pass - elif (len(options.get('origins')) > 1 or + elif (len(origins) > 1 or len(origins_to_set) > 1 or - any(map(probably_regex, options.get('origins')))): + any(map(probably_regex, origins))): headers.add('Vary', 'Origin') return MultiDict((k, v) for k, v in headers.items() if v) @@ -350,6 +353,11 @@ def ensure_iterable(inst): def sanitize_regex_param(param): return [re_fix(x) for x in ensure_iterable(param)] +def retrieve_origins(options, request_origin): + origins = options.get('origins') + if callable(origins): + origins = sanitize_regex_param(origins(request_origin)) + return origins def serialize_options(opts): """ @@ -362,12 +370,19 @@ def serialize_options(opts): LOG.warning("Unknown option passed to Flask-CORS: %s", key) # Ensure origins is a list of allowed origins with at least one entry. - options['origins'] = sanitize_regex_param(options.get('origins')) + origins = options.get('origins') + if not callable(origins): + origins = sanitize_regex_param(origins) # sanitize if not a fn + options['origins'] = origins # keep it as a function in options options['allow_headers'] = sanitize_regex_param(options.get('allow_headers')) # This is expressly forbidden by the spec. Raise a value error so people # don't get burned in production. - if r'.*' in options['origins'] and options['supports_credentials'] and options['send_wildcard']: + if (not callable(origins) + and r'.*' in origins + and options['supports_credentials'] + and options['send_wildcard'] + ): raise ValueError("Cannot use supports_credentials in conjunction with" "an origin string of '*'. See: " "http://www.w3.org/TR/cors/#resource-requests") diff --git a/tests/decorator/test_origins.py b/tests/decorator/test_origins.py index 5e7e20a..6296a47 100644 --- a/tests/decorator/test_origins.py +++ b/tests/decorator/test_origins.py @@ -16,6 +16,12 @@ from flask_cors.core import * letters = 'abcdefghijklmnopqrstuvwxyz' # string.letters is not PY3 compatible +dynamic_pattern = "dynamic" + +def _dynamic_origin(origin): + if dynamic_pattern in origin: + return origin + return "" class OriginsTestCase(FlaskCorsTestCase): def setUp(self): @@ -81,6 +87,11 @@ def test_regex_mixed_list(): def test_multiple_protocols(): return '' + @self.app.route('/test_dynamic_origin') + @cross_origin(origins=_dynamic_origin) + def test_dynamic_origin(): + return '' + def test_defaults_no_origin(self): ''' If there is no Origin header in the request, the Access-Control-Allow-Origin header should be '*' by default. @@ -199,6 +210,15 @@ def test_multiple_protocols(self): resp = self.get('test_multiple_protocols', origin='https://example.com') self.assertEqual('https://example.com', resp.headers.get(ACL_ORIGIN)) + def test_dynamic_header(self): + ''' If the origin contains the variable dynamic_pattern, the + Access-Control-Allow-Origin should be echoed + ''' + resp = self.get('/test_dynamic_origin', origin='http://foo.com') + self.assertEqual(resp.headers.get(ACL_ORIGIN), None) + resp = self.get('/test_dynamic_origin', origin='http://foo-dynamic.com') + self.assertEqual(resp.headers.get(ACL_ORIGIN), 'http://foo-dynamic.com') + if __name__ == "__main__": unittest.main()