Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,10 @@ def get_tag(self):
"plotly>=5.3.0",
"Pillow>=9; platform_python_implementation != 'PyPy'",
],
"test-pyg": [
"torch>=2.0.0; platform_python_implementation != 'PyPy'",
"torch-geometric>=2.0.0; platform_python_implementation != 'PyPy'",
],
# Dependencies needed for testing on Windows ARM64; only those that are either
# pure Python or have Windows ARM64 wheels as we don't want to compile wheels
# in CI
Expand Down
12 changes: 8 additions & 4 deletions src/igraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@
_export_graph_to_networkx,
_construct_graph_from_graph_tool,
_export_graph_to_graph_tool,
_export_graph_to_torch_geometric,
)
from igraph.io.random import (
_construct_random_geometric_graph,
Expand Down Expand Up @@ -463,6 +464,8 @@ def __init__(self, *args, **kwds):
from_graph_tool = classmethod(_construct_graph_from_graph_tool)
to_graph_tool = _export_graph_to_graph_tool

to_torch_geometric = _export_graph_to_torch_geometric

# Files
Read_DIMACS = classmethod(_construct_graph_from_dimacs_file)
write_dimacs = _write_graph_to_dimacs_file
Expand Down Expand Up @@ -708,7 +711,9 @@ def es(self):

###########################
# Paths/traversals
def get_all_simple_paths(self, v, to=None, minlen=0, maxlen=-1, mode="out", max_results=None):
def get_all_simple_paths(
self, v, to=None, minlen=0, maxlen=-1, mode="out", max_results=None
):
"""Calculates all the simple paths from a given node to some other nodes
(or all of them) in a graph.

Expand Down Expand Up @@ -973,15 +978,14 @@ def Incidence(cls, *args, **kwds):
def are_connected(self, *args, **kwds):
"""Deprecated alias to L{Graph.are_adjacent()}."""
deprecated(
"Graph.are_connected() is deprecated; use Graph.are_adjacent() " "instead"
"Graph.are_connected() is deprecated; use Graph.are_adjacent() instead"
)
return self.are_adjacent(*args, **kwds)

def get_incidence(self, *args, **kwds):
"""Deprecated alias to L{Graph.get_biadjacency()}."""
deprecated(
"Graph.get_incidence() is deprecated; use Graph.get_biadjacency() "
"instead"
"Graph.get_incidence() is deprecated; use Graph.get_biadjacency() instead"
)
return self.get_biadjacency(*args, **kwds)

Expand Down
45 changes: 45 additions & 0 deletions src/igraph/io/libraries.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,48 @@ def _construct_graph_from_graph_tool(cls, g):
graph.add_edges(edges, eattr)

return graph


def _export_graph_to_torch_geometric(
graph, vertex_attributes=None, edge_attributes=None
):
"""Converts the graph to torch geometric

Data types: graph-tool only accepts specific data types. See the
following web page for a list:

https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data

@param g: graph-tool Graph
@param vertex_attributes: dictionary of vertex attributes to transfer.
Keys are attributes from the vertices, values are data types (see
below). C{None} means no vertex attributes are transferred.
@param edge_attributes: dictionary of edge attributes to transfer.
Keys are attributes from the edges, values are data types (see
below). C{None} means no vertex attributes are transferred.
"""
import torch
from torch_geometric.data import Data

if vertex_attributes is None:
vertex_attributes = graph.vertex_attributes()
if edge_attributes is None:
edge_attributes = graph.edge_attributes()

# Edge index
edge_index = torch.tensor(graph.get_edgelist(), dtype=torch.long)

# Node attributes
x = torch.tensor([graph.vs[attr] for attr in vertex_attributes])
if x.ndim > 1:
x = x.permute(*torch.arange(x.ndim - 1, -1, -1))

# Edge attributes
edge_attr = torch.tensor([graph.es[attr] for attr in edge_attributes])
if edge_attr.ndim > 1:
edge_attr = edge_attr.permute(*torch.arange(edge_attr.ndim - 1, -1, -1))

# Wrap into correct data structure
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

return data
46 changes: 46 additions & 0 deletions tests/test_foreign.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
pd = None


try:
import torch
from torch_geometric.data import Data as PyGData
except ImportError:
torch = None
PyGData = None


GRAPHML_EXAMPLE_FILE = """\
<?xml version="1.0" encoding="UTF-8"?>
<graphml xmlns="http://graphml.graphdrawing.org/xmlns"
Expand Down Expand Up @@ -821,6 +829,44 @@ def testGraphGraphTool(self):
self.assertEqual(g.vcount(), g2.vcount())
self.assertEqual(sorted(g.get_edgelist()), sorted(g2.get_edgelist()))

@unittest.skipIf(PyGData is None, "test case depends on torch_geometric")
def testGraphTorchGeometric(self):
# Undirected
g = Graph.Ring(10)
g.vs["vattr"] = list(range(g.vcount()))
g.es["eattr"] = list(range(len(g.es)))

# Go to torch geometric
data_pyg = g.to_torch_geometric()

self.assertEqual(g.vcount(), data_pyg.num_nodes)
self.assertEqual(
sorted([list(x) for x in g.get_edgelist()]),
sorted(data_pyg.edge_index.tolist()),
)

# Test attributes
self.assertEqual(g.vs["vattr"], data_pyg.x[:, 0].tolist())
self.assertEqual(g.es["eattr"], data_pyg.edge_attr[:, 0].tolist())

# Directed
g = Graph.Ring(10, directed=True)
g.vs["vattr"] = list(range(g.vcount()))
g.es["eattr"] = list(range(len(g.es)))

# Go to torch geometric
data_pyg = g.to_torch_geometric()

self.assertEqual(g.vcount(), data_pyg.num_nodes)
self.assertEqual(
sorted([list(x) for x in g.get_edgelist()]),
sorted(data_pyg.edge_index.tolist()),
)

# Test attributes
self.assertEqual(g.vs["vattr"], data_pyg.x[:, 0].tolist())
self.assertEqual(g.es["eattr"], data_pyg.edge_attr[:, 0].tolist())

@unittest.skipIf(gt is None, "test case depends on graph-tool")
def testMultigraphGraphTool(self):
# Undirected
Expand Down
Loading