Skip to content
681 changes: 681 additions & 0 deletions examples/medlink_mimic3.ipynb
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some quick thoughts that:

  • Can we move the medlink task into the pyhealth.tasks module too? I actually think it'd be really helpful also to further have detailed documentation surrounding the query/document identifiers. It'd be good to link it up with the original paper's task of mapping records to a master known patient record.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would also be nice to have it in the docs/ as that'll actually be a pretty nice to have for anyone working on record linkage problems.

Large diffs are not rendered by default.

25 changes: 14 additions & 11 deletions examples/test_eICU_addition.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
from pyhealth.datasets import eICUDataset
from pyhealth.tasks import mortality_prediction_eicu_fn, mortality_prediction_eicu_fn2
from pyhealth.tasks import MortalityPredictionEICU, MortalityPredictionEICU2

base_dataset = eICUDataset(
root="/srv/local/data/physionet.org/files/eicu-crd/2.0",
tables=["diagnosis", "admissionDx", "treatment"],
dev=False,
refresh_cache=False,
)
sample_dataset = base_dataset.set_task(task_fn=mortality_prediction_eicu_fn2)
sample_dataset.stat()
print(sample_dataset.available_keys)
if __name__ == "__main__":
base_dataset = eICUDataset(
root="/srv/local/data/physionet.org/files/eicu-crd/2.0",
tables=["diagnosis", "admissionDx", "treatment"],
dev=False,
refresh_cache=False,
)
task = MortalityPredictionEICU2()
sample_dataset = base_dataset.set_task(task=task)
sample_dataset.stat()
print(sample_dataset.available_keys)

# base_dataset = eICUDataset(
# root="/srv/local/data/physionet.org/files/eicu-crd/2.0",
# tables=["diagnosis", "admissionDx", "treatment"],
# dev=True,
# refresh_cache=False,
# )
# sample_dataset = base_dataset.set_task(task_fn=mortality_prediction_eicu_fn2)
# task = MortalityPredictionEICU2()
# sample_dataset = base_dataset.set_task(task=task)
# sample_dataset.stat()
# print(sample_dataset.available_keys)
1 change: 1 addition & 0 deletions pyhealth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
formatter = logging.Formatter("%(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)

3 changes: 2 additions & 1 deletion pyhealth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@
from .transformer import Transformer, TransformerLayer
from .transformers_model import TransformersModel
from .vae import VAE
from .sdoh import SdohClassifier
from .sdoh import SdohClassifier
from .medlink import MedLink
118 changes: 116 additions & 2 deletions pyhealth/models/embedding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Dict
from __future__ import annotations

from typing import Dict, Any, Optional, Union
import os

import torch
import torch.nn as nn
Expand All @@ -18,6 +21,94 @@
)
from .base_model import BaseModel


def _iter_text_vectors(
path: str,
embedding_dim: int,
wanted_tokens: set[str],
encoding: str = "utf-8",
) -> Dict[str, torch.Tensor]:
"""Loads word vectors from a text file (e.g., GloVe) for a subset of tokens.

Expected format: one token per line followed by embedding_dim floats.

This function reads the file line-by-line and only retains vectors for
tokens present in `wanted_tokens`.
"""

if not os.path.exists(path):
raise FileNotFoundError(f"pretrained embedding file not found: {path}")

vectors: Dict[str, torch.Tensor] = {}
with open(path, "r", encoding=encoding) as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split()
# token + embedding_dim values
if len(parts) < embedding_dim + 1:
continue
token = parts[0]
if token not in wanted_tokens:
continue
try:
vec = torch.tensor(
[float(x) for x in parts[1 : embedding_dim + 1]],
dtype=torch.float,
)
except ValueError:
continue
vectors[token] = vec
return vectors


def init_embedding_with_pretrained(
embedding: nn.Embedding,
code_vocab: Dict[Any, int],
pretrained_path: str,
embedding_dim: int,
pad_token: str = "<pad>",
unk_token: str = "<unk>",
normalize: bool = False,
freeze: bool = False,
) -> int:
"""Initializes an nn.Embedding from a pretrained text-vector file.

Tokens not found in the pretrained file are left as the module's existing
random initialization.

Returns:
int: number of tokens successfully initialized from the file.
"""

# Build wanted token set (stringified)
vocab_tokens = {str(t) for t in code_vocab.keys()}
vectors = _iter_text_vectors(pretrained_path, embedding_dim, vocab_tokens)

loaded = 0
with torch.no_grad():
for tok, idx in code_vocab.items():
tok_s = str(tok)
if tok_s in vectors:
vec = vectors[tok_s]
if normalize:
vec = vec / (vec.norm(p=2) + 1e-12)
embedding.weight[idx].copy_(vec)
loaded += 1

# Ensure pad row is zero
if pad_token in code_vocab:
embedding.weight[code_vocab[pad_token]].zero_()
# If embedding has a padding_idx, keep it consistent
if embedding.padding_idx is not None:
embedding.weight[embedding.padding_idx].zero_()

if freeze:
embedding.weight.requires_grad_(False)

return loaded

class EmbeddingModel(BaseModel):
"""
EmbeddingModel is responsible for creating embedding layers for different types of input data.
Expand Down Expand Up @@ -46,7 +137,14 @@ class EmbeddingModel(BaseModel):
- MultiHotProcessor: nn.Linear over multi-hot vector
"""

def __init__(self, dataset: SampleDataset, embedding_dim: int = 128):
def __init__(
self,
dataset: SampleDataset,
embedding_dim: int = 128,
pretrained_emb_path: Optional[Union[str, Dict[str, str]]] = None,
freeze_pretrained: bool = False,
normalize_pretrained: bool = False,
):
super().__init__(dataset)
self.embedding_dim = embedding_dim
self.embedding_layers = nn.ModuleDict()
Expand Down Expand Up @@ -81,6 +179,22 @@ def __init__(self, dataset: SampleDataset, embedding_dim: int = 128):
padding_idx=0,
)

# Optional pretrained initialization (e.g., GloVe).
if pretrained_emb_path is not None:
if isinstance(pretrained_emb_path, str):
path = pretrained_emb_path
else:
path = pretrained_emb_path.get(field_name)
if path:
init_embedding_with_pretrained(
self.embedding_layers[field_name],
processor.code_vocab,
path,
embedding_dim=embedding_dim,
normalize=normalize_pretrained,
freeze=freeze_pretrained,
)

# Numeric features (including deep nested floats) -> nn.Linear over last dim
elif isinstance(
processor,
Expand Down
Loading