diff --git a/src/openfe_analysis/__init__.py b/src/openfe_analysis/__init__.py index b65c451..e751fd2 100644 --- a/src/openfe_analysis/__init__.py +++ b/src/openfe_analysis/__init__.py @@ -3,6 +3,6 @@ from .reader import FEReader from .transformations import ( Aligner, - Minimiser, + ClosestImageShift, NoJump, ) diff --git a/src/openfe_analysis/rmsd.py b/src/openfe_analysis/rmsd.py index 1519d64..68d642e 100644 --- a/src/openfe_analysis/rmsd.py +++ b/src/openfe_analysis/rmsd.py @@ -7,10 +7,12 @@ import numpy as np import tqdm from MDAnalysis.analysis import rms +from MDAnalysis.lib.mdamath import make_whole +from MDAnalysis.transformations import unwrap from numpy import typing as npt from .reader import FEReader -from .transformations import Aligner, Minimiser, NoJump +from .transformations import Aligner, ClosestImageShift, NoJump def make_Universe(top: pathlib.Path, trj: nc.Dataset, state: int) -> mda.Universe: @@ -41,17 +43,21 @@ def make_Universe(top: pathlib.Path, trj: nc.Dataset, state: int) -> mda.Univers ligand = u.select_atoms("resname UNK") if prot: - # if there's a protein in the system: - # - make the protein not jump periodic images between frames - # - put the ligand in the closest periodic image as the protein - # - align everything to minimise protein RMSD - nope = NoJump(prot) - minnie = Minimiser(prot, ligand) + # Unwrap all atoms + unwrap_tr = unwrap(prot) + + # Shift chains + ligand + chains = [seg.atoms for seg in prot.segments] + shift = ClosestImageShift(chains[0], [*chains[1:], ligand]) + # # Make each protein chain whole + # for frag in prot.fragments: + # make_whole(frag, reference_atom=frag[0]) + align = Aligner(prot) u.trajectory.add_transformations( - nope, - minnie, + unwrap_tr, + shift, align, ) else: @@ -129,9 +135,9 @@ def gather_rms_data( # TODO: Some smart guard to avoid allocating a silly amount of memory? prot2d = np.empty((len(u.trajectory[::skip]), len(prot), 3), dtype=np.float32) - prot_start = prot.positions - # prot_weights = prot.masses / np.mean(prot.masses) - ligand_start = ligand.positions + # Would this copy be safer? + prot_start = prot.positions.copy() + ligand_start = ligand.positions.copy() ligand_initial_com = ligand.center_of_mass() ligand_weights = ligand.masses / np.mean(ligand.masses) diff --git a/src/openfe_analysis/tests/test_reader.py b/src/openfe_analysis/tests/test_reader.py index d6a7ca1..c8d4bb2 100644 --- a/src/openfe_analysis/tests/test_reader.py +++ b/src/openfe_analysis/tests/test_reader.py @@ -156,19 +156,25 @@ def test_fereader_replica_state_id_error( def test_simulation_skipped_nc(simulation_skipped_nc, hybrid_system_skipped_pdb): + from MDAnalysis.transformations import wrap + u = mda.Universe( hybrid_system_skipped_pdb, simulation_skipped_nc, format=FEReader, replica_id=0, ) + + # Wrap all atoms inside the simulation box + u.trajectory.add_transformations(wrap(u.atoms)) + assert len(u.trajectory) == 51 assert u.trajectory.n_frames == 51 assert u.trajectory.dt == 100 times = np.arange(0, 5001, 100) for inx, ts in enumerate(u.trajectory): assert ts.time == times[inx] - # Positions are not all zero since PBC is not removed + assert np.all(u.atoms.positions > 0) assert np.any(u.atoms.positions != 0) with pytest.raises(mda.exceptions.NoDataError, match="This Timestep has no velocities"): u.atoms.velocities diff --git a/src/openfe_analysis/tests/test_rmsd.py b/src/openfe_analysis/tests/test_rmsd.py index e70dd08..b9022ce 100644 --- a/src/openfe_analysis/tests/test_rmsd.py +++ b/src/openfe_analysis/tests/test_rmsd.py @@ -1,9 +1,34 @@ +from itertools import islice + +import MDAnalysis as mda import netCDF4 as nc import numpy as np import pytest +from MDAnalysis.analysis import rms +from MDAnalysis.lib.mdamath import make_whole +from MDAnalysis.transformations import unwrap from numpy.testing import assert_allclose -from openfe_analysis.rmsd import gather_rms_data +from openfe_analysis.reader import FEReader +from openfe_analysis.rmsd import gather_rms_data, make_Universe +from openfe_analysis.transformations import Aligner + + +@pytest.fixture +def mda_universe(hybrid_system_skipped_pdb, simulation_skipped_nc): + """ + Safely create and destroy an MDAnalysis Universe. + + Guarantees: + - NetCDF file is opened exactly once + """ + u = make_Universe( + hybrid_system_skipped_pdb, + simulation_skipped_nc, + state=0, + ) + yield u + u.trajectory.close() def test_gather_rms_data_regression(simulation_nc, hybrid_system_pdb): @@ -51,24 +76,24 @@ def test_gather_rms_data_regression_skippednc(simulation_skipped_nc, hybrid_syst assert_allclose(output["time(ps)"], np.arange(0, 5001, 100)) assert len(output["protein_RMSD"]) == 11 - # TODO: RMSD is very large as the multichain fix is not in yet + # RMSD is low for this multichain protein assert_allclose( output["protein_RMSD"][0][:6], - [0, 30.620948, 31.158894, 1.045068, 30.735975, 30.999849], + [0, 1.089747, 1.006143, 1.045068, 1.476353, 1.332893], rtol=1e-3, ) assert len(output["ligand_RMSD"]) == 11 # TODO: RMSD is very large as the multichain fix is not in yet assert_allclose( output["ligand_RMSD"][0][:6], - [0.0, 12.607834, 13.882825, 1.228384, 14.129542, 14.535247], + [0.0, 1.092039, 0.839234, 1.228383, 1.533331, 1.276798], rtol=1e-3, ) assert len(output["ligand_wander"]) == 11 # TODO: very large as the multichain fix is not in yet assert_allclose( output["ligand_wander"][0][:6], - [0.0, 10.150182, 11.868109, 0.971329, 12.160156, 12.843338], + [0.0, 0.908097, 0.674262, 0.971328, 0.909263, 1.101882], rtol=1e-3, ) assert len(output["protein_2D_RMSD"]) == 11 @@ -77,6 +102,78 @@ def test_gather_rms_data_regression_skippednc(simulation_skipped_nc, hybrid_syst # TODO: very large as the multichain fix is not in yet assert_allclose( output["protein_2D_RMSD"][0][:6], - [30.620948, 31.158894, 1.045068, 30.735975, 30.999849, 31.102847], + [1.089747, 1.006143, 1.045068, 1.476353, 1.332893, 1.110507], rtol=1e-3, ) + + +def test_multichain_rmsd_shifting(simulation_skipped_nc, hybrid_system_skipped_pdb): + u = mda.Universe( + hybrid_system_skipped_pdb, + simulation_skipped_nc, + state_id=0, + format=FEReader, + ) + prot = u.select_atoms("protein") + # Do other transformations, but no shifting + unwrap_tr = unwrap(prot) + for frag in prot.fragments: + make_whole(frag, reference_atom=frag[0]) + align = Aligner(prot) + u.trajectory.add_transformations(unwrap_tr, align) + chains = [seg.atoms for seg in prot.segments] + assert len(chains) > 1, "Test requires multi-chain protein" + + # RMSD without shifting + r = rms.RMSD(prot) + r.run() + rmsd_no_shift = r.rmsd[:, 2] + assert np.max(np.diff(rmsd_no_shift[:20])) > 10 # expect jumps + u.trajectory.close() + + # RMSD with shifting + u2 = make_Universe(hybrid_system_skipped_pdb, simulation_skipped_nc, state=0) + prot2 = u2.select_atoms("protein") + R2 = rms.RMSD(prot2) + R2.run() + rmsd_shift = R2.rmsd[:, 2] + assert np.max(np.diff(rmsd_shift[:20])) < 2 # jumps should disappear + u2.trajectory.close() + + +def test_chain_radius_of_gyration_stable(simulation_skipped_nc, hybrid_system_skipped_pdb): + u = make_Universe(hybrid_system_skipped_pdb, simulation_skipped_nc, state=0) + + protein = u.select_atoms("protein") + chain = protein.segments[0].atoms + + rgs = [] + for ts in u.trajectory[:50]: + rgs.append(chain.radius_of_gyration()) + + # Chain should not explode or collapse due to PBC errors + assert np.std(rgs) < 2.0 + u.trajectory.close() + + +def test_rmsd_reference_is_first_frame(mda_universe): + u = mda_universe + prot = u.select_atoms("protein") + + _ = next(iter(u.trajectory)) # SAFE + ref = prot.positions.copy() + + rmsd = np.sqrt(((prot.positions - ref) ** 2).mean()) + assert rmsd == 0.0 + u.trajectory.close() + + +def test_ligand_com_continuity(mda_universe): + u = mda_universe + ligand = u.select_atoms("resname UNK") + + coms = [ligand.center_of_mass() for ts in islice(u.trajectory, 20)] + jumps = [np.linalg.norm(coms[i + 1] - coms[i]) for i in range(len(coms) - 1)] + + assert max(jumps) < 5.0 + u.trajectory.close() diff --git a/src/openfe_analysis/tests/test_transformations.py b/src/openfe_analysis/tests/test_transformations.py index 3081acb..2f6c267 100644 --- a/src/openfe_analysis/tests/test_transformations.py +++ b/src/openfe_analysis/tests/test_transformations.py @@ -6,7 +6,7 @@ from openfe_analysis import FEReader from openfe_analysis.transformations import ( Aligner, - Minimiser, + ClosestImageShift, NoJump, ) @@ -23,10 +23,10 @@ def universe(hybrid_system_skipped_pdb, simulation_skipped_nc): u.trajectory.close() -def test_minimiser(universe): +def test_closest_image_shift(universe): prot = universe.select_atoms("protein and name CA") lig = universe.select_atoms("resname UNK") - m = Minimiser(prot, lig) + m = ClosestImageShift(prot, [lig]) universe.trajectory.add_transformations(m) d = mda.lib.distances.calc_bonds(prot.center_of_mass(), lig.center_of_mass()) @@ -54,6 +54,7 @@ def test_nojump(hybrid_system_pdb, simulation_nc): # without the transformation, the y coordinate would jump up to ~81.86 ref = np.array([31.79594626, 52.14568866, 30.64103877]) assert prot.center_of_mass() == pytest.approx(ref, abs=0.01) + universe.trajectory.close() def test_aligner(universe): diff --git a/src/openfe_analysis/transformations.py b/src/openfe_analysis/transformations.py index 1a4ef34..d9e635d 100644 --- a/src/openfe_analysis/transformations.py +++ b/src/openfe_analysis/transformations.py @@ -8,6 +8,7 @@ import MDAnalysis as mda import numpy as np from MDAnalysis.analysis.align import rotation_matrix +from MDAnalysis.lib.mdamath import triclinic_vectors from MDAnalysis.transformations.base import TransformationBase from numpy import typing as npt @@ -44,31 +45,27 @@ def _transform(self, ts): return ts -class Minimiser(TransformationBase): - """Minimises the difference from ags to central_ag by choosing image - - This transformation will translate any AtomGroup in *ags* in multiples of - the box vectors in order to minimise the distance between the center of mass - to the center of mass of each ag. +class ClosestImageShift(TransformationBase): + """ + PBC-safe transformation that shifts one or more target AtomGroups + so that their COM is in the closest image relative to a reference AtomGroup. + Works for any box type (triclinic or orthorhombic). """ - central_ag: mda.AtomGroup - other_ags: list[mda.AtomGroup] - - def __init__(self, central_ag: mda.AtomGroup, *ags): + def __init__(self, reference: mda.AtomGroup, targets: list[mda.AtomGroup]): super().__init__() - self.central_ag = central_ag - self.other_ags = ags + self.reference = reference + self.targets = targets def _transform(self, ts): - center = self.central_ag.center_of_mass() - box = self.central_ag.dimensions[:3] + center = self.reference.center_of_mass() + box = triclinic_vectors(ts.dimensions) - for ag in self.other_ags: + for ag in self.targets: vec = ag.center_of_mass() - center - - # this only works for orthogonal boxes - ag.positions -= np.rint(vec / box) * box + frac = np.linalg.solve(box.T, vec) # fractional coordinates + shift = np.dot(np.round(frac), box) # nearest image, then compute shift + ag.positions -= shift return ts @@ -87,7 +84,8 @@ class Aligner(TransformationBase): def __init__(self, ref_ag: mda.AtomGroup): super().__init__() self.ref_idx = ref_ag.ix - self.ref_pos = ref_ag.positions + # Would this copy be safer? + self.ref_pos = ref_ag.positions.copy() self.weights = np.asarray(ref_ag.masses, dtype=np.float64) self.weights /= np.mean(self.weights) # normalise weights # remove COM shift from reference positions