Skip to content

Commit

Permalink
Started refactoring MBTR.
Browse files Browse the repository at this point in the history
  • Loading branch information
lauri-codes committed Aug 13, 2024
1 parent 893722c commit 87b976f
Show file tree
Hide file tree
Showing 10 changed files with 2,148 additions and 742 deletions.
856 changes: 124 additions & 732 deletions dscribe/descriptors/mbtr.py

Large diffs are not rendered by default.

953 changes: 953 additions & 0 deletions dscribe/descriptors/mbtr_old.py

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions dscribe/ext/constants.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#ifndef CONSTANTS_H
#define CONSTANTS_H
#define PI 3.1415926535897932384626433832795028841971693993751058209749445923078164062
#define SQRT2 sqrt(2.0)
#define SQRT2PI sqrt(2.0 * PI)
#endif
2 changes: 1 addition & 1 deletion dscribe/ext/descriptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class Descriptor {
Descriptor(bool periodic, string average="", double cutoff=0);
const bool periodic;
const string average;
const double cutoff;
double cutoff;
};

#endif
52 changes: 44 additions & 8 deletions dscribe/ext/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ limitations under the License.
#include "coulombmatrix.h"
#include "soap.h"
#include "acsf.h"
#include "mbtr.h"
#include "mbtr2.h"
#include "geometry.h"

namespace py = pybind11;
Expand Down Expand Up @@ -110,13 +110,49 @@ PYBIND11_MODULE(ext, m) {
));

// MBTR
py::class_<MBTR>(m, "MBTRWrapper")
.def(py::init< map<int,int>, int , vector<vector<int>> >())
.def("get_k1", &MBTR::getK1)
.def("get_k2", &MBTR::getK2)
.def("get_k3", &MBTR::getK3)
.def("get_k2_local", &MBTR::getK2Local)
.def("get_k3_local", &MBTR::getK3Local);
// py::class_<MBTR>(m, "MBTRWrapper")
// .def(py::init< map<int,int>, int , vector<vector<int>> >())
// .def("get_k1", &MBTR::getK1)
// .def("get_k2", &MBTR::getK2)
// .def("get_k3", &MBTR::getK3)
// .def("get_k2_local", &MBTR::getK2Local)
// .def("get_k3_local", &MBTR::getK3Local);
// MBTR
py::class_<MBTR>(m, "MBTR")
.def(py::init<py::dict, py::dict, py::dict, bool, string, py::array_t<int>, bool>())
.def("create", overload_cast_<py::array_t<double>, py::array_t<double>, py::array_t<int>, py::array_t<double>, py::array_t<bool>>()(&DescriptorGlobal::create))
.def("get_number_of_features", &MBTR::get_number_of_features)
.def("get_location", overload_cast_<int>()(&MBTR::get_location))
.def("get_location", overload_cast_<int, int>()(&MBTR::get_location))
.def("get_location", overload_cast_<int, int, int>()(&MBTR::get_location))
.def_property("geometry", &MBTR::get_geometry, &MBTR::set_geometry)
.def_property("grid", &MBTR::get_grid, &MBTR::set_grid)
.def_property("weighting", &MBTR::get_weighting, &MBTR::set_weighting)
.def_property_readonly("k", &MBTR::get_k)
.def_property("species", &MBTR::get_species, &MBTR::set_species)
.def_property("normalization", &MBTR::get_normalization, &MBTR::set_normalization)
.def_property("normalize_gaussians", &MBTR::get_normalize_gaussians, &MBTR::set_normalize_gaussians)
.def_property_readonly("species_index_map", &MBTR::get_species_index_map)
.def("derivatives_numerical", &MBTR::derivatives_numerical)
.def(py::pickle(
[](const MBTR &p) {
return py::make_tuple(p.geometry, p.grid, p.weighting, p.normalize_gaussians, p.normalization, p.species, p.periodic);
},
[](py::tuple t) {
if (t.size() != 7)
throw std::runtime_error("Invalid state!");
MBTR p(
t[0].cast<py::dict>(),
t[1].cast<py::dict>(),
t[2].cast<py::dict>(),
t[3].cast<bool>(),
t[4].cast<string>(),
t[5].cast<py::array_t<int>>(),
t[6].cast<bool>()
);
return p;
}
));

// CellList
py::class_<CellList>(m, "CellList")
Expand Down
51 changes: 51 additions & 0 deletions dscribe/ext/geometry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,57 @@ inline double norm(const vector<double>& a) {
return sqrt(accum);
};

System::System(
py::array_t<double> positions,
py::array_t<int> atomic_numbers,
bool extra
)
: positions(positions)
, atomic_numbers(atomic_numbers)
{
if (!extra) { return; }

// Create the default set of interactive atoms, which encompasses the whole
// system
unordered_set<int> interactive_atoms = unordered_set<int>();
int n_atoms = atomic_numbers.size();
for (int i = 0; i < n_atoms; ++i) {
interactive_atoms.insert(i);
}
this->interactive_atoms = interactive_atoms;

// Create the default cell indices
py::array_t<int> cell_indices({n_atoms});
auto cell_indices_mu = cell_indices.mutable_unchecked<1>();
for (int i = 0; i < n_atoms; ++i) {
cell_indices_mu(i) = 0;
}
this->cell_indices = cell_indices;

// Create the default indices
py::array_t<int> indices({uint(n_atoms)});
auto indices_mu = indices.mutable_unchecked<1>();
for (int i = 0; i < n_atoms; ++i) {
indices_mu(i) = i;
}
this->indices = indices;
}

System::System(
py::array_t<double> positions,
py::array_t<int> atomic_numbers,
py::array_t<int> indices,
py::array_t<int> cell_indices,
unordered_set<int> interactive_atoms
)
: positions(positions)
, atomic_numbers(atomic_numbers)
, indices(indices)
, cell_indices(cell_indices)
, interactive_atoms(interactive_atoms)
{
}

ExtendedSystem extend_system(
py::array_t<double> positions,
py::array_t<int> atomic_numbers,
Expand Down
42 changes: 42 additions & 0 deletions dscribe/ext/geometry.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,48 @@ struct ExtendedSystem {
py::array_t<int> indices;
};

class System {
public:
System(
py::array_t<double> positions,
py::array_t<int> atomic_numbers,
bool extra = true
);
System(
py::array_t<double> positions,
py::array_t<int> atomic_numbers,
py::array_t<int> indices,
py::array_t<int> cell_indices,
unordered_set<int> interactive_atoms
);
py::array_t<double> positions;
py::array_t<int> atomic_numbers;
/**
* Indices is a one-dimensional array that links each atom in the system
* into an index in the original, non-repeated system.
*/
py::array_t<int> indices;
/**
* Cell indices is a {n_atoms, 3} array that links each atom in the
* system into the index of a repeated cell. For non-extended systems
* all atoms are always tied to cell with index (0, 0, 0), but for
* extended atoms the index will vary.
*/
py::array_t<int> cell_indices;
/**
* Interactive atoms contains the indices of the interacting atoms in
* the system. Interacting atoms are the ones which will act as local
* centers when creating a descriptor.
*/
unordered_set<int> interactive_atoms;

py::array_t<double> get_positions() {return this->positions;};
py::array_t<int> get_atomic_numbers() {return this->atomic_numbers;};
py::array_t<int> get_indices() {return this->indices;};
py::array_t<int> get_cell_indices() {return this->cell_indices;};
unordered_set<int> get_interactive_atoms() {return this->interactive_atoms;};
};

inline vector<double> cross(const vector<double>& a, const vector<double>& b);
inline double dot(const vector<double>& a, const vector<double>& b);
inline double norm(const vector<double>& a);
Expand Down
Loading

0 comments on commit 87b976f

Please sign in to comment.