Skip to content

Commit

Permalink
Fix HttpAuthenticationType to add callable (#133)
Browse files Browse the repository at this point in the history
Add a callable object to the `HttpAuthenticationType` type definition.

A type checking error occurs when a callable object is given as the
`auth` argument.


![image](https://github.com/jawah/niquests/assets/51289448/a460196f-1a9f-4831-a56d-1a992e85da30)

<details>
<summary>Reproduction code</summary>

```python
import niquests


def pizza_auth(request: niquests.PreparedRequest) -> niquests.PreparedRequest:
    if request.headers:
        request.headers["X-Pizza"] = "Token"
    return request


def test_callable_auth():
    r = niquests.get("https://httpbin.org/get", auth=pizza_auth)
    print(r.json()["headers"])


if __name__ == "__main__":
    test_callable_auth()
```

</details>

This is allowed at runtime.


https://github.com/jawah/niquests/blob/d83ab6b98e317bbf82ea950a693fce1fc95936a3/src/niquests/models.py#L615-L647

---------

Co-authored-by: Ahmed TAHRI <[email protected]>
  • Loading branch information
MtkN1 and Ousret authored Jul 2, 2024
1 parent d83ab6b commit 7cd3bcf
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
6 changes: 6 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
Release History
===============

3.7.1 (2024-06-??)
------------------

**Fixed**
- auth argument not accepting a function according to static type checkers. (#133)

3.7.0 (2024-06-24)
------------------

Expand Down
8 changes: 5 additions & 3 deletions src/niquests/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
from .auth import AuthBase
from .structures import CaseInsensitiveDict

if typing.TYPE_CHECKING:
from .models import PreparedRequest

#: (Restricted) list of http verb that we natively support and understand.
HttpMethodType: typing.TypeAlias = (
str # todo: have typing.Literal when ready to drop Python 3.7
)
HttpMethodType: typing.TypeAlias = str
#: List of formats accepted for URL queries parameters. (e.g. /?param1=a&param2=b)
QueryParameterType: typing.TypeAlias = typing.Union[
typing.List[typing.Tuple[str, typing.Union[str, typing.List[str]]]],
Expand Down Expand Up @@ -89,6 +90,7 @@
typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]],
str,
AuthBase,
typing.Callable[["PreparedRequest"], "PreparedRequest"],
]
#: Map for each protocol (http, https) associated proxy to be used.
ProxyType: typing.TypeAlias = typing.Dict[str, str]
Expand Down
8 changes: 7 additions & 1 deletion src/niquests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ def __init__(self) -> None:
self.ocsp_verified: bool | None = None
#: upload progress if any.
self.upload_progress: TransferProgress | None = None
#: internal usage only. warn us that we should re-compute content-length and await auth() outside of PreparedRequest.
self._asynchronous_auth: bool = False

@property
def oheaders(self) -> Headers:
Expand Down Expand Up @@ -636,7 +638,11 @@ def prepare_auth(self, auth: HttpAuthenticationType | None, url: str = "") -> No
"Unexpected non-callable authentication. Did you pass unsupported tuple to auth argument?"
)

if not asyncio.iscoroutinefunction(auth.__call__):
self._asynchronous_auth = hasattr(
auth, "__call__"
) and asyncio.iscoroutinefunction(auth.__call__)

if not self._asynchronous_auth:
# Allow auth to make its changes.
r = auth(self)

Expand Down

0 comments on commit 7cd3bcf

Please sign in to comment.