From a8c84c5f193e668eb8b5fc3f0db1d6a63796f472 Mon Sep 17 00:00:00 2001 From: Matthew Evans <7916000+ml-evs@users.noreply.github.com> Date: Mon, 5 Dec 2022 12:05:59 +0000 Subject: [PATCH] Add ability to (in/ex)clude providers by ID within client (#1412) * Add ability to in(ex)clude providers by ID within client - Using `--include-providers`, `--exclude-providers` and `--exclude_databases` at the CLI or the corresponding Python options * Add test case * Add docs notes --- docs/getting_started/client.md | 56 ++++++++++++++++++++++++++++++++++ optimade/client/cli.py | 39 ++++++++++++++++++++++- optimade/client/client.py | 32 +++++++++++++++++-- optimade/utils.py | 32 ++++++++++++++++--- tests/server/test_client.py | 36 ++++++++++++++++++++++ 5 files changed, 188 insertions(+), 7 deletions(-) diff --git a/docs/getting_started/client.md b/docs/getting_started/client.md index fdf453666..8c6bd7ca6 100644 --- a/docs/getting_started/client.md +++ b/docs/getting_started/client.md @@ -72,6 +72,62 @@ We can refine the search by manually specifying some URLs: client.get() ``` +or by including/excluding some providers by their registered IDs in the [Providers list](https://providers.optimade.org). + +Query only a list of included providers (after a lookup of the providers list): + +=== "Command line" + ```shell + # Only query databases served by the example providers + optimade-get --include-providers exmpl,optimade + ``` + +=== "Python" + ```python + # Only query databases served by the example providers + from optimade.client import OptimadeClient + client = OptimadeClient( + include_providers={"exmpl", "optimade"}, + ) + client.get() + ``` + +Exclude certain providers: + +=== "Command line" + ```shell + # Exclude example providers from global list + optimade-get --exclude-providers exmpl,optimade + ``` + +=== "Python" + ```python + # Exclude example providers from global list + from optimade.client import OptimadeClient + client = OptimadeClient( + exclude_providers={"exmpl", "optimade"}, + ) + client.get() + ``` + +Exclude particular databases by URL: + +=== "Command line" + ```shell + # Exclude specific example databases + optimade-get --exclude-databases https://example.org/optimade,https://optimade.org/example + ``` + +=== "Python" + ```python + # Exclude specific example databases + from optimade.client import OptimadeClient + client = OptimadeClient( + exclude_databases={"https://example.org/optimade", "https://optimade.org/example"} + ) + client.get() + ``` + ### Filtering By default, an empty filter will be used (which will return all entries in a database). diff --git a/optimade/client/cli.py b/optimade/client/cli.py index 3aafe2d9a..d94abdfc2 100644 --- a/optimade/client/cli.py +++ b/optimade/client/cli.py @@ -53,7 +53,26 @@ is_flag=True, help="Pretty print the JSON results.", ) -@click.argument("base-url", default=None, nargs=-1) +@click.option( + "--include-providers", + default=None, + help="A string of comma-separated provider IDs to query.", +) +@click.option( + "--exclude-providers", + default=None, + help="A string of comma-separated provider IDs to exclude from queries.", +) +@click.option( + "--exclude-databases", + default=None, + help="A string of comma-separated database URLs to exclude from queries.", +) +@click.argument( + "base-url", + default=None, + nargs=-1, +) def get( use_async, filter, @@ -65,6 +84,9 @@ def get( sort, endpoint, pretty_print, + include_providers, + exclude_providers, + exclude_databases, ): return _get( use_async, @@ -77,6 +99,9 @@ def get( sort, endpoint, pretty_print, + include_providers, + exclude_providers, + exclude_databases, ) @@ -91,6 +116,9 @@ def _get( sort, endpoint, pretty_print, + include_providers, + exclude_providers, + exclude_databases, ): if output_file: @@ -106,6 +134,15 @@ def _get( base_urls=base_url, use_async=use_async, max_results_per_provider=max_results_per_provider, + include_providers=set(_.strip() for _ in include_providers.split(",")) + if include_providers + else None, + exclude_providers=set(_.strip() for _ in exclude_providers.split(",")) + if exclude_providers + else None, + exclude_databases=set(_.strip() for _ in exclude_databases.split(",")) + if exclude_databases + else None, ) if response_fields: response_fields = response_fields.split(",") diff --git a/optimade/client/client.py b/optimade/client/client.py index 64f53c333..622b34189 100644 --- a/optimade/client/client.py +++ b/optimade/client/client.py @@ -10,7 +10,7 @@ import json import time from collections import defaultdict -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from urllib.parse import urlparse # External deps that are only used in the client code @@ -85,6 +85,15 @@ class OptimadeClient: use_async: bool """Whether or not to make all requests asynchronously using asyncio.""" + _excluded_providers: Optional[Set[str]] = None + """A set of providers IDs excluded from future queries.""" + + _included_providers: Optional[Set[str]] = None + """A set of providers IDs included from future queries.""" + + _excluded_databases: Optional[Set[str]] = None + """A set of child database URLs excluded from future queries.""" + __current_endpoint: Optional[str] = None """Used internally when querying via `client.structures.get()` to set the chosen endpoint. Should be reset to `None` outside of all `get()` calls.""" @@ -97,6 +106,9 @@ def __init__( http_timeout: int = 10, max_attempts: int = 5, use_async: bool = True, + exclude_providers: Optional[List[str]] = None, + include_providers: Optional[List[str]] = None, + exclude_databases: Optional[List[str]] = None, ): """Create the OPTIMADE client object. @@ -108,6 +120,9 @@ def __init__( http_timeout: The HTTP timeout to use per request. max_attempts: The maximum number of times to repeat a failing query. use_async: Whether or not to make all requests asynchronously. + exclude_providers: A set or collection of provider IDs to exclude from queries. + include_providers: A set or collection of provider IDs to include in queries. + exclude_databases: A set or collection of child database URLs to exclude from queries. """ @@ -115,9 +130,22 @@ def __init__( if self.max_results_per_provider in (-1, 0): self.max_results_per_provider = None + self._excluded_providers = set(exclude_providers) if exclude_providers else None + self._included_providers = set(include_providers) if include_providers else None + self._excluded_databases = set(exclude_databases) if exclude_databases else None + if not base_urls: - self.base_urls = get_all_databases() + self.base_urls = get_all_databases( + exclude_providers=self._excluded_providers, + include_providers=self._included_providers, + exclude_databases=self._excluded_databases, + ) else: + if exclude_providers or include_providers or exclude_databases: + raise RuntimeError( + "Cannot provide both a list of base URLs and included/excluded databases." + ) + self.base_urls = base_urls if isinstance(self.base_urls, str): diff --git a/optimade/utils.py b/optimade/utils.py index bbdff8d8d..cb3039565 100644 --- a/optimade/utils.py +++ b/optimade/utils.py @@ -4,7 +4,7 @@ """ import json -from typing import Iterable, List +from typing import Container, Iterable, List, Optional from pydantic import ValidationError @@ -101,7 +101,7 @@ def get_providers(add_mongo_id: bool = False) -> list: def get_child_database_links( - provider: LinksResource, obey_aggregate=True + provider: LinksResource, obey_aggregate: bool = True ) -> List[LinksResource]: """For a provider, return a list of available child databases. @@ -155,13 +155,37 @@ def get_child_database_links( ) from exc -def get_all_databases() -> Iterable[str]: - """Iterate through all databases reported by registered OPTIMADE providers.""" +def get_all_databases( + include_providers: Optional[Container[str]] = None, + exclude_providers: Optional[Container[str]] = None, + exclude_databases: Optional[Container[str]] = None, +) -> Iterable[str]: + """Iterate through all databases reported by registered OPTIMADE providers. + + Parameters: + include_providers: A set/container of provider IDs to include child databases for. + exclude_providers: A set/container of provider IDs to exclude child databases for. + exclude_databases: A set/container of specific database URLs to exclude. + + Returns: + A generator of child database links that obey the given parameters. + + """ for provider in get_providers(): + if exclude_providers and provider["id"] in exclude_providers: + continue + if include_providers and provider["id"] not in include_providers: + continue + try: links = get_child_database_links(provider) for link in links: if link.attributes.base_url: + if ( + exclude_databases + and link.attributes.base_url in exclude_databases + ): + continue yield str(link.attributes.base_url) except RuntimeError: pass diff --git a/tests/server/test_client.py b/tests/server/test_client.py index 89846e02f..b92f8d2fe 100644 --- a/tests/server/test_client.py +++ b/tests/server/test_client.py @@ -116,6 +116,39 @@ def test_multiple_base_urls(httpx_mocked_response, use_async): ) +@pytest.mark.parametrize("use_async", [False]) +def test_include_exclude_providers(use_async): + with pytest.raises( + SystemExit, + match="Unable to access any OPTIMADE base URLs. If you believe this is an error, try manually specifying some base URLs.", + ): + OptimadeClient( + include_providers={"exmpl"}, + exclude_providers={"exmpl"}, + use_async=use_async, + ) + + with pytest.raises( + RuntimeError, + match="Cannot provide both a list of base URLs and included/excluded databases.", + ): + OptimadeClient( + base_urls=TEST_URLS, + include_providers={"exmpl"}, + use_async=use_async, + ) + + with pytest.raises( + SystemExit, + match="Unable to access any OPTIMADE base URLs. If you believe this is an error, try manually specifying some base URLs.", + ): + OptimadeClient( + include_providers={"exmpl"}, + exclude_databases={"https://example.org/optimade"}, + use_async=use_async, + ) + + @pytest.mark.parametrize("use_async", [False]) def test_client_sort(httpx_mocked_response, use_async): cli = OptimadeClient(base_urls=[TEST_URL], use_async=use_async) @@ -138,6 +171,9 @@ def test_command_line_client(httpx_mocked_response, use_async, capsys): sort=None, endpoint="structures", pretty_print=False, + include_providers=None, + exclude_providers=None, + exclude_databases=None, ) # Test multi-provider query