diff --git a/README.md b/README.md index d2db3df..289abaf 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ Additionally, you may need to install additional development tools. Depending on - If you have sudo privileges: ```bash - sudo apt install built-essential + sudo apt install build-essential ``` - For HPC cluster environment, it is recommended to use [Conda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html) (or [Mamba](https://mamba.readthedocs.io/en/latest/)): @@ -57,7 +57,7 @@ Additionally, you may need to install additional development tools. Depending on ## Building Datasets -This step build the datasets necessary for Oncodrive3D to run the 3D clustering analysis. It is required once after installation or whenever you need to generate datasets for a different organism or apply a specific threshold to define amino acid contacts. +This step builds the datasets necessary for Oncodrive3D to run the 3D clustering analysis. It is required once after installation or whenever you need to generate datasets for a different organism or apply a specific threshold to define amino acid contacts. > [!WARNING] > This step is highly time- and resource-intensive, requiring a significant amount of free disk space and computational power. It will download and process a large amount of data. Ensure sufficient resources are available before proceeding, as insufficient capacity may result in extended runtimes or processing failures. @@ -65,11 +65,14 @@ This step build the datasets necessary for Oncodrive3D to run the 3D clustering > Reliable internet access is required because AlphaFold structures, Ensembl annotations, Pfam files, and other resources are downloaded on demand during the build. > [!NOTE] -> The first time that you run Oncodrive3D building dataset step with a given reference genome, it will download it from our servers. By default the downloaded datasets go to`~/.bgdata`. If you want to move these datasets to another folder you have to define the system environment variable `BGDATA_LOCAL` with an export command. +> The first time that you run Oncodrive3D building dataset step with a given reference genome, it will download it from our servers. By default the downloaded datasets go to `~/.bgdata`. If you want to move these datasets to another folder you have to define the system environment variable `BGDATA_LOCAL` with an export command. > [!NOTE] > Human datasets built with the default settings pull canonical transcript metadata from the January 2024 Ensembl archive (release 111 / GENCODE v45). For maximum compatibility, annotate your input variants with the same Ensembl/Gencode release or supply the unfiltered VEP output together with `--o3d_transcripts --use_input_symbols`. +> [!NOTE] Predicted Aligned Error (PAE) files for older AlphaFold DB versions (e.g., v4) are no longer hosted after 2025. If you need PAE for an older AF version, download and supply them locally via `--custom_pae_dir`. +> MANE structures are only available from the AlphaFold DB v4 release. Non‑MANE builds default to v6; MANE mode forces v4 for structures, so you should provide PAE files via `--custom_pae_dir`. + ``` Usage: oncodrive3d build-datasets [OPTIONS] @@ -83,13 +86,16 @@ Examples: Options: -o, --output_dir PATH Path to the directory where the output files will be saved. Default: ./datasets/ - -s, --organism PATH Specifies the organism (`human` or `mouse`). - Default: human + -s, --organism TEXT Specifies the organism (`human` or `mouse`; also accepts `Homo sapiens` / `Mus musculus`). + Default: Homo sapiens -m, --mane Use structures predicted from MANE Select transcripts (applicable to Homo sapiens only). -M, --mane_only Use only structures predicted from MANE Select transcripts (applicable to Homo sapiens only). - -C, --custom_mane_pdb_dir PATH Path to directory containing custom MANE PDB structures. + -C, --custom_mane_pdb_dir PATH Path to directory containing custom MANE PDB structures (requires --mane_only). + Default: None + --custom_pae_dir PATH Path to directory containing pre-downloaded PAE JSON files. + The directory will be copied into the build as `pae/`. Default: None -f, --custom_mane_metadata_path Path to a dataframe (typically a samplesheet.csv) including Ensembl IDs and sequences of the custom pdbs. @@ -98,8 +104,8 @@ Options: Default: 10 -c, --cores INT Number of CPU cores for computation. Default: All available CPU cores - --af_version INT Version of the AlphaFold Protein Structure Database release. - Default: 4 + --af_version INT AlphaFold DB version for non-MANE builds (MANE uses v4). + Default: 6 -y, --yes Run without interactive prompts. -v, --verbose Enables verbose output. -h, --help Show this message and exit. @@ -112,7 +118,7 @@ For more information on the output of this step, please refer to the [Building D > To maximize structural coverage of **MANE Select transcripts**, you can [predict missing structures locally and integrate them into Oncodrive3D](tools/preprocessing/README.md) using: > > - `tools/preprocessing/prepare_samplesheet.py`: a standalone utility that: -> - Retrieve the full MANE entries from NCBI. +> - Retrieves the full MANE entries from NCBI. > - Identifies proteins missing from the AlphaFold MANE dataset. > - Generates: > - A `samplesheet.csv` with Ensembl protein IDs, FASTA paths, and optional sequences. @@ -133,7 +139,7 @@ For more information on the output of this step, please refer to the [Building D ## Running 3D clustering Analysis -For in depth information on how to obtain the required input data and for comprehensive information about the output, please refer to the [Input and Output Documentation](https://github.com/bbglab/oncodrive3d/tree/master/docs/run_input_output.md) of the 3D clustering analysis. +For in-depth information on how to obtain the required input data and for comprehensive information about the output, please refer to the [Input and Output Documentation](https://github.com/bbglab/oncodrive3d/tree/master/docs/run_input_output.md) of the 3D clustering analysis. ### Input @@ -256,8 +262,6 @@ For more information, refer to the [Oncodrive3D Pipeline](https://github.com/bbg ### Usage ---- - > [!WARNING] > When using the Nextflow script, ensure that your input files are organized in the following directory structure (you only need either the `maf/` or `vep/` directory): > @@ -302,10 +306,10 @@ Options: --vep_input BOOL Use `vep/` subdir as input and select transcripts matching the Ensembl transcript IDs in Oncodrive3D built datasets. Default: false - --mane BOOL Prioritize structures corresponding to MANE transcrips if + --mane BOOL Prioritize structures corresponding to MANE transcripts if multiple structures are associated to the same gene. Default: false - --seed INT: Seed value for reproducibility. + --seed INT Seed value for reproducibility. Default: 128 ``` diff --git a/scripts/datasets/af_merge.py b/scripts/datasets/af_merge.py index 90169ff..c7c88b5 100755 --- a/scripts/datasets/af_merge.py +++ b/scripts/datasets/af_merge.py @@ -223,13 +223,18 @@ def get_pdb_seqres_records(lst_res): def add_refseq_record_to_pdb(path_structure): """ - Add the SEQREF records to the pdb file. + Add the SEQRES records to the pdb file. + Returns True if SEQRES was inserted, False if skipped because SEQRES already exists. """ # Open the PDB file and get SEQRES insert index with open(path_structure, 'r') as file: pdb_lines = file.readlines() - insert_index = next(i for i, line in enumerate(pdb_lines) if line.startswith('MODEL')) + + if any(line.startswith('SEQRES') for line in pdb_lines): + return False + + insert_index = next(i for i, line in enumerate(pdb_lines) if line.startswith('MODEL')) # Get seares records residues = get_res_from_chain(path_structure) @@ -243,6 +248,8 @@ def add_refseq_record_to_pdb(path_structure): output_file.truncate() output_file.writelines(pdb_lines) + return True + # Other functions @@ -306,6 +313,8 @@ def merge_af_fragments(input_dir, output_dir=None, af_version=4, gzip=False): else: # Get list of fragmented Uniprot ID and max AF-F not_processed = [] + refseq_added = 0 + refseq_skipped_existing = 0 for uni_id, max_f in tqdm(fragments, total=len(fragments), desc="Merging AF fragments"): processed = False @@ -329,7 +338,10 @@ def merge_af_fragments(input_dir, output_dir=None, af_version=4, gzip=False): tmp_name = os.path.join(output_dir, f"AF-{uni_id}-FM-model_v{af_version}.pdb") name = os.path.join(output_dir, f"AF-{uni_id}-F{max_f}M-model_v{af_version}.pdb") os.rename(tmp_name, name) - add_refseq_record_to_pdb(name) + if add_refseq_record_to_pdb(name): + refseq_added += 1 + else: + refseq_skipped_existing += 1 if len(not_processed) > 0: logger.warning(f"Not processed: {not_processed}") @@ -338,6 +350,11 @@ def merge_af_fragments(input_dir, output_dir=None, af_version=4, gzip=False): save_unprocessed_ids(not_processed, os.path.join(output_dir, "fragmented_pdbs", "ids_not_merged.txt")) + if refseq_skipped_existing: + logger.info( + "Skipped SEQRES insertion for %s merged structures (SEQRES already present).", + refseq_skipped_existing, + ) logger.info("Merge of structures completed!") else: diff --git a/scripts/datasets/build_datasets.py b/scripts/datasets/build_datasets.py index 04a4f45..adc683c 100644 --- a/scripts/datasets/build_datasets.py +++ b/scripts/datasets/build_datasets.py @@ -21,6 +21,7 @@ import os +import shutil import daiquiri from scripts import __logger_name__ @@ -43,6 +44,7 @@ def build(output_datasets, mane, mane_only, custom_pdb_dir, + custom_pae_dir, custom_mane_metadata_path, distance_threshold, num_cores, @@ -57,6 +59,13 @@ def build(output_datasets, # Download PDB structures species = get_species(organism) + if mane and str(af_version) != "4": + logger.warning( + "MANE structures are only available in AlphaFold DB v4. " + "Ignoring --af_version=%s and using v4 for this build.", + af_version, + ) + af_version = 4 if not mane_only: logger.info("Downloading AlphaFold (AF) predicted structures...") get_structures( @@ -69,7 +78,11 @@ def build(output_datasets, # Merge fragmented structures logger.info("Merging fragmented structures...") - merge_af_fragments(input_dir=os.path.join(output_datasets,"pdb_structures"), gzip=True) + merge_af_fragments( + input_dir=os.path.join(output_datasets,"pdb_structures"), + af_version=af_version, + gzip=True + ) # Download PDB MANE structures if species == "Homo sapiens" and mane: @@ -78,6 +91,7 @@ def build(output_datasets, path=os.path.join(output_datasets,"pdb_structures_mane"), species=species, mane=True, + af_version=str(af_version), threads=num_cores ) mv_mane_pdb(output_datasets, "pdb_structures", "pdb_structures_mane") @@ -85,6 +99,13 @@ def build(output_datasets, # Copy custom PDB structures and optinally add SEQRES if custom_pdb_dir is not None: + if not mane_only: + logger.error( + "custom_pdb_dir requires --mane_only. Use --mane_only when providing custom MANE structures." + ) + raise ValueError( + "custom_pdb_dir requires --mane_only" + ) if custom_mane_metadata_path is None: logger.error( "custom_mane_metadata_path must be provided when custom_pdb_dir is specified" @@ -112,9 +133,11 @@ def build(output_datasets, output_seq_df=os.path.join(output_datasets, "seq_for_mut_prob.tsv"), organism=species, mane=mane, + mane_only=mane_only, num_cores=num_cores, mane_version=mane_version, - custom_mane_metadata_path=custom_mane_metadata_path + custom_mane_metadata_path=custom_mane_metadata_path, + af_version=af_version, ) logger.info("Generation of sequences dataframe completed!") @@ -127,18 +150,32 @@ def build(output_datasets, ) # Get PAE - logger.info("Downloading AF predicted aligned error (PAE)...") - get_pae( - input_dir=os.path.join(output_datasets,"pdb_structures"), - output_dir=os.path.join(output_datasets,"pae"), - num_cores=num_cores, - af_version=str(af_version), - custom_pdb_dir=custom_pdb_dir - ) + pae_output_dir = os.path.join(output_datasets, "pae") + if custom_pae_dir is not None: + logger.info("Copying precomputed PAE directory...") + if os.path.exists(custom_pae_dir): + if os.path.exists(pae_output_dir): + shutil.rmtree(pae_output_dir) + shutil.copytree(custom_pae_dir, pae_output_dir) + else: + logger.warning( + "Custom PAE directory does not exist: %s. Skipping copy. " + "Contact maps will be computed without PAE (binary maps).", + custom_pae_dir, + ) + else: + logger.info("Downloading AF predicted aligned error (PAE)...") + get_pae( + input_dir=os.path.join(output_datasets,"pdb_structures"), + output_dir=pae_output_dir, + num_cores=num_cores, + af_version=str(af_version), + custom_pdb_dir=custom_pdb_dir + ) # Parse PAE logger.info("Parsing PAE...") - parse_pae(input=os.path.join(output_datasets, 'pae')) + parse_pae(input=pae_output_dir) logger.info("Parsing PAE completed!") # Get pCAMPs @@ -159,15 +196,6 @@ def build(output_datasets, logger.info("Datasets have been successfully built and are ready for analysis!") if __name__ == "__main__": - build( - output_datasets="/data/bbg/nobackup/scratch/oncodrive3d/mane_missing/oncodrive3d/datasets/datasets-mane_only-mane_custom-250729", - organism="Homo sapiens", - mane=False, - mane_only=True, - custom_pdb_dir="/data/bbg/nobackup/scratch/oncodrive3d/mane_missing/data/250724-no_fragments/all_pdbs-pred_and_retrieved/pdbs", - custom_mane_metadata_path="/data/bbg/nobackup/scratch/oncodrive3d/mane_missing/data/250724-no_fragments/all_pdbs-pred_and_retrieved/samplesheet.csv", - distance_threshold=10, - num_cores=8, - af_version=4, - mane_version=1.4 - ) + raise SystemExit( + "This module is intended to be used via the CLI: `oncodrive3d build-datasets`." + ) diff --git a/scripts/datasets/custom_pdb.py b/scripts/datasets/custom_pdb.py index 63f5a46..212582f 100644 --- a/scripts/datasets/custom_pdb.py +++ b/scripts/datasets/custom_pdb.py @@ -39,7 +39,7 @@ def get_pdb_seqres_records(lst_res): return records -def add_seqres_to_pdb(path_pdb: str, residues: list) -> None: +def add_seqres_to_pdb(path_pdb: str, residues: list) -> bool: """ Insert SEQRES records at the very top of a PDB file (supports gzipped and plain). @@ -56,6 +56,9 @@ def add_seqres_to_pdb(path_pdb: str, residues: list) -> None: with open_in(path_pdb, mode_in) as fh: lines = fh.readlines() + if any(line.startswith("SEQRES") for line in lines): + return False + # Generate SEQRES lines seqres = get_pdb_seqres_records(residues) @@ -65,6 +68,8 @@ def add_seqres_to_pdb(path_pdb: str, residues: list) -> None: # Write back with open_out(path_pdb, mode_out) as fh: fh.writelines(new_lines) + + return True def copy_and_parse_custom_pdbs( @@ -99,13 +104,20 @@ def copy_and_parse_custom_pdbs( samplesheet_df = None # Copy and gzip pdb and optionally add REFSEQ + total_pdb_files = 0 + copied = 0 + skipped_format = 0 + seqres_inserted = 0 + seqres_skipped_existing = 0 for fname in os.listdir(src_dir): if not fname.endswith('.pdb'): continue + total_pdb_files += 1 parts = fname.split('.') # e.g. [ACCESSION, fragment_code, 'alphafold', 'pdb'] if len(parts) < 4: logger.warning(f"Skipping unexpected filename format: {fname}") + skipped_format += 1 continue accession = parts[0] @@ -119,6 +131,7 @@ def copy_and_parse_custom_pdbs( with open(src_path, 'rb') as fin, gzip.open(dst_path, 'wb') as fout: shutil.copyfileobj(fin, fout) + copied += 1 logger.debug(f'Copied and gzipped: {fname} -> {new_name}') # Optionally add SEQRES records @@ -130,8 +143,11 @@ def copy_and_parse_custom_pdbs( if not pd.isna(seq): seq = [one_to_three_res_map[aa] for aa in seq] - add_seqres_to_pdb(path_pdb=dst_path, residues=seq) - logger.debug(f"Inserted SEQRES records into: {new_name}") + if add_seqres_to_pdb(path_pdb=dst_path, residues=seq): + logger.debug(f"Inserted SEQRES records into: {new_name}") + seqres_inserted += 1 + else: + seqres_skipped_existing += 1 else: try: seq = "".join(list(get_seq_from_pdb(dst_path))) @@ -141,4 +157,18 @@ def copy_and_parse_custom_pdbs( logger.warning(f"SEQRES not found in samplesheet and its extraction from structure failed: {new_name}") except Exception as e: logger.warning(f"SEQRES not found in samplesheet and its extraction from structure failed: {new_name}") - logger.warning(f"Exception captured: {e}") \ No newline at end of file + logger.warning(f"Exception captured: {e}") + + logger.info( + "Custom PDB copy summary: %s/%s structures copied (skipped %s invalid filenames).", + copied, + total_pdb_files, + skipped_format, + ) + if seqres_inserted: + logger.debug("Inserted SEQRES records into %s custom structures.", seqres_inserted) + if seqres_skipped_existing: + logger.info( + "Skipped SEQRES insertion for %s custom structures (SEQRES already present).", + seqres_skipped_existing, + ) diff --git a/scripts/datasets/get_pae.py b/scripts/datasets/get_pae.py index fe60087..4af4a61 100644 --- a/scripts/datasets/get_pae.py +++ b/scripts/datasets/get_pae.py @@ -17,7 +17,7 @@ def download_pae( af_version: int, output_dir: str, max_retries: int = 100 - ) -> None: + ) -> str: """ Download Predicted Aligned Error (PAE) file from AlphaFold DB. @@ -26,6 +26,9 @@ def download_pae( af_version: AlphaFold 2 version. output_dir: Output directory where to download the PAE files. max_retries: Break the loop if the download fails too many times. + + Returns: + "ok" if downloaded, "missing" if 404/410, "failed" otherwise. """ file_path = os.path.join(output_dir, f"{uniprot_id}-F1-predicted_aligned_error.json") @@ -42,6 +45,17 @@ def download_pae( time.sleep(30) try: response = requests.get(download_url, timeout=30) + if response.status_code in (404, 410): + logger.warning( + "PAE not available for %s (AF v%s). Skipping.", + uniprot_id, + af_version, + ) + return "missing" + if not response.ok: + raise requests.exceptions.HTTPError( + f"PAE download failed with status {response.status_code}" + ) content = response.content if content.endswith(b'}]') and not content.endswith(b''): with open(file_path, 'wb') as output_file: @@ -51,6 +65,7 @@ def download_pae( status = "ERROR" if i % 10 == 0: logger.debug(f"Request failed {e}: Retrying") + return "ok" if status == "FINISHED" else "failed" def get_pae( @@ -88,13 +103,38 @@ def get_pae( custom_uniprot_ids = [fname.split('.')[0] for fname in os.listdir(custom_pdb_dir) if fname.endswith('.pdb')] uniprot_ids = [uni_id for uni_id in uniprot_ids if uni_id not in custom_uniprot_ids] - with concurrent.futures.ThreadPoolExecutor(max_workers=num_cores) as executor: - tasks = [executor.submit(download_pae, uniprot_id, af_version, output_dir) for uniprot_id in uniprot_ids] - - for _ in tqdm(concurrent.futures.as_completed(tasks), total=len(tasks), desc="Downloading PAE"): - pass + consecutive_missing = 0 + if uniprot_ids: + probe_ids = uniprot_ids[:10] + remaining_ids = uniprot_ids[10:] + for uniprot_id in tqdm(probe_ids, desc="Downloading PAE"): + result = download_pae(uniprot_id, af_version, output_dir) + if result == "missing": + consecutive_missing += 1 + elif result == "ok": + consecutive_missing = 0 + + if consecutive_missing >= 10: + logger.warning( + "Detected %s consecutive missing PAE downloads (HTTP 404/410). " + "PAE files for AlphaFold DB v%s are likely unavailable. " + "Skipping remaining PAE downloads. Contact maps will be computed without PAE (binary maps). " + "Re-run build-datasets with --custom_pae_dir to use precomputed PAE.", + consecutive_missing, + af_version, + ) + with open(checkpoint, "w") as f: + f.write('') + return + + if remaining_ids: + with concurrent.futures.ThreadPoolExecutor(max_workers=num_cores) as executor: + tasks = [executor.submit(download_pae, uniprot_id, af_version, output_dir) for uniprot_id in remaining_ids] + + for _ in tqdm(concurrent.futures.as_completed(tasks), total=len(tasks), desc="Downloading PAE"): + pass with open(checkpoint, "w") as f: f.write('') - logger.info('Download of PAE completed!') \ No newline at end of file + logger.info('Download of PAE completed!') diff --git a/scripts/datasets/seq_for_mut_prob.py b/scripts/datasets/seq_for_mut_prob.py index 535fb33..d213337 100644 --- a/scripts/datasets/seq_for_mut_prob.py +++ b/scripts/datasets/seq_for_mut_prob.py @@ -13,10 +13,11 @@ import ast import json +import logging import multiprocessing import os import re -import shlex +import shutil import subprocess import sys import time @@ -41,6 +42,12 @@ logger = daiquiri.getLogger(__logger_name__ + ".build.seq_for_mut_prob") +_ENSEMBL_REST_SERVER = "https://rest.ensembl.org" +_ENSEMBL_REST_TIMEOUT = (10, 160) # (connect, read) seconds +_ENSEMBL_REST_HEADERS = {"Accept": "text/plain"} +_ENSEMBL_CDS_MAX_CORES = 8 + + #=========== # region Initialize #=========== @@ -84,7 +91,7 @@ def initialize_seq_df(input_path, uniprot_to_gene_dict): # (not 100% reliable but available fo most seq) #============================================== -def backtranseq(protein_seqs, organism = "Homo sapiens"): +def backtranseq(protein_seqs, organism = "Homo sapiens", max_attempts=5, total_timeout=45 * 60): """ Perform backtranslation from proteins to DNA sequences using EMBOS backtranseq. """ @@ -101,44 +108,101 @@ def backtranseq(protein_seqs, organism = "Homo sapiens"): "molecule": "dna", "organism": organism} - # Submit the job request and retrieve the job ID - response = "INIT" - while str(response) != "": - if response != "INIT": + for attempt in range(1, max_attempts + 1): + attempt_start = time.monotonic() + + # Submit the job request and retrieve the job ID + job_id = None + while job_id is None: + if time.monotonic() - attempt_start > total_timeout: + logger.warning( + "Backtranseq submit timed out after %ss (attempt %s/%s).", + total_timeout, + attempt, + max_attempts, + ) + break + try: + response = requests.post(run_url, data=params, timeout=160) + if response.ok: + job_id = response.text.strip() + break + logger.debug( + "Backtranseq submit returned HTTP %s: %s", + response.status_code, + response.text[:200], + ) + except requests.exceptions.RequestException as e: + logger.debug(f"Request failed: {e}") time.sleep(10) - try: - response = requests.post(run_url, data=params, timeout=160) - except requests.exceptions.RequestException as e: - response = "ERROR" - logger.debug(f"Request failed: {e}") - - job_id = response.text.strip() - - # Wait for the job to complete - status = "INIT" - while status != "FINISHED": - time.sleep(20) - try: - result = requests.get(status_url + job_id, timeout=160) - status = result.text.strip() - except requests.exceptions.RequestException as e: - status = "ERROR" - logger.debug(f"Request failed {e}: Retrying..") - - # Retrieve the results of the job - status = "INIT" - while status != "FINISHED": - try: - result = requests.get(result_url + job_id + "/out", timeout=160) - status = "FINISHED" - except requests.exceptions.RequestException as e: - status = "ERROR" - logger.debug(f"Request failed {e}: Retrying..") - time.sleep(10) - dna_seq = result.text.strip() + if job_id is None: + continue + + # Wait for the job to complete + status = "INIT" + while True: + if time.monotonic() - attempt_start > total_timeout: + logger.warning( + "Backtranseq status polling timed out after %ss (attempt %s/%s).", + total_timeout, + attempt, + max_attempts, + ) + status = "TIMEOUT" + break + time.sleep(20) + try: + result = requests.get(status_url + job_id, timeout=160) + if not result.ok: + logger.debug( + "Backtranseq status returned HTTP %s: %s", + result.status_code, + result.text[:200], + ) + continue + status = result.text.strip() + if status == "FINISHED": + break + if status in {"ERROR", "FAILED"}: + logger.warning( + "Backtranseq returned terminal status '%s' (attempt %s/%s).", + status, + attempt, + max_attempts, + ) + break + except requests.exceptions.RequestException as e: + logger.debug(f"Request failed {e}: Retrying..") + + if status != "FINISHED": + continue + + # Retrieve the results of the job + while True: + if time.monotonic() - attempt_start > total_timeout: + logger.warning( + "Backtranseq result fetch timed out after %ss (attempt %s/%s).", + total_timeout, + attempt, + max_attempts, + ) + break + try: + result = requests.get(result_url + job_id + "/out", timeout=160) + if result.ok: + return result.text.strip() + logger.debug( + "Backtranseq result returned HTTP %s: %s", + result.status_code, + result.text[:200], + ) + except requests.exceptions.RequestException as e: + logger.debug(f"Request failed {e}: Retrying..") + time.sleep(10) - return dna_seq + logger.warning("Backtranseq failed after %s attempts; returning empty result.", max_attempts) + return None def batch_backtranseq(df, batch_size, organism = "Homo sapiens"): @@ -166,6 +230,16 @@ def batch_backtranseq(df, batch_size, organism = "Homo sapiens"): # Run backtranseq batch_dna = backtranseq(batch_seq, organism = organism) + if not batch_dna: + logger.warning( + "Backtranseq failed for batch %s/%s; setting Seq_dna to NaN.", + i + 1, + len(batches), + ) + batch["Seq_dna"] = np.nan + lst_batches.append(batch) + continue + # Parse output batch_dna = re.split(">EMBOSS_\d+", batch_dna.replace("\n", ""))[1:] @@ -307,8 +381,30 @@ def get_exons_coord(ids, ens_canonical_transcripts_lst, batch_size=100): https://doi.org/10.1093/nar/gkx237 """ + ids = [str(i) for i in ids] + ens_prot_ids = [i for i in ids if i.upper().startswith("ENSP")] + proteins_api_ids = [i for i in ids if not i.upper().startswith("ENSP")] + + if ens_prot_ids: + logger.debug( + "Skipping %s ENSP IDs for Proteins API (will fall back to Backtranseq).", + len(ens_prot_ids), + ) + lst_df = [] - batches_ids = [ids[i:i+batch_size] for i in range(0, len(ids), batch_size)] + if not proteins_api_ids: + nan = np.repeat(np.nan, len(ids)) + return pd.DataFrame({ + "Uniprot_ID": ids, + "Ens_Gene_ID": nan, + "Ens_Transcr_ID": nan, + "Seq": nan, + "Chr": nan, + "Reverse_strand": nan, + "Exons_coord": nan, + }) + + batches_ids = [proteins_api_ids[i:i+batch_size] for i in range(0, len(proteins_api_ids), batch_size)] for batch_ids in tqdm(batches_ids, total=len(batches_ids), desc="Adding exons coordinate"): @@ -328,6 +424,19 @@ def get_exons_coord(ids, ens_canonical_transcripts_lst, batch_size=100): batch_df = pd.concat([batch_df, nan_rows], ignore_index=True) lst_df.append(batch_df) + if ens_prot_ids: + nan = np.repeat(np.nan, len(ens_prot_ids)) + nan_rows = pd.DataFrame({ + "Uniprot_ID": ens_prot_ids, + "Ens_Gene_ID": nan, + "Ens_Transcr_ID": nan, + "Seq": nan, + "Chr": nan, + "Reverse_strand": nan, + "Exons_coord": nan, + }) + lst_df.append(nan_rows) + return pd.concat(lst_df).reset_index(drop=True) @@ -444,6 +553,9 @@ def add_extra_genes_to_seq_df(seq_df, uniprot_to_gene_dict): lst_extra_genes_rows.append(row) lst_added_genes.append(gene) + if not lst_extra_genes_rows: + return seq_df + seq_df_extra_genes = pd.concat(lst_extra_genes_rows, axis=1).T # Remove rows with multiple symbols and drop duplicated ones @@ -519,6 +631,41 @@ def load_custom_ens_prot_ids(path): return ids +def load_custom_symbol_map(path): + """ + Read a samplesheet and return a mapping from ENSP (sequence) to gene symbol. + Falls back to 'gene' column if 'symbol' is not present. + """ + if not os.path.isfile(path): + logger.error(f"Custom MANE metadata path does not exist: {path!r}") + raise FileNotFoundError(f"Custom MANE metadata not found: {path!r}") + df = pd.read_csv(path) + if "sequence" not in df.columns: + logger.debug("Custom MANE metadata missing 'sequence' column; skipping symbol mapping.") + return {} + + symbol_col = None + if "symbol" in df.columns: + symbol_col = "symbol" + elif "gene" in df.columns: + symbol_col = "gene" + else: + logger.debug("Custom MANE metadata missing 'symbol'/'gene' column; skipping symbol mapping.") + return {} + + seq = df["sequence"].astype(str).str.split(".", n=1).str[0] + sym = df[symbol_col] + mapping = {} + for k, v in zip(seq, sym): + if pd.isna(v): + continue + v = str(v).strip() + if not v: + continue + mapping[k] = v + return mapping + + def get_mane_to_af_mapping( datasets_dir, uniprot_ids, @@ -592,6 +739,18 @@ def get_mane_to_af_mapping( base_ens = mane_mapping["Ens_Prot_ID"].str.split(".", n=1).str[0] mask = base_ens.isin(custom_ids) mane_mapping.loc[mask, "Uniprot_ID"] = base_ens[mask] + if logger.isEnabledFor(logging.DEBUG): + summary_ens = set(mane_summary["Ens_Prot_ID"].astype(str).str.split(".", n=1).str[0]) + missing_custom = sorted(set(custom_ids) - summary_ens) + if missing_custom: + preview = ", ".join(missing_custom[:10]) + suffix = "..." if len(missing_custom) > 10 else "" + logger.debug( + "Custom MANE ENSP IDs not found in MANE summary (%s): %s%s", + len(missing_custom), + preview, + suffix, + ) # Select available Uniprot ID, fist one if multiple are present mane_mapping = mane_mapping.dropna(subset=["Uniprot_ID"]).reset_index(drop=True) @@ -607,37 +766,185 @@ def get_mane_to_af_mapping( return mane_mapping -# def download_biomart_metadata(path_to_file, max_attempts=15, cores=8): -# """ -# Query biomart to get the list of transcript corresponding to the downloaded -# structures (a few structures are missing) and other information. -# """ - -# url = 'http://jan2024.archive.ensembl.org/biomart/martservice?query=' -# attempts = 0 - -# while not os.path.exists(path_to_file): -# download_single_file(url, path_to_file, threads=cores) -# attempts += 1 -# if attempts >= max_attempts: -# raise RuntimeError(f"Failed to download MANE summary file after {max_attempts} attempts. Exiting..") -# time.sleep(5) - - -def download_biomart_metadata(path_to_file): +def download_biomart_metadata(path_to_file, max_attempts=3, wait_seconds=10, use_archive=True): """ Query biomart to get the list of transcript corresponding to the downloaded structures (a few structures are missing) and other information. """ - command = f""" - wget -O {path_to_file} 'http://jan2024.archive.ensembl.org/biomart/martservice?query=' - """ + base_archive = "http://jan2024.archive.ensembl.org" + base_latest = "https://www.ensembl.org" + query = ( + '/biomart/martservice?query=' + '' + '' + '' + '' + '' + '' + ) + url = f"{base_archive}{query}" + fallback_url = f"{base_latest}{query}" + logger.debug("Starting BioMart metadata download to %s (archive: %s, latest: %s).", path_to_file, base_archive, base_latest) + + if shutil.which("wget") is None: + logger.warning("wget not found; falling back to Python downloader for BioMart metadata.") + last_exc = None + if use_archive: + ssl_verify_archive = url.startswith("https://") + for attempt in range(1, max_attempts + 1): + logger.debug("Starting BioMart download attempt %s/%s (archive).", attempt, max_attempts) + try: + download_single_file(url, path_to_file, threads=4, ssl=ssl_verify_archive) + return + except Exception as exc: + last_exc = exc + logger.warning( + "BioMart download failed (attempt %s/%s). Retrying in %ss... Error: %s", + attempt, + max_attempts, + wait_seconds, + exc, + ) + logger.debug("BioMart download exception details:", exc_info=True) + time.sleep(wait_seconds) + + logger.warning("Falling back to latest Ensembl BioMart URL after failure on %s.", base_archive) + else: + logger.debug("Skipping archive BioMart URL; using latest only.") + + if os.path.exists(path_to_file): + try: + os.remove(path_to_file) + except OSError as exc: + logger.warning( + "Failed to remove partial BioMart metadata file %s before fallback: %s", + path_to_file, + exc, + ) + ssl_verify_fallback = fallback_url.startswith("https://") + for attempt in range(1, max_attempts + 1): + logger.debug("Starting BioMart download attempt %s/%s (latest).", attempt, max_attempts) + try: + download_single_file(fallback_url, path_to_file, threads=4, ssl=ssl_verify_fallback) + return + except Exception as exc: + last_exc = exc + logger.warning( + "Fallback BioMart download failed (attempt %s/%s). Retrying in %ss... Error: %s", + attempt, + max_attempts, + wait_seconds, + exc, + ) + logger.debug("Fallback BioMart download exception details:", exc_info=True) + time.sleep(wait_seconds) + + if use_archive: + message = ( + f"Failed to download BioMart metadata after {max_attempts} attempts on archive and " + f"{max_attempts} attempts on latest." + ) + else: + message = f"Failed to download BioMart metadata after {max_attempts} attempts on latest." + raise RuntimeError(message) from last_exc + + command = [ + "wget", + "--no-hsts", + "--continue", + "--read-timeout=120", + "--timeout=120", + "--tries=1", + "-O", + path_to_file, + url, + ] + + if use_archive: + for attempt in range(1, max_attempts + 1): + logger.debug("Starting BioMart wget attempt %s/%s (archive).", attempt, max_attempts) + result = subprocess.run(command, capture_output=True, text=True) + if result.returncode == 0: + return + stderr = (result.stderr or "").strip() + if stderr: + logger.warning( + "BioMart download failed (attempt %s/%s, return code %s). stderr: %s", + attempt, + max_attempts, + result.returncode, + stderr, + ) + else: + logger.warning( + "BioMart download failed (attempt %s/%s, return code %s). Retrying in %ss...", + attempt, + max_attempts, + result.returncode, + wait_seconds, + ) + if result.stdout: + logger.debug("BioMart wget stdout (attempt %s/%s): %s", attempt, max_attempts, result.stdout.strip()) + time.sleep(wait_seconds) + + logger.warning("Falling back to latest Ensembl BioMart URL after failure on %s.", base_archive) + else: + logger.debug("Skipping archive BioMart URL; using latest only.") - subprocess.run(shlex.split(command)) + if os.path.exists(path_to_file): + try: + os.remove(path_to_file) + except OSError as exc: + logger.warning( + "Failed to remove partial BioMart metadata file %s before fallback: %s", + path_to_file, + exc, + ) + command[-1] = fallback_url + for attempt in range(1, max_attempts + 1): + logger.debug("Starting BioMart wget attempt %s/%s (latest).", attempt, max_attempts) + result = subprocess.run(command, capture_output=True, text=True) + if result.returncode == 0: + return + stderr = (result.stderr or "").strip() + if stderr: + logger.warning( + "Fallback BioMart download failed (attempt %s/%s, return code %s). stderr: %s", + attempt, + max_attempts, + result.returncode, + stderr, + ) + else: + logger.warning( + "Fallback BioMart download failed (attempt %s/%s, return code %s). Retrying in %ss...", + attempt, + max_attempts, + result.returncode, + wait_seconds, + ) + if result.stdout: + logger.debug( + "Fallback BioMart wget stdout (attempt %s/%s): %s", + attempt, + max_attempts, + result.stdout.strip(), + ) + time.sleep(wait_seconds) + if use_archive: + message = ( + f"Failed to download BioMart metadata after {max_attempts} attempts on archive and " + f"{max_attempts} attempts on latest." + ) + else: + message = f"Failed to download BioMart metadata after {max_attempts} attempts on latest." + raise RuntimeError(message) -def get_biomart_metadata(datasets_dir, uniprot_ids): + +def get_biomart_metadata(datasets_dir, uniprot_ids, use_archive=True): """ Download a dataframe including ensembl canonical transcript IDs, HGNC IDs, Uniprot IDs, and other useful information. @@ -646,7 +953,7 @@ def get_biomart_metadata(datasets_dir, uniprot_ids): try: path_biomart_metadata = os.path.join(datasets_dir, "biomart_metadata.tsv") if not os.path.exists(path_biomart_metadata): - download_biomart_metadata(path_biomart_metadata) + download_biomart_metadata(path_biomart_metadata, use_archive=use_archive) # Parse biomart_df = pd.read_csv(path_biomart_metadata, sep="\t", header=None, low_memory=False) @@ -679,6 +986,159 @@ def get_biomart_metadata(datasets_dir, uniprot_ids): return canonical_transcripts +def get_ref_dna_from_ensembl_batch(transcript_ids, max_attempts=8, wait_seconds=3): + """ + Retrieve Ensembl CDS DNA sequences for up to 50 stable IDs in a single request. + + Ensembl REST docs: POST /sequence/id (max POST size = 50). + """ + + pid = os.getpid() + start_time = time.perf_counter() + + if not transcript_ids: + return [] + + # Keep output aligned with the input order (including any NA values). + results = [np.nan] * len(transcript_ids) + request_ids = [] + request_positions = [] + for pos, tid in enumerate(transcript_ids): + if pd.isna(tid): + continue + tid_str = str(tid) + request_ids.append(tid_str) + request_positions.append(pos) + + if not request_ids: + return results + + if len(request_ids) > 50: + raise ValueError(f"Ensembl POST /sequence/id supports max 50 IDs per request; got {len(request_ids)}.") + + url = f"{_ENSEMBL_REST_SERVER}/sequence/id" + headers = {"Content-Type": "application/json", "Accept": "application/json"} + payload = {"ids": request_ids} + params = {"type": "cds"} + + last_error = None + for attempt in range(1, max_attempts + 1): + try: + r = requests.post(url, headers=headers, json=payload, params=params, timeout=_ENSEMBL_REST_TIMEOUT) + if r.status_code == 429: + retry_after = r.headers.get("Retry-After") + try: + retry_after = float(retry_after) + except (TypeError, ValueError): + retry_after = wait_seconds + if attempt >= max_attempts: + logger.warning( + "Ensembl CDS batch rate limited (pid=%s, attempt=%s/%s). Giving up after %ss.", + pid, + attempt, + max_attempts, + retry_after, + ) + return results + logger.warning( + "Ensembl CDS batch rate limited (pid=%s, attempt=%s/%s). Retrying after %ss.", + pid, + attempt, + max_attempts, + retry_after, + ) + time.sleep(retry_after) + continue + + r.raise_for_status() + decoded = r.json() + + seq_by_id = {} + errors_by_id = {} + if isinstance(decoded, dict): + decoded = [decoded] + if not isinstance(decoded, list): + raise ValueError(f"Unexpected Ensembl response type: {type(decoded).__name__}") + + for item in decoded: + if not isinstance(item, dict): + continue + + item_id = item.get("id") + item_query = item.get("query") + item_keys = [k for k in (item_query, item_id) if k] + if not item_keys: + continue + + if "seq" in item and item.get("seq") is not None: + seq_val = item.get("seq") + for key in item_keys: + seq_by_id[key] = seq_val + elif "error" in item: + err_val = item.get("error") + for key in item_keys: + errors_by_id[key] = err_val + + missing = 0 + for pos, tid in zip(request_positions, request_ids): + seq_dna = seq_by_id.get(tid) + if seq_dna is None and "." in tid: + seq_dna = seq_by_id.get(tid.split(".", 1)[0]) + if not seq_dna: + missing += 1 + continue + results[pos] = seq_dna[:-3] if len(seq_dna) >= 3 else np.nan + + elapsed = time.perf_counter() - start_time + if missing > 0 and logger.isEnabledFor(logging.DEBUG): + example_missing = [tid for tid in request_ids if tid not in seq_by_id][:5] + logger.debug( + "Ensembl CDS batch completed with missing IDs (pid=%s, elapsed=%.2fs, requested=%s, missing=%s). Example missing: %s", + pid, + elapsed, + len(request_ids), + missing, + example_missing, + ) + if errors_by_id: + example_errors = {k: errors_by_id[k] for k in list(errors_by_id)[:3]} + logger.debug("Ensembl CDS batch errors (pid=%s): %s", pid, example_errors) + else: + logger.debug( + "Ensembl CDS batch completed (pid=%s, elapsed=%.2fs, requested=%s)", + pid, + elapsed, + len(request_ids), + ) + + return results + + except (requests.exceptions.RequestException, ValueError, json.JSONDecodeError) as exc: + last_error = exc + if attempt < max_attempts: + logger.debug( + "Ensembl CDS batch request failed (pid=%s, attempt=%s/%s): %s. Retrying in %ss...", + pid, + attempt, + max_attempts, + exc, + wait_seconds, + ) + time.sleep(wait_seconds) + continue + + elapsed = time.perf_counter() - start_time + logger.warning( + "Ensembl CDS batch failed (pid=%s, elapsed=%.2fs, requested=%s). Last error: %s", + pid, + elapsed, + len(request_ids), + last_error, + ) + + return results + + def get_ref_dna_from_ensembl(transcript_id): """ Use Ensembl GET sequence rest API to obtain CDS DNA @@ -687,15 +1147,25 @@ def get_ref_dna_from_ensembl(transcript_id): https://rest.ensembl.org/documentation/info/sequence_id """ - server = "https://rest.ensembl.org" - ext = f"/sequence/id/{transcript_id}?type=cds" + pid = os.getpid() + start_time = time.perf_counter() + failures = 0 + + if pd.isna(transcript_id): + logger.debug("Ensembl CDS start: (pid=%s) -> skipping", pid) + return np.nan + + transcript_id = str(transcript_id) + logger.debug("Ensembl CDS start: %s (pid=%s)", transcript_id, pid) + + url = f"{_ENSEMBL_REST_SERVER}/sequence/id/{transcript_id}?type=cds" status = "INIT" - i = 0 + last_error = None while status != "FINISHED": try: - r = requests.get(server+ext, headers={ "Content-Type" : "text/x-fasta"}, timeout=160) + r = requests.get(url, headers=_ENSEMBL_REST_HEADERS, timeout=_ENSEMBL_REST_TIMEOUT) if not r.ok: r.raise_for_status() @@ -705,29 +1175,57 @@ def get_ref_dna_from_ensembl(transcript_id): status = "FINISHED" except requests.exceptions.RequestException as e: - i += 1 + failures += 1 + last_error = e status = "ERROR" - if i%10 == 0: - logger.debug(f"Failed to retrieve sequence for {transcript_id} {e}: Retrying..") + if failures % 10 == 0: + logger.debug( + "Failed to retrieve sequence for %s (pid=%s, failures=%s) %s: Retrying..", + transcript_id, + pid, + failures, + e, + ) time.sleep(5) - if i == 100: - logger.debug(f"Failed to retrieve sequence for {transcript_id} {e}: Skipping..") + if failures == 100: + elapsed = time.perf_counter() - start_time + logger.warning( + "Ensembl CDS failed: %s (pid=%s, elapsed=%.2fs, failures=%s). Last error: %s", + transcript_id, + pid, + elapsed, + failures, + last_error, + ) return np.nan time.sleep(1) - seq_dna = "".join(r.text.strip().split("\n")[1:]) - - return seq_dna[:len(seq_dna)-3] - - -def get_ref_dna_from_ensembl_wrapper(ensembl_id): - """ - Wrapper for multiple processing function using - Ensembl Get sequence rest API. - """ + text = r.text.strip() + if text.startswith(">"): + seq_dna = "".join(text.splitlines()[1:]) + else: + seq_dna = "".join(text.splitlines()) + + elapsed = time.perf_counter() - start_time + if failures > 0: + logger.info( + "Ensembl CDS completed: %s (pid=%s, elapsed=%.2fs, failures=%s)", + transcript_id, + pid, + elapsed, + failures, + ) + else: + logger.debug( + "Ensembl CDS completed: %s (pid=%s, elapsed=%.2fs, failures=%s)", + transcript_id, + pid, + elapsed, + failures, + ) - return get_ref_dna_from_ensembl(ensembl_id) + return seq_dna[:-3] if len(seq_dna) >= 3 else np.nan def get_ref_dna_from_ensembl_mp(seq_df, cores): @@ -738,12 +1236,57 @@ def get_ref_dna_from_ensembl_mp(seq_df, cores): https://rest.ensembl.org/documentation/info/sequence_id """ - pool = multiprocessing.Pool(processes=cores) seq_df = seq_df.copy() - seq_df["Seq_dna"] = pool.map(get_ref_dna_from_ensembl_wrapper, seq_df.Ens_Transcr_ID) - pool.close() - pool.join() - + transcript_ids = seq_df.Ens_Transcr_ID.tolist() + total = len(transcript_ids) + if total == 0: + seq_df["Seq_dna"] = [] + return seq_df + + if cores > _ENSEMBL_CDS_MAX_CORES: + logger.info( + "Capping Ensembl CDS batch workers from %s to %s.", + cores, + _ENSEMBL_CDS_MAX_CORES, + ) + cores = _ENSEMBL_CDS_MAX_CORES + + logger.debug("Retrieving CDS DNA from Ensembl for %s transcripts with %s cores.", total, cores) + + if cores <= 1: + results = [] + batch_size = 50 + with tqdm(total=total, desc="Ensembl CDS") as pbar: + for i in range(0, total, batch_size): + batch_ids = transcript_ids[i : i + batch_size] + batch_results = get_ref_dna_from_ensembl_batch(batch_ids) + results.extend(batch_results) + pbar.update(len(batch_ids)) + seq_df["Seq_dna"] = results + retrieved = int(pd.Series(results).notna().sum()) + logger.debug( + "Completed Ensembl CDS retrieval: %s/%s sequences retrieved.", + retrieved, + total, + ) + return seq_df + + batch_size = 50 + batches = [transcript_ids[i : i + batch_size] for i in range(0, total, batch_size)] + results = [] + with multiprocessing.Pool(processes=cores) as pool: + results_iter = pool.imap(get_ref_dna_from_ensembl_batch, batches) + with tqdm(total=total, desc="Ensembl CDS") as pbar: + for batch_ids, batch_results in zip(batches, results_iter): + results.extend(batch_results) + pbar.update(len(batch_ids)) + seq_df["Seq_dna"] = results + retrieved = int(pd.Series(results).notna().sum()) + logger.debug( + "Completed Ensembl CDS retrieval: %s/%s sequences retrieved.", + retrieved, + total, + ) return seq_df @@ -783,19 +1326,31 @@ def process_seq_df(seq_df, datasets_dir, organism, uniprot_to_gene_dict, - ens_canonical_transcripts_lst, - num_cores=1): + use_archive_biomart=True): """ Retrieve DNA sequence and tri-nucleotide context for each structure in the initialized dataframe prioritizing structures obtained from transcripts whose exon coordinates are available in the Proteins API. + Canonical transcript metadata is retrieved internally from BioMart + for the provided dataset directory and Uniprot IDs. + Reference_info labels: 1 : Transcript ID, exons coord, seq DNA obtained from Proteins API -1 : Not available transcripts, seq DNA retrieved from Backtranseq API """ + if seq_df.empty: + logger.error("No sequences to process in process_seq_df; this should not happen.") + raise RuntimeError("Empty sequence dataframe: no structures to process.") + + ens_canonical_transcripts_lst = get_biomart_metadata( + datasets_dir, + seq_df["Uniprot_ID"].unique(), + use_archive=use_archive_biomart, + ) + # Process entries in Proteins API (Reference_info 1) #--------------------------------------------------- @@ -821,13 +1376,16 @@ def process_seq_df(seq_df, # Process entries not in Proteins API (Reference_info -1) #------------------------------------------------------------ - # Add DNA seq from Backtranseq for any other entry - logger.debug(f"Retrieving CDS DNA seq for entries without available transcript ID (Backtranseq API): {len(seq_df_not_uniprot)} structures..") - seq_df_not_uniprot = batch_backtranseq(seq_df_not_uniprot, 500, organism=organism) + if len(seq_df_not_uniprot) > 0: + # Add DNA seq from Backtranseq for any other entry + logger.debug(f"Retrieving CDS DNA seq for entries without available transcript ID (Backtranseq API): {len(seq_df_not_uniprot)} structures..") + seq_df_not_uniprot = batch_backtranseq(seq_df_not_uniprot, 100, organism=organism) - # Get trinucleotide context - seq_df_not_uniprot["Tri_context"] = seq_df_not_uniprot["Seq_dna"].apply( - lambda x: ",".join(per_site_trinucleotide_context(x, no_flanks=True))) + # Get trinucleotide context + seq_df_not_uniprot["Tri_context"] = np.nan + valid_seq_mask = seq_df_not_uniprot["Seq_dna"].notna() + seq_df_not_uniprot.loc[valid_seq_mask, "Tri_context"] = seq_df_not_uniprot.loc[valid_seq_mask, "Seq_dna"].apply( + lambda x: ",".join(per_site_trinucleotide_context(x, no_flanks=True))) # Prepare final output @@ -839,7 +1397,14 @@ def process_seq_df(seq_df, seq_df.Reference_info.value_counts().values)]) logger.info(f"Built of sequence dataframe completed. Retrieved {len(seq_df)} structures ({logger_report})") seq_df = add_extra_genes_to_seq_df(seq_df, uniprot_to_gene_dict) + pre_drop = len(seq_df) seq_df = drop_gene_duplicates(seq_df) + logger.debug( + "Duplicate gene removal: %s removed (from %s to %s).", + pre_drop - len(seq_df), + pre_drop, + len(seq_df), + ) return seq_df @@ -849,10 +1414,9 @@ def process_seq_df_mane(seq_df, uniprot_to_gene_dict, mane_mapping, mane_mapping_not_af, - ens_canonical_transcripts_lst, - custom_mane_metadata_path=None, + mane_only=False, num_cores=1, - mane_version=1.4): + use_archive_biomart=True): """ Retrieve DNA sequence and tri-nucleotide context for each structure in the initialized dataframe @@ -872,51 +1436,136 @@ def process_seq_df_mane(seq_df, seq_df_mane = seq_df_mane.drop(columns=["Gene"]).merge(mane_mapping, how="left", on="Uniprot_ID") seq_df_mane["Reference_info"] = 0 - # Add DNA seq from Ensembl for structures with available transcript ID - logger.debug(f"Retrieving CDS DNA seq from transcript ID (Ensembl API): {len(seq_df_mane)} structures..") - seq_df_mane = get_ref_dna_from_ensembl_mp(seq_df_mane, cores=num_cores) - - # Set failed and len-mismatching entries as no-transcripts entries - failed_ix = seq_df_mane.apply(lambda x: True if pd.isna(x.Seq_dna) else len(x.Seq_dna) / 3 != len(x.Seq), axis=1) - if sum(failed_ix) > 0: - seq_df_mane_failed = seq_df_mane[failed_ix] - seq_df_mane = seq_df_mane[~failed_ix] - seq_df_mane_failed = seq_df_mane_failed.drop(columns=[ - "Ens_Gene_ID", - "Ens_Transcr_ID", - "Reverse_strand", - "Chr", - "Refseq_prot", - "Reference_info", - "Seq_dna" - ]) - seq_df_nomane = pd.concat((seq_df_nomane, seq_df_mane_failed)) + if seq_df_mane.empty: + logger.warning("No MANE sequences to process; skipping Ensembl CDS retrieval.") + if "Seq_dna" not in seq_df_mane.columns: + seq_df_mane["Seq_dna"] = pd.Series(dtype=object) + else: + # Add DNA seq from Ensembl for structures with available transcript ID + logger.debug(f"Retrieving CDS DNA seq from transcript ID (Ensembl API): {len(seq_df_mane)} structures..") + seq_df_mane = get_ref_dna_from_ensembl_mp(seq_df_mane, cores=num_cores) + + # Retry missing entries using single-request API (bounded parallelism) + missing_mask = seq_df_mane["Seq_dna"].isna() + if missing_mask.any(): + missing_ids = seq_df_mane.loc[missing_mask, "Ens_Transcr_ID"].tolist() + logger.debug( + "Retrying %s missing Ensembl CDS entries with %s workers.", + len(missing_ids), + min(num_cores, _ENSEMBL_CDS_MAX_CORES), + ) + retry_workers = min(num_cores, _ENSEMBL_CDS_MAX_CORES) + if retry_workers <= 1: + retry_results = [get_ref_dna_from_ensembl(tid) for tid in missing_ids] + else: + with multiprocessing.Pool(processes=retry_workers) as pool: + retry_results = pool.map(get_ref_dna_from_ensembl, missing_ids) + recovered = sum(pd.notna(val) for val in retry_results) + seq_df_mane.loc[missing_mask, "Seq_dna"] = retry_results + if recovered > 0: + logger.debug( + "Recovered %s missing Ensembl CDS entries after single-request retry.", + recovered, + ) + + # Set failed and len-mismatching entries as no-transcripts entries + seq_len = seq_df_mane["Seq"].str.len() + dna_len = seq_df_mane["Seq_dna"].str.len() + failed_nan = seq_df_mane["Seq_dna"].isna() + failed_mismatch = (~failed_nan) & (dna_len / 3 != seq_len) + failed_ix = failed_nan | failed_mismatch + logger.debug( + "Ensembl CDS failures: total=%s (missing=%s, length_mismatch=%s).", + int(failed_ix.sum()), + int(failed_nan.sum()), + int(failed_mismatch.sum()), + ) + if sum(failed_ix) > 0: + seq_df_mane_failed = seq_df_mane[failed_ix] + seq_df_mane = seq_df_mane[~failed_ix] + seq_df_mane_failed = seq_df_mane_failed.drop(columns=[ + "Ens_Gene_ID", + "Ens_Transcr_ID", + "Reverse_strand", + "Chr", + "Refseq_prot", + "Reference_info", + "Seq_dna" + ]) + seq_df_nomane = pd.concat((seq_df_nomane, seq_df_mane_failed)) + logger.debug( + "Moved %s failed MANE entries to non-MANE pool.", + len(seq_df_mane_failed), + ) # Seq df not MANE # --------------- - seq_df_nomane = add_extra_genes_to_seq_df(seq_df_nomane, uniprot_to_gene_dict) # Filter out genes with NA - seq_df_nomane = seq_df_nomane[seq_df_nomane.Gene.isin(mane_mapping_not_af.Gene)] # Filter out genes that are not in MANE list - - # Retrieve seq from coordinates - logger.debug(f"Retrieving CDS DNA seq from reference genome (Proteins API): {len(seq_df_nomane['Uniprot_ID'].unique())} structures..") - coord_df = get_exons_coord(seq_df_nomane["Uniprot_ID"].unique(), ens_canonical_transcripts_lst) - seq_df_nomane = seq_df_nomane.merge(coord_df, on=["Seq", "Uniprot_ID"], how="left").reset_index(drop=True) # Discard entries whose Seq obtained by Proteins API don't exactly match the one in structure - seq_df_nomane = add_ref_dna_and_context(seq_df_nomane, hg38) - seq_df_nomane_tr = seq_df_nomane[seq_df_nomane["Reference_info"] == 1] - seq_df_nomane_notr = seq_df_nomane[seq_df_nomane["Reference_info"] == -1] + if seq_df_nomane.empty: + logger.debug("No non-MANE sequences to process; skipping Proteins/Backtranseq retrieval.") + seq_df_nomane_tr = seq_df_nomane.copy() + seq_df_nomane_notr = seq_df_nomane.copy() + else: + before_nomane = len(seq_df_nomane) + seq_df_nomane = add_extra_genes_to_seq_df(seq_df_nomane, uniprot_to_gene_dict) # Filter out genes with NA + after_extra = len(seq_df_nomane) + if not mane_only: + logger.debug("Filtering non-MANE entries using mane_mapping_not_af (gene whitelist).") + seq_df_nomane = seq_df_nomane[seq_df_nomane.Gene.isin(mane_mapping_not_af.Gene)] # Filter out genes that are not in MANE list + logger.debug( + "Non-MANE pool sizes: initial=%s, after_extra_genes=%s, after_mane_filter=%s.", + before_nomane, + after_extra, + len(seq_df_nomane), + ) + else: + logger.debug( + "Non-MANE pool sizes: initial=%s, after_extra_genes=%s (mane_only, no MANE filter applied).", + before_nomane, + after_extra, + ) - # Add DNA seq from Backtranseq for any other entry - logger.debug(f"Retrieving CDS DNA seq for genes without available transcript ID (Backtranseq API): {len(seq_df_nomane_notr)} structures..") - seq_df_nomane_notr = batch_backtranseq(seq_df_nomane_notr, 500, organism="Homo sapiens") + if seq_df_nomane.empty: + logger.debug("No non-MANE sequences after filtering; skipping Proteins/Backtranseq retrieval.") + seq_df_nomane_tr = seq_df_nomane.copy() + seq_df_nomane_notr = seq_df_nomane.copy() + else: + ens_canonical_transcripts_lst = get_biomart_metadata( + datasets_dir, + seq_df_nomane["Uniprot_ID"].unique(), + use_archive=use_archive_biomart, + ) + # Retrieve seq from coordinates + logger.debug(f"Retrieving CDS DNA seq from reference genome (Proteins API): {len(seq_df_nomane['Uniprot_ID'].unique())} structures..") + coord_df = get_exons_coord(seq_df_nomane["Uniprot_ID"].unique(), ens_canonical_transcripts_lst) + seq_df_nomane = seq_df_nomane.merge(coord_df, on=["Seq", "Uniprot_ID"], how="left").reset_index(drop=True) # Discard entries whose Seq obtained by Proteins API don't exactly match the one in structure + seq_df_nomane = add_ref_dna_and_context(seq_df_nomane, hg38) + seq_df_nomane_tr = seq_df_nomane[seq_df_nomane["Reference_info"] == 1] + seq_df_nomane_notr = seq_df_nomane[seq_df_nomane["Reference_info"] == -1] + + # Add DNA seq from Backtranseq for any other entry + if len(seq_df_nomane_notr) > 0: + logger.debug(f"Retrieving CDS DNA seq for genes without available transcript ID (Backtranseq API): {len(seq_df_nomane_notr)} structures..") + seq_df_nomane_notr = batch_backtranseq(seq_df_nomane_notr, 100, organism="Homo sapiens") # Get trinucleotide context seq_df_not_uniprot = pd.concat((seq_df_mane, seq_df_nomane_notr)) - seq_df_not_uniprot["Tri_context"] = seq_df_not_uniprot["Seq_dna"].apply( + if "Seq_dna" not in seq_df_not_uniprot.columns: + seq_df_not_uniprot["Seq_dna"] = pd.Series(dtype=object) + seq_df_not_uniprot["Tri_context"] = np.nan + valid_seq_mask = seq_df_not_uniprot["Seq_dna"].notna() + seq_df_not_uniprot.loc[valid_seq_mask, "Tri_context"] = seq_df_not_uniprot.loc[valid_seq_mask, "Seq_dna"].apply( lambda x: ",".join(per_site_trinucleotide_context(x, no_flanks=True))) # Prepare final output seq_df = pd.concat((seq_df_not_uniprot, seq_df_nomane_tr)).reset_index(drop=True) + pre_drop = len(seq_df) seq_df = drop_gene_duplicates(seq_df) + logger.debug( + "Duplicate gene removal: %s removed (from %s to %s).", + pre_drop - len(seq_df), + pre_drop, + len(seq_df), + ) report_df = seq_df.Reference_info.value_counts().reset_index() report_df = report_df.rename(columns={"index" : "Source"}) report_df.Source = report_df.Source.map({1 : "Proteins API", @@ -933,9 +1582,11 @@ def get_seq_df(datasets_dir, output_seq_df, organism = "Homo sapiens", mane=False, + mane_only=False, custom_mane_metadata_path=None, num_cores=1, - mane_version=1.4): + mane_version=1.4, + af_version=None): """ Generate a dataframe including IDs mapping information, the protein sequence, the DNA sequence and its tri-nucleotide context, which is @@ -968,6 +1619,19 @@ def get_seq_df(datasets_dir, ) uniprot_to_gene_dict = dict(zip(mane_mapping["Uniprot_ID"], mane_mapping["Gene"])) + if custom_mane_metadata_path is not None: + custom_symbol_map = load_custom_symbol_map(custom_mane_metadata_path) + if custom_symbol_map: + filled = 0 + for ens_id, symbol in custom_symbol_map.items(): + if ens_id in uniprot_ids and (ens_id not in uniprot_to_gene_dict or pd.isna(uniprot_to_gene_dict[ens_id])): + uniprot_to_gene_dict[ens_id] = symbol + filled += 1 + if filled > 0: + logger.debug( + "Filled %s gene symbols from custom MANE samplesheet for ENSP-only entries.", + filled, + ) missing_uni_ids = list(set(uniprot_ids) - set(mane_mapping.Uniprot_ID)) uniprot_to_gene_dict = uniprot_to_gene_dict | uniprot_to_hugo(missing_uni_ids) else: @@ -981,30 +1645,37 @@ def get_seq_df(datasets_dir, # uniprot_to_gene_dict = uniprot_to_hugo_pressed(uniprot_ids) # --- - # Get biomart metadata and canonical transcript IDs - ens_canonical_transcripts_lst = get_biomart_metadata(datasets_dir, uniprot_ids) + use_archive_biomart = True + if af_version is not None and str(af_version) != "4": + use_archive_biomart = False + logger.debug( + "Using latest BioMart URL only because af_version=%s (archive disabled).", + af_version, + ) # Create a dataframe with protein sequences logger.debug("Initializing sequence df..") seq_df = initialize_seq_df(pdb_dir, uniprot_to_gene_dict) if mane: - seq_df = process_seq_df_mane(seq_df, - datasets_dir, - uniprot_to_gene_dict, - mane_mapping, - mane_mapping_not_af, - ens_canonical_transcripts_lst, - custom_mane_metadata_path, - num_cores, - mane_version=mane_version) + seq_df = process_seq_df_mane( + seq_df, + datasets_dir, + uniprot_to_gene_dict, + mane_mapping, + mane_mapping_not_af, + mane_only, + num_cores, + use_archive_biomart=use_archive_biomart, + ) else: - seq_df = process_seq_df(seq_df, - datasets_dir, - organism, - uniprot_to_gene_dict, - ens_canonical_transcripts_lst, - num_cores) + seq_df = process_seq_df( + seq_df, + datasets_dir, + organism, + uniprot_to_gene_dict, + use_archive_biomart=use_archive_biomart, + ) # Save seq_df_cols = ['Gene', 'HGNC_ID', 'Ens_Gene_ID', @@ -1019,13 +1690,6 @@ def get_seq_df(datasets_dir, if __name__ == "__main__": - output_datasets = '/data/bbg/nobackup/scratch/oncodrive3d/tests/datasets_mane_240725_mane_missing_dev' - get_seq_df( - datasets_dir=output_datasets, - output_seq_df=os.path.join(output_datasets, "seq_for_mut_prob.tsv"), - organism='Homo sapiens', - mane=True, - num_cores=8, - mane_version=1.4, - custom_mane_metadata_path="/data/bbg/nobackup/scratch/oncodrive3d/mane_missing/data/250724-no_fragments/af_predictions/previously_pred/samplesheet.csv" - ) \ No newline at end of file + raise SystemExit( + "This module is intended to be used via the CLI: `oncodrive3d build-datasets`." + ) diff --git a/scripts/datasets/utils.py b/scripts/datasets/utils.py index 70612b2..58abaef 100644 --- a/scripts/datasets/utils.py +++ b/scripts/datasets/utils.py @@ -115,7 +115,7 @@ def calculate_hash(filepath: str, hash_func=hashlib.sha256) -> str: return hash_obj.hexdigest() -def download_single_file(url: str, destination: str, threads: int, proteome=None) -> None: +def download_single_file(url: str, destination: str, threads: int, proteome=None, ssl=False) -> None: """ Downloads a file from a URL and saves it to the specified destination. @@ -124,7 +124,7 @@ def download_single_file(url: str, destination: str, threads: int, proteome=None destination (str): The local path where the file will be saved. """ - num_connections = 15 if threads > 40 else threads + num_connections = min(threads, 8) if os.path.exists(destination): logger.debug(f"File {destination} already exists..") @@ -138,7 +138,7 @@ def download_single_file(url: str, destination: str, threads: int, proteome=None logger.debug(f'Downloading {url}') logger.debug(f"Downloading to {destination}") - dl = Downloader(timeout=aiohttp.ClientTimeout(sock_read=400), ssl=False) + dl = Downloader(timeout=aiohttp.ClientTimeout(sock_read=400), ssl=ssl) dl.start(url, destination, segments=num_connections, display=True, retries=10, clear_terminal=False) logger.debug('Download complete') @@ -392,4 +392,4 @@ def uniprot_to_hugo(uniprot_ids, hugo_as_keys=False, batch_size=5000): # if hugo_as_keys: # result_dict = convert_dict_hugo_to_uniprot(result_dict) -# return result_dict \ No newline at end of file +# return result_dict diff --git a/scripts/main.py b/scripts/main.py index 8fd1af3..42ed143 100755 --- a/scripts/main.py +++ b/scripts/main.py @@ -43,7 +43,9 @@ def oncodrive3D(): @click.option("-M", "--mane_only", help="Use only structures predicted from MANE Select transcripts", is_flag=True) @click.option("-C", "--custom_mane_pdb_dir", - help="Directory where to load custom MANE PDB structures (overwriting existing ones)") + help="Directory where to load custom MANE PDB structures (overwriting existing ones; requires --mane_only)") +@click.option("--custom_pae_dir", + help="Directory containing pre-downloaded PAE JSON files to copy into the build (renamed to 'pae')") @click.option("-f", "--custom_mane_metadata_path", help="Path to a dataframe including the Ensembl Protein ID and the amino acid sequence of the custom MANE PDB structures") @click.option("-j", "--mane_version", default=1.4, @@ -52,8 +54,8 @@ def oncodrive3D(): help="Distance threshold (Å) to define contact between amino acids") @click.option("-c", "--cores", type=click.IntRange(min=1, max=len(os.sched_getaffinity(0)), clamp=False), default=len(os.sched_getaffinity(0)), help="Number of cores to use in the computation") -@click.option("--af_version", type=click.IntRange(min=1, clamp=False), default=4, - help="Version of AlphaFold 2 predictions") +@click.option("--af_version", type=click.IntRange(min=1, clamp=False), default=6, + help="AlphaFold DB version for non-MANE downloads (MANE uses v4)") @click.option("-y", "--yes", help="No interaction", is_flag=True) @click.option("-v", "--verbose", @@ -64,6 +66,7 @@ def build_datasets(output_dir, mane, mane_only, custom_mane_pdb_dir, + custom_pae_dir, custom_mane_metadata_path, distance_threshold, cores, @@ -85,6 +88,7 @@ def build_datasets(output_dir, logger.info(f"MANE Select: {mane}") logger.info(f"MANE Select only: {mane_only}") logger.info(f"Custom MANE PDB directory: {custom_mane_pdb_dir}") + logger.info(f"Custom PAE directory: {custom_pae_dir}") logger.info(f"Custom MANE PDB metadata path: {custom_mane_metadata_path}") logger.info(f"Distance threshold: {distance_threshold}Å") logger.info(f"CPU cores: {cores}") @@ -100,6 +104,7 @@ def build_datasets(output_dir, mane, mane_only, custom_mane_pdb_dir, + custom_pae_dir, custom_mane_metadata_path, distance_threshold, cores, diff --git a/tools/preprocessing/README.md b/tools/preprocessing/README.md index 6e35e46..79706aa 100644 --- a/tools/preprocessing/README.md +++ b/tools/preprocessing/README.md @@ -99,7 +99,7 @@ Arguments: - `--max-workers`: Parallel workers for canonical indexing (default = all cores). - `--filter-long-sequences/--no-filter-long-sequences`: Whether to drop long proteins from the nf-core input (default enabled). - `--max-sequence-length`: Length cutoff applied when filtering (default `2700` residues). -- `--include-metadata/--no-include-metadata`: Add `symbol`, `CGC`, and `length` columns to every emitted `samplesheet.csv` (default disabled). +- `--include-metadata/--no-include-metadata`: Add `CGC` and `length` columns to every emitted `samplesheet.csv`. - `--config-path`: YAML with path templates describing where to place predicted/missing/retrieved bundles relative to `--samplesheet-folder` (default `config.yaml`). > [!NOTE] diff --git a/tools/preprocessing/prepare_samplesheet.py b/tools/preprocessing/prepare_samplesheet.py index cb3401e..a05db2e 100644 --- a/tools/preprocessing/prepare_samplesheet.py +++ b/tools/preprocessing/prepare_samplesheet.py @@ -13,11 +13,16 @@ import os +import sys +from pathlib import Path import click import gzip import time import pandas as pd -from pathlib import Path +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + from scripts.datasets.utils import download_single_file # import logging diff --git a/tools/preprocessing/update_samplesheet_and_structures.py b/tools/preprocessing/update_samplesheet_and_structures.py index 06b6b3f..0edc9e1 100644 --- a/tools/preprocessing/update_samplesheet_and_structures.py +++ b/tools/preprocessing/update_samplesheet_and_structures.py @@ -381,6 +381,38 @@ def prepare_annotation_maps( return seq_map[["symbol", "CGC"]], refseq_map[["symbol", "CGC"]], cgc_symbols +def prepare_symbol_maps(mane_summary_path: Path) -> tuple[pd.DataFrame, pd.DataFrame]: + """Build lookup tables that map ENSP/RefSeq identifiers to symbols.""" + mane_summary = pd.read_table(mane_summary_path) + + column_aliases = { + "ensembl_prot": {"Ensembl_prot", "ensembl_prot", "Ens_Prot_ID"}, + "refseq_prot": {"RefSeq_prot", "refseq_prot"}, + "symbol": {"symbol", "Gene Symbol", "gene_symbol"}, + } + + rename_map = {} + for target, candidates in column_aliases.items(): + for candidate in candidates: + if candidate in mane_summary.columns: + rename_map[candidate] = target + break + required = {"ensembl_prot", "refseq_prot", "symbol"} + if not required.issubset(rename_map.values()): + missing = required - set(rename_map.values()) + raise KeyError(f"Missing columns in MANE summary: {missing}") + + mane_summary = mane_summary.rename(columns=rename_map) + annotations = mane_summary[["ensembl_prot", "refseq_prot", "symbol"]].copy() + annotations["ensembl_prot"] = strip_version_suffix(annotations["ensembl_prot"]) + annotations["refseq_prot"] = strip_version_suffix(annotations["refseq_prot"]) + annotations = annotations.drop_duplicates() + + seq_map = annotations.dropna(subset=["ensembl_prot"]).set_index("ensembl_prot") + refseq_map = annotations.dropna(subset=["refseq_prot"]).set_index("refseq_prot") + return seq_map[["symbol"]], refseq_map[["symbol"]] + + def attach_symbol_and_cgc( df: pd.DataFrame, seq_map: pd.DataFrame, @@ -389,12 +421,13 @@ def attach_symbol_and_cgc( """Annotate `df` with symbol/CGC columns using the provided lookup tables.""" annotated = df.copy() seq_keys = strip_version_suffix(annotated["sequence"]) + if "symbol" not in annotated.columns: + annotated["symbol"] = pd.NA + if "CGC" not in annotated.columns: + annotated["CGC"] = pd.NA if not seq_map.empty: - annotated["symbol"] = seq_keys.map(seq_map["symbol"]) - annotated["CGC"] = seq_keys.map(seq_map["CGC"]) - else: - annotated["symbol"] = pd.Series("", index=annotated.index) - annotated["CGC"] = pd.Series(0, index=annotated.index, dtype="Int64") + annotated["symbol"] = annotated["symbol"].fillna(seq_keys.map(seq_map["symbol"])) + annotated["CGC"] = annotated["CGC"].fillna(seq_keys.map(seq_map["CGC"])) if "refseq_prot" in annotated.columns and not refseq_map.empty: refseq_keys = strip_version_suffix(annotated["refseq_prot"]) @@ -406,17 +439,12 @@ def attach_symbol_and_cgc( return annotated -def build_metadata_map( - samplesheet: pd.DataFrame, - fasta_dir: Path, - mane_summary_path: Path, - cgc_path: Optional[Path], -) -> pd.DataFrame: - """Return a dataframe with one row per sequence containing symbol/CGC/length metadata.""" +def build_symbol_map(samplesheet: pd.DataFrame, mane_summary_path: Path) -> pd.DataFrame: + """Return a dataframe with one row per sequence containing symbol metadata.""" if samplesheet.empty: - return pd.DataFrame(columns=["sequence", "symbol", "CGC", "length"]) + return pd.DataFrame(columns=["sequence", "symbol"]) - seq_map, refseq_map, _ = prepare_annotation_maps(mane_summary_path, cgc_path) + seq_map, refseq_map = prepare_symbol_maps(mane_summary_path) metadata = samplesheet[["sequence"]].drop_duplicates().copy() if "refseq_prot" in samplesheet.columns: @@ -427,11 +455,44 @@ def build_metadata_map( ) metadata["refseq_prot"] = metadata["sequence"].map(refseq_lookup) - metadata = attach_symbol_and_cgc(metadata, seq_map, refseq_map) + seq_keys = strip_version_suffix(metadata["sequence"]) + if not seq_map.empty: + metadata["symbol"] = seq_keys.map(seq_map["symbol"]) + else: + metadata["symbol"] = pd.Series(pd.NA, index=metadata.index) + + if "refseq_prot" in metadata.columns and not refseq_map.empty: + refseq_keys = strip_version_suffix(metadata["refseq_prot"]) + metadata["symbol"] = metadata["symbol"].fillna(refseq_keys.map(refseq_map["symbol"])) + + metadata["symbol"] = metadata["symbol"].fillna("") + return metadata.drop(columns=["refseq_prot"], errors="ignore") + + +def build_metadata_map( + samplesheet: pd.DataFrame, + fasta_dir: Path, + mane_summary_path: Path, + cgc_path: Optional[Path], + symbol_map: Optional[pd.DataFrame] = None, +) -> pd.DataFrame: + """Return a dataframe with one row per sequence containing symbol/CGC/length metadata.""" + if samplesheet.empty: + return pd.DataFrame(columns=["sequence", "symbol", "CGC", "length"]) + + metadata = samplesheet[["sequence"]].drop_duplicates().copy() + if symbol_map is None: + symbol_map = build_symbol_map(samplesheet, mane_summary_path) + metadata = attach_metadata(metadata, symbol_map) + metadata["symbol"] = metadata.get("symbol", pd.Series("", index=metadata.index)).fillna("") + cgc_symbols = load_cgc_symbols(cgc_path) + if cgc_symbols: + metadata["CGC"] = metadata["symbol"].isin(cgc_symbols).astype(int) + else: + metadata["CGC"] = 0 fasta_paths = metadata["sequence"].map(lambda seq: (fasta_dir / f"{seq}.fasta").as_posix()) metadata["length"] = compute_fasta_lengths(pd.Series(fasta_paths.values, index=metadata.index)) - metadata = metadata.drop(columns=["refseq_prot"], errors="ignore") return metadata @@ -439,12 +500,23 @@ def attach_metadata(df: pd.DataFrame, metadata_map: Optional[pd.DataFrame]) -> p """Merge symbol/CGC/length metadata into df when available.""" if metadata_map is None or df.empty: return df - columns = ["sequence", "symbol", "CGC", "length"] - metadata = metadata_map[columns].drop_duplicates(subset=["sequence"]) - return ( - df.drop(columns=["symbol", "CGC", "length"], errors="ignore") - .merge(metadata, on="sequence", how="left") - ) + available = [col for col in ["symbol", "CGC", "length"] if col in metadata_map.columns] + if not available: + return df + metadata = metadata_map[["sequence"] + available].drop_duplicates(subset=["sequence"]) + merged = df.merge(metadata, on="sequence", how="left", suffixes=("", "_meta")) + for col in available: + meta_col = f"{col}_meta" + if meta_col in merged.columns: + if col in merged.columns: + merged[col] = merged[meta_col].combine_first(merged[col]) + else: + merged[col] = merged[meta_col] + merged = merged.drop(columns=[meta_col]) + for col in ("CGC", "length"): + if col not in available and col in merged.columns: + merged = merged.drop(columns=[col]) + return merged def attach_refseq(df: pd.DataFrame, master_samplesheet: Optional[pd.DataFrame]) -> pd.DataFrame: @@ -555,7 +627,7 @@ def filter_long_sequences( if include_metadata: removed_clean = removed.copy() else: - removed_clean = removed.drop(columns=["symbol", "CGC", "length"], errors="ignore") + removed_clean = removed.drop(columns=["CGC", "length"], errors="ignore") removed_path = missing_dir / "samplesheet_removed_long.csv" removed_clean.to_csv(removed_path, index=False) @@ -690,6 +762,7 @@ def run_pipeline( ) -> None: """Execute the MANE maintenance workflow end-to-end.""" samplesheet = pd.read_csv(paths.samplesheet_path) + symbol_map = build_symbol_map(samplesheet, paths.mane_summary_path) metadata_map = None if settings.include_metadata: metadata_map = build_metadata_map( @@ -697,16 +770,22 @@ def run_pipeline( paths.fasta_dir, paths.mane_summary_path, paths.cgc_list_path, + symbol_map=symbol_map, ) - samplesheet = attach_metadata(samplesheet, metadata_map) - samplesheet.to_csv(paths.samplesheet_path, index=False) + + metadata_for_outputs = metadata_map if metadata_map is not None else symbol_map + samplesheet = attach_metadata(samplesheet, metadata_for_outputs) + samplesheet.to_csv(paths.samplesheet_path, index=False) + if settings.include_metadata: print(f"Annotated master samplesheet with metadata at {paths.samplesheet_path}") + else: + print(f"Annotated master samplesheet with symbols at {paths.samplesheet_path}") print(f"Loaded {len(samplesheet):,} samples from {paths.samplesheet_path}") if predicted_raw_dir: if not predicted_raw_dir.exists(): raise FileNotFoundError(f"--predicted-dir not found: {predicted_raw_dir}") - sync_predicted_bundle(predicted_raw_dir, paths.predicted_bundle_dir, metadata_map) + sync_predicted_bundle(predicted_raw_dir, paths.predicted_bundle_dir, metadata_for_outputs) else: print("No --predicted-dir supplied; skipping nf-core sync and using existing predicted bundle.") @@ -727,7 +806,7 @@ def run_pipeline( samplesheet_missing = samplesheet_missing.iloc[0:0].copy() if settings.enable_canonical_reuse: - retrieved_df = reuse_canonical_structures(samplesheet_missing, paths, settings, metadata_map) + retrieved_df = reuse_canonical_structures(samplesheet_missing, paths, settings, metadata_for_outputs) if retrieved_df.empty: print("Skipping canonical reuse because no PDBs were harvested.") else: @@ -749,11 +828,8 @@ def run_pipeline( include_metadata=settings.include_metadata, ) - if settings.include_metadata: - annotated_missing_df.to_csv(paths.missing_samplesheet_path, index=False) - else: - clean_missing_df = annotated_missing_df.drop(columns=["symbol", "CGC", "length"], errors="ignore") - clean_missing_df.to_csv(paths.missing_samplesheet_path, index=False) + clean_missing_df = annotated_missing_df.drop(columns=["symbol", "CGC", "length"], errors="ignore") + clean_missing_df.to_csv(paths.missing_samplesheet_path, index=False) print(f"Updated missing samplesheet saved to {paths.missing_samplesheet_path}") if settings.filter_long_sequences and not removed_long_df.empty: @@ -772,7 +848,7 @@ def run_pipeline( merge_structure_bundles( bundles_to_merge, paths.final_bundle_dir, - metadata_map, + metadata_for_outputs, master_samplesheet=samplesheet, ) print(f"Final bundle written to {paths.final_bundle_dir}") @@ -834,7 +910,7 @@ def run_pipeline( "--include-metadata/--no-include-metadata", default=False, show_default=True, - help="Attach symbol/CGC/length columns to every emitted samplesheet.", + help="Attach CGC/length columns (symbol is always included) to every emitted samplesheet.", ) def cli( samplesheet_folder: str,