From 3cf3b451517407ff9f084fcfb31d824a34ad9c46 Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Tue, 16 Dec 2025 03:17:29 +0000 Subject: [PATCH 01/10] ReadmissionPredictionMIMIC3 task (untested) --- pyhealth/tasks/readmission_prediction.py | 93 +++++++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) diff --git a/pyhealth/tasks/readmission_prediction.py b/pyhealth/tasks/readmission_prediction.py index 9d04b5eb7..55f33f108 100644 --- a/pyhealth/tasks/readmission_prediction.py +++ b/pyhealth/tasks/readmission_prediction.py @@ -1,5 +1,96 @@ -from pyhealth.data import Patient, Visit +from datetime import datetime, timedelta +from typing import Dict, List + +import polars as pl + +from pyhealth.data import Event, Patient +from pyhealth.tasks import BaseTask + +class ReadmissionPredictionMIMIC3(BaseTask): + #todo: add unit tests (demo dataset? synthetic dataset? my own synthetic dataset?) + #todo: add doc strings + #todo: replace examples + #todo: update docs (replace all references to readmission_prediction_mimic3_fn) + #todo: deprecate readmission_prediction_mimic3_fn (make it a wrapper around this) + #todo: review other similar tasks for best practices and common patterns + task_name: str = "ReadmissionPredictionMIMIC3" + input_schema: Dict[str, str] = {"conditions": "sequence", "procedures": "sequence", "drugs": "sequence"} + output_schema: Dict[str, str] = {"label": "binary"} + + def __init__(self, window: timedelta, min_admission_length: timedelta = timedelta(0)) -> None: + self.window = window + self.min_admission_len = min_admission_length + + def __call__(self, patient: Patient) -> List[Dict]: + patients: List[Event] = patient.get_events(event_type="patients") + assert len(patients) == 1 + if int(patients[0]["anchor_age"]) < 18: + return [] + + admissions: List[Event] = patient.get_events(event_type="admissions") + if len(admissions) < 2: + return [] + + samples = [] + for i in range(len(admissions) - 1): # Skip the last admission since we need a "next" admission + #todo: Exclude visits where the patient is under 18 + # if int(admissions[0].timestamp - patients[0]["dob"]) < 18: + # continue + + discharge_time = datetime.strptime(admissions[i].dischtime, "%Y-%m-%d %H:%M:%S") + if (discharge_time - admissions[i].timestamp) < self.min_admission_len: + continue + + diagnoses_icd = patient.get_events( + event_type="diagnoses_icd", + start=admissions[i].timestamp, + end=discharge_time, + return_df=True + ) + conditions = diagnoses_icd.select( + pl.concat_str(["diagnoses_icd/icd_version", "diagnoses_icd/icd_code"], separator="_") + ).to_series().to_list() + if len(conditions) == 0: #todo: move the short-circuit before the conversion (here and below) + continue + + procedures_icd = patient.get_events( + event_type="procedures_icd", + start=admissions[i].timestamp, + end=discharge_time, + return_df=True + ) + procedures = procedures_icd.select( + pl.concat_str(["procedures_icd/icd_version", "procedures_icd/icd_code"], separator="_") + ).to_series().to_list() + if len(procedures) == 0: #todo: confirm we want conditions AND procedures AND drugs (instead of OR) + continue + + prescriptions = patient.get_events( + event_type="prescriptions", + start=admissions[i].timestamp, + end=discharge_time, + return_df=True + ) + drugs = prescriptions.select( + pl.concat_str(["prescriptions/drug"], separator="_") + ).to_series().to_list() + if len(drugs) == 0: + continue + + readmission = int((admissions[i + 1].timestamp - discharge_time) < self.window) + + samples.append( + { + "patient_id": patient.patient_id, + "admission_id": admissions[0].hadm_id, + "conditions": conditions, + "procedures": procedures, + "drugs": drugs, + "readmission": readmission, + } + ) + return samples # TODO: time_window cannot be passed in to base_dataset def readmission_prediction_mimic3_fn(patient: Patient, time_window=15): From f6d596ad74959140ba72cefd01029ec714046e7f Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Mon, 22 Dec 2025 18:59:54 +0000 Subject: [PATCH 02/10] Basic task schema unit test --- pyhealth/tasks/__init__.py | 1 + pyhealth/tasks/readmission_prediction.py | 2 -- tests/core/test_mimic3_readmission_prediction.py | 12 ++++++++++++ 3 files changed, 13 insertions(+), 2 deletions(-) create mode 100644 tests/core/test_mimic3_readmission_prediction.py diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index bcfb19f7a..92b90f1d3 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -54,6 +54,7 @@ readmission_prediction_mimic3_fn, readmission_prediction_mimic4_fn, readmission_prediction_omop_fn, + ReadmissionPredictionMIMIC3, ) from .sleep_staging import ( sleep_staging_isruc_fn, diff --git a/pyhealth/tasks/readmission_prediction.py b/pyhealth/tasks/readmission_prediction.py index 55f33f108..853d1ac54 100644 --- a/pyhealth/tasks/readmission_prediction.py +++ b/pyhealth/tasks/readmission_prediction.py @@ -24,8 +24,6 @@ def __init__(self, window: timedelta, min_admission_length: timedelta = timedelt def __call__(self, patient: Patient) -> List[Dict]: patients: List[Event] = patient.get_events(event_type="patients") assert len(patients) == 1 - if int(patients[0]["anchor_age"]) < 18: - return [] admissions: List[Event] = patient.get_events(event_type="admissions") if len(admissions) < 2: diff --git a/tests/core/test_mimic3_readmission_prediction.py b/tests/core/test_mimic3_readmission_prediction.py new file mode 100644 index 000000000..8eb88ba4c --- /dev/null +++ b/tests/core/test_mimic3_readmission_prediction.py @@ -0,0 +1,12 @@ +import unittest + +from pyhealth.tasks import ReadmissionPredictionMIMIC3 + +class TestReadmissionPredictionMIMIC3(unittest.TestCase): + def test_task_schema(self): + self.assertIn("task_name", vars(ReadmissionPredictionMIMIC3)) + self.assertIn("input_schema", vars(ReadmissionPredictionMIMIC3)) + self.assertIn("output_schema", vars(ReadmissionPredictionMIMIC3)) + +if __name__ == "__main__": + unittest.main() From 0bab01788428788edf446762768480872fb36ff7 Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Tue, 23 Dec 2025 19:16:08 +0000 Subject: [PATCH 03/10] Add basic unit tests --- pyhealth/tasks/readmission_prediction.py | 56 +++----- .../test_mimic3_readmission_prediction.py | 125 ++++++++++++++++++ 2 files changed, 140 insertions(+), 41 deletions(-) diff --git a/pyhealth/tasks/readmission_prediction.py b/pyhealth/tasks/readmission_prediction.py index 853d1ac54..fe3202edb 100644 --- a/pyhealth/tasks/readmission_prediction.py +++ b/pyhealth/tasks/readmission_prediction.py @@ -13,13 +13,13 @@ class ReadmissionPredictionMIMIC3(BaseTask): #todo: update docs (replace all references to readmission_prediction_mimic3_fn) #todo: deprecate readmission_prediction_mimic3_fn (make it a wrapper around this) #todo: review other similar tasks for best practices and common patterns + #todo: add short-circuits to loop and test sample generation speed difference task_name: str = "ReadmissionPredictionMIMIC3" - input_schema: Dict[str, str] = {"conditions": "sequence", "procedures": "sequence", "drugs": "sequence"} - output_schema: Dict[str, str] = {"label": "binary"} + input_schema: Dict[str, str] = {"diagnoses": "sequence", "prescriptions": "sequence", "procedures": "sequence"} + output_schema: Dict[str, str] = {"readmission": "binary"} - def __init__(self, window: timedelta, min_admission_length: timedelta = timedelta(0)) -> None: + def __init__(self, window: timedelta) -> None: self.window = window - self.min_admission_len = min_admission_length def __call__(self, patient: Patient) -> List[Dict]: patients: List[Event] = patient.get_events(event_type="patients") @@ -36,43 +36,17 @@ def __call__(self, patient: Patient) -> List[Dict]: # continue discharge_time = datetime.strptime(admissions[i].dischtime, "%Y-%m-%d %H:%M:%S") - if (discharge_time - admissions[i].timestamp) < self.min_admission_len: - continue - diagnoses_icd = patient.get_events( - event_type="diagnoses_icd", - start=admissions[i].timestamp, - end=discharge_time, - return_df=True - ) - conditions = diagnoses_icd.select( - pl.concat_str(["diagnoses_icd/icd_version", "diagnoses_icd/icd_code"], separator="_") - ).to_series().to_list() - if len(conditions) == 0: #todo: move the short-circuit before the conversion (here and below) - continue + filter = ("hadm_id", "==", admissions[i].hadm_id) + diagnoses = patient.get_events(event_type="diagnoses_icd", filters=[filter]) + procedures = patient.get_events(event_type="procedures_icd", filters=[filter]) + prescriptions = patient.get_events(event_type="prescriptions", filters=[filter]) - procedures_icd = patient.get_events( - event_type="procedures_icd", - start=admissions[i].timestamp, - end=discharge_time, - return_df=True - ) - procedures = procedures_icd.select( - pl.concat_str(["procedures_icd/icd_version", "procedures_icd/icd_code"], separator="_") - ).to_series().to_list() - if len(procedures) == 0: #todo: confirm we want conditions AND procedures AND drugs (instead of OR) - continue + diagnoses = [event.icd9_code for event in diagnoses] + procedures = [event.icd9_code for event in procedures] + prescriptions = [event.drug for event in prescriptions] - prescriptions = patient.get_events( - event_type="prescriptions", - start=admissions[i].timestamp, - end=discharge_time, - return_df=True - ) - drugs = prescriptions.select( - pl.concat_str(["prescriptions/drug"], separator="_") - ).to_series().to_list() - if len(drugs) == 0: + if len(diagnoses) * len(procedures) * len(prescriptions) == 0: continue readmission = int((admissions[i + 1].timestamp - discharge_time) < self.window) @@ -80,10 +54,10 @@ def __call__(self, patient: Patient) -> List[Dict]: samples.append( { "patient_id": patient.patient_id, - "admission_id": admissions[0].hadm_id, - "conditions": conditions, + "admission_id": admissions[i].hadm_id, + "diagnoses": diagnoses, + "prescriptions": prescriptions, "procedures": procedures, - "drugs": drugs, "readmission": readmission, } ) diff --git a/tests/core/test_mimic3_readmission_prediction.py b/tests/core/test_mimic3_readmission_prediction.py index 8eb88ba4c..2b29e60ce 100644 --- a/tests/core/test_mimic3_readmission_prediction.py +++ b/tests/core/test_mimic3_readmission_prediction.py @@ -1,12 +1,137 @@ +from datetime import timedelta +import os +import shutil import unittest +from pyhealth.datasets import MIMIC3Dataset from pyhealth.tasks import ReadmissionPredictionMIMIC3 class TestReadmissionPredictionMIMIC3(unittest.TestCase): + @classmethod + def setUpClass(cls): + if os.path.exists("test"): + shutil.rmtree("test") + os.makedirs("test") + + patients = [ + "row_id,subject_id,gender,dob,dod,dod_hosp,dod_ssn,expire_flag", + "1,1,,2000-01-01 00:00:00,,,,", + "2,2,,2000-01-01 00:00:00,,,,", + "3,3,,2000-01-01 00:00:00,,,,", + ] + with open("test/PATIENTS.csv", 'w') as f: + f.write("\n".join(patients)) + + admissions = [ + "row_id,subject_id,hadm_id,admittime,dischtime,deathtime,admission_type,admission_location,discharge_location,insurance,language,religion,marital_status,ethnicity,edregtime,edouttime,diagnosis,hospital_expire_flag,has_chartevents_data", + "1,1,1,2020-01-01 00:00:00,2020-01-01 12:00:00,,,,,,,,,,,,,,", + "2,2,2,2020-01-01 00:00:00,2020-01-01 12:00:00,,,,,,,,,,,,,,", + "3,2,3,2020-01-31 11:00:00,2020-01-31 12:00:00,,,,,,,,,,,,,,", + "4,3,4,2020-01-01 00:00:00,2020-01-01 12:00:00,,,,,,,,,,,,,,", + "5,3,5,2020-01-31 12:00:00,2020-01-31 13:00:00,,,,,,,,,,,,,,", + ] + with open("test/ADMISSIONS.csv", 'w') as f: + f.write("\n".join(admissions)) + + icu_stays = [ + "subject_id,intime,icustay_id,first_careunit,dbsource,last_careunit,outtime", + ] + with open("test/ICUSTAYS.csv", 'w') as f: + f.write("\n".join(icu_stays)) + + diagnoses = [ + "row_id,subject_id,hadm_id,seq_num,icd9_code", + "1,1,1,1,", + "2,2,2,1,", + "3,2,3,1,", + "4,3,4,1,", + "5,3,5,1,", + ] + with open("test/DIAGNOSES_ICD.csv", 'w') as f: + f.write("\n".join(diagnoses)) + + prescriptions = [ + "row_id,subject_id,hadm_id,icustay_id,startdate,enddate,drug_type,drug,drug_name_poe,drug_name_generic,formulary_drug_cd,gsn,ndc,prod_strength,dose_val_rx,dose_unit_rx,form_val_disp,form_unit_disp,route", + "1,1,1,,,,,,,,,,,,,,,,", + "2,2,2,,,,,,,,,,,,,,,,", + "3,2,3,,,,,,,,,,,,,,,,", + "4,3,4,,,,,,,,,,,,,,,,", + "5,3,5,,,,,,,,,,,,,,,,", + ] + with open("test/PRESCRIPTIONS.csv", 'w') as f: + f.write("\n".join(prescriptions)) + + procedures = [ + "row_id,subject_id,hadm_id,seq_num,icd9_code", + "1,1,1,1,", + "2,2,2,1,", + "3,2,3,1,", + "4,3,4,1,", + "5,3,5,1,", + ] + with open("test/PROCEDURES_ICD.csv", 'w') as f: + f.write("\n".join(procedures)) + + dataset = MIMIC3Dataset(root="./test", tables=["diagnoses_icd", "prescriptions", "procedures_icd"]) + + cls.samples = dataset.set_task(ReadmissionPredictionMIMIC3(timedelta(days=30))) + + @classmethod + def tearDownClass(cls): + if os.path.exists("test"): + shutil.rmtree("test") + def test_task_schema(self): self.assertIn("task_name", vars(ReadmissionPredictionMIMIC3)) self.assertIn("input_schema", vars(ReadmissionPredictionMIMIC3)) self.assertIn("output_schema", vars(ReadmissionPredictionMIMIC3)) + def test_sample_schema(self): + for sample in self.samples: + self.assertIn("patient_id", sample) + self.assertIn("admission_id", sample) + self.assertIn("diagnoses", sample) + self.assertIn("prescriptions", sample) + self.assertIn("procedures", sample) + self.assertIn("readmission", sample) + + def test_expected_num_samples(self): + self.assertEqual(len(self.samples), 2) + + def test_patient_with_only_one_visit_is_excluded(self): + self.assertTrue(all(sample["patient_id"] != '1' for sample in self.samples)) + self.assertTrue(all(sample["admission_id"] != '1' for sample in self.samples)) + + def test_positive_sample(self): + samples = [sample for sample in self.samples if sample["admission_id"] == '2'] + self.assertEqual(len(samples), 1) + self.assertEqual(samples[0]["readmission"], 1) + + def test_negative_sample(self): + samples = [sample for sample in self.samples if sample["admission_id"] == '4'] + self.assertEqual(len(samples), 1) + self.assertEqual(samples[0]["readmission"], 0) + + def test_patient_with_positive_and_negative_samples(self): + pass + + def test_last_admission_is_excluded_since_no_readmission_data(self): + samples = [sample for sample in self.samples if sample["admission_id"] == '3'] + self.assertEqual(len(samples), 0) + samples = [sample for sample in self.samples if sample["admission_id"] == '5'] + self.assertEqual(len(samples), 0) + + def test_admissions_without_diagnoses_are_excluded(self): + pass + + def test_admissions_without_prescriptions_are_excluded(self): + pass + + def test_admissions_without_procedures_are_excluded(self): + pass + + def test_admissions_of_minors_are_excluded(self): + pass + if __name__ == "__main__": unittest.main() From f275473462f8c4d3892c67f7fa82f5829a3f37a4 Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Tue, 23 Dec 2025 21:23:22 +0000 Subject: [PATCH 04/10] Add age check and more unit tests --- pyhealth/tasks/readmission_prediction.py | 12 ++-- .../test_mimic3_readmission_prediction.py | 64 ++++++++++++++++--- 2 files changed, 61 insertions(+), 15 deletions(-) diff --git a/pyhealth/tasks/readmission_prediction.py b/pyhealth/tasks/readmission_prediction.py index fe3202edb..b38595f1b 100644 --- a/pyhealth/tasks/readmission_prediction.py +++ b/pyhealth/tasks/readmission_prediction.py @@ -7,7 +7,6 @@ from pyhealth.tasks import BaseTask class ReadmissionPredictionMIMIC3(BaseTask): - #todo: add unit tests (demo dataset? synthetic dataset? my own synthetic dataset?) #todo: add doc strings #todo: replace examples #todo: update docs (replace all references to readmission_prediction_mimic3_fn) @@ -24,6 +23,7 @@ def __init__(self, window: timedelta) -> None: def __call__(self, patient: Patient) -> List[Dict]: patients: List[Event] = patient.get_events(event_type="patients") assert len(patients) == 1 + dob = datetime.strptime(patients[0].dob, "%Y-%m-%d %H:%M:%S") admissions: List[Event] = patient.get_events(event_type="admissions") if len(admissions) < 2: @@ -31,11 +31,10 @@ def __call__(self, patient: Patient) -> List[Dict]: samples = [] for i in range(len(admissions) - 1): # Skip the last admission since we need a "next" admission - #todo: Exclude visits where the patient is under 18 - # if int(admissions[0].timestamp - patients[0]["dob"]) < 18: - # continue - - discharge_time = datetime.strptime(admissions[i].dischtime, "%Y-%m-%d %H:%M:%S") + age = admissions[i].timestamp.year - dob.year + age = age-1 if ((admissions[i].timestamp.month, admissions[i].timestamp.day) < (dob.month, dob.day)) else age + if age < 18: + continue filter = ("hadm_id", "==", admissions[i].hadm_id) diagnoses = patient.get_events(event_type="diagnoses_icd", filters=[filter]) @@ -49,6 +48,7 @@ def __call__(self, patient: Patient) -> List[Dict]: if len(diagnoses) * len(procedures) * len(prescriptions) == 0: continue + discharge_time = datetime.strptime(admissions[i].dischtime, "%Y-%m-%d %H:%M:%S") readmission = int((admissions[i + 1].timestamp - discharge_time) < self.window) samples.append( diff --git a/tests/core/test_mimic3_readmission_prediction.py b/tests/core/test_mimic3_readmission_prediction.py index 2b29e60ce..967acc648 100644 --- a/tests/core/test_mimic3_readmission_prediction.py +++ b/tests/core/test_mimic3_readmission_prediction.py @@ -18,6 +18,9 @@ def setUpClass(cls): "1,1,,2000-01-01 00:00:00,,,,", "2,2,,2000-01-01 00:00:00,,,,", "3,3,,2000-01-01 00:00:00,,,,", + "4,4,,2000-01-01 00:00:00,,,,", + "5,5,,2000-01-01 00:00:00,,,,", + "6,6,,2000-01-01 00:00:00,,,,", ] with open("test/PATIENTS.csv", 'w') as f: f.write("\n".join(patients)) @@ -29,6 +32,16 @@ def setUpClass(cls): "3,2,3,2020-01-31 11:00:00,2020-01-31 12:00:00,,,,,,,,,,,,,,", "4,3,4,2020-01-01 00:00:00,2020-01-01 12:00:00,,,,,,,,,,,,,,", "5,3,5,2020-01-31 12:00:00,2020-01-31 13:00:00,,,,,,,,,,,,,,", + "6,4,6,2020-01-01 00:00:00,2020-01-01 12:00:00,,,,,,,,,,,,,,", + "7,4,7,2020-02-01 00:00:00,2020-02-01 12:00:00,,,,,,,,,,,,,,", + "8,4,8,2020-02-02 00:00:00,2020-02-02 12:00:00,,,,,,,,,,,,,,", + "9,5,9,2020-01-01 00:00:00,2020-01-01 12:00:00,,,,,,,,,,,,,,", + "10,5,10,2020-01-02 00:00:00,2020-01-02 12:00:00,,,,,,,,,,,,,,", + "11,5,11,2020-01-03 00:00:00,2020-01-03 12:00:00,,,,,,,,,,,,,,", + "12,5,12,2020-01-04 00:00:00,2020-01-04 12:00:00,,,,,,,,,,,,,,", + "13,6,13,2017-12-31 23:59:59,2018-01-01 00:00:00,,,,,,,,,,,,,,", + "14,6,14,2018-01-01 00:00:00,2018-01-01 12:00:00,,,,,,,,,,,,,,", + "15,6,15,2020-01-01 00:00:00,2020-01-01 12:00:00,,,,,,,,,,,,,,", ] with open("test/ADMISSIONS.csv", 'w') as f: f.write("\n".join(admissions)) @@ -46,6 +59,15 @@ def setUpClass(cls): "3,2,3,1,", "4,3,4,1,", "5,3,5,1,", + "6,4,6,1,", + "7,4,7,1,", + "8,4,8,1,", + "9,5,9,1,", + "10,5,10,1,", + "11,5,12,1,", + "12,6,13,1,", + "13,6,14,1,", + "14,6,15,1,", ] with open("test/DIAGNOSES_ICD.csv", 'w') as f: f.write("\n".join(diagnoses)) @@ -57,6 +79,15 @@ def setUpClass(cls): "3,2,3,,,,,,,,,,,,,,,,", "4,3,4,,,,,,,,,,,,,,,,", "5,3,5,,,,,,,,,,,,,,,,", + "6,4,6,,,,,,,,,,,,,,,,", + "7,4,7,,,,,,,,,,,,,,,,", + "8,4,8,,,,,,,,,,,,,,,,", + "9,5,9,,,,,,,,,,,,,,,,", + "10,5,11,,,,,,,,,,,,,,,,", + "11,5,12,,,,,,,,,,,,,,,,", + "12,6,13,,,,,,,,,,,,,,,,", + "13,6,14,,,,,,,,,,,,,,,,", + "14,6,15,,,,,,,,,,,,,,,,", ] with open("test/PRESCRIPTIONS.csv", 'w') as f: f.write("\n".join(prescriptions)) @@ -68,6 +99,15 @@ def setUpClass(cls): "3,2,3,1,", "4,3,4,1,", "5,3,5,1,", + "6,4,6,1,", + "7,4,7,1,", + "8,4,8,1,", + "9,5,10,1,", + "10,5,11,1,", + "11,5,12,1,", + "12,6,13,1,", + "13,6,14,1,", + "14,6,15,1,", ] with open("test/PROCEDURES_ICD.csv", 'w') as f: f.write("\n".join(procedures)) @@ -96,7 +136,7 @@ def test_sample_schema(self): self.assertIn("readmission", sample) def test_expected_num_samples(self): - self.assertEqual(len(self.samples), 2) + self.assertEqual(len(self.samples), 5) def test_patient_with_only_one_visit_is_excluded(self): self.assertTrue(all(sample["patient_id"] != '1' for sample in self.samples)) @@ -113,25 +153,31 @@ def test_negative_sample(self): self.assertEqual(samples[0]["readmission"], 0) def test_patient_with_positive_and_negative_samples(self): - pass + samples = [sample for sample in self.samples if sample["admission_id"] == '6'] + self.assertEqual(len(samples), 1) + self.assertEqual(samples[0]["readmission"], 0) + samples = [sample for sample in self.samples if sample["admission_id"] == '7'] + self.assertEqual(len(samples), 1) + self.assertEqual(samples[0]["readmission"], 1) def test_last_admission_is_excluded_since_no_readmission_data(self): - samples = [sample for sample in self.samples if sample["admission_id"] == '3'] - self.assertEqual(len(samples), 0) - samples = [sample for sample in self.samples if sample["admission_id"] == '5'] + samples = [sample for sample in self.samples if sample["admission_id"] in ('3', '5', '8', '12', '15')] self.assertEqual(len(samples), 0) def test_admissions_without_diagnoses_are_excluded(self): - pass + self.assertTrue(all(sample["admission_id"] != '11' for sample in self.samples)) def test_admissions_without_prescriptions_are_excluded(self): - pass + self.assertTrue(all(sample["admission_id"] != '10' for sample in self.samples)) def test_admissions_without_procedures_are_excluded(self): - pass + self.assertTrue(all(sample["admission_id"] != '9' for sample in self.samples)) def test_admissions_of_minors_are_excluded(self): - pass + self.assertTrue(all(sample["admission_id"] != '13' for sample in self.samples)) + samples = [sample for sample in self.samples if sample["admission_id"] == '14'] + self.assertEqual(len(samples), 1) + self.assertEqual(samples[0]["readmission"], 0) if __name__ == "__main__": unittest.main() From f72cd5b56c4f6408bdc4bbe415f09439fb7e21f8 Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Tue, 23 Dec 2025 23:34:39 +0000 Subject: [PATCH 05/10] Add MIMIC3 mock for generating test cases --- .../test_mimic3_readmission_prediction.py | 327 ++++++++++-------- 1 file changed, 176 insertions(+), 151 deletions(-) diff --git a/tests/core/test_mimic3_readmission_prediction.py b/tests/core/test_mimic3_readmission_prediction.py index 967acc648..b1d47ce34 100644 --- a/tests/core/test_mimic3_readmission_prediction.py +++ b/tests/core/test_mimic3_readmission_prediction.py @@ -1,133 +1,30 @@ from datetime import timedelta import os -import shutil import unittest from pyhealth.datasets import MIMIC3Dataset from pyhealth.tasks import ReadmissionPredictionMIMIC3 + class TestReadmissionPredictionMIMIC3(unittest.TestCase): - @classmethod - def setUpClass(cls): - if os.path.exists("test"): - shutil.rmtree("test") - os.makedirs("test") - - patients = [ - "row_id,subject_id,gender,dob,dod,dod_hosp,dod_ssn,expire_flag", - "1,1,,2000-01-01 00:00:00,,,,", - "2,2,,2000-01-01 00:00:00,,,,", - "3,3,,2000-01-01 00:00:00,,,,", - "4,4,,2000-01-01 00:00:00,,,,", - "5,5,,2000-01-01 00:00:00,,,,", - "6,6,,2000-01-01 00:00:00,,,,", - ] - with open("test/PATIENTS.csv", 'w') as f: - f.write("\n".join(patients)) - - admissions = [ - "row_id,subject_id,hadm_id,admittime,dischtime,deathtime,admission_type,admission_location,discharge_location,insurance,language,religion,marital_status,ethnicity,edregtime,edouttime,diagnosis,hospital_expire_flag,has_chartevents_data", - "1,1,1,2020-01-01 00:00:00,2020-01-01 12:00:00,,,,,,,,,,,,,,", - "2,2,2,2020-01-01 00:00:00,2020-01-01 12:00:00,,,,,,,,,,,,,,", - "3,2,3,2020-01-31 11:00:00,2020-01-31 12:00:00,,,,,,,,,,,,,,", - "4,3,4,2020-01-01 00:00:00,2020-01-01 12:00:00,,,,,,,,,,,,,,", - "5,3,5,2020-01-31 12:00:00,2020-01-31 13:00:00,,,,,,,,,,,,,,", - "6,4,6,2020-01-01 00:00:00,2020-01-01 12:00:00,,,,,,,,,,,,,,", - "7,4,7,2020-02-01 00:00:00,2020-02-01 12:00:00,,,,,,,,,,,,,,", - "8,4,8,2020-02-02 00:00:00,2020-02-02 12:00:00,,,,,,,,,,,,,,", - "9,5,9,2020-01-01 00:00:00,2020-01-01 12:00:00,,,,,,,,,,,,,,", - "10,5,10,2020-01-02 00:00:00,2020-01-02 12:00:00,,,,,,,,,,,,,,", - "11,5,11,2020-01-03 00:00:00,2020-01-03 12:00:00,,,,,,,,,,,,,,", - "12,5,12,2020-01-04 00:00:00,2020-01-04 12:00:00,,,,,,,,,,,,,,", - "13,6,13,2017-12-31 23:59:59,2018-01-01 00:00:00,,,,,,,,,,,,,,", - "14,6,14,2018-01-01 00:00:00,2018-01-01 12:00:00,,,,,,,,,,,,,,", - "15,6,15,2020-01-01 00:00:00,2020-01-01 12:00:00,,,,,,,,,,,,,,", - ] - with open("test/ADMISSIONS.csv", 'w') as f: - f.write("\n".join(admissions)) - - icu_stays = [ - "subject_id,intime,icustay_id,first_careunit,dbsource,last_careunit,outtime", - ] - with open("test/ICUSTAYS.csv", 'w') as f: - f.write("\n".join(icu_stays)) - - diagnoses = [ - "row_id,subject_id,hadm_id,seq_num,icd9_code", - "1,1,1,1,", - "2,2,2,1,", - "3,2,3,1,", - "4,3,4,1,", - "5,3,5,1,", - "6,4,6,1,", - "7,4,7,1,", - "8,4,8,1,", - "9,5,9,1,", - "10,5,10,1,", - "11,5,12,1,", - "12,6,13,1,", - "13,6,14,1,", - "14,6,15,1,", - ] - with open("test/DIAGNOSES_ICD.csv", 'w') as f: - f.write("\n".join(diagnoses)) - - prescriptions = [ - "row_id,subject_id,hadm_id,icustay_id,startdate,enddate,drug_type,drug,drug_name_poe,drug_name_generic,formulary_drug_cd,gsn,ndc,prod_strength,dose_val_rx,dose_unit_rx,form_val_disp,form_unit_disp,route", - "1,1,1,,,,,,,,,,,,,,,,", - "2,2,2,,,,,,,,,,,,,,,,", - "3,2,3,,,,,,,,,,,,,,,,", - "4,3,4,,,,,,,,,,,,,,,,", - "5,3,5,,,,,,,,,,,,,,,,", - "6,4,6,,,,,,,,,,,,,,,,", - "7,4,7,,,,,,,,,,,,,,,,", - "8,4,8,,,,,,,,,,,,,,,,", - "9,5,9,,,,,,,,,,,,,,,,", - "10,5,11,,,,,,,,,,,,,,,,", - "11,5,12,,,,,,,,,,,,,,,,", - "12,6,13,,,,,,,,,,,,,,,,", - "13,6,14,,,,,,,,,,,,,,,,", - "14,6,15,,,,,,,,,,,,,,,,", - ] - with open("test/PRESCRIPTIONS.csv", 'w') as f: - f.write("\n".join(prescriptions)) - - procedures = [ - "row_id,subject_id,hadm_id,seq_num,icd9_code", - "1,1,1,1,", - "2,2,2,1,", - "3,2,3,1,", - "4,3,4,1,", - "5,3,5,1,", - "6,4,6,1,", - "7,4,7,1,", - "8,4,8,1,", - "9,5,10,1,", - "10,5,11,1,", - "11,5,12,1,", - "12,6,13,1,", - "13,6,14,1,", - "14,6,15,1,", - ] - with open("test/PROCEDURES_ICD.csv", 'w') as f: - f.write("\n".join(procedures)) - - dataset = MIMIC3Dataset(root="./test", tables=["diagnoses_icd", "prescriptions", "procedures_icd"]) - - cls.samples = dataset.set_task(ReadmissionPredictionMIMIC3(timedelta(days=30))) - - @classmethod - def tearDownClass(cls): - if os.path.exists("test"): - shutil.rmtree("test") - - def test_task_schema(self): + def setUp(self): + """Seed dataset with neg and pos 5 day readmission examples (min required for sample generation)""" + self.mock = MockMICIC3Dataset() + self.patient1 = self.mock.add_patient() + self.admission1 = self.mock.add_admission(self.patient1, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + self.admission2 = self.mock.add_admission(self.patient1, "2020-01-06 12:00:00", "2020-01-06 12:00:01") # Exactly 5 days later + self.admission3 = self.mock.add_admission(self.patient1, "2020-01-11 12:00:00", "2020-01-11 12:00:01") # 5 days later less 1 second + + def test_patient_with_pos_and_neg_samples(self): + dataset = self.mock.create() + + samples = dataset.set_task(ReadmissionPredictionMIMIC3(timedelta(days=5))) + self.assertIn("task_name", vars(ReadmissionPredictionMIMIC3)) self.assertIn("input_schema", vars(ReadmissionPredictionMIMIC3)) self.assertIn("output_schema", vars(ReadmissionPredictionMIMIC3)) - def test_sample_schema(self): - for sample in self.samples: + for sample in samples: self.assertIn("patient_id", sample) self.assertIn("admission_id", sample) self.assertIn("diagnoses", sample) @@ -135,49 +32,177 @@ def test_sample_schema(self): self.assertIn("procedures", sample) self.assertIn("readmission", sample) - def test_expected_num_samples(self): - self.assertEqual(len(self.samples), 5) + self.assertEqual(len(samples), 2) + + neg_samples = [s for s in samples if s["readmission"] == 0] + pos_samples = [s for s in samples if s["readmission"] == 1] + + self.assertEqual(len(neg_samples), 1) + self.assertEqual(len(pos_samples), 1) + + self.assertEqual(neg_samples[0]["admission_id"], str(self.admission1)) + self.assertEqual(pos_samples[0]["admission_id"], str(self.admission2)) + + self.assertTrue(all(s["admission_id"] != str(self.admission3) for s in samples)) # Patient's last admission not included def test_patient_with_only_one_visit_is_excluded(self): - self.assertTrue(all(sample["patient_id"] != '1' for sample in self.samples)) - self.assertTrue(all(sample["admission_id"] != '1' for sample in self.samples)) - - def test_positive_sample(self): - samples = [sample for sample in self.samples if sample["admission_id"] == '2'] - self.assertEqual(len(samples), 1) - self.assertEqual(samples[0]["readmission"], 1) - - def test_negative_sample(self): - samples = [sample for sample in self.samples if sample["admission_id"] == '4'] - self.assertEqual(len(samples), 1) - self.assertEqual(samples[0]["readmission"], 0) - - def test_patient_with_positive_and_negative_samples(self): - samples = [sample for sample in self.samples if sample["admission_id"] == '6'] - self.assertEqual(len(samples), 1) - self.assertEqual(samples[0]["readmission"], 0) - samples = [sample for sample in self.samples if sample["admission_id"] == '7'] - self.assertEqual(len(samples), 1) - self.assertEqual(samples[0]["readmission"], 1) - - def test_last_admission_is_excluded_since_no_readmission_data(self): - samples = [sample for sample in self.samples if sample["admission_id"] in ('3', '5', '8', '12', '15')] - self.assertEqual(len(samples), 0) + patient = self.mock.add_patient() + admission = self.mock.add_admission(patient, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + dataset = self.mock.create() + + samples = dataset.set_task(ReadmissionPredictionMIMIC3(timedelta(days=5))) + + self.assertTrue(all(s["patient_id"] != str(patient) for s in samples)) + self.assertTrue(all(s["admission_id"] != str(admission) for s in samples)) def test_admissions_without_diagnoses_are_excluded(self): - self.assertTrue(all(sample["admission_id"] != '11' for sample in self.samples)) + patient1 = self.mock.add_patient() + patient2 = self.mock.add_patient() + admission1 = self.mock.add_admission(patient1, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + admission2 = self.mock.add_admission(patient1, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + admission3 = self.mock.add_admission(patient2, "2020-01-01 00:00:00", "2020-01-01 12:00:00", add_diagnosis=False) + admission4 = self.mock.add_admission(patient2, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + dataset = self.mock.create() + + samples = dataset.set_task(ReadmissionPredictionMIMIC3(timedelta(days=5))) + + admission_ids = [int(s["admission_id"]) for s in samples] + + self.assertIn (admission1, admission_ids) + self.assertNotIn(admission2, admission_ids) # Patient's last admission should not be included + self.assertNotIn(admission3, admission_ids) + self.assertNotIn(admission4, admission_ids) # Patient's last admission should not be included def test_admissions_without_prescriptions_are_excluded(self): - self.assertTrue(all(sample["admission_id"] != '10' for sample in self.samples)) + patient1 = self.mock.add_patient() + patient2 = self.mock.add_patient() + admission1 = self.mock.add_admission(patient1, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + admission2 = self.mock.add_admission(patient1, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + admission3 = self.mock.add_admission(patient2, "2020-01-01 00:00:00", "2020-01-01 12:00:00", add_prescription=False) + admission4 = self.mock.add_admission(patient2, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + dataset = self.mock.create() + + samples = dataset.set_task(ReadmissionPredictionMIMIC3(timedelta(days=5))) + + admission_ids = [int(s["admission_id"]) for s in samples] + + self.assertIn (admission1, admission_ids) + self.assertNotIn(admission2, admission_ids) # Patient's last admission should not be included + self.assertNotIn(admission3, admission_ids) + self.assertNotIn(admission4, admission_ids) # Patient's last admission should not be included def test_admissions_without_procedures_are_excluded(self): - self.assertTrue(all(sample["admission_id"] != '9' for sample in self.samples)) + patient1 = self.mock.add_patient() + patient2 = self.mock.add_patient() + admission1 = self.mock.add_admission(patient1, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + admission2 = self.mock.add_admission(patient1, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + admission3 = self.mock.add_admission(patient2, "2020-01-01 00:00:00", "2020-01-01 12:00:00", add_procedure=False) + admission4 = self.mock.add_admission(patient2, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + dataset = self.mock.create() + + samples = dataset.set_task(ReadmissionPredictionMIMIC3(timedelta(days=5))) + + admission_ids = [int(s["admission_id"]) for s in samples] + + self.assertIn (admission1, admission_ids) + self.assertNotIn(admission2, admission_ids) # Patient's last admission should not be included + self.assertNotIn(admission3, admission_ids) + self.assertNotIn(admission4, admission_ids) # Patient's last admission should not be included def test_admissions_of_minors_are_excluded(self): - self.assertTrue(all(sample["admission_id"] != '13' for sample in self.samples)) - samples = [sample for sample in self.samples if sample["admission_id"] == '14'] - self.assertEqual(len(samples), 1) - self.assertEqual(samples[0]["readmission"], 0) + patient = self.mock.add_patient(dob="2000-01-01 00:00:00") + admission1 = self.mock.add_admission(patient, admittime="2017-12-31 23:59:59", dischtime="2018-01-01 00:00:00") # Admitted 1 second before turning 18 + admission2 = self.mock.add_admission(patient, admittime="2018-01-01 00:00:00", dischtime="2018-01-01 12:00:00") # Admitted at exactly 18 + admission3 = self.mock.add_admission(patient, admittime="2020-01-01 00:00:00", dischtime="2020-01-01 12:00:00") + dataset = self.mock.create() + + samples = dataset.set_task(ReadmissionPredictionMIMIC3(timedelta(days=5))) + + admission_ids = [int(s["admission_id"]) for s in samples] + + self.assertNotIn(admission1, admission_ids) + self.assertIn (admission2, admission_ids) + self.assertNotIn(admission3, admission_ids) # Patient's last admission should not be included + + +class MockMICIC3Dataset: + def __init__(self): + self.patients = ["row_id,subject_id,gender,dob,dod,dod_hosp,dod_ssn,expire_flag"] + self.admissions = ["row_id,subject_id,hadm_id,admittime,dischtime,deathtime,admission_type,admission_location,discharge_location,insurance,language,religion,marital_status,ethnicity,edregtime,edouttime,diagnosis,hospital_expire_flag,has_chartevents_data"] + self.icu_stays = ["subject_id,intime,icustay_id,first_careunit,dbsource,last_careunit,outtime"] + self.diagnoses = ["row_id,subject_id,hadm_id,seq_num,icd9_code"] + self.prescriptions = ["row_id,subject_id,hadm_id,icustay_id,startdate,enddate,drug_type,drug,drug_name_poe,drug_name_generic,formulary_drug_cd,gsn,ndc,prod_strength,dose_val_rx,dose_unit_rx,form_val_disp,form_unit_disp,route"] + self.procedures = ["row_id,subject_id,hadm_id,seq_num,icd9_code"] + + self.next_subject_id = 1 + self.next_hadm_id = 1 + self.next_diagnosis_id = 1 + self.next_prescription_id = 1 + self.next_procedure_id = 1 + + def add_patient(self,dob: str = "2000-01-01 00:00:00") -> int: + subject_id = self.next_subject_id + self.next_subject_id += 1 + self.patients.append(f"{subject_id},{subject_id},,{dob},,,,") + return subject_id + + def add_admission(self, + subject_id: int, + admittime: str, + dischtime: str, + add_diagnosis: bool = True, + add_prescription: bool = True, + add_procedure: bool = True + ) -> int: + hadm_id = self.next_hadm_id + self.next_hadm_id += 1 + self.admissions.append(f"{hadm_id},{subject_id},{hadm_id},{admittime},{dischtime},,,,,,,,,,,,,,") + if add_diagnosis: self.add_diagnosis(subject_id, hadm_id) + if add_prescription: self.add_prescription(subject_id, hadm_id) + if add_procedure: self.add_procedure(subject_id, hadm_id) + return hadm_id + + def add_diagnosis(self, subject_id, hadm_id, seq_num: int=1, icd9_code: str="") -> int: + row_id = self.next_diagnosis_id + self.next_diagnosis_id += 1 + self.diagnoses.append(f"{row_id},{subject_id},{hadm_id},{seq_num},{icd9_code}") + return row_id + + def add_prescription(self, subject_id, hadm_id) -> int: + row_id = self.next_prescription_id + self.next_prescription_id += 1 + self.prescriptions.append(f"{row_id},{subject_id},{hadm_id},,,,,,,,,,,,,,,,") + return row_id + + def add_procedure(self, subject_id, hadm_id, seq_num: int=1, icd9_code: str="") -> int: + row_id = self.next_procedure_id + self.next_procedure_id += 1 + self.procedures.append(f"{row_id},{subject_id},{hadm_id},{seq_num},{icd9_code}") + return row_id + + def create(self, tables=["diagnoses_icd", "prescriptions", "procedures_icd"]): + files = { + "PATIENTS.csv": "\n".join(self.patients), + "ADMISSIONS.csv": "\n".join(self.admissions), + "ICUSTAYS.csv": "\n".join(self.icu_stays), + "DIAGNOSES_ICD.csv": "\n".join(self.diagnoses), + "PRESCRIPTIONS.csv": "\n".join(self.prescriptions), + "PROCEDURES_ICD.csv": "\n".join(self.procedures), + } + + for k, v in files.items(): + with open(k, 'w') as f: f.write(v) + + return MIMIC3Dataset(root=".", tables=tables) + + def __del__(self): + if os.path.exists("PATIENTS.csv"): os.remove("PATIENTS.csv") + if os.path.exists("ADMISSIONS.csv"): os.remove("ADMISSIONS.csv") + if os.path.exists("ICUSTAYS.csv"): os.remove("ICUSTAYS.csv") + if os.path.exists("DIAGNOSES_ICD.csv"): os.remove("DIAGNOSES_ICD.csv") + if os.path.exists("PRESCRIPTIONS.csv"): os.remove("PRESCRIPTIONS.csv") + if os.path.exists("PROCEDURES_ICD.csv"): os.remove("PROCEDURES_ICD.csv") + if __name__ == "__main__": unittest.main() From 858f81e7f8c0478d5aac1243b711051db9762fe2 Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Wed, 24 Dec 2025 02:03:35 +0000 Subject: [PATCH 06/10] Replace references to the old task --- README.rst | 86 +++++++++---------- .../pyhealth.tasks.readmission_prediction.rst | 2 +- examples/readmission_mimic3_fairness.py | 4 +- examples/readmission_mimic3_rnn.py | 4 +- pyhealth/tasks/__init__.py | 3 +- pyhealth/tasks/readmission_prediction.py | 74 ++-------------- .../test_mimic3_readmission_prediction.py | 68 +++++++-------- 7 files changed, 90 insertions(+), 151 deletions(-) diff --git a/README.rst b/README.rst index ad8d1b257..683302ea1 100644 --- a/README.rst +++ b/README.rst @@ -13,7 +13,7 @@ Welcome to PyHealth! .. image:: https://readthedocs.org/projects/pyhealth/badge/?version=latest :target: https://pyhealth.readthedocs.io/en/latest/ :alt: Documentation status - + .. image:: https://img.shields.io/github/stars/sunlabuiuc/pyhealth.svg :target: https://github.com/sunlabuiuc/pyhealth/stargazers @@ -121,7 +121,7 @@ PyHealth is a comprehensive deep learning toolkit for supporting clinical predic You can use the following functions independently: - **Dataset**: ``MIMIC-III``, ``MIMIC-IV``, ``eICU``, ``OMOP-CDM``, ``customized EHR datasets``, etc. -- **Tasks**: ``diagnosis-based drug recommendation``, ``patient hospitalization and mortality prediction``, ``length stay forecasting``, etc. +- **Tasks**: ``diagnosis-based drug recommendation``, ``patient hospitalization and mortality prediction``, ``length stay forecasting``, etc. - **ML models**: ``CNN``, ``LSTM``, ``GRU``, ``LSTM``, ``RETAIN``, ``SafeDrug``, ``Deepr``, etc. *Building a healthcare AI pipeline can be as short as 10 lines of code in PyHealth*. @@ -130,7 +130,7 @@ You can use the following functions independently: 3. Build ML Pipelines :trophy: --------------------------------- -All healthcare tasks in our package follow a **five-stage pipeline**: +All healthcare tasks in our package follow a **five-stage pipeline**: .. image:: figure/five-stage-pipeline.png :width: 640 @@ -150,7 +150,7 @@ Module 1: mimic3base = MIMIC3Dataset( # root directory of the dataset - root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/", + root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/", # raw CSV table name tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], # map all NDC codes to CCS codes in these tables @@ -169,9 +169,9 @@ Module 2: .. code-block:: python - from pyhealth.tasks import readmission_prediction_mimic3_fn + from pyhealth.tasks import ReadmissionPredictionMIMIC3 - mimic3sample = mimic3base.set_task(task_fn=readmission_prediction_mimic3_fn) # use default task + mimic3sample = mimic3base.set_task(ReadmissionPredictionMIMIC3()) mimic3sample.samples[0] # show the information of the first sample """ { @@ -180,7 +180,7 @@ Module 2: 'conditions': ['5990', '4280', '2851', '4240', '2749', '9982', 'E8499', '42831', '34600'], 'procedures': ['0040', '3931', '7769'], 'drugs': ['N06DA02', 'V06DC01', 'B01AB01', 'A06AA02', 'R03AC02', 'H03AA01', 'J01FA09'], - 'label': 0 + 'readmission': 0 } """ @@ -213,7 +213,7 @@ Module 4: ``pyhealth.trainer`` can specify training arguments, such as epochs, optimizer, learning rate, etc. The trainer will automatically save the best model and output the path in the end. .. code-block:: python - + from pyhealth.trainer import Trainer trainer = Trainer(model=model) @@ -233,19 +233,19 @@ Module 5: # method 1 trainer.evaluate(test_loader) - + # method 2 from pyhealth.metrics.binary import binary_metrics_fn y_true, y_prob, loss = trainer.inference(test_loader) binary_metrics_fn(y_true, y_prob, metrics=["pr_auc", "roc_auc"]) -4. Medical Code Map :hospital: +4. Medical Code Map :hospital: --------------------------------- ``pyhealth.codemap`` provides two core functionalities. **This module can be used independently.** -* For code ontology lookup within one medical coding system (e.g., name, category, sub-concept); +* For code ontology lookup within one medical coding system (e.g., name, category, sub-concept); .. code-block:: python @@ -256,7 +256,7 @@ Module 5: # `Congestive heart failure, unspecified` icd9cm.get_ancestors("428.0") # ['428', '420-429.99', '390-459.99', '001-999.99'] - + atc = InnerMap.load("ATC") atc.lookup("M01AE51") # `ibuprofen, combinations` @@ -267,7 +267,7 @@ Module 5: atc.lookup("M01AE51", "indication") # Ibuprofen is the most commonly used and prescribed NSAID. It is very common over the ... -* For code mapping between two coding systems (e.g., ICD9CM to CCSCM). +* For code mapping between two coding systems (e.g., ICD9CM to CCSCM). .. code-block:: python @@ -300,12 +300,12 @@ Module 5: 'A12B', 'A12C', 'A13A', 'A14A', 'A14B', 'A16A'] tokenizer = Tokenizer(tokens=token_space, special_tokens=["", ""]) - # 2d encode + # 2d encode tokens = [['A03C', 'A03D', 'A03E', 'A03F'], ['A04A', 'B035', 'C129']] - indices = tokenizer.batch_encode_2d(tokens) + indices = tokenizer.batch_encode_2d(tokens) # [[8, 9, 10, 11], [12, 1, 1, 0]] - # 2d decode + # 2d decode indices = [[8, 9, 10, 11], [12, 1, 1, 0]] tokens = tokenizer.batch_decode_2d(indices) # [['A03C', 'A03D', 'A03E', 'A03F'], ['A04A', '', '']] @@ -331,69 +331,69 @@ Module 5: .. - We provide the following tutorials to help users get started with our pyhealth. + We provide the following tutorials to help users get started with our pyhealth. -`Tutorial 0: Introduction to pyhealth.data `_ `[Video] `__ +`Tutorial 0: Introduction to pyhealth.data `_ `[Video] `__ -`Tutorial 1: Introduction to pyhealth.datasets `_ `[Video] `__ +`Tutorial 1: Introduction to pyhealth.datasets `_ `[Video] `__ -`Tutorial 2: Introduction to pyhealth.tasks `_ `[Video] `__ +`Tutorial 2: Introduction to pyhealth.tasks `_ `[Video] `__ -`Tutorial 3: Introduction to pyhealth.models `_ `[Video] `__ +`Tutorial 3: Introduction to pyhealth.models `_ `[Video] `__ -`Tutorial 4: Introduction to pyhealth.trainer `_ `[Video] `__ +`Tutorial 4: Introduction to pyhealth.trainer `_ `[Video] `__ -`Tutorial 5: Introduction to pyhealth.metrics `_ `[Video] `__ +`Tutorial 5: Introduction to pyhealth.metrics `_ `[Video] `__ -`Tutorial 6: Introduction to pyhealth.tokenizer `_ `[Video] `__ +`Tutorial 6: Introduction to pyhealth.tokenizer `_ `[Video] `__ -`Tutorial 7: Introduction to pyhealth.medcode `_ `[Video] `__ +`Tutorial 7: Introduction to pyhealth.medcode `_ `[Video] `__ The following tutorials will help users build their own task pipelines. `Pipeline 1: Drug Recommendation `_ `[Video] `__ +www.youtube.com/watch?v=GGP3Dhfyisc&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=12>`__ `Pipeline 2: Length of Stay Prediction `_ `[Video] `__ +www.youtube.com/watch?v=GGP3Dhfyisc&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=12>`__ `Pipeline 3: Readmission Prediction `_ `[Video] `__ +www.youtube.com/watch?v=GGP3Dhfyisc&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=12>`__ `Pipeline 4: Mortality Prediction `_ `[Video] `__ +www.youtube.com/watch?v=GGP3Dhfyisc&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=12>`__ -`Pipeline 5: Sleep Staging `_ `[Video] `__ +`Pipeline 5: Sleep Staging `_ `[Video] `__ - We provided the advanced tutorials for supporting various needs. + We provided the advanced tutorials for supporting various needs. -`Advanced Tutorial 1: Fit your dataset into our pipeline `_ `[Video] `__ +`Advanced Tutorial 1: Fit your dataset into our pipeline `_ `[Video] `__ -`Advanced Tutorial 2: Define your own healthcare task `_ +`Advanced Tutorial 2: Define your own healthcare task `_ -`Advanced Tutorial 3: Adopt customized model into pyhealth `_ `[Video] `__ +`Advanced Tutorial 3: Adopt customized model into pyhealth `_ `[Video] `__ -`Advanced Tutorial 4: Load your own processed data into pyhealth and try out our ML models `_ `[Video] `__ +`Advanced Tutorial 4: Load your own processed data into pyhealth and try out our ML models `_ `[Video] `__ 7. Datasets :mountain_snow: ----------------------------- We provide the processing files for the following open EHR datasets: -=================== ======================================= ======================================== ======================================================================================================== -Dataset Module Year Information =================== ======================================= ======================================== ======================================================================================================== -MIMIC-III ``pyhealth.datasets.MIMIC3Dataset`` 2016 `MIMIC-III Clinical Database `_ -MIMIC-IV ``pyhealth.datasets.MIMIC4Dataset`` 2020 `MIMIC-IV Clinical Database `_ -eICU ``pyhealth.datasets.eICUDataset`` 2018 `eICU Collaborative Research Database `_ -OMOP ``pyhealth.datasets.OMOPDataset`` `OMOP-CDM schema based dataset `_ +Dataset Module Year Information +=================== ======================================= ======================================== ======================================================================================================== +MIMIC-III ``pyhealth.datasets.MIMIC3Dataset`` 2016 `MIMIC-III Clinical Database `_ +MIMIC-IV ``pyhealth.datasets.MIMIC4Dataset`` 2020 `MIMIC-IV Clinical Database `_ +eICU ``pyhealth.datasets.eICUDataset`` 2018 `eICU Collaborative Research Database `_ +OMOP ``pyhealth.datasets.OMOPDataset`` `OMOP-CDM schema based dataset `_ SleepEDF ``pyhealth.datasets.SleepEDFDataset`` 2018 `Sleep-EDF dataset `_ -SHHS ``pyhealth.datasets.SHHSDataset`` 2016 `Sleep Heart Health Study dataset `_ -ISRUC ``pyhealth.datasets.ISRUCDataset`` 2016 `ISRUC-SLEEP dataset `_ +SHHS ``pyhealth.datasets.SHHSDataset`` 2016 `Sleep Heart Health Study dataset `_ +ISRUC ``pyhealth.datasets.ISRUCDataset`` 2016 `ISRUC-SLEEP dataset `_ =================== ======================================= ======================================== ======================================================================================================== diff --git a/docs/api/tasks/pyhealth.tasks.readmission_prediction.rst b/docs/api/tasks/pyhealth.tasks.readmission_prediction.rst index f9c56f3aa..cfa5ea6e6 100644 --- a/docs/api/tasks/pyhealth.tasks.readmission_prediction.rst +++ b/docs/api/tasks/pyhealth.tasks.readmission_prediction.rst @@ -1,7 +1,7 @@ pyhealth.tasks.readmission_prediction ======================================= -.. autofunction:: pyhealth.tasks.readmission_prediction.readmission_prediction_mimic3_fn +.. autofunction:: pyhealth.tasks.readmission_prediction.ReadmissionPredictionMIMIC3 .. autofunction:: pyhealth.tasks.readmission_prediction.readmission_prediction_mimic4_fn .. autofunction:: pyhealth.tasks.readmission_prediction.readmission_prediction_eicu_fn .. autofunction:: pyhealth.tasks.readmission_prediction.readmission_prediction_eicu_fn2 diff --git a/examples/readmission_mimic3_fairness.py b/examples/readmission_mimic3_fairness.py index d53b67c9e..f8be824d4 100644 --- a/examples/readmission_mimic3_fairness.py +++ b/examples/readmission_mimic3_fairness.py @@ -1,5 +1,5 @@ from pyhealth.datasets import MIMIC3Dataset -from pyhealth.tasks import readmission_prediction_mimic3_fn +from pyhealth.tasks import ReadmissionPredictionMIMIC3 from pyhealth.datasets import split_by_patient, get_dataloader from pyhealth.metrics import fairness_metrics_fn from pyhealth.models import Transformer @@ -14,7 +14,7 @@ base_dataset.stat() # STEP 2: set task -sample_dataset = base_dataset.set_task(readmission_prediction_mimic3_fn) +sample_dataset = base_dataset.set_task(ReadmissionPredictionMIMIC3()) sample_dataset.stat() train_dataset, val_dataset, test_dataset = split_by_patient(sample_dataset, [0.8, 0.1, 0.1]) diff --git a/examples/readmission_mimic3_rnn.py b/examples/readmission_mimic3_rnn.py index 9870fcd64..9e2fbb364 100644 --- a/examples/readmission_mimic3_rnn.py +++ b/examples/readmission_mimic3_rnn.py @@ -1,7 +1,7 @@ from pyhealth.datasets import MIMIC3Dataset from pyhealth.datasets import split_by_patient, get_dataloader from pyhealth.models import RNN -from pyhealth.tasks import readmission_prediction_mimic3_fn +from pyhealth.tasks import ReadmissionPredictionMIMIC3 from pyhealth.trainer import Trainer # STEP 1: load data @@ -15,7 +15,7 @@ base_dataset.stat() # STEP 2: set task -sample_dataset = base_dataset.set_task(readmission_prediction_mimic3_fn) +sample_dataset = base_dataset.set_task(ReadmissionPredictionMIMIC3()) sample_dataset.stat() train_dataset, val_dataset, test_dataset = split_by_patient( diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 92b90f1d3..76632467b 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -49,12 +49,11 @@ from .patient_linkage import patient_linkage_mimic3_fn from .readmission_30days_mimic4 import Readmission30DaysMIMIC4 from .readmission_prediction import ( + ReadmissionPredictionMIMIC3, readmission_prediction_eicu_fn, readmission_prediction_eicu_fn2, - readmission_prediction_mimic3_fn, readmission_prediction_mimic4_fn, readmission_prediction_omop_fn, - ReadmissionPredictionMIMIC3, ) from .sleep_staging import ( sleep_staging_isruc_fn, diff --git a/pyhealth/tasks/readmission_prediction.py b/pyhealth/tasks/readmission_prediction.py index b38595f1b..d87aa506a 100644 --- a/pyhealth/tasks/readmission_prediction.py +++ b/pyhealth/tasks/readmission_prediction.py @@ -9,15 +9,14 @@ class ReadmissionPredictionMIMIC3(BaseTask): #todo: add doc strings #todo: replace examples - #todo: update docs (replace all references to readmission_prediction_mimic3_fn) - #todo: deprecate readmission_prediction_mimic3_fn (make it a wrapper around this) #todo: review other similar tasks for best practices and common patterns #todo: add short-circuits to loop and test sample generation speed difference + #todo: review my chestxray14 PR to make sure I updated all the right places task_name: str = "ReadmissionPredictionMIMIC3" - input_schema: Dict[str, str] = {"diagnoses": "sequence", "prescriptions": "sequence", "procedures": "sequence"} + input_schema: Dict[str, str] = {"conditions": "sequence", "procedures": "sequence", "drugs": "sequence"} output_schema: Dict[str, str] = {"readmission": "binary"} - def __init__(self, window: timedelta) -> None: + def __init__(self, window: timedelta=timedelta(days=15)) -> None: self.window = window def __call__(self, patient: Patient) -> List[Dict]: @@ -53,19 +52,18 @@ def __call__(self, patient: Patient) -> List[Dict]: samples.append( { + "visit_id": admissions[i].hadm_id, "patient_id": patient.patient_id, - "admission_id": admissions[i].hadm_id, - "diagnoses": diagnoses, - "prescriptions": prescriptions, + "conditions": diagnoses, "procedures": procedures, + "drugs": prescriptions, "readmission": readmission, } ) return samples -# TODO: time_window cannot be passed in to base_dataset -def readmission_prediction_mimic3_fn(patient: Patient, time_window=15): + """Processes a single patient for the readmission prediction task. Readmission prediction aims at predicting whether the patient will be readmitted @@ -80,52 +78,7 @@ def readmission_prediction_mimic3_fn(patient: Patient, time_window=15): Returns: samples: a list of samples, each sample is a dict with patient_id, visit_id, and other task-specific attributes as key - - Note that we define the task as a binary classification task. - - Examples: - >>> from pyhealth.datasets import MIMIC3Dataset - >>> mimic3_base = MIMIC3Dataset( - ... root="/srv/local/data/physionet.org/files/mimiciii/1.4", - ... tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], - ... code_mapping={"ICD9CM": "CCSCM"}, - ... ) - >>> from pyhealth.tasks import readmission_prediction_mimic3_fn - >>> mimic3_sample = mimic3_base.set_task(readmission_prediction_mimic3_fn) - >>> mimic3_sample.samples[0] - [{'visit_id': '130744', 'patient_id': '103', 'conditions': [['42', '109', '19', '122', '98', '663', '58', '51']], 'procedures': [['1']], 'label': 1}] """ - samples = [] - - # we will drop the last visit - for i in range(len(patient) - 1): - visit: Visit = patient[i] - next_visit: Visit = patient[i + 1] - - # get time difference between current visit and next visit - time_diff = (next_visit.encounter_time - visit.encounter_time).days - readmission_label = 1 if time_diff < time_window else 0 - - conditions = visit.get_code_list(table="DIAGNOSES_ICD") - procedures = visit.get_code_list(table="PROCEDURES_ICD") - drugs = visit.get_code_list(table="PRESCRIPTIONS") - # exclude: visits without condition, procedure, or drug code - if len(conditions) * len(procedures) * len(drugs) == 0: - continue - # TODO: should also exclude visit with age < 18 - samples.append( - { - "visit_id": visit.visit_id, - "patient_id": patient.patient_id, - "conditions": [conditions], - "procedures": [procedures], - "drugs": [drugs], - "label": readmission_label, - } - ) - # no cohort selection - return samples - def readmission_prediction_mimic4_fn(patient: Patient, time_window=15): """Processes a single patient for the readmission prediction task. @@ -391,19 +344,6 @@ def readmission_prediction_omop_fn(patient: Patient, time_window=15): if __name__ == "__main__": - from pyhealth.datasets import MIMIC3Dataset - - base_dataset = MIMIC3Dataset( - root="/srv/local/data/physionet.org/files/mimiciii/1.4", - tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], - dev=True, - code_mapping={"ICD9CM": "CCSCM", "NDC": "ATC"}, - refresh_cache=False, - ) - sample_dataset = base_dataset.set_task(task_fn=readmission_prediction_mimic3_fn) - sample_dataset.stat() - print(sample_dataset.available_keys) - from pyhealth.datasets import MIMIC4Dataset base_dataset = MIMIC4Dataset( diff --git a/tests/core/test_mimic3_readmission_prediction.py b/tests/core/test_mimic3_readmission_prediction.py index b1d47ce34..79e6be4d8 100644 --- a/tests/core/test_mimic3_readmission_prediction.py +++ b/tests/core/test_mimic3_readmission_prediction.py @@ -25,11 +25,11 @@ def test_patient_with_pos_and_neg_samples(self): self.assertIn("output_schema", vars(ReadmissionPredictionMIMIC3)) for sample in samples: + self.assertIn("visit_id", sample) self.assertIn("patient_id", sample) - self.assertIn("admission_id", sample) - self.assertIn("diagnoses", sample) - self.assertIn("prescriptions", sample) + self.assertIn("conditions", sample) self.assertIn("procedures", sample) + self.assertIn("drugs", sample) self.assertIn("readmission", sample) self.assertEqual(len(samples), 2) @@ -40,10 +40,10 @@ def test_patient_with_pos_and_neg_samples(self): self.assertEqual(len(neg_samples), 1) self.assertEqual(len(pos_samples), 1) - self.assertEqual(neg_samples[0]["admission_id"], str(self.admission1)) - self.assertEqual(pos_samples[0]["admission_id"], str(self.admission2)) + self.assertEqual(neg_samples[0]["visit_id"], str(self.admission1)) + self.assertEqual(pos_samples[0]["visit_id"], str(self.admission2)) - self.assertTrue(all(s["admission_id"] != str(self.admission3) for s in samples)) # Patient's last admission not included + self.assertTrue(all(s["visit_id"] != str(self.admission3) for s in samples)) # Patient's last admission not included def test_patient_with_only_one_visit_is_excluded(self): patient = self.mock.add_patient() @@ -53,7 +53,7 @@ def test_patient_with_only_one_visit_is_excluded(self): samples = dataset.set_task(ReadmissionPredictionMIMIC3(timedelta(days=5))) self.assertTrue(all(s["patient_id"] != str(patient) for s in samples)) - self.assertTrue(all(s["admission_id"] != str(admission) for s in samples)) + self.assertTrue(all(s["visit_id"] != str(admission) for s in samples)) def test_admissions_without_diagnoses_are_excluded(self): patient1 = self.mock.add_patient() @@ -66,12 +66,12 @@ def test_admissions_without_diagnoses_are_excluded(self): samples = dataset.set_task(ReadmissionPredictionMIMIC3(timedelta(days=5))) - admission_ids = [int(s["admission_id"]) for s in samples] + visit_ids = [int(s["visit_id"]) for s in samples] - self.assertIn (admission1, admission_ids) - self.assertNotIn(admission2, admission_ids) # Patient's last admission should not be included - self.assertNotIn(admission3, admission_ids) - self.assertNotIn(admission4, admission_ids) # Patient's last admission should not be included + self.assertIn (admission1, visit_ids) + self.assertNotIn(admission2, visit_ids) # Patient's last admission should not be included + self.assertNotIn(admission3, visit_ids) + self.assertNotIn(admission4, visit_ids) # Patient's last admission should not be included def test_admissions_without_prescriptions_are_excluded(self): patient1 = self.mock.add_patient() @@ -84,12 +84,12 @@ def test_admissions_without_prescriptions_are_excluded(self): samples = dataset.set_task(ReadmissionPredictionMIMIC3(timedelta(days=5))) - admission_ids = [int(s["admission_id"]) for s in samples] + visit_ids = [int(s["visit_id"]) for s in samples] - self.assertIn (admission1, admission_ids) - self.assertNotIn(admission2, admission_ids) # Patient's last admission should not be included - self.assertNotIn(admission3, admission_ids) - self.assertNotIn(admission4, admission_ids) # Patient's last admission should not be included + self.assertIn (admission1, visit_ids) + self.assertNotIn(admission2, visit_ids) # Patient's last admission should not be included + self.assertNotIn(admission3, visit_ids) + self.assertNotIn(admission4, visit_ids) # Patient's last admission should not be included def test_admissions_without_procedures_are_excluded(self): patient1 = self.mock.add_patient() @@ -102,12 +102,12 @@ def test_admissions_without_procedures_are_excluded(self): samples = dataset.set_task(ReadmissionPredictionMIMIC3(timedelta(days=5))) - admission_ids = [int(s["admission_id"]) for s in samples] + visit_ids = [int(s["visit_id"]) for s in samples] - self.assertIn (admission1, admission_ids) - self.assertNotIn(admission2, admission_ids) # Patient's last admission should not be included - self.assertNotIn(admission3, admission_ids) - self.assertNotIn(admission4, admission_ids) # Patient's last admission should not be included + self.assertIn (admission1, visit_ids) + self.assertNotIn(admission2, visit_ids) # Patient's last admission should not be included + self.assertNotIn(admission3, visit_ids) + self.assertNotIn(admission4, visit_ids) # Patient's last admission should not be included def test_admissions_of_minors_are_excluded(self): patient = self.mock.add_patient(dob="2000-01-01 00:00:00") @@ -118,11 +118,11 @@ def test_admissions_of_minors_are_excluded(self): samples = dataset.set_task(ReadmissionPredictionMIMIC3(timedelta(days=5))) - admission_ids = [int(s["admission_id"]) for s in samples] + visit_ids = [int(s["visit_id"]) for s in samples] - self.assertNotIn(admission1, admission_ids) - self.assertIn (admission2, admission_ids) - self.assertNotIn(admission3, admission_ids) # Patient's last admission should not be included + self.assertNotIn(admission1, visit_ids) + self.assertIn (admission2, visit_ids) + self.assertNotIn(admission3, visit_ids) # Patient's last admission should not be included class MockMICIC3Dataset: @@ -140,7 +140,7 @@ def __init__(self): self.next_prescription_id = 1 self.next_procedure_id = 1 - def add_patient(self,dob: str = "2000-01-01 00:00:00") -> int: + def add_patient(self,dob: str="2000-01-01 00:00:00") -> int: subject_id = self.next_subject_id self.next_subject_id += 1 self.patients.append(f"{subject_id},{subject_id},,{dob},,,,") @@ -150,9 +150,9 @@ def add_admission(self, subject_id: int, admittime: str, dischtime: str, - add_diagnosis: bool = True, - add_prescription: bool = True, - add_procedure: bool = True + add_diagnosis: bool=True, + add_prescription: bool=True, + add_procedure: bool=True ) -> int: hadm_id = self.next_hadm_id self.next_hadm_id += 1 @@ -162,25 +162,25 @@ def add_admission(self, if add_procedure: self.add_procedure(subject_id, hadm_id) return hadm_id - def add_diagnosis(self, subject_id, hadm_id, seq_num: int=1, icd9_code: str="") -> int: + def add_diagnosis(self, subject_id: int, hadm_id: int, seq_num: int=1, icd9_code: str="") -> int: row_id = self.next_diagnosis_id self.next_diagnosis_id += 1 self.diagnoses.append(f"{row_id},{subject_id},{hadm_id},{seq_num},{icd9_code}") return row_id - def add_prescription(self, subject_id, hadm_id) -> int: + def add_prescription(self, subject_id: int, hadm_id: int) -> int: row_id = self.next_prescription_id self.next_prescription_id += 1 self.prescriptions.append(f"{row_id},{subject_id},{hadm_id},,,,,,,,,,,,,,,,") return row_id - def add_procedure(self, subject_id, hadm_id, seq_num: int=1, icd9_code: str="") -> int: + def add_procedure(self, subject_id: int, hadm_id: int, seq_num: int=1, icd9_code: str="") -> int: row_id = self.next_procedure_id self.next_procedure_id += 1 self.procedures.append(f"{row_id},{subject_id},{hadm_id},{seq_num},{icd9_code}") return row_id - def create(self, tables=["diagnoses_icd", "prescriptions", "procedures_icd"]): + def create(self, tables: list=["diagnoses_icd", "prescriptions", "procedures_icd"]): files = { "PATIENTS.csv": "\n".join(self.patients), "ADMISSIONS.csv": "\n".join(self.admissions), From 7a76dab2f5a38a864fdb77aa9fcdd5575c8fb3a5 Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Wed, 24 Dec 2025 03:38:45 +0000 Subject: [PATCH 07/10] Get RNN example working --- examples/readmission_mimic3_fairness.py | 5 +- examples/readmission_mimic3_rnn.py | 15 ++--- pyhealth/tasks/readmission_prediction.py | 44 ++++++++----- .../test_mimic3_readmission_prediction.py | 63 +++++++++++++++---- 4 files changed, 86 insertions(+), 41 deletions(-) diff --git a/examples/readmission_mimic3_fairness.py b/examples/readmission_mimic3_fairness.py index f8be824d4..63cdcce26 100644 --- a/examples/readmission_mimic3_fairness.py +++ b/examples/readmission_mimic3_fairness.py @@ -11,11 +11,10 @@ root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/", tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], ) -base_dataset.stat() +base_dataset.stats() # STEP 2: set task -sample_dataset = base_dataset.set_task(ReadmissionPredictionMIMIC3()) -sample_dataset.stat() +sample_dataset = base_dataset.set_task(ReadmissionPredictionMIMIC3(exclude_minors=False)) # Must include minors to get any readmission samples on the synthetic dataset train_dataset, val_dataset, test_dataset = split_by_patient(sample_dataset, [0.8, 0.1, 0.1]) train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) diff --git a/examples/readmission_mimic3_rnn.py b/examples/readmission_mimic3_rnn.py index 9e2fbb364..18eb0df1c 100644 --- a/examples/readmission_mimic3_rnn.py +++ b/examples/readmission_mimic3_rnn.py @@ -6,17 +6,13 @@ # STEP 1: load data base_dataset = MIMIC3Dataset( - root="/srv/local/data/physionet.org/files/mimiciii/1.4", + root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III", tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], - code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC", "NDC": "ATC"}, - dev=False, - refresh_cache=True, ) -base_dataset.stat() +base_dataset.stats() # STEP 2: set task -sample_dataset = base_dataset.set_task(ReadmissionPredictionMIMIC3()) -sample_dataset.stat() +sample_dataset = base_dataset.set_task(ReadmissionPredictionMIMIC3(exclude_minors=False)) # Must include minors to get any readmission samples on the synthetic dataset train_dataset, val_dataset, test_dataset = split_by_patient( sample_dataset, [0.8, 0.1, 0.1] @@ -28,9 +24,6 @@ # STEP 3: define model model = RNN( dataset=sample_dataset, - feature_keys=["conditions", "procedures", "drugs"], - label_key="label", - mode="binary", ) # STEP 4: define trainer @@ -38,7 +31,7 @@ trainer.train( train_dataloader=train_dataloader, val_dataloader=val_dataloader, - epochs=50, + epochs=1, monitor="roc_auc", ) diff --git a/pyhealth/tasks/readmission_prediction.py b/pyhealth/tasks/readmission_prediction.py index d87aa506a..02ffc7337 100644 --- a/pyhealth/tasks/readmission_prediction.py +++ b/pyhealth/tasks/readmission_prediction.py @@ -8,21 +8,23 @@ class ReadmissionPredictionMIMIC3(BaseTask): #todo: add doc strings - #todo: replace examples - #todo: review other similar tasks for best practices and common patterns - #todo: add short-circuits to loop and test sample generation speed difference - #todo: review my chestxray14 PR to make sure I updated all the right places task_name: str = "ReadmissionPredictionMIMIC3" input_schema: Dict[str, str] = {"conditions": "sequence", "procedures": "sequence", "drugs": "sequence"} output_schema: Dict[str, str] = {"readmission": "binary"} - def __init__(self, window: timedelta=timedelta(days=15)) -> None: + def __init__(self, window: timedelta=timedelta(days=15), exclude_minors: bool=True) -> None: self.window = window + self.exclude_minors = exclude_minors def __call__(self, patient: Patient) -> List[Dict]: patients: List[Event] = patient.get_events(event_type="patients") assert len(patients) == 1 - dob = datetime.strptime(patients[0].dob, "%Y-%m-%d %H:%M:%S") + + if self.exclude_minors: + try: + dob = datetime.strptime(patients[0].dob, "%Y-%m-%d %H:%M:%S") + except ValueError: + dob = datetime.strptime(patients[0].dob, "%Y-%m-%d") admissions: List[Event] = patient.get_events(event_type="admissions") if len(admissions) < 2: @@ -30,24 +32,34 @@ def __call__(self, patient: Patient) -> List[Dict]: samples = [] for i in range(len(admissions) - 1): # Skip the last admission since we need a "next" admission - age = admissions[i].timestamp.year - dob.year - age = age-1 if ((admissions[i].timestamp.month, admissions[i].timestamp.day) < (dob.month, dob.day)) else age - if age < 18: - continue + if self.exclude_minors: + age = admissions[i].timestamp.year - dob.year + age = age-1 if ((admissions[i].timestamp.month, admissions[i].timestamp.day) < (dob.month, dob.day)) else age + if age < 18: + continue filter = ("hadm_id", "==", admissions[i].hadm_id) - diagnoses = patient.get_events(event_type="diagnoses_icd", filters=[filter]) - procedures = patient.get_events(event_type="procedures_icd", filters=[filter]) - prescriptions = patient.get_events(event_type="prescriptions", filters=[filter]) + diagnoses = patient.get_events(event_type="diagnoses_icd", filters=[filter]) diagnoses = [event.icd9_code for event in diagnoses] + if len(diagnoses) == 0: + continue + + procedures = patient.get_events(event_type="procedures_icd", filters=[filter]) procedures = [event.icd9_code for event in procedures] - prescriptions = [event.drug for event in prescriptions] + if len(procedures) == 0: + continue - if len(diagnoses) * len(procedures) * len(prescriptions) == 0: + prescriptions = patient.get_events(event_type="prescriptions", filters=[filter]) + prescriptions = [event.drug for event in prescriptions] + if len(prescriptions) == 0: continue - discharge_time = datetime.strptime(admissions[i].dischtime, "%Y-%m-%d %H:%M:%S") + try: + discharge_time = datetime.strptime(admissions[i].dischtime, "%Y-%m-%d %H:%M:%S") + except ValueError: + discharge_time = datetime.strptime(admissions[i].dischtime, "%Y-%m-%d") + readmission = int((admissions[i + 1].timestamp - discharge_time) < self.window) samples.append( diff --git a/tests/core/test_mimic3_readmission_prediction.py b/tests/core/test_mimic3_readmission_prediction.py index 79e6be4d8..f7b7b6a3e 100644 --- a/tests/core/test_mimic3_readmission_prediction.py +++ b/tests/core/test_mimic3_readmission_prediction.py @@ -8,22 +8,25 @@ class TestReadmissionPredictionMIMIC3(unittest.TestCase): def setUp(self): - """Seed dataset with neg and pos 5 day readmission examples (min required for sample generation)""" + """Seed dataset with neg and pos 15 day readmission examples (min required for sample generation)""" self.mock = MockMICIC3Dataset() - self.patient1 = self.mock.add_patient() - self.admission1 = self.mock.add_admission(self.patient1, "2020-01-01 00:00:00", "2020-01-01 12:00:00") - self.admission2 = self.mock.add_admission(self.patient1, "2020-01-06 12:00:00", "2020-01-06 12:00:01") # Exactly 5 days later - self.admission3 = self.mock.add_admission(self.patient1, "2020-01-11 12:00:00", "2020-01-11 12:00:01") # 5 days later less 1 second + patient = self.mock.add_patient() + self.admission1 = self.mock.add_admission(patient, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + self.admission2 = self.mock.add_admission(patient, "2020-01-16 12:00:00", "2020-01-16 12:00:01") # Exactly 15 days later + self.admission3 = self.mock.add_admission(patient, "2020-01-31 12:00:00", "2020-01-31 12:00:01") # 15 days later less 1 second def test_patient_with_pos_and_neg_samples(self): dataset = self.mock.create() - samples = dataset.set_task(ReadmissionPredictionMIMIC3(timedelta(days=5))) + task = ReadmissionPredictionMIMIC3() + samples = dataset.set_task(task) self.assertIn("task_name", vars(ReadmissionPredictionMIMIC3)) self.assertIn("input_schema", vars(ReadmissionPredictionMIMIC3)) self.assertIn("output_schema", vars(ReadmissionPredictionMIMIC3)) + self.assertEqual(task.window, timedelta(days=15)) + for sample in samples: self.assertIn("visit_id", sample) self.assertIn("patient_id", sample) @@ -45,12 +48,32 @@ def test_patient_with_pos_and_neg_samples(self): self.assertTrue(all(s["visit_id"] != str(self.admission3) for s in samples)) # Patient's last admission not included + def test_explicit_time_window(self): + patient = self.mock.add_patient() + admission1 = self.mock.add_admission(patient, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + admission2 = self.mock.add_admission(patient, "2020-01-06 12:00:00", "2020-01-06 12:00:01") # Exactly 5 days later + admission3 = self.mock.add_admission(patient, "2020-01-11 12:00:00", "2020-01-11 12:00:01") # 5 days later less 1 second + dataset = self.mock.create() + + samples = dataset.set_task(ReadmissionPredictionMIMIC3(timedelta(days=5))) + + visit1 = [s for s in samples if s["visit_id"] == str(admission1)] + visit2 = [s for s in samples if s["visit_id"] == str(admission2)] + + self.assertEqual(len(visit1), 1) + self.assertEqual(len(visit2), 1) + + self.assertEqual(visit1[0]["readmission"], 0) + self.assertEqual(visit2[0]["readmission"], 1) + + self.assertTrue(all(s["visit_id"] != str(admission3) for s in samples)) # Patient's last admission not included + def test_patient_with_only_one_visit_is_excluded(self): patient = self.mock.add_patient() admission = self.mock.add_admission(patient, "2020-01-01 00:00:00", "2020-01-01 12:00:00") dataset = self.mock.create() - samples = dataset.set_task(ReadmissionPredictionMIMIC3(timedelta(days=5))) + samples = dataset.set_task(ReadmissionPredictionMIMIC3()) self.assertTrue(all(s["patient_id"] != str(patient) for s in samples)) self.assertTrue(all(s["visit_id"] != str(admission) for s in samples)) @@ -64,7 +87,7 @@ def test_admissions_without_diagnoses_are_excluded(self): admission4 = self.mock.add_admission(patient2, "2020-01-01 00:00:00", "2020-01-01 12:00:00") dataset = self.mock.create() - samples = dataset.set_task(ReadmissionPredictionMIMIC3(timedelta(days=5))) + samples = dataset.set_task(ReadmissionPredictionMIMIC3()) visit_ids = [int(s["visit_id"]) for s in samples] @@ -82,7 +105,7 @@ def test_admissions_without_prescriptions_are_excluded(self): admission4 = self.mock.add_admission(patient2, "2020-01-01 00:00:00", "2020-01-01 12:00:00") dataset = self.mock.create() - samples = dataset.set_task(ReadmissionPredictionMIMIC3(timedelta(days=5))) + samples = dataset.set_task(ReadmissionPredictionMIMIC3()) visit_ids = [int(s["visit_id"]) for s in samples] @@ -100,7 +123,7 @@ def test_admissions_without_procedures_are_excluded(self): admission4 = self.mock.add_admission(patient2, "2020-01-01 00:00:00", "2020-01-01 12:00:00") dataset = self.mock.create() - samples = dataset.set_task(ReadmissionPredictionMIMIC3(timedelta(days=5))) + samples = dataset.set_task(ReadmissionPredictionMIMIC3()) visit_ids = [int(s["visit_id"]) for s in samples] @@ -116,7 +139,10 @@ def test_admissions_of_minors_are_excluded(self): admission3 = self.mock.add_admission(patient, admittime="2020-01-01 00:00:00", dischtime="2020-01-01 12:00:00") dataset = self.mock.create() - samples = dataset.set_task(ReadmissionPredictionMIMIC3(timedelta(days=5))) + task = ReadmissionPredictionMIMIC3() + samples = dataset.set_task(task) + + self.assertTrue(task.exclude_minors) visit_ids = [int(s["visit_id"]) for s in samples] @@ -124,6 +150,21 @@ def test_admissions_of_minors_are_excluded(self): self.assertIn (admission2, visit_ids) self.assertNotIn(admission3, visit_ids) # Patient's last admission should not be included + def test_exclude_minors_flag(self): + patient = self.mock.add_patient(dob="2000-01-01 00:00:00") + admission1 = self.mock.add_admission(patient, admittime="2017-12-31 23:59:59", dischtime="2018-01-01 00:00:00") # Admitted 1 second before turning 18 + admission2 = self.mock.add_admission(patient, admittime="2018-01-01 00:00:00", dischtime="2018-01-01 12:00:00") # Admitted at exactly 18 + admission3 = self.mock.add_admission(patient, admittime="2020-01-01 00:00:00", dischtime="2020-01-01 12:00:00") + dataset = self.mock.create() + + samples = dataset.set_task(ReadmissionPredictionMIMIC3(exclude_minors=False)) + + visit_ids = [int(s["visit_id"]) for s in samples] + + self.assertIn (admission1, visit_ids) + self.assertIn (admission2, visit_ids) + self.assertNotIn(admission3, visit_ids) # Patient's last admission should not be included + class MockMICIC3Dataset: def __init__(self): From b3e908826334cb625592a1d2f73e569d02199b75 Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Wed, 24 Dec 2025 04:06:22 +0000 Subject: [PATCH 08/10] Add doc strings --- pyhealth/tasks/readmission_prediction.py | 52 ++++++++++++++++-------- 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/pyhealth/tasks/readmission_prediction.py b/pyhealth/tasks/readmission_prediction.py index 02ffc7337..6de177a41 100644 --- a/pyhealth/tasks/readmission_prediction.py +++ b/pyhealth/tasks/readmission_prediction.py @@ -7,16 +7,50 @@ from pyhealth.tasks import BaseTask class ReadmissionPredictionMIMIC3(BaseTask): - #todo: add doc strings + """ + Readmission prediction on the MIMIC3 dataset. + + This task aims at predicting whether the patient will be readmitted into hospital within + a specified number of days based on clinical information from the current visit. + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): The schema for the task input. + output_schema (Dict[str, str]): The schema for the task output. + """ task_name: str = "ReadmissionPredictionMIMIC3" input_schema: Dict[str, str] = {"conditions": "sequence", "procedures": "sequence", "drugs": "sequence"} output_schema: Dict[str, str] = {"readmission": "binary"} def __init__(self, window: timedelta=timedelta(days=15), exclude_minors: bool=True) -> None: + """ + Initializes the task object. + + Args: + window (timedelta): If two admissions are closer than this window, it is considered a readmission. Defaults to 15 days. + exclude_minors (bool): Whether to exclude visits where the patient was under 18 years old. Defaults to True. + """ self.window = window self.exclude_minors = exclude_minors def __call__(self, patient: Patient) -> List[Dict]: + """ + Generates binary classification data samples for a single patient. + + Visits with no conditions OR no procedures OR no drugs are excluded from the output but are still used to calculate readmission for prior visits. + + Args: + patient (Patient): A patient object. + + Returns: + List[Dict]: A list containing a dictionary for each patient visit with: + - 'visit_id': MIMIC3 hadm_id. + - 'patient_id': MIMIC3 subject_id. + - 'conditions': MIMIC3 diagnoses_icd table ICD-9 codes. + - 'procedures': MIMIC3 procedures_icd table ICD-9 codes. + - 'drugs': MIMIC3 prescriptions table drug column entries. + - 'readmission': binary label. + """ patients: List[Event] = patient.get_events(event_type="patients") assert len(patients) == 1 @@ -76,22 +110,6 @@ def __call__(self, patient: Patient) -> List[Dict]: return samples - """Processes a single patient for the readmission prediction task. - - Readmission prediction aims at predicting whether the patient will be readmitted - into hospital within time_window days based on the clinical information from - current visit (e.g., conditions and procedures). - - Args: - patient: a Patient object - time_window: the time window threshold (gap < time_window means label=1 for - the task) - - Returns: - samples: a list of samples, each sample is a dict with patient_id, visit_id, - and other task-specific attributes as key - """ - def readmission_prediction_mimic4_fn(patient: Patient, time_window=15): """Processes a single patient for the readmission prediction task. From 82c158ed7381cf45e2f798457a791183ef139ef3 Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Wed, 24 Dec 2025 04:55:00 +0000 Subject: [PATCH 09/10] Remove unused import --- pyhealth/tasks/readmission_prediction.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyhealth/tasks/readmission_prediction.py b/pyhealth/tasks/readmission_prediction.py index 6de177a41..cd8b94b38 100644 --- a/pyhealth/tasks/readmission_prediction.py +++ b/pyhealth/tasks/readmission_prediction.py @@ -1,8 +1,6 @@ from datetime import datetime, timedelta from typing import Dict, List -import polars as pl - from pyhealth.data import Event, Patient from pyhealth.tasks import BaseTask From 5c58b9395feadc4fd025b030e3c89a2bce31c6b1 Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Wed, 24 Dec 2025 05:32:46 +0000 Subject: [PATCH 10/10] Don't use __del__ for unit test clean up --- tests/core/test_mimic3_readmission_prediction.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/core/test_mimic3_readmission_prediction.py b/tests/core/test_mimic3_readmission_prediction.py index f7b7b6a3e..103728db4 100644 --- a/tests/core/test_mimic3_readmission_prediction.py +++ b/tests/core/test_mimic3_readmission_prediction.py @@ -15,6 +15,9 @@ def setUp(self): self.admission2 = self.mock.add_admission(patient, "2020-01-16 12:00:00", "2020-01-16 12:00:01") # Exactly 15 days later self.admission3 = self.mock.add_admission(patient, "2020-01-31 12:00:00", "2020-01-31 12:00:01") # 15 days later less 1 second + def tearDown(self): + self.mock.destroy() + def test_patient_with_pos_and_neg_samples(self): dataset = self.mock.create() @@ -236,7 +239,7 @@ def create(self, tables: list=["diagnoses_icd", "prescriptions", "procedures_icd return MIMIC3Dataset(root=".", tables=tables) - def __del__(self): + def destroy(self): if os.path.exists("PATIENTS.csv"): os.remove("PATIENTS.csv") if os.path.exists("ADMISSIONS.csv"): os.remove("ADMISSIONS.csv") if os.path.exists("ICUSTAYS.csv"): os.remove("ICUSTAYS.csv")