Skip to content

Commit

Permalink
ruff select PERF SLOT
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Jun 21, 2023
1 parent 2ca8e80 commit 755e17c
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.272
rev: v0.0.273
hooks:
- id: ruff
args: [--fix]
Expand Down
2 changes: 1 addition & 1 deletion chgnet/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def __init__(

self.keys = []
for mp_id, dic in self.labels.items():
for graph_id, _ in dic.items():
for graph_id in dic:
self.keys.append((mp_id, graph_id))
random.shuffle(self.keys)
print(f"{len(self.labels)} mp_ids, {len(self)} frames imported")
Expand Down
2 changes: 1 addition & 1 deletion chgnet/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class Node:
"""A node in a graph."""

def __init__(self, index: int, info: dict = None) -> None:
def __init__(self, index: int, info: dict | None = None) -> None:
"""Initialize a Node.
Args:
Expand Down
6 changes: 5 additions & 1 deletion chgnet/model/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,11 @@ class GaussianExpansion(nn.Module):
"""

def __init__(
self, min: float = 0, max: float = 5, step: float = 0.5, var: float = None
self,
min: float = 0,
max: float = 5,
step: float = 0.5,
var: float | None = None,
) -> None:
"""Gaussian Expansion
expand a scalar feature to a soft-one-hot feature vector.
Expand Down
2 changes: 1 addition & 1 deletion chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
class CHGNetCalculator(Calculator):
"""CHGNet Calculator for ASE applications."""

implemented_properties = ["energy", "forces", "stress", "magmoms"]
implemented_properties = ("energy", "forces", "stress", "magmoms")

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion chgnet/model/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def find_activation(name: str) -> nn.Module:
raise NotImplementedError from exc


def find_normalization(name: str, dim: int = None) -> nn.Module | None:
def find_normalization(name: str, dim: int | None = None) -> nn.Module | None:
"""Return an normalization function using name."""
if name is None:
return None
Expand Down
2 changes: 1 addition & 1 deletion chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def save(self, filename: str = "training_result.pth.tar") -> None:
torch.save(state, filename)

def save_checkpoint(
self, epoch: int, mae_error: dict, save_dir: str = None
self, epoch: int, mae_error: dict, save_dir: str | None = None
) -> None:
"""Function to save CHGNet trained weights after each epoch.
Expand Down
54 changes: 28 additions & 26 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,32 +49,34 @@ find = { include = ["chgnet*"], exclude = ["tests", "tests*"] }
target-version = "py38"
line-length = 95
select = [
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"D", # pydocstyle
"E", # pycodestyle error
"EXE", # flake8-executable
"F", # pyflakes
"FLY", # flynt
"I", # isort
"ICN", # flake8-import-conventions
"ISC", # flake8-implicit-str-concat
"PD", # pandas-vet
"PIE", # flake8-pie
"PL", # pylint
"PT", # flake8-pytest-style
"PYI", # flakes8-pyi
"Q", # flake8-quotes
"RET", # flake8-return
"RSE", # flake8-raise
"RUF", # Ruff-specific rules
"SIM", # flake8-simplify
"TCH", # flake8-type-checking
"TID", # tidy imports
"TID", # flake8-tidy-imports
"UP", # pyupgrade
"W", # pycodestyle warning
"YTT", # flake8-2020
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"D", # pydocstyle
"E", # pycodestyle error
"EXE", # flake8-executable
"F", # pyflakes
"FLY", # flynt
"I", # isort
"ICN", # flake8-import-conventions
"ISC", # flake8-implicit-str-concat
"PD", # pandas-vet
"PERF", # perflint
"PIE", # flake8-pie
"PL", # pylint
"PT", # flake8-pytest-style
"PYI", # flakes8-pyi
"Q", # flake8-quotes
"RET", # flake8-return
"RSE", # flake8-raise
"RUF", # Ruff-specific rules
"SIM", # flake8-simplify
"SLOT", # flakes8-slot
"TCH", # flake8-type-checking
"TID", # tidy imports
"TID", # flake8-tidy-imports
"UP", # pyupgrade
"W", # pycodestyle warning
"YTT", # flake8-2020
]
ignore = [
"B019", # Use of functools.lru_cache on methods can lead to memory leaks
Expand Down
9 changes: 5 additions & 4 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
forces.append(np.random.random([2, 3]))
stresses.append(np.random.random([3, 3]))
magmoms.append(np.random.random(2))

data = StructureData(
structures=structures,
energies=energies,
Expand Down Expand Up @@ -48,8 +49,8 @@ def test_trainer(tmp_path) -> None:
assert test_dir.is_dir(), "Training dir was not created"

saved_weight = [f for f in test_dir.iterdir() if f.name.startswith("epoch")]
bestE_weight = [f for f in test_dir.iterdir() if f.name.startswith("bestE")]
bestF_weight = [f for f in test_dir.iterdir() if f.name.startswith("bestF")]
best_e_weight = [f for f in test_dir.iterdir() if f.name.startswith("bestE")]
best_f_weight = [f for f in test_dir.iterdir() if f.name.startswith("bestF")]
assert len(saved_weight) == 1
assert len(bestE_weight) == 1
assert len(bestF_weight) == 1
assert len(best_e_weight) == 1
assert len(best_f_weight) == 1

0 comments on commit 755e17c

Please sign in to comment.