Skip to content

Commit

Permalink
add ut for extend stat
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Jun 11, 2024
1 parent af6c8b2 commit ad838c4
Show file tree
Hide file tree
Showing 17 changed files with 346 additions and 0 deletions.
5 changes: 5 additions & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,14 @@ def set_stat_mean_and_stddev(
mean: np.ndarray,
stddev: np.ndarray,
) -> None:
"""Update mean and stddev for descriptor."""
self.se_atten.mean = mean
self.se_atten.stddev = stddev

def get_stat_mean_and_stddev(self) -> Tuple[np.ndarray, np.ndarray]:
"""Get mean and stddev for descriptor."""
return self.se_atten.mean, self.se_atten.stddev

def change_type_map(
self, type_map: List[str], model_with_new_type_stat=None
) -> None:
Expand Down
17 changes: 17 additions & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,23 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None)
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError

def set_stat_mean_and_stddev(
self,
mean: List[np.ndarray],
stddev: List[np.ndarray],
) -> None:
"""Update mean and stddev for descriptor."""
for ii, descrpt in enumerate([self.repinit, self.repformers]):
descrpt.mean = mean[ii]
descrpt.stddev = stddev[ii]

def get_stat_mean_and_stddev(self) -> Tuple[List[np.ndarray], List[np.ndarray]]:
"""Get mean and stddev for descriptor."""
return [self.repinit.mean, self.repformers.mean], [
self.repinit.stddev,
self.repformers.stddev,
]

def call(
self,
coord_ext: np.ndarray,
Expand Down
24 changes: 24 additions & 0 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,30 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None)
for descrpt in self.descrpt_list:
descrpt.compute_input_stats(merged, path)

def set_stat_mean_and_stddev(
self,
mean: List[Union[np.ndarray, List[np.ndarray]]],
stddev: List[Union[np.ndarray, List[np.ndarray]]],
) -> None:
"""Update mean and stddev for descriptor."""
for ii, descrpt in enumerate(self.descrpt_list):
descrpt.set_stat_mean_and_stddev(mean[ii], stddev[ii])

def get_stat_mean_and_stddev(
self,
) -> Tuple[
List[Union[np.ndarray, List[np.ndarray]]],
List[Union[np.ndarray, List[np.ndarray]]],
]:
"""Get mean and stddev for descriptor."""
mean_list = []
stddev_list = []
for ii, descrpt in enumerate(self.descrpt_list):
mean_item, stddev_item = descrpt.get_stat_mean_and_stddev()
mean_list.append(mean_item)
stddev_list.append(stddev_item)
return mean_list, stddev_list

def call(
self,
coord_ext,
Expand Down
10 changes: 10 additions & 0 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ def change_type_map(
"""
pass

@abstractmethod
def set_stat_mean_and_stddev(self, mean, stddev) -> None:
"""Update mean and stddev for descriptor."""
pass

@abstractmethod
def get_stat_mean_and_stddev(self):
"""Get mean and stddev for descriptor."""
pass

def compute_input_stats(
self,
merged: Union[Callable[[], List[dict]], List[dict]],
Expand Down
13 changes: 13 additions & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,19 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None)
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError

def set_stat_mean_and_stddev(
self,
mean: np.ndarray,
stddev: np.ndarray,
) -> None:
"""Update mean and stddev for descriptor."""
self.davg = mean
self.dstd = stddev

def get_stat_mean_and_stddev(self) -> Tuple[np.ndarray, np.ndarray]:
"""Get mean and stddev for descriptor."""
return self.davg, self.dstd

def cal_g(
self,
ss,
Expand Down
13 changes: 13 additions & 0 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,19 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None)
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError

def set_stat_mean_and_stddev(
self,
mean: np.ndarray,
stddev: np.ndarray,
) -> None:
"""Update mean and stddev for descriptor."""
self.davg = mean
self.dstd = stddev

def get_stat_mean_and_stddev(self) -> Tuple[np.ndarray, np.ndarray]:
"""Get mean and stddev for descriptor."""
return self.davg, self.dstd

def cal_g(
self,
ss,
Expand Down
13 changes: 13 additions & 0 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,19 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None)
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError

def set_stat_mean_and_stddev(
self,
mean: np.ndarray,
stddev: np.ndarray,
) -> None:
"""Update mean and stddev for descriptor."""
self.davg = mean
self.dstd = stddev

def get_stat_mean_and_stddev(self) -> Tuple[np.ndarray, np.ndarray]:
"""Get mean and stddev for descriptor."""
return self.davg, self.dstd

def reinit_exclude(
self,
exclude_types: List[Tuple[int, int]] = [],
Expand Down
5 changes: 5 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,9 +419,14 @@ def set_stat_mean_and_stddev(
mean: torch.Tensor,
stddev: torch.Tensor,
) -> None:
"""Update mean and stddev for descriptor."""
self.se_atten.mean = mean
self.se_atten.stddev = stddev

def get_stat_mean_and_stddev(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get mean and stddev for descriptor."""
return self.se_atten.mean, self.se_atten.stddev

def change_type_map(
self, type_map: List[str], model_with_new_type_stat=None
) -> None:
Expand Down
17 changes: 17 additions & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,23 @@ def compute_input_stats(
for ii, descrpt in enumerate([self.repinit, self.repformers]):
descrpt.compute_input_stats(merged, path)

def set_stat_mean_and_stddev(
self,
mean: List[torch.Tensor],
stddev: List[torch.Tensor],
) -> None:
"""Update mean and stddev for descriptor."""
for ii, descrpt in enumerate([self.repinit, self.repformers]):
descrpt.mean = mean[ii]
descrpt.stddev = stddev[ii]

def get_stat_mean_and_stddev(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""Get mean and stddev for descriptor."""
return [self.repinit.mean, self.repformers.mean], [
self.repinit.stddev,
self.repformers.stddev,
]

def serialize(self) -> dict:
repinit = self.repinit
repformers = self.repformers
Expand Down
24 changes: 24 additions & 0 deletions deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,30 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None)
for descrpt in self.descrpt_list:
descrpt.compute_input_stats(merged, path)

def set_stat_mean_and_stddev(
self,
mean: List[Union[torch.Tensor, List[torch.Tensor]]],
stddev: List[Union[torch.Tensor, List[torch.Tensor]]],
) -> None:
"""Update mean and stddev for descriptor."""
for ii, descrpt in enumerate(self.descrpt_list):
descrpt.set_stat_mean_and_stddev(mean[ii], stddev[ii])

def get_stat_mean_and_stddev(
self,
) -> Tuple[
List[Union[torch.Tensor, List[torch.Tensor]]],
List[Union[torch.Tensor, List[torch.Tensor]]],
]:
"""Get mean and stddev for descriptor."""
mean_list = []
stddev_list = []
for ii, descrpt in enumerate(self.descrpt_list):
mean_item, stddev_item = descrpt.get_stat_mean_and_stddev()
mean_list.append(mean_item)
stddev_list.append(stddev_item)
return mean_list, stddev_list

def forward(
self,
coord_ext: torch.Tensor,
Expand Down
5 changes: 5 additions & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,14 @@ def set_stat_mean_and_stddev(
mean: torch.Tensor,
stddev: torch.Tensor,
) -> None:
"""Update mean and stddev for descriptor."""
self.sea.mean = mean
self.sea.stddev = stddev

def get_stat_mean_and_stddev(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get mean and stddev for descriptor."""
return self.sea.mean, self.sea.stddev

def serialize(self) -> dict:
obj = self.sea
return {
Expand Down
5 changes: 5 additions & 0 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,14 @@ def set_stat_mean_and_stddev(
mean: torch.Tensor,
stddev: torch.Tensor,
) -> None:
"""Update mean and stddev for descriptor."""
self.mean = mean
self.stddev = stddev

def get_stat_mean_and_stddev(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get mean and stddev for descriptor."""
return self.mean, self.stddev

def serialize(self) -> dict:
return {
"@class": "Descriptor",
Expand Down
5 changes: 5 additions & 0 deletions deepmd/pt/model/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,14 @@ def set_stat_mean_and_stddev(
mean: torch.Tensor,
stddev: torch.Tensor,
) -> None:
"""Update mean and stddev for descriptor."""
self.seat.mean = mean
self.seat.stddev = stddev

def get_stat_mean_and_stddev(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get mean and stddev for descriptor."""
return self.seat.mean, self.seat.stddev

def serialize(self) -> dict:
obj = self.seat
return {
Expand Down
10 changes: 10 additions & 0 deletions source/tests/universal/common/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,13 @@ def modules_to_test(self) -> list:
@abstractmethod
def forward_wrapper(self, x):
pass

@classmethod
@abstractmethod
def convert_to_numpy(cls, xx):
pass

@classmethod
@abstractmethod
def convert_from_numpy(cls, xx):
pass
Loading

0 comments on commit ad838c4

Please sign in to comment.