Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pages/advanced-algorithms/available-algorithms.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ If you want to know more and learn how this affects you, read our [announcement]

| Algorithms | Lang | Description |
|---------------------------------------------------------------------------------------------------|--------|-----------------------------------------------------------------------------------------------------------------------------------------------------------|
| [gnn](/advanced-algorithms/available-algorithms/gnn) | Python | Export and import graph data in PyTorch Geometric (PyG) and TensorFlow GNN (TF-GNN) formats for GNN training pipelines. |
| [link_prediction with GNN](/advanced-algorithms/available-algorithms/gnn_link_prediction) | Python | Module for predicting links in the graph by using graph neural networks. |
| [node_classification with GNN](/advanced-algorithms/available-algorithms/gnn_node_classification) | Python | Graph neural network-based node classification module |
| [node2vec](/advanced-algorithms/available-algorithms/node2vec) | Python | An algorithm for calculating node embeddings on static graph. |
Expand Down
1 change: 1 addition & 0 deletions pages/advanced-algorithms/available-algorithms/_meta.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ export default {
"elasticsearch_synchronization": "elasticsearch_synchronization",
"embeddings": "embeddings",
"export_util": "export_util",
"gnn": "gnn",
"gnn_link_prediction": "gnn_link_prediction",
"gnn_node_classification": "gnn_node_classification",
"graph_analyzer": "graph_analyzer",
Expand Down
285 changes: 285 additions & 0 deletions pages/advanced-algorithms/available-algorithms/gnn.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
---
title: gnn
description: Export and import graph data in PyTorch Geometric (PyG) and TensorFlow GNN (TF-GNN) formats. Enables GNN training pipelines by converting Memgraph graphs to framework-native JSON representations and writing back inference results.
---

import { Callout } from 'nextra/components'
import { Cards } from 'nextra/components'
import GitHub from '/components/icons/GitHub'

# gnn

GNN integration module for Memgraph. Provides export/import procedures for
**[PyTorch Geometric (PyG)](https://pytorch-geometric.readthedocs.io/en/latest/)** and **[TensorFlow GNN (TF-GNN)](https://github.com/tensorflow/gnn)** formats. All
exports produce a single JSON string that can be deserialized on the client side
and fed into the respective framework.

Typical workflow:

1. **Export** – extract the graph (or a subgraph) from Memgraph into a
JSON representation that PyG or TF-GNN can consume directly.

2. **Train / Infer** – use the exported data in your GNN pipeline outside
Memgraph.

3. **Import** – write new nodes and relationships back into Memgraph from
the framework's output, or update existing nodes with inference results.

<Cards>
<Cards.Card
icon={<GitHub />}
title="Source code"
href="https://github.com/memgraph/memgraph/blob/master/mage/python/gnn.py"
/>
</Cards>

| Trait | Value |
| ------------------- | ---------- |
| **Module type** | module |
| **Implementation** | Python |
| **Parallelism** | sequential |

## Procedures

### `pyg_export()`

Exports the current graph to a JSON string in **PyTorch Geometric** format.

The JSON payload contains:
- `edge_index` – source and destination index arrays.
- `x` – node feature matrix (when `node_property_names` is provided).
- `edge_attr` – edge feature matrix (when `edge_property_names` is provided).
- `y` – node labels (when `node_label_property` is provided).
- `num_nodes` – total number of nodes.
- `node_id_mapping` / `idx_to_node_id` – bidirectional mapping between
Memgraph internal IDs and PyG indices (used for write-back).
- `labels` – original node labels.
- `edge_types` – original relationship types.

{<h4 className="custom-header"> Input: </h4>}

- `node_property_names: List[string] (default = null)` ➡ Node properties to
include in the feature matrix `x`. Numeric properties are cast to floats;
list properties are flattened.
- `edge_property_names: List[string] (default = null)` ➡ Edge properties to
include in `edge_attr`.
- `node_label_property: string (default = null)` ➡ Node property to use as
the target label vector `y`.

{<h4 className="custom-header"> Output: </h4>}

- `json_data: string` ➡ A JSON string representing the graph in PyG format.

{<h4 className="custom-header"> Usage: </h4>}

Export features `feat` and edge attribute `weight`, with `class` as the
target label:

```cypher
CALL gnn.pyg_export(["feat"], ["weight"], "class")
YIELD json_data
RETURN json_data;
```

Export with no features (topology only):

```cypher
CALL gnn.pyg_export()
YIELD json_data
RETURN json_data;
```

---

### `pyg_import()`

Imports data from a PyG JSON string into Memgraph. Supports two modes:

- **Create mode** (default) – creates new nodes and relationships.
- **Update mode** (`update_existing = true`) – uses the `idx_to_node_id`
mapping in the JSON payload to find existing Memgraph nodes and sets
properties on them. This is the typical *export → inference → write-back*
workflow.

{<h4 className="custom-header"> Input: </h4>}

- `json_data: string` ➡ JSON string previously produced by `pyg_export()` (or
any compatible PyG-format JSON).
- `default_node_label: string (default = "PyGNode")` ➡ Label assigned to
created nodes when no label information is present in the JSON.
- `default_edge_type: string (default = "CONNECTS")` ➡ Relationship type
assigned to created relationships when no type information is present.
- `node_property_names: List[string] (default = null)` ➡ Names to assign
to individual feature columns when importing the feature matrix `x`.
- `edge_property_names: List[string] (default = null)` ➡ Names to assign
to individual edge-attribute columns when importing `edge_attr`.
- `update_existing: boolean (default = false)` ➡ When `true`, existing nodes
are updated instead of creating new ones.

{<h4 className="custom-header"> Output: </h4>}

- `nodes_created: integer` ➡ Number of nodes created (0 in update mode).
- `edges_created: integer` ➡ Number of relationships created (0 in update mode).
- `nodes_updated: integer` ➡ Number of existing nodes updated (0 in create mode).

{<h4 className="custom-header"> Usage: </h4>}

**Roundtrip example** – export from the graph and import as new nodes:

```cypher
CALL gnn.pyg_export(["feat"], ["weight"], "class")
YIELD json_data
WITH json_data
CALL gnn.pyg_import(json_data, "Imported", "IMP", ["feat"], ["weight"])
YIELD nodes_created, edges_created
RETURN nodes_created, edges_created;
```

**Write-back example** – update existing nodes with predictions after
inference:

```cypher
CALL gnn.pyg_import($json_data, "Node", "EDGE", ["prediction"], null, true)
YIELD nodes_updated
RETURN nodes_updated;
```

<Callout type="info">
In update mode the procedure uses the `idx_to_node_id` mapping inside the JSON
payload to look up existing vertices by their Memgraph internal ID. Make sure
the JSON was originally exported from the same database.
</Callout>

---

### `tf_export()`

Exports the current graph to a JSON string in **TF-GNN** format.

The JSON payload contains:
- `schema` – describes node sets, edge sets and their feature schemas
(dtypes and shapes), matching the TF-GNN `GraphSchema` structure.
- `graph` – the actual graph data with feature values, sizes and adjacency
information.

{<h4 className="custom-header"> Input: </h4>}

- `node_property_names: List[string] (default = null)` ➡ Node properties to
include as node-set features.
- `edge_property_names: List[string] (default = null)` ➡ Edge properties to
include as edge-set features.
- `node_set_name: string (default = "node")` ➡ Name of the node set in the
TF-GNN schema.
- `edge_set_name: string (default = "edge")` ➡ Name of the edge set in the
TF-GNN schema.

{<h4 className="custom-header"> Output: </h4>}

- `json_data: string` ➡ A JSON string representing the graph in TF-GNN format.

{<h4 className="custom-header"> Usage: </h4>}

Export with node property `score` and edge property `weight`:

```cypher
CALL gnn.tf_export(["score"], ["weight"])
YIELD json_data
RETURN json_data;
```

Specify custom set names:

```cypher
CALL gnn.tf_export(["score"], ["weight"], "items", "similarities")
YIELD json_data
RETURN json_data;
```

---

### `tf_import()`

Imports data from a TF-GNN JSON string into Memgraph, creating new nodes
and relationships.

{<h4 className="custom-header"> Input: </h4>}

- `json_data: string` ➡ JSON string previously produced by `tf_export()` (or
any compatible TF-GNN-format JSON).
- `default_node_label: string (default = "TfGnnNode")` ➡ Label assigned to
created nodes when no label information is present.
- `default_edge_type: string (default = "CONNECTS")` ➡ Relationship type
assigned to created relationships when no type information is present.

{<h4 className="custom-header"> Output: </h4>}

- `nodes_created: integer` ➡ Number of nodes created.
- `edges_created: integer` ➡ Number of relationships created.

{<h4 className="custom-header"> Usage: </h4>}

**Roundtrip example** – export and re-import:

```cypher
CALL gnn.tf_export(["score"], ["weight"])
YIELD json_data
WITH json_data
CALL gnn.tf_import(json_data, "TfNode", "TF_EDGE")
YIELD nodes_created, edges_created
RETURN nodes_created, edges_created;
```

---

## Example

The following end-to-end example shows how to move graph data through a PyG
training pipeline.

**1. Create sample data:**

```cypher
CREATE (a:Person {feat: [1.0, 2.0], age: 30, class: 0})
-[:KNOWS {weight: 0.5}]->
(b:Person {feat: [3.0, 4.0], age: 25, class: 1})
-[:KNOWS {weight: 0.8}]->
(c:Person {feat: [5.0, 6.0], age: 35, class: 0});
```

**2. Export to PyG format:**

```cypher
CALL gnn.pyg_export(["feat"], ["weight"], "class")
YIELD json_data
RETURN json_data;
```

**3. Use the JSON payload in Python (client-side):**

```python
import json
import torch
from torch_geometric.data import Data

# result is the json_data string returned by Memgraph
pyg_dict = json.loads(result)

data = Data(
x=torch.tensor(pyg_dict["x"], dtype=torch.float),
edge_index=torch.tensor(pyg_dict["edge_index"], dtype=torch.long),
edge_attr=torch.tensor(pyg_dict["edge_attr"], dtype=torch.float),
y=torch.tensor(pyg_dict["y"], dtype=torch.long),
)
# Train your model ...
```

**4. Write predictions back to Memgraph:**

After inference, update your JSON payload with the predictions and call
`pyg_import` with `update_existing` set to `true`:

```cypher
CALL gnn.pyg_import($updated_json, "Person", "KNOWS", ["prediction"], null, true)
YIELD nodes_updated
RETURN nodes_updated;
```