diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 5bd4f87..94d0cd3 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -8,7 +8,8 @@ jobs: steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 - - uses: psf/black@stable + - run: pip install black + - run: black --check . flake8: runs-on: ubuntu-latest diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index af9e13a..7c9ce76 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,7 +16,7 @@ jobs: - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: - python-version: "3.6" + python-version: "3.8" - name: install run: | diff --git a/eve_elastic/elastic.py b/eve_elastic/elastic.py index a69381e..09e521f 100644 --- a/eve_elastic/elastic.py +++ b/eve_elastic/elastic.py @@ -216,7 +216,12 @@ def fix_query(query, top=True, context=None): elif key == "query_string": new_query[key] = val val.setdefault("lenient", True) - elif key == "query" and not top and context != "aggs": + elif ( + key == "query" + and not top + and context != "aggs" + and not isinstance(val, str) + ): new_query["bool"] = {"must": fix_query(val, top=False, context=context)} elif top: new_query[key] = fix_query(val, top=False, context=key) @@ -629,7 +634,8 @@ def find(self, resource, req, sub_resource_lookup, **kwargs): else: raise - return self._parse_hits(hits, resource) + cursor = self._parse_hits(hits, resource) + return cursor, cursor.count() def should_aggregate(self, req): """Check the environment variable and the given argument parameter to decide if aggregations needed. diff --git a/requirements.txt b/requirements.txt index 8242b74..92e00c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ nose -black mock -werkzeug<1.0 -eve>=0.4,<0.9 +werkzeug>=1.0,<1.1 +eve>1.1,<1.2 +MarkupSafe>2.0,<2.1 \ No newline at end of file diff --git a/test/test_elastic.py b/test/test_elastic.py index 1a6df39..0f68fda 100644 --- a/test/test_elastic.py +++ b/test/test_elastic.py @@ -261,16 +261,14 @@ def test_query_filter_with_filter_dsl_and_schema_filter(self): with self.app.app_context(): req = ParsedRequest() req.args = {"filter": json.dumps(query_filter)} - self.assertEqual( - 1, self.app.data.find("items_with_description", req, None).count() - ) + cursor, count = self.app.data.find("items_with_description", req, None) + self.assertEqual(1, count) with self.app.app_context(): req = ParsedRequest() req.args = {"q": "bar", "filter": json.dumps(query_filter)} - self.assertEqual( - 0, self.app.data.find("items_with_description", req, None).count() - ) + cursor, count = self.app.data.find("items_with_description", req, None) + self.assertEqual(0, count) def test_find_one_by_id(self): """elastic 1.0+ is using 'found' property instead of 'exists'""" @@ -304,7 +302,7 @@ def test_search_via_source_param(self): self.app.data.insert("items", [{"uri": "bar", "name": "bar"}]) req = ParsedRequest() req.args = {"source": json.dumps(query)} - res = self.app.data.find("items", req, None) + res, count = self.app.data.find("items", req, None) self.assertEqual(1, res.count()) def test_search_via_source_param_and_schema_filter(self): @@ -319,7 +317,7 @@ def test_search_via_source_param_and_schema_filter(self): ) req = ParsedRequest() req.args = {"source": json.dumps(query)} - res = self.app.data.find("items_with_description", req, None) + res, count = self.app.data.find("items_with_description", req, None) self.assertEqual(1, res.count()) def test_search_via_source_param_and_with_highlight(self): @@ -334,7 +332,7 @@ def test_search_via_source_param_and_with_highlight(self): ) req = ParsedRequest() req.args = {"source": json.dumps(query), "es_highlight": 1} - res = self.app.data.find("items_with_description", req, None) + res, count = self.app.data.find("items_with_description", req, None) self.assertEqual(1, res.count()) es_highlight = res[0].get("es_highlight") self.assertIsNotNone(es_highlight) @@ -353,7 +351,7 @@ def test_search_with_highlight_without_query_string_query(self): "source": json.dumps({"query": {"term": {"name": "foo"}}}), "es_highlight": 1, } - res = self.app.data.find("items_with_description", req, None) + res, count = self.app.data.find("items_with_description", req, None) self.assertEqual(0, res.count()) def test_search_via_source_param_and_without_highlight(self): @@ -368,7 +366,7 @@ def test_search_via_source_param_and_without_highlight(self): ) req = ParsedRequest() req.args = {"source": json.dumps(query), "es_highlight": 0} - res = self.app.data.find("items_with_description", req, None) + res, count = self.app.data.find("items_with_description", req, None) self.assertEqual(1, res.count()) es_highlight = res[0].get("es_highlight") self.assertIsNone(es_highlight) @@ -388,7 +386,7 @@ def test_search_via_source_param_and_with_source_projection(self): "source": json.dumps(query), "projections": json.dumps(["name"]), } - res = self.app.data.find("items_with_description", req, None) + res, count = self.app.data.find("items_with_description", req, None) self.assertEqual(1, res.count()) self.assertTrue("description" not in res.docs[0]) self.assertTrue("name" in res.docs[0]) @@ -453,7 +451,7 @@ def test_eve_projection(self): } ) - items = self.app.data.find("items", req, None) + items, count = self.app.data.find("items", req, None) fields = items[0].keys() self.assertIn("name", fields) self.assertIn("_id", fields) @@ -501,16 +499,14 @@ def test_sub_resource_lookup(self): self.app.data.insert("items", [{"uri": "foo", "name": "foo"}]) req = ParsedRequest() req.args = {} - self.assertEqual( - 1, self.app.data.find("items", req, {"name": "foo"}).count() - ) - self.assertEqual( - 0, self.app.data.find("items", req, {"name": "bar"}).count() - ) - self.assertEqual( - 1, - self.app.data.find("items", req, {"name": "foo", "uri": "foo"}).count(), + cursor, count = self.app.data.find("items", req, {"name": "foo"}) + self.assertEqual(1, count) + cursor, count = self.app.data.find("items", req, {"name": "bar"}) + self.assertEqual(0, count) + cursor, count = self.app.data.find( + "items", req, {"name": "foo", "uri": "foo"} ) + self.assertEqual(1, count) def test_sub_resource_lookup_with_schema_filter(self): with self.app.app_context(): @@ -520,18 +516,14 @@ def test_sub_resource_lookup_with_schema_filter(self): ) req = ParsedRequest() req.args = {} - self.assertEqual( - 1, - self.app.data.find( - "items_with_description", req, {"name": "foo"} - ).count(), + cursor, count = self.app.data.find( + "items_with_description", req, {"name": "foo"} ) - self.assertEqual( - 0, - self.app.data.find( - "items_with_description", req, {"name": "bar"} - ).count(), + self.assertEqual(1, count) + cursor, count = self.app.data.find( + "items_with_description", req, {"name": "bar"} ) + self.assertEqual(0, count) def test_resource_filter(self): with self.app.app_context(): @@ -544,9 +536,8 @@ def test_resource_filter(self): req.args["source"] = json.dumps( {"query": {"bool": {"must": [{"term": {"uri": "bar"}}]}}} ) - self.assertEqual( - 0, self.app.data.find("items_with_description", req, None).count() - ) + cursor, count = self.app.data.find("items_with_description", req, None) + self.assertEqual(0, count) def test_where_filter(self): with self.app.app_context(): @@ -579,7 +570,8 @@ def test_remove_by_id(self): self.app.data.remove("items", {"_id": self.ids[0]}) req = ParsedRequest() req.args = {} - self.assertEqual(1, self.app.data.find("items", req, None).count()) + cursor, count = self.app.data.find("items", req, None) + self.assertEqual(1, count) def test_remove_non_existing_item(self): with self.app.app_context(): @@ -613,8 +605,12 @@ def test_resource_aggregates(self): req = ParsedRequest() req.args = {} response = {} - item1 = self.app.data.find("items_with_description", req, {"name": "foo"}) - item2 = self.app.data.find("items_with_description", req, {"name": "bar"}) + item1, count1 = self.app.data.find( + "items_with_description", req, {"name": "foo"} + ) + item2, count2 = self.app.data.find( + "items_with_description", req, {"name": "bar"} + ) item1.extra(response) self.assertEqual(3, item1.count()) self.assertEqual(1, item2.count()) @@ -629,12 +625,12 @@ def test_resource_aggregates_no_auto(self): req = ParsedRequest() req.args = {} response = {} - cursor = self.app.data.find("items_with_description", req, {}) + cursor, count = self.app.data.find("items_with_description", req, {}) cursor.extra(response) self.assertNotIn("_aggregations", response) req.args = {"aggregations": 1} - cursor = self.app.data.find("items_with_description", req, {}) + cursor, count = self.app.data.find("items_with_description", req, {}) cursor.extra(response) self.assertIn("_aggregations", response) @@ -653,7 +649,8 @@ def test_args_filter(self): req = ParsedRequest() req.args = {} req.args["filter"] = json.dumps({"term": {"uri": "foo"}}) - self.assertEqual(1, self.app.data.find("items", req, None).count()) + cursor, count = self.app.data.find("items", req, None) + self.assertEqual(1, count) def test_filters_with_aggregations(self): with self.app.app_context(): @@ -667,7 +664,9 @@ def test_filters_with_aggregations(self): req = ParsedRequest() res = {} - cursor = self.app.data.find("items_with_description", req, {"uri": "bar"}) + cursor, count = self.app.data.find( + "items_with_description", req, {"uri": "bar"} + ) cursor.extra(res) self.assertEqual(1, cursor.count()) self.assertIn( @@ -679,10 +678,10 @@ def test_filter_without_args(self): with self.app.app_context(): self.app.data.insert("items", [{"uri": "foo"}, {"uri": "bar"}]) req = ParsedRequest() - self.assertEqual(2, self.app.data.find("items", req, None).count()) - self.assertEqual( - 1, self.app.data.find("items", req, {"uri": "foo"}).count() - ) + cursor, count = self.app.data.find("items", req, None) + self.assertEqual(2, count) + cursor, count = self.app.data.find("items", req, {"uri": "foo"}) + self.assertEqual(1, count) def test_filters_with_filtered_query(self): with self.app.app_context(): @@ -700,7 +699,7 @@ def test_filters_with_filtered_query(self): req = ParsedRequest() req.args = {"source": json.dumps(query)} - cursor = self.app.data.find("items", req, None) + cursor, count = self.app.data.find("items", req, None) self.assertEqual(0, cursor.count()) def test_basic_search_query(self): @@ -709,7 +708,7 @@ def test_basic_search_query(self): with self.app.test_request_context("/items/?q=foo"): req = parse_request("items") - cursor = self.app.data.find("items", req, None) + cursor, count = self.app.data.find("items", req, None) self.assertEquals(1, cursor.count()) def test_phrase_search_query(self): @@ -718,12 +717,12 @@ def test_phrase_search_query(self): with self.app.test_request_context('/items/?q="foo bar"'): req = parse_request("items") - cursor = self.app.data.find("items", req, None) + cursor, count = self.app.data.find("items", req, None) self.assertEquals(1, cursor.count()) with self.app.test_request_context('/items/?q="bar foo"'): req = parse_request("items") - cursor = self.app.data.find("items", req, None) + cursor, count = self.app.data.find("items", req, None) self.assertEquals(0, cursor.count()) def test_elastic_filter_callback(self): @@ -734,7 +733,7 @@ def test_elastic_filter_callback(self): with self.app.test_request_context("test?uri=foo"): req = parse_request("items_with_callback_filter") - cursor = self.app.data.find("items_with_callback_filter", req, None) + cursor, count = self.app.data.find("items_with_callback_filter", req, None) self.assertEqual(1, cursor.count()) def test_elastic_sort_by_score_if_there_is_query(self): @@ -747,7 +746,7 @@ def test_elastic_sort_by_score_if_there_is_query(self): with self.app.test_request_context("/items/"): req = parse_request("items") req.args = {"q": "foo"} - cursor = self.app.data.find("items", req, None) + cursor, count = self.app.data.find("items", req, None) self.assertEqual(2, cursor.count()) self.assertEqual("foo", cursor[0]["uri"]) @@ -755,7 +754,7 @@ def test_elastic_find_default_sort_no_mapping(self): with self.app.test_request_context("/items/"): req = parse_request("items") req.args = {} - cursor = self.app.data.find("items", req, None) + cursor, count = self.app.data.find("items", req, None) self.assertEqual(0, cursor.count()) @skip("every resource has it's own index now") @@ -802,7 +801,7 @@ def test_no_force_refresh(self): time.sleep(2) req = ParsedRequest() - cursor = self.app.data.find("items", req, None) + cursor, count = self.app.data.find("items", req, None) self.assertEqual(2, cursor.count()) def test_elastic_prefix(self): @@ -811,11 +810,11 @@ def test_elastic_prefix(self): self.assertIn("firstcreated", mapping) self.app.data.insert("items_foo_default_index", [{"uri": "test"}]) - foo_items = self.app.data.find("items_foo", ParsedRequest(), None) + foo_items, count = self.app.data.find("items_foo", ParsedRequest(), None) self.assertEqual(0, foo_items.count()) self.app.data.insert("items_foo", [{"uri": "foo"}, {"uri": "bar"}]) - foo_items = self.app.data.find("items_foo", ParsedRequest(), None) + foo_items, count = self.app.data.find("items_foo", ParsedRequest(), None) self.assertEqual(2, foo_items.count()) def test_retry_on_conflict(self): @@ -1325,7 +1324,7 @@ def test_parent_child_query(self): } req = ParsedRequest() req.args = {"source": json.dumps(query)} - results = self.app.data.find(self.parent_item, req, None) + results, count = self.app.data.find(self.parent_item, req, None) self.assertEqual(1, results.count()) self.assertEqual(results[0].get("_id"), "foo") self.assertEqual(results[0].get("_type"), self.parent_item) @@ -1428,7 +1427,7 @@ def test_inner_hits_query(self): } req = ParsedRequest() req.args = {"source": json.dumps(query)} - results = self.app.data.find("items", req, None) + results, count = self.app.data.find("items", req, None) self.assertEqual(2, results.count()) self.assertEqual(results[0].get("_id"), "foo") self.assertEqual(len(results[0].get("_inner_hits")), 1) @@ -1541,7 +1540,7 @@ def test_nested_sort(self): } req = ParsedRequest() req.args = {"source": json.dumps(query)} - results = self.app.data.find("items", req, None) + results, count = self.app.data.find("items", req, None) self.assertEqual(2, results.count()) self.assertEqual(results[0].get("_id"), "bar") self.assertEqual(results[1].get("_id"), "foo")