[OpenVINO] Support Gemma 4#1675
[OpenVINO] Support Gemma 4#1675rkazants wants to merge 2 commits intohuggingface:transformers-v5from
Conversation
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
|
@rkazants thanks for all your work on this (and other models like qwen3.5) Not all heros wear capes |
|
They left us out of the release... But @rkazants did not forget! |
|
Omg, thank you!!! |
|
Are My test system was running a docker container with image built from this and I ran into these errors during export: Gemma 4 31B
Gemma 4 26B A4B
For comparison, the E2B export was successful: Gemma 4 E2B
|
I was able to export It uses different PR to optimum-intel from @aleksandr-mokrov You can test my export of OpenArc @SearchSavior https://github.com/SearchSavior/OpenArc |
## Summary Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such graphs and leaves other SDPA blocks unfused. The resulting inconsistent model crashes at inference with `null input states` in `ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982). This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose KV-cache is shared with another SDPA, while still fusing the exclusive ones in the same model. Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA we walk forward from its `past_k` / `past_v` `ReadValue` via `ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and count how many `ScaledDotProductAttention` ops are reachable. If more than one SDPA is reachable from either the K-cache or the V-cache side, the callback returns `false` and this particular SDPA is left unfused — other SDPAs in the same model are still fused normally. ## Relation to #35260 #35260 addresses the same crash with a simpler check: it looks at the direct K/V input nodes of each SDPA (`input_value(1).get_node()`, `input_value(2).get_node()`) and counts how many SDPAs reference the same node. ## Why the direct-input check is fragile It only works when there are no intermediate ops between the shared source of the KV-cache and the SDPA. In the idealized shape it expects, the graph looks like: ``` ReadValue → Concat → SDPA1 │ └──→ SDPA2 ← same Concat object is the direct K input of both SDPAs ✓ ``` Here SDPA1 and SDPA2 literally share the same Concat node pointer as their K input, so the check sees them as sharing and skips the fusion. In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and depending on which version of transformers / optimum-intel was used to export the model — almost every path from the shared source to an SDPA carries some intermediate op: Transpose, Reshape, Convert, Gather, Broadcast, and so on. The graph typically looks like this instead: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` Now SDPA1's direct K input is `Transpose_A` and SDPA2's is `Transpose_B`. Even if both Transposes have identical parameters they are distinct `ov::Node` objects, so the "same direct input" comparison returns `false`. The sharing is no longer detected, `StatefulSDPAFusion` runs, partially fuses the graph, and the model crashes at runtime with `null input states` in `ScaledDotProductAttentionWithKVCache`. The set of intermediate ops on the K/V path is not stable either — it depends on which earlier passes (TransposeSinking, SimplifyGatherShapeOf, shape-inference rewrites, etc.) have already run, and on small differences in how the model was exported. This makes the direct-input check behave differently across otherwise equivalent Gemma-style models. ## The fix: walk forward from the ReadValue, decide per SDPA Two changes vs. #35260: 1. **Anchor the check on the `ReadValue`, not on direct-input node pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's direct K/V inputs, we walk the graph forward from the matched SDPA's `past_k` / `past_v` `ReadValue` and count how many SDPAs are reachable. The traversal passes through any non-SDPA op (Transpose, Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA boundaries, so intermediate topology does not hide the sharing. 2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion` globally as soon as the model contains a single shared-KV-cache SDPA is overly conservative: a model can mix SDPAs that share a cache with SDPAs that do not, and the exclusive ones would lose their fusion unnecessarily. The decision is made per SDPA inside the matcher callback, so exclusive SDPAs in the same model are still fused. ## Algorithm Inside `StatefulSDPAFusion`'s callback, for the matched SDPA: - Take the matched `past_k` and `past_v` `ReadValue` nodes from the pattern map. - For each of them, call `ov::op::util::visit_path_forward` with a `skip_node_predicate` that returns `true` on `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op and halts at SDPA boundaries. - Count `ScaledDotProductAttention` nodes in the resulting `visited` set. - If the count is greater than 1 for either `past_k` or `past_v`, the matched SDPA shares its cache with at least one other SDPA — return `false` from the callback and leave it unfused. On the problematic graph: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` the BFS from the `ReadValue` walks through Concat and both Transposes, reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An SDPA on a different, unshared `ReadValue` in the same model is still fused as before. In short: #35260 asks "do these SDPAs have the same neighbor?", and the answer depends on the current shape of the graph. This PR asks "from this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is independent of intermediate topology, and the decision is taken per SDPA, so the fix is minimal in scope. ## Verification - Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`): all 7 tests pass, including a new `StateConcatSDPAMixedSharedAndExclusive` that builds a model with one `ReadValue` feeding two SDPAs (shared) and another `ReadValue` feeding a single SDPA (exclusive), and asserts that after `SDPASubgraphFusion` exactly one `ScaledDotProductAttentionWithKVCache` is produced while the two shared SDPAs remain plain `ScaledDotProductAttention`. - Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650 at commit `698bfcec`, which predates the optimum-side workaround that replaces SDPA with matmul. Without this PR the graph still contains `ScaledDotProductAttention` ops and the crash reproduces; with this PR inference succeeds. - Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675. ### Tickets: - 183493 --------- Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
## Summary Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such graphs and leaves other SDPA blocks unfused. The resulting inconsistent model crashes at inference with `null input states` in `ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982). This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose KV-cache is shared with another SDPA, while still fusing the exclusive ones in the same model. Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA we walk forward from its `past_k` / `past_v` `ReadValue` via `ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and count how many `ScaledDotProductAttention` ops are reachable. If more than one SDPA is reachable from either the K-cache or the V-cache side, the callback returns `false` and this particular SDPA is left unfused — other SDPAs in the same model are still fused normally. ## Relation to #35260 #35260 addresses the same crash with a simpler check: it looks at the direct K/V input nodes of each SDPA (`input_value(1).get_node()`, `input_value(2).get_node()`) and counts how many SDPAs reference the same node. ## Why the direct-input check is fragile It only works when there are no intermediate ops between the shared source of the KV-cache and the SDPA. In the idealized shape it expects, the graph looks like: ``` ReadValue → Concat → SDPA1 │ └──→ SDPA2 ← same Concat object is the direct K input of both SDPAs ✓ ``` Here SDPA1 and SDPA2 literally share the same Concat node pointer as their K input, so the check sees them as sharing and skips the fusion. In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and depending on which version of transformers / optimum-intel was used to export the model — almost every path from the shared source to an SDPA carries some intermediate op: Transpose, Reshape, Convert, Gather, Broadcast, and so on. The graph typically looks like this instead: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` Now SDPA1's direct K input is `Transpose_A` and SDPA2's is `Transpose_B`. Even if both Transposes have identical parameters they are distinct `ov::Node` objects, so the "same direct input" comparison returns `false`. The sharing is no longer detected, `StatefulSDPAFusion` runs, partially fuses the graph, and the model crashes at runtime with `null input states` in `ScaledDotProductAttentionWithKVCache`. The set of intermediate ops on the K/V path is not stable either — it depends on which earlier passes (TransposeSinking, SimplifyGatherShapeOf, shape-inference rewrites, etc.) have already run, and on small differences in how the model was exported. This makes the direct-input check behave differently across otherwise equivalent Gemma-style models. ## The fix: walk forward from the ReadValue, decide per SDPA Two changes vs. #35260: 1. **Anchor the check on the `ReadValue`, not on direct-input node pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's direct K/V inputs, we walk the graph forward from the matched SDPA's `past_k` / `past_v` `ReadValue` and count how many SDPAs are reachable. The traversal passes through any non-SDPA op (Transpose, Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA boundaries, so intermediate topology does not hide the sharing. 2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion` globally as soon as the model contains a single shared-KV-cache SDPA is overly conservative: a model can mix SDPAs that share a cache with SDPAs that do not, and the exclusive ones would lose their fusion unnecessarily. The decision is made per SDPA inside the matcher callback, so exclusive SDPAs in the same model are still fused. ## Algorithm Inside `StatefulSDPAFusion`'s callback, for the matched SDPA: - Take the matched `past_k` and `past_v` `ReadValue` nodes from the pattern map. - For each of them, call `ov::op::util::visit_path_forward` with a `skip_node_predicate` that returns `true` on `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op and halts at SDPA boundaries. - Count `ScaledDotProductAttention` nodes in the resulting `visited` set. - If the count is greater than 1 for either `past_k` or `past_v`, the matched SDPA shares its cache with at least one other SDPA — return `false` from the callback and leave it unfused. On the problematic graph: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` the BFS from the `ReadValue` walks through Concat and both Transposes, reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An SDPA on a different, unshared `ReadValue` in the same model is still fused as before. In short: #35260 asks "do these SDPAs have the same neighbor?", and the answer depends on the current shape of the graph. This PR asks "from this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is independent of intermediate topology, and the decision is taken per SDPA, so the fix is minimal in scope. ## Verification - Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`): all 7 tests pass, including a new `StateConcatSDPAMixedSharedAndExclusive` that builds a model with one `ReadValue` feeding two SDPAs (shared) and another `ReadValue` feeding a single SDPA (exclusive), and asserts that after `SDPASubgraphFusion` exactly one `ScaledDotProductAttentionWithKVCache` is produced while the two shared SDPAs remain plain `ScaledDotProductAttention`. - Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650 at commit `698bfcec`, which predates the optimum-side workaround that replaces SDPA with matmul. Without this PR the graph still contains `ScaledDotProductAttention` ops and the crash reproduces; with this PR inference succeeds. - Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675. ### Tickets: - 183493 --------- Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
## Summary Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such graphs and leaves other SDPA blocks unfused. The resulting inconsistent model crashes at inference with `null input states` in `ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982). This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose KV-cache is shared with another SDPA, while still fusing the exclusive ones in the same model. Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA we walk forward from its `past_k` / `past_v` `ReadValue` via `ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and count how many `ScaledDotProductAttention` ops are reachable. If more than one SDPA is reachable from either the K-cache or the V-cache side, the callback returns `false` and this particular SDPA is left unfused — other SDPAs in the same model are still fused normally. ## Relation to #35260 #35260 addresses the same crash with a simpler check: it looks at the direct K/V input nodes of each SDPA (`input_value(1).get_node()`, `input_value(2).get_node()`) and counts how many SDPAs reference the same node. ## Why the direct-input check is fragile It only works when there are no intermediate ops between the shared source of the KV-cache and the SDPA. In the idealized shape it expects, the graph looks like: ``` ReadValue → Concat → SDPA1 │ └──→ SDPA2 ← same Concat object is the direct K input of both SDPAs ✓ ``` Here SDPA1 and SDPA2 literally share the same Concat node pointer as their K input, so the check sees them as sharing and skips the fusion. In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and depending on which version of transformers / optimum-intel was used to export the model — almost every path from the shared source to an SDPA carries some intermediate op: Transpose, Reshape, Convert, Gather, Broadcast, and so on. The graph typically looks like this instead: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` Now SDPA1's direct K input is `Transpose_A` and SDPA2's is `Transpose_B`. Even if both Transposes have identical parameters they are distinct `ov::Node` objects, so the "same direct input" comparison returns `false`. The sharing is no longer detected, `StatefulSDPAFusion` runs, partially fuses the graph, and the model crashes at runtime with `null input states` in `ScaledDotProductAttentionWithKVCache`. The set of intermediate ops on the K/V path is not stable either — it depends on which earlier passes (TransposeSinking, SimplifyGatherShapeOf, shape-inference rewrites, etc.) have already run, and on small differences in how the model was exported. This makes the direct-input check behave differently across otherwise equivalent Gemma-style models. ## The fix: walk forward from the ReadValue, decide per SDPA Two changes vs. #35260: 1. **Anchor the check on the `ReadValue`, not on direct-input node pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's direct K/V inputs, we walk the graph forward from the matched SDPA's `past_k` / `past_v` `ReadValue` and count how many SDPAs are reachable. The traversal passes through any non-SDPA op (Transpose, Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA boundaries, so intermediate topology does not hide the sharing. 2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion` globally as soon as the model contains a single shared-KV-cache SDPA is overly conservative: a model can mix SDPAs that share a cache with SDPAs that do not, and the exclusive ones would lose their fusion unnecessarily. The decision is made per SDPA inside the matcher callback, so exclusive SDPAs in the same model are still fused. ## Algorithm Inside `StatefulSDPAFusion`'s callback, for the matched SDPA: - Take the matched `past_k` and `past_v` `ReadValue` nodes from the pattern map. - For each of them, call `ov::op::util::visit_path_forward` with a `skip_node_predicate` that returns `true` on `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op and halts at SDPA boundaries. - Count `ScaledDotProductAttention` nodes in the resulting `visited` set. - If the count is greater than 1 for either `past_k` or `past_v`, the matched SDPA shares its cache with at least one other SDPA — return `false` from the callback and leave it unfused. On the problematic graph: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` the BFS from the `ReadValue` walks through Concat and both Transposes, reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An SDPA on a different, unshared `ReadValue` in the same model is still fused as before. In short: #35260 asks "do these SDPAs have the same neighbor?", and the answer depends on the current shape of the graph. This PR asks "from this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is independent of intermediate topology, and the decision is taken per SDPA, so the fix is minimal in scope. ## Verification - Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`): all 7 tests pass, including a new `StateConcatSDPAMixedSharedAndExclusive` that builds a model with one `ReadValue` feeding two SDPAs (shared) and another `ReadValue` feeding a single SDPA (exclusive), and asserts that after `SDPASubgraphFusion` exactly one `ScaledDotProductAttentionWithKVCache` is produced while the two shared SDPAs remain plain `ScaledDotProductAttention`. - Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650 at commit `698bfcec`, which predates the optimum-side workaround that replaces SDPA with matmul. Without this PR the graph still contains `ScaledDotProductAttention` ops and the crash reproduces; with this PR inference succeeds. - Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675. ### Tickets: - 183493 --------- Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
## Summary Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such graphs and leaves other SDPA blocks unfused. The resulting inconsistent model crashes at inference with `null input states` in `ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982). This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose KV-cache is shared with another SDPA, while still fusing the exclusive ones in the same model. Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA we walk forward from its `past_k` / `past_v` `ReadValue` via `ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and count how many `ScaledDotProductAttention` ops are reachable. If more than one SDPA is reachable from either the K-cache or the V-cache side, the callback returns `false` and this particular SDPA is left unfused — other SDPAs in the same model are still fused normally. ## Relation to #35260 #35260 addresses the same crash with a simpler check: it looks at the direct K/V input nodes of each SDPA (`input_value(1).get_node()`, `input_value(2).get_node()`) and counts how many SDPAs reference the same node. ## Why the direct-input check is fragile It only works when there are no intermediate ops between the shared source of the KV-cache and the SDPA. In the idealized shape it expects, the graph looks like: ``` ReadValue → Concat → SDPA1 │ └──→ SDPA2 ← same Concat object is the direct K input of both SDPAs ✓ ``` Here SDPA1 and SDPA2 literally share the same Concat node pointer as their K input, so the check sees them as sharing and skips the fusion. In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and depending on which version of transformers / optimum-intel was used to export the model — almost every path from the shared source to an SDPA carries some intermediate op: Transpose, Reshape, Convert, Gather, Broadcast, and so on. The graph typically looks like this instead: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` Now SDPA1's direct K input is `Transpose_A` and SDPA2's is `Transpose_B`. Even if both Transposes have identical parameters they are distinct `ov::Node` objects, so the "same direct input" comparison returns `false`. The sharing is no longer detected, `StatefulSDPAFusion` runs, partially fuses the graph, and the model crashes at runtime with `null input states` in `ScaledDotProductAttentionWithKVCache`. The set of intermediate ops on the K/V path is not stable either — it depends on which earlier passes (TransposeSinking, SimplifyGatherShapeOf, shape-inference rewrites, etc.) have already run, and on small differences in how the model was exported. This makes the direct-input check behave differently across otherwise equivalent Gemma-style models. ## The fix: walk forward from the ReadValue, decide per SDPA Two changes vs. #35260: 1. **Anchor the check on the `ReadValue`, not on direct-input node pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's direct K/V inputs, we walk the graph forward from the matched SDPA's `past_k` / `past_v` `ReadValue` and count how many SDPAs are reachable. The traversal passes through any non-SDPA op (Transpose, Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA boundaries, so intermediate topology does not hide the sharing. 2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion` globally as soon as the model contains a single shared-KV-cache SDPA is overly conservative: a model can mix SDPAs that share a cache with SDPAs that do not, and the exclusive ones would lose their fusion unnecessarily. The decision is made per SDPA inside the matcher callback, so exclusive SDPAs in the same model are still fused. ## Algorithm Inside `StatefulSDPAFusion`'s callback, for the matched SDPA: - Take the matched `past_k` and `past_v` `ReadValue` nodes from the pattern map. - For each of them, call `ov::op::util::visit_path_forward` with a `skip_node_predicate` that returns `true` on `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op and halts at SDPA boundaries. - Count `ScaledDotProductAttention` nodes in the resulting `visited` set. - If the count is greater than 1 for either `past_k` or `past_v`, the matched SDPA shares its cache with at least one other SDPA — return `false` from the callback and leave it unfused. On the problematic graph: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` the BFS from the `ReadValue` walks through Concat and both Transposes, reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An SDPA on a different, unshared `ReadValue` in the same model is still fused as before. In short: #35260 asks "do these SDPAs have the same neighbor?", and the answer depends on the current shape of the graph. This PR asks "from this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is independent of intermediate topology, and the decision is taken per SDPA, so the fix is minimal in scope. ## Verification - Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`): all 7 tests pass, including a new `StateConcatSDPAMixedSharedAndExclusive` that builds a model with one `ReadValue` feeding two SDPAs (shared) and another `ReadValue` feeding a single SDPA (exclusive), and asserts that after `SDPASubgraphFusion` exactly one `ScaledDotProductAttentionWithKVCache` is produced while the two shared SDPAs remain plain `ScaledDotProductAttention`. - Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650 at commit `698bfcec`, which predates the optimum-side workaround that replaces SDPA with matmul. Without this PR the graph still contains `ScaledDotProductAttention` ops and the crash reproduces; with this PR inference succeeds. - Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675. ### Tickets: - 183493 --------- Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
## Summary Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such graphs and leaves other SDPA blocks unfused. The resulting inconsistent model crashes at inference with `null input states` in `ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982). This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose KV-cache is shared with another SDPA, while still fusing the exclusive ones in the same model. Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA we walk forward from its `past_k` / `past_v` `ReadValue` via `ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and count how many `ScaledDotProductAttention` ops are reachable. If more than one SDPA is reachable from either the K-cache or the V-cache side, the callback returns `false` and this particular SDPA is left unfused — other SDPAs in the same model are still fused normally. ## Relation to #35260 #35260 addresses the same crash with a simpler check: it looks at the direct K/V input nodes of each SDPA (`input_value(1).get_node()`, `input_value(2).get_node()`) and counts how many SDPAs reference the same node. ## Why the direct-input check is fragile It only works when there are no intermediate ops between the shared source of the KV-cache and the SDPA. In the idealized shape it expects, the graph looks like: ``` ReadValue → Concat → SDPA1 │ └──→ SDPA2 ← same Concat object is the direct K input of both SDPAs ✓ ``` Here SDPA1 and SDPA2 literally share the same Concat node pointer as their K input, so the check sees them as sharing and skips the fusion. In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and depending on which version of transformers / optimum-intel was used to export the model — almost every path from the shared source to an SDPA carries some intermediate op: Transpose, Reshape, Convert, Gather, Broadcast, and so on. The graph typically looks like this instead: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` Now SDPA1's direct K input is `Transpose_A` and SDPA2's is `Transpose_B`. Even if both Transposes have identical parameters they are distinct `ov::Node` objects, so the "same direct input" comparison returns `false`. The sharing is no longer detected, `StatefulSDPAFusion` runs, partially fuses the graph, and the model crashes at runtime with `null input states` in `ScaledDotProductAttentionWithKVCache`. The set of intermediate ops on the K/V path is not stable either — it depends on which earlier passes (TransposeSinking, SimplifyGatherShapeOf, shape-inference rewrites, etc.) have already run, and on small differences in how the model was exported. This makes the direct-input check behave differently across otherwise equivalent Gemma-style models. ## The fix: walk forward from the ReadValue, decide per SDPA Two changes vs. #35260: 1. **Anchor the check on the `ReadValue`, not on direct-input node pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's direct K/V inputs, we walk the graph forward from the matched SDPA's `past_k` / `past_v` `ReadValue` and count how many SDPAs are reachable. The traversal passes through any non-SDPA op (Transpose, Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA boundaries, so intermediate topology does not hide the sharing. 2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion` globally as soon as the model contains a single shared-KV-cache SDPA is overly conservative: a model can mix SDPAs that share a cache with SDPAs that do not, and the exclusive ones would lose their fusion unnecessarily. The decision is made per SDPA inside the matcher callback, so exclusive SDPAs in the same model are still fused. ## Algorithm Inside `StatefulSDPAFusion`'s callback, for the matched SDPA: - Take the matched `past_k` and `past_v` `ReadValue` nodes from the pattern map. - For each of them, call `ov::op::util::visit_path_forward` with a `skip_node_predicate` that returns `true` on `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op and halts at SDPA boundaries. - Count `ScaledDotProductAttention` nodes in the resulting `visited` set. - If the count is greater than 1 for either `past_k` or `past_v`, the matched SDPA shares its cache with at least one other SDPA — return `false` from the callback and leave it unfused. On the problematic graph: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` the BFS from the `ReadValue` walks through Concat and both Transposes, reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An SDPA on a different, unshared `ReadValue` in the same model is still fused as before. In short: #35260 asks "do these SDPAs have the same neighbor?", and the answer depends on the current shape of the graph. This PR asks "from this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is independent of intermediate topology, and the decision is taken per SDPA, so the fix is minimal in scope. ## Verification - Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`): all 7 tests pass, including a new `StateConcatSDPAMixedSharedAndExclusive` that builds a model with one `ReadValue` feeding two SDPAs (shared) and another `ReadValue` feeding a single SDPA (exclusive), and asserts that after `SDPASubgraphFusion` exactly one `ScaledDotProductAttentionWithKVCache` is produced while the two shared SDPAs remain plain `ScaledDotProductAttention`. - Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650 at commit `698bfcec`, which predates the optimum-side workaround that replaces SDPA with matmul. Without this PR the graph still contains `ScaledDotProductAttention` ops and the crash reproduces; with this PR inference succeeds. - Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675. ### Tickets: - 183493 --------- Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
…kit#35323) ## Summary Models like Gemma3n and Gemma4 share one KV-cache across multiple SDPA blocks (per-layer KV sharing). `StatefulSDPAFusion` partially fuses such graphs and leaves other SDPA blocks unfused. The resulting inconsistent model crashes at inference with `null input states` in `ScaledDotProductAttentionWithKVCache` (scaled_attn.cpp:1982). This PR makes `StatefulSDPAFusion` skip the fusion only for SDPAs whose KV-cache is shared with another SDPA, while still fusing the exclusive ones in the same model. Inside the `StatefulSDPAFusion` matcher callback, for the matched SDPA we walk forward from its `past_k` / `past_v` `ReadValue` via `ov::op::util::visit_path_forward`, stopping at SDPA boundaries, and count how many `ScaledDotProductAttention` ops are reachable. If more than one SDPA is reachable from either the K-cache or the V-cache side, the callback returns `false` and this particular SDPA is left unfused — other SDPAs in the same model are still fused normally. ## Relation to openvinotoolkit#35260 openvinotoolkit#35260 addresses the same crash with a simpler check: it looks at the direct K/V input nodes of each SDPA (`input_value(1).get_node()`, `input_value(2).get_node()`) and counts how many SDPAs reference the same node. ## Why the direct-input check is fragile It only works when there are no intermediate ops between the shared source of the KV-cache and the SDPA. In the idealized shape it expects, the graph looks like: ``` ReadValue → Concat → SDPA1 │ └──→ SDPA2 ← same Concat object is the direct K input of both SDPAs ✓ ``` Here SDPA1 and SDPA2 literally share the same Concat node pointer as their K input, so the check sees them as sharing and skips the fusion. In real graphs (Gemma3n, Gemma4, and other KV-sharing models) — and depending on which version of transformers / optimum-intel was used to export the model — almost every path from the shared source to an SDPA carries some intermediate op: Transpose, Reshape, Convert, Gather, Broadcast, and so on. The graph typically looks like this instead: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` Now SDPA1's direct K input is `Transpose_A` and SDPA2's is `Transpose_B`. Even if both Transposes have identical parameters they are distinct `ov::Node` objects, so the "same direct input" comparison returns `false`. The sharing is no longer detected, `StatefulSDPAFusion` runs, partially fuses the graph, and the model crashes at runtime with `null input states` in `ScaledDotProductAttentionWithKVCache`. The set of intermediate ops on the K/V path is not stable either — it depends on which earlier passes (TransposeSinking, SimplifyGatherShapeOf, shape-inference rewrites, etc.) have already run, and on small differences in how the model was exported. This makes the direct-input check behave differently across otherwise equivalent Gemma-style models. ## The fix: walk forward from the ReadValue, decide per SDPA Two changes vs. openvinotoolkit#35260: 1. **Anchor the check on the `ReadValue`, not on direct-input node pointers.** Instead of comparing `ov::Node*` pointers of the SDPA's direct K/V inputs, we walk the graph forward from the matched SDPA's `past_k` / `past_v` `ReadValue` and count how many SDPAs are reachable. The traversal passes through any non-SDPA op (Transpose, Reshape, Convert, Gather, Concat, Broadcast, …) and stops at SDPA boundaries, so intermediate topology does not hide the sharing. 2. **Decide per SDPA, not per model.** Skipping `StatefulSDPAFusion` globally as soon as the model contains a single shared-KV-cache SDPA is overly conservative: a model can mix SDPAs that share a cache with SDPAs that do not, and the exclusive ones would lose their fusion unnecessarily. The decision is made per SDPA inside the matcher callback, so exclusive SDPAs in the same model are still fused. ## Algorithm Inside `StatefulSDPAFusion`'s callback, for the matched SDPA: - Take the matched `past_k` and `past_v` `ReadValue` nodes from the pattern map. - For each of them, call `ov::op::util::visit_path_forward` with a `skip_node_predicate` that returns `true` on `ScaledDotProductAttention`, so the BFS walks through any non-SDPA op and halts at SDPA boundaries. - Count `ScaledDotProductAttention` nodes in the resulting `visited` set. - If the count is greater than 1 for either `past_k` or `past_v`, the matched SDPA shares its cache with at least one other SDPA — return `false` from the callback and leave it unfused. On the problematic graph: ``` ReadValue → Concat ──┬──→ Transpose_A → SDPA1 │ └──→ Transpose_B → SDPA2 ``` the BFS from the `ReadValue` walks through Concat and both Transposes, reaches SDPA1 and SDPA2, counts 2, and leaves this SDPA unfused. An SDPA on a different, unshared `ReadValue` in the same model is still fused as before. In short: openvinotoolkit#35260 asks "do these SDPAs have the same neighbor?", and the answer depends on the current shape of the graph. This PR asks "from this SDPA's KV-cache slot, can I reach another SDPA?" — the answer is independent of intermediate topology, and the decision is taken per SDPA, so the fix is minimal in scope. ## Verification - Unit tests (`ov_cpu_unit_tests --gtest_filter='*StateConcatSDPA*'`): all 7 tests pass, including a new `StateConcatSDPAMixedSharedAndExclusive` that builds a model with one `ReadValue` feeding two SDPAs (shared) and another `ReadValue` feeding a single SDPA (exclusive), and asserts that after `SDPASubgraphFusion` exactly one `ScaledDotProductAttentionWithKVCache` is produced while the two shared SDPAs remain plain `ScaledDotProductAttention`. - Gemma3n: verified with optimum-intel PR huggingface/optimum-intel#1650 at commit `698bfcec`, which predates the optimum-side workaround that replaces SDPA with matmul. Without this PR the graph still contains `ScaledDotProductAttention` ops and the crash reproduces; with this PR inference succeeds. - Gemma4: verified with optimum-intel PR huggingface/optimum-intel#1675. ### Tickets: - 183493 --------- Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
What does this PR do?
Fixes 182357
Installation instructions:
Exporting cmd-line:
optimum-cli export openvino -m google/gemma-4-E2B-it ov_gemma4_E2Bit --task=image-text-to-textInference script:
Fixes # (issue)
Before submitting