From f54a144df3a3b3477515bc75f2fbb85f5c1185c3 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Sat, 2 May 2026 02:57:58 +0430 Subject: [PATCH] fix: gemma-4 inference with Pipeline Parallelism --- aphrodite/model_executor/models/gemma4.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/aphrodite/model_executor/models/gemma4.py b/aphrodite/model_executor/models/gemma4.py index 95e0dec39e..bee98e6fdf 100644 --- a/aphrodite/model_executor/models/gemma4.py +++ b/aphrodite/model_executor/models/gemma4.py @@ -1218,8 +1218,8 @@ def forward( else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - per_layer_inputs = intermediate_tensors.get("per_layer_inputs") + residual = intermediate_tensors.tensors.get("residual") + per_layer_inputs = intermediate_tensors.tensors.get("per_layer_inputs") aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual) for layer_idx, layer in enumerate(islice(self.layers, self.start_layer, self.end_layer)): @@ -1238,13 +1238,12 @@ def forward( ) self._maybe_add_hidden_state(aux_hidden_states, layer_idx + 1, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors( - { - "hidden_states": hidden_states, - "residual": residual, - "per_layer_inputs": per_layer_inputs, - } - ) + tensors = {"hidden_states": hidden_states} + if residual is not None: + tensors["residual"] = residual + if per_layer_inputs is not None: + tensors["per_layer_inputs"] = per_layer_inputs + return IntermediateTensors(tensors) # Gemma4 incorporates residual into hidden_states directly # Apply norm without residual fusion when possible. if residual is None: