diff --git a/src/openfe_analysis/__init__.py b/src/openfe_analysis/__init__.py index b65c451..e751fd2 100644 --- a/src/openfe_analysis/__init__.py +++ b/src/openfe_analysis/__init__.py @@ -3,6 +3,6 @@ from .reader import FEReader from .transformations import ( Aligner, - Minimiser, + ClosestImageShift, NoJump, ) diff --git a/src/openfe_analysis/reader.py b/src/openfe_analysis/reader.py index b9ac9cd..83cca1c 100644 --- a/src/openfe_analysis/reader.py +++ b/src/openfe_analysis/reader.py @@ -1,4 +1,5 @@ -from typing import Optional +import pathlib +from typing import Literal, Optional import netCDF4 as nc import numpy as np @@ -52,16 +53,20 @@ def _determine_iteration_dt(dataset) -> float: class FEReader(ReaderBase): - """A MDAnalysis Reader for NetCDF files created by + """ + MDAnalysis Reader for NetCDF files created by `openmmtools.multistate.MultiStateReporter` - Looks along a multistate NetCDF file along one of two axes: - - constant state/lambda (varying replica) - - constant replica (varying lambda) + Provides a 1D trajectory along either: + + - constant Hamiltonian state (`index_method="state"`) + - constant replica (`index_method="replica"`) + + selected via the `index` argument. """ - _state_id: Optional[int] - _replica_id: Optional[int] + _multistate_index: Optional[int] + _index_method: Optional[str] _frame_index: int _dataset: nc.Dataset _dataset_owner: bool @@ -70,35 +75,27 @@ class FEReader(ReaderBase): units = {"time": "ps", "length": "nanometer"} - def __init__(self, filename, convert_units=True, state_id=None, replica_id=None, **kwargs): + def __init__( + self, + filename: str | pathlib.Path | nc.Dataset, + *, + index: int, + index_method: Literal["state", "replica"] = "state", + convert_units: bool = True, + **kwargs, + ): """ Parameters ---------- filename : pathlike or nc.Dataset - path to the .nc file + Path to the .nc file or an open Dataset. + index : int + Index of the state or replica to extract. May be negative. + index_method : {"state", "replica"}, default "state" + Whether `index` refers to a Hamiltonian state or a replica. convert_units : bool - convert positions to Angstrom - state_id : Optional[int] - The Hamiltonian state index to extract. Must be defined if - ``replica_id`` is not defined. May be negative (see notes below). - replica_id : Optional[int] - The replica index to extract. Must be defined if ``state_id`` - is not defined. May be negative (see notes below). - - Notes - ----- - A negative index may be passed to either ``state_id`` or - ``replica_id``. This will be interpreted as indexing in reverse - starting from the last state/replica. For example, passing a - value of -2 for ``replica_id`` will select the before last replica. + Convert positions to Angstrom. """ - if not ((state_id is None) ^ (replica_id is None)): - raise ValueError( - "Specify one and only one of state or replica, " - f"got state id={state_id} " - f"replica_id={replica_id}" - ) - super().__init__(filename, convert_units, **kwargs) if isinstance(filename, nc.Dataset): @@ -108,15 +105,18 @@ def __init__(self, filename, convert_units=True, state_id=None, replica_id=None, self._dataset = nc.Dataset(filename) self._dataset_owner = True - # Handle the negative ID case - if state_id is not None and state_id < 0: - state_id = range(self._dataset.dimensions["state"].size)[state_id] + if index_method not in {"state", "replica"}: + raise ValueError(f"index_method must be 'state' or 'replica', got {index_method}") + + self._index_method = index_method - if replica_id is not None and replica_id < 0: - replica_id = range(self._dataset.dimensions["replica"].size)[replica_id] + # Handle the negative ID case + if index_method == "state": + size = self._dataset.dimensions["state"].size + else: + size = self._dataset.dimensions["replica"].size - self._state_id = state_id - self._replica_id = replica_id + self._multistate_index = index % size self._n_atoms = self._dataset.dimensions["atom"].size self.ts = Timestep(self._n_atoms) @@ -131,6 +131,10 @@ def _format_hint(thing) -> bool: # can pass raw nc datasets through to reduce open/close operations return isinstance(thing, nc.Dataset) + @property + def multistate_index(self) -> int: + return self._multistate_index + @property def n_atoms(self) -> int: return self._n_atoms @@ -139,6 +143,10 @@ def n_atoms(self) -> int: def n_frames(self) -> int: return len(self._frames) + @property + def index_method(self) -> str: + return self._index_method + @staticmethod def parse_n_atoms(filename, **kwargs) -> int: with nc.Dataset(filename) as ds: @@ -153,17 +161,19 @@ def _read_next_timestep(self, ts=None) -> Timestep: def _read_frame(self, frame: int) -> Timestep: self._frame_index = frame - if self._state_id is not None: + frame = self._frames[self._frame_index] + + if self._index_method == "state": rep = multistate._state_to_replica( - self._dataset, self._state_id, self._frames[self._frame_index] + self._dataset, + self._multistate_index, + frame, ) else: - rep = self._replica_id + rep = self._multistate_index - pos = multistate._replica_positions_at_frame( - self._dataset, rep, self._frames[self._frame_index] - ) - dim = multistate._get_unitcell(self._dataset, rep, self._frames[self._frame_index]) + pos = multistate._replica_positions_at_frame(self._dataset, rep, frame) + dim = multistate._get_unitcell(self._dataset, rep, frame) if pos is None: errmsg = ( diff --git a/src/openfe_analysis/rmsd.py b/src/openfe_analysis/rmsd.py index 1519d64..9f04c97 100644 --- a/src/openfe_analysis/rmsd.py +++ b/src/openfe_analysis/rmsd.py @@ -7,10 +7,12 @@ import numpy as np import tqdm from MDAnalysis.analysis import rms +from MDAnalysis.lib.mdamath import make_whole +from MDAnalysis.transformations import unwrap from numpy import typing as npt from .reader import FEReader -from .transformations import Aligner, Minimiser, NoJump +from .transformations import Aligner, ClosestImageShift, NoJump def make_Universe(top: pathlib.Path, trj: nc.Dataset, state: int) -> mda.Universe: @@ -34,24 +36,29 @@ def make_Universe(top: pathlib.Path, trj: nc.Dataset, state: int) -> mda.Univers u = mda.Universe( top, trj, - state_id=state, + index=state, + view="state", format=FEReader, ) prot = u.select_atoms("protein and name CA") ligand = u.select_atoms("resname UNK") if prot: - # if there's a protein in the system: - # - make the protein not jump periodic images between frames - # - put the ligand in the closest periodic image as the protein - # - align everything to minimise protein RMSD - nope = NoJump(prot) - minnie = Minimiser(prot, ligand) + # Unwrap all atoms + unwrap_tr = unwrap(prot) + + # Shift chains + ligand + chains = [seg.atoms for seg in prot.segments] + shift = ClosestImageShift(chains[0], [*chains[1:], ligand]) + # Make each protein chain whole + for frag in prot.fragments: + make_whole(frag, reference_atom=frag[0]) + align = Aligner(prot) u.trajectory.add_transformations( - nope, - minnie, + unwrap_tr, + shift, align, ) else: @@ -129,9 +136,9 @@ def gather_rms_data( # TODO: Some smart guard to avoid allocating a silly amount of memory? prot2d = np.empty((len(u.trajectory[::skip]), len(prot), 3), dtype=np.float32) - prot_start = prot.positions - # prot_weights = prot.masses / np.mean(prot.masses) - ligand_start = ligand.positions + # Would this copy be safer? + prot_start = prot.positions.copy() + ligand_start = ligand.positions.copy() ligand_initial_com = ligand.center_of_mass() ligand_weights = ligand.masses / np.mean(ligand.masses) diff --git a/src/openfe_analysis/tests/test_reader.py b/src/openfe_analysis/tests/test_reader.py index d6a7ca1..9cba171 100644 --- a/src/openfe_analysis/tests/test_reader.py +++ b/src/openfe_analysis/tests/test_reader.py @@ -44,7 +44,7 @@ def test_determine_position_indices_warns_for_old_nc(tmp_path): def test_universe_creation(simulation_nc, hybrid_system_pdb): - u = mda.Universe(hybrid_system_pdb, simulation_nc, format=FEReader, state_id=0) + u = mda.Universe(hybrid_system_pdb, simulation_nc, format=FEReader, index=0) # Check that a Universe exists assert u @@ -92,7 +92,7 @@ def test_universe_creation(simulation_nc, hybrid_system_pdb): def test_universe_from_nc_file(simulation_skipped_nc, hybrid_system_skipped_pdb): with nc.Dataset(simulation_skipped_nc) as ds: - u = mda.Universe(hybrid_system_skipped_pdb, ds, format="MultiStateReporter", state_id=0) + u = mda.Universe(hybrid_system_skipped_pdb, ds, format="MultiStateReporter", index=0) assert u assert len(u.atoms) == 9178 @@ -105,7 +105,7 @@ def test_universe_creation_noconversion(simulation_skipped_nc, hybrid_system_ski hybrid_system_skipped_pdb, simulation_skipped_nc, format=FEReader, - state_id=0, + index=0, convert_units=False, ) assert u.trajectory.ts.frame == 0 @@ -124,20 +124,23 @@ def test_universe_creation_noconversion(simulation_skipped_nc, hybrid_system_ski def test_fereader_negative_state(simulation_skipped_nc, hybrid_system_skipped_pdb): - u = mda.Universe(hybrid_system_skipped_pdb, simulation_skipped_nc, format=FEReader, state_id=-1) + u = mda.Universe(hybrid_system_skipped_pdb, simulation_skipped_nc, format=FEReader, index=-1) - assert u.trajectory._state_id == 10 - assert u.trajectory._replica_id is None + assert u.trajectory._multistate_index == 10 u.trajectory.close() def test_fereader_negative_replica(simulation_skipped_nc, hybrid_system_skipped_pdb): u = mda.Universe( - hybrid_system_skipped_pdb, simulation_skipped_nc, format=FEReader, replica_id=-2 + hybrid_system_skipped_pdb, + simulation_skipped_nc, + format=FEReader, + index=-2, + index_method="replica", ) - assert u.trajectory._state_id is None - assert u.trajectory._replica_id == 9 + assert u.trajectory._multistate_index == 9 + assert u.trajectory._index_method == "replica" u.trajectory.close() @@ -145,30 +148,35 @@ def test_fereader_negative_replica(simulation_skipped_nc, hybrid_system_skipped_ def test_fereader_replica_state_id_error( simulation_skipped_nc, hybrid_system_skipped_pdb, rep_id, state_id ): - with pytest.raises(ValueError, match="Specify one and only one"): + with pytest.raises(ValueError, match="index_method must be 'state'"): _ = mda.Universe( hybrid_system_skipped_pdb, simulation_skipped_nc, format=FEReader, - state_id=state_id, - replica_id=rep_id, + index=0, + index_method="wrong", ) def test_simulation_skipped_nc(simulation_skipped_nc, hybrid_system_skipped_pdb): + from MDAnalysis.transformations import wrap + u = mda.Universe( hybrid_system_skipped_pdb, simulation_skipped_nc, format=FEReader, - replica_id=0, + index=0, + index_method="replica", ) + # Wrap all atoms inside the simulation box + u.trajectory.add_transformations(wrap(u.atoms)) assert len(u.trajectory) == 51 assert u.trajectory.n_frames == 51 assert u.trajectory.dt == 100 times = np.arange(0, 5001, 100) for inx, ts in enumerate(u.trajectory): assert ts.time == times[inx] - # Positions are not all zero since PBC is not removed + assert np.all(u.atoms.positions > 0) assert np.any(u.atoms.positions != 0) with pytest.raises(mda.exceptions.NoDataError, match="This Timestep has no velocities"): u.atoms.velocities diff --git a/src/openfe_analysis/tests/test_rmsd.py b/src/openfe_analysis/tests/test_rmsd.py index e70dd08..a9e8b09 100644 --- a/src/openfe_analysis/tests/test_rmsd.py +++ b/src/openfe_analysis/tests/test_rmsd.py @@ -1,9 +1,34 @@ +from itertools import islice + +import MDAnalysis as mda import netCDF4 as nc import numpy as np import pytest +from MDAnalysis.analysis import rms +from MDAnalysis.lib.mdamath import make_whole +from MDAnalysis.transformations import unwrap from numpy.testing import assert_allclose -from openfe_analysis.rmsd import gather_rms_data +from openfe_analysis.reader import FEReader +from openfe_analysis.rmsd import gather_rms_data, make_Universe +from openfe_analysis.transformations import Aligner + + +@pytest.fixture +def mda_universe(hybrid_system_skipped_pdb, simulation_skipped_nc): + """ + Safely create and destroy an MDAnalysis Universe. + + Guarantees: + - NetCDF file is opened exactly once + """ + u = make_Universe( + hybrid_system_skipped_pdb, + simulation_skipped_nc, + state=0, + ) + yield u + u.trajectory.close() def test_gather_rms_data_regression(simulation_nc, hybrid_system_pdb): @@ -51,24 +76,24 @@ def test_gather_rms_data_regression_skippednc(simulation_skipped_nc, hybrid_syst assert_allclose(output["time(ps)"], np.arange(0, 5001, 100)) assert len(output["protein_RMSD"]) == 11 - # TODO: RMSD is very large as the multichain fix is not in yet + # RMSD is low for this multichain protein assert_allclose( output["protein_RMSD"][0][:6], - [0, 30.620948, 31.158894, 1.045068, 30.735975, 30.999849], + [0, 1.089747, 1.006143, 1.045068, 1.476353, 1.332893], rtol=1e-3, ) assert len(output["ligand_RMSD"]) == 11 # TODO: RMSD is very large as the multichain fix is not in yet assert_allclose( output["ligand_RMSD"][0][:6], - [0.0, 12.607834, 13.882825, 1.228384, 14.129542, 14.535247], + [0.0, 1.092039, 0.839234, 1.228383, 1.533331, 1.276798], rtol=1e-3, ) assert len(output["ligand_wander"]) == 11 # TODO: very large as the multichain fix is not in yet assert_allclose( output["ligand_wander"][0][:6], - [0.0, 10.150182, 11.868109, 0.971329, 12.160156, 12.843338], + [0.0, 0.908097, 0.674262, 0.971328, 0.909263, 1.101882], rtol=1e-3, ) assert len(output["protein_2D_RMSD"]) == 11 @@ -77,6 +102,78 @@ def test_gather_rms_data_regression_skippednc(simulation_skipped_nc, hybrid_syst # TODO: very large as the multichain fix is not in yet assert_allclose( output["protein_2D_RMSD"][0][:6], - [30.620948, 31.158894, 1.045068, 30.735975, 30.999849, 31.102847], + [1.089747, 1.006143, 1.045068, 1.476353, 1.332893, 1.110507], rtol=1e-3, ) + + +def test_multichain_rmsd_shifting(simulation_skipped_nc, hybrid_system_skipped_pdb): + u = mda.Universe( + hybrid_system_skipped_pdb, + simulation_skipped_nc, + index=0, + format=FEReader, + ) + prot = u.select_atoms("protein") + # Do other transformations, but no shifting + unwrap_tr = unwrap(prot) + for frag in prot.fragments: + make_whole(frag, reference_atom=frag[0]) + align = Aligner(prot) + u.trajectory.add_transformations(unwrap_tr, align) + chains = [seg.atoms for seg in prot.segments] + assert len(chains) > 1, "Test requires multi-chain protein" + + # RMSD without shifting + r = rms.RMSD(prot) + r.run() + rmsd_no_shift = r.rmsd[:, 2] + assert np.max(np.diff(rmsd_no_shift[:20])) > 10 # expect jumps + u.trajectory.close() + + # RMSD with shifting + u2 = make_Universe(hybrid_system_skipped_pdb, simulation_skipped_nc, state=0) + prot2 = u2.select_atoms("protein") + R2 = rms.RMSD(prot2) + R2.run() + rmsd_shift = R2.rmsd[:, 2] + assert np.max(np.diff(rmsd_shift[:20])) < 2 # jumps should disappear + u2.trajectory.close() + + +def test_chain_radius_of_gyration_stable(simulation_skipped_nc, hybrid_system_skipped_pdb): + u = make_Universe(hybrid_system_skipped_pdb, simulation_skipped_nc, state=0) + + protein = u.select_atoms("protein") + chain = protein.segments[0].atoms + + rgs = [] + for ts in u.trajectory[:50]: + rgs.append(chain.radius_of_gyration()) + + # Chain should not explode or collapse due to PBC errors + assert np.std(rgs) < 2.0 + u.trajectory.close() + + +def test_rmsd_reference_is_first_frame(mda_universe): + u = mda_universe + prot = u.select_atoms("protein") + + _ = next(iter(u.trajectory)) # SAFE + ref = prot.positions.copy() + + rmsd = np.sqrt(((prot.positions - ref) ** 2).mean()) + assert rmsd == 0.0 + u.trajectory.close() + + +def test_ligand_com_continuity(mda_universe): + u = mda_universe + ligand = u.select_atoms("resname UNK") + + coms = [ligand.center_of_mass() for ts in islice(u.trajectory, 20)] + jumps = [np.linalg.norm(coms[i + 1] - coms[i]) for i in range(len(coms) - 1)] + + assert max(jumps) < 5.0 + u.trajectory.close() diff --git a/src/openfe_analysis/tests/test_transformations.py b/src/openfe_analysis/tests/test_transformations.py index 3081acb..a2bf9cd 100644 --- a/src/openfe_analysis/tests/test_transformations.py +++ b/src/openfe_analysis/tests/test_transformations.py @@ -6,7 +6,7 @@ from openfe_analysis import FEReader from openfe_analysis.transformations import ( Aligner, - Minimiser, + ClosestImageShift, NoJump, ) @@ -17,16 +17,16 @@ def universe(hybrid_system_skipped_pdb, simulation_skipped_nc): hybrid_system_skipped_pdb, simulation_skipped_nc, format="MultiStateReporter", - state_id=0, + index=0, ) yield u u.trajectory.close() -def test_minimiser(universe): +def test_closest_image_shift(universe): prot = universe.select_atoms("protein and name CA") lig = universe.select_atoms("resname UNK") - m = Minimiser(prot, lig) + m = ClosestImageShift(prot, [lig]) universe.trajectory.add_transformations(m) d = mda.lib.distances.calc_bonds(prot.center_of_mass(), lig.center_of_mass()) @@ -41,7 +41,7 @@ def test_nojump(hybrid_system_pdb, simulation_nc): hybrid_system_pdb, simulation_nc, format="MultiStateReporter", - state_id=2, + index=2, ) # find frame where protein would teleport across boundary and check it prot = universe.select_atoms("protein and name CA") @@ -54,6 +54,7 @@ def test_nojump(hybrid_system_pdb, simulation_nc): # without the transformation, the y coordinate would jump up to ~81.86 ref = np.array([31.79594626, 52.14568866, 30.64103877]) assert prot.center_of_mass() == pytest.approx(ref, abs=0.01) + universe.trajectory.close() def test_aligner(universe): diff --git a/src/openfe_analysis/tests/utils/test_multistate.py b/src/openfe_analysis/tests/utils/test_multistate.py index 4b4cdf3..cb2d8db 100644 --- a/src/openfe_analysis/tests/utils/test_multistate.py +++ b/src/openfe_analysis/tests/utils/test_multistate.py @@ -7,9 +7,11 @@ from openfe_analysis import __version__ from openfe_analysis.utils.multistate import ( _create_new_dataset, + _determine_position_indices, _get_unitcell, _replica_positions_at_frame, _state_to_replica, + trajectory_from_multistate, ) @@ -41,6 +43,17 @@ def test_replica_positions_at_frame(dataset): ) +def test_determine_position_indices_inconsistent(monkeypatch, dataset): + # Force np.diff to return inconsistent spacing + def fake_diff(x): + return np.array([1, 2, 1]) + + monkeypatch.setattr(np, "diff", fake_diff) + + with pytest.raises(ValueError, match="consistent frame rate"): + _determine_position_indices(dataset) + + def test_create_new_dataset(tmp_path): file_path = tmp_path / "foo.nc" with _create_new_dataset(file_path, 100, title="bar") as ds: @@ -91,3 +104,80 @@ def test_simulation_skipped_nc_no_positions_box_vectors_frame1( ): assert _get_unitcell(skipped_dataset, 1, 1) is None assert skipped_dataset.variables["positions"][1][0].mask.all() + + +def test_trajectory_invalid_index_method(tmp_path): + dummy_input = tmp_path / "dummy.nc" + dummy_output = tmp_path / "out.nc" + + # Create minimal NetCDF + ds = nc.Dataset(dummy_input, "w", format="NETCDF3_64BIT_OFFSET") + ds.createDimension("atom", 1) + ds.createDimension("frame", 1) + pos = ds.createVariable("positions", "f4", ("frame", "atom")) + pos[:] = 0.0 + ds.close() + + with pytest.raises(ValueError, match="index_method must be 'state' or 'replica'"): + trajectory_from_multistate(dummy_input, dummy_output, index=0, index_method="foo") + + +def test_trajectory_frame_without_positions(tmp_path): + dummy_input = tmp_path / "dummy.nc" + dummy_output = tmp_path / "out.nc" + + # Minimal NetCDF file + with nc.Dataset(dummy_input, "w", format="NETCDF4") as ds: + ds.createDimension("frame", 2) # at least 2 frames + ds.createDimension("replica", 1) + ds.createDimension("atom", 1) + ds.createDimension("spatial", 3) + ds.createDimension("iteration", 2) # at least 2 iterations + + positions = ds.createVariable("positions", "f4", ("frame", "replica", "atom", "spatial")) + positions.units = "nanometer" + positions[:] = np.ma.masked # All positions masked + + # Expect RuntimeError due to missing positions + with pytest.raises(RuntimeError, match="Frame without positions encountered"): + trajectory_from_multistate(dummy_input, dummy_output, index=0, index_method="replica") + + +def test_trajectory_success(tmp_path): + dummy_input = tmp_path / "dummy.nc" + dummy_output = tmp_path / "out.nc" + + # Minimal valid NetCDF with positions, box vectors, and iteration dimension + ds = nc.Dataset(dummy_input, "w", format="NETCDF3_64BIT_OFFSET") + ds.createDimension("atom", 2) + ds.createDimension("frame", 2) + ds.createDimension("replica", 2) + ds.createDimension("state", 2) + ds.createDimension("spatial", 3) + ds.createDimension("iteration", 2) # Added for _determine_position_indices + + # positions: frame x replica x atom x spatial + pos = ds.createVariable("positions", "f4", ("frame", "replica", "atom", "spatial")) + pos.units = "nanometer" + pos[:] = np.zeros((2, 2, 2, 3), dtype=np.float32) + + # box_vectors: frame x replica x 3 x 3 + bv = ds.createVariable("box_vectors", "f8", ("frame", "replica", "spatial", "spatial")) + bv.units = "nanometer" + bv[:] = np.tile(np.eye(3), (2, 2, 1, 1)) + + # states: frame x replica + st = ds.createVariable("states", "i4", ("frame", "replica")) + st[:] = np.array([[0, 1], [0, 1]], dtype=np.int32) # replica 0->state 0, replica1->state1 + + ds.close() + + # Call function for replica extraction + trajectory_from_multistate(dummy_input, dummy_output, index=1, index_method="replica") + + # Check output file exists and contains positions + out_ds = nc.Dataset(dummy_output, "r") + assert out_ds.variables["coordinates"].shape == (2, 2, 3) + assert out_ds.variables["cell_lengths"].shape == (2, 3) + assert out_ds.variables["cell_angles"].shape == (2, 3) + out_ds.close() diff --git a/src/openfe_analysis/transformations.py b/src/openfe_analysis/transformations.py index 1a4ef34..bb3fc9f 100644 --- a/src/openfe_analysis/transformations.py +++ b/src/openfe_analysis/transformations.py @@ -8,6 +8,7 @@ import MDAnalysis as mda import numpy as np from MDAnalysis.analysis.align import rotation_matrix +from MDAnalysis.lib.mdamath import triclinic_vectors from MDAnalysis.transformations.base import TransformationBase from numpy import typing as npt @@ -44,31 +45,27 @@ def _transform(self, ts): return ts -class Minimiser(TransformationBase): - """Minimises the difference from ags to central_ag by choosing image - - This transformation will translate any AtomGroup in *ags* in multiples of - the box vectors in order to minimise the distance between the center of mass - to the center of mass of each ag. +class ClosestImageShift(TransformationBase): + """ + PBC-safe transformation that shifts one or more target AtomGroups + so that their COM is in the closest image relative to a reference AtomGroup. + Works for any box type (triclinic or orthorhombic). """ - central_ag: mda.AtomGroup - other_ags: list[mda.AtomGroup] - - def __init__(self, central_ag: mda.AtomGroup, *ags): + def __init__(self, reference: mda.AtomGroup, targets: list[mda.AtomGroup]): super().__init__() - self.central_ag = central_ag - self.other_ags = ags + self.reference = reference + self.targets = targets def _transform(self, ts): - center = self.central_ag.center_of_mass() - box = self.central_ag.dimensions[:3] + center = self.reference.center_of_mass() + box = triclinic_vectors(ts.dimensions) - for ag in self.other_ags: + for ag in self.targets: vec = ag.center_of_mass() - center - - # this only works for orthogonal boxes - ag.positions -= np.rint(vec / box) * box + frac = np.linalg.solve(box.T, vec) # fractional coordinates + shift = np.dot(np.round(frac), box) # nearest image + ag.positions -= shift return ts @@ -87,7 +84,8 @@ class Aligner(TransformationBase): def __init__(self, ref_ag: mda.AtomGroup): super().__init__() self.ref_idx = ref_ag.ix - self.ref_pos = ref_ag.positions + # Would this copy be safer? + self.ref_pos = ref_ag.positions.copy() self.weights = np.asarray(ref_ag.masses, dtype=np.float64) self.weights /= np.mean(self.weights) # normalise weights # remove COM shift from reference positions diff --git a/src/openfe_analysis/utils/multistate.py b/src/openfe_analysis/utils/multistate.py index f885e2b..3816889 100644 --- a/src/openfe_analysis/utils/multistate.py +++ b/src/openfe_analysis/utils/multistate.py @@ -1,6 +1,6 @@ import warnings from pathlib import Path -from typing import Optional, Tuple +from typing import Literal, Optional, Tuple import netCDF4 as nc import numpy as np @@ -213,14 +213,12 @@ def _get_unitcell( def trajectory_from_multistate( input_file: Path, output_file: Path, - state_number: Optional[int] = None, - replica_number: Optional[int] = None, + index: int, + index_method: Literal["state", "replica"] = "state", ) -> None: """ - Extract a state's trajectory (in an AMBER compliant format) - from a MultiState sampler generated NetCDF file. - - Either a state or replica index must be supplied, but not both! + Extract a 1D trajectory (in an AMBER compliant format) from a MultiState + sampler generated NetCDF file. Parameters ---------- @@ -228,54 +226,52 @@ def trajectory_from_multistate( Path to the input MultiState sampler generated NetCDF file. output_file : path.Pathlib Path to the AMBER-style NetCDF trajectory to be written. - state_number : int, optional - Index of the state to write out to the trajectory. - replica_number : int, optional - Index of the replica to write out + index : int + Index of the state or replica to extract. May be negative. + index_method : {"state", "replica"}, default "state" + Whether `index` refers to a Hamiltonian state or a replica. """ - if not ((state_number is None) ^ (replica_number is None)): - raise ValueError( - "Supply either state or replica number, " - f"got state_number={state_number} " - f"and replica_number={replica_number}" - ) + if index_method not in {"state", "replica"}: + raise ValueError(f"index_method must be 'state' or 'replica', got {index_method}") # Open MultiState NC file and get number of atoms and frames multistate = nc.Dataset(input_file, "r") n_atoms = len(multistate.variables["positions"][0][0]) - n_replicas = len(multistate.variables["positions"][0]) frame_list = _determine_position_indices(multistate) n_frames = len(frame_list) - # Sanity check - if state_number is not None and (state_number + 1 > n_replicas): - # Note this works for now, but when we have more states - # than replicas (e.g. SAMS) this won't really work - errmsg = "State does not exist" - raise ValueError(errmsg) + # Normalize index (handles negatives) + if index_method == "state": + size = multistate.dimensions["state"].size + else: + size = multistate.dimensions["replica"].size + + index = index % size # Create output AMBER NetCDF convention file traj = _create_new_dataset( - output_file, n_atoms, title=f"state {state_number} trajectory from {input_file}" + output_file, + n_atoms, + title=f"{index_method} {index} trajectory from {input_file}", ) - replica_id: int = -1 - if replica_number is not None: - replica_id = replica_number + replica_id: int = index if index_method == "replica" else -1 # Loopy de loop over n_frames so that the new Dataset # is just 0 -> n_frames for frame in range(n_frames): - if state_number is not None: - replica_id = _state_to_replica(multistate, state_number, frame_list[frame]) + if index_method == "state": + replica_id = _state_to_replica(multistate, index, frame_list[frame]) + + pos = _replica_positions_at_frame(multistate, replica_id, frame_list[frame]) + if pos is None: + raise RuntimeError("Frame without positions encountered") + + traj.variables["coordinates"][frame] = pos.to("angstrom").m - traj.variables["coordinates"][frame] = ( - _replica_positions_at_frame(multistate, replica_id, frame_list[frame]).to("angstrom").m - ) unitcell = _get_unitcell(multistate, replica_id, frame_list[frame]) traj.variables["cell_lengths"][frame] = unitcell[:3] traj.variables["cell_angles"][frame] = unitcell[3:] - # Make sure to clean up when you are done multistate.close() traj.close()