diff --git a/src/openfe_analysis/rmsd.py b/src/openfe_analysis/rmsd.py index 418d362..0b5abf9 100644 --- a/src/openfe_analysis/rmsd.py +++ b/src/openfe_analysis/rmsd.py @@ -1,5 +1,6 @@ import itertools import pathlib +from dataclasses import asdict, dataclass, field from typing import Optional import MDAnalysis as mda @@ -15,19 +16,77 @@ from .transformations import Aligner, ClosestImageShift, NoJump -def select_protein_and_ligands( +@dataclass +class SingleLigandRMSData: + rmsd: list[float] + com_drift: list[float] + resname: str + resid: int + segid: str + + +@dataclass +class LigandsRMSData: + ligands: list[SingleLigandRMSData] + + def __iter__(self): + return iter(self.ligands) + + def __len__(self): + return len(self.ligands) + + def __getitem__(self, idx): + return self.ligands[idx] + + +@dataclass +class StateRMSData: + protein_rmsd: list[float] | None + protein_2d_rmsd: list[float] | None + ligands: LigandsRMSData | None + + +@dataclass +class RMSResults: + time_ps: list[float] + states: list[StateRMSData] = field(default_factory=list) + + def to_dict(self) -> dict: + """Convert results to a JSON-serializable dictionary.""" + return asdict(self) + + @classmethod + def from_dict(cls, d): + return cls( + time_ps=d["time_ps"], + states=[ + StateRMSData( + protein_rmsd=s["protein_rmsd"], + protein_2d_rmsd=s["protein_2d_rmsd"], + ligands=( + LigandsRMSData( + ligands=[SingleLigandRMSData(**lig) for lig in s["ligands"]["ligands"]] + ) + if s["ligands"] is not None + else None + ), + ) + for s in d["states"] + ], + ) + + +def _select_protein_and_ligands( u: mda.Universe, protein_selection: str, ligand_selection: str, -): - prot = u.select_atoms(protein_selection) - +) -> tuple[mda.core.groups.AtomGroup, list[mda.core.groups.AtomGroup]]: + protein = u.select_atoms(protein_selection) lig_atoms = u.select_atoms(ligand_selection) - # split into individual ligands by residue ligands = [res.atoms for res in lig_atoms.residues] - return prot, ligands + return protein, ligands def make_Universe( @@ -75,7 +134,7 @@ def make_Universe( format=FEReader, ) - prot, ligands = select_protein_and_ligands(u, protein_selection, ligand_selection) + prot, ligands = _select_protein_and_ligands(u, protein_selection, ligand_selection) if prot: # Unwrap all atoms @@ -105,13 +164,134 @@ def make_Universe( return u +def twoD_RMSD(positions: np.ndarray, w: Optional[npt.NDArray]) -> list[float]: + """2 dimensions RMSD + + Parameters + ---------- + positions : np.ndarray + the protein positions for the entire trajectory + w : np.ndarray, optional + weights array + + Returns + ------- + rmsd_matrix : list + Flattened list of RMSD values between all frame pairs. + """ + nframes, _, _ = positions.shape + + output = [] + + for i, j in itertools.combinations(range(nframes), 2): + posi, posj = positions[i], positions[j] + + rmsd = rms.rmsd(posi, posj, w, center=True, superposition=True) + + output.append(rmsd) + + return output + + +def analyze_state( + u: mda.Universe, + prot: Optional[mda.core.groups.AtomGroup], + ligands: list[mda.core.groups.AtomGroup], + skip: int, + pb: Optional[tqdm.tqdm] = None, +) -> StateRMSData: + """ + Compute RMSD and COM drift for a single lambda state. + + Parameters + ---------- + u : mda.Universe + Universe containing the trajectory. + protein : AtomGroup or None + Protein atoms to compute RMSD for. + ligands : list of AtomGroups + Ligands to compute RMSD and COM drift for. + skip : int + Step size to skip frames (e.g., every `skip`-th frame). + + Returns + ------- + StateRMSData + RMSD data for protein and ligands. + """ + traj_slice = u.trajectory[::skip] + # Prepare storage + if prot: + prot_positions = np.empty((len(traj_slice), len(prot), 3), dtype=np.float32) + prot_start = prot.positions.copy() + prot_rmsd = [] + + lig_starts = [lig.positions.copy() for lig in ligands] + lig_initial_coms = np.array([lig.center_of_mass() for lig in ligands]) + lig_rmsd: list[list[float]] = [[] for _ in ligands] + lig_com_drift: list[list[float]] = [[] for _ in ligands] + + for ts_i, ts in enumerate(traj_slice): + if pb: + pb.update() + if prot: + prot_positions[ts_i, :, :] = prot.positions + prot_rmsd.append( + rms.rmsd( + prot.positions, + prot_start, + None, # prot_weights, + center=False, + superposition=False, + ) + ) + for i, lig in enumerate(ligands): + lig_rmsd[i].append( + rms.rmsd( + lig.positions, + lig_starts[i], + lig.masses / np.mean(lig.masses), + center=False, + superposition=False, + ) + ) + com = lig.center_of_mass() + drift = np.linalg.norm(com - lig_initial_coms[i]) + lig_com_drift[i].append(drift) + + protein_2d = twoD_RMSD(prot_positions, w=None) if prot else None + protein_rmsd_out = prot_rmsd if prot else None + + ligands_data = None + if ligands: + single_ligands = [ + SingleLigandRMSData( + rmsd=lig_rmsd[i], + com_drift=lig_com_drift[i], + resname=lig.residues[0].resname, + resid=lig.residues[0].resid, + segid=lig.residues[0].segid, + ) + for i, lig in enumerate(ligands) + ] + ligands_data = LigandsRMSData(single_ligands) + + state_data = StateRMSData( + protein_rmsd=protein_rmsd_out, + protein_2d_rmsd=protein_2d, + ligands=ligands_data, + ) + + return state_data + + def gather_rms_data( pdb_topology: pathlib.Path, dataset: pathlib.Path, skip: Optional[int] = None, ligand_selection: str = "resname UNK", protein_selection: str = "protein and name CA", -) -> dict[str, list[float]]: +) -> RMSResults: """Generate structural analysis of RBFE simulation Parameters @@ -128,19 +308,16 @@ def gather_rms_data( protein_selection : str, default 'protein and name CA' MDAnalysis selection string for the protein atoms to consider. - Produces, for each lambda state: - - 1D protein RMSD timeseries 'protein_RMSD' - - ligand RMSD timeseries - - ligand COM motion 'ligand_COM_drift' - - 2D protein RMSD plot + Returns + ------- + RMSResults + Per-state RMSD data for protein and ligands. + Produces, for each lambda state: + - 1D protein RMSD timeseries 'protein_RMSD' + - ligand RMSD timeseries + - ligand COM motion 'ligand_COM_drift' + - 2D protein RMSD plot """ - output = { - "protein_RMSD": [], - "ligand_RMSD": [], - "ligand_COM_drift": [], - "protein_2D_RMSD": [], - } - # Open the NetCDF file safely using a context manager with nc.Dataset(dataset) as ds: n_lambda = ds.dimensions["state"].size @@ -161,99 +338,25 @@ def gather_rms_data( u_top = mda.Universe(pdb_topology) - for i in range(n_lambda): + states: list[StateRMSData] = [] + + for state in range(n_lambda): # cheeky, but we can read the PDB topology once and reuse per universe # this then only hits the PDB file once for all replicas u = make_Universe( u_top._topology, ds, - state=i, + state=state, ligand_selection=ligand_selection, protein_selection=protein_selection, ) - prot, ligands = select_protein_and_ligands(u, protein_selection, ligand_selection) + prot, ligands = _select_protein_and_ligands(u, protein_selection, ligand_selection) - # Prepare storage - if prot: - prot_positions = np.empty( - (len(u.trajectory[::skip]), len(prot), 3), dtype=np.float32 - ) - prot_start = prot.positions.copy() - prot_rmsd = [] - - lig_starts = [lig.positions.copy() for lig in ligands] - lig_initial_coms = [lig.center_of_mass() for lig in ligands] - lig_rmsd: list[list[float]] = [[] for _ in ligands] - lig_com_drift: list[list[float]] = [[] for _ in ligands] - - for ts_i, ts in enumerate(u.trajectory[::skip]): - pb.update() - if prot: - prot_positions[ts_i, :, :] = prot.positions - prot_rmsd.append( - rms.rmsd( - prot.positions, - prot_start, - None, # prot_weights, - center=False, - superposition=False, - ) - ) - for i, lig in enumerate(ligands): - lig_rmsd[i].append( - rms.rmsd( - lig.positions, - lig_starts[i], - lig.masses / np.mean(lig.masses), - center=False, - superposition=False, - ) - ) - lig_com_drift[i].append( - # distance between start and current ligand position - # ignores PBC, but we've already centered the traj - mda.lib.distances.calc_bonds(lig.center_of_mass(), lig_initial_coms[i]) - ) - - if prot: - # can ignore weights here as it's all Ca - rmsd2d = twoD_RMSD(prot_positions, w=None) # prot_weights) - output["protein_RMSD"].append(prot_rmsd) - output["protein_2D_RMSD"].append(rmsd2d) - if ligands: - output["ligand_RMSD"].append(lig_rmsd) - output["ligand_COM_drift"].append(lig_com_drift) - - output["time(ps)"] = list(np.arange(len(u.trajectory))[::skip] * u.trajectory.dt) + state_data = analyze_state(u, prot, ligands, skip, pb) - return output + states.append(state_data) + time = list(np.arange(len(u.trajectory))[::skip] * u.trajectory.dt) -def twoD_RMSD(positions: np.ndarray, w: Optional[npt.NDArray]) -> list[float]: - """2 dimensions RMSD - - Parameters - ---------- - positions : np.ndarray - the protein positions for the entire trajectory - w : np.ndarray, optional - weights array - - Returns - ------- - rmsd_matrix : list - a flattened version of the 2d - """ - nframes, _, _ = positions.shape - - output = [] - - for i, j in itertools.combinations(range(nframes), 2): - posi, posj = positions[i], positions[j] - - rmsd = rms.rmsd(posi, posj, w, center=True, superposition=True) - - output.append(rmsd) - - return output + return RMSResults(time_ps=time, states=states) diff --git a/src/openfe_analysis/tests/test_rmsd.py b/src/openfe_analysis/tests/test_rmsd.py index 26d93af..cfcc5e4 100644 --- a/src/openfe_analysis/tests/test_rmsd.py +++ b/src/openfe_analysis/tests/test_rmsd.py @@ -10,7 +10,7 @@ from numpy.testing import assert_allclose from openfe_analysis.reader import FEReader -from openfe_analysis.rmsd import gather_rms_data, make_Universe +from openfe_analysis.rmsd import RMSResults, gather_rms_data, make_Universe from openfe_analysis.transformations import Aligner @@ -38,30 +38,28 @@ def test_gather_rms_data_regression(simulation_nc, hybrid_system_pdb): skip=100, ) - assert_allclose(output["time(ps)"], [0.0, 100.0, 200.0, 300.0, 400.0, 500.0]) - assert len(output["protein_RMSD"]) == 3 + assert_allclose(output.time_ps, [0.0, 100.0, 200.0, 300.0, 400.0, 500.0]) + assert len(output.states) == 3 + state0 = output.states[0] assert_allclose( - output["protein_RMSD"][0], + state0.protein_rmsd, [0.0, 1.003, 1.276, 1.263, 1.516, 1.251], rtol=1e-3, ) - assert len(output["ligand_RMSD"]) == 3 assert_allclose( - output["ligand_RMSD"][0][0], + state0.ligands[0].rmsd, [0.0, 0.9094, 1.0398, 0.9774, 1.9108, 1.2149], rtol=1e-3, ) - assert len(output["ligand_COM_drift"]) == 3 assert_allclose( - output["ligand_COM_drift"][0][0], + state0.ligands[0].com_drift, [0.0, 0.5458, 0.8364, 0.4914, 1.1939, 0.7587], rtol=1e-3, ) - assert len(output["protein_2D_RMSD"]) == 3 # 15 entries because 6 * 6 frames // 2 - assert len(output["protein_2D_RMSD"][0]) == 15 + assert len(state0.protein_2d_rmsd) == 15 assert_allclose( - output["protein_2D_RMSD"][0][:6], + state0.protein_2d_rmsd[:6], [1.0029, 1.2756, 1.2635, 1.5165, 1.2509, 1.0882], rtol=1e-3, ) @@ -74,14 +72,12 @@ def test_gather_rms_data_septop(simulation_nc_septop, system_septop): skip=100, ) - assert_allclose(output["time(ps)"], [0.0, 10000.0]) - assert len(output["protein_RMSD"]) == 19 - assert len(output["ligand_RMSD"]) == 19 - # Check that we have two lists, one for each ligand - assert len(output["ligand_RMSD"][0]) == 2 - assert len(output["ligand_COM_drift"]) == 19 - # Check that we have two lists, one for each ligand - assert len(output["ligand_COM_drift"][0]) == 2 + assert_allclose(output.time_ps, [0.0, 10000.0]) + assert len(output.states) == 19 + state0 = output.states[0] + # Check that we have two ligands + assert len(state0.ligands) == 2 + assert state0.ligands[0].segid != state0.ligands[1].segid def test_make_universe_two_ligands(simulation_nc_septop, system_septop): @@ -106,32 +102,30 @@ def test_gather_rms_data_regression_skippednc(simulation_skipped_nc, hybrid_syst skip=None, ) - assert_allclose(output["time(ps)"], np.arange(0, 5001, 100)) - assert len(output["protein_RMSD"]) == 11 + assert_allclose(output.time_ps, np.arange(0, 5001, 100)) + assert len(output.states) == 11 + state0 = output.states[0] # RMSD is low for this multichain protein assert_allclose( - output["protein_RMSD"][0][:6], + state0.protein_rmsd[:6], [0, 1.089747, 1.006143, 1.045068, 1.476353, 1.332893], rtol=1e-3, ) - assert len(output["ligand_RMSD"]) == 11 assert_allclose( - output["ligand_RMSD"][0][0][:6], + state0.ligands[0].rmsd[:6], [0.0, 1.092039, 0.839234, 1.228383, 1.533331, 1.276798], rtol=1e-3, ) - assert len(output["ligand_COM_drift"]) == 11 assert_allclose( - output["ligand_COM_drift"][0][0][:6], + state0.ligands[0].com_drift[:6], [0.0, 0.908097, 0.674262, 0.971328, 0.909263, 1.101882], rtol=1e-3, ) - assert len(output["protein_2D_RMSD"]) == 11 # 15 entries because 6 * 6 frames // 2 - assert len(output["protein_2D_RMSD"][0]) == 1275 + assert len(state0.protein_2d_rmsd) == 1275 # TODO: very large as the multichain fix is not in yet assert_allclose( - output["protein_2D_RMSD"][0][:6], + state0.protein_2d_rmsd[:6], [1.089747, 1.006143, 1.045068, 1.476353, 1.332893, 1.110507], rtol=1e-3, ) @@ -207,3 +201,27 @@ def test_ligand_com_continuity(mda_universe): assert max(jumps) < 5.0 u.trajectory.close() + + +def test_rmsresults_serialization_roundtrip(simulation_nc, hybrid_system_pdb): + results = gather_rms_data( + hybrid_system_pdb, + simulation_nc, + skip=100, + ) + + results_dict = results.to_dict() + + loaded = RMSResults.from_dict(results_dict) + + # basic structure + assert_allclose(loaded.time_ps, results.time_ps) + assert len(loaded.states) == len(results.states) + + # spot-check first state + s0 = loaded.states[0] + r0 = results.states[0] + + assert_allclose(s0.protein_rmsd, r0.protein_rmsd) + assert_allclose(s0.ligands[0].rmsd, r0.ligands[0].rmsd) + assert_allclose(s0.ligands[0].com_drift, r0.ligands[0].com_drift)