diff --git a/CodeEntropy/levels/axes.py b/CodeEntropy/levels/axes.py index 57ef891..20ab150 100644 --- a/CodeEntropy/levels/axes.py +++ b/CodeEntropy/levels/axes.py @@ -67,7 +67,7 @@ def get_residue_axes(self, data_container, index: int, residue=None): The translational and rotational axes at the residue level. - Identify the residue (either provided or selected by `resindex index`). - - Determine whether the residue is bonded to neighboring residues + - Determine whether the residue is bonded to neighbouring residues (previous/next in sequence) using MDAnalysis bonded selections. - If there are *no* bonds to other residues: * Use a custom principal axes, from a moment-of-inertia (MOI) tensor @@ -76,7 +76,10 @@ def get_residue_axes(self, data_container, index: int, residue=None): * Set translational axes equal to rotational axes (as per the original code convention). - If bonded to other residues: - * Use default axes and MOI (MDAnalysis principal axes / inertia). + Find edge heavy atoms (i.e. heavy atoms bonded to neighbour residues) + and find the shortest chain between them: the backbone. Edge + atoms + backbone COM are used to determine UA translational axes + (see get_residue_custom_axes) Args: data_container (MDAnalysis.Universe or AtomGroup): @@ -99,43 +102,85 @@ def get_residue_axes(self, data_container, index: int, residue=None): If the residue selection is empty. """ # TODO refine selection so that it will work for branched polymers + # match indexing to MDAnalysis indexing index_prev = index - 1 index_next = index + 1 - if residue is None: residue = data_container.select_atoms(f"resindex {index}") + # residue of interest if len(residue) == 0: raise ValueError(f"Empty residue selection for resindex={index}") - - center = residue.atoms.center_of_mass(unwrap=True) - atom_set = data_container.select_atoms( - f"(resindex {index_prev} or resindex {index_next}) and bonded resid {index}" + edge_atom_set = data_container.atoms.select_atoms( + f" resindex {index} and " + f"(bonded resindex {index_prev} or " + f"resindex {index_next})" ) - if len(atom_set) == 0: - # No bonds to other residues. + uas = residue.select_atoms("mass 2 to 999") + ua_masses = self.get_UA_masses(residue) + + if len(edge_atom_set) == 0: + # No UAS are bonded to other residues # Use a custom principal axes, from a MOI tensor that uses positions of # heavy atoms only, but including masses of heavy atom + bonded H. - uas = residue.select_atoms("mass 2 to 999") - ua_masses = self.get_UA_masses(residue) moi_tensor = self.get_moment_of_inertia_tensor( - center_of_mass=center, + center_of_mass=np.array(residue.center_of_mass()), positions=uas.positions, masses=ua_masses, dimensions=data_container.dimensions[:3], ) rot_axes, moment_of_inertia = self.get_custom_principal_axes(moi_tensor) trans_axes = rot_axes # per original convention + center = np.array(residue.center_of_mass()) else: - # If bonded to other residues, use default axes and MOI. + # If bonded to other residues, use local axes. make_whole(data_container.atoms) trans_axes = data_container.atoms.principal_axes() - rot_axes, moment_of_inertia = self.get_vanilla_axes(residue) - center = residue.center_of_mass(unwrap=True) - + residue = data_container.residues[index] + if len(edge_atom_set) == 1: + if index == 0: + # first residue + # use first heavy atom + edges = [residue.atoms[0], edge_atom_set[0]] + backbone = self.get_chain( + residue, residue.atoms[0], edge_atom_set[0] + ) + else: + # last residue + last_index = len(uas) - 1 + last = None + # look for last heavy atom + # with only one bond to another + while last_index > 0 and last is None: + heavy_atom = uas[last_index] + bonded_atoms = residue.atoms.select_atoms( + f"(mass 2 to 999) and bonded index {heavy_atom.index}" + ) + if len(bonded_atoms) == 1: + last = heavy_atom + else: + last_index -= 1 + edges = [edge_atom_set[0], last] + backbone = self.get_chain(residue, edge_atom_set[0], last) + else: + # residue has two bonds to other residues + edges = [edge_atom_set[0], edge_atom_set[1]] + backbone = self.get_chain(residue, edge_atom_set[0], edge_atom_set[1]) + # get edge atoms of the residue + # for terminal residues, this will include the C/N terminus + center = np.array(backbone.center_of_mass()) + rot_axes = self.get_residue_custom_axes(edges, center) + + moment_of_inertia = self.get_custom_residue_moment_of_inertia( + center_of_mass=center, + positions=uas.positions, + masses=ua_masses, + custom_rot_axes=rot_axes, + dimensions=data_container.dimensions[:3], + ) return trans_axes, rot_axes, center, moment_of_inertia - def get_UA_axes(self, data_container, index: int): + def get_UA_axes(self, data_container, index: int, res_position): """Compute united-atom-level translational and rotational axes. The translational and rotational axes at the united-atom level. @@ -143,7 +188,12 @@ def get_UA_axes(self, data_container, index: int): This preserves the original behaviour and its rationale: - Translational axes: - Use the same custom principal-axes approach as residue level: + Use the same approach as residue level rotational. + Identify residue of interest and neighbours, then select + edge heavy atoms (i.e. heavy atoms bonded to neighbour residues) + and find the shortest chain between them: the backbone. Edge + atoms + backbone COM are used to determine UA translational axes + (see get_residue_custom_axes) compute a custom MOI tensor using heavy-atom coordinates but UA masses (heavy + bonded H masses), then compute the principal axes from it. @@ -158,7 +208,8 @@ def get_UA_axes(self, data_container, index: int): Molecule and trajectory data. index (int): Bead index (ordinal among heavy atoms). - + res_position: where the residue of interest is + in data_container Returns: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - trans_axes: Translational axes (3, 3). @@ -170,54 +221,164 @@ def get_UA_axes(self, data_container, index: int): IndexError: If `index` does not correspond to an existing heavy atom. ValueError: - If bonded-axis construction fails. + If axis construction fails. """ index = int(index) # bead index + heavy_atoms = data_container.atoms.select_atoms("mass 2 to 999") # use the same customPI trans axes as the residue level - heavy_atoms = data_container.select_atoms("prop mass > 1.1") if len(heavy_atoms) > 1: - UA_masses = self.get_UA_masses(data_container.atoms) - center = data_container.atoms.center_of_mass(unwrap=True) - moment_of_inertia_tensor = self.get_moment_of_inertia_tensor( - center, heavy_atoms.positions, UA_masses, data_container.dimensions[:3] - ) - trans_axes, _moment_of_inertia = self.get_custom_principal_axes( - moment_of_inertia_tensor - ) + if len(data_container.residues) == 1: + # only the one residue => use principal axes + residue = data_container + trans_center = data_container.atoms.center_of_mass(unwrap=True) + trans_axes = data_container.atoms.principal_axes() + else: + # residue of interest has at least one neighbour + if res_position == -1: + residue = data_container.residues[0] + index_next = residue.resid + 1 + # the .resid attribute gives 1-indexing + # substract 1 to match indexing later + second_edge = data_container.atoms.select_atoms( + f"resindex {residue.resid - 1} and " + f"bonded resindex {index_next - 1}" + ) + edges = [residue.atoms[0], second_edge.atoms[0]] + backbone = self.get_chain( + residue, residue.atoms[0], second_edge.atoms[0] + ) + + elif res_position == 0: + # between 2 residues + residue = data_container.residues[1] + index_prev = residue.resid - 1 + index_next = residue.resid + 1 + edge_set = data_container.atoms.select_atoms( + f"resindex {residue.resid - 1} and " + f"(bonded resindex {index_next - 1} or " + f"resindex {index_prev - 1})" + ) + edges = [edge_set[0], edge_set[1]] + backbone = self.get_chain(residue, edge_set[0], edge_set[1]) + + else: + # last resid + residue = data_container.residues[1] + index_prev = residue.resid - 1 + first_edge = data_container.atoms.select_atoms( + f"resindex {residue.resid - 1} and " + f"bonded resindex {index_prev - 1}" + ) + last_index = len(heavy_atoms) - 1 + last = None + # look for last heavy atom + # with only one bond to another + while last_index > 0 and last is None: + heavy_atom = heavy_atoms[last_index] + bonded_atoms = residue.atoms.select_atoms( + f"(mass 2 to 999) and bonded index {heavy_atom.index}" + ) + if len(bonded_atoms) == 1: + last = heavy_atom + else: + last_index -= 1 + edges = [first_edge.atoms[0], last] + backbone = self.get_chain(residue, first_edge.atoms[0], last) + + trans_center = np.array(backbone.center_of_mass()) + trans_axes = self.get_residue_custom_axes(edges, trans_center) + else: - # use standard PA for UA not bonded to anything else + # only one heavy atom or hydrogen molecule make_whole(data_container.atoms) + residue = data_container + # trans_center is center of mass + trans_center = np.array(data_container.center_of_mass()) trans_axes = data_container.atoms.principal_axes() + residue_heavy_atoms = residue.atoms.select_atoms("mass 2 to 999") # look for heavy atoms in residue of interest heavy_atom_indices = [] - for atom in heavy_atoms: + for atom in residue_heavy_atoms: heavy_atom_indices.append(atom.index) # we find the nth heavy atom # where n is the bead index heavy_atom_index = heavy_atom_indices[index] - heavy_atom = data_container.select_atoms(f"index {heavy_atom_index}") + heavy_atom = residue.atoms.select_atoms(f"index {heavy_atom_index}") + + if trans_axes is None: + raise ValueError("Unable to compute translation axes for UA bead.") - center = heavy_atom.positions[0] + rot_center = heavy_atom.positions[0] rot_axes, moment_of_inertia = self.get_bonded_axes( system=data_container, atom=heavy_atom[0], dimensions=data_container.dimensions[:3], ) + if rot_axes is None or moment_of_inertia is None: raise ValueError("Unable to compute bonded axes for UA bead.") - logger.debug(f"Translational Axes: {trans_axes}") - logger.debug(f"Rotational Axes: {rot_axes}") - logger.debug(f"Center: {center}") - logger.debug(f"Moment of Inertia: {moment_of_inertia}") + logger.debug("Translational Axes: %s", trans_axes) + logger.debug("Rotational Axes: %s", rot_axes) + logger.debug("Translational center: %s", trans_center) + logger.debug("Rotational center: %s", rot_center) + logger.debug("Moment of Inertia: %s", moment_of_inertia) - return trans_axes, rot_axes, center, moment_of_inertia + return trans_axes, rot_axes, rot_center, moment_of_inertia + + def get_residue_custom_axes(self, edges, center): + """ + Compute rotation axes at the residue level, given + two edge atoms of the residue (E1+E2), + and the rotation centre (O). + - x axis is O-E1 + - y axis is O-Q (perpendicular to O-E1 in the + same plane as E2) + - z axis is perpendicular to the two other axes + + Q --- E2 + | | + | | + E1 ---- O --- P + Args: + edges: (2,3) positions of two edge atoms + center: (3,) coordinates of the rotation centre + Returns: + rot_axes: (3,3) rotation axes of residue + """ + # x axis is O-E1 + E1O_vector = center - edges[0].position + x_axis = -E1O_vector + # y axis is perpendicular to x + # in the same plane as E2 + # look for projection of E1-E2 on E1-O (E1-P) + E1E2_vector = edges[1].position - edges[0].position + projection = ( + np.dot(E1O_vector, E1E2_vector) / (np.linalg.norm(E1O_vector) ** 2) + ) * E1O_vector + # get the perpendicular onto E1-O (P-E2) + # P-E2 = P-E1 + E1-E2 + perpendicular = E1E2_vector - projection + # get the perpendicular through O (Q-O) + # first get P-Q diagonal through paralellogram rule + # P- Q = P-E2 + P-O + diagonal = -(projection - E1O_vector) + perpendicular + # get the parallel of P-E2 through O + # OQ = OP + PQ + y_axis = (projection - E1O_vector) + diagonal + z_axis = np.cross(x_axis, y_axis) + x_axis /= np.linalg.norm(x_axis) + y_axis /= np.linalg.norm(y_axis) + z_axis /= np.linalg.norm(z_axis) + rot_axes = np.array([x_axis, y_axis, z_axis]) + + return rot_axes def get_bonded_axes(self, system, atom, dimensions: np.ndarray): - r"""Compute UA rotational axes from bonded topology around a heavy atom. + """Compute UA rotational axes from bonded topology around a heavy atom. For a given heavy atom, use its bonded atoms to get the axes for rotating forces around. Few cases for choosing united atom axes, which are dependent @@ -446,6 +607,41 @@ def get_custom_axes( scaled_custom_axes = unscaled_custom_axes / mod[:, np.newaxis] return scaled_custom_axes + def get_custom_residue_moment_of_inertia( + self, + center_of_mass: np.ndarray, + positions: np.ndarray, + masses: np.ndarray, + custom_rot_axes: np.ndarray, + dimensions: np.ndarray, + ): + """ + Compute moment of inertia around custom axes for a bead + formed of multiple UAs. + + Args: + center_of_mass: (3, ) COM for bead + positions: (N,3) positions of the UAs in the bead + masses: (N,) masses of the UAs in the bead + custom_rot_axes: (3,3) array of residue rotation axes + dimensions: (3,) simulation_box_dimensions + + Returns: + np.ndarray: (3,) moment of inertia array. + + """ + + translated_coords = self.get_vector(center_of_mass, positions, dimensions) + custom_moment_of_inertia = np.zeros(3, dtype=float) + + for coord, mass in zip(translated_coords, masses, strict=True): + axis_component = np.sum( + np.cross(custom_rot_axes, coord) ** 2 * mass, axis=1 + ) + custom_moment_of_inertia += axis_component + + return custom_moment_of_inertia + def get_custom_moment_of_inertia( self, UA, @@ -636,3 +832,76 @@ def get_UA_masses(self, molecule) -> list[float]: ua_mass += float(h.mass) ua_masses.append(ua_mass) return ua_masses + + def get_chain(self, residue, first, last): + """ + For a given MDAnalysis AtomGroup and two given heavy atoms + within that AtomGroup, return the + shortest path between the two atoms. + Args: + residue: MDAnalysis AtomGroup representing + the residue/monomer of interest. + first: First heavy atom in the chain + last: Last heavy atom in the chain + + Returns: + chain: MDAnalysis AtomGroup containing + the chain heavy atoms. + """ + chain = [] + chain_indices = [] + # at the beggining we've only visited the first atom + visited_dict = {first: True} + # keep the previous atom to trace back the path + prev = {} + # queue of next heavy atoms to visit + next_to_visit = [first] + # all others heavy atoms in the residue, we have not yet visited + remaining_heavy_atoms = residue.atoms.select_atoms( + f"(mass 2 to 999) and not index {first.index}" + ) + for atom in remaining_heavy_atoms: + visited_dict[atom] = False + current = first + while not visited_dict[last]: + # we haven't found a path to the last residue + next_to_visit.pop(0) + # we're visiting the current atom => we remove it from the queue + bonded_atoms = residue.atoms.select_atoms( + f"(mass 2 to 999) and bonded index {current.index}" + ) + if last in bonded_atoms: + # we found a path to the last atom + visited_dict[last] = True + chain.append(last) + prev[last] = current + else: + for bonded_atom in bonded_atoms: + # look for unvisited bonded atoms to the current atom we're visiting + if not visited_dict[bonded_atom]: + # we're going to want to visit the atoms + next_to_visit.append(bonded_atom) + prev[bonded_atom] = current + # we visit the next atom in the queue + current = next_to_visit[0] + visited_dict[current] = True + + # we track the previous atom back to the first atom now + current = last + chain = [last] + # subtract index of first atom in resid + # most likely will coincide with first + # but this will work even if it doesn't + # accout for in-residue index + chain_indices = [last.index - residue.atoms.indices[0]] + # start from last atom in chain + while chain[-1] != first: + # we haven't yet returned to the first atom + current = prev[current] + chain.append(current) + chain_indices.append(current.index - residue.atoms.indices[0]) + chain_indices = np.flip(chain_indices) + # accout for in-residue index + chain_AtomGroup = residue.atoms[chain_indices] + chain = chain_AtomGroup.atoms.select_atoms("all") + return chain diff --git a/CodeEntropy/levels/nodes/covariance.py b/CodeEntropy/levels/nodes/covariance.py index 25375d4..cbf2c3b 100644 --- a/CodeEntropy/levels/nodes/covariance.py +++ b/CodeEntropy/levels/nodes/covariance.py @@ -200,7 +200,30 @@ def _process_united_atom( Returns: None. Mutates out_force/out_torque and molcount in-place. """ + for local_res_i, res in enumerate(mol.residues): + if len(mol.residues) > 1: + # there are multiple residues in the molecule + # build residue group here + if local_res_i == 0: + # first residue + res_position = -1 + res_next = mol.residues[1] + residue_group = res + res_next + elif local_res_i == len(mol.residues) - 1: + # last residue + res_position = 1 + res_prev = mol.residues[-2] + residue_group = res + res_prev + else: + res_position = 0 + res_prev = mol.residues[local_res_i - 1] + res_next = mol.residues[local_res_i + 1] + residue_group = res_prev + res + res_next + else: + # only one residue + res_position = None + residue_group = res bead_key = (mol_id, "united_atom", local_res_i) bead_idx_list = beads.get(bead_key, []) if not bead_idx_list: @@ -211,13 +234,14 @@ def _process_united_atom( continue force_vecs, torque_vecs = self._build_ua_vectors( - residue_atoms=res.atoms, + residue_group=residue_group.atoms, bead_groups=bead_groups, axes_manager=axes_manager, box=box, force_partitioning=force_partitioning, customised_axes=customised_axes, is_highest=is_highest, + res_position=res_position, ) F, T = self._ft.compute_frame_covariance(force_vecs, torque_vecs) @@ -413,23 +437,26 @@ def _build_ua_vectors( self, *, bead_groups: list[Any], - residue_atoms: Any, + residue_group: Any, axes_manager: Any, box: np.ndarray | None, force_partitioning: float, customised_axes: bool, is_highest: bool, + res_position: int, ) -> tuple[list[np.ndarray], list[np.ndarray]]: """Build force/torque vectors for UA-level beads of one residue. Args: bead_groups: List of UA bead AtomGroups for the residue. - residue_atoms: AtomGroup for the residue atoms (used for axes when vanilla). + residue_group: AtomGroup for the residue group atoms. axes_manager: Axes manager used to determine axes/centers/MOI. box: Optional box vector used for PBC-aware displacements. force_partitioning: Force scaling factor applied at highest level. customised_axes: Whether to use customised axes methods when available. is_highest: Whether UA level is the highest level for the molecule. + res_position: Where the residue is in the residue group + Returns: A tuple (force_vecs, torque_vecs), each a list of (3,) vectors ordered @@ -441,13 +468,21 @@ def _build_ua_vectors( for ua_i, bead in enumerate(bead_groups): if customised_axes: trans_axes, rot_axes, center, moi = axes_manager.get_UA_axes( - residue_atoms, ua_i + residue_group, ua_i, res_position ) else: - make_whole(residue_atoms) + make_whole(residue_group) make_whole(bead) - - trans_axes = residue_atoms.principal_axes() + if res_position == -1: + # first residue in group + residue = residue_group.residues[0] + elif res_position == 0 or res_position == 1: + # middle or last residue => second in group + residue = residue_group.residues[1] + else: + # res_position is None bc there is only one residue + residue = residue_group + trans_axes = residue.atoms.principal_axes() rot_axes, moi = axes_manager.get_vanilla_axes(bead) center = bead.center_of_mass(unwrap=True) diff --git a/tests/unit/CodeEntropy/levels/nodes/test_frame_covariance_node.py b/tests/unit/CodeEntropy/levels/nodes/test_frame_covariance_node.py index 1b51006..00233f1 100644 --- a/tests/unit/CodeEntropy/levels/nodes/test_frame_covariance_node.py +++ b/tests/unit/CodeEntropy/levels/nodes/test_frame_covariance_node.py @@ -401,7 +401,7 @@ def test_build_ua_vectors_customised_axes_true_calls_get_UA_axes(): node = FrameCovarianceNode() bead = _BeadGroup(1) - residue_atoms = MagicMock() + residue_group = MagicMock() axes_manager = MagicMock() axes_manager.get_UA_axes.return_value = ( @@ -416,12 +416,13 @@ def test_build_ua_vectors_customised_axes_true_calls_get_UA_axes(): force_vecs, torque_vecs = node._build_ua_vectors( bead_groups=[bead], - residue_atoms=residue_atoms, + residue_group=residue_group, axes_manager=axes_manager, box=np.array([10.0, 10.0, 10.0]), force_partitioning=1.0, customised_axes=True, is_highest=True, + res_position=None, ) axes_manager.get_UA_axes.assert_called_once() @@ -451,12 +452,13 @@ def test_build_ua_vectors_vanilla_path_uses_principal_axes_and_vanilla_axes( force_vecs, torque_vecs = node._build_ua_vectors( bead_groups=[bead], - residue_atoms=residue_atoms, + residue_group=residue_atoms, axes_manager=axes_manager, box=np.array([10.0, 10.0, 10.0]), force_partitioning=1.0, customised_axes=False, is_highest=True, + res_position=None, ) axes_manager.get_vanilla_axes.assert_called_once() diff --git a/tests/unit/CodeEntropy/levels/test_axes.py b/tests/unit/CodeEntropy/levels/test_axes.py index d6d0093..4fc0c56 100644 --- a/tests/unit/CodeEntropy/levels/test_axes.py +++ b/tests/unit/CodeEntropy/levels/test_axes.py @@ -184,7 +184,7 @@ def _sel(q): lambda system, atom, dimensions: (np.eye(3), np.array([1.0, 2.0, 3.0])), ) - trans, rot, center, moi = ax.get_UA_axes(u, index=0) + trans, rot, center, moi = ax.get_UA_axes(u, index=0, res_position=None) assert np.allclose(trans, np.eye(3)) assert np.allclose(rot, np.eye(3)) @@ -217,7 +217,7 @@ def _sel(q): monkeypatch.setattr(ax, "get_bonded_axes", lambda **kwargs: (None, None)) with pytest.raises(ValueError): - ax.get_UA_axes(u, index=0) + ax.get_UA_axes(u, index=0, res_position=None) def test_get_custom_axes_degenerate_axis1_raises(): @@ -531,7 +531,7 @@ def _select_atoms(q): ax, "get_vanilla_axes", lambda mol: (np.eye(3) * 2, np.array([9.0, 8.0, 7.0])) ) - trans, rot, center, moi = ax.get_residue_axes(u, index=10) + trans, rot, center, moi = ax.get_residue_axes(u, index=10, residue=residue) assert np.allclose(trans, np.eye(3) * 2) assert np.allclose(rot, np.eye(3) * 2) @@ -664,7 +664,9 @@ def _select_atoms(q): got_custom_axes = MagicMock(return_value=(np.eye(3), np.array([3.0, 2.0, 1.0]))) monkeypatch.setattr(ax, "get_custom_principal_axes", got_custom_axes) - trans_axes, rot_axes, center, moi = ax.get_UA_axes(data_container, index=0) + trans_axes, rot_axes, center, moi = ax.get_UA_axes( + data_container, index=0, res_position=None + ) assert trans_axes.shape == (3, 3) assert rot_axes.shape == (3, 3)