diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index 31162fe80e..6ef6c7f733 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -8,6 +8,7 @@ import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from deepmd.pt.utils import ( env, @@ -202,20 +203,16 @@ def forward( ori_prec = xx.dtype if not env.DP_DTYPE_PROMOTION_STRICT: xx = xx.to(self.prec) - yy = ( - torch.matmul(xx, self.matrix) + self.bias - if self.bias is not None - else torch.matmul(xx, self.matrix) - ) + yy = F.linear(xx, self.matrix.t(), self.bias) + # some activation functions are in-place, prevent further modification on `yy`; needs to be cloned yy = self.activate(yy).clone() - yy = yy * self.idt if self.idt is not None else yy + if self.idt is not None: + yy *= self.idt if self.resnet: if xx.shape[-1] == yy.shape[-1]: yy += xx elif 2 * xx.shape[-1] == yy.shape[-1]: yy += torch.concat([xx, xx], dim=-1) - else: - yy = yy if not env.DP_DTYPE_PROMOTION_STRICT: yy = yy.to(ori_prec) return yy