Skip to content

Commit

Permalink
chore: refactor global stat
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Apr 8, 2024
1 parent 74cdba0 commit 6c299fa
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 78 deletions.
229 changes: 151 additions & 78 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def model_forward_auto_batch_size(*args, **kwargs):
for kk in keys:
model_predict[kk].append(
to_numpy_array(
torch.sum(sample_predict[kk], dim=1) # nf x nloc x odims
sample_predict[kk] # nf x nloc x odims
)
)
model_predict = {kk: np.concatenate(model_predict[kk]) for kk in keys}
Expand Down Expand Up @@ -246,87 +246,160 @@ def compute_output_stats(

# failed to restore the bias from stat file. compute
if bias_atom_e is None:
# only get data for once
sampled = merged() if callable(merged) else merged
# remove the keys that are not in the sample
keys = [keys] if isinstance(keys, str) else keys
assert isinstance(keys, list)
new_keys = [ii for ii in keys if ii in sampled[0].keys()]
del keys
keys = new_keys
# get label dict from sample
outputs = {kk: [item[kk] for item in sampled] for kk in keys}
data_mixed_type = "real_natoms_vec" in sampled[0]
natoms_key = "natoms" if not data_mixed_type else "real_natoms_vec"
for system in sampled:
if "atom_exclude_types" in system:
type_mask = AtomExcludeMask(
ntypes, system["atom_exclude_types"]
).get_type_mask()
system[natoms_key][:, 2:] *= type_mask.unsqueeze(0)
input_natoms = [item[natoms_key] for item in sampled]
# shape: (nframes, ndim)
merged_output = {kk: to_numpy_array(torch.cat(outputs[kk])) for kk in keys}
# shape: (nframes, ntypes)
merged_natoms = to_numpy_array(torch.cat(input_natoms)[:, 2:])
nf = merged_natoms.shape[0]
if preset_bias is not None:
assigned_atom_ener = {
kk: _make_preset_out_bias(ntypes, preset_bias[kk])
if kk in preset_bias.keys()
else None
for kk in keys
}


# only get data once, sampled is a list of dict[str, torch.Tensor]
sampled = merged() if callable(merged) else merged
if model_forward is not None:
model_pred = _compute_model_predict(sampled, keys, model_forward)
else:
assigned_atom_ener = {kk: None for kk in keys}

if model_forward is None:
stats_input = merged_output
else:
# subtract the model bias and output the delta bias
model_predict = _compute_model_predict(sampled, keys, model_forward)
stats_input = {kk: merged_output[kk] - model_predict[kk] for kk in keys}

bias_atom_e = {}
std_atom_e = {}
for kk in keys:
bias_atom_e[kk], std_atom_e[kk] = compute_stats_from_redu(
stats_input[kk],
merged_natoms,
assigned_bias=assigned_atom_ener[kk],
rcond=rcond,
model_pred = None

# split system based on label
atomic_sampled = {}
global_sampled = {}
"""
case1: system-1 global dipole and atomic polar, system-2 global dipole and global polar
dipole,sys1 --> add to global_sampled
dipole,sys2 --> add to global_sampled
polar, sys1 --> add to atomic_sampled
polar, sys2 --> do nothing
global_sampled : [sys1, sys2]
atomic_sampled : [sys1]
"""
for kk in keys:
for idx, system in enumerate(sampled):
if (("find_atom_" + kk) in system) and (system["find_atom_" + kk] > 0.0) and (idx not in atomic_sampled):
atomic_sampled[idx] = system
elif (("find_" + kk) in system) and (system["find_" + kk] > 0.0) and (idx not in global_sampled):
global_sampled[idx] = system
else:
continue

atomic_sampled = list(atomic_sampled.values())
global_sampled = list(global_sampled.values())
if len(global_sampled) > 0:
bias_atom_e, std_atom_e = compute_output_stats_global(
global_sampled,
ntypes,
keys,
rcond,
preset_bias,
model_pred,
)
bias_atom_e, std_atom_e = _post_process_stat(bias_atom_e, std_atom_e)

# unbias_e is only used for print rmse
if model_forward is None:
unbias_e = {
kk: merged_natoms @ bias_atom_e[kk].reshape(ntypes, -1) for kk in keys
}
else:
unbias_e = {
kk: model_predict[kk].reshape(nf, -1)
+ merged_natoms @ bias_atom_e[kk].reshape(ntypes, -1)
for kk in keys
}
atom_numbs = merged_natoms.sum(-1)

def rmse(x):
return np.sqrt(np.mean(np.square(x)))

for kk in keys:
rmse_ae = rmse(
(unbias_e[kk].reshape(nf, -1) - merged_output[kk].reshape(nf, -1))
/ atom_numbs[:, None]
)
log.info(
f"RMSE of {kk} per atom after linear regression is: {rmse_ae} in the unit of {kk}."

if len(atomic_sampled) > 0:
bias_atom_e, std_atom_e = compute_output_stats_atomic(
global_sampled,
ntypes,
keys,
rcond,
preset_bias,
model_pred,
)


# need to merge dict
if stat_file_path is not None:
_save_to_file(stat_file_path, bias_atom_e, std_atom_e)

ret_bias = {kk: to_torch_tensor(vv) for kk, vv in bias_atom_e.items()}
ret_std = {kk: to_torch_tensor(vv) for kk, vv in std_atom_e.items()}
bias_atom_e = {kk: to_torch_tensor(vv) for kk, vv in bias_atom_e.items()}
std_atom_e = {kk: to_torch_tensor(vv) for kk, vv in std_atom_e.items()}
return bias_atom_e, std_atom_e

return ret_bias, ret_std
def compute_output_stats_global(
sampled: List[dict],
ntypes: int,
keys: List[str],
rcond: Optional[float] = None,
preset_bias: Optional[Dict[str, List[Optional[torch.Tensor]]]] = None,
model_pred: Optional[Dict[str, np.ndarray]] = None,
):
"""This function only handle stat computation from reduced global labels."""

# remove the keys that are not in the sample
keys = [keys] if isinstance(keys, str) else keys
assert isinstance(keys, list)
new_keys = [ii for ii in keys if ii in sampled[0].keys()]
del keys
keys = new_keys

# get label dict from sample; for each key, only picking the system with global labels.
outputs = {kk: [system[kk] for system in sampled if kk in system and system.get(f"find_{kk}", 0) > 0] for kk in keys}

data_mixed_type = "real_natoms_vec" in sampled[0]
natoms_key = "natoms" if not data_mixed_type else "real_natoms_vec"
for system in sampled:
if "atom_exclude_types" in system:
type_mask = AtomExcludeMask(
ntypes, system["atom_exclude_types"]
).get_type_mask()
system[natoms_key][:, 2:] *= type_mask.unsqueeze(0)
input_natoms = [item[natoms_key] for item in sampled]
# shape: (nframes, ndim)
merged_output = {kk: to_numpy_array(torch.cat(outputs[kk])) for kk in keys}
# shape: (nframes, ntypes)
merged_natoms = to_numpy_array(torch.cat(input_natoms)[:, 2:])
nf = merged_natoms.shape[0]
if preset_bias is not None:
assigned_atom_ener = {
kk: _make_preset_out_bias(ntypes, preset_bias[kk])
if kk in preset_bias.keys()
else None
for kk in keys
}
else:
assigned_atom_ener = {kk: None for kk in keys}

if model_pred is None:
stats_input = merged_output
else:
# subtract the model bias and output the delta bias
model_pred = {kk: np.sum(model_pred[kk], axis=1) for kk in keys}
stats_input = {kk: merged_output[kk] - model_pred[kk] for kk in keys}

bias_atom_e = {}
std_atom_e = {}
for kk in keys:
bias_atom_e[kk], std_atom_e[kk] = compute_stats_from_redu(
stats_input[kk],
merged_natoms,
assigned_bias=assigned_atom_ener[kk],
rcond=rcond,
)
bias_atom_e, std_atom_e = _post_process_stat(bias_atom_e, std_atom_e)

# unbias_e is only used for print rmse
if model_pred is None:
unbias_e = {
kk: merged_natoms @ bias_atom_e[kk].reshape(ntypes, -1) for kk in keys
}
else:
unbias_e = {
kk: model_pred[kk].reshape(nf, -1)
+ merged_natoms @ bias_atom_e[kk].reshape(ntypes, -1)
for kk in keys
}
atom_numbs = merged_natoms.sum(-1)

def rmse(x):
return np.sqrt(np.mean(np.square(x)))

for kk in keys:
rmse_ae = rmse(
(unbias_e[kk].reshape(nf, -1) - merged_output[kk].reshape(nf, -1))
/ atom_numbs[:, None]
)
log.info(
f"RMSE of {kk} per atom after linear regression is: {rmse_ae} in the unit of {kk}."
)
return bias_atom_e, std_atom_e

def compute_output_stats_atomic(
sampled: List[dict],
ntypes: int,
keys: List[str],
rcond: Optional[float] = None,
preset_bias: Optional[Dict[str, List[Optional[torch.Tensor]]]] = None,
model_pred: Optional[Dict[str, np.ndarray]] = None,
):
pass
2 changes: 2 additions & 0 deletions source/tests/pt/model/test_atomic_model_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def setUp(self):
"bar": to_torch_tensor(
np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2)
),
"find_foo": np.float32(1.0),
"find_bar": np.float32(1.0)
}
]
self.tempdir = tempfile.TemporaryDirectory()
Expand Down

0 comments on commit 6c299fa

Please sign in to comment.