Skip to content

Commit

Permalink
Merge branch 'devel' into dipole_polar_train
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd authored Mar 1, 2024
2 parents e0cd84a + ee8b82b commit a8f58a7
Show file tree
Hide file tree
Showing 14 changed files with 132 additions and 83 deletions.
2 changes: 1 addition & 1 deletion deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def main_parser() -> argparse.ArgumentParser:
"--output",
type=str,
default="out.json",
help="(Supported backend: TensorFlow) The output file of the parameters used in training.",
help="The output file of the parameters used in training.",
)
parser_train.add_argument(
"--skip-neighbor-stat",
Expand Down
12 changes: 8 additions & 4 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,6 @@ def get_trainer(
shared_links=None,
):
multi_task = "model_dict" in config.get("model", {})
# argcheck
if not multi_task:
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config)

# Initialize DDP
local_rank = os.environ.get("LOCAL_RANK")
Expand Down Expand Up @@ -236,6 +232,11 @@ def train(FLAGS):
if multi_task:
config["model"], shared_links = preprocess_shared_params(config["model"])

# argcheck
if not multi_task:
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config)

# do neighbor stat
if not FLAGS.skip_neighbor_stat:
log.info(
Expand All @@ -257,6 +258,9 @@ def train(FLAGS):
fake_global_jdata, config["model"]["model_dict"][model_item]
)

with open(FLAGS.output, "w") as fp:
json.dump(config, fp, indent=4)

trainer = get_trainer(
config,
FLAGS.init_model,
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def reinit_pair_exclude(
# export public methods that are not abstract
get_nsel = torch.jit.export(BaseAtomicModel_.get_nsel)
get_nnei = torch.jit.export(BaseAtomicModel_.get_nnei)
get_ntypes = torch.jit.export(BaseAtomicModel_.get_ntypes)

@torch.jit.export
def get_model_def_script(self) -> str:
Expand Down
1 change: 0 additions & 1 deletion deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,6 @@ def output_type_cast(
)
return model_ret

@torch.jit.export
def format_nlist(
self,
extended_coord: torch.Tensor,
Expand Down
51 changes: 49 additions & 2 deletions doc/model/dprc.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Deep Potential - Range Correction (DPRc) {{ tensorflow_icon }}
# Deep Potential - Range Correction (DPRc) {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DPModel {{ dpmodel_icon }}
:::

Deep Potential - Range Correction (DPRc) is designed to combine with QM/MM method, and corrects energies from a low-level QM/MM method to a high-level QM/MM method:
Expand Down Expand Up @@ -62,6 +62,10 @@ In a DPRc model, QM atoms and MM atoms have different atom types. Assuming we ha

As described in the paper, the DPRc model only corrects $E_\text{QM}$ and $E_\text{QM/MM}$ within the cutoff, so we use a hybrid descriptor to describe them separatedly:

::::{tab-set}

:::{tab-item} TensorFlow {{ tensorflow_icon }}

```json
"descriptor" :{
"type": "hybrid",
Expand Down Expand Up @@ -91,6 +95,45 @@ As described in the paper, the DPRc model only corrects $E_\text{QM}$ and $E_\te
}
```

:::

:::{tab-item} PyTorch {{ pytorch_icon }}

```json
"descriptor" :{
"type": "hybrid",
"list" : [
{
"type": "se_e2_a",
"sel": [6, 11, 0, 6, 0, 1],
"rcut_smth": 1.00,
"rcut": 9.00,
"neuron": [12, 25, 50],
"exclude_types": [[2, 2], [2, 4], [4, 4], [0, 2], [0, 4], [1, 2], [1, 4], [3, 2], [3, 4], [5, 2], [5, 4]],
"axis_neuron": 12,
"type_one_side": true,
"_comment": " QM/QM interaction"
},
{
"type": "se_e2_a",
"sel": [6, 11, 100, 6, 50, 1],
"rcut_smth": 0.50,
"rcut": 6.00,
"neuron": [12, 25, 50],
"exclude_types": [[0, 0], [0, 1], [0, 3], [0, 5], [1, 1], [1, 3], [1, 5], [3, 3], [3, 5], [5, 5], [2, 2], [2, 4], [4, 4]],
"axis_neuron": 12,
"set_davg_zero": true,
"type_one_side": true,
"_comment": " QM/MM interaction"
}
]
}
```

:::

::::

{ref}`exclude_types <model/descriptor[se_a_ebd_v2]/exclude_types>` can be generated by the following Python script:
```py
from itertools import combinations_with_replacement, product
Expand Down Expand Up @@ -131,6 +174,10 @@ The DPRc model has the best practices with the [AMBER](../third-party/out-of-dee

## Pairwise DPRc

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}
:::

If one wants to correct from a low-level method into a full DFT level, and the system is too large to do full DFT calculation, one may try the experimental pairwise DPRc model.
In a pairwise DPRc model, the total energy is divided into QM internal energy and the sum of QM/MM energy for each MM residue $l$:

Expand Down
6 changes: 3 additions & 3 deletions source/api_cc/include/DeepPotPT.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
#pragma once

#include <torch/script.h>
#include <torch/torch.h>

#include "DeepPot.h"
#include "commonPT.h"

namespace deepmd {
/**
Expand Down Expand Up @@ -106,7 +106,7 @@ class DeepPotPT : public DeepPotBase {
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box,
// const int nghost,
const int nghost,
const InputNlist& lmp_list,
const int& ago,
const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
Expand Down Expand Up @@ -322,7 +322,7 @@ class DeepPotPT : public DeepPotBase {
// copy neighbor list info from host
torch::jit::script::Module module;
double rcut;
NeighborListDataPT nlist_data;
NeighborListData nlist_data;
int max_num_neighbors;
int gpu_id;
bool gpu_enabled;
Expand Down
1 change: 1 addition & 0 deletions source/api_cc/include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct NeighborListData {
void shuffle(const deepmd::AtomMap& map);
void shuffle_exclude_empty(const std::vector<int>& fwd_map);
void make_inlist(InputNlist& inlist);
void padding();
};

/**
Expand Down
24 changes: 0 additions & 24 deletions source/api_cc/include/commonPT.h

This file was deleted.

82 changes: 59 additions & 23 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@

#include "common.h"
using namespace deepmd;
torch::Tensor createNlistTensor(const std::vector<std::vector<int>>& data) {
std::vector<torch::Tensor> row_tensors;

for (const auto& row : data) {
torch::Tensor row_tensor = torch::tensor(row, torch::kInt32).unsqueeze(0);
row_tensors.push_back(row_tensor);
}

torch::Tensor tensor = torch::cat(row_tensors, 0).unsqueeze(0);
return tensor;
}
DeepPotPT::DeepPotPT() : inited(false) {}
DeepPotPT::DeepPotPT(const std::string& model,
const int& gpu_rank,
Expand Down Expand Up @@ -60,7 +71,7 @@ void DeepPotPT::init(const std::string& model,

auto rcut_ = module.run_method("get_rcut").toDouble();
rcut = static_cast<double>(rcut_);
ntypes = 0;
ntypes = module.run_method("get_ntypes").toInt();
ntypes_spin = 0;
dfparam = module.run_method("get_dim_fparam").toInt();
daparam = module.run_method("get_dim_aparam").toInt();
Expand All @@ -78,6 +89,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box,
const int nghost,
const InputNlist& lmp_list,
const int& ago,
const std::vector<VALUETYPE>& fparam,
Expand All @@ -86,7 +98,6 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
if (!gpu_enabled) {
device = torch::Device(torch::kCPU);
}
std::vector<VALUETYPE> coord_wrapped = coord;
int natoms = atype.size();
auto options = torch::TensorOptions().dtype(torch::kFloat64);
torch::ScalarType floatType = torch::kFloat64;
Expand All @@ -96,18 +107,29 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
}
auto int_options = torch::TensorOptions().dtype(torch::kInt64);
auto int32_options = torch::TensorOptions().dtype(torch::kInt32);

// select real atoms
std::vector<VALUETYPE> dcoord, dforce, aparam_, datom_energy, datom_virial;
std::vector<int> datype, fwd_map, bkw_map;
int nghost_real, nall_real, nloc_real;
int nall = natoms;
select_real_atoms_coord(dcoord, datype, aparam_, nghost_real, fwd_map,
bkw_map, nall_real, nloc_real, coord, atype, aparam,
nghost, ntypes, 1, daparam, nall, aparam_nall);
std::cout << datype.size() << std::endl;
std::vector<VALUETYPE> coord_wrapped = dcoord;
at::Tensor coord_wrapped_Tensor =
torch::from_blob(coord_wrapped.data(), {1, natoms, 3}, options)
torch::from_blob(coord_wrapped.data(), {1, nall_real, 3}, options)
.to(device);
std::vector<int64_t> atype_64(atype.begin(), atype.end());
std::vector<int64_t> atype_64(datype.begin(), datype.end());
at::Tensor atype_Tensor =
torch::from_blob(atype_64.data(), {1, natoms}, int_options).to(device);
torch::from_blob(atype_64.data(), {1, nall_real}, int_options).to(device);
if (ago == 0) {
nlist_data.copy_from_nlist(lmp_list, max_num_neighbors);
nlist_data.copy_from_nlist(lmp_list);
nlist_data.shuffle_exclude_empty(fwd_map);
nlist_data.padding();
}
at::Tensor firstneigh =
torch::from_blob(nlist_data.jlist.data(),
{1, lmp_list.inum, max_num_neighbors}, int32_options);
at::Tensor firstneigh = createNlistTensor(nlist_data.jlist);
firstneigh_tensor = firstneigh.to(torch::kInt64).to(device);
bool do_atom_virial_tensor = true;
c10::optional<torch::Tensor> optional_tensor;
Expand All @@ -119,13 +141,13 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
.to(device);
}
c10::optional<torch::Tensor> aparam_tensor;
if (!aparam.empty()) {
aparam_tensor =
torch::from_blob(const_cast<VALUETYPE*>(aparam.data()),
{1, lmp_list.inum,
static_cast<long int>(aparam.size()) / lmp_list.inum},
options)
.to(device);
if (!aparam_.empty()) {
aparam_tensor = torch::from_blob(
const_cast<VALUETYPE*>(aparam_.data()),
{1, lmp_list.inum,
static_cast<long int>(aparam_.size()) / lmp_list.inum},
options)
.to(device);
}
c10::Dict<c10::IValue, c10::IValue> outputs =
module
Expand All @@ -145,24 +167,36 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
torch::Tensor flat_atom_energy_ =
atom_energy_.toTensor().view({-1}).to(floatType);
torch::Tensor cpu_atom_energy_ = flat_atom_energy_.to(torch::kCPU);
atom_energy.resize(natoms, 0.0); // resize to nall to be consistenet with TF.
atom_energy.assign(
datom_energy.resize(nall_real,
0.0); // resize to nall to be consistenet with TF.
datom_energy.assign(
cpu_atom_energy_.data_ptr<VALUETYPE>(),
cpu_atom_energy_.data_ptr<VALUETYPE>() + cpu_atom_energy_.numel());
torch::Tensor flat_force_ = force_.toTensor().view({-1}).to(floatType);
torch::Tensor cpu_force_ = flat_force_.to(torch::kCPU);
force.assign(cpu_force_.data_ptr<VALUETYPE>(),
cpu_force_.data_ptr<VALUETYPE>() + cpu_force_.numel());
dforce.assign(cpu_force_.data_ptr<VALUETYPE>(),
cpu_force_.data_ptr<VALUETYPE>() + cpu_force_.numel());
torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());
torch::Tensor flat_atom_virial_ =
atom_virial_.toTensor().view({-1}).to(floatType);
torch::Tensor cpu_atom_virial_ = flat_atom_virial_.to(torch::kCPU);
atom_virial.assign(
datom_virial.assign(
cpu_atom_virial_.data_ptr<VALUETYPE>(),
cpu_atom_virial_.data_ptr<VALUETYPE>() + cpu_atom_virial_.numel());
int nframes = 1;
// bkw map
force.resize(static_cast<size_t>(nframes) * fwd_map.size() * 3);
atom_energy.resize(static_cast<size_t>(nframes) * fwd_map.size());
atom_virial.resize(static_cast<size_t>(nframes) * fwd_map.size() * 9);
select_map<VALUETYPE>(force, dforce, bkw_map, 3, nframes, fwd_map.size(),
nall_real);
select_map<VALUETYPE>(atom_energy, datom_energy, bkw_map, 1, nframes,
fwd_map.size(), nall_real);
select_map<VALUETYPE>(atom_virial, datom_virial, bkw_map, 9, nframes,
fwd_map.size(), nall_real);
}
template void DeepPotPT::compute<double, std::vector<ENERGYTYPE>>(
std::vector<ENERGYTYPE>& ener,
Expand All @@ -173,6 +207,7 @@ template void DeepPotPT::compute<double, std::vector<ENERGYTYPE>>(
const std::vector<double>& coord,
const std::vector<int>& atype,
const std::vector<double>& box,
const int nghost,
const InputNlist& lmp_list,
const int& ago,
const std::vector<double>& fparam,
Expand All @@ -186,6 +221,7 @@ template void DeepPotPT::compute<float, std::vector<ENERGYTYPE>>(
const std::vector<float>& coord,
const std::vector<int>& atype,
const std::vector<float>& box,
const int nghost,
const InputNlist& lmp_list,
const int& ago,
const std::vector<float>& fparam,
Expand Down Expand Up @@ -353,7 +389,7 @@ void DeepPotPT::computew(std::vector<double>& ener,
const std::vector<double>& fparam,
const std::vector<double>& aparam) {
compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box,
inlist, ago, fparam, aparam);
nghost, inlist, ago, fparam, aparam);
}
void DeepPotPT::computew(std::vector<double>& ener,
std::vector<float>& force,
Expand All @@ -369,7 +405,7 @@ void DeepPotPT::computew(std::vector<double>& ener,
const std::vector<float>& fparam,
const std::vector<float>& aparam) {
compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box,
inlist, ago, fparam, aparam);
nghost, inlist, ago, fparam, aparam);
}
void DeepPotPT::computew_mixed_type(std::vector<double>& ener,
std::vector<double>& force,
Expand Down
10 changes: 10 additions & 0 deletions source/api_cc/src/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,16 @@ void deepmd::NeighborListData::shuffle_exclude_empty(
ilist = new_ilist;
jlist = new_jlist;
}
void deepmd::NeighborListData::padding() {
size_t max_length = 0;
for (const auto& row : jlist) {
max_length = std::max(max_length, row.size());
}

for (int i = 0; i < jlist.size(); i++) {
jlist[i].resize(max_length, -1);
}
}

void deepmd::NeighborListData::make_inlist(InputNlist& inlist) {
int nloc = ilist.size();
Expand Down
Loading

0 comments on commit a8f58a7

Please sign in to comment.