Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
309 changes: 206 additions & 103 deletions src/openfe_analysis/rmsd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import pathlib
from dataclasses import asdict, dataclass, field
from typing import Optional

import MDAnalysis as mda
Expand All @@ -15,19 +16,77 @@
from .transformations import Aligner, ClosestImageShift, NoJump


def select_protein_and_ligands(
@dataclass
class SingleLigandRMSData:
rmsd: list[float]
com_drift: list[float]
resname: str
resid: int
segid: str


@dataclass
class LigandsRMSData:
ligands: list[SingleLigandRMSData]

def __iter__(self):
return iter(self.ligands)

def __len__(self):
return len(self.ligands)

def __getitem__(self, idx):
return self.ligands[idx]


@dataclass
class StateRMSData:
protein_rmsd: list[float] | None
protein_2d_rmsd: list[float] | None
ligands: LigandsRMSData | None


@dataclass
class RMSResults:
time_ps: list[float]
states: list[StateRMSData] = field(default_factory=list)

def to_dict(self) -> dict:
"""Convert results to a JSON-serializable dictionary."""
return asdict(self)

@classmethod
def from_dict(cls, d):
return cls(
time_ps=d["time_ps"],
states=[
StateRMSData(
protein_rmsd=s["protein_rmsd"],
protein_2d_rmsd=s["protein_2d_rmsd"],
ligands=(
LigandsRMSData(
ligands=[SingleLigandRMSData(**lig) for lig in s["ligands"]["ligands"]]
)
if s["ligands"] is not None
else None
),
)
for s in d["states"]
],
)


def _select_protein_and_ligands(
u: mda.Universe,
protein_selection: str,
ligand_selection: str,
):
prot = u.select_atoms(protein_selection)

) -> tuple[mda.core.groups.AtomGroup, list[mda.core.groups.AtomGroup]]:
protein = u.select_atoms(protein_selection)
lig_atoms = u.select_atoms(ligand_selection)

# split into individual ligands by residue
ligands = [res.atoms for res in lig_atoms.residues]

return prot, ligands
return protein, ligands


def make_Universe(
Expand Down Expand Up @@ -75,7 +134,7 @@ def make_Universe(
format=FEReader,
)

prot, ligands = select_protein_and_ligands(u, protein_selection, ligand_selection)
prot, ligands = _select_protein_and_ligands(u, protein_selection, ligand_selection)

if prot:
# Unwrap all atoms
Expand Down Expand Up @@ -105,13 +164,134 @@ def make_Universe(
return u


def twoD_RMSD(positions: np.ndarray, w: Optional[npt.NDArray]) -> list[float]:
"""2 dimensions RMSD

Parameters
----------
positions : np.ndarray
the protein positions for the entire trajectory
w : np.ndarray, optional
weights array

Returns
-------
rmsd_matrix : list
Flattened list of RMSD values between all frame pairs.
"""
nframes, _, _ = positions.shape

output = []

for i, j in itertools.combinations(range(nframes), 2):
posi, posj = positions[i], positions[j]

rmsd = rms.rmsd(posi, posj, w, center=True, superposition=True)

output.append(rmsd)

return output


def analyze_state(
u: mda.Universe,
prot: Optional[mda.core.groups.AtomGroup],
ligands: list[mda.core.groups.AtomGroup],
skip: int,
pb: Optional[tqdm.tqdm] = None,
) -> StateRMSData:
"""
Compute RMSD and COM drift for a single lambda state.

Parameters
----------
u : mda.Universe
Universe containing the trajectory.
protein : AtomGroup or None
Protein atoms to compute RMSD for.
ligands : list of AtomGroups
Ligands to compute RMSD and COM drift for.
skip : int
Step size to skip frames (e.g., every `skip`-th frame).

Returns
-------
StateRMSData
RMSD data for protein and ligands.
"""
traj_slice = u.trajectory[::skip]
# Prepare storage
if prot:
prot_positions = np.empty((len(traj_slice), len(prot), 3), dtype=np.float32)
prot_start = prot.positions.copy()
prot_rmsd = []

lig_starts = [lig.positions.copy() for lig in ligands]
lig_initial_coms = np.array([lig.center_of_mass() for lig in ligands])
lig_rmsd: list[list[float]] = [[] for _ in ligands]
lig_com_drift: list[list[float]] = [[] for _ in ligands]

for ts_i, ts in enumerate(traj_slice):
if pb:
pb.update()
if prot:
prot_positions[ts_i, :, :] = prot.positions
prot_rmsd.append(
rms.rmsd(
prot.positions,
prot_start,
None, # prot_weights,
center=False,
superposition=False,
)
)
for i, lig in enumerate(ligands):
lig_rmsd[i].append(
rms.rmsd(
lig.positions,
lig_starts[i],
lig.masses / np.mean(lig.masses),
center=False,
superposition=False,
)
)
com = lig.center_of_mass()
drift = np.linalg.norm(com - lig_initial_coms[i])
lig_com_drift[i].append(drift)

protein_2d = twoD_RMSD(prot_positions, w=None) if prot else None
protein_rmsd_out = prot_rmsd if prot else None

ligands_data = None
if ligands:
single_ligands = [
SingleLigandRMSData(
rmsd=lig_rmsd[i],
com_drift=lig_com_drift[i],
resname=lig.residues[0].resname,
resid=lig.residues[0].resid,
segid=lig.residues[0].segid,
)
for i, lig in enumerate(ligands)
]
ligands_data = LigandsRMSData(single_ligands)

state_data = StateRMSData(
protein_rmsd=protein_rmsd_out,
protein_2d_rmsd=protein_2d,
ligands=ligands_data,
)

return state_data


def gather_rms_data(
pdb_topology: pathlib.Path,
dataset: pathlib.Path,
skip: Optional[int] = None,
ligand_selection: str = "resname UNK",
protein_selection: str = "protein and name CA",
) -> dict[str, list[float]]:
) -> RMSResults:
"""Generate structural analysis of RBFE simulation

Parameters
Expand All @@ -128,19 +308,16 @@ def gather_rms_data(
protein_selection : str, default 'protein and name CA'
MDAnalysis selection string for the protein atoms to consider.

Produces, for each lambda state:
- 1D protein RMSD timeseries 'protein_RMSD'
- ligand RMSD timeseries
- ligand COM motion 'ligand_COM_drift'
- 2D protein RMSD plot
Returns
-------
RMSResults
Per-state RMSD data for protein and ligands.
Produces, for each lambda state:
- 1D protein RMSD timeseries 'protein_RMSD'
- ligand RMSD timeseries
- ligand COM motion 'ligand_COM_drift'
- 2D protein RMSD plot
"""
output = {
"protein_RMSD": [],
"ligand_RMSD": [],
"ligand_COM_drift": [],
"protein_2D_RMSD": [],
}

# Open the NetCDF file safely using a context manager
with nc.Dataset(dataset) as ds:
n_lambda = ds.dimensions["state"].size
Expand All @@ -161,99 +338,25 @@ def gather_rms_data(

u_top = mda.Universe(pdb_topology)

for i in range(n_lambda):
states: list[StateRMSData] = []

for state in range(n_lambda):
# cheeky, but we can read the PDB topology once and reuse per universe
# this then only hits the PDB file once for all replicas
u = make_Universe(
u_top._topology,
ds,
state=i,
state=state,
ligand_selection=ligand_selection,
protein_selection=protein_selection,
)

prot, ligands = select_protein_and_ligands(u, protein_selection, ligand_selection)
prot, ligands = _select_protein_and_ligands(u, protein_selection, ligand_selection)

# Prepare storage
if prot:
prot_positions = np.empty(
(len(u.trajectory[::skip]), len(prot), 3), dtype=np.float32
)
prot_start = prot.positions.copy()
prot_rmsd = []

lig_starts = [lig.positions.copy() for lig in ligands]
lig_initial_coms = [lig.center_of_mass() for lig in ligands]
lig_rmsd: list[list[float]] = [[] for _ in ligands]
lig_com_drift: list[list[float]] = [[] for _ in ligands]

for ts_i, ts in enumerate(u.trajectory[::skip]):
pb.update()
if prot:
prot_positions[ts_i, :, :] = prot.positions
prot_rmsd.append(
rms.rmsd(
prot.positions,
prot_start,
None, # prot_weights,
center=False,
superposition=False,
)
)
for i, lig in enumerate(ligands):
lig_rmsd[i].append(
rms.rmsd(
lig.positions,
lig_starts[i],
lig.masses / np.mean(lig.masses),
center=False,
superposition=False,
)
)
lig_com_drift[i].append(
# distance between start and current ligand position
# ignores PBC, but we've already centered the traj
mda.lib.distances.calc_bonds(lig.center_of_mass(), lig_initial_coms[i])
)

if prot:
# can ignore weights here as it's all Ca
rmsd2d = twoD_RMSD(prot_positions, w=None) # prot_weights)
output["protein_RMSD"].append(prot_rmsd)
output["protein_2D_RMSD"].append(rmsd2d)
if ligands:
output["ligand_RMSD"].append(lig_rmsd)
output["ligand_COM_drift"].append(lig_com_drift)

output["time(ps)"] = list(np.arange(len(u.trajectory))[::skip] * u.trajectory.dt)
state_data = analyze_state(u, prot, ligands, skip, pb)

return output
states.append(state_data)

time = list(np.arange(len(u.trajectory))[::skip] * u.trajectory.dt)

def twoD_RMSD(positions: np.ndarray, w: Optional[npt.NDArray]) -> list[float]:
"""2 dimensions RMSD

Parameters
----------
positions : np.ndarray
the protein positions for the entire trajectory
w : np.ndarray, optional
weights array

Returns
-------
rmsd_matrix : list
a flattened version of the 2d
"""
nframes, _, _ = positions.shape

output = []

for i, j in itertools.combinations(range(nframes), 2):
posi, posj = positions[i], positions[j]

rmsd = rms.rmsd(posi, posj, w, center=True, superposition=True)

output.append(rmsd)

return output
return RMSResults(time_ps=time, states=states)
Loading