diff --git a/README.md b/README.md index d2cabf1..597cd36 100644 --- a/README.md +++ b/README.md @@ -575,7 +575,7 @@ for normal operation of PyXero. It's only required for testing purposes. Once you've installed these dependencies, you can run the test suite by running the following from the root directory of the project: - $ python setup.py test + $ tox -e py If you find any problems with PyXero, you can log them on [Github Issues](https://github.com/freakboy3742/pyxero/issues). When reporting problems, it's extremely helpful if you can provide diff --git a/src/xero/filesmanager.py b/src/xero/filesmanager.py index 1f7b8b3..c32f777 100644 --- a/src/xero/filesmanager.py +++ b/src/xero/filesmanager.py @@ -2,6 +2,8 @@ import requests from urllib.parse import parse_qs +from xero.auth import OAuth2Credentials + from .constants import XERO_FILES_URL from .exceptions import ( XeroBadRequest, @@ -12,6 +14,7 @@ XeroNotFound, XeroNotImplemented, XeroRateLimitExceeded, + XeroTenantIdNotSet, XeroUnauthorized, XeroUnsupportedMediaType, ) @@ -67,6 +70,14 @@ def wrapper(*args, **kwargs): uri, params, method, body, headers, singleobject, files = func( *args, **kwargs ) + if headers is None: + headers = {} + + if isinstance(self.credentials, OAuth2Credentials): + if self.credentials.tenant_id: + headers["Xero-tenant-id"] = self.credentials.tenant_id + else: + raise XeroTenantIdNotSet response = getattr(requests, method)( uri, @@ -169,15 +180,19 @@ def _delete(self, id): uri = "/".join([self.base_url, self.name, id]) return uri, {}, "delete", None, None, False, None - def _upload_file(self, path, folderId=None): + def _upload_file(self, path=None, folderId=None, filename=None, file=None): if folderId is not None: uri = "/".join([self.base_url, self.name, folderId]) else: uri = "/".join([self.base_url, self.name]) - filename = self.filename(path) files = dict() - files[filename] = open(path, mode="rb") + if path: + filename = os.path.basename(path) + files[filename] = open(path, mode="rb") + + elif filename and file: + files[filename] = file return uri, {}, "post", None, None, False, files @@ -193,7 +208,3 @@ def _make_association(self, id, data): def _all(self): uri = "/".join([self.base_url, self.name]) return uri, {}, "get", None, None, False, None - - def filename(self, path): - head, tail = os.path.split(path) - return tail or os.path.basename(head) diff --git a/tests/test_filesmanager.py b/tests/test_filesmanager.py new file mode 100644 index 0000000..616f3ef --- /dev/null +++ b/tests/test_filesmanager.py @@ -0,0 +1,73 @@ +import os.path +import unittest +from time import time +from unittest.mock import Mock, patch + +from xero import Xero +from xero.auth import OAuth2Credentials + + +class FilesManagerTest(unittest.TestCase): + def setUp(self): + super().setUp() + # Create an expired token to be used by tests + self.expired_token = { + "access_token": "1234567890", + "expires_in": 1800, + "token_type": "Bearer", + "refresh_token": "0987654321", + "expires_at": time(), + } + + self.filepath = "test_file.txt" + with open(self.filepath, "w") as f: + f.write("test") + + def tearDown(self): + os.remove(self.filepath) + + @patch("requests.get") + def test_tenant_is_used_in_xero_request(self, r_get): + credentials = OAuth2Credentials( + "client_id", "client_secret", token=self.expired_token, tenant_id="12345" + ) + xero = Xero(credentials) + r_get.return_value = Mock( + status_code=200, + headers={"content-type": "text/html; charset=utf-8"}, + ) + xero.filesAPI.files.all() + + self.assertEqual(r_get.call_args[1]["headers"]["Xero-tenant-id"], "12345") + + @patch("requests.post") + def test_upload_file_as_path(self, r_get): + credentials = OAuth2Credentials( + "client_id", "client_secret", token=self.expired_token, tenant_id="12345" + ) + xero = Xero(credentials) + r_get.return_value = Mock( + status_code=200, + headers={"content-type": "text/html; charset=utf-8"}, + ) + xero.filesAPI.files.upload_file(path=self.filepath) + + self.assertIn(self.filepath, r_get.call_args[1]["files"]) + + @patch("requests.post") + def test_upload_file_as_file(self, r_get): + credentials = OAuth2Credentials( + "client_id", "client_secret", token=self.expired_token, tenant_id="12345" + ) + xero = Xero(credentials) + r_get.return_value = Mock( + status_code=200, + headers={"content-type": "text/html; charset=utf-8"}, + ) + + with open(self.filepath) as f: + xero.filesAPI.files.upload_file( + file=f, filename=os.path.basename(self.filepath) + ) + + self.assertIn(self.filepath, r_get.call_args[1]["files"])