Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cc: refactor DataModifier for multiple-backend framework #3148

Merged
merged 4 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 88 additions & 36 deletions source/api_cc/include/DataModifier.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,92 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
#pragma once

#include "DeepPot.h"
#include <memory>

#include "common.h"

namespace deepmd {
/**
* @brief Dipole charge modifier. (Base class)
**/
class DipoleChargeModifierBase {
public:
/**
* @brief Dipole charge modifier without initialization.
**/
DipoleChargeModifierBase(){};
/**
* @brief Dipole charge modifier without initialization.
* @param[in] model The name of the frozen model file.
* @param[in] gpu_rank The GPU rank. Default is 0.
* @param[in] name_scope The name scope.
**/
DipoleChargeModifierBase(const std::string& model,
const int& gpu_rank = 0,
const std::string& name_scope = "");
virtual ~DipoleChargeModifierBase(){};
/**
* @brief Initialize the dipole charge modifier.
* @param[in] model The name of the frozen model file.
* @param[in] gpu_rank The GPU rank. Default is 0.
* @param[in] name_scope The name scope.
**/
virtual void init(const std::string& model,
const int& gpu_rank = 0,
const std::string& name_scope = "") = 0;
/**
* @brief Evaluate the force and virial correction by using this dipole charge
*modifier.
* @param[out] dfcorr_ The force correction on each atom.
* @param[out] dvcorr_ The virial correction.
* @param[in] dcoord_ The coordinates of atoms. The array should be of size
*natoms x 3.
* @param[in] datype_ The atom types. The list should contain natoms ints.
* @param[in] dbox The cell of the region. The array should be of size 9.
* @param[in] pairs The pairs of atoms. The list should contain npairs pairs
*of ints.
* @param[in] delef_ The electric field on each atom. The array should be of
*size natoms x 3.
* @param[in] nghost The number of ghost atoms.
* @param[in] lmp_list The neighbor list.
@{
**/
virtual void computew(std::vector<double>& dfcorr_,
std::vector<double>& dvcorr_,
const std::vector<double>& dcoord_,
const std::vector<int>& datype_,
const std::vector<double>& dbox,
const std::vector<std::pair<int, int>>& pairs,
const std::vector<double>& delef_,
const int nghost,
const InputNlist& lmp_list) = 0;
virtual void computew(std::vector<float>& dfcorr_,
std::vector<float>& dvcorr_,
const std::vector<float>& dcoord_,
const std::vector<int>& datype_,
const std::vector<float>& dbox,
const std::vector<std::pair<int, int>>& pairs,
const std::vector<float>& delef_,
const int nghost,
const InputNlist& lmp_list) = 0;
/** @} */
/**
* @brief Get cutoff radius.
* @return double cutoff radius.
*/
virtual double cutoff() const = 0;
/**
* @brief Get the number of atom types.
* @return int number of atom types.
*/
virtual int numb_types() const = 0;
/**
* @brief Get the list of sel types.
* @return The list of sel types.
*/
virtual std::vector<int> sel_types() const = 0;
};

/**
* @brief Dipole charge modifier.
**/
Expand Down Expand Up @@ -38,7 +121,6 @@ class DipoleChargeModifier {
**/
void print_summary(const std::string& pre) const;

public:
/**
* @brief Evaluate the force and virial correction by using this dipole charge
*modifier.
Expand Down Expand Up @@ -69,50 +151,20 @@ class DipoleChargeModifier {
* @brief Get cutoff radius.
* @return double cutoff radius.
*/
double cutoff() const {
assert(inited);
return rcut;
};
double cutoff() const;
/**
* @brief Get the number of atom types.
* @return int number of atom types.
*/
int numb_types() const {
assert(inited);
return ntypes;
};
int numb_types() const;
/**
* @brief Get the list of sel types.
* @return The list of sel types.
*/
std::vector<int> sel_types() const {
assert(inited);
return sel_type;
};
std::vector<int> sel_types() const;

private:
tensorflow::Session* session;
std::string name_scope, name_prefix;
int num_intra_nthreads, num_inter_nthreads;
tensorflow::GraphDef* graph_def;
bool inited;
double rcut;
int dtype;
double cell_size;
int ntypes;
std::string model_type;
std::vector<int> sel_type;
template <class VT>
VT get_scalar(const std::string& name) const;
template <class VT>
void get_vector(std::vector<VT>& vec, const std::string& name) const;
template <typename MODELTYPE, typename VALUETYPE>
void run_model(std::vector<VALUETYPE>& dforce,
std::vector<VALUETYPE>& dvirial,
tensorflow::Session* session,
const std::vector<std::pair<std::string, tensorflow::Tensor>>&
input_tensors,
const AtomMap& atommap,
const int nghost);
std::shared_ptr<deepmd::DipoleChargeModifierBase> dcm;
};
} // namespace deepmd
132 changes: 132 additions & 0 deletions source/api_cc/include/DataModifierTF.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
#pragma once

#include "DataModifier.h"
#include "common.h"

namespace deepmd {
/**
* @brief Dipole charge modifier.
**/
class DipoleChargeModifierTF : public DipoleChargeModifierBase {
public:
/**
* @brief Dipole charge modifier without initialization.
**/
DipoleChargeModifierTF();
/**
* @brief Dipole charge modifier without initialization.
* @param[in] model The name of the frozen model file.
* @param[in] gpu_rank The GPU rank. Default is 0.
* @param[in] name_scope The name scope.
**/
DipoleChargeModifierTF(const std::string& model,
const int& gpu_rank = 0,
const std::string& name_scope = "");
~DipoleChargeModifierTF();
/**
* @brief Initialize the dipole charge modifier.
* @param[in] model The name of the frozen model file.
* @param[in] gpu_rank The GPU rank. Default is 0.
* @param[in] name_scope The name scope.
**/
void init(const std::string& model,
const int& gpu_rank = 0,
const std::string& name_scope = "");

public:
/**
* @brief Evaluate the force and virial correction by using this dipole charge
*modifier.
* @param[out] dfcorr_ The force correction on each atom.
* @param[out] dvcorr_ The virial correction.
* @param[in] dcoord_ The coordinates of atoms. The array should be of size
*natoms x 3.
* @param[in] datype_ The atom types. The list should contain natoms ints.
* @param[in] dbox The cell of the region. The array should be of size 9.
* @param[in] pairs The pairs of atoms. The list should contain npairs pairs
*of ints.
* @param[in] delef_ The electric field on each atom. The array should be of
*size natoms x 3.
* @param[in] nghost The number of ghost atoms.
* @param[in] lmp_list The neighbor list.
**/
template <typename VALUETYPE>
void compute(std::vector<VALUETYPE>& dfcorr_,
std::vector<VALUETYPE>& dvcorr_,
const std::vector<VALUETYPE>& dcoord_,
const std::vector<int>& datype_,
const std::vector<VALUETYPE>& dbox,
const std::vector<std::pair<int, int>>& pairs,
const std::vector<VALUETYPE>& delef_,
const int nghost,
const InputNlist& lmp_list);
/**
* @brief Get cutoff radius.
* @return double cutoff radius.
*/
double cutoff() const {

Check warning on line 68 in source/api_cc/include/DataModifierTF.h

View check run for this annotation

Codecov / codecov/patch

source/api_cc/include/DataModifierTF.h#L68

Added line #L68 was not covered by tests
assert(inited);
return rcut;

Check warning on line 70 in source/api_cc/include/DataModifierTF.h

View check run for this annotation

Codecov / codecov/patch

source/api_cc/include/DataModifierTF.h#L70

Added line #L70 was not covered by tests
};
/**
* @brief Get the number of atom types.
* @return int number of atom types.
*/
int numb_types() const {

Check warning on line 76 in source/api_cc/include/DataModifierTF.h

View check run for this annotation

Codecov / codecov/patch

source/api_cc/include/DataModifierTF.h#L76

Added line #L76 was not covered by tests
assert(inited);
return ntypes;

Check warning on line 78 in source/api_cc/include/DataModifierTF.h

View check run for this annotation

Codecov / codecov/patch

source/api_cc/include/DataModifierTF.h#L78

Added line #L78 was not covered by tests
};
/**
* @brief Get the list of sel types.
* @return The list of sel types.
*/
std::vector<int> sel_types() const {
assert(inited);
return sel_type;
};
void computew(std::vector<double>& dfcorr_,
std::vector<double>& dvcorr_,
const std::vector<double>& dcoord_,
const std::vector<int>& datype_,
const std::vector<double>& dbox,
const std::vector<std::pair<int, int>>& pairs,
const std::vector<double>& delef_,
const int nghost,
const InputNlist& lmp_list);
void computew(std::vector<float>& dfcorr_,
std::vector<float>& dvcorr_,
const std::vector<float>& dcoord_,
const std::vector<int>& datype_,
const std::vector<float>& dbox,
const std::vector<std::pair<int, int>>& pairs,
const std::vector<float>& delef_,
const int nghost,
const InputNlist& lmp_list);

private:
tensorflow::Session* session;
std::string name_scope, name_prefix;
int num_intra_nthreads, num_inter_nthreads;
tensorflow::GraphDef* graph_def;
bool inited;
double rcut;
int dtype;
double cell_size;
int ntypes;
std::string model_type;
std::vector<int> sel_type;
template <class VT>
VT get_scalar(const std::string& name) const;
template <class VT>
void get_vector(std::vector<VT>& vec, const std::string& name) const;
template <typename MODELTYPE, typename VALUETYPE>
void run_model(std::vector<VALUETYPE>& dforce,
std::vector<VALUETYPE>& dvirial,
tensorflow::Session* session,
const std::vector<std::pair<std::string, tensorflow::Tensor>>&
input_tensors,
const AtomMap& atommap,
const int nghost);
};
} // namespace deepmd
Loading
Loading