diff --git a/deepmd/main.py b/deepmd/main.py index 5dab029d83..870a04a088 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -240,7 +240,7 @@ def main_parser() -> argparse.ArgumentParser: "--output", type=str, default="out.json", - help="(Supported backend: TensorFlow) The output file of the parameters used in training.", + help="The output file of the parameters used in training.", ) parser_train.add_argument( "--skip-neighbor-stat", diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 023bc5305e..736e8dde09 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -78,10 +78,6 @@ def get_trainer( shared_links=None, ): multi_task = "model_dict" in config.get("model", {}) - # argcheck - if not multi_task: - config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json") - config = normalize(config) # Initialize DDP local_rank = os.environ.get("LOCAL_RANK") @@ -236,6 +232,11 @@ def train(FLAGS): if multi_task: config["model"], shared_links = preprocess_shared_params(config["model"]) + # argcheck + if not multi_task: + config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json") + config = normalize(config) + # do neighbor stat if not FLAGS.skip_neighbor_stat: log.info( @@ -257,6 +258,9 @@ def train(FLAGS): fake_global_jdata, config["model"]["model_dict"][model_item] ) + with open(FLAGS.output, "w") as fp: + json.dump(config, fp, indent=4) + trainer = get_trainer( config, FLAGS.init_model, diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index a1ca81bc1b..7c4c103509 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -54,6 +54,7 @@ def reinit_pair_exclude( # export public methods that are not abstract get_nsel = torch.jit.export(BaseAtomicModel_.get_nsel) get_nnei = torch.jit.export(BaseAtomicModel_.get_nnei) + get_ntypes = torch.jit.export(BaseAtomicModel_.get_ntypes) @torch.jit.export def get_model_def_script(self) -> str: diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 196e1497b1..4020f5edc9 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -301,7 +301,6 @@ def output_type_cast( ) return model_ret - @torch.jit.export def format_nlist( self, extended_coord: torch.Tensor, diff --git a/doc/model/dprc.md b/doc/model/dprc.md index 48e18e8d89..ac1ab0e261 100644 --- a/doc/model/dprc.md +++ b/doc/model/dprc.md @@ -1,7 +1,7 @@ -# Deep Potential - Range Correction (DPRc) {{ tensorflow_icon }} +# Deep Potential - Range Correction (DPRc) {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }} :::{note} -**Supported backends**: TensorFlow {{ tensorflow_icon }} +**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DPModel {{ dpmodel_icon }} ::: Deep Potential - Range Correction (DPRc) is designed to combine with QM/MM method, and corrects energies from a low-level QM/MM method to a high-level QM/MM method: @@ -62,6 +62,10 @@ In a DPRc model, QM atoms and MM atoms have different atom types. Assuming we ha As described in the paper, the DPRc model only corrects $E_\text{QM}$ and $E_\text{QM/MM}$ within the cutoff, so we use a hybrid descriptor to describe them separatedly: +::::{tab-set} + +:::{tab-item} TensorFlow {{ tensorflow_icon }} + ```json "descriptor" :{ "type": "hybrid", @@ -91,6 +95,45 @@ As described in the paper, the DPRc model only corrects $E_\text{QM}$ and $E_\te } ``` +::: + +:::{tab-item} PyTorch {{ pytorch_icon }} + +```json +"descriptor" :{ + "type": "hybrid", + "list" : [ + { + "type": "se_e2_a", + "sel": [6, 11, 0, 6, 0, 1], + "rcut_smth": 1.00, + "rcut": 9.00, + "neuron": [12, 25, 50], + "exclude_types": [[2, 2], [2, 4], [4, 4], [0, 2], [0, 4], [1, 2], [1, 4], [3, 2], [3, 4], [5, 2], [5, 4]], + "axis_neuron": 12, + "type_one_side": true, + "_comment": " QM/QM interaction" + }, + { + "type": "se_e2_a", + "sel": [6, 11, 100, 6, 50, 1], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [12, 25, 50], + "exclude_types": [[0, 0], [0, 1], [0, 3], [0, 5], [1, 1], [1, 3], [1, 5], [3, 3], [3, 5], [5, 5], [2, 2], [2, 4], [4, 4]], + "axis_neuron": 12, + "set_davg_zero": true, + "type_one_side": true, + "_comment": " QM/MM interaction" + } + ] +} +``` + +::: + +:::: + {ref}`exclude_types ` can be generated by the following Python script: ```py from itertools import combinations_with_replacement, product @@ -131,6 +174,10 @@ The DPRc model has the best practices with the [AMBER](../third-party/out-of-dee ## Pairwise DPRc +:::{note} +**Supported backends**: TensorFlow {{ tensorflow_icon }} +::: + If one wants to correct from a low-level method into a full DFT level, and the system is too large to do full DFT calculation, one may try the experimental pairwise DPRc model. In a pairwise DPRc model, the total energy is divided into QM internal energy and the sum of QM/MM energy for each MM residue $l$: diff --git a/source/api_cc/include/DeepPotPT.h b/source/api_cc/include/DeepPotPT.h index d50d338d33..a7fc910b46 100644 --- a/source/api_cc/include/DeepPotPT.h +++ b/source/api_cc/include/DeepPotPT.h @@ -1,10 +1,10 @@ // SPDX-License-Identifier: LGPL-3.0-or-later #pragma once +#include #include #include "DeepPot.h" -#include "commonPT.h" namespace deepmd { /** @@ -106,7 +106,7 @@ class DeepPotPT : public DeepPotBase { const std::vector& coord, const std::vector& atype, const std::vector& box, - // const int nghost, + const int nghost, const InputNlist& lmp_list, const int& ago, const std::vector& fparam = std::vector(), @@ -322,7 +322,7 @@ class DeepPotPT : public DeepPotBase { // copy neighbor list info from host torch::jit::script::Module module; double rcut; - NeighborListDataPT nlist_data; + NeighborListData nlist_data; int max_num_neighbors; int gpu_id; bool gpu_enabled; diff --git a/source/api_cc/include/common.h b/source/api_cc/include/common.h index 72382169f8..4743336e0c 100644 --- a/source/api_cc/include/common.h +++ b/source/api_cc/include/common.h @@ -32,6 +32,7 @@ struct NeighborListData { void shuffle(const deepmd::AtomMap& map); void shuffle_exclude_empty(const std::vector& fwd_map); void make_inlist(InputNlist& inlist); + void padding(); }; /** diff --git a/source/api_cc/include/commonPT.h b/source/api_cc/include/commonPT.h deleted file mode 100644 index 57ffd5b295..0000000000 --- a/source/api_cc/include/commonPT.h +++ /dev/null @@ -1,24 +0,0 @@ -// SPDX-License-Identifier: LGPL-3.0-or-later -#include - -#include -#include -#include -#include - -#include "neighbor_list.h" -namespace deepmd { -struct NeighborListDataPT { - /// Array stores the core region atom's index - std::vector ilist; - /// Array stores the core region atom's neighbor index - std::vector jlist; - /// Array stores the number of neighbors of core region atoms - std::vector numneigh; - /// Array stores the the location of the first neighbor of core region atoms - std::vector firstneigh; - - public: - void copy_from_nlist(const InputNlist& inlist, int& max_num_neighbors); -}; -} // namespace deepmd diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 9514a9769c..919d690bed 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -4,6 +4,17 @@ #include "common.h" using namespace deepmd; +torch::Tensor createNlistTensor(const std::vector>& data) { + std::vector row_tensors; + + for (const auto& row : data) { + torch::Tensor row_tensor = torch::tensor(row, torch::kInt32).unsqueeze(0); + row_tensors.push_back(row_tensor); + } + + torch::Tensor tensor = torch::cat(row_tensors, 0).unsqueeze(0); + return tensor; +} DeepPotPT::DeepPotPT() : inited(false) {} DeepPotPT::DeepPotPT(const std::string& model, const int& gpu_rank, @@ -60,7 +71,7 @@ void DeepPotPT::init(const std::string& model, auto rcut_ = module.run_method("get_rcut").toDouble(); rcut = static_cast(rcut_); - ntypes = 0; + ntypes = module.run_method("get_ntypes").toInt(); ntypes_spin = 0; dfparam = module.run_method("get_dim_fparam").toInt(); daparam = module.run_method("get_dim_aparam").toInt(); @@ -78,6 +89,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, const std::vector& coord, const std::vector& atype, const std::vector& box, + const int nghost, const InputNlist& lmp_list, const int& ago, const std::vector& fparam, @@ -86,7 +98,6 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, if (!gpu_enabled) { device = torch::Device(torch::kCPU); } - std::vector coord_wrapped = coord; int natoms = atype.size(); auto options = torch::TensorOptions().dtype(torch::kFloat64); torch::ScalarType floatType = torch::kFloat64; @@ -96,18 +107,29 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, } auto int_options = torch::TensorOptions().dtype(torch::kInt64); auto int32_options = torch::TensorOptions().dtype(torch::kInt32); + + // select real atoms + std::vector dcoord, dforce, aparam_, datom_energy, datom_virial; + std::vector datype, fwd_map, bkw_map; + int nghost_real, nall_real, nloc_real; + int nall = natoms; + select_real_atoms_coord(dcoord, datype, aparam_, nghost_real, fwd_map, + bkw_map, nall_real, nloc_real, coord, atype, aparam, + nghost, ntypes, 1, daparam, nall, aparam_nall); + std::cout << datype.size() << std::endl; + std::vector coord_wrapped = dcoord; at::Tensor coord_wrapped_Tensor = - torch::from_blob(coord_wrapped.data(), {1, natoms, 3}, options) + torch::from_blob(coord_wrapped.data(), {1, nall_real, 3}, options) .to(device); - std::vector atype_64(atype.begin(), atype.end()); + std::vector atype_64(datype.begin(), datype.end()); at::Tensor atype_Tensor = - torch::from_blob(atype_64.data(), {1, natoms}, int_options).to(device); + torch::from_blob(atype_64.data(), {1, nall_real}, int_options).to(device); if (ago == 0) { - nlist_data.copy_from_nlist(lmp_list, max_num_neighbors); + nlist_data.copy_from_nlist(lmp_list); + nlist_data.shuffle_exclude_empty(fwd_map); + nlist_data.padding(); } - at::Tensor firstneigh = - torch::from_blob(nlist_data.jlist.data(), - {1, lmp_list.inum, max_num_neighbors}, int32_options); + at::Tensor firstneigh = createNlistTensor(nlist_data.jlist); firstneigh_tensor = firstneigh.to(torch::kInt64).to(device); bool do_atom_virial_tensor = true; c10::optional optional_tensor; @@ -119,13 +141,13 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, .to(device); } c10::optional aparam_tensor; - if (!aparam.empty()) { - aparam_tensor = - torch::from_blob(const_cast(aparam.data()), - {1, lmp_list.inum, - static_cast(aparam.size()) / lmp_list.inum}, - options) - .to(device); + if (!aparam_.empty()) { + aparam_tensor = torch::from_blob( + const_cast(aparam_.data()), + {1, lmp_list.inum, + static_cast(aparam_.size()) / lmp_list.inum}, + options) + .to(device); } c10::Dict outputs = module @@ -145,14 +167,15 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, torch::Tensor flat_atom_energy_ = atom_energy_.toTensor().view({-1}).to(floatType); torch::Tensor cpu_atom_energy_ = flat_atom_energy_.to(torch::kCPU); - atom_energy.resize(natoms, 0.0); // resize to nall to be consistenet with TF. - atom_energy.assign( + datom_energy.resize(nall_real, + 0.0); // resize to nall to be consistenet with TF. + datom_energy.assign( cpu_atom_energy_.data_ptr(), cpu_atom_energy_.data_ptr() + cpu_atom_energy_.numel()); torch::Tensor flat_force_ = force_.toTensor().view({-1}).to(floatType); torch::Tensor cpu_force_ = flat_force_.to(torch::kCPU); - force.assign(cpu_force_.data_ptr(), - cpu_force_.data_ptr() + cpu_force_.numel()); + dforce.assign(cpu_force_.data_ptr(), + cpu_force_.data_ptr() + cpu_force_.numel()); torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType); torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU); virial.assign(cpu_virial_.data_ptr(), @@ -160,9 +183,20 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, torch::Tensor flat_atom_virial_ = atom_virial_.toTensor().view({-1}).to(floatType); torch::Tensor cpu_atom_virial_ = flat_atom_virial_.to(torch::kCPU); - atom_virial.assign( + datom_virial.assign( cpu_atom_virial_.data_ptr(), cpu_atom_virial_.data_ptr() + cpu_atom_virial_.numel()); + int nframes = 1; + // bkw map + force.resize(static_cast(nframes) * fwd_map.size() * 3); + atom_energy.resize(static_cast(nframes) * fwd_map.size()); + atom_virial.resize(static_cast(nframes) * fwd_map.size() * 9); + select_map(force, dforce, bkw_map, 3, nframes, fwd_map.size(), + nall_real); + select_map(atom_energy, datom_energy, bkw_map, 1, nframes, + fwd_map.size(), nall_real); + select_map(atom_virial, datom_virial, bkw_map, 9, nframes, + fwd_map.size(), nall_real); } template void DeepPotPT::compute>( std::vector& ener, @@ -173,6 +207,7 @@ template void DeepPotPT::compute>( const std::vector& coord, const std::vector& atype, const std::vector& box, + const int nghost, const InputNlist& lmp_list, const int& ago, const std::vector& fparam, @@ -186,6 +221,7 @@ template void DeepPotPT::compute>( const std::vector& coord, const std::vector& atype, const std::vector& box, + const int nghost, const InputNlist& lmp_list, const int& ago, const std::vector& fparam, @@ -353,7 +389,7 @@ void DeepPotPT::computew(std::vector& ener, const std::vector& fparam, const std::vector& aparam) { compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box, - inlist, ago, fparam, aparam); + nghost, inlist, ago, fparam, aparam); } void DeepPotPT::computew(std::vector& ener, std::vector& force, @@ -369,7 +405,7 @@ void DeepPotPT::computew(std::vector& ener, const std::vector& fparam, const std::vector& aparam) { compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box, - inlist, ago, fparam, aparam); + nghost, inlist, ago, fparam, aparam); } void DeepPotPT::computew_mixed_type(std::vector& ener, std::vector& force, diff --git a/source/api_cc/src/common.cc b/source/api_cc/src/common.cc index d2923c8d9e..f104433468 100644 --- a/source/api_cc/src/common.cc +++ b/source/api_cc/src/common.cc @@ -293,6 +293,16 @@ void deepmd::NeighborListData::shuffle_exclude_empty( ilist = new_ilist; jlist = new_jlist; } +void deepmd::NeighborListData::padding() { + size_t max_length = 0; + for (const auto& row : jlist) { + max_length = std::max(max_length, row.size()); + } + + for (int i = 0; i < jlist.size(); i++) { + jlist[i].resize(max_length, -1); + } +} void deepmd::NeighborListData::make_inlist(InputNlist& inlist) { int nloc = ilist.size(); diff --git a/source/api_cc/src/commonPT.cc b/source/api_cc/src/commonPT.cc deleted file mode 100644 index 4ed3b21fe8..0000000000 --- a/source/api_cc/src/commonPT.cc +++ /dev/null @@ -1,23 +0,0 @@ -// SPDX-License-Identifier: LGPL-3.0-or-later -#ifdef BUILD_PYTORCH -#include "commonPT.h" -using namespace deepmd; -void NeighborListDataPT::copy_from_nlist(const InputNlist& inlist, - int& max_num_neighbors) { - int inum = inlist.inum; - ilist.resize(inum); - numneigh.resize(inum); - memcpy(&ilist[0], inlist.ilist, inum * sizeof(int)); - int* max_element = std::max_element(inlist.numneigh, inlist.numneigh + inum); - max_num_neighbors = *max_element; - unsigned long nlist_size = (unsigned long)inum * max_num_neighbors; - jlist.resize(nlist_size); - jlist.assign(nlist_size, -1); - for (int ii = 0; ii < inum; ++ii) { - int jnum = inlist.numneigh[ii]; - numneigh[ii] = inlist.numneigh[ii]; - memcpy(&jlist[(unsigned long)ii * max_num_neighbors], inlist.firstneigh[ii], - jnum * sizeof(int)); - } -} -#endif diff --git a/source/api_cc/tests/test_deeppot_pt.cc b/source/api_cc/tests/test_deeppot_pt.cc index e0e90ac75c..cc30e606c0 100644 --- a/source/api_cc/tests/test_deeppot_pt.cc +++ b/source/api_cc/tests/test_deeppot_pt.cc @@ -402,7 +402,6 @@ TYPED_TEST(TestInferDeepPotAPt, cpu_lmp_nlist_2rc) { } TYPED_TEST(TestInferDeepPotAPt, cpu_lmp_nlist_type_sel) { - GTEST_SKIP() << "Skipping this test for unsupported"; using VALUETYPE = TypeParam; std::vector& coord = this->coord; std::vector& atype = this->atype; @@ -465,7 +464,6 @@ TYPED_TEST(TestInferDeepPotAPt, cpu_lmp_nlist_type_sel) { } TYPED_TEST(TestInferDeepPotAPt, cpu_lmp_nlist_type_sel_atomic) { - GTEST_SKIP() << "Skipping this test for unsupported"; using VALUETYPE = TypeParam; std::vector& coord = this->coord; std::vector& atype = this->atype; diff --git a/source/tests/infer/deeppot_sea.pth b/source/tests/infer/deeppot_sea.pth index 98aaa8a2ad..c830f0df9e 100644 Binary files a/source/tests/infer/deeppot_sea.pth and b/source/tests/infer/deeppot_sea.pth differ diff --git a/source/tests/infer/fparam_aparam.pth b/source/tests/infer/fparam_aparam.pth index c433ced49b..703f7267be 100644 Binary files a/source/tests/infer/fparam_aparam.pth and b/source/tests/infer/fparam_aparam.pth differ