Skip to content

Commit

Permalink
Allow passing requests session for customisation
Browse files Browse the repository at this point in the history
  • Loading branch information
zacps committed Dec 4, 2023
1 parent c208d4f commit f681dbd
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 23 deletions.
3 changes: 2 additions & 1 deletion doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,14 @@ General Usage
First, create a new Zotero instance:


.. py:class:: Zotero(library_id, library_type[, api_key, preserve_json_order, locale])
.. py:class:: Zotero(library_id, library_type[, api_key, preserve_json_order, locale, session])
:param str library_id: a valid Zotero API user ID
:param str library_type: a valid Zotero API library type: **user** or **group**
:param str api_key: a valid Zotero API user key
:param bool preserve_json_order: Load JSON returns with OrderedDict to preserve their order
:param str locale: Set the `locale <https://www.zotero.org/support/dev/web_api/v3/types_and_fields#zotero_web_api_item_typefield_requests>`_, allowing retrieval of localised item types, field types, and creator types. Defaults to "en-US".
:param requests.Session session: a custom requests session


Example:
Expand Down
46 changes: 24 additions & 22 deletions src/pyzotero/zotero.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def __init__(
api_key=None,
preserve_json_order=False,
locale="en-US",
session=None,
):
"""Store Zotero credentials"""
self.endpoint = "https://api.zotero.org"
Expand All @@ -303,6 +304,7 @@ def __init__(
self.api_key = api_key
self.preserve_json_order = preserve_json_order
self.locale = locale
self.s = session or requests.Session()
self.url_params = None
self.tag_data = False
self.request = None
Expand Down Expand Up @@ -416,7 +418,7 @@ def _retrieve_data(self, request=None, params=None):
self.self_link = request
# ensure that we wait if there's an active backoff
self._check_backoff()
self.request = requests.get(
self.request = self.s.get(
url=full_url, headers=self.default_headers(), params=params
)
self.request.encoding = "utf-8"
Expand Down Expand Up @@ -487,7 +489,7 @@ def _updated(self, url, payload, template=None):
headers.update(self.default_headers())
# perform the request, and check whether the response returns 304
self._check_backoff()
req = requests.get(query, headers=headers)
req = self.s.get(query, headers=headers)
try:
req.raise_for_status()
except requests.exceptions.HTTPError as exc:
Expand Down Expand Up @@ -615,7 +617,7 @@ def set_fulltext(self, itemkey, payload):
"""
headers = self.default_headers()
headers.update({"Content-Type": "application/json"})
return requests.put(
return self.s.put(
url=build_url(
self.endpoint,
"/{t}/{u}/items/{k}/fulltext".format(
Expand All @@ -636,7 +638,7 @@ def new_fulltext(self, since):
)
headers = self.default_headers()
self._check_backoff()
resp = requests.get(build_url(self.endpoint, query_string), headers=headers)
resp = self.s.get(build_url(self.endpoint, query_string), headers=headers)
try:
resp.raise_for_status()
except requests.exceptions.HTTPError as exc:
Expand Down Expand Up @@ -1026,7 +1028,7 @@ def saved_search(self, name, conditions):
headers = {"Zotero-Write-Token": token()}
headers.update(self.default_headers())
self._check_backoff()
req = requests.post(
req = self.s.post(
url=build_url(
self.endpoint,
"/{t}/{u}/searches".format(t=self.library_type, u=self.library_id),
Expand Down Expand Up @@ -1054,7 +1056,7 @@ def delete_saved_search(self, keys):
headers = {"Zotero-Write-Token": token()}
headers.update(self.default_headers())
self._check_backoff()
req = requests.delete(
req = self.s.delete(
url=build_url(
self.endpoint,
"/{t}/{u}/searches".format(t=self.library_type, u=self.library_id),
Expand Down Expand Up @@ -1230,7 +1232,7 @@ def create_items(self, payload, parentid=None, last_modified=None):
to_send = json.dumps([i for i in self._cleanup(*payload, allow=("key"))])
headers.update(self.default_headers())
self._check_backoff()
req = requests.post(
req = self.s.post(
url=build_url(
self.endpoint,
"/{t}/{u}/items".format(t=self.library_type, u=self.library_id),
Expand Down Expand Up @@ -1260,7 +1262,7 @@ def create_items(self, payload, parentid=None, last_modified=None):
for value in resp["success"].values():
payload = json.dumps({"parentItem": parentid})
self._check_backoff()
presp = requests.patch(
presp = self.s.patch(
url=build_url(
self.endpoint,
"/{t}/{u}/items/{v}".format(
Expand Down Expand Up @@ -1306,7 +1308,7 @@ def create_collections(self, payload, last_modified=None):
headers["If-Unmodified-Since-Version"] = str(last_modified)
headers.update(self.default_headers())
self._check_backoff()
req = requests.post(
req = self.s.post(
url=build_url(
self.endpoint,
"/{t}/{u}/collections".format(t=self.library_type, u=self.library_id),
Expand Down Expand Up @@ -1338,7 +1340,7 @@ def update_collection(self, payload, last_modified=None):
headers = {"If-Unmodified-Since-Version": str(modified)}
headers.update(self.default_headers())
headers.update({"Content-Type": "application/json"})
return requests.put(
return self.s.put(
url=build_url(
self.endpoint,
"/{t}/{u}/collections/{c}".format(
Expand Down Expand Up @@ -1397,7 +1399,7 @@ def update_item(self, payload, last_modified=None):
ident = payload["key"]
headers = {"If-Unmodified-Since-Version": str(modified)}
headers.update(self.default_headers())
return requests.patch(
return self.s.patch(
url=build_url(
self.endpoint,
"/{t}/{u}/items/{id}".format(
Expand All @@ -1420,7 +1422,7 @@ def update_items(self, payload):
# anything longer
for chunk in chunks(to_send, 50):
self._check_backoff()
req = requests.post(
req = self.s.post(
url=build_url(
self.endpoint,
"/{t}/{u}/items/".format(t=self.library_type, u=self.library_id),
Expand Down Expand Up @@ -1450,7 +1452,7 @@ def update_collections(self, payload):
# anything longer
for chunk in chunks(to_send, 50):
self._check_backoff()
req = requests.post(
req = self.s.post(
url=build_url(
self.endpoint,
"/{t}/{u}/collections/".format(
Expand Down Expand Up @@ -1483,7 +1485,7 @@ def addto_collection(self, collection, payload):
modified_collections = payload["data"]["collections"] + [collection]
headers = {"If-Unmodified-Since-Version": str(modified)}
headers.update(self.default_headers())
return requests.patch(
return self.s.patch(
url=build_url(
self.endpoint,
"/{t}/{u}/items/{i}".format(
Expand All @@ -1509,7 +1511,7 @@ def deletefrom_collection(self, collection, payload):
]
headers = {"If-Unmodified-Since-Version": str(modified)}
headers.update(self.default_headers())
return requests.patch(
return self.s.patch(
url=build_url(
self.endpoint,
"/{t}/{u}/items/{i}".format(
Expand All @@ -1536,7 +1538,7 @@ def delete_tags(self, *payload):
"If-Unmodified-Since-Version": self.request.headers["last-modified-version"]
}
headers.update(self.default_headers())
return requests.delete(
return self.s.delete(
url=build_url(
self.endpoint,
"/{t}/{u}/tags".format(t=self.library_type, u=self.library_id),
Expand Down Expand Up @@ -1578,7 +1580,7 @@ def delete_item(self, payload, last_modified=None):
)
headers = {"If-Unmodified-Since-Version": str(modified)}
headers.update(self.default_headers())
return requests.delete(url=url, params=params, headers=headers)
return self.s.delete(url=url, params=params, headers=headers)

@backoff_check
def delete_collection(self, payload, last_modified=None):
Expand Down Expand Up @@ -1613,7 +1615,7 @@ def delete_collection(self, payload, last_modified=None):
)
headers = {"If-Unmodified-Since-Version": str(modified)}
headers.update(self.default_headers())
return requests.delete(url=url, params=params, headers=headers)
return self.s.delete(url=url, params=params, headers=headers)


def error_handler(zot, req, exc=None):
Expand Down Expand Up @@ -1898,7 +1900,7 @@ def _create_prelim(self):
child["parentItem"] = self.parentid
to_send = json.dumps(self.payload)
self.zinstance._check_backoff()
req = requests.post(
req = self.s.post(
url=build_url(
self.zinstance.endpoint,
liblevel.format(
Expand Down Expand Up @@ -1946,7 +1948,7 @@ def _get_auth(self, attachment, reg_key, md5=None):
"params": 1,
}
self.zinstance._check_backoff()
auth_req = requests.post(
auth_req = self.s.post(
url=build_url(
self.zinstance.endpoint,
"/{t}/{u}/items/{i}/file".format(
Expand Down Expand Up @@ -1983,7 +1985,7 @@ def _upload_file(self, authdata, attachment, reg_key):
upload_pairs = tuple(upload_list)
try:
self.zinstance._check_backoff()
upload = requests.post(
upload = self.s.post(
url=authdata["url"],
files=upload_pairs,
headers={"User-Agent": "Pyzotero/%s" % pz.__version__},
Expand Down Expand Up @@ -2011,7 +2013,7 @@ def _register_upload(self, authdata, reg_key):
reg_headers.update(self.zinstance.default_headers())
reg_data = {"upload": authdata.get("uploadKey")}
self.zinstance._check_backoff()
upload_reg = requests.post(
upload_reg = self.s.post(
url=build_url(
self.zinstance.endpoint,
"/{t}/{u}/items/{i}/file".format(
Expand Down

0 comments on commit f681dbd

Please sign in to comment.