forked from SSAGESLabs/PySAGES
-
Notifications
You must be signed in to change notification settings - Fork 0
/
spline_string.py
134 lines (103 loc) · 4.16 KB
/
spline_string.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#!/usr/bin/env python3
import argparse
import importlib
import sys
import hoomd
import hoomd.dlext
import hoomd.md as md
import matplotlib.pyplot as plt
import numpy as np
import pysages
from pysages.colvars import Component
from pysages.methods import SerialExecutor, SplineString
params = {"A": 0.5, "w": 0.2, "p": 2}
def generate_context(**kwargs):
if kwargs.get("mpi_enabled") and False:
MPI = importlib.import_module("mpi4py.MPI")
init_kwargs = {"mpi_comm": MPI.COMM_SELF}
else:
init_kwargs = {}
hoomd.context.initialize("--single-mpi", **init_kwargs)
context = hoomd.context.SimulationContext()
with context:
print(f"Operating replica {kwargs.get('replica_num')}")
hoomd.init.read_gsd("start.gsd")
md.integrate.nve(group=hoomd.group.all())
md.integrate.mode_standard(dt=0.01)
nl = md.nlist.cell()
dpd = md.pair.dpd(r_cut=1, nlist=nl, seed=42, kT=1.0)
dpd.pair_coeff.set("A", "A", A=5.0, gamma=1.0)
dpd.pair_coeff.set("A", "B", A=5.0, gamma=1.0)
dpd.pair_coeff.set("B", "B", A=5.0, gamma=1.0)
periodic = md.external.periodic()
periodic.force_coeff.set("A", A=params["A"], i=0, w=params["w"], p=params["p"])
periodic.force_coeff.set("B", A=0.0, i=0, w=0.02, p=1)
return context
def external_field(r, A, p, w):
return A * np.tanh(1 / (2 * np.pi * p * w) * np.cos(p * r))
def get_args(argv):
available_args = [
("k-spring", "k", float, 50, "Spring constant for each replica"),
("replicas", "N", int, 25, "Number of replicas along the path"),
("time-steps", "t", int, 1e5, "Number of simulation steps for each replica"),
("log-period", "l", int, 50, "Frequency of logging the CVs into each histogram"),
("log-delay", "d", int, 5e3, "Number of timesteps to discard before logging"),
("start-path", "s", float, -1.5, "Start point of the path"),
("end-path", "e", float, 1.5, "Start point of the path"),
("string-steps", "p", int, 15, "Iteration of the string algorithm"),
]
parser = argparse.ArgumentParser(description="Example script to run string method.")
for (name, short, T, val, doc) in available_args:
parser.add_argument("--" + name, "-" + short, type=T, default=T(val), help=doc)
parser.add_argument("--mpi", action="store_true", help="Use MPI executor")
args = parser.parse_args(argv)
return args
def post_run_action(**kwargs):
hoomd.option.set_notice_level(0)
hoomd.dump.gsd(
filename=f"final_{kwargs.get('stringstep')}_{kwargs.get('replica_num')}.gsd",
overwrite=True,
period=None,
group=hoomd.group.all(),
)
hoomd.option.set_notice_level(2)
def get_executor(args):
if args.mpi:
futures = importlib.import_module("mpi4py.futures")
return futures.MPIPoolExecutor()
return SerialExecutor()
def plot_energy(result):
fig, ax = plt.subplots()
ax.set_xlabel("CV")
ax.set_ylabel("Free energy $[\\epsilon]$")
centers = np.asarray(np.asarray(result["path"])[:, 0])
free_energy = np.asarray(result["free_energy"])
offset = np.min(free_energy)
ax.plot(centers, free_energy - offset, "o", color="teal")
x = np.linspace(-3, 3, 50)
data = external_field(x, **params)
offset = np.min(data)
ax.plot(x, data - offset, label="test")
fig.savefig("energy.pdf")
def main(argv):
args = get_args(argv)
cvs = [Component([0], 0), Component([0], 1), Component([0], 2)]
centers = [[c, -1, 1] for c in np.linspace(args.start_path, args.end_path, args.replicas)]
method = SplineString(cvs, args.k_spring, centers, 1e-2, args.log_period, args.log_delay)
context_args = {"mpi_enabled": args.mpi}
raw_result = pysages.run(
method,
generate_context,
args.time_steps,
args.string_steps,
context_args=context_args,
post_run_action=post_run_action,
executor=get_executor(args),
)
result = pysages.analyze(raw_result)
print(np.asarray(result["path_history"]))
print(result["path"])
print(result["point_convergence"])
plot_energy(result)
if __name__ == "__main__":
main(sys.argv[1:])