diff --git a/ipsae.py b/ipsae.py index 45f64de..2193b89 100644 --- a/ipsae.py +++ b/ipsae.py @@ -74,13 +74,17 @@ af3 = True boltz1 = False cif = True -elif ".cif" in pdb_path and pae_file_path.endswith(".npz"): - pdb_stem=pdb_path.replace(".cif","") - path_stem = f'{pdb_path.replace(".cif","")}_{pae_string}_{dist_string}' +elif (".cif" in pdb_path or ".pdb" in pdb_path) and pae_file_path.endswith(".npz"): + pdb_stem=pdb_path.replace(".cif","").replace(".pdb","") + # Determine extension for replacement + ext = ".cif" if ".cif" in pdb_path else ".pdb" + path_stem = f'{pdb_path.replace(ext,"")}_{pae_string}_{dist_string}' + af2 = False af3 = False boltz1 = True - cif = True + cif = False # Set to False if using PDB so it uses the PDB parser + if ext == ".cif": cif = True else: print("Wrong PDB or PAE file type ", pdb_path) sys.exit() @@ -216,7 +220,10 @@ def parse_cif_atom_line(line,fielddict): atom_num = linelist[ fielddict['id'] ] atom_name = linelist[ fielddict['label_atom_id'] ] residue_name = linelist[ fielddict['label_comp_id'] ] - chain_id = linelist[ fielddict['label_asym_id'] ] + if 'auth_asym_id' in fielddict: + chain_id = linelist[ fielddict['auth_asym_id'] ] + else: + chain_id = linelist[ fielddict['label_asym_id'] ] residue_seq_num = linelist[ fielddict['label_seq_id'] ] x = linelist[ fielddict['Cartn_x'] ] y = linelist[ fielddict['Cartn_y'] ] @@ -441,7 +448,12 @@ def classify_chains(chains, residue_types): plddt_file_path=pae_file_path.replace("pae","plddt") if os.path.exists(plddt_file_path): data_plddt=np.load(plddt_file_path) - plddt_boltz1=np.array(100.0*data_plddt['plddt']) + raw_plddt = data_plddt['plddt'] + # Only multiply by 100 if the max value is <= 1.0 (meaning it's normalized) + if np.max(raw_plddt) <= 1.0: + plddt_boltz1 = np.array(100.0 * raw_plddt) + else: + plddt_boltz1 = np.array(raw_plddt) plddt = plddt_boltz1[np.ix_(token_array.astype(bool))] cb_plddt = plddt_boltz1[np.ix_(token_array.astype(bool))] else: @@ -463,8 +475,12 @@ def classify_chains(chains, residue_types): if os.path.exists(summary_file_path): with open(summary_file_path, 'r') as file: data_summary = json.load(file) - - boltz1_chain_pair_iptm_data=data_summary['pair_chains_iptm'] + if 'pair_chains_iptm' in data_summary: + boltz1_chain_pair_iptm_data=data_summary['pair_chains_iptm'] + else: + # Boltz2 specific or missing key fallback + print(f"Warning: 'pair_chains_iptm' key not found in {summary_file_path}. ipTM scores will be 0.") + boltz1_chain_pair_iptm_data = {} for chain1 in unique_chains: nchain1= ord(chain1) - ord('A') # map A,B,C... to 0,1,2... for chain2 in unique_chains: @@ -967,4 +983,4 @@ def classify_chains(chains, residue_types): chain1_residues = f'chain {chain1} and resi {contiguous_ranges(unique_residues_chain1[chain1][chain2])}' chain2_residues = f'chain {chain2} and resi {contiguous_ranges(unique_residues_chain2[chain1][chain2])}' PML.write(f'alias {chain_pair}, color gray80, all; color {color1}, {chain1_residues}; color {color2}, {chain2_residues}\n\n') - OUT.write("\n") + OUT.write("\n") \ No newline at end of file