Skip to content

Commit 3c8bbac

Browse files
committed
debug
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent d18eabe commit 3c8bbac

1 file changed

Lines changed: 9 additions & 21 deletions

File tree

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -519,9 +519,11 @@ def pop_and_gather_aux_hiddens(self):
519519
self._aux_hidden_states.clear()
520520

521521
# Gather aux hidden states on the draft model device
522-
aux_h_list = [h.to(self.eagle_module.fc.weight.device) for h in aux_h_list]
522+
aux_hiddens = torch.cat(
523+
[h.to(self.eagle_module.fc.weight.device) for h in aux_h_list], dim=-1
524+
)
523525

524-
return aux_h_list
526+
return aux_hiddens
525527

526528
def _get_eagle_device(self):
527529
"""Return the device where we should place eagle module."""
@@ -789,7 +791,7 @@ def _base_model_forward(
789791

790792
return EagleBaseModelOutput(
791793
input_embeds=base_input_embeds,
792-
aux_hiddens=torch.cat(self.pop_and_gather_aux_hiddens(), dim=-1),
794+
aux_hiddens=self.pop_and_gather_aux_hiddens(),
793795
out_hiddens=base_model_hidden_states,
794796
logits=base_model_logits,
795797
loss=base_model_loss,
@@ -936,20 +938,10 @@ def forward(
936938
classification_loss, acc = self._eagle_loss(
937939
# base model predict +1 tok, while eagle predict +2
938940
# so we shift base model outputs compared to eagle outputs
939-
base_outputs.logits[:, 1 + i :],
940-
eagle_logit[:, : -(1 + i)],
941941
# additionally, we mask the first n tok of eagle outputs at nth TTT step
942-
torch.cat(
943-
(
944-
torch.zeros(
945-
b, ttt_step, dtype=loss_mask.dtype, device=loss_mask.device
946-
),
947-
loss_mask[:, 1 + ttt_step :]
948-
if i == 0
949-
else loss_mask[:, 1 + ttt_step : -i],
950-
),
951-
dim=1,
952-
),
942+
base_outputs.logits[:, 1 + i + ttt_step :],
943+
eagle_logit[:, ttt_step : -(1 + i)],
944+
loss_mask[:, 1 + ttt_step :] if i == 0 else loss_mask[:, 1 + ttt_step : -i],
953945
)
954946
# Apply loss decay factor to focus on early steps
955947
classification_loss *= self.eagle_loss_decay_factor ** (ttt_step + i)
@@ -1034,11 +1026,7 @@ def pseudo_speculative_generate(
10341026
# EAGLE-3
10351027
# Only the first iteration input_hidden_states are from aux_hidden_state layers
10361028
# Gather _aux_hidden_states from all devices before concatenation
1037-
gathered_aux_hidden_states = self.pop_and_gather_aux_hiddens()
1038-
eagle_input_hidden_states = self.eagle_module.fc(
1039-
torch.cat(gathered_aux_hidden_states, dim=-1)
1040-
)
1041-
1029+
eagle_input_hidden_states = self.eagle_module.fc(self.pop_and_gather_aux_hiddens())
10421030
else:
10431031
eagle_input_hidden_states = base_model_hidden_states
10441032

0 commit comments

Comments
 (0)