Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: set_response_header breaking openapi doc rendering #977

Merged
merged 30 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
08035e2
fix: set_response_header breaking openapi doc rendering
VishnuSanal Oct 7, 2024
3385767
chore: cargo fmt
VishnuSanal Oct 7, 2024
aaf8704
feat: add `exclude_response_headers` function
VishnuSanal Oct 12, 2024
e7ba780
update tests
VishnuSanal Oct 12, 2024
f0ce35c
chore: cargo fmt & clippy
VishnuSanal Oct 12, 2024
7bd04a9
fix typng
VishnuSanal Oct 12, 2024
7f963f0
Merge branch 'main' into fix-content-type-openapi
VishnuSanal Oct 13, 2024
ee4ba8d
cleanup after merge conlicts
VishnuSanal Oct 13, 2024
19f0fe7
docs: add docs
VishnuSanal Oct 13, 2024
87c34ee
update docs & docstrings
VishnuSanal Oct 13, 2024
96900a2
exclude headers for /openapi.json
VishnuSanal Oct 13, 2024
0672fd8
update tests
VishnuSanal Oct 13, 2024
72e7c05
move `exclude_paths` to Server struct
VishnuSanal Oct 17, 2024
884dcda
fix(ci): fix cargo clippy warning https://rust-lang.github.io/rust-cl…
VishnuSanal Oct 17, 2024
b4a3471
chore: cargo fmt
VishnuSanal Oct 18, 2024
a210b3b
Merge branch 'refs/heads/main' into fix-content-type-openapi
VishnuSanal Oct 18, 2024
d9635df
fix(ci):
VishnuSanal Oct 18, 2024
1e108f2
update docs
VishnuSanal Oct 19, 2024
e54cad7
rename
VishnuSanal Oct 19, 2024
c4a6334
rename
VishnuSanal Oct 19, 2024
f08e0ff
chore formatting
VishnuSanal Oct 19, 2024
aa4c74b
rename
VishnuSanal Oct 21, 2024
ba73fd7
Apply suggestions from code review
VishnuSanal Oct 21, 2024
d9440ad
Apply suggestions from code review
VishnuSanal Oct 22, 2024
10531fb
rename
VishnuSanal Oct 22, 2024
351e952
rever tuples
VishnuSanal Oct 22, 2024
3f460d1
rename
VishnuSanal Oct 24, 2024
42c4af5
Apply suggestions from code review
VishnuSanal Oct 25, 2024
1428458
fix errors caused during the rename
VishnuSanal Oct 25, 2024
c43ef44
Merge branch 'refs/heads/main' into fix-content-type-openapi
VishnuSanal Oct 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions docs_src/src/pages/documentation/api_reference/getting_started.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -558,12 +558,10 @@ Either, by using the `headers` field in the `Response` object:
### Global Response Headers

<Row>
<Col>
Or setting the Headers globally *per* router.
</Col>
<Col>

<Col>Or setting the Headers globally *per* router.</Col>

<Col>
<CodeGroup title="Request" tag="GET" label="/hello_world">

```python {{ title: 'untyped' }}
Expand All @@ -574,16 +572,14 @@ Or setting the Headers globally *per* router.
app.add_response_header("content-type", "application/json")
```
</CodeGroup>

</Col>
</Row>

<Row>
<Col>
`add_response_header` appends the header to the list of headers, while `set_response_header` replaces the header if it exists.
</Col>
<Col>


<CodeGroup title="Request" tag="GET" label="/hello_world">

```python {{ title: 'untyped' }}
Expand All @@ -594,7 +590,25 @@ Or setting the Headers globally *per* router.
app.set_response_header("content-type", "application/json")
```
</CodeGroup>

</Col>

<Col>
To prevent the headers from getting applied to certain endpoints, you can use the `exclude_response_headers_for` function.
</Col>
<Col>
<CodeGroup title="Request" tag="GET" label="/hello_world">

```python {{ title: 'untyped' }}
app.exclude_response_headers_for(["/login", "/signup"])
```

```python {{title: 'typed'}}
app.exclude_response_headers_for(["/login", "/signup"])
```
</CodeGroup>

</Col>
</Row>

### Cookies
Expand Down
41 changes: 34 additions & 7 deletions integration_tests/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,20 @@

@pytest.mark.benchmark
def test_docs_handler():
html_response = get("/docs")
# should_check_response = False because check_response raises a
# failure if the global headers are not present in the response
# provided we are excluding headers for /docs and /openapi.json
html_response = get("/docs", should_check_response=False)
assert html_response.status_code == 200


@pytest.mark.benchmark
def test_json_handler():
openapi_spec = get("/openapi.json").json()
openapi_response = get("/openapi.json", should_check_response=False)

assert openapi_response.status_code == 200

openapi_spec = openapi_response.json()

assert isinstance(openapi_spec, dict)
assert "openapi" in openapi_spec
Expand All @@ -24,7 +31,11 @@ def test_json_handler():

@pytest.mark.benchmark
def test_add_openapi_path():
openapi_spec = get("/openapi.json").json()
openapi_response = get("/openapi.json", should_check_response=False)

assert openapi_response.status_code == 200

openapi_spec = openapi_response.json()

assert isinstance(openapi_spec, dict)

Expand All @@ -41,7 +52,11 @@ def test_add_openapi_path():

@pytest.mark.benchmark
def test_add_subrouter_paths():
openapi_spec = get("/openapi.json").json()
openapi_response = get("/openapi.json", should_check_response=False)

assert openapi_response.status_code == 200

openapi_spec = openapi_response.json()

assert isinstance(openapi_spec, dict)

Expand All @@ -58,7 +73,11 @@ def test_add_subrouter_paths():

@pytest.mark.benchmark
def test_openapi_request_body():
openapi_spec = get("/openapi.json").json()
openapi_response = get("/openapi.json", should_check_response=False)

assert openapi_response.status_code == 200

openapi_spec = openapi_response.json()

assert isinstance(openapi_spec, dict)

Expand Down Expand Up @@ -118,7 +137,11 @@ def test_openapi_request_body():

@pytest.mark.benchmark
def test_openapi_response_body():
openapi_spec = get("/openapi.json").json()
openapi_response = get("/openapi.json", should_check_response=False)

assert openapi_response.status_code == 200

openapi_spec = openapi_response.json()

assert isinstance(openapi_spec, dict)

Expand Down Expand Up @@ -153,7 +176,11 @@ def test_openapi_response_body():

@pytest.mark.benchmark
def test_openapi_query_params():
openapi_spec = get("/openapi.json").json()
openapi_response = get("/openapi.json", should_check_response=False)

assert openapi_response.status_code == 200

openapi_spec = openapi_response.json()

assert isinstance(openapi_spec, dict)

Expand Down
10 changes: 10 additions & 0 deletions robyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
self.web_socket_router = WebSocketRouter()
self.request_headers: Headers = Headers({})
self.response_headers: Headers = Headers({})
self.excluded_response_headers_paths: Optional[List[str]] = None
self.directories: List[Directory] = []
self.event_handlers = {}
self.exception_handler: Optional[Callable] = None
Expand Down Expand Up @@ -200,6 +201,13 @@ def set_request_header(self, key: str, value: str) -> None:
def set_response_header(self, key: str, value: str) -> None:
self.response_headers.set(key, value)

def exclude_response_headers_for(self, excluded_response_header_paths: Optional[List[str]]):
"""
To exclude response headers from certain routes
@param exclude_paths: the paths to exclude response headers from
"""
self.excluded_response_header_paths = excluded_response_header_paths

def add_web_socket(self, endpoint: str, ws: WebSocket) -> None:
self.web_socket_router.add_route(endpoint, ws)

Expand Down Expand Up @@ -239,6 +247,7 @@ def _add_openapi_routes(self, auth_required: bool = False):
is_const=True,
auth_required=auth_required,
)
self.exclude_response_headers_for(["/docs", "/openapi.json"])

def start(self, host: str = "127.0.0.1", port: int = 8080, _check_port: bool = True):
"""
Expand Down Expand Up @@ -284,6 +293,7 @@ def start(self, host: str = "127.0.0.1", port: int = 8080, _check_port: bool = T
self.config.workers,
self.config.processes,
self.response_headers,
self.excluded_response_header_paths,
open_browser,
)

Expand Down
10 changes: 9 additions & 1 deletion robyn/processpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import signal
import sys
import webbrowser
from typing import Dict, List
from typing import Dict, List, Optional

from multiprocess import Process

Expand All @@ -27,6 +27,7 @@ def run_processes(
workers: int,
processes: int,
response_headers: Headers,
excluded_response_headers_paths: Optional[List[str]],
open_browser: bool,
) -> List[Process]:
socket = SocketHeld(url, port)
Expand All @@ -43,6 +44,7 @@ def run_processes(
workers,
processes,
response_headers,
excluded_response_headers_paths,
)

def terminating_signal_handler(_sig, _frame):
Expand Down Expand Up @@ -76,6 +78,7 @@ def init_processpool(
workers: int,
processes: int,
response_headers: Headers,
excluded_response_headers_paths: Optional[List[str]],
) -> List[Process]:
process_pool = []
if sys.platform.startswith("win32") or processes == 1:
Expand All @@ -90,6 +93,7 @@ def init_processpool(
socket,
workers,
response_headers,
excluded_response_headers_paths,
)

return process_pool
Expand All @@ -109,6 +113,7 @@ def init_processpool(
copied_socket,
workers,
response_headers,
excluded_response_headers_paths,
),
)
process.start()
Expand Down Expand Up @@ -144,6 +149,7 @@ def spawn_process(
socket: SocketHeld,
workers: int,
response_headers: Headers,
excluded_response_headers_paths: Optional[List[str]],
):
"""
This function is called by the main process handler to create a server runtime.
Expand Down Expand Up @@ -173,6 +179,8 @@ def spawn_process(

server.apply_response_headers(response_headers)

server.set_response_headers_exclude_paths(excluded_response_headers_paths)

for route in routes:
route_type, endpoint, function, is_const = route
server.add_route(route_type, endpoint, function, is_const)
Expand Down
2 changes: 2 additions & 0 deletions robyn/robyn.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ class Server:
pass
def apply_response_headers(self, headers: Headers) -> None:
pass
def set_response_headers_exclude_paths(self, excluded_response_header_paths: Optional[list[str]] = None):
pass

def add_route(
self,
Expand Down
27 changes: 26 additions & 1 deletion src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ pub struct Server {
directories: Arc<RwLock<Vec<Directory>>>,
startup_handler: Option<Arc<FunctionInfo>>,
shutdown_handler: Option<Arc<FunctionInfo>>,
excluded_response_headers_paths: Option<Vec<String>>,
}

#[pymethods]
Expand All @@ -72,6 +73,7 @@ impl Server {
directories: Arc::new(RwLock::new(Vec::new())),
startup_handler: None,
shutdown_handler: None,
excluded_response_headers_paths: None,
}
}

Expand Down Expand Up @@ -108,6 +110,8 @@ impl Server {
let startup_handler = self.startup_handler.clone();
let shutdown_handler = self.shutdown_handler.clone();

let excluded_response_headers_paths = self.excluded_response_headers_paths.clone();

let task_locals = pyo3_asyncio::TaskLocals::new(event_loop).copy_context(py)?;
let task_locals_copy = task_locals.clone();

Expand Down Expand Up @@ -162,7 +166,8 @@ impl Server {
.app_data(web::Data::new(const_router.clone()))
.app_data(web::Data::new(middleware_router.clone()))
.app_data(web::Data::new(global_request_headers.clone()))
.app_data(web::Data::new(global_response_headers.clone()));
.app_data(web::Data::new(global_response_headers.clone()))
.app_data(web::Data::new(excluded_response_headers_paths.clone()));

let web_socket_map = web_socket_router.get_web_socket_map();
for (elem, value) in (web_socket_map.read()).iter() {
Expand Down Expand Up @@ -194,6 +199,7 @@ impl Server {
payload: web::Payload,
global_request_headers,
global_response_headers,
response_headers_exclude_paths,
req| {
pyo3_asyncio::tokio::scope_local(task_locals.clone(), async move {
index(
Expand All @@ -203,6 +209,7 @@ impl Server {
middleware_router,
global_request_headers,
global_response_headers,
response_headers_exclude_paths,
req,
)
.await
Expand Down Expand Up @@ -298,6 +305,13 @@ impl Server {
self.global_response_headers = Arc::new(headers.clone());
}

pub fn set_response_headers_exclude_paths(
&mut self,
excluded_response_headers_paths: Option<Vec<String>>,
) {
self.excluded_response_headers_paths = excluded_response_headers_paths;
}

/// Add a new route to the routing tables
/// can be called after the server has been started
pub fn add_route(
Expand Down Expand Up @@ -409,13 +423,15 @@ impl Default for Server {

/// This is our service handler. It receives a Request, routes on it
/// path, and returns a Future of a Response.
#[allow(clippy::too_many_arguments)]
async fn index(
router: web::Data<Arc<HttpRouter>>,
payload: web::Payload,
const_router: web::Data<Arc<ConstRouter>>,
middleware_router: web::Data<Arc<MiddlewareRouter>>,
global_request_headers: web::Data<Arc<Headers>>,
global_response_headers: web::Data<Arc<Headers>>,
excluded_response_headers_paths: web::Data<Option<Vec<String>>>,
req: HttpRequest,
) -> impl Responder {
let mut request = Request::from_actix_request(&req, payload, &global_request_headers).await;
Expand Down Expand Up @@ -480,6 +496,15 @@ async fn index(

response.headers.extend(&global_response_headers);

match &excluded_response_headers_paths.get_ref() {
None => {}
Some(excluded_response_headers_paths) => {
if excluded_response_headers_paths.contains(&req.uri().path().to_owned()) {
response.headers.clear();
}
}
}

debug!("Extended Response : {:?}", response);

// After middleware
Expand Down
4 changes: 4 additions & 0 deletions src/types/headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ impl Headers {
self.headers.remove(&key.to_lowercase());
}

pub fn clear(&mut self) {
self.headers.clear();
}

pub fn extend(&mut self, headers: &Headers) {
for iter in headers.headers.iter() {
let (key, values) = iter.pair();
Expand Down
Loading