Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyhealth/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(self, *args, **kwargs):
from .bmd_hs import BMDHSDataset
from .support2 import Support2Dataset
from .tcga_prad import TCGAPRADDataset
from .tcga_paad import TCGAPAADDataset
from .splitter import (
split_by_patient,
split_by_patient_conformal,
Expand Down
22 changes: 22 additions & 0 deletions pyhealth/datasets/configs/tcga_paad.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
version: "1.0"
tables:
mutations:
file_path: "tcga_paad_mutations-pyhealth.csv"
patient_id: "patient_id"
timestamp: null
attributes:
- "hugo_symbol"
- "variant_classification"
- "variant_type"
- "hgvsc"
- "hgvsp"
- "tumor_sample_barcode"
clinical:
file_path: "tcga_paad_clinical-pyhealth.csv"
patient_id: "patient_id"
timestamp: null
attributes:
- "age_at_diagnosis"
- "vital_status"
- "days_to_death"
- "tumor_stage"
271 changes: 271 additions & 0 deletions pyhealth/datasets/tcga_paad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
"""TCGA-PAAD dataset for PyHealth.

This module provides the TCGAPAADDataset class for loading and processing
TCGA Pancreatic Adenocarcinoma (PAAD) data for machine learning tasks.
"""

import logging
import os
from pathlib import Path
from typing import List, Optional

import pandas as pd

from .base_dataset import BaseDataset

logger = logging.getLogger(__name__)


class TCGAPAADDataset(BaseDataset):
"""TCGA Pancreatic Adenocarcinoma (PAAD) dataset.

The Cancer Genome Atlas (TCGA) PAAD dataset contains multi-omics data
for pancreatic adenocarcinoma patients, including somatic mutations,
clinical data, and survival outcomes. This dataset enables cancer
survival prediction and mutation analysis tasks.

Dataset is available at:
https://portal.gdc.cancer.gov/projects/TCGA-PAAD

Args:
root: Root directory of the raw data containing the TCGA-PAAD files.
tables: Optional list of additional tables to load beyond defaults.
dataset_name: Optional name of the dataset. Defaults to "tcga_paad".
config_path: Optional path to the configuration file. If not provided,
uses the default config in the configs directory.

Attributes:
root: Root directory of the raw data.
dataset_name: Name of the dataset.
config_path: Path to the configuration file.

Examples:
>>> from pyhealth.datasets import TCGAPAADDataset
>>> dataset = TCGAPAADDataset(root="/path/to/tcga_paad")
>>> dataset.stats()
>>> samples = dataset.set_task()
>>> print(samples[0])
"""

def __init__(
self,
root: str,
tables: List[str] = None,
dataset_name: Optional[str] = None,
config_path: Optional[str] = None,
**kwargs,
) -> None:
if config_path is None:
logger.info("No config path provided, using default config")
config_path = Path(__file__).parent / "configs" / "tcga_paad.yaml"

# Prepare standardized CSVs if not exists
mutations_csv = os.path.join(root, "tcga_paad_mutations-pyhealth.csv")
clinical_csv = os.path.join(root, "tcga_paad_clinical-pyhealth.csv")

if not os.path.exists(mutations_csv) or not os.path.exists(clinical_csv):
logger.info("Preparing TCGA-PAAD metadata...")
self.prepare_metadata(root)

default_tables = ["mutations", "clinical"]
tables = default_tables + (tables or [])

super().__init__(
root=root,
tables=tables,
dataset_name=dataset_name or "tcga_paad",
config_path=config_path,
**kwargs,
)

@staticmethod
def prepare_metadata(root: str) -> None:
"""Prepare metadata for the TCGA-PAAD dataset.

Converts raw TCGA MAF and clinical files to standardized CSV format.

Args:
root: Root directory containing the TCGA-PAAD files.
"""
# Process mutations file
TCGAPAADDataset._prepare_mutations(root)
# Process clinical file
TCGAPAADDataset._prepare_clinical(root)

@staticmethod
def _prepare_mutations(root: str) -> None:
"""Prepare mutations data from MAF file."""
# Try to find the raw mutations file
possible_files = [
"PAAD_mutations.csv",
"TCGA.PAAD.mutect.maf",
"TCGA.PAAD.mutect.maf.gz",
"PAAD.maf",
"PAAD.maf.gz",
"mutations.maf",
]

raw_file = None
for fname in possible_files:
fpath = os.path.join(root, fname)
if os.path.exists(fpath):
raw_file = fpath
break

output_path = os.path.join(root, "tcga_paad_mutations-pyhealth.csv")

if raw_file is None:
logger.warning(
f"No raw TCGA-PAAD mutations file found in {root}. "
"Please download from GDC portal or use TCGAmutations R package."
)
# Create empty placeholder
pd.DataFrame(
columns=[
"patient_id",
"hugo_symbol",
"variant_classification",
"variant_type",
"hgvsc",
"hgvsp",
"tumor_sample_barcode",
]
).to_csv(output_path, index=False)
return

logger.info(f"Processing TCGA-PAAD mutations file: {raw_file}")

# Read the raw file
if raw_file.endswith(".gz"):
df = pd.read_csv(
raw_file, sep="\t", compression="gzip", comment="#", low_memory=False
)
elif raw_file.endswith(".maf"):
df = pd.read_csv(raw_file, sep="\t", comment="#", low_memory=False)
else:
df = pd.read_csv(raw_file, low_memory=False)

# Standardize column names
column_mapping = {
"Hugo_Symbol": "hugo_symbol",
"Variant_Classification": "variant_classification",
"Variant_Type": "variant_type",
"HGVSc": "hgvsc",
"HGVSp_Short": "hgvsp",
"HGVSp": "hgvsp",
"Tumor_Sample_Barcode": "tumor_sample_barcode",
}

rename_dict = {k: v for k, v in column_mapping.items() if k in df.columns}
df = df.rename(columns=rename_dict)

# Extract patient_id from tumor_sample_barcode (first 12 characters)
if "tumor_sample_barcode" in df.columns:
df["patient_id"] = df["tumor_sample_barcode"].str[:12]
else:
df["patient_id"] = df.index.astype(str)

# Select output columns
output_cols = [
"patient_id",
"hugo_symbol",
"variant_classification",
"variant_type",
"hgvsc",
"hgvsp",
"tumor_sample_barcode",
]
available_cols = [c for c in output_cols if c in df.columns]
df_out = df[available_cols]

df_out.to_csv(output_path, index=False)
logger.info(f"Saved {len(df_out)} mutations to {output_path}")

@staticmethod
def _prepare_clinical(root: str) -> None:
"""Prepare clinical data file."""
# Try to find the raw clinical file
possible_files = [
"PAAD_clinical.csv",
"clinical.tsv",
"clinical.csv",
"nationwidechildrens.org_clinical_patient_paad.txt",
]

raw_file = None
for fname in possible_files:
fpath = os.path.join(root, fname)
if os.path.exists(fpath):
raw_file = fpath
break

output_path = os.path.join(root, "tcga_paad_clinical-pyhealth.csv")

if raw_file is None:
logger.warning(
f"No raw TCGA-PAAD clinical file found in {root}. "
"Please download from GDC portal."
)
# Create empty placeholder
pd.DataFrame(
columns=[
"patient_id",
"age_at_diagnosis",
"vital_status",
"days_to_death",
"tumor_stage",
]
).to_csv(output_path, index=False)
return

logger.info(f"Processing TCGA-PAAD clinical file: {raw_file}")

# Read the raw file
sep = "\t" if raw_file.endswith(".tsv") or raw_file.endswith(".txt") else ","
df = pd.read_csv(raw_file, sep=sep, low_memory=False)

# Standardize column names (TCGA uses various naming conventions)
column_mapping = {
"submitter_id": "patient_id",
"bcr_patient_barcode": "patient_id",
"case_id": "patient_id",
"age_at_diagnosis": "age_at_diagnosis",
"age_at_initial_pathologic_diagnosis": "age_at_diagnosis",
"vital_status": "vital_status",
"days_to_death": "days_to_death",
"tumor_stage": "tumor_stage",
"ajcc_pathologic_stage": "tumor_stage",
"pathologic_stage": "tumor_stage",
}

rename_dict = {k: v for k, v in column_mapping.items() if k in df.columns}
df = df.rename(columns=rename_dict)

# If patient_id doesn't exist, create from index
if "patient_id" not in df.columns:
df["patient_id"] = df.index.astype(str)

# Select output columns
output_cols = [
"patient_id",
"age_at_diagnosis",
"vital_status",
"days_to_death",
"tumor_stage",
]
available_cols = [c for c in output_cols if c in df.columns]
df_out = df[available_cols].drop_duplicates(subset=["patient_id"])

df_out.to_csv(output_path, index=False)
logger.info(f"Saved {len(df_out)} clinical records to {output_path}")

@property
def default_task(self):
"""Returns the default task for this dataset.

Returns:
CancerSurvivalPrediction: The default prediction task.
"""
from pyhealth.tasks import CancerSurvivalPrediction

return CancerSurvivalPrediction()
4 changes: 4 additions & 0 deletions test-resources/tcga_paad/PAAD_clinical.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
submitter_id,age_at_diagnosis,vital_status,days_to_death,tumor_stage
TCGA-AB-1234,23000,Alive,,Stage II
TCGA-AB-5678,25000,Dead,300,Stage III
TCGA-AB-9012,28000,Alive,,Stage II
4 changes: 4 additions & 0 deletions test-resources/tcga_paad/PAAD_mutations.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Hugo_Symbol,Variant_Classification,Variant_Type,HGVSc,HGVSp_Short,Tumor_Sample_Barcode
KRAS,Missense_Mutation,SNP,c.35G>T,p.G12V,TCGA-AB-1234-01A-01D-1234-08
TP53,Nonsense_Mutation,SNP,c.743G>A,p.R248Q,TCGA-AB-5678-01A-01D-1234-08
SMAD4,Frame_Shift_Del,DEL,c.123_124del,p.L41fs,TCGA-AB-9012-01A-01D-1234-08
4 changes: 4 additions & 0 deletions test-resources/tcga_paad/tcga_paad_clinical-pyhealth.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
patient_id,age_at_diagnosis,vital_status,days_to_death,tumor_stage
TCGA-AB-1234,23000,Alive,,Stage II
TCGA-AB-5678,25000,Dead,300.0,Stage III
TCGA-AB-9012,28000,Alive,,Stage II
4 changes: 4 additions & 0 deletions test-resources/tcga_paad/tcga_paad_mutations-pyhealth.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
patient_id,hugo_symbol,variant_classification,variant_type,hgvsc,hgvsp,tumor_sample_barcode
TCGA-AB-1234,KRAS,Missense_Mutation,SNP,c.35G>T,p.G12V,TCGA-AB-1234-01A-01D-1234-08
TCGA-AB-5678,TP53,Nonsense_Mutation,SNP,c.743G>A,p.R248Q,TCGA-AB-5678-01A-01D-1234-08
TCGA-AB-9012,SMAD4,Frame_Shift_Del,DEL,c.123_124del,p.L41fs,TCGA-AB-9012-01A-01D-1234-08
57 changes: 57 additions & 0 deletions tests/core/test_tcga_paad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
Unit tests for the TCGAPAADDataset, mirroring PRAD tests style.
"""
import unittest
from pathlib import Path

from pyhealth.datasets import TCGAPAADDataset
from pyhealth.tasks import CancerSurvivalPrediction


class TestTCGAPAADDataset(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.test_resources = (
Path(__file__).parent.parent.parent / "test-resources" / "tcga_paad"
)

def test_dataset_initialization(self):
dataset = TCGAPAADDataset(root=str(self.test_resources))
self.assertIsNotNone(dataset)
self.assertEqual(dataset.dataset_name, "tcga_paad")

def test_stats(self):
dataset = TCGAPAADDataset(root=str(self.test_resources))
dataset.stats()

def test_get_patient(self):
dataset = TCGAPAADDataset(root=str(self.test_resources))
patient = dataset.get_patient("TCGA-AB-1234")
self.assertIsNotNone(patient)
self.assertEqual(patient.patient_id, "TCGA-AB-1234")

def test_get_mutation_events(self):
dataset = TCGAPAADDataset(root=str(self.test_resources))
patient = dataset.get_patient("TCGA-AB-1234")
events = patient.get_events(event_type="mutations")
self.assertGreaterEqual(len(events), 1)

def test_get_clinical_events(self):
dataset = TCGAPAADDataset(root=str(self.test_resources))
patient = dataset.get_patient("TCGA-AB-1234")
events = patient.get_events(event_type="clinical")
self.assertEqual(len(events), 1)

def test_default_task(self):
dataset = TCGAPAADDataset(root=str(self.test_resources))
self.assertIsInstance(dataset.default_task, CancerSurvivalPrediction)

def test_set_task_survival(self):
dataset = TCGAPAADDataset(root=str(self.test_resources))
task = CancerSurvivalPrediction()
samples = dataset.set_task(task)
self.assertGreater(len(samples), 0)


if __name__ == "__main__":
unittest.main()