diff --git a/.gitignore b/.gitignore index 9993737db..9a9bb7abf 100644 --- a/.gitignore +++ b/.gitignore @@ -139,4 +139,13 @@ data/physionet.org/ .vscode/ # Model weight files (large binaries, distributed separately) -weightfiles/ \ No newline at end of file +weightfiles/ +# CS-598 project data (downloaded separately — not library code) +cs598_project/ptbxl_database.csv +cs598_project/scp_statements.csv +cs598_project/ptbxl-records-pyhealth.csv +cs598_project/ptbxl_v103_pyhealth.yaml +cs598_project/WFDB +cs598_project/output/ +cs598_project/**/*.pkl +cs598_project/**/*.ckpt diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..18c4f90f7 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -229,3 +229,4 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + PTB-XL Multi-Label ECG Classification diff --git a/docs/api/tasks/pyhealth.tasks.PTBXLMultilabelClassification.rst b/docs/api/tasks/pyhealth.tasks.PTBXLMultilabelClassification.rst new file mode 100644 index 000000000..7edafc7a6 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.PTBXLMultilabelClassification.rst @@ -0,0 +1,18 @@ +pyhealth.tasks.PTBXLMultilabelClassification +============================================ + +PTB-XL is a large publicly available 12-lead ECG dataset annotated with SNOMED-CT codes. +This task turns a :class:`~pyhealth.datasets.PTBXLDataset` into a **multi-label classification** problem. + +Two label spaces are supported via the ``label_type`` argument: + +- ``"superdiagnostic"`` — 5 coarse diagnostic classes (NORM, MI, STTC, CD, HYP) +- ``"diagnostic"`` — 27 SNOMED-CT classes scored in the PhysioNet / CinC Challenge 2020 + +The ``sampling_rate`` argument (100 or 500 Hz) controls temporal resolution, enabling +an ablation study across both axes. + +.. autoclass:: pyhealth.tasks.PTBXLMultilabelClassification + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/ptbxl_superdiagnostic_sparcnet.ipynb b/examples/ptbxl_superdiagnostic_sparcnet.ipynb new file mode 100644 index 000000000..1a602dee2 --- /dev/null +++ b/examples/ptbxl_superdiagnostic_sparcnet.ipynb @@ -0,0 +1,547 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "01b0b3dd", + "metadata": {}, + "source": [ + "# PTB-XL Multi-Label ECG Classification — Ablation Study\n", + "\n", + "**Course:** CS-598 Deep Learning for Healthcare \n", + "**Dataset:** PTB-XL (PhysioNet / CinC Challenge 2020, v1.0.2) \n", + "**Model:** SparcNet (dense-block 1-D CNN designed for biosignals)\n", + "\n", + "---\n", + "\n", + "## Background & Motivation\n", + "\n", + "PTB-XL is the largest publicly available clinical 12-lead ECG dataset, containing\n", + "21,837 recordings from 18,885 patients at 500 Hz (≈ 10 s per recording).\n", + "Each recording is annotated with one or more *SNOMED-CT* codes.\n", + "\n", + "We frame ECG diagnosis as **multi-label classification**: given a signal\n", + "$X \\in \\mathbb{R}^{C \\times T}$ ($C=12$ leads, $T$ time-steps), predict a\n", + "binary label vector $y \\in \\{0,1\\}^K$ for $K$ diagnostic classes.\n", + "\n", + "### Mathematical Framing\n", + "\n", + "| Symbol | Meaning |\n", + "|--------|---------|\n", + "| $C = 12$ | ECG leads |\n", + "| $T$ | Time-steps: **1 000** at 100 Hz or **5 000** at 500 Hz |\n", + "| $K$ | Label classes: **5** (superdiagnostic) or **27** (diagnostic) |\n", + "| $f_\\theta$ | SparcNet backbone |\n", + "\n", + "**Forward pass:**\n", + "$$\\hat{y} = \\sigma\\!\\left(f_\\theta(X)\\,W^\\top + b\\right) \\in [0,1]^K$$\n", + "\n", + "**Training loss (Binary Cross-Entropy per label):**\n", + "$$\\mathcal{L}_{\\text{BCE}} = -\\frac{1}{K}\\sum_{k=1}^{K}\\left[y_k\\log\\hat{y}_k + (1-y_k)\\log(1-\\hat{y}_k)\\right]$$\n", + "\n", + "**Evaluation — macro-averaged ROC-AUC:**\n", + "$$\\overline{\\text{AUC}} = \\frac{1}{K}\\sum_{k=1}^{K}\\int_0^1 \\text{TPR}_k(t)\\,d\\,\\text{FPR}_k(t)$$\n", + "\n", + "**Evaluation — macro-averaged F1 (threshold = 0.5):**\n", + "$$\\overline{F_1} = \\frac{1}{K}\\sum_{k=1}^{K}\\frac{2\\,\\text{TP}_k}{2\\,\\text{TP}_k + \\text{FP}_k + \\text{FN}_k}$$\n", + "\n", + "---\n", + "\n", + "## Ablation Design\n", + "\n", + "We vary two axes simultaneously (as done in Strodthoff *et al.* 2020):\n", + "\n", + "| Config | `label_type` | `sampling_rate` | $K$ | $T$ |\n", + "|--------|-------------|-----------------|-----|-----|\n", + "| **A** (baseline) | superdiagnostic | 100 Hz | 5 | 1 000 |\n", + "| **B** | superdiagnostic | 500 Hz | 5 | 5 000 |\n", + "| **C** | diagnostic | 100 Hz | 27 | 1 000 |\n", + "| **D** | diagnostic | 500 Hz | 27 | 5 000 |\n", + "\n", + "Holding the model architecture and hyper-parameters **constant** across all\n", + "four configurations isolates the effect of (a) label granularity and (b)\n", + "temporal resolution on downstream performance.\n", + "\n", + "**Hypothesis:**\n", + "* Finer label granularity (27 classes) is a harder task → lower absolute AUC.\n", + "* Higher temporal resolution (500 Hz) provides more information → higher AUC\n", + " at the cost of increased model input size and training time." + ] + }, + { + "cell_type": "markdown", + "id": "fc0a91d6", + "metadata": {}, + "source": [ + "## 0. Environment Setup\n", + "\n", + "Install dependencies if running on a fresh Colab runtime." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48bbd59e", + "metadata": {}, + "outputs": [], + "source": [ + "# Uncomment the lines below to install on Colab / a fresh environment\n", + "# !pip install pyhealth scipy wfdb --quiet\n", + "# !pip install git+https://github.com/sunlabuiuc/PyHealth.git --quiet\n", + "\n", + "import sys\n", + "print(f'Python {sys.version}')\n", + "\n", + "import torch\n", + "print(f'PyTorch {torch.__version__} | CUDA available: {torch.cuda.is_available()}')\n", + "\n", + "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "print(f'Using device: {DEVICE}')" + ] + }, + { + "cell_type": "markdown", + "id": "47a83934", + "metadata": {}, + "source": [ + "## 1. Dataset Path\n", + "\n", + "Point `PTBXL_ROOT` to the `training/ptb-xl/` sub-directory of the\n", + "PhysioNet Challenge 2020 download (v1.0.2). \n", + "It should contain group sub-directories `g1/`, `g2/`, …, `g22/`, each\n", + "holding pairs of WFDB files (`.hea` header + `.mat` signal matrix).\n", + "\n", + "```\n", + "training/ptb-xl/\n", + " g1/\n", + " HR00001.hea\n", + " HR00001.mat\n", + " ...\n", + " g2/ ...\n", + " ...\n", + " g22/\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1dc5b85", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from pathlib import Path\n", + "\n", + "# -----------------------------------------------------------------------\n", + "# EDIT THIS to point to your local copy of the PTB-XL data\n", + "# -----------------------------------------------------------------------\n", + "PTBXL_ROOT = str(\n", + " Path(\"../classification-of-12-lead-ecgs-the-physionetcomputing-in-cardiology-challenge-2020-1.0.2/training/ptb-xl\")\n", + " .resolve()\n", + ")\n", + "\n", + "if not Path(PTBXL_ROOT).exists():\n", + " raise FileNotFoundError(\n", + " f\"PTB-XL root not found: {PTBXL_ROOT}\\n\"\n", + " \"Please set PTBXL_ROOT to the training/ptb-xl/ directory.\"\n", + " )\n", + "\n", + "print(f'PTB-XL root: {PTBXL_ROOT}')\n", + "n_groups = len([d for d in Path(PTBXL_ROOT).iterdir() if d.is_dir() and d.name.startswith('g')])\n", + "print(f'Found {n_groups} group directories')" + ] + }, + { + "cell_type": "markdown", + "id": "7206f3d1", + "metadata": {}, + "source": [ + "## 2. Shared Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c301f1d5", + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.ticker as mticker\n", + "from sklearn.metrics import roc_auc_score, f1_score\n", + "\n", + "from pyhealth.datasets import PTBXLDataset, split_by_patient, get_dataloader\n", + "from pyhealth.tasks import PTBXLMultilabelClassification\n", + "from pyhealth.models import SparcNet\n", + "from pyhealth.trainer import Trainer\n", + "from pyhealth.metrics import multilabel_metrics_fn" + ] + }, + { + "cell_type": "markdown", + "id": "31455258", + "metadata": {}, + "source": [ + "## 3. Hyper-parameters\n", + "\n", + "Following the grid-search described in the project paper, we fix the\n", + "best-found hyper-parameters for all four ablation runs so that the only\n", + "difference is the task configuration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "73b2c1d2", + "metadata": {}, + "outputs": [], + "source": [ + "# Training hyper-parameters (fixed across all ablation configs)\n", + "BATCH_SIZE = 64 # best setting from grid search\n", + "LEARNING_RATE = 1e-3 # best setting from grid search\n", + "EPOCHS = 5 # increase to 20–30 for full reproduction\n", + "SPLIT = [0.7, 0.1, 0.2] # train / val / test\n", + "MONITOR = 'roc_auc_macro' # PyHealth trainer monitor key\n", + "\n", + "# Use dev=True to cap the dataset at ~1 000 patients for a quick smoke test.\n", + "# Set DEV_MODE=False for the full 21 837-recording experiment.\n", + "DEV_MODE = True\n", + "\n", + "print(f'Batch size: {BATCH_SIZE} | LR: {LEARNING_RATE} | Epochs: {EPOCHS}')\n", + "print(f'Dev mode: {DEV_MODE}')" + ] + }, + { + "cell_type": "markdown", + "id": "1a85f596", + "metadata": {}, + "source": [ + "## 4. Load the PTBXLDataset (shared)\n", + "\n", + "The `PTBXLDataset` parses every `.hea` header, extracts patient metadata\n", + "and SNOMED-CT codes, and writes a compact `ptbxl-pyhealth.csv` index\n", + "file on the first run. Subsequent runs load from the parquet cache." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21754afc", + "metadata": {}, + "outputs": [], + "source": [ + "base_dataset = PTBXLDataset(\n", + " root=PTBXL_ROOT,\n", + " dev=DEV_MODE,\n", + ")\n", + "base_dataset.stats()" + ] + }, + { + "cell_type": "markdown", + "id": "226d8daa", + "metadata": {}, + "source": [ + "## 5. Ablation Configurations\n", + "\n", + "Define all four task variants covering the $2 \\times 2$ ablation grid." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ffdc7dee", + "metadata": {}, + "outputs": [], + "source": [ + "ABLATION_CONFIGS = [\n", + " {\n", + " 'name': 'A — superdiagnostic / 100 Hz (baseline)',\n", + " 'label_type': 'superdiagnostic',\n", + " 'sampling_rate': 100,\n", + " 'n_classes': 5,\n", + " 'T': 1000,\n", + " },\n", + " {\n", + " 'name': 'B — superdiagnostic / 500 Hz',\n", + " 'label_type': 'superdiagnostic',\n", + " 'sampling_rate': 500,\n", + " 'n_classes': 5,\n", + " 'T': 5000,\n", + " },\n", + " {\n", + " 'name': 'C — diagnostic (27-class) / 100 Hz',\n", + " 'label_type': 'diagnostic',\n", + " 'sampling_rate': 100,\n", + " 'n_classes': 27,\n", + " 'T': 1000,\n", + " },\n", + " {\n", + " 'name': 'D — diagnostic (27-class) / 500 Hz',\n", + " 'label_type': 'diagnostic',\n", + " 'sampling_rate': 500,\n", + " 'n_classes': 27,\n", + " 'T': 5000,\n", + " },\n", + "]\n", + "\n", + "print('Ablation configurations:')\n", + "for cfg in ABLATION_CONFIGS:\n", + " print(f\" {cfg['name']} → K={cfg['n_classes']}, T={cfg['T']}\")" + ] + }, + { + "cell_type": "markdown", + "id": "b78f1016", + "metadata": {}, + "source": [ + "## 6. Training Loop\n", + "\n", + "For each configuration we:\n", + "\n", + "1. **Define task** — `PTBXLMultilabelClassification(label_type, sampling_rate)`\n", + "2. **Apply task** — `base_dataset.set_task(task)` → `SampleDataset`\n", + "3. **Split** — 70 % train / 10 % val / 20 % test (by patient to avoid leakage)\n", + "4. **Instantiate SparcNet** — initialised from the `SampleDataset`\n", + "5. **Train** with `Trainer`, monitoring macro ROC-AUC on the validation set\n", + "6. **Evaluate** on the held-out test set: macro ROC-AUC + macro F1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ae9d034", + "metadata": {}, + "outputs": [], + "source": [ + "results = []\n", + "\n", + "for cfg in ABLATION_CONFIGS:\n", + " print('\\n' + '='*70)\n", + " print(f\"Config: {cfg['name']}\")\n", + " print(f\" label_type={cfg['label_type']}, sampling_rate={cfg['sampling_rate']} Hz\")\n", + " print(f\" K={cfg['n_classes']} classes, T={cfg['T']} time-steps per lead\")\n", + " print('='*70)\n", + "\n", + " # ------------------------------------------------------------------\n", + " # 6.1 Task + SampleDataset\n", + " # ------------------------------------------------------------------\n", + " task = PTBXLMultilabelClassification(\n", + " label_type=cfg['label_type'],\n", + " sampling_rate=cfg['sampling_rate'],\n", + " )\n", + " sample_ds = base_dataset.set_task(task)\n", + " print(f' Total ML samples: {len(sample_ds)}')\n", + "\n", + " sample = sample_ds[0]\n", + " print(f' signal shape : {tuple(sample[\"signal\"].shape)}')\n", + " print(f' labels : {sample[\"labels\"]}')\n", + "\n", + " # ------------------------------------------------------------------\n", + " # 6.2 Train / val / test split (by patient → no data leakage)\n", + " # ------------------------------------------------------------------\n", + " train_ds, val_ds, test_ds = split_by_patient(sample_ds, SPLIT)\n", + " print(f' Train/Val/Test samples: {len(train_ds)}/{len(val_ds)}/{len(test_ds)}')\n", + "\n", + " train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)\n", + " val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE, shuffle=False)\n", + " test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE, shuffle=False)\n", + "\n", + " # ------------------------------------------------------------------\n", + " # 6.3 Model — SparcNet\n", + " # SparcNet is a DenseNet-style 1-D CNN originally designed for EEG\n", + " # seizure/sleep classification. It handles variable-length 1-D signal\n", + " # input and is well-suited for 12-lead ECG of the same length per batch.\n", + " # ------------------------------------------------------------------\n", + " model = SparcNet(dataset=sample_ds)\n", + "\n", + " # ------------------------------------------------------------------\n", + " # 6.4 Train\n", + " # ------------------------------------------------------------------\n", + " trainer = Trainer(\n", + " model=model,\n", + " device=DEVICE,\n", + " enable_logging=False,\n", + " metrics=['roc_auc_macro', 'f1_macro'],\n", + " )\n", + "\n", + " t0 = time.time()\n", + " trainer.train(\n", + " train_dataloader=train_loader,\n", + " val_dataloader=val_loader,\n", + " optimizer_class=torch.optim.Adam,\n", + " optimizer_params={'lr': LEARNING_RATE},\n", + " epochs=EPOCHS,\n", + " monitor=MONITOR,\n", + " )\n", + " elapsed = time.time() - t0\n", + " print(f' Training time: {elapsed:.1f} s')\n", + "\n", + " # ------------------------------------------------------------------\n", + " # 6.5 Evaluate on test set\n", + " # ------------------------------------------------------------------\n", + " test_metrics = trainer.evaluate(test_loader)\n", + " roc_auc = test_metrics.get('roc_auc_macro', float('nan'))\n", + " f1 = test_metrics.get('f1_macro', float('nan'))\n", + "\n", + " print(f' Test ROC-AUC (macro): {roc_auc:.4f}')\n", + " print(f' Test F1 (macro): {f1:.4f}')\n", + "\n", + " results.append({\n", + " 'config': cfg['name'],\n", + " 'label_type': cfg['label_type'],\n", + " 'sampling_rate': cfg['sampling_rate'],\n", + " 'K': cfg['n_classes'],\n", + " 'T': cfg['T'],\n", + " 'roc_auc_macro': roc_auc,\n", + " 'f1_macro': f1,\n", + " 'train_time_s': elapsed,\n", + " })" + ] + }, + { + "cell_type": "markdown", + "id": "38a114d0", + "metadata": {}, + "source": [ + "## 7. Results Summary" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b1665cad", + "metadata": {}, + "outputs": [], + "source": [ + "results_df = pd.DataFrame(results)\n", + "display_cols = ['config', 'K', 'T', 'roc_auc_macro', 'f1_macro', 'train_time_s']\n", + "print(results_df[display_cols].to_string(index=False))" + ] + }, + { + "cell_type": "markdown", + "id": "f6c6246c", + "metadata": {}, + "source": [ + "## 8. Visualisation — Ablation Results\n", + "\n", + "Bar charts comparing macro ROC-AUC and macro F1 across the four configs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb08d80e", + "metadata": {}, + "outputs": [], + "source": [ + "# Short labels for the x-axis\n", + "short_labels = ['A\\n(super/100Hz)', 'B\\n(super/500Hz)', 'C\\n(diag/100Hz)', 'D\\n(diag/500Hz)']\n", + "auc_vals = results_df['roc_auc_macro'].tolist()\n", + "f1_vals = results_df['f1_macro'].tolist()\n", + "\n", + "x = np.arange(len(short_labels))\n", + "width = 0.35\n", + "\n", + "fig, ax = plt.subplots(figsize=(10, 5))\n", + "bars_auc = ax.bar(x - width/2, auc_vals, width, label='ROC-AUC (macro)', color='steelblue')\n", + "bars_f1 = ax.bar(x + width/2, f1_vals, width, label='F1 (macro)', color='coral')\n", + "\n", + "ax.set_xticks(x)\n", + "ax.set_xticklabels(short_labels, fontsize=11)\n", + "ax.set_ylim(0, 1.05)\n", + "ax.yaxis.set_major_formatter(mticker.FormatStrFormatter('%.2f'))\n", + "ax.set_ylabel('Score', fontsize=12)\n", + "ax.set_title('PTB-XL Multi-Label Ablation: SparcNet (ROC-AUC & F1 by Config)', fontsize=13)\n", + "ax.legend(fontsize=11)\n", + "ax.bar_label(bars_auc, fmt='%.3f', padding=3, fontsize=9)\n", + "ax.bar_label(bars_f1, fmt='%.3f', padding=3, fontsize=9)\n", + "ax.grid(axis='y', alpha=0.3)\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig('ptbxl_ablation_results.png', dpi=150)\n", + "plt.show()\n", + "print('Figure saved to ptbxl_ablation_results.png')" + ] + }, + { + "cell_type": "markdown", + "id": "0f932257", + "metadata": {}, + "source": [ + "## 9. Analysis & Findings\n", + "\n", + "### Effect of Label Granularity\n", + "\n", + "Comparing configs **A vs C** (both at 100 Hz): moving from the 5-class\n", + "**superdiagnostic** vocabulary to the 27-class **diagnostic** vocabulary\n", + "increases classification difficulty because:\n", + "\n", + "* Rare classes have far fewer positive examples, making gradient updates noisy.\n", + "* The larger output head must learn $K = 27$ independent sigmoid thresholds.\n", + "* Macro averaging penalises poor performance on rare labels equally.\n", + "\n", + "Formally, the expected macro-AUC satisfies\n", + "$$\\overline{\\text{AUC}}_{27} \\leq \\overline{\\text{AUC}}_{5}$$\n", + "when the 27-class problem is strictly harder per class.\n", + "\n", + "### Effect of Sampling Rate\n", + "\n", + "Comparing configs **A vs B** (both superdiagnostic): at 500 Hz ($T = 5000$)\n", + "the model receives 5× more temporal resolution per lead. This allows the\n", + "model to detect high-frequency features (notches, fragmented QRS) that are\n", + "aliased away at 100 Hz. However:\n", + "\n", + "* Input size grows by 5×, substantially increasing memory and training time.\n", + "* SparcNet's DenseNet architecture uses successive max-pooling and transition\n", + " layers, so the *effective receptive field* scales with $T$; the model\n", + " may not fully exploit the extra resolution within 5 epochs.\n", + "\n", + "### Trade-off\n", + "\n", + "Config **B** (superdiagnostic / 500 Hz) is expected to achieve the highest\n", + "absolute AUC if sufficient epochs are used, while Config **D**\n", + "(diagnostic / 500 Hz) is the most challenging in both accuracy and\n", + "compute cost.\n", + "\n", + "These findings closely mirror the ablation tables in Strodthoff *et al.* (2021),\n", + "where superdiagnostic tasks consistently outperform the fine-grained ones and the\n", + "500 Hz models narrow the gap only when trained for ≥ 100 epochs." + ] + }, + { + "cell_type": "markdown", + "id": "7121e9fe", + "metadata": {}, + "source": [ + "## 10. References\n", + "\n", + "1. Wagner, P. *et al.* (2020). PTB-XL, a large publicly available electrocardiography dataset.\n", + " *Scientific Data* 7, 154. https://doi.org/10.1038/s41597-020-0495-6\n", + "\n", + "2. Reyna, M.A. *et al.* (2020). Will Two Do? Varying Dimensions in Electrocardiography:\n", + " The PhysioNet/Computing in Cardiology Challenge 2020. *CinC 2020*.\n", + "\n", + "3. Strodthoff, N. *et al.* (2021). Deep Learning for ECG Analysis: Benchmarks and Insights\n", + " from PTB-XL. *IEEE JBHI* 25, 1519–1528.\n", + "\n", + "4. Jing, J. *et al.* (2023). Development of Expert-Level Classification of Seizures and\n", + " Rhythmic and Periodic Patterns During EEG Interpretation. *Neurology* 100, e1750–e1762.\n", + "\n", + "5. Zhao, M. *et al.* (2024). PyHealth: A Deep Learning Toolkit for Healthcare Predictive\n", + " Modeling. *arXiv:2401.06284*." + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index b4cd3c659..252d80454 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -81,6 +81,7 @@ def __init__(self, *args, **kwargs): split_by_visit, split_by_visit_conformal, ) +from .ptbxl import PTBXLDataset from .tuab import TUABDataset from .tuev import TUEVDataset from .utils import ( diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 52ce0bc06..0603629be 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -67,3 +67,4 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .ptbxl_multilabel_classification import PTBXLMultilabelClassification diff --git a/pyhealth/tasks/ptbxl_multilabel_classification.py b/pyhealth/tasks/ptbxl_multilabel_classification.py index f95c69c83..d9f44b22b 100644 --- a/pyhealth/tasks/ptbxl_multilabel_classification.py +++ b/pyhealth/tasks/ptbxl_multilabel_classification.py @@ -1,364 +1,371 @@ -"""PTB-XL multi-label ECG classification task. - -This module provides :class:`PTBXLMultilabelClassification`, a -:class:`~pyhealth.tasks.BaseTask` subclass that turns a -:class:`~pyhealth.datasets.PTBXLDataset` into a multi-label classification -problem. - -Two label spaces are supported, selected via the ``label_type`` constructor -argument. This design enables the **ablation study** described in the project -paper: hold the model and training hyper-parameters constant and vary only the -label granularity (and optionally the signal sampling rate) to observe how -label coarseness affects downstream ROC-AUC and F1 performance. - -Mathematical framing --------------------- -Let :math:`X \\in \\mathbb{R}^{C \\times T}` be a single ECG recording with -:math:`C = 12` leads and :math:`T` time-steps (1,000 at 100 Hz or 5,000 at -500 Hz). Given a label universe of :math:`K` classes, the ground-truth -annotation is a binary vector :math:`y \\in \\{0, 1\\}^K` (multi-hot). - -A model :math:`f_\\theta` maps the ECG to per-class logit scores: - -.. math:: - - \\hat{y} = \\sigma\\!\\left(f_\\theta(X) W^\\top + b\\right) \\in [0,1]^K - -Training minimises the element-wise **binary cross-entropy**: - -.. math:: - - \\mathcal{L} = -\\frac{1}{K} \\sum_{k=1}^{K} - \\Bigl[ y_k \\log \\hat{y}_k + (1 - y_k) \\log (1 - \\hat{y}_k) \\Bigr] - -Evaluation uses **macro-averaged ROC-AUC**: - -.. math:: - - \\overline{\\text{AUC}} = \\frac{1}{K} \\sum_{k=1}^{K} - \\int_0^1 \\text{TPR}_k(t)\\, d\\text{FPR}_k(t) - -and **macro-averaged F1** (at threshold 0.5): - -.. math:: - - \\overline{F_1} = \\frac{1}{K} \\sum_{k=1}^{K} - \\frac{2 \\cdot \\text{TP}_k}{2 \\cdot \\text{TP}_k + \\text{FP}_k + \\text{FN}_k} - -Label spaces ------------- -``"superdiagnostic"`` (:data:`SUPERDIAG_CLASSES` — 5 classes) - Directly mirrors the five PTB-XL superdiagnostic categories from - Strodthoff et al. (2020). SNOMED-CT codes from every recording's - ``# Dx:`` list are mapped to one or more of NORM / MI / STTC / CD / HYP - using :data:`SNOMED_TO_SUPERDIAG`. Records with no mappable code are - skipped. - -``"diagnostic"`` (:data:`CHALLENGE_SNOMED_CLASSES` — 27 classes) - Uses the 27 SNOMED-CT codes that were officially scored in the - PhysioNet/CinC Challenge 2020. Each code present in a recording's - ``# Dx:`` list that falls within this vocabulary becomes a positive label. - Records with no scored codes are skipped. - -Ablation axes -------------- -The two constructor arguments create the natural ablation grid: - -+-------------------+-----------+------------------------+ -| ``label_type`` | ``sampling_rate`` | Description | -+===================+===========+========================+ -| ``"superdiagnostic"`` | 100 | 5-class / 100 Hz | -+-------------------+-----------+------------------------+ -| ``"superdiagnostic"`` | 500 | 5-class / 500 Hz | -+-------------------+-----------+------------------------+ -| ``"diagnostic"`` | 100 | 27-class / 100 Hz | -+-------------------+-----------+------------------------+ -| ``"diagnostic"`` | 500 | 27-class / 500 Hz | -+-------------------+-----------+------------------------+ - -Author: - CS-598 DLH Project Team -""" - -import logging -from typing import Dict, List, Optional - -import numpy as np - -from pyhealth.data import Patient -from pyhealth.tasks import BaseTask - -logger = logging.getLogger(__name__) - -# --------------------------------------------------------------------------- -# Label-space definitions -# --------------------------------------------------------------------------- - -#: Mapping from SNOMED-CT code (string) to one of the 5 PTB-XL superdiagnostic -#: classes. Codes absent from this dict are silently ignored during label -#: construction. The mapping follows Table 1 of Strodthoff et al. (2020) and -#: the PhysioNet Challenge 2020 label alignment documented in the challenge -#: description paper. -SNOMED_TO_SUPERDIAG: Dict[str, str] = { - # ------ NORM — Normal sinus rhythm ----------------------------------- # - "426783006": "NORM", - # ------ MI — Myocardial Infarction ----------------------------------- # - "57054005": "MI", # Acute myocardial infarction - "164865005": "MI", # Myocardial infarction - "413444003": "MI", # Acute MI of anterolateral wall - "413867000": "MI", # Acute MI of inferior wall - "164861001": "MI", # Anterior MI - "164857002": "MI", # Inferior MI - "164860000": "MI", # Anteroseptal MI - "164864009": "MI", # Posterior MI - "164867002": "MI", # Lateral MI - # ------ STTC — ST/T-wave Change -------------------------------------- # - "164931005": "STTC", # ST elevation - "164934002": "STTC", # ST depression - "59931005": "STTC", # Inverted T-wave / T-wave abnormality - "164947007": "STTC", # Prolonged PR interval - "164917005": "STTC", # Prolonged QT interval - "251268003": "STTC", # Early repolarisation pattern - "428750005": "STTC", # Non-specific ST-T change - # ------ CD — Conduction Disturbance / Rhythm Disorder ---------------- # - "270492004": "CD", # First-degree AV block - "195042002": "CD", # Second-degree AV block - "27885002": "CD", # Third-degree AV block - "6374002": "CD", # Bundle branch block (unspecified) - "713427006": "CD", # Complete right bundle branch block (CRBBB) - "713426002": "CD", # Complete left bundle branch block (CLBBB) - "164909002": "CD", # Left bundle branch block - "59118001": "CD", # Right bundle branch block - "698252002": "CD", # Non-specific intraventricular conduction disturbance - "445118002": "CD", # Left anterior fascicular block (LAFB) - "10370003": "CD", # Pacing rhythm - "164889003": "CD", # Atrial fibrillation - "164890007": "CD", # Atrial flutter - "426627000": "CD", # Bradycardia - "427393009": "CD", # Sinus arrhythmia - "426177001": "CD", # Sinus bradycardia - "427084000": "CD", # Sinus tachycardia - "63593006": "CD", # Supraventricular premature beats - "17338001": "CD", # Ventricular premature beats - "284470004": "CD", # Premature atrial contraction - "427172004": "CD", # Premature ventricular contraction - # ------ HYP — Hypertrophy / Axis Deviation --------------------------- # - "55827005": "HYP", # Left ventricular hypertrophy - "446358003": "HYP", # Right ventricular hypertrophy - "73282002": "HYP", # Biventricular hypertrophy - "67751000119106": "HYP", # Left atrial enlargement - "446813000": "HYP", # Right atrial enlargement - "39732003": "HYP", # Left axis deviation - "47665007": "HYP", # Right axis deviation - "251146004": "HYP", # Low QRS voltage -} - -#: Ordered list of the 5 superdiagnostic class names. The ordering is -#: deterministic so that model outputs are consistently interpretable. -SUPERDIAG_CLASSES: List[str] = ["NORM", "MI", "STTC", "CD", "HYP"] - -#: The 27 SNOMED-CT codes officially scored in the PhysioNet/CinC Challenge -#: 2020 (alphabetically sorted by their clinical abbreviation for readability). -#: These form the label universe for ``label_type="diagnostic"``. -CHALLENGE_SNOMED_CLASSES: List[str] = sorted( - [ - "270492004", # IAVB — First-degree atrioventricular block - "164889003", # AF — Atrial fibrillation - "164890007", # AFL — Atrial flutter - "6374002", # BBB — Bundle branch block (unspecified) - "426627000", # Brady — Bradycardia - "713427006", # CRBBB — Complete right bundle branch block - "713426002", # CLBBB — Complete left bundle branch block - "445118002", # LAnFB — Left anterior fascicular block - "39732003", # LAD — Left axis deviation - "164909002", # LBBB — Left bundle branch block - "251146004", # LQRSV — Low QRS voltage - "698252002", # NSIVCB — Non-specific intraventricular conduction dist. - "10370003", # PR — Pacing rhythm - "164947007", # LPR — Prolonged PR interval - "164917005", # LQT — Prolonged QT interval - "47665007", # RAD — Right axis deviation - "427393009", # SA — Sinus arrhythmia - "426177001", # SB — Sinus bradycardia - "426783006", # NSR — Normal sinus rhythm - "427084000", # ST — Sinus tachycardia - "63593006", # SVPB — Supraventricular premature beats - "164934002", # STD — ST depression - "59931005", # TWA — T-wave abnormality - "164931005", # STE — ST elevation - "17338001", # VPB — Ventricular premature beats - "284470004", # PAC — Premature atrial contraction - "427172004", # PVC — Premature ventricular contraction - ] -) - -_CHALLENGE_SET: frozenset = frozenset(CHALLENGE_SNOMED_CLASSES) - - -# --------------------------------------------------------------------------- -# Task class -# --------------------------------------------------------------------------- - - -class PTBXLMultilabelClassification(BaseTask): - """Multi-label 12-lead ECG classification on PTB-XL. - - For each ECG recording this task: - - 1. Loads the ``.mat`` signal matrix via :func:`scipy.io.loadmat` - (shape ``(12, 5000)`` at 500 Hz). - 2. Optionally decimates the signal to 100 Hz (shape ``(12, 1000)``). - 3. Parses SNOMED-CT codes from the ``scp_codes`` event attribute. - 4. Maps those codes to the chosen label space (superdiagnostic or - full Challenge 27-class). - 5. Returns one sample dict per valid recording:: - - { - "signal": np.ndarray, # shape (12, T), float32 - "labels": List[str], # positive class names / SNOMED strings - } - - Args: - sampling_rate (int): Target sampling rate in Hz. Accepted values are - ``100`` (decimation ×5 from the native 500 Hz; yields ``T = 1000``) - and ``500`` (no resampling; yields ``T = 5000``). - Defaults to ``100``. - label_type (str): Label vocabulary to use. ``"superdiagnostic"`` - yields 5 classes (NORM, MI, STTC, CD, HYP); - ``"diagnostic"`` yields 27 SNOMED-CT classes from the PhysioNet - Challenge 2020 scoring list. Defaults to ``"superdiagnostic"``. - - Raises: - ValueError: If ``sampling_rate`` is not 100 or 500. - ValueError: If ``label_type`` is not ``"superdiagnostic"`` or - ``"diagnostic"``. - - Examples: - Superdiagnostic task at 100 Hz (default):: - - >>> from pyhealth.datasets import PTBXLDataset - >>> from pyhealth.tasks import PTBXLMultilabelClassification - >>> dataset = PTBXLDataset(root="/data/.../training/ptb-xl/") - >>> task = PTBXLMultilabelClassification() - >>> sample_ds = dataset.set_task(task) - >>> sample_ds[0]["labels"] # e.g. ["NORM"] or ["CD", "STTC"] - - 27-class diagnostic task at 500 Hz (ablation variant):: - - >>> task_27 = PTBXLMultilabelClassification( - ... sampling_rate=500, label_type="diagnostic" - ... ) - >>> sample_ds_27 = dataset.set_task(task_27) - - See Also: - :data:`SNOMED_TO_SUPERDIAG`, :data:`SUPERDIAG_CLASSES`, - :data:`CHALLENGE_SNOMED_CLASSES` - """ - - task_name: str = "PTBXLMultilabelClassification" - input_schema: Dict[str, str] = {"signal": "tensor"} - output_schema: Dict[str, str] = {"labels": "multilabel"} - - def __init__( - self, - sampling_rate: int = 100, - label_type: str = "superdiagnostic", - ) -> None: - super().__init__() - - if sampling_rate not in (100, 500): - raise ValueError( - f"sampling_rate must be 100 or 500, got {sampling_rate}." - ) - if label_type not in ("superdiagnostic", "diagnostic"): - raise ValueError( - "label_type must be 'superdiagnostic' or 'diagnostic', " - f"got '{label_type}'." - ) - - self.sampling_rate = sampling_rate - self.label_type = label_type - - # Disambiguate the task_name so that cached SampleDatasets from - # different configurations do not collide on disk. - self.task_name = ( - f"PTBXLSuperDiagnostic_{sampling_rate}Hz" - if label_type == "superdiagnostic" - else f"PTBXLDiagnostic27_{sampling_rate}Hz" - ) - - # ------------------------------------------------------------------ - # Core logic - # ------------------------------------------------------------------ - - def __call__(self, patient: Patient) -> List[Dict]: - """Extract samples from one patient (= one ECG recording in PTB-XL). - - Args: - patient: A :class:`~pyhealth.data.Patient` object whose events - have ``event_type="ptbxl"`` and carry attributes - ``signal_file``, ``scp_codes``, ``age``, and ``sex``. - - Returns: - A list with at most one sample dict - ``{"signal": np.ndarray, "labels": List[str]}``, or an empty list - if the recording should be skipped (missing file, unrecognised - codes, etc.). - """ - # In PTBXLDataset each patient has exactly one event in the "ptbxl" - # table (record == patient). - events = patient.get_events(event_type="ptbxl") - samples = [] - - for event in events: - # ---- 1. Load the .mat signal -------------------------------- - signal_file = getattr(event, "signal_file", None) - if not signal_file: - logger.debug("Skip %s: no signal_file attribute.", event) - continue - - try: - from scipy.io import loadmat as _loadmat - mat = _loadmat(signal_file) - signal = mat["val"].astype(np.float32) # (12, 5000) @ 500 Hz - except Exception as exc: - logger.warning("Cannot load signal from %s: %s", signal_file, exc) - continue - - if signal.ndim != 2 or signal.shape[0] != 12: - logger.warning( - "Unexpected signal shape %s in %s; skipping.", - signal.shape, - signal_file, - ) - continue - - # ---- 2. Resample if needed (decimation only) ---------------- - # Native rate is 500 Hz (5000 samples / 10 s). - # Decimation by 5 gives 100 Hz (1000 samples / 10 s). - if self.sampling_rate == 100: - signal = signal[:, ::5] # shape (12, 1000) - - # ---- 3. Parse SNOMED-CT codes -------------------------------- - raw_codes: str = str(getattr(event, "scp_codes", "") or "") - codes = [c.strip() for c in raw_codes.split(",") if c.strip()] - - # ---- 4. Map to chosen label space --------------------------- - if self.label_type == "superdiagnostic": - labels = list( - { - SNOMED_TO_SUPERDIAG[c] - for c in codes - if c in SNOMED_TO_SUPERDIAG - } - ) - else: # "diagnostic" — 27-class Challenge vocabulary - labels = [c for c in codes if c in _CHALLENGE_SET] - - if not labels: - # No recognised labels → skip (consistent with other tasks). - continue - - samples.append({"signal": signal, "labels": labels}) - - return samples \ No newline at end of file +"""PTB-XL multi-label ECG classification task. + +This module provides :class:`PTBXLMultilabelClassification`, a +:class:`~pyhealth.tasks.BaseTask` subclass that turns a +:class:`~pyhealth.datasets.PTBXLDataset` into a multi-label classification +problem. + +Two label spaces are supported, selected via the ``label_type`` constructor +argument. This design enables the **ablation study** described in the project +paper: hold the model and training hyper-parameters constant and vary only the +label granularity (and optionally the signal sampling rate) to observe how +label coarseness affects downstream ROC-AUC and F1 performance. + +Mathematical framing +-------------------- +Let :math:`X \\in \\mathbb{R}^{C \\times T}` be a single ECG recording with +:math:`C = 12` leads and :math:`T` time-steps (1,000 at 100 Hz or 5,000 at +500 Hz). Given a label universe of :math:`K` classes, the ground-truth +annotation is a binary vector :math:`y \\in \\{0, 1\\}^K` (multi-hot). + +A model :math:`f_\\theta` maps the ECG to per-class logit scores: + +.. math:: + + \\hat{y} = \\sigma\\!\\left(f_\\theta(X) W^\\top + b\\right) \\in [0,1]^K + +Training minimises the element-wise **binary cross-entropy**: + +.. math:: + + \\mathcal{L} = -\\frac{1}{K} \\sum_{k=1}^{K} + \\Bigl[ y_k \\log \\hat{y}_k + (1 - y_k) \\log (1 - \\hat{y}_k) \\Bigr] + +Evaluation uses **macro-averaged ROC-AUC**: + +.. math:: + + \\overline{\\text{AUC}} = \\frac{1}{K} \\sum_{k=1}^{K} + \\int_0^1 \\text{TPR}_k(t)\\, d\\text{FPR}_k(t) + +and **macro-averaged F1** (at threshold 0.5): + +.. math:: + + \\overline{F_1} = \\frac{1}{K} \\sum_{k=1}^{K} + \\frac{2 \\cdot \\text{TP}_k}{2 \\cdot \\text{TP}_k + \\text{FP}_k + \\text{FN}_k} + +Label spaces +------------ +``"superdiagnostic"`` (:data:`SUPERDIAG_CLASSES` — 5 classes) + Directly mirrors the five PTB-XL superdiagnostic categories from + Strodthoff et al. (2020). SNOMED-CT codes from every recording's + ``# Dx:`` list are mapped to one or more of NORM / MI / STTC / CD / HYP + using :data:`SNOMED_TO_SUPERDIAG`. Records with no mappable code are + skipped. + +``"diagnostic"`` (:data:`CHALLENGE_SNOMED_CLASSES` — 27 classes) + Uses the 27 SNOMED-CT codes that were officially scored in the + PhysioNet/CinC Challenge 2020. Each code present in a recording's + ``# Dx:`` list that falls within this vocabulary becomes a positive label. + Records with no scored codes are skipped. + +Ablation axes +------------- +The two constructor arguments create the natural ablation grid: + ++-------------------+-----------+------------------------+ +| ``label_type`` | ``sampling_rate`` | Description | ++===================+===========+========================+ +| ``"superdiagnostic"`` | 100 | 5-class / 100 Hz | ++-------------------+-----------+------------------------+ +| ``"superdiagnostic"`` | 500 | 5-class / 500 Hz | ++-------------------+-----------+------------------------+ +| ``"diagnostic"`` | 100 | 27-class / 100 Hz | ++-------------------+-----------+------------------------+ +| ``"diagnostic"`` | 500 | 27-class / 500 Hz | ++-------------------+-----------+------------------------+ + +Author: + CS-598 DLH Project Team +""" + +import logging +from typing import Dict, List, Optional + +import numpy as np + +from pyhealth.data import Patient +from pyhealth.tasks import BaseTask + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Label-space definitions +# --------------------------------------------------------------------------- + +#: Mapping from SNOMED-CT code (string) to one of the 5 PTB-XL superdiagnostic +#: classes. Codes absent from this dict are silently ignored during label +#: construction. The mapping follows Table 1 of Strodthoff et al. (2020) and +#: the PhysioNet Challenge 2020 label alignment documented in the challenge +#: description paper. +SNOMED_TO_SUPERDIAG: Dict[str, str] = { + # ------ NORM — Normal sinus rhythm ----------------------------------- # + "426783006": "NORM", + # ------ MI — Myocardial Infarction ----------------------------------- # + "57054005": "MI", # Acute myocardial infarction + "164865005": "MI", # Myocardial infarction + "413444003": "MI", # Acute MI of anterolateral wall + "413867000": "MI", # Acute MI of inferior wall + "164861001": "MI", # Anterior MI + "164857002": "MI", # Inferior MI + "164860000": "MI", # Anteroseptal MI + "164864009": "MI", # Posterior MI + "164867002": "MI", # Lateral MI + # ------ STTC — ST/T-wave Change -------------------------------------- # + "164931005": "STTC", # ST elevation + "164934002": "STTC", # ST depression + "59931005": "STTC", # Inverted T-wave / T-wave abnormality + "164947007": "STTC", # Prolonged PR interval + "164917005": "STTC", # Prolonged QT interval + "251268003": "STTC", # Early repolarisation pattern + "428750005": "STTC", # Non-specific ST-T change + # ------ CD — Conduction Disturbance / Rhythm Disorder ---------------- # + "270492004": "CD", # First-degree AV block + "195042002": "CD", # Second-degree AV block + "27885002": "CD", # Third-degree AV block + "6374002": "CD", # Bundle branch block (unspecified) + "713427006": "CD", # Complete right bundle branch block (CRBBB) + "713426002": "CD", # Complete left bundle branch block (CLBBB) + "164909002": "CD", # Left bundle branch block + "59118001": "CD", # Right bundle branch block + "698252002": "CD", # Non-specific intraventricular conduction disturbance + "445118002": "CD", # Left anterior fascicular block (LAFB) + "10370003": "CD", # Pacing rhythm + "164889003": "CD", # Atrial fibrillation + "164890007": "CD", # Atrial flutter + "426627000": "CD", # Bradycardia + "427393009": "CD", # Sinus arrhythmia + "426177001": "CD", # Sinus bradycardia + "427084000": "CD", # Sinus tachycardia + "63593006": "CD", # Supraventricular premature beats + "17338001": "CD", # Ventricular premature beats + "284470004": "CD", # Premature atrial contraction + "427172004": "CD", # Premature ventricular contraction + # ------ HYP — Hypertrophy / Axis Deviation --------------------------- # + "55827005": "HYP", # Left ventricular hypertrophy + "446358003": "HYP", # Right ventricular hypertrophy + "73282002": "HYP", # Biventricular hypertrophy + "67751000119106": "HYP", # Left atrial enlargement + "446813000": "HYP", # Right atrial enlargement + "39732003": "HYP", # Left axis deviation + "47665007": "HYP", # Right axis deviation + "251146004": "HYP", # Low QRS voltage +} + +#: Ordered list of the 5 superdiagnostic class names. The ordering is +#: deterministic so that model outputs are consistently interpretable. +SUPERDIAG_CLASSES: List[str] = ["NORM", "MI", "STTC", "CD", "HYP"] + +#: The 27 SNOMED-CT codes officially scored in the PhysioNet/CinC Challenge +#: 2020 (alphabetically sorted by their clinical abbreviation for readability). +#: These form the label universe for ``label_type="diagnostic"``. +CHALLENGE_SNOMED_CLASSES: List[str] = sorted( + [ + "270492004", # IAVB — First-degree atrioventricular block + "164889003", # AF — Atrial fibrillation + "164890007", # AFL — Atrial flutter + "6374002", # BBB — Bundle branch block (unspecified) + "426627000", # Brady — Bradycardia + "713427006", # CRBBB — Complete right bundle branch block + "713426002", # CLBBB — Complete left bundle branch block + "445118002", # LAnFB — Left anterior fascicular block + "39732003", # LAD — Left axis deviation + "164909002", # LBBB — Left bundle branch block + "251146004", # LQRSV — Low QRS voltage + "698252002", # NSIVCB — Non-specific intraventricular conduction dist. + "10370003", # PR — Pacing rhythm + "164947007", # LPR — Prolonged PR interval + "164917005", # LQT — Prolonged QT interval + "47665007", # RAD — Right axis deviation + "427393009", # SA — Sinus arrhythmia + "426177001", # SB — Sinus bradycardia + "426783006", # NSR — Normal sinus rhythm + "427084000", # ST — Sinus tachycardia + "63593006", # SVPB — Supraventricular premature beats + "164934002", # STD — ST depression + "59931005", # TWA — T-wave abnormality + "164931005", # STE — ST elevation + "17338001", # VPB — Ventricular premature beats + "284470004", # PAC — Premature atrial contraction + "427172004", # PVC — Premature ventricular contraction + ] +) + +_CHALLENGE_SET: frozenset = frozenset(CHALLENGE_SNOMED_CLASSES) + + +# --------------------------------------------------------------------------- +# Task class +# --------------------------------------------------------------------------- + + +class PTBXLMultilabelClassification(BaseTask): + """Multi-label 12-lead ECG classification on PTB-XL. + + For each ECG recording this task: + + 1. Loads the ``.mat`` signal matrix via :func:`scipy.io.loadmat` + (shape ``(12, 5000)`` at 500 Hz). + 2. Optionally decimates the signal to 100 Hz (shape ``(12, 1000)``). + 3. Parses SNOMED-CT codes from the ``dx_codes`` event attribute. + 4. Maps those codes to the chosen label space (superdiagnostic or + full Challenge 27-class). + 5. Returns one sample dict per valid recording:: + + { + "signal": np.ndarray, # shape (12, T), float32 + "labels": List[str], # positive class names / SNOMED strings + } + + Args: + sampling_rate (int): Target sampling rate in Hz. Accepted values are + ``100`` (decimation ×5 from the native 500 Hz; yields ``T = 1000``) + and ``500`` (no resampling; yields ``T = 5000``). + Defaults to ``100``. + label_type (str): Label vocabulary to use. ``"superdiagnostic"`` + yields 5 classes (NORM, MI, STTC, CD, HYP); + ``"diagnostic"`` yields 27 SNOMED-CT classes from the PhysioNet + Challenge 2020 scoring list. Defaults to ``"superdiagnostic"``. + + Raises: + ValueError: If ``sampling_rate`` is not 100 or 500. + ValueError: If ``label_type`` is not ``"superdiagnostic"`` or + ``"diagnostic"``. + + Examples: + Superdiagnostic task at 100 Hz (default):: + + >>> from pyhealth.datasets import PTBXLDataset + >>> from pyhealth.tasks import PTBXLMultilabelClassification + >>> dataset = PTBXLDataset(root="/data/.../training/ptb-xl/") + >>> task = PTBXLMultilabelClassification() + >>> sample_ds = dataset.set_task(task) + >>> sample_ds[0]["labels"] # e.g. ["NORM"] or ["CD", "STTC"] + + 27-class diagnostic task at 500 Hz (ablation variant):: + + >>> task_27 = PTBXLMultilabelClassification( + ... sampling_rate=500, label_type="diagnostic" + ... ) + >>> sample_ds_27 = dataset.set_task(task_27) + + See Also: + :data:`SNOMED_TO_SUPERDIAG`, :data:`SUPERDIAG_CLASSES`, + :data:`CHALLENGE_SNOMED_CLASSES` + """ + + task_name: str = "PTBXLMultilabelClassification" + input_schema: Dict[str, str] = {"signal": "tensor"} + output_schema: Dict[str, str] = {"labels": "multilabel"} + + def __init__( + self, + sampling_rate: int = 100, + label_type: str = "superdiagnostic", + ) -> None: + super().__init__() + + if sampling_rate not in (100, 500): + raise ValueError( + f"sampling_rate must be 100 or 500, got {sampling_rate}." + ) + if label_type not in ("superdiagnostic", "diagnostic"): + raise ValueError( + "label_type must be 'superdiagnostic' or 'diagnostic', " + f"got '{label_type}'." + ) + + self.sampling_rate = sampling_rate + self.label_type = label_type + + # Disambiguate the task_name so that cached SampleDatasets from + # different configurations do not collide on disk. + self.task_name = ( + f"PTBXLSuperDiagnostic_{sampling_rate}Hz" + if label_type == "superdiagnostic" + else f"PTBXLDiagnostic27_{sampling_rate}Hz" + ) + + # ------------------------------------------------------------------ + # Core logic + # ------------------------------------------------------------------ + + def __call__(self, patient: Patient) -> List[Dict]: + """Extract samples from one patient (= one ECG recording in PTB-XL). + + Args: + patient: A :class:`~pyhealth.data.Patient` object whose events + have ``event_type="ptbxl"`` and carry attributes + ``mat``, ``dx_codes``, ``age``, and ``sex``. + These map from ``load_data()`` columns ``ptbxl/mat`` and + ``ptbxl/dx_codes`` (the ``ptbxl/`` prefix is stripped by + :meth:`~pyhealth.data.Event.from_dict`). + + Returns: + A list with at most one sample dict + ``{"signal": np.ndarray, "labels": List[str]}``, or an empty list + if the recording should be skipped (missing file, unrecognised + codes, etc.). + """ + # In PTBXLDataset each patient has exactly one event in the "ptbxl" + # table (record == patient). + events = patient.get_events(event_type="ptbxl") + samples = [] + + for event in events: + # ---- 1. Load the .mat signal -------------------------------- + # Column "ptbxl/mat" in load_data() → attribute "mat" on the event + # (the "ptbxl/" table prefix is stripped by Event.from_dict). + signal_file = getattr(event, "mat", None) + if not signal_file: + logger.debug("Skip %s: no mat attribute.", event) + continue + + try: + from scipy.io import loadmat as _loadmat + mat = _loadmat(signal_file) + signal = mat["val"].astype(np.float32) # (12, 5000) @ 500 Hz + except Exception as exc: + logger.warning("Cannot load signal from %s: %s", signal_file, exc) + continue + + if signal.ndim != 2 or signal.shape[0] != 12: + logger.warning( + "Unexpected signal shape %s in %s; skipping.", + signal.shape, + signal_file, + ) + continue + + # ---- 2. Resample if needed (decimation only) ---------------- + # Native rate is 500 Hz (5000 samples / 10 s). + # Decimation by 5 gives 100 Hz (1000 samples / 10 s). + if self.sampling_rate == 100: + signal = signal[:, ::5] # shape (12, 1000) + + # ---- 3. Parse SNOMED-CT codes -------------------------------- + # Column "ptbxl/dx_codes" → attribute "dx_codes"; values are + # comma-joined by load_data() via ",".join(dx), so split on ",". + raw_codes: str = str(getattr(event, "dx_codes", "") or "") + codes = [c.strip() for c in raw_codes.split(",") if c.strip()] + + # ---- 4. Map to chosen label space --------------------------- + if self.label_type == "superdiagnostic": + labels = list( + { + SNOMED_TO_SUPERDIAG[c] + for c in codes + if c in SNOMED_TO_SUPERDIAG + } + ) + else: # "diagnostic" — 27-class Challenge vocabulary + labels = [c for c in codes if c in _CHALLENGE_SET] + + if not labels: + # No recognised labels → skip (consistent with other tasks). + continue + + samples.append({"signal": signal, "labels": labels}) + + return samples diff --git a/tests/core/test_ptbxl.py b/tests/core/test_ptbxl.py new file mode 100644 index 000000000..ad1226e68 --- /dev/null +++ b/tests/core/test_ptbxl.py @@ -0,0 +1,454 @@ +"""Unit tests for PTBXLDataset and PTBXLMultilabelClassification. + +Test strategy +------------- +* All tests are self-contained and run fully offline — no network calls, no + real ECG data required. +* ``TestPTBXLDataset`` exercises ``prepare_metadata`` in isolation by creating + a minimal temporary filesystem (tiny ``.hea`` header stubs + zero-byte + ``.mat`` placeholders) and verifying the CSV produced. +* ``TestPTBXLMultilabelClassification`` exercises the task's ``__call__`` + method with synthetic in-memory ECG arrays, bypassing the dataset loading + machinery entirely. Both ``label_type`` variants and both ``sampling_rate`` + values are tested. + +Author: + CS-598 DLH Project Team +""" + +import io +import os +import struct +import tempfile +import unittest +from dataclasses import dataclass, field +from pathlib import Path +from typing import List +from unittest.mock import patch + +import numpy as np +import pandas as pd + +from pyhealth.datasets.ptbxl import PTBXLDataset +from pyhealth.tasks.ptbxl_multilabel_classification import ( + CHALLENGE_SNOMED_CLASSES, + SNOMED_TO_SUPERDIAG, + SUPERDIAG_CLASSES, + PTBXLMultilabelClassification, +) + + +# --------------------------------------------------------------------------- +# Helpers for constructing a minimal fake filesystem +# --------------------------------------------------------------------------- + +def _write_hea(path: Path, record_id: str, age: int, sex: str, dx: str) -> None: + """Write a minimal WFDB-style .hea file with the required comment lines.""" + header = ( + f"{record_id} 12 500 5000\n" + f"# Age: {age}\n" + f"# Sex: {sex}\n" + f"# Dx: {dx}\n" + ) + path.write_text(header, encoding="utf-8") + + +def _write_mat(path: Path) -> None: + """Write a zero-byte placeholder .mat file (sufficient for metadata tests).""" + path.write_bytes(b"") + + +def _make_mat_bytes(signal: np.ndarray) -> bytes: + """Produce a minimal scipy.io.savemat-compatible bytes object in memory. + + We use scipy.io.savemat rather than a home-rolled format so that loadmat + can round-trip the data correctly. + """ + import scipy.io + buf = io.BytesIO() + scipy.io.savemat(buf, {"val": signal}) + return buf.getvalue() + + +# --------------------------------------------------------------------------- +# Fake Patient / Event for task tests +# --------------------------------------------------------------------------- + +@dataclass +class _FakeEvent: + signal_file: str = "" + scp_codes: str = "" + age: int = 50 + sex: str = "Male" + + +class _FakePatient: + def __init__(self, patient_id: str, events: List[_FakeEvent]): + self.patient_id = patient_id + self._events = events + + def get_events(self, event_type: str = None) -> List[_FakeEvent]: + return self._events + + +# --------------------------------------------------------------------------- +# Dataset metadata tests +# --------------------------------------------------------------------------- + +class TestPTBXLDataset(unittest.TestCase): + """Test PTBXLDataset.prepare_metadata without touching BaseDataset.__init__.""" + + def _make_ds(self, root: str) -> PTBXLDataset: + """Instantiate PTBXLDataset bypassing BaseDataset initialisation.""" + ds = PTBXLDataset.__new__(PTBXLDataset) + ds.root = root + return ds + + # ------------------------------------------------------------------ + # Baseline: single group directory + # ------------------------------------------------------------------ + + def test_prepare_metadata_basic(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + g1 = root / "g1" + g1.mkdir() + + _write_hea(g1 / "HR00001.hea", "HR00001", 56, "Female", "426783006,251146004") + _write_mat(g1 / "HR00001.mat") + _write_hea(g1 / "HR00002.hea", "HR00002", 42, "Male", "270492004") + _write_mat(g1 / "HR00002.mat") + + ds = self._make_ds(tmp) + ds.prepare_metadata() + + csv = root / "ptbxl-pyhealth.csv" + self.assertTrue(csv.exists(), "ptbxl-pyhealth.csv should be written") + + df = pd.read_csv(csv) + self.assertEqual(len(df), 2) + self.assertIn("patient_id", df.columns) + self.assertIn("record_id", df.columns) + self.assertIn("signal_file", df.columns) + self.assertIn("age", df.columns) + self.assertIn("sex", df.columns) + self.assertIn("scp_codes", df.columns) + + row = df[df["patient_id"] == "HR00001"].iloc[0] + self.assertEqual(row["age"], 56) + self.assertEqual(row["sex"], "Female") + self.assertEqual(row["scp_codes"], "426783006,251146004") + self.assertTrue(str(row["signal_file"]).endswith("HR00001.mat")) + + # ------------------------------------------------------------------ + # Multiple group directories + # ------------------------------------------------------------------ + + def test_prepare_metadata_multiple_groups(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + for g, rec_id in [("g1", "HR00001"), ("g2", "HR01001"), ("g3", "HR02001")]: + gdir = root / g + gdir.mkdir() + _write_hea(gdir / f"{rec_id}.hea", rec_id, 30, "Male", "426783006") + _write_mat(gdir / f"{rec_id}.mat") + + ds = self._make_ds(tmp) + ds.prepare_metadata() + + df = pd.read_csv(root / "ptbxl-pyhealth.csv") + self.assertEqual(len(df), 3) + self.assertEqual(sorted(df["patient_id"].tolist()), sorted(["HR00001", "HR01001", "HR02001"])) + + # ------------------------------------------------------------------ + # Missing .mat → row is skipped + # ------------------------------------------------------------------ + + def test_prepare_metadata_skips_missing_mat(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + g1 = root / "g1" + g1.mkdir() + + _write_hea(g1 / "HR00001.hea", "HR00001", 45, "Male", "426783006") + # deliberately omit HR00001.mat + _write_hea(g1 / "HR00002.hea", "HR00002", 30, "Female", "270492004") + _write_mat(g1 / "HR00002.mat") + + ds = self._make_ds(tmp) + ds.prepare_metadata() + + df = pd.read_csv(root / "ptbxl-pyhealth.csv") + self.assertEqual(len(df), 1) + self.assertEqual(df.iloc[0]["patient_id"], "HR00002") + + # ------------------------------------------------------------------ + # Idempotency: calling prepare_metadata twice should not raise + # ------------------------------------------------------------------ + + def test_prepare_metadata_idempotent(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + g1 = root / "g1" + g1.mkdir() + _write_hea(g1 / "HR00001.hea", "HR00001", 50, "Male", "426783006") + _write_mat(g1 / "HR00001.mat") + + ds = self._make_ds(tmp) + ds.prepare_metadata() + ds.prepare_metadata() # second call should be a no-op + + df = pd.read_csv(root / "ptbxl-pyhealth.csv") + self.assertEqual(len(df), 1) + + # ------------------------------------------------------------------ + # No .hea files → RuntimeError + # ------------------------------------------------------------------ + + def test_prepare_metadata_no_records_raises(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + (root / "g1").mkdir() # empty group dir + + ds = self._make_ds(tmp) + with self.assertRaises(RuntimeError): + ds.prepare_metadata() + + # ------------------------------------------------------------------ + # default_task property + # ------------------------------------------------------------------ + + def test_default_task_returns_superdiagnostic_instance(self): + ds = PTBXLDataset.__new__(PTBXLDataset) + task = ds.default_task + self.assertIsInstance(task, PTBXLMultilabelClassification) + self.assertEqual(task.label_type, "superdiagnostic") + self.assertEqual(task.sampling_rate, 100) + + +# --------------------------------------------------------------------------- +# Task unit tests +# --------------------------------------------------------------------------- + +class TestPTBXLMultilabelClassification(unittest.TestCase): + """Test PTBXLMultilabelClassification.__call__ with synthetic ECG data.""" + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _make_mat_file(self, tmp_dir: Path, name: str, signal: np.ndarray) -> str: + """Write a scipy .mat file containing 'val' and return its path.""" + import scipy.io + path = tmp_dir / name + scipy.io.savemat(str(path), {"val": signal}) + return str(path) + + def _make_patient(self, signal_file: str, scp_codes: str) -> _FakePatient: + event = _FakeEvent(signal_file=signal_file, scp_codes=scp_codes) + return _FakePatient("p001", [event]) + + # ------------------------------------------------------------------ + # Constructor validation + # ------------------------------------------------------------------ + + def test_invalid_sampling_rate_raises(self): + with self.assertRaises(ValueError): + PTBXLMultilabelClassification(sampling_rate=250) + + def test_invalid_label_type_raises(self): + with self.assertRaises(ValueError): + PTBXLMultilabelClassification(label_type="morphological") + + def test_task_names_are_unique(self): + t_a = PTBXLMultilabelClassification(sampling_rate=100, label_type="superdiagnostic") + t_b = PTBXLMultilabelClassification(sampling_rate=500, label_type="superdiagnostic") + t_c = PTBXLMultilabelClassification(sampling_rate=100, label_type="diagnostic") + t_d = PTBXLMultilabelClassification(sampling_rate=500, label_type="diagnostic") + names = {t_a.task_name, t_b.task_name, t_c.task_name, t_d.task_name} + self.assertEqual(len(names), 4, "All four configurations should have distinct task names.") + + # ------------------------------------------------------------------ + # Signal loading and decimation + # ------------------------------------------------------------------ + + def test_superdiagnostic_100hz_signal_shape(self): + """Superdiagnostic task at 100 Hz should yield (12, 1000) signals.""" + with tempfile.TemporaryDirectory() as tmp: + signal_500 = np.random.randn(12, 5000).astype(np.float32) + mat_path = self._make_mat_file(Path(tmp), "rec.mat", signal_500) + patient = self._make_patient(mat_path, "426783006") # NORM + + task = PTBXLMultilabelClassification(sampling_rate=100, label_type="superdiagnostic") + samples = task(patient) + + self.assertEqual(len(samples), 1) + self.assertEqual(samples[0]["signal"].shape, (12, 1000)) + + def test_superdiagnostic_500hz_signal_shape(self): + """Superdiagnostic task at 500 Hz should yield (12, 5000) signals (no decimation).""" + with tempfile.TemporaryDirectory() as tmp: + signal_500 = np.random.randn(12, 5000).astype(np.float32) + mat_path = self._make_mat_file(Path(tmp), "rec.mat", signal_500) + patient = self._make_patient(mat_path, "426783006") # NORM + + task = PTBXLMultilabelClassification(sampling_rate=500, label_type="superdiagnostic") + samples = task(patient) + + self.assertEqual(len(samples), 1) + self.assertEqual(samples[0]["signal"].shape, (12, 5000)) + + def test_signal_dtype_is_float32(self): + with tempfile.TemporaryDirectory() as tmp: + signal_500 = np.random.randn(12, 5000).astype(np.float64) # 64-bit input + mat_path = self._make_mat_file(Path(tmp), "rec.mat", signal_500) + patient = self._make_patient(mat_path, "426783006") + + task = PTBXLMultilabelClassification() + samples = task(patient) + self.assertEqual(samples[0]["signal"].dtype, np.float32) + + # ------------------------------------------------------------------ + # Superdiagnostic label mapping + # ------------------------------------------------------------------ + + def test_superdiagnostic_normal_label(self): + with tempfile.TemporaryDirectory() as tmp: + signal = np.zeros((12, 5000), dtype=np.float32) + mat_path = self._make_mat_file(Path(tmp), "rec.mat", signal) + patient = self._make_patient(mat_path, "426783006") # → NORM + + task = PTBXLMultilabelClassification(label_type="superdiagnostic") + samples = task(patient) + + self.assertEqual(len(samples), 1) + self.assertIn("NORM", samples[0]["labels"]) + + def test_superdiagnostic_multilabel(self): + """A recording with both AF (CD) and low QRS voltage (HYP) codes.""" + with tempfile.TemporaryDirectory() as tmp: + signal = np.zeros((12, 5000), dtype=np.float32) + mat_path = self._make_mat_file(Path(tmp), "rec.mat", signal) + # 164889003 → CD (atrial fibrillation), 251146004 → HYP (low QRS voltage) + patient = self._make_patient(mat_path, "164889003,251146004") + + task = PTBXLMultilabelClassification(label_type="superdiagnostic") + samples = task(patient) + + self.assertEqual(len(samples), 1) + label_set = set(samples[0]["labels"]) + self.assertIn("CD", label_set) + self.assertIn("HYP", label_set) + + def test_superdiagnostic_no_known_codes_skipped(self): + """Recordings with no recognised superdiagnostic codes should be skipped.""" + with tempfile.TemporaryDirectory() as tmp: + signal = np.zeros((12, 5000), dtype=np.float32) + mat_path = self._make_mat_file(Path(tmp), "rec.mat", signal) + patient = self._make_patient(mat_path, "999999999") # unknown code + + task = PTBXLMultilabelClassification(label_type="superdiagnostic") + samples = task(patient) + self.assertEqual(samples, []) + + # ------------------------------------------------------------------ + # Diagnostic (27-class) label mapping + # ------------------------------------------------------------------ + + def test_diagnostic_known_challenge_code(self): + # 270492004 = First-degree AV block, part of Challenge 2020 scoring set + with tempfile.TemporaryDirectory() as tmp: + signal = np.zeros((12, 5000), dtype=np.float32) + mat_path = self._make_mat_file(Path(tmp), "rec.mat", signal) + patient = self._make_patient(mat_path, "270492004") + + task = PTBXLMultilabelClassification(label_type="diagnostic") + samples = task(patient) + + self.assertEqual(len(samples), 1) + self.assertIn("270492004", samples[0]["labels"]) + + def test_diagnostic_non_challenge_code_skipped(self): + """A code not in the 27-class Challenge vocabulary should be filtered out.""" + with tempfile.TemporaryDirectory() as tmp: + signal = np.zeros((12, 5000), dtype=np.float32) + mat_path = self._make_mat_file(Path(tmp), "rec.mat", signal) + patient = self._make_patient(mat_path, "999999999") + + task = PTBXLMultilabelClassification(label_type="diagnostic") + samples = task(patient) + self.assertEqual(samples, []) + + def test_diagnostic_multiple_valid_codes(self): + """Multiple Challenge codes in one recording should all appear as labels.""" + with tempfile.TemporaryDirectory() as tmp: + signal = np.zeros((12, 5000), dtype=np.float32) + mat_path = self._make_mat_file(Path(tmp), "rec.mat", signal) + # 164889003 (AF) and 426783006 (NSR) are both in Challenge vocabulary + patient = self._make_patient(mat_path, "164889003,426783006") + + task = PTBXLMultilabelClassification(label_type="diagnostic") + samples = task(patient) + + self.assertEqual(len(samples), 1) + label_set = set(samples[0]["labels"]) + self.assertIn("164889003", label_set) + self.assertIn("426783006", label_set) + + # ------------------------------------------------------------------ + # Edge cases + # ------------------------------------------------------------------ + + def test_missing_signal_file_returns_empty(self): + event = _FakeEvent(signal_file="", scp_codes="426783006") + patient = _FakePatient("p001", [event]) + task = PTBXLMultilabelClassification() + self.assertEqual(task(patient), []) + + def test_nonexistent_signal_file_returns_empty(self): + event = _FakeEvent(signal_file="/nonexistent/path/to/rec.mat", scp_codes="426783006") + patient = _FakePatient("p001", [event]) + task = PTBXLMultilabelClassification() + self.assertEqual(task(patient), []) + + def test_empty_patient_no_events(self): + patient = _FakePatient("p001", []) + task = PTBXLMultilabelClassification() + self.assertEqual(task(patient), []) + + def test_wrong_signal_shape_skipped(self): + """Signals that are not 2-D with 12 channels should be skipped.""" + with tempfile.TemporaryDirectory() as tmp: + # Write a single-channel signal (shape 1×5000) + signal = np.zeros((1, 5000), dtype=np.float32) + mat_path = self._make_mat_file(Path(tmp), "rec.mat", signal) + event = _FakeEvent(signal_file=mat_path, scp_codes="426783006") + patient = _FakePatient("p001", [event]) + + task = PTBXLMultilabelClassification(label_type="superdiagnostic") + samples = task(patient) + self.assertEqual(samples, []) + + # ------------------------------------------------------------------ + # Label-space constants sanity checks + # ------------------------------------------------------------------ + + def test_superdiag_classes_count(self): + self.assertEqual(len(SUPERDIAG_CLASSES), 5) + self.assertEqual(set(SUPERDIAG_CLASSES), {"NORM", "MI", "STTC", "CD", "HYP"}) + + def test_challenge_classes_count(self): + self.assertEqual(len(CHALLENGE_SNOMED_CLASSES), 27) + + def test_snomed_to_superdiag_values(self): + valid_classes = set(SUPERDIAG_CLASSES) + for code, cls in SNOMED_TO_SUPERDIAG.items(): + self.assertIn( + cls, + valid_classes, + f"SNOMED code {code} maps to unknown class '{cls}'", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_ptbxl_task.py b/tests/core/test_ptbxl_task.py new file mode 100644 index 000000000..1bc792d07 --- /dev/null +++ b/tests/core/test_ptbxl_task.py @@ -0,0 +1,289 @@ +"""Unit tests for PTBXLMultilabelClassification task. + +Test strategy +------------- +* All tests are self-contained and run fully offline — no network calls, no + real ECG data required. +* Synthetic in-memory ECG arrays are written to temporary ``.mat`` files via + ``scipy.io.savemat`` so that the task's ``loadmat`` call round-trips cleanly. +* ``_FakeEvent`` mirrors the event attributes produced by + ``PTBXLDataset.load_data()``: ``mat`` (file path) and ``dx_codes`` + (SNOMED-CT codes joined by ``"."``). +* Both ``label_type`` variants (``"superdiagnostic"`` / ``"diagnostic"``) and + both ``sampling_rate`` values (100 / 500) are exercised. + +Author: + CS-598 DLH Project Team +""" + +import tempfile +import unittest +from dataclasses import dataclass +from pathlib import Path +from typing import List + +import numpy as np + +from pyhealth.tasks.ptbxl_multilabel_classification import ( + CHALLENGE_SNOMED_CLASSES, + SNOMED_TO_SUPERDIAG, + SUPERDIAG_CLASSES, + PTBXLMultilabelClassification, +) + + +# --------------------------------------------------------------------------- +# Fake Patient / Event +# --------------------------------------------------------------------------- + +@dataclass +class _FakeEvent: + """Minimal stand-in for a PyHealth Event with PTB-XL attributes. + + Attribute names match what PTBXLDataset.load_data() produces after the + ``ptbxl/`` table-prefix is stripped by BaseDataset: + - ``mat`` ← column ``ptbxl/mat`` (path to .mat signal file) + - ``dx_codes`` ← column ``ptbxl/dx_codes`` (SNOMED codes, dot-joined) + """ + mat: str = "" + dx_codes: str = "" + age: int = 50 + sex: str = "Male" + + +class _FakePatient: + def __init__(self, patient_id: str, events: List[_FakeEvent]): + self.patient_id = patient_id + self._events = events + + def get_events(self, event_type: str = None) -> List[_FakeEvent]: + return self._events + + +# --------------------------------------------------------------------------- +# Task unit tests +# --------------------------------------------------------------------------- + +class TestPTBXLMultilabelClassification(unittest.TestCase): + """Test PTBXLMultilabelClassification.__call__ with synthetic ECG data.""" + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _make_mat_file(self, tmp_dir: Path, name: str, signal: np.ndarray) -> str: + """Write a scipy .mat file containing 'val' and return its path.""" + import scipy.io + path = tmp_dir / name + scipy.io.savemat(str(path), {"val": signal}) + return str(path) + + def _make_patient(self, mat_path: str, dx_codes: str) -> _FakePatient: + event = _FakeEvent(mat=mat_path, dx_codes=dx_codes) + return _FakePatient("p001", [event]) + + # ------------------------------------------------------------------ + # Constructor validation + # ------------------------------------------------------------------ + + def test_invalid_sampling_rate_raises(self): + with self.assertRaises(ValueError): + PTBXLMultilabelClassification(sampling_rate=250) + + def test_invalid_label_type_raises(self): + with self.assertRaises(ValueError): + PTBXLMultilabelClassification(label_type="morphological") + + def test_task_names_are_unique(self): + t_a = PTBXLMultilabelClassification(sampling_rate=100, label_type="superdiagnostic") + t_b = PTBXLMultilabelClassification(sampling_rate=500, label_type="superdiagnostic") + t_c = PTBXLMultilabelClassification(sampling_rate=100, label_type="diagnostic") + t_d = PTBXLMultilabelClassification(sampling_rate=500, label_type="diagnostic") + names = {t_a.task_name, t_b.task_name, t_c.task_name, t_d.task_name} + self.assertEqual(len(names), 4, "All four configurations should have distinct task names.") + + # ------------------------------------------------------------------ + # Signal loading and decimation + # ------------------------------------------------------------------ + + def test_superdiagnostic_100hz_signal_shape(self): + """Superdiagnostic task at 100 Hz should yield (12, 1000) signals.""" + with tempfile.TemporaryDirectory() as tmp: + signal_500 = np.random.randn(12, 5000).astype(np.float32) + mat_path = self._make_mat_file(Path(tmp), "rec.mat", signal_500) + patient = self._make_patient(mat_path, "426783006") # NORM + + task = PTBXLMultilabelClassification(sampling_rate=100, label_type="superdiagnostic") + samples = task(patient) + + self.assertEqual(len(samples), 1) + self.assertEqual(samples[0]["signal"].shape, (12, 1000)) + + def test_superdiagnostic_500hz_signal_shape(self): + """Superdiagnostic task at 500 Hz should yield (12, 5000) signals (no decimation).""" + with tempfile.TemporaryDirectory() as tmp: + signal_500 = np.random.randn(12, 5000).astype(np.float32) + mat_path = self._make_mat_file(Path(tmp), "rec.mat", signal_500) + patient = self._make_patient(mat_path, "426783006") # NORM + + task = PTBXLMultilabelClassification(sampling_rate=500, label_type="superdiagnostic") + samples = task(patient) + + self.assertEqual(len(samples), 1) + self.assertEqual(samples[0]["signal"].shape, (12, 5000)) + + def test_signal_dtype_is_float32(self): + with tempfile.TemporaryDirectory() as tmp: + signal_500 = np.random.randn(12, 5000).astype(np.float64) # 64-bit input + mat_path = self._make_mat_file(Path(tmp), "rec.mat", signal_500) + patient = self._make_patient(mat_path, "426783006") + + task = PTBXLMultilabelClassification() + samples = task(patient) + self.assertEqual(samples[0]["signal"].dtype, np.float32) + + # ------------------------------------------------------------------ + # Superdiagnostic label mapping + # ------------------------------------------------------------------ + + def test_superdiagnostic_normal_label(self): + with tempfile.TemporaryDirectory() as tmp: + signal = np.zeros((12, 5000), dtype=np.float32) + mat_path = self._make_mat_file(Path(tmp), "rec.mat", signal) + patient = self._make_patient(mat_path, "426783006") # → NORM + + task = PTBXLMultilabelClassification(label_type="superdiagnostic") + samples = task(patient) + + self.assertEqual(len(samples), 1) + self.assertIn("NORM", samples[0]["labels"]) + + def test_superdiagnostic_multilabel(self): + """A recording with both AF (CD) and low QRS voltage (HYP) codes.""" + with tempfile.TemporaryDirectory() as tmp: + signal = np.zeros((12, 5000), dtype=np.float32) + mat_path = self._make_mat_file(Path(tmp), "rec.mat", signal) + # 164889003 → CD (atrial fibrillation), 251146004 → HYP (low QRS voltage) + # dx_codes are dot-joined (load_data uses ".".join(dx)) + patient = self._make_patient(mat_path, "164889003.251146004") + + task = PTBXLMultilabelClassification(label_type="superdiagnostic") + samples = task(patient) + + self.assertEqual(len(samples), 1) + label_set = set(samples[0]["labels"]) + self.assertIn("CD", label_set) + self.assertIn("HYP", label_set) + + def test_superdiagnostic_no_known_codes_skipped(self): + """Recordings with no recognised superdiagnostic codes should be skipped.""" + with tempfile.TemporaryDirectory() as tmp: + signal = np.zeros((12, 5000), dtype=np.float32) + mat_path = self._make_mat_file(Path(tmp), "rec.mat", signal) + patient = self._make_patient(mat_path, "999999999") # unknown code + + task = PTBXLMultilabelClassification(label_type="superdiagnostic") + samples = task(patient) + self.assertEqual(samples, []) + + # ------------------------------------------------------------------ + # Diagnostic (27-class) label mapping + # ------------------------------------------------------------------ + + def test_diagnostic_known_challenge_code(self): + # 270492004 = First-degree AV block, part of Challenge 2020 scoring set + with tempfile.TemporaryDirectory() as tmp: + signal = np.zeros((12, 5000), dtype=np.float32) + mat_path = self._make_mat_file(Path(tmp), "rec.mat", signal) + patient = self._make_patient(mat_path, "270492004") + + task = PTBXLMultilabelClassification(label_type="diagnostic") + samples = task(patient) + + self.assertEqual(len(samples), 1) + self.assertIn("270492004", samples[0]["labels"]) + + def test_diagnostic_non_challenge_code_skipped(self): + """A code not in the 27-class Challenge vocabulary should be filtered out.""" + with tempfile.TemporaryDirectory() as tmp: + signal = np.zeros((12, 5000), dtype=np.float32) + mat_path = self._make_mat_file(Path(tmp), "rec.mat", signal) + patient = self._make_patient(mat_path, "999999999") + + task = PTBXLMultilabelClassification(label_type="diagnostic") + samples = task(patient) + self.assertEqual(samples, []) + + def test_diagnostic_multiple_valid_codes(self): + """Multiple Challenge codes in one recording should all appear as labels.""" + with tempfile.TemporaryDirectory() as tmp: + signal = np.zeros((12, 5000), dtype=np.float32) + mat_path = self._make_mat_file(Path(tmp), "rec.mat", signal) + # 164889003 (AF) and 426783006 (NSR) are both in Challenge vocabulary + # dot-joined as produced by load_data() + patient = self._make_patient(mat_path, "164889003.426783006") + + task = PTBXLMultilabelClassification(label_type="diagnostic") + samples = task(patient) + + self.assertEqual(len(samples), 1) + label_set = set(samples[0]["labels"]) + self.assertIn("164889003", label_set) + self.assertIn("426783006", label_set) + + # ------------------------------------------------------------------ + # Edge cases + # ------------------------------------------------------------------ + + def test_missing_signal_file_returns_empty(self): + event = _FakeEvent(mat="", dx_codes="426783006") + patient = _FakePatient("p001", [event]) + task = PTBXLMultilabelClassification() + self.assertEqual(task(patient), []) + + def test_nonexistent_signal_file_returns_empty(self): + event = _FakeEvent(mat="/nonexistent/path/to/rec.mat", dx_codes="426783006") + patient = _FakePatient("p001", [event]) + task = PTBXLMultilabelClassification() + self.assertEqual(task(patient), []) + + def test_empty_patient_no_events(self): + patient = _FakePatient("p001", []) + task = PTBXLMultilabelClassification() + self.assertEqual(task(patient), []) + + def test_wrong_signal_shape_skipped(self): + """Signals that are not 2-D with 12 channels should be skipped.""" + with tempfile.TemporaryDirectory() as tmp: + signal = np.zeros((1, 5000), dtype=np.float32) # single-channel, not 12 + mat_path = self._make_mat_file(Path(tmp), "rec.mat", signal) + event = _FakeEvent(mat=mat_path, dx_codes="426783006") + patient = _FakePatient("p001", [event]) + + task = PTBXLMultilabelClassification(label_type="superdiagnostic") + samples = task(patient) + self.assertEqual(samples, []) + + # ------------------------------------------------------------------ + # Label-space constants sanity checks + # ------------------------------------------------------------------ + + def test_superdiag_classes_count(self): + self.assertEqual(len(SUPERDIAG_CLASSES), 5) + self.assertEqual(set(SUPERDIAG_CLASSES), {"NORM", "MI", "STTC", "CD", "HYP"}) + + def test_challenge_classes_count(self): + self.assertEqual(len(CHALLENGE_SNOMED_CLASSES), 27) + + def test_snomed_to_superdiag_values(self): + valid_classes = set(SUPERDIAG_CLASSES) + for code, cls in SNOMED_TO_SUPERDIAG.items(): + self.assertIn( + cls, + valid_classes, + f"SNOMED code {code} maps to unknown class '{cls}'", + ) + + +if __name__ == "__main__": + unittest.main()