From f4c99ecfde69ccbb0d4ace5191c4ccdbea0c616c Mon Sep 17 00:00:00 2001 From: takamoto Date: Wed, 15 Sep 2021 03:46:10 +0000 Subject: [PATCH 1/4] use shift for derivative variable instead of cell --- tests/functions_tests/test_triplets.py | 32 +++++++++++-------- tests/test_torch_dftd3_calculator.py | 6 +++- torch_dftd/functions/dftd3.py | 17 ++++++---- torch_dftd/functions/distance.py | 8 +---- torch_dftd/functions/triplets.py | 41 +++++++++++++++---------- torch_dftd/functions/triplets_kernel.py | 31 ++++++++----------- torch_dftd/nn/base_dftd_module.py | 25 +++++++++++++-- torch_dftd/nn/dftd2_module.py | 6 ++-- torch_dftd/nn/dftd3_module.py | 8 +++-- torch_dftd/torch_dftd3_calculator.py | 12 ++++++-- 10 files changed, 115 insertions(+), 71 deletions(-) diff --git a/tests/functions_tests/test_triplets.py b/tests/functions_tests/test_triplets.py index 9abe2e2..a364131 100644 --- a/tests/functions_tests/test_triplets.py +++ b/tests/functions_tests/test_triplets.py @@ -18,9 +18,7 @@ def test_calc_triplets(): [1, 2, 3, 4, 5, 6, -1, -2, -3, -4, -5, -6], dtype=torch.float32, device=device ) # print("shift", shift.shape) - triplet_node_index, multiplicity, triplet_shift, batch_triplets = calc_triplets( - edge_index, shift - ) + triplet_node_index, multiplicity, edge_jk, batch_triplets = calc_triplets(edge_index, shift) # print("triplet_node_index", triplet_node_index.shape, triplet_node_index) # print("multiplicity", multiplicity.shape, multiplicity) # print("triplet_shift", triplet_shift.shape, triplet_shift) @@ -38,6 +36,20 @@ def test_calc_triplets(): ) assert multiplicity.shape == (n_triplets,) assert torch.all(multiplicity.cpu() == torch.ones((n_triplets,), dtype=torch.float32)) + + assert torch.allclose( + edge_jk.cpu(), + torch.tensor([[7, 6], [8, 6], [8, 7], [9, 10], [9, 11], [11, 10]], dtype=torch.long), + ) + # shift for edge `i->j`, `i->k`, `j->k`. + triplet_shift = torch.stack( + [ + -shift[edge_jk[:, 0]], + -shift[edge_jk[:, 1]], + shift[edge_jk[:, 0]] - shift[edge_jk[:, 1]], + ], + dim=1, + ) assert torch.allclose( triplet_shift.cpu()[:, :, 0], torch.tensor( @@ -61,7 +73,7 @@ def test_calc_triplets_noshift(): edge_index = torch.tensor( [[0, 1, 1, 3, 1, 2, 3, 0], [1, 2, 3, 0, 0, 1, 1, 3]], dtype=torch.long, device=device ) - triplet_node_index, multiplicity, triplet_shift, batch_triplets = calc_triplets( + triplet_node_index, multiplicity, edge_jk, batch_triplets = calc_triplets( edge_index, dtype=torch.float64 ) # print("triplet_node_index", triplet_node_index.shape, triplet_node_index) @@ -78,13 +90,7 @@ def test_calc_triplets_noshift(): assert multiplicity.shape == (n_triplets,) assert multiplicity.dtype == torch.float64 assert torch.all(multiplicity.cpu() == torch.ones((n_triplets,), dtype=torch.float64)) - assert torch.all( - triplet_shift.cpu() - == torch.zeros( - (n_triplets, 3, 3), - dtype=torch.float32, - ) - ) + assert torch.all(edge_jk.cpu() == torch.tensor([[1, 0], [2, 3]], dtype=torch.long)) assert torch.all(batch_triplets.cpu() == torch.zeros((n_triplets,), dtype=torch.long)) @@ -95,7 +101,7 @@ def test_calc_triplets_noshift(): def test_calc_triplets_no_triplets(edge_index): # edge_index = edge_index.to("cuda:0") # No triplet exist in this graph. Case1: No edge, Case 2 No triplets in this edge. - triplet_node_index, multiplicity, triplet_shift, batch_triplets = calc_triplets(edge_index) + triplet_node_index, multiplicity, edge_jk, batch_triplets = calc_triplets(edge_index) # print("triplet_node_index", triplet_node_index.shape, triplet_node_index) # print("multiplicity", multiplicity.shape, multiplicity) # print("triplet_shift", triplet_shift.shape, triplet_shift) @@ -104,7 +110,7 @@ def test_calc_triplets_no_triplets(edge_index): # 0 triplets exist. assert triplet_node_index.shape == (0, 3) assert multiplicity.shape == (0,) - assert triplet_shift.shape == (0, 3, 3) + assert edge_jk.shape == (0, 2) assert batch_triplets.shape == (0,) diff --git a/tests/test_torch_dftd3_calculator.py b/tests/test_torch_dftd3_calculator.py index f1351c3..b35c314 100644 --- a/tests/test_torch_dftd3_calculator.py +++ b/tests/test_torch_dftd3_calculator.py @@ -20,6 +20,9 @@ def _create_atoms() -> List[Atoms]: atoms = molecule("CH3CH2OCH3") slab = fcc111("Au", size=(2, 1, 3), vacuum=80.0) + slab.set_cell( + slab.get_cell().array @ np.array([[1.0, 0.1, 0.2], [0.05, 1.0, 0.02], [0.03, 0.04, 1.0]]) + ) slab.pbc = np.array([True, True, True]) return [atoms, slab] @@ -58,6 +61,8 @@ def _assert_energy_force_stress_equal(calc1, calc2, atoms: Atoms): atoms.calc = calc1 f1 = atoms.get_forces() e1 = atoms.get_potential_energy() + if np.all(atoms.pbc == np.array([True, True, True])): + s1 = atoms.get_stress() calc2.reset() atoms.calc = calc2 @@ -66,7 +71,6 @@ def _assert_energy_force_stress_equal(calc1, calc2, atoms: Atoms): assert np.allclose(e1, e2, atol=1e-4, rtol=1e-4) assert np.allclose(f1, f2, atol=1e-5, rtol=1e-5) if np.all(atoms.pbc == np.array([True, True, True])): - s1 = atoms.get_stress() s2 = atoms.get_stress() assert np.allclose(s1, s2, atol=1e-5, rtol=1e-5) diff --git a/torch_dftd/functions/dftd3.py b/torch_dftd/functions/dftd3.py index 6201c70..bcfbf7d 100644 --- a/torch_dftd/functions/dftd3.py +++ b/torch_dftd/functions/dftd3.py @@ -283,21 +283,25 @@ def edisp( shift_abc = None if shift_abc is None else torch.cat([shift_abc, -shift_abc], dim=0) with torch.no_grad(): # triplet_node_index, triplet_edge_index = calc_triplets_cycle(edge_index_abc, n_atoms, shift=shift_abc) - triplet_node_index, multiplicity, triplet_shift, batch_triplets = calc_triplets( - edge_index_abc, shift=shift_abc, dtype=pos.dtype, batch_edge=batch_edge_abc + triplet_node_index, multiplicity, edge_jk, batch_triplets = calc_triplets( + edge_index_abc, + shift=shift_abc, + dtype=pos.dtype, + batch_edge=batch_edge_abc, ) batch_triplets = None if batch_edge is None else batch_triplets # Apply `cnthr` cutoff threshold for r_kj idx_j, idx_k = triplet_node_index[:, 1], triplet_node_index[:, 2] - ts2 = triplet_shift[:, 2] + ts2 = shift_abc[edge_jk[:, 0]] - shift_abc[edge_jk[:, 1]] r_jk = calc_distances(pos, torch.stack([idx_j, idx_k], dim=0), cell, ts2, batch_triplets) kj_within_cutoff = r_jk <= cnthr + del ts2 triplet_node_index = triplet_node_index[kj_within_cutoff] - multiplicity, triplet_shift, batch_triplets = ( + multiplicity, edge_jk, batch_triplets = ( multiplicity[kj_within_cutoff], - triplet_shift[kj_within_cutoff], + edge_jk[kj_within_cutoff], None if batch_triplets is None else batch_triplets[kj_within_cutoff], ) @@ -306,7 +310,8 @@ def edisp( triplet_node_index[:, 1], triplet_node_index[:, 2], ) - ts0, ts1, ts2 = triplet_shift[:, 0], triplet_shift[:, 1], triplet_shift[:, 2] + ts0 = -shift_abc[edge_jk[:, 0]] + ts1 = -shift_abc[edge_jk[:, 1]] r_ij = calc_distances(pos, torch.stack([idx_i, idx_j], dim=0), cell, ts0, batch_triplets) r_ik = calc_distances(pos, torch.stack([idx_i, idx_k], dim=0), cell, ts1, batch_triplets) diff --git a/torch_dftd/functions/distance.py b/torch_dftd/functions/distance.py index 84c5786..58ac495 100644 --- a/torch_dftd/functions/distance.py +++ b/torch_dftd/functions/distance.py @@ -18,13 +18,7 @@ def calc_distances( Ri = pos[idx_i] Rj = pos[idx_j] if cell is not None: - if batch_edge is None: - # shift (n_edges, 3), cell (3, 3) -> offsets (n_edges, 3) - offsets = torch.mm(shift, cell) - else: - # shift (n_edges, 3), cell[batch] (n_atoms, 3, 3) -> offsets (n_edges, 3) - offsets = torch.bmm(shift[:, None, :], cell[batch_edge])[:, 0] - Rj += offsets + Rj += shift # eps is to avoid Nan in backward when Dij = 0 with sqrt. Dij = torch.sqrt(torch.sum((Ri - Rj) ** 2, dim=-1) + eps) return Dij diff --git a/torch_dftd/functions/triplets.py b/torch_dftd/functions/triplets.py index a6088be..f8c295a 100644 --- a/torch_dftd/functions/triplets.py +++ b/torch_dftd/functions/triplets.py @@ -24,8 +24,8 @@ def calc_triplets( i.e.: idx_i, idx_j, idx_k = triplet_node_index multiplicity (Tensor): (n_triplets,) multiplicity indicates duplication of same triplet pair. It only takes 1 in non-pbc, but it takes 2 or 3 in pbc case. dtype is specified in the argument. - triplet_shift (Tensor): (n_triplets, 3=(ij, ik, jk), 3=(xyz)) shift for edge `i->j`, `i->k`, `j->k`. - i.e.: idx_ij, idx_ik, idx_jk = triplet_shift + edge_jk (Tensor): (n_triplet_edges, 2=(j, k)) edge indices for j and k. + i.e.: idx_j, idx_k = triplet_shift batch_triplets (Tensor): (n_triplets,) batch indices for each triplets. """ dst, src = edge_index @@ -38,9 +38,12 @@ def calc_triplets( dst = dst[sort_inds] if shift is None: - shift = torch.zeros((src.shape[0], 3), dtype=dtype, device=edge_index.device) + edge_indices = torch.arange(src.shape[0], dtype=torch.long, device=edge_index.device) + # shift = torch.zeros((src.shape[0], 3), dtype=dtype, device=edge_index.device) else: - shift = shift[is_larger][sort_inds] + edge_indices = torch.arange(shift.shape[0], dtype=torch.long, device=edge_index.device) + edge_indices = edge_indices[is_larger][sort_inds] + # shift = shift[is_larger][sort_inds] if batch_edge is None: batch_edge = torch.zeros(src.shape[0], dtype=torch.long, device=edge_index.device) @@ -55,15 +58,15 @@ def calc_triplets( if str(unique.device) == "cpu": return _calc_triplets_core( - counts, unique, dst, shift, batch_edge, counts_cumsum, dtype=dtype + counts, unique, dst, edge_indices, batch_edge, counts_cumsum, dtype=dtype ) else: return _calc_triplets_core_gpu( - counts, unique, dst, shift, batch_edge, counts_cumsum, dtype=dtype + counts, unique, dst, edge_indices, batch_edge, counts_cumsum, dtype=dtype ) -def _calc_triplets_core(counts, unique, dst, shift, batch_edge, counts_cumsum, dtype): +def _calc_triplets_core(counts, unique, dst, edge_indices, batch_edge, counts_cumsum, dtype): device = unique.device n_triplets = torch.sum(counts * (counts - 1) / 2) if n_triplets == 0: @@ -71,21 +74,22 @@ def _calc_triplets_core(counts, unique, dst, shift, batch_edge, counts_cumsum, d triplet_node_index = torch.zeros((0, 3), dtype=torch.long, device=device) # (n_triplet_edges) multiplicity = torch.zeros((0,), dtype=dtype, device=device) - # (n_triplet_edges, 3=(ij, ik, jk), 3=(xyz) ) - triplet_shift = torch.zeros((0, 3, 3), dtype=dtype, device=device) + # (n_triplet_edges, 2=(j, k)) + edge_jk = torch.zeros((0, 2), dtype=torch.long, device=device) # (n_triplet_edges) batch_triplets = torch.zeros((0,), dtype=torch.long, device=device) - return triplet_node_index, multiplicity, triplet_shift, batch_triplets + return triplet_node_index, multiplicity, edge_jk, batch_triplets triplet_node_index_list = [] # (n_triplet_edges, 3) - shift_list = [] # (n_triplet_edges, 3, 3) represents shift vector + edge_jk_list = [] # (n_triplet_edges, 2) represents j and k indices multiplicity_list = [] # (n_triplet_edges) represents multiplicity batch_triplets_list = [] # (n_triplet_edges) represents batch index for triplets for i in range(len(unique)): _src = unique[i].item() _n_edges = counts[i].item() _dst = dst[counts_cumsum[i] : counts_cumsum[i + 1]] - _shift = shift[counts_cumsum[i] : counts_cumsum[i + 1]] + # _shift = shift[counts_cumsum[i] : counts_cumsum[i + 1]] + _offset = counts_cumsum[i].item() _batch_index = batch_edge[counts_cumsum[i]].item() for j in range(_n_edges - 1): for k in range(j + 1, _n_edges): @@ -101,8 +105,12 @@ def _calc_triplets_core(counts, unique, dst, shift, batch_edge, counts_cumsum, d _j, _k = k, j triplet_node_index_list.append([_src, _dst0, _dst1]) - shift_list.append( - torch.stack([-_shift[_j], -_shift[_k], _shift[_j] - _shift[_k]], dim=0) + edge_jk_list.append( + [ + _offset + _j, + _offset + _k, + ] + # torch.stack([-_shift[_j], -_shift[_k], _shift[_j] - _shift[_k]], dim=0) ) # --- multiplicity --- if _dst0 == _dst1: @@ -126,7 +134,8 @@ def _calc_triplets_core(counts, unique, dst, shift, batch_edge, counts_cumsum, d triplet_node_index = torch.as_tensor(triplet_node_index_list, device=device) # (n_triplet_edges) multiplicity = torch.as_tensor(multiplicity_list, dtype=dtype, device=device) + # (n_triplet_edges, 2=(j, k)) + edge_jk = edge_indices[torch.tensor(edge_jk_list, dtype=torch.long, device=device)] # (n_triplet_edges, 3=(ij, ik, jk), 3=(xyz) ) - triplet_shift = torch.stack(shift_list, dim=0) batch_triplets = torch.as_tensor(batch_triplets_list, dtype=torch.long, device=device) - return triplet_node_index, multiplicity, triplet_shift, batch_triplets + return triplet_node_index, multiplicity, edge_jk, batch_triplets diff --git a/torch_dftd/functions/triplets_kernel.py b/torch_dftd/functions/triplets_kernel.py index 4fbb92a..cf4d96a 100644 --- a/torch_dftd/functions/triplets_kernel.py +++ b/torch_dftd/functions/triplets_kernel.py @@ -35,8 +35,8 @@ def _cupy2torch(array: cp.ndarray) -> Tensor: if _cupy_available: _calc_triplets_core_gpu_kernel = cp.ElementwiseKernel( - "raw int64 counts, raw int64 unique, raw int64 dst, raw T shift, raw int64 batch_edge, raw int64 counts_cumsum", - "raw int64 triplet_node_index, raw T multiplicity, raw T triplet_shift, raw int64 batch_triplets", + "raw int64 counts, raw int64 unique, raw int64 dst, raw int64 edge_indices, raw int64 batch_edge, raw int64 counts_cumsum", + "raw int64 triplet_node_index, raw T multiplicity, raw int64 edge_jk, raw int64 batch_triplets", """ long long n_unique = unique.size(); long long a = 0; @@ -100,16 +100,9 @@ def _cupy2torch(array: cp.ndarray) -> Tensor: } } - // --- triplet_shift --- - triplet_shift[9 * i] = -shift[3 * (_offset + b)]; - triplet_shift[9 * i + 1] = -shift[3 * (_offset + b) + 1]; - triplet_shift[9 * i + 2] = -shift[3 * (_offset + b) + 2]; - triplet_shift[9 * i + 3] = -shift[3 * (_offset + c)]; - triplet_shift[9 * i + 4] = -shift[3 * (_offset + c) + 1]; - triplet_shift[9 * i + 5] = -shift[3 * (_offset + c) + 2]; - triplet_shift[9 * i + 6] = shift[3 * (_offset + b)] - shift[3 * (_offset + c)]; - triplet_shift[9 * i + 7] = shift[3 * (_offset + b) + 1] - shift[3 * (_offset + c) + 1]; - triplet_shift[9 * i + 8] = shift[3 * (_offset + b) + 2] - shift[3 * (_offset + c) + 2]; + // --- edge_jk --- + edge_jk[2 * i] = edge_indices[_offset + b]; + edge_jk[2 * i + 1] = edge_indices[_offset + c]; // --- batch_triplets --- batch_triplets[i] = _batch_index; @@ -124,7 +117,7 @@ def _calc_triplets_core_gpu( counts: Tensor, unique: Tensor, dst: Tensor, - shift: Tensor, + edge_indices: Tensor, batch_edge: Tensor, counts_cumsum: Tensor, dtype: torch.dtype = torch.float32, @@ -140,26 +133,26 @@ def _calc_triplets_core_gpu( triplet_node_index = torch.zeros((n_triplets, 3), dtype=torch.long, device=device) # (n_triplet_edges) multiplicity = torch.zeros((n_triplets,), dtype=dtype, device=device) - # (n_triplet_edges, 3=(ij, ik, jk), 3=(xyz) ) - triplet_shift = torch.zeros((n_triplets, 3, 3), dtype=dtype, device=device) + # (n_triplet_edges, 2=(j, k)) + edge_jk = torch.zeros((n_triplets, 2), dtype=torch.long, device=device) # (n_triplet_edges) batch_triplets = torch.zeros((n_triplets,), dtype=torch.long, device=device) if n_triplets == 0: - return triplet_node_index, multiplicity, triplet_shift, batch_triplets + return triplet_node_index, multiplicity, edge_jk, batch_triplets _calc_triplets_core_gpu_kernel( _torch2cupy(counts), _torch2cupy(unique), _torch2cupy(dst), - _torch2cupy(shift), + _torch2cupy(edge_indices), _torch2cupy(batch_edge), _torch2cupy(counts_cumsum), # n_triplets, _torch2cupy(triplet_node_index), _torch2cupy(multiplicity), - _torch2cupy(triplet_shift), + _torch2cupy(edge_jk), _torch2cupy(batch_triplets), size=n_triplets, ) # torch tensor buffer is already modified in above cupy functions. - return triplet_node_index, multiplicity, triplet_shift, batch_triplets + return triplet_node_index, multiplicity, edge_jk, batch_triplets diff --git a/torch_dftd/nn/base_dftd_module.py b/torch_dftd/nn/base_dftd_module.py index a30141e..ff97f41 100644 --- a/torch_dftd/nn/base_dftd_module.py +++ b/torch_dftd/nn/base_dftd_module.py @@ -53,6 +53,7 @@ def calc_energy( batch: Optional[Tensor] = None, batch_edge: Optional[Tensor] = None, damping: str = "zero", + shift_int: Optional[Tensor] = None, ) -> List[Dict[str, Any]]: """Forward computation of dispersion energy @@ -95,6 +96,7 @@ def calc_energy_and_forces( batch: Optional[Tensor] = None, batch_edge: Optional[Tensor] = None, damping: str = "zero", + shift_int: Optional[Tensor] = None, ) -> List[Dict[str, Any]]: """Forward computation of dispersion energy, force and stress @@ -116,6 +118,7 @@ def calc_energy_and_forces( # We need to explicitly include this dependency to calculate cell gradient # for stress computation. # pos is assumed to be inside "cell", so relative position `rel_pos` lies between 0~1. + assert isinstance(shift, Tensor) if batch is None: rel_pos = torch.mm(pos, torch.inverse(cell)) pos = torch.mm(rel_pos.detach(), cell) @@ -125,8 +128,13 @@ def calc_energy_and_forces( # pos (n_atoms, 1, 3) * cell (n_atoms, 3, 3) -> (n_atoms, 3) pos = torch.bmm(rel_pos[:, None, :].detach(), cell[batch])[:, 0] + # cell_2 = cell.detach().requires_grad_(True) + # shift = torch.mm(shift_int, cell_2) + # shift = shift_int pos.retain_grad() cell.retain_grad() + shift.retain_grad() + # cell_2.retain_grad() E_disp = self.calc_energy_batch( Z, pos, edge_index, cell, pbc, shift, batch, batch_edge, damping=damping ) @@ -150,13 +158,26 @@ def calc_energy_and_forces( # Get stress in Voigt notation (xx, yy, zz, yz, xz, xy) if batch is None: cell_volume = torch.det(cell).abs() - stress = torch.mm(cell.grad, cell.T) / cell_volume + cell_grad = torch.mm(torch.inverse(cell.T), torch.mm(pos.T, pos.grad)) + cell_grad += torch.mm(torch.inverse(cell.T), torch.mm(shift.T, shift.grad)) + # assert torch.allclose(torch.mm(cell_grad.T, cell), torch.mm(cell_grad, cell.T)) + # stress = torch.mm(cell.T, cell_grad) / cell_volume + stress = torch.mm(cell_grad, cell.T) / cell_volume stress = stress.view(-1)[[0, 4, 8, 5, 2, 1]] results_list[0]["stress"] = stress.detach().cpu().numpy() else: cell_volume = torch.det(cell).abs() + cell_T = cell.permute(0, 2, 1) # cell (bs, 3, 3) - stress = torch.bmm(cell.grad, cell.permute(0, 2, 1)) / cell_volume[:, None, None] + edge_grad = shift.new_zeros((n_graphs, 3, 3)) + edge_grad.scatter_add_( + 0, + batch_edge.view(batch_edge.size()[0], 1, 1).expand(batch_edge.size()[0], 3, 3), + shift[:, :, None] * shift.grad[:, None, :], + ) + cell_grad = cell.grad + cell_grad += torch.bmm(torch.inverse(cell_T), edge_grad) + stress = torch.bmm(cell_grad, cell_T) / cell_volume[:, None, None] stress = stress.view(-1, 9)[:, [0, 4, 8, 5, 2, 1]].detach().cpu().numpy() for i in range(n_graphs): results_list[i]["stress"] = stress[i] diff --git a/torch_dftd/nn/dftd2_module.py b/torch_dftd/nn/dftd2_module.py index 9be70cd..0484945 100644 --- a/torch_dftd/nn/dftd2_module.py +++ b/torch_dftd/nn/dftd2_module.py @@ -54,12 +54,14 @@ def calc_energy_batch( damping: str = "zero", ) -> Tensor: """Forward computation to calculate atomic wise dispersion energy""" + shift = pos.new_zeros((edge_index.size()[1], 3, 3)) if shift is None else shift pos_bohr = pos / d3_autoang # angstrom -> bohr if cell is None: - cell_bohr = None + cell_bohr: Optional[Tensor] = None else: cell_bohr = cell / d3_autoang # angstrom -> bohr - r = calc_distances(pos_bohr, edge_index, cell_bohr, shift, batch_edge=batch_edge) + shift_bohr = shift / d3_autoang # angstrom -> bohr + r = calc_distances(pos_bohr, edge_index, cell_bohr, shift_bohr, batch_edge=batch_edge) # E_disp (n_graphs,): Energy in eV unit E_disp = d3_autoev * edisp_d2( diff --git a/torch_dftd/nn/dftd3_module.py b/torch_dftd/nn/dftd3_module.py index 0f756ee..52fe9d7 100644 --- a/torch_dftd/nn/dftd3_module.py +++ b/torch_dftd/nn/dftd3_module.py @@ -76,12 +76,14 @@ def calc_energy_batch( damping: str = "zero", ) -> Tensor: """Forward computation to calculate atomic wise dispersion energy""" + shift = pos.new_zeros((edge_index.size()[1], 3, 3)) if shift is None else shift pos_bohr = pos / d3_autoang # angstrom -> bohr if cell is None: - cell_bohr = None + cell_bohr: Optional[Tensor] = None else: cell_bohr = cell / d3_autoang # angstrom -> bohr - r = calc_distances(pos_bohr, edge_index, cell_bohr, shift, batch_edge=batch_edge) + shift_bohr = shift / d3_autoang # angstrom -> bohr + r = calc_distances(pos_bohr, edge_index, cell_bohr, shift_bohr, batch_edge=batch_edge) # E_disp (n_graphs,): Energy in eV unit E_disp = d3_autoev * edisp( Z, @@ -96,7 +98,7 @@ def calc_energy_batch( cnthr=self.cnthr / Bohr, batch=batch, batch_edge=batch_edge, - shift=shift, + shift=shift_bohr, damping=damping, cutoff_smoothing=self.cutoff_smoothing, bidirectional=self.bidirectional, diff --git a/torch_dftd/torch_dftd3_calculator.py b/torch_dftd/torch_dftd3_calculator.py index 51ecebc..ffbda0a 100644 --- a/torch_dftd/torch_dftd3_calculator.py +++ b/torch_dftd/torch_dftd3_calculator.py @@ -97,14 +97,22 @@ def _preprocess_atoms(self, atoms: Atoms) -> Dict[str, Optional[Tensor]]: ) Z = torch.tensor(atoms.get_atomic_numbers(), device=self.device) if any(atoms.pbc): - cell = torch.tensor( + cell: Optional[Tensor] = torch.tensor( atoms.get_cell(), device=self.device, dtype=self.dtype, requires_grad=True ) else: cell = None pbc = torch.tensor(atoms.pbc, device=self.device) edge_index, S = self._calc_edge_index(pos, cell, pbc) - input_dicts = dict(pos=pos, Z=Z, cell=cell, pbc=pbc, edge_index=edge_index, shift=S) + if cell is None: + shift = S + else: + # shift = S + shift = torch.mm(S, cell.detach()) + shift.requires_grad_(True) + input_dicts = dict( + pos=pos, Z=Z, cell=cell, pbc=pbc, edge_index=edge_index, shift=shift, shift_int=S + ) return input_dicts def calculate(self, atoms=None, properties=["energy"], system_changes=all_changes): From 8633e62a2c9a4b2b7731b5568b5923b67197b061 Mon Sep 17 00:00:00 2001 From: takamoto Date: Thu, 16 Sep 2021 04:52:10 +0000 Subject: [PATCH 2/4] modify pymatgen version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a9da009..be8eb2d 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup_requires: List[str] = [] install_requires: List[str] = [ "ase>=3.18, <4.0.0", # Note that we require ase==3.21.1 for pytest. - "pymatgen", + "pymatgen>=2020.1.10", ] extras_require: Dict[str, List[str]] = { "develop": ["pysen[lint]==0.9.1"], From e7f73aff1c7a93327540457f894aa7eb0de7f654 Mon Sep 17 00:00:00 2001 From: takamoto Date: Thu, 16 Sep 2021 05:06:19 +0000 Subject: [PATCH 3/4] remove debug code --- torch_dftd/nn/base_dftd_module.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torch_dftd/nn/base_dftd_module.py b/torch_dftd/nn/base_dftd_module.py index ff97f41..24f0f64 100644 --- a/torch_dftd/nn/base_dftd_module.py +++ b/torch_dftd/nn/base_dftd_module.py @@ -53,7 +53,6 @@ def calc_energy( batch: Optional[Tensor] = None, batch_edge: Optional[Tensor] = None, damping: str = "zero", - shift_int: Optional[Tensor] = None, ) -> List[Dict[str, Any]]: """Forward computation of dispersion energy @@ -96,7 +95,6 @@ def calc_energy_and_forces( batch: Optional[Tensor] = None, batch_edge: Optional[Tensor] = None, damping: str = "zero", - shift_int: Optional[Tensor] = None, ) -> List[Dict[str, Any]]: """Forward computation of dispersion energy, force and stress From 82a369e31e405d8ee1bb8e957e60cc9ee1ee2ec2 Mon Sep 17 00:00:00 2001 From: takamoto Date: Thu, 16 Sep 2021 05:07:14 +0000 Subject: [PATCH 4/4] remove debug code --- torch_dftd/nn/base_dftd_module.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/torch_dftd/nn/base_dftd_module.py b/torch_dftd/nn/base_dftd_module.py index 24f0f64..9cc37d7 100644 --- a/torch_dftd/nn/base_dftd_module.py +++ b/torch_dftd/nn/base_dftd_module.py @@ -126,13 +126,9 @@ def calc_energy_and_forces( # pos (n_atoms, 1, 3) * cell (n_atoms, 3, 3) -> (n_atoms, 3) pos = torch.bmm(rel_pos[:, None, :].detach(), cell[batch])[:, 0] - # cell_2 = cell.detach().requires_grad_(True) - # shift = torch.mm(shift_int, cell_2) - # shift = shift_int pos.retain_grad() cell.retain_grad() shift.retain_grad() - # cell_2.retain_grad() E_disp = self.calc_energy_batch( Z, pos, edge_index, cell, pbc, shift, batch, batch_edge, damping=damping ) @@ -158,8 +154,6 @@ def calc_energy_and_forces( cell_volume = torch.det(cell).abs() cell_grad = torch.mm(torch.inverse(cell.T), torch.mm(pos.T, pos.grad)) cell_grad += torch.mm(torch.inverse(cell.T), torch.mm(shift.T, shift.grad)) - # assert torch.allclose(torch.mm(cell_grad.T, cell), torch.mm(cell_grad, cell.T)) - # stress = torch.mm(cell.T, cell_grad) / cell_volume stress = torch.mm(cell_grad, cell.T) / cell_volume stress = stress.view(-1)[[0, 4, 8, 5, 2, 1]] results_list[0]["stress"] = stress.detach().cpu().numpy()