diff --git a/source/api_cc/include/common.h b/source/api_cc/include/common.h index 2bd0cf7135..612f699ea4 100644 --- a/source/api_cc/include/common.h +++ b/source/api_cc/include/common.h @@ -33,7 +33,13 @@ struct NeighborListData { std::vector firstneigh; public: - void copy_from_nlist(const InputNlist& inlist); + /** + * @brief Copy the neighbor list from an InputNlist. + * @param[in] inlist The input neighbor list. + * @param[in] natoms The number of atoms to copy. If natoms is -1, copy all + * atoms. + */ + void copy_from_nlist(const InputNlist& inlist, const int natoms = -1); void shuffle(const std::vector& fwd_map); void shuffle(const deepmd::AtomMap& map); void shuffle_exclude_empty(const std::vector& fwd_map); diff --git a/source/api_cc/src/DeepPotJAX.cc b/source/api_cc/src/DeepPotJAX.cc index 805380081d..07f8b9119b 100644 --- a/source/api_cc/src/DeepPotJAX.cc +++ b/source/api_cc/src/DeepPotJAX.cc @@ -566,7 +566,7 @@ void deepmd::DeepPotJAX::compute(std::vector& ener, input_list[1] = add_input(op, atype, atype_shape, data_tensor[1], status); // nlist if (ago == 0) { - nlist_data.copy_from_nlist(lmp_list); + nlist_data.copy_from_nlist(lmp_list, nall - nghost); nlist_data.shuffle_exclude_empty(fwd_map); } size_t max_size = 0; diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 6910de3ccd..abd35eaf1e 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -169,7 +169,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, at::Tensor atype_Tensor = torch::from_blob(atype_64.data(), {1, nall_real}, int_option).to(device); if (ago == 0) { - nlist_data.copy_from_nlist(lmp_list); + nlist_data.copy_from_nlist(lmp_list, nall - nghost); nlist_data.shuffle_exclude_empty(fwd_map); nlist_data.padding(); if (do_message_passing) { diff --git a/source/api_cc/src/DeepSpinPT.cc b/source/api_cc/src/DeepSpinPT.cc index aef2d60150..7421b623db 100644 --- a/source/api_cc/src/DeepSpinPT.cc +++ b/source/api_cc/src/DeepSpinPT.cc @@ -177,7 +177,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener, torch::from_blob(atype_64.data(), {1, nall_real}, int_option).to(device); c10::optional mapping_tensor; if (ago == 0) { - nlist_data.copy_from_nlist(lmp_list); + nlist_data.copy_from_nlist(lmp_list, nall - nghost); nlist_data.shuffle_exclude_empty(fwd_map); nlist_data.padding(); if (do_message_passing) { diff --git a/source/api_cc/src/common.cc b/source/api_cc/src/common.cc index 5a4f05d75c..c51ae9a8b4 100644 --- a/source/api_cc/src/common.cc +++ b/source/api_cc/src/common.cc @@ -232,8 +232,9 @@ template void deepmd::select_real_atoms_coord( const int& nall, const bool aparam_nall); -void deepmd::NeighborListData::copy_from_nlist(const InputNlist& inlist) { - int inum = inlist.inum; +void deepmd::NeighborListData::copy_from_nlist(const InputNlist& inlist, + const int natoms) { + int inum = natoms >= 0 ? natoms : inlist.inum; ilist.resize(inum); jlist.resize(inum); memcpy(&ilist[0], inlist.ilist, inum * sizeof(int));