Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chemiscope.metatensor_featurizer #357

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ dependencies with:
pip install chemiscope[explore]
```

To use `chemiscope.metatensor_featurizer` for providing your trained model
to get the features for `chemiscope.explore`, install the dependencies with:
```bash
pip install chemiscope[metatensor]
```

## sphinx and sphinx-gallery integration

Chemiscope provides also extensions for `sphinx` and `sphinx-gallery` to
Expand Down
2 changes: 2 additions & 0 deletions docs/src/python/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,5 @@
.. autofunction:: chemiscope.ase_tensors_to_ellipsoids

.. autofunction:: chemiscope.explore

.. autofunction:: chemiscope.metatensor_featurizer
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,8 @@ explore = [
"dscribe",
"scikit-learn",
]

metatensor = [
"metatensor",
"metatensor[torch]"
]
2 changes: 1 addition & 1 deletion python/chemiscope/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
extract_properties,
librascal_atomic_environments,
)
from .explore import explore # noqa: F401
from .explore import explore, metatensor_featurizer # noqa: F401
from .version import __version__ # noqa: F401

from .jupyter import show, show_input # noqa
191 changes: 186 additions & 5 deletions python/chemiscope/explore.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

import numpy as np

from .jupyter import show


Expand Down Expand Up @@ -33,8 +35,7 @@ def explore(frames, featurize=None, properties=None, environments=None, mode="de
:param environments: optional. List of environments (described as
``(structure id, center id, cutoff)``) to include when extracting the
atomic properties. Can be extracted from frames with
:py:func:`all_atomic_environments`.
or manually defined.
:py:func:`all_atomic_environments` or manually defined.

:param str mode: optional. Visualization mode for the chemiscope widget. Can be one
of "default", "structure", or "map". The default mode is "default".
Expand Down Expand Up @@ -197,15 +198,195 @@ def soap_pca_featurize(frames, environments=None):
return pca.fit_transform(feats)


def _extract_environment_indices(envs):
def metatensor_featurizer(
model, extensions_directory=None, check_consistency=False, device=None
):
"""
Create a featurizer function using a `metatensor`_ model to obtain the features
from structures with a pre-trained ``metatensor`` model.

If a list of `ase.Atoms <ase-io_>`_ is provided, ``featurize`` processes each frame
to obtain features and stacks them into a single array. If a single frame
(`ase.Atoms <ase-io_>`_) is provided, it processes it directly.

:param model: a model to use for the calculation. It can be a file path, a Python
instance of `MetatensorAtomisticModel
<https://docs.metatensor.org/latest/atomistic/reference/models/export.html#metatensor.torch.atomistic.MetatensorAtomisticModel>`_,
or the output of `torch.jit.script
<https://pytorch.org/docs/stable/generated/torch.jit.script.html#torch.jit.script>`_
on :py:class:`MetatensorAtomisticModel`.

:param extensions_directory: a directory where model extensions are located, default
is ``"extensions/"``.

:param check_consistency: should we check the model for consistency when running,
defaults to False.

:param device: a torch device to use for the calculation. If ``None``, the function
will use the options in model's ``supported_device`` attribute.

:returns: a function that takes a list of frames and returns the features.

To use this function, additional dependencies are required, specifically,
`metatensor`_. It can be installed with the following command:

.. code:: bash

pip install chemiscope[metatensor]

Here is an example using a pre-trained `metatensor`_ model, stored as a
``model.pt`` file with the compiled extensions stored in the ``extensions/``
directory. To obtain the details on how to get it, see metatensor `tutorial
<https://lab-cosmo.github.io/metatrain/latest/getting-started/usage.html>`_.
The frames are obtained by reading structures from a file that `ase <ase-io_>`_
can read.

.. code-block:: python

import chemiscope
import ase.io

# Read the structures from the dataset
frames = ase.io.read("data/explore_c-gap-20u.xyz", ":")

# Provide model file ("model.pt") to 'metatensor_featurizer', it's result will
# be visualised with a chemiscope widget
chemiscope.explore(
frames,
featurize=chemiscope.metatensor_featurizer("model.pt"),
)

For more examples, see the related `documentation <chemiscope-explore-metatensor>`_.

.. _metatensor: https://docs.metatensor.org/latest/index.html
.. _chemiscope-explore-torchscript:
https://chemiscope.org/docs/examples/7-explore-advanced.html#example-with-metatensor-model
.. _torch.jit.script:
https://pytorch.org/docs/stable/generated/torch.jit.script.html#torch.jit.script
"""

# Check if dependencies were installed
try:
from metatensor.torch.atomistic import ModelOutput
from metatensor.torch.atomistic.ase_calculator import MetatensorCalculator
except ImportError as e:
raise ImportError(
f"Required package not found: {e}. Please install the dependency using "
"'pip install chemiscope[metatensor]'."
)

# Initialize metatensor calculator
mtt_calculator = MetatensorCalculator(
model=model,
extensions_directory=extensions_directory,
check_consistency=check_consistency,
device=device,
)

def get_llf(model, atoms, environments):
"""
Run the model on a single atomic structure and extract the features

:param model: a Metatensor calculator model
:param atoms: an atomic structure to be featurized
:param environments: a subset of atoms on which to run the calculation
"""
# Run the model
per_atom = environments is not None
outputs = {"features": ModelOutput(per_atom=per_atom)}
selected_atoms = _create_selected_atoms(environments)
output = model.run_model(atoms, outputs, selected_atoms)

# Extract and return features
llf = output["features"]
return llf.block().values.detach().cpu().numpy()

def featurize(frames, environments):
"""
Featurize the given frames with the loaded model

:param list frames: a list of atomic structures
:param list environments: optional. List of environments (described as
``(structure_id, atom_id, cutoff)``), can be extracted from frames with
:py:func:`all_atomic_environments` or manually defined
"""
# Process list of frames
if isinstance(frames, list):
Luthaf marked this conversation as resolved.
Show resolved Hide resolved
# Group environments per frame
envs_per_frame = _get_environments_per_frame(environments, len(frames))

# Get features for each frame
outputs = [
get_llf(mtt_calculator, frame, envs)
for frame, envs in zip(frames, envs_per_frame)
]
return np.vstack(outputs)

# If a single frame is provided, process it directly
return get_llf(mtt_calculator, frames, environments)

return featurize


def _extract_environment_indices(environments):
"""
Convert from chemiscope's environements to DScribe's centers selection

:param: list envs: each element is a list of [env_index, atom_index, cutoff]
:param: list environments: each element is a list of [env_index, atom_index, cutoff]
"""
grouped_envs = {}
for [env_index, atom_index, _cutoff] in envs:
for [env_index, atom_index, _cutoff] in environments:
if env_index not in grouped_envs:
grouped_envs[env_index] = []
grouped_envs[env_index].append(atom_index)
return list(grouped_envs.values())


def _get_environments_per_frame(environments, num_frames):
"""
Organize the environments for each frame

:param list environments: a list of atomic environments
:param int num_frames: total number of frames
"""
envs_per_frame = [None] * num_frames

if environments is None:
return envs_per_frame

frames_dict = {}

# Group environments by structure_id
for env in environments:
structure_id = env[0]
if structure_id not in frames_dict:
frames_dict[structure_id] = []
frames_dict[structure_id].append(env)

# Insert environments to the frame indices
for structure_id, envs in frames_dict.items():
if structure_id < num_frames:
envs_per_frame[structure_id] = envs

return envs_per_frame


def _create_selected_atoms(environments):
"""
Convert environments into a metatensor.torch.Labels object for selected_atoms

:param environments: a list of atom-centered environments
"""
import torch # noqa
from metatensor.torch import Labels # noqa

if environments is None:
return None

# Extract system and atom indices from environments
values = torch.tensor(
[(structure_id, atom_id) for structure_id, atom_id, _ in environments]
)

# Create Labels with system and atom dimensions
return Labels(names=["system", "atom"], values=values)
1 change: 1 addition & 0 deletions python/examples/7-explore-advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def mace_mp0_tsne(frames, environments):
fetch_dataset("mace-mp-tsne-m3cd.json.gz")
chemiscope.show_input("data/mace-mp-tsne-m3cd.json.gz")


# %%
#
# Example with SOAP, t-SNE and environments
Expand Down
Loading
Loading