Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
a-kore committed May 3, 2024
1 parent 806eab3 commit a82850b
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 26 deletions.
54 changes: 42 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,50 @@
[![codecov](https://codecov.io/gh/VectorInstitute/aieng-template/branch/main/graph/badge.svg)](https://codecov.io/gh/VectorInstitute/aieng-template)
[![license](https://img.shields.io/github/license/VectorInstitute/aieng-template.svg)](https://github.com/VectorInstitute/aieng-template/blob/main/LICENSE)

## Table of Contents

- [Overview](#overview)
- [Datasets](#datasets)
- [Models](#models)
- [Tasks](#tasks)
- [Installation](#installation)
- [Developing](#developing)

## Introduction

AtomGen is a Python package that generates atomistic structures for molecular simulations. The package is designed to be used with the [ASE](https://wiki.fysik.dtu.dk/ase/) package and provides a simple interface to generate structures of various materials.
AtomGen provides a robust framework for handling atomistic graph datasets focusing on transformer-based implementations. We provide utilities for training various models, experimenting with different pre-training tasks, and pre-trained models.

It streamlines the process of aggregation, standardization, and utilization of datasets from diverse sources, enabling large-scale pre-training and generative modeling on atomistic graphs.

## Datasets

AtomGen facilitates the aggregation and standardization of datasets, including but not limited to:

- **S2EF Datasets**: Aggregated from multiple sources such as OC20, OC22, ODAC23, MPtrj, and SPICE with structures and energies/forces for pre-training.

- **Misc. Atomistic Graph Datasets**: Including Molecule3D, Protein Data Bank (PDB), and the Open Quantum Materials Database (OQMD).

Currently, AtomGen has pre-processed datasets for the S2EF pre-training task for OC20 and a mixed dataset of OC20, OC22, ODAC23, MPtrj, and SPICE. They have been uploaded to huggingface hub and can be accessed using the datasets API.

## Models

AtomGen supports a variety of models for training on atomistic graph datasets, including:

- SchNet
- TokenGT
- Uni-Mol+ (Modified)

## Tasks

Experimentation with pre-training tasks is facilitated through AtomGen, including:

- **Structure to Energy & Forces**: Predicting energies and forces for atomistic graphs.

- **Masked Atom Modeling**: Masking atoms and predicting their properties.

- **Coordinate Denoising**: Denoising atom coordinates.

These tasks are all facilitated through the DataCollatorForAtomModeling class and can be used simultaneously or individually.

## Installation

Expand All @@ -20,17 +61,6 @@ python3 -m poetry install
source $(poetry env info --path)/bin/activate
```

## Usage

The package can be used to generate structures of various materials. For example, to generate a diamond structure:

```python
from atomgen import Diamond

diamond = Diamond()
diamond.generate()
```


## 🧑🏿‍💻 Developing

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ markers = [

[tool.coverage]
[tool.coverage.run]
source=["aieng_template"]
source=["atomgen"]
omit=["tests/*", "*__init__.py"]

[build-system]
Expand Down
36 changes: 25 additions & 11 deletions tests/atomgen/data/test_datacollator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
"""Tests for DataCollator."""
import torch

from atomgen.data.data_collator import DataCollatorForAtomModeling
from atomgen.data.tokenizer import AtomTokenizer


def test_data_collator():
data_collator = DataCollatorForAtomModeling()
input_ids = [1,2,10]
coords = [[0.5, 0.2, 0.1], [0.3, 0.4, 0.5], [0.1, 0.2, 0.3]]
labels = [[1, 2, 3], [4, 5, 6]]
batch = data_collator(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
assert batch["input_ids"].tolist() == input_ids
assert batch["attention_mask"].tolist() == attention_mask
assert batch["labels"].tolist() == labels
assert batch["decoder_input_ids"].tolist() == [[1, 2, 3], [4, 5, 6]]
assert batch["decoder_attention_mask"].tolist() == [[1, 1, 1], [1, 1, 1]]
assert batch["decoder_labels"].tolist() == [[1, 2, 3], [4, 5, 6]]
"""Test DataCollatorForAtomModeling."""
tokenizer = AtomTokenizer(vocab_file="atomgen/data/tokenizer.json")
data_collator = DataCollatorForAtomModeling(
tokenizer=tokenizer,
mam=False,
coords_perturb=False,
causal=False,
return_edge_indices=False,
pad=True,
)
size = torch.randint(4, 16, (10,)).tolist()
dataset = [
{
"input_ids": torch.randint(0, 123, (size[i],)).tolist(),
"coords": torch.randint(0, 123, (size[i], 3)).tolist(),
}
for i in range(10)
]
batch = data_collator(dataset)
assert len(batch["input_ids"]) == 10
7 changes: 5 additions & 2 deletions tests/atomgen/data/test_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Test AtomTokenizer."""
from atomgen.data.tokenizer import AtomTokenizer


def test_tokenizer():
"""Test AtomTokenizer."""
tokenizer = AtomTokenizer(vocab_file="atomgen/data/tokenizer.json")
text = "BaCCHeNNN"
text = "MgCCHeNNN"
tokens = tokenizer.tokenize(text)
assert tokens == ["Ba", "C", "C", "He", "N", "N", "N"]
assert tokens == ["Mg", "C", "C", "He", "N", "N", "N"]

0 comments on commit a82850b

Please sign in to comment.