From 95f30444250ac9119550d10af249eb93dec021b4 Mon Sep 17 00:00:00 2001 From: hannahbaumann Date: Fri, 30 Jan 2026 12:01:50 +0100 Subject: [PATCH 1/3] Refactor rms output to dataclass --- src/openfe_analysis/rmsd.py | 131 ++++++++++++++++++++----- src/openfe_analysis/tests/test_rmsd.py | 76 ++++++++------ 2 files changed, 153 insertions(+), 54 deletions(-) diff --git a/src/openfe_analysis/rmsd.py b/src/openfe_analysis/rmsd.py index 418d362..45995c3 100644 --- a/src/openfe_analysis/rmsd.py +++ b/src/openfe_analysis/rmsd.py @@ -1,6 +1,7 @@ import itertools import pathlib -from typing import Optional +from dataclasses import asdict, dataclass, field +from typing import List, Optional import MDAnalysis as mda import netCDF4 as nc @@ -15,6 +16,66 @@ from .transformations import Aligner, ClosestImageShift, NoJump +@dataclass +class SingleLigandRMSData: + rmsd: list[float] + com_drift: list[float] + resname: str + resid: int + segid: str + + +@dataclass +class LigandsRMSData: + ligands: list[SingleLigandRMSData] + + def __iter__(self): + return iter(self.ligands) + + def __len__(self): + return len(self.ligands) + + def __getitem__(self, idx): + return self.ligands[idx] + + +@dataclass +class StateRMSData: + protein_rmsd: list[float] | None + protein_2d_rmsd: list[float] | None + ligands: LigandsRMSData | None + + +@dataclass +class RMSResults: + time_ps: list[float] + states: list[StateRMSData] = field(default_factory=list) + + def to_dict(self) -> dict: + """Convert results to a JSON-serializable dictionary.""" + return asdict(self) + + @classmethod + def from_dict(cls, d): + return cls( + time_ps=d["time_ps"], + states=[ + StateRMSData( + protein_rmsd=s["protein_rmsd"], + protein_2d_rmsd=s["protein_2d_rmsd"], + ligands=( + LigandsRMSData( + ligands=[SingleLigandRMSData(**lig) for lig in s["ligands"]["ligands"]] + ) + if s["ligands"] is not None + else None + ), + ) + for s in d["states"] + ], + ) + + def select_protein_and_ligands( u: mda.Universe, protein_selection: str, @@ -22,10 +83,12 @@ def select_protein_and_ligands( ): prot = u.select_atoms(protein_selection) - lig_atoms = u.select_atoms(ligand_selection) + lig_residues = u.select_atoms(ligand_selection).residues + print([res.resid for res in lig_residues]) + print([res.segid for res in lig_residues]) # split into individual ligands by residue - ligands = [res.atoms for res in lig_atoms.residues] + ligands = [res.atoms for res in lig_residues] return prot, ligands @@ -111,7 +174,7 @@ def gather_rms_data( skip: Optional[int] = None, ligand_selection: str = "resname UNK", protein_selection: str = "protein and name CA", -) -> dict[str, list[float]]: +) -> RMSResults: """Generate structural analysis of RBFE simulation Parameters @@ -128,19 +191,16 @@ def gather_rms_data( protein_selection : str, default 'protein and name CA' MDAnalysis selection string for the protein atoms to consider. - Produces, for each lambda state: - - 1D protein RMSD timeseries 'protein_RMSD' - - ligand RMSD timeseries - - ligand COM motion 'ligand_COM_drift' - - 2D protein RMSD plot + Returns + ------- + RMSResults + Per-state RMSD data for protein and ligands. + Produces, for each lambda state: + - 1D protein RMSD timeseries 'protein_RMSD' + - ligand RMSD timeseries + - ligand COM motion 'ligand_COM_drift' + - 2D protein RMSD plot """ - output = { - "protein_RMSD": [], - "ligand_RMSD": [], - "ligand_COM_drift": [], - "protein_2D_RMSD": [], - } - # Open the NetCDF file safely using a context manager with nc.Dataset(dataset) as ds: n_lambda = ds.dimensions["state"].size @@ -161,6 +221,8 @@ def gather_rms_data( u_top = mda.Universe(pdb_topology) + states: list[StateRMSData] = [] + for i in range(n_lambda): # cheeky, but we can read the PDB topology once and reuse per universe # this then only hits the PDB file once for all replicas @@ -181,6 +243,9 @@ def gather_rms_data( ) prot_start = prot.positions.copy() prot_rmsd = [] + else: + prot_rmsd = [] + prot_positions = None lig_starts = [lig.positions.copy() for lig in ligands] lig_initial_coms = [lig.center_of_mass() for lig in ligands] @@ -216,18 +281,34 @@ def gather_rms_data( mda.lib.distances.calc_bonds(lig.center_of_mass(), lig_initial_coms[i]) ) - if prot: - # can ignore weights here as it's all Ca - rmsd2d = twoD_RMSD(prot_positions, w=None) # prot_weights) - output["protein_RMSD"].append(prot_rmsd) - output["protein_2D_RMSD"].append(rmsd2d) + protein_2d = twoD_RMSD(prot_positions, w=None) if prot else None + protein_rmsd_out = prot_rmsd if prot else None + + ligands_data = None if ligands: - output["ligand_RMSD"].append(lig_rmsd) - output["ligand_COM_drift"].append(lig_com_drift) + single_ligands = [ + SingleLigandRMSData( + rmsd=lig_rmsd[i], + com_drift=lig_com_drift[i], + resname=lig.residues[0].resname, + resid=lig.residues[0].resid, + segid=lig.residues[0].segid, + ) + for i, lig in enumerate(ligands) + ] + ligands_data = LigandsRMSData(single_ligands) + + states.append( + StateRMSData( + protein_rmsd=protein_rmsd_out, + protein_2d_rmsd=protein_2d, + ligands=ligands_data, + ), + ) - output["time(ps)"] = list(np.arange(len(u.trajectory))[::skip] * u.trajectory.dt) + time = list(np.arange(len(u.trajectory))[::skip] * u.trajectory.dt) - return output + return RMSResults(time_ps=time, states=states) def twoD_RMSD(positions: np.ndarray, w: Optional[npt.NDArray]) -> list[float]: diff --git a/src/openfe_analysis/tests/test_rmsd.py b/src/openfe_analysis/tests/test_rmsd.py index 26d93af..cfcc5e4 100644 --- a/src/openfe_analysis/tests/test_rmsd.py +++ b/src/openfe_analysis/tests/test_rmsd.py @@ -10,7 +10,7 @@ from numpy.testing import assert_allclose from openfe_analysis.reader import FEReader -from openfe_analysis.rmsd import gather_rms_data, make_Universe +from openfe_analysis.rmsd import RMSResults, gather_rms_data, make_Universe from openfe_analysis.transformations import Aligner @@ -38,30 +38,28 @@ def test_gather_rms_data_regression(simulation_nc, hybrid_system_pdb): skip=100, ) - assert_allclose(output["time(ps)"], [0.0, 100.0, 200.0, 300.0, 400.0, 500.0]) - assert len(output["protein_RMSD"]) == 3 + assert_allclose(output.time_ps, [0.0, 100.0, 200.0, 300.0, 400.0, 500.0]) + assert len(output.states) == 3 + state0 = output.states[0] assert_allclose( - output["protein_RMSD"][0], + state0.protein_rmsd, [0.0, 1.003, 1.276, 1.263, 1.516, 1.251], rtol=1e-3, ) - assert len(output["ligand_RMSD"]) == 3 assert_allclose( - output["ligand_RMSD"][0][0], + state0.ligands[0].rmsd, [0.0, 0.9094, 1.0398, 0.9774, 1.9108, 1.2149], rtol=1e-3, ) - assert len(output["ligand_COM_drift"]) == 3 assert_allclose( - output["ligand_COM_drift"][0][0], + state0.ligands[0].com_drift, [0.0, 0.5458, 0.8364, 0.4914, 1.1939, 0.7587], rtol=1e-3, ) - assert len(output["protein_2D_RMSD"]) == 3 # 15 entries because 6 * 6 frames // 2 - assert len(output["protein_2D_RMSD"][0]) == 15 + assert len(state0.protein_2d_rmsd) == 15 assert_allclose( - output["protein_2D_RMSD"][0][:6], + state0.protein_2d_rmsd[:6], [1.0029, 1.2756, 1.2635, 1.5165, 1.2509, 1.0882], rtol=1e-3, ) @@ -74,14 +72,12 @@ def test_gather_rms_data_septop(simulation_nc_septop, system_septop): skip=100, ) - assert_allclose(output["time(ps)"], [0.0, 10000.0]) - assert len(output["protein_RMSD"]) == 19 - assert len(output["ligand_RMSD"]) == 19 - # Check that we have two lists, one for each ligand - assert len(output["ligand_RMSD"][0]) == 2 - assert len(output["ligand_COM_drift"]) == 19 - # Check that we have two lists, one for each ligand - assert len(output["ligand_COM_drift"][0]) == 2 + assert_allclose(output.time_ps, [0.0, 10000.0]) + assert len(output.states) == 19 + state0 = output.states[0] + # Check that we have two ligands + assert len(state0.ligands) == 2 + assert state0.ligands[0].segid != state0.ligands[1].segid def test_make_universe_two_ligands(simulation_nc_septop, system_septop): @@ -106,32 +102,30 @@ def test_gather_rms_data_regression_skippednc(simulation_skipped_nc, hybrid_syst skip=None, ) - assert_allclose(output["time(ps)"], np.arange(0, 5001, 100)) - assert len(output["protein_RMSD"]) == 11 + assert_allclose(output.time_ps, np.arange(0, 5001, 100)) + assert len(output.states) == 11 + state0 = output.states[0] # RMSD is low for this multichain protein assert_allclose( - output["protein_RMSD"][0][:6], + state0.protein_rmsd[:6], [0, 1.089747, 1.006143, 1.045068, 1.476353, 1.332893], rtol=1e-3, ) - assert len(output["ligand_RMSD"]) == 11 assert_allclose( - output["ligand_RMSD"][0][0][:6], + state0.ligands[0].rmsd[:6], [0.0, 1.092039, 0.839234, 1.228383, 1.533331, 1.276798], rtol=1e-3, ) - assert len(output["ligand_COM_drift"]) == 11 assert_allclose( - output["ligand_COM_drift"][0][0][:6], + state0.ligands[0].com_drift[:6], [0.0, 0.908097, 0.674262, 0.971328, 0.909263, 1.101882], rtol=1e-3, ) - assert len(output["protein_2D_RMSD"]) == 11 # 15 entries because 6 * 6 frames // 2 - assert len(output["protein_2D_RMSD"][0]) == 1275 + assert len(state0.protein_2d_rmsd) == 1275 # TODO: very large as the multichain fix is not in yet assert_allclose( - output["protein_2D_RMSD"][0][:6], + state0.protein_2d_rmsd[:6], [1.089747, 1.006143, 1.045068, 1.476353, 1.332893, 1.110507], rtol=1e-3, ) @@ -207,3 +201,27 @@ def test_ligand_com_continuity(mda_universe): assert max(jumps) < 5.0 u.trajectory.close() + + +def test_rmsresults_serialization_roundtrip(simulation_nc, hybrid_system_pdb): + results = gather_rms_data( + hybrid_system_pdb, + simulation_nc, + skip=100, + ) + + results_dict = results.to_dict() + + loaded = RMSResults.from_dict(results_dict) + + # basic structure + assert_allclose(loaded.time_ps, results.time_ps) + assert len(loaded.states) == len(results.states) + + # spot-check first state + s0 = loaded.states[0] + r0 = results.states[0] + + assert_allclose(s0.protein_rmsd, r0.protein_rmsd) + assert_allclose(s0.ligands[0].rmsd, r0.ligands[0].rmsd) + assert_allclose(s0.ligands[0].com_drift, r0.ligands[0].com_drift) From 9693f160944e7f83122053a26325e33b1f2c0694 Mon Sep 17 00:00:00 2001 From: hannahbaumann Date: Fri, 30 Jan 2026 14:13:33 +0100 Subject: [PATCH 2/3] small fix --- src/openfe_analysis/rmsd.py | 251 ++++++++++++++++++++---------------- 1 file changed, 139 insertions(+), 112 deletions(-) diff --git a/src/openfe_analysis/rmsd.py b/src/openfe_analysis/rmsd.py index 45995c3..60866d6 100644 --- a/src/openfe_analysis/rmsd.py +++ b/src/openfe_analysis/rmsd.py @@ -76,21 +76,17 @@ def from_dict(cls, d): ) -def select_protein_and_ligands( +def _select_protein_and_ligands( u: mda.Universe, protein_selection: str, ligand_selection: str, -): - prot = u.select_atoms(protein_selection) - - lig_residues = u.select_atoms(ligand_selection).residues - print([res.resid for res in lig_residues]) - print([res.segid for res in lig_residues]) - +) -> tuple[mda.core.groups.AtomGroup, list[mda.core.groups.AtomGroup]]: + protein = u.select_atoms(protein_selection) + lig_atoms = u.select_atoms(ligand_selection) # split into individual ligands by residue - ligands = [res.atoms for res in lig_residues] + ligands = [res.atoms for res in lig_atoms.residues] - return prot, ligands + return protein, ligands def make_Universe( @@ -138,7 +134,7 @@ def make_Universe( format=FEReader, ) - prot, ligands = select_protein_and_ligands(u, protein_selection, ligand_selection) + prot, ligands = _select_protein_and_ligands(u, protein_selection, ligand_selection) if prot: # Unwrap all atoms @@ -168,6 +164,131 @@ def make_Universe( return u +def twoD_RMSD(positions: np.ndarray, w: Optional[npt.NDArray]) -> list[float]: + """2 dimensions RMSD + + Parameters + ---------- + positions : np.ndarray + the protein positions for the entire trajectory + w : np.ndarray, optional + weights array + + Returns + ------- + rmsd_matrix : list + Flattened list of RMSD values between all frame pairs. + """ + nframes, _, _ = positions.shape + + output = [] + + for i, j in itertools.combinations(range(nframes), 2): + posi, posj = positions[i], positions[j] + + rmsd = rms.rmsd(posi, posj, w, center=True, superposition=True) + + output.append(rmsd) + + return output + + +def analyze_state( + u: mda.Universe, + prot: Optional[mda.core.groups.AtomGroup], + ligands: list[mda.core.groups.AtomGroup], + skip: int, +) -> tuple[ + Optional[list[float]], + Optional[np.ndarray], + Optional[list[list[float]]], + Optional[list[list[float]]], +]: + """ + Compute RMSD and COM drift for a single lambda state. + + Parameters + ---------- + u : mda.Universe + Universe containing the trajectory. + protein : AtomGroup or None + Protein atoms to compute RMSD for. + ligands : list of AtomGroups + Ligands to compute RMSD and COM drift for. + skip : int + Step size to skip frames (e.g., every `skip`-th frame). + + Returns + ------- + StateRMSData + RMSD data for protein and ligands. + """ + traj_slice = u.trajectory[::skip] + # Prepare storage + if prot: + prot_positions = np.empty((len(traj_slice), len(prot), 3), dtype=np.float32) + prot_start = prot.positions.copy() + prot_rmsd = [] + + lig_starts = [lig.positions.copy() for lig in ligands] + lig_initial_coms = [lig.center_of_mass() for lig in ligands] + lig_rmsd: list[list[float]] = [[] for _ in ligands] + lig_com_drift: list[list[float]] = [[] for _ in ligands] + + for ts_i, ts in enumerate(traj_slice): + if prot: + prot_positions[ts_i, :, :] = prot.positions + prot_rmsd.append( + rms.rmsd( + prot.positions, + prot_start, + None, # prot_weights, + center=False, + superposition=False, + ) + ) + for i, lig in enumerate(ligands): + lig_rmsd[i].append( + rms.rmsd( + lig.positions, + lig_starts[i], + lig.masses / np.mean(lig.masses), + center=False, + superposition=False, + ) + ) + lig_com_drift[i].append( + # distance between start and current ligand position + # ignores PBC, but we've already centered the traj + mda.lib.distances.calc_bonds(lig.center_of_mass(), lig_initial_coms[i]) + ) + + protein_2d = twoD_RMSD(prot_positions, w=None) if prot else None + protein_rmsd_out = prot_rmsd if prot else None + + ligands_data = None + if ligands: + single_ligands = [ + SingleLigandRMSData( + rmsd=lig_rmsd[i], + com_drift=lig_com_drift[i], + resname=lig.residues[0].resname, + resid=lig.residues[0].resid, + segid=lig.residues[0].segid, + ) + for i, lig in enumerate(ligands) + ] + ligands_data = LigandsRMSData(single_ligands) + + state_data = StateRMSData( + protein_rmsd=protein_rmsd_out, + protein_2d_rmsd=protein_2d, + ligands=ligands_data, + ) + + return state_data + + def gather_rms_data( pdb_topology: pathlib.Path, dataset: pathlib.Path, @@ -223,118 +344,24 @@ def gather_rms_data( states: list[StateRMSData] = [] - for i in range(n_lambda): + for state in range(n_lambda): # cheeky, but we can read the PDB topology once and reuse per universe # this then only hits the PDB file once for all replicas u = make_Universe( u_top._topology, ds, - state=i, + state=state, ligand_selection=ligand_selection, protein_selection=protein_selection, ) - prot, ligands = select_protein_and_ligands(u, protein_selection, ligand_selection) + prot, ligands = _select_protein_and_ligands(u, protein_selection, ligand_selection) - # Prepare storage - if prot: - prot_positions = np.empty( - (len(u.trajectory[::skip]), len(prot), 3), dtype=np.float32 - ) - prot_start = prot.positions.copy() - prot_rmsd = [] - else: - prot_rmsd = [] - prot_positions = None - - lig_starts = [lig.positions.copy() for lig in ligands] - lig_initial_coms = [lig.center_of_mass() for lig in ligands] - lig_rmsd: list[list[float]] = [[] for _ in ligands] - lig_com_drift: list[list[float]] = [[] for _ in ligands] - - for ts_i, ts in enumerate(u.trajectory[::skip]): - pb.update() - if prot: - prot_positions[ts_i, :, :] = prot.positions - prot_rmsd.append( - rms.rmsd( - prot.positions, - prot_start, - None, # prot_weights, - center=False, - superposition=False, - ) - ) - for i, lig in enumerate(ligands): - lig_rmsd[i].append( - rms.rmsd( - lig.positions, - lig_starts[i], - lig.masses / np.mean(lig.masses), - center=False, - superposition=False, - ) - ) - lig_com_drift[i].append( - # distance between start and current ligand position - # ignores PBC, but we've already centered the traj - mda.lib.distances.calc_bonds(lig.center_of_mass(), lig_initial_coms[i]) - ) - - protein_2d = twoD_RMSD(prot_positions, w=None) if prot else None - protein_rmsd_out = prot_rmsd if prot else None - - ligands_data = None - if ligands: - single_ligands = [ - SingleLigandRMSData( - rmsd=lig_rmsd[i], - com_drift=lig_com_drift[i], - resname=lig.residues[0].resname, - resid=lig.residues[0].resid, - segid=lig.residues[0].segid, - ) - for i, lig in enumerate(ligands) - ] - ligands_data = LigandsRMSData(single_ligands) - - states.append( - StateRMSData( - protein_rmsd=protein_rmsd_out, - protein_2d_rmsd=protein_2d, - ligands=ligands_data, - ), - ) + state_data = analyze_state(u, prot, ligands, skip) + + states.append(state_data) time = list(np.arange(len(u.trajectory))[::skip] * u.trajectory.dt) + pb.update(len(u.trajectory[::skip])) return RMSResults(time_ps=time, states=states) - - -def twoD_RMSD(positions: np.ndarray, w: Optional[npt.NDArray]) -> list[float]: - """2 dimensions RMSD - - Parameters - ---------- - positions : np.ndarray - the protein positions for the entire trajectory - w : np.ndarray, optional - weights array - - Returns - ------- - rmsd_matrix : list - a flattened version of the 2d - """ - nframes, _, _ = positions.shape - - output = [] - - for i, j in itertools.combinations(range(nframes), 2): - posi, posj = positions[i], positions[j] - - rmsd = rms.rmsd(posi, posj, w, center=True, superposition=True) - - output.append(rmsd) - - return output From 8363b7d122ffb4e598ed9ad7075a44899f4e26e0 Mon Sep 17 00:00:00 2001 From: hannahbaumann Date: Fri, 30 Jan 2026 15:12:22 +0100 Subject: [PATCH 3/3] Move pb into other function --- src/openfe_analysis/rmsd.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/src/openfe_analysis/rmsd.py b/src/openfe_analysis/rmsd.py index 60866d6..0b5abf9 100644 --- a/src/openfe_analysis/rmsd.py +++ b/src/openfe_analysis/rmsd.py @@ -1,7 +1,7 @@ import itertools import pathlib from dataclasses import asdict, dataclass, field -from typing import List, Optional +from typing import Optional import MDAnalysis as mda import netCDF4 as nc @@ -198,12 +198,8 @@ def analyze_state( prot: Optional[mda.core.groups.AtomGroup], ligands: list[mda.core.groups.AtomGroup], skip: int, -) -> tuple[ - Optional[list[float]], - Optional[np.ndarray], - Optional[list[list[float]]], - Optional[list[list[float]]], -]: + pb: Optional[tqdm.tqdm] = None, +) -> StateRMSData: """ Compute RMSD and COM drift for a single lambda state. @@ -231,11 +227,13 @@ def analyze_state( prot_rmsd = [] lig_starts = [lig.positions.copy() for lig in ligands] - lig_initial_coms = [lig.center_of_mass() for lig in ligands] + lig_initial_coms = np.array([lig.center_of_mass() for lig in ligands]) lig_rmsd: list[list[float]] = [[] for _ in ligands] lig_com_drift: list[list[float]] = [[] for _ in ligands] for ts_i, ts in enumerate(traj_slice): + if pb: + pb.update() if prot: prot_positions[ts_i, :, :] = prot.positions prot_rmsd.append( @@ -257,11 +255,9 @@ def analyze_state( superposition=False, ) ) - lig_com_drift[i].append( - # distance between start and current ligand position - # ignores PBC, but we've already centered the traj - mda.lib.distances.calc_bonds(lig.center_of_mass(), lig_initial_coms[i]) - ) + com = lig.center_of_mass() + drift = np.linalg.norm(com - lig_initial_coms[i]) + lig_com_drift[i].append(drift) protein_2d = twoD_RMSD(prot_positions, w=None) if prot else None protein_rmsd_out = prot_rmsd if prot else None @@ -357,11 +353,10 @@ def gather_rms_data( prot, ligands = _select_protein_and_ligands(u, protein_selection, ligand_selection) - state_data = analyze_state(u, prot, ligands, skip) + state_data = analyze_state(u, prot, ligands, skip, pb) states.append(state_data) time = list(np.arange(len(u.trajectory))[::skip] * u.trajectory.dt) - pb.update(len(u.trajectory[::skip])) return RMSResults(time_ps=time, states=states)