diff --git a/examples/openmm/metad/pivcv/alanine-dipeptide.py b/examples/openmm/metad/pivcv/alanine-dipeptide.py new file mode 100644 index 00000000..50520d8c --- /dev/null +++ b/examples/openmm/metad/pivcv/alanine-dipeptide.py @@ -0,0 +1,296 @@ +#!/usr/bin/env python3 + +""" +Metadynamics simulation of Alanine Dipeptide in water with OpenMM and PySAGES using +Permutation Invariant Vector (PIV) as CVs. + +Example command to run the simulation `python3 alanine-dipeptide.py --time-steps 1000` +For other supported commandline parameters, check `python3 alanine-dipeptide.py --help` +""" + + +# %% +import argparse +import os +import sys +import time + +import numpy +import pysages + +from pysages.colvars import PIV +from pysages.methods import Metadynamics, MetaDLogger +from pysages.utils import try_import +from pysages.approxfun import compute_mesh + +import numpy as onp +from jax import numpy as np +from jax_md.partition import neighbor_list as nlist, space +from jax_md.partition import NeighborListFormat + +import matplotlib.pyplot as plt + +openmm = try_import("openmm", "simtk.openmm") +unit = try_import("openmm.unit", "simtk.unit") +app = try_import("openmm.app", "simtk.openmm.app") + + +# %% +pi = numpy.pi +kB = unit.BOLTZMANN_CONSTANT_kB * unit.AVOGADRO_CONSTANT_NA +kB = kB.value_in_unit(unit.kilojoules_per_mole / unit.kelvin) + +T = 298.15 * unit.kelvin +dt = 2.0 * unit.femtoseconds +adp_pdb = os.path.join(os.pardir, os.pardir, os.pardir, "inputs", "alanine-dipeptide", "adp-explicit.pdb") + + +# %% +def generate_simulation(pdb_filename=adp_pdb, T=T, dt=dt): + pdb = app.PDBFile(pdb_filename) + + ff = app.ForceField("amber99sb.xml", "tip3p.xml") + cutoff_distance = 1.0 * unit.nanometer + topology = pdb.topology + + system = ff.createSystem( + topology, constraints=app.HBonds, nonbondedMethod=app.PME, nonbondedCutoff=cutoff_distance + ) + + # Set dispersion correction use. + forces = {} + for i in range(system.getNumForces()): + force = system.getForce(i) + forces[force.__class__.__name__] = force + + forces["NonbondedForce"].setUseDispersionCorrection(True) + forces["NonbondedForce"].setEwaldErrorTolerance(1.0e-5) + + positions = pdb.getPositions(asNumpy=True) + + integrator = openmm.LangevinIntegrator(T, 1 / unit.picosecond, dt) + + integrator.setRandomNumberSeed(42) + + # platform = openmm.Platform.getPlatformByName(platform) + # simulation = app.Simulation(topology, system, integrator, platform) + simulation = app.Simulation(topology, system, integrator) + simulation.context.setPositions(positions) + simulation.minimizeEnergy() + + simulation.reporters.append(app.PDBReporter("output.pdb", 1000)) + simulation.reporters.append( + app.StateDataReporter("log.dat", 1000, step=True, potentialEnergy=True, temperature=True) + ) + + return simulation + +def gen_neighbor_list(pdb_filename, custom_mask_function): + + pdb = app.PDBFile(pdb_filename) + top = pdb.getTopology() + positions = np.array(pdb.getPositions(asNumpy=True), dtype=np.float32) + + dr_threshold = 0.5 + box_size = 3 + nl_cutoff = 1 + displacement_fn, shift_fn = space.periodic(box_size) + neighbor_list_fn = nlist(displacement_fn, box_size, nl_cutoff, dr_threshold, capacity_multiplier=0.5, + custom_mask_function=custom_mask_function, format=NeighborListFormat.Dense) + neighbors = neighbor_list_fn.allocate(positions) + + return neighbors + + +def gen_atomtype_lists(pdb_filename=adp_pdb, atomtypes=['C', 'N', 'O'], solventname='HOH'): + + pdb = app.PDBFile(pdb_filename) + top = pdb.getTopology() + + # separate each atom type of interest - solute and solvent oxygen into a list + solute_list = [] + for residue in top.residues(): + if residue.name != solventname: + for atomtype in atomtypes: + for atom in residue.atoms(): + if atom.name.startswith(atomtype): + solute_list.append([int(atom.id)-1]) + + + solute_atoms = [] + oxygen_list = [] + hydrogen_dict = {} + hydrogen_array = np.ones((pdb.topology.getNumAtoms(), 2))*(-1000) + for residue in top.residues(): + if residue.name == solventname: + for atom in residue.atoms(): + if atom.name.startswith('O'): + oxygen_list.append(int(atom.id)-1) + hatom_list = [] + for bond in residue.bonds(): + if bond.atom1.id == atom.id: + hatom_list.append(int(bond.atom2.id)-1) + elif bond.atom2.id == atom.id: + hatom_list.append(int(bond.atom1.id)-1) + hydrogen_dict[int(atom.id)-1] = hatom_list + hydrogen_array = hydrogen_array.at[int(atom.id)-1].set(np.array(hatom_list)) + if atom.name.startswith('H'): + solute_atoms.append(int(atom.id)-1) + else: + for atom in residue.atoms(): + solute_atoms.append(int(atom.id)-1) + + + + #atom_indices.append(oxygen_list) + + print("oxygen list") + print(oxygen_list) + + print("hydrogen dict") + print(hydrogen_dict[22]) + + print("hydrogen array") + print(hydrogen_array) + + print("\n") + + num_atoms = top.getNumAtoms() + natom_types = len(atomtypes) + 1 + + return solute_atoms, solute_list, oxygen_list, hydrogen_array, num_atoms, natom_types + + +def gen_atompair_list(atom_lists, natom_types, exclude_atomtype_pairindices): + + position_pairs = [] + for i in range(natom_types): + + for j in range(i, natom_types): + + for i_particle in range(len(atom_lists[i])): + + for j_particle in range(len(atom_lists[j])): + + if i == j and j_particle <= i_particle: + continue + + if [i, j] in exclude_atomtype_pairindices: + continue + + position_pairs.append([i, atom_lists[i][i_particle], j, atom_lists[j][j_particle]]) + + return np.array(position_pairs) + + +# %% +def get_args(argv): + available_args = [ + ("well-tempered", "w", bool, 0, "Whether to use well-tempered metadynamics"), + ("use-grids", "g", bool, 0, "Whether to use grid acceleration"), + ("log", "l", bool, 0, "Whether to use a callback to log data into a file"), + ("time-steps", "t", int, 5e5, "Number of simulation steps"), + ] + parser = argparse.ArgumentParser(description="Example script to run metadynamics") + for (name, short, T, val, doc) in available_args: + parser.add_argument("--" + name, "-" + short, type=T, default=T(val), help=doc) + return parser.parse_args(argv) + + +# %% +def main(argv=[]): + args = get_args(argv) + + atom_indices, solute_list, oxygen_list, hydrogen_array, num_atoms, natom_types = gen_atomtype_lists() + exclude_atomtype_pairindices = [ [1, 1], [1, 2] ] + + position_pairs = gen_atompair_list(solute_list, natom_types, exclude_atomtype_pairindices) + + all_atoms = list(onp.arange(num_atoms)) + #atom_indices.append(all_atoms) + + print("solute atoms ") + print(np.array(solute_list).flatten()) + + print(atom_indices) + + def filter_solvent_neighbors(idx): + mask = np.isin(idx, np.array(atom_indices), invert=True) + return np.where(mask, idx, num_atoms) + #return idx + + cvs = [PIV( all_atoms, + position_pairs, + solute_list, + oxygen_list, + hydrogen_array, + {'r_0': 0.4, 'd_0': 2.3, 'n': 3, 'm': 6}, + {'update_neighborlist': gen_neighbor_list(adp_pdb, + filter_solvent_neighbors)}) ] + + height = 1.2 # kJ/mol + sigma = [0.35, 0.35] # radians + deltaT = 5000 if args.well_tempered else None + stride = 500 # frequency for depositing gaussians + timesteps = args.time_steps + ngauss = timesteps // stride # total number of gaussians + + ## Grid for storing bias potential and its gradient + #grid = pysages.Grid(lower=(-pi, -pi), upper=(pi, pi), shape=(50, 50), periodic=True) + #grid = grid if args.use_grids else None + + # Method + method = Metadynamics(cvs, height, sigma, stride, ngauss, deltaT=deltaT, kB=kB) # grid=grid) + + # Logging + hills_file = "hills.dat" + callback = MetaDLogger(hills_file, stride) if args.log else None + + tic = time.perf_counter() + run_result = pysages.run(method, generate_simulation, timesteps, callback) + toc = time.perf_counter() + print(f"Completed the simulation in {toc - tic:0.4f} seconds.") + + # Analysis: Calculate free energy using the deposited bias potential + + # generate CV values on a grid to evaluate bias potential + #plot_grid = pysages.Grid(lower=(-pi, -pi), upper=(pi, pi), shape=(64, 64), periodic=True) + #xi = (compute_mesh(plot_grid) + 1) / 2 * plot_grid.size + plot_grid.lower + + # determine bias factor depending on method (for standard = 1 and for well-tempered = (T+deltaT)/deltaT) + #alpha = ( + # 1 + # if method.deltaT is None + # else (T.value_in_unit(unit.kelvin) + method.deltaT) / method.deltaT + #) + #kT = kB * T.value_in_unit(unit.kelvin) + + ## extract metapotential function from result + #result = pysages.analyze(run_result) + #metapotential = result["metapotential"] + + ## report in kT and set min free energy to zero + #A = metapotential(xi) * -alpha / kT + #A = A - A.min() + #A = A.reshape(plot_grid.shape) + + ## plot and save free energy to a PNG file + #fig, ax = plt.subplots(dpi=120) + + #im = ax.imshow(A, interpolation="bicubic", origin="lower", extent=[-pi, pi, -pi, pi]) + #ax.contour(A, levels=12, linewidths=0.75, colors="k", extent=[-pi, pi, -pi, pi]) + #ax.set_xlabel(r"$\phi$") + #ax.set_ylabel(r"$\psi$") + + #cbar = plt.colorbar(im) + #cbar.ax.set_ylabel(r"$A~[k_{B}T]$", rotation=270, labelpad=20) + + #fig.savefig("adp-fe.png", dpi=fig.dpi) + + #return result + + +# %% +if __name__ == "__main__": + + main(sys.argv[1:]) diff --git a/pysages/colvars/__init__.py b/pysages/colvars/__init__.py index 0ccbf330..f0c64d35 100644 --- a/pysages/colvars/__init__.py +++ b/pysages/colvars/__init__.py @@ -26,6 +26,17 @@ import jax_md import jaxopt + +from .piv import ( + PIV +) + +from .utils import ( + get_periods, + wrap, +) + from .patterns import GeM except ImportError: pass + diff --git a/pysages/colvars/piv.py b/pysages/colvars/piv.py new file mode 100644 index 00000000..7d634484 --- /dev/null +++ b/pysages/colvars/piv.py @@ -0,0 +1,229 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2020-2021: PySAGES contributors +# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES + +""" +Implementation of Permutation Invariant Vector (PIV) as described in +[Handb. Mater. Model.: Theor. and Model., 597-619 (2020)] +(https://doi.org/10.1007/978-3-319-44677-6_51) by Fabio Petrucci +""" + +from jax import numpy as np, vmap + +from pysages.colvars.core import CollectiveVariable +from pysages.colvars import coordinates +from pysages.colvars.utils import rational_switching_function + +import numpy as onp + +class PIV(CollectiveVariable): + """ + Permutation Invariant Vector (PIV) of a given system of points + in space as described in Section 4 of + [Handb. Mater. Model.: Theor. and Model., 597-619 (2020)] + (https://doi.org/10.1007/978-3-319-44677-6_51). + + PIV collective variable is generated by using the user-defined + points in space. These points are typically the coordinates + of solute and solvent. Single or multiple solutes in a given + solvent are supported. For determining solvent atoms in a solvation + shell around solute, + [JAX MD](https://jax-md.readthedocs.io/en/main/jax_md.partition.html) + neighborlist library is utilized. This requires the user + to define the indices of all the atoms in the system and a JAX MD + neighbor list callable + (see alanine dipeptide example in examples/openmm/metad/pivcv) for + updating the neighbor list. + + solute-solute blocks of the PIV are determined by using the indices + of the solute pairs generated by the user. To sort solute-solute blocks + and solute-solvent blocks, user need to provide the indices of all + solute atoms and oxygen atoms in two separate lists. The indices + of the hydrogen atoms bonded to the oxygen atoms are determined + by using the oxygen-hydrogen dictionary generated by the user using + a PDB file of the system configuration. + + The switching function parameters parameters for each block + should be provided as a list of dictionaries. + + Example PIV CV definition: + cvs = [PIV( all_atoms, position_pairs, solute_list, oxygen_list, + hydrogen_array, [{'r_0': 0.4, 'd_0': 2.3, 'n': 3, 'm': 6}, ...], + {'neighbor_list': gen_neighbor_list()})] + + Parameters + ---------- + allatoms: list + List of indices of all atoms required for updating neighbor list. + position_pairs: JaxArray + Array containing indices of solute-solute pairs for the solute-solute + block of PIV. + solute_array: JaxArray + Indices of all solute atoms + oxygen_array: JaxArray + Indices of all oxygen atoms + hydrogen_array: JaxArray + Dictionary mapping each oxygen in water with their hydrogen atoms. + switching_params: list[dict] + List of dictionaries containing switching function parameters for each + PIV block. + neighbor_list: Callable + JAX MD neighbor list function to update the neighbor list. + + Returns + ------- + piv: JaxArray + Permutation Invariant Vector (PIV) + """ + + def __init__(self, indices, position_pairs, solute_array, solvent_oxygen_array, + hydrogen_array, switching_params, update_neighborlist): + super().__init__(indices, group_length=None) + self.position_pairs = position_pairs + self.solute_array = solute_array + self.solvent_oxygen_array = solvent_oxygen_array + self.hydrogen_array = hydrogen_array + self.switching_params = switching_params + self.update_neighborlist = update_neighborlist['update_neighborlist'] + + self.time = 0 + print("simulation timestep init, 0") + + @property + def function(self): + """ + Function generator + + Returns + ------- + Function that generates PIV from a simulation snapshot. + Look at `pysages.colvars.ann.piv` for details. + """ + return lambda positions: piv(positions, self.update_neighborlist, self.time, self) + + +def piv(positions, update_neighborlist, time, params): + """ + Implementation of permutation invariant vector as described in + [Section 4, Handb. Mater. Model. 597-619 (2020)] + (https://doi.org/10.1007/978-3-319-44677-6_51) by Fabio Petrucci. + + Parameters + ---------- + positions: JaxArray + Contains positions of all atoms in the system. + update_neighborlist: Callable + Function to update neighbor list. + params: Object + Links to all the helper parameters. This includes + solute-solute pair indices, solvent oxygen indices, + solvent hydrogen indices, and + switching function parameters. + + Returns + ------- + piv : DeviceArray + Permutation Invariant Vector (PIV). + """ + + all_atom_positions = np.array(positions) + + update_neighborlist = update_neighborlist.update(all_atom_positions) + + print("neighbor list state:\n") + print("time is " + str(time)) + time += 1 + print(update_neighborlist.idx) + print(update_neighborlist.reference_position) + print(update_neighborlist.did_buffer_overflow) + print(update_neighborlist.cell_list_capacity) + print(update_neighborlist.max_occupancy) + print(update_neighborlist.format) + print(update_neighborlist.update_fn) + print("\n") + + position_pairs = params.position_pairs + solute_list = params.solute_array + solvent_oxygen_list = params.solvent_oxygen_array + hydrogen_array = params.hydrogen_array + + i_pos = all_atom_positions[position_pairs[:,1]] + j_pos = all_atom_positions[position_pairs[:,3]] + + piv_solute_blocks = vmap(get_piv_block, in_axes=(0, 0, None))(i_pos, j_pos, params.switching_params) + piv_solute_block_index = vmap(cantor_pair, in_axes=(0, 0))(position_pairs[:,0], position_pairs[:,2]) + + idx_solute_sort = np.argsort(piv_solute_block_index) + piv_solute_blocks = piv_solute_blocks[idx_solute_sort] + + if solvent_oxygen_list: + + nlist = update_neighborlist.idx + atom_ids = np.arange(np.shape(nlist)[0])[:, np.newaxis] + atom_ids_nlist = np.hstack((atom_ids, nlist)) + solute_atom_ids_nlist = atom_ids_nlist[np.array(solute_list).flatten()] + oxygen_array = solute_atom_ids_nlist[:,1:] + solute_array = solute_atom_ids_nlist[:,0] + solute_array = np.repeat(solute_array, np.shape(oxygen_array)[1]) + solute_array = np.reshape(solute_array, np.shape(oxygen_array)) + solute_array = solute_array[:, :, np.newaxis] + block_ids = np.arange(np.shape(solute_list)[0]) + block_ids = np.repeat(block_ids, np.shape(oxygen_array)[1]) + block_ids = np.reshape(block_ids, np.shape(oxygen_array)) + block_ids = block_ids[:, :, np.newaxis] + + + hydrogen_array = hydrogen_array[oxygen_array.flatten()] + hydrogen_array = np.array(hydrogen_array, dtype=int) + hydrogen_array = np.reshape(hydrogen_array, (*np.shape(oxygen_array)[:2], 2)) + + + + solute_atom_ids_nlist = np.dstack((block_ids, solute_array, + oxygen_array, hydrogen_array)).reshape(-1,5) + + + i_pos = all_atom_positions[solute_atom_ids_nlist[:,1]] + j_pos = all_atom_positions[solute_atom_ids_nlist[:,2]] + piv_solute_solvent_blocks = vmap(get_piv_block, in_axes=(0, 0, None))(i_pos, j_pos, params.switching_params) + piv_solute_solvent_block_index = solute_atom_ids_nlist[:,0] + idx_solvent_sort = np.argsort(piv_solute_solvent_block_index) + piv_solute_solvent_blocks = piv_solute_solvent_blocks[idx_solvent_sort] + piv_solute_solvent_blocks = piv_solute_solvent_blocks.flatten() + + + piv_blocks = np.concatenate( (piv_solute_blocks, piv_solute_solvent_blocks), axis=0) + + + else: + + piv_blocks = piv_solute_blocks + + return piv_blocks[0] + + +def get_piv_block(i_pos, j_pos, switching_params): + + r_0 = switching_params['r_0'] + d_0 = switching_params['d_0'] + n = switching_params['n'] + m = switching_params['m'] + + r = coordinates.distance(i_pos, j_pos) + s_r = rational_switching_function(r, r_0, d_0, n, m) + + return s_r + + +def cantor_pair(int1, int2): + """ + Generates an uniuqe integer using two integers via Cantor pair function. + This unique integer can be mapped back to the two integers, if needed. + """ + + pi = int1 + int2 + pi = pi * (pi + 1) + pi *= 0.5 + pi += int2 + + return np.int32(pi) \ No newline at end of file diff --git a/pysages/colvars/utils.py b/pysages/colvars/utils.py index f77a0f9b..12bc3f3b 100644 --- a/pysages/colvars/utils.py +++ b/pysages/colvars/utils.py @@ -24,3 +24,38 @@ def wrap(x, P): Given a period `P`, wraps around `x` over the interval from `-P / 2` to `P / 2`. """ return np.where(np.isinf(P), x, x - (np.abs(x) > P / 2) * np.sign(x) * P) + + +def rational_switching_function(r, r_0, d_0=0.0, n=6, m=None): + """ + Rational switching function applied to a given variable r. + + Parameters + ---------- + r: float + variable to which switching function is applied. + + r_0 : float + + d_0: float = 0.0 + + n: int = 6 + + m: int = 2*n + + Returns + ------- + s : float + Rational switching function applied to a given r. + """ + + if m == None: + m = 2*n + + s_common = (r - d_0)/r_0 + s_n = 1 - s_common**n + s_m = 1 - s_common**m + s = s_n/s_m + + return s + \ No newline at end of file