Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion gigl/distributed/base_dist_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ def create_sampling_config(
batch_size: int = 1,
shuffle: bool = False,
drop_last: bool = False,
seed: Optional[int] = None,
) -> SamplingConfig:
"""Creates a SamplingConfig with patched fanout.

Expand All @@ -352,6 +353,8 @@ def create_sampling_config(
batch_size: How many samples per batch.
shuffle: Whether to shuffle input nodes.
drop_last: Whether to drop the last incomplete batch.
seed: When provided, seeds the sampling RNG so that the same inputs produce the same
batches across runs. When None, sampling is non-deterministic.

Returns:
A fully configured SamplingConfig.
Expand All @@ -371,7 +374,7 @@ def create_sampling_config(
with_neg=False,
with_weight=False,
edge_dir=dataset_schema.edge_dir,
seed=None,
seed=seed,
)

@staticmethod
Expand Down
12 changes: 11 additions & 1 deletion gigl/distributed/dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def _load_and_build_partitioned_dataset(
edge_tf_dataset_options: TFDatasetOptions,
splitter: Optional[Union[NodeSplitter, NodeAnchorLinkSplitter]] = None,
_ssl_positive_label_percentage: Optional[float] = None,
seed: Optional[int] = None,
) -> DistDataset:
"""
Given some information about serialized TFRecords, loads and builds a partitioned dataset into a DistDataset class.
Expand All @@ -82,6 +83,8 @@ def _load_and_build_partitioned_dataset(
splitter (Optional[Union[NodeSplitter, NodeAnchorLinkSplitter]]): Optional splitter to use for splitting the graph data into train, val, and test sets. If not provided (None), no splitting will be performed.
_ssl_positive_label_percentage (Optional[float]): Percentage of edges to select as self-supervised labels. Must be None if supervised edge labels are provided in advance.
Slotted for refactor once this functionality is available in the transductive `splitter` directly
seed (Optional[int]): When provided, seeds the partitioner RNG so that the same inputs produce the same
partition assignment across runs. When None, partitioning is non-deterministic.
Returns:
DistDataset: Initialized dataset with partitioned graph information

Expand Down Expand Up @@ -164,7 +167,8 @@ def _load_and_build_partitioned_dataset(
f"Initializing {partitioner_class.__name__} instance while partitioning edges to its destination node machine"
)
partitioner = partitioner_class(
should_assign_edges_by_src_node=should_assign_edges_by_src_node
should_assign_edges_by_src_node=should_assign_edges_by_src_node,
seed=seed,
)

partitioner.register_node_ids(node_ids=loaded_graph_tensors.node_ids)
Expand Down Expand Up @@ -230,6 +234,7 @@ def _build_dataset_process(
edge_tf_dataset_options: TFDatasetOptions,
splitter: Optional[Union[NodeSplitter, NodeAnchorLinkSplitter]] = None,
_ssl_positive_label_percentage: Optional[float] = None,
seed: Optional[int] = None,
) -> None:
"""
This function is spawned by a single process per machine and is responsible for:
Expand Down Expand Up @@ -311,6 +316,7 @@ def _build_dataset_process(
edge_tf_dataset_options=edge_tf_dataset_options,
splitter=splitter,
_ssl_positive_label_percentage=_ssl_positive_label_percentage,
seed=seed,
)

output_dict["dataset"] = output_dataset
Expand All @@ -336,6 +342,7 @@ def build_dataset(
_dataset_building_port: Optional[
int
] = None, # WARNING: This field will be deprecated in the future
seed: Optional[int] = None,
) -> DistDataset:
"""
Launches a spawned process for building and returning a DistDataset instance provided some
Expand Down Expand Up @@ -368,6 +375,8 @@ def build_dataset(
Slotted for refactor once this functionality is available in the transductive `splitter` directly
_dataset_building_port (deprecated field - will be removed soon) (Optional[int]): Contains information about master port. Defaults to None, in which case it will
be initialized from the current torch.distributed context.
seed (Optional[int]): When provided, seeds the partitioner RNG so that the same inputs produce the same
partition assignment across runs. When None, partitioning is non-deterministic.

Returns:
DistDataset: Built GraphLearn-for-PyTorch Dataset class
Expand Down Expand Up @@ -463,6 +472,7 @@ def build_dataset(
edge_tf_dataset_options,
splitter,
_ssl_positive_label_percentage,
seed,
),
)

Expand Down
5 changes: 5 additions & 0 deletions gigl/distributed/dist_ablp_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
local_process_rank: Optional[int] = None, # TODO: (svij) Deprecate this
local_process_world_size: Optional[int] = None, # TODO: (svij) Deprecate this
non_blocking_transfers: bool = True,
seed: Optional[int] = None,
):
"""
Neighbor loader for Anchor Based Link Prediction (ABLP) tasks.
Expand Down Expand Up @@ -215,6 +216,9 @@ def __init__(
is used instead.
See https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html
for background on pinned memory and non-blocking transfers.
seed (Optional[int]): When provided, seeds the sampling RNG so that the same inputs
produce the same batches across runs. When None, sampling is non-deterministic.
(default: ``None``).
"""

# Set self._shutdowned right away, that way if we throw here, and __del__ is called,
Expand Down Expand Up @@ -349,6 +353,7 @@ def __init__(
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
seed=seed,
)

producer: Optional[DistSamplingProducer] = None
Expand Down
11 changes: 10 additions & 1 deletion gigl/distributed/dist_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __init__(
Union[torch.Tensor, dict[EdgeType, torch.Tensor]]
] = None,
node_labels: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]] = None,
seed: Optional[int] = None,
):
"""
Initializes the parameters of the partitioner. Also optionally takes in node and edge tensors as arguments and registers them to the partitioner. Registered
Expand All @@ -176,13 +177,16 @@ def __init__(
positive_labels (Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]): Optionally registered positive labels from input. Tensors should be of shape [2, num_pos_labels_on_current_rank]
negative_labels (Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]): Optionally registered negative labels from input. Tensors should be of shape [2, num_neg_labels_on_current_rank]
node_labels (Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]]): Optionally registered node labels from input. Tensors should be of shape [num_nodes_on_current_rank, node_label_dim]
seed (Optional[int]): When provided, seeds the random permutation used during node partitioning so that
the same inputs produce the same partition assignment across runs. When None, partitioning is non-deterministic.
"""

self._world_size: int
self._rank: int
self._partition_mgr: _DistLinkPredicitonPartitionManager

self._is_rpc_initialized: bool = False
self._seed: Optional[int] = seed

self._is_input_homogeneous: Optional[bool] = None
self._should_assign_edges_by_src_node: bool = should_assign_edges_by_src_node
Expand Down Expand Up @@ -847,7 +851,12 @@ def _partition_node(self, node_type: NodeType) -> PartitionBook:
# TODO (mkolodner-sc): Explore other node partitioning strategies here beyond random permutation
def _node_pfn(n_ids, _):
partition_idx = n_ids % self._world_size
rand_order = torch.randperm(len(n_ids))
if self._seed is not None:
generator = torch.Generator()
generator.manual_seed(self._seed)
rand_order = torch.randperm(len(n_ids), generator=generator)
else:
rand_order = torch.randperm(len(n_ids))
return partition_idx[rand_order]

partitioned_results, node_partition_book = self._partition_by_chunk(
Expand Down
5 changes: 5 additions & 0 deletions gigl/distributed/distributed_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(
drop_last: bool = False,
sampler_options: Optional[SamplerOptions] = None,
non_blocking_transfers: bool = True,
seed: Optional[int] = None,
):
"""
Distributed Neighbor Loader.
Expand Down Expand Up @@ -170,6 +171,9 @@ def __init__(
is used instead.
See https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html
for background on pinned memory and non-blocking transfers.
seed (Optional[int]): When provided, seeds the sampling RNG so that the same inputs
produce the same batches across runs. When None, sampling is non-deterministic.
(default: ``None``).
"""

# Set self._shutdowned right away, that way if we throw here, and __del__ is called,
Expand Down Expand Up @@ -263,6 +267,7 @@ def __init__(
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
seed=seed,
)

producer: Optional[DistSamplingProducer] = None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Type, Union
from typing import Optional, Type, Union

import torch
from graphlearn_torch.distributed import init_rpc, init_worker_group
Expand Down Expand Up @@ -31,6 +31,7 @@ def run_distributed_partitioner(
master_port: int,
input_data_strategy: InputDataStrategy,
partitioner_class: Type[DistPartitioner],
seed: Optional[int] = None,
) -> None:
"""
Runs the distributed partitioner on a specific rank.
Expand All @@ -44,6 +45,7 @@ def run_distributed_partitioner(
master_port (int): Master port for initializing rpc for partitioning
input_data_strategy (InputDataStrategy): Strategy for registering inputs to the partitioner
partitioner_class (Type[DistPartitioner]): The class to use for partitioning
seed (Optional[int]): Seed for deterministic node partitioning. When None, partitioning is non-deterministic.
"""

input_graph = rank_to_input_graph[rank]
Expand Down Expand Up @@ -81,6 +83,7 @@ def run_distributed_partitioner(
if input_data_strategy == InputDataStrategy.REGISTER_ALL_ENTITIES_SEPARATELY:
dist_partitioner = partitioner_class(
should_assign_edges_by_src_node=should_assign_edges_by_src_node,
seed=seed,
)
# We call del to mimic the real use case for handling these input tensors
dist_partitioner.register_node_ids(node_ids=node_ids)
Expand Down Expand Up @@ -139,6 +142,7 @@ def run_distributed_partitioner(
elif input_data_strategy == InputDataStrategy.REGISTER_MINIMAL_ENTITIES_SEPARATELY:
dist_partitioner = partitioner_class(
should_assign_edges_by_src_node=should_assign_edges_by_src_node,
seed=seed,
)
# We call del to mimic the real use case for handling these input tensors
dist_partitioner.register_node_ids(node_ids=node_ids)
Expand Down Expand Up @@ -176,6 +180,7 @@ def run_distributed_partitioner(
positive_labels=positive_labels,
negative_labels=negative_labels,
node_labels=node_labels,
seed=seed,
)
# We call del to mimic the real use case for handling these input tensors
del (
Expand Down
112 changes: 112 additions & 0 deletions tests/unit/distributed/distributed_neighborloader_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
from collections.abc import Mapping
from multiprocessing import Manager

import torch
import torch.multiprocessing as mp
Expand All @@ -8,9 +9,11 @@
from parameterized import param, parameterized
from torch_geometric.data import Data, HeteroData

from gigl.distributed.base_dist_loader import BaseDistLoader
from gigl.distributed.dataset_factory import build_dataset
from gigl.distributed.dist_dataset import DistDataset
from gigl.distributed.distributed_neighborloader import DistNeighborLoader
from gigl.distributed.utils.neighborloader import DatasetSchema
from gigl.distributed.utils import get_free_port
from gigl.distributed.utils.serialized_graph_metadata_translator import (
convert_pb_to_serialized_graph_metadata,
Expand Down Expand Up @@ -398,6 +401,27 @@ def _run_cora_supervised_node_classification(
shutdown_rpc()


def _run_seeded_e2e_loader_worker(
_: int,
dataset: DistDataset,
output_list,
seed: int,
) -> None:
create_test_process_group()
loader = DistNeighborLoader(
dataset=dataset,
num_neighbors=[1, 1],
pin_memory_device=torch.device("cpu"),
batch_size=5,
shuffle=True,
seed=seed,
)
for datum in loader:
assert isinstance(datum, Data)
output_list.append(datum.node.clone())
shutdown_rpc()


class DistributedNeighborLoaderTest(TestCase):
def setUp(self):
super().setUp()
Expand Down Expand Up @@ -861,5 +885,93 @@ def test_distributed_neighbor_loader_invalid_inputs_colocated(
DistNeighborLoader(**kwargs)


class CreateSamplingConfigSeedTestCase(TestCase):
"""Tests that create_sampling_config correctly threads the seed into SamplingConfig."""

def _make_dataset_schema(self) -> DatasetSchema:
return DatasetSchema(
is_homogeneous_with_labeled_edge_type=False,
edge_types=[DEFAULT_HOMOGENEOUS_EDGE_TYPE],
node_feature_info=None,
edge_feature_info=None,
edge_dir="in",
)

def test_seed_is_propagated_to_sampling_config(self) -> None:
schema = self._make_dataset_schema()
config = BaseDistLoader.create_sampling_config(
num_neighbors=[5, 5],
dataset_schema=schema,
seed=42,
)
self.assertEqual(config.seed, 42)

def test_no_seed_leaves_config_seed_none(self) -> None:
schema = self._make_dataset_schema()
config = BaseDistLoader.create_sampling_config(
num_neighbors=[5, 5],
dataset_schema=schema,
)
self.assertIsNone(config.seed)


class DistNeighborLoaderE2EDeterminismTestCase(TestCase):
"""Verifies that identical seeds produce identical batch ordering and neighbor samples end-to-end."""

_N: int = 50

def _make_ring_dataset(self) -> DistDataset:
N = self._N
src = torch.arange(N)
dst = (torch.arange(N) + 1) % N
edge_index = torch.stack([src, dst])
partition_output = PartitionOutput(
node_partition_book=torch.zeros(N, dtype=torch.int64),
edge_partition_book=None,
partitioned_edge_index=GraphPartitionData(
edge_index=edge_index,
edge_ids=None,
),
partitioned_node_features=FeaturePartitionData(
feats=torch.zeros(N, 2), ids=torch.arange(N)
),
partitioned_edge_features=None,
partitioned_positive_labels=None,
partitioned_negative_labels=None,
partitioned_node_labels=None,
)
dataset = DistDataset(rank=0, world_size=1, edge_dir="in")
dataset.build(partition_output=partition_output)
return dataset

def _collect_batches(self, dataset: DistDataset, seed: int) -> list[torch.Tensor]:
manager = Manager()
output_list = manager.list()
mp.spawn(
fn=_run_seeded_e2e_loader_worker,
args=(dataset, output_list, seed),
nprocs=1,
join=True,
)
return list(output_list)

def test_seeded_sampling_is_e2e_deterministic(self) -> None:
"""Same seed must yield identical batch ordering and neighbor samples on a 50-node ring graph."""
dataset = self._make_ring_dataset()
batches_run1 = self._collect_batches(dataset, seed=42)
batches_run2 = self._collect_batches(dataset, seed=42)

self.assertEqual(
len(batches_run1),
len(batches_run2),
msg="Number of batches differs between runs with identical seed",
)
for i, (b1, b2) in enumerate(zip(batches_run1, batches_run2)):
self.assertTrue(
torch.equal(b1, b2),
msg=f"Batch {i} node tensors differ despite identical seed",
)


if __name__ == "__main__":
absltest.main()
Loading