From deea29aec9348a96d4e2d4c9207c7a2ceb0bdbce Mon Sep 17 00:00:00 2001
From: Vishnu Sanal T <50027064+VishnuSanal@users.noreply.github.com>
Date: Sun, 27 Oct 2024 19:57:10 +0530
Subject: [PATCH] fix: set_response_header breaking openapi doc rendering
(#977)
* fix: set_response_header breaking openapi doc rendering
* docs: add docs
* update tests
* fix errors caused during the rename
---------
Co-authored-by: Sanskar Jethi <29942790+sansyrox@users.noreply.github.com>
---
.../api_reference/getting_started.mdx | 28 +++++++++----
integration_tests/test_openapi.py | 41 +++++++++++++++----
robyn/__init__.py | 10 +++++
robyn/processpool.py | 10 ++++-
robyn/robyn.pyi | 2 +
src/server.rs | 27 +++++++++++-
src/types/headers.rs | 4 ++
7 files changed, 106 insertions(+), 16 deletions(-)
diff --git a/docs_src/src/pages/documentation/api_reference/getting_started.mdx b/docs_src/src/pages/documentation/api_reference/getting_started.mdx
index 77c6629d5..933a381fa 100644
--- a/docs_src/src/pages/documentation/api_reference/getting_started.mdx
+++ b/docs_src/src/pages/documentation/api_reference/getting_started.mdx
@@ -558,12 +558,10 @@ Either, by using the `headers` field in the `Response` object:
### Global Response Headers
-
-Or setting the Headers globally *per* router.
-
-
+Or setting the Headers globally *per* router.
+
```python {{ title: 'untyped' }}
@@ -574,16 +572,14 @@ Or setting the Headers globally *per* router.
app.add_response_header("content-type", "application/json")
```
+
-
-
`add_response_header` appends the header to the list of headers, while `set_response_header` replaces the header if it exists.
-
```python {{ title: 'untyped' }}
@@ -594,7 +590,25 @@ Or setting the Headers globally *per* router.
app.set_response_header("content-type", "application/json")
```
+
+
+
+To prevent the headers from getting applied to certain endpoints, you can use the `exclude_response_headers_for` function.
+
+
+
+
+ ```python {{ title: 'untyped' }}
+ app.exclude_response_headers_for(["/login", "/signup"])
+ ```
+
+ ```python {{title: 'typed'}}
+ app.exclude_response_headers_for(["/login", "/signup"])
+ ```
+
+
+
### Cookies
diff --git a/integration_tests/test_openapi.py b/integration_tests/test_openapi.py
index fb066f9c9..bf163b953 100644
--- a/integration_tests/test_openapi.py
+++ b/integration_tests/test_openapi.py
@@ -5,13 +5,20 @@
@pytest.mark.benchmark
def test_docs_handler():
- html_response = get("/docs")
+ # should_check_response = False because check_response raises a
+ # failure if the global headers are not present in the response
+ # provided we are excluding headers for /docs and /openapi.json
+ html_response = get("/docs", should_check_response=False)
assert html_response.status_code == 200
@pytest.mark.benchmark
def test_json_handler():
- openapi_spec = get("/openapi.json").json()
+ openapi_response = get("/openapi.json", should_check_response=False)
+
+ assert openapi_response.status_code == 200
+
+ openapi_spec = openapi_response.json()
assert isinstance(openapi_spec, dict)
assert "openapi" in openapi_spec
@@ -24,7 +31,11 @@ def test_json_handler():
@pytest.mark.benchmark
def test_add_openapi_path():
- openapi_spec = get("/openapi.json").json()
+ openapi_response = get("/openapi.json", should_check_response=False)
+
+ assert openapi_response.status_code == 200
+
+ openapi_spec = openapi_response.json()
assert isinstance(openapi_spec, dict)
@@ -41,7 +52,11 @@ def test_add_openapi_path():
@pytest.mark.benchmark
def test_add_subrouter_paths():
- openapi_spec = get("/openapi.json").json()
+ openapi_response = get("/openapi.json", should_check_response=False)
+
+ assert openapi_response.status_code == 200
+
+ openapi_spec = openapi_response.json()
assert isinstance(openapi_spec, dict)
@@ -58,7 +73,11 @@ def test_add_subrouter_paths():
@pytest.mark.benchmark
def test_openapi_request_body():
- openapi_spec = get("/openapi.json").json()
+ openapi_response = get("/openapi.json", should_check_response=False)
+
+ assert openapi_response.status_code == 200
+
+ openapi_spec = openapi_response.json()
assert isinstance(openapi_spec, dict)
@@ -118,7 +137,11 @@ def test_openapi_request_body():
@pytest.mark.benchmark
def test_openapi_response_body():
- openapi_spec = get("/openapi.json").json()
+ openapi_response = get("/openapi.json", should_check_response=False)
+
+ assert openapi_response.status_code == 200
+
+ openapi_spec = openapi_response.json()
assert isinstance(openapi_spec, dict)
@@ -153,7 +176,11 @@ def test_openapi_response_body():
@pytest.mark.benchmark
def test_openapi_query_params():
- openapi_spec = get("/openapi.json").json()
+ openapi_response = get("/openapi.json", should_check_response=False)
+
+ assert openapi_response.status_code == 200
+
+ openapi_spec = openapi_response.json()
assert isinstance(openapi_spec, dict)
diff --git a/robyn/__init__.py b/robyn/__init__.py
index 268be447d..09e9a204f 100644
--- a/robyn/__init__.py
+++ b/robyn/__init__.py
@@ -70,6 +70,7 @@ def __init__(
self.web_socket_router = WebSocketRouter()
self.request_headers: Headers = Headers({})
self.response_headers: Headers = Headers({})
+ self.excluded_response_headers_paths: Optional[List[str]] = None
self.directories: List[Directory] = []
self.event_handlers = {}
self.exception_handler: Optional[Callable] = None
@@ -200,6 +201,13 @@ def set_request_header(self, key: str, value: str) -> None:
def set_response_header(self, key: str, value: str) -> None:
self.response_headers.set(key, value)
+ def exclude_response_headers_for(self, excluded_response_header_paths: Optional[List[str]]):
+ """
+ To exclude response headers from certain routes
+ @param exclude_paths: the paths to exclude response headers from
+ """
+ self.excluded_response_header_paths = excluded_response_header_paths
+
def add_web_socket(self, endpoint: str, ws: WebSocket) -> None:
self.web_socket_router.add_route(endpoint, ws)
@@ -239,6 +247,7 @@ def _add_openapi_routes(self, auth_required: bool = False):
is_const=True,
auth_required=auth_required,
)
+ self.exclude_response_headers_for(["/docs", "/openapi.json"])
def start(self, host: str = "127.0.0.1", port: int = 8080, _check_port: bool = True):
"""
@@ -284,6 +293,7 @@ def start(self, host: str = "127.0.0.1", port: int = 8080, _check_port: bool = T
self.config.workers,
self.config.processes,
self.response_headers,
+ self.excluded_response_header_paths,
open_browser,
)
diff --git a/robyn/processpool.py b/robyn/processpool.py
index 85de6683e..1391d21bf 100644
--- a/robyn/processpool.py
+++ b/robyn/processpool.py
@@ -2,7 +2,7 @@
import signal
import sys
import webbrowser
-from typing import Dict, List
+from typing import Dict, List, Optional
from multiprocess import Process
@@ -27,6 +27,7 @@ def run_processes(
workers: int,
processes: int,
response_headers: Headers,
+ excluded_response_headers_paths: Optional[List[str]],
open_browser: bool,
) -> List[Process]:
socket = SocketHeld(url, port)
@@ -43,6 +44,7 @@ def run_processes(
workers,
processes,
response_headers,
+ excluded_response_headers_paths,
)
def terminating_signal_handler(_sig, _frame):
@@ -76,6 +78,7 @@ def init_processpool(
workers: int,
processes: int,
response_headers: Headers,
+ excluded_response_headers_paths: Optional[List[str]],
) -> List[Process]:
process_pool = []
if sys.platform.startswith("win32") or processes == 1:
@@ -90,6 +93,7 @@ def init_processpool(
socket,
workers,
response_headers,
+ excluded_response_headers_paths,
)
return process_pool
@@ -109,6 +113,7 @@ def init_processpool(
copied_socket,
workers,
response_headers,
+ excluded_response_headers_paths,
),
)
process.start()
@@ -144,6 +149,7 @@ def spawn_process(
socket: SocketHeld,
workers: int,
response_headers: Headers,
+ excluded_response_headers_paths: Optional[List[str]],
):
"""
This function is called by the main process handler to create a server runtime.
@@ -173,6 +179,8 @@ def spawn_process(
server.apply_response_headers(response_headers)
+ server.set_response_headers_exclude_paths(excluded_response_headers_paths)
+
for route in routes:
route_type, endpoint, function, is_const = route
server.add_route(route_type, endpoint, function, is_const)
diff --git a/robyn/robyn.pyi b/robyn/robyn.pyi
index f354f99bd..cd9367c2a 100644
--- a/robyn/robyn.pyi
+++ b/robyn/robyn.pyi
@@ -327,6 +327,8 @@ class Server:
pass
def apply_response_headers(self, headers: Headers) -> None:
pass
+ def set_response_headers_exclude_paths(self, excluded_response_header_paths: Optional[list[str]] = None):
+ pass
def add_route(
self,
diff --git a/src/server.rs b/src/server.rs
index 7d7771986..757b3b5e9 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -56,6 +56,7 @@ pub struct Server {
directories: Arc>>,
startup_handler: Option>,
shutdown_handler: Option>,
+ excluded_response_headers_paths: Option>,
}
#[pymethods]
@@ -72,6 +73,7 @@ impl Server {
directories: Arc::new(RwLock::new(Vec::new())),
startup_handler: None,
shutdown_handler: None,
+ excluded_response_headers_paths: None,
}
}
@@ -108,6 +110,8 @@ impl Server {
let startup_handler = self.startup_handler.clone();
let shutdown_handler = self.shutdown_handler.clone();
+ let excluded_response_headers_paths = self.excluded_response_headers_paths.clone();
+
let task_locals = pyo3_asyncio::TaskLocals::new(event_loop).copy_context(py)?;
let task_locals_copy = task_locals.clone();
@@ -162,7 +166,8 @@ impl Server {
.app_data(web::Data::new(const_router.clone()))
.app_data(web::Data::new(middleware_router.clone()))
.app_data(web::Data::new(global_request_headers.clone()))
- .app_data(web::Data::new(global_response_headers.clone()));
+ .app_data(web::Data::new(global_response_headers.clone()))
+ .app_data(web::Data::new(excluded_response_headers_paths.clone()));
let web_socket_map = web_socket_router.get_web_socket_map();
for (elem, value) in (web_socket_map.read()).iter() {
@@ -194,6 +199,7 @@ impl Server {
payload: web::Payload,
global_request_headers,
global_response_headers,
+ response_headers_exclude_paths,
req| {
pyo3_asyncio::tokio::scope_local(task_locals.clone(), async move {
index(
@@ -203,6 +209,7 @@ impl Server {
middleware_router,
global_request_headers,
global_response_headers,
+ response_headers_exclude_paths,
req,
)
.await
@@ -298,6 +305,13 @@ impl Server {
self.global_response_headers = Arc::new(headers.clone());
}
+ pub fn set_response_headers_exclude_paths(
+ &mut self,
+ excluded_response_headers_paths: Option>,
+ ) {
+ self.excluded_response_headers_paths = excluded_response_headers_paths;
+ }
+
/// Add a new route to the routing tables
/// can be called after the server has been started
pub fn add_route(
@@ -409,6 +423,7 @@ impl Default for Server {
/// This is our service handler. It receives a Request, routes on it
/// path, and returns a Future of a Response.
+#[allow(clippy::too_many_arguments)]
async fn index(
router: web::Data>,
payload: web::Payload,
@@ -416,6 +431,7 @@ async fn index(
middleware_router: web::Data>,
global_request_headers: web::Data>,
global_response_headers: web::Data>,
+ excluded_response_headers_paths: web::Data