From 81eaea31eb84221b627e4cc6846ec1aa5bcfae73 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 6 Jun 2024 17:51:05 -0400 Subject: [PATCH] fix the situation when input_dict is not empty Signed-off-by: Jinzhe Zeng --- deepmd/tf/descriptor/se_a.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deepmd/tf/descriptor/se_a.py b/deepmd/tf/descriptor/se_a.py index 1175842a5a..108e486da7 100644 --- a/deepmd/tf/descriptor/se_a.py +++ b/deepmd/tf/descriptor/se_a.py @@ -301,7 +301,8 @@ def __init__( self.stat_descrpt *= tf.reshape(mask, tf.shape(self.stat_descrpt)) self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config) self.original_sel = None - self.use_tebd: Optional[bool] = None + # Whether type embedding is used + self.use_tebd: bool = False def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -747,10 +748,10 @@ def _pass_filter( ): if input_dict is not None: type_embedding = input_dict.get("type_embedding", None) - self.use_tebd = True + if type_embedding is not None: + self.use_tebd = True else: type_embedding = None - self.use_tebd = False if self.stripped_type_embedding and type_embedding is None: raise RuntimeError("type_embedding is required for se_a_tebd_v2 model.") start_index = 0 @@ -1419,7 +1420,6 @@ def serialize(self, suffix: str = "") -> dict: raise NotImplementedError("spin is unsupported") assert self.davg is not None assert self.dstd is not None - assert self.use_tebd is not None if self.use_tebd: raise RuntimeError( "Serialization is unsupported when type_embedding is used."