Skip to content

[OpenVINO] Support Gemma 4#1675

Closed
rkazants wants to merge 2 commits intohuggingface:transformers-v5from
rkazants:support_gemma_4
Closed

[OpenVINO] Support Gemma 4#1675
rkazants wants to merge 2 commits intohuggingface:transformers-v5from
rkazants:support_gemma_4

Conversation

@rkazants
Copy link
Copy Markdown
Collaborator

@rkazants rkazants commented Apr 2, 2026

What does this PR do?

Fixes 182357

Installation instructions:

pip install git+https://github.com/rkazants/optimum-intel.git@support_gemma_4
pip install --pre -U openvino openvino-tokenizers nncf --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly
pip install transformers==5.5.0
pip install requests torchvision requests

Exporting cmd-line:

optimum-cli export openvino -m google/gemma-4-E2B-it ov_gemma4_E2Bit --task=image-text-to-text

Inference script:

from transformers import AutoProcessor
import torch
from transformers import AutoProcessor, Gemma3nForConditionalGeneration, Gemma4ForConditionalGeneration
from optimum.intel.openvino import OVModelForVisualCausalLM

#model_id = "google/gemma-4-E2B-it"
model_id = "ov_gemma4_E2Bit"
model = OVModelForVisualCausalLM.from_pretrained(model_id)

processor = AutoProcessor.from_pretrained(model_id, padding_side="left")

url = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
messages = [
    {
        "role": "system",
        "content": [
            {"type": "text", "text": "You are a helpful assistant."}
        ]
    },
    {
        "role": "user", "content": [
            {"type": "image", "url": url},
            {"type": "text", "text": "What is shown in this image?"},
        ]
    },
]
inputs = processor.apply_chat_template(
    messages,
    tokenize=True,
    return_dict=True,
    return_tensors="pt",
    add_generation_prompt=True,
)

output = model.generate(**inputs, max_new_tokens=50)
print(processor.decode(output[0, inputs.input_ids.shape[1]: ], skip_special_tokens=True))

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

rkazants added 2 commits April 2, 2026 21:50
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
@sund00bie
Copy link
Copy Markdown

@rkazants thanks for all your work on this (and other models like qwen3.5)

Not all heros wear capes

@SearchSavior
Copy link
Copy Markdown

They left us out of the release...

But @rkazants did not forget!

@savvadesogle
Copy link
Copy Markdown

Omg, thank you!!!

@Schuwi
Copy link
Copy Markdown

Schuwi commented Apr 12, 2026

Are google/gemma-4-31B-it and google/gemma-4-26B-A4B-it intentionally unsupported by this PR or is that a bug?
If the former then PR title/description is a bit misleading imo, if only E2B (and E4B?) variant support is implemented.


My test system was running a docker container with image built from this Dockerfile:

FROM python:3.11-slim

ENV DEBIAN_FRONTEND=noninteractive \
    PIP_NO_CACHE_DIR=1 \
    PYTHONDONTWRITEBYTECODE=1 \
    PYTHONUNBUFFERED=1

RUN apt-get update && apt-get install -y \
    git curl build-essential \
    && rm -rf /var/lib/apt/lists/*

RUN python -m pip install --upgrade pip setuptools wheel

# Gemma 4 support path from Optimum-Intel PR #1675
RUN pip install git+https://github.com/rkazants/optimum-intel.git@support_gemma_4 && \
    pip install --pre -U openvino openvino-tokenizers nncf \
      --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly && \
    pip install transformers==5.5.0 && \
    pip install requests torchvision huggingface_hub

WORKDIR /workspace
CMD ["/bin/bash"]

and I ran into these errors during export:

Gemma 4 31B

root@93f3f19ebc7e:/workspace# optimum-cli export openvino -m google/gemma-4-31B-it ov_gemma4_31b_it-int4 --task=image-text-to-text
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
`torch_dtype` is deprecated! Use `dtype` instead!
Fetching 2 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 33156.55it/s]
Download complete: : 0.00B [00:00, ?B/s]                                                                                                                                                                          | 0/2 [00:00<?, ?it/s]
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1188/1188 [00:00<00:00, 5663.98it/s]
`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.
/usr/local/lib/python3.11/site-packages/transformers/models/gemma4/modeling_gemma4.py:834: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if num_rotated_channels_per_dim <= 0:
/usr/local/lib/python3.11/site-packages/transformers/integrations/sdpa_attention.py:77: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
/usr/local/lib/python3.11/site-packages/transformers/models/gemma4/modeling_gemma4.py:615: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if output_length > hidden_states.shape[1]:
/usr/local/lib/python3.11/site-packages/transformers/models/gemma4/modeling_gemma4.py:623: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if hidden_states.shape[1] != output_length:
/usr/local/lib/python3.11/site-packages/transformers/models/gemma4/modeling_gemma4.py:590: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  k = int((input_seq_len // length) ** 0.5)
/usr/local/lib/python3.11/site-packages/transformers/models/gemma4/modeling_gemma4.py:592: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if k_squared * length != input_seq_len:
/usr/local/lib/python3.11/site-packages/transformers/cache_utils.py:131: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if not self.is_initialized or self.keys.numel() == 0:
/usr/local/lib/python3.11/site-packages/transformers/masking_utils.py:192: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if (padding_length := kv_length + kv_offset - attention_mask.shape[-1]) > 0:
/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/model_patcher.py:281: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  torch.tensor(0.0, device=mask.device, dtype=dtype),
/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/model_patcher.py:282: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  torch.tensor(torch.finfo(torch.float16).min, device=mask.device, dtype=dtype),
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/site-packages/openvino/frontend/pytorch/ts_decoder.py", line 72, in __init__
    pt_module = self._get_scripted_model(
                ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/openvino/frontend/pytorch/ts_decoder.py", line 178, in _get_scripted_model
    scripted = torch.jit.trace(
               ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/jit/_trace.py", line 1022, in trace
    traced_func = _trace_impl(
                  ^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/jit/_trace.py", line 707, in _trace_impl
    return trace_module(
           ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/jit/_trace.py", line 1216, in trace_module
    module._c._create_method_from_trace(
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1769, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/model_patcher.py", line 4963, in gemma4_lm_forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1769, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/model_patcher.py", line 4907, in gemma4_language_model_forward
    outputs = self.model.language_model(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1769, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/transformers/utils/generic.py", line 952, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/transformers/utils/output_capturing.py", line 248, in wrapper
    outputs = func(self, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/transformers/models/gemma4/modeling_gemma4.py", line 1605, in forward
    hidden_states = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/transformers/modeling_layers.py", line 93, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1769, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/transformers/models/gemma4/modeling_gemma4.py", line 1361, in forward
    hidden_states, _ = self.self_attn(
                       ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1769, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/model_patcher.py", line 5082, in gemma4_text_attention_forward
    key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/transformers/cache_utils.py", line 937, in update
    keys, values = self.layers[layer_idx].update(key_states, value_states, *args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/optimum/exporters/onnx/model_patcher.py", line 531, in patched_dynamic_layer_update
    self.keys = torch.cat([self.keys, key_states], dim=-2)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 16 but got size 4 for tensor number 1 in the list.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/bin/optimum-cli", line 6, in <module>
    sys.exit(main())
             ^^^^^^
  File "/usr/local/lib/python3.11/site-packages/optimum/commands/optimum_cli.py", line 219, in main
    service.run()
  File "/usr/local/lib/python3.11/site-packages/optimum/commands/export/openvino.py", line 468, in run
    main_export(
  File "/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/__main__.py", line 534, in main_export
    submodel_paths = export_from_model(
                     ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/convert.py", line 773, in export_from_model
    export_models(
  File "/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/convert.py", line 536, in export_models
    export(
  File "/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/convert.py", line 219, in export
    return export_pytorch(
           ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/convert.py", line 442, in export_pytorch
    ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/openvino/frontend/pytorch/ts_decoder.py", line 84, in __init__
    raise RuntimeError(
RuntimeError: Couldn't get TorchScript module by tracing.
Exception:
Sizes of tensors must match except in dimension 2. Expected size 16 but got size 4 for tensor number 1 in the list.
Please check correctness of provided 'example_input'. Sometimes models can be converted in scripted mode, please try running conversion without 'example_input'.
 You can also provide TorchScript module that you obtained yourself, please refer to PyTorch documentation: https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html.

Gemma 4 26B A4B

root@93f3f19ebc7e:/workspace# optimum-cli export openvino -m google/gemma-4-26B-A4B-it ov_gemma4_26b_A4Bit-8bit --task=image-text-to-text
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
`torch_dtype` is deprecated! Use `dtype` instead!
Fetching 2 files: 100%|████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 16448.25it/s]
Download complete: : 0.00B [00:00, ?B/s]                                                      | 0/2 [00:00<?, ?it/s]
Loading weights: 100%|████████████████████████████████████████████████████████| 1013/1013 [00:00<00:00, 6510.38it/s]
`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.
/usr/local/lib/python3.11/site-packages/transformers/models/gemma4/modeling_gemma4.py:834: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if num_rotated_channels_per_dim <= 0:
/usr/local/lib/python3.11/site-packages/transformers/integrations/sdpa_attention.py:77: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
/usr/local/lib/python3.11/site-packages/transformers/models/gemma4/modeling_gemma4.py:615: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if output_length > hidden_states.shape[1]:
/usr/local/lib/python3.11/site-packages/transformers/models/gemma4/modeling_gemma4.py:623: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if hidden_states.shape[1] != output_length:
/usr/local/lib/python3.11/site-packages/transformers/models/gemma4/modeling_gemma4.py:590: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  k = int((input_seq_len // length) ** 0.5)
/usr/local/lib/python3.11/site-packages/transformers/models/gemma4/modeling_gemma4.py:592: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if k_squared * length != input_seq_len:
/usr/local/lib/python3.11/site-packages/transformers/cache_utils.py:131: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if not self.is_initialized or self.keys.numel() == 0:
/usr/local/lib/python3.11/site-packages/transformers/masking_utils.py:192: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if (padding_length := kv_length + kv_offset - attention_mask.shape[-1]) > 0:
/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/model_patcher.py:281: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  torch.tensor(0.0, device=mask.device, dtype=dtype),
/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/model_patcher.py:282: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  torch.tensor(torch.finfo(torch.float16).min, device=mask.device, dtype=dtype),
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/site-packages/openvino/frontend/pytorch/ts_decoder.py", line 72, in __init__
    pt_module = self._get_scripted_model(
                ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/openvino/frontend/pytorch/ts_decoder.py", line 178, in _get_scripted_model
    scripted = torch.jit.trace(
               ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/jit/_trace.py", line 1022, in trace
    traced_func = _trace_impl(
                  ^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/jit/_trace.py", line 707, in _trace_impl
    return trace_module(
           ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/jit/_trace.py", line 1216, in trace_module
    module._c._create_method_from_trace(
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1769, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/model_patcher.py", line 4963, in gemma4_lm_forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1769, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/model_patcher.py", line 4907, in gemma4_language_model_forward
    outputs = self.model.language_model(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1769, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/transformers/utils/generic.py", line 952, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/transformers/utils/output_capturing.py", line 248, in wrapper
    outputs = func(self, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/transformers/models/gemma4/modeling_gemma4.py", line 1605, in forward
    hidden_states = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/transformers/modeling_layers.py", line 93, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1769, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/transformers/models/gemma4/modeling_gemma4.py", line 1361, in forward
    hidden_states, _ = self.self_attn(
                       ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1769, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/model_patcher.py", line 5082, in gemma4_text_attention_forward
    key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/transformers/cache_utils.py", line 937, in update
    keys, values = self.layers[layer_idx].update(key_states, value_states, *args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/optimum/exporters/onnx/model_patcher.py", line 531, in patched_dynamic_layer_update
    self.keys = torch.cat([self.keys, key_states], dim=-2)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 8 but got size 2 for tensor number 1 in the list.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/bin/optimum-cli", line 6, in <module>
    sys.exit(main())
             ^^^^^^
  File "/usr/local/lib/python3.11/site-packages/optimum/commands/optimum_cli.py", line 219, in main
    service.run()
  File "/usr/local/lib/python3.11/site-packages/optimum/commands/export/openvino.py", line 468, in run
    main_export(
  File "/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/__main__.py", line 534, in main_export
    submodel_paths = export_from_model(
                     ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/convert.py", line 773, in export_from_model
    export_models(
  File "/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/convert.py", line 536, in export_models
    export(
  File "/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/convert.py", line 219, in export
    return export_pytorch(
           ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/convert.py", line 442, in export_pytorch
    ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/openvino/frontend/pytorch/ts_decoder.py", line 84, in __init__
    raise RuntimeError(
RuntimeError: Couldn't get TorchScript module by tracing.
Exception:
Sizes of tensors must match except in dimension 2. Expected size 8 but got size 2 for tensor number 1 in the list.
Please check correctness of provided 'example_input'. Sometimes models can be converted in scripted mode, please try running conversion without 'example_input'.
 You can also provide TorchScript module that you obtained yourself, please refer to PyTorch documentation: https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html.

For comparison, the E2B export was successful:

Gemma 4 E2B

root@93f3f19ebc7e:/workspace# optimum-cli export openvino -m google/gemma-4-E2B-it ov_gemma4_E2Bit --task=image-text-to-text
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
`torch_dtype` is deprecated! Use `dtype` instead!
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2011/2011 [00:00<00:00, 7010.99it/s]
`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.
/usr/local/lib/python3.11/site-packages/transformers/models/gemma4/modeling_gemma4.py:834: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if num_rotated_channels_per_dim <= 0:
/usr/local/lib/python3.11/site-packages/transformers/integrations/sdpa_attention.py:77: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
/usr/local/lib/python3.11/site-packages/transformers/models/gemma4/modeling_gemma4.py:615: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if output_length > hidden_states.shape[1]:
/usr/local/lib/python3.11/site-packages/transformers/models/gemma4/modeling_gemma4.py:623: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if hidden_states.shape[1] != output_length:
/usr/local/lib/python3.11/site-packages/transformers/models/gemma4/modeling_gemma4.py:590: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  k = int((input_seq_len // length) ** 0.5)
/usr/local/lib/python3.11/site-packages/transformers/models/gemma4/modeling_gemma4.py:592: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if k_squared * length != input_seq_len:
/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/model_patcher.py:4861: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if per_layer_projection.shape != per_layer_inputs.shape:
/usr/local/lib/python3.11/site-packages/transformers/cache_utils.py:131: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if not self.is_initialized or self.keys.numel() == 0:
/usr/local/lib/python3.11/site-packages/transformers/masking_utils.py:192: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if (padding_length := kv_length + kv_offset - attention_mask.shape[-1]) > 0:
/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/model_patcher.py:281: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  torch.tensor(0.0, device=mask.device, dtype=dtype),
/usr/local/lib/python3.11/site-packages/optimum/exporters/openvino/model_patcher.py:282: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  torch.tensor(torch.finfo(torch.float16).min, device=mask.device, dtype=dtype),
INFO:nncf:Statistics of the bitwidth distribution:
┍━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑
│ Weight compression mode   │ % all parameters (layers)   │ % ratio-defining parameters (layers)   │
┝━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥
│ int8_asym, per-channel    │ 100% (277 / 277)            │ 100% (277 / 277)                       │
┕━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙
Applying Weight Compression ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% • 0:00:11 • 0:00:00
INFO:nncf:Statistics of the bitwidth distribution:
┍━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑
│ Weight compression mode   │ % all parameters (layers)   │ % ratio-defining parameters (layers)   │
┝━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥
│ int8_asym, per-channel    │ 100% (1 / 1)                │ 100% (1 / 1)                           │
┕━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙
Applying Weight Compression ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% • 0:00:03 • 0:00:00


optimum-cli env:

- `optimum` version: 2.1.0.dev0
- `transformers` version: 5.5.0
- Platform: Linux-6.18.16-18.222.amzn2023.x86_64-x86_64-with-glibc2.41
- Python version: 3.11.15
- Huggingface_hub version: 1.10.1
- PyTorch version (GPU?): 2.11.0+cu130 (cuda available: False)

@OpenArcBob
Copy link
Copy Markdown

OpenArcBob commented Apr 13, 2026

Are google/gemma-4-31B-it and google/gemma-4-26B-A4B-it intentionally unsupported by this PR or is that a bug? If the former then PR title/description is a bit misleading imo, if only E2B (and E4B?) variant support is implemented.
@Schuwi

I was able to export google/gemma-4-26B-A4B-it using Gemma 4 OpenVINO Notebook: https://github.com/openvinotoolkit/openvino_notebooks/blob/latest/notebooks/gemma4/gemma4.ipynb

It uses different PR to optimum-intel from @aleksandr-mokrov
https://github.com/aleksandr-mokrov/optimum-intel.git@gemma4-moe-fixes
that was submitted to @rkazants in PR: rkazants#5

You can test my export of google/gemma-4-26B-A4B-it on HuggingFace:
https://huggingface.co/OpenArcBob/gemma-4-26B-A4B-it-int4-OpenArc
https://huggingface.co/OpenArcBob/gemma-4-26B-A4B-it-int8-OpenArc

OpenArc @SearchSavior https://github.com/SearchSavior/OpenArc

@echarlaix echarlaix deleted the branch huggingface:transformers-v5 April 15, 2026 07:21
@echarlaix echarlaix closed this Apr 15, 2026
github-merge-queue bot pushed a commit to openvinotoolkit/openvino that referenced this pull request Apr 15, 2026
## Summary

Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA
blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such
graphs and leaves other SDPA blocks unfused. The resulting inconsistent
model crashes at inference with `null input states` in
`ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982).

This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose
KV-cache is shared with another SDPA, while still fusing the exclusive
ones in the same model.

Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA
we walk forward from its `past_k` / `past_v` `ReadValue` via
`ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and
count how many `ScaledDotProductAttention` ops are reachable. If more
than one SDPA is reachable from either the K-cache or the V-cache side,
the callback returns `false` and this particular SDPA is left unfused —
other SDPAs in the same model are still fused normally.

## Relation to #35260

#35260 addresses the same crash with a simpler check: it looks at the
direct K/V input nodes of each SDPA (`input_value(1).get_node()`,
`input_value(2).get_node()`) and counts how many SDPAs reference the
same node.

## Why the direct-input check is fragile

It only works when there are no intermediate ops between the shared
source of the KV-cache and the SDPA. In the idealized shape it expects,
the graph looks like:

```
ReadValue → Concat → SDPA1
              │
              └──→ SDPA2      ← same Concat object is the direct K input of both SDPAs ✓
```

Here SDPA1 and SDPA2 literally share the same Concat node pointer as
their K input, so the check sees them as sharing and skips the fusion.

In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and
depending on which version of transformers / optimum-intel was used to
export the model — almost every path from the shared source to an SDPA
carries some intermediate op: Transpose, Reshape, Convert, Gather,
Broadcast, and so on. The graph typically looks like this instead:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

Now SDPA1's direct K input is `Transpose_A` and SDPA2's is
`Transpose_B`. Even if both Transposes have identical parameters they
are distinct `ov::Node` objects, so the "same direct input" comparison
returns `false`. The sharing is no longer detected, `StatefulSDPAFusion`
runs, partially fuses the graph, and the model crashes at runtime with
`null input states` in `ScaledDotProductAttentionWithKVCache`.

The set of intermediate ops on the K/V path is not stable either — it
depends on which earlier passes (TransposeSinking,
SimplifyGatherShapeOf,
shape-inference rewrites, etc.) have already run, and on small
differences in how the model was exported. This makes the direct-input
check behave differently across otherwise equivalent Gemma-style models.

## The fix: walk forward from the ReadValue, decide per SDPA

Two changes vs. #35260:

1. **Anchor the check on the `ReadValue`, not on direct-input node
   pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's
   direct K/V inputs, we walk the graph forward from the matched SDPA's
   `past_k` / `past_v` `ReadValue` and count how many SDPAs are
   reachable. The traversal passes through any non-SDPA op (Transpose,
   Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA
   boundaries, so intermediate topology does not hide the sharing.

2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion`
   globally as soon as the model contains a single shared-KV-cache SDPA
   is overly conservative: a model can mix SDPAs that share a cache with
   SDPAs that do not, and the exclusive ones would lose their fusion
   unnecessarily. The decision is made per SDPA inside the matcher
   callback, so exclusive SDPAs in the same model are still fused.

## Algorithm

Inside `StatefulSDPAFusion`'s callback, for the matched SDPA:

- Take the matched `past_k` and `past_v` `ReadValue` nodes from the
  pattern map.
- For each of them, call `ov::op::util::visit_path_forward` with a
  `skip_node_predicate` that returns `true` on
  `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op
  and halts at SDPA boundaries.
- Count `ScaledDotProductAttention` nodes in the resulting `visited`
  set.
- If the count is greater than 1 for either `past_k` or `past_v`, the
  matched SDPA shares its cache with at least one other SDPA — return
  `false` from the callback and leave it unfused.

On the problematic graph:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

the BFS from the `ReadValue` walks through Concat and both Transposes,
reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An
SDPA on a different, unshared `ReadValue` in the same model is still
fused as before.

In short: #35260 asks "do these SDPAs have the same neighbor?", and the
answer depends on the current shape of the graph. This PR asks "from
this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is
independent of intermediate topology, and the decision is taken per
SDPA, so the fix is minimal in scope.

## Verification

- Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`):
all 7 tests pass, including a new
`StateConcatSDPAMixedSharedAndExclusive`
  that builds a model with one `ReadValue` feeding two SDPAs (shared)
  and another `ReadValue` feeding a single SDPA (exclusive), and asserts
  that after `SDPASubgraphFusion` exactly one
  `ScaledDotProductAttentionWithKVCache` is produced while the two
  shared SDPAs remain plain `ScaledDotProductAttention`.
- Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650
  at commit `698bfcec`, which predates the optimum-side workaround that
  replaces SDPA with matmul. Without this PR the graph still contains
  `ScaledDotProductAttention` ops and the crash reproduces; with this
  PR inference succeeds.
- Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675.

### Tickets:
 - 183493

---------

Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
@rkazants rkazants mentioned this pull request Apr 15, 2026
3 tasks
github-merge-queue bot pushed a commit to openvinotoolkit/openvino that referenced this pull request Apr 16, 2026
## Summary

Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA
blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such
graphs and leaves other SDPA blocks unfused. The resulting inconsistent
model crashes at inference with `null input states` in
`ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982).

This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose
KV-cache is shared with another SDPA, while still fusing the exclusive
ones in the same model.

Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA
we walk forward from its `past_k` / `past_v` `ReadValue` via
`ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and
count how many `ScaledDotProductAttention` ops are reachable. If more
than one SDPA is reachable from either the K-cache or the V-cache side,
the callback returns `false` and this particular SDPA is left unfused —
other SDPAs in the same model are still fused normally.

## Relation to #35260

#35260 addresses the same crash with a simpler check: it looks at the
direct K/V input nodes of each SDPA (`input_value(1).get_node()`,
`input_value(2).get_node()`) and counts how many SDPAs reference the
same node.

## Why the direct-input check is fragile

It only works when there are no intermediate ops between the shared
source of the KV-cache and the SDPA. In the idealized shape it expects,
the graph looks like:

```
ReadValue → Concat → SDPA1
              │
              └──→ SDPA2      ← same Concat object is the direct K input of both SDPAs ✓
```

Here SDPA1 and SDPA2 literally share the same Concat node pointer as
their K input, so the check sees them as sharing and skips the fusion.

In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and
depending on which version of transformers / optimum-intel was used to
export the model — almost every path from the shared source to an SDPA
carries some intermediate op: Transpose, Reshape, Convert, Gather,
Broadcast, and so on. The graph typically looks like this instead:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

Now SDPA1's direct K input is `Transpose_A` and SDPA2's is
`Transpose_B`. Even if both Transposes have identical parameters they
are distinct `ov::Node` objects, so the "same direct input" comparison
returns `false`. The sharing is no longer detected, `StatefulSDPAFusion`
runs, partially fuses the graph, and the model crashes at runtime with
`null input states` in `ScaledDotProductAttentionWithKVCache`.

The set of intermediate ops on the K/V path is not stable either — it
depends on which earlier passes (TransposeSinking,
SimplifyGatherShapeOf,
shape-inference rewrites, etc.) have already run, and on small
differences in how the model was exported. This makes the direct-input
check behave differently across otherwise equivalent Gemma-style models.

## The fix: walk forward from the ReadValue, decide per SDPA

Two changes vs. #35260:

1. **Anchor the check on the `ReadValue`, not on direct-input node
   pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's
   direct K/V inputs, we walk the graph forward from the matched SDPA's
   `past_k` / `past_v` `ReadValue` and count how many SDPAs are
   reachable. The traversal passes through any non-SDPA op (Transpose,
   Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA
   boundaries, so intermediate topology does not hide the sharing.

2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion`
   globally as soon as the model contains a single shared-KV-cache SDPA
   is overly conservative: a model can mix SDPAs that share a cache with
   SDPAs that do not, and the exclusive ones would lose their fusion
   unnecessarily. The decision is made per SDPA inside the matcher
   callback, so exclusive SDPAs in the same model are still fused.

## Algorithm

Inside `StatefulSDPAFusion`'s callback, for the matched SDPA:

- Take the matched `past_k` and `past_v` `ReadValue` nodes from the
  pattern map.
- For each of them, call `ov::op::util::visit_path_forward` with a
  `skip_node_predicate` that returns `true` on
  `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op
  and halts at SDPA boundaries.
- Count `ScaledDotProductAttention` nodes in the resulting `visited`
  set.
- If the count is greater than 1 for either `past_k` or `past_v`, the
  matched SDPA shares its cache with at least one other SDPA — return
  `false` from the callback and leave it unfused.

On the problematic graph:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

the BFS from the `ReadValue` walks through Concat and both Transposes,
reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An
SDPA on a different, unshared `ReadValue` in the same model is still
fused as before.

In short: #35260 asks "do these SDPAs have the same neighbor?", and the
answer depends on the current shape of the graph. This PR asks "from
this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is
independent of intermediate topology, and the decision is taken per
SDPA, so the fix is minimal in scope.

## Verification

- Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`):
all 7 tests pass, including a new
`StateConcatSDPAMixedSharedAndExclusive`
  that builds a model with one `ReadValue` feeding two SDPAs (shared)
  and another `ReadValue` feeding a single SDPA (exclusive), and asserts
  that after `SDPASubgraphFusion` exactly one
  `ScaledDotProductAttentionWithKVCache` is produced while the two
  shared SDPAs remain plain `ScaledDotProductAttention`.
- Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650
  at commit `698bfcec`, which predates the optimum-side workaround that
  replaces SDPA with matmul. Without this PR the graph still contains
  `ScaledDotProductAttention` ops and the crash reproduces; with this
  PR inference succeeds.
- Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675.

### Tickets:
 - 183493

---------

Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
github-merge-queue bot pushed a commit to openvinotoolkit/openvino that referenced this pull request Apr 16, 2026
## Summary

Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA
blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such
graphs and leaves other SDPA blocks unfused. The resulting inconsistent
model crashes at inference with `null input states` in
`ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982).

This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose
KV-cache is shared with another SDPA, while still fusing the exclusive
ones in the same model.

Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA
we walk forward from its `past_k` / `past_v` `ReadValue` via
`ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and
count how many `ScaledDotProductAttention` ops are reachable. If more
than one SDPA is reachable from either the K-cache or the V-cache side,
the callback returns `false` and this particular SDPA is left unfused —
other SDPAs in the same model are still fused normally.

## Relation to #35260

#35260 addresses the same crash with a simpler check: it looks at the
direct K/V input nodes of each SDPA (`input_value(1).get_node()`,
`input_value(2).get_node()`) and counts how many SDPAs reference the
same node.

## Why the direct-input check is fragile

It only works when there are no intermediate ops between the shared
source of the KV-cache and the SDPA. In the idealized shape it expects,
the graph looks like:

```
ReadValue → Concat → SDPA1
              │
              └──→ SDPA2      ← same Concat object is the direct K input of both SDPAs ✓
```

Here SDPA1 and SDPA2 literally share the same Concat node pointer as
their K input, so the check sees them as sharing and skips the fusion.

In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and
depending on which version of transformers / optimum-intel was used to
export the model — almost every path from the shared source to an SDPA
carries some intermediate op: Transpose, Reshape, Convert, Gather,
Broadcast, and so on. The graph typically looks like this instead:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

Now SDPA1's direct K input is `Transpose_A` and SDPA2's is
`Transpose_B`. Even if both Transposes have identical parameters they
are distinct `ov::Node` objects, so the "same direct input" comparison
returns `false`. The sharing is no longer detected, `StatefulSDPAFusion`
runs, partially fuses the graph, and the model crashes at runtime with
`null input states` in `ScaledDotProductAttentionWithKVCache`.

The set of intermediate ops on the K/V path is not stable either — it
depends on which earlier passes (TransposeSinking,
SimplifyGatherShapeOf,
shape-inference rewrites, etc.) have already run, and on small
differences in how the model was exported. This makes the direct-input
check behave differently across otherwise equivalent Gemma-style models.

## The fix: walk forward from the ReadValue, decide per SDPA

Two changes vs. #35260:

1. **Anchor the check on the `ReadValue`, not on direct-input node
   pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's
   direct K/V inputs, we walk the graph forward from the matched SDPA's
   `past_k` / `past_v` `ReadValue` and count how many SDPAs are
   reachable. The traversal passes through any non-SDPA op (Transpose,
   Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA
   boundaries, so intermediate topology does not hide the sharing.

2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion`
   globally as soon as the model contains a single shared-KV-cache SDPA
   is overly conservative: a model can mix SDPAs that share a cache with
   SDPAs that do not, and the exclusive ones would lose their fusion
   unnecessarily. The decision is made per SDPA inside the matcher
   callback, so exclusive SDPAs in the same model are still fused.

## Algorithm

Inside `StatefulSDPAFusion`'s callback, for the matched SDPA:

- Take the matched `past_k` and `past_v` `ReadValue` nodes from the
  pattern map.
- For each of them, call `ov::op::util::visit_path_forward` with a
  `skip_node_predicate` that returns `true` on
  `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op
  and halts at SDPA boundaries.
- Count `ScaledDotProductAttention` nodes in the resulting `visited`
  set.
- If the count is greater than 1 for either `past_k` or `past_v`, the
  matched SDPA shares its cache with at least one other SDPA — return
  `false` from the callback and leave it unfused.

On the problematic graph:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

the BFS from the `ReadValue` walks through Concat and both Transposes,
reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An
SDPA on a different, unshared `ReadValue` in the same model is still
fused as before.

In short: #35260 asks "do these SDPAs have the same neighbor?", and the
answer depends on the current shape of the graph. This PR asks "from
this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is
independent of intermediate topology, and the decision is taken per
SDPA, so the fix is minimal in scope.

## Verification

- Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`):
all 7 tests pass, including a new
`StateConcatSDPAMixedSharedAndExclusive`
  that builds a model with one `ReadValue` feeding two SDPAs (shared)
  and another `ReadValue` feeding a single SDPA (exclusive), and asserts
  that after `SDPASubgraphFusion` exactly one
  `ScaledDotProductAttentionWithKVCache` is produced while the two
  shared SDPAs remain plain `ScaledDotProductAttention`.
- Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650
  at commit `698bfcec`, which predates the optimum-side workaround that
  replaces SDPA with matmul. Without this PR the graph still contains
  `ScaledDotProductAttention` ops and the crash reproduces; with this
  PR inference succeeds.
- Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675.

### Tickets:
 - 183493

---------

Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
github-merge-queue bot pushed a commit to openvinotoolkit/openvino that referenced this pull request Apr 16, 2026
## Summary

Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA
blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such
graphs and leaves other SDPA blocks unfused. The resulting inconsistent
model crashes at inference with `null input states` in
`ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982).

This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose
KV-cache is shared with another SDPA, while still fusing the exclusive
ones in the same model.

Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA
we walk forward from its `past_k` / `past_v` `ReadValue` via
`ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and
count how many `ScaledDotProductAttention` ops are reachable. If more
than one SDPA is reachable from either the K-cache or the V-cache side,
the callback returns `false` and this particular SDPA is left unfused —
other SDPAs in the same model are still fused normally.

## Relation to #35260

#35260 addresses the same crash with a simpler check: it looks at the
direct K/V input nodes of each SDPA (`input_value(1).get_node()`,
`input_value(2).get_node()`) and counts how many SDPAs reference the
same node.

## Why the direct-input check is fragile

It only works when there are no intermediate ops between the shared
source of the KV-cache and the SDPA. In the idealized shape it expects,
the graph looks like:

```
ReadValue → Concat → SDPA1
              │
              └──→ SDPA2      ← same Concat object is the direct K input of both SDPAs ✓
```

Here SDPA1 and SDPA2 literally share the same Concat node pointer as
their K input, so the check sees them as sharing and skips the fusion.

In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and
depending on which version of transformers / optimum-intel was used to
export the model — almost every path from the shared source to an SDPA
carries some intermediate op: Transpose, Reshape, Convert, Gather,
Broadcast, and so on. The graph typically looks like this instead:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

Now SDPA1's direct K input is `Transpose_A` and SDPA2's is
`Transpose_B`. Even if both Transposes have identical parameters they
are distinct `ov::Node` objects, so the "same direct input" comparison
returns `false`. The sharing is no longer detected, `StatefulSDPAFusion`
runs, partially fuses the graph, and the model crashes at runtime with
`null input states` in `ScaledDotProductAttentionWithKVCache`.

The set of intermediate ops on the K/V path is not stable either — it
depends on which earlier passes (TransposeSinking,
SimplifyGatherShapeOf,
shape-inference rewrites, etc.) have already run, and on small
differences in how the model was exported. This makes the direct-input
check behave differently across otherwise equivalent Gemma-style models.

## The fix: walk forward from the ReadValue, decide per SDPA

Two changes vs. #35260:

1. **Anchor the check on the `ReadValue`, not on direct-input node
   pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's
   direct K/V inputs, we walk the graph forward from the matched SDPA's
   `past_k` / `past_v` `ReadValue` and count how many SDPAs are
   reachable. The traversal passes through any non-SDPA op (Transpose,
   Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA
   boundaries, so intermediate topology does not hide the sharing.

2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion`
   globally as soon as the model contains a single shared-KV-cache SDPA
   is overly conservative: a model can mix SDPAs that share a cache with
   SDPAs that do not, and the exclusive ones would lose their fusion
   unnecessarily. The decision is made per SDPA inside the matcher
   callback, so exclusive SDPAs in the same model are still fused.

## Algorithm

Inside `StatefulSDPAFusion`'s callback, for the matched SDPA:

- Take the matched `past_k` and `past_v` `ReadValue` nodes from the
  pattern map.
- For each of them, call `ov::op::util::visit_path_forward` with a
  `skip_node_predicate` that returns `true` on
  `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op
  and halts at SDPA boundaries.
- Count `ScaledDotProductAttention` nodes in the resulting `visited`
  set.
- If the count is greater than 1 for either `past_k` or `past_v`, the
  matched SDPA shares its cache with at least one other SDPA — return
  `false` from the callback and leave it unfused.

On the problematic graph:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

the BFS from the `ReadValue` walks through Concat and both Transposes,
reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An
SDPA on a different, unshared `ReadValue` in the same model is still
fused as before.

In short: #35260 asks "do these SDPAs have the same neighbor?", and the
answer depends on the current shape of the graph. This PR asks "from
this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is
independent of intermediate topology, and the decision is taken per
SDPA, so the fix is minimal in scope.

## Verification

- Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`):
all 7 tests pass, including a new
`StateConcatSDPAMixedSharedAndExclusive`
  that builds a model with one `ReadValue` feeding two SDPAs (shared)
  and another `ReadValue` feeding a single SDPA (exclusive), and asserts
  that after `SDPASubgraphFusion` exactly one
  `ScaledDotProductAttentionWithKVCache` is produced while the two
  shared SDPAs remain plain `ScaledDotProductAttention`.
- Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650
  at commit `698bfcec`, which predates the optimum-side workaround that
  replaces SDPA with matmul. Without this PR the graph still contains
  `ScaledDotProductAttention` ops and the crash reproduces; with this
  PR inference succeeds.
- Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675.

### Tickets:
 - 183493

---------

Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
github-merge-queue bot pushed a commit to openvinotoolkit/openvino that referenced this pull request Apr 16, 2026
## Summary

Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA
blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such
graphs and leaves other SDPA blocks unfused. The resulting inconsistent
model crashes at inference with `null input states` in
`ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982).

This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose
KV-cache is shared with another SDPA, while still fusing the exclusive
ones in the same model.

Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA
we walk forward from its `past_k` / `past_v` `ReadValue` via
`ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and
count how many `ScaledDotProductAttention` ops are reachable. If more
than one SDPA is reachable from either the K-cache or the V-cache side,
the callback returns `false` and this particular SDPA is left unfused —
other SDPAs in the same model are still fused normally.

## Relation to #35260

#35260 addresses the same crash with a simpler check: it looks at the
direct K/V input nodes of each SDPA (`input_value(1).get_node()`,
`input_value(2).get_node()`) and counts how many SDPAs reference the
same node.

## Why the direct-input check is fragile

It only works when there are no intermediate ops between the shared
source of the KV-cache and the SDPA. In the idealized shape it expects,
the graph looks like:

```
ReadValue → Concat → SDPA1
              │
              └──→ SDPA2      ← same Concat object is the direct K input of both SDPAs ✓
```

Here SDPA1 and SDPA2 literally share the same Concat node pointer as
their K input, so the check sees them as sharing and skips the fusion.

In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and
depending on which version of transformers / optimum-intel was used to
export the model — almost every path from the shared source to an SDPA
carries some intermediate op: Transpose, Reshape, Convert, Gather,
Broadcast, and so on. The graph typically looks like this instead:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

Now SDPA1's direct K input is `Transpose_A` and SDPA2's is
`Transpose_B`. Even if both Transposes have identical parameters they
are distinct `ov::Node` objects, so the "same direct input" comparison
returns `false`. The sharing is no longer detected, `StatefulSDPAFusion`
runs, partially fuses the graph, and the model crashes at runtime with
`null input states` in `ScaledDotProductAttentionWithKVCache`.

The set of intermediate ops on the K/V path is not stable either — it
depends on which earlier passes (TransposeSinking,
SimplifyGatherShapeOf,
shape-inference rewrites, etc.) have already run, and on small
differences in how the model was exported. This makes the direct-input
check behave differently across otherwise equivalent Gemma-style models.

## The fix: walk forward from the ReadValue, decide per SDPA

Two changes vs. #35260:

1. **Anchor the check on the `ReadValue`, not on direct-input node
   pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's
   direct K/V inputs, we walk the graph forward from the matched SDPA's
   `past_k` / `past_v` `ReadValue` and count how many SDPAs are
   reachable. The traversal passes through any non-SDPA op (Transpose,
   Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA
   boundaries, so intermediate topology does not hide the sharing.

2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion`
   globally as soon as the model contains a single shared-KV-cache SDPA
   is overly conservative: a model can mix SDPAs that share a cache with
   SDPAs that do not, and the exclusive ones would lose their fusion
   unnecessarily. The decision is made per SDPA inside the matcher
   callback, so exclusive SDPAs in the same model are still fused.

## Algorithm

Inside `StatefulSDPAFusion`'s callback, for the matched SDPA:

- Take the matched `past_k` and `past_v` `ReadValue` nodes from the
  pattern map.
- For each of them, call `ov::op::util::visit_path_forward` with a
  `skip_node_predicate` that returns `true` on
  `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op
  and halts at SDPA boundaries.
- Count `ScaledDotProductAttention` nodes in the resulting `visited`
  set.
- If the count is greater than 1 for either `past_k` or `past_v`, the
  matched SDPA shares its cache with at least one other SDPA — return
  `false` from the callback and leave it unfused.

On the problematic graph:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

the BFS from the `ReadValue` walks through Concat and both Transposes,
reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An
SDPA on a different, unshared `ReadValue` in the same model is still
fused as before.

In short: #35260 asks "do these SDPAs have the same neighbor?", and the
answer depends on the current shape of the graph. This PR asks "from
this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is
independent of intermediate topology, and the decision is taken per
SDPA, so the fix is minimal in scope.

## Verification

- Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`):
all 7 tests pass, including a new
`StateConcatSDPAMixedSharedAndExclusive`
  that builds a model with one `ReadValue` feeding two SDPAs (shared)
  and another `ReadValue` feeding a single SDPA (exclusive), and asserts
  that after `SDPASubgraphFusion` exactly one
  `ScaledDotProductAttentionWithKVCache` is produced while the two
  shared SDPAs remain plain `ScaledDotProductAttention`.
- Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650
  at commit `698bfcec`, which predates the optimum-side workaround that
  replaces SDPA with matmul. Without this PR the graph still contains
  `ScaledDotProductAttention` ops and the crash reproduces; with this
  PR inference succeeds.
- Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675.

### Tickets:
 - 183493

---------

Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
praasz pushed a commit to praasz/openvino that referenced this pull request Apr 20, 2026
…kit#35323)

## Summary

Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA
blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such
graphs and leaves other SDPA blocks unfused. The resulting inconsistent
model crashes at inference with `null input states` in
`ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982).

This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose
KV-cache is shared with another SDPA, while still fusing the exclusive
ones in the same model.

Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA
we walk forward from its `past_k` / `past_v` `ReadValue` via
`ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and
count how many `ScaledDotProductAttention` ops are reachable. If more
than one SDPA is reachable from either the K-cache or the V-cache side,
the callback returns `false` and this particular SDPA is left unfused —
other SDPAs in the same model are still fused normally.

## Relation to openvinotoolkit#35260

openvinotoolkit#35260 addresses the same crash with a simpler check: it looks at the
direct K/V input nodes of each SDPA (`input_value(1).get_node()`,
`input_value(2).get_node()`) and counts how many SDPAs reference the
same node.

## Why the direct-input check is fragile

It only works when there are no intermediate ops between the shared
source of the KV-cache and the SDPA. In the idealized shape it expects,
the graph looks like:

```
ReadValue → Concat → SDPA1
              │
              └──→ SDPA2      ← same Concat object is the direct K input of both SDPAs ✓
```

Here SDPA1 and SDPA2 literally share the same Concat node pointer as
their K input, so the check sees them as sharing and skips the fusion.

In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and
depending on which version of transformers / optimum-intel was used to
export the model — almost every path from the shared source to an SDPA
carries some intermediate op: Transpose, Reshape, Convert, Gather,
Broadcast, and so on. The graph typically looks like this instead:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

Now SDPA1's direct K input is `Transpose_A` and SDPA2's is
`Transpose_B`. Even if both Transposes have identical parameters they
are distinct `ov::Node` objects, so the "same direct input" comparison
returns `false`. The sharing is no longer detected, `StatefulSDPAFusion`
runs, partially fuses the graph, and the model crashes at runtime with
`null input states` in `ScaledDotProductAttentionWithKVCache`.

The set of intermediate ops on the K/V path is not stable either — it
depends on which earlier passes (TransposeSinking,
SimplifyGatherShapeOf,
shape-inference rewrites, etc.) have already run, and on small
differences in how the model was exported. This makes the direct-input
check behave differently across otherwise equivalent Gemma-style models.

## The fix: walk forward from the ReadValue, decide per SDPA

Two changes vs. openvinotoolkit#35260:

1. **Anchor the check on the `ReadValue`, not on direct-input node
   pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's
   direct K/V inputs, we walk the graph forward from the matched SDPA's
   `past_k` / `past_v` `ReadValue` and count how many SDPAs are
   reachable. The traversal passes through any non-SDPA op (Transpose,
   Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA
   boundaries, so intermediate topology does not hide the sharing.

2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion`
   globally as soon as the model contains a single shared-KV-cache SDPA
   is overly conservative: a model can mix SDPAs that share a cache with
   SDPAs that do not, and the exclusive ones would lose their fusion
   unnecessarily. The decision is made per SDPA inside the matcher
   callback, so exclusive SDPAs in the same model are still fused.

## Algorithm

Inside `StatefulSDPAFusion`'s callback, for the matched SDPA:

- Take the matched `past_k` and `past_v` `ReadValue` nodes from the
  pattern map.
- For each of them, call `ov::op::util::visit_path_forward` with a
  `skip_node_predicate` that returns `true` on
  `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op
  and halts at SDPA boundaries.
- Count `ScaledDotProductAttention` nodes in the resulting `visited`
  set.
- If the count is greater than 1 for either `past_k` or `past_v`, the
  matched SDPA shares its cache with at least one other SDPA — return
  `false` from the callback and leave it unfused.

On the problematic graph:

```
ReadValue → Concat ──┬──→ Transpose_A → SDPA1
                     │
                     └──→ Transpose_B → SDPA2
```

the BFS from the `ReadValue` walks through Concat and both Transposes,
reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An
SDPA on a different, unshared `ReadValue` in the same model is still
fused as before.

In short: openvinotoolkit#35260 asks "do these SDPAs have the same neighbor?", and the
answer depends on the current shape of the graph. This PR asks "from
this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is
independent of intermediate topology, and the decision is taken per
SDPA, so the fix is minimal in scope.

## Verification

- Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`):
all 7 tests pass, including a new
`StateConcatSDPAMixedSharedAndExclusive`
  that builds a model with one `ReadValue` feeding two SDPAs (shared)
  and another `ReadValue` feeding a single SDPA (exclusive), and asserts
  that after `SDPASubgraphFusion` exactly one
  `ScaledDotProductAttentionWithKVCache` is produced while the two
  shared SDPAs remain plain `ScaledDotProductAttention`.
- Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650
  at commit `698bfcec`, which predates the optimum-side workaround that
  replaces SDPA with matmul. Without this PR the graph still contains
  `ScaledDotProductAttention` ops and the crash reproduces; with this
  PR inference succeeds.
- Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675.

### Tickets:
 - 183493

---------

Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants