Skip to content

[NNCFGraph] Migrate from nx.DiGraph to nx.MultiDiGraph#3843

Merged
AlexanderDokuchaev merged 5 commits intoopenvinotoolkit:developfrom
daniil-lyakhov:dl/nncf_multidigraph
Mar 25, 2026
Merged

[NNCFGraph] Migrate from nx.DiGraph to nx.MultiDiGraph#3843
AlexanderDokuchaev merged 5 commits intoopenvinotoolkit:developfrom
daniil-lyakhov:dl/nncf_multidigraph

Conversation

@daniil-lyakhov
Copy link
Copy Markdown
Collaborator

@daniil-lyakhov daniil-lyakhov commented Jan 15, 2026

Changes

  • NNCFGraph + InsertionPointGraph + PatternMatchingGraph are inhereted from nx.MultiDiGraph instead of nx.DiGraph
  • SDPA Ignored pattern is updated with split operation as a branching metatype for OpenVINO/ONNX backends
  • Hardware fusing pattern conv -> arithmetic is updated to have 1 and 2 degrees edges like on this picture:
image
  • ROPE patern is updated to have several parallel input edges to the concat node
image

Reason for changes

  • To make it possible to represent and compress models with Multi edges in NNCF
  • To quantize YOLO26 SDPA block properly in OpenVINO/ONNX backends
  • To quantize models like a = conv(x); return aa correctly (not quantizing the xx part)
  • To apply the ROPE pattern correctly, as the original pattern has parallel edges

Related tickets

174691
179960

Tests

Jobs

Test WC: https://github.com/openvinotoolkit/nncf/actions/runs/21724379639 - Green
Test examples: https://github.com/openvinotoolkit/nncf/actions/runs/21724373555 - Green
PTQ: NNCF/job/manual/job/post_training_quantization/804/ - Green

@github-actions github-actions Bot added NNCF PT Pull requests that updates NNCF PyTorch NNCF Common Pull request that updates NNCF Common NNCF OpenVINO Pull requests that updates NNCF OpenVINO labels Jan 15, 2026
@daniil-lyakhov daniil-lyakhov force-pushed the dl/nncf_multidigraph branch 3 times, most recently from 9437ebb to bf47f71 Compare January 29, 2026 15:06
@github-actions github-actions Bot added the NNCF ONNX Pull requests that updates NNCF ONNX label Feb 2, 2026
@daniil-lyakhov daniil-lyakhov force-pushed the dl/nncf_multidigraph branch 3 times, most recently from 40c2053 to 8b4c822 Compare February 4, 2026 19:16
Comment thread tests/onnx/quantization/test_transformer_models_graph.py Outdated
@daniil-lyakhov daniil-lyakhov force-pushed the dl/nncf_multidigraph branch 8 times, most recently from 2be12e8 to 51401fc Compare February 10, 2026 13:49
@daniil-lyakhov daniil-lyakhov marked this pull request as ready for review February 10, 2026 14:24
@daniil-lyakhov daniil-lyakhov requested a review from a team as a code owner February 10, 2026 14:24
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR migrates NNCF graph representations from nx.DiGraph to nx.MultiDiGraph to support parallel/multi-edges, and updates quantization/hardware-fusing/ignored-pattern logic (notably SDPA and RoPE) plus corresponding tests and reference graphs.

Changes:

  • Switched core graph utilities, pattern matching, and insertion-point graphs to nx.MultiDiGraph and adapted edge handling APIs (get_edges, keyed edges in .dot files).
  • Updated HW fused patterns (conv -> arithmetic) and ignored patterns (RoPE + SDPA) to account for parallel edges/branching.
  • Added/updated synthetic models, test fixtures, and many .dot/JSON references to validate multi-edge behavior across Torch/FX, OpenVINO, and ONNX.

Reviewed changes

Copilot reviewed 97 out of 114 changed files in this pull request and generated 13 comments.

Show a summary per file
File Description
tools/render_dot_to_svg.py Uses nx.MultiDiGraph when rendering .dot graphs to preserve multi-edges.
tools/clip_dot.py Uses nx.MultiDiGraph for reading and producing clipped graphs with possible parallel edges.
tests/torch/utils.py Updates comparable graph conversion to output nx.MultiDiGraph.
tests/torch/test_graph_analysis.py Updates mock graph to MultiDiGraph and edge construction without parallel-port field.
tests/torch/fx/test_weights_compression.py Updates edge retrieval for multi-edge API and parameterizes RoPE degree.
tests/torch/fx/test_quantizer_config.py Adds fixtures for degree-2 arithmetic patterns and split-based transformer graphs.
tests/torch/fx/test_quantizer.py Removes parallel_input_port_ids usage in normalization helper.
tests/torch/fx/test_models.py Adds YOLO26 attention block to FX model cases.
tests/torch/function_hook/quantization/test_weights_compression.py Parameterizes RoPE model creation with degree.
tests/torch/function_hook/quantization/test_quantizer_config.py Adds fixtures mirroring FX config tests (degree-2 + split transformer).
tests/torch/function_hook/quantization/test_quantized_graphs.py Adds models for parallel edges + YOLO26; updates RoPE to degree=2.
tests/torch/function_hook/quantization/helper.py Adds builder for degree-2 conv->arithmetic graph template.
tests/torch/data/fx/yolo26_attn_block.dot Adds FX reference .dot for YOLO26 attention block.
tests/torch/data/fx/reference_metatypes/yolo26_attn_block.json Adds metatype reference for YOLO26 attention block.
tests/torch/data/function_hook/sparsify_activations/three_linear_sparse_activations.dot Updates .dot format to include multi-edge key= and non-strict header.
tests/torch/data/function_hook/sparsify_activations/three_linear_int8_sym_weights_sparse_activations.dot Same .dot updates for int8 sparse-activations reference.
tests/torch/data/function_hook/sparsify_activations/three_linear_ignore1_sparse_activations.dot Same .dot updates for ignore-scope reference.
tests/torch/data/function_hook/sparsify_activations/three_linear_ignore1_int8_sym_weights_sparse_activations.dot Same .dot updates for ignore-scope + int8 reference.
tests/torch/data/function_hook/sparsify_activations/linear_sparse_activations.dot Same .dot updates for single-linear sparse reference.
tests/torch/data/function_hook/sparsify_activations/linear_int8_sym_weights_sparse_activations.dot Same .dot updates for single-linear + int8 reference.
tests/torch/data/function_hook/quantization/test_quantized_graphs/yolo26_attn_block.dot Adds quantized-graph .dot reference for YOLO26 attention block.
tests/torch/data/function_hook/quantization/test_quantized_graphs/unbind_scaled_dot_product_attention_model.dot Updates .dot to keyed multi-edges.
tests/torch/data/function_hook/quantization/test_quantized_graphs/shared_model.dot Replaces parallel_input_port_ids with explicit parallel edges via key.
tests/torch/data/function_hook/quantization/test_quantized_graphs/scaled_dot_product_attention_model.dot Updates .dot to keyed multi-edges.
tests/torch/data/function_hook/quantization/test_quantized_graphs/rope_model.dot Updates RoPE .dot to represent parallel concat inputs as multi-edges.
tests/torch/data/function_hook/quantization/test_quantized_graphs/parallel_edges_model.dot Adds .dot reference for explicit parallel-edge model.
tests/torch/data/function_hook/quantization/test_quantized_graphs/lenet.dot Updates .dot to keyed multi-edges.
tests/torch/data/function_hook/quantization/test_quantized_graphs/embedding_model.dot Updates .dot to keyed multi-edges.
tests/torch/data/function_hook/pruning_and_quantization/prune_ptq_model.dot Updates .dot to keyed multi-edges.
tests/torch/data/function_hook/nncf_graph/model_graph_with_shared_parameters.dot Updates .dot to keyed multi-edges.
tests/torch/data/function_hook/nncf_graph/model_graph_gru.dot Updates .dot to keyed multi-edges.
tests/torch/data/function_hook/nncf_graph/convert_to_nncf_graph_multi_edges.dot Updates reference from “parallel ports” to explicit multi-edges.
tests/torch/data/function_hook/nncf_graph/convert_to_nncf_graph.dot Updates .dot to keyed multi-edges.
tests/torch/data/function_hook/compress_weights/fq_lora/shared_weights_all_layers_True.dot Updates .dot to keyed multi-edges.
tests/torch/data/function_hook/compress_weights/fq_lora/shared_weights_all_layers_False.dot Updates .dot to keyed multi-edges.
tests/openvino/native/test_nncf_graph_builder.py Removes parallel_input_port_ids from edge fixtures.
tests/openvino/native/test_model_utils.py Updates direct edge-attr access to include MultiDiGraph edge key.
tests/openvino/native/quantization/test_weights_compression.py Parameterizes RoPE model creation with degree.
tests/openvino/native/quantization/test_quantizer_config.py Adds split + degree-2 graph fixtures and includes OVSplitMetatype.
tests/openvino/native/quantization/test_graphs.py Updates transformer model paramization to use partial(RoPEModel, degree=2).
tests/openvino/native/models.py Updates RoPE model to represent concat degree and adds YOLO26 attention block model.
tests/openvino/native/data/2025.4/reference_scales/YOLO26AttentionBlock_performance.json Adds reference scales for YOLO26 attention block.
tests/openvino/native/data/2025.4/reference_scales/YOLO26AttentionBlock_mixed.json Adds mixed reference scales for YOLO26 attention block.
tests/openvino/native/data/2025.4/reference_graphs/quantized/YOLO26AttentionBlock.dot Adds quantized reference .dot for YOLO26 attention block with keyed edges.
tests/openvino/native/data/2025.4/reference_graphs/original_nncf_graph/YOLO26AttentionBlock.dot Adds original-graph reference .dot for YOLO26 attention block with keyed edges.
tests/openvino/native/data/2024.4/reference_graphs/quantized/RoPEModel.dot Updates RoPE reference .dot to keyed multi-edges and renamed nodes.
tests/onnx/test_nncf_graph_builder.py Adds ONNX multi-edge conversion test using Split.
tests/onnx/quantization/test_weights_compression.py Parameterizes RoPE concat degree for ONNX model builder.
tests/onnx/quantization/test_quantizer_config.py Adds split + degree-2 fixtures and includes ONNXSplitMetatype.
tests/onnx/models.py Adds synthetic ONNX model with parallel edges via repeated Concat inputs.
tests/onnx/data/reference_graphs/quantization/synthetic/multi_input_output_parallel_edges_model.dot Adds keyed-edge reference graph for ONNX parallel-edges synthetic model.
tests/onnx/data/reference_graphs/original_nncf_graph/synthetic/multi_input_output_parallel_edges_model.dot Adds keyed-edge original-graph reference for ONNX parallel-edges model.
tests/onnx/common.py Adds ModelBuilder.add_split helper for Split-based models.
tests/cross_fw/test_templates/test_quantizer_config.py Adds template tests/states for degree-2 conv+arith and split-transformer patterns.
tests/cross_fw/test_templates/template_test_weights_compression.py Parameterizes RoPE compression test by concat degree.
tests/cross_fw/test_templates/models.py Adds graph templates for arithmetic degree-2 and split-transformer.
tests/cross_fw/test_templates/helpers.py Updates torch RoPE model to accept degree; adds ParallelEdgesModel + YOLO26AttentionBlock.
tests/cross_fw/shared/nx_graph.py Updates DOT sorting + graph comparison to MultiDiGraph and keyed edges.
tests/common/quantization/test_quantizer_propagation_solver.py Converts mock graphs to MultiDiGraph and updates edge indexing to include keys.
tests/common/quantization/test_quantizer_propagation_graph.py Converts mock graphs to MultiDiGraph and updates edge indexing to include keys.
tests/common/quantization/mock_graphs.py Converts mock graph utilities to MultiDiGraph; ensures in_edges/out_edges use keys.
tests/common/graph/test_nncf_graph.py Adds explicit multi-edge tests and duplicate-edge validation test.
tests/common/graph/test_dot_file_rw.py Updates DOT read/write test graph type to MultiDiGraph.
tests/common/data/reference_graphs/dot_rw_reference.dot Updates DOT reference to non-strict digraph and keyed edge format.
src/nncf/torch/quantization/ignored_patterns.py Reuses shared RoPE pattern builder to support parallel edges.
src/nncf/torch/model_graph_manager.py Updates constant-input resolution for multi-edge graph API.
src/nncf/torch/hardware/fused_patterns.py Updates HW fused patterns to include parallel-edge degree alternatives.
src/nncf/torch/graph/graph.py Updates insertion-point shape extraction using edge-by-port APIs.
src/nncf/torch/function_hook/nncf_graph/nncf_graph_builder.py Emits explicit edges for each meta edge instead of parallel-port encoding.
src/nncf/quantization/passes.py Updates node-removal reconnect logic to handle multi-edges.
src/nncf/quantization/ignored_patterns.py Introduces shared create_rope_pattern supporting concat degree alternatives.
src/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py Updates weight-shape/axes helpers to handle multi-edge APIs and validates edge uniqueness.
src/nncf/quantization/algorithms/weight_compression/torch_backend.py Updates weight/activation port resolution to iterate over multiple edges.
src/nncf/quantization/algorithms/smooth_quant/algorithm.py Updates grouping logic to validate and use a single edge between nodes.
src/nncf/openvino/quantization/ignored_patterns.py Reuses shared RoPE pattern builder and updates SDPA ignored pattern to include split.
src/nncf/openvino/hardware/fused_patterns.py Updates HW fused patterns to include parallel-edge degree alternatives.
src/nncf/openvino/graph/nncf_graph_builder.py Emits explicit edges per OpenVINO input port instead of parallel-port encoding.
src/nncf/onnx/quantization/ignored_patterns.py Reuses shared RoPE pattern builder and updates SDPA ignored pattern to include split.
src/nncf/onnx/hardware/fused_patterns.py Updates HW fused patterns to include parallel-edge degree alternatives.
src/nncf/onnx/graph/onnx_helper.py Updates input-port lookup to return all matching input ports for parallel edges.
src/nncf/onnx/graph/nncf_graph_builder.py Updates ONNX graph builder to create explicit multi-edges and disambiguate repeated inputs.
src/nncf/experimental/torch/sparsify_activations/torch_backend.py Updates activation-port resolution to iterate over multiple edges.
src/nncf/common/utils/dot_file_rw.py Updates DOT reader to return nx.MultiDiGraph.
src/nncf/common/insertion_point_graph.py Migrates insertion-point graph to MultiDiGraph and updates edge/key handling.
src/nncf/common/graph/patterns/patterns.py Migrates GraphPattern to MultiDiGraph and adds join_patterns_parallel.
src/nncf/common/graph/graph_matching.py Uses MultiDiGraphMatcher for pattern matching on MultiDiGraph graphs.
src/nncf/common/graph/graph.py Migrates core NNCFGraph to MultiDiGraph; replaces parallel-port encoding with explicit multi-edges.
Comments suppressed due to low confidence (1)

src/nncf/common/utils/dot_file_rw.py:1

  • read_dot_graph now returns nx.MultiDiGraph, but write_dot_graph still type-annotates nx.DiGraph. Consider widening the annotation (e.g., nx.MultiDiGraph or a union of graph types) to reflect actual usage and prevent confusion in downstream call sites.
# Copyright (c) 2026 Intel Corporation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


for u, v in self._nx_graph.edges:
edge = self._nx_graph.edges[u, v]
for u, v, k in self._nx_graph.edges:
Copy link

Copilot AI Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iterating over MultiDiGraph.edges without keys=True yields (u, v) pairs, so unpacking into (u, v, k) will raise at runtime. Use self._nx_graph.edges(keys=True) (or edges(keys=True, data=True)) when you need edge keys, and keep access consistent with that.

Suggested change
for u, v, k in self._nx_graph.edges:
for u, v, k in self._nx_graph.edges(keys=True):

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a valid comment?”

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an invalid comment. There is a difference betwen the property edges and the function edges in the netrokx: the property always returns the tupple of 3 for the MultiDigraph, so this change won't fix anything.
The property call: https://github.com/networkx/networkx/blob/main/networkx/classes/multidigraph.py#L679
OutMultiEdgeView iter: https://github.com/networkx/networkx/blob/main/networkx/classes/reportviews.py#L1444-L1448

Comment thread src/nncf/common/graph/graph.py
Comment thread src/nncf/common/graph/graph.py
Comment thread src/nncf/common/graph/graph.py
Comment thread src/nncf/common/graph/graph.py
Comment thread src/nncf/quantization/algorithms/smooth_quant/algorithm.py Outdated
Comment thread tests/common/graph/test_nncf_graph.py Outdated
Comment thread tests/cross_fw/test_templates/test_quantizer_config.py Outdated
Comment thread tests/torch/utils.py Outdated
Comment thread tests/torch/utils.py Outdated
Comment thread src/nncf/common/insertion_point_graph.py
@daniil-lyakhov daniil-lyakhov force-pushed the dl/nncf_multidigraph branch 2 times, most recently from 485ee5a to 8337ef2 Compare February 19, 2026 14:46
@daniil-lyakhov daniil-lyakhov force-pushed the dl/nncf_multidigraph branch 5 times, most recently from 026b0c6 to 0065810 Compare March 18, 2026 17:02
"""
self._unite_with_copy_of_graph(other.graph)

def join_patterns_parallel(self, other: "GraphPattern", degree: int) -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add tests of new method
And test in test_graph_matching.py to matching graphs with parallel edges

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 651 to +652

for u, v in self._nx_graph.edges:
edge = self._nx_graph.edges[u, v]
for u, v, k in self._nx_graph.edges:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's can be used as

        for u, v, k, data in self._nx_graph.edges(keys=True, data=True):

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please provide an example when my approach won't work


pattern.add_edge(matmul_node, transpose_node)
pattern.add_edge(transpose_node, concat_node)
for _ in range(cat_degree):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we are discused rope pattern is broken,
Is it really need to change it now?

And mannually create two edge is shorter and easy to read

pattern.add_edge(transpose_node, concat_node)
pattern.add_edge(transpose_node, concat_node)

Than

    cat_degree = 2
    ...
    for _ in range(cat_degree):
       pattern.add_edge(transpose_node, concat_node)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept them for ONNX/Torch/TorchFX the same but skipped the test, please tell me if you want it in an outher way

Comment thread tests/common/graph/test_nncf_graph.py Outdated
for node in "abc":
nodes.append(nncf_graph.add_nncf_node(node, f"type_{node}", f"metatype_{node}"))

nncf_graph.add_edge_between_nncf_nodes(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like edge 0-0 is not used in this test

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

INPUT_SIZE = [1, 10]

def __init__(self):
def __init__(self, degree: int):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we was discuesed you should ouse real implementation of rope module
WIth degree=1 it is not valide

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

return x1, x2


class RoPEWCModel(nn.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same it's not valide RoPE module
Please fix it another PR

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

args = parser.parse_args(args=argv)

graph = nx.DiGraph(read_dot_graph(args.input_file))
graph = nx.MultiDiGraph(read_dot_graph(args.input_file))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you check script from tools/other?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, found a bug, fixed and tested manualy, thanks

@daniil-lyakhov daniil-lyakhov force-pushed the dl/nncf_multidigraph branch 3 times, most recently from 2f3fad8 to b6ab0bc Compare March 23, 2026 17:03
@AlexanderDokuchaev AlexanderDokuchaev merged commit 91e6c16 into openvinotoolkit:develop Mar 25, 2026
18 checks passed
AlexanderDokuchaev pushed a commit that referenced this pull request Mar 27, 2026
### Changes

Postprocessing of the ssd300_vgg model from the PTQ torch example is not
traced to skip quantization of the last multiplies

### Reason for changes

The #3843 fixed the issue with output ports ids:
https://github.com/openvinotoolkit/nncf/blob/develop/src/nncf/quantization/algorithms/min_max/algorithm.py#L792-L801

This lead to additional fqs in the ssd block hence fail of example test
https://github.com/openvinotoolkit/nncf/actions/runs/23580354118/job/68661703894

Before the multidigraph:
<img width="887" height="736" alt="image"
src="https://github.com/user-attachments/assets/0bc088bc-7e6f-486a-a4d4-7aa0b29696ce"
/>

After the multidigraph:
<img width="1299" height="269" alt="image"
src="https://github.com/user-attachments/assets/9301e3ff-a394-4962-a1fa-bdb159a0e03f"
/>

All the multiplies are being quantized in comparison with the old
approach - output_port == 0 and only one FQ was inserted

### Related tickets

183762

### Tests

Test examples:
https://github.com/openvinotoolkit/nncf/actions/runs/23602708146
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Code Freeze NNCF Common Pull request that updates NNCF Common NNCF ONNX Pull requests that updates NNCF ONNX NNCF OpenVINO Pull requests that updates NNCF OpenVINO NNCF PT Pull requests that updates NNCF PyTorch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants