From 4b2b80f55a1a6d294769ff8d2c48ea51c23575ee Mon Sep 17 00:00:00 2001 From: Rian354 Date: Mon, 8 Dec 2025 03:08:40 -0500 Subject: [PATCH 1/8] medlink bounty implementation --- examples/medlink_mimic3.ipynb | 680 ++++++++++++++++++++++++++++ pyhealth/__init__.py | 1 + pyhealth/datasets/sample_dataset.py | 34 +- pyhealth/models/__init__.py | 3 +- pyhealth/models/medlink/model.py | 19 +- tests/core/test_medlink.py | 122 +++++ 6 files changed, 846 insertions(+), 13 deletions(-) create mode 100644 examples/medlink_mimic3.ipynb create mode 100644 tests/core/test_medlink.py diff --git a/examples/medlink_mimic3.ipynb b/examples/medlink_mimic3.ipynb new file mode 100644 index 000000000..9f408d889 --- /dev/null +++ b/examples/medlink_mimic3.ipynb @@ -0,0 +1,680 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "1ee5347e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PROJECT_ROOT: /Users/saurabhatri/Downloads/PyHealth\n", + "✓ PyTorch is installed\n", + "✓ pyhealth is importable, version: 1.1.4\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/saurabhatri/Downloads/PyHealth/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import os\n", + "import sys\n", + "\n", + "# Ensure project root is on sys.path when running from examples/\n", + "PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n", + "if PROJECT_ROOT not in sys.path:\n", + " sys.path.insert(0, PROJECT_ROOT)\n", + "\n", + "print(\"PROJECT_ROOT:\", PROJECT_ROOT)\n", + "\n", + "# Basic sanity check for torch and pyhealth\n", + "try:\n", + " import torch\n", + " print(\"✓ PyTorch is installed\")\n", + "except ImportError as e:\n", + " raise RuntimeError(\n", + " \"PyTorch is not installed. Install it into your environment first \"\n", + " \"(e.g., `pip install torch` matching your CUDA/CPU).\" \n", + " ) from e\n", + "\n", + "try:\n", + " import pyhealth\n", + " print(\"✓ pyhealth is importable, version:\", getattr(pyhealth, \"__version__\", \"unknown\"))\n", + "except ImportError as e:\n", + " raise RuntimeError(\n", + " \"pyhealth is not importable. From the project root, run\\n\"\n", + " \" pip install -e .\\n\"\n", + " \"to install PyHealth in editable mode.\"\n", + " ) from e\n", + "\n", + "# Core dataset + MedLink imports\n", + "from pyhealth.datasets import MIMIC3Dataset\n", + "from pyhealth.tasks import BaseTask\n", + "from pyhealth.models.medlink import (\n", + " BM25Okapi,\n", + " convert_to_ir_format,\n", + " filter_by_candidates,\n", + " generate_candidates,\n", + " get_bm25_hard_negatives,\n", + " get_eval_dataloader,\n", + " get_train_dataloader,\n", + " tvt_split,\n", + ")\n", + "from pyhealth.models.medlink.model import MedLink\n", + "from pyhealth.metrics import ranking_metrics_fn\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "240e358e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MIMIC-III demo root: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4\n" + ] + } + ], + "source": [ + "# Downloaded from: https://physionet.org/content/mimiciii-demo/1.4/\n", + "MIMIC3_DEMO_ROOT = \"/path/to/mimic-iii-clinical-database-demo-1.4\" # <-- adjust for real\n", + "\n", + "print(\"MIMIC-III demo root:\", MIMIC3_DEMO_ROOT)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f0851481", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No config path provided, using default config\n", + "Initializing mimic3 dataset from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4 (dev mode: False)\n", + "Scanning table: patients from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/PATIENTS.csv.gz\n", + "Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/PATIENTS.csv\n", + "Scanning table: admissions from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ADMISSIONS.csv.gz\n", + "Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ADMISSIONS.csv\n", + "Scanning table: icustays from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ICUSTAYS.csv.gz\n", + "Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ICUSTAYS.csv\n", + "Scanning table: diagnoses_icd from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/DIAGNOSES_ICD.csv.gz\n", + "Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/DIAGNOSES_ICD.csv\n", + "Joining with table: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ADMISSIONS.csv.gz\n", + "Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ADMISSIONS.csv\n", + "Collecting global event dataframe...\n", + "Collected dataframe with shape: (2126, 31)\n", + "Dataset: mimic3\n", + "Dev mode: False\n", + "Number of patients: 100\n", + "Number of events: 2126\n" + ] + } + ], + "source": [ + "# STEP 1: Load base MIMIC-III dataset from the demo\n", + "\n", + "base_dataset = MIMIC3Dataset(\n", + " root=MIMIC3_DEMO_ROOT,\n", + " tables=[\"diagnoses_icd\"], # matches `diagnoses_icd` in configs/mimic3.yaml\n", + " dev=False, # True => small subset of patients\n", + ")\n", + "\n", + "base_dataset.stats()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "8969fcbd", + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "from collections import defaultdict\n", + "import math\n", + "\n", + "class PatientLinkageMIMIC3Task(BaseTask):\n", + " \"\"\"\n", + " Patient linkage task for MIMIC-III using the new Patient/Event API.\n", + "\n", + " It produces samples with the same keys as the old\n", + " `patient_linkage_mimic3_fn`, so that medlink.utils.convert_to_ir_format\n", + " works unchanged.\n", + " \"\"\"\n", + "\n", + " task_name = \"patient_linkage_mimic3\"\n", + " # MedLink actually consumes `conditions` / `d_conditions` as sequences,\n", + " # but we don't rely heavily on the feature processors here.\n", + " input_schema = {\n", + " \"conditions\": \"sequence\",\n", + " \"d_conditions\": \"sequence\",\n", + " }\n", + " # No supervised label for MedLink retrieval\n", + " output_schema = {}\n", + "\n", + " def __call__(self, patient):\n", + " \"\"\"\n", + " Process a single patient into MedLink samples.\n", + "\n", + " Requirements (same as original task):\n", + " - At least 2 visits (admissions)\n", + " - Age >= 18 at both visits\n", + " - Non-empty conditions for both visits\n", + " \"\"\"\n", + " # --- 1) Get admissions (visits), sorted by time ---\n", + " admissions = patient.get_events(event_type=\"admissions\")\n", + " if len(admissions) < 2:\n", + " return []\n", + "\n", + " admissions = sorted(admissions, key=lambda e: e.timestamp)\n", + " q_visit = admissions[-1] # last visit (query)\n", + " d_visit = admissions[-2] # second last visit (document)\n", + "\n", + " # --- 2) Get patient demographics (gender, dob) ---\n", + " patients_events = patient.get_events(event_type=\"patients\")\n", + " if not patients_events:\n", + " return []\n", + " demo = patients_events[0]\n", + "\n", + " gender = str(demo.attr_dict.get(\"gender\") or \"\")\n", + "\n", + " dob_raw = demo.attr_dict.get(\"dob\")\n", + " birth_dt = None\n", + " if isinstance(dob_raw, datetime):\n", + " birth_dt = dob_raw\n", + " elif dob_raw is not None:\n", + " # In the MIMIC CSV it's a string like \"2111-04-20 00:00:00\"\n", + " try:\n", + " birth_dt = datetime.fromisoformat(str(dob_raw))\n", + " except Exception:\n", + " birth_dt = None\n", + "\n", + " def compute_age(ts):\n", + " if birth_dt is None or ts is None:\n", + " return None\n", + " # rough years\n", + " return int((ts - birth_dt).days // 365.25)\n", + "\n", + " q_age = compute_age(q_visit.timestamp)\n", + " d_age = compute_age(d_visit.timestamp)\n", + "\n", + " # Exclude under 18 or missing age\n", + " if q_age is None or d_age is None or q_age < 18 or d_age < 18:\n", + " return []\n", + "\n", + " # --- 3) Collect diagnosis codes per admission (hadm_id) ---\n", + " diag_events = patient.get_events(event_type=\"diagnoses_icd\")\n", + " hadm_to_codes = defaultdict(list)\n", + " for ev in diag_events:\n", + " hadm = ev.attr_dict.get(\"hadm_id\")\n", + " code = ev.attr_dict.get(\"icd9_code\")\n", + " if hadm is None or code is None:\n", + " continue\n", + " hadm_to_codes[str(hadm)].append(str(code))\n", + "\n", + " q_hadm = str(q_visit.attr_dict.get(\"hadm_id\"))\n", + " d_hadm = str(d_visit.attr_dict.get(\"hadm_id\"))\n", + "\n", + " q_conditions = hadm_to_codes.get(q_hadm, [])\n", + " d_conditions = hadm_to_codes.get(d_hadm, [])\n", + "\n", + " # Exclude if any side has no conditions\n", + " if len(q_conditions) == 0 or len(d_conditions) == 0:\n", + " return []\n", + "\n", + " # --- 4) Identifier strings (gender + admin attributes) ---\n", + " def clean(x):\n", + " # mimic old NaN handling: empty string if missing/NaN\n", + " if x is None:\n", + " return \"\"\n", + " if isinstance(x, float) and math.isnan(x):\n", + " return \"\"\n", + " return str(x)\n", + "\n", + " def build_identifiers(adm_event):\n", + " insurance = clean(adm_event.attr_dict.get(\"insurance\"))\n", + " language = clean(adm_event.attr_dict.get(\"language\"))\n", + " religion = clean(adm_event.attr_dict.get(\"religion\"))\n", + " marital_status = clean(adm_event.attr_dict.get(\"marital_status\"))\n", + " ethnicity = clean(adm_event.attr_dict.get(\"ethnicity\"))\n", + " return \"+\".join(\n", + " [gender, insurance, language, religion, marital_status, ethnicity]\n", + " )\n", + "\n", + " q_identifiers = build_identifiers(q_visit)\n", + " d_identifiers = build_identifiers(d_visit)\n", + "\n", + " # --- 5) Build sample dict (same keys as old function) ---\n", + " sample = {\n", + " \"patient_id\": patient.patient_id,\n", + " \"visit_id\": q_hadm, # query visit_id\n", + " \"conditions\": [\"\"] + q_conditions,\n", + " \"age\": q_age,\n", + " \"identifiers\": q_identifiers,\n", + "\n", + " \"d_visit_id\": d_hadm, # document visit_id\n", + " \"d_conditions\": [\"\"] + d_conditions,\n", + " \"d_age\": d_age,\n", + " \"d_identifiers\": d_identifiers,\n", + " }\n", + "\n", + " return [sample]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "bce967de", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Setting task patient_linkage_mimic3 for mimic3 base dataset...\n", + "Generating samples with 1 worker(s)...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating samples for patient_linkage_mimic3 with 1 worker: 100%|██████████| 100/100 [00:00<00:00, 1499.64it/s]\n", + "Processing samples: 100%|██████████| 14/14 [00:00<00:00, 35246.25it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generated 14 samples for task patient_linkage_mimic3\n", + "Number of samples generated: 14\n", + "Example sample:\n", + " {'patient_id': '42346', 'visit_id': '175880', 'conditions': tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,\n", + " 19, 20]), 'age': 88, 'identifiers': 'F+Medicare+ENGL+NOT SPECIFIED+SINGLE+WHITE', 'd_visit_id': '180391', 'd_conditions': tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,\n", + " 19, 20, 21, 22, 23]), 'd_age': 88, 'd_identifiers': 'F+Medicare+ENGL+NOT SPECIFIED+SINGLE+WHITE'}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# STEP 3: Set the patient linkage task and build the sample dataset\n", + "\n", + "patient_linkage_task = PatientLinkageMIMIC3Task()\n", + "sample_dataset = base_dataset.set_task(task=patient_linkage_task)\n", + "\n", + "print(\"Number of samples generated:\", len(sample_dataset.samples))\n", + "if sample_dataset.samples:\n", + " print(\"Example sample:\\n\", sample_dataset.samples[0])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "831ac79a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Corpus / Query / Qrel summary: corpus=14, queries=14, qrels=14\n", + "Train queries: 9, Val queries: 2, Test queries: 3\n" + ] + } + ], + "source": [ + "# Convert samples to IR format and split train/val/test\n", + "from pyhealth.models.medlink import convert_to_ir_format, tvt_split\n", + "\n", + "corpus, queries, qrels, corpus_meta, queries_meta = convert_to_ir_format(\n", + " sample_dataset.samples\n", + ")\n", + "\n", + "tr_queries, va_queries, te_queries, tr_qrels, va_qrels, te_qrels = tvt_split(queries, qrels)\n", + "\n", + "print(f\"Corpus / Query / Qrel summary: corpus={len(corpus)}, queries={len(queries)}, qrels={len(qrels)}\")\n", + "print(f\"Train queries: {len(tr_queries)}, Val queries: {len(va_queries)}, Test queries: {len(te_queries)}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "1f69690e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded 9 training pairs.\n", + "Loaded 2 training pairs.\n", + "Loaded 14 eval corpus.\n", + "Loaded 3 eval queries.\n", + "query_id 9\n", + "id_p 9\n", + "s_q 9\n", + "s_p 9\n" + ] + } + ], + "source": [ + "USE_BM25_HARDNEGS = False\n", + "\n", + "# Optionally refine training qrels with BM25-based hard negatives\n", + "if USE_BM25_HARDNEGS:\n", + " bm25_model = BM25Okapi(corpus)\n", + " tr_qrels = get_bm25_hard_negatives(\n", + " bm25_model, corpus, tr_queries, tr_qrels\n", + " )\n", + "\n", + "# STEP 4: Dataloaders for training / validation / test\n", + "train_dataloader = get_train_dataloader(\n", + " corpus=corpus,\n", + " queries=tr_queries,\n", + " qrels=tr_qrels,\n", + " batch_size=32,\n", + " shuffle=True,\n", + ")\n", + "\n", + "val_dataloader = get_train_dataloader(\n", + " corpus=corpus,\n", + " queries=va_queries,\n", + " qrels=va_qrels,\n", + " batch_size=32,\n", + " shuffle=False,\n", + ")\n", + "\n", + "test_corpus_dataloader, test_queries_dataloader = get_eval_dataloader(\n", + " corpus=corpus,\n", + " queries=te_queries,\n", + " batch_size=32,\n", + ")\n", + "\n", + "batch = next(iter(train_dataloader))\n", + "for k, v in batch.items():\n", + " print(k, type(v), (len(v) if hasattr(v, \"__len__\") else None))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "eae98819", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded 9 training pairs.\n", + "dict_keys(['query_id', 'id_p', 's_q', 's_p'])\n" + ] + } + ], + "source": [ + "# Build train_loader for MedLink (run this before the Step 5 MedLink cell)\n", + "\n", + "from pyhealth.models.medlink import get_train_dataloader, tvt_split\n", + "\n", + "tr_queries, va_queries, te_queries, tr_qrels, va_qrels, te_qrels = tvt_split(\n", + " queries, qrels\n", + ")\n", + "\n", + "train_loader = get_train_dataloader(\n", + " corpus=corpus,\n", + " queries=tr_queries,\n", + " qrels=tr_qrels,\n", + " batch_size=32,\n", + " shuffle=True,\n", + ")\n", + "\n", + "# quick sanity check\n", + "batch = next(iter(train_loader))\n", + "print(batch.keys())\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c877b5ba", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Raw batch keys: dict_keys(['query_id', 'id_p', 's_q', 's_p'])\n", + "MedLink outputs keys: dict_keys(['loss'])\n", + "Loss: 32.6289176940918\n", + "Backward pass completed.\n" + ] + } + ], + "source": [ + "import torch\n", + "from pyhealth.models import BaseModel\n", + "from pyhealth.datasets import SampleDataset\n", + "from pyhealth.models.medlink.model import MedLink\n", + "\n", + "# ---------------------------------------------------------\n", + "# 1) Patch BaseModel.__init__ so MedLink's legacy kwargs are ignored\n", + "# ---------------------------------------------------------\n", + "if not hasattr(BaseModel, \"_orig_init_for_medlink\"):\n", + " BaseModel._orig_init_for_medlink = BaseModel.__init__\n", + "\n", + " def _patched_bm_init(self, dataset=None, *args, **kwargs):\n", + " # MedLink passes feature_keys, label_key, mode; ignore them here\n", + " return BaseModel._orig_init_for_medlink(self, dataset=dataset)\n", + "\n", + " BaseModel.__init__ = _patched_bm_init\n", + "\n", + "# ---------------------------------------------------------\n", + "# 2) Patch SampleDataset.get_all_tokens used in MedLink.__init__\n", + "# ---------------------------------------------------------\n", + "if not hasattr(SampleDataset, \"get_all_tokens\"):\n", + " def _get_all_tokens(self, key, remove_duplicates=True, sort=False):\n", + " tokens = []\n", + "\n", + " for sample in self.samples:\n", + " if key not in sample:\n", + " continue\n", + " value = sample[key]\n", + "\n", + " # Flatten nested lists/tuples\n", + " stack = [value]\n", + " while stack:\n", + " cur = stack.pop()\n", + " if isinstance(cur, (list, tuple)):\n", + " stack.extend(cur)\n", + " else:\n", + " tokens.append(cur)\n", + "\n", + " if remove_duplicates:\n", + " seen = set()\n", + " uniq = []\n", + " for t in tokens:\n", + " if t in seen:\n", + " continue\n", + " seen.add(t)\n", + " uniq.append(t)\n", + " tokens = uniq\n", + "\n", + " if sort:\n", + " try:\n", + " tokens = sorted(tokens)\n", + " except Exception:\n", + " pass\n", + "\n", + " return tokens\n", + "\n", + " SampleDataset.get_all_tokens = _get_all_tokens\n", + "\n", + "# ---------------------------------------------------------\n", + "# 3) Helper: normalize sequences so tokenizer sees lists, not tensors\n", + "# ---------------------------------------------------------\n", + "def _normalize_seqs(obj):\n", + " \"\"\"\n", + " Convert batch field (tensor or list of tensors/lists) into\n", + " List[List[str]] as expected by Tokenizer.batch_encode_2d.\n", + " \"\"\"\n", + " if torch.is_tensor(obj):\n", + " obj = obj.tolist() # -> list[list[int]]\n", + "\n", + " seqs_out = []\n", + " for seq in obj:\n", + " if torch.is_tensor(seq):\n", + " seq = seq.tolist()\n", + " # at this point seq is list[int] or list[str]\n", + " seqs_out.append([str(tok) for tok in seq])\n", + " return seqs_out\n", + "\n", + "# ---------------------------------------------------------\n", + "# 4) Instantiate MedLink and run a single forward/backward pass\n", + "# ---------------------------------------------------------\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "# sample_dataset and train_loader must already be defined in earlier cells\n", + "model = MedLink(\n", + " dataset=sample_dataset,\n", + " feature_keys=[\"conditions\"],\n", + " embedding_dim=128,\n", + ").to(device)\n", + "\n", + "# Take one batch from the MedLink train dataloader\n", + "batch = next(iter(train_loader))\n", + "print(\"Raw batch keys:\", batch.keys())\n", + "\n", + "# Normalize the sequence fields so AdmissionPrediction/Tokenizer work\n", + "if \"s_q\" in batch:\n", + " batch[\"s_q\"] = _normalize_seqs(batch[\"s_q\"])\n", + "if \"s_p\" in batch:\n", + " batch[\"s_p\"] = _normalize_seqs(batch[\"s_p\"])\n", + "if \"s_n\" in batch and batch[\"s_n\"] is not None:\n", + " batch[\"s_n\"] = _normalize_seqs(batch[\"s_n\"])\n", + "\n", + "model.train()\n", + "outputs = model(**batch)\n", + "print(\"MedLink outputs keys:\", outputs.keys())\n", + "print(\"Loss:\", float(outputs[\"loss\"]))\n", + "\n", + "outputs[\"loss\"].backward()\n", + "print(\"Backward pass completed.\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "03113472", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 0: avg loss = 30.2854\n", + "epoch 1: avg loss = 34.0946\n", + "epoch 2: avg loss = 27.1799\n" + ] + } + ], + "source": [ + "#Sanity\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", + "\n", + "for epoch in range(3):\n", + " total = 0.0\n", + " n = 0\n", + " for batch in train_loader:\n", + " # normalize s_q / s_p as before\n", + " batch[\"s_q\"] = _normalize_seqs(batch[\"s_q\"])\n", + " batch[\"s_p\"] = _normalize_seqs(batch[\"s_p\"])\n", + " if \"s_n\" in batch and batch[\"s_n\"] is not None:\n", + " batch[\"s_n\"] = _normalize_seqs(batch[\"s_n\"])\n", + "\n", + " optimizer.zero_grad()\n", + " out = model(**batch)\n", + " loss = out[\"loss\"]\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " total += float(loss)\n", + " n += 1\n", + " print(f\"epoch {epoch}: avg loss = {total / max(n,1):.4f}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "ed96b498", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing samples: 100%|███████████████████████| 2/2 [00:00<00:00, 6641.81it/s]\n", + "Processing samples: 100%|██████████████████████| 2/2 [00:00<00:00, 60787.01it/s]\n", + "Processing samples: 100%|██████████████████████| 2/2 [00:00<00:00, 64527.75it/s]\n", + ".\n", + "----------------------------------------------------------------------\n", + "Ran 3 tests in 0.037s\n", + "\n", + "OK\n" + ] + } + ], + "source": [ + "!python /Users/saurabhatri/Downloads/PyHealth/tests/core/test_medlink.py" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e8b452d2", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyhealth/__init__.py b/pyhealth/__init__.py index efd7e39b7..722483dcb 100755 --- a/pyhealth/__init__.py +++ b/pyhealth/__init__.py @@ -18,3 +18,4 @@ formatter = logging.Formatter("%(message)s") handler.setFormatter(formatter) logger.addHandler(handler) + diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 54c40420c..3aeeb9616 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -86,6 +86,38 @@ def __init__( self.validate() self.build() + def get_all_tokens(self, key: str) -> List[str]: + """ + Collect all tokens under a given key across samples. + + This is mainly used by MedLink to build its vocabulary. + It assumes that sample[key] is either: + - a sequence (list/tuple) of tokens, or + - a scalar token (str/int/etc.). + """ + tokens: List[str] = [] + seen = set() + + for sample in self.samples: + if key not in sample: + continue + value = sample[key] + + if isinstance(value, (list, tuple)): + values = value + else: + values = [value] + + for v in values: + if v is None: + continue + s = str(v) + if s in seen: + continue + seen.add(s) + tokens.append(s) + + return tokens def _get_processor_instance(self, processor_spec): """Get processor instance from either string alias, class reference, processor instance, or tuple with kwargs. @@ -182,4 +214,4 @@ def __len__(self) -> int: Returns: int: The number of samples. """ - return len(self.samples) + return len(self.samples) \ No newline at end of file diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5c3683bc1..ee606158f 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -26,4 +26,5 @@ from .transformer import Transformer, TransformerLayer from .transformers_model import TransformersModel from .vae import VAE -from .sdoh import SdohClassifier \ No newline at end of file +from .sdoh import SdohClassifier +from .medlink import MedLink diff --git a/pyhealth/models/medlink/model.py b/pyhealth/models/medlink/model.py index f4e2b4ddb..1dea4fe5e 100644 --- a/pyhealth/models/medlink/model.py +++ b/pyhealth/models/medlink/model.py @@ -103,12 +103,7 @@ def __init__( **kwargs, ): assert len(feature_keys) == 1, "MedLink only supports one feature key" - super(MedLink, self).__init__( - dataset=dataset, - feature_keys=feature_keys, - label_key=None, - mode=None, - ) + super(MedLink, self).__init__(dataset=dataset) self.feature_key = feature_keys[0] self.embedding_dim = embedding_dim self.alpha = alpha @@ -121,13 +116,15 @@ def __init__( special_tokens=["", "", ""], ) self.fwd_adm_pred = AdmissionPrediction(tokenizer, embedding_dim, **kwargs) + self.forward_encoder = self.fwd_adm_pred.encoder self.bwd_adm_pred = AdmissionPrediction(tokenizer, embedding_dim, **kwargs) + self.backward_encoder = self.bwd_adm_pred.encoder self.criterion = nn.CrossEntropyLoss() - self.vocabs_size = tokenizer.get_vocabulary_size() + self.vocab_size = tokenizer.get_vocabulary_size() return def encode_queries(self, queries: List[str]): - all_vocab = torch.tensor(list(range(self.vocabs_size)), device=self.device) + all_vocab = torch.tensor(list(range(self.vocab_size)), device=self.device) bwd_vocab_emb = self.bwd_adm_pred.embedding(all_vocab) pred_corpus, queries_one_hot = self.bwd_adm_pred( queries, bwd_vocab_emb, device=self.device @@ -137,7 +134,7 @@ def encode_queries(self, queries: List[str]): return queries_emb def encode_corpus(self, corpus: List[str]): - all_vocab = torch.tensor(list(range(self.vocabs_size)), device=self.device) + all_vocab = torch.tensor(list(range(self.vocab_size)), device=self.device) fwd_vocab_emb = self.fwd_adm_pred.embedding(all_vocab) pred_queries, corpus_one_hot = self.fwd_adm_pred( corpus, fwd_vocab_emb, device=self.device @@ -165,7 +162,7 @@ def get_loss(self, scores): def forward(self, query_id, id_p, s_q, s_p, s_n=None) -> Dict[str, torch.Tensor]: corpus = s_p if s_n is None else s_p + s_n queries = s_q - all_vocab = torch.tensor(list(range(self.vocabs_size)), device=self.device) + all_vocab = torch.tensor(list(range(self.vocab_size)), device=self.device) fwd_vocab_emb = self.fwd_adm_pred.embedding(all_vocab) bwd_vocab_emb = self.bwd_adm_pred.embedding(all_vocab) pred_queries, corpus_one_hot = self.fwd_adm_pred( @@ -262,4 +259,4 @@ def evaluate(self, corpus_dataloader, queries_dataloader): with torch.autograd.detect_anomaly(): o = model(**batch) print("loss:", o["loss"]) - o["loss"].backward() + o["loss"].backward() \ No newline at end of file diff --git a/tests/core/test_medlink.py b/tests/core/test_medlink.py new file mode 100644 index 000000000..c718d3dc9 --- /dev/null +++ b/tests/core/test_medlink.py @@ -0,0 +1,122 @@ +import unittest +import torch + +from pyhealth.datasets import SampleDataset +from pyhealth.models import MedLink + + +class TestMedLink(unittest.TestCase): + """Basic tests for the MedLink model on pseudo data.""" + + def setUp(self): + # Each "sample" here is a simple patient-record placeholder + # The dataset is only used to build the vocabulary via get_all_tokens. + self.samples = [ + { + "patient_id": "p0", + "visit_id": "v0", + # query-side codes + "conditions": ["A", "B", "C"], + # corpus-side codes ("d_" + feature_key) + "d_conditions": ["A", "D"], + }, + { + "patient_id": "p1", + "visit_id": "v1", + "conditions": ["B", "E"], + "d_conditions": ["C", "E", "F"], + }, + ] + + # Two sequence-type inputs: conditions and d_conditions + self.input_schema = { + "conditions": "sequence", + "d_conditions": "sequence", + } + # No labels are needed; MedLink is self-supervised + self.output_schema = {} + + self.dataset = SampleDataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="medlink_test", + ) + + self.model = MedLink( + dataset=self.dataset, + feature_keys=["conditions"], + embedding_dim=32, + alpha=0.5, + beta=0.5, + gamma=1.0, + ) + + def _make_batch(self): + # Construct a tiny batch in the format expected by MedLink.forward + # s_q: list of query sequences + s_q = [ + ["A", "B", "C"], + ["B", "E"], + ] + # s_p: list of positive corpus sequences + s_p = [ + ["A", "D"], + ["C", "E", "F"], + ] + # Optionally you could also define negatives s_n = [...] + batch = { + "query_id": ["q0", "q1"], + "id_p": ["p0", "p1"], + "s_q": s_q, + "s_p": s_p, + # no s_n -> defaults to None + } + return batch + + def test_model_initialization(self): + """Model constructs with correct vocabulary size and encoders.""" + self.assertIsInstance(self.model, MedLink) + self.assertEqual(self.model.feature_key, "conditions") + self.assertGreater(self.model.vocab_size, 0) + self.assertIsNotNone(self.model.forward_encoder) + self.assertIsNotNone(self.model.backward_encoder) + + def test_forward_and_backward(self): + """Forward pass returns a scalar loss and backward computes gradients.""" + batch = self._make_batch() + + # Forward + ret = self.model(**batch) + self.assertIn("loss", ret) + loss = ret["loss"] + self.assertTrue(torch.is_tensor(loss)) + self.assertEqual(loss.dim(), 0) # scalar + + # Backward + loss.backward() + has_grad = any( + p.requires_grad and p.grad is not None + for p in self.model.parameters() + ) + self.assertTrue(has_grad, "No gradients after backward pass") + + def test_encoding_helpers(self): + """encode_queries / encode_corpus produce consistent shapes.""" + queries = [["A", "B"], ["C"]] + corpus = [["A"], ["B", "C"]] + + q_emb = self.model.encode_queries(queries) + c_emb = self.model.encode_corpus(corpus) + + self.assertEqual(q_emb.shape[1], self.model.vocab_size) + self.assertEqual(c_emb.shape[1], self.model.vocab_size) + self.assertEqual(q_emb.shape[0], len(queries)) + self.assertEqual(c_emb.shape[0], len(corpus)) + + scores = self.model.compute_scores(q_emb, c_emb) + self.assertEqual(scores.shape, (len(queries), len(corpus))) + + +if __name__ == "__main__": + unittest.main() From b14293833dacb929e58626ca358795eaa2b308b6 Mon Sep 17 00:00:00 2001 From: Rian354 Date: Mon, 8 Dec 2025 12:50:22 -0500 Subject: [PATCH 2/8] Notebook Clean Up --- examples/medlink_mimic3.ipynb | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/medlink_mimic3.ipynb b/examples/medlink_mimic3.ipynb index 9f408d889..83a6160a4 100644 --- a/examples/medlink_mimic3.ipynb +++ b/examples/medlink_mimic3.ipynb @@ -624,7 +624,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "id": "ed96b498", "metadata": {}, "outputs": [ @@ -644,7 +644,8 @@ } ], "source": [ - "!python /Users/saurabhatri/Downloads/PyHealth/tests/core/test_medlink.py" + "#Medlink Unit Tests\n", + "!pytest tests/core/test_medlink.py" ] }, { From 8369ed39b329cbda56f18886b3a86436be5eee5d Mon Sep 17 00:00:00 2001 From: Rian354 Date: Mon, 8 Dec 2025 12:55:14 -0500 Subject: [PATCH 3/8] Further notebook modification --- examples/medlink_mimic3.ipynb | 82 +++++++++++++---------------------- 1 file changed, 31 insertions(+), 51 deletions(-) diff --git a/examples/medlink_mimic3.ipynb b/examples/medlink_mimic3.ipynb index 83a6160a4..58d45141a 100644 --- a/examples/medlink_mimic3.ipynb +++ b/examples/medlink_mimic3.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 25, "id": "1ee5347e", "metadata": {}, "outputs": [ @@ -11,16 +11,8 @@ "output_type": "stream", "text": [ "PROJECT_ROOT: /Users/saurabhatri/Downloads/PyHealth\n", - "✓ PyTorch is installed\n", - "✓ pyhealth is importable, version: 1.1.4\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/saurabhatri/Downloads/PyHealth/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" + "PyTorch is installed\n", + "pyhealth is importable, version: 1.1.4\n" ] } ], @@ -38,21 +30,18 @@ "# Basic sanity check for torch and pyhealth\n", "try:\n", " import torch\n", - " print(\"✓ PyTorch is installed\")\n", + " print(\"PyTorch is installed\")\n", "except ImportError as e:\n", " raise RuntimeError(\n", - " \"PyTorch is not installed. Install it into your environment first \"\n", - " \"(e.g., `pip install torch` matching your CUDA/CPU).\" \n", + " \"PyTorch is not installed. Install it into your environment first \" \n", " ) from e\n", "\n", "try:\n", " import pyhealth\n", - " print(\"✓ pyhealth is importable, version:\", getattr(pyhealth, \"__version__\", \"unknown\"))\n", + " print(\"pyhealth is importable, version:\", getattr(pyhealth, \"__version__\", \"unknown\"))\n", "except ImportError as e:\n", " raise RuntimeError(\n", - " \"pyhealth is not importable. From the project root, run\\n\"\n", - " \" pip install -e .\\n\"\n", - " \"to install PyHealth in editable mode.\"\n", + " \"pyhealth is not importable.\"\n", " ) from e\n", "\n", "# Core dataset + MedLink imports\n", @@ -95,7 +84,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "f0851481", "metadata": {}, "outputs": [ @@ -125,11 +114,11 @@ } ], "source": [ - "# STEP 1: Load base MIMIC-III dataset from the demo\n", + "# Load base MIMIC-III dataset from the demo\n", "\n", "base_dataset = MIMIC3Dataset(\n", " root=MIMIC3_DEMO_ROOT,\n", - " tables=[\"diagnoses_icd\"], # matches `diagnoses_icd` in configs/mimic3.yaml\n", + " tables=[\"diagnoses_icd\"], # matches in configs/mimic3.yaml\n", " dev=False, # True => small subset of patients\n", ")\n", "\n", @@ -138,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "8969fcbd", "metadata": {}, "outputs": [], @@ -152,12 +141,12 @@ " Patient linkage task for MIMIC-III using the new Patient/Event API.\n", "\n", " It produces samples with the same keys as the old\n", - " `patient_linkage_mimic3_fn`, so that medlink.utils.convert_to_ir_format\n", + " 'patient_linkage_mimic3_fn', so that medlink.utils.convert_to_ir_format\n", " works unchanged.\n", " \"\"\"\n", "\n", " task_name = \"patient_linkage_mimic3\"\n", - " # MedLink actually consumes `conditions` / `d_conditions` as sequences,\n", + " # MedLink actually consumes conditions / d_conditions as sequences,\n", " # but we don't rely heavily on the feature processors here.\n", " input_schema = {\n", " \"conditions\": \"sequence\",\n", @@ -175,7 +164,7 @@ " - Age >= 18 at both visits\n", " - Non-empty conditions for both visits\n", " \"\"\"\n", - " # --- 1) Get admissions (visits), sorted by time ---\n", + " # Get admissions (visits), sorted by time\n", " admissions = patient.get_events(event_type=\"admissions\")\n", " if len(admissions) < 2:\n", " return []\n", @@ -184,7 +173,7 @@ " q_visit = admissions[-1] # last visit (query)\n", " d_visit = admissions[-2] # second last visit (document)\n", "\n", - " # --- 2) Get patient demographics (gender, dob) ---\n", + " # get patient demographics (gender, dob)\n", " patients_events = patient.get_events(event_type=\"patients\")\n", " if not patients_events:\n", " return []\n", @@ -216,7 +205,7 @@ " if q_age is None or d_age is None or q_age < 18 or d_age < 18:\n", " return []\n", "\n", - " # --- 3) Collect diagnosis codes per admission (hadm_id) ---\n", + " # collect diagnosis codes per admission (hadm_id)\n", " diag_events = patient.get_events(event_type=\"diagnoses_icd\")\n", " hadm_to_codes = defaultdict(list)\n", " for ev in diag_events:\n", @@ -236,7 +225,7 @@ " if len(q_conditions) == 0 or len(d_conditions) == 0:\n", " return []\n", "\n", - " # --- 4) Identifier strings (gender + admin attributes) ---\n", + " #id strings\n", " def clean(x):\n", " # mimic old NaN handling: empty string if missing/NaN\n", " if x is None:\n", @@ -258,7 +247,7 @@ " q_identifiers = build_identifiers(q_visit)\n", " d_identifiers = build_identifiers(d_visit)\n", "\n", - " # --- 5) Build sample dict (same keys as old function) ---\n", + " #Build sample dict\n", " sample = {\n", " \"patient_id\": patient.patient_id,\n", " \"visit_id\": q_hadm, # query visit_id\n", @@ -277,7 +266,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "bce967de", "metadata": {}, "outputs": [ @@ -318,7 +307,7 @@ } ], "source": [ - "# STEP 3: Set the patient linkage task and build the sample dataset\n", + "# patient linkage task and build the sample dataset\n", "\n", "patient_linkage_task = PatientLinkageMIMIC3Task()\n", "sample_dataset = base_dataset.set_task(task=patient_linkage_task)\n", @@ -359,7 +348,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "1f69690e", "metadata": {}, "outputs": [ @@ -381,14 +370,14 @@ "source": [ "USE_BM25_HARDNEGS = False\n", "\n", - "# Optionally refine training qrels with BM25-based hard negatives\n", + "# optionally refine training qrels with BM25-based hard negatives\n", "if USE_BM25_HARDNEGS:\n", " bm25_model = BM25Okapi(corpus)\n", " tr_qrels = get_bm25_hard_negatives(\n", " bm25_model, corpus, tr_queries, tr_qrels\n", " )\n", "\n", - "# STEP 4: Dataloaders for training / validation / test\n", + "#Dataloaders for training / validation / test\n", "train_dataloader = get_train_dataloader(\n", " corpus=corpus,\n", " queries=tr_queries,\n", @@ -418,7 +407,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "eae98819", "metadata": {}, "outputs": [ @@ -432,7 +421,7 @@ } ], "source": [ - "# Build train_loader for MedLink (run this before the Step 5 MedLink cell)\n", + "# Build train_loader for MedLink\n", "\n", "from pyhealth.models.medlink import get_train_dataloader, tvt_split\n", "\n", @@ -455,7 +444,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "c877b5ba", "metadata": {}, "outputs": [ @@ -465,7 +454,7 @@ "text": [ "Raw batch keys: dict_keys(['query_id', 'id_p', 's_q', 's_p'])\n", "MedLink outputs keys: dict_keys(['loss'])\n", - "Loss: 32.6289176940918\n", + "Loss: 29.643001556396484\n", "Backward pass completed.\n" ] } @@ -476,9 +465,7 @@ "from pyhealth.datasets import SampleDataset\n", "from pyhealth.models.medlink.model import MedLink\n", "\n", - "# ---------------------------------------------------------\n", - "# 1) Patch BaseModel.__init__ so MedLink's legacy kwargs are ignored\n", - "# ---------------------------------------------------------\n", + "# Patch BaseModel.__init__ so MedLink's legacy kwargs are ignored\n", "if not hasattr(BaseModel, \"_orig_init_for_medlink\"):\n", " BaseModel._orig_init_for_medlink = BaseModel.__init__\n", "\n", @@ -487,10 +474,7 @@ " return BaseModel._orig_init_for_medlink(self, dataset=dataset)\n", "\n", " BaseModel.__init__ = _patched_bm_init\n", - "\n", - "# ---------------------------------------------------------\n", - "# 2) Patch SampleDataset.get_all_tokens used in MedLink.__init__\n", - "# ---------------------------------------------------------\n", + "#monkey patching, corrected now\n", "if not hasattr(SampleDataset, \"get_all_tokens\"):\n", " def _get_all_tokens(self, key, remove_duplicates=True, sort=False):\n", " tokens = []\n", @@ -529,9 +513,7 @@ "\n", " SampleDataset.get_all_tokens = _get_all_tokens\n", "\n", - "# ---------------------------------------------------------\n", - "# 3) Helper: normalize sequences so tokenizer sees lists, not tensors\n", - "# ---------------------------------------------------------\n", + "# normalize sequences so tokenizer sees lists, not tensors\n", "def _normalize_seqs(obj):\n", " \"\"\"\n", " Convert batch field (tensor or list of tensors/lists) into\n", @@ -548,9 +530,7 @@ " seqs_out.append([str(tok) for tok in seq])\n", " return seqs_out\n", "\n", - "# ---------------------------------------------------------\n", - "# 4) Instantiate MedLink and run a single forward/backward pass\n", - "# ---------------------------------------------------------\n", + "#init medlink and run a single forward/backward pass\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "# sample_dataset and train_loader must already be defined in earlier cells\n", From 456b5bceba4c8a9068bd4c039164feed583f1c4a Mon Sep 17 00:00:00 2001 From: Rian354 Date: Mon, 8 Dec 2025 14:57:14 -0500 Subject: [PATCH 4/8] Removed redecleration of methods --- examples/medlink_mimic3.ipynb | 112 +++++++--------------------------- 1 file changed, 23 insertions(+), 89 deletions(-) diff --git a/examples/medlink_mimic3.ipynb b/examples/medlink_mimic3.ipynb index 58d45141a..48fdf2fa4 100644 --- a/examples/medlink_mimic3.ipynb +++ b/examples/medlink_mimic3.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 25, + "execution_count": 15, "id": "1ee5347e", "metadata": {}, "outputs": [ @@ -10,7 +10,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "PROJECT_ROOT: /Users/saurabhatri/Downloads/PyHealth\n", + "PROJECT_ROOT: /Users/saurabhatri/Downloads\n", "PyTorch is installed\n", "pyhealth is importable, version: 1.1.4\n" ] @@ -78,13 +78,13 @@ "source": [ "# Downloaded from: https://physionet.org/content/mimiciii-demo/1.4/\n", "MIMIC3_DEMO_ROOT = \"/path/to/mimic-iii-clinical-database-demo-1.4\" # <-- adjust for real\n", - "\n", + "#MIMIC3_DEMO_ROOT = \"/Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4\"\n", "print(\"MIMIC-III demo root:\", MIMIC3_DEMO_ROOT)\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "id": "f0851481", "metadata": {}, "outputs": [ @@ -127,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "id": "8969fcbd", "metadata": {}, "outputs": [], @@ -266,7 +266,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "id": "bce967de", "metadata": {}, "outputs": [ @@ -282,8 +282,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "Generating samples for patient_linkage_mimic3 with 1 worker: 100%|██████████| 100/100 [00:00<00:00, 1499.64it/s]\n", - "Processing samples: 100%|██████████| 14/14 [00:00<00:00, 35246.25it/s]" + "Generating samples for patient_linkage_mimic3 with 1 worker: 100%|██████████| 100/100 [00:00<00:00, 1569.65it/s]\n", + "Processing samples: 100%|██████████| 14/14 [00:00<00:00, 1738.06it/s]" ] }, { @@ -293,9 +293,7 @@ "Generated 14 samples for task patient_linkage_mimic3\n", "Number of samples generated: 14\n", "Example sample:\n", - " {'patient_id': '42346', 'visit_id': '175880', 'conditions': tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,\n", - " 19, 20]), 'age': 88, 'identifiers': 'F+Medicare+ENGL+NOT SPECIFIED+SINGLE+WHITE', 'd_visit_id': '180391', 'd_conditions': tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,\n", - " 19, 20, 21, 22, 23]), 'd_age': 88, 'd_identifiers': 'F+Medicare+ENGL+NOT SPECIFIED+SINGLE+WHITE'}\n" + " {'patient_id': '44083', 'visit_id': '198330', 'conditions': tensor([1, 2, 3, 4, 5, 6, 7]), 'age': 54, 'identifiers': 'M+Private+ENGL+CATHOLIC+SINGLE+WHITE', 'd_visit_id': '131048', 'd_conditions': tensor([1, 2, 3, 4, 5, 6]), 'd_age': 54, 'd_identifiers': 'M+Private+ENGL+CATHOLIC+SINGLE+WHITE'}\n" ] }, { @@ -319,7 +317,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 20, "id": "831ac79a", "metadata": {}, "outputs": [ @@ -348,7 +346,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "id": "1f69690e", "metadata": {}, "outputs": [ @@ -407,7 +405,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "id": "eae98819", "metadata": {}, "outputs": [ @@ -444,7 +442,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "id": "c877b5ba", "metadata": {}, "outputs": [ @@ -454,7 +452,7 @@ "text": [ "Raw batch keys: dict_keys(['query_id', 'id_p', 's_q', 's_p'])\n", "MedLink outputs keys: dict_keys(['loss'])\n", - "Loss: 29.643001556396484\n", + "Loss: 24.614633560180664\n", "Backward pass completed.\n" ] } @@ -465,54 +463,6 @@ "from pyhealth.datasets import SampleDataset\n", "from pyhealth.models.medlink.model import MedLink\n", "\n", - "# Patch BaseModel.__init__ so MedLink's legacy kwargs are ignored\n", - "if not hasattr(BaseModel, \"_orig_init_for_medlink\"):\n", - " BaseModel._orig_init_for_medlink = BaseModel.__init__\n", - "\n", - " def _patched_bm_init(self, dataset=None, *args, **kwargs):\n", - " # MedLink passes feature_keys, label_key, mode; ignore them here\n", - " return BaseModel._orig_init_for_medlink(self, dataset=dataset)\n", - "\n", - " BaseModel.__init__ = _patched_bm_init\n", - "#monkey patching, corrected now\n", - "if not hasattr(SampleDataset, \"get_all_tokens\"):\n", - " def _get_all_tokens(self, key, remove_duplicates=True, sort=False):\n", - " tokens = []\n", - "\n", - " for sample in self.samples:\n", - " if key not in sample:\n", - " continue\n", - " value = sample[key]\n", - "\n", - " # Flatten nested lists/tuples\n", - " stack = [value]\n", - " while stack:\n", - " cur = stack.pop()\n", - " if isinstance(cur, (list, tuple)):\n", - " stack.extend(cur)\n", - " else:\n", - " tokens.append(cur)\n", - "\n", - " if remove_duplicates:\n", - " seen = set()\n", - " uniq = []\n", - " for t in tokens:\n", - " if t in seen:\n", - " continue\n", - " seen.add(t)\n", - " uniq.append(t)\n", - " tokens = uniq\n", - "\n", - " if sort:\n", - " try:\n", - " tokens = sorted(tokens)\n", - " except Exception:\n", - " pass\n", - "\n", - " return tokens\n", - "\n", - " SampleDataset.get_all_tokens = _get_all_tokens\n", - "\n", "# normalize sequences so tokenizer sees lists, not tensors\n", "def _normalize_seqs(obj):\n", " \"\"\"\n", @@ -563,7 +513,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 24, "id": "03113472", "metadata": {}, "outputs": [ @@ -571,9 +521,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "epoch 0: avg loss = 30.2854\n", - "epoch 1: avg loss = 34.0946\n", - "epoch 2: avg loss = 27.1799\n" + "epoch 0: avg loss = 24.8576\n", + "epoch 1: avg loss = 24.3715\n", + "epoch 2: avg loss = 17.6286\n" ] } ], @@ -604,34 +554,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 25, "id": "ed96b498", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Processing samples: 100%|███████████████████████| 2/2 [00:00<00:00, 6641.81it/s]\n", - "Processing samples: 100%|██████████████████████| 2/2 [00:00<00:00, 60787.01it/s]\n", - "Processing samples: 100%|██████████████████████| 2/2 [00:00<00:00, 64527.75it/s]\n", - ".\n", - "----------------------------------------------------------------------\n", - "Ran 3 tests in 0.037s\n", - "\n", - "OK\n" - ] - } - ], + "outputs": [], "source": [ - "#Medlink Unit Tests\n", - "!pytest tests/core/test_medlink.py" + "#Unit test script - pytest tests/core/test_medlink.py\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "e8b452d2", + "id": "b77a6ee2", "metadata": {}, "outputs": [], "source": [] @@ -653,7 +587,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.6" + "version": "3.13.3" } }, "nbformat": 4, From 851f696457c0f541e818c809bd7d332e84494438 Mon Sep 17 00:00:00 2001 From: Rian354 Date: Mon, 22 Dec 2025 01:27:39 -0500 Subject: [PATCH 5/8] MedLink bounty, processor-native model + tests + MIMIC-III notebook --- examples/medlink_mimic3.ipynb | 436 ++++++++++++--------- examples/test_eICU_addition.py | 25 +- pyhealth/datasets/sample_dataset.py | 31 -- pyhealth/models/embedding.py | 118 +++++- pyhealth/models/medlink.py | 423 +++++++++++++++++++++ pyhealth/models/medlink/model.py | 462 +++++++++++++++-------- pyhealth/tasks/__init__.py | 1 + pyhealth/tasks/patient_linkage_mimic3.py | 106 ++++++ tests/__init__.py | 2 + tests/core/test_medlink.py | 5 +- tests/core/test_sdoh.py | 2 +- tests/nlp/test_metrics.py | 2 +- tests/todo/test_datasets/test_eicu.py | 2 +- tests/todo/test_datasets/test_mimic3.py | 4 +- tests/todo/test_datasets/test_mimic4.py | 19 +- tests/todo/test_datasets/test_omop.py | 4 +- tests/todo/test_mortality_prediction.py | 6 +- 17 files changed, 1257 insertions(+), 391 deletions(-) create mode 100644 pyhealth/models/medlink.py create mode 100644 pyhealth/tasks/patient_linkage_mimic3.py create mode 100644 tests/__init__.py diff --git a/examples/medlink_mimic3.ipynb b/examples/medlink_mimic3.ipynb index 48fdf2fa4..421aeeefc 100644 --- a/examples/medlink_mimic3.ipynb +++ b/examples/medlink_mimic3.ipynb @@ -2,15 +2,28 @@ "cells": [ { "cell_type": "code", - "execution_count": 15, + "execution_count": 1, "id": "1ee5347e", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2025-12-22T06:22:52.819501Z", + "iopub.status.busy": "2025-12-22T06:22:52.819215Z", + "iopub.status.idle": "2025-12-22T06:22:59.559734Z", + "shell.execute_reply": "2025-12-22T06:22:59.559502Z" + } + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "PROJECT_ROOT: /Users/saurabhatri/Downloads\n", + "PROJECT_ROOT: /Users/saurabhatri/Downloads/PyHealth\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "PyTorch is installed\n", "pyhealth is importable, version: 1.1.4\n" ] @@ -65,7 +78,14 @@ "cell_type": "code", "execution_count": null, "id": "240e358e", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2025-12-22T06:22:59.561383Z", + "iopub.status.busy": "2025-12-22T06:22:59.561144Z", + "iopub.status.idle": "2025-12-22T06:22:59.563040Z", + "shell.execute_reply": "2025-12-22T06:22:59.562832Z" + } + }, "outputs": [ { "name": "stdout", @@ -84,28 +104,119 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 3, "id": "f0851481", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2025-12-22T06:22:59.564228Z", + "iopub.status.busy": "2025-12-22T06:22:59.564145Z", + "iopub.status.idle": "2025-12-22T06:22:59.579422Z", + "shell.execute_reply": "2025-12-22T06:22:59.579179Z" + } + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "No config path provided, using default config\n", - "Initializing mimic3 dataset from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4 (dev mode: False)\n", - "Scanning table: patients from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/PATIENTS.csv.gz\n", - "Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/PATIENTS.csv\n", - "Scanning table: admissions from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ADMISSIONS.csv.gz\n", - "Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ADMISSIONS.csv\n", - "Scanning table: icustays from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ICUSTAYS.csv.gz\n", - "Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ICUSTAYS.csv\n", - "Scanning table: diagnoses_icd from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/DIAGNOSES_ICD.csv.gz\n", - "Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/DIAGNOSES_ICD.csv\n", - "Joining with table: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ADMISSIONS.csv.gz\n", - "Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ADMISSIONS.csv\n", - "Collecting global event dataframe...\n", - "Collected dataframe with shape: (2126, 31)\n", + "No config path provided, using default config\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Initializing mimic3 dataset from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4 (dev mode: False)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scanning table: patients from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/PATIENTS.csv.gz\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/PATIENTS.csv\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scanning table: admissions from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ADMISSIONS.csv.gz\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ADMISSIONS.csv\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scanning table: icustays from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ICUSTAYS.csv.gz\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ICUSTAYS.csv\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scanning table: diagnoses_icd from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/DIAGNOSES_ICD.csv.gz\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/DIAGNOSES_ICD.csv\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Joining with table: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ADMISSIONS.csv.gz\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ADMISSIONS.csv\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting global event dataframe...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collected dataframe with shape: (2126, 31)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "Dataset: mimic3\n", "Dev mode: False\n", "Number of patients: 100\n", @@ -127,154 +238,48 @@ }, { "cell_type": "code", - "execution_count": 18, - "id": "8969fcbd", - "metadata": {}, + "execution_count": 4, + "id": "5d18d87c", + "metadata": { + "execution": { + "iopub.execute_input": "2025-12-22T06:22:59.580593Z", + "iopub.status.busy": "2025-12-22T06:22:59.580517Z", + "iopub.status.idle": "2025-12-22T06:22:59.582093Z", + "shell.execute_reply": "2025-12-22T06:22:59.581921Z" + } + }, "outputs": [], "source": [ + "from pyhealth.tasks.patient_linkage_mimic3 import PatientLinkageMIMIC3Task\n", "from datetime import datetime\n", "from collections import defaultdict\n", - "import math\n", - "\n", - "class PatientLinkageMIMIC3Task(BaseTask):\n", - " \"\"\"\n", - " Patient linkage task for MIMIC-III using the new Patient/Event API.\n", - "\n", - " It produces samples with the same keys as the old\n", - " 'patient_linkage_mimic3_fn', so that medlink.utils.convert_to_ir_format\n", - " works unchanged.\n", - " \"\"\"\n", - "\n", - " task_name = \"patient_linkage_mimic3\"\n", - " # MedLink actually consumes conditions / d_conditions as sequences,\n", - " # but we don't rely heavily on the feature processors here.\n", - " input_schema = {\n", - " \"conditions\": \"sequence\",\n", - " \"d_conditions\": \"sequence\",\n", - " }\n", - " # No supervised label for MedLink retrieval\n", - " output_schema = {}\n", - "\n", - " def __call__(self, patient):\n", - " \"\"\"\n", - " Process a single patient into MedLink samples.\n", - "\n", - " Requirements (same as original task):\n", - " - At least 2 visits (admissions)\n", - " - Age >= 18 at both visits\n", - " - Non-empty conditions for both visits\n", - " \"\"\"\n", - " # Get admissions (visits), sorted by time\n", - " admissions = patient.get_events(event_type=\"admissions\")\n", - " if len(admissions) < 2:\n", - " return []\n", - "\n", - " admissions = sorted(admissions, key=lambda e: e.timestamp)\n", - " q_visit = admissions[-1] # last visit (query)\n", - " d_visit = admissions[-2] # second last visit (document)\n", - "\n", - " # get patient demographics (gender, dob)\n", - " patients_events = patient.get_events(event_type=\"patients\")\n", - " if not patients_events:\n", - " return []\n", - " demo = patients_events[0]\n", - "\n", - " gender = str(demo.attr_dict.get(\"gender\") or \"\")\n", - "\n", - " dob_raw = demo.attr_dict.get(\"dob\")\n", - " birth_dt = None\n", - " if isinstance(dob_raw, datetime):\n", - " birth_dt = dob_raw\n", - " elif dob_raw is not None:\n", - " # In the MIMIC CSV it's a string like \"2111-04-20 00:00:00\"\n", - " try:\n", - " birth_dt = datetime.fromisoformat(str(dob_raw))\n", - " except Exception:\n", - " birth_dt = None\n", - "\n", - " def compute_age(ts):\n", - " if birth_dt is None or ts is None:\n", - " return None\n", - " # rough years\n", - " return int((ts - birth_dt).days // 365.25)\n", - "\n", - " q_age = compute_age(q_visit.timestamp)\n", - " d_age = compute_age(d_visit.timestamp)\n", - "\n", - " # Exclude under 18 or missing age\n", - " if q_age is None or d_age is None or q_age < 18 or d_age < 18:\n", - " return []\n", - "\n", - " # collect diagnosis codes per admission (hadm_id)\n", - " diag_events = patient.get_events(event_type=\"diagnoses_icd\")\n", - " hadm_to_codes = defaultdict(list)\n", - " for ev in diag_events:\n", - " hadm = ev.attr_dict.get(\"hadm_id\")\n", - " code = ev.attr_dict.get(\"icd9_code\")\n", - " if hadm is None or code is None:\n", - " continue\n", - " hadm_to_codes[str(hadm)].append(str(code))\n", - "\n", - " q_hadm = str(q_visit.attr_dict.get(\"hadm_id\"))\n", - " d_hadm = str(d_visit.attr_dict.get(\"hadm_id\"))\n", - "\n", - " q_conditions = hadm_to_codes.get(q_hadm, [])\n", - " d_conditions = hadm_to_codes.get(d_hadm, [])\n", - "\n", - " # Exclude if any side has no conditions\n", - " if len(q_conditions) == 0 or len(d_conditions) == 0:\n", - " return []\n", - "\n", - " #id strings\n", - " def clean(x):\n", - " # mimic old NaN handling: empty string if missing/NaN\n", - " if x is None:\n", - " return \"\"\n", - " if isinstance(x, float) and math.isnan(x):\n", - " return \"\"\n", - " return str(x)\n", - "\n", - " def build_identifiers(adm_event):\n", - " insurance = clean(adm_event.attr_dict.get(\"insurance\"))\n", - " language = clean(adm_event.attr_dict.get(\"language\"))\n", - " religion = clean(adm_event.attr_dict.get(\"religion\"))\n", - " marital_status = clean(adm_event.attr_dict.get(\"marital_status\"))\n", - " ethnicity = clean(adm_event.attr_dict.get(\"ethnicity\"))\n", - " return \"+\".join(\n", - " [gender, insurance, language, religion, marital_status, ethnicity]\n", - " )\n", - "\n", - " q_identifiers = build_identifiers(q_visit)\n", - " d_identifiers = build_identifiers(d_visit)\n", - "\n", - " #Build sample dict\n", - " sample = {\n", - " \"patient_id\": patient.patient_id,\n", - " \"visit_id\": q_hadm, # query visit_id\n", - " \"conditions\": [\"\"] + q_conditions,\n", - " \"age\": q_age,\n", - " \"identifiers\": q_identifiers,\n", - "\n", - " \"d_visit_id\": d_hadm, # document visit_id\n", - " \"d_conditions\": [\"\"] + d_conditions,\n", - " \"d_age\": d_age,\n", - " \"d_identifiers\": d_identifiers,\n", - " }\n", - "\n", - " return [sample]\n" + "import math\n" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 5, "id": "bce967de", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2025-12-22T06:22:59.583209Z", + "iopub.status.busy": "2025-12-22T06:22:59.583137Z", + "iopub.status.idle": "2025-12-22T06:22:59.666946Z", + "shell.execute_reply": "2025-12-22T06:22:59.666677Z" + } + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Setting task patient_linkage_mimic3 for mimic3 base dataset...\n", + "Setting task patient_linkage_mimic3 for mimic3 base dataset...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "Generating samples with 1 worker(s)...\n" ] }, @@ -282,18 +287,57 @@ "name": "stderr", "output_type": "stream", "text": [ - "Generating samples for patient_linkage_mimic3 with 1 worker: 100%|██████████| 100/100 [00:00<00:00, 1569.65it/s]\n", - "Processing samples: 100%|██████████| 14/14 [00:00<00:00, 1738.06it/s]" + "\r\n", + "Generating samples for patient_linkage_mimic3 with 1 worker: 0%| | 0/100 [00:00 List[str]: - """ - Collect all tokens under a given key across samples. - - This is mainly used by MedLink to build its vocabulary. - It assumes that sample[key] is either: - - a sequence (list/tuple) of tokens, or - - a scalar token (str/int/etc.). - """ - tokens: List[str] = [] - seen = set() - for sample in self.samples: - if key not in sample: - continue - value = sample[key] - - if isinstance(value, (list, tuple)): - values = value - else: - values = [value] - - for v in values: - if v is None: - continue - s = str(v) - if s in seen: - continue - seen.add(s) - tokens.append(s) - - return tokens def _get_processor_instance(self, processor_spec): """Get processor instance from either string alias, class reference, processor instance, or tuple with kwargs. diff --git a/pyhealth/models/embedding.py b/pyhealth/models/embedding.py index 3b5a39d3c..7c1280162 100644 --- a/pyhealth/models/embedding.py +++ b/pyhealth/models/embedding.py @@ -1,4 +1,7 @@ -from typing import Dict +from __future__ import annotations + +from typing import Dict, Any, Optional, Union +import os import torch import torch.nn as nn @@ -18,6 +21,94 @@ ) from .base_model import BaseModel + +def _iter_text_vectors( + path: str, + embedding_dim: int, + wanted_tokens: set[str], + encoding: str = "utf-8", +) -> Dict[str, torch.Tensor]: + """Loads word vectors from a text file (e.g., GloVe) for a subset of tokens. + + Expected format: one token per line followed by embedding_dim floats. + + This function reads the file line-by-line and only retains vectors for + tokens present in `wanted_tokens`. + """ + + if not os.path.exists(path): + raise FileNotFoundError(f"pretrained embedding file not found: {path}") + + vectors: Dict[str, torch.Tensor] = {} + with open(path, "r", encoding=encoding) as f: + for line in f: + line = line.strip() + if not line: + continue + parts = line.split() + # token + embedding_dim values + if len(parts) < embedding_dim + 1: + continue + token = parts[0] + if token not in wanted_tokens: + continue + try: + vec = torch.tensor( + [float(x) for x in parts[1 : embedding_dim + 1]], + dtype=torch.float, + ) + except ValueError: + continue + vectors[token] = vec + return vectors + + +def init_embedding_with_pretrained( + embedding: nn.Embedding, + code_vocab: Dict[Any, int], + pretrained_path: str, + embedding_dim: int, + pad_token: str = "", + unk_token: str = "", + normalize: bool = False, + freeze: bool = False, +) -> int: + """Initializes an nn.Embedding from a pretrained text-vector file. + + Tokens not found in the pretrained file are left as the module's existing + random initialization. + + Returns: + int: number of tokens successfully initialized from the file. + """ + + # Build wanted token set (stringified) + vocab_tokens = {str(t) for t in code_vocab.keys()} + vectors = _iter_text_vectors(pretrained_path, embedding_dim, vocab_tokens) + + loaded = 0 + with torch.no_grad(): + for tok, idx in code_vocab.items(): + tok_s = str(tok) + if tok_s in vectors: + vec = vectors[tok_s] + if normalize: + vec = vec / (vec.norm(p=2) + 1e-12) + embedding.weight[idx].copy_(vec) + loaded += 1 + + # Ensure pad row is zero + if pad_token in code_vocab: + embedding.weight[code_vocab[pad_token]].zero_() + # If embedding has a padding_idx, keep it consistent + if embedding.padding_idx is not None: + embedding.weight[embedding.padding_idx].zero_() + + if freeze: + embedding.weight.requires_grad_(False) + + return loaded + class EmbeddingModel(BaseModel): """ EmbeddingModel is responsible for creating embedding layers for different types of input data. @@ -46,7 +137,14 @@ class EmbeddingModel(BaseModel): - MultiHotProcessor: nn.Linear over multi-hot vector """ - def __init__(self, dataset: SampleDataset, embedding_dim: int = 128): + def __init__( + self, + dataset: SampleDataset, + embedding_dim: int = 128, + pretrained_emb_path: Optional[Union[str, Dict[str, str]]] = None, + freeze_pretrained: bool = False, + normalize_pretrained: bool = False, + ): super().__init__(dataset) self.embedding_dim = embedding_dim self.embedding_layers = nn.ModuleDict() @@ -81,6 +179,22 @@ def __init__(self, dataset: SampleDataset, embedding_dim: int = 128): padding_idx=0, ) + # Optional pretrained initialization (e.g., GloVe). + if pretrained_emb_path is not None: + if isinstance(pretrained_emb_path, str): + path = pretrained_emb_path + else: + path = pretrained_emb_path.get(field_name) + if path: + init_embedding_with_pretrained( + self.embedding_layers[field_name], + processor.code_vocab, + path, + embedding_dim=embedding_dim, + normalize=normalize_pretrained, + freeze=freeze_pretrained, + ) + # Numeric features (including deep nested floats) -> nn.Linear over last dim elif isinstance( processor, diff --git a/pyhealth/models/medlink.py b/pyhealth/models/medlink.py new file mode 100644 index 000000000..f11334e53 --- /dev/null +++ b/pyhealth/models/medlink.py @@ -0,0 +1,423 @@ +from typing import Dict, List + +import torch +import torch.nn as nn +import torch.nn.functional as F +import tqdm + +from pyhealth.datasets import SampleEHRDataset +from pyhealth.models import BaseModel +from pyhealth.models.transformer import TransformerLayer +from pyhealth.tokenizer import Tokenizer + + +def batch_to_multi_hot(label_batch: torch.Tensor, num_classes: int) -> torch.Tensor: + """ + Convert a 2D batch of label indices into a multi-hot representation. + + Parameters + ---------- + label_batch: + Long tensor of shape (batch_size, seq_len) with token indices. + num_classes: + Size of vocabulary. + + Returns + ------- + multi_hot: + Float tensor of shape (batch_size, num_classes), entries in {0,1}. + """ + # label_batch: (B, T) + batch_size, seq_len = label_batch.shape + flat = label_batch.view(-1) # (B*T,) + # Build index for scatter + row_idx = torch.arange(batch_size, device=label_batch.device).repeat_interleave(seq_len) + multi_hot = torch.zeros(batch_size, num_classes, device=label_batch.device, dtype=torch.float32) + multi_hot.index_put_((row_idx, flat), torch.ones_like(flat, dtype=torch.float32), accumulate=True) + multi_hot.clamp_max_(1.0) + return multi_hot + + +class AdmissionEncoder(nn.Module): + """ + Encodes a sequence of discrete tokens (code sequence) for MedLink. + + It uses: + - a learnable embedding over the vocabulary + - a TransformerLayer backbone + - a BCE-with-logits loss over multi-hot targets + """ + + def __init__( + self, + tokenizer: Tokenizer, + embedding_dim: int, + heads: int = 2, + dropout: float = 0.5, + num_layers: int = 1, + ) -> None: + super().__init__() + self.tokenizer = tokenizer + self.vocab_size = tokenizer.get_vocabulary_size() + + self.embedding = nn.Embedding( + num_embeddings=self.vocab_size, + embedding_dim=embedding_dim, + padding_idx=tokenizer.get_padding_index(), + ) + + self.encoder = TransformerLayer( + feature_size=embedding_dim, + heads=heads, + dropout=dropout, + num_layers=num_layers, + ) + + self.criterion = nn.BCEWithLogitsLoss() + + def _encode_tokens(self, seqs: List[List[str]], device: torch.device): + """ + Turn a batch of token sequences into contextual embeddings and a padding mask. + + seqs: list of list of token strings, e.g. [["250.0","401.9"], ["414.0"], ...] + """ + token_ids = self.tokenizer.batch_encode_2d(seqs, padding=True) + token_ids = torch.tensor(token_ids, dtype=torch.long, device=device) # (B, T) + pad_idx = self.tokenizer.get_padding_index() + mask = token_ids != pad_idx # (B, T) + + emb = self.embedding(token_ids) # (B, T, D) + encoded, _ = self.encoder(emb) # (B, T, D) + return encoded, mask, token_ids + + def _multi_hot_targets(self, token_ids: torch.Tensor) -> torch.Tensor: + """ + Build a multi-hot target vector for each sequence in the batch. + """ + multi_hot = batch_to_multi_hot(token_ids, self.vocab_size) # (B, V) + # Clear special tokens + pad_id = self.tokenizer.vocabulary("") + cls_id = self.tokenizer.vocabulary("") + if pad_id is not None: + multi_hot[:, pad_id] = 0.0 + if cls_id is not None: + multi_hot[:, cls_id] = 0.0 + return multi_hot + + def logits_and_targets( + self, + seqs: List[List[str]], + vocab_embeddings: torch.Tensor, + device: torch.device, + ): + """ + Compute: + - per-token logits against vocab embeddings + - multi-hot label vectors for the sequence. + + Returns + ------- + logits: (B, V) tensor + targets: (B, V) tensor multi-hot + """ + encoded, mask, token_ids = self._encode_tokens(seqs, device=device) # (B,T,D), (B,T), (B,T) + targets = self._multi_hot_targets(token_ids) # (B,V) + + # encoded: (B,T,D), vocab_embeddings: (V,D) + # per-token logits: (B,T,V) + logits = torch.matmul(encoded, vocab_embeddings.T) # (B,T,V) + # mask padded positions with large negative value + mask_expanded = mask.unsqueeze(-1) # (B,T,1) + logits = logits.masked_fill(~mask_expanded, -1e9) + # max-pool over time + logits = logits.max(dim=1).values # (B,V) + + return logits, targets + + def classification_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """ + BCE loss on multi-hot labels; handles potential size mismatches defensively. + """ + # In case of tiny mismatches, truncate to the smaller dimension. + batch = min(logits.size(0), targets.size(0)) + logits = logits[:batch] + targets = targets[:batch] + return self.criterion(logits, targets) + + +class MedLink(BaseModel): + """ + MedLink: de-identified patient record linkage model (KDD 2023). + + This implementation is adapted to the PyHealth framework: + * no pre-trained GloVe; embeddings are learned from scratch + * training monitored via loss instead of ranking metrics + + It implements three losses: + - forward admission prediction (corpus -> queries) + - backward admission prediction (queries -> corpus) + - retrieval loss via TF-IDF-style matching + """ + + def __init__( + self, + dataset: SampleEHRDataset, + feature_keys: List[str], + embedding_dim: int = 128, + alpha: float = 0.5, + beta: float = 0.5, + gamma: float = 1.0, + heads: int = 2, + dropout: float = 0.5, + num_layers: int = 1, + **kwargs, + ) -> None: + # MedLink is defined over a single textual / code sequence feature + assert len(feature_keys) == 1, "MedLink supports exactly one feature key" + # BaseModel only accepts dataset parameter, not feature_keys, label_key, or mode + super().__init__(dataset=dataset) + # Set feature_keys manually since BaseModel extracts it from dataset.input_schema + # but MedLink needs to use the provided feature_keys + self.feature_keys = feature_keys + self.feature_key = feature_keys[0] + self.embedding_dim = embedding_dim + self.alpha = alpha + self.beta = beta + self.gamma = gamma + + # Build vocabulary for both queries and corpus sides + q_tokens = dataset.get_all_tokens(key=self.feature_key) + d_tokens = dataset.get_all_tokens(key="d_" + self.feature_key) + + tokenizer = Tokenizer( + tokens=q_tokens + d_tokens, + special_tokens=["", "", ""], + ) + self.tokenizer = tokenizer + self.vocab_size = tokenizer.get_vocabulary_size() + + # Two direction-specific encoders (forward / backward) + self.forward_encoder = AdmissionEncoder( + tokenizer=tokenizer, + embedding_dim=embedding_dim, + heads=heads, + dropout=dropout, + num_layers=num_layers, + ) + self.backward_encoder = AdmissionEncoder( + tokenizer=tokenizer, + embedding_dim=embedding_dim, + heads=heads, + dropout=dropout, + num_layers=num_layers, + ) + + # Retrieval / ranking loss + self.rank_loss = nn.CrossEntropyLoss() + + # ------------------------ + # Encoding utilities + # ------------------------ + def _all_vocab_ids(self) -> torch.Tensor: + return torch.arange(self.vocab_size, device=self.device, dtype=torch.long) + + def encode_queries(self, queries: List[List[str]]) -> torch.Tensor: + """ + Encode query records into embeddings for retrieval. + + queries: list of token sequences, e.g. [["250.0","401.9"], ...] + Returns: (num_queries, vocab_size) embedding matrix. + """ + all_vocab = self._all_vocab_ids() # (V,) + bwd_vocab_emb = self.backward_encoder.embedding(all_vocab) # (V,D) + + logits, multi_hot = self.backward_encoder.logits_and_targets( + seqs=queries, + vocab_embeddings=bwd_vocab_emb, + device=self.device, + ) + logits = torch.log1p(F.relu(logits)) # smooth nonlinearity + return logits + multi_hot # (Q,V) + + def encode_corpus(self, corpus: List[List[str]]) -> torch.Tensor: + """ + Encode corpus records into embeddings for retrieval. + + corpus: list of token sequences. + Returns: (num_docs, vocab_size) embedding matrix. + """ + all_vocab = self._all_vocab_ids() + fwd_vocab_emb = self.forward_encoder.embedding(all_vocab) # (V,D) + + logits, multi_hot = self.forward_encoder.logits_and_targets( + seqs=corpus, + vocab_embeddings=fwd_vocab_emb, + device=self.device, + ) + logits = torch.log1p(F.relu(logits)) + return logits + multi_hot # (D,V) + + # ------------------------ + # Retrieval scoring + # ------------------------ + @staticmethod + def compute_scores(queries_emb: torch.Tensor, corpus_emb: torch.Tensor) -> torch.Tensor: + """ + Compute TF-IDF-like matching scores between queries and corpus. + + queries_emb: (Q,V) + corpus_emb: (D,V) + + Returns: + scores: (Q,D) + """ + # Inverse document frequency per term + n_docs = torch.tensor(corpus_emb.shape[0], device=corpus_emb.device, dtype=torch.float32) + df = (corpus_emb > 0).sum(dim=0) # (V,) + idf = torch.log1p(n_docs) - torch.log1p(df) + + # Term-frequency contribution per (query, doc, term) + tf = torch.einsum("qv,dv->qdv", queries_emb, corpus_emb) # (Q,D,V) + tf_idf = tf * idf # broadcast idf over last dim + + scores = tf_idf.sum(dim=-1) # (Q,D) + return scores + + def get_loss(self, scores: torch.Tensor) -> torch.Tensor: + """ + Retrieval loss: each query is matched to its corresponding positive + document at the same index. + """ + num_queries = scores.size(0) + target = torch.arange(num_queries, device=scores.device, dtype=torch.long) + return self.rank_loss(scores, target) + + # ------------------------ + # Training forward + # ------------------------ + def forward( + self, + query_id, + id_p, + s_q, + s_p, + s_n=None, + ) -> Dict[str, torch.Tensor]: + """ + Forward pass used for training. + + Parameters in the batch (dict passed as **batch): + - query_id: list of query identifiers (unused by the loss) + - id_p: list of positive record ids (unused here, used for evaluation) + - s_q: list of query sequences (list[list[str]]) + - s_p: list of positive corpus sequences (list[list[str]]) + - s_n: optional list of negative corpus sequences (list[list[str]]) + + Returns + ------- + dict with key "loss": scalar tensor. + """ + # Build full corpus: positives plus negatives if provided + if s_n is None: + corpus = s_p + else: + corpus = s_p + s_n + queries = s_q + + # Precompute vocab embeddings for both encoders + all_vocab = self._all_vocab_ids() + fwd_vocab_emb = self.forward_encoder.embedding(all_vocab) # (V,D) + bwd_vocab_emb = self.backward_encoder.embedding(all_vocab) # (V,D) + + # Forward and backward admission prediction losses + # Corpus -> query distributions + pred_queries, corpus_targets = self.forward_encoder.logits_and_targets( + seqs=corpus, + vocab_embeddings=fwd_vocab_emb, + device=self.device, + ) + # Query -> corpus distributions + pred_corpus, query_targets = self.backward_encoder.logits_and_targets( + seqs=queries, + vocab_embeddings=bwd_vocab_emb, + device=self.device, + ) + + fwd_cls_loss = self.forward_encoder.classification_loss(pred_queries, query_targets) + bwd_cls_loss = self.backward_encoder.classification_loss(pred_corpus, corpus_targets) + + # Turn predictions into dense embeddings + pred_queries_act = torch.log1p(F.relu(pred_queries)) + pred_corpus_act = torch.log1p(F.relu(pred_corpus)) + + corpus_emb = corpus_targets + pred_queries_act + queries_emb = query_targets + pred_corpus_act + + scores = self.compute_scores(queries_emb, corpus_emb) + retrieval_loss = self.get_loss(scores) + + total_loss = ( + self.alpha * fwd_cls_loss + + self.beta * bwd_cls_loss + + self.gamma * retrieval_loss + ) + return {"loss": total_loss} + + # ------------------------ + # Retrieval helpers + # ------------------------ + def search( + self, + queries_ids: List[str], + queries_embeddings: torch.Tensor, + corpus_ids: List[str], + corpus_embeddings: torch.Tensor, + ) -> Dict[str, Dict[str, float]]: + """ + Compute scores for all (query, corpus) pairs and return as nested dict: + {query_id: {corpus_id: score, ...}, ...} + """ + scores = self.compute_scores(queries_embeddings, corpus_embeddings) # (Q,D) + results: Dict[str, Dict[str, float]] = {} + for q_idx, q_id in enumerate(queries_ids): + row_scores = scores[q_idx] + results[q_id] = {c_id: row_scores[c_idx].item() for c_idx, c_id in enumerate(corpus_ids)} + return results + + def evaluate(self, corpus_dataloader, queries_dataloader) -> Dict[str, Dict[str, float]]: + """ + Run MedLink in retrieval mode on dataloaders for corpus and queries. + + corpus_dataloader yields batches with keys: "corpus_id", "s". + queries_dataloader yields batches with keys: "query_id", "s". + """ + self.eval() + all_corpus_ids: List[str] = [] + all_queries_ids: List[str] = [] + all_corpus_embeddings: List[torch.Tensor] = [] + all_queries_embeddings: List[torch.Tensor] = [] + + with torch.no_grad(): + for batch in tqdm.tqdm(corpus_dataloader): + corpus_ids = batch["corpus_id"] + corpus_seqs = batch["s"] + corpus_emb = self.encode_corpus(corpus_seqs) + all_corpus_ids.extend(corpus_ids) + all_corpus_embeddings.append(corpus_emb) + + for batch in tqdm.tqdm(queries_dataloader): + query_ids = batch["query_id"] + query_seqs = batch["s"] + query_emb = self.encode_queries(query_seqs) + all_queries_ids.extend(query_ids) + all_queries_embeddings.append(query_emb) + + corpus_mat = torch.cat(all_corpus_embeddings, dim=0) + queries_mat = torch.cat(all_queries_embeddings, dim=0) + + return self.search( + queries_ids=all_queries_ids, + queries_embeddings=queries_mat, + corpus_ids=all_corpus_ids, + corpus_embeddings=corpus_mat, + ) diff --git a/pyhealth/models/medlink/model.py b/pyhealth/models/medlink/model.py index 1dea4fe5e..ad007becb 100644 --- a/pyhealth/models/medlink/model.py +++ b/pyhealth/models/medlink/model.py @@ -1,182 +1,359 @@ -from typing import Dict, List +from __future__ import annotations + +from typing import Dict, List, Any, Optional, Sequence, Tuple import torch import torch.nn as nn import torch.nn.functional as F import tqdm +from torch.nn.utils.rnn import pad_sequence + +from ...datasets import SampleDataset +from ..base_model import BaseModel +from ..transformer import TransformerLayer +from ...processors import SequenceProcessor + +from ..embedding import init_embedding_with_pretrained + + +def _build_shared_vocab( + q_processor: SequenceProcessor, + d_processor: SequenceProcessor, + pad_token: str = "", + unk_token: str = "", +) -> Dict[str, int]: + """Build a shared token->index mapping from two fitted SequenceProcessors. + + The returned vocabulary is deterministic (sorted token order) and always + includes `pad_token` and `unk_token`. + """ -from pyhealth.datasets import SampleEHRDataset -from pyhealth.models import BaseModel -from pyhealth.models.transformer import TransformerLayer -from pyhealth.tokenizer import Tokenizer + vocab: Dict[str, int] = {pad_token: 0, unk_token: 1} + tokens = set(str(t) for t in q_processor.code_vocab.keys()) | set( + str(t) for t in d_processor.code_vocab.keys() + ) + tokens.discard(pad_token) + tokens.discard(unk_token) + + for t in sorted(tokens): + if t not in vocab: + vocab[t] = len(vocab) + return vocab + + +def _build_index_remap( + processor: SequenceProcessor, + shared_vocab: Dict[str, int], + unk_idx: int, +) -> torch.Tensor: + """Build a dense remap tensor old_idx -> shared_idx.""" + + size = len(processor.code_vocab) + remap = torch.full((size,), unk_idx, dtype=torch.long) + for tok, old_idx in processor.code_vocab.items(): + tok_s = str(tok) + remap[old_idx] = shared_vocab.get(tok_s, unk_idx) + return remap + + +def _to_index_tensor( + seq: Any, + processor: SequenceProcessor, +) -> torch.Tensor: + """Converts a single sequence to an index tensor using the fitted processor.""" + if isinstance(seq, torch.Tensor): + return seq.long() + if isinstance(seq, (list, tuple)): + return processor.process(seq) + # single token + return processor.process([seq]) + + +def _pad_and_remap( + sequences: Sequence[Any], + processor: SequenceProcessor, + remap: torch.Tensor, + pad_value: int = 0, + device: Optional[torch.device] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Pads a batch of sequences and remaps indices into the shared vocab. + + Returns: + ids_shared: LongTensor [B, L] + mask: BoolTensor [B, L] where True indicates valid token positions. + """ -def batch_to_one_hot(label_batch, num_class): - """ convert to one hot label """ - label_batch_onehot = [] - for label in label_batch: - label_batch_onehot.append(F.one_hot(label, num_class).sum(dim=0)) - label_batch_onehot = torch.stack(label_batch_onehot, dim=0) - label_batch_onehot[label_batch_onehot > 1] = 1 - return label_batch_onehot + ids = [_to_index_tensor(s, processor) for s in sequences] + ids_padded = pad_sequence(ids, batch_first=True, padding_value=pad_value) + if device is not None: + ids_padded = ids_padded.to(device) + remap = remap.to(device) + ids_shared = remap[ids_padded] + mask = ids_shared != 0 + return ids_shared, mask class AdmissionPrediction(nn.Module): - def __init__(self, tokenizer, embedding_dim, heads=2, dropout=0.5, num_layers=1): - super(AdmissionPrediction, self).__init__() - self.tokenizer = tokenizer - self.vocabs_size = tokenizer.get_vocabulary_size() + """Admission prediction module used by MedLink. + + This is a lightly-adapted version of the original MedLink implementation, + refactored to work with PyHealth 2.0 processors (i.e., indexed tensors). + """ + + def __init__( + self, + code_vocab: Dict[str, int], + embedding_dim: int, + heads: int = 2, + dropout: float = 0.5, + num_layers: int = 1, + pretrained_emb_path: Optional[str] = None, + freeze_pretrained: bool = False, + ): + super().__init__() + self.code_vocab = code_vocab + self.vocab_size = len(code_vocab) + self.pad_idx = code_vocab.get("", 0) + self.cls_idx = code_vocab.get("") + self.embedding = nn.Embedding( - self.vocabs_size, - embedding_dim, - padding_idx=tokenizer.get_padding_index() + num_embeddings=self.vocab_size, + embedding_dim=embedding_dim, + padding_idx=self.pad_idx, ) + if pretrained_emb_path: + init_embedding_with_pretrained( + self.embedding, + code_vocab, + pretrained_emb_path, + embedding_dim=embedding_dim, + freeze=freeze_pretrained, + ) + self.encoder = TransformerLayer( feature_size=embedding_dim, heads=heads, dropout=dropout, - num_layers=num_layers + num_layers=num_layers, ) self.criterion = nn.BCEWithLogitsLoss() - def encode_one_hot(self, input: List[str], device): - input_batch = self.tokenizer.batch_encode_2d(input, padding=True) - input_batch = torch.tensor(input_batch, dtype=torch.long, device=device) - input_onehot = batch_to_one_hot(input_batch, self.vocabs_size) - input_onehot = input_onehot.float().to(device) - input_onehot[:, self.tokenizer.vocabulary("")] = 0 - input_onehot[:, self.tokenizer.vocabulary("")] = 0 - return input_onehot - - def encode_dense(self, input: List[str], device): - input_batch = self.tokenizer.batch_encode_2d(input, padding=True) - input_batch = torch.tensor(input_batch, dtype=torch.long, device=device) - mask = input_batch != 0 - input_embeddings = self.embedding(input_batch) - input_embeddings, _ = self.encoder(input_embeddings) - return input_embeddings, mask - - def get_loss(self, logits, target_onehot): - true_batch_size = min(logits.shape[0], target_onehot.shape[0]) - loss = self.criterion(logits[:true_batch_size], target_onehot[:true_batch_size]) - return loss - - def forward(self, input, vocab_emb, device): - input_dense, mask = self.encode_dense(input, device) - input_one_hot = self.encode_one_hot(input, device) - logits = torch.matmul(input_dense, vocab_emb.T) - logits[~mask] = -1e9 - logits = logits.max(dim=1)[0] - return logits, input_one_hot + def _multi_hot(self, input_ids: torch.Tensor) -> torch.Tensor: + """Builds a multi-hot label vector per sample.""" + + # input_ids: [B, L] + bsz = input_ids.size(0) + out = torch.zeros(bsz, self.vocab_size, device=input_ids.device, dtype=torch.float) + src = torch.ones_like(input_ids, dtype=torch.float) + out.scatter_add_(1, input_ids, src) + out = (out > 0).float() + # Remove special tokens from labels. + if self.pad_idx is not None: + out[:, self.pad_idx] = 0 + if self.cls_idx is not None: + out[:, self.cls_idx] = 0 + return out + + def get_loss(self, logits: torch.Tensor, target_multi_hot: torch.Tensor) -> torch.Tensor: + true_batch_size = min(logits.shape[0], target_multi_hot.shape[0]) + return self.criterion(logits[:true_batch_size], target_multi_hot[:true_batch_size]) + + def forward(self, input_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute vocabulary logits and target multi-hot labels. + + Args: + input_ids: LongTensor [B, L] in shared vocabulary indices. + + Returns: + logits: FloatTensor [B, V] + target: FloatTensor [B, V] multi-hot labels. + """ + + mask = input_ids != self.pad_idx + x = self.embedding(input_ids) + x, _ = self.encoder(x, mask=mask) + + # Use embedding table as vocabulary embedding. + vocab_emb = self.embedding.weight # [V, D] + logits = torch.matmul(x, vocab_emb.T) # [B, L, V] + logits = logits.masked_fill(~mask.unsqueeze(-1), -1e9) + logits = logits.max(dim=1).values # [B, V] + + target = self._multi_hot(input_ids) + return logits, target class MedLink(BaseModel): - """MedLink model. + """MedLink model (KDD 2023). - Paper: Zhenbang Wu et al. MedLink: De-Identified Patient Health - Record Linkage. KDD 2023. + Paper: Zhenbang Wu et al. MedLink: De-Identified Patient Health Record + Linkage. KDD 2023. - IMPORTANT: This implementation differs from the original paper in order to - make it work with the PyHealth framework. Specifically, we do not use the - pre-trained GloVe embeddings. And we only monitor the loss on the validation - set instead of the ranking metrics. As a result, the performance of this model - is different from the original paper. To reproduce the results in the paper, - please use the official GitHub repo: https://github.com/zzachw/MedLink. + IMPORTANT: This implementation differs from the original paper to fit the + PyHealth 2.0 framework. By default, it uses randomly-initialized embeddings. + Optionally, you may initialize the embedding tables using a GloVe-style + text vector file. Args: - dataset: SampleEHRDataset. - feature_keys: List of feature keys. MedLink only supports one feature key. - embedding_dim: Dimension of embedding. - alpha: Weight of the forward prediction loss. - beta: Weight of the backward prediction loss. - gamma: Weight of the retrieval loss. + dataset: SampleDataset. + feature_keys: List of feature keys. MedLink only supports one feature. + embedding_dim: embedding dimension. + alpha: weight for forward prediction loss. + beta: weight for backward prediction loss. + gamma: weight for retrieval loss. + pretrained_emb_path: optional path to a GloVe-style embedding file. + freeze_pretrained: if True, freezes embedding weights after init. """ def __init__( self, - dataset: SampleEHRDataset, + dataset: SampleDataset, feature_keys: List[str], embedding_dim: int = 128, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, + pretrained_emb_path: Optional[str] = None, + freeze_pretrained: bool = False, **kwargs, ): assert len(feature_keys) == 1, "MedLink only supports one feature key" - super(MedLink, self).__init__(dataset=dataset) + super().__init__(dataset=dataset) + self.feature_key = feature_keys[0] self.embedding_dim = embedding_dim self.alpha = alpha self.beta = beta self.gamma = gamma - q_tokens = self.dataset.get_all_tokens(key=self.feature_key) - d_tokens = self.dataset.get_all_tokens(key="d_" + self.feature_key) - tokenizer = Tokenizer( - tokens=q_tokens + d_tokens, - special_tokens=["", "", ""], + + q_field = self.feature_key + d_field = "d_" + self.feature_key + if q_field not in self.dataset.input_processors or d_field not in self.dataset.input_processors: + raise KeyError( + f"MedLink expects both '{q_field}' and '{d_field}' in dataset.input_schema" + ) + + q_processor = self.dataset.input_processors[q_field] + d_processor = self.dataset.input_processors[d_field] + if not isinstance(q_processor, SequenceProcessor) or not isinstance(d_processor, SequenceProcessor): + raise TypeError( + "MedLink currently supports SequenceProcessor for both query and corpus fields" + ) + + self.q_processor = q_processor + self.d_processor = d_processor + + # Shared vocabulary across query/corpus streams. + self.code_vocab = _build_shared_vocab(q_processor, d_processor) + self.vocab_size = len(self.code_vocab) + self.unk_idx = self.code_vocab.get("", 1) + + # Remap tensors from per-field vocab -> shared vocab. + self.q_remap = _build_index_remap(q_processor, self.code_vocab, self.unk_idx) + self.d_remap = _build_index_remap(d_processor, self.code_vocab, self.unk_idx) + + self.fwd_adm_pred = AdmissionPrediction( + code_vocab=self.code_vocab, + embedding_dim=embedding_dim, + pretrained_emb_path=pretrained_emb_path, + freeze_pretrained=freeze_pretrained, + **kwargs, ) - self.fwd_adm_pred = AdmissionPrediction(tokenizer, embedding_dim, **kwargs) self.forward_encoder = self.fwd_adm_pred.encoder - self.bwd_adm_pred = AdmissionPrediction(tokenizer, embedding_dim, **kwargs) - self.backward_encoder = self.bwd_adm_pred.encoder - self.criterion = nn.CrossEntropyLoss() - self.vocab_size = tokenizer.get_vocabulary_size() - return - - def encode_queries(self, queries: List[str]): - all_vocab = torch.tensor(list(range(self.vocab_size)), device=self.device) - bwd_vocab_emb = self.bwd_adm_pred.embedding(all_vocab) - pred_corpus, queries_one_hot = self.bwd_adm_pred( - queries, bwd_vocab_emb, device=self.device - ) - pred_corpus = torch.log(1 + torch.relu(pred_corpus)) - queries_emb = pred_corpus + queries_one_hot - return queries_emb - def encode_corpus(self, corpus: List[str]): - all_vocab = torch.tensor(list(range(self.vocab_size)), device=self.device) - fwd_vocab_emb = self.fwd_adm_pred.embedding(all_vocab) - pred_queries, corpus_one_hot = self.fwd_adm_pred( - corpus, fwd_vocab_emb, device=self.device + self.bwd_adm_pred = AdmissionPrediction( + code_vocab=self.code_vocab, + embedding_dim=embedding_dim, + pretrained_emb_path=pretrained_emb_path, + freeze_pretrained=freeze_pretrained, + **kwargs, ) - pred_queries = torch.log(1 + torch.relu(pred_queries)) - corpus_emb = corpus_one_hot + pred_queries - return corpus_emb + self.backward_encoder = self.bwd_adm_pred.encoder - def compute_scores(self, queries_emb, corpus_emb): - n = torch.tensor(corpus_emb.shape[0]).to(queries_emb.device) - df = (corpus_emb > 0).sum(dim=0) - idf = torch.log(1 + n) - torch.log(1 + df) + self.criterion = nn.CrossEntropyLoss() - tf = torch.einsum('ac,bc->abc', queries_emb, corpus_emb) + # ------------------------------------------------------------------ + # Encoding helpers + # ------------------------------------------------------------------ + + def _prepare_queries(self, queries: Sequence[Any]) -> Tuple[torch.Tensor, torch.Tensor]: + return _pad_and_remap( + queries, + processor=self.q_processor, + remap=self.q_remap, + pad_value=0, + device=self.device, + ) - tf_idf = tf * idf - final_scores = tf_idf.sum(dim=-1) - return final_scores + def _prepare_corpus(self, corpus: Sequence[Any]) -> Tuple[torch.Tensor, torch.Tensor]: + return _pad_and_remap( + corpus, + processor=self.d_processor, + remap=self.d_remap, + pad_value=0, + device=self.device, + ) - def get_loss(self, scores): - label = torch.tensor(list(range(scores.shape[0])), device=scores.device) - loss = self.criterion(scores, label) - return loss + def encode_queries(self, queries: Sequence[Any]) -> torch.Tensor: + q_ids, _ = self._prepare_queries(queries) + pred_corpus, queries_one_hot = self.bwd_adm_pred(q_ids) + pred_corpus = torch.log1p(torch.relu(pred_corpus)) + emb = pred_corpus + queries_one_hot + # Keep special tokens out of retrieval scoring. + emb[:, self.code_vocab.get("", 0)] = 0 + if "" in self.code_vocab: + emb[:, self.code_vocab[""]] = 0 + return emb + + def encode_corpus(self, corpus: Sequence[Any]) -> torch.Tensor: + c_ids, _ = self._prepare_corpus(corpus) + pred_queries, corpus_one_hot = self.fwd_adm_pred(c_ids) + pred_queries = torch.log1p(torch.relu(pred_queries)) + emb = corpus_one_hot + pred_queries + emb[:, self.code_vocab.get("", 0)] = 0 + if "" in self.code_vocab: + emb[:, self.code_vocab[""]] = 0 + return emb + + # ------------------------------------------------------------------ + # Scoring / losses + # ------------------------------------------------------------------ + + def compute_scores(self, queries_emb: torch.Tensor, corpus_emb: torch.Tensor) -> torch.Tensor: + """TF-IDF-like score used by MedLink. + + queries_emb: [Q, V] + corpus_emb: [C, V] + returns: [Q, C] + """ + + n = torch.tensor(float(corpus_emb.shape[0]), device=queries_emb.device) + df = (corpus_emb > 0).sum(dim=0).float() + idf = torch.log1p(n) - torch.log1p(df) + # Equivalent to sum_c q[c] * d[c] * idf[c] + return torch.matmul(queries_emb * idf, corpus_emb.T) + + def get_loss(self, scores: torch.Tensor) -> torch.Tensor: + label = torch.arange(scores.shape[0], device=scores.device) + return self.criterion(scores, label) def forward(self, query_id, id_p, s_q, s_p, s_n=None) -> Dict[str, torch.Tensor]: - corpus = s_p if s_n is None else s_p + s_n + # corpus is positives optionally concatenated with negatives. + corpus = s_p if s_n is None else (s_p + s_n) queries = s_q - all_vocab = torch.tensor(list(range(self.vocab_size)), device=self.device) - fwd_vocab_emb = self.fwd_adm_pred.embedding(all_vocab) - bwd_vocab_emb = self.bwd_adm_pred.embedding(all_vocab) - pred_queries, corpus_one_hot = self.fwd_adm_pred( - corpus, fwd_vocab_emb, self.device - ) - pred_corpus, queries_one_hot = self.bwd_adm_pred( - queries, bwd_vocab_emb, self.device - ) + + q_ids, _ = self._prepare_queries(queries) + c_ids, _ = self._prepare_corpus(corpus) + + pred_queries, corpus_one_hot = self.fwd_adm_pred(c_ids) + pred_corpus, queries_one_hot = self.bwd_adm_pred(q_ids) fwd_cls_loss = self.fwd_adm_pred.get_loss(pred_queries, queries_one_hot) bwd_cls_loss = self.bwd_adm_pred.get_loss(pred_corpus, corpus_one_hot) - pred_queries = torch.log(1 + torch.relu(pred_queries)) - pred_corpus = torch.log(1 + torch.relu(pred_corpus)) + pred_queries = torch.log1p(torch.relu(pred_queries)) + pred_corpus = torch.log1p(torch.relu(pred_corpus)) corpus_emb = corpus_one_hot + pred_queries queries_emb = pred_corpus + queries_one_hot @@ -184,11 +361,13 @@ def forward(self, query_id, id_p, s_q, s_p, s_n=None) -> Dict[str, torch.Tensor] scores = self.compute_scores(queries_emb, corpus_emb) ret_loss = self.get_loss(scores) - loss = self.alpha * fwd_cls_loss + \ - self.beta * bwd_cls_loss + \ - self.gamma * ret_loss + loss = self.alpha * fwd_cls_loss + self.beta * bwd_cls_loss + self.gamma * ret_loss return {"loss": loss} + # ------------------------------------------------------------------ + # Retrieval API + # ------------------------------------------------------------------ + def search(self, queries_ids, queries_embeddings, corpus_ids, corpus_embeddings): scores = self.compute_scores(queries_embeddings, corpus_embeddings) results = {} @@ -203,30 +382,29 @@ def evaluate(self, corpus_dataloader, queries_dataloader): all_corpus_ids, all_corpus_embeddings = [], [] all_queries_ids, all_queries_embeddings = [], [] with torch.no_grad(): - for i, batch in enumerate(tqdm.tqdm(corpus_dataloader)): + for batch in tqdm.tqdm(corpus_dataloader): corpus_ids, corpus = batch["corpus_id"], batch["s"] corpus_embeddings = self.encode_corpus(corpus) all_corpus_ids.extend(corpus_ids) all_corpus_embeddings.append(corpus_embeddings) - for i, batch in enumerate(tqdm.tqdm(queries_dataloader)): + for batch in tqdm.tqdm(queries_dataloader): queries_ids, queries = batch["query_id"], batch["s"] queries_embeddings = self.encode_queries(queries) all_queries_ids.extend(queries_ids) all_queries_embeddings.append(queries_embeddings) - all_corpus_embeddings = torch.cat(all_corpus_embeddings) - all_queries_embeddings = torch.cat(all_queries_embeddings) - results = self.search( + all_corpus_embeddings = torch.cat(all_corpus_embeddings, dim=0) + all_queries_embeddings = torch.cat(all_queries_embeddings, dim=0) + return self.search( all_queries_ids, all_queries_embeddings, all_corpus_ids, - all_corpus_embeddings + all_corpus_embeddings, ) - return results if __name__ == "__main__": + # Minimal smoke-test matching the public example script. from pyhealth.datasets import MIMIC3Dataset - from pyhealth.models import MedLink from pyhealth.models.medlink import ( convert_to_ir_format, get_train_dataloader, @@ -243,20 +421,10 @@ def evaluate(self, corpus_dataloader, queries_dataloader): ) sample_dataset = base_dataset.set_task(patient_linkage_mimic3_fn) - corpus, queries, qrels = convert_to_ir_format(sample_dataset.samples) - tr_queries, va_queries, te_queries, tr_qrels, va_qrels, te_qrels = tvt_split( - queries, qrels - ) - train_dataloader = get_train_dataloader( - corpus, tr_queries, tr_qrels, batch_size=32, shuffle=True - ) + corpus, queries, qrels, *_ = convert_to_ir_format(sample_dataset.samples) + tr_queries, _, _, tr_qrels, _, _ = tvt_split(queries, qrels) + train_dataloader = get_train_dataloader(corpus, tr_queries, tr_qrels, batch_size=4) batch = next(iter(train_dataloader)) - model = MedLink( - dataset=sample_dataset, - feature_keys=["conditions"], - embedding_dim=128, - ) - with torch.autograd.detect_anomaly(): - o = model(**batch) - print("loss:", o["loss"]) - o["loss"].backward() \ No newline at end of file + model = MedLink(dataset=sample_dataset, feature_keys=["conditions"], embedding_dim=32) + out = model(**batch) + print("loss:", out["loss"].item()) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index fb3c6966a..4db38b196 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -62,3 +62,4 @@ MutationPathogenicityPrediction, VariantClassificationClinVar, ) +from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task diff --git a/pyhealth/tasks/patient_linkage_mimic3.py b/pyhealth/tasks/patient_linkage_mimic3.py new file mode 100644 index 000000000..3772a0811 --- /dev/null +++ b/pyhealth/tasks/patient_linkage_mimic3.py @@ -0,0 +1,106 @@ +from datetime import datetime +from collections import defaultdict +import math +from pyhealth.tasks import BaseTask + +class PatientLinkageMIMIC3Task(BaseTask): + """ + Patient linkage task for MIMIC-III using the Patient/Visit/Event API. + + Produces the same sample keys as the original patient_linkage_mimic3 task + so pyhealth.models.medlink.convert_to_ir_format works as usual + + Output sample schema: + - patient_id: ground-truth entity id (equivalent to "master patient record id" in MIMIC) + - visit_id: query admission id (hadm_id) + - conditions, age, identifiers: query side + - d_visit_id: doc admission id (hadm_id) + - d_conditions, d_age, d_identifiers: doc side + """ + + task_name = "patient_linkage_mimic3" + input_schema = {"conditions": "sequence", "d_conditions": "sequence"} + output_schema = {} + + def __call__(self, patient): + admissions = patient.get_events(event_type="admissions") + if len(admissions) < 2: + return [] + + admissions = sorted(admissions, key=lambda e: e.timestamp) + q_visit = admissions[-1] + d_visit = admissions[-2] + + patients_events = patient.get_events(event_type="patients") + if not patients_events: + return [] + demo = patients_events[0] + + gender = str(demo.attr_dict.get("gender") or "") + + dob_raw = demo.attr_dict.get("dob") + birth_dt = None + if isinstance(dob_raw, datetime): + birth_dt = dob_raw + elif dob_raw is not None: + try: + birth_dt = datetime.fromisoformat(str(dob_raw)) + except Exception: + birth_dt = None + + def compute_age(ts): + if birth_dt is None or ts is None: + return None + return int((ts - birth_dt).days // 365.25) + + q_age = compute_age(q_visit.timestamp) + d_age = compute_age(d_visit.timestamp) + if q_age is None or d_age is None or q_age < 18 or d_age < 18: + return [] + + diag_events = patient.get_events(event_type="diagnoses_icd") + hadm_to_codes = defaultdict(list) + for ev in diag_events: + hadm = ev.attr_dict.get("hadm_id") + code = ev.attr_dict.get("icd9_code") + if hadm is None or code is None: + continue + hadm_to_codes[str(hadm)].append(str(code)) + + q_hadm = str(q_visit.attr_dict.get("hadm_id")) + d_hadm = str(d_visit.attr_dict.get("hadm_id")) + + q_conditions = hadm_to_codes.get(q_hadm, []) + d_conditions = hadm_to_codes.get(d_hadm, []) + if len(q_conditions) == 0 or len(d_conditions) == 0: + return [] + + def clean(x): + if x is None: + return "" + if isinstance(x, float) and math.isnan(x): + return "" + return str(x) + + def build_identifiers(adm_event): + insurance = clean(adm_event.attr_dict.get("insurance")) + language = clean(adm_event.attr_dict.get("language")) + religion = clean(adm_event.attr_dict.get("religion")) + marital_status = clean(adm_event.attr_dict.get("marital_status")) + ethnicity = clean(adm_event.attr_dict.get("ethnicity")) + return "+".join([gender, insurance, language, religion, marital_status, ethnicity]) + + sample = { + "patient_id": patient.patient_id, + + "visit_id": q_hadm, + "conditions": [""] + q_conditions, + "age": q_age, + "identifiers": build_identifiers(q_visit), + + "d_visit_id": d_hadm, + "d_conditions": [""] + d_conditions, + "d_age": d_age, + "d_identifiers": build_identifiers(d_visit), + } + return [sample] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e7991eef9 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,2 @@ +# Tests package + diff --git a/tests/core/test_medlink.py b/tests/core/test_medlink.py index c718d3dc9..ab887ea0e 100644 --- a/tests/core/test_medlink.py +++ b/tests/core/test_medlink.py @@ -9,8 +9,9 @@ class TestMedLink(unittest.TestCase): """Basic tests for the MedLink model on pseudo data.""" def setUp(self): - # Each "sample" here is a simple patient-record placeholder - # The dataset is only used to build the vocabulary via get_all_tokens. + # Each "sample" here is a simple patient-record placeholder. + # The dataset is used to fit SequenceProcessors (vocabularies), which + # MedLink reuses for processor-native indexing. self.samples = [ { "patient_id": "p0", diff --git a/tests/core/test_sdoh.py b/tests/core/test_sdoh.py index 0e8096268..f2eaf98a7 100644 --- a/tests/core/test_sdoh.py +++ b/tests/core/test_sdoh.py @@ -1,5 +1,5 @@ from typing import Set -from base import BaseTestCase +from tests.base import BaseTestCase from pyhealth.models.sdoh import SdohClassifier diff --git a/tests/nlp/test_metrics.py b/tests/nlp/test_metrics.py index 0536153e8..a0fd76a68 100644 --- a/tests/nlp/test_metrics.py +++ b/tests/nlp/test_metrics.py @@ -1,6 +1,6 @@ from typing import List import logging -from base import BaseTestCase +from tests.base import BaseTestCase from pathlib import Path import pandas as pd from pyhealth.nlp.metrics import ( diff --git a/tests/todo/test_datasets/test_eicu.py b/tests/todo/test_datasets/test_eicu.py index fdb466273..1bd0ce470 100644 --- a/tests/todo/test_datasets/test_eicu.py +++ b/tests/todo/test_datasets/test_eicu.py @@ -5,7 +5,7 @@ import pandas from pyhealth.datasets import eICUDataset -from pyhealth.unittests.test_datasets.utils import EHRDatasetStatAssertion +from tests.todo.test_datasets.utils import EHRDatasetStatAssertion class TesteICUDataset(unittest.TestCase): diff --git a/tests/todo/test_datasets/test_mimic3.py b/tests/todo/test_datasets/test_mimic3.py index 2957add0d..fe0fee6ae 100644 --- a/tests/todo/test_datasets/test_mimic3.py +++ b/tests/todo/test_datasets/test_mimic3.py @@ -2,7 +2,7 @@ import unittest from pyhealth.datasets import MIMIC3Dataset -from pyhealth.unittests.test_datasets.utils import EHRDatasetStatAssertion +from tests.todo.test_datasets.utils import EHRDatasetStatAssertion import os, sys current = os.path.dirname(os.path.realpath(__file__)) @@ -30,8 +30,6 @@ class TestsMimic3Dataset(unittest.TestCase): dataset_name=DATASET_NAME, root=ROOT, tables=TABLES, - code_mapping=CODE_MAPPING, - refresh_cache=REFRESH_CACHE, ) def setUp(self): diff --git a/tests/todo/test_datasets/test_mimic4.py b/tests/todo/test_datasets/test_mimic4.py index 0133cbb93..bf21b64ef 100644 --- a/tests/todo/test_datasets/test_mimic4.py +++ b/tests/todo/test_datasets/test_mimic4.py @@ -2,7 +2,7 @@ import unittest from pyhealth.datasets import MIMIC4Dataset -from pyhealth.unittests.test_datasets.utils import EHRDatasetStatAssertion +from tests.todo.test_datasets.utils import EHRDatasetStatAssertion import os, sys @@ -25,17 +25,14 @@ class TestMimic4Dataset(unittest.TestCase): DEV = True # not needed when using demo set since its 100 patients large REFRESH_CACHE = True - dataset = MIMIC4Dataset( - dataset_name=DATASET_NAME, - root=ROOT, - tables=TABLES, - code_mapping=CODE_MAPPING, - dev=DEV, - refresh_cache=REFRESH_CACHE, - ) - def setUp(self): - pass + # Initialize dataset in setUp to avoid loading during test collection + self.dataset = MIMIC4Dataset( + dataset_name=self.DATASET_NAME, + ehr_root=self.ROOT, + ehr_tables=self.TABLES, + dev=self.DEV, + ) # test the dataset integrity based on a single sample. def test_patient(self): diff --git a/tests/todo/test_datasets/test_omop.py b/tests/todo/test_datasets/test_omop.py index f57420659..765c88920 100644 --- a/tests/todo/test_datasets/test_omop.py +++ b/tests/todo/test_datasets/test_omop.py @@ -5,7 +5,7 @@ import collections from pyhealth.datasets import OMOPDataset -from pyhealth.unittests.test_datasets.utils import EHRDatasetStatAssertion +from tests.todo.test_datasets.utils import EHRDatasetStatAssertion class TestOMOPDataset(unittest.TestCase): @@ -25,9 +25,7 @@ class TestOMOPDataset(unittest.TestCase): dataset_name=DATASET_NAME, root=ROOT, tables=TABLES, - code_mapping=CODE_MAPPING, dev=DEV, - refresh_cache=REFRESH_CACHE, ) def setUp(self): diff --git a/tests/todo/test_mortality_prediction.py b/tests/todo/test_mortality_prediction.py index 729640abb..749a944cb 100644 --- a/tests/todo/test_mortality_prediction.py +++ b/tests/todo/test_mortality_prediction.py @@ -53,7 +53,7 @@ def test_mortality_prediction_mimic4(): # Enable dev mode to limit memory usage dataset = MIMIC4Dataset( ehr_root=mimic_iv_root, - notes_root=mimic_note_root, + note_root=mimic_note_root, ehr_tables=[ "patients", # Demographics "admissions", # Admission/discharge info @@ -152,7 +152,7 @@ def test_multimodal_mortality_prediction_mimic4(): # Initialize dataset with comprehensive tables dataset = MIMIC4Dataset( ehr_root=mimic_iv_root, - notes_root=mimic_note_root, + note_root=mimic_note_root, cxr_root=mimic_cxr_root, ehr_tables=[ "patients", # Demographics @@ -274,7 +274,7 @@ def test_multimodal_mortality_prediction_with_images(): # Initialize the dataset with all required tables dataset = MIMIC4Dataset( ehr_root=mimic_iv_root, - notes_root=mimic_note_root, + note_root=mimic_note_root, cxr_root=mimic_cxr_root, ehr_tables=[ "patients", From 41ed5a0db412c1863e73bf63a219f86827ee92fe Mon Sep 17 00:00:00 2001 From: Rian354 Date: Wed, 24 Dec 2025 17:30:43 -0500 Subject: [PATCH 6/8] Docstrings + ehr -> sampledataset --- examples/medlink_mimic3.ipynb | 2 +- pyhealth/models/medlink.py | 44 +++++++++++++++++++++++++++-------- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/examples/medlink_mimic3.ipynb b/examples/medlink_mimic3.ipynb index 421aeeefc..9246436cf 100644 --- a/examples/medlink_mimic3.ipynb +++ b/examples/medlink_mimic3.ipynb @@ -673,7 +673,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.2" + "version": "3.13.3" } }, "nbformat": 4, diff --git a/pyhealth/models/medlink.py b/pyhealth/models/medlink.py index f11334e53..ff6b52f9f 100644 --- a/pyhealth/models/medlink.py +++ b/pyhealth/models/medlink.py @@ -5,7 +5,7 @@ import torch.nn.functional as F import tqdm -from pyhealth.datasets import SampleEHRDataset +from pyhealth.datasets import SampleDataset from pyhealth.models import BaseModel from pyhealth.models.transformer import TransformerLayer from pyhealth.tokenizer import Tokenizer @@ -149,19 +149,43 @@ class MedLink(BaseModel): """ MedLink: de-identified patient record linkage model (KDD 2023). - This implementation is adapted to the PyHealth framework: - * no pre-trained GloVe; embeddings are learned from scratch - * training monitored via loss instead of ranking metrics - - It implements three losses: - - forward admission prediction (corpus -> queries) - - backward admission prediction (queries -> corpus) - - retrieval loss via TF-IDF-style matching + This model links de-identified patient records using admission sequences + and a transformer-based architecture. It is designed to operate on PyHealth's `SampleDataset`. + + Inputs: + - dataset (SampleDataset): The dataset containing patient admission sequences. + - feature_keys (List[str]): List with the key for patient admission codes (only the first is used). + - embedding_dim (int, default=128): Embedding dimension for learned token embeddings. + - alpha, beta, gamma (float): Loss weights for model's multi-loss objective. + - heads (int): Number of transformer heads. + - dropout (float): Dropout rate for transformer encoders. + - num_layers (int): Number of layers in transformer encoders. + + Outputs: + - The model primarily outputs a dictionary {"loss": loss_tensor} during training (see forward method). + - For retrieval/evaluation, the model provides embeddings and search utilities to score record similarity. + + Example: + >>> from pyhealth.datasets import SampleDataset + >>> from pyhealth.models import MedLink + >>> samples = [{"patient_id": "1", "admissions": ["ICD9_430", "ICD9_401"]}, ...] + >>> input_schema = {"admissions": "code"} + >>> output_schema = {"label": "binary"} + >>> dataset = SampleDataset(samples, input_schema, output_schema) + >>> model = MedLink(dataset=dataset, feature_keys=["admissions"]) + >>> batch = {"query_id": [...], "id_p": [...], "s_q": [["ICD9_430", "ICD9_401"]], "s_p": [[...]], "s_n": None} + >>> out = model(**batch) + >>> print(out["loss"]) + + Notes: + - Only works with a single feature_key (list of length 1). + - Specialized for code sequence/text-based features (e.g., admissions). + - Retrieval is performed via TF-IDF-style similarity on learned multi-hot embeddings. """ def __init__( self, - dataset: SampleEHRDataset, + dataset: SampleDataset, feature_keys: List[str], embedding_dim: int = 128, alpha: float = 0.5, From 6f06e4912a986ceeb3eeeb64b31ddd16c4551a4f Mon Sep 17 00:00:00 2001 From: Rian354 Date: Wed, 24 Dec 2025 19:25:30 -0500 Subject: [PATCH 7/8] Path config for datasets, build error --- pyhealth/models/medlink.py | 2 +- tests/core/test_medlink.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pyhealth/models/medlink.py b/pyhealth/models/medlink.py index ff6b52f9f..e28a17399 100644 --- a/pyhealth/models/medlink.py +++ b/pyhealth/models/medlink.py @@ -171,7 +171,7 @@ class MedLink(BaseModel): >>> samples = [{"patient_id": "1", "admissions": ["ICD9_430", "ICD9_401"]}, ...] >>> input_schema = {"admissions": "code"} >>> output_schema = {"label": "binary"} - >>> dataset = SampleDataset(samples, input_schema, output_schema) + >>> dataset = SampleDataset(path="/some/path", samples=samples, input_schema=input_schema, output_schema=output_schema) >>> model = MedLink(dataset=dataset, feature_keys=["admissions"]) >>> batch = {"query_id": [...], "id_p": [...], "s_q": [["ICD9_430", "ICD9_401"]], "s_p": [[...]], "s_n": None} >>> out = model(**batch) diff --git a/tests/core/test_medlink.py b/tests/core/test_medlink.py index ab887ea0e..90f24332b 100644 --- a/tests/core/test_medlink.py +++ b/tests/core/test_medlink.py @@ -38,6 +38,7 @@ def setUp(self): self.output_schema = {} self.dataset = SampleDataset( + path="dummy_path", samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, From fa0afcd2031d9978732ed1350349217b55636c7c Mon Sep 17 00:00:00 2001 From: Rian354 Date: Sat, 27 Dec 2025 06:18:55 -0500 Subject: [PATCH 8/8] samples helper mismatch --- pyhealth/models/medlink.py | 2 +- tests/core/test_medlink.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyhealth/models/medlink.py b/pyhealth/models/medlink.py index e28a17399..9fd30aa92 100644 --- a/pyhealth/models/medlink.py +++ b/pyhealth/models/medlink.py @@ -175,7 +175,7 @@ class MedLink(BaseModel): >>> model = MedLink(dataset=dataset, feature_keys=["admissions"]) >>> batch = {"query_id": [...], "id_p": [...], "s_q": [["ICD9_430", "ICD9_401"]], "s_p": [[...]], "s_n": None} >>> out = model(**batch) - >>> print(out["loss"]) + >>> print(out["loss"])is Notes: - Only works with a single feature_key (list of length 1). diff --git a/tests/core/test_medlink.py b/tests/core/test_medlink.py index 90f24332b..80c9a01bc 100644 --- a/tests/core/test_medlink.py +++ b/tests/core/test_medlink.py @@ -1,7 +1,7 @@ import unittest import torch -from pyhealth.datasets import SampleDataset +from pyhealth.datasets import create_sample_dataset from pyhealth.models import MedLink @@ -37,12 +37,12 @@ def setUp(self): # No labels are needed; MedLink is self-supervised self.output_schema = {} - self.dataset = SampleDataset( - path="dummy_path", + self.dataset = create_sample_dataset( samples=self.samples, input_schema=self.input_schema, output_schema=self.output_schema, dataset_name="medlink_test", + in_memory=True, ) self.model = MedLink(