From deea29aec9348a96d4e2d4c9207c7a2ceb0bdbce Mon Sep 17 00:00:00 2001 From: Vishnu Sanal T <50027064+VishnuSanal@users.noreply.github.com> Date: Sun, 27 Oct 2024 19:57:10 +0530 Subject: [PATCH] fix: set_response_header breaking openapi doc rendering (#977) * fix: set_response_header breaking openapi doc rendering * docs: add docs * update tests * fix errors caused during the rename --------- Co-authored-by: Sanskar Jethi <29942790+sansyrox@users.noreply.github.com> --- .../api_reference/getting_started.mdx | 28 +++++++++---- integration_tests/test_openapi.py | 41 +++++++++++++++---- robyn/__init__.py | 10 +++++ robyn/processpool.py | 10 ++++- robyn/robyn.pyi | 2 + src/server.rs | 27 +++++++++++- src/types/headers.rs | 4 ++ 7 files changed, 106 insertions(+), 16 deletions(-) diff --git a/docs_src/src/pages/documentation/api_reference/getting_started.mdx b/docs_src/src/pages/documentation/api_reference/getting_started.mdx index 77c6629d5..933a381fa 100644 --- a/docs_src/src/pages/documentation/api_reference/getting_started.mdx +++ b/docs_src/src/pages/documentation/api_reference/getting_started.mdx @@ -558,12 +558,10 @@ Either, by using the `headers` field in the `Response` object: ### Global Response Headers - -Or setting the Headers globally *per* router. - - +Or setting the Headers globally *per* router. + ```python {{ title: 'untyped' }} @@ -574,16 +572,14 @@ Or setting the Headers globally *per* router. app.add_response_header("content-type", "application/json") ``` + - - `add_response_header` appends the header to the list of headers, while `set_response_header` replaces the header if it exists. - ```python {{ title: 'untyped' }} @@ -594,7 +590,25 @@ Or setting the Headers globally *per* router. app.set_response_header("content-type", "application/json") ``` + + + +To prevent the headers from getting applied to certain endpoints, you can use the `exclude_response_headers_for` function. + + + + + ```python {{ title: 'untyped' }} + app.exclude_response_headers_for(["/login", "/signup"]) + ``` + + ```python {{title: 'typed'}} + app.exclude_response_headers_for(["/login", "/signup"]) + ``` + + + ### Cookies diff --git a/integration_tests/test_openapi.py b/integration_tests/test_openapi.py index fb066f9c9..bf163b953 100644 --- a/integration_tests/test_openapi.py +++ b/integration_tests/test_openapi.py @@ -5,13 +5,20 @@ @pytest.mark.benchmark def test_docs_handler(): - html_response = get("/docs") + # should_check_response = False because check_response raises a + # failure if the global headers are not present in the response + # provided we are excluding headers for /docs and /openapi.json + html_response = get("/docs", should_check_response=False) assert html_response.status_code == 200 @pytest.mark.benchmark def test_json_handler(): - openapi_spec = get("/openapi.json").json() + openapi_response = get("/openapi.json", should_check_response=False) + + assert openapi_response.status_code == 200 + + openapi_spec = openapi_response.json() assert isinstance(openapi_spec, dict) assert "openapi" in openapi_spec @@ -24,7 +31,11 @@ def test_json_handler(): @pytest.mark.benchmark def test_add_openapi_path(): - openapi_spec = get("/openapi.json").json() + openapi_response = get("/openapi.json", should_check_response=False) + + assert openapi_response.status_code == 200 + + openapi_spec = openapi_response.json() assert isinstance(openapi_spec, dict) @@ -41,7 +52,11 @@ def test_add_openapi_path(): @pytest.mark.benchmark def test_add_subrouter_paths(): - openapi_spec = get("/openapi.json").json() + openapi_response = get("/openapi.json", should_check_response=False) + + assert openapi_response.status_code == 200 + + openapi_spec = openapi_response.json() assert isinstance(openapi_spec, dict) @@ -58,7 +73,11 @@ def test_add_subrouter_paths(): @pytest.mark.benchmark def test_openapi_request_body(): - openapi_spec = get("/openapi.json").json() + openapi_response = get("/openapi.json", should_check_response=False) + + assert openapi_response.status_code == 200 + + openapi_spec = openapi_response.json() assert isinstance(openapi_spec, dict) @@ -118,7 +137,11 @@ def test_openapi_request_body(): @pytest.mark.benchmark def test_openapi_response_body(): - openapi_spec = get("/openapi.json").json() + openapi_response = get("/openapi.json", should_check_response=False) + + assert openapi_response.status_code == 200 + + openapi_spec = openapi_response.json() assert isinstance(openapi_spec, dict) @@ -153,7 +176,11 @@ def test_openapi_response_body(): @pytest.mark.benchmark def test_openapi_query_params(): - openapi_spec = get("/openapi.json").json() + openapi_response = get("/openapi.json", should_check_response=False) + + assert openapi_response.status_code == 200 + + openapi_spec = openapi_response.json() assert isinstance(openapi_spec, dict) diff --git a/robyn/__init__.py b/robyn/__init__.py index 268be447d..09e9a204f 100644 --- a/robyn/__init__.py +++ b/robyn/__init__.py @@ -70,6 +70,7 @@ def __init__( self.web_socket_router = WebSocketRouter() self.request_headers: Headers = Headers({}) self.response_headers: Headers = Headers({}) + self.excluded_response_headers_paths: Optional[List[str]] = None self.directories: List[Directory] = [] self.event_handlers = {} self.exception_handler: Optional[Callable] = None @@ -200,6 +201,13 @@ def set_request_header(self, key: str, value: str) -> None: def set_response_header(self, key: str, value: str) -> None: self.response_headers.set(key, value) + def exclude_response_headers_for(self, excluded_response_header_paths: Optional[List[str]]): + """ + To exclude response headers from certain routes + @param exclude_paths: the paths to exclude response headers from + """ + self.excluded_response_header_paths = excluded_response_header_paths + def add_web_socket(self, endpoint: str, ws: WebSocket) -> None: self.web_socket_router.add_route(endpoint, ws) @@ -239,6 +247,7 @@ def _add_openapi_routes(self, auth_required: bool = False): is_const=True, auth_required=auth_required, ) + self.exclude_response_headers_for(["/docs", "/openapi.json"]) def start(self, host: str = "127.0.0.1", port: int = 8080, _check_port: bool = True): """ @@ -284,6 +293,7 @@ def start(self, host: str = "127.0.0.1", port: int = 8080, _check_port: bool = T self.config.workers, self.config.processes, self.response_headers, + self.excluded_response_header_paths, open_browser, ) diff --git a/robyn/processpool.py b/robyn/processpool.py index 85de6683e..1391d21bf 100644 --- a/robyn/processpool.py +++ b/robyn/processpool.py @@ -2,7 +2,7 @@ import signal import sys import webbrowser -from typing import Dict, List +from typing import Dict, List, Optional from multiprocess import Process @@ -27,6 +27,7 @@ def run_processes( workers: int, processes: int, response_headers: Headers, + excluded_response_headers_paths: Optional[List[str]], open_browser: bool, ) -> List[Process]: socket = SocketHeld(url, port) @@ -43,6 +44,7 @@ def run_processes( workers, processes, response_headers, + excluded_response_headers_paths, ) def terminating_signal_handler(_sig, _frame): @@ -76,6 +78,7 @@ def init_processpool( workers: int, processes: int, response_headers: Headers, + excluded_response_headers_paths: Optional[List[str]], ) -> List[Process]: process_pool = [] if sys.platform.startswith("win32") or processes == 1: @@ -90,6 +93,7 @@ def init_processpool( socket, workers, response_headers, + excluded_response_headers_paths, ) return process_pool @@ -109,6 +113,7 @@ def init_processpool( copied_socket, workers, response_headers, + excluded_response_headers_paths, ), ) process.start() @@ -144,6 +149,7 @@ def spawn_process( socket: SocketHeld, workers: int, response_headers: Headers, + excluded_response_headers_paths: Optional[List[str]], ): """ This function is called by the main process handler to create a server runtime. @@ -173,6 +179,8 @@ def spawn_process( server.apply_response_headers(response_headers) + server.set_response_headers_exclude_paths(excluded_response_headers_paths) + for route in routes: route_type, endpoint, function, is_const = route server.add_route(route_type, endpoint, function, is_const) diff --git a/robyn/robyn.pyi b/robyn/robyn.pyi index f354f99bd..cd9367c2a 100644 --- a/robyn/robyn.pyi +++ b/robyn/robyn.pyi @@ -327,6 +327,8 @@ class Server: pass def apply_response_headers(self, headers: Headers) -> None: pass + def set_response_headers_exclude_paths(self, excluded_response_header_paths: Optional[list[str]] = None): + pass def add_route( self, diff --git a/src/server.rs b/src/server.rs index 7d7771986..757b3b5e9 100644 --- a/src/server.rs +++ b/src/server.rs @@ -56,6 +56,7 @@ pub struct Server { directories: Arc>>, startup_handler: Option>, shutdown_handler: Option>, + excluded_response_headers_paths: Option>, } #[pymethods] @@ -72,6 +73,7 @@ impl Server { directories: Arc::new(RwLock::new(Vec::new())), startup_handler: None, shutdown_handler: None, + excluded_response_headers_paths: None, } } @@ -108,6 +110,8 @@ impl Server { let startup_handler = self.startup_handler.clone(); let shutdown_handler = self.shutdown_handler.clone(); + let excluded_response_headers_paths = self.excluded_response_headers_paths.clone(); + let task_locals = pyo3_asyncio::TaskLocals::new(event_loop).copy_context(py)?; let task_locals_copy = task_locals.clone(); @@ -162,7 +166,8 @@ impl Server { .app_data(web::Data::new(const_router.clone())) .app_data(web::Data::new(middleware_router.clone())) .app_data(web::Data::new(global_request_headers.clone())) - .app_data(web::Data::new(global_response_headers.clone())); + .app_data(web::Data::new(global_response_headers.clone())) + .app_data(web::Data::new(excluded_response_headers_paths.clone())); let web_socket_map = web_socket_router.get_web_socket_map(); for (elem, value) in (web_socket_map.read()).iter() { @@ -194,6 +199,7 @@ impl Server { payload: web::Payload, global_request_headers, global_response_headers, + response_headers_exclude_paths, req| { pyo3_asyncio::tokio::scope_local(task_locals.clone(), async move { index( @@ -203,6 +209,7 @@ impl Server { middleware_router, global_request_headers, global_response_headers, + response_headers_exclude_paths, req, ) .await @@ -298,6 +305,13 @@ impl Server { self.global_response_headers = Arc::new(headers.clone()); } + pub fn set_response_headers_exclude_paths( + &mut self, + excluded_response_headers_paths: Option>, + ) { + self.excluded_response_headers_paths = excluded_response_headers_paths; + } + /// Add a new route to the routing tables /// can be called after the server has been started pub fn add_route( @@ -409,6 +423,7 @@ impl Default for Server { /// This is our service handler. It receives a Request, routes on it /// path, and returns a Future of a Response. +#[allow(clippy::too_many_arguments)] async fn index( router: web::Data>, payload: web::Payload, @@ -416,6 +431,7 @@ async fn index( middleware_router: web::Data>, global_request_headers: web::Data>, global_response_headers: web::Data>, + excluded_response_headers_paths: web::Data>>, req: HttpRequest, ) -> impl Responder { let mut request = Request::from_actix_request(&req, payload, &global_request_headers).await; @@ -480,6 +496,15 @@ async fn index( response.headers.extend(&global_response_headers); + match &excluded_response_headers_paths.get_ref() { + None => {} + Some(excluded_response_headers_paths) => { + if excluded_response_headers_paths.contains(&req.uri().path().to_owned()) { + response.headers.clear(); + } + } + } + debug!("Extended Response : {:?}", response); // After middleware diff --git a/src/types/headers.rs b/src/types/headers.rs index 37503cf09..2e615c004 100644 --- a/src/types/headers.rs +++ b/src/types/headers.rs @@ -158,6 +158,10 @@ impl Headers { self.headers.remove(&key.to_lowercase()); } + pub fn clear(&mut self) { + self.headers.clear(); + } + pub fn extend(&mut self, headers: &Headers) { for iter in headers.headers.iter() { let (key, values) = iter.pair();