diff --git a/hiqbind/bioassembly.py b/hiqbind/bioassembly.py new file mode 100644 index 0000000..01a9d45 --- /dev/null +++ b/hiqbind/bioassembly.py @@ -0,0 +1,521 @@ + + + +def duplicate_missing_info_by_chain_map(missing_residues_added, missing_atoms_added, chain_map): + from collections import defaultdict + dup_missing_residues = [] + dup_missing_atoms = [] + + chain_targets = defaultdict(list) + for src_chain, out_chain in chain_map: + chain_targets[src_chain].append(out_chain) + + for chain, res_id, res_name in missing_residues_added: + for out_chain in chain_targets.get(chain, []): + dup_missing_residues.append((out_chain, res_id, res_name)) + + for chain, res_id, res_name, atoms in missing_atoms_added: + for out_chain in chain_targets.get(chain, []): + dup_missing_atoms.append((out_chain, res_id, res_name, list(atoms))) + + return dup_missing_residues, dup_missing_atoms + +def _find_assembly_in_structure(st, assembly_name): + for assembly in st.assemblies: + if assembly.name == str(assembly_name): + return assembly + return None + + +def _build_subchain_maps_from_model(model): + exact_map = {} + seq_map = {} + subchain_to_chain = {} + + for chain in model: + for residue in chain: + subchain = residue.subchain + if not subchain: + subchain = chain.name + + key_exact = (chain.name, residue.seqid.num, residue.seqid.icode, residue.name) + key_seq = (chain.name, residue.seqid.num, residue.seqid.icode) + + exact_map[key_exact] = subchain + if key_seq not in seq_map: + seq_map[key_seq] = subchain + if subchain not in subchain_to_chain: + subchain_to_chain[subchain] = chain.name + + return exact_map, seq_map, subchain_to_chain + + +def extract_assembly_instruction_gemmi(cif_file, assembly_name='1'): + import gemmi + import numpy as np + + st = gemmi.read_structure(cif_file) + assembly = _find_assembly_in_structure(st, assembly_name) + if assembly is None: + print(f'Assembly {assembly_name} not found in {cif_file}') + return None + + _, _, subchain_to_chain = _build_subchain_maps_from_model(st[0]) + + instructions = [] + for i_gen, gen in enumerate(assembly.generators): + source_chains = list(gen.chains) + source_subchains = list(gen.subchains) + + if len(source_chains) == 0 and len(source_subchains) > 0: + for subchain in source_subchains: + chain_name = subchain_to_chain.get(subchain, subchain) + if chain_name not in source_chains: + source_chains.append(chain_name) + + for i_op, oper in enumerate(gen.operators): + tr = oper.transform + instructions.append({ + 'assembly_name': assembly.name, + 'generator_index': i_gen, + 'operator_index': i_op, + 'operator_name': oper.name, + 'operator_type': oper.type, + 'chains': source_chains, + 'subchains': source_subchains, + 'matrix': np.asarray(tr.mat.tolist(), dtype=float), + 'vector': np.asarray(tr.vec.tolist(), dtype=float), + }) + + return instructions + + +def _structure_to_numpy_arrays(st): + import numpy as np + + model = st[0] + atom_rows = [] + coords = [] + + for chain in model: + for residue in chain: + subchain = residue.subchain + if not subchain: + subchain = chain.name + + for atom in residue: + atom_rows.append({ + 'chain_id': chain.name, + 'subchain_id': subchain, + 'resname': residue.name, + 'resseq': residue.seqid.num, + 'icode': residue.seqid.icode, + 'atom_name': atom.name, + }) + coords.append([atom.pos.x, atom.pos.y, atom.pos.z]) + + coords = np.asarray(coords, dtype=float) + return coords, atom_rows + + +def read_pdb_coords_gemmi(pdb_file): + import gemmi + + st = gemmi.read_structure(pdb_file) + coords, atom_rows = _structure_to_numpy_arrays(st) + return coords, atom_rows, st + + +def _copy_subchain_ids_from_original_cif_to_refined_pdb(original_cif_file, refined_pdb_file): + import gemmi + + original_st = gemmi.read_structure(original_cif_file) + refined_st = gemmi.read_structure(refined_pdb_file) + + exact_map, seq_map, _ = _build_subchain_maps_from_model(original_st[0]) + + for chain in refined_st[0]: + for residue in chain: + key_exact = (chain.name, residue.seqid.num, residue.seqid.icode, residue.name) + key_seq = (chain.name, residue.seqid.num, residue.seqid.icode) + + if key_exact in exact_map: + residue.subchain = exact_map[key_exact] + elif key_seq in seq_map: + residue.subchain = seq_map[key_seq] + else: + residue.subchain = chain.name + + return original_st, refined_st + + +def _prune_assembly_to_present_content(assembly, model): + import gemmi + + present_subchains = set() + present_chains = set() + + for chain in model: + present_chains.add(chain.name) + for residue in chain: + subchain = residue.subchain + if not subchain: + subchain = chain.name + present_subchains.add(subchain) + + pruned = gemmi.Assembly(assembly.name) + pruned.author_determined = assembly.author_determined + pruned.software_determined = assembly.software_determined + pruned.oligomeric_details = assembly.oligomeric_details + pruned.special_kind = assembly.special_kind + + for gen in assembly.generators: + new_gen = gemmi.Assembly.Gen() + + if len(gen.subchains) > 0: + kept_subchains = [s for s in gen.subchains if s in present_subchains] + if len(kept_subchains) == 0: + continue + new_gen.subchains = kept_subchains + else: + kept_chains = [c for c in gen.chains if c in present_chains] + if len(kept_chains) == 0: + continue + new_gen.chains = kept_chains + + for op in gen.operators: + new_gen.operators.append(op) + + pruned.generators.append(new_gen) + + return pruned + + +def apply_assembly_instructions_gemmi(original_cif_file, refined_pdb_file, assembly_name='1'): + import numpy as np + + instructions = extract_assembly_instruction_gemmi(original_cif_file, assembly_name=assembly_name) + if instructions is None: + return None + + _, refined_st = _copy_subchain_ids_from_original_cif_to_refined_pdb(original_cif_file, refined_pdb_file) + coords, atom_rows = _structure_to_numpy_arrays(refined_st) + + present_subchains = set(row['subchain_id'] for row in atom_rows) + present_chains = set(row['chain_id'] for row in atom_rows) + + copies = [] + for inst in instructions: + atom_indices = [] + + if len(inst['subchains']) > 0: + wanted_subchains = [s for s in inst['subchains'] if s in present_subchains] + for i_atom, row in enumerate(atom_rows): + if row['subchain_id'] in wanted_subchains: + atom_indices.append(i_atom) + else: + wanted_chains = [c for c in inst['chains'] if c in present_chains] + for i_atom, row in enumerate(atom_rows): + if row['chain_id'] in wanted_chains: + atom_indices.append(i_atom) + + atom_indices = np.asarray(atom_indices, dtype=int) + coords_copy = coords[atom_indices] @ inst['matrix'].T + inst['vector'] + + copies.append({ + 'assembly_name': inst['assembly_name'], + 'generator_index': inst['generator_index'], + 'operator_index': inst['operator_index'], + 'operator_name': inst['operator_name'], + 'operator_type': inst['operator_type'], + 'chains': list(inst['chains']), + 'subchains': list(inst['subchains']), + 'subchains_present': wanted_subchains if len(inst['subchains']) > 0 else [], + 'atom_indices': atom_indices, + 'coords': coords_copy, + 'matrix': inst['matrix'], + 'vector': inst['vector'], + }) + + return copies + + +def write_biological_assembly_pdb_gemmi(original_cif_file, + refined_pdb_file, + pdb_out, + assembly_name='1', + chain_naming='Short', + merge_dist=0.0): + import gemmi + + original_st, refined_st = _copy_subchain_ids_from_original_cif_to_refined_pdb(original_cif_file, refined_pdb_file) + + assembly = _find_assembly_in_structure(original_st, assembly_name) + if assembly is None: + print(f'Assembly {assembly_name} not found in {original_cif_file}') + return None + + pruned_assembly = _prune_assembly_to_present_content(assembly, refined_st[0]) + input_model = refined_st[0] + input_chain_count = len(input_model) + how = getattr(gemmi.HowToNameCopiedChain, chain_naming) + assembled_model = gemmi.make_assembly(pruned_assembly, refined_st[0], how) + has_new_assembled_chain = len(assembled_model) > input_chain_count + if merge_dist > 0: + gemmi.merge_atoms_in_expanded_model(assembled_model, gemmi.UnitCell(), max_dist=merge_dist) + + out_st = refined_st.clone() + while len(out_st) > 1: + del out_st[1] + out_st[0] = assembled_model + out_st.assign_serial_numbers() + out_st.write_pdb(pdb_out) + + return out_st, has_new_assembled_chain + + +def _get_chain_order_from_model(model): + return [chain.name for chain in model] + + +def _get_generator_source_chain_order(pruned_assembly, source_model): + out = [] + + for gen in pruned_assembly.generators: + if len(gen.subchains) > 0: + source_chains = [] + for chain in source_model: + chain_subchains = set() + for residue in chain: + subchain = residue.subchain + if not subchain: + subchain = chain.name + chain_subchains.add(subchain) + + if len(chain_subchains.intersection(set(gen.subchains))) > 0: + source_chains.append(chain.name) + else: + source_chains = list(gen.chains) + + for op in gen.operators: + for chain_name in source_chains: + out.append(chain_name) + + return out + + +def get_assembly_chain_map_gemmi(original_cif_file, source_pdb_file, assembled_pdb_file, assembly_name='1'): + import gemmi + + original_st, source_st = _copy_subchain_ids_from_original_cif_to_refined_pdb(original_cif_file, source_pdb_file) + + assembly = _find_assembly_in_structure(original_st, assembly_name) + pruned_assembly = _prune_assembly_to_present_content(assembly, source_st[0]) + + source_chain_order = _get_generator_source_chain_order(pruned_assembly, source_st[0]) + + assembled_st = gemmi.read_structure(assembled_pdb_file) + output_chain_order = _get_chain_order_from_model(assembled_st[0]) + + return list(zip(source_chain_order, output_chain_order)) + + + +def _replace_char_in_line(line, idx, new_char): + return line[:idx] + new_char + line[idx + 1:] +def ensure_ter_between_chains_in_pdb(pdb_file): + out_lines = [] + last_chain = None + last_was_ter = False + + with open(pdb_file, "r") as f: + for line in f: + if line.startswith(("ATOM ", "HETATM")): + chain = line[21] + if (last_chain is not None) and (chain != last_chain) and (not last_was_ter): + out_lines.append("TER\n") + last_was_ter = True + + out_lines.append(line) + last_chain = chain + last_was_ter = False + continue + + if line.startswith("TER"): + out_lines.append(line if line.endswith("\n") else line + "\n") + last_was_ter = True + continue + + out_lines.append(line) + + if (last_chain is not None) and (not last_was_ter): + out_lines.append("TER\n") + + with open(pdb_file, "w") as f: + f.writelines(out_lines) + +def rewrite_assembled_pdb_header_with_chain_annotations(source_pdb_file, assembled_pdb_file, chain_map): + from collections import defaultdict + from hiqbind.fix_protein import convert_to_seqres + def _replace_char(line, idx, new_char): + return line[:idx] + new_char + line[idx + 1:] + + def _is_remark465_data_line(line): + return ( + line.startswith('REMARK 465') + and len(line) >= 27 + and line[19].strip() != '' + and line[21:26].strip().isdigit() + ) + + def _is_remark470_data_line(line): + return ( + line.startswith('REMARK 470') + and len(line) >= 27 + and line[19].strip() != '' + and line[20:26].strip() != '' + ) + + def _make_ter_line(serial, last_coord_line): + resname = last_coord_line[17:20] + chain = last_coord_line[21] + resseq = last_coord_line[22:26] + icode = last_coord_line[26] + return f"TER {serial:>5d} {resname} {chain}{resseq}{icode}\n" + + with open(source_pdb_file, 'r') as f: + source_lines = f.readlines() + + source_header = [] + for line in source_lines: + if line.startswith(('ATOM ', 'HETATM', 'MODEL ')): + break + source_header.append(line.rstrip('\n')) + + with open(assembled_pdb_file, 'r') as f: + assembled_lines = f.readlines() + + body_start = 0 + for i, line in enumerate(assembled_lines): + if line.startswith(('ATOM ', 'HETATM', 'MODEL ')): + body_start = i + break + + body_lines = assembled_lines[body_start:] + + other_header = [] + source_seq_by_chain = defaultdict(list) + modres_by_src = defaultdict(list) + remark465_prefix = [] + remark465_by_src = defaultdict(list) + remark470_prefix = [] + remark470_by_src = defaultdict(list) + + for line in source_header: + if line.startswith('SEQRES'): + chain = line[11] + source_seq_by_chain[chain].extend(line[19:].split()) + + elif line.startswith('MODRES'): + modres_by_src[line[16]].append(line) + + elif line.startswith('REMARK 465'): + if _is_remark465_data_line(line): + remark465_by_src[line[19]].append(line) + else: + remark465_prefix.append(line) + + elif line.startswith('REMARK 470'): + if _is_remark470_data_line(line): + remark470_by_src[line[19]].append(line) + else: + remark470_prefix.append(line) + + else: + other_header.append(line) + + grouped_targets = defaultdict(list) + for src_chain, out_chain in chain_map: + grouped_targets[src_chain].append(out_chain) + + new_header = [] + new_header.extend(other_header) + + # regenerate SEQRES from full per-chain sequence + for src_chain, out_chains in grouped_targets.items(): + seq = source_seq_by_chain.get(src_chain, []) + for out_chain in out_chains: + if len(seq) > 0: + new_header.extend(convert_to_seqres(seq, out_chain).split('\n')) + + # duplicate MODRES + for src_chain, out_chains in grouped_targets.items(): + for out_chain in out_chains: + for line in modres_by_src.get(src_chain, []): + new_header.append(_replace_char(line, 16, out_chain)) + + # preserve REMARK 465 prefix once, duplicate chain-specific rows + new_header.extend(remark465_prefix) + for src_chain, out_chains in grouped_targets.items(): + for out_chain in out_chains: + for line in remark465_by_src.get(src_chain, []): + new_header.append(_replace_char(line, 19, out_chain)) + + # preserve REMARK 470 prefix once, duplicate chain-specific rows + new_header.extend(remark470_prefix) + for src_chain, out_chains in grouped_targets.items(): + for out_chain in out_chains: + for line in remark470_by_src.get(src_chain, []): + new_header.append(_replace_char(line, 19, out_chain)) + + # rewrite body with fresh serials and proper TER records + new_body = [] + trailing = [] + + serial = 1 + last_coord_line = None + last_segment_key = None + + for line in body_lines: + rec = line[:6] + + if rec in ('ATOM ', 'HETATM'): + chain = line[21] + record_type = rec.strip() + segment_key = (record_type, chain) + + if (last_segment_key is not None) and (segment_key != last_segment_key): + new_body.append(_make_ter_line(serial, last_coord_line)) + serial += 1 + + new_line = line[:6] + f"{serial:>5d}" + line[11:] + if not new_line.endswith('\n'): + new_line += '\n' + + new_body.append(new_line) + last_coord_line = new_line + last_segment_key = segment_key + serial += 1 + continue + + if rec == 'TER ': + continue + + if line.startswith('END') or line.startswith('CONECT'): + trailing.append(line if line.endswith('\n') else line + '\n') + + if last_coord_line is not None: + new_body.append(_make_ter_line(serial, last_coord_line)) + serial += 1 + + if len(trailing) == 0: + trailing = ['END\n'] + + with open(assembled_pdb_file, 'w') as f: + for line in new_header: + f.write(line + '\n') + for line in new_body: + f.write(line) + for line in trailing: + f.write(line) + diff --git a/hiqbind/fix_protein.py b/hiqbind/fix_protein.py index 6149265..b7f0da5 100644 --- a/hiqbind/fix_protein.py +++ b/hiqbind/fix_protein.py @@ -98,7 +98,7 @@ def select(self, select: Select): modeller.delete(to_delete) return Structure(modeller.topology, modeller.positions) - def save(self, file, select: Optional[Select] = None, keepIds: bool = True, header: str = '', res_num_mapping: Optional[Dict[str, Dict[int, str]]] = None): + def save(self, file, select=None, keepIds=True, header='', res_num_mapping=None): if isinstance(file, str) or isinstance(file, os.PathLike): fp = open(file, 'w') self.save(fp, select, keepIds, header, res_num_mapping) @@ -111,12 +111,18 @@ def save(self, file, select: Optional[Select] = None, keepIds: bool = True, head top, pos = self.topology, self.positions if res_num_mapping: + reverse_mapping = {} + for chain, mapping in res_num_mapping.items(): + reverse_mapping[chain] = {v: k for k, v in mapping.items()} + for residue in top.residues(): chain = residue.chain.id - new_res_num = [k for k, v in res_num_mapping[chain].items() if v == f"{residue.id}{residue.insertionCode}".strip()][0] - residue.id = str(new_res_num) - residue.insertionCode = " " - + old_id = f"{residue.id}{residue.insertionCode}".strip() + + if (chain in reverse_mapping) and (old_id in reverse_mapping[chain]): + residue.id = str(reverse_mapping[chain][old_id]) + residue.insertionCode = " " + app.PDBFile.writeHeader(top, file) if header: print(header[:-1] if header[-1] == '\n' else header, file=file) @@ -198,6 +204,118 @@ def to_quantity(ndarray): return quantity +def save_openmm_refinement_bundle(bundle_dir, topology, positions, system, integrator, state, metadata=None): + import json + + os.makedirs(bundle_dir, exist_ok=True) + + topology_pdb = os.path.join(bundle_dir, 'topology.pdb') + system_xml = os.path.join(bundle_dir, 'system.xml') + integrator_xml = os.path.join(bundle_dir, 'integrator.xml') + state_xml = os.path.join(bundle_dir, 'state.xml') + metadata_json = os.path.join(bundle_dir, 'metadata.json') + + with open(topology_pdb, 'w') as f: + app.PDBFile.writeFile(topology, positions, f, keepIds=True) + + with open(system_xml, 'w') as f: + f.write(mm.XmlSerializer.serialize(system)) + + with open(integrator_xml, 'w') as f: + f.write(mm.XmlSerializer.serialize(integrator)) + + with open(state_xml, 'w') as f: + f.write(mm.XmlSerializer.serialize(state)) + + if metadata is None: + metadata = {} + else: + metadata = dict(metadata) + + metadata.update({ + 'topology_pdb': os.path.basename(topology_pdb), + 'system_xml': os.path.basename(system_xml), + 'integrator_xml': os.path.basename(integrator_xml), + 'state_xml': os.path.basename(state_xml), + 'num_atoms': int(topology.getNumAtoms()), + 'num_residues': int(topology.getNumResidues()), + }) + + with open(metadata_json, 'w') as f: + json.dump(metadata, f, indent=2, sort_keys=True) + + return { + 'bundle_dir': bundle_dir, + 'topology_pdb': topology_pdb, + 'system_xml': system_xml, + 'integrator_xml': integrator_xml, + 'state_xml': state_xml, + 'metadata_json': metadata_json, + 'num_atoms': int(topology.getNumAtoms()), + 'num_residues': int(topology.getNumResidues()), + } + + + +def load_openmm_refinement_bundle(bundle_dir, platform_name=None, platform_properties=None, set_state=True): + import json + + metadata_json = os.path.join(bundle_dir, 'metadata.json') + with open(metadata_json) as f: + metadata = json.load(f) + + topology_pdb = os.path.join(bundle_dir, metadata['topology_pdb']) + system_xml = os.path.join(bundle_dir, metadata['system_xml']) + integrator_xml = os.path.join(bundle_dir, metadata['integrator_xml']) + state_xml = os.path.join(bundle_dir, metadata['state_xml']) + + pdb = app.PDBFile(topology_pdb) + + with open(system_xml) as f: + system = mm.XmlSerializer.deserialize(f.read()) + + with open(integrator_xml) as f: + integrator = mm.XmlSerializer.deserialize(f.read()) + + with open(state_xml) as f: + state = mm.XmlSerializer.deserialize(f.read()) + + if platform_name is None: + sim = app.Simulation(pdb.topology, system, integrator) + else: + platform = mm.Platform.getPlatformByName(platform_name) + if platform_properties is None: + sim = app.Simulation(pdb.topology, system, integrator, platform) + else: + sim = app.Simulation(pdb.topology, system, integrator, platform, platform_properties) + + if set_state: + try: + sim.context.setState(state) + except Exception: + sim.context.setPositions(pdb.positions) + try: + sim.context.setPositions(state.getPositions()) + except Exception: + pass + try: + sim.context.setVelocities(state.getVelocities()) + except Exception: + pass + else: + sim.context.setPositions(pdb.positions) + + return { + 'metadata': metadata, + 'topology': pdb.topology, + 'positions': pdb.positions, + 'system': system, + 'integrator': integrator, + 'state': state, + 'simulation': sim, + } + + class StandardizedPDBFixer(PDBFixer): """ A class to fix standarized PDB file. Here the standardized means that: @@ -211,9 +329,10 @@ def __init__(self, protein_pdb: os.PathLike, ligand_sdf: Optional[os.PathLike] = if ligand_sdf: self.addLigand(ligand_sdf) self._has_ligand = True + else: self._has_ligand = False - + #print(self._has_ligand) self.mod_res_info = [] self.missing_residues_added = [] self.missing_residues_skipped = [] @@ -369,42 +488,126 @@ def getModresRecords(cls, mod_res_info: List[Tuple[str, int, str, str, str]], re # TODO: need to figure out how to deal with modfied residues with insertion code (Eric) modres_lines.append(f"MODRES {pdb_id:>4} {res_name:3} {chain:1} {res_id:>4}{icode:1} {std_res_name:3} MODIFIED RESIDUE") return modres_lines - + + def _get_nonpro_peptide_omega_atom_quads(self): + + atom_index_by_residue = {} + + for residue in self.topology.residues(): + atom_index_by_residue[residue.index] = {atom.name: atom.index for atom in residue.atoms()} + + quads = [] + residues = list(self.topology.residues()) + + for i in range(1, len(residues)): + res_prev = residues[i - 1] + res_curr = residues[i] + + if res_prev.chain.index != res_curr.chain.index: + continue + + if res_curr.name == 'PRO': + continue + + names_prev = atom_index_by_residue[res_prev.index] + names_curr = atom_index_by_residue[res_curr.index] + + if ('CA' in names_prev) and ('C' in names_prev) and ('N' in names_curr) and ('CA' in names_curr): + quads.append(( + names_prev['CA'], + names_prev['C'], + names_curr['N'], + names_curr['CA'], + )) + + return quads + def add_trans_peptide_restraints(self, system, omega_atom_quads, k_kj_per_mol=10.0): + import openmm as mm + from openmm import unit + force = mm.PeriodicTorsionForce() + k = k_kj_per_mol * unit.kilojoule_per_mole + + for a, b, c, d in omega_atom_quads: + force.addTorsion(a, b, c, d, 1, 0.0, k) + + system.addForce(force) + return force + def get_edge_residue_of_added_residue(self,missing_atoms_info_as_dict): + all_residues = list(self.topology.residues()) + + missing_residue_ids = set() + for chain, res_id, res_name in self.missing_residues_added: + missing_residue_ids.add((chain, int(res_id))) + + for chain, res_id, res_name in missing_atoms_info_as_dict.keys(): + missing_residue_ids.add((chain, int(res_id))) + + edge_residue_ids = set(missing_residue_ids) + for chain, res_id in missing_residue_ids: + edge_residue_ids.add((chain, res_id - 1)) + edge_residue_ids.add((chain, res_id + 1)) + return edge_residue_ids def refineAddedAtomPositions(self, forcefield=None): if forcefield is None: - forcefield = app.ForceField('amber14-all.xml', 'tip3p.xml') + forcefield = app.ForceField('amber14-all.xml', 'amber14/tip3p.xml') # Conver List missing atoms information to dictionary, for better indexing missing_atoms_info_as_dict = defaultdict(list) for chain, res_id, res_name, atoms in self.missing_atoms_added: missing_atoms_info_as_dict[(chain, res_id, res_name)] += atoms - system = forcefield.createSystem(self.topology, nonbondedMethod=app.CutoffNonPeriodic, constraints=None, rigidWater=False) nonstd_names = [res.name for res, stdname in self.nonstandardResidues] - for residue in self.topology.residues(): + system = forcefield.createSystem(self.topology, nonbondedMethod=app.CutoffNonPeriodic, constraints=None, rigidWater=False) + original_masses = [system.getParticleMass(i) for i in range(system.getNumParticles())] + #omega_atom_quads = self._get_nonpro_peptide_omega_atom_quads() + #peptide_trans_force = self.add_trans_peptide_restraints( # TODO even though this removes most of cis peptide bonds it can still happen that the cis persist, but further increasing the restrain results in unphysical modelling. likely we need staging. + # system, + # omega_atom_quads, + # k_kj_per_mol=25.0, + #) + edge_residue_ids = self.get_edge_residue_of_added_residue(missing_atoms_info_as_dict) + for idx_residue, residue in enumerate(self.topology.residues()): resdata = (residue.chain.id, int(residue.id), residue.name) if resdata in self.missing_residues_added: - self.log(f'Found fixed residue: {residue}') + self.log(f'Found free residue: {residue}') continue - + for i, atom in enumerate(residue.atoms()): # Always constrained all atoms in modified residue, including hydrogens (because we don't have good force field) if residue.name in nonstd_names: system.setParticleMass(atom.index, 0.0) continue if (resdata in missing_atoms_info_as_dict) and (atom.name in missing_atoms_info_as_dict[resdata]): - self.log(f'Found fixed atom: {atom}') + self.log(f'Found free atom: {atom}') continue if atom.element is app.element.hydrogen: continue if (self._has_ligand) and (residue.index == self.topology.getNumResidues() - 1) and (i in self.ligand_missing_atoms): continue + + if (residue.chain.id, int(residue.id)) in edge_residue_ids: + self.log(f'Found edge residue near missing content: {residue}') + continue system.setParticleMass(atom.index, 0.0) integrator = mm.LangevinIntegrator(300*unit.kelvin, 10/unit.picosecond, 5*unit.femtosecond) - context = mm.Context(system, integrator) - context.setPositions(self.positions) - mm.LocalEnergyMinimizer.minimize(context, tolerance=10) - self.positions = context.getState(getPositions=True).getPositions() - return self.positions + sim = app.Simulation(self.topology, system, integrator) + sim.context.setPositions(self.positions) + mm.LocalEnergyMinimizer.minimize(sim.context, tolerance=10) + for i, mass in enumerate(original_masses): # NOTE restore the mass + system.setParticleMass(i, mass) + sim.context.reinitialize(preserveState=True) + state = sim.context.getState(getPositions=True, + getVelocities=True, + getEnergy=True, + getParameters=True) + self.positions = state.getPositions() + return { + 'topology': self.topology, + 'positions': self.positions, + 'system': system, + 'integrator': integrator, + 'state': state, + 'simulation': sim, + } def runFixWorkflow( self, @@ -415,7 +618,9 @@ def runFixWorkflow( skip_long_missing_residues: Optional[int] = 10, add_hydrogens: bool = True, refine_positions: bool = True, - res_num_mapping: Optional[Dict] = None + res_num_mapping: Optional[Dict] = None, + save_refinement_bundle: bool = True, + refinement_bundle_dir: Optional[os.PathLike] = None, ): """ Parameters @@ -451,45 +656,63 @@ def runFixWorkflow( self.topology = modeller.getTopology() self.positions = modeller.getPositions() + refinement_bundle_info = None if refine_positions: for top_xml in top_xmls: app.Topology.loadBondDefinitions(top_xml) self.topology.createStandardBonds() try: - ff = app.ForceField('amber14-all.xml', 'tip3p.xml', *list(set(ff_xmls))) + ff = app.ForceField('amber14-all.xml', 'amber14/tip3p.xml', *list(set(ff_xmls))) if self._has_ligand: generator = SMIRNOFFTemplateGenerator(molecules=[self.off_mol]).generator ff.registerTemplateGenerator(generator) - self.refineAddedAtomPositions(ff) + refinement_state = self.refineAddedAtomPositions(ff) except ValueError: for residue in self.topology.residues(): if residue.name == 'PCA': print(list(residue.atoms()), list(residue.bonds())) # Some cases OpenFF will failed, use GAFF - ff = app.ForceField('amber14-all.xml', 'tip3p.xml', *list(set(ff_xmls))) + ff = app.ForceField('amber14-all.xml', 'amber14/tip3p.xml', *list(set(ff_xmls))) if self._has_ligand: generator = GAFFTemplateGenerator(molecules=[self.off_mol], forcefield='gaff-2.11').generator ff.registerTemplateGenerator(generator) - self.refineAddedAtomPositions(ff) + refinement_state = self.refineAddedAtomPositions(ff) + + if save_refinement_bundle and (refinement_state is not None): + assert refinement_bundle_dir is not None, "ABORTED. refinement_bundle_dir must be assigned" + refinement_bundle_info = save_openmm_refinement_bundle( + bundle_dir=refinement_bundle_dir, + topology=refinement_state['topology'], + positions=refinement_state['positions'], + system=refinement_state['system'], + integrator=refinement_state['integrator'], + state=refinement_state['state'], + metadata={ + 'pdb_id': self.pdb_id, + 'output_protein': str(output_protein), + 'output_ligand': None if output_ligand is None else str(output_ligand), + 'has_ligand': bool(self._has_ligand), + }, + ) - if self._has_ligand: + #print([residue for residue in self.topology.residues()]) protein = Structure(self.topology, self.positions).select_residues([residue for residue in self.topology.residues()][:-1]) protein_top = protein.topology protein_pos = protein.positions else: protein_top = self.topology protein_pos = self.positions - + # Save protein seqres = [(seq.chainId, convert_to_seqres(seq.residues, seq.chainId)) for seq in self.sequences] seqres.sort(key=lambda x: x[0]) headers = [x[1] for x in seqres] - headers += StandardizedPDBFixer.getFixedResidueRemarks(self.missing_residues_skipped, res_num_mapping, use_fixed_remark=False) + #headers += StandardizedPDBFixer.getFixedResidueRemarks(self.missing_residues_skipped, res_num_mapping, use_fixed_remark=False) headers += StandardizedPDBFixer.getFixedResidueRemarks(self.missing_residues_added, res_num_mapping) headers += StandardizedPDBFixer.getFixedAtomRemarks(self.missing_atoms_added, res_num_mapping) headers += StandardizedPDBFixer.getModresRecords(self.mod_res_info, res_num_mapping, self.pdb_id) - + #print(output_protein) fp = open(output_protein, 'w') app.PDBFile.writeHeader(protein_top, fp) for line in headers: @@ -497,6 +720,7 @@ def runFixWorkflow( # map back residue id and icode if res_num_mapping: for residue in protein_top.residues(): + #print(residue.chain.id, int(residue.id)) res_id = res_num_mapping[residue.chain.id][int(residue.id)] if res_id[-1].isalpha(): insert_code = res_id[-1] @@ -516,3 +740,5 @@ def runFixWorkflow( self.rdmol.GetConformer().SetAtomPosition(i, [vec.x * 10, vec.y * 10, vec.z * 10]) with Chem.SDWriter(output_ligand) as w: w.write(self.rdmol) + + return refinement_bundle_info diff --git a/hiqbind/process.py b/hiqbind/process.py index 8a8fe1b..473cf49 100644 --- a/hiqbind/process.py +++ b/hiqbind/process.py @@ -16,14 +16,14 @@ RDLogger.DisableLog('rdApp.*') import openmm.app as app - -from fix_protein import * -from fix_ligand import * -from fix_polymer import * -from rcsb import * -# from refine import * - - +import json +#from hiqbind.fix_protein import StandardizedPDBFixer, convert_to_three_letter_seq, convert_to_seqres , Structure +from hiqbind.fix_ligand import get_reference_smi, read_by_obabel, write_sdf, LigandFixException, fix_ligand + +from hiqbind.rcsb import download_pdb_cif, get_rcsb_data, download_ligand_sdf +from hiqbind.fix_protein import StandardizedPDBFixer, convert_to_three_letter_seq, convert_to_seqres , Structure +from hiqbind.bioassembly import write_biological_assembly_pdb_gemmi, get_assembly_chain_map_gemmi, rewrite_assembled_pdb_header_with_chain_annotations +from hiqbind.fix_polymer import mol_from_seq standardResidues = [ 'ALA', 'ASN', 'CYS', 'GLU', 'HIS', 'LEU', 'MET', 'PRO', 'THR', 'TYR', 'ARG', 'ASP', 'GLN', 'GLY', 'ILE', 'LYS', 'PHE', 'SER', 'TRP', 'VAL', @@ -35,6 +35,17 @@ # Ligands with elements other than this list will be discarded COMMON_ELEMENTS = ['H', 'C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I'] +# +AMBER14_TIP3PXML_ACCEPTABLE_RESNAME = [ + "AL", "Ag", "BA", "BR", "Be", "CA", "CD", "CE", "CL", "CO", "CR", "CS", + "CU", "Ce", "Cr", "Dy", "EU", "EU3", "Er", "F", "FE", "FE2", "GD3", "HG", + "Hf", "IN", "IOD", "K", "LA", "LI", "LU", "MG", "MN", "NA", "NI", "Nd", + "PB", "PD", "PR", "PT", "Pu", "RB", "Ra", "SM", "SR", "Sm", "Sn", "TB", + "Th", "Tl", "Tm", "U4+", "V2+", "Y", "YB2", "ZN", "Zr", "HOH" + ] +PDB_COVALENT_HETATM = [ + +] # Ligands with less than 4 heavy atoms will be discarded MAX_HEAVY_ATOMS = 4 # Ligands with distances < 2.0 Angstrom to the protein will be discarded @@ -394,6 +405,11 @@ def find_ligand_residues(topology: app.Topology, ligand_chain: str, ligand_resid return list(residues) + + + + + def process_everything( pdb_id: str, ligand_id: Union[None, str, List[str]] = None, @@ -402,7 +418,11 @@ def process_everything( binding_cutoff: float = BINDING_CUTOFF, hetatm_cutoff: float = HETATM_CUTOFF, find_connected_ligand_residues: bool = True, -): + do_refine_structure_with_ligand_plus_hetatm: bool = False, + do_refine_structure_with_ligand_plus_hetatm_assembly:bool = False, + do_refine_structure_with_ligand_plus_assembly:bool = False, + assembly_name:int = 1 + ): """ Run process workflow """ @@ -414,11 +434,18 @@ def process_everything( if not os.path.isdir(folder): os.mkdir(folder) + # ==================== + # RCSB + # ====================== + download_pdb_cif(pdb_id, folder) get_rcsb_data(pdb_id, os.path.join(folder, 'rcsb_data.json')) pdb_file = os.path.join(folder, f'{pdb_id}.pdb') cif_file = os.path.join(folder, f'{pdb_id}.cif') + # ============================= + # Informatio + # ============================= # read in the key properties from the original pdb and cif file key_properties, interchain_ss, modres_info, connect = extract_pdb_information(pdb_file) @@ -428,7 +455,11 @@ def process_everything( sequences[row['chain_id']] = convert_to_three_letter_seq(row['pdbx_seq_one_letter_code']) with open(os.path.join(folder, 'res_num_mapping.json'), 'w') as f: json.dump(res_num_mapping, f, indent=4) - + + # ================================= + # Read the structure file + # ================================== + from openmm import app # OpenMM will change residue namings app.PDBFile._loadNameReplacementTables() app.PDBFile._residueNameReplacements = {k:v for k, v in app.PDBFile._residueNameReplacements.items() if k == v} @@ -467,7 +498,10 @@ def process_everything( ligand_residues_list = [] for chain, residue_numbers in ligand_info: - ligand_residues = find_ligand_residues(struct.topology, chain, residue_numbers, max_num_residues=MAX_NUM_RES_POLY, find_connected=find_connected_ligand_residues) + ligand_residues = find_ligand_residues( + struct.topology, chain, residue_numbers, + max_num_residues=MAX_NUM_RES_POLY, + find_connected=find_connected_ligand_residues) ligand_residues_list.append(ligand_residues) @@ -551,8 +585,12 @@ def process_everything( positions = struct.get_positions_by_residues(ligand_residues + include['polymer']) for residue in hetero_residues: het_positions = struct.get_positions_by_residues([residue]) - if np.min(cdist(positions, het_positions)) * 10 < hetatm_cutoff: - include['hetatm'].append(residue) + if np.min(cdist(positions, het_positions)) * 10 > hetatm_cutoff: + continue + if residue.name not in AMBER14_TIP3PXML_ACCEPTABLE_RESNAME: + print(f"WARNING. hetatm (not ligand) {residue.name} was excluded in {pdb_id}.") + continue + include['hetatm'].append(residue) # Record ligand include['ligand'] = ligand_residues @@ -592,7 +630,9 @@ def process_everything( ref_name = 'seq:' + ','.join(seq) except: ref_smi = None - + # ================================= + # Ligand as a chemical + # ================================= if ref_smi: with open(os.path.join(subfolder, f'ref.smi'), 'w') as f: f.write(ref_smi + ' ' + ref_name) @@ -634,8 +674,16 @@ def process_everything( # all chain_properties, ssbond_lines = extract_chain_specific_information(key_properties, chains_include) all_pdb = os.path.join(subfolder, f'{basename}_protein_hetatm.pdb') - struct.select_residues(include['polymer'] + include['hetatm']).save(all_pdb, header='\n'.join(chain_properties)) - + all_header = seqres + modres + [ + line for line in chain_properties + if (not line.startswith('SEQRES')) and (not line.startswith('MODRES')) + ] + struct.select_residues(include['polymer'] + include['hetatm']).save( + all_pdb, + header='\n'.join(all_header), + res_num_mapping=res_num_mapping + ) + # Record alignment info alignment_info = alignment_info[alignment_info['chain_id'].isin(all_chains_to_include)] diff_info = diff_info[diff_info['chain_id'].isin(all_chains_to_include)] @@ -643,33 +691,233 @@ def process_everything( diff_info.to_csv(os.path.join(folder, 'diff_info.csv'), index=None) fix_ligands_in_folder(folder) + refine_structure_with_ligand(folder) fp = open(os.path.join(folder, 'done.tag'), 'w') fp.close() + # ==================== + # Optional hetatm, assembly, hetatm-assembly + # ====================== + if do_refine_structure_with_ligand_plus_assembly: + _, _ = refine_structure_with_ligand_plus_assembly(folder, assembly_name=f'{assembly_name}') + + if do_refine_structure_with_ligand_plus_hetatm: + refine_structure_with_ligand_plus_hetatm(folder) + + if do_refine_structure_with_ligand_plus_hetatm_assembly: + _, count_presence_hetatm_assembly = refine_structure_with_ligand_plus_hetatm_assembly(folder, assembly_name=f'{assembly_name}') + + json_file = os.path.join(folder, 'stat_indicate_hetatm_bioassembly.json') + with open(json_file, "w") as f: + json.dump(count_presence_hetatm_assembly, f, indent=2) + + + +def detect_tip3_residues_in_pdb(pdb_file): + from Bio.PDB import PDBParser + import os + + + + out = {name: 0 for name in AMBER14_TIP3PXML_ACCEPTABLE_RESNAME} + + parser = PDBParser(QUIET=True) + structure = parser.get_structure(os.path.basename(pdb_file), pdb_file) + + present_resnames = set() + for residue in structure.get_residues(): + present_resnames.add(residue.get_resname().strip()) + + for name in out: + if name in present_resnames: + out[name] = 1 + + return out def refine_structure_with_ligand(folder): pdb_id = os.path.basename(folder) res_num_mapping = {} + saved_states = [] with open(os.path.join(folder, 'res_num_mapping.json')) as f: for chain, mapping in json.load(f).items(): res_num_mapping[chain] = {int(k): v for k, v in mapping.items()} for protein_pdb in glob.glob(os.path.join(folder, '*/*_protein.pdb')): ligand_sdf = protein_pdb.replace("_protein.pdb", "_ligand_fixed.sdf") + output_protein = protein_pdb.replace("_protein.pdb", "_protein_refined.pdb") + output_dir = os.path.dirname(protein_pdb) + bundle_dir = output_dir + "/refine_structure_with_ligand/" if VERBOSE: print(f'Processing {protein_pdb}') fixer = StandardizedPDBFixer(protein_pdb=protein_pdb, ligand_sdf=ligand_sdf, pdb_id=pdb_id, verbose=False) - fixer.runFixWorkflow( - output_protein=protein_pdb.replace("_protein.pdb", "_protein_refined.pdb"), + + os.makedirs(bundle_dir, exist_ok=True) + result = fixer.runFixWorkflow( + output_protein=output_protein, output_ligand=ligand_sdf.replace("_fixed.sdf", "_refined.sdf"), res_num_mapping=res_num_mapping, refine_positions=True, - skip_long_missing_residues=MAX_ADD_MISSING_RES + skip_long_missing_residues=MAX_ADD_MISSING_RES, + save_refinement_bundle=True, + refinement_bundle_dir=bundle_dir, + ) + + return saved_states + +def refine_structure_with_ligand_plus_assembly(folder, assembly_name='1'): + pdb_id = os.path.basename(folder) + saved_states = [] + cif_file = os.path.join(folder, f'{pdb_id}.cif') + count_presence_hetatm_assembly = {"has_new_assembled_chain" : 0} + for protein_pdb in glob.glob(os.path.join(folder, '*/*_protein.pdb')): + ligand_sdf = protein_pdb.replace("_protein.pdb", "_ligand_fixed.sdf") + assembly_pdb = protein_pdb.replace("_protein.pdb", f"_protein_assembly{assembly_name}.pdb") + output_protein = protein_pdb.replace("_protein.pdb", f"_protein_assembly{assembly_name}_refined.pdb") + output_dir = os.path.dirname(protein_pdb) + bundle_dir = output_dir + "/refine_structure_with_ligand_plus_assembly/" + + out_st, has_new_assembled_chain = write_biological_assembly_pdb_gemmi( + cif_file, + protein_pdb, + assembly_pdb, + assembly_name=assembly_name, + chain_naming='Short', + merge_dist=0.0, ) + chain_map = get_assembly_chain_map_gemmi( + cif_file, + protein_pdb, + assembly_pdb, + assembly_name=assembly_name, + ) + + rewrite_assembled_pdb_header_with_chain_annotations( + protein_pdb, + assembly_pdb, + chain_map, + ) + + + fixer = StandardizedPDBFixer( + protein_pdb=assembly_pdb, + ligand_sdf=ligand_sdf, + pdb_id=pdb_id, + verbose=False, + ) + + result = fixer.runFixWorkflow( + output_protein=output_protein, + output_ligand=ligand_sdf.replace("_fixed.sdf", f"_assembly{assembly_name}_refined.sdf"), + res_num_mapping=None, + refine_positions=True, + skip_long_missing_residues=MAX_ADD_MISSING_RES*20, + save_refinement_bundle=True, + refinement_bundle_dir=bundle_dir, + ) + if has_new_assembled_chain: + count_presence_hetatm_assembly['has_new_assembled_chain'] +=1 + count_presence_hetatm_assembly.update(detect_tip3_residues_in_pdb(output_protein)) + saved_states.append(result) + + return saved_states, count_presence_hetatm_assembly + +def refine_structure_with_ligand_plus_hetatm(folder): + pdb_id = os.path.basename(folder) + res_num_mapping = {} + saved_states = [] + with open(os.path.join(folder, 'res_num_mapping.json')) as f: + for chain, mapping in json.load(f).items(): + res_num_mapping[chain] = {int(k): v for k, v in mapping.items()} + + for protein_pdb in glob.glob(os.path.join(folder, '*/*_protein_hetatm.pdb')): + ligand_sdf = protein_pdb.replace("_protein_hetatm.pdb", "_ligand_fixed.sdf") + output_protein = protein_pdb.replace("_protein_hetatm.pdb", "_protein_hetatm_refined.pdb") + output_dir = os.path.dirname(protein_pdb) + bundle_dir = output_dir + "/refine_structure_with_ligand_plus_hetatm/" + if VERBOSE: + print(f'Processing {protein_pdb}') + fixer = StandardizedPDBFixer(protein_pdb=protein_pdb, + ligand_sdf=ligand_sdf, + pdb_id=pdb_id, + verbose=False) + + + result = fixer.runFixWorkflow( + output_protein=output_protein, + output_ligand=ligand_sdf.replace("_fixed.sdf", "_refined.sdf"), + res_num_mapping=None, + refine_positions=True, + skip_long_missing_residues=MAX_ADD_MISSING_RES, + save_refinement_bundle=True, + refinement_bundle_dir=bundle_dir, + + ) + + return saved_states + +def refine_structure_with_ligand_plus_hetatm_assembly(folder, assembly_name='1'): + pdb_id = os.path.basename(folder) + saved_states = [] + cif_file = os.path.join(folder, f'{pdb_id}.cif') + count_presence_hetatm_assembly = {"has_new_assembled_chain" : 0} + for protein_pdb in glob.glob(os.path.join(folder, '*/*_protein_hetatm.pdb')): + ligand_sdf = protein_pdb.replace("_protein_hetatm.pdb", "_ligand_fixed.sdf") + assembly_pdb = protein_pdb.replace("_protein_hetatm.pdb", f"_protein_hetatm_assembly{assembly_name}.pdb") + output_protein = protein_pdb.replace("_protein_hetatm.pdb", f"_protein_hetatm_assembly{assembly_name}_refined.pdb") + output_dir = os.path.dirname(protein_pdb) + bundle_dir = output_dir + "/refine_structure_with_ligand_plus_hetatm_assembly/" + + out_st, has_new_assembled_chain = write_biological_assembly_pdb_gemmi( + cif_file, + protein_pdb, + assembly_pdb, + assembly_name=assembly_name, + chain_naming='Short', + merge_dist=0.0, + ) + + + chain_map = get_assembly_chain_map_gemmi( + cif_file, + protein_pdb, + assembly_pdb, + assembly_name=assembly_name, + ) + + rewrite_assembled_pdb_header_with_chain_annotations( + protein_pdb, + assembly_pdb, + chain_map, + ) + + + fixer = StandardizedPDBFixer( + protein_pdb=assembly_pdb, + ligand_sdf=ligand_sdf, + pdb_id=pdb_id, + verbose=False, + ) + + result = fixer.runFixWorkflow( + output_protein=output_protein, + output_ligand=ligand_sdf.replace("_fixed.sdf", f"_hetatm_assembly{assembly_name}_refined.sdf"), + res_num_mapping=None, + refine_positions=True, + skip_long_missing_residues=MAX_ADD_MISSING_RES*20, + save_refinement_bundle=True, + refinement_bundle_dir=bundle_dir, + ) + if has_new_assembled_chain: + count_presence_hetatm_assembly['has_new_assembled_chain'] +=1 + count_presence_hetatm_assembly.update(detect_tip3_residues_in_pdb(output_protein)) + saved_states.append(result) + + return saved_states, count_presence_hetatm_assembly + def fix_ligands_in_folder(folder): # Fix Ligands err_log = [] @@ -702,7 +950,7 @@ def fix_ligands_in_folder(folder): err_msg = '----\n'.join(f'\nError occurs when fixing {name}: \n{err_msg}' for name, err_msg in err_log) if err_msg: raise LigandFixException(err_msg) - + if __name__ == "__main__": import warnings @@ -717,16 +965,70 @@ def fix_ligands_in_folder(folder): parser.add_argument('-d', '--output', dest='output', help='output directory') parser.add_argument('--poly', dest='poly', action='store_true', help='if polymer csv') parser.add_argument('--serial', dest='serial', action='store_true', help='if run serially (not in parallel)') + + parser.add_argument( + '--binding_cutoff', + dest='binding_cutoff', + type=float, + default=BINDING_CUTOFF, + help='binding cutoff' + ) + parser.add_argument( + '--hetatm_cutoff', + dest='hetatm_cutoff', + type=float, + default=HETATM_CUTOFF, + help='hetatm cutoff' + ) + parser.add_argument( + '--find_connected_ligand_residues', + dest='find_connected_ligand_residues', + action=argparse.BooleanOptionalAction, + default=True, + help='whether to find connected ligand residues' + ) + parser.add_argument( + '--do_refine_structure_with_ligand_plus_hetatm', + dest='do_refine_structure_with_ligand_plus_hetatm', + action='store_true', + default=False, + help='refine structure with ligand plus hetatm' + ) + parser.add_argument( + '--do_refine_structure_with_ligand_plus_assembly', + dest='do_refine_structure_with_ligand_plus_assembly', + action='store_true', + default=False, + help='refine structure with ligand plus assembly' + ) + parser.add_argument( + '--do_refine_structure_with_ligand_plus_hetatm_assembly', + dest='do_refine_structure_with_ligand_plus_hetatm_assembly', + action='store_true', + default=False, + help='refine structure with ligand plus hetatm assembly' + ) + + input_args = parser.parse_args() dataset_dir = input_args.output if not os.path.isdir(dataset_dir): os.mkdir(dataset_dir) + process_kwargs = { + 'binding_cutoff': input_args.binding_cutoff, + 'hetatm_cutoff': input_args.hetatm_cutoff, + 'find_connected_ligand_residues': input_args.find_connected_ligand_residues, + 'do_refine_structure_with_ligand_plus_hetatm': input_args.do_refine_structure_with_ligand_plus_hetatm, + 'do_refine_structure_with_ligand_plus_hetatm_assembly': input_args.do_refine_structure_with_ligand_plus_hetatm_assembly, + 'do_refine_structure_with_ligand_plus_assembly': input_args.do_refine_structure_with_ligand_plus_assembly, + } def wrap_process_wf(args): pdbid, ligand_ccd, ligand_info = args try: - process_everything(pdbid, ligand_ccd, ligand_info, dataset_dir) + process_everything(pdbid, ligand_ccd, ligand_info, dataset_dir, + **process_kwargs) except Exception as e: # raise e errmsg = traceback.format_exc() @@ -739,6 +1041,8 @@ def wrap_process_wf(args): if not input_args.poly: args = [] for pdbid, subdf in df.groupby('PDBID'): + #if pdbid != '1ork': + # continue ligand_info, ligand_ccd = [], None for _, row in subdf.iterrows(): chain, resnum = row['Ligand chain'], row['Ligand residue sequence number'] @@ -751,6 +1055,7 @@ def wrap_process_wf(args): ligand_ccd = subdf['Ligand CCD'].unique().tolist() arg = (pdbid, ligand_ccd, ligand_info) args.append(arg) + args = sorted(args)[:] else: args = [(pdbid, None, None) for pdbid in df['PDBID'].unique()] diff --git a/hiqbind/rcsb.py b/hiqbind/rcsb.py index 4627a28..bb1c758 100644 --- a/hiqbind/rcsb.py +++ b/hiqbind/rcsb.py @@ -56,6 +56,31 @@ def download_file(url: str, fp: os.PathLike, overwrite: bool = False, raise_erro msg = f"Fail to download {url}. Error: {e}" +def download_file_bioassembly(url: str, fp: os.PathLike, overwrite: bool = False, raise_error: bool = True, folder=None, pdb_id=None): + """ + Download file from given URL + + Parameters + ---------- + url: str + URL of the file to be downloaded + fp: os.PathLike + Local path to the downloaded file + overwrite: bool + If True, will overwrite the exisiting file. + If False, will skip the download process if the file exists. Default False. + raise_error: bool + If True, will raise error if the download fails. Default True. + """ + download_file(url, fp, overwrite , raise_error) + import os, gzip, shutil + + gz_path = os.path.join(folder, f"{pdb_id}-assembly1.cif.gz") + out_path = os.path.join(folder, f"{pdb_id}-assembly1.cif") + + with gzip.open(gz_path, "rb") as f_in, open(out_path, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + def download_pdb_cif(pdb_id: str, folder: os.PathLike, overwrite: bool = False, raise_error: bool = True): """ Download a PDB & CIF file from RCSB. @@ -74,11 +99,14 @@ def download_pdb_cif(pdb_id: str, folder: os.PathLike, overwrite: bool = False, """ # URL for the PDB file (Replace with the base URL of your choice) url_cif = f"https://files.rcsb.org/download/{pdb_id}.cif" + url_cif_bioassembly = f"https://files.rcsb.org/download/{pdb_id.upper()}-assembly1.cif.gz" url_pdb = f"https://files.rcsb.org/download/{pdb_id}.pdb" download_file(url_cif, os.path.join(folder, f'{pdb_id}.cif'), overwrite, raise_error) download_file(url_pdb, os.path.join(folder, f'{pdb_id}.pdb'), overwrite, raise_error) - + # Get the symmetry cured version but the missing res modres and other bioinformatics info are all missing, so let's derive from the .cif using gemmi + #download_file_bioassembly(url_cif_bioassembly, os.path.join(folder, f'{pdb_id}-assembly1.cif.gz'), overwrite, raise_error, folder, pdb_id) + def get_smiles_from_rcsb(comp_id: str): """