Skip to content

Commit

Permalink
[ENG-6681] Fix/citations (#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
opaduchak authored Dec 6, 2024
1 parent b11c779 commit 8581838
Show file tree
Hide file tree
Showing 5 changed files with 362 additions and 80 deletions.
57 changes: 39 additions & 18 deletions addon_imps/citations/mendeley.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio

from addon_toolkit.async_utils import join_list
from addon_toolkit.interfaces.citation import (
CitationAddonImp,
ItemResult,
Expand All @@ -20,41 +21,60 @@ async def get_external_account_id(self, auth_result_extras: dict[str, str]) -> s
return str(user_id)

async def list_root_collections(self) -> ItemSampleResult:
async with self.network.GET("folders") as response:
response_json = await response.json_content()
raw_collections = [
collection
for collection in response_json
if "parent_id" not in collection
return ItemSampleResult(
items=[
ItemResult(
item_id="ROOT",
item_type=ItemType.COLLECTION,
item_name="All Documents",
)
]
return self._parse_collection_response(raw_collections)
)

async def list_collection_items(
self,
collection_id: str,
filter_items: ItemType | None = None,
) -> ItemSampleResult:
tasks = []
if filter_items != ItemType.COLLECTION:
tasks.append(self._fetch_collection_documents(collection_id))
if filter_items != ItemType.DOCUMENT:
tasks.append(self._fetch_subcollections(collection_id))
items = await join_list(tasks)

return ItemSampleResult(items=items, total_count=len(items))

async def _fetch_subcollections(self, collection_id):
async with self.network.GET(
f"folders/{collection_id}/documents",
"folders",
) as response:
document_ids = await response.json_content()
items = await self._fetch_documents_details(document_ids, filter_items)
items = self._parse_collection_response(document_ids, collection_id)

return ItemSampleResult(items=items, total_count=len(items))
return items.items

async def _fetch_collection_documents(self, collection_id: str):
if collection_id and collection_id != "ROOT":
prefix = f"folders/{collection_id}/"
else:
prefix = ""
async with self.network.GET(
f"{prefix}documents",
) as response:
document_ids = await response.json_content()
return await self._fetch_documents_details(document_ids)

async def _fetch_documents_details(
self, document_ids: list[dict], filter_items: ItemType | None
self, document_ids: list[dict]
) -> list[ItemResult]:
tasks = [
self._fetch_item_details(doc["id"], filter_items) for doc in document_ids
]
tasks = [self._fetch_item_details(doc["id"]) for doc in document_ids]

return list(await asyncio.gather(*tasks))

async def _fetch_item_details(
self,
item_id: str,
filter_items: ItemType | None,
) -> ItemResult:
async with self.network.GET(f"documents/{item_id}") as item_response:
item_details = await item_response.json_content()
Expand All @@ -68,16 +88,17 @@ async def _fetch_item_details(
csl=csl_data,
)

def _parse_collection_response(self, response_json: dict) -> ItemSampleResult:
def _parse_collection_response(
self, response_json: dict, parent_id: str
) -> ItemSampleResult:
items = [
ItemResult(
item_id=collection["id"],
item_name=collection["name"],
item_type=ItemType.COLLECTION,
item_path=None,
csl=None,
)
for collection in response_json
if collection.get("parent_id", "ROOT") == parent_id
]

return ItemSampleResult(items=items, total_count=len(items))
Expand Down
64 changes: 53 additions & 11 deletions addon_imps/citations/zotero_org.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from addon_toolkit.async_utils import join_list
from addon_toolkit.interfaces.citation import (
CitationAddonImp,
ItemResult,
Expand All @@ -7,7 +8,6 @@


class ZoteroOrgCitationImp(CitationAddonImp):

async def get_external_account_id(self, auth_result_extras: dict[str, str]) -> str:
user_id = auth_result_extras.get("userID")
if user_id:
Expand Down Expand Up @@ -48,36 +48,78 @@ async def get_external_account_id(self, auth_result_extras: dict[str, str]) -> s
)

async def list_root_collections(self) -> ItemSampleResult:
"""
For Zotero this API call lists all libraries which user may access
"""
async with self.network.GET(
f"users/{self.config.external_account_id}/collections"
f"users/{self.config.external_account_id}/groups"
) as response:
collections = await response.json_content()
items = [
ItemResult(
item_id=collection["key"],
item_name=collection["data"].get("name", "Unnamed Collection"),
item_id=f'{collection["id"]}:',
item_name=collection["data"].get("name", "Unnamed Library"),
item_type=ItemType.COLLECTION,
)
for collection in collections
]
items.append(
ItemResult(
item_id="personal:",
item_name="My Library",
item_type=ItemType.COLLECTION,
)
)
return ItemSampleResult(items=items, total_count=len(items))

async def list_collection_items(
self,
collection_id: str,
filter_items: ItemType | None = None,
) -> ItemSampleResult:
library, collection = collection_id.split(":")
tasks = []
if filter_items != ItemType.COLLECTION:
tasks.append(self.fetch_collection_documents(library, collection))
if filter_items != ItemType.DOCUMENT:
tasks.append(self.fetch_subcollections(library, collection))
all_items = await join_list(tasks)
return ItemSampleResult(items=all_items, total_count=len(all_items))

async def fetch_subcollections(self, library, collection):
prefix = self.resolve_collection_prefix(library, collection)
async with self.network.GET(f"{prefix}/collections/top") as response:
items_json = await response.json_content()
return [
ItemResult(
item_id=f'{library}:{item["key"]}',
item_name=item["data"].get("name", "Unnamed title"),
item_type=ItemType.COLLECTION,
)
for item in items_json
]

async def fetch_collection_documents(self, library, collection):
prefix = self.resolve_collection_prefix(library, collection)
async with self.network.GET(
f"users/{self.config.external_account_id}/collections/{collection_id}/items",
f"{prefix}/items/top", query={"format": "csljson"}
) as response:
items_json = await response.json_content()
items = [
return [
ItemResult(
item_id=item["key"],
item_name=item["data"].get("title", "Unnamed title"),
item_id=f'{library}:{item["id"]}',
item_name=item.get("title", "Unnamed title"),
item_type=ItemType.DOCUMENT,
csl=item,
)
for item in items_json
if filter_items is None
for item in items_json["items"]
]
return ItemSampleResult(items=items, total_count=len(items))

def resolve_collection_prefix(self, library: str, collection):
if library == "personal":
prefix = f"users/{self.config.external_account_id}"
else:
prefix = f"groups/{library}"
if collection != "ROOT":
prefix = f"{prefix}/collections/{collection}"
return prefix
113 changes: 97 additions & 16 deletions addon_imps/tests/citations/test_mendeley.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import asyncio
import unittest
from unittest.mock import AsyncMock
from unittest.mock import (
AsyncMock,
create_autospec,
sentinel,
)

from addon_imps.citations.mendeley import MendeleyCitationImp
from addon_toolkit.constrained_network.http import HttpRequestor
Expand All @@ -12,8 +15,7 @@
)


class TestMendeleyCitationImp(unittest.TestCase):

class TestMendeleyCitationImp(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.config = CitationConfig(
external_api_url="https://api.mendeley.com",
Expand All @@ -25,18 +27,18 @@ def setUp(self):
config=self.config, network=self.network
)

def test_get_external_account_id(self):
async def test_get_external_account_id(self):
mock_response = {"id": "12345"}
self.mendeley_imp.network.GET.return_value.__aenter__.return_value.json_content = AsyncMock(
return_value=mock_response
)

result = asyncio.run(self.mendeley_imp.get_external_account_id({}))
result = await self.mendeley_imp.get_external_account_id({})

self.assertEqual(result, "12345")
self.mendeley_imp.network.GET.assert_called_with("profiles/me")

def test_list_root_collections(self):
async def test_list_root_collections(self):
mock_response_data = [
{"id": "1", "name": "Collection 1"},
{"id": "2", "name": "Collection 2"},
Expand All @@ -45,15 +47,14 @@ def test_list_root_collections(self):
return_value=mock_response_data
)

result = asyncio.run(self.mendeley_imp.list_root_collections())
result = await self.mendeley_imp.list_root_collections()

expected_result = ItemSampleResult(
items=[
ItemResult(
item_id="1", item_name="Collection 1", item_type=ItemType.COLLECTION
),
ItemResult(
item_id="2", item_name="Collection 2", item_type=ItemType.COLLECTION
item_id="ROOT",
item_name="All Documents",
item_type=ItemType.COLLECTION,
),
],
total_count=2,
Expand All @@ -63,9 +64,9 @@ def test_list_root_collections(self):
sorted(result.items, key=lambda x: x.item_id),
sorted(expected_result.items, key=lambda x: x.item_id),
)
self.mendeley_imp.network.GET.assert_called_with("folders")
self.mendeley_imp.network.GET.assert_not_called()

def test_list_collection_items(self):
async def test_fetch_collection_documents(self):
mock_document_ids = [{"id": "doc1"}, {"id": "doc2"}]
mock_doc1_details = {
"id": "doc1",
Expand Down Expand Up @@ -105,7 +106,7 @@ def test_list_collection_items(self):
),
]

result = asyncio.run(self.mendeley_imp.list_collection_items("folder_id"))
result = await self.mendeley_imp._fetch_collection_documents("folder_id")

expected_items = [
ItemResult(
Expand All @@ -130,6 +131,86 @@ def test_list_collection_items(self):
]

self.assertEqual(
sorted(result.items, key=lambda x: x.item_id),
sorted(result, key=lambda x: x.item_id),
sorted(expected_items, key=lambda x: x.item_id),
)

async def test_fetch_subcollections(self):
mock_response_data = [
{"id": "1", "name": "Collection 1", "parent_id": "collection_id"},
{"id": "2", "name": "Collection 2", "parent_id": "collection_id"},
{"id": "3", "name": "Collection 3"},
]
self.mendeley_imp.network.GET.return_value.__aenter__.return_value.json_content = AsyncMock(
return_value=mock_response_data
)

result = await self.mendeley_imp._fetch_subcollections("collection_id")

expected_result = ItemSampleResult(
items=[
ItemResult(
item_id="1", item_name="Collection 1", item_type=ItemType.COLLECTION
),
ItemResult(
item_id="2", item_name="Collection 2", item_type=ItemType.COLLECTION
),
],
total_count=2,
)

self.assertEqual(
sorted(result, key=lambda x: x.item_id),
sorted(expected_result.items, key=lambda x: x.item_id),
)
self.mendeley_imp.network.GET.assert_called_once_with("folders")

async def test_list_collection_items(self):
collections = [sentinel.collection1, sentinel.collection2]
documents = [sentinel.document1, sentinel.document2]
self.mendeley_imp._fetch_subcollections = create_autospec(
self.mendeley_imp._fetch_subcollections, return_value=collections
)
self.mendeley_imp._fetch_collection_documents = create_autospec(
self.mendeley_imp._fetch_collection_documents, return_value=documents
)
cases = [
[
ItemType.COLLECTION,
[self.mendeley_imp._fetch_subcollections],
[self.mendeley_imp._fetch_collection_documents],
ItemSampleResult(collections, total_count=2),
],
[
None,
[
self.mendeley_imp._fetch_subcollections,
self.mendeley_imp._fetch_collection_documents,
],
[],
ItemSampleResult(documents + collections, total_count=4),
],
[
ItemType.DOCUMENT,
[self.mendeley_imp._fetch_collection_documents],
[self.mendeley_imp._fetch_subcollections],
ItemSampleResult(documents, total_count=2),
],
]
for (
item_filter,
calls_to_be_made,
calls_not_to_be_made,
expected_result,
) in cases:
with self.subTest(item_filter):
result = await self.mendeley_imp.list_collection_items(
"collection_id", filter_items=item_filter
)
for call in calls_to_be_made:
call.assert_awaited_once_with("collection_id")
for call in calls_not_to_be_made:
call.assert_not_called()
self.assertEqual(result, expected_result)
self.mendeley_imp._fetch_subcollections.reset_mock()
self.mendeley_imp._fetch_collection_documents.reset_mock()
Loading

0 comments on commit 8581838

Please sign in to comment.