Skip to content

Commit

Permalink
fix the situation when input_dict is not empty
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Jun 6, 2024
1 parent de3b048 commit 81eaea3
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions deepmd/tf/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ def __init__(
self.stat_descrpt *= tf.reshape(mask, tf.shape(self.stat_descrpt))
self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config)
self.original_sel = None
self.use_tebd: Optional[bool] = None
# Whether type embedding is used
self.use_tebd: bool = False

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -747,10 +748,10 @@ def _pass_filter(
):
if input_dict is not None:
type_embedding = input_dict.get("type_embedding", None)
self.use_tebd = True
if type_embedding is not None:
self.use_tebd = True
else:
type_embedding = None
self.use_tebd = False
if self.stripped_type_embedding and type_embedding is None:
raise RuntimeError("type_embedding is required for se_a_tebd_v2 model.")
start_index = 0
Expand Down Expand Up @@ -1419,7 +1420,6 @@ def serialize(self, suffix: str = "") -> dict:
raise NotImplementedError("spin is unsupported")
assert self.davg is not None
assert self.dstd is not None
assert self.use_tebd is not None
if self.use_tebd:
raise RuntimeError(

Check warning on line 1424 in deepmd/tf/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_a.py#L1424

Added line #L1424 was not covered by tests
"Serialization is unsupported when type_embedding is used."
Expand Down

0 comments on commit 81eaea3

Please sign in to comment.