Skip to content

Commit

Permalink
move to DeepEval.eval
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Dec 23, 2024
1 parent c601367 commit e781fd4
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions deepmd/pd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,11 @@ def eval(
The output of the evaluation. The keys are the names of the output
variables, and the values are the corresponding output arrays.
"""
# switch to eval mode
is_training = self.dp.training
if is_training:
self.dp.eval()

# convert all of the input to numpy array
atom_types = np.array(atom_types, dtype=np.int32)
coords = np.array(coords)
Expand All @@ -260,6 +265,11 @@ def eval(
aparam,
request_defs,
)

# switch back to training mode if previously enabled
if is_training:
self.dp.train()

return dict(
zip(
[x.name for x in request_defs],
Expand Down Expand Up @@ -354,12 +364,6 @@ def _eval_model(
request_defs: list[OutputVariableDef],
):
model = self.dp.to(DEVICE)

# switch to eval mode
is_training = model.training
if is_training:
model.eval()

prec = NP_PRECISION_DICT[RESERVED_PRECISON_DICT[GLOBAL_PD_FLOAT_PRECISION]]

nframes = coords.shape[0]
Expand Down Expand Up @@ -425,11 +429,6 @@ def _eval_model(
results.append(
np.full(np.abs(shape), np.nan, dtype=prec)
) # this is kinda hacky

# switch back to training mode if previously enabled
if is_training:
model.train()

return tuple(results)

def _eval_model_spin(
Expand Down

0 comments on commit e781fd4

Please sign in to comment.