Skip to content

Commit

Permalink
Add ability to (in/ex)clude providers by ID within client (#1412)
Browse files Browse the repository at this point in the history
* Add ability to in(ex)clude providers by ID within client

- Using `--include-providers`, `--exclude-providers` and
  `--exclude_databases` at the CLI or the corresponding Python options

* Add test case

* Add docs notes
  • Loading branch information
ml-evs authored Dec 5, 2022
1 parent 5243791 commit a8c84c5
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 7 deletions.
56 changes: 56 additions & 0 deletions docs/getting_started/client.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,62 @@ We can refine the search by manually specifying some URLs:
client.get()
```

or by including/excluding some providers by their registered IDs in the [Providers list](https://providers.optimade.org).

Query only a list of included providers (after a lookup of the providers list):

=== "Command line"
```shell
# Only query databases served by the example providers
optimade-get --include-providers exmpl,optimade
```

=== "Python"
```python
# Only query databases served by the example providers
from optimade.client import OptimadeClient
client = OptimadeClient(
include_providers={"exmpl", "optimade"},
)
client.get()
```

Exclude certain providers:

=== "Command line"
```shell
# Exclude example providers from global list
optimade-get --exclude-providers exmpl,optimade
```

=== "Python"
```python
# Exclude example providers from global list
from optimade.client import OptimadeClient
client = OptimadeClient(
exclude_providers={"exmpl", "optimade"},
)
client.get()
```

Exclude particular databases by URL:

=== "Command line"
```shell
# Exclude specific example databases
optimade-get --exclude-databases https://example.org/optimade,https://optimade.org/example
```

=== "Python"
```python
# Exclude specific example databases
from optimade.client import OptimadeClient
client = OptimadeClient(
exclude_databases={"https://example.org/optimade", "https://optimade.org/example"}
)
client.get()
```

### Filtering

By default, an empty filter will be used (which will return all entries in a database).
Expand Down
39 changes: 38 additions & 1 deletion optimade/client/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,26 @@
is_flag=True,
help="Pretty print the JSON results.",
)
@click.argument("base-url", default=None, nargs=-1)
@click.option(
"--include-providers",
default=None,
help="A string of comma-separated provider IDs to query.",
)
@click.option(
"--exclude-providers",
default=None,
help="A string of comma-separated provider IDs to exclude from queries.",
)
@click.option(
"--exclude-databases",
default=None,
help="A string of comma-separated database URLs to exclude from queries.",
)
@click.argument(
"base-url",
default=None,
nargs=-1,
)
def get(
use_async,
filter,
Expand All @@ -65,6 +84,9 @@ def get(
sort,
endpoint,
pretty_print,
include_providers,
exclude_providers,
exclude_databases,
):
return _get(
use_async,
Expand All @@ -77,6 +99,9 @@ def get(
sort,
endpoint,
pretty_print,
include_providers,
exclude_providers,
exclude_databases,
)


Expand All @@ -91,6 +116,9 @@ def _get(
sort,
endpoint,
pretty_print,
include_providers,
exclude_providers,
exclude_databases,
):

if output_file:
Expand All @@ -106,6 +134,15 @@ def _get(
base_urls=base_url,
use_async=use_async,
max_results_per_provider=max_results_per_provider,
include_providers=set(_.strip() for _ in include_providers.split(","))
if include_providers
else None,
exclude_providers=set(_.strip() for _ in exclude_providers.split(","))
if exclude_providers
else None,
exclude_databases=set(_.strip() for _ in exclude_databases.split(","))
if exclude_databases
else None,
)
if response_fields:
response_fields = response_fields.split(",")
Expand Down
32 changes: 30 additions & 2 deletions optimade/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import json
import time
from collections import defaultdict
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
from urllib.parse import urlparse

# External deps that are only used in the client code
Expand Down Expand Up @@ -85,6 +85,15 @@ class OptimadeClient:
use_async: bool
"""Whether or not to make all requests asynchronously using asyncio."""

_excluded_providers: Optional[Set[str]] = None
"""A set of providers IDs excluded from future queries."""

_included_providers: Optional[Set[str]] = None
"""A set of providers IDs included from future queries."""

_excluded_databases: Optional[Set[str]] = None
"""A set of child database URLs excluded from future queries."""

__current_endpoint: Optional[str] = None
"""Used internally when querying via `client.structures.get()` to set the
chosen endpoint. Should be reset to `None` outside of all `get()` calls."""
Expand All @@ -97,6 +106,9 @@ def __init__(
http_timeout: int = 10,
max_attempts: int = 5,
use_async: bool = True,
exclude_providers: Optional[List[str]] = None,
include_providers: Optional[List[str]] = None,
exclude_databases: Optional[List[str]] = None,
):
"""Create the OPTIMADE client object.
Expand All @@ -108,16 +120,32 @@ def __init__(
http_timeout: The HTTP timeout to use per request.
max_attempts: The maximum number of times to repeat a failing query.
use_async: Whether or not to make all requests asynchronously.
exclude_providers: A set or collection of provider IDs to exclude from queries.
include_providers: A set or collection of provider IDs to include in queries.
exclude_databases: A set or collection of child database URLs to exclude from queries.
"""

self.max_results_per_provider = max_results_per_provider
if self.max_results_per_provider in (-1, 0):
self.max_results_per_provider = None

self._excluded_providers = set(exclude_providers) if exclude_providers else None
self._included_providers = set(include_providers) if include_providers else None
self._excluded_databases = set(exclude_databases) if exclude_databases else None

if not base_urls:
self.base_urls = get_all_databases()
self.base_urls = get_all_databases(
exclude_providers=self._excluded_providers,
include_providers=self._included_providers,
exclude_databases=self._excluded_databases,
)
else:
if exclude_providers or include_providers or exclude_databases:
raise RuntimeError(
"Cannot provide both a list of base URLs and included/excluded databases."
)

self.base_urls = base_urls

if isinstance(self.base_urls, str):
Expand Down
32 changes: 28 additions & 4 deletions optimade/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

import json
from typing import Iterable, List
from typing import Container, Iterable, List, Optional

from pydantic import ValidationError

Expand Down Expand Up @@ -101,7 +101,7 @@ def get_providers(add_mongo_id: bool = False) -> list:


def get_child_database_links(
provider: LinksResource, obey_aggregate=True
provider: LinksResource, obey_aggregate: bool = True
) -> List[LinksResource]:
"""For a provider, return a list of available child databases.
Expand Down Expand Up @@ -155,13 +155,37 @@ def get_child_database_links(
) from exc


def get_all_databases() -> Iterable[str]:
"""Iterate through all databases reported by registered OPTIMADE providers."""
def get_all_databases(
include_providers: Optional[Container[str]] = None,
exclude_providers: Optional[Container[str]] = None,
exclude_databases: Optional[Container[str]] = None,
) -> Iterable[str]:
"""Iterate through all databases reported by registered OPTIMADE providers.
Parameters:
include_providers: A set/container of provider IDs to include child databases for.
exclude_providers: A set/container of provider IDs to exclude child databases for.
exclude_databases: A set/container of specific database URLs to exclude.
Returns:
A generator of child database links that obey the given parameters.
"""
for provider in get_providers():
if exclude_providers and provider["id"] in exclude_providers:
continue
if include_providers and provider["id"] not in include_providers:
continue

try:
links = get_child_database_links(provider)
for link in links:
if link.attributes.base_url:
if (
exclude_databases
and link.attributes.base_url in exclude_databases
):
continue
yield str(link.attributes.base_url)
except RuntimeError:
pass
36 changes: 36 additions & 0 deletions tests/server/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,39 @@ def test_multiple_base_urls(httpx_mocked_response, use_async):
)


@pytest.mark.parametrize("use_async", [False])
def test_include_exclude_providers(use_async):
with pytest.raises(
SystemExit,
match="Unable to access any OPTIMADE base URLs. If you believe this is an error, try manually specifying some base URLs.",
):
OptimadeClient(
include_providers={"exmpl"},
exclude_providers={"exmpl"},
use_async=use_async,
)

with pytest.raises(
RuntimeError,
match="Cannot provide both a list of base URLs and included/excluded databases.",
):
OptimadeClient(
base_urls=TEST_URLS,
include_providers={"exmpl"},
use_async=use_async,
)

with pytest.raises(
SystemExit,
match="Unable to access any OPTIMADE base URLs. If you believe this is an error, try manually specifying some base URLs.",
):
OptimadeClient(
include_providers={"exmpl"},
exclude_databases={"https://example.org/optimade"},
use_async=use_async,
)


@pytest.mark.parametrize("use_async", [False])
def test_client_sort(httpx_mocked_response, use_async):
cli = OptimadeClient(base_urls=[TEST_URL], use_async=use_async)
Expand All @@ -138,6 +171,9 @@ def test_command_line_client(httpx_mocked_response, use_async, capsys):
sort=None,
endpoint="structures",
pretty_print=False,
include_providers=None,
exclude_providers=None,
exclude_databases=None,
)

# Test multi-provider query
Expand Down

0 comments on commit a8c84c5

Please sign in to comment.