From 25eb5911ef7773fd9a86b95a9c9c084e9f903bd0 Mon Sep 17 00:00:00 2001 From: shreyaaishi Date: Sun, 19 Apr 2026 13:25:36 -0400 Subject: [PATCH 01/12] Create OMOP knowledge graph --- pyhealth/models/__init__.py | 1 + pyhealth/models/keep_embedding.py | 353 ++++++++++++++++++++++++++++++ 2 files changed, 354 insertions(+) create mode 100644 pyhealth/models/keep_embedding.py diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5233b1726..ed4fbc837 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -44,3 +44,4 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding +from .keep_embedding import KeepEmbedding, N2V \ No newline at end of file diff --git a/pyhealth/models/keep_embedding.py b/pyhealth/models/keep_embedding.py new file mode 100644 index 000000000..1d255228b --- /dev/null +++ b/pyhealth/models/keep_embedding.py @@ -0,0 +1,353 @@ +import os +import sys +import logging +from typing import List, Tuple, Dict, Set +from collections import defaultdict, Counter + +import networkx as nx +from node2vec import Node2Vec +import numpy as np +import pandas as pd +# from scipy import sparse + +from pyhealth.datasets import SampleDataset +from .base_model import BaseModel + +logger = logging.getLogger(__name__) + +class N2V(): + """ + """ + def __init__( + self, + path:str, + domain_type:str, + # depth_limit:int, + embedding_dim:int, + walk_length:int, + num_walks:int + ): + self.path = path + self.domain_type = domain_type + # self.depth_limit = depth_limit + self.embedding_dim = embedding_dim + self.walk_length = walk_length + self.num_walks = num_walks + # Map domain_type to OMOP domain names + self.domain_map = { + "condition": ["Condition"], + "drug": ["Drug"], + "procedure": ["Procedure"], + "all": ["Condition", "Drug", "Procedure"], + } + # If all then no filtering needed + + # Create graph from concept and their relationships data + + def create_graph(self) -> nx.DiGraph: + """ + Create a directed graph from OMOP concept relationships. + + Loads concepts and their relationships from CSV files, filters by domain_type, + and builds a NetworkX DiGraph where nodes are concept IDs and edges are + concept relationships (maps_to). + + Returns: + nx.DiGraph: Directed graph with concept_id as nodes and relationships as edges. + + Raises: + FileNotFoundError: If CSV files are not found. + ValueError: If no concepts found for specified domains. + """ + # Load concept table + concept_path = os.path.join(self.path, "2b_concept.csv") + + # Load concept relationships table + concept_relationship_path = os.path.join(self.path, "2b_concept_relationship.csv") + + # if not os.path.exists(concept_path): + # raise FileNotFoundError(f"2b_concept.csv not found at {concept_path}") + # if not os.path.exists(concept_relationship_path): + # raise FileNotFoundError(f"2b_concept_relationship.csv not found at {concept_relationship_path}") + + # Read CSV files + print(f"Loading concepts from {concept_path}") + concept_df = pd.read_csv(concept_path, dtype=str) + + print(f"Loading concept relationships from {concept_relationship_path}") + concept_rel_df = pd.read_csv(concept_relationship_path, dtype=str) + + target_domains = self.domain_map[self.domain_type] + + # Filter concepts by target domains + # concept_df = concept_df[concept_df["domain_id"].isin(target_domains)].copy() + + # if len(concept_df) == 0: + # raise ValueError(f"No concepts found for domains: {target_domains}") + + # print(f"Filtered to {len(concept_df)} concepts in domains: {target_domains}") + + # # Create set of filtered concept IDs for quick lookup + filtered_concept_ids = set(concept_df["concept_id"].values) + + # # Filter to relationships where both concepts are in our domain set + concept_rel_df = concept_rel_df[ + (concept_rel_df["concept_id_1"].isin(filtered_concept_ids)) & + (concept_rel_df["concept_id_2"].isin(filtered_concept_ids)) + ].copy() + + print(f"Found {len(concept_rel_df)} relationships between filtered concepts") + + # Create directed graph + graph = nx.DiGraph() + + # Add all filtered concepts as nodes + for concept_id, row in concept_df.iterrows(): + graph.add_node( + row["concept_id"], + name=row["concept_name"], + domain=row["domain_id"] + ) + + print(f"Added {len(graph.nodes())} nodes to graph") + + # Add edges from concept relationships + # Typically "maps_to" relationship indicates concept_id_1 maps to concept_id_2 + for _, row in concept_rel_df.iterrows(): + concept_1 = row["concept_id_1"] + concept_2 = row["concept_id_2"] + rel_type = row.get("relationship_id", "maps_to") + + # Add directed edge from concept_1 to concept_2 + graph.add_edge(concept_1, concept_2, relationship=rel_type) + + print(f"Added {len(graph.edges())} edges to graph") + + return graph + + def generate_embeddings(self): + """ + Generate node embeddings using Node2Vec algorithm. + + Creates a graph from OMOP concepts and applies Node2Vec to generate + embeddings for each concept based on its network structure. + + Returns: + gensim.models.Word2Vec: Trained Node2Vec model for concept embeddings. + """ + # Create graph from concepts and relationships + logger.info("Creating concept graph") + graph = self.create_graph() + + logger.info(f"Graph created with {len(graph.nodes())} nodes and {len(graph.edges())} edges") + + if len(graph.nodes()) == 0: + raise ValueError("Graph is empty, cannot generate embeddings") + + # Initialize and fit Node2Vec + logger.info( + f"Initializing Node2Vec with embedding_dim={self.embedding_dim}, " + f"walk_length={self.walk_length}, num_walks={self.num_walks}" + ) + node2vec = Node2Vec( + graph, + dimensions=self.embedding_dim, + walk_length=self.walk_length, + num_walks=self.num_walks, + workers=4 + ) + + # Train the model + logger.info("Training Node2Vec model") + self.model = node2vec.fit(window=10, min_count=1, epochs=1) + + logger.info("Node2Vec training completed") + logger.info(f"Model vocabulary size: {len(self.model.wv)}") + + return self.model + + + +class KeepEmbedding(BaseModel): + """Knowledge-Enhanced Patient Embedding model using OMOP data and node2vec.""" + + def __init__(self, dataset: SampleDataset): + """ + Initialize KeepEmbedding model. + + Args: + dataset: An OMOPDataset instance containing patient clinical data. + """ + super().__init__(dataset=dataset) + + # def build_cooccurrence_matrix( + # self, + # graph: nx.DiGraph, + # domain_type: str = "condition", + # min_occurrences: int = 2, + # ) -> Tuple[sparse.csr_matrix, List[str]]: + # """ + # Build co-occurrence matrix from patient histories using dense roll-up. + + # Iterates through all patients in the dataset, collects concept codes from + # their complete medical history, applies dense roll-up to ancestor concepts, + # and builds a sparse co-occurrence matrix. + + # Args: + # graph (nx.DiGraph): NetworkX graph from Node2Vec.create_graph() + # domain_type (str): Concept domain to include: + # - "condition": condition_occurrence events + # - "drug": drug_exposure events + # - "procedure": procedure_occurrence events + # - "all": All three event types. Default is "condition". + # min_occurrences (int): Minimum number of times a concept must appear + # in a patient's history to be retained (per paper requirement). + # Default is 2. + + # Returns: + # Tuple[sparse.csr_matrix, List[str]]: + # - X: Sparse CSR matrix where X[i,j] = co-occurrence frequency + # between concept_i and concept_j across all patients + # - concept_ids: List of concept IDs corresponding to matrix rows/columns + + # Raises: + # ValueError: If domain_type is invalid or dataset is empty. + # """ + # logger.info(f"Building co-occurrence matrix for domain_type={domain_type}") + + # # Map domain_type to event types and fields + # domain_map = { + # "condition": [("condition_occurrence", "condition_concept_id")], + # "drug": [("drug_exposure", "drug_concept_id")], + # "procedure": [("procedure_occurrence", "procedure_concept_id")], + # "all": [ + # ("condition_occurrence", "condition_concept_id"), + # ("drug_exposure", "drug_concept_id"), + # ("procedure_occurrence", "procedure_concept_id"), + # ], + # } + + # if domain_type not in domain_map: + # raise ValueError( + # f"domain_type must be one of {list(domain_map.keys())}, got {domain_type}" + # ) + + # event_types = domain_map[domain_type] + + # # Extract concept IDs from graph nodes + # concept_ids = sorted(list(graph.nodes())) + # concept_id_to_idx = {cid: idx for idx, cid in enumerate(concept_ids)} + + # logger.info(f"Graph has {len(concept_ids)} concepts") + + # if len(concept_ids) == 0: + # raise ValueError("Graph is empty, cannot build co-occurrence matrix") + + # if self.dataset is None or len(self.dataset.unique_patient_ids) == 0: + # raise ValueError("Dataset is empty, cannot build co-occurrence matrix") + + # # Initialize co-occurrence counter + # cooc_counts = defaultdict(int) + + # # Iterate through all patients + # patient_ids = self.dataset.unique_patient_ids + # logger.info(f"Processing {len(patient_ids)} patients") + + # for patient_id in patient_ids: + # try: + # patient = self.dataset.get_patient(patient_id) + # except Exception as e: + # logger.warning(f"Failed to load patient {patient_id}: {e}") + # continue + + # # Collect all codes from patient's complete history + # all_codes = [] + # for event_type, field in event_types: + # try: + # events = patient.get_events(event_type=event_type) + # codes = [] + # for event in events: + # code = str(getattr(event, field, "")) + # if code and code != "nan": + # codes.append(code) + # all_codes.extend(codes) + # except Exception as e: + # logger.debug(f"Could not get {event_type} for patient {patient_id}: {e}") + # continue + + # if len(all_codes) == 0: + # continue + + # # Count occurrences of each code + # code_counts = Counter(all_codes) + + # # Filter codes with min_occurrences + # retained_codes = [ + # code for code, count in code_counts.items() + # if count >= min_occurrences + # ] + + # if len(retained_codes) == 0: + # continue + + # # Apply dense roll-up: map each code to ALL ancestors in graph + # rolled_codes = set() + # for code in retained_codes: + # rolled_codes.add(code) # Include self + # if code in graph.nodes(): + # # Find all ancestors + # try: + # ancestors = nx.ancestors(graph, code) + # rolled_codes.update(ancestors) + # except Exception as e: + # logger.debug(f"Could not find ancestors for {code}: {e}") + + # # Build co-occurrence pairs (only for codes in graph) + # rolled_codes_in_graph = [c for c in rolled_codes if c in graph.nodes()] + + # if len(rolled_codes_in_graph) > 1: + # # Create all pairs + # for i, code_i in enumerate(rolled_codes_in_graph): + # for code_j in rolled_codes_in_graph[i + 1 :]: + # idx_i = concept_id_to_idx[code_i] + # idx_j = concept_id_to_idx[code_j] + + # # Store symmetric pairs + # if idx_i <= idx_j: + # cooc_counts[(idx_i, idx_j)] += 1 + # else: + # cooc_counts[(idx_j, idx_i)] += 1 + + # logger.info(f"Generated {len(cooc_counts)} unique co-occurrence pairs") + + # # Build sparse matrix + # if len(cooc_counts) == 0: + # logger.warning("No co-occurrences found, returning empty sparse matrix") + # X = sparse.csr_matrix((len(concept_ids), len(concept_ids)), dtype=np.float32) + # return X, concept_ids + + # # Extract rows, columns, and data + # rows, cols, data = [], [], [] + # for (i, j), count in cooc_counts.items(): + # rows.append(i) + # cols.append(j) + # data.append(count) + # # Add symmetric entry + # rows.append(j) + # cols.append(i) + # data.append(count) + + # # Create COO matrix and convert to CSR + # X = sparse.coo_matrix( + # (data, (rows, cols)), + # shape=(len(concept_ids), len(concept_ids)), + # dtype=np.float32, + # ) + # X = X.tocsr() + + # logger.info( + # f"Built sparse co-occurrence matrix: shape={X.shape}, nnz={X.nnz}, " + # f"sparsity={1 - X.nnz / (X.shape[0] * X.shape[1]):.4f}" + # ) + + # return X, concept_ids \ No newline at end of file From 20b528d9ba9c819d016edb68e63e0620c169dd81 Mon Sep 17 00:00:00 2001 From: shreyaaishi Date: Sun, 19 Apr 2026 15:42:52 -0400 Subject: [PATCH 02/12] Finished Node2Vec stage of KEEP framework --- pyhealth/models/keep_embedding.py | 177 ++++++++++++++++++------------ 1 file changed, 106 insertions(+), 71 deletions(-) diff --git a/pyhealth/models/keep_embedding.py b/pyhealth/models/keep_embedding.py index 1d255228b..f3350ad28 100644 --- a/pyhealth/models/keep_embedding.py +++ b/pyhealth/models/keep_embedding.py @@ -21,30 +21,19 @@ class N2V(): def __init__( self, path:str, - domain_type:str, - # depth_limit:int, + domain_type:list[str], embedding_dim:int, walk_length:int, num_walks:int ): self.path = path self.domain_type = domain_type - # self.depth_limit = depth_limit self.embedding_dim = embedding_dim self.walk_length = walk_length self.num_walks = num_walks - # Map domain_type to OMOP domain names - self.domain_map = { - "condition": ["Condition"], - "drug": ["Drug"], - "procedure": ["Procedure"], - "all": ["Condition", "Drug", "Procedure"], - } - # If all then no filtering needed - # Create graph from concept and their relationships data - - def create_graph(self) -> nx.DiGraph: + # Create graph from concept and their relationships data + def _create_graph(self) -> nx.DiGraph: """ Create a directed graph from OMOP concept relationships. @@ -60,71 +49,94 @@ def create_graph(self) -> nx.DiGraph: ValueError: If no concepts found for specified domains. """ # Load concept table - concept_path = os.path.join(self.path, "2b_concept.csv") - - # Load concept relationships table - concept_relationship_path = os.path.join(self.path, "2b_concept_relationship.csv") - - # if not os.path.exists(concept_path): - # raise FileNotFoundError(f"2b_concept.csv not found at {concept_path}") - # if not os.path.exists(concept_relationship_path): - # raise FileNotFoundError(f"2b_concept_relationship.csv not found at {concept_relationship_path}") - - # Read CSV files + concept_path = os.path.join(self.path, "2b_concept.csv") print(f"Loading concepts from {concept_path}") concept_df = pd.read_csv(concept_path, dtype=str) - + + # Load concept relationships table + concept_relationship_path = os.path.join(self.path, "2b_concept_relationship.csv") print(f"Loading concept relationships from {concept_relationship_path}") concept_rel_df = pd.read_csv(concept_relationship_path, dtype=str) - target_domains = self.domain_map[self.domain_type] + print(f"Loaded {len(concept_df)} concepts and {len(concept_rel_df)} relationships") - # Filter concepts by target domains - # concept_df = concept_df[concept_df["domain_id"].isin(target_domains)].copy() + if self.domain_type != ["all"]: + # Filter concepts by target domain + concept_df = concept_df[concept_df["domain_id"].isin(self.domain_type)].copy() - # if len(concept_df) == 0: - # raise ValueError(f"No concepts found for domains: {target_domains}") + print(f"Filtered to {len(concept_df)} concepts in domains: {self.domain_type}") - # print(f"Filtered to {len(concept_df)} concepts in domains: {target_domains}") - - # # Create set of filtered concept IDs for quick lookup + # Create set of filtered concept IDs for quick lookup filtered_concept_ids = set(concept_df["concept_id"].values) + print(f"Created set of {len(filtered_concept_ids)} concept IDs") - # # Filter to relationships where both concepts are in our domain set + # Filter to relationships where both concepts are in our domain set concept_rel_df = concept_rel_df[ (concept_rel_df["concept_id_1"].isin(filtered_concept_ids)) & (concept_rel_df["concept_id_2"].isin(filtered_concept_ids)) ].copy() - print(f"Found {len(concept_rel_df)} relationships between filtered concepts") + print(f"Found {len(concept_rel_df)} relationships between concepts") # Create directed graph graph = nx.DiGraph() # Add all filtered concepts as nodes - for concept_id, row in concept_df.iterrows(): + for _, row in concept_df.iterrows(): graph.add_node( row["concept_id"], name=row["concept_name"], domain=row["domain_id"] ) - print(f"Added {len(graph.nodes())} nodes to graph") - # Add edges from concept relationships - # Typically "maps_to" relationship indicates concept_id_1 maps to concept_id_2 for _, row in concept_rel_df.iterrows(): concept_1 = row["concept_id_1"] concept_2 = row["concept_id_2"] - rel_type = row.get("relationship_id", "maps_to") + rel_type = row.get("relationship_id") # Add directed edge from concept_1 to concept_2 - graph.add_edge(concept_1, concept_2, relationship=rel_type) - - print(f"Added {len(graph.edges())} edges to graph") + if graph.has_edge(concept_1, concept_2): + # Append to existing relationships list + graph[concept_1][concept_2]["relationships"].append(rel_type) + else: + # Create new edge with relationships list + graph.add_edge(concept_1, concept_2, relationships=[rel_type]) return graph + def _build_index_mapping(self, node_embeddings): + """ + Build a dictionary to map concept code to the index in node_embeddings. + + Args: + node_embeddings: Gensim Word2Vec model word vectors + + Returns: + dict: Mapping from concept_id (int) to index in embeddings + """ + return {int(key): i for i, key in enumerate(node_embeddings.index_to_key)} + + def _get_vector_iso(self, code, node_embeddings, index_mapping, mean_vector): + """ + Return concept embedding for the given code or mean vector if not found. + + Args: + code: Concept ID + node_embeddings: Gensim Word2Vec model word vectors + index_mapping: Dictionary mapping concept_id to index + mean_vector: Mean vector to use as fallback + + Returns: + np.ndarray: Embedding vector for the concept + """ + index = index_mapping.get(int(code)) + if index is not None: + return node_embeddings.get_vector(index) + else: + print(f"Code {code} not found, returning mean vector.") + return mean_vector + def generate_embeddings(self): """ Generate node embeddings using Node2Vec algorithm. @@ -136,50 +148,49 @@ def generate_embeddings(self): gensim.models.Word2Vec: Trained Node2Vec model for concept embeddings. """ # Create graph from concepts and relationships - logger.info("Creating concept graph") - graph = self.create_graph() + print("Creating OMOP knowledge graph") + graph = self._create_graph() - logger.info(f"Graph created with {len(graph.nodes())} nodes and {len(graph.edges())} edges") + print(f"Graph created with {len(graph.nodes())} nodes and {len(graph.edges())} edges") if len(graph.nodes()) == 0: raise ValueError("Graph is empty, cannot generate embeddings") # Initialize and fit Node2Vec - logger.info( - f"Initializing Node2Vec with embedding_dim={self.embedding_dim}, " - f"walk_length={self.walk_length}, num_walks={self.num_walks}" - ) + print(f"Initializing Node2Vec with embedding_dim={self.embedding_dim} walk_length={self.walk_length}, num_walks={self.num_walks}") + node2vec = Node2Vec( graph, dimensions=self.embedding_dim, walk_length=self.walk_length, num_walks=self.num_walks, - workers=4 + p=1, q=1, workers=4 ) # Train the model - logger.info("Training Node2Vec model") self.model = node2vec.fit(window=10, min_count=1, epochs=1) - logger.info("Node2Vec training completed") - logger.info(f"Model vocabulary size: {len(self.model.wv)}") + # Extract embeddings from trained model + keys = list(graph.nodes()) + node_embeddings = self.model.wv - return self.model - - - -class KeepEmbedding(BaseModel): - """Knowledge-Enhanced Patient Embedding model using OMOP data and node2vec.""" - - def __init__(self, dataset: SampleDataset): - """ - Initialize KeepEmbedding model. + # Build index mapping for efficient lookup + index_mapping = self._build_index_mapping(node_embeddings) + mean_vector = np.mean(node_embeddings.vectors, axis=0) - Args: - dataset: An OMOPDataset instance containing patient clinical data. - """ - super().__init__(dataset=dataset) - + # Create embedding vectors for all concepts + print(f"Creating embedding vectors for {len(keys)} concepts...") + vectors = [self._get_vector_iso(key, node_embeddings, index_mapping, mean_vector) for key in keys] + + # Stack into matrix + embedding_matrix = np.vstack(vectors) + print(f"Embedding matrix shape: {embedding_matrix.shape}") + + return embedding_matrix + +class GloVe(): + def __init__(self): + pass # def build_cooccurrence_matrix( # self, # graph: nx.DiGraph, @@ -350,4 +361,28 @@ def __init__(self, dataset: SampleDataset): # f"sparsity={1 - X.nnz / (X.shape[0] * X.shape[1]):.4f}" # ) - # return X, concept_ids \ No newline at end of file + # return X, concept_ids + +class KeepEmbedding(BaseModel): + def __init__(self, + dataset: SampleDataset, + path:str, + domain_type:list[str], + embedding_dim:int, + walk_length:int, + num_walks:int + ): + """ + """ + super().__init__(dataset=dataset) + self.n2v = N2V( + path=path, + domain_type=domain_type, + embedding_dim=embedding_dim, + walk_length=walk_length, + num_walks=num_walks + ) + + def test(self): + embedding_matrix = self.n2v.generate_embeddings() + print(f"Created embedding matrix with shape: {embedding_matrix.shape}") \ No newline at end of file From 64e7845bdd8f4a598db9cef393864566f78f7e85 Mon Sep 17 00:00:00 2001 From: shreyaaishi Date: Sun, 19 Apr 2026 18:16:23 -0400 Subject: [PATCH 03/12] Add GloVe stage of KEEP framework --- pyhealth/models/__init__.py | 2 +- pyhealth/models/keep_embedding.py | 374 +++++++++++++++--------------- 2 files changed, 188 insertions(+), 188 deletions(-) diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index ed4fbc837..f285a260c 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -44,4 +44,4 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding -from .keep_embedding import KeepEmbedding, N2V \ No newline at end of file +from .keep_embedding import KeepEmbedding \ No newline at end of file diff --git a/pyhealth/models/keep_embedding.py b/pyhealth/models/keep_embedding.py index f3350ad28..ac035437b 100644 --- a/pyhealth/models/keep_embedding.py +++ b/pyhealth/models/keep_embedding.py @@ -1,20 +1,18 @@ import os -import sys -import logging -from typing import List, Tuple, Dict, Set from collections import defaultdict, Counter import networkx as nx from node2vec import Node2Vec import numpy as np import pandas as pd +import torch +import torch.nn.functional as F # from scipy import sparse +from torch import nn from pyhealth.datasets import SampleDataset from .base_model import BaseModel -logger = logging.getLogger(__name__) - class N2V(): """ """ @@ -189,192 +187,61 @@ def generate_embeddings(self): return embedding_matrix class GloVe(): - def __init__(self): + def __init__(self, + dataset: SampleDataset, + ): + self.dataset = dataset + + def build_cooccurrence_matrix(self): pass - # def build_cooccurrence_matrix( - # self, - # graph: nx.DiGraph, - # domain_type: str = "condition", - # min_occurrences: int = 2, - # ) -> Tuple[sparse.csr_matrix, List[str]]: - # """ - # Build co-occurrence matrix from patient histories using dense roll-up. - - # Iterates through all patients in the dataset, collects concept codes from - # their complete medical history, applies dense roll-up to ancestor concepts, - # and builds a sparse co-occurrence matrix. - - # Args: - # graph (nx.DiGraph): NetworkX graph from Node2Vec.create_graph() - # domain_type (str): Concept domain to include: - # - "condition": condition_occurrence events - # - "drug": drug_exposure events - # - "procedure": procedure_occurrence events - # - "all": All three event types. Default is "condition". - # min_occurrences (int): Minimum number of times a concept must appear - # in a patient's history to be retained (per paper requirement). - # Default is 2. - - # Returns: - # Tuple[sparse.csr_matrix, List[str]]: - # - X: Sparse CSR matrix where X[i,j] = co-occurrence frequency - # between concept_i and concept_j across all patients - # - concept_ids: List of concept IDs corresponding to matrix rows/columns - - # Raises: - # ValueError: If domain_type is invalid or dataset is empty. - # """ - # logger.info(f"Building co-occurrence matrix for domain_type={domain_type}") - - # # Map domain_type to event types and fields - # domain_map = { - # "condition": [("condition_occurrence", "condition_concept_id")], - # "drug": [("drug_exposure", "drug_concept_id")], - # "procedure": [("procedure_occurrence", "procedure_concept_id")], - # "all": [ - # ("condition_occurrence", "condition_concept_id"), - # ("drug_exposure", "drug_concept_id"), - # ("procedure_occurrence", "procedure_concept_id"), - # ], - # } - - # if domain_type not in domain_map: - # raise ValueError( - # f"domain_type must be one of {list(domain_map.keys())}, got {domain_type}" - # ) - - # event_types = domain_map[domain_type] - - # # Extract concept IDs from graph nodes - # concept_ids = sorted(list(graph.nodes())) - # concept_id_to_idx = {cid: idx for idx, cid in enumerate(concept_ids)} - - # logger.info(f"Graph has {len(concept_ids)} concepts") - - # if len(concept_ids) == 0: - # raise ValueError("Graph is empty, cannot build co-occurrence matrix") - - # if self.dataset is None or len(self.dataset.unique_patient_ids) == 0: - # raise ValueError("Dataset is empty, cannot build co-occurrence matrix") - - # # Initialize co-occurrence counter - # cooc_counts = defaultdict(int) - - # # Iterate through all patients - # patient_ids = self.dataset.unique_patient_ids - # logger.info(f"Processing {len(patient_ids)} patients") - - # for patient_id in patient_ids: - # try: - # patient = self.dataset.get_patient(patient_id) - # except Exception as e: - # logger.warning(f"Failed to load patient {patient_id}: {e}") - # continue - - # # Collect all codes from patient's complete history - # all_codes = [] - # for event_type, field in event_types: - # try: - # events = patient.get_events(event_type=event_type) - # codes = [] - # for event in events: - # code = str(getattr(event, field, "")) - # if code and code != "nan": - # codes.append(code) - # all_codes.extend(codes) - # except Exception as e: - # logger.debug(f"Could not get {event_type} for patient {patient_id}: {e}") - # continue - - # if len(all_codes) == 0: - # continue - - # # Count occurrences of each code - # code_counts = Counter(all_codes) - - # # Filter codes with min_occurrences - # retained_codes = [ - # code for code, count in code_counts.items() - # if count >= min_occurrences - # ] - - # if len(retained_codes) == 0: - # continue - - # # Apply dense roll-up: map each code to ALL ancestors in graph - # rolled_codes = set() - # for code in retained_codes: - # rolled_codes.add(code) # Include self - # if code in graph.nodes(): - # # Find all ancestors - # try: - # ancestors = nx.ancestors(graph, code) - # rolled_codes.update(ancestors) - # except Exception as e: - # logger.debug(f"Could not find ancestors for {code}: {e}") - - # # Build co-occurrence pairs (only for codes in graph) - # rolled_codes_in_graph = [c for c in rolled_codes if c in graph.nodes()] - - # if len(rolled_codes_in_graph) > 1: - # # Create all pairs - # for i, code_i in enumerate(rolled_codes_in_graph): - # for code_j in rolled_codes_in_graph[i + 1 :]: - # idx_i = concept_id_to_idx[code_i] - # idx_j = concept_id_to_idx[code_j] - - # # Store symmetric pairs - # if idx_i <= idx_j: - # cooc_counts[(idx_i, idx_j)] += 1 - # else: - # cooc_counts[(idx_j, idx_i)] += 1 - - # logger.info(f"Generated {len(cooc_counts)} unique co-occurrence pairs") - - # # Build sparse matrix - # if len(cooc_counts) == 0: - # logger.warning("No co-occurrences found, returning empty sparse matrix") - # X = sparse.csr_matrix((len(concept_ids), len(concept_ids)), dtype=np.float32) - # return X, concept_ids - - # # Extract rows, columns, and data - # rows, cols, data = [], [], [] - # for (i, j), count in cooc_counts.items(): - # rows.append(i) - # cols.append(j) - # data.append(count) - # # Add symmetric entry - # rows.append(j) - # cols.append(i) - # data.append(count) - - # # Create COO matrix and convert to CSR - # X = sparse.coo_matrix( - # (data, (rows, cols)), - # shape=(len(concept_ids), len(concept_ids)), - # dtype=np.float32, - # ) - # X = X.tocsr() - - # logger.info( - # f"Built sparse co-occurrence matrix: shape={X.shape}, nnz={X.nnz}, " - # f"sparsity={1 - X.nnz / (X.shape[0] * X.shape[1]):.4f}" - # ) - - # return X, concept_ids class KeepEmbedding(BaseModel): + """KEEP Embedding: Fine-tune Node2Vec embeddings using GloVe while penalizing + deviation from original embeddings. + + Balances: + - Co-occurrence structure (GloVe objective) + - Graph structure prior (Node2Vec via regularization) + + Args: + dataset (SampleDataset): The dataset to train the model. + path (str): Path to OMOP data files for graph construction. + domain_type (list[str]): Domain types to include in graph. + embedding_dim (int): Dimension of embeddings. + walk_length (int): Length of random walks for Node2Vec. + num_walks (int): Number of random walks per node for Node2Vec. + lambda_reg (float): Regularization strength for Node2Vec prior. Default: 1.0. + reg_norm (str or float): Norm type for regularization ('cosine' or numeric p-norm). + Default: None (cosine similarity). + log_scale (bool): Whether to apply log scaling to regularization distance. + Default: False. + device (str): Device to use ('cuda' or 'cpu'). Default: 'cpu'. + """ + def __init__(self, dataset: SampleDataset, - path:str, - domain_type:list[str], - embedding_dim:int, - walk_length:int, - num_walks:int + path: str, + domain_type: list[str], + embedding_dim: int, + walk_length: int, + num_walks: int, + lambda_reg: float = 1.0, + reg_norm: str | float = None, + log_scale: bool = False, + device: str = "cpu" ): - """ - """ + """Initialize KEEP Embedding model.""" super().__init__(dataset=dataset) + + self.embedding_dim = embedding_dim + self.lambda_reg = lambda_reg + self.reg_norm = reg_norm + self.log_scale = log_scale + self.device = device + self.mode = "regression" # Set mode for compatibility with BaseModel + + # Generate Node2Vec embeddings + print(f"Initializing Node2Vec with embedding_dim={embedding_dim}...") self.n2v = N2V( path=path, domain_type=domain_type, @@ -382,7 +249,140 @@ def __init__(self, walk_length=walk_length, num_walks=num_walks ) - - def test(self): + embedding_matrix = self.n2v.generate_embeddings() - print(f"Created embedding matrix with shape: {embedding_matrix.shape}") \ No newline at end of file + print(f"Created embedding matrix with shape: {embedding_matrix.shape}") + + num_words = embedding_matrix.shape[0] + + # Create learnable embedding and bias parameters + self.embeddings_v = nn.Embedding(num_words, embedding_dim) + self.embeddings_u = nn.Embedding(num_words, embedding_dim) + self.biases_v = nn.Embedding(num_words, 1) + self.biases_u = nn.Embedding(num_words, 1) + + # Initialize with Node2Vec embeddings + embedding_tensor = torch.from_numpy(embedding_matrix).float() + self.embeddings_v.weight.data.copy_(embedding_tensor) + self.embeddings_u.weight.data.copy_(embedding_tensor) + + # Store initial embeddings for regularization + self.register_buffer( + "initial_embeddings", + embedding_tensor.clone().to(device) + ) + + # Initialize biases to zero + self.biases_v.weight.data.fill_(0) + self.biases_u.weight.data.fill_(0) + + print(f"Initialized KEEP Embedding with {num_words} tokens") + print(f"Embedding dimension: {embedding_dim}") + print(f"Regularization lambda: {lambda_reg}") + print(f"Regularization norm: {reg_norm}") + print(f"Log scaling: {log_scale}") + + def forward(self, + i_indices: torch.Tensor = None, + j_indices: torch.Tensor = None, + counts: torch.Tensor = None, + weights: torch.Tensor = None, + **kwargs) -> dict[str, torch.Tensor]: + """Forward pass for KEEP Embedding. + + Computes GloVe loss with optional Node2Vec regularization. For compatibility + with BaseModel.forward(), returns a dictionary with keys: loss, y_prob, + y_true, logit. + + For training GloVe objective, pass: + - i_indices: Token indices (batch_size,) + - j_indices: Context token indices (batch_size,) + - counts: Co-occurrence counts (batch_size,) + - weights: Weights for each co-occurrence pair (batch_size,) + + Args: + i_indices (torch.Tensor, optional): Token indices. + j_indices (torch.Tensor, optional): Context token indices. + counts (torch.Tensor, optional): Co-occurrence counts. + weights (torch.Tensor, optional): Weights for loss terms. + **kwargs: Additional arguments for compatibility. + + Returns: + dict: Dictionary with keys: + - loss: Total loss (GloVe + regularization if applicable) + - logit: Placeholder tensor (for BaseModel compatibility) + - y_prob: Placeholder tensor (for BaseModel compatibility) + - y_true: Placeholder tensor (for BaseModel compatibility) + - reg_loss: Regularization loss component (if applicable) + """ + + # If no GloVe inputs provided, return dummy output + if i_indices is None or j_indices is None: + dummy_loss = torch.tensor(0.0, device=self.device, requires_grad=True) + return { + "loss": dummy_loss, + "logit": dummy_loss, + "y_prob": dummy_loss, + "y_true": dummy_loss, + } + + # Move inputs to correct device + i_indices = i_indices.to(self.device) + j_indices = j_indices.to(self.device) + counts = counts.to(self.device) + weights = weights.to(self.device) + + # Get embeddings and biases + embedding_i = self.embeddings_v(i_indices) # (batch_size, embedding_dim) + embedding_j = self.embeddings_u(j_indices) # (batch_size, embedding_dim) + bias_i = self.biases_v(i_indices).squeeze(-1) # (batch_size,) + bias_j = self.biases_u(j_indices).squeeze(-1) # (batch_size,) + + # Compute GloVe loss: weighted squared difference + # GloVe objective: w(i,j) * (u_i · v_j + b_i + b_j - log(X_ij))^2 + dot_product = torch.sum(embedding_i * embedding_j, dim=1) # (batch_size,) + glove_target = torch.log(counts + 1e-8) # Avoid log(0) + + squared_diff = (dot_product + bias_i + bias_j - glove_target) ** 2 + glove_loss = torch.sum(weights * squared_diff) + + total_loss = glove_loss + reg_loss = torch.tensor(0.0, device=self.device) + + # Add Node2Vec regularization if lambda > 0 + if self.lambda_reg > 0: + # Average embeddings: (u_i + v_i) / 2 and (u_j + v_j) / 2 + u_plus_v_i = (embedding_i + self.embeddings_u(i_indices)) / 2 + u_plus_v_j = (embedding_j + self.embeddings_v(j_indices)) / 2 + + # Get initial embeddings + initial_i = self.initial_embeddings[i_indices] + initial_j = self.initial_embeddings[j_indices] + + # Compute regularization distance based on norm type + if self.reg_norm is None or self.reg_norm == "cosine": + # Cosine distance: 1 - cosine_similarity + reg_dist_i = 1 - F.cosine_similarity(u_plus_v_i, initial_i, dim=1) + reg_dist_j = 1 - F.cosine_similarity(u_plus_v_j, initial_j, dim=1) + else: + # Lp norm distance + p_norm = float(self.reg_norm) + reg_dist_i = torch.norm(u_plus_v_i - initial_i, p=p_norm, dim=1) + reg_dist_j = torch.norm(u_plus_v_j - initial_j, p=p_norm, dim=1) + + # Apply log scaling if enabled + if self.log_scale: + reg_dist_i = torch.log(reg_dist_i + 1e-8) + reg_dist_j = torch.log(reg_dist_j + 1e-8) + + # Compute regularization loss + reg_loss = self.lambda_reg * (torch.sum(reg_dist_i) + torch.sum(reg_dist_j)) + total_loss = glove_loss + reg_loss + + return { + "loss": total_loss, + "logit": glove_loss.detach(), # Return GloVe component as logit for reference + "y_prob": torch.zeros(i_indices.shape[0], device=self.device), # Placeholder + "y_true": torch.zeros(i_indices.shape[0], device=self.device), # Placeholder + "reg_loss": reg_loss.detach(), # Return regularization loss for monitoring + } From 60db35a50bb4b5d33489bd47ec477f2e6b95a8b5 Mon Sep 17 00:00:00 2001 From: shreyaaishi Date: Sun, 19 Apr 2026 19:43:24 -0400 Subject: [PATCH 04/12] Add GloVe stage of KEEP framework --- examples/keep.ipynb | 92 +++++++++++++++++++++++++++++++ pyhealth/models/keep_embedding.py | 9 --- 2 files changed, 92 insertions(+), 9 deletions(-) create mode 100644 examples/keep.ipynb diff --git a/examples/keep.ipynb b/examples/keep.ipynb new file mode 100644 index 000000000..981d89095 --- /dev/null +++ b/examples/keep.ipynb @@ -0,0 +1,92 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "840a3b7f", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\michi\\Workspace\\PyHealth\\.venv\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on device: cpu\n", + "Creating OMOP knowledge graph\n", + "Loading concepts from C:\\Users\\michi\\Workspace\\mimic-iv-demo-data-in-the-omop-common-data-model-0.9\\1_omop_data_csv\\2b_concept.csv\n", + "Loading concept relationships from C:\\Users\\michi\\Workspace\\mimic-iv-demo-data-in-the-omop-common-data-model-0.9\\1_omop_data_csv\\2b_concept_relationship.csv\n", + "Loaded 3885 concepts and 7716 relationships\n", + "Created set of 3885 concept IDs\n", + "Found 50 relationships between concepts\n", + "Graph created with 3885 nodes and 26 edges\n", + "Initializing Node2Vec with embedding_dim=100 walk_length=30, num_walks=750\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Computing transition probabilities: 100%|██████████| 3885/3885 [00:00<00:00, 243672.55it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating embedding vectors for 3885 concepts...\n", + "Embedding matrix shape: (3885, 100)\n", + "Generated embedding matrix with shape: (3885, 100)\n" + ] + } + ], + "source": [ + "import torch\n", + "import warnings\n", + "\n", + "from pyhealth.models import KeepEmbedding\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Running on device: {device}\")\n", + "\n", + "keep = KeepEmbedding(\n", + " dataset=None, # Replace with actual dataset if needed\n", + " path=\"C:\\\\Users\\\\michi\\\\Workspace\\\\mimic-iv-demo-data-in-the-omop-common-data-model-0.9\\\\1_omop_data_csv\",\n", + " domain_type=[\"all\"],\n", + " embedding_dim=100,\n", + " walk_length=30,\n", + " num_walks=750\n", + ")\n", + "\n", + "keep.test()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyhealth/models/keep_embedding.py b/pyhealth/models/keep_embedding.py index ac035437b..a87a73f0a 100644 --- a/pyhealth/models/keep_embedding.py +++ b/pyhealth/models/keep_embedding.py @@ -186,15 +186,6 @@ def generate_embeddings(self): return embedding_matrix -class GloVe(): - def __init__(self, - dataset: SampleDataset, - ): - self.dataset = dataset - - def build_cooccurrence_matrix(self): - pass - class KeepEmbedding(BaseModel): """KEEP Embedding: Fine-tune Node2Vec embeddings using GloVe while penalizing deviation from original embeddings. From 4f98ba8cb9f70c55be54c360a30b684771e73f1c Mon Sep 17 00:00:00 2001 From: tyroney Date: Mon, 20 Apr 2026 20:54:13 -0500 Subject: [PATCH 05/12] add KeepEmbedding API docs, update models index, and document N2V --- docs/api/models.rst | 1 + docs/api/models/pyhealth.models.KeepEmbedding | 18 ++++++++++++++++++ pyhealth/models/keep_embedding.py | 14 +++++++++++++- 3 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 docs/api/models/pyhealth.models.KeepEmbedding diff --git a/docs/api/models.rst b/docs/api/models.rst index 7368dec94..845efad8d 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -204,3 +204,4 @@ API Reference models/pyhealth.models.TextEmbedding models/pyhealth.models.BIOT models/pyhealth.models.unified_multimodal_embedding_docs + models/pyhealth.models.KeepEmbedding diff --git a/docs/api/models/pyhealth.models.KeepEmbedding b/docs/api/models/pyhealth.models.KeepEmbedding new file mode 100644 index 000000000..1338eb118 --- /dev/null +++ b/docs/api/models/pyhealth.models.KeepEmbedding @@ -0,0 +1,18 @@ +pyhealth.models.KeepEmbedding +============================= + +KEEP embedding model for ontology-preserving medical code embeddings. + + +Classes +------- + +.. autoclass:: pyhealth.models.KeepEmbedding + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.keep_embedding.N2V + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/pyhealth/models/keep_embedding.py b/pyhealth/models/keep_embedding.py index a87a73f0a..244fa962e 100644 --- a/pyhealth/models/keep_embedding.py +++ b/pyhealth/models/keep_embedding.py @@ -14,7 +14,19 @@ from .base_model import BaseModel class N2V(): - """ + """Node2Vec embeddings for OMOP concepts. + + This class builds a directed knowledge graph from OMOP concept and + concept relationship tables, then trains Node2Vec to generate + ontology-informed embeddings for medical concepts. + + Attributes: + path: Path to the OMOP CSV files. + domain_type: List of OMOP domains used to filter concepts. + embedding_dim: Dimension of the learned embeddings. + walk_length: Length of each random walk. + num_walks: Number of walks generated per node. + graph: Directed graph constructed from OMOP concepts and relations. """ def __init__( self, From 81401aedd99f21aa7683d71adb27d33e0174fcf2 Mon Sep 17 00:00:00 2001 From: Rohan Vasavada Date: Tue, 21 Apr 2026 20:52:05 -0500 Subject: [PATCH 06/12] test cases --- pyhealth/models/keep_embedding.py | 16 +- tests/core/test_keep_embedding.py | 246 ++++++++++++++++++++++++++++++ 2 files changed, 254 insertions(+), 8 deletions(-) create mode 100644 tests/core/test_keep_embedding.py diff --git a/pyhealth/models/keep_embedding.py b/pyhealth/models/keep_embedding.py index 244fa962e..896ab68a8 100644 --- a/pyhealth/models/keep_embedding.py +++ b/pyhealth/models/keep_embedding.py @@ -240,7 +240,7 @@ def __init__(self, self.lambda_reg = lambda_reg self.reg_norm = reg_norm self.log_scale = log_scale - self.device = device + self._device = device self.mode = "regression" # Set mode for compatibility with BaseModel # Generate Node2Vec embeddings @@ -330,10 +330,10 @@ def forward(self, } # Move inputs to correct device - i_indices = i_indices.to(self.device) - j_indices = j_indices.to(self.device) - counts = counts.to(self.device) - weights = weights.to(self.device) + i_indices = i_indices.to(self._device) + j_indices = j_indices.to(self._device) + counts = counts.to(self._device) + weights = weights.to(self._device) # Get embeddings and biases embedding_i = self.embeddings_v(i_indices) # (batch_size, embedding_dim) @@ -350,7 +350,7 @@ def forward(self, glove_loss = torch.sum(weights * squared_diff) total_loss = glove_loss - reg_loss = torch.tensor(0.0, device=self.device) + reg_loss = torch.tensor(0.0, device=self._device) # Add Node2Vec regularization if lambda > 0 if self.lambda_reg > 0: @@ -385,7 +385,7 @@ def forward(self, return { "loss": total_loss, "logit": glove_loss.detach(), # Return GloVe component as logit for reference - "y_prob": torch.zeros(i_indices.shape[0], device=self.device), # Placeholder - "y_true": torch.zeros(i_indices.shape[0], device=self.device), # Placeholder + "y_prob": torch.zeros(i_indices.shape[0], device=self._device), # Placeholder + "y_true": torch.zeros(i_indices.shape[0], device=self._device), # Placeholder "reg_loss": reg_loss.detach(), # Return regularization loss for monitoring } diff --git a/tests/core/test_keep_embedding.py b/tests/core/test_keep_embedding.py new file mode 100644 index 000000000..3e3bc2f37 --- /dev/null +++ b/tests/core/test_keep_embedding.py @@ -0,0 +1,246 @@ +import sys +import unittest +from unittest.mock import MagicMock, patch + +import numpy as np +import torch + +# pyhealth.models.__init__ imports every model in the package, each of which +# pulls in its own optional deps (einops, litdata, polars, rdkit, …). +# Mock out everything except keep_embedding and base_model so Python never +# loads those files, keeping test imports fast and dep-free. +_datasets_mock = MagicMock() +_datasets_mock.SampleDataset = MagicMock +sys.modules.setdefault("pyhealth.datasets", _datasets_mock) +sys.modules.setdefault("pyhealth.processors", MagicMock()) + +for _mod in ( + "pyhealth.models.adacare", + "pyhealth.models.agent", + "pyhealth.models.biot", + "pyhealth.models.cnn", + "pyhealth.models.concare", + "pyhealth.models.contrawr", + "pyhealth.models.deepr", + "pyhealth.models.embedding", + "pyhealth.models.gamenet", + "pyhealth.models.jamba_ehr", + "pyhealth.models.logistic_regression", + "pyhealth.models.gan", + "pyhealth.models.gnn", + "pyhealth.models.graph_torchvision_model", + "pyhealth.models.graphcare", + "pyhealth.models.grasp", + "pyhealth.models.medlink", + "pyhealth.models.micron", + "pyhealth.models.mlp", + "pyhealth.models.molerec", + "pyhealth.models.retain", + "pyhealth.models.rnn", + "pyhealth.models.safedrug", + "pyhealth.models.sparcnet", + "pyhealth.models.stagenet", + "pyhealth.models.stagenet_mha", + "pyhealth.models.tcn", + "pyhealth.models.tfm_tokenizer", + "pyhealth.models.torchvision_model", + "pyhealth.models.transformer", + "pyhealth.models.transformers_model", + "pyhealth.models.ehrmamba", + "pyhealth.models.vae", + "pyhealth.models.vision_embedding", + "pyhealth.models.text_embedding", + "pyhealth.models.sdoh", + "pyhealth.models.unified_embedding", +): + sys.modules.setdefault(_mod, MagicMock()) + +from pyhealth.models.keep_embedding import N2V, KeepEmbedding + + +# Tiny sizes so every test finishes in milliseconds. +NUM_CONCEPTS = 8 +EMBEDDING_DIM = 4 + + +class TestN2VHelpers(unittest.TestCase): + """Test N2V helper methods that require no CSV files or graph construction.""" + + def setUp(self): + """Set up a minimal N2V instance.""" + self.n2v = N2V( + path="/fake", + domain_type=["all"], + embedding_dim=EMBEDDING_DIM, + walk_length=5, + num_walks=5, + ) + + def test_build_index_mapping(self): + """Test that concept string keys are mapped to integer indices.""" + wv = MagicMock() + wv.index_to_key = ["100", "200", "300"] + mapping = self.n2v._build_index_mapping(wv) + self.assertEqual(mapping, {100: 0, 200: 1, 300: 2}) + + def test_get_vector_iso_found(self): + """Test that the correct embedding is returned for a known concept.""" + wv = MagicMock() + wv.index_to_key = ["42"] + vec = np.array([1.0, 2.0]) + wv.get_vector.return_value = vec + result = self.n2v._get_vector_iso("42", wv, {42: 0}, np.zeros(2)) + np.testing.assert_array_equal(result, vec) + + def test_get_vector_iso_missing_returns_mean(self): + """Test that the mean vector is returned for an unknown concept.""" + wv = MagicMock() + wv.index_to_key = [] + mean_vec = np.array([0.5, 0.5]) + result = self.n2v._get_vector_iso("999", wv, {}, mean_vec) + np.testing.assert_array_equal(result, mean_vec) + + +class TestKeepEmbeddingInit(unittest.TestCase): + """Test KeepEmbedding initialization with a mocked N2V embedding matrix.""" + + def setUp(self): + """Set up a KeepEmbedding instance with N2V mocked out.""" + fake_matrix = np.random.randn(NUM_CONCEPTS, EMBEDDING_DIM).astype(np.float32) + with patch.object(N2V, "generate_embeddings", return_value=fake_matrix): + self.model = KeepEmbedding( + dataset=None, + path="/fake/path", + domain_type=["all"], + embedding_dim=EMBEDDING_DIM, + walk_length=5, + num_walks=5, + device="cpu", + ) + + def test_embedding_shapes(self): + """Test that embedding layers are created with the correct dimensions.""" + self.assertEqual(self.model.embeddings_v.num_embeddings, NUM_CONCEPTS) + self.assertEqual(self.model.embeddings_v.embedding_dim, EMBEDDING_DIM) + self.assertEqual(self.model.initial_embeddings.shape, (NUM_CONCEPTS, EMBEDDING_DIM)) + + def test_biases_initialized_to_zero(self): + """Test that bias embeddings are initialized to zero.""" + self.assertTrue(torch.all(self.model.biases_v.weight.data == 0)) + self.assertTrue(torch.all(self.model.biases_u.weight.data == 0)) + + +class TestKeepEmbeddingForward(unittest.TestCase): + """Test KeepEmbedding forward pass across regularization configurations.""" + + def _make_model(self, lambda_reg=1.0, reg_norm=None, log_scale=False): + """Return a KeepEmbedding with N2V mocked out.""" + fake_matrix = np.random.randn(NUM_CONCEPTS, EMBEDDING_DIM).astype(np.float32) + with patch.object(N2V, "generate_embeddings", return_value=fake_matrix): + return KeepEmbedding( + dataset=None, + path="/fake/path", + domain_type=["all"], + embedding_dim=EMBEDDING_DIM, + walk_length=5, + num_walks=5, + lambda_reg=lambda_reg, + reg_norm=reg_norm, + log_scale=log_scale, + device="cpu", + ) + + def _glove_batch(self, batch_size=3): + """Return a minimal GloVe batch with random indices.""" + return { + "i_indices": torch.randint(0, NUM_CONCEPTS, (batch_size,)), + "j_indices": torch.randint(0, NUM_CONCEPTS, (batch_size,)), + "counts": torch.rand(batch_size) * 10 + 1, + "weights": torch.rand(batch_size), + } + + def setUp(self): + """Set up a default KeepEmbedding model.""" + self.model = self._make_model(lambda_reg=1.0) + + def test_output_keys(self): + """Test that the forward pass returns all expected output keys.""" + ret = self.model(**self._glove_batch()) + for key in ("loss", "logit", "y_prob", "y_true", "reg_loss"): + self.assertIn(key, ret) + + def test_loss_is_scalar(self): + """Test that the total loss is a scalar tensor.""" + self.assertEqual(self.model(**self._glove_batch())["loss"].dim(), 0) + + def test_placeholder_shapes(self): + """Test that placeholder output tensors match the batch size.""" + ret = self.model(**self._glove_batch(batch_size=3)) + self.assertEqual(ret["y_prob"].shape[0], 3) + self.assertEqual(ret["y_true"].shape[0], 3) + + def test_no_inputs_returns_zero_loss(self): + """Test that calling forward with no inputs returns zero loss.""" + self.assertEqual(self.model()["loss"].item(), 0.0) + + def test_no_regularization(self): + """Test that lambda_reg=0 produces zero regularization loss.""" + ret = self._make_model(lambda_reg=0.0)(**self._glove_batch()) + self.assertEqual(ret["reg_loss"].item(), 0.0) + + def test_lp_norm_regularization(self): + """Test that Lp norm regularization runs without error.""" + ret = self._make_model(lambda_reg=1.0, reg_norm=2)(**self._glove_batch()) + self.assertEqual(ret["loss"].dim(), 0) + + def test_log_scale_regularization(self): + """Test that log-scale cosine regularization runs without error.""" + ret = self._make_model(lambda_reg=1.0, log_scale=True)(**self._glove_batch()) + self.assertEqual(ret["loss"].dim(), 0) + + +class TestKeepEmbeddingBackward(unittest.TestCase): + """Test that gradients flow correctly through KeepEmbedding.""" + + def _make_model(self, lambda_reg=1.0): + """Return a KeepEmbedding with N2V mocked out.""" + fake_matrix = np.random.randn(NUM_CONCEPTS, EMBEDDING_DIM).astype(np.float32) + with patch.object(N2V, "generate_embeddings", return_value=fake_matrix): + return KeepEmbedding( + dataset=None, + path="/fake/path", + domain_type=["all"], + embedding_dim=EMBEDDING_DIM, + walk_length=5, + num_walks=5, + lambda_reg=lambda_reg, + device="cpu", + ) + + def _glove_batch(self, batch_size=3): + """Return a minimal GloVe batch with random indices.""" + return { + "i_indices": torch.randint(0, NUM_CONCEPTS, (batch_size,)), + "j_indices": torch.randint(0, NUM_CONCEPTS, (batch_size,)), + "counts": torch.rand(batch_size) * 10 + 1, + "weights": torch.rand(batch_size), + } + + def _has_grads(self, model): + return any(p.grad is not None for p in model.parameters() if p.requires_grad) + + def test_gradients_flow_with_regularization(self): + """Test that gradients flow through the GloVe + regularization loss.""" + model = self._make_model(lambda_reg=1.0) + model(**self._glove_batch())["loss"].backward() + self.assertTrue(self._has_grads(model)) + + def test_gradients_flow_without_regularization(self): + """Test that gradients flow through the GloVe-only loss.""" + model = self._make_model(lambda_reg=0.0) + model(**self._glove_batch())["loss"].backward() + self.assertTrue(self._has_grads(model)) + + +if __name__ == "__main__": + unittest.main() From 38a3cf524a64a89a101b6ad61af24a46bc08ce4c Mon Sep 17 00:00:00 2001 From: shreyaaishi Date: Wed, 22 Apr 2026 01:45:22 -0400 Subject: [PATCH 07/12] Ablation study for keep framework --- examples/keep.ipynb | 377 +++++++++++++++++++++++++----- pyhealth/models/__init__.py | 2 +- pyhealth/models/keep_embedding.py | 41 ++-- 3 files changed, 344 insertions(+), 76 deletions(-) diff --git a/examples/keep.ipynb b/examples/keep.ipynb index 981d89095..679c978fd 100644 --- a/examples/keep.ipynb +++ b/examples/keep.ipynb @@ -1,70 +1,343 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "2631719c", + "metadata": {}, + "source": [ + "1. Load PyHealth/sample data\n", + "2. Build Co-Occurrence Matrix\n", + "3. Convert matrix to GloveDataset dataloader that returns i, j, counts, weights\n", + "4. Create Keep Model\n", + "5. Pass data loader and model to trainer" + ] + }, { "cell_type": "code", - "execution_count": 1, - "id": "840a3b7f", + "execution_count": null, + "id": "773a5c88", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\Users\\michi\\Workspace\\PyHealth\\.venv\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Running on device: cpu\n", - "Creating OMOP knowledge graph\n", - "Loading concepts from C:\\Users\\michi\\Workspace\\mimic-iv-demo-data-in-the-omop-common-data-model-0.9\\1_omop_data_csv\\2b_concept.csv\n", - "Loading concept relationships from C:\\Users\\michi\\Workspace\\mimic-iv-demo-data-in-the-omop-common-data-model-0.9\\1_omop_data_csv\\2b_concept_relationship.csv\n", - "Loaded 3885 concepts and 7716 relationships\n", - "Created set of 3885 concept IDs\n", - "Found 50 relationships between concepts\n", - "Graph created with 3885 nodes and 26 edges\n", - "Initializing Node2Vec with embedding_dim=100 walk_length=30, num_walks=750\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Computing transition probabilities: 100%|██████████| 3885/3885 [00:00<00:00, 243672.55it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Creating embedding vectors for 3885 concepts...\n", - "Embedding matrix shape: (3885, 100)\n", - "Generated embedding matrix with shape: (3885, 100)\n" - ] - } - ], + "outputs": [], "source": [ - "import torch\n", + "from collections import defaultdict, Counter\n", + "import networkx as nx\n", + "import numpy as np\n", + "from pyhealth.datasets import OMOPDataset\n", + "from pyhealth.models import KeepEmbedding, N2V\n", + "from pyhealth.trainer import Trainer\n", "import warnings\n", + "import sys\n", + "import torch\n", + "from torch.utils.data import Dataset, DataLoader\n", "\n", - "from pyhealth.models import KeepEmbedding\n", + "# Add examples to path to import builder\n", + "sys.path.insert(0, '/path/to/examples') # Will be updated below\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "print(f\"Running on device: {device}\")\n", + "print(f\"Running on device: {device}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "efe2f018", + "metadata": {}, + "outputs": [], + "source": [ + "class GloveDataset(Dataset):\n", + " def __init__(self, cooc_matrix, num_words, x_max, alpha):\n", + " super(GloveDataset, self).__init__()\n", + " self.data = []\n", + " for i in range(cooc_matrix.shape[0]):\n", + " for j in range(cooc_matrix.shape[1]):\n", + " if cooc_matrix[i, j] > 0:\n", + " self.data.append((i, j, cooc_matrix[i, j]))\n", + " self.cooc_matrix = cooc_matrix\n", + " self.num_words = num_words\n", + " self.x_max = x_max\n", + " self.alpha = alpha\n", "\n", - "keep = KeepEmbedding(\n", - " dataset=None, # Replace with actual dataset if needed\n", - " path=\"C:\\\\Users\\\\michi\\\\Workspace\\\\mimic-iv-demo-data-in-the-omop-common-data-model-0.9\\\\1_omop_data_csv\",\n", - " domain_type=[\"all\"],\n", + " def __len__(self):\n", + " return len(self.data)\n", + "\n", + " def __getitem__(self, idx):\n", + " i, j, count = self.data[idx]\n", + " weight = (count / self.x_max) ** self.alpha if count < self.x_max else 1.0\n", + " return torch.tensor(i), torch.tensor(j), torch.tensor(count).float(), torch.tensor(weight).float()\n", + " \n", + "def get_code_and_ancestors(graph, code):\n", + " # Start with the code itself\n", + " codes_set = {code}\n", + " \n", + " # Add all ancestors if code exists in graph\n", + " if code in graph:\n", + " ancestors = nx.ancestors(graph, code)\n", + " codes_set.update(ancestors)\n", + " else:\n", + " # Code not in graph - may be rare/invalid, skip silently\n", + " # print(f\"Code {code} not found in concept graph (may be rare/invalid)\")\n", + " pass\n", + " \n", + " return codes_set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4af28ea6", + "metadata": {}, + "outputs": [], + "source": [ + "# Load OMOPDataset and extract condition codes for all patients\n", + "dataset = OMOPDataset(\n", + " root=r\"C:\\Users\\michi\\Workspace\\mimic-iv-demo-data-in-the-omop-common-data-model-0.9\\1_omop_data_csv\",\n", + " tables=[\"condition_occurrence\"],\n", + " dataset_name=\"omop\",\n", + " dev=False\n", + ")\n", + "\n", + "dataset.stats()\n", + "\n", + "print(\"Extracting condition codes from all patients...\")\n", + " \n", + "patient_conditions = defaultdict(list)\n", + "\n", + "# Iterate through all patients\n", + "for patient in dataset.iter_patients():\n", + " patient_id = patient.patient_id\n", + " \n", + " # Get all condition events for this patient\n", + " condition_events = patient.get_events(event_type=\"condition_occurrence\")\n", + " \n", + " # Extract condition_concept_id from each event\n", + " for event in condition_events:\n", + " code = event.attr_dict.get(\"condition_concept_id\")\n", + " if code is not None:\n", + " patient_conditions[patient_id].append(code)\n", + "\n", + "print(f\"Extracted conditions for {len(patient_conditions)} patients\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4570ee56", + "metadata": {}, + "outputs": [], + "source": [ + "# Filter out conditions codes with <2 occurrences in patient history\n", + "filtered_conditions = {}\n", + " \n", + "before_count = sum(len(codes) for codes in patient_conditions.values())\n", + "before_unique = len(set(code for codes in patient_conditions.values() for code in codes))\n", + "\n", + "for patient_id, codes in patient_conditions.items():\n", + " # Count occurrences of each code\n", + " code_counts = Counter(codes)\n", + " \n", + " filtered_codes = [code for code, count in code_counts.items() if count >= 2]\n", + " \n", + " if filtered_codes:\n", + " filtered_conditions[patient_id] = filtered_codes\n", + "\n", + "after_count = sum(len(codes) for codes in filtered_conditions.values())\n", + "after_unique = len(set(code for codes in filtered_conditions.values() for code in codes))\n", + "\n", + "print(f\"Before Filtering:\")\n", + "print(f\"Total codes: {before_count}, Unique: {before_unique}\")\n", + "print(\"=\"*70)\n", + "print(f\"After Filtering:\")\n", + "print(f\"Total codes: {after_count}, Unique: {after_unique}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07f90eb1", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize N2V with same parameters as KeepEmbedding will use\n", + "n2v = N2V(\n", + " path=r\"C:\\Users\\michi\\Workspace\\mimic-iv-demo-data-in-the-omop-common-data-model-0.9\\1_omop_data_csv\",\n", + " domain_type=[\"Condition\"],\n", " embedding_dim=100,\n", " walk_length=30,\n", - " num_walks=750\n", + " num_walks=750, \n", + ")\n", + "\n", + "# Build the concept relationship graph from conditions\n", + "print(\"Building concept relationship graph...\")\n", + "graph = n2v._create_graph()\n", + "\n", + "print(f\"Graph loaded successfully:\")\n", + "print(f\" Nodes (unique concepts): {len(graph.nodes())}\")\n", + "print(f\" Edges (relationships): {len(graph.edges())}\")\n", + "\n", + "print(\"Applying hierarchy roll-up (dense: each code -> itself + all parents)...\")\n", + "\n", + "rolled_up_conditions = {}\n", + "\n", + "# Track statistics\n", + "total_original_codes = 0\n", + "total_rolled_up_codes = 0\n", + "patients_processed = 0\n", + "\n", + "for patient_id, codes in filtered_conditions.items():\n", + " expanded_codes = set()\n", + " \n", + " # For each code, add it and all ancestors\n", + " for code in codes:\n", + " code_and_ancestors = get_code_and_ancestors(graph, code)\n", + " expanded_codes.update(code_and_ancestors)\n", + " \n", + " if expanded_codes:\n", + " rolled_up_conditions[patient_id] = expanded_codes\n", + " total_original_codes += len(codes)\n", + " total_rolled_up_codes += len(expanded_codes)\n", + " patients_processed += 1\n", + "\n", + "# Log statistics\n", + "print(f\"Roll-up complete:\")\n", + "print(f\" Patients processed: {patients_processed}\")\n", + "print(f\" Original codes per patient (avg): {total_original_codes / patients_processed:.2f}\")\n", + "print(f\" Rolled-up codes per patient (avg): {total_rolled_up_codes / patients_processed:.2f}\")\n", + "print(f\" Expansion factor: {total_rolled_up_codes / total_original_codes:.2f}x\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07af9886", + "metadata": {}, + "outputs": [], + "source": [ + "# Construct co-occurrence matrix for rolled-up conditions\n", + "print(\"Building code index from rolled-up conditions...\")\n", + " \n", + "# Collect all unique codes after roll-up\n", + "unique_codes = set()\n", + "for codes in rolled_up_conditions.values():\n", + " unique_codes.update(codes)\n", + "\n", + "# Sort for reproducibility\n", + "unique_codes = sorted(unique_codes)\n", + "\n", + "# Create bidirectional mapping\n", + "code_to_index = {code: idx for idx, code in enumerate(unique_codes)}\n", + "index_to_code = {idx: code for code, idx in code_to_index.items()}\n", + "\n", + "\n", + "print(\"Building co-occurrence matrix...\")\n", + "\n", + "num_codes = len(code_to_index)\n", + "\n", + "# Initialize matrix\n", + "cooc_matrix = np.zeros((num_codes, num_codes), dtype=np.float32)\n", + "\n", + "# Track statistics\n", + "total_pairs = 0\n", + "patients_processed = 0\n", + "\n", + "for patient_id, codes in rolled_up_conditions.items():\n", + " # Convert codes to indices\n", + " code_indices = [code_to_index[code] for code in codes]\n", + " \n", + " # Create all unique pairs\n", + " for i in range(len(code_indices)):\n", + " for j in range(i + 1, len(code_indices)):\n", + " idx_i = code_indices[i]\n", + " idx_j = code_indices[j]\n", + " \n", + " # Increment both matrix[i,j] and matrix[j,i] for symmetry\n", + " cooc_matrix[idx_i, idx_j] += 1.0\n", + " cooc_matrix[idx_j, idx_i] += 1.0\n", + " \n", + " total_pairs += 1\n", + " \n", + " patients_processed += 1\n", + "\n", + "print(f\"Matrix construction complete:\")\n", + "print(f\" Matrix shape: {cooc_matrix.shape}\")\n", + "print(f\" Patients processed: {patients_processed}\")\n", + "print(f\" Total co-occurrence pairs: {total_pairs}\")\n", + "print(f\" Matrix sparsity: {(cooc_matrix == 0).sum() / cooc_matrix.size * 100:.2f}%\")\n", + "print(f\" Matrix sum (total co-occurrences): {cooc_matrix.sum():.0f}\")\n", + "print(f\" Min value: {cooc_matrix.min():.2f}, Max value: {cooc_matrix.max():.2f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0c079f22", + "metadata": {}, + "outputs": [], + "source": [ + "# Parameters for GloveDataset (standard GloVe hyperparameters)\n", + "x_max = 100 # Maximum co-occurrence count before weight saturates\n", + "alpha = 0.75 # Weighting factor exponent\n", + "batch_size = 128\n", + "\n", + "# Create GloveDataset from co-occurrence matrix\n", + "print(\"Creating GloveDataset from co-occurrence matrix...\")\n", + "num_words = cooc_matrix.shape[0]\n", + "glove_dataset = GloveDataset(cooc_matrix, num_words, x_max, alpha)\n", + "\n", + "print(f\" GloveDataset created:\")\n", + "print(f\" Num words (diagnosis codes): {num_words}\")\n", + "print(f\" Dataset size (co-occurrence pairs): {len(glove_dataset)}\")\n", + "\n", + "# Create DataLoader\n", + "data_loader = DataLoader(glove_dataset, batch_size=batch_size, shuffle=True)\n", + "print(f\" DataLoader created: batch_size={batch_size}, num_batches={len(data_loader)}\")\n", + "\n", + "# Initialize KeepEmbedding with MATCHING parameters as builder\n", + "print(\"\\nInitializing KeepEmbedding with builder parameters...\")\n", + "keep_model = KeepEmbedding(\n", + " dataset=None, # GloVe training doesn't require full dataset\n", + " graph = graph, # Pass the concept relationship graph for Node2Vec\n", + " num_words=num_words,\n", + " embedding_dim=100,\n", + " lambda_reg=1.0, # Regularization strength (balances GloVe vs graph prior)\n", + " reg_norm=None, # Use cosine similarity for regularization distance\n", + " log_scale=False, # No log scaling on regularization\n", + " device=device\n", + ")\n", + "\n", + "print(f\" KeepEmbedding initialized:\")\n", + "print(f\" Embedding dimension: 100\")\n", + "print(f\" Regularization lambda: 1.0\")\n", + "print(f\" Device: {device}\")\n", + "print(f\"\\nModel ready for training!\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "840a3b7f", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Training KeepEmbedding model with GloVe objective + graph regularization...\")\n", + "print(f\" Total batches: {len(data_loader)}\")\n", + "print(f\" Training for 50 epochs\\n\")\n", + "\n", + "trainer = Trainer(model=keep_model)\n", + "\n", + "# Train with GloVe objective + Node2Vec graph regularization\n", + "trainer.train(\n", + " train_dataloader=data_loader,\n", + " val_dataloader=None, # No validation split for embedding training\n", + " epochs=50,\n", + " optimizer_class=torch.optim.Adam,\n", + " learning_rate=0.01,\n", + " monitor=\"loss\", # Monitor training loss\n", + " patience=10, # Early stopping patience\n", ")\n", "\n", - "keep.test()" + "print(\"\\nTraining complete!\")\n", + "print(f\"Learned embeddings shape: {keep_model.embeddings_v.weight.shape}\")\n", + "print(f\"Ready for downstream tasks!\")" ] } ], diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index f285a260c..991850db0 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -44,4 +44,4 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding -from .keep_embedding import KeepEmbedding \ No newline at end of file +from .keep_embedding import N2V, KeepEmbedding \ No newline at end of file diff --git a/pyhealth/models/keep_embedding.py b/pyhealth/models/keep_embedding.py index a87a73f0a..fb8b96ca0 100644 --- a/pyhealth/models/keep_embedding.py +++ b/pyhealth/models/keep_embedding.py @@ -7,7 +7,6 @@ import pandas as pd import torch import torch.nn.functional as F -# from scipy import sparse from torch import nn from pyhealth.datasets import SampleDataset @@ -135,7 +134,7 @@ def _get_vector_iso(self, code, node_embeddings, index_mapping, mean_vector): print(f"Code {code} not found, returning mean vector.") return mean_vector - def generate_embeddings(self): + def generate_embeddings(self, graph): """ Generate node embeddings using Node2Vec algorithm. @@ -211,11 +210,9 @@ class KeepEmbedding(BaseModel): def __init__(self, dataset: SampleDataset, - path: str, - domain_type: list[str], + graph: nx.Graph, + num_words: int, embedding_dim: int, - walk_length: int, - num_walks: int, lambda_reg: float = 1.0, reg_norm: str | float = None, log_scale: bool = False, @@ -228,24 +225,22 @@ def __init__(self, self.lambda_reg = lambda_reg self.reg_norm = reg_norm self.log_scale = log_scale - self.device = device + self._device = device self.mode = "regression" # Set mode for compatibility with BaseModel # Generate Node2Vec embeddings - print(f"Initializing Node2Vec with embedding_dim={embedding_dim}...") - self.n2v = N2V( - path=path, - domain_type=domain_type, - embedding_dim=embedding_dim, - walk_length=walk_length, - num_walks=num_walks - ) - - embedding_matrix = self.n2v.generate_embeddings() + # print(f"Initializing Node2Vec with embedding_dim={embedding_dim}...") + # self.n2v = N2V( + # path=path, + # domain_type=domain_type, + # embedding_dim=embedding_dim, + # walk_length=walk_length, + # num_walks=num_walks + # ) + + embedding_matrix = self.n2v.generate_embeddings(graph) print(f"Created embedding matrix with shape: {embedding_matrix.shape}") - num_words = embedding_matrix.shape[0] - # Create learnable embedding and bias parameters self.embeddings_v = nn.Embedding(num_words, embedding_dim) self.embeddings_u = nn.Embedding(num_words, embedding_dim) @@ -257,16 +252,16 @@ def __init__(self, self.embeddings_v.weight.data.copy_(embedding_tensor) self.embeddings_u.weight.data.copy_(embedding_tensor) + # Initialize biases to zero + self.biases_v.weight.data.fill_(0) + self.biases_u.weight.data.fill_(0) + # Store initial embeddings for regularization self.register_buffer( "initial_embeddings", embedding_tensor.clone().to(device) ) - # Initialize biases to zero - self.biases_v.weight.data.fill_(0) - self.biases_u.weight.data.fill_(0) - print(f"Initialized KEEP Embedding with {num_words} tokens") print(f"Embedding dimension: {embedding_dim}") print(f"Regularization lambda: {lambda_reg}") From af9073afa84aba6aa151c61bf657be45e2520dbb Mon Sep 17 00:00:00 2001 From: shreyaaishi Date: Wed, 22 Apr 2026 02:21:30 -0400 Subject: [PATCH 08/12] Ablation study for keep framework --- examples/keep.ipynb | 1377 ++++++++++++++++++++++++++++- pyhealth/models/keep_embedding.py | 81 +- 2 files changed, 1398 insertions(+), 60 deletions(-) diff --git a/examples/keep.ipynb b/examples/keep.ipynb index 679c978fd..843268aa1 100644 --- a/examples/keep.ipynb +++ b/examples/keep.ipynb @@ -14,10 +14,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "773a5c88", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on device: cpu\n" + ] + } + ], "source": [ "from collections import defaultdict, Counter\n", "import networkx as nx\n", @@ -30,16 +38,13 @@ "import torch\n", "from torch.utils.data import Dataset, DataLoader\n", "\n", - "# Add examples to path to import builder\n", - "sys.path.insert(0, '/path/to/examples') # Will be updated below\n", - "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"Running on device: {device}\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "efe2f018", "metadata": {}, "outputs": [], @@ -63,7 +68,13 @@ " def __getitem__(self, idx):\n", " i, j, count = self.data[idx]\n", " weight = (count / self.x_max) ** self.alpha if count < self.x_max else 1.0\n", - " return torch.tensor(i), torch.tensor(j), torch.tensor(count).float(), torch.tensor(weight).float()\n", + " # Return dictionary with keys matching KeepEmbedding.forward() parameters\n", + " return {\n", + " \"i_indices\": torch.tensor(i),\n", + " \"j_indices\": torch.tensor(j),\n", + " \"counts\": torch.tensor(count).float(),\n", + " \"weights\": torch.tensor(weight).float(),\n", + " }\n", " \n", "def get_code_and_ancestors(graph, code):\n", " # Start with the code itself\n", @@ -83,10 +94,27 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "4af28ea6", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No config path provided, using default OMOP config\n", + "Initializing omop dataset from C:\\Users\\michi\\Workspace\\mimic-iv-demo-data-in-the-omop-common-data-model-0.9\\1_omop_data_csv (dev mode: False)\n", + "No cache_dir provided. Using default cache dir: C:\\Users\\michi\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\d8eca071-4bf2-5cb6-ba60-95451c941912\n", + "Found cached event dataframe: C:\\Users\\michi\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\d8eca071-4bf2-5cb6-ba60-95451c941912\\global_event_df.parquet\n", + "Dataset: omop\n", + "Dev mode: False\n", + "Number of patients: 100\n", + "Number of events: 17408\n", + "Extracting condition codes from all patients...\n", + "Extracted conditions for 100 patients\n" + ] + } + ], "source": [ "# Load OMOPDataset and extract condition codes for all patients\n", "dataset = OMOPDataset(\n", @@ -120,10 +148,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "4570ee56", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Before Filtering:\n", + "Total codes: 16441, Unique: 979\n", + "======================================================================\n", + "After Filtering:\n", + "Total codes: 799, Unique: 280\n" + ] + } + ], "source": [ "# Filter out conditions codes with <2 occurrences in patient history\n", "filtered_conditions = {}\n", @@ -152,23 +192,42 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "07f90eb1", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Building concept relationship graph...\n", + "Loading concepts from C:\\Users\\michi\\Workspace\\mimic-iv-demo-data-in-the-omop-common-data-model-0.9\\1_omop_data_csv\\2b_concept.csv\n", + "Loading concept relationships from C:\\Users\\michi\\Workspace\\mimic-iv-demo-data-in-the-omop-common-data-model-0.9\\1_omop_data_csv\\2b_concept_relationship.csv\n", + "Loaded 3885 concepts and 7716 relationships\n", + "Filtered to 34 concepts in domains: ['Condition']\n", + "Created set of 34 concept IDs\n", + "Found 0 relationships between concepts\n", + "Graph loaded successfully:\n", + " Nodes (unique concepts): 34\n", + " Edges (relationships): 0\n", + "Applying hierarchy roll-up (dense: each code -> itself + all parents)...\n", + "Roll-up complete:\n", + " Patients processed: 100\n", + " Original codes per patient (avg): 7.99\n", + " Rolled-up codes per patient (avg): 7.99\n", + " Expansion factor: 1.00x\n" + ] + } + ], "source": [ - "# Initialize N2V with same parameters as KeepEmbedding will use\n", - "n2v = N2V(\n", - " path=r\"C:\\Users\\michi\\Workspace\\mimic-iv-demo-data-in-the-omop-common-data-model-0.9\\1_omop_data_csv\",\n", - " domain_type=[\"Condition\"],\n", - " embedding_dim=100,\n", - " walk_length=30,\n", - " num_walks=750, \n", - ")\n", + "n2v = N2V()\n", "\n", "# Build the concept relationship graph from conditions\n", "print(\"Building concept relationship graph...\")\n", - "graph = n2v._create_graph()\n", + "graph = n2v.create_graph(\n", + " path=r\"C:\\Users\\michi\\Workspace\\mimic-iv-demo-data-in-the-omop-common-data-model-0.9\\1_omop_data_csv\",\n", + " domain_type=[\"Condition\"]\n", + ")\n", "\n", "print(f\"Graph loaded successfully:\")\n", "print(f\" Nodes (unique concepts): {len(graph.nodes())}\")\n", @@ -207,10 +266,26 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "07af9886", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Building code index from rolled-up conditions...\n", + "Building co-occurrence matrix...\n", + "Matrix construction complete:\n", + " Matrix shape: (280, 280)\n", + " Patients processed: 100\n", + " Total co-occurrence pairs: 7978\n", + " Matrix sparsity: 84.11%\n", + " Matrix sum (total co-occurrences): 15956\n", + " Min value: 0.00, Max value: 52.00\n" + ] + } + ], "source": [ "# Construct co-occurrence matrix for rolled-up conditions\n", "print(\"Building code index from rolled-up conditions...\")\n", @@ -268,10 +343,64 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "0c079f22", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating GloveDataset from co-occurrence matrix...\n", + " GloveDataset created:\n", + " Num words (diagnosis codes): 280\n", + " Dataset size (co-occurrence pairs): 12458\n", + " DataLoader created: batch_size=128, num_batches=98\n", + "\n", + "Initializing KeepEmbedding with builder parameters...\n", + "Initializing Node2Vec with embedding_dim=100...\n", + "Graph created with 34 nodes and 0 edges\n", + "Initializing Node2Vec with embedding_dim=100 walk_length=30, num_walks=750\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bbe126854ba544b49b26085b791254a0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Computing transition probabilities: 0%| | 0/34 [00:00\n", + "Optimizer params: {'lr': 0.01}\n", + "Weight decay: 0.0\n", + "Max grad norm: None\n", + "Val dataloader: None\n", + "Monitor: loss\n", + "Monitor criterion: max\n", + "Epochs: 50\n", + "Patience: 10\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "602358d5ea6f4e679ecda85c2b05e58e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Epoch 0 / 50: 0%| | 0/98 [00:00 nx.DiGraph: + def create_graph(self, path, domain_type) -> nx.DiGraph: """ Create a directed graph from OMOP concept relationships. @@ -58,22 +54,22 @@ def _create_graph(self) -> nx.DiGraph: ValueError: If no concepts found for specified domains. """ # Load concept table - concept_path = os.path.join(self.path, "2b_concept.csv") + concept_path = os.path.join(path, "2b_concept.csv") print(f"Loading concepts from {concept_path}") concept_df = pd.read_csv(concept_path, dtype=str) # Load concept relationships table - concept_relationship_path = os.path.join(self.path, "2b_concept_relationship.csv") + concept_relationship_path = os.path.join(path, "2b_concept_relationship.csv") print(f"Loading concept relationships from {concept_relationship_path}") concept_rel_df = pd.read_csv(concept_relationship_path, dtype=str) print(f"Loaded {len(concept_df)} concepts and {len(concept_rel_df)} relationships") - if self.domain_type != ["all"]: + if domain_type != ["all"]: # Filter concepts by target domain - concept_df = concept_df[concept_df["domain_id"].isin(self.domain_type)].copy() + concept_df = concept_df[concept_df["domain_id"].isin(domain_type)].copy() - print(f"Filtered to {len(concept_df)} concepts in domains: {self.domain_type}") + print(f"Filtered to {len(concept_df)} concepts in domains: {domain_type}") # Create set of filtered concept IDs for quick lookup filtered_concept_ids = set(concept_df["concept_id"].values) @@ -154,12 +150,9 @@ def generate_embeddings(self, graph): embeddings for each concept based on its network structure. Returns: - gensim.models.Word2Vec: Trained Node2Vec model for concept embeddings. + tuple: (embedding_matrix, node_ids) where embedding_matrix is the numpy array + of embeddings and node_ids is the list of graph node IDs in order. """ - # Create graph from concepts and relationships - print("Creating OMOP knowledge graph") - graph = self._create_graph() - print(f"Graph created with {len(graph.nodes())} nodes and {len(graph.edges())} edges") if len(graph.nodes()) == 0: @@ -195,7 +188,7 @@ def generate_embeddings(self, graph): embedding_matrix = np.vstack(vectors) print(f"Embedding matrix shape: {embedding_matrix.shape}") - return embedding_matrix + return embedding_matrix, keys class KeepEmbedding(BaseModel): """KEEP Embedding: Fine-tune Node2Vec embeddings using GloVe while penalizing @@ -217,17 +210,22 @@ class KeepEmbedding(BaseModel): Default: None (cosine similarity). log_scale (bool): Whether to apply log scaling to regularization distance. Default: False. + code_to_index (dict, optional): Mapping from concept codes to vocabulary indices. + If provided, embeddings are filtered to only include codes in this mapping. device (str): Device to use ('cuda' or 'cpu'). Default: 'cpu'. """ def __init__(self, dataset: SampleDataset, graph: nx.Graph, + embedding_dim:int, + walk_length:int, + num_walks:int, num_words: int, - embedding_dim: int, lambda_reg: float = 1.0, reg_norm: str | float = None, log_scale: bool = False, + code_to_index: dict = None, device: str = "cpu" ): """Initialize KEEP Embedding model.""" @@ -241,18 +239,41 @@ def __init__(self, self.mode = "regression" # Set mode for compatibility with BaseModel # Generate Node2Vec embeddings - # print(f"Initializing Node2Vec with embedding_dim={embedding_dim}...") - # self.n2v = N2V( - # path=path, - # domain_type=domain_type, - # embedding_dim=embedding_dim, - # walk_length=walk_length, - # num_walks=num_walks - # ) - - embedding_matrix = self.n2v.generate_embeddings(graph) + print(f"Initializing Node2Vec with embedding_dim={embedding_dim}...") + self.n2v = N2V( + embedding_dim=embedding_dim, + walk_length=walk_length, + num_walks=num_walks + ) + + embedding_matrix, node_ids = self.n2v.generate_embeddings(graph) print(f"Created embedding matrix with shape: {embedding_matrix.shape}") + # Filter embeddings if code_to_index mapping is provided + if code_to_index is not None: + print(f"Filtering embeddings from {len(node_ids)} concepts to {num_words} vocabulary items...") + # Create a mapping from node IDs to embeddings + node_to_embedding = {node_id: embedding_matrix[i] for i, node_id in enumerate(node_ids)} + + # Build filtered embedding matrix for only codes in vocabulary + filtered_vectors = [] + missing_count = 0 + for idx in range(num_words): + # Find the concept code for this index + code = next((code for code, code_idx in code_to_index.items() if code_idx == idx), None) + if code is not None and code in node_to_embedding: + filtered_vectors.append(node_to_embedding[code]) + else: + missing_count += 1 + # Use mean of all embeddings as fallback + filtered_vectors.append(np.mean(embedding_matrix, axis=0)) + + if missing_count > 0: + print(f" {missing_count}/{num_words} codes not found in graph, using mean embedding as fallback") + + embedding_matrix = np.vstack(filtered_vectors) + print(f"Filtered embedding matrix shape: {embedding_matrix.shape}") + # Create learnable embedding and bias parameters self.embeddings_v = nn.Embedding(num_words, embedding_dim) self.embeddings_u = nn.Embedding(num_words, embedding_dim) From e09261d7c0613244e671f1421320fd1658468f0e Mon Sep 17 00:00:00 2001 From: shreyaaishi Date: Wed, 22 Apr 2026 21:57:54 -0400 Subject: [PATCH 09/12] Ablation study for keep framework --- examples/keep.ipynb | 1954 ++++++++++--------------------------------- 1 file changed, 439 insertions(+), 1515 deletions(-) diff --git a/examples/keep.ipynb b/examples/keep.ipynb index 843268aa1..58d5c24a4 100644 --- a/examples/keep.ipynb +++ b/examples/keep.ipynb @@ -5,16 +5,18 @@ "id": "2631719c", "metadata": {}, "source": [ - "1. Load PyHealth/sample data\n", - "2. Build Co-Occurrence Matrix\n", - "3. Convert matrix to GloveDataset dataloader that returns i, j, counts, weights\n", - "4. Create Keep Model\n", - "5. Pass data loader and model to trainer" + "## Ablation Study: KeepEmbedding with Different Domain Types\n", + "\n", + "**Objective**: \n", + "\n", + "Evaluate the impact of including different OMOP domain types when creating\n", + "KeepEmbedding models. This helps understand which clinical domains contribute\n", + "most to embedding quality and computational efficiency.\n" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "773a5c88", "metadata": {}, "outputs": [ @@ -30,21 +32,50 @@ "from collections import defaultdict, Counter\n", "import networkx as nx\n", "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import json\n", + "import time\n", "from pyhealth.datasets import OMOPDataset\n", "from pyhealth.models import KeepEmbedding, N2V\n", "from pyhealth.trainer import Trainer\n", - "import warnings\n", - "import sys\n", "import torch\n", "from torch.utils.data import Dataset, DataLoader\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "print(f\"Running on device: {device}\")" + "print(f\"Running on device: {device}\")\n", + "\n", + "# Define experiment configurations for ablation study\n", + "EXPERIMENTS = {\n", + " \"condition_only\": {\n", + " \"domain_type\": [\"Condition\"],\n", + " \"tables\": [\"condition_occurrence\"],\n", + " \"description\": \"Baseline: Condition codes only\"\n", + " },\n", + " \"condition_drug\": {\n", + " \"domain_type\": [\"Condition\", \"Drug\"],\n", + " \"tables\": [\"condition_occurrence\", \"drug_exposure\"],\n", + " \"description\": \"Condition + Drug exposures\"\n", + " },\n", + " \"condition_drug_measurement\": {\n", + " \"domain_type\": [\"Condition\", \"Drug\", \"Measurement\"],\n", + " \"tables\": [\"condition_occurrence\", \"drug_exposure\", \"measurement\"],\n", + " \"description\": \"Condition + Drug + Measurements\"\n", + " },\n", + " \"full_domains\": {\n", + " \"domain_type\": [\"Condition\", \"Drug\", \"Measurement\", \"Procedure\"],\n", + " \"tables\": [\"condition_occurrence\", \"drug_exposure\", \"measurement\", \"procedure_occurrence\"],\n", + " \"description\": \"All domains: Condition + Drug + Measurement + Procedure\"\n", + " }\n", + "}\n", + "\n", + "# Store results from all experiments\n", + "experiment_results = {}" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "efe2f018", "metadata": {}, "outputs": [], @@ -84,1577 +115,470 @@ " if code in graph:\n", " ancestors = nx.ancestors(graph, code)\n", " codes_set.update(ancestors)\n", - " else:\n", - " # Code not in graph - may be rare/invalid, skip silently\n", - " # print(f\"Code {code} not found in concept graph (may be rare/invalid)\")\n", - " pass\n", " \n", - " return codes_set" + " return codes_set\n", + "\n", + "def extract_events_by_domain(patient, domain_types):\n", + " \"\"\"\n", + " Extract events from a patient for specified domain types.\n", + " Returns a list of codes from all specified domains.\n", + " \"\"\"\n", + " codes = []\n", + " \n", + " # Map domain types to event types in pyhealth\n", + " domain_to_event_type = {\n", + " \"Condition\": \"condition_occurrence\",\n", + " \"Drug\": \"drug_exposure\",\n", + " \"Measurement\": \"measurement\",\n", + " \"Procedure\": \"procedure_occurrence\"\n", + " }\n", + " \n", + " for domain in domain_types:\n", + " if domain in domain_to_event_type:\n", + " event_type = domain_to_event_type[domain]\n", + " events = patient.get_events(event_type=event_type)\n", + " \n", + " for event in events:\n", + " # Extract appropriate code field based on domain\n", + " if domain == \"Condition\":\n", + " code = event.attr_dict.get(\"condition_concept_id\")\n", + " elif domain == \"Drug\":\n", + " code = event.attr_dict.get(\"drug_concept_id\")\n", + " elif domain == \"Measurement\":\n", + " code = event.attr_dict.get(\"measurement_concept_id\")\n", + " elif domain == \"Procedure\":\n", + " code = event.attr_dict.get(\"procedure_concept_id\")\n", + " \n", + " if code is not None:\n", + " codes.append(code)\n", + " \n", + " return codes" ] }, { - "cell_type": "code", - "execution_count": 3, - "id": "4af28ea6", + "cell_type": "markdown", + "id": "b24d86d0", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "No config path provided, using default OMOP config\n", - "Initializing omop dataset from C:\\Users\\michi\\Workspace\\mimic-iv-demo-data-in-the-omop-common-data-model-0.9\\1_omop_data_csv (dev mode: False)\n", - "No cache_dir provided. Using default cache dir: C:\\Users\\michi\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\d8eca071-4bf2-5cb6-ba60-95451c941912\n", - "Found cached event dataframe: C:\\Users\\michi\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\d8eca071-4bf2-5cb6-ba60-95451c941912\\global_event_df.parquet\n", - "Dataset: omop\n", - "Dev mode: False\n", - "Number of patients: 100\n", - "Number of events: 17408\n", - "Extracting condition codes from all patients...\n", - "Extracted conditions for 100 patients\n" - ] - } - ], "source": [ - "# Load OMOPDataset and extract condition codes for all patients\n", - "dataset = OMOPDataset(\n", - " root=r\"C:\\Users\\michi\\Workspace\\mimic-iv-demo-data-in-the-omop-common-data-model-0.9\\1_omop_data_csv\",\n", - " tables=[\"condition_occurrence\"],\n", - " dataset_name=\"omop\",\n", - " dev=False\n", - ")\n", - "\n", - "dataset.stats()\n", - "\n", - "print(\"Extracting condition codes from all patients...\")\n", - " \n", - "patient_conditions = defaultdict(list)\n", - "\n", - "# Iterate through all patients\n", - "for patient in dataset.iter_patients():\n", - " patient_id = patient.patient_id\n", - " \n", - " # Get all condition events for this patient\n", - " condition_events = patient.get_events(event_type=\"condition_occurrence\")\n", - " \n", - " # Extract condition_concept_id from each event\n", - " for event in condition_events:\n", - " code = event.attr_dict.get(\"condition_concept_id\")\n", - " if code is not None:\n", - " patient_conditions[patient_id].append(code)\n", - "\n", - "print(f\"Extracted conditions for {len(patient_conditions)} patients\")" + "### Experiment Pipeline\n", + "1. Load PyHealth OMOPDaset with appropiate tables and extract distinct conditions from all patients\n", + "2. Filter out codes with <2 occurrences in entire patients' history\n", + "3. Build OMOP Concepts Knowledge Graph\n", + "4. Apply hierarchy rollup\n", + "5. Build Co-Occurrence Matrix\n", + "6. Load Co-Occurence Matrix as a GloveDataset Dataloader\n", + "7. Initialize KEEP PyHealth model and train with PyHealth Trainer\n", + "8. Collect remaining performance metrics (loss, training time, etc.)" ] }, { "cell_type": "code", - "execution_count": 4, - "id": "4570ee56", + "execution_count": null, + "id": "1b8fc62e", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Before Filtering:\n", - "Total codes: 16441, Unique: 979\n", - "======================================================================\n", - "After Filtering:\n", - "Total codes: 799, Unique: 280\n" - ] - } - ], + "outputs": [], "source": [ - "# Filter out conditions codes with <2 occurrences in patient history\n", - "filtered_conditions = {}\n", - " \n", - "before_count = sum(len(codes) for codes in patient_conditions.values())\n", - "before_unique = len(set(code for codes in patient_conditions.values() for code in codes))\n", - "\n", - "for patient_id, codes in patient_conditions.items():\n", - " # Count occurrences of each code\n", - " code_counts = Counter(codes)\n", + "def run_experiment(exp_name, config, data_root, epochs=300):\n", + " print(f\"\\n{'='*80}\")\n", + " print(f\"EXPERIMENT: {exp_name}\")\n", + " print(f\"Description: {config['description']}\")\n", + " print(f\"Domain types: {config['domain_type']}\")\n", + " print(f\"Tables: {config['tables']}\")\n", + " print(f\"{'='*80}\\n\")\n", " \n", - " filtered_codes = [code for code, count in code_counts.items() if count >= 2]\n", + " start_time = time.time()\n", + " metrics = {\n", + " \"exp_name\": exp_name,\n", + " \"domain_type\": config['domain_type'],\n", + " \"description\": config['description']\n", + " }\n", " \n", - " if filtered_codes:\n", - " filtered_conditions[patient_id] = filtered_codes\n", - "\n", - "after_count = sum(len(codes) for codes in filtered_conditions.values())\n", - "after_unique = len(set(code for codes in filtered_conditions.values() for code in codes))\n", - "\n", - "print(f\"Before Filtering:\")\n", - "print(f\"Total codes: {before_count}, Unique: {before_unique}\")\n", - "print(\"=\"*70)\n", - "print(f\"After Filtering:\")\n", - "print(f\"Total codes: {after_count}, Unique: {after_unique}\")" + " try:\n", + " # Load OMOPDataset with specified tables\n", + " print(f\"Step 1: Loading OMOPDataset with tables {config['tables']}...\")\n", + " dataset = OMOPDataset(\n", + " root=data_root,\n", + " tables=config['tables'],\n", + " dataset_name=\"omop\",\n", + " dev=False\n", + " )\n", + " dataset.stats()\n", + " \n", + " # Extract codes from all patients for specified domains\n", + " print(f\"\\nStep 2: Extracting {config['domain_type']} codes from all patients...\")\n", + " patient_codes = defaultdict(list)\n", + " \n", + " for patient in dataset.iter_patients():\n", + " patient_id = patient.patient_id\n", + " codes = extract_events_by_domain(patient, config['domain_type'])\n", + " if codes:\n", + " patient_codes[patient_id] = codes\n", + " \n", + " print(f\"Extracted codes for {len(patient_codes)} patients\")\n", + " metrics[\"patients_with_codes\"] = len(patient_codes)\n", + " \n", + " # Filter codes with <2 occurrences in patient history\n", + " print(f\"\\nStep 3: Filtering codes (keeping only those with ≥2 occurrences)...\")\n", + " filtered_codes_dict = {}\n", + " \n", + " before_count = sum(len(codes) for codes in patient_codes.values())\n", + " before_unique = len(set(code for codes in patient_codes.values() for code in codes))\n", + " \n", + " for patient_id, codes in patient_codes.items():\n", + " code_counts = Counter(codes)\n", + " filtered = [code for code, count in code_counts.items() if count >= 2]\n", + " if filtered:\n", + " filtered_codes_dict[patient_id] = filtered\n", + " \n", + " after_count = sum(len(codes) for codes in filtered_codes_dict.values())\n", + " after_unique = len(set(code for codes in filtered_codes_dict.values() for code in codes))\n", + " \n", + " print(f\"Before filtering: {before_count} codes, {before_unique} unique\")\n", + " print(f\"After filtering: {after_count} codes, {after_unique} unique\")\n", + " \n", + " metrics[\"before_filtering_count\"] = before_count\n", + " metrics[\"before_filtering_unique\"] = before_unique\n", + " metrics[\"after_filtering_count\"] = after_count\n", + " metrics[\"after_filtering_unique\"] = after_unique\n", + " \n", + " # Build concept graph with specified domain types\n", + " print(f\"\\nStep 4: Building concept relationship graph...\")\n", + " n2v = N2V()\n", + " graph = n2v.create_graph(\n", + " path=data_root,\n", + " domain_type=config['domain_type']\n", + " )\n", + " \n", + " print(f\"Graph loaded:\")\n", + " print(f\" Nodes (unique concepts): {len(graph.nodes())}\")\n", + " print(f\" Edges (relationships): {len(graph.edges())}\")\n", + " \n", + " metrics[\"graph_nodes\"] = len(graph.nodes())\n", + " metrics[\"graph_edges\"] = len(graph.edges())\n", + " \n", + " # Roll up codes with hierarchy\n", + " print(f\"\\nStep 5: Applying hierarchy roll-up...\")\n", + " rolled_up_codes = {}\n", + " \n", + " total_original = 0\n", + " total_rolled_up = 0\n", + " patients_processed = 0\n", + " \n", + " for patient_id, codes in filtered_codes_dict.items():\n", + " expanded = set()\n", + " for code in codes:\n", + " code_and_ancestors = get_code_and_ancestors(graph, code)\n", + " expanded.update(code_and_ancestors)\n", + " \n", + " if expanded:\n", + " rolled_up_codes[patient_id] = expanded\n", + " total_original += len(codes)\n", + " total_rolled_up += len(expanded)\n", + " patients_processed += 1\n", + " \n", + " print(f\"Roll-up complete:\")\n", + " print(f\" Avg original codes/patient: {total_original / patients_processed:.2f}\")\n", + " print(f\" Avg rolled-up codes/patient: {total_rolled_up / patients_processed:.2f}\")\n", + " expansion_factor = total_rolled_up / total_original if total_original > 0 else 0\n", + " print(f\" Expansion factor: {expansion_factor:.2f}x\")\n", + " \n", + " metrics[\"avg_original_codes\"] = total_original / patients_processed if patients_processed > 0 else 0\n", + " metrics[\"avg_rolled_up_codes\"] = total_rolled_up / patients_processed if patients_processed > 0 else 0\n", + " metrics[\"expansion_factor\"] = expansion_factor\n", + " \n", + " # Build co-occurrence matrix\n", + " print(f\"\\nStep 6: Building co-occurrence matrix...\")\n", + " \n", + " unique_codes = set()\n", + " for codes in rolled_up_codes.values():\n", + " unique_codes.update(codes)\n", + " \n", + " unique_codes = sorted(unique_codes)\n", + " code_to_index = {code: idx for idx, code in enumerate(unique_codes)}\n", + " index_to_code = {idx: code for code, idx in code_to_index.items()}\n", + " \n", + " num_codes = len(code_to_index)\n", + " cooc_matrix = np.zeros((num_codes, num_codes), dtype=np.float32)\n", + " \n", + " total_pairs = 0\n", + " for patient_id, codes in rolled_up_codes.items():\n", + " code_indices = [code_to_index[code] for code in codes]\n", + " for i in range(len(code_indices)):\n", + " for j in range(i + 1, len(code_indices)):\n", + " idx_i = code_indices[i]\n", + " idx_j = code_indices[j]\n", + " cooc_matrix[idx_i, idx_j] += 1.0\n", + " cooc_matrix[idx_j, idx_i] += 1.0\n", + " total_pairs += 1\n", + " \n", + " print(f\"Matrix complete:\")\n", + " print(f\" Shape: {cooc_matrix.shape}\")\n", + " print(f\" Sparsity: {(cooc_matrix == 0).sum() / cooc_matrix.size * 100:.2f}%\")\n", + " print(f\" Total co-occurrences: {cooc_matrix.sum():.0f}\")\n", + " \n", + " metrics[\"matrix_shape\"] = cooc_matrix.shape\n", + " metrics[\"matrix_sparsity\"] = (cooc_matrix == 0).sum() / cooc_matrix.size * 100\n", + " metrics[\"total_cooccurrence_pairs\"] = cooc_matrix.sum()\n", + " \n", + " # Create GloveDataset and DataLoader\n", + " print(f\"\\nStep 7: Creating GloveDataset and DataLoader...\")\n", + " x_max = 100\n", + " alpha = 0.75\n", + " batch_size = 1024\n", + " \n", + " glove_dataset = GloveDataset(cooc_matrix, num_codes, x_max, alpha)\n", + " data_loader = DataLoader(glove_dataset, batch_size=batch_size, shuffle=True)\n", + " \n", + " # Initialize KeepEmbedding model\n", + " print(f\"\\nStep 8: Initializing KeepEmbedding model...\")\n", + " keep_model = KeepEmbedding(\n", + " dataset=None,\n", + " graph=graph,\n", + " num_words=num_codes,\n", + " embedding_dim=100,\n", + " walk_length=30,\n", + " num_walks=750,\n", + " lambda_reg=1.0e-3,\n", + " reg_norm=None,\n", + " log_scale=False,\n", + " code_to_index=code_to_index,\n", + " device=device\n", + " )\n", + " \n", + " # Train model\n", + " print(f\"\\nStep 9: Training KeepEmbedding model ({epochs} epochs)...\")\n", + " training_start = time.time()\n", + " \n", + " trainer = Trainer(model=keep_model, enable_logging=False)\n", + " trainer.train(\n", + " train_dataloader=data_loader,\n", + " val_dataloader=None,\n", + " epochs=epochs,\n", + " optimizer_class=torch.optim.Adam,\n", + " optimizer_params={\"lr\": 0.05},\n", + " monitor=\"loss\",\n", + " patience=10,\n", + " )\n", + " \n", + " training_time = time.time() - training_start\n", + " print(f\"Training complete in {training_time:.2f} seconds\")\n", + " \n", + " metrics[\"training_time_seconds\"] = training_time\n", + " metrics[\"embedding_shape\"] = tuple(keep_model.embeddings_v.weight.shape)\n", + " metrics[\"final_loss\"] = trainer.model_state.best_loss if hasattr(trainer, 'model_state') else None\n", + " \n", + " total_time = time.time() - start_time\n", + " metrics[\"total_time_seconds\"] = total_time\n", + " metrics[\"status\"] = \"SUCCESS\"\n", + " \n", + " print(f\"\\n✓ Experiment completed successfully in {total_time:.2f} seconds\")\n", + " \n", + " return metrics" ] }, { - "cell_type": "code", - "execution_count": 5, - "id": "07f90eb1", + "cell_type": "markdown", + "id": "5567958d", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Building concept relationship graph...\n", - "Loading concepts from C:\\Users\\michi\\Workspace\\mimic-iv-demo-data-in-the-omop-common-data-model-0.9\\1_omop_data_csv\\2b_concept.csv\n", - "Loading concept relationships from C:\\Users\\michi\\Workspace\\mimic-iv-demo-data-in-the-omop-common-data-model-0.9\\1_omop_data_csv\\2b_concept_relationship.csv\n", - "Loaded 3885 concepts and 7716 relationships\n", - "Filtered to 34 concepts in domains: ['Condition']\n", - "Created set of 34 concept IDs\n", - "Found 0 relationships between concepts\n", - "Graph loaded successfully:\n", - " Nodes (unique concepts): 34\n", - " Edges (relationships): 0\n", - "Applying hierarchy roll-up (dense: each code -> itself + all parents)...\n", - "Roll-up complete:\n", - " Patients processed: 100\n", - " Original codes per patient (avg): 7.99\n", - " Rolled-up codes per patient (avg): 7.99\n", - " Expansion factor: 1.00x\n" - ] - } - ], "source": [ - "n2v = N2V()\n", - "\n", - "# Build the concept relationship graph from conditions\n", - "print(\"Building concept relationship graph...\")\n", - "graph = n2v.create_graph(\n", - " path=r\"C:\\Users\\michi\\Workspace\\mimic-iv-demo-data-in-the-omop-common-data-model-0.9\\1_omop_data_csv\",\n", - " domain_type=[\"Condition\"]\n", - ")\n", - "\n", - "print(f\"Graph loaded successfully:\")\n", - "print(f\" Nodes (unique concepts): {len(graph.nodes())}\")\n", - "print(f\" Edges (relationships): {len(graph.edges())}\")\n", - "\n", - "print(\"Applying hierarchy roll-up (dense: each code -> itself + all parents)...\")\n", - "\n", - "rolled_up_conditions = {}\n", - "\n", - "# Track statistics\n", - "total_original_codes = 0\n", - "total_rolled_up_codes = 0\n", - "patients_processed = 0\n", - "\n", - "for patient_id, codes in filtered_conditions.items():\n", - " expanded_codes = set()\n", - " \n", - " # For each code, add it and all ancestors\n", - " for code in codes:\n", - " code_and_ancestors = get_code_and_ancestors(graph, code)\n", - " expanded_codes.update(code_and_ancestors)\n", - " \n", - " if expanded_codes:\n", - " rolled_up_conditions[patient_id] = expanded_codes\n", - " total_original_codes += len(codes)\n", - " total_rolled_up_codes += len(expanded_codes)\n", - " patients_processed += 1\n", - "\n", - "# Log statistics\n", - "print(f\"Roll-up complete:\")\n", - "print(f\" Patients processed: {patients_processed}\")\n", - "print(f\" Original codes per patient (avg): {total_original_codes / patients_processed:.2f}\")\n", - "print(f\" Rolled-up codes per patient (avg): {total_rolled_up_codes / patients_processed:.2f}\")\n", - "print(f\" Expansion factor: {total_rolled_up_codes / total_original_codes:.2f}x\")" + "#### Executing Experiment Pipeline for 4 Scenarios\n", + "1. Baseline (Condition only): Minimal feature set\n", + "2. Condition + Drug: Add medication information \n", + "3. Condition + Drug + Measurement: Add lab values\n", + "4. Full (Condition + Drug + Measurement + Procedure): Complete medical context" ] }, { "cell_type": "code", - "execution_count": 6, - "id": "07af9886", + "execution_count": null, + "id": "4af28ea6", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Building code index from rolled-up conditions...\n", - "Building co-occurrence matrix...\n", - "Matrix construction complete:\n", - " Matrix shape: (280, 280)\n", - " Patients processed: 100\n", - " Total co-occurrence pairs: 7978\n", - " Matrix sparsity: 84.11%\n", - " Matrix sum (total co-occurrences): 15956\n", - " Min value: 0.00, Max value: 52.00\n" - ] - } - ], + "outputs": [], "source": [ - "# Construct co-occurrence matrix for rolled-up conditions\n", - "print(\"Building code index from rolled-up conditions...\")\n", - " \n", - "# Collect all unique codes after roll-up\n", - "unique_codes = set()\n", - "for codes in rolled_up_conditions.values():\n", - " unique_codes.update(codes)\n", + "# Set data root path\n", + "data_root = r\"\\Workspace\\mimic-iv-demo-data-in-the-omop-common-data-model-0.9\\1_omop_data_csv\"\n", "\n", - "# Sort for reproducibility\n", - "unique_codes = sorted(unique_codes)\n", + "# Run all ablation experiments\n", + "print(\"STARTING ABLATION STUDY\")\n", + "print(f\"Number of experiments: {len(EXPERIMENTS)}\")\n", + "print()\n", "\n", - "# Create bidirectional mapping\n", - "code_to_index = {code: idx for idx, code in enumerate(unique_codes)}\n", - "index_to_code = {idx: code for code, idx in code_to_index.items()}\n", - "\n", - "\n", - "print(\"Building co-occurrence matrix...\")\n", - "\n", - "num_codes = len(code_to_index)\n", - "\n", - "# Initialize matrix\n", - "cooc_matrix = np.zeros((num_codes, num_codes), dtype=np.float32)\n", - "\n", - "# Track statistics\n", - "total_pairs = 0\n", - "patients_processed = 0\n", - "\n", - "for patient_id, codes in rolled_up_conditions.items():\n", - " # Convert codes to indices\n", - " code_indices = [code_to_index[code] for code in codes]\n", - " \n", - " # Create all unique pairs\n", - " for i in range(len(code_indices)):\n", - " for j in range(i + 1, len(code_indices)):\n", - " idx_i = code_indices[i]\n", - " idx_j = code_indices[j]\n", - " \n", - " # Increment both matrix[i,j] and matrix[j,i] for symmetry\n", - " cooc_matrix[idx_i, idx_j] += 1.0\n", - " cooc_matrix[idx_j, idx_i] += 1.0\n", - " \n", - " total_pairs += 1\n", - " \n", - " patients_processed += 1\n", + "for exp_name, config in EXPERIMENTS.items():\n", + " metrics = run_experiment(exp_name, config, data_root, epochs=50)\n", + " experiment_results[exp_name] = metrics\n", "\n", - "print(f\"Matrix construction complete:\")\n", - "print(f\" Matrix shape: {cooc_matrix.shape}\")\n", - "print(f\" Patients processed: {patients_processed}\")\n", - "print(f\" Total co-occurrence pairs: {total_pairs}\")\n", - "print(f\" Matrix sparsity: {(cooc_matrix == 0).sum() / cooc_matrix.size * 100:.2f}%\")\n", - "print(f\" Matrix sum (total co-occurrences): {cooc_matrix.sum():.0f}\")\n", - "print(f\" Min value: {cooc_matrix.min():.2f}, Max value: {cooc_matrix.max():.2f}\")" + "print(\"\\n\" + \"=\"*80)\n", + "print(\"ALL EXPERIMENTS COMPLETED\")\n", + "print(\"=\"*80)" + ] + }, + { + "cell_type": "markdown", + "id": "6eacbda0", + "metadata": {}, + "source": [ + "#### Summarizing Experiment Results\n", + "**Key Findings**\n", + "Best Performance Across Metrics:\n", + " * Most unique codes: full_domains (1357 codes)\n", + " * Fastest training: condition_only (81.29 seconds)\n", + " * Largest graph: full_domains (2187 nodes)" ] }, { "cell_type": "code", - "execution_count": 7, - "id": "0c079f22", + "execution_count": 17, + "id": "4570ee56", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Creating GloveDataset from co-occurrence matrix...\n", - " GloveDataset created:\n", - " Num words (diagnosis codes): 280\n", - " Dataset size (co-occurrence pairs): 12458\n", - " DataLoader created: batch_size=128, num_batches=98\n", "\n", - "Initializing KeepEmbedding with builder parameters...\n", - "Initializing Node2Vec with embedding_dim=100...\n", - "Graph created with 34 nodes and 0 edges\n", - "Initializing Node2Vec with embedding_dim=100 walk_length=30, num_walks=750\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "bbe126854ba544b49b26085b791254a0", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Computing transition probabilities: 0%| | 0/34 [00:00\n", - "Optimizer params: {'lr': 0.01}\n", - "Weight decay: 0.0\n", - "Max grad norm: None\n", - "Val dataloader: None\n", - "Monitor: loss\n", - "Monitor criterion: max\n", - "Epochs: 50\n", - "Patience: 10\n", - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "602358d5ea6f4e679ecda85c2b05e58e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Epoch 0 / 50: 0%| | 0/98 [00:00" ] }, "metadata": {}, "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "--- Train epoch-7, step-784 ---\n", - "loss: 0.3647\n", - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d7601acd55ae4021a1b00cf580bb3d62", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Epoch 8 / 50: 0%| | 0/98 [00:00 Date: Wed, 22 Apr 2026 22:25:03 -0400 Subject: [PATCH 10/12] Fix docstrings --- pyhealth/models/keep_embedding.py | 247 ++++++++++++++++-------------- 1 file changed, 129 insertions(+), 118 deletions(-) diff --git a/pyhealth/models/keep_embedding.py b/pyhealth/models/keep_embedding.py index 2f9c8fd98..4d246e449 100644 --- a/pyhealth/models/keep_embedding.py +++ b/pyhealth/models/keep_embedding.py @@ -12,46 +12,39 @@ from pyhealth.datasets import SampleDataset from .base_model import BaseModel -class N2V(): - """Node2Vec embeddings for OMOP concepts. +class N2V: + """Generate Node2Vec embeddings for OMOP concepts. - This class builds a directed knowledge graph from OMOP concept and - concept relationship tables, then trains Node2Vec to generate - ontology-informed embeddings for medical concepts. + Builds a directed knowledge graph from OMOP concept relationship tables + and trains Node2Vec to create graph embeddings. Attributes: - path: Path to the OMOP CSV files. - domain_type: List of OMOP domains used to filter concepts. - embedding_dim: Dimension of the learned embeddings. + embedding_dim: Dimension of learned embeddings. walk_length: Length of each random walk. - num_walks: Number of walks generated per node. - graph: Directed graph constructed from OMOP concepts and relations. + num_walks: Number of walks per node. """ def __init__( - self, - embedding_dim:int=None, - walk_length:int=None, - num_walks:int=None - ): + self, + embedding_dim: int | None = None, + walk_length: int | None = None, + num_walks: int | None = None, + ) -> None: self.embedding_dim = embedding_dim self.walk_length = walk_length self.num_walks = num_walks - # Create graph from concept and their relationships data - def create_graph(self, path, domain_type) -> nx.DiGraph: - """ - Create a directed graph from OMOP concept relationships. - - Loads concepts and their relationships from CSV files, filters by domain_type, - and builds a NetworkX DiGraph where nodes are concept IDs and edges are - concept relationships (maps_to). - + def create_graph( + self, path: str, domain_type: list[str] + ) -> nx.DiGraph: + """Create a Network directed graph from OMOP concept relationships. + + Args: + path: Path to OMOP concept CSV files. + domain_type: List of domain IDs to include. + Returns: - nx.DiGraph: Directed graph with concept_id as nodes and relationships as edges. - - Raises: - FileNotFoundError: If CSV files are not found. - ValueError: If no concepts found for specified domains. + Directed graph with concept IDs as nodes and relationships + as edges. """ # Load concept table concept_path = os.path.join(path, "2b_concept.csv") @@ -110,30 +103,36 @@ def create_graph(self, path, domain_type) -> nx.DiGraph: return graph - def _build_index_mapping(self, node_embeddings): - """ - Build a dictionary to map concept code to the index in node_embeddings. - + def _build_index_mapping( + self, node_embeddings: object + ) -> dict[int, int]: + """Map concept codes to embedding indices. + Args: - node_embeddings: Gensim Word2Vec model word vectors - + node_embeddings: Word2Vec model word vectors. + Returns: - dict: Mapping from concept_id (int) to index in embeddings + Mapping from concept_id to index in embeddings. """ return {int(key): i for i, key in enumerate(node_embeddings.index_to_key)} - def _get_vector_iso(self, code, node_embeddings, index_mapping, mean_vector): - """ - Return concept embedding for the given code or mean vector if not found. - + def _get_vector_iso( + self, + code: int, + node_embeddings: object, + index_mapping: dict, + mean_vector: np.ndarray, + ) -> np.ndarray: + """Get embedding vector for code or mean vector if not found. + Args: - code: Concept ID - node_embeddings: Gensim Word2Vec model word vectors - index_mapping: Dictionary mapping concept_id to index - mean_vector: Mean vector to use as fallback - + code: Concept ID. + node_embeddings: Word2Vec model word vectors. + index_mapping: Concept ID to embedding index mapping. + mean_vector: Fallback vector to use if code not found. + Returns: - np.ndarray: Embedding vector for the concept + Embedding vector for the concept. """ index = index_mapping.get(int(code)) if index is not None: @@ -142,16 +141,18 @@ def _get_vector_iso(self, code, node_embeddings, index_mapping, mean_vector): print(f"Code {code} not found, returning mean vector.") return mean_vector - def generate_embeddings(self, graph): - """ - Generate node embeddings using Node2Vec algorithm. - - Creates a graph from OMOP concepts and applies Node2Vec to generate - embeddings for each concept based on its network structure. - + def generate_embeddings( + self, graph: nx.DiGraph + ) -> tuple[np.ndarray, list]: + """Generate node embeddings using Node2Vec. + + Args: + graph: Directed graph of OMOP concepts and relationships. + Returns: - tuple: (embedding_matrix, node_ids) where embedding_matrix is the numpy array - of embeddings and node_ids is the list of graph node IDs in order. + Tuple of (embedding_matrix, node_ids) where embedding_matrix + is numpy array of embeddings and node_ids is the list of + node IDs in order. """ print(f"Graph created with {len(graph.nodes())} nodes and {len(graph.edges())} edges") @@ -191,43 +192,65 @@ def generate_embeddings(self, graph): return embedding_matrix, keys class KeepEmbedding(BaseModel): - """KEEP Embedding: Fine-tune Node2Vec embeddings using GloVe while penalizing - deviation from original embeddings. - - Balances: - - Co-occurrence structure (GloVe objective) - - Graph structure prior (Node2Vec via regularization) - + """KEEP Embedding Framework + + + Fine-tune Node2Vec embeddings using GloVe with graph regularization. + + Balances co-occurrence structure (GloVe) with graph (Node2Vec) via regularization + to generate medical concept embeddings. + Args: - dataset (SampleDataset): The dataset to train the model. - path (str): Path to OMOP data files for graph construction. - domain_type (list[str]): Domain types to include in graph. - embedding_dim (int): Dimension of embeddings. - walk_length (int): Length of random walks for Node2Vec. - num_walks (int): Number of random walks per node for Node2Vec. - lambda_reg (float): Regularization strength for Node2Vec prior. Default: 1.0. - reg_norm (str or float): Norm type for regularization ('cosine' or numeric p-norm). - Default: None (cosine similarity). - log_scale (bool): Whether to apply log scaling to regularization distance. - Default: False. - code_to_index (dict, optional): Mapping from concept codes to vocabulary indices. - If provided, embeddings are filtered to only include codes in this mapping. - device (str): Device to use ('cuda' or 'cpu'). Default: 'cpu'. + dataset: Dataset to train the model. + graph: Directed graph of concepts and relationships. + embedding_dim: Dimension of embeddings. + walk_length: Length of random walks for Node2Vec. + num_walks: Number of random walks per node. + num_words: Size of vocabulary. + lambda_reg: Regularization strength (default: 1.0). + reg_norm: Norm type ('cosine' or numeric p-norm, default: None). + log_scale: Apply log scaling to regularization distance + (default: False). + code_to_index: Optional mapping from concept codes to indices. + device: Device to use ('cuda' or 'cpu', default: 'cpu'). + + Examples: + >>> from pyhealth.datasets import OMOPDataset + >>> from pyhealth.models import KeepEmbedding + >>> dataset = SampleDataset(num_patients=100, num_visits=10, num_codes=50) + >>> graph = n2v.create_graph() # Build knowledge graph from concept and relationship tables + >>> dataset = OMOPDataset(...) + >>> # Build co-occurrence matrix from dataset + >>> # Load co-occurrence matrix as GloveDatset Dataloader + >>> model = KeepEmbedding( + ... dataset=None, + ... graph=graph, + ... embedding_dim=128, + ... walk_length=10, + ... num_walks=5, + ... num_words=50, + ... lambda_reg=0.5, + ... reg_norm='cosine', + ... log_scale=True, + ... device='cuda' + ... ) + >>> # Use embeddings for with downstream PyHealth models """ - def __init__(self, - dataset: SampleDataset, - graph: nx.Graph, - embedding_dim:int, - walk_length:int, - num_walks:int, - num_words: int, - lambda_reg: float = 1.0, - reg_norm: str | float = None, - log_scale: bool = False, - code_to_index: dict = None, - device: str = "cpu" - ): + def __init__( + self, + dataset: SampleDataset, + graph: nx.Graph, + embedding_dim: int, + walk_length: int, + num_walks: int, + num_words: int, + lambda_reg: float = 1.0, + reg_norm: str | float | None = None, + log_scale: bool = False, + code_to_index: dict | None = None, + device: str = "cpu", + ) -> None: """Initialize KEEP Embedding model.""" super().__init__(dataset=dataset) @@ -301,38 +324,26 @@ def __init__(self, print(f"Regularization norm: {reg_norm}") print(f"Log scaling: {log_scale}") - def forward(self, - i_indices: torch.Tensor = None, - j_indices: torch.Tensor = None, - counts: torch.Tensor = None, - weights: torch.Tensor = None, - **kwargs) -> dict[str, torch.Tensor]: - """Forward pass for KEEP Embedding. - - Computes GloVe loss with optional Node2Vec regularization. For compatibility - with BaseModel.forward(), returns a dictionary with keys: loss, y_prob, - y_true, logit. - - For training GloVe objective, pass: - - i_indices: Token indices (batch_size,) - - j_indices: Context token indices (batch_size,) - - counts: Co-occurrence counts (batch_size,) - - weights: Weights for each co-occurrence pair (batch_size,) - + def forward( + self, + i_indices: torch.Tensor | None = None, + j_indices: torch.Tensor | None = None, + counts: torch.Tensor | None = None, + weights: torch.Tensor | None = None, + **kwargs, + ) -> dict[str, torch.Tensor]: + """Compute GloVe loss with optional Node2Vec regularization. + Args: - i_indices (torch.Tensor, optional): Token indices. - j_indices (torch.Tensor, optional): Context token indices. - counts (torch.Tensor, optional): Co-occurrence counts. - weights (torch.Tensor, optional): Weights for loss terms. + i_indices: Token indices (batch_size,). + j_indices: Context token indices (batch_size,). + counts: Co-occurrence counts (batch_size,). + weights: Weights for loss terms (batch_size,). **kwargs: Additional arguments for compatibility. - + Returns: - dict: Dictionary with keys: - - loss: Total loss (GloVe + regularization if applicable) - - logit: Placeholder tensor (for BaseModel compatibility) - - y_prob: Placeholder tensor (for BaseModel compatibility) - - y_true: Placeholder tensor (for BaseModel compatibility) - - reg_loss: Regularization loss component (if applicable) + Dictionary with keys: loss, logit, y_prob, y_true, + reg_loss. """ # If no GloVe inputs provided, return dummy output From e761b5948876bd97bb0037a386849248e98b4826 Mon Sep 17 00:00:00 2001 From: shreyaaishi Date: Wed, 22 Apr 2026 22:48:05 -0400 Subject: [PATCH 11/12] Update project dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 98f88d47b..d93115e3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "ogb>=1.3.5", "scikit-learn~=1.7.0", "networkx", + "node2vec", "mne~=1.10.0", "urllib3~=2.5.0", "numpy~=2.2.0", From a13ea4e8ab0f53d006fdbe34933e0ef3e3071d0b Mon Sep 17 00:00:00 2001 From: Rohan Vasavada Date: Wed, 22 Apr 2026 21:54:53 -0500 Subject: [PATCH 12/12] fix tests --- tests/core/test_keep_embedding.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/core/test_keep_embedding.py b/tests/core/test_keep_embedding.py index 3e3bc2f37..e579d42f4 100644 --- a/tests/core/test_keep_embedding.py +++ b/tests/core/test_keep_embedding.py @@ -52,6 +52,8 @@ "pyhealth.models.text_embedding", "pyhealth.models.sdoh", "pyhealth.models.unified_embedding", + "pyhealth.models.transformer_deid", + "pyhealth.models.califorest", ): sys.modules.setdefault(_mod, MagicMock()) @@ -69,8 +71,6 @@ class TestN2VHelpers(unittest.TestCase): def setUp(self): """Set up a minimal N2V instance.""" self.n2v = N2V( - path="/fake", - domain_type=["all"], embedding_dim=EMBEDDING_DIM, walk_length=5, num_walks=5, @@ -107,14 +107,15 @@ class TestKeepEmbeddingInit(unittest.TestCase): def setUp(self): """Set up a KeepEmbedding instance with N2V mocked out.""" fake_matrix = np.random.randn(NUM_CONCEPTS, EMBEDDING_DIM).astype(np.float32) - with patch.object(N2V, "generate_embeddings", return_value=fake_matrix): + fake_keys = list(range(NUM_CONCEPTS)) + with patch.object(N2V, "generate_embeddings", return_value=(fake_matrix, fake_keys)): self.model = KeepEmbedding( dataset=None, - path="/fake/path", - domain_type=["all"], + graph=MagicMock(), embedding_dim=EMBEDDING_DIM, walk_length=5, num_walks=5, + num_words=NUM_CONCEPTS, device="cpu", ) @@ -136,14 +137,15 @@ class TestKeepEmbeddingForward(unittest.TestCase): def _make_model(self, lambda_reg=1.0, reg_norm=None, log_scale=False): """Return a KeepEmbedding with N2V mocked out.""" fake_matrix = np.random.randn(NUM_CONCEPTS, EMBEDDING_DIM).astype(np.float32) - with patch.object(N2V, "generate_embeddings", return_value=fake_matrix): + fake_keys = list(range(NUM_CONCEPTS)) + with patch.object(N2V, "generate_embeddings", return_value=(fake_matrix, fake_keys)): return KeepEmbedding( dataset=None, - path="/fake/path", - domain_type=["all"], + graph=MagicMock(), embedding_dim=EMBEDDING_DIM, walk_length=5, num_walks=5, + num_words=NUM_CONCEPTS, lambda_reg=lambda_reg, reg_norm=reg_norm, log_scale=log_scale, @@ -205,14 +207,15 @@ class TestKeepEmbeddingBackward(unittest.TestCase): def _make_model(self, lambda_reg=1.0): """Return a KeepEmbedding with N2V mocked out.""" fake_matrix = np.random.randn(NUM_CONCEPTS, EMBEDDING_DIM).astype(np.float32) - with patch.object(N2V, "generate_embeddings", return_value=fake_matrix): + fake_keys = list(range(NUM_CONCEPTS)) + with patch.object(N2V, "generate_embeddings", return_value=(fake_matrix, fake_keys)): return KeepEmbedding( dataset=None, - path="/fake/path", - domain_type=["all"], + graph=MagicMock(), embedding_dim=EMBEDDING_DIM, walk_length=5, num_walks=5, + num_words=NUM_CONCEPTS, lambda_reg=lambda_reg, device="cpu", )