Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
a2156a8
move _scan_csv_tsv_gz into ../tmp folder
Logiquo Dec 24, 2025
d7b2edc
clean up global_event_df.parquet if failed, clean up tmp dir
Logiquo Dec 24, 2025
01553c1
support context manager for SampleDataset
Logiquo Dec 24, 2025
44add05
Fix set_task cache_dir
Logiquo Dec 24, 2025
c0cc77d
move function to _event_transform
Logiquo Dec 24, 2025
8d66343
Fix set_task
Logiquo Dec 24, 2025
f1c55ae
Refactor set_task
Logiquo Dec 24, 2025
cc91385
Fix up
Logiquo Dec 24, 2025
decd35e
Fixup
Logiquo Dec 24, 2025
a456f12
rename
Logiquo Dec 24, 2025
ef39415
fix _task_transform_fn
Logiquo Dec 24, 2025
a7a06c1
fixup
Logiquo Dec 24, 2025
b466fa0
Fixup
Logiquo Dec 24, 2025
d08f66f
Fixup
Logiquo Dec 24, 2025
83a45b7
Fix mimic4 cache dir
Logiquo Dec 24, 2025
5430c8f
update memtest
Logiquo Dec 25, 2025
ba4c4f2
more workers
Logiquo Dec 25, 2025
74fce5f
Fix signle thread
Logiquo Dec 25, 2025
5ad6ca0
Fix single thread
Logiquo Dec 25, 2025
0604813
Fixup
Logiquo Dec 25, 2025
99da1a9
Fix test
Logiquo Dec 25, 2025
647ce7d
Fix up environ.
Logiquo Dec 25, 2025
5d22c4d
Fix test
Logiquo Dec 25, 2025
d02b3cc
better env_var management
Logiquo Dec 25, 2025
078707d
Fixup
Logiquo Dec 25, 2025
dc36543
correct result scope
Logiquo Dec 25, 2025
01199ba
fix litdata A newer version of litdata is available (0.2.59)
Logiquo Dec 25, 2025
95d8b47
Fix incorrect tmpdir cleanup
Logiquo Dec 25, 2025
8b33eed
Add TODO
Logiquo Dec 25, 2025
d6f0cf5
Clear unused code
Logiquo Dec 25, 2025
0439059
rename queue to progress to better reflect it's usage.
Logiquo Dec 25, 2025
5ba4c1d
Fix tqdm for single worker
Logiquo Dec 25, 2025
acd6308
Fix busy waiting
Logiquo Dec 25, 2025
71e3f54
delete outdated test
Logiquo Dec 25, 2025
8e18067
Fix signle thread context
Logiquo Dec 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 59 additions & 72 deletions examples/memtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
4. Training a StageNet model
"""

# %%
if __name__ == "__main__":
from pyhealth.datasets import (
MIMIC4Dataset,
Expand All @@ -20,7 +19,6 @@
from pyhealth.trainer import Trainer
import torch

# %% STEP 1: Load MIMIC-IV base dataset
base_dataset = MIMIC4Dataset(
ehr_root="/home/logic/physionet.org/files/mimiciv/3.1/",
ehr_tables=[
Expand All @@ -31,76 +29,65 @@
"labevents",
],
dev=False,
num_workers=8,
)

# %% # STEP 2: Apply StageNet mortality prediction task
sample_dataset = base_dataset.set_task(
with base_dataset.set_task(
MortalityPredictionStageNetMIMIC4(),
num_workers=4,
)

print(f"Total samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

# %% Inspect a sample
sample = next(iter(sample_dataset))
print("\nSample structure:")
print(f" Patient ID: {sample['patient_id']}")
print(f"ICD Codes: {sample['icd_codes']}")
print(f" Labs shape: {len(sample['labs'][0])} timesteps")
print(f" Mortality: {sample['mortality']}")

# %% STEP 3: Split dataset
train_dataset, val_dataset, test_dataset = split_by_patient(
sample_dataset, [0.8, 0.1, 0.1]
)

# Create dataloaders
train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False)

# %% STEP 4: Initialize StageNet model
model = StageNet(
dataset=sample_dataset,
embedding_dim=128,
chunk_size=128,
levels=3,
dropout=0.3,
)

num_params = sum(p.numel() for p in model.parameters())
print(f"\nModel initialized with {num_params} parameters")

# %% STEP 5: Train the model
trainer = Trainer(
model=model,
device="cpu", # or "cpu"
metrics=["pr_auc", "roc_auc", "accuracy", "f1"],
)

trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
epochs=5,
monitor="roc_auc",
optimizer_params={"lr": 1e-5},
)

# %% STEP 6: Evaluate on test set
results = trainer.evaluate(test_loader)
print("\nTest Results:")
for metric, value in results.items():
print(f" {metric}: {value:.4f}")

# %% STEP 7: Inspect model predictions
sample_batch = next(iter(test_loader))
with torch.no_grad():
output = model(**sample_batch)

print("\nSample predictions:")
print(f" Predicted probabilities: {output['y_prob'][:5]}")
print(f" True labels: {output['y_true'][:5]}")

# %%
) as sample_dataset:
print(f"Total samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

sample = next(iter(sample_dataset))
print("\nSample structure:")
print(f" Patient ID: {sample['patient_id']}")
print(f"ICD Codes: {sample['icd_codes']}")
print(f" Labs shape: {len(sample['labs'][0])} timesteps")
print(f" Mortality: {sample['mortality']}")

train_dataset, val_dataset, test_dataset = split_by_patient(
sample_dataset, [0.8, 0.1, 0.1]
)

train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False)

model = StageNet(
dataset=sample_dataset,
embedding_dim=128,
chunk_size=128,
levels=3,
dropout=0.3,
)

num_params = sum(p.numel() for p in model.parameters())
print(f"\nModel initialized with {num_params} parameters")

trainer = Trainer(
model=model,
device="cpu", # or "cpu"
metrics=["pr_auc", "roc_auc", "accuracy", "f1"],
)

trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
epochs=5,
monitor="roc_auc",
optimizer_params={"lr": 1e-5},
)

results = trainer.evaluate(test_loader)
print("\nTest Results:")
for metric, value in results.items():
print(f" {metric}: {value:.4f}")

sample_batch = next(iter(test_loader))
with torch.no_grad():
output = model(**sample_batch)

print("\nSample predictions:")
print(f" Predicted probabilities: {output['y_prob'][:5]}")
print(f" True labels: {output['y_true'][:5]}")
Loading