Skip to content

Commit

Permalink
MAINT: Parameters of cg solver in scipy changed with version 1.12.0
Browse files Browse the repository at this point in the history
  • Loading branch information
pastewka committed Jul 16, 2024
1 parent 6b6544a commit f0112ca
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 10 deletions.
5 changes: 3 additions & 2 deletions matscipy/calculators/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
nonaffine_elastic_contribution,
)

from ..numerical import numerical_nonaffine_forces

from ..compat import compat_cg_parameters
from ..numpy_tricks import mabincount


Expand Down Expand Up @@ -368,6 +367,8 @@ def get_non_affine_contribution_to_elastic_constants(self, atoms, eigenvalues=No
"This function is deprecated and will be removed in the future. Use 'elasticity.nonaffine_elastic_contribution' instead.",
DeprecationWarning)

cg_parameters = compat_cg_parameters(cg_parameters)

nat = len(atoms)

calc = self
Expand Down
13 changes: 13 additions & 0 deletions matscipy/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import scipy
from packaging.version import Version


def compat_cg_parameters(cg_parameters):
if Version(scipy.__version__) >= Version('1.12.0'):
return cg_parameters
else:
cg_parameters = cg_parameters.copy()
if 'rtol' in cg_parameters:
cg_parameters['tol'] = cg_parameters['rtol']
del cg_parameters['rtol']
return cg_parameters
4 changes: 4 additions & 0 deletions matscipy/elasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import ase.units as units
from ase.atoms import Atoms

from .compat import compat_cg_parameters

###

# The indices of the full stiffness matrix of (orthorhombic) interest
Expand Down Expand Up @@ -1335,6 +1337,8 @@ def _sym(C_abab):
symmetry_group = [(0, 1, 2, 3), (1, 0, 2, 3), (0, 1, 3, 2), (1, 0, 3, 2)]
return 0.25 * np.add.reduce([C_abab.transpose(s) for s in symmetry_group])

cg_parameters = compat_cg_parameters(cg_parameters)

nat = len(atoms)
naforces_icab = atoms.calc.get_property('nonaffine_forces')

Expand Down
8 changes: 4 additions & 4 deletions matscipy/io/opls.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import copy
import re

from looseversion import LooseVersion
from packaging.version import Version

import numpy as np
import ase
Expand Down Expand Up @@ -359,7 +359,7 @@ def write_lammps_atoms(prefix, atoms, units='metal'):
fileobj.write('%d dihedral types\n' % (len(dtypes)))

# cell
if LooseVersion(ase_version_str) > LooseVersion('3.11.0'):
if Version(ase_version_str) > Version('3.11.0'):
p = ase.calculators.lammpsrun.Prism(atoms.get_cell())
else:
p = ase.calculators.lammpsrun.prism(atoms.get_cell())
Expand All @@ -385,9 +385,9 @@ def write_lammps_atoms(prefix, atoms, units='metal'):
molid = [1] * len(atoms)

pos = ase.calculators.lammpsrun.convert(atoms.get_positions(), 'distance', 'ASE', units)
if LooseVersion(ase_version_str) > LooseVersion('3.17.0'):
if Version(ase_version_str) > Version('3.17.0'):
positions_lammps_str = p.vector_to_lammps(pos).astype(str)
elif LooseVersion(ase_version_str) > LooseVersion('3.13.0'):
elif Version(ase_version_str) > Version('3.13.0'):
positions_lammps_str = p.positions_to_lammps_strs(pos)
else:
positions_lammps_str = map(p.pos_to_lammps_str, pos)
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ requires-python = ">=3.8.0"
dynamic = ["version"]
dependencies = [
"numpy>=1.16.0",
"scipy>=1.12.0", # rtol in scipy cg
"scipy>=1.2.3",
"ase>=3.16.0",
"looseversion"
"packaging"
]

[project.optional-dependencies]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_electrochemistry_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import tempfile
import unittest

from looseversion import LooseVersion
from packaging.version import Version


class ElectrochemistryCliTest(matscipytest.MatSciPyTestCase):
Expand Down Expand Up @@ -116,7 +116,7 @@ def test_c2d_input_format_txt_output_format_xyz(self):
== self.ref_xyz.get_initial_charges() ).all() )
self.assertTrue( ( xyz.cell == self.ref_xyz.cell ).all() )

@unittest.skipUnless(LooseVersion(ase.__version__) > LooseVersion('3.19.0'),
@unittest.skipUnless(Version(ase.__version__) > Version('3.19.0'),
""" LAMMPS data file won't work for ASE version up until 3.18.1,
LAMMPS data file input broken in ASE 3.19.0, skipped""")
def test_c2d_input_format_npz_output_format_lammps(self):
Expand Down

0 comments on commit f0112ca

Please sign in to comment.