diff --git a/grapple/types/structures.py b/grapple/types/structures.py index 3262c4b2..9c723db0 100644 --- a/grapple/types/structures.py +++ b/grapple/types/structures.py @@ -21,6 +21,21 @@ def parse_literal(ast, _variables=None): return return_value +class SearchOperatorEnum(graphene.Enum): + """ + Enum for search operator. + """ + + AND = "and" + OR = "or" + + def __str__(self): + # the core search parser expects the operator to be a string. + # the default __str__ returns SearchOperatorEnum.AND/OR, + # this __str__ returns the value and/or for compatibility. + return self.value + + class QuerySetList(graphene.List): """ List type with arguments used by Django's query sets. @@ -31,6 +46,8 @@ class QuerySetList(graphene.List): * ``limit`` * ``offset`` * ``search_query`` + * ``search_operator`` + * ``search_fields`` * ``order`` :param enable_limit: Enable limit argument. @@ -39,6 +56,10 @@ class QuerySetList(graphene.List): :type enable_offset: bool :param enable_search: Enable search query argument. :type enable_search: bool + :param enable_search_fields: Enable search fields argument, enable_search must also be True + :type enable_search_fields: bool + :param enable_search_operator: Enable search operator argument, enable_search must also be True + :type enable_search_operator: bool :param enable_order: Enable ordering via query argument. :type enable_order: bool """ @@ -46,8 +67,10 @@ class QuerySetList(graphene.List): def __init__(self, of_type, *args, **kwargs): enable_limit = kwargs.pop("enable_limit", True) enable_offset = kwargs.pop("enable_offset", True) - enable_search = kwargs.pop("enable_search", True) enable_order = kwargs.pop("enable_order", True) + enable_search = kwargs.pop("enable_search", True) + enable_search_fields = kwargs.pop("enable_search_fields", True) + enable_search_operator = kwargs.pop("enable_search_operator", True) # Check if the type is a Django model type. Do not perform the # check if value is lazy. @@ -92,6 +115,22 @@ def __init__(self, of_type, *args, **kwargs): graphene.String, description=_("Filter the results using Wagtail's search."), ) + if enable_search_operator: + kwargs["search_operator"] = graphene.Argument( + SearchOperatorEnum, + description=_( + "Specify search operator (and/or), see: https://docs.wagtail.org/en/stable/topics/search/searching.html#search-operator" + ), + default_value="and", + ) + + if enable_search_fields: + kwargs["search_fields"] = graphene.Argument( + graphene.List(graphene.String), + description=_( + "A list of fields to search in. see: https://docs.wagtail.org/en/stable/topics/search/searching.html#specifying-the-fields-to-search" + ), + ) if "id" not in kwargs: kwargs["id"] = graphene.Argument(graphene.ID, description=_("Filter by ID")) @@ -138,21 +177,29 @@ def PaginatedQuerySet(of_type, type_class, **kwargs): """ Paginated QuerySet type with arguments used by Django's query sets. - This type setts the following arguments on itself: + This type sets the following arguments on itself: * ``id`` * ``page`` * ``per_page`` * ``search_query`` + * ``search_operator`` + * ``search_fields`` * ``order`` :param enable_search: Enable search query argument. :type enable_search: bool + :param enable_search_fields: Enable search fields argument, enable_search must also be True + :type enable_search_fields: bool + :param enable_search_operator: Enable search operator argument, enable_search must also be True + :type enable_search_operator: bool :param enable_order: Enable ordering via query argument. :type enable_order: bool """ enable_search = kwargs.pop("enable_search", True) + enable_search_fields = kwargs.pop("enable_search_fields", True) + enable_search_operator = kwargs.pop("enable_search_operator", True) enable_order = kwargs.pop("enable_order", True) required = kwargs.get("required", False) type_name = type_class if isinstance(type_class, str) else type_class.__name__ @@ -199,6 +246,22 @@ def PaginatedQuerySet(of_type, type_class, **kwargs): kwargs["search_query"] = graphene.Argument( graphene.String, description=_("Filter the results using Wagtail's search.") ) + if enable_search_operator: + kwargs["search_operator"] = graphene.Argument( + SearchOperatorEnum, + description=_( + "Specify search operator (and/or), see: https://docs.wagtail.org/en/stable/topics/search/searching.html#search-operator" + ), + default_value="and", + ) + + if enable_search_fields: + kwargs["search_fields"] = graphene.Argument( + graphene.List(graphene.String), + description=_( + "A comma-separated list of fields to search in. see: https://docs.wagtail.org/en/stable/topics/search/searching.html#specifying-the-fields-to-search" + ), + ) if "id" not in kwargs: kwargs["id"] = graphene.Argument(graphene.ID, description=_("Filter by ID")) diff --git a/grapple/utils.py b/grapple/utils.py index ede62155..7e2fc874 100644 --- a/grapple/utils.py +++ b/grapple/utils.py @@ -8,6 +8,7 @@ from wagtail import VERSION as WAGTAIL_VERSION from wagtail.models import Site from wagtail.search.index import class_is_indexed +from wagtail.search.utils import parse_query_string from .settings import grapple_settings from .types.structures import BasePaginatedType, PaginationType @@ -100,6 +101,8 @@ def resolve_queryset( id=None, order=None, collection=None, + search_operator="and", + search_fields=None, **kwargs, ): """ @@ -121,6 +124,11 @@ def resolve_queryset( :type order: str :param collection: Use Wagtail's collection id to filter images or documents :type collection: int + :param search_operator: The operator to use when combining search terms. + Defaults to "and". + :type search_operator: "and" | "or" + :param search_fields: A list of fields to search. Defaults to all fields. + :type search_fields: list """ qs = qs.all() if id is None else qs.filter(pk=id) @@ -147,7 +155,14 @@ def resolve_queryset( query = Query.get(search_query) query.add_hit() - qs = qs.search(search_query, order_by_relevance=order_by_relevance) + filters, parsed_query = parse_query_string(search_query, str(search_operator)) + + qs = qs.search( + parsed_query, + order_by_relevance=order_by_relevance, + operator=search_operator, + fields=search_fields, + ) if connection.vendor != "sqlite": qs = qs.annotate_score("search_score") @@ -178,9 +193,9 @@ def get_paginated_result(qs, page, per_page): count=len(page_obj.object_list), per_page=per_page, current_page=page_obj.number, - prev_page=page_obj.previous_page_number() - if page_obj.has_previous() - else None, + prev_page=( + page_obj.previous_page_number() if page_obj.has_previous() else None + ), next_page=page_obj.next_page_number() if page_obj.has_next() else None, total_pages=paginator.num_pages, ), @@ -188,7 +203,16 @@ def get_paginated_result(qs, page, per_page): def resolve_paginated_queryset( - qs, info, page=None, per_page=None, search_query=None, id=None, order=None, **kwargs + qs, + info, + page=None, + per_page=None, + id=None, + order=None, + search_query=None, + search_operator="and", + search_fields=None, + **kwargs, ): """ Add page, per_page and search capabilities to the query. This contains @@ -202,11 +226,16 @@ def resolve_paginated_queryset( :type id: int :param per_page: The maximum number of items to include on a page. :type per_page: int + :param order: Order the query set using the Django QuerySet order_by format. + :type order: str :param search_query: Using Wagtail search, exclude objects that do not match the search query. :type search_query: str - :param order: Order the query set using the Django QuerySet order_by format. - :type order: str + :param search_operator: The operator to use when combining search terms. + Defaults to "and". + :type search_operator: "and" | "or" + :param search_fields: A list of fields to search. Defaults to all fields. + :type search_fields: list """ page = int(page or 1) per_page = min( @@ -231,7 +260,14 @@ def resolve_paginated_queryset( query = Query.get(search_query) query.add_hit() - qs = qs.search(search_query, order_by_relevance=order_by_relevance) + filters, parsed_query = parse_query_string(search_query, search_operator) + + qs = qs.search( + parsed_query, + order_by_relevance=order_by_relevance, + operator=search_operator, + fields=search_fields, + ) if connection.vendor != "sqlite": qs = qs.annotate_score("search_score") diff --git a/tests/test_grapple.py b/tests/test_grapple.py index 1c2445f9..077e7265 100644 --- a/tests/test_grapple.py +++ b/tests/test_grapple.py @@ -1,3 +1,4 @@ +from pprint import pprint import unittest from pydoc import locate @@ -474,18 +475,18 @@ class PagesSearchTest(BaseGrappleTest): @classmethod def setUpTestData(cls): cls.home = HomePage.objects.first() - BlogPageFactory(title="Alpha", parent=cls.home) - BlogPageFactory(title="Alpha Alpha", parent=cls.home) - BlogPageFactory(title="Alpha Beta", parent=cls.home) - BlogPageFactory(title="Alpha Gamma", parent=cls.home) - BlogPageFactory(title="Beta", parent=cls.home) - BlogPageFactory(title="Beta Alpha", parent=cls.home) - BlogPageFactory(title="Beta Beta", parent=cls.home) - BlogPageFactory(title="Beta Gamma", parent=cls.home) - BlogPageFactory(title="Gamma", parent=cls.home) - BlogPageFactory(title="Gamma Alpha", parent=cls.home) - BlogPageFactory(title="Gamma Beta", parent=cls.home) - BlogPageFactory(title="Gamma Gamma", parent=cls.home) + BlogPageFactory(title="Alpha", body=[("heading", "Sigma")], parent=cls.home) + BlogPageFactory(title="Alpha Alpha", body=[("heading", "Sigma Sigma")], parent=cls.home) + BlogPageFactory(title="Alpha Beta", body=[("heading", "Sigma Theta")], parent=cls.home) + BlogPageFactory(title="Alpha Gamma", body=[("heading", "Sigma Delta")], parent=cls.home) + BlogPageFactory(title="Beta", body=[("heading", "Theta")], parent=cls.home) + BlogPageFactory(title="Beta Alpha", body=[("heading", "Theta Sigma")], parent=cls.home) + BlogPageFactory(title="Beta Beta", body=[("heading", "Theta Theta")], parent=cls.home) + BlogPageFactory(title="Beta Gamma", body=[("heading", "Theta Delta")], parent=cls.home) + BlogPageFactory(title="Gamma", body=[("heading", "Delta")], parent=cls.home) + BlogPageFactory(title="Gamma Alpha", body=[("heading", "Delta Sigma")], parent=cls.home) + BlogPageFactory(title="Gamma Beta", body=[("heading", "Delta Theta")], parent=cls.home) + BlogPageFactory(title="Gamma Gamma", body=[("heading", "Delta Delta")], parent=cls.home) @unittest.skipIf( connection.vendor != "sqlite", @@ -530,7 +531,6 @@ def test_searchQuery_order_by_relevance(self): } } """ - executed = self.client.execute(query, variables={"searchQuery": "Alpha"}) page_data = executed["data"].get("pages") self.assertEqual(len(page_data), 6) @@ -559,7 +559,6 @@ def test_explicit_order(self): query, variables={"searchQuery": "Gamma", "order": "-title"} ) page_data = executed["data"].get("pages") - self.assertEqual(len(page_data), 6) self.assertEqual(page_data[0]["title"], "Gamma Gamma") self.assertEqual(page_data[1]["title"], "Gamma Beta") @@ -568,6 +567,110 @@ def test_explicit_order(self): self.assertEqual(page_data[4]["title"], "Beta Gamma") self.assertEqual(page_data[5]["title"], "Alpha Gamma") + def test_search_operator_default(self): + """default operator is and""" + query = """ + query($searchQuery: String) { + pages(searchQuery: $searchQuery) { + title + searchScore + } + } + """ + executed = self.client.execute(query, variables={"searchQuery": "Alpha Beta"}) + page_data = executed["data"].get("pages") + self.assertEqual(len(page_data), 2) + self.assertEqual(page_data[0]["title"], "Alpha Beta") + self.assertEqual(page_data[1]["title"], "Beta Alpha") + + def test_search_operator_and(self): + query = """ + query($searchQuery: String, $searchOperator: SearchOperatorEnum) { + pages(searchQuery: $searchQuery, searchOperator: $searchOperator) { + title + searchScore + } + } + """ + executed = self.client.execute( + query, variables={"searchQuery": "Alpha Beta", "searchOperator": "AND"} + ) + page_data = executed["data"].get("pages") + self.assertEqual(len(page_data), 2) + self.assertEqual(page_data[0]["title"], "Alpha Beta") + self.assertEqual(page_data[1]["title"], "Beta Alpha") + + def test_search_operator_or(self): + query = """ + query($searchQuery: String, $searchOperator: SearchOperatorEnum) { + pages(searchQuery: $searchQuery, searchOperator: $searchOperator) { + title + searchScore + } + } + """ + executed = self.client.execute( + query, variables={"searchQuery": "Alpha Beta", "searchOperator": "OR"} + ) + page_data = executed["data"].get("pages") + self.assertEqual(len(page_data), 10) + self.assertEqual(page_data[0]["title"], "Alpha") + self.assertEqual(page_data[1]["title"], "Alpha Alpha") + self.assertEqual(page_data[2]["title"], "Alpha Beta") + self.assertEqual(page_data[3]["title"], "Alpha Gamma") + self.assertEqual(page_data[4]["title"], "Beta") + self.assertEqual(page_data[5]["title"], "Beta Alpha") + self.assertEqual(page_data[6]["title"], "Beta Beta") + self.assertEqual(page_data[7]["title"], "Beta Gamma") + self.assertEqual(page_data[8]["title"], "Gamma Alpha") + self.assertEqual(page_data[9]["title"], "Gamma Beta") + + def test_search_fields(self): + query = """ + query($searchQuery: String, $searchFields: [String]) { + pages(searchQuery: $searchQuery, searchFields: $searchFields, contentType: "testapp.BlogPage") { + title + searchScore + } + } + """ + executed = self.client.execute( + query, + variables={"searchQuery": "Sigma", "searchFields": "body"}, + ) + print(executed) + page_data = executed["data"].get("pages") + self.assertEqual(len(page_data), 6) + self.assertEqual(page_data[0]["title"], "Alpha") + self.assertEqual(page_data[1]["title"], "Alpha Alpha") + self.assertEqual(page_data[2]["title"], "Alpha Beta") + self.assertEqual(page_data[3]["title"], "Alpha Gamma") + self.assertEqual(page_data[4]["title"], "Beta Alpha") + self.assertEqual(page_data[5]["title"], "Gamma Alpha") + + def test_search_fields_filter(self): + query = """ + query($searchQuery: String) { + pages(searchQuery: $searchQuery) { + title + searchScore + } + } + """ + executed = self.client.execute( + query, + variables={"searchQuery": "Beta Sigma fields:body"} , + ) + pprint(executed) + page_data = executed["data"].get("pages") + self.assertEqual(len(page_data), 6) + self.assertEqual(page_data[0]["title"], "Alpha") + self.assertEqual(page_data[1]["title"], "Alpha Alpha") + self.assertEqual(page_data[2]["title"], "Alpha Beta") + self.assertEqual(page_data[3]["title"], "Alpha Gamma") + self.assertEqual(page_data[4]["title"], "Beta Alpha") + self.assertEqual(page_data[5]["title"], "Gamma Alpha") + class PageUrlPathTest(BaseGrappleTest): def _query_by_path(self, path, *, in_site=False): diff --git a/tests/testapp/models/core.py b/tests/testapp/models/core.py index e010333f..e20f2fa5 100644 --- a/tests/testapp/models/core.py +++ b/tests/testapp/models/core.py @@ -15,6 +15,7 @@ from wagtail.fields import RichTextField, StreamField from wagtail.models import Orderable, Page from wagtail.snippets.models import register_snippet +from wagtail.search import index from wagtail_headless_preview.models import HeadlessPreviewMixin from wagtailmedia.edit_handlers import MediaChooserPanel @@ -163,6 +164,8 @@ def custom_property(self): "author": self.author.name if self.author else "Unknown", } + search_fields = Page.search_fields + [ index.SearchField('body') ] + graphql_fields = [ GraphQLString("date", required=True), GraphQLRichText("summary"),