diff --git a/invenio_rdm_records/alembic/425b691f768b_create_collections_tables.py b/invenio_rdm_records/alembic/425b691f768b_create_collections_tables.py new file mode 100644 index 000000000..75068011e --- /dev/null +++ b/invenio_rdm_records/alembic/425b691f768b_create_collections_tables.py @@ -0,0 +1,96 @@ +# +# This file is part of Invenio. +# Copyright (C) 2016-2018 CERN. +# +# Invenio is free software; you can redistribute it and/or modify it +# under the terms of the MIT License; see LICENSE file for more details. + +"""Create collections tables.""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy_utils import UUIDType + +# revision identifiers, used by Alembic. +revision = "425b691f768b" +down_revision = "ff9bec971d30" +branch_labels = () +depends_on = None + + +def upgrade(): + """Upgrade database.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "collections_collection_tree", + sa.Column("created", sa.DateTime(), nullable=False), + sa.Column("updated", sa.DateTime(), nullable=False), + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("community_id", UUIDType(), nullable=True), + sa.Column("title", sa.String(length=255), nullable=False), + sa.Column("order", sa.Integer(), nullable=True), + sa.Column("slug", sa.String(length=255), nullable=False), + sa.ForeignKeyConstraint( + ["community_id"], + ["communities_metadata.id"], + name=op.f( + "fk_collections_collection_tree_community_id_communities_metadata" + ), + ondelete="SET NULL", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_collections_collection_tree")), + sa.UniqueConstraint( + "slug", + "community_id", + name="uq_collections_collection_tree_slug_community_id", + ), + ) + op.create_index( + op.f("ix_collections_collection_tree_community_id"), + "collections_collection_tree", + ["community_id"], + unique=False, + ) + + op.create_table( + "collections_collection", + sa.Column("created", sa.DateTime(), nullable=False), + sa.Column("updated", sa.DateTime(), nullable=False), + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("slug", sa.String(length=255), nullable=False), + sa.Column("path", sa.Text(), nullable=False), + sa.Column("tree_id", sa.Integer(), nullable=False), + sa.Column("title", sa.String(length=255), nullable=False), + sa.Column("query", sa.Text(), nullable=False), + sa.Column("order", sa.Integer(), nullable=True), + sa.Column("depth", sa.Integer(), nullable=True), + sa.Column("num_records", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["tree_id"], + ["collections_collection_tree.id"], + name=op.f("fk_collections_collection_tree_id_collections_collection_tree"), + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_collections_collection")), + sa.UniqueConstraint( + "slug", "tree_id", name="uq_collections_collection_slug_tree_id" + ), + ) + op.create_index( + op.f("ix_collections_collection_path"), + "collections_collection", + ["path"], + unique=False, + ) + + +def downgrade(): + """Downgrade database.""" + op.drop_index( + op.f("ix_collections_collection_path"), table_name="collections_collection" + ) + op.drop_table("collections_collection") + op.drop_index( + op.f("ix_collections_collection_tree_community_id"), + table_name="collections_collection_tree", + ) + op.drop_table("collections_collection_tree") diff --git a/invenio_rdm_records/collections/__init__.py b/invenio_rdm_records/collections/__init__.py new file mode 100644 index 000000000..abd0193da --- /dev/null +++ b/invenio_rdm_records/collections/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# Invenio-RDM is free software; you can redistribute it and/or modify +# it under the terms of the MIT License; see LICENSE file for more details. +"""Collections entrypoint.""" + +from .models import Collection, CollectionTree +from .searchapp import search_app_context + +__all__ = ( + "Collection", + "CollectionTree", + "search_app_context", +) diff --git a/invenio_rdm_records/collections/api.py b/invenio_rdm_records/collections/api.py new file mode 100644 index 000000000..68200f60a --- /dev/null +++ b/invenio_rdm_records/collections/api.py @@ -0,0 +1,346 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# Invenio-RDM is free software; you can redistribute it and/or modify +# it under the terms of the MIT License; see LICENSE file for more details. +"""Collections programmatic API.""" + +from types import ClassMethodDescriptorType + +from invenio_db import db +from sqlalchemy import select +from werkzeug.utils import cached_property + +from .models import Collection as CollectionModel +from .models import CollectionTree as CollectionTreeModel + + +class Collection: + """Collection Object.""" + + model_cls = CollectionModel + + def __init__(self, model=None): + """Instantiate a Collection object.""" + self.model = model + + @classmethod + def create(cls, slug, title, query, ctree=None, parent=None, order=None): + """Create a new collection.""" + _ctree = None + if parent: + path = f"{parent.path}{parent.id}," + _ctree = parent.collection_tree.model + elif ctree: + path = "," + _ctree = ctree if isinstance(ctree, int) else ctree.model + else: + raise ValueError("Either parent or ctree must be set.") + + return cls( + cls.model_cls.create( + slug=slug, + path=path, + title=title, + search_query=query, + order=order, + ctree_or_id=_ctree, + ) + ) + + @classmethod + def resolve(cls, id_, ctree_id=None, use_slug=False): + """Resolve a collection by ID or slug. + + To resolve by slug, the collection tree ID must be provided. + """ + if not use_slug: + return cls.get(id_) + if not ctree_id: + raise ValueError( + "Collection tree ID is required to resolve a collection by slug." + ) + return cls.get_by_slug(id_, ctree_id) + + @classmethod + def get(cls, id_): + """Get a collection by ID.""" + model = cls.model_cls.get(id_) + if not model: + return None + return cls(model) + + @classmethod + def get_by_slug(cls, slug, ctree_id): + """Get a collection by slug.""" + model = cls.model_cls.get_by_slug(slug, ctree_id) + if not model: + return None + return cls(model) + + def add( + self, + slug, + title, + query, + order=None, + ): + """Add a subcollection to the collection.""" + return self.create( + slug=slug, + title=title, + query=query, + parent=self, + order=order, + ) + + @property + def id(self): + """Get the collection ID.""" + return self.model.id + + @property + def path(self): + """Get the collection path.""" + return self.model.path + + @property + def ctree_id(self): + """Get the collection tree ID.""" + return self.model.tree_id + + @property + def order(self): + """Get the collection order.""" + return self.model.order + + @property + def title(self): + """Get the collection title.""" + return self.model.title + + @property + def ctree_title(self): + """Get the collection tree title.""" + return self.model.collection_tree.title + + @property + def collection_tree(self): + """Get the collection tree object. + + Note: this will execute a query to the collection tree table. + """ + return CollectionTree(self.model.collection_tree) + + @property + def depth(self): + """Get the collection depth in its tree.""" + return self.model.depth + + @property + def slug(self): + """Get the collection slug.""" + return self.model.slug + + @cached_property + def community(self): + """Get the community object.""" + return self.collection_tree.community + + @property + def query(self): + """Get the collection query.""" + q = "" + for _a in self.ancestors: + q += f"({_a.model.search_query}) AND " + q += f"({self.model.search_query})" + return q + + @cached_property + def ancestors(self): + """Get the collection ancestors.""" + if not self.model: + return None + + cps = self.path.split(",") + ret = [] + for cid in cps: + if not cid: + continue + cl = Collection.get(cid) + ret.append(cl) + return list(sorted(ret, key=lambda x: (x.path, x.order))) + + @cached_property + def sub_collections(self): + """Fetch all the descendants.""" + return self.get_subcollections() + + @cached_property + def direct_subcollections(self): + """Fetch only direct descendants.""" + return self.get_direct_subcollections() + + def get_direct_subcollections(self): + """Get the collection first level (direct) children. + + More preformant query to retrieve descendants, executes an exact match query. + """ + if not self.model: + return None + stmt = ( + select(self.model_cls) + .filter( + self.model_cls.path == f"{self.path}{self.id},", + self.model_cls.tree_id == self.ctree_id, + ) + .order_by(self.model_cls.path, self.model_cls.order) + ) + ret = db.session.execute(stmt).scalars().all() + return [type(self)(r) for r in ret] + + def get_subcollections(self, max_depth=3): + """Get the collection subcollections. + + This query executes a LIKE query on the path column. + """ + if not self.model: + return None + + stmt = ( + select(self.model_cls) + .filter( + self.model_cls.path.like(f"{self.path}{self.id},%"), + self.model_cls.depth < self.model.depth + max_depth, + ) + .order_by(self.model_cls.path, self.model_cls.order) + ) + ret = db.session.execute(stmt).scalars().all() + return [type(self)(r) for r in ret] + + @classmethod + def dump(cls, collection): + """Transform the collection into a dictionary.""" + res = { + "title": collection.title, + "slug": collection.slug, + "depth": collection.depth, + "order": collection.order, + "id": collection.id, + "query": collection.query, + } + return res + + def to_dict(self) -> dict: + """Return a dictionary representation of the collection. + + Uses an adjacency list. + """ + ret = { + "root": self.id, + self.id: {**Collection.dump(self), "children": set()}, + } + + for _c in self.sub_collections: + # Add the collection itself to the dictionary + if _c.id not in ret: + ret[_c.id] = {**Collection.dump(_c), "children": set()} + + # Find the parent ID from the collection's path (last valid ID in the path) + path_parts = [int(part) for part in _c.path.split(",") if part.strip()] + if path_parts: + parent_id = path_parts[-1] + # Add the collection as a child of its parent + ret[parent_id]["children"].add(_c.id) + for k, v in ret.items(): + if isinstance(v, dict): + v["children"] = list(v["children"]) + return ret + + def __repr__(self) -> str: + """Return a string representation of the collection.""" + if self.model: + return f"Collection {self.id} ({self.path})" + else: + return "Collection (None)" + + +class CollectionTree: + """Collection Tree Object.""" + + model_cls = CollectionTreeModel + + def __init__(self, model): + """Instantiate a CollectionTree object.""" + self.model = model + + @classmethod + def create(cls, title, slug, community_id=None, order=None): + """Create a new collection tree.""" + return cls( + cls.model_cls.create( + title=title, slug=slug, community_id=community_id, order=order + ) + ) + + @classmethod + def resolve(cls, id_, community_id=None, use_slug=False): + """Resolve a CollectionTree.""" + if not use_slug: + return cls.get(id_) + + if not community_id: + raise ValueError( + "Community ID is required to resolve a collection tree by slug." + ) + return cls.get_by_slug(id_, community_id) + + @classmethod + def get(cls, id_): + """Get a collection tree by ID.""" + model = cls.model_cls.get(id_) + if not model: + return None + return cls(model) + + @classmethod + def get_by_slug(cls, slug, community_id): + """Get a collection tree by slug. + + Community ID is required to avoid ambiguity. + """ + model = cls.model_cls.get_by_slug(slug, community_id) + if not model: + return None + return cls(model) + + @property + def id(self): + """Get the collection tree ID.""" + return self.model.id + + @property + def title(self): + """Get the collection tree title.""" + return self.model.title + + @property + def slug(self): + """Get the collection tree slug.""" + return self.model.slug + + @property + def community_id(self): + """Get the community ID.""" + return self.model.community_id + + @property + def order(self): + """Get the collection tree order.""" + return self.model.order + + @property + def community(self): + """Get the community object.""" + return self.model.community diff --git a/invenio_rdm_records/collections/models.py b/invenio_rdm_records/collections/models.py new file mode 100644 index 000000000..d3f510dba --- /dev/null +++ b/invenio_rdm_records/collections/models.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# Invenio-RDM-records is free software; you can redistribute it and/or modify +# it under the terms of the MIT License; see LICENSE file for more details. +"""Collections models.""" + +from invenio_communities.communities.records.models import CommunityMetadata +from invenio_db import db +from invenio_records.models import Timestamp +from sqlalchemy import UniqueConstraint +from sqlalchemy_utils.types import UUIDType +from traitlets import ClassBasedTraitType + + +# CollectionTree Table +class CollectionTree(db.Model, Timestamp): + """Collection tree model.""" + + __tablename__ = "collections_collection_tree" + + __table_args__ = ( + # Unique constraint on slug and community_id. Slugs should be unique within a community. + UniqueConstraint( + "slug", + "community_id", + name="uq_collections_collection_tree_slug_community_id", + ), + ) + + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + community_id = db.Column( + UUIDType, + db.ForeignKey(CommunityMetadata.id, ondelete="SET NULL"), + nullable=True, + index=True, + ) + title = db.Column(db.String(255), nullable=False) + order = db.Column(db.Integer, nullable=True) + slug = db.Column(db.String(255), nullable=False) + + # Relationship to Collection + collections = db.relationship("Collection", back_populates="collection_tree") + community = db.relationship(CommunityMetadata, backref="collection_trees") + + @classmethod + def create(cls, title, slug, community_id=None, order=None): + """Create a new collection tree.""" + with db.session.begin_nested(): + collection_tree = cls( + title=title, slug=slug, community_id=community_id, order=order + ) + db.session.add(collection_tree) + return collection_tree + + @classmethod + def get(cls, id_): + """Get a collection tree by ID.""" + return cls.query.get(id_) + + @classmethod + def get_by_slug(cls, slug, community_id): + """Get a collection tree by slug.""" + return cls.query.filter( + cls.slug == slug, cls.community_id == community_id + ).one_or_none() + + +# Collection Table +class Collection(db.Model, Timestamp): + """Collection model. + + Indices: + - id + - collection_tree_id + - path + - slug + """ + + __tablename__ = "collections_collection" + __table_args__ = ( + # Unique constraint on slug and tree_id. Slugs should be unique within a tree. + UniqueConstraint( + "slug", "tree_id", name="uq_collections_collection_slug_tree_id" + ), + ) + + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + slug = db.Column(db.String(255), nullable=False) + path = db.Column(db.Text, nullable=False, index=True) + tree_id = db.Column( + db.Integer, db.ForeignKey("collections_collection_tree.id"), nullable=False + ) + title = db.Column(db.String(255), nullable=False) + search_query = db.Column("query", db.Text, nullable=False) + order = db.Column(db.Integer, nullable=True) + # TODO index depth + depth = db.Column(db.Integer, nullable=False) + num_records = db.Column(db.Integer, nullable=True) + + # Relationship to CollectionTree + collection_tree = db.relationship("CollectionTree", back_populates="collections") + + @classmethod + def create(cls, slug, path, title, search_query, ctree_or_id, **kwargs): + """Create a new collection.""" + depth = len([int(part) for part in path.split(",") if part.strip()]) + with db.session.begin_nested(): + if isinstance(ctree_or_id, CollectionTree): + collection = cls( + slug=slug, + path=path, + title=title, + search_query=search_query, + collection_tree=ctree_or_id, + depth=depth, + **kwargs, + ) + elif isinstance(ctree_or_id, int): + collection = cls( + slug=slug, + path=path, + title=title, + search_query=search_query, + tree_id=ctree_or_id, + depth=depth, + **kwargs, + ) + else: + raise ValueError( + "Either `collection_tree` or `collection_tree_id` must be provided." + ) + db.session.add(collection) + return collection + + @classmethod + def get(cls, id_): + """Get a collection by ID.""" + return cls.query.get(id_) + + @classmethod + def get_by_slug(cls, slug, tree_id): + """Get a collection by slug.""" + return cls.query.filter(cls.slug == slug, cls.tree_id == tree_id).one_or_none() diff --git a/invenio_rdm_records/collections/searchapp.py b/invenio_rdm_records/collections/searchapp.py new file mode 100644 index 000000000..a48e2699a --- /dev/null +++ b/invenio_rdm_records/collections/searchapp.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# Invenio-RDM is free software; you can redistribute it and/or modify +# it under the terms of the MIT License; see LICENSE file for more details. +"""Collection search app helpers for React-SearchKit.""" + + +from functools import partial + +from flask import current_app +from invenio_search_ui.searchconfig import search_app_config + + +def search_app_context(): + """Search app context.""" + return { + "search_app_collection_config": partial( + search_app_config, + config_name="RDM_SEARCH", + available_facets=current_app.config["RDM_FACETS"], + sort_options=current_app.config["RDM_SORT_OPTIONS"], + headers={"Accept": "application/vnd.inveniordm.v1+json"}, + pagination_options=(10, 25, 50, 100), + # endpoint=/communities/eu/records + # endpoint=/api/records + # hidden_params=[ + # ["q", collection.query] + # ] + ) + } diff --git a/invenio_rdm_records/collections/service.py b/invenio_rdm_records/collections/service.py new file mode 100644 index 000000000..98b7f9e8f --- /dev/null +++ b/invenio_rdm_records/collections/service.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# Invenio-RDM is free software; you can redistribute it and/or modify +# it under the terms of the MIT License; see LICENSE file for more details. +"""Collections service.""" + +from invenio_communities.proxies import current_communities +from invenio_records_resources.services.uow import ModelCommitOp, unit_of_work + +from .api import Collection, CollectionTree + + +class CollectionItem: + """Collection item.""" + + def __init__(self, collection): + """Instantiate a Collection object.""" + self._collection = collection + + def to_dict(self): + """Serialize the collection to a dictionary and add links.""" + res = {**self._collection.to_dict()} + res[self._collection.id]["links"] = self.links + + return res + + @property + def links(self): + """Return the links of the collection.""" + self_html = None + search = None + tree_slug = self._collection.collection_tree.slug + if self._collection.community: + self_html = f"/communities/{self._collection.community.slug}/{tree_slug}/{self._collection.slug}" + search = f"/api/communities/{self._collection.community.slug}/records" + else: + self_html = f"/collections/{tree_slug}/{self._collection.slug}" + search = "/api/records" + return { + "search": search, + "self_html": self_html, + } + + @property + def community(self): + """Get the collection community.""" + return self._collection.community + + @property + def query(self): + """Get the collection query.""" + return self._collection.query + + @property + def title(self): + """Get the collection title.""" + return self._collection.title + + +class CollectionsService: + """Collections service.""" + + collection_cls = Collection + + @unit_of_work() + def create( + self, identity, community_id, tree_slug, slug, title, query, uow=None, **kwargs + ): + """Create a new collection.""" + current_communities.service.require_permission( + identity, "update", community_id=community_id + ) + ctree = CollectionTree.get_by_slug(tree_slug, community_id) + if not ctree: + raise ValueError(f"Collection tree {tree_slug} not found.") + collection = self.collection_cls.create( + slug=slug, title=title, query=query, ctree=ctree, **kwargs + ) + uow.register(ModelCommitOp(collection.model)) + return CollectionItem(collection) + + def read(self, identity, id_): + """Get a collection by ID or slug.""" + collection = self.collection_cls.resolve(id_) + if not collection: + raise ValueError(f"Collection {id_} not found.") + if collection.community: + current_communities.service.require_permission( + identity, "read", community_id=collection.community.id + ) + + return CollectionItem(collection) + + def read_slug(self, identity, community_id, tree_slug, slug): + """Get a collection by slug.""" + current_communities.service.require_permission( + identity, "read", community_id=community_id + ) + + ctree = CollectionTree.get_by_slug(tree_slug, community_id) + if not ctree: + raise ValueError(f"Collection tree {tree_slug} not found.") + + collection = self.collection_cls.resolve(slug, ctree.id, use_slug=True) + if not collection: + raise ValueError(f"Collection {slug} not found.") + + return CollectionItem(collection) + + @unit_of_work() + def add(self, identity, collection, slug, title, query, uow=None, **kwargs): + """Add a subcollection to a collection.""" + current_communities.service.require_permission( + identity, "update", community_id=collection.community.id + ) + new_collection = self.collection_cls.create( + parent=collection, slug=slug, title=title, query=query, **kwargs + ) + uow.register(ModelCommitOp(new_collection.model)) + return CollectionItem(new_collection) diff --git a/invenio_rdm_records/ext.py b/invenio_rdm_records/ext.py index b9802d80f..470c9ef42 100644 --- a/invenio_rdm_records/ext.py +++ b/invenio_rdm_records/ext.py @@ -19,6 +19,7 @@ from invenio_records_resources.resources.files import FileResource from . import config +from .collections.service import CollectionsService from .oaiserver.resources.config import OAIPMHServerResourceConfig from .oaiserver.resources.resources import OAIPMHServerResource from .oaiserver.services.config import OAIPMHServerServiceConfig @@ -203,6 +204,8 @@ def init_services(self, app): config=service_configs.oaipmh_server, ) + self.collections_service = CollectionsService() + def init_resource(self, app): """Initialize resources.""" self.records_resource = RDMRecordResource( diff --git a/setup.cfg b/setup.cfg index e53590945..48663f372 100644 --- a/setup.cfg +++ b/setup.cfg @@ -111,6 +111,7 @@ invenio_celery.tasks = invenio_rdm_records_user_moderation = invenio_rdm_records.requests.user_moderation.tasks invenio_db.models = invenio_rdm_records = invenio_rdm_records.records.models + invenio_rdm_records_collections = invenio_rdm_records.collections.models invenio_db.alembic = invenio_rdm_records = invenio_rdm_records:alembic invenio_jsonschemas.schemas = diff --git a/tests/collections/__init__.py b/tests/collections/__init__.py new file mode 100644 index 000000000..03664b765 --- /dev/null +++ b/tests/collections/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# Invenio-RDM is free software; you can redistribute it and/or modify +# it under the terms of the MIT License; see LICENSE file for more details. +"""Tests for collections.""" diff --git a/tests/collections/test_collections_api.py b/tests/collections/test_collections_api.py new file mode 100644 index 000000000..d6b715a80 --- /dev/null +++ b/tests/collections/test_collections_api.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# Invenio-RDM is free software; you can redistribute it and/or modify +# it under the terms of the MIT License; see LICENSE file for more details. +"""Test suite for the collections programmatic API.""" + +from invenio_rdm_records.collections.api import Collection, CollectionTree + + +def test_create(running_app, db, community, community_owner): + """Test collection creation via API.""" + tree = CollectionTree.create( + title="Tree 1", + order=10, + community_id=community.id, + slug="tree-1", + ) + + # Use ORM object (collection tree) + collection = Collection.create( + title="My Collection", + query="*:*", + slug="my-collection", + ctree=tree, + ) + + assert collection.id + assert collection.title == "My Collection" + assert collection.collection_tree.id == tree.id + + # Use collection tree id + collection = Collection.create( + title="My Collection 2", + query="*:*", + slug="my-collection-2", + ctree=tree.id, + ) + assert collection.id + assert collection.title == "My Collection 2" + assert collection.collection_tree.id == tree.id + + +def test_as_dict(running_app, db): + """Test collection as dict.""" + tree = CollectionTree.create( + title="Tree 1", + order=10, + slug="tree-1", + ) + c1 = Collection.create( + title="My Collection", + query="*:*", + slug="my-collection", + ctree=tree, + ) + + c2 = Collection.create( + title="My Collection 2", + query="*:*", + slug="my-collection-2", + parent=c1, + ) + + c3 = Collection.create(title="3", query="*", slug="my-collection-3", parent=c2) + res = c1.to_dict() + assert all(k in res for k in (c1.id, c2.id, c3.id)) + assert res[c1.id]["title"] == "My Collection" + assert res[c1.id]["children"] == [c2.id] + assert res[c2.id]["children"] == [c3.id] + assert res[c3.id]["children"] == [] diff --git a/tests/services/test_collections_service.py b/tests/services/test_collections_service.py new file mode 100644 index 000000000..cf97becdb --- /dev/null +++ b/tests/services/test_collections_service.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# Invenio-RDM is free software; you can redistribute it and/or modify +# it under the terms of the MIT License; see LICENSE file for more details. +""" TODO """ + +import pytest + +from invenio_rdm_records.collections.api import Collection, CollectionTree +from invenio_rdm_records.proxies import current_rdm_records + + +@pytest.fixture() +def collections_service(): + """Get collections service fixture.""" + return current_rdm_records.collections_service + + +@pytest.fixture(autouse=True) +def add_collections(running_app, db, community): + """Create collections on demand.""" + + def _inner(): + """Add collections to the app.""" + tree = CollectionTree.create( + title="Tree 1", + order=10, + community_id=community.id, + slug="tree-1", + ) + c1 = Collection.create( + title="Collection 1", query="*:*", slug="collection-1", ctree=tree + ) + c2 = Collection.create( + title="Collection 2", + query="*:*", + slug="collection-2", + ctree=tree, + parent=c1, + ) + return [c1, c2] + + return _inner + + +def test_collections_service_read( + running_app, db, add_collections, collections_service, community_owner +): + """Test collections service.""" + collections = add_collections() + c0 = collections[0] + c1 = collections[1] + res = collections_service.read(community_owner.identity, c0.id) + assert res._collection.id == c0.id