From 1385dd547828b9c743106e5173e8f4088654df43 Mon Sep 17 00:00:00 2001 From: Zac Pullar-Strecker Date: Mon, 4 Dec 2023 13:00:10 +1300 Subject: [PATCH] Allow passing requests session for customisation --- doc/index.rst | 3 ++- src/pyzotero/zotero.py | 46 ++++++++++++++++++++++-------------------- 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/doc/index.rst b/doc/index.rst index ba41f46..9bc613e 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -105,13 +105,14 @@ General Usage First, create a new Zotero instance: - .. py:class:: Zotero(library_id, library_type[, api_key, preserve_json_order, locale]) + .. py:class:: Zotero(library_id, library_type[, api_key, preserve_json_order, locale, session]) :param str library_id: a valid Zotero API user ID :param str library_type: a valid Zotero API library type: **user** or **group** :param str api_key: a valid Zotero API user key :param bool preserve_json_order: Load JSON returns with OrderedDict to preserve their order :param str locale: Set the `locale `_, allowing retrieval of localised item types, field types, and creator types. Defaults to "en-US". + :param requests.Session session: a custom requests session, for example to use `requests-cache `_ Example: diff --git a/src/pyzotero/zotero.py b/src/pyzotero/zotero.py index fae50e0..4fa9b72 100644 --- a/src/pyzotero/zotero.py +++ b/src/pyzotero/zotero.py @@ -288,6 +288,7 @@ def __init__( api_key=None, preserve_json_order=False, locale="en-US", + session=None, ): """Store Zotero credentials""" self.endpoint = "https://api.zotero.org" @@ -303,6 +304,7 @@ def __init__( self.api_key = api_key self.preserve_json_order = preserve_json_order self.locale = locale + self.s = session or requests.Session() self.url_params = None self.tag_data = False self.request = None @@ -416,7 +418,7 @@ def _retrieve_data(self, request=None, params=None): self.self_link = request # ensure that we wait if there's an active backoff self._check_backoff() - self.request = requests.get( + self.request = self.s.get( url=full_url, headers=self.default_headers(), params=params ) self.request.encoding = "utf-8" @@ -487,7 +489,7 @@ def _updated(self, url, payload, template=None): headers.update(self.default_headers()) # perform the request, and check whether the response returns 304 self._check_backoff() - req = requests.get(query, headers=headers) + req = self.s.get(query, headers=headers) try: req.raise_for_status() except requests.exceptions.HTTPError as exc: @@ -615,7 +617,7 @@ def set_fulltext(self, itemkey, payload): """ headers = self.default_headers() headers.update({"Content-Type": "application/json"}) - return requests.put( + return self.s.put( url=build_url( self.endpoint, "/{t}/{u}/items/{k}/fulltext".format( @@ -636,7 +638,7 @@ def new_fulltext(self, since): ) headers = self.default_headers() self._check_backoff() - resp = requests.get(build_url(self.endpoint, query_string), headers=headers) + resp = self.s.get(build_url(self.endpoint, query_string), headers=headers) try: resp.raise_for_status() except requests.exceptions.HTTPError as exc: @@ -1026,7 +1028,7 @@ def saved_search(self, name, conditions): headers = {"Zotero-Write-Token": token()} headers.update(self.default_headers()) self._check_backoff() - req = requests.post( + req = self.s.post( url=build_url( self.endpoint, "/{t}/{u}/searches".format(t=self.library_type, u=self.library_id), @@ -1054,7 +1056,7 @@ def delete_saved_search(self, keys): headers = {"Zotero-Write-Token": token()} headers.update(self.default_headers()) self._check_backoff() - req = requests.delete( + req = self.s.delete( url=build_url( self.endpoint, "/{t}/{u}/searches".format(t=self.library_type, u=self.library_id), @@ -1230,7 +1232,7 @@ def create_items(self, payload, parentid=None, last_modified=None): to_send = json.dumps([i for i in self._cleanup(*payload, allow=("key"))]) headers.update(self.default_headers()) self._check_backoff() - req = requests.post( + req = self.s.post( url=build_url( self.endpoint, "/{t}/{u}/items".format(t=self.library_type, u=self.library_id), @@ -1260,7 +1262,7 @@ def create_items(self, payload, parentid=None, last_modified=None): for value in resp["success"].values(): payload = json.dumps({"parentItem": parentid}) self._check_backoff() - presp = requests.patch( + presp = self.s.patch( url=build_url( self.endpoint, "/{t}/{u}/items/{v}".format( @@ -1306,7 +1308,7 @@ def create_collections(self, payload, last_modified=None): headers["If-Unmodified-Since-Version"] = str(last_modified) headers.update(self.default_headers()) self._check_backoff() - req = requests.post( + req = self.s.post( url=build_url( self.endpoint, "/{t}/{u}/collections".format(t=self.library_type, u=self.library_id), @@ -1338,7 +1340,7 @@ def update_collection(self, payload, last_modified=None): headers = {"If-Unmodified-Since-Version": str(modified)} headers.update(self.default_headers()) headers.update({"Content-Type": "application/json"}) - return requests.put( + return self.s.put( url=build_url( self.endpoint, "/{t}/{u}/collections/{c}".format( @@ -1397,7 +1399,7 @@ def update_item(self, payload, last_modified=None): ident = payload["key"] headers = {"If-Unmodified-Since-Version": str(modified)} headers.update(self.default_headers()) - return requests.patch( + return self.s.patch( url=build_url( self.endpoint, "/{t}/{u}/items/{id}".format( @@ -1420,7 +1422,7 @@ def update_items(self, payload): # anything longer for chunk in chunks(to_send, 50): self._check_backoff() - req = requests.post( + req = self.s.post( url=build_url( self.endpoint, "/{t}/{u}/items/".format(t=self.library_type, u=self.library_id), @@ -1450,7 +1452,7 @@ def update_collections(self, payload): # anything longer for chunk in chunks(to_send, 50): self._check_backoff() - req = requests.post( + req = self.s.post( url=build_url( self.endpoint, "/{t}/{u}/collections/".format( @@ -1483,7 +1485,7 @@ def addto_collection(self, collection, payload): modified_collections = payload["data"]["collections"] + [collection] headers = {"If-Unmodified-Since-Version": str(modified)} headers.update(self.default_headers()) - return requests.patch( + return self.s.patch( url=build_url( self.endpoint, "/{t}/{u}/items/{i}".format( @@ -1509,7 +1511,7 @@ def deletefrom_collection(self, collection, payload): ] headers = {"If-Unmodified-Since-Version": str(modified)} headers.update(self.default_headers()) - return requests.patch( + return self.s.patch( url=build_url( self.endpoint, "/{t}/{u}/items/{i}".format( @@ -1536,7 +1538,7 @@ def delete_tags(self, *payload): "If-Unmodified-Since-Version": self.request.headers["last-modified-version"] } headers.update(self.default_headers()) - return requests.delete( + return self.s.delete( url=build_url( self.endpoint, "/{t}/{u}/tags".format(t=self.library_type, u=self.library_id), @@ -1578,7 +1580,7 @@ def delete_item(self, payload, last_modified=None): ) headers = {"If-Unmodified-Since-Version": str(modified)} headers.update(self.default_headers()) - return requests.delete(url=url, params=params, headers=headers) + return self.s.delete(url=url, params=params, headers=headers) @backoff_check def delete_collection(self, payload, last_modified=None): @@ -1613,7 +1615,7 @@ def delete_collection(self, payload, last_modified=None): ) headers = {"If-Unmodified-Since-Version": str(modified)} headers.update(self.default_headers()) - return requests.delete(url=url, params=params, headers=headers) + return self.s.delete(url=url, params=params, headers=headers) def error_handler(zot, req, exc=None): @@ -1898,7 +1900,7 @@ def _create_prelim(self): child["parentItem"] = self.parentid to_send = json.dumps(self.payload) self.zinstance._check_backoff() - req = requests.post( + req = self.s.post( url=build_url( self.zinstance.endpoint, liblevel.format( @@ -1946,7 +1948,7 @@ def _get_auth(self, attachment, reg_key, md5=None): "params": 1, } self.zinstance._check_backoff() - auth_req = requests.post( + auth_req = self.s.post( url=build_url( self.zinstance.endpoint, "/{t}/{u}/items/{i}/file".format( @@ -1983,7 +1985,7 @@ def _upload_file(self, authdata, attachment, reg_key): upload_pairs = tuple(upload_list) try: self.zinstance._check_backoff() - upload = requests.post( + upload = self.s.post( url=authdata["url"], files=upload_pairs, headers={"User-Agent": "Pyzotero/%s" % pz.__version__}, @@ -2011,7 +2013,7 @@ def _register_upload(self, authdata, reg_key): reg_headers.update(self.zinstance.default_headers()) reg_data = {"upload": authdata.get("uploadKey")} self.zinstance._check_backoff() - upload_reg = requests.post( + upload_reg = self.s.post( url=build_url( self.zinstance.endpoint, "/{t}/{u}/items/{i}/file".format(