Skip to content

Commit

Permalink
fix: set_response_header breaking openapi doc rendering (#977)
Browse files Browse the repository at this point in the history
* fix: set_response_header breaking openapi doc rendering

* docs: add docs

* update tests

* fix errors caused during the rename

---------

Co-authored-by: Sanskar Jethi <[email protected]>
  • Loading branch information
VishnuSanal and sansyrox authored Oct 27, 2024
1 parent 40f7f0a commit deea29a
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 16 deletions.
28 changes: 21 additions & 7 deletions docs_src/src/pages/documentation/api_reference/getting_started.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -558,12 +558,10 @@ Either, by using the `headers` field in the `Response` object:
### Global Response Headers

<Row>
<Col>
Or setting the Headers globally *per* router.
</Col>
<Col>

<Col>Or setting the Headers globally *per* router.</Col>

<Col>
<CodeGroup title="Request" tag="GET" label="/hello_world">

```python {{ title: 'untyped' }}
Expand All @@ -574,16 +572,14 @@ Or setting the Headers globally *per* router.
app.add_response_header("content-type", "application/json")
```
</CodeGroup>

</Col>
</Row>

<Row>
<Col>
`add_response_header` appends the header to the list of headers, while `set_response_header` replaces the header if it exists.
</Col>
<Col>


<CodeGroup title="Request" tag="GET" label="/hello_world">

```python {{ title: 'untyped' }}
Expand All @@ -594,7 +590,25 @@ Or setting the Headers globally *per* router.
app.set_response_header("content-type", "application/json")
```
</CodeGroup>

</Col>

<Col>
To prevent the headers from getting applied to certain endpoints, you can use the `exclude_response_headers_for` function.
</Col>
<Col>
<CodeGroup title="Request" tag="GET" label="/hello_world">

```python {{ title: 'untyped' }}
app.exclude_response_headers_for(["/login", "/signup"])
```

```python {{title: 'typed'}}
app.exclude_response_headers_for(["/login", "/signup"])
```
</CodeGroup>

</Col>
</Row>

### Cookies
Expand Down
41 changes: 34 additions & 7 deletions integration_tests/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,20 @@

@pytest.mark.benchmark
def test_docs_handler():
html_response = get("/docs")
# should_check_response = False because check_response raises a
# failure if the global headers are not present in the response
# provided we are excluding headers for /docs and /openapi.json
html_response = get("/docs", should_check_response=False)
assert html_response.status_code == 200


@pytest.mark.benchmark
def test_json_handler():
openapi_spec = get("/openapi.json").json()
openapi_response = get("/openapi.json", should_check_response=False)

assert openapi_response.status_code == 200

openapi_spec = openapi_response.json()

assert isinstance(openapi_spec, dict)
assert "openapi" in openapi_spec
Expand All @@ -24,7 +31,11 @@ def test_json_handler():

@pytest.mark.benchmark
def test_add_openapi_path():
openapi_spec = get("/openapi.json").json()
openapi_response = get("/openapi.json", should_check_response=False)

assert openapi_response.status_code == 200

openapi_spec = openapi_response.json()

assert isinstance(openapi_spec, dict)

Expand All @@ -41,7 +52,11 @@ def test_add_openapi_path():

@pytest.mark.benchmark
def test_add_subrouter_paths():
openapi_spec = get("/openapi.json").json()
openapi_response = get("/openapi.json", should_check_response=False)

assert openapi_response.status_code == 200

openapi_spec = openapi_response.json()

assert isinstance(openapi_spec, dict)

Expand All @@ -58,7 +73,11 @@ def test_add_subrouter_paths():

@pytest.mark.benchmark
def test_openapi_request_body():
openapi_spec = get("/openapi.json").json()
openapi_response = get("/openapi.json", should_check_response=False)

assert openapi_response.status_code == 200

openapi_spec = openapi_response.json()

assert isinstance(openapi_spec, dict)

Expand Down Expand Up @@ -118,7 +137,11 @@ def test_openapi_request_body():

@pytest.mark.benchmark
def test_openapi_response_body():
openapi_spec = get("/openapi.json").json()
openapi_response = get("/openapi.json", should_check_response=False)

assert openapi_response.status_code == 200

openapi_spec = openapi_response.json()

assert isinstance(openapi_spec, dict)

Expand Down Expand Up @@ -153,7 +176,11 @@ def test_openapi_response_body():

@pytest.mark.benchmark
def test_openapi_query_params():
openapi_spec = get("/openapi.json").json()
openapi_response = get("/openapi.json", should_check_response=False)

assert openapi_response.status_code == 200

openapi_spec = openapi_response.json()

assert isinstance(openapi_spec, dict)

Expand Down
10 changes: 10 additions & 0 deletions robyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
self.web_socket_router = WebSocketRouter()
self.request_headers: Headers = Headers({})
self.response_headers: Headers = Headers({})
self.excluded_response_headers_paths: Optional[List[str]] = None
self.directories: List[Directory] = []
self.event_handlers = {}
self.exception_handler: Optional[Callable] = None
Expand Down Expand Up @@ -200,6 +201,13 @@ def set_request_header(self, key: str, value: str) -> None:
def set_response_header(self, key: str, value: str) -> None:
self.response_headers.set(key, value)

def exclude_response_headers_for(self, excluded_response_header_paths: Optional[List[str]]):
"""
To exclude response headers from certain routes
@param exclude_paths: the paths to exclude response headers from
"""
self.excluded_response_header_paths = excluded_response_header_paths

def add_web_socket(self, endpoint: str, ws: WebSocket) -> None:
self.web_socket_router.add_route(endpoint, ws)

Expand Down Expand Up @@ -239,6 +247,7 @@ def _add_openapi_routes(self, auth_required: bool = False):
is_const=True,
auth_required=auth_required,
)
self.exclude_response_headers_for(["/docs", "/openapi.json"])

def start(self, host: str = "127.0.0.1", port: int = 8080, _check_port: bool = True):
"""
Expand Down Expand Up @@ -284,6 +293,7 @@ def start(self, host: str = "127.0.0.1", port: int = 8080, _check_port: bool = T
self.config.workers,
self.config.processes,
self.response_headers,
self.excluded_response_header_paths,
open_browser,
)

Expand Down
10 changes: 9 additions & 1 deletion robyn/processpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import signal
import sys
import webbrowser
from typing import Dict, List
from typing import Dict, List, Optional

from multiprocess import Process

Expand All @@ -27,6 +27,7 @@ def run_processes(
workers: int,
processes: int,
response_headers: Headers,
excluded_response_headers_paths: Optional[List[str]],
open_browser: bool,
) -> List[Process]:
socket = SocketHeld(url, port)
Expand All @@ -43,6 +44,7 @@ def run_processes(
workers,
processes,
response_headers,
excluded_response_headers_paths,
)

def terminating_signal_handler(_sig, _frame):
Expand Down Expand Up @@ -76,6 +78,7 @@ def init_processpool(
workers: int,
processes: int,
response_headers: Headers,
excluded_response_headers_paths: Optional[List[str]],
) -> List[Process]:
process_pool = []
if sys.platform.startswith("win32") or processes == 1:
Expand All @@ -90,6 +93,7 @@ def init_processpool(
socket,
workers,
response_headers,
excluded_response_headers_paths,
)

return process_pool
Expand All @@ -109,6 +113,7 @@ def init_processpool(
copied_socket,
workers,
response_headers,
excluded_response_headers_paths,
),
)
process.start()
Expand Down Expand Up @@ -144,6 +149,7 @@ def spawn_process(
socket: SocketHeld,
workers: int,
response_headers: Headers,
excluded_response_headers_paths: Optional[List[str]],
):
"""
This function is called by the main process handler to create a server runtime.
Expand Down Expand Up @@ -173,6 +179,8 @@ def spawn_process(

server.apply_response_headers(response_headers)

server.set_response_headers_exclude_paths(excluded_response_headers_paths)

for route in routes:
route_type, endpoint, function, is_const = route
server.add_route(route_type, endpoint, function, is_const)
Expand Down
2 changes: 2 additions & 0 deletions robyn/robyn.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ class Server:
pass
def apply_response_headers(self, headers: Headers) -> None:
pass
def set_response_headers_exclude_paths(self, excluded_response_header_paths: Optional[list[str]] = None):
pass

def add_route(
self,
Expand Down
27 changes: 26 additions & 1 deletion src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ pub struct Server {
directories: Arc<RwLock<Vec<Directory>>>,
startup_handler: Option<Arc<FunctionInfo>>,
shutdown_handler: Option<Arc<FunctionInfo>>,
excluded_response_headers_paths: Option<Vec<String>>,
}

#[pymethods]
Expand All @@ -72,6 +73,7 @@ impl Server {
directories: Arc::new(RwLock::new(Vec::new())),
startup_handler: None,
shutdown_handler: None,
excluded_response_headers_paths: None,
}
}

Expand Down Expand Up @@ -108,6 +110,8 @@ impl Server {
let startup_handler = self.startup_handler.clone();
let shutdown_handler = self.shutdown_handler.clone();

let excluded_response_headers_paths = self.excluded_response_headers_paths.clone();

let task_locals = pyo3_asyncio::TaskLocals::new(event_loop).copy_context(py)?;
let task_locals_copy = task_locals.clone();

Expand Down Expand Up @@ -162,7 +166,8 @@ impl Server {
.app_data(web::Data::new(const_router.clone()))
.app_data(web::Data::new(middleware_router.clone()))
.app_data(web::Data::new(global_request_headers.clone()))
.app_data(web::Data::new(global_response_headers.clone()));
.app_data(web::Data::new(global_response_headers.clone()))
.app_data(web::Data::new(excluded_response_headers_paths.clone()));

let web_socket_map = web_socket_router.get_web_socket_map();
for (elem, value) in (web_socket_map.read()).iter() {
Expand Down Expand Up @@ -194,6 +199,7 @@ impl Server {
payload: web::Payload,
global_request_headers,
global_response_headers,
response_headers_exclude_paths,
req| {
pyo3_asyncio::tokio::scope_local(task_locals.clone(), async move {
index(
Expand All @@ -203,6 +209,7 @@ impl Server {
middleware_router,
global_request_headers,
global_response_headers,
response_headers_exclude_paths,
req,
)
.await
Expand Down Expand Up @@ -298,6 +305,13 @@ impl Server {
self.global_response_headers = Arc::new(headers.clone());
}

pub fn set_response_headers_exclude_paths(
&mut self,
excluded_response_headers_paths: Option<Vec<String>>,
) {
self.excluded_response_headers_paths = excluded_response_headers_paths;
}

/// Add a new route to the routing tables
/// can be called after the server has been started
pub fn add_route(
Expand Down Expand Up @@ -409,13 +423,15 @@ impl Default for Server {

/// This is our service handler. It receives a Request, routes on it
/// path, and returns a Future of a Response.
#[allow(clippy::too_many_arguments)]
async fn index(
router: web::Data<Arc<HttpRouter>>,
payload: web::Payload,
const_router: web::Data<Arc<ConstRouter>>,
middleware_router: web::Data<Arc<MiddlewareRouter>>,
global_request_headers: web::Data<Arc<Headers>>,
global_response_headers: web::Data<Arc<Headers>>,
excluded_response_headers_paths: web::Data<Option<Vec<String>>>,
req: HttpRequest,
) -> impl Responder {
let mut request = Request::from_actix_request(&req, payload, &global_request_headers).await;
Expand Down Expand Up @@ -480,6 +496,15 @@ async fn index(

response.headers.extend(&global_response_headers);

match &excluded_response_headers_paths.get_ref() {
None => {}
Some(excluded_response_headers_paths) => {
if excluded_response_headers_paths.contains(&req.uri().path().to_owned()) {
response.headers.clear();
}
}
}

debug!("Extended Response : {:?}", response);

// After middleware
Expand Down
4 changes: 4 additions & 0 deletions src/types/headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ impl Headers {
self.headers.remove(&key.to_lowercase());
}

pub fn clear(&mut self) {
self.headers.clear();
}

pub fn extend(&mut self, headers: &Headers) {
for iter in headers.headers.iter() {
let (key, values) = iter.pair();
Expand Down

0 comments on commit deea29a

Please sign in to comment.