Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
4b257b2
Fix stab at fixing multi chain RMSD analysis
hannahbaumann Dec 18, 2025
236ff72
Some updates
hannahbaumann Dec 18, 2025
d274b0c
Add tests
hannahbaumann Dec 18, 2025
f3634dd
Some fixes
hannahbaumann Dec 19, 2025
e92adb3
Add another test
hannahbaumann Dec 19, 2025
b528ca2
Move some tests to use skipped smaller data
hannahbaumann Jan 16, 2026
a477bc1
Test out zenodo dealings
hannahbaumann Jan 16, 2026
ad84082
Try to improbe speed
hannahbaumann Jan 16, 2026
8ba8087
Try removing locking
hannahbaumann Jan 16, 2026
ead7951
Run downloads before the testing to have a single download for all th…
hannahbaumann Jan 19, 2026
f898a35
add import pooch
hannahbaumann Jan 19, 2026
c675a5c
Test out more
hannahbaumann Jan 19, 2026
88e456d
Ensure datasets get closed
hannahbaumann Jan 19, 2026
73a8e4d
Move to per test download again
hannahbaumann Jan 19, 2026
43aaca2
Remove commented out lines
hannahbaumann Jan 21, 2026
c165525
Test out adding an extra slash
hannahbaumann Jan 21, 2026
5f17770
Switch to all version doi
hannahbaumann Jan 21, 2026
c28286e
Download url directly
hannahbaumann Jan 21, 2026
197b6ba
Small fix
hannahbaumann Jan 21, 2026
b45390a
Change url
hannahbaumann Jan 21, 2026
1d70936
Add missing s
hannahbaumann Jan 21, 2026
20084c3
Switch to api url
hannahbaumann Jan 21, 2026
1a1c916
Revert to old cli
hannahbaumann Jan 22, 2026
59c7392
Update cli.py
hannahbaumann Jan 22, 2026
5e135ab
Update cli.py
hannahbaumann Jan 22, 2026
a9a8780
Update tests for new results
hannahbaumann Jan 23, 2026
92af45b
Change shift to enable other boxes
hannahbaumann Jan 23, 2026
8ea3585
Update multichain code
hannahbaumann Jan 26, 2026
220d504
Add ligand in shifting
hannahbaumann Jan 26, 2026
8c44cb2
USe new shift class instead of old minimiser since that one is no lon…
hannahbaumann Jan 26, 2026
d13495f
Update some tests
hannahbaumann Jan 26, 2026
c34c97c
Update conftest
hannahbaumann Jan 26, 2026
0161673
Update to v2
hannahbaumann Jan 26, 2026
9b6ca69
Update tests
hannahbaumann Jan 26, 2026
1d5c849
Update rmsd test, currently large rmsd till rmsd fix comes in
hannahbaumann Jan 26, 2026
f4e88e2
Make last test pass
hannahbaumann Jan 26, 2026
bd0c8ee
Switch to zenodo fetch
hannahbaumann Jan 26, 2026
ba4c912
remove lines
hannahbaumann Jan 26, 2026
157c02f
Update tests with large errors multichain failure
hannahbaumann Jan 27, 2026
98ea023
Apply suggestion from @hannahbaumann
hannahbaumann Jan 28, 2026
c5b2d70
Reuse zenodo specification
hannahbaumann Jan 28, 2026
54576ab
reorder install
hannahbaumann Jan 28, 2026
3aa52a5
Small fix
hannahbaumann Jan 28, 2026
ff6991a
Remove flaky retries
hannahbaumann Jan 28, 2026
7a30f69
Small fix
hannahbaumann Jan 28, 2026
ac1fe7b
Merge in the fix flakyness PR and update tests
hannahbaumann Jan 28, 2026
67a0913
Add wrapping to get positions to be greater than 0
hannahbaumann Jan 29, 2026
e706b11
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2026
6c1a6c5
Apply suggestion from @hannahbaumann
hannahbaumann Feb 2, 2026
f4637f4
Remove unnecessary make_whole
hannahbaumann Feb 2, 2026
e7d6935
Merge branch 'main' into fix_rmsd_multichain
hannahbaumann Feb 3, 2026
deb5126
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 3, 2026
fa3227e
Merge branch 'main' into fix_rmsd_multichain
hannahbaumann Feb 6, 2026
43eb039
Small fix
hannahbaumann Feb 6, 2026
73fe2ee
Merge branch 'main' into fix_rmsd_multichain
hannahbaumann Feb 6, 2026
7be4c53
Update tests
hannahbaumann Feb 10, 2026
68a2aab
Apply suggestion from @hannahbaumann
hannahbaumann Feb 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/openfe_analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
from .reader import FEReader
from .transformations import (
Aligner,
Minimiser,
ClosestImageShift,
NoJump,
)
30 changes: 18 additions & 12 deletions src/openfe_analysis/rmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import numpy as np
import tqdm
from MDAnalysis.analysis import rms
from MDAnalysis.lib.mdamath import make_whole
from MDAnalysis.transformations import unwrap
from numpy import typing as npt

from .reader import FEReader
from .transformations import Aligner, Minimiser, NoJump
from .transformations import Aligner, ClosestImageShift, NoJump


def make_Universe(top: pathlib.Path, trj: nc.Dataset, state: int) -> mda.Universe:
Expand Down Expand Up @@ -41,17 +43,21 @@ def make_Universe(top: pathlib.Path, trj: nc.Dataset, state: int) -> mda.Univers
ligand = u.select_atoms("resname UNK")

if prot:
# if there's a protein in the system:
# - make the protein not jump periodic images between frames
# - put the ligand in the closest periodic image as the protein
# - align everything to minimise protein RMSD
nope = NoJump(prot)
minnie = Minimiser(prot, ligand)
# Unwrap all atoms
unwrap_tr = unwrap(prot)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I changed the NoJump to unwrap since NoJump only corrects the COM. Should we do the same for the ligand below? @IAlibay , do you know what was the reason for implementing this new NoJump transformation instead of using the MDanalysis unwrap?


# Shift chains + ligand
chains = [seg.atoms for seg in prot.segments]
shift = ClosestImageShift(chains[0], [*chains[1:], ligand])
# # Make each protein chain whole
# for frag in prot.fragments:
# make_whole(frag, reference_atom=frag[0])

align = Aligner(prot)

u.trajectory.add_transformations(
nope,
minnie,
unwrap_tr,
shift,
align,
)
else:
Expand Down Expand Up @@ -129,9 +135,9 @@ def gather_rms_data(
# TODO: Some smart guard to avoid allocating a silly amount of memory?
prot2d = np.empty((len(u.trajectory[::skip]), len(prot), 3), dtype=np.float32)

prot_start = prot.positions
# prot_weights = prot.masses / np.mean(prot.masses)
ligand_start = ligand.positions
# Would this copy be safer?
prot_start = prot.positions.copy()
ligand_start = ligand.positions.copy()
ligand_initial_com = ligand.center_of_mass()
ligand_weights = ligand.masses / np.mean(ligand.masses)

Expand Down
8 changes: 7 additions & 1 deletion src/openfe_analysis/tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,19 +156,25 @@ def test_fereader_replica_state_id_error(


def test_simulation_skipped_nc(simulation_skipped_nc, hybrid_system_skipped_pdb):
from MDAnalysis.transformations import wrap

u = mda.Universe(
hybrid_system_skipped_pdb,
simulation_skipped_nc,
format=FEReader,
replica_id=0,
)

# Wrap all atoms inside the simulation box
u.trajectory.add_transformations(wrap(u.atoms))

assert len(u.trajectory) == 51
assert u.trajectory.n_frames == 51
assert u.trajectory.dt == 100
times = np.arange(0, 5001, 100)
for inx, ts in enumerate(u.trajectory):
assert ts.time == times[inx]
# Positions are not all zero since PBC is not removed
assert np.all(u.atoms.positions > 0)
assert np.any(u.atoms.positions != 0)
with pytest.raises(mda.exceptions.NoDataError, match="This Timestep has no velocities"):
u.atoms.velocities
Expand Down
109 changes: 103 additions & 6 deletions src/openfe_analysis/tests/test_rmsd.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,34 @@
from itertools import islice

import MDAnalysis as mda
import netCDF4 as nc
import numpy as np
import pytest
from MDAnalysis.analysis import rms
from MDAnalysis.lib.mdamath import make_whole
from MDAnalysis.transformations import unwrap
from numpy.testing import assert_allclose

from openfe_analysis.rmsd import gather_rms_data
from openfe_analysis.reader import FEReader
from openfe_analysis.rmsd import gather_rms_data, make_Universe
from openfe_analysis.transformations import Aligner


@pytest.fixture
def mda_universe(hybrid_system_skipped_pdb, simulation_skipped_nc):
"""
Safely create and destroy an MDAnalysis Universe.

Guarantees:
- NetCDF file is opened exactly once
"""
u = make_Universe(
hybrid_system_skipped_pdb,
simulation_skipped_nc,
state=0,
)
yield u
u.trajectory.close()


def test_gather_rms_data_regression(simulation_nc, hybrid_system_pdb):
Expand Down Expand Up @@ -51,24 +76,24 @@ def test_gather_rms_data_regression_skippednc(simulation_skipped_nc, hybrid_syst

assert_allclose(output["time(ps)"], np.arange(0, 5001, 100))
assert len(output["protein_RMSD"]) == 11
# TODO: RMSD is very large as the multichain fix is not in yet
# RMSD is low for this multichain protein
assert_allclose(
output["protein_RMSD"][0][:6],
[0, 30.620948, 31.158894, 1.045068, 30.735975, 30.999849],
[0, 1.089747, 1.006143, 1.045068, 1.476353, 1.332893],
rtol=1e-3,
)
assert len(output["ligand_RMSD"]) == 11
# TODO: RMSD is very large as the multichain fix is not in yet
assert_allclose(
output["ligand_RMSD"][0][:6],
[0.0, 12.607834, 13.882825, 1.228384, 14.129542, 14.535247],
[0.0, 1.092039, 0.839234, 1.228383, 1.533331, 1.276798],
rtol=1e-3,
)
assert len(output["ligand_wander"]) == 11
# TODO: very large as the multichain fix is not in yet
assert_allclose(
output["ligand_wander"][0][:6],
[0.0, 10.150182, 11.868109, 0.971329, 12.160156, 12.843338],
[0.0, 0.908097, 0.674262, 0.971328, 0.909263, 1.101882],
rtol=1e-3,
)
assert len(output["protein_2D_RMSD"]) == 11
Expand All @@ -77,6 +102,78 @@ def test_gather_rms_data_regression_skippednc(simulation_skipped_nc, hybrid_syst
# TODO: very large as the multichain fix is not in yet
assert_allclose(
output["protein_2D_RMSD"][0][:6],
[30.620948, 31.158894, 1.045068, 30.735975, 30.999849, 31.102847],
[1.089747, 1.006143, 1.045068, 1.476353, 1.332893, 1.110507],
rtol=1e-3,
)


def test_multichain_rmsd_shifting(simulation_skipped_nc, hybrid_system_skipped_pdb):
u = mda.Universe(
hybrid_system_skipped_pdb,
simulation_skipped_nc,
state_id=0,
format=FEReader,
)
prot = u.select_atoms("protein")
# Do other transformations, but no shifting
unwrap_tr = unwrap(prot)
for frag in prot.fragments:
make_whole(frag, reference_atom=frag[0])
align = Aligner(prot)
u.trajectory.add_transformations(unwrap_tr, align)
chains = [seg.atoms for seg in prot.segments]
assert len(chains) > 1, "Test requires multi-chain protein"

# RMSD without shifting
r = rms.RMSD(prot)
r.run()
rmsd_no_shift = r.rmsd[:, 2]
assert np.max(np.diff(rmsd_no_shift[:20])) > 10 # expect jumps
u.trajectory.close()

# RMSD with shifting
u2 = make_Universe(hybrid_system_skipped_pdb, simulation_skipped_nc, state=0)
prot2 = u2.select_atoms("protein")
R2 = rms.RMSD(prot2)
R2.run()
rmsd_shift = R2.rmsd[:, 2]
assert np.max(np.diff(rmsd_shift[:20])) < 2 # jumps should disappear
u2.trajectory.close()


def test_chain_radius_of_gyration_stable(simulation_skipped_nc, hybrid_system_skipped_pdb):
u = make_Universe(hybrid_system_skipped_pdb, simulation_skipped_nc, state=0)

protein = u.select_atoms("protein")
chain = protein.segments[0].atoms

rgs = []
for ts in u.trajectory[:50]:
rgs.append(chain.radius_of_gyration())

# Chain should not explode or collapse due to PBC errors
assert np.std(rgs) < 2.0
u.trajectory.close()


def test_rmsd_reference_is_first_frame(mda_universe):
u = mda_universe
prot = u.select_atoms("protein")

_ = next(iter(u.trajectory)) # SAFE
ref = prot.positions.copy()

rmsd = np.sqrt(((prot.positions - ref) ** 2).mean())
assert rmsd == 0.0
u.trajectory.close()


def test_ligand_com_continuity(mda_universe):
u = mda_universe
ligand = u.select_atoms("resname UNK")

coms = [ligand.center_of_mass() for ts in islice(u.trajectory, 20)]
jumps = [np.linalg.norm(coms[i + 1] - coms[i]) for i in range(len(coms) - 1)]

assert max(jumps) < 5.0
u.trajectory.close()
7 changes: 4 additions & 3 deletions src/openfe_analysis/tests/test_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from openfe_analysis import FEReader
from openfe_analysis.transformations import (
Aligner,
Minimiser,
ClosestImageShift,
NoJump,
)

Expand All @@ -23,10 +23,10 @@ def universe(hybrid_system_skipped_pdb, simulation_skipped_nc):
u.trajectory.close()


def test_minimiser(universe):
def test_closest_image_shift(universe):
prot = universe.select_atoms("protein and name CA")
lig = universe.select_atoms("resname UNK")
m = Minimiser(prot, lig)
m = ClosestImageShift(prot, [lig])
universe.trajectory.add_transformations(m)

d = mda.lib.distances.calc_bonds(prot.center_of_mass(), lig.center_of_mass())
Expand Down Expand Up @@ -54,6 +54,7 @@ def test_nojump(hybrid_system_pdb, simulation_nc):
# without the transformation, the y coordinate would jump up to ~81.86
ref = np.array([31.79594626, 52.14568866, 30.64103877])
assert prot.center_of_mass() == pytest.approx(ref, abs=0.01)
universe.trajectory.close()


def test_aligner(universe):
Expand Down
36 changes: 17 additions & 19 deletions src/openfe_analysis/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import MDAnalysis as mda
import numpy as np
from MDAnalysis.analysis.align import rotation_matrix
from MDAnalysis.lib.mdamath import triclinic_vectors
from MDAnalysis.transformations.base import TransformationBase
from numpy import typing as npt

Expand Down Expand Up @@ -44,31 +45,27 @@ def _transform(self, ts):
return ts


class Minimiser(TransformationBase):
"""Minimises the difference from ags to central_ag by choosing image

This transformation will translate any AtomGroup in *ags* in multiples of
the box vectors in order to minimise the distance between the center of mass
to the center of mass of each ag.
class ClosestImageShift(TransformationBase):
"""
PBC-safe transformation that shifts one or more target AtomGroups
so that their COM is in the closest image relative to a reference AtomGroup.
Works for any box type (triclinic or orthorhombic).
"""

central_ag: mda.AtomGroup
other_ags: list[mda.AtomGroup]

def __init__(self, central_ag: mda.AtomGroup, *ags):
def __init__(self, reference: mda.AtomGroup, targets: list[mda.AtomGroup]):
super().__init__()
self.central_ag = central_ag
self.other_ags = ags
self.reference = reference
self.targets = targets

def _transform(self, ts):
center = self.central_ag.center_of_mass()
box = self.central_ag.dimensions[:3]
center = self.reference.center_of_mass()
box = triclinic_vectors(ts.dimensions)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice this is great that it works with all boxes.


for ag in self.other_ags:
for ag in self.targets:
vec = ag.center_of_mass() - center

# this only works for orthogonal boxes
ag.positions -= np.rint(vec / box) * box
frac = np.linalg.solve(box.T, vec) # fractional coordinates
shift = np.dot(np.round(frac), box) # nearest image, then compute shift
ag.positions -= shift

return ts

Expand All @@ -87,7 +84,8 @@ class Aligner(TransformationBase):
def __init__(self, ref_ag: mda.AtomGroup):
super().__init__()
self.ref_idx = ref_ag.ix
self.ref_pos = ref_ag.positions
# Would this copy be safer?
self.ref_pos = ref_ag.positions.copy()
self.weights = np.asarray(ref_ag.masses, dtype=np.float64)
self.weights /= np.mean(self.weights) # normalise weights
# remove COM shift from reference positions
Expand Down