From 2d5624c39cc2075a921b82b279ddd5a7533a57de Mon Sep 17 00:00:00 2001 From: mkolodner Date: Tue, 12 May 2026 22:36:30 +0000 Subject: [PATCH] Update --- gigl/distributed/base_dist_loader.py | 5 +- gigl/distributed/dataset_factory.py | 12 +- gigl/distributed/dist_ablp_neighborloader.py | 5 + gigl/distributed/dist_partitioner.py | 11 +- .../distributed/distributed_neighborloader.py | 5 + .../run_distributed_partitioner.py | 7 +- .../distributed_neighborloader_test.py | 112 ++++++++++++++++++ .../distributed_partitioner_test.py | 76 ++++++++++++ 8 files changed, 229 insertions(+), 4 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 203c8520d..25f1ef490 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -340,6 +340,7 @@ def create_sampling_config( batch_size: int = 1, shuffle: bool = False, drop_last: bool = False, + seed: Optional[int] = None, ) -> SamplingConfig: """Creates a SamplingConfig with patched fanout. @@ -352,6 +353,8 @@ def create_sampling_config( batch_size: How many samples per batch. shuffle: Whether to shuffle input nodes. drop_last: Whether to drop the last incomplete batch. + seed: When provided, seeds the sampling RNG so that the same inputs produce the same + batches across runs. When None, sampling is non-deterministic. Returns: A fully configured SamplingConfig. @@ -371,7 +374,7 @@ def create_sampling_config( with_neg=False, with_weight=False, edge_dir=dataset_schema.edge_dir, - seed=None, + seed=seed, ) @staticmethod diff --git a/gigl/distributed/dataset_factory.py b/gigl/distributed/dataset_factory.py index ffa13fc71..3327e7319 100644 --- a/gigl/distributed/dataset_factory.py +++ b/gigl/distributed/dataset_factory.py @@ -66,6 +66,7 @@ def _load_and_build_partitioned_dataset( edge_tf_dataset_options: TFDatasetOptions, splitter: Optional[Union[NodeSplitter, NodeAnchorLinkSplitter]] = None, _ssl_positive_label_percentage: Optional[float] = None, + seed: Optional[int] = None, ) -> DistDataset: """ Given some information about serialized TFRecords, loads and builds a partitioned dataset into a DistDataset class. @@ -82,6 +83,8 @@ def _load_and_build_partitioned_dataset( splitter (Optional[Union[NodeSplitter, NodeAnchorLinkSplitter]]): Optional splitter to use for splitting the graph data into train, val, and test sets. If not provided (None), no splitting will be performed. _ssl_positive_label_percentage (Optional[float]): Percentage of edges to select as self-supervised labels. Must be None if supervised edge labels are provided in advance. Slotted for refactor once this functionality is available in the transductive `splitter` directly + seed (Optional[int]): When provided, seeds the partitioner RNG so that the same inputs produce the same + partition assignment across runs. When None, partitioning is non-deterministic. Returns: DistDataset: Initialized dataset with partitioned graph information @@ -164,7 +167,8 @@ def _load_and_build_partitioned_dataset( f"Initializing {partitioner_class.__name__} instance while partitioning edges to its destination node machine" ) partitioner = partitioner_class( - should_assign_edges_by_src_node=should_assign_edges_by_src_node + should_assign_edges_by_src_node=should_assign_edges_by_src_node, + seed=seed, ) partitioner.register_node_ids(node_ids=loaded_graph_tensors.node_ids) @@ -230,6 +234,7 @@ def _build_dataset_process( edge_tf_dataset_options: TFDatasetOptions, splitter: Optional[Union[NodeSplitter, NodeAnchorLinkSplitter]] = None, _ssl_positive_label_percentage: Optional[float] = None, + seed: Optional[int] = None, ) -> None: """ This function is spawned by a single process per machine and is responsible for: @@ -311,6 +316,7 @@ def _build_dataset_process( edge_tf_dataset_options=edge_tf_dataset_options, splitter=splitter, _ssl_positive_label_percentage=_ssl_positive_label_percentage, + seed=seed, ) output_dict["dataset"] = output_dataset @@ -336,6 +342,7 @@ def build_dataset( _dataset_building_port: Optional[ int ] = None, # WARNING: This field will be deprecated in the future + seed: Optional[int] = None, ) -> DistDataset: """ Launches a spawned process for building and returning a DistDataset instance provided some @@ -368,6 +375,8 @@ def build_dataset( Slotted for refactor once this functionality is available in the transductive `splitter` directly _dataset_building_port (deprecated field - will be removed soon) (Optional[int]): Contains information about master port. Defaults to None, in which case it will be initialized from the current torch.distributed context. + seed (Optional[int]): When provided, seeds the partitioner RNG so that the same inputs produce the same + partition assignment across runs. When None, partitioning is non-deterministic. Returns: DistDataset: Built GraphLearn-for-PyTorch Dataset class @@ -463,6 +472,7 @@ def build_dataset( edge_tf_dataset_options, splitter, _ssl_positive_label_percentage, + seed, ), ) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 989907640..c4a5c1041 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -96,6 +96,7 @@ def __init__( local_process_rank: Optional[int] = None, # TODO: (svij) Deprecate this local_process_world_size: Optional[int] = None, # TODO: (svij) Deprecate this non_blocking_transfers: bool = True, + seed: Optional[int] = None, ): """ Neighbor loader for Anchor Based Link Prediction (ABLP) tasks. @@ -215,6 +216,9 @@ def __init__( is used instead. See https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html for background on pinned memory and non-blocking transfers. + seed (Optional[int]): When provided, seeds the sampling RNG so that the same inputs + produce the same batches across runs. When None, sampling is non-deterministic. + (default: ``None``). """ # Set self._shutdowned right away, that way if we throw here, and __del__ is called, @@ -349,6 +353,7 @@ def __init__( batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, + seed=seed, ) producer: Optional[DistSamplingProducer] = None diff --git a/gigl/distributed/dist_partitioner.py b/gigl/distributed/dist_partitioner.py index 514ba5193..83449f0ec 100644 --- a/gigl/distributed/dist_partitioner.py +++ b/gigl/distributed/dist_partitioner.py @@ -161,6 +161,7 @@ def __init__( Union[torch.Tensor, dict[EdgeType, torch.Tensor]] ] = None, node_labels: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]] = None, + seed: Optional[int] = None, ): """ Initializes the parameters of the partitioner. Also optionally takes in node and edge tensors as arguments and registers them to the partitioner. Registered @@ -176,6 +177,8 @@ def __init__( positive_labels (Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]): Optionally registered positive labels from input. Tensors should be of shape [2, num_pos_labels_on_current_rank] negative_labels (Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]): Optionally registered negative labels from input. Tensors should be of shape [2, num_neg_labels_on_current_rank] node_labels (Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]]): Optionally registered node labels from input. Tensors should be of shape [num_nodes_on_current_rank, node_label_dim] + seed (Optional[int]): When provided, seeds the random permutation used during node partitioning so that + the same inputs produce the same partition assignment across runs. When None, partitioning is non-deterministic. """ self._world_size: int @@ -183,6 +186,7 @@ def __init__( self._partition_mgr: _DistLinkPredicitonPartitionManager self._is_rpc_initialized: bool = False + self._seed: Optional[int] = seed self._is_input_homogeneous: Optional[bool] = None self._should_assign_edges_by_src_node: bool = should_assign_edges_by_src_node @@ -847,7 +851,12 @@ def _partition_node(self, node_type: NodeType) -> PartitionBook: # TODO (mkolodner-sc): Explore other node partitioning strategies here beyond random permutation def _node_pfn(n_ids, _): partition_idx = n_ids % self._world_size - rand_order = torch.randperm(len(n_ids)) + if self._seed is not None: + generator = torch.Generator() + generator.manual_seed(self._seed) + rand_order = torch.randperm(len(n_ids), generator=generator) + else: + rand_order = torch.randperm(len(n_ids)) return partition_idx[rand_order] partitioned_results, node_partition_book = self._partition_by_chunk( diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 3d6d5a34b..e17da3c46 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -92,6 +92,7 @@ def __init__( drop_last: bool = False, sampler_options: Optional[SamplerOptions] = None, non_blocking_transfers: bool = True, + seed: Optional[int] = None, ): """ Distributed Neighbor Loader. @@ -170,6 +171,9 @@ def __init__( is used instead. See https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html for background on pinned memory and non-blocking transfers. + seed (Optional[int]): When provided, seeds the sampling RNG so that the same inputs + produce the same batches across runs. When None, sampling is non-deterministic. + (default: ``None``). """ # Set self._shutdowned right away, that way if we throw here, and __del__ is called, @@ -263,6 +267,7 @@ def __init__( batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, + seed=seed, ) producer: Optional[DistSamplingProducer] = None diff --git a/tests/test_assets/distributed/run_distributed_partitioner.py b/tests/test_assets/distributed/run_distributed_partitioner.py index 96cfb25cf..068cfae11 100644 --- a/tests/test_assets/distributed/run_distributed_partitioner.py +++ b/tests/test_assets/distributed/run_distributed_partitioner.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Type, Union +from typing import Optional, Type, Union import torch from graphlearn_torch.distributed import init_rpc, init_worker_group @@ -31,6 +31,7 @@ def run_distributed_partitioner( master_port: int, input_data_strategy: InputDataStrategy, partitioner_class: Type[DistPartitioner], + seed: Optional[int] = None, ) -> None: """ Runs the distributed partitioner on a specific rank. @@ -44,6 +45,7 @@ def run_distributed_partitioner( master_port (int): Master port for initializing rpc for partitioning input_data_strategy (InputDataStrategy): Strategy for registering inputs to the partitioner partitioner_class (Type[DistPartitioner]): The class to use for partitioning + seed (Optional[int]): Seed for deterministic node partitioning. When None, partitioning is non-deterministic. """ input_graph = rank_to_input_graph[rank] @@ -81,6 +83,7 @@ def run_distributed_partitioner( if input_data_strategy == InputDataStrategy.REGISTER_ALL_ENTITIES_SEPARATELY: dist_partitioner = partitioner_class( should_assign_edges_by_src_node=should_assign_edges_by_src_node, + seed=seed, ) # We call del to mimic the real use case for handling these input tensors dist_partitioner.register_node_ids(node_ids=node_ids) @@ -139,6 +142,7 @@ def run_distributed_partitioner( elif input_data_strategy == InputDataStrategy.REGISTER_MINIMAL_ENTITIES_SEPARATELY: dist_partitioner = partitioner_class( should_assign_edges_by_src_node=should_assign_edges_by_src_node, + seed=seed, ) # We call del to mimic the real use case for handling these input tensors dist_partitioner.register_node_ids(node_ids=node_ids) @@ -176,6 +180,7 @@ def run_distributed_partitioner( positive_labels=positive_labels, negative_labels=negative_labels, node_labels=node_labels, + seed=seed, ) # We call del to mimic the real use case for handling these input tensors del ( diff --git a/tests/unit/distributed/distributed_neighborloader_test.py b/tests/unit/distributed/distributed_neighborloader_test.py index 309f047a5..be58e0461 100644 --- a/tests/unit/distributed/distributed_neighborloader_test.py +++ b/tests/unit/distributed/distributed_neighborloader_test.py @@ -1,5 +1,6 @@ import unittest from collections.abc import Mapping +from multiprocessing import Manager import torch import torch.multiprocessing as mp @@ -8,9 +9,11 @@ from parameterized import param, parameterized from torch_geometric.data import Data, HeteroData +from gigl.distributed.base_dist_loader import BaseDistLoader from gigl.distributed.dataset_factory import build_dataset from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.distributed_neighborloader import DistNeighborLoader +from gigl.distributed.utils.neighborloader import DatasetSchema from gigl.distributed.utils import get_free_port from gigl.distributed.utils.serialized_graph_metadata_translator import ( convert_pb_to_serialized_graph_metadata, @@ -398,6 +401,27 @@ def _run_cora_supervised_node_classification( shutdown_rpc() +def _run_seeded_e2e_loader_worker( + _: int, + dataset: DistDataset, + output_list, + seed: int, +) -> None: + create_test_process_group() + loader = DistNeighborLoader( + dataset=dataset, + num_neighbors=[1, 1], + pin_memory_device=torch.device("cpu"), + batch_size=5, + shuffle=True, + seed=seed, + ) + for datum in loader: + assert isinstance(datum, Data) + output_list.append(datum.node.clone()) + shutdown_rpc() + + class DistributedNeighborLoaderTest(TestCase): def setUp(self): super().setUp() @@ -861,5 +885,93 @@ def test_distributed_neighbor_loader_invalid_inputs_colocated( DistNeighborLoader(**kwargs) +class CreateSamplingConfigSeedTestCase(TestCase): + """Tests that create_sampling_config correctly threads the seed into SamplingConfig.""" + + def _make_dataset_schema(self) -> DatasetSchema: + return DatasetSchema( + is_homogeneous_with_labeled_edge_type=False, + edge_types=[DEFAULT_HOMOGENEOUS_EDGE_TYPE], + node_feature_info=None, + edge_feature_info=None, + edge_dir="in", + ) + + def test_seed_is_propagated_to_sampling_config(self) -> None: + schema = self._make_dataset_schema() + config = BaseDistLoader.create_sampling_config( + num_neighbors=[5, 5], + dataset_schema=schema, + seed=42, + ) + self.assertEqual(config.seed, 42) + + def test_no_seed_leaves_config_seed_none(self) -> None: + schema = self._make_dataset_schema() + config = BaseDistLoader.create_sampling_config( + num_neighbors=[5, 5], + dataset_schema=schema, + ) + self.assertIsNone(config.seed) + + +class DistNeighborLoaderE2EDeterminismTestCase(TestCase): + """Verifies that identical seeds produce identical batch ordering and neighbor samples end-to-end.""" + + _N: int = 50 + + def _make_ring_dataset(self) -> DistDataset: + N = self._N + src = torch.arange(N) + dst = (torch.arange(N) + 1) % N + edge_index = torch.stack([src, dst]) + partition_output = PartitionOutput( + node_partition_book=torch.zeros(N, dtype=torch.int64), + edge_partition_book=None, + partitioned_edge_index=GraphPartitionData( + edge_index=edge_index, + edge_ids=None, + ), + partitioned_node_features=FeaturePartitionData( + feats=torch.zeros(N, 2), ids=torch.arange(N) + ), + partitioned_edge_features=None, + partitioned_positive_labels=None, + partitioned_negative_labels=None, + partitioned_node_labels=None, + ) + dataset = DistDataset(rank=0, world_size=1, edge_dir="in") + dataset.build(partition_output=partition_output) + return dataset + + def _collect_batches(self, dataset: DistDataset, seed: int) -> list[torch.Tensor]: + manager = Manager() + output_list = manager.list() + mp.spawn( + fn=_run_seeded_e2e_loader_worker, + args=(dataset, output_list, seed), + nprocs=1, + join=True, + ) + return list(output_list) + + def test_seeded_sampling_is_e2e_deterministic(self) -> None: + """Same seed must yield identical batch ordering and neighbor samples on a 50-node ring graph.""" + dataset = self._make_ring_dataset() + batches_run1 = self._collect_batches(dataset, seed=42) + batches_run2 = self._collect_batches(dataset, seed=42) + + self.assertEqual( + len(batches_run1), + len(batches_run2), + msg="Number of batches differs between runs with identical seed", + ) + for i, (b1, b2) in enumerate(zip(batches_run1, batches_run2)): + self.assertTrue( + torch.equal(b1, b2), + msg=f"Batch {i} node tensors differ despite identical seed", + ) + + if __name__ == "__main__": absltest.main() diff --git a/tests/unit/distributed/distributed_partitioner_test.py b/tests/unit/distributed/distributed_partitioner_test.py index 1ac56df31..a258a8f88 100644 --- a/tests/unit/distributed/distributed_partitioner_test.py +++ b/tests/unit/distributed/distributed_partitioner_test.py @@ -1419,5 +1419,81 @@ def test_heterogeneous_re_registration(self) -> None: partitioner3.register_node_labels(node_labels=node_labels) +def _run_large_node_partition_worker( + rank: int, + output_dict: MutableMapping[int, Union[torch.Tensor, dict[NodeType, torch.Tensor]]], + master_addr: str, + master_port: int, + seed: Optional[int], + n_nodes_per_rank: int, +) -> None: + """Spawnable worker: partitions a large homogeneous graph and stores the node partition book.""" + init_worker_group(world_size=MOCKED_NUM_PARTITIONS, rank=rank) + init_rpc(master_addr=master_addr, master_port=master_port, num_rpc_threads=4) + + start = rank * n_nodes_per_rank + node_ids = torch.arange(start, start + n_nodes_per_rank, dtype=torch.int64) + + partitioner = DistPartitioner(should_assign_edges_by_src_node=False, seed=seed) + partitioner.register_node_ids(node_ids=node_ids) + node_partition_book = partitioner.partition_node() + + output_dict[rank] = node_partition_book + + +class DistPartitionerDeterminismTestCase(TestCase): + """Tests that seeded partitioning is reproducible.""" + + _N_NODES_PER_RANK: int = 50 + + def setUp(self) -> None: + self._master_ip_address = "localhost" + + def _run_large_partitioner_with_seed( + self, seed: Optional[int] + ) -> dict[int, Union[torch.Tensor, dict[NodeType, torch.Tensor]]]: + """Runs node partitioning on a large graph and returns the node partition books by rank.""" + master_port = get_free_port() + manager = Manager() + output_dict: MutableMapping[ + int, Union[torch.Tensor, dict[NodeType, torch.Tensor]] + ] = manager.dict() + + mp.spawn( + _run_large_node_partition_worker, + args=( + output_dict, + self._master_ip_address, + master_port, + seed, + self._N_NODES_PER_RANK, + ), + nprocs=MOCKED_NUM_PARTITIONS, + join=True, + ) + return dict(output_dict) + + def test_seeded_partitioning_is_deterministic_on_large_graph(self) -> None: + """Same seed must produce identical partition books on a large graph (50 nodes per rank). + + The small mocked graph has too few nodes for a meaningful determinism check since + there are only a handful of distinct partition outcomes. This test uses a larger + graph to provide stronger evidence that seeding is applied correctly. + """ + partition_books_run1 = self._run_large_partitioner_with_seed(seed=42) + partition_books_run2 = self._run_large_partitioner_with_seed(seed=42) + + for rank in range(MOCKED_NUM_PARTITIONS): + pb1 = partition_books_run1[rank] + pb2 = partition_books_run2[rank] + assert isinstance(pb1, torch.Tensor) and isinstance(pb2, torch.Tensor), ( + f"Expected tensor partition books on rank {rank}" + ) + self.assertTrue( + torch.equal(pb1, pb2), + msg=f"Node partition books differ on rank {rank} despite identical seed", + ) + + if __name__ == "__main__": absltest.main()