diff --git a/pyhealth/models/stagenet_mha.py b/pyhealth/models/stagenet_mha.py new file mode 100644 index 00000000..1609cc02 --- /dev/null +++ b/pyhealth/models/stagenet_mha.py @@ -0,0 +1,703 @@ +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn + +from pyhealth.datasets import SampleDataset +from pyhealth.models import BaseModel +from pyhealth.models.utils import get_last_visit + +from .embedding import EmbeddingModel + + +class StageNetLayer(nn.Module): + """StageNet layer. + + Paper: Stagenet: Stage-aware neural networks for health risk prediction. WWW 2020. + + This layer is used in the StageNet model. But it can also be used as a + standalone layer. + + Args: + input_dim: dynamic feature size. + chunk_size: the chunk size for the StageNet layer. Default is 128. + levels: the number of levels for the StageNet layer. levels * chunk_size = hidden_dim in the RNN. Smaller chunk size and more levels can capture more detailed patient status variations. Default is 3. + conv_size: the size of the convolutional kernel. Default is 10. + dropconnect: the dropout rate for the dropconnect. Default is 0.3. + dropout: the dropout rate for the dropout. Default is 0.3. + dropres: the dropout rate for the residual connection. Default is 0.3. + num_heads: number of heads in the multi-head attention inserted between the SA-LSTM + and the stage-adaptive CNN. Default is 8. + attn_dropout: dropout rate applied to attention weights. Default is 0.1. + + Examples: + >>> from pyhealth.models import StageNetLayer + >>> input = torch.randn(3, 128, 64) # [batch size, sequence len, feature_size] + >>> layer = StageNetLayer(64) + >>> c, _, _ = layer(input) + >>> c.shape + torch.Size([3, 384]) + """ + + def __init__( + self, + input_dim: int, + chunk_size: int = 128, + conv_size: int = 10, + levels: int = 3, + dropconnect: int = 0.3, + dropout: int = 0.3, + dropres: int = 0.3, + num_heads: int = 8, + attn_dropout: float = 0.1, + ): + super(StageNetLayer, self).__init__() + + self.dropout = dropout + self.dropconnect = dropconnect + self.dropres = dropres + self.input_dim = input_dim + self.hidden_dim = chunk_size * levels + self.conv_dim = self.hidden_dim + self.conv_size = conv_size + # self.output_dim = output_dim + self.levels = levels + self.chunk_size = chunk_size + + self.kernel = nn.Linear( + int(input_dim + 1), int(self.hidden_dim * 4 + levels * 2) + ) + nn.init.xavier_uniform_(self.kernel.weight) + nn.init.zeros_(self.kernel.bias) + self.recurrent_kernel = nn.Linear( + int(self.hidden_dim + 1), int(self.hidden_dim * 4 + levels * 2) + ) + nn.init.orthogonal_(self.recurrent_kernel.weight) + nn.init.zeros_(self.recurrent_kernel.bias) + + self.nn_scale = nn.Linear(int(self.hidden_dim), int(self.hidden_dim // 6)) + self.nn_rescale = nn.Linear(int(self.hidden_dim // 6), int(self.hidden_dim)) + self.nn_conv = nn.Conv1d( + int(self.hidden_dim), int(self.conv_dim), int(conv_size), 1 + ) + if self.hidden_dim % num_heads != 0: + raise ValueError( + f"hidden_dim ({self.hidden_dim}) must be divisible by num_heads ({num_heads})" + ) + self.mha = nn.MultiheadAttention( + embed_dim=self.hidden_dim, + num_heads=num_heads, + dropout=attn_dropout, + batch_first=False, + ) + self.attn_norm = nn.LayerNorm(self.hidden_dim) + # self.nn_output = nn.Linear(int(self.conv_dim), int(output_dim)) + + if self.dropconnect: + self.nn_dropconnect = nn.Dropout(p=dropconnect) + self.nn_dropconnect_r = nn.Dropout(p=dropconnect) + if self.dropout: + self.nn_dropout = nn.Dropout(p=dropout) + self.nn_dropres = nn.Dropout(p=dropres) + + # Hooks for interpretability (e.g., DeepLIFT) default to None + self._activation_hooks = None + + def set_activation_hooks(self, hooks) -> None: + """Registers activation hooks for interpretability methods. + + Args: + hooks: Object exposing ``apply(name, tensor, **kwargs)``. When + provided, activation functions inside the layer will be + routed through ``hooks`` instead of raw torch.ops. Passing + ``None`` disables the hooks. + """ + + self._activation_hooks = hooks + + def _apply_activation(self, name: str, tensor: torch.Tensor, **kwargs) -> torch.Tensor: + if self._activation_hooks is not None and hasattr(self._activation_hooks, "apply"): + return self._activation_hooks.apply(name, tensor, **kwargs) + fn = getattr(torch, name) + return fn(tensor, **kwargs) + + def cumax(self, x, mode="l2r"): + if mode == "l2r": + x = self._apply_activation("softmax", x, dim=-1) + x = torch.cumsum(x, dim=-1) + return x + elif mode == "r2l": + x = torch.flip(x, [-1]) + x = self._apply_activation("softmax", x, dim=-1) + x = torch.cumsum(x, dim=-1) + return torch.flip(x, [-1]) + else: + return x + + def step(self, inputs, c_last, h_last, interval, device): + x_in = inputs.to(device=device) + + # Integrate inter-visit time intervals + interval = interval.unsqueeze(-1).to(device=device) + x_out1 = self.kernel(torch.cat((x_in, interval), dim=-1)).to(device) + x_out2 = self.recurrent_kernel( + torch.cat((h_last.to(device=device), interval), dim=-1) + ) + + if self.dropconnect: + x_out1 = self.nn_dropconnect(x_out1) + x_out2 = self.nn_dropconnect_r(x_out2) + x_out = x_out1 + x_out2 + f_master_gate = self.cumax(x_out[:, : self.levels], "l2r") + f_master_gate = f_master_gate.unsqueeze(2).to(device=device) + i_master_gate = self.cumax(x_out[:, self.levels : self.levels * 2], "r2l") + i_master_gate = i_master_gate.unsqueeze(2) + x_out = x_out[:, self.levels * 2 :] + x_out = x_out.reshape(-1, self.levels * 4, self.chunk_size) + f_gate = self._apply_activation("sigmoid", x_out[:, : self.levels]).to( + device=device + ) + i_gate = self._apply_activation( + "sigmoid", x_out[:, self.levels : self.levels * 2] + ).to(device=device) + o_gate = self._apply_activation( + "sigmoid", x_out[:, self.levels * 2 : self.levels * 3] + ) + c_in = self._apply_activation("tanh", x_out[:, self.levels * 3 :]).to( + device=device + ) + c_last = c_last.reshape(-1, self.levels, self.chunk_size).to(device=device) + overlap = (f_master_gate * i_master_gate).to(device=device) + c_out = ( + overlap * (f_gate * c_last + i_gate * c_in) + + (f_master_gate - overlap) * c_last + + (i_master_gate - overlap) * c_in + ) + h_out = o_gate * self._apply_activation("tanh", c_out) + c_out = c_out.reshape(-1, self.hidden_dim) + h_out = h_out.reshape(-1, self.hidden_dim) + out = torch.cat([h_out, f_master_gate[..., 0], i_master_gate[..., 0]], 1) + return out, c_out, h_out + + def forward( + self, + x: torch.tensor, + time: Optional[torch.tensor] = None, + mask: Optional[torch.tensor] = None, + ) -> Tuple[torch.tensor]: + """Forward propagation. + + Args: + x: a tensor of shape [batch size, sequence len, input_dim]. + static: a tensor of shape [batch size, static_dim]. + mask: an optional tensor of shape [batch size, sequence len], where + 1 indicates valid and 0 indicates invalid. + + Returns: + last_output: a tensor of shape [batch size, chunk_size*levels] representing the + patient embedding. + outputs: a tensor of shape [batch size, sequence len, chunk_size*levels] representing the patient at each time step. + """ + # rnn will only apply dropout between layers + batch_size, time_step, feature_dim = x.size() + device = x.device + if time == None: + time = torch.ones(batch_size, time_step, device=device) + time = time.reshape(batch_size, time_step) + c_out = torch.zeros(batch_size, self.hidden_dim, device=device) + h_out = torch.zeros(batch_size, self.hidden_dim, device=device) + + hidden_states = [] + distance = [] + for t in range(time_step): + out, c_out, h_out = self.step(x[:, t, :], c_out, h_out, time[:, t], device) + cur_distance = 1 - torch.mean( + out[..., self.hidden_dim : self.hidden_dim + self.levels], -1 + ) + hidden_states.append(out[..., : self.hidden_dim]) + distance.append(cur_distance) + + # shape: [time, batch, hidden_dim] + hidden_seq = torch.stack(hidden_states) + distance = torch.stack(distance) + + key_padding_mask = None + if mask is not None: + key_padding_mask = (mask == 0).to(device=device, dtype=torch.bool) + + attn_output, _ = self.mha( + hidden_seq, hidden_seq, hidden_seq, key_padding_mask=key_padding_mask + ) + attn_output = self.attn_norm(attn_output + hidden_seq) + + tmp_h = torch.zeros( + (self.conv_size, batch_size, self.hidden_dim), device=device + ) + tmp_dis = torch.zeros((self.conv_size, batch_size), device=device) + conv_outputs = [] + for t in range(time_step): + cur_h = attn_output[t] + cur_distance = distance[t] + + tmp_h = torch.cat((tmp_h[1:], cur_h.unsqueeze(0)), 0) + tmp_dis = torch.cat((tmp_dis[1:], cur_distance.unsqueeze(0)), 0) + + # Re-weighted convolution operation + local_dis = tmp_dis.permute(1, 0) + local_dis = torch.cumsum(local_dis, dim=1) + local_dis = self._apply_activation("softmax", local_dis, dim=1) + local_h = tmp_h.permute(1, 2, 0) + local_h = local_h * local_dis.unsqueeze(1) + + # Re-calibrate Progression patterns + local_theme = torch.mean(local_h, dim=-1) + local_theme = self.nn_scale(local_theme) + local_theme = self._apply_activation("relu", local_theme) + local_theme = self.nn_rescale(local_theme) + local_theme = self._apply_activation("sigmoid", local_theme) + + local_h = self.nn_conv(local_h).squeeze(-1) + local_h = local_theme * local_h + conv_outputs.append(local_h) + + origin_h = attn_output.permute(1, 0, 2) + rnn_outputs = torch.stack(conv_outputs).permute(1, 0, 2) + if self.dropres > 0.0: + origin_h = self.nn_dropres(origin_h) + rnn_outputs = rnn_outputs + origin_h + rnn_outputs = rnn_outputs.contiguous().view(-1, rnn_outputs.size(-1)) + if self.dropout > 0.0: + rnn_outputs = self.nn_dropout(rnn_outputs) + + output = rnn_outputs.contiguous().view(batch_size, time_step, self.hidden_dim) + last_output = get_last_visit(output, mask) + + return last_output, output, distance + + +class StageAttentionNet(BaseModel): + """StageNet model. + + Paper: Junyi Gao et al. Stagenet: Stage-aware neural networks for health + risk prediction. WWW 2020. But with Multi-Head Attention (MHA) between + the SA-LSTM and the SA-CNN. + + This model uses the StageNetProcessor which expects inputs in the format: + {"value": [...], "time": [...]} + + The processor handles various input types: + - Code sequences (with/without time intervals) + - Nested code sequences (with/without time intervals) + - Numeric feature vectors (with/without time intervals) + + Time intervals are optional and represent inter-event delays. If not + provided, all events are treated as having uniform time intervals. + + Args: + dataset: the dataset to train the model. It is used to query certain + information such as the set of all tokens. + embedding_dim: the embedding dimension. Default is 128. + chunk_size: the chunk size for the StageNet layer. Default is 128. + levels: the number of levels for the StageNet layer. + levels * chunk_size = hidden_dim in the RNN. Smaller chunk_size + and more levels can capture more detailed patient status + variations. Default is 3. + **kwargs: other parameters for the StageNet layer. + + Examples: + >>> from pyhealth.datasets import SampleDataset + >>> samples = [ + ... { + ... "patient_id": "patient-0", + ... "visit_id": "visit-0", + ... "codes": { + ... "value": ["505800458", "50580045810", "50580045811"], + ... "time": [0.0, 2.0, 1.3], + ... }, + ... "procedures": { + ... "value": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], + ... "time": [0.0, 1.5], + ... }, + ... "label": 1, + ... }, + ... { + ... "patient_id": "patient-0", + ... "visit_id": "visit-1", + ... "codes": { + ... "value": ["55154191800", "551541928", "55154192800"], + ... "time": [0.0, 2.0, 1.3], + ... }, + ... "procedures": { + ... "value": [["A04A", "B035", "C129"]], + ... "time": [0.0], + ... }, + ... "label": 0, + ... }, + ... ] + >>> + >>> # dataset + >>> dataset = SampleDataset( + ... samples=samples, + ... input_schema={ + ... "codes": "stagenet", + ... "procedures": "stagenet", + ... }, + ... output_schema={"label": "binary"}, + ... dataset_name="test" + ... ) + >>> + >>> # data loader + >>> from pyhealth.datasets import get_dataloader + >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) + >>> + >>> # model + >>> model = StageNet(dataset=dataset) + >>> + >>> # data batch + >>> data_batch = next(iter(train_loader)) + >>> + >>> # try the model + >>> ret = model(**data_batch) + >>> print(ret) + { + 'loss': tensor(...), + 'y_prob': tensor(...), + 'y_true': tensor(...), + 'logit': tensor(...) + } + >>> + + """ + + def __init__( + self, + dataset: SampleDataset, + embedding_dim: int = 128, + chunk_size: int = 128, + levels: int = 3, + **kwargs, + ): + super(StageAttentionNet, self).__init__( + dataset=dataset, + ) + self.embedding_dim = embedding_dim + self.chunk_size = chunk_size + self.levels = levels + + # validate kwargs for StageNet layer + if "input_dim" in kwargs: + raise ValueError("input_dim is determined by embedding_dim") + + assert len(self.label_keys) == 1, "Only one label key is supported" + self.label_key = self.label_keys[0] + self.mode = self.dataset.output_schema[self.label_key] + + # Use EmbeddingModel for unified embedding handling + self.embedding_model = EmbeddingModel(dataset, embedding_dim) + + # Create StageNet layers for each feature + self.stagenet = nn.ModuleDict() + for feature_key in self.feature_keys: + self.stagenet[feature_key] = StageNetLayer( + input_dim=embedding_dim, + chunk_size=self.chunk_size, + levels=self.levels, + **kwargs, + ) + + output_size = self.get_output_size() + self.fc = nn.Linear( + len(self.feature_keys) * self.chunk_size * self.levels, output_size + ) + + self._deeplift_hooks = None + + # ------------------------------------------------------------------ + # Interpretability support (e.g., DeepLIFT) + # ------------------------------------------------------------------ + def set_deeplift_hooks(self, hooks) -> None: + """Attach activation hooks for interpretability algorithms. + + Args: + hooks: Object exposing ``apply(name, tensor, **kwargs)`` which + will be invoked for activation calls within StageNet layers. + """ + + self._deeplift_hooks = hooks + for layer in self.stagenet.values(): + if hasattr(layer, "set_activation_hooks"): + layer.set_activation_hooks(hooks) + + def clear_deeplift_hooks(self) -> None: + """Remove previously registered interpretability hooks.""" + + self._deeplift_hooks = None + for layer in self.stagenet.values(): + if hasattr(layer, "set_activation_hooks"): + layer.set_activation_hooks(None) + + def forward_from_embedding( + self, + feature_embeddings: Dict[str, torch.Tensor], + time_info: Optional[Dict[str, torch.Tensor]] = None, + **kwargs, + ) -> Dict[str, torch.Tensor]: + """Forward pass starting from feature embeddings. + + This method bypasses the embedding layers but still performs + temporal processing through StageNet layers. This is useful for + interpretability methods like Integrated Gradients that need to + interpolate in embedding space. + + Args: + feature_embeddings: Dictionary mapping feature keys to their + embedded representations. Each tensor should have shape + [batch_size, seq_len, embedding_dim]. + time_info: Optional dictionary mapping feature keys to their + time information tensors of shape [batch_size, seq_len]. + If None, uniform time intervals are assumed. + **kwargs: Additional keyword arguments, must include the label + key for loss computation. + + Returns: + A dictionary with the following keys: + loss: a scalar tensor representing the final loss. + y_prob: a tensor of predicted probabilities. + y_true: a tensor representing the true labels. + logit: the raw logits before activation. + embed: (if embed=True in kwargs) the patient embedding. + """ + patient_emb = [] + distance = [] + + for feature_key in self.feature_keys: + # Get embedded feature + x = feature_embeddings[feature_key].to(self.device) + # x: [batch, seq_len, embedding_dim] or 4D nested + + # Handle nested sequences (4D) by pooling over inner dim + # This matches forward() processing for consistency + if x.dim() == 4: # [batch, seq_len, inner_len, embedding_dim] + # Sum pool over inner dimension + x = x.sum(dim=2) # [batch, seq_len, embedding_dim] + + # Get time information if available + time = None + if time_info is not None and feature_key in time_info: + if time_info[feature_key] is not None: + time = time_info[feature_key].to(self.device) + # Ensure time is 2D [batch, seq_len] + if time.dim() == 1: + time = time.unsqueeze(0) + + # Create mask from embedded values + mask = (x.sum(dim=-1) != 0).int() # [batch, seq_len] + + # Pass through StageNet layer with embedded features + last_output, _, cur_dis = self.stagenet[feature_key]( + x, time=time, mask=mask + ) + + patient_emb.append(last_output) + distance.append(cur_dis) + + # Concatenate all feature embeddings + patient_emb = torch.cat(patient_emb, dim=1) + + # Register hook if needed for gradient tracking + if patient_emb.requires_grad: + patient_emb.register_hook(lambda grad: grad) + + # Pass through final classification layer + logits = self.fc(patient_emb) + + # Obtain y_true, loss, y_prob + y_true = kwargs[self.label_key].to(self.device) + loss = self.get_loss_function()(logits, y_true) + + y_prob = self.prepare_y_prob(logits) + results = { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + "logit": logits, + } + + # Optionally return embeddings + if kwargs.get("embed", False): + results["embed"] = patient_emb + + return results + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward propagation. + + The label `kwargs[self.label_key]` is a list of labels for each + patient. + + Args: + **kwargs: keyword arguments for the model. The keys must contain + all the feature keys and the label key. Feature keys should + contain tuples of (time, values) from temporal processors. + + Returns: + A dictionary with the following keys: + loss: a scalar tensor representing the final loss. + distance: list of tensors of stage variation. + y_prob: a tensor of predicted probabilities. + y_true: a tensor representing the true labels. + """ + patient_emb = [] + distance = [] + + for feature_key in self.feature_keys: + # Extract (time, values) tuple + feature = kwargs[feature_key] + + # Get value and time tensors from tuple + if isinstance(feature, tuple) and len(feature) == 2: + time, x = feature # Unpack (time, values) + # x: [batch, seq_len] or [batch, seq_len, dim] + # time: [batch, seq_len] or None + + # Warn if time information is missing + if time is None: + import warnings + + warnings.warn( + f"Feature '{feature_key}' does not have time " + f"intervals. StageNet's temporal modeling " + f"capabilities will be limited. Consider using " + f"StageNet format with time intervals for " + f"better performance.", + UserWarning, + ) + else: + # Fallback for backward compatibility + import warnings + + warnings.warn( + f"Feature '{feature_key}' is not a temporal tuple. " + f"Using fallback mode without time intervals. " + f"The model may not learn temporal patterns properly. " + f"Please use 'stagenet' or 'stagenet_tensor' " + f"processors in your input schema.", + UserWarning, + ) + x = feature + time = None + + # Embed the values using EmbeddingModel + # Need to pass as dict for EmbeddingModel + embedded = self.embedding_model({feature_key: x}) + x = embedded[feature_key] # [batch, seq_len, embedding_dim] + # Handle nested sequences (2D codes -> need pooling on inner dim) + if x.dim() == 4: # [batch, seq_len, inner_len, embedding_dim] + # Sum pool over inner dimension + x = x.sum(dim=2) # [batch, seq_len, embedding_dim] + + # Create mask from embedded values + mask = (x.sum(dim=-1) != 0).int() # [batch, seq_len] + + # Move time to correct device if present + if time is not None: + time = time.to(self.device) + # Ensure time is 2D [batch, seq_len] + if time.dim() == 1: + time = time.unsqueeze(0) + + # Pass through StageNet layer + last_output, _, cur_dis = self.stagenet[feature_key]( + x, time=time, mask=mask + ) + + patient_emb.append(last_output) + + distance.append(cur_dis) + + patient_emb = torch.cat(patient_emb, dim=1) + # (patient, label_size) + logits = self.fc(patient_emb) + + # obtain y_true, loss, y_prob + y_true = kwargs[self.label_key].to(self.device) + loss = self.get_loss_function()(logits, y_true) + + y_prob = self.prepare_y_prob(logits) + results = { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + "logit": logits, + } + if kwargs.get("embed", False): + results["embed"] = patient_emb + return results + + +if __name__ == "__main__": + from pyhealth.datasets import SampleDataset + + samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "codes": ( + [0.0, 2.0, 1.3], + ["505800458", "50580045810", "50580045811"], + ), + "procedures": ( + [0.0, 1.5], + [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], + ), + "label": 1, + }, + { + "patient_id": "patient-0", + "visit_id": "visit-1", + "codes": ( + [0.0, 2.0, 1.3, 1.0, 2.0], + [ + "55154191800", + "551541928", + "55154192800", + "705182798", + "70518279800", + ], + ), + "procedures": ( + [0.0], + [["A04A", "B035", "C129"]], + ), + "label": 0, + }, + ] + + # dataset + dataset = SampleDataset( + samples=samples, + input_schema={ + "codes": "stagenet", + "procedures": "stagenet", + }, + output_schema={"label": "binary"}, + dataset_name="test", + ) + + # data loader + from pyhealth.datasets import get_dataloader + + train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) + + # model + model = StageAttentionNet(dataset=dataset) + + # data batch + data_batch = next(iter(train_loader)) + + # try the model + ret = model(**data_batch) + print(ret) + + # try loss backward + ret["loss"].backward() diff --git a/tests/core/test_stagenet_mha.py b/tests/core/test_stagenet_mha.py new file mode 100644 index 00000000..a503690b --- /dev/null +++ b/tests/core/test_stagenet_mha.py @@ -0,0 +1,117 @@ +import unittest +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models.stagenet_mha import StageAttentionNet as StageNetMHA + + +class TestStageNetMHA(unittest.TestCase): + """Tests for the StageNet variant with MHA inserted after SA-LSTM.""" + + def setUp(self): + # Mixed input types to exercise masking and time handling + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "codes": ([0.0, 1.0, 2.0], ["A1", "A2", "A3"]), + "procedures": ( + [0.0, 1.5], + [["P1", "P2", "P3"], ["P4", "P5"]], + ), + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "codes": ([0.0, 2.0], ["B1", "B2"]), + "procedures": ([0.0], [["P6"]]), + "label": 0, + }, + ] + + self.input_schema = {"codes": "stagenet", "procedures": "stagenet"} + self.output_schema = {"label": "binary"} + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test_stagenet_mha", + ) + + # hidden_dim = chunk_size * levels = 2 * 3 = 6 so pick num_heads=3 + self.model = StageNetMHA( + dataset=self.dataset, chunk_size=2, levels=3, num_heads=3 + ) + + def test_forward_pass(self): + """Forward pass returns expected keys and shapes.""" + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + ret = self.model(**batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + self.assertIn("logit", ret) + self.assertEqual(ret["logit"].shape[0], 2) + self.assertEqual(ret["y_prob"].shape[0], 2) + self.assertEqual(ret["loss"].dim(), 0) + + def test_backward_pass(self): + """Backward pass produces gradients through the MHA variant.""" + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + ret = self.model(**batch) + ret["loss"].backward() + + has_grad = any( + p.requires_grad and p.grad is not None for p in self.model.parameters() + ) + self.assertTrue(has_grad, "No gradients found after backward pass") + + def test_time_none_support(self): + """MHA variant works when time intervals are absent (mask-driven).""" + samples_no_time = [ + { + "patient_id": "p0", + "visit_id": "v0", + "codes": (None, ["X1", "X2"]), + "label": 1, + }, + { + "patient_id": "p1", + "visit_id": "v1", + "codes": (None, ["Y1"]), + "label": 0, + }, + ] + dataset_no_time = create_sample_dataset( + samples=samples_no_time, + input_schema={"codes": "stagenet"}, + output_schema=self.output_schema, + dataset_name="test_stagenet_mha_no_time", + ) + model_no_time = StageNetMHA( + dataset=dataset_no_time, chunk_size=2, levels=2, num_heads=2 + ) + + loader = get_dataloader(dataset_no_time, batch_size=2, shuffle=False) + batch = next(iter(loader)) + time, _ = batch["codes"] + self.assertIsNone(time) + + with torch.no_grad(): + ret = model_no_time(**batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertEqual(ret["y_prob"].shape[0], 2) + + +if __name__ == "__main__": + unittest.main()