From fa924997967d9fa3908f59226a15e8cb60c48b23 Mon Sep 17 00:00:00 2001 From: Fabio Zanini Date: Wed, 25 Feb 2026 13:59:09 +1100 Subject: [PATCH] Export to pytorch geometric --- setup.py | 4 ++++ src/igraph/__init__.py | 12 ++++++---- src/igraph/io/libraries.py | 45 +++++++++++++++++++++++++++++++++++++ tests/test_foreign.py | 46 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 103 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index c919d3507..9c5b7c7f9 100644 --- a/setup.py +++ b/setup.py @@ -1018,6 +1018,10 @@ def get_tag(self): "plotly>=5.3.0", "Pillow>=9; platform_python_implementation != 'PyPy'", ], + "test-pyg": [ + "torch>=2.0.0; platform_python_implementation != 'PyPy'", + "torch-geometric>=2.0.0; platform_python_implementation != 'PyPy'", + ], # Dependencies needed for testing on Windows ARM64; only those that are either # pure Python or have Windows ARM64 wheels as we don't want to compile wheels # in CI diff --git a/src/igraph/__init__.py b/src/igraph/__init__.py index 6a4e189b9..85ad1a47a 100644 --- a/src/igraph/__init__.py +++ b/src/igraph/__init__.py @@ -211,6 +211,7 @@ _export_graph_to_networkx, _construct_graph_from_graph_tool, _export_graph_to_graph_tool, + _export_graph_to_torch_geometric, ) from igraph.io.random import ( _construct_random_geometric_graph, @@ -463,6 +464,8 @@ def __init__(self, *args, **kwds): from_graph_tool = classmethod(_construct_graph_from_graph_tool) to_graph_tool = _export_graph_to_graph_tool + to_torch_geometric = _export_graph_to_torch_geometric + # Files Read_DIMACS = classmethod(_construct_graph_from_dimacs_file) write_dimacs = _write_graph_to_dimacs_file @@ -708,7 +711,9 @@ def es(self): ########################### # Paths/traversals - def get_all_simple_paths(self, v, to=None, minlen=0, maxlen=-1, mode="out", max_results=None): + def get_all_simple_paths( + self, v, to=None, minlen=0, maxlen=-1, mode="out", max_results=None + ): """Calculates all the simple paths from a given node to some other nodes (or all of them) in a graph. @@ -973,15 +978,14 @@ def Incidence(cls, *args, **kwds): def are_connected(self, *args, **kwds): """Deprecated alias to L{Graph.are_adjacent()}.""" deprecated( - "Graph.are_connected() is deprecated; use Graph.are_adjacent() " "instead" + "Graph.are_connected() is deprecated; use Graph.are_adjacent() instead" ) return self.are_adjacent(*args, **kwds) def get_incidence(self, *args, **kwds): """Deprecated alias to L{Graph.get_biadjacency()}.""" deprecated( - "Graph.get_incidence() is deprecated; use Graph.get_biadjacency() " - "instead" + "Graph.get_incidence() is deprecated; use Graph.get_biadjacency() instead" ) return self.get_biadjacency(*args, **kwds) diff --git a/src/igraph/io/libraries.py b/src/igraph/io/libraries.py index f35cc9545..9b06f41f1 100644 --- a/src/igraph/io/libraries.py +++ b/src/igraph/io/libraries.py @@ -270,3 +270,48 @@ def _construct_graph_from_graph_tool(cls, g): graph.add_edges(edges, eattr) return graph + + +def _export_graph_to_torch_geometric( + graph, vertex_attributes=None, edge_attributes=None +): + """Converts the graph to torch geometric + + Data types: graph-tool only accepts specific data types. See the + following web page for a list: + + https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data + + @param g: graph-tool Graph + @param vertex_attributes: dictionary of vertex attributes to transfer. + Keys are attributes from the vertices, values are data types (see + below). C{None} means no vertex attributes are transferred. + @param edge_attributes: dictionary of edge attributes to transfer. + Keys are attributes from the edges, values are data types (see + below). C{None} means no vertex attributes are transferred. + """ + import torch + from torch_geometric.data import Data + + if vertex_attributes is None: + vertex_attributes = graph.vertex_attributes() + if edge_attributes is None: + edge_attributes = graph.edge_attributes() + + # Edge index + edge_index = torch.tensor(graph.get_edgelist(), dtype=torch.long) + + # Node attributes + x = torch.tensor([graph.vs[attr] for attr in vertex_attributes]) + if x.ndim > 1: + x = x.permute(*torch.arange(x.ndim - 1, -1, -1)) + + # Edge attributes + edge_attr = torch.tensor([graph.es[attr] for attr in edge_attributes]) + if edge_attr.ndim > 1: + edge_attr = edge_attr.permute(*torch.arange(edge_attr.ndim - 1, -1, -1)) + + # Wrap into correct data structure + data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + + return data diff --git a/tests/test_foreign.py b/tests/test_foreign.py index 83664e5f5..b9caf6b52 100644 --- a/tests/test_foreign.py +++ b/tests/test_foreign.py @@ -23,6 +23,14 @@ pd = None +try: + import torch + from torch_geometric.data import Data as PyGData +except ImportError: + torch = None + PyGData = None + + GRAPHML_EXAMPLE_FILE = """\