@@ -519,9 +519,11 @@ def pop_and_gather_aux_hiddens(self):
519519 self ._aux_hidden_states .clear ()
520520
521521 # Gather aux hidden states on the draft model device
522- aux_h_list = [h .to (self .eagle_module .fc .weight .device ) for h in aux_h_list ]
522+ aux_hiddens = torch .cat (
523+ [h .to (self .eagle_module .fc .weight .device ) for h in aux_h_list ], dim = - 1
524+ )
523525
524- return aux_h_list
526+ return aux_hiddens
525527
526528 def _get_eagle_device (self ):
527529 """Return the device where we should place eagle module."""
@@ -789,7 +791,7 @@ def _base_model_forward(
789791
790792 return EagleBaseModelOutput (
791793 input_embeds = base_input_embeds ,
792- aux_hiddens = torch . cat ( self .pop_and_gather_aux_hiddens (), dim = - 1 ),
794+ aux_hiddens = self .pop_and_gather_aux_hiddens (),
793795 out_hiddens = base_model_hidden_states ,
794796 logits = base_model_logits ,
795797 loss = base_model_loss ,
@@ -936,20 +938,10 @@ def forward(
936938 classification_loss , acc = self ._eagle_loss (
937939 # base model predict +1 tok, while eagle predict +2
938940 # so we shift base model outputs compared to eagle outputs
939- base_outputs .logits [:, 1 + i :],
940- eagle_logit [:, : - (1 + i )],
941941 # additionally, we mask the first n tok of eagle outputs at nth TTT step
942- torch .cat (
943- (
944- torch .zeros (
945- b , ttt_step , dtype = loss_mask .dtype , device = loss_mask .device
946- ),
947- loss_mask [:, 1 + ttt_step :]
948- if i == 0
949- else loss_mask [:, 1 + ttt_step : - i ],
950- ),
951- dim = 1 ,
952- ),
942+ base_outputs .logits [:, 1 + i + ttt_step :],
943+ eagle_logit [:, ttt_step : - (1 + i )],
944+ loss_mask [:, 1 + ttt_step :] if i == 0 else loss_mask [:, 1 + ttt_step : - i ],
953945 )
954946 # Apply loss decay factor to focus on early steps
955947 classification_loss *= self .eagle_loss_decay_factor ** (ttt_step + i )
@@ -1034,11 +1026,7 @@ def pseudo_speculative_generate(
10341026 # EAGLE-3
10351027 # Only the first iteration input_hidden_states are from aux_hidden_state layers
10361028 # Gather _aux_hidden_states from all devices before concatenation
1037- gathered_aux_hidden_states = self .pop_and_gather_aux_hiddens ()
1038- eagle_input_hidden_states = self .eagle_module .fc (
1039- torch .cat (gathered_aux_hidden_states , dim = - 1 )
1040- )
1041-
1029+ eagle_input_hidden_states = self .eagle_module .fc (self .pop_and_gather_aux_hiddens ())
10421030 else :
10431031 eagle_input_hidden_states = base_model_hidden_states
10441032
0 commit comments