Skip to content

Commit

Permalink
Perf: use F.linear for MLP
Browse files Browse the repository at this point in the history
  • Loading branch information
caic99 committed Dec 26, 2024
1 parent 3cecca4 commit 4645a43
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions deepmd/pt/model/network/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from deepmd.pt.utils import (
env,
Expand Down Expand Up @@ -202,20 +203,16 @@ def forward(
ori_prec = xx.dtype
if not env.DP_DTYPE_PROMOTION_STRICT:
xx = xx.to(self.prec)
yy = (
torch.matmul(xx, self.matrix) + self.bias
if self.bias is not None
else torch.matmul(xx, self.matrix)
)
yy = F.linear(xx, self.matrix.t(), self.bias)
# some activation functions are in-place, prevent further modification on `yy`; needs to be cloned
yy = self.activate(yy).clone()
yy = yy * self.idt if self.idt is not None else yy
if self.idt is not None:
yy *= self.idt
if self.resnet:
if xx.shape[-1] == yy.shape[-1]:
yy += xx
elif 2 * xx.shape[-1] == yy.shape[-1]:
yy += torch.concat([xx, xx], dim=-1)
else:
yy = yy
if not env.DP_DTYPE_PROMOTION_STRICT:
yy = yy.to(ori_prec)
return yy
Expand Down

0 comments on commit 4645a43

Please sign in to comment.