From 9ba159be7cd1ff1d3b5a00269267c2fa2dc15183 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Tue, 12 May 2026 21:00:32 +0000 Subject: [PATCH 01/13] Initial commit' --- gigl/distributed/base_dist_loader.py | 6 +- gigl/distributed/dataset_factory.py | 112 +++++++- gigl/distributed/dist_ablp_neighborloader.py | 29 ++ gigl/distributed/dist_dataset.py | 37 +++ gigl/distributed/dist_partitioner.py | 93 +++++- .../distributed/distributed_neighborloader.py | 29 ++ .../run_distributed_partitioner.py | 10 +- tests/test_assets/distributed/test_dataset.py | 5 +- .../distributed_weighted_sampling_test.py | 272 ++++++++++++++++++ 9 files changed, 582 insertions(+), 11 deletions(-) create mode 100644 tests/unit/distributed/distributed_weighted_sampling_test.py diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index c5046ca7d..52eab6b77 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -340,6 +340,7 @@ def create_sampling_config( batch_size: int = 1, shuffle: bool = False, drop_last: bool = False, + with_weight: bool = False, ) -> SamplingConfig: """Creates a SamplingConfig with patched fanout. @@ -352,6 +353,9 @@ def create_sampling_config( batch_size: How many samples per batch. shuffle: Whether to shuffle input nodes. drop_last: Whether to drop the last incomplete batch. + with_weight: Whether to use edge weights for sampling. Requires that + edge weights were registered during dataset construction via + ``DistPartitioner.register_edge_weights()``. Returns: A fully configured SamplingConfig. @@ -369,7 +373,7 @@ def create_sampling_config( with_edge=True, collect_features=True, with_neg=False, - with_weight=False, + with_weight=with_weight, edge_dir=dataset_schema.edge_dir, seed=None, ) diff --git a/gigl/distributed/dataset_factory.py b/gigl/distributed/dataset_factory.py index ffa13fc71..72d9a55b9 100644 --- a/gigl/distributed/dataset_factory.py +++ b/gigl/distributed/dataset_factory.py @@ -20,7 +20,7 @@ ) from gigl.common import Uri, UriFactory -from gigl.common.data.dataloaders import TFRecordDataLoader +from gigl.common.data.dataloaders import SerializedTFRecordInfo, TFRecordDataLoader from gigl.common.data.load_torch_tensors import ( SerializedGraphMetadata, TFDatasetOptions, @@ -43,6 +43,7 @@ ) from gigl.src.common.types.graph_data import EdgeType from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.types.graph import DEFAULT_HOMOGENEOUS_EDGE_TYPE from gigl.src.common.types.pb_wrappers.task_metadata import TaskMetadataType from gigl.utils.data_splitters import ( DistNodeAnchorLinkSplitter, @@ -56,6 +57,93 @@ logger = Logger() +def _extract_weight_column( + edge_features: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], + weight_edge_feat_name: Union[str, dict[EdgeType, str]], + edge_entity_info: Union[SerializedTFRecordInfo, dict[EdgeType, SerializedTFRecordInfo]], +) -> Tuple[ + dict[EdgeType, torch.Tensor], + Union[torch.Tensor, dict[EdgeType, torch.Tensor]], +]: + """Extracts a weight column from edge features, removing it from the feature tensor. + + Returns a tuple of (edge_weights_by_type, trimmed_edge_features). The weight column + is removed from the feature tensor so it is not duplicated in memory. + + Args: + edge_features: Edge feature tensor(s). + weight_edge_feat_name: Name of the weight feature column, either a single string + (applied to all edge types) or a per-edge-type mapping. + edge_entity_info: SerializedTFRecordInfo carrying ordered ``feature_keys`` used + to resolve the column name to an index. + + Returns: + Tuple of (weights_by_type, trimmed_features) where ``weights_by_type`` maps each + edge type to its extracted 1-D weight tensor and ``trimmed_features`` has the + weight column removed. + """ + # Normalise edge_features to heterogeneous format for uniform processing + is_homogeneous_input = isinstance(edge_features, torch.Tensor) + if is_homogeneous_input: + assert isinstance(edge_features, torch.Tensor) + edge_features_dict: dict[EdgeType, torch.Tensor] = { + DEFAULT_HOMOGENEOUS_EDGE_TYPE: edge_features + } + else: + assert isinstance(edge_features, dict) + edge_features_dict = edge_features + + # Normalise edge_entity_info to heterogeneous format + if isinstance(edge_entity_info, SerializedTFRecordInfo): + entity_info_dict: dict[EdgeType, SerializedTFRecordInfo] = { + DEFAULT_HOMOGENEOUS_EDGE_TYPE: edge_entity_info + } + else: + entity_info_dict = edge_entity_info + + # Normalise weight_edge_feat_name to per-edge-type mapping + if isinstance(weight_edge_feat_name, str): + weight_name_dict: dict[EdgeType, str] = { + et: weight_edge_feat_name for et in edge_features_dict + } + else: + weight_name_dict = weight_edge_feat_name + + edge_weights_by_type: dict[EdgeType, torch.Tensor] = {} + trimmed_features_dict: dict[EdgeType, torch.Tensor] = {} + + for edge_type, feat_tensor in edge_features_dict.items(): + if edge_type not in weight_name_dict: + trimmed_features_dict[edge_type] = feat_tensor + continue + + feat_name = weight_name_dict[edge_type] + info = entity_info_dict[edge_type] + feature_keys = list(info.feature_keys) + + if feat_name not in feature_keys: + raise ValueError( + f"weight_edge_feat_name '{feat_name}' not found in edge feature keys " + f"for edge type {edge_type}: {feature_keys}" + ) + + col_idx = feature_keys.index(feat_name) + edge_weights_by_type[edge_type] = feat_tensor[:, col_idx] + + # Remove the weight column from the feature tensor + keep_cols = [i for i in range(feat_tensor.shape[1]) if i != col_idx] + trimmed_features_dict[edge_type] = feat_tensor[:, keep_cols] + + if is_homogeneous_input: + trimmed_features: Union[torch.Tensor, dict[EdgeType, torch.Tensor]] = ( + trimmed_features_dict[DEFAULT_HOMOGENEOUS_EDGE_TYPE] + ) + else: + trimmed_features = trimmed_features_dict + + return edge_weights_by_type, trimmed_features + + @tf_on_cpu def _load_and_build_partitioned_dataset( serialized_graph_metadata: SerializedGraphMetadata, @@ -66,6 +154,7 @@ def _load_and_build_partitioned_dataset( edge_tf_dataset_options: TFDatasetOptions, splitter: Optional[Union[NodeSplitter, NodeAnchorLinkSplitter]] = None, _ssl_positive_label_percentage: Optional[float] = None, + weight_edge_feat_name: Optional[Union[str, dict[EdgeType, str]]] = None, ) -> DistDataset: """ Given some information about serialized TFRecords, loads and builds a partitioned dataset into a DistDataset class. @@ -82,6 +171,10 @@ def _load_and_build_partitioned_dataset( splitter (Optional[Union[NodeSplitter, NodeAnchorLinkSplitter]]): Optional splitter to use for splitting the graph data into train, val, and test sets. If not provided (None), no splitting will be performed. _ssl_positive_label_percentage (Optional[float]): Percentage of edges to select as self-supervised labels. Must be None if supervised edge labels are provided in advance. Slotted for refactor once this functionality is available in the transductive `splitter` directly + weight_edge_feat_name (Optional[Union[str, dict[EdgeType, str]]]): Name of the edge feature column to use as + sampling weights. The column is extracted from the feature tensor and registered separately via + ``DistPartitioner.register_edge_weights()``; it is removed from the feature matrix to avoid duplication. + Supply a single string to use the same column name for all edge types, or a per-edge-type dict. Returns: DistDataset: Initialized dataset with partitioned graph information @@ -176,6 +269,15 @@ def _load_and_build_partitioned_dataset( if loaded_graph_tensors.node_labels is not None: partitioner.register_node_labels(node_labels=loaded_graph_tensors.node_labels) if loaded_graph_tensors.edge_features is not None: + if weight_edge_feat_name is not None: + edge_weights_by_type, trimmed_edge_features = _extract_weight_column( + edge_features=loaded_graph_tensors.edge_features, + weight_edge_feat_name=weight_edge_feat_name, + edge_entity_info=serialized_graph_metadata.edge_entity_info, + ) + loaded_graph_tensors.edge_features = trimmed_edge_features + if edge_weights_by_type: + partitioner.register_edge_weights(edge_weights=edge_weights_by_type) partitioner.register_edge_features( edge_features=loaded_graph_tensors.edge_features ) @@ -230,6 +332,7 @@ def _build_dataset_process( edge_tf_dataset_options: TFDatasetOptions, splitter: Optional[Union[NodeSplitter, NodeAnchorLinkSplitter]] = None, _ssl_positive_label_percentage: Optional[float] = None, + weight_edge_feat_name: Optional[Union[str, dict[EdgeType, str]]] = None, ) -> None: """ This function is spawned by a single process per machine and is responsible for: @@ -311,6 +414,7 @@ def _build_dataset_process( edge_tf_dataset_options=edge_tf_dataset_options, splitter=splitter, _ssl_positive_label_percentage=_ssl_positive_label_percentage, + weight_edge_feat_name=weight_edge_feat_name, ) output_dict["dataset"] = output_dataset @@ -336,6 +440,7 @@ def build_dataset( _dataset_building_port: Optional[ int ] = None, # WARNING: This field will be deprecated in the future + weight_edge_feat_name: Optional[Union[str, dict[EdgeType, str]]] = None, ) -> DistDataset: """ Launches a spawned process for building and returning a DistDataset instance provided some @@ -368,6 +473,10 @@ def build_dataset( Slotted for refactor once this functionality is available in the transductive `splitter` directly _dataset_building_port (deprecated field - will be removed soon) (Optional[int]): Contains information about master port. Defaults to None, in which case it will be initialized from the current torch.distributed context. + weight_edge_feat_name (Optional[Union[str, dict[EdgeType, str]]]): Name of the edge feature column to use + as sampling weights. The column is extracted from the feature tensor and registered separately; it is + removed from the feature matrix to avoid memory duplication. Supply a single string to apply to all + edge types, or a per-edge-type dict. (default: ``None``) Returns: DistDataset: Built GraphLearn-for-PyTorch Dataset class @@ -463,6 +572,7 @@ def build_dataset( edge_tf_dataset_options, splitter, _ssl_positive_label_percentage, + weight_edge_feat_name, ), ) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 989907640..f49c3a703 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -91,6 +91,7 @@ def __init__( num_cpu_threads: Optional[int] = None, shuffle: bool = False, drop_last: bool = False, + with_weight: Optional[bool] = None, sampler_options: Optional[SamplerOptions] = None, context: Optional[DistributedContext] = None, # TODO: (svij) Deprecate this local_process_rank: Optional[int] = None, # TODO: (svij) Deprecate this @@ -201,6 +202,13 @@ def __init__( Defaults to `2` if set to `None` when using cpu training/inference. shuffle (bool): Whether to shuffle the input nodes. (default: ``False``). drop_last (bool): Whether to drop the last incomplete batch. (default: ``False``). + with_weight (Optional[bool]): Whether to use edge weights for neighbor + sampling. When ``None`` (default), inferred automatically from the + dataset: ``True`` if edge weights were registered via + ``DistPartitioner.register_edge_weights()``, ``False`` otherwise. + Pass ``True`` or ``False`` explicitly to override the inference. + Only auto-inferred for colocated mode; graph-store mode defaults to + ``False`` when ``None``. sampler_options (Optional[SamplerOptions]): Controls which sampler class is instantiated. Defaults to `KHopNeighborSamplerOptions`, which will use the num_neighbors argument to instantiate the sampler. @@ -261,6 +269,26 @@ def __init__( ) del context, local_process_rank, local_process_world_size + # Validate explicit with_weight overrides and infer when not specified. + if isinstance(dataset, DistDataset): + if with_weight is True and not dataset.has_edge_weights: + raise ValueError( + "with_weight=True requires edge weights to be registered in the dataset. " + "Pass weight_edge_feat_name to build_dataset() to register edge weights, " + "or omit with_weight to have it inferred automatically." + ) + if with_weight is False and dataset.has_edge_weights: + logger.warning( + "with_weight=False explicitly set but the dataset has edge weights registered. " + "Weighted sampling will be disabled. Omit with_weight to enable it automatically." + ) + if with_weight is None: + with_weight = ( + dataset.has_edge_weights + if isinstance(dataset, DistDataset) + else False + ) + device = ( pin_memory_device if pin_memory_device @@ -349,6 +377,7 @@ def __init__( batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, + with_weight=with_weight, ) producer: Optional[DistSamplingProducer] = None diff --git a/gigl/distributed/dist_dataset.py b/gigl/distributed/dist_dataset.py index b40f2969a..69e328854 100644 --- a/gigl/distributed/dist_dataset.py +++ b/gigl/distributed/dist_dataset.py @@ -84,6 +84,7 @@ def __init__( Union[torch.Tensor, dict[EdgeType, torch.Tensor]] ] = None, max_labels_per_anchor_node: Optional[int] = None, + has_edge_weights: bool = False, ) -> None: """ Initializes the fields of the DistDataset class. This function is called upon each serialization of the DistDataset instance. @@ -111,6 +112,9 @@ def __init__( degree_tensor: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Pre-computed degree tensor. Lazily computed on first access via the degree_tensor property. max_labels_per_anchor_node (Optional[int]): Optional cap for how many labels to materialize per anchor node for ABLP label fetching. + has_edge_weights (bool): Whether edge weights were registered during dataset + construction via ``DistPartitioner.register_edge_weights()``. Automatically + set during ``build()``; do not pass this manually. (default: ``False``) """ self._rank: int = rank self._world_size: int = world_size @@ -150,6 +154,7 @@ def __init__( Union[torch.Tensor, dict[EdgeType, torch.Tensor]] ] = degree_tensor self._max_labels_per_anchor_node = max_labels_per_anchor_node + self._has_edge_weights: bool = has_edge_weights # TODO (mkolodner-sc): Modify so that we don't need to rely on GLT's base variable naming (i.e. partition_idx, num_partitions) in favor of more clear # naming (i.e. rank, world_size). @@ -336,6 +341,15 @@ def degree_tensor( self._degree_tensor = compute_and_broadcast_degree_tensor(self.graph) return self._degree_tensor + @property + def has_edge_weights(self) -> bool: + """True if edge weights were registered during dataset construction. + + Used by loaders to automatically infer ``with_weight`` when it is not + explicitly specified. + """ + return self._has_edge_weights + @property def max_labels_per_anchor_node(self) -> Optional[int]: return self._max_labels_per_anchor_node @@ -575,6 +589,9 @@ def _initialize_graph( edge_ids: Union[ Optional[torch.Tensor], dict[EdgeType, Optional[torch.Tensor]] ] = partitioned_edge_index.edge_ids + edge_weights: Optional[ + Union[torch.Tensor, dict[EdgeType, torch.Tensor]] + ] = partitioned_edge_index.weights else: edge_index = { edge_type: graph_partition_data.edge_index @@ -584,12 +601,29 @@ def _initialize_graph( edge_type: graph_partition_data.edge_ids for edge_type, graph_partition_data in partitioned_edge_index.items() } + weights_by_type = { + edge_type: graph_partition_data.weights + for edge_type, graph_partition_data in partitioned_edge_index.items() + if graph_partition_data.weights is not None + } + edge_weights = weights_by_type if weights_by_type else None + if weights_by_type: + missing = set(partitioned_edge_index.keys()) - set(weights_by_type.keys()) + if missing: + logger.warning( + f"Edge weights are registered for {set(weights_by_type.keys())} but " + f"not for {missing}. When with_weight=True, edge types without weights " + f"will fall back to uniform sampling." + ) + + self._has_edge_weights = edge_weights is not None self.init_graph( edge_index=edge_index, edge_ids=edge_ids, graph_mode="CPU", directed=True, + edge_weights=edge_weights, ) if isinstance(partitioned_edge_index, Mapping): @@ -876,6 +910,7 @@ def share_ipc( Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]], Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], Optional[int], + bool, ]: """ Serializes the member variables of the DistDatasetClass @@ -928,6 +963,7 @@ def share_ipc( self._edge_feature_info, # Additional field unique to DistDataset class self._degree_tensor, # Additional field unique to DistDataset class self._max_labels_per_anchor_node, # Additional field unique to DistDataset class + self._has_edge_weights, # Additional field unique to DistDataset class ) return ipc_handle @@ -1185,6 +1221,7 @@ def _rebuild_distributed_dataset( ], # Edge feature dim and its data type Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], # Degree tensors Optional[int], # Optional per-anchor label cap for ABLP label fetching + bool, # has_edge_weights ], ): dataset = DistDataset.from_ipc_handle(ipc_handle) diff --git a/gigl/distributed/dist_partitioner.py b/gigl/distributed/dist_partitioner.py index 514ba5193..cba668035 100644 --- a/gigl/distributed/dist_partitioner.py +++ b/gigl/distributed/dist_partitioner.py @@ -202,6 +202,7 @@ def __init__( self._edge_ids: Optional[dict[EdgeType, tuple[int, int]]] = None self._edge_feat: Optional[dict[EdgeType, torch.Tensor]] = None self._edge_feat_dim: Optional[dict[EdgeType, int]] = None + self._edge_weights: Optional[dict[EdgeType, torch.Tensor]] = None # TODO (mkolodner-sc): Deprecate the need for explicitly storing labels are part of this class, leveraging # heterogeneous support instead @@ -623,6 +624,43 @@ def register_edge_features( for edge_type in input_edge_features: self._edge_feat_dim[edge_type] = input_edge_features[edge_type].shape[1] + def register_edge_weights( + self, edge_weights: Union[torch.Tensor, dict[EdgeType, torch.Tensor]] + ) -> None: + """Registers per-edge sampling weights to the partitioner. + + Weights must be a 1-D float tensor of shape ``[num_edges]``, one scalar per edge. + Must be called after ``register_edge_index()`` and before ``partition()``. + + Weights are kept separate from edge features — do not include the weight column + in edge features passed to ``register_edge_features()``. + + For optimal memory management, delete the reference to ``edge_weights`` after + calling this function. + + Args: + edge_weights: Per-edge weights, either a ``[num_edges]`` tensor if homogeneous + or a ``dict[EdgeType, Tensor]`` if heterogeneous. + """ + self._assert_and_get_rpc_setup() + + if self._edge_weights is not None: + raise ValueError( + "Edge weights have already been registered. Cannot re-register edge weight data." + ) + + logger.info("Registering Edge Weights ...") + + input_edge_weights = self._convert_edge_entity_to_heterogeneous_format( + input_edge_entity=edge_weights + ) + + assert input_edge_weights, ( + "Edge weights is an empty dictionary. Please provide edge weights to register." + ) + + self._edge_weights = convert_to_tensor(input_edge_weights, dtype=torch.float32) + def register_labels( self, label_edge_index: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], @@ -1079,6 +1117,13 @@ def _partition_edge_index_and_edge_features( should_skip_edge_feats = ( self._edge_feat is None or edge_type not in self._edge_feat ) + has_weights_for_edge_type = ( + self._edge_weights is not None and edge_type in self._edge_weights + ) + # Need a partition book if we have features or weights to reindex. + should_generate_partition_book = ( + not should_skip_edge_feats or has_weights_for_edge_type + ) # Partitioning Edge Indices @@ -1100,13 +1145,12 @@ def _edge_pfn(_, chunk_range): chunk_target_indices = index_select(target_indices, chunk_range) return target_node_partition_book[chunk_target_indices] - # TODO (mkolodner-sc): Investigate partitioning edge features as part of this input_data tuple edge_res_list, edge_partition_book = self._partition_by_chunk( input_data=(edge_index[0], edge_index[1], edge_ids), rank_indices=edge_ids, partition_function=_edge_pfn, total_val_size=num_edges, - generate_pb=not should_skip_edge_feats, + generate_pb=should_generate_partition_book, ) del edge_index, target_indices @@ -1128,20 +1172,55 @@ def _edge_pfn(_, chunk_range): ), dim=0, ) - if should_skip_edge_feats: + if should_skip_edge_feats and not has_weights_for_edge_type: partitioned_edge_ids = None else: partitioned_edge_ids = torch.cat([r[2] for r in edge_res_list]) + edge_res_list.clear() + + gc.collect() + + # Partitioning Edge Weights (before creating GraphPartitionData so weights + # can be included at construction time; frozen dataclass cannot be mutated). + partitioned_weights: Optional[torch.Tensor] = None + if has_weights_for_edge_type: + assert edge_partition_book is not None, ( + "edge_partition_book must be populated when edge weights are registered" + ) + assert self._edge_weights is not None + edge_weights_tensor = self._edge_weights[edge_type] + + def _edge_weight_pfn(weight_ids, _): + assert edge_partition_book is not None + return edge_partition_book[weight_ids] + + weight_res_list, _ = self._partition_by_chunk( + input_data=(edge_weights_tensor, edge_ids), + rank_indices=edge_ids, + partition_function=_edge_weight_pfn, + total_val_size=num_edges, + generate_pb=False, + ) + del edge_weights_tensor + del self._edge_weights[edge_type] + if len(self._edge_weights) == 0: + self._edge_weights = None + gc.collect() + + if len(weight_res_list) == 0: + partitioned_weights = torch.empty(0) + else: + partitioned_weights = torch.cat([r[0] for r in weight_res_list]) + weight_res_list.clear() + gc.collect() + current_graph_part = GraphPartitionData( edge_index=partitioned_edge_index, edge_ids=partitioned_edge_ids, + weights=partitioned_weights, ) - edge_res_list.clear() - - gc.collect() - # Partitioning Edge Features if should_skip_edge_feats: diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 3d6d5a34b..f3cd94440 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -90,6 +90,7 @@ def __init__( num_cpu_threads: Optional[int] = None, shuffle: bool = False, drop_last: bool = False, + with_weight: Optional[bool] = None, sampler_options: Optional[SamplerOptions] = None, non_blocking_transfers: bool = True, ): @@ -158,6 +159,13 @@ def __init__( Defaults to `2` if set to `None` when using cpu training/inference. shuffle (bool): Whether to shuffle the input nodes. (default: ``False``). drop_last (bool): Whether to drop the last incomplete batch. (default: ``False``). + with_weight (Optional[bool]): Whether to use edge weights for neighbor + sampling. When ``None`` (default), inferred automatically from the + dataset: ``True`` if edge weights were registered via + ``DistPartitioner.register_edge_weights()``, ``False`` otherwise. + Pass ``True`` or ``False`` explicitly to override the inference. + Only auto-inferred for colocated mode; graph-store mode defaults to + ``False`` when ``None``. sampler_options (Optional[SamplerOptions]): Controls which sampler class is instantiated. Pass ``KHopNeighborSamplerOptions`` to use the built-in sampler, or ``CustomSamplerOptions`` to dynamically import a custom sampler class. @@ -184,6 +192,26 @@ def __init__( ) del context, local_process_rank, local_process_world_size + # Validate explicit with_weight overrides and infer when not specified. + if isinstance(dataset, DistDataset): + if with_weight is True and not dataset.has_edge_weights: + raise ValueError( + "with_weight=True requires edge weights to be registered in the dataset. " + "Pass weight_edge_feat_name to build_dataset() to register edge weights, " + "or omit with_weight to have it inferred automatically." + ) + if with_weight is False and dataset.has_edge_weights: + logger.warning( + "with_weight=False explicitly set but the dataset has edge weights registered. " + "Weighted sampling will be disabled. Omit with_weight to enable it automatically." + ) + if with_weight is None: + with_weight = ( + dataset.has_edge_weights + if isinstance(dataset, DistDataset) + else False + ) + # Determine mode if isinstance(dataset, RemoteDistDataset): self._sampling_cluster_setup = SamplingClusterSetup.GRAPH_STORE @@ -263,6 +291,7 @@ def __init__( batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, + with_weight=with_weight, ) producer: Optional[DistSamplingProducer] = None diff --git a/tests/test_assets/distributed/run_distributed_partitioner.py b/tests/test_assets/distributed/run_distributed_partitioner.py index 96cfb25cf..9e76d107b 100644 --- a/tests/test_assets/distributed/run_distributed_partitioner.py +++ b/tests/test_assets/distributed/run_distributed_partitioner.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Type, Union +from typing import Optional, Type, Union import torch from graphlearn_torch.distributed import init_rpc, init_worker_group @@ -31,6 +31,9 @@ def run_distributed_partitioner( master_port: int, input_data_strategy: InputDataStrategy, partitioner_class: Type[DistPartitioner], + rank_to_edge_weights: Optional[ + dict[int, Union[torch.Tensor, dict[EdgeType, torch.Tensor]]] + ] = None, ) -> None: """ Runs the distributed partitioner on a specific rank. @@ -44,6 +47,7 @@ def run_distributed_partitioner( master_port (int): Master port for initializing rpc for partitioning input_data_strategy (InputDataStrategy): Strategy for registering inputs to the partitioner partitioner_class (Type[DistPartitioner]): The class to use for partitioning + rank_to_edge_weights (Optional[dict[int, Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]]): Optional mapping of rank to 1D edge weight tensor (or dict per EdgeType for heterogeneous). Only supported with REGISTER_ALL_ENTITIES_SEPARATELY strategy. """ input_graph = rank_to_input_graph[rank] @@ -91,6 +95,10 @@ def run_distributed_partitioner( dist_partitioner.register_edge_features(edge_features=edge_features) del edge_index del edge_features + if rank_to_edge_weights is not None and rank in rank_to_edge_weights: + dist_partitioner.register_edge_weights( + edge_weights=rank_to_edge_weights[rank] + ) ( output_edge_index, output_edge_features, diff --git a/tests/test_assets/distributed/test_dataset.py b/tests/test_assets/distributed/test_dataset.py index bab8afdf4..eb0b852c5 100644 --- a/tests/test_assets/distributed/test_dataset.py +++ b/tests/test_assets/distributed/test_dataset.py @@ -67,6 +67,7 @@ def create_homogeneous_dataset( node_features: Optional[torch.Tensor] = None, edge_features: Optional[torch.Tensor] = None, node_labels: Optional[torch.Tensor] = None, + edge_weights: Optional[torch.Tensor] = None, rank: int = 0, world_size: int = 1, edge_dir: Literal["in", "out"] = "out", @@ -74,13 +75,14 @@ def create_homogeneous_dataset( """Create a homogeneous test dataset. Creates a single-partition DistDataset with the specified edge index, node features, - edge features, and node labels. + edge features, node labels, and optional per-edge sampling weights. Args: edge_index: COO format edge index [2, num_edges]. node_features: Node feature tensor [num_nodes, feature_dim], or None. edge_features: Edge feature tensor [num_edges, feature_dim], or None. node_labels: Node label tensor [num_nodes, label_dim], or None. + edge_weights: 1D per-edge sampling weight tensor [num_edges], or None. rank: Rank of the current process. Defaults to 0. world_size: Total number of processes. Defaults to 1. edge_dir: Edge direction ("in" or "out"). Defaults to "out". @@ -129,6 +131,7 @@ def create_homogeneous_dataset( partitioned_edge_index=GraphPartitionData( edge_index=edge_index, edge_ids=None, + weights=edge_weights, ), partitioned_node_features=partitioned_node_features, partitioned_edge_features=partitioned_edge_features, diff --git a/tests/unit/distributed/distributed_weighted_sampling_test.py b/tests/unit/distributed/distributed_weighted_sampling_test.py new file mode 100644 index 000000000..8f371ec6c --- /dev/null +++ b/tests/unit/distributed/distributed_weighted_sampling_test.py @@ -0,0 +1,272 @@ +"""Distributed integration tests for weighted edge sampling. + +Covers two surfaces: + 1. DistPartitioner correctly partitions registered edge weights (weights land on + the right rank and match the expected values). + 2. DistNeighborLoader runs end-to-end without errors when with_weight=True is set + and the dataset carries edge weights — both homogeneous and heterogeneous. +""" + +from collections.abc import Mapping +from typing import MutableMapping + +import torch +import torch.multiprocessing as mp +from absl.testing import absltest +from graphlearn_torch.distributed import shutdown_rpc +from torch.multiprocessing import Manager +from torch_geometric.data import Data, HeteroData + +from gigl.distributed import DistPartitioner +from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.distributed_neighborloader import DistNeighborLoader +from gigl.distributed.utils.networking import get_free_port +from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation +from gigl.types.graph import ( + FeaturePartitionData, + GraphPartitionData, + PartitionOutput, +) +from tests.test_assets.distributed.constants import ( + MOCKED_NUM_PARTITIONS, + MOCKED_U2U_EDGE_INDEX_ON_RANK_ONE, + MOCKED_U2U_EDGE_INDEX_ON_RANK_ZERO, + RANK_TO_MOCKED_GRAPH, +) +from tests.test_assets.distributed.run_distributed_partitioner import ( + InputDataStrategy, + run_distributed_partitioner, +) +from tests.test_assets.distributed.utils import create_test_process_group +from tests.test_assets.test_case import TestCase + +_USER = NodeType("user") +_STORY = NodeType("story") +_USER_TO_STORY = EdgeType(_USER, Relation("to"), _STORY) +_STORY_TO_USER = EdgeType(_STORY, Relation("to"), _USER) + + +# --------------------------------------------------------------------------- +# Subprocess functions — must accept local_rank as first arg (mp.spawn) +# --------------------------------------------------------------------------- + + +def _run_distributed_weighted_neighbor_loader_homogeneous( + _: int, + dataset: DistDataset, + expected_data_count: int, +) -> None: + """Subprocess: iterates a weighted homogeneous loader and checks batch count and type.""" + create_test_process_group() + loader = DistNeighborLoader( + dataset=dataset, + num_neighbors=[2, 2], + with_weight=True, + pin_memory_device=torch.device("cpu"), + ) + count = 0 + for datum in loader: + assert isinstance(datum, Data), ( + f"Subgraph should be Data for homogeneous datasets, got {type(datum)}" + ) + count += 1 + assert count == expected_data_count, ( + f"Expected {expected_data_count} batches, got {count}" + ) + shutdown_rpc() + + +def _run_distributed_weighted_neighbor_loader_heterogeneous( + _: int, + dataset: DistDataset, + expected_data_count: int, +) -> None: + """Subprocess: iterates a weighted heterogeneous loader and checks batch count and type.""" + create_test_process_group() + assert isinstance(dataset.node_ids, Mapping) + loader = DistNeighborLoader( + dataset=dataset, + input_nodes=(_USER, dataset.node_ids[_USER]), + num_neighbors=[2, 2], + with_weight=True, + pin_memory_device=torch.device("cpu"), + ) + count = 0 + for datum in loader: + assert isinstance(datum, HeteroData), ( + f"Subgraph should be HeteroData for heterogeneous datasets, got {type(datum)}" + ) + count += 1 + assert count == expected_data_count, ( + f"Expected {expected_data_count} batches, got {count}" + ) + shutdown_rpc() + + +# --------------------------------------------------------------------------- +# Test classes +# --------------------------------------------------------------------------- + + +class WeightedEdgePartitionerTestCase(TestCase): + """Tests that DistPartitioner correctly partitions registered edge weights.""" + + def setUp(self) -> None: + self._master_ip_address = "localhost" + + def test_homogeneous_weights_partitioned_correctly(self) -> None: + """Edge weights (= src_node_id / 10.0) land on the correct rank after partitioning. + + The mocked graph has edges with source nodes 0–3 on rank 0 and 4–7 on rank 1. + Weights are set to src_node_id / 10.0, mirroring the existing edge-feature + convention. After partitioning by source node each rank should hold only its + own weights, and each weight should equal the corresponding global edge ID * 0.1. + """ + master_port = get_free_port() + manager = Manager() + output_dict: MutableMapping[int, PartitionOutput] = manager.dict() + + rank_to_edge_weights = { + 0: MOCKED_U2U_EDGE_INDEX_ON_RANK_ZERO[0].float() / 10.0, + 1: MOCKED_U2U_EDGE_INDEX_ON_RANK_ONE[0].float() / 10.0, + } + + mp.spawn( + run_distributed_partitioner, + args=( + output_dict, + False, # is_heterogeneous + RANK_TO_MOCKED_GRAPH, + True, # should_assign_edges_by_src_node + self._master_ip_address, + master_port, + InputDataStrategy.REGISTER_ALL_ENTITIES_SEPARATELY, + DistPartitioner, + rank_to_edge_weights, + ), + nprocs=MOCKED_NUM_PARTITIONS, + join=True, + ) + + for rank, partition_output in output_dict.items(): + partitioned_edge_index = partition_output.partitioned_edge_index + self.assertIsInstance(partitioned_edge_index, GraphPartitionData) + assert isinstance(partitioned_edge_index, GraphPartitionData) + + weights = partitioned_edge_index.weights + self.assertIsNotNone( + weights, + msg=f"Rank {rank}: expected weights in GraphPartitionData, got None", + ) + assert weights is not None + + edge_ids = partitioned_edge_index.edge_ids + self.assertIsNotNone( + edge_ids, + msg=f"Rank {rank}: edge_ids must be present when weights are registered", + ) + assert edge_ids is not None + + self.assertEqual( + weights.shape, + edge_ids.shape, + msg=f"Rank {rank}: weights and edge_ids must have the same length", + ) + + # weight for each edge == its global edge_id * 0.1 (i.e. src_node_id / 10.0). + # Sort both so the comparison is order-independent. + expected_weights = edge_ids.float() * 0.1 + torch.testing.assert_close( + weights.sort().values, + expected_weights.sort().values, + msg=f"Rank {rank}: partitioned weights do not match expected src_node_id / 10.0", + ) + + +class DistributedWeightedSamplingTest(TestCase): + """End-to-end dataloading tests for DistNeighborLoader with with_weight=True.""" + + def setUp(self) -> None: + super().setUp() + self._world_size = 1 + + def tearDown(self) -> None: + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + super().tearDown() + + def test_distributed_neighbor_loader_with_weights_homogeneous(self) -> None: + """Homogeneous loader with with_weight=True iterates all nodes without error.""" + n = 5 + partition_output = PartitionOutput( + node_partition_book=torch.zeros(n), + edge_partition_book=torch.zeros(n), + partitioned_edge_index=GraphPartitionData( + edge_index=torch.tensor([[0, 1, 2, 3, 4], [1, 2, 3, 4, 0]]), + edge_ids=None, + weights=torch.ones(n, dtype=torch.float32), + ), + partitioned_node_features=FeaturePartitionData( + feats=torch.zeros(n, 2), ids=torch.arange(n) + ), + partitioned_edge_features=None, + partitioned_positive_labels=None, + partitioned_negative_labels=None, + partitioned_node_labels=None, + ) + dataset = DistDataset(rank=0, world_size=1, edge_dir="out") + dataset.build(partition_output=partition_output) + + mp.spawn( + fn=_run_distributed_weighted_neighbor_loader_homogeneous, + args=(dataset, n), + ) + + def test_distributed_neighbor_loader_with_weights_heterogeneous(self) -> None: + """Heterogeneous loader with with_weight=True iterates all seed nodes without error.""" + n = 5 + partition_output = PartitionOutput( + node_partition_book={ + _USER: torch.zeros(n), + _STORY: torch.zeros(n), + }, + edge_partition_book={ + _USER_TO_STORY: torch.zeros(n), + _STORY_TO_USER: torch.zeros(n), + }, + partitioned_edge_index={ + _USER_TO_STORY: GraphPartitionData( + edge_index=torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]), + edge_ids=None, + weights=torch.ones(n, dtype=torch.float32), + ), + _STORY_TO_USER: GraphPartitionData( + edge_index=torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]), + edge_ids=None, + weights=torch.ones(n, dtype=torch.float32), + ), + }, + partitioned_node_features={ + _USER: FeaturePartitionData( + feats=torch.zeros(n, 2), ids=torch.arange(n) + ), + _STORY: FeaturePartitionData( + feats=torch.zeros(n, 2), ids=torch.arange(n) + ), + }, + partitioned_edge_features=None, + partitioned_positive_labels=None, + partitioned_negative_labels=None, + partitioned_node_labels=None, + ) + dataset = DistDataset(rank=0, world_size=1, edge_dir="out") + dataset.build(partition_output=partition_output) + + mp.spawn( + fn=_run_distributed_weighted_neighbor_loader_heterogeneous, + args=(dataset, n), + ) + + +if __name__ == "__main__": + absltest.main() From e8b33c2230bd22df287fe3ff8e5ea0aefced8b64 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 13 May 2026 21:55:30 +0000 Subject: [PATCH 02/13] Add TODO for weighted sampling with PPR --- gigl/distributed/dist_ablp_neighborloader.py | 6 ++++++ gigl/distributed/distributed_neighborloader.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index f49c3a703..362561d04 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -289,6 +289,12 @@ def __init__( else False ) + if with_weight and isinstance(sampler_options, PPRSamplerOptions): + raise NotImplementedError( + "Weighted sampling is not yet supported with PPRSamplerOptions. " + "Weight-proportional residual propagation for PPR is planned but not implemented." + ) + device = ( pin_memory_device if pin_memory_device diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index f3cd94440..41df8f371 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -212,6 +212,12 @@ def __init__( else False ) + if with_weight and isinstance(sampler_options, PPRSamplerOptions): + raise NotImplementedError( + "Weighted sampling is not yet supported with PPRSamplerOptions. " + "Weight-proportional residual propagation for PPR is planned but not implemented." + ) + # Determine mode if isinstance(dataset, RemoteDistDataset): self._sampling_cluster_setup = SamplingClusterSetup.GRAPH_STORE From 49462984bb7610f87952cc36a1f8909023c01ab8 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 13 May 2026 22:03:59 +0000 Subject: [PATCH 03/13] cleanup --- gigl/distributed/base_dist_loader.py | 32 +++++++++++++++ gigl/distributed/dataset_factory.py | 6 ++- gigl/distributed/dist_ablp_neighborloader.py | 39 +++---------------- gigl/distributed/dist_dataset.py | 20 ++++++---- gigl/distributed/dist_partitioner.py | 4 +- .../distributed/distributed_neighborloader.py | 39 +++---------------- 6 files changed, 62 insertions(+), 78 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 2199af57f..3931b4175 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -333,6 +333,38 @@ def __init__( "for graph-store mode." ) + @staticmethod + def validate_with_weight( + with_weight: bool, + dataset: Union[DistDataset, RemoteDistDataset], + sampler_options: SamplerOptions, + ) -> None: + """Validates the ``with_weight`` parameter against the dataset and sampler. + + Args: + with_weight: Whether weighted sampling was requested. + dataset: The dataset being sampled from. + sampler_options: The sampler to be used. + + Raises: + ValueError: If ``with_weight=True`` but no edge weights are registered. + NotImplementedError: If ``with_weight=True`` and a PPR sampler is requested. + """ + if ( + with_weight + and isinstance(dataset, DistDataset) + and not dataset.has_edge_weights + ): + raise ValueError( + "with_weight=True requires edge weights to be registered in the dataset. " + "Pass weight_edge_feat_name to build_dataset() to register edge weights." + ) + if with_weight and isinstance(sampler_options, PPRSamplerOptions): + raise NotImplementedError( + "Weighted sampling is not yet supported with PPRSamplerOptions. " + "Weight-proportional residual propagation for PPR is planned but not implemented." + ) + @staticmethod def create_sampling_config( num_neighbors: Union[list[int], dict[EdgeType, list[int]]], diff --git a/gigl/distributed/dataset_factory.py b/gigl/distributed/dataset_factory.py index 72d9a55b9..258c9da1f 100644 --- a/gigl/distributed/dataset_factory.py +++ b/gigl/distributed/dataset_factory.py @@ -43,8 +43,8 @@ ) from gigl.src.common.types.graph_data import EdgeType from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper -from gigl.types.graph import DEFAULT_HOMOGENEOUS_EDGE_TYPE from gigl.src.common.types.pb_wrappers.task_metadata import TaskMetadataType +from gigl.types.graph import DEFAULT_HOMOGENEOUS_EDGE_TYPE from gigl.utils.data_splitters import ( DistNodeAnchorLinkSplitter, DistNodeSplitter, @@ -60,7 +60,9 @@ def _extract_weight_column( edge_features: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], weight_edge_feat_name: Union[str, dict[EdgeType, str]], - edge_entity_info: Union[SerializedTFRecordInfo, dict[EdgeType, SerializedTFRecordInfo]], + edge_entity_info: Union[ + SerializedTFRecordInfo, dict[EdgeType, SerializedTFRecordInfo] + ], ) -> Tuple[ dict[EdgeType, torch.Tensor], Union[torch.Tensor, dict[EdgeType, torch.Tensor]], diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 362561d04..5e9f6a3df 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -91,7 +91,7 @@ def __init__( num_cpu_threads: Optional[int] = None, shuffle: bool = False, drop_last: bool = False, - with_weight: Optional[bool] = None, + with_weight: bool = False, sampler_options: Optional[SamplerOptions] = None, context: Optional[DistributedContext] = None, # TODO: (svij) Deprecate this local_process_rank: Optional[int] = None, # TODO: (svij) Deprecate this @@ -202,13 +202,10 @@ def __init__( Defaults to `2` if set to `None` when using cpu training/inference. shuffle (bool): Whether to shuffle the input nodes. (default: ``False``). drop_last (bool): Whether to drop the last incomplete batch. (default: ``False``). - with_weight (Optional[bool]): Whether to use edge weights for neighbor - sampling. When ``None`` (default), inferred automatically from the - dataset: ``True`` if edge weights were registered via - ``DistPartitioner.register_edge_weights()``, ``False`` otherwise. - Pass ``True`` or ``False`` explicitly to override the inference. - Only auto-inferred for colocated mode; graph-store mode defaults to - ``False`` when ``None``. + with_weight (bool): Whether to use edge weights for neighbor sampling. + Requires edge weights to be registered via + ``DistPartitioner.register_edge_weights()`` during dataset construction. + Defaults to ``False``. sampler_options (Optional[SamplerOptions]): Controls which sampler class is instantiated. Defaults to `KHopNeighborSamplerOptions`, which will use the num_neighbors argument to instantiate the sampler. @@ -269,31 +266,7 @@ def __init__( ) del context, local_process_rank, local_process_world_size - # Validate explicit with_weight overrides and infer when not specified. - if isinstance(dataset, DistDataset): - if with_weight is True and not dataset.has_edge_weights: - raise ValueError( - "with_weight=True requires edge weights to be registered in the dataset. " - "Pass weight_edge_feat_name to build_dataset() to register edge weights, " - "or omit with_weight to have it inferred automatically." - ) - if with_weight is False and dataset.has_edge_weights: - logger.warning( - "with_weight=False explicitly set but the dataset has edge weights registered. " - "Weighted sampling will be disabled. Omit with_weight to enable it automatically." - ) - if with_weight is None: - with_weight = ( - dataset.has_edge_weights - if isinstance(dataset, DistDataset) - else False - ) - - if with_weight and isinstance(sampler_options, PPRSamplerOptions): - raise NotImplementedError( - "Weighted sampling is not yet supported with PPRSamplerOptions. " - "Weight-proportional residual propagation for PPR is planned but not implemented." - ) + BaseDistLoader.validate_with_weight(with_weight, dataset, sampler_options) device = ( pin_memory_device diff --git a/gigl/distributed/dist_dataset.py b/gigl/distributed/dist_dataset.py index 69e328854..dd2682e74 100644 --- a/gigl/distributed/dist_dataset.py +++ b/gigl/distributed/dist_dataset.py @@ -343,10 +343,8 @@ def degree_tensor( @property def has_edge_weights(self) -> bool: - """True if edge weights were registered during dataset construction. - - Used by loaders to automatically infer ``with_weight`` when it is not - explicitly specified. + """True if edge weights were registered during dataset construction via + ``DistPartitioner.register_edge_weights()``. """ return self._has_edge_weights @@ -573,11 +571,15 @@ def _initialize_graph( GraphPartitionData, dict[EdgeType, GraphPartitionData] ], ) -> None: - """ - Initializes the graph structure with edge index and edge IDs from partition output. + """Initializes the graph structure from partition output. + + Sets up the GLT graph with edge index, edge IDs, and optional edge weights. + For heterogeneous graphs with weights registered on only some edge types, a + warning is logged and unweighted edge types fall back to uniform sampling. Args: - partitioned_edge_index(Union[GraphPartitionData, dict[EdgeType, GraphPartitionData]]): The partitioned graph data + partitioned_edge_index: Partitioned graph data per edge type (heterogeneous) + or a single partition (homogeneous). """ # Edge Index refers to the [2, num_edges] tensor representing pairs of nodes connecting each edge @@ -608,7 +610,9 @@ def _initialize_graph( } edge_weights = weights_by_type if weights_by_type else None if weights_by_type: - missing = set(partitioned_edge_index.keys()) - set(weights_by_type.keys()) + missing = set(partitioned_edge_index.keys()) - set( + weights_by_type.keys() + ) if missing: logger.warning( f"Edge weights are registered for {set(weights_by_type.keys())} but " diff --git a/gigl/distributed/dist_partitioner.py b/gigl/distributed/dist_partitioner.py index cba668035..d5df4c712 100644 --- a/gigl/distributed/dist_partitioner.py +++ b/gigl/distributed/dist_partitioner.py @@ -1191,9 +1191,9 @@ def _edge_pfn(_, chunk_range): assert self._edge_weights is not None edge_weights_tensor = self._edge_weights[edge_type] - def _edge_weight_pfn(weight_ids, _): + def _edge_weight_pfn(edge_ids_chunk, _): assert edge_partition_book is not None - return edge_partition_book[weight_ids] + return edge_partition_book[edge_ids_chunk] weight_res_list, _ = self._partition_by_chunk( input_data=(edge_weights_tensor, edge_ids), diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 41df8f371..b99d3a528 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -90,7 +90,7 @@ def __init__( num_cpu_threads: Optional[int] = None, shuffle: bool = False, drop_last: bool = False, - with_weight: Optional[bool] = None, + with_weight: bool = False, sampler_options: Optional[SamplerOptions] = None, non_blocking_transfers: bool = True, ): @@ -159,13 +159,10 @@ def __init__( Defaults to `2` if set to `None` when using cpu training/inference. shuffle (bool): Whether to shuffle the input nodes. (default: ``False``). drop_last (bool): Whether to drop the last incomplete batch. (default: ``False``). - with_weight (Optional[bool]): Whether to use edge weights for neighbor - sampling. When ``None`` (default), inferred automatically from the - dataset: ``True`` if edge weights were registered via - ``DistPartitioner.register_edge_weights()``, ``False`` otherwise. - Pass ``True`` or ``False`` explicitly to override the inference. - Only auto-inferred for colocated mode; graph-store mode defaults to - ``False`` when ``None``. + with_weight (bool): Whether to use edge weights for neighbor sampling. + Requires edge weights to be registered via + ``DistPartitioner.register_edge_weights()`` during dataset construction. + Defaults to ``False``. sampler_options (Optional[SamplerOptions]): Controls which sampler class is instantiated. Pass ``KHopNeighborSamplerOptions`` to use the built-in sampler, or ``CustomSamplerOptions`` to dynamically import a custom sampler class. @@ -192,31 +189,7 @@ def __init__( ) del context, local_process_rank, local_process_world_size - # Validate explicit with_weight overrides and infer when not specified. - if isinstance(dataset, DistDataset): - if with_weight is True and not dataset.has_edge_weights: - raise ValueError( - "with_weight=True requires edge weights to be registered in the dataset. " - "Pass weight_edge_feat_name to build_dataset() to register edge weights, " - "or omit with_weight to have it inferred automatically." - ) - if with_weight is False and dataset.has_edge_weights: - logger.warning( - "with_weight=False explicitly set but the dataset has edge weights registered. " - "Weighted sampling will be disabled. Omit with_weight to enable it automatically." - ) - if with_weight is None: - with_weight = ( - dataset.has_edge_weights - if isinstance(dataset, DistDataset) - else False - ) - - if with_weight and isinstance(sampler_options, PPRSamplerOptions): - raise NotImplementedError( - "Weighted sampling is not yet supported with PPRSamplerOptions. " - "Weight-proportional residual propagation for PPR is planned but not implemented." - ) + BaseDistLoader.validate_with_weight(with_weight, dataset, sampler_options) # Determine mode if isinstance(dataset, RemoteDistDataset): From 56d23552e8aa4942ccc6f0020efb38f805328819 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 13 May 2026 22:56:02 +0000 Subject: [PATCH 04/13] update partitioning to be done in 1 pass --- gigl/distributed/dist_partitioner.py | 181 +++--- .../run_distributed_partitioner.py | 66 ++ .../distributed_weighted_sampling_test.py | 587 +++++++++++++++--- 3 files changed, 668 insertions(+), 166 deletions(-) diff --git a/gigl/distributed/dist_partitioner.py b/gigl/distributed/dist_partitioner.py index d5df4c712..d9da0caa5 100644 --- a/gigl/distributed/dist_partitioner.py +++ b/gigl/distributed/dist_partitioner.py @@ -1163,7 +1163,6 @@ def _edge_pfn(_, chunk_range): if len(edge_res_list) == 0: partitioned_edge_index = torch.empty((2, 0)) - partitioned_edge_ids = torch.empty(0) else: partitioned_edge_index = torch.stack( ( @@ -1172,120 +1171,138 @@ def _edge_pfn(_, chunk_range): ), dim=0, ) - if should_skip_edge_feats and not has_weights_for_edge_type: - partitioned_edge_ids = None - else: - partitioned_edge_ids = torch.cat([r[2] for r in edge_res_list]) edge_res_list.clear() gc.collect() - # Partitioning Edge Weights (before creating GraphPartitionData so weights - # can be included at construction time; frozen dataclass cannot be mutated). + # Partition edge features and weights together in a single pass, + # mirroring how node features and labels are co-partitioned. + # Input tuple layout: (edge_feat?, edge_weights?, edge_ids) + # IDs are always at r[-1]; features at r[0]; weights at r[1] when + # features are also present, else r[0]. + current_feat_part: Optional[FeaturePartitionData] = None partitioned_weights: Optional[torch.Tensor] = None - if has_weights_for_edge_type: - assert edge_partition_book is not None, ( - "edge_partition_book must be populated when edge weights are registered" - ) - assert self._edge_weights is not None - edge_weights_tensor = self._edge_weights[edge_type] - - def _edge_weight_pfn(edge_ids_chunk, _): - assert edge_partition_book is not None - return edge_partition_book[edge_ids_chunk] + partitioned_edge_ids: Optional[torch.Tensor] = None - weight_res_list, _ = self._partition_by_chunk( - input_data=(edge_weights_tensor, edge_ids), - rank_indices=edge_ids, - partition_function=_edge_weight_pfn, - total_val_size=num_edges, - generate_pb=False, - ) - del edge_weights_tensor - del self._edge_weights[edge_type] - if len(self._edge_weights) == 0: - self._edge_weights = None - gc.collect() - - if len(weight_res_list) == 0: - partitioned_weights = torch.empty(0) - else: - partitioned_weights = torch.cat([r[0] for r in weight_res_list]) - weight_res_list.clear() - gc.collect() - - current_graph_part = GraphPartitionData( - edge_index=partitioned_edge_index, - edge_ids=partitioned_edge_ids, - weights=partitioned_weights, - ) - - # Partitioning Edge Features - - if should_skip_edge_feats: + if not should_generate_partition_book: logger.info( f"No edge features detected for edge type {edge_type}, will only partition edge indices for this edge type." ) - current_feat_part = None del edge_ids del self._edge_ids[edge_type] if len(self._edge_ids) == 0: self._edge_ids = None gc.collect() else: - assert self._edge_feat_dim is not None and edge_type in self._edge_feat_dim - assert self._edge_feat is not None and edge_type in self._edge_feat assert edge_partition_book is not None - edge_feat = self._edge_feat[edge_type] - edge_feat_dim = self._edge_feat_dim[edge_type] - def _edge_feature_pfn(edge_feature_ids, _): + edge_feat: Optional[torch.Tensor] = None + edge_feat_dim: Optional[int] = None + edge_weights_tensor: Optional[torch.Tensor] = None + if not should_skip_edge_feats: + assert self._edge_feat is not None and edge_type in self._edge_feat + assert ( + self._edge_feat_dim is not None and edge_type in self._edge_feat_dim + ) + edge_feat = self._edge_feat[edge_type] + edge_feat_dim = self._edge_feat_dim[edge_type] + if has_weights_for_edge_type: + assert self._edge_weights is not None + edge_weights_tensor = self._edge_weights[edge_type] + + input_parts: list[torch.Tensor] = [] + if edge_feat is not None: + input_parts.append(edge_feat) + if edge_weights_tensor is not None: + input_parts.append(edge_weights_tensor) + input_parts.append(edge_ids) + + # Positional indices: features first, weights next, ids always last. + feat_idx: Optional[int] = 0 if not should_skip_edge_feats else None + weight_idx: Optional[int] = ( + (1 if not should_skip_edge_feats else 0) + if has_weights_for_edge_type + else None + ) + + def _edge_feat_weight_pfn( + ids_chunk: torch.Tensor, _: object + ) -> torch.Tensor: assert edge_partition_book is not None - return edge_partition_book[edge_feature_ids] + return edge_partition_book[ids_chunk] - # partitioned_results is a list of tuples. Each tuple correpsonds - # to a chunk of data. A tuple contains edge features and edge ids. - edge_feat_res_list, _ = self._partition_by_chunk( - input_data=(edge_feat, edge_ids), + # Each result tuple contains (edge_feat?, edge_weights?, edge_ids). + feat_weight_res_list, _ = self._partition_by_chunk( + input_data=tuple(input_parts), rank_indices=edge_ids, - partition_function=_edge_feature_pfn, + partition_function=_edge_feat_weight_pfn, total_val_size=num_edges, generate_pb=False, ) - del edge_feat, edge_ids - del ( - self._edge_feat[edge_type], - self._edge_feat_dim[edge_type], - self._edge_ids[edge_type], - ) + + del edge_ids + del self._edge_ids[edge_type] if len(self._edge_ids) == 0: self._edge_ids = None - if len(self._edge_feat) == 0 and len(self._edge_feat_dim) == 0: - self._edge_feat = None - self._edge_feat_dim = None + if not should_skip_edge_feats: + assert edge_feat is not None + assert self._edge_feat is not None and self._edge_feat_dim is not None + del edge_feat + del self._edge_feat[edge_type], self._edge_feat_dim[edge_type] + if len(self._edge_feat) == 0 and len(self._edge_feat_dim) == 0: + self._edge_feat = None + self._edge_feat_dim = None + if has_weights_for_edge_type: + assert edge_weights_tensor is not None + assert self._edge_weights is not None + del edge_weights_tensor + del self._edge_weights[edge_type] + if len(self._edge_weights) == 0: + self._edge_weights = None gc.collect() - if len(edge_feat_res_list) == 0: - partitioned_edge_features = torch.empty(0, edge_feat_dim) - partitioned_edge_feat_ids = torch.empty(0) + + if len(feat_weight_res_list) == 0: + partitioned_edge_ids = torch.empty(0) + if not should_skip_edge_feats: + assert edge_feat_dim is not None + current_feat_part = FeaturePartitionData( + feats=torch.empty(0, edge_feat_dim), + ids=partitioned_edge_ids, + ) + if has_weights_for_edge_type: + partitioned_weights = torch.empty(0) else: - partitioned_edge_features = torch.cat( - [r[0] for r in edge_feat_res_list] - ) - partitioned_edge_feat_ids = torch.cat( - [r[1] for r in edge_feat_res_list] - ) + partitioned_edge_ids = torch.cat([r[-1] for r in feat_weight_res_list]) + if not should_skip_edge_feats: + assert feat_idx is not None and edge_feat_dim is not None + current_feat_part = FeaturePartitionData( + feats=torch.cat([r[feat_idx] for r in feat_weight_res_list]), + ids=partitioned_edge_ids, + ) + if has_weights_for_edge_type: + assert weight_idx is not None + partitioned_weights = torch.cat( + [r[weight_idx] for r in feat_weight_res_list] + ) + + feat_weight_res_list.clear() + gc.collect() - current_feat_part = FeaturePartitionData( - feats=partitioned_edge_features, ids=partitioned_edge_feat_ids - ) - logger.info( - f"Got edge tensor-based partition book for edge type {edge_type} on rank {self._rank} of shape {edge_partition_book.shape}" - ) + if not should_skip_edge_feats: + logger.info( + f"Got edge tensor-based partition book for edge type {edge_type} on rank {self._rank} of shape {edge_partition_book.shape}" + ) logger.info( f"Edge Index and Feature Partitioning for edge type {edge_type} finished, took {time.time() - start_time:.3f}s" ) + current_graph_part = GraphPartitionData( + edge_index=partitioned_edge_index, + edge_ids=partitioned_edge_ids, + weights=partitioned_weights, + ) + return current_graph_part, current_feat_part, edge_partition_book def _partition_label_edge_index( diff --git a/tests/test_assets/distributed/run_distributed_partitioner.py b/tests/test_assets/distributed/run_distributed_partitioner.py index 9e76d107b..b06177fec 100644 --- a/tests/test_assets/distributed/run_distributed_partitioner.py +++ b/tests/test_assets/distributed/run_distributed_partitioner.py @@ -19,6 +19,9 @@ class InputDataStrategy(Enum): REGISTER_ALL_ENTITIES_SEPARATELY = "REGISTER_ALL_ENTITIES_SEPARATELY" REGISTER_ALL_ENTITIES_TOGETHER = "REGISTER_ALL_ENTITIES_TOGETHER" REGISTER_MINIMAL_ENTITIES_SEPARATELY = "REGISTER_MINIMAL_ENTITIES_SEPARATELY" + REGISTER_EDGE_WEIGHTS_WITHOUT_EDGE_FEATURES = ( + "REGISTER_EDGE_WEIGHTS_WITHOUT_EDGE_FEATURES" + ) def run_distributed_partitioner( @@ -134,6 +137,69 @@ def run_distributed_partitioner( node_partition_book=output_node_partition_book, is_positive=False ) + partition_output = PartitionOutput( + node_partition_book=output_node_partition_book, + edge_partition_book=output_edge_partition_book, + partitioned_edge_index=output_edge_index, + partitioned_node_features=output_node_features, + partitioned_node_labels=output_node_labels, + partitioned_edge_features=output_edge_features, + partitioned_positive_labels=output_positive_labels, + partitioned_negative_labels=output_negative_labels, + ) + elif ( + input_data_strategy + == InputDataStrategy.REGISTER_EDGE_WEIGHTS_WITHOUT_EDGE_FEATURES + ): + # Same as REGISTER_ALL_ENTITIES_SEPARATELY but skips edge feature registration, + # so weights can be tested independently of features. + dist_partitioner = partitioner_class( + should_assign_edges_by_src_node=should_assign_edges_by_src_node, + ) + dist_partitioner.register_node_ids(node_ids=node_ids) + del node_ids + output_node_partition_book = dist_partitioner.partition_node() + + dist_partitioner.register_edge_index(edge_index=edge_index) + del edge_index, edge_features + if rank_to_edge_weights is not None and rank in rank_to_edge_weights: + dist_partitioner.register_edge_weights( + edge_weights=rank_to_edge_weights[rank] + ) + ( + output_edge_index, + output_edge_features, + output_edge_partition_book, + ) = dist_partitioner.partition_edge_index_and_edge_features( + node_partition_book=output_node_partition_book + ) + + dist_partitioner.register_node_features(node_features=node_features) + dist_partitioner.register_node_labels(node_labels=node_labels) + del node_labels, node_features + ( + output_node_features, + output_node_labels, + ) = dist_partitioner.partition_node_features_and_labels( + node_partition_book=output_node_partition_book + ) + + dist_partitioner.register_labels( + label_edge_index=positive_labels, is_positive=True + ) + del positive_labels + output_positive_labels = dist_partitioner.partition_labels( + node_partition_book=output_node_partition_book, is_positive=True + ) + + dist_partitioner.register_labels( + label_edge_index=negative_labels, is_positive=False + ) + del negative_labels + output_negative_labels = dist_partitioner.partition_labels( + node_partition_book=output_node_partition_book, is_positive=False + ) + partition_output = PartitionOutput( node_partition_book=output_node_partition_book, edge_partition_book=output_edge_partition_book, diff --git a/tests/unit/distributed/distributed_weighted_sampling_test.py b/tests/unit/distributed/distributed_weighted_sampling_test.py index 8f371ec6c..7d5107c47 100644 --- a/tests/unit/distributed/distributed_weighted_sampling_test.py +++ b/tests/unit/distributed/distributed_weighted_sampling_test.py @@ -3,8 +3,9 @@ Covers two surfaces: 1. DistPartitioner correctly partitions registered edge weights (weights land on the right rank and match the expected values). - 2. DistNeighborLoader runs end-to-end without errors when with_weight=True is set - and the dataset carries edge weights — both homogeneous and heterogeneous. + 2. DistNeighborLoader with with_weight=True never traverses weight=0 edges — + verified by encoding node type into features (hub=2.0, good=1.0, bad=0.0) + and asserting no bad node appears in any sampled subgraph. """ from collections.abc import Mapping @@ -32,6 +33,8 @@ MOCKED_U2U_EDGE_INDEX_ON_RANK_ONE, MOCKED_U2U_EDGE_INDEX_ON_RANK_ZERO, RANK_TO_MOCKED_GRAPH, + USER_TO_ITEM_EDGE_TYPE, + USER_TO_USER_EDGE_TYPE, ) from tests.test_assets.distributed.run_distributed_partitioner import ( InputDataStrategy, @@ -41,9 +44,184 @@ from tests.test_assets.test_case import TestCase _USER = NodeType("user") -_STORY = NodeType("story") -_USER_TO_STORY = EdgeType(_USER, Relation("to"), _STORY) -_STORY_TO_USER = EdgeType(_STORY, Relation("to"), _USER) +_ITEM = NodeType("item") +_USER_TO_ITEM = EdgeType(_USER, Relation("to"), _ITEM) +_ITEM_TO_USER = EdgeType(_ITEM, Relation("to"), _USER) + + +# --------------------------------------------------------------------------- +# Graph builders +# --------------------------------------------------------------------------- + + +def _build_homogeneous_bipartite_weight_graph() -> tuple[ + PartitionOutput, int, int, int +]: + """Build a homogeneous graph with hub, good, and bad nodes. + + Graph structure: + - 10 hub nodes (0..9): used as seed nodes; feature value = 2.0 + - 50 good nodes (10..59): reachable from hubs via weight=1 edges; feature = 1.0 + - 40 bad nodes (60..99): reachable from hubs via weight=0 edges; feature = 0.0 + - Each good node also has 5 outgoing weight=1 edges to nearby good nodes + (ring topology, for 2nd-hop sampling). + + With weighted sampling only good nodes should ever appear as sampled + neighbors — weight=0 edges to bad nodes must never be traversed. + + Returns: + (partition_output, n_hub, n_good, n_bad) + """ + n_hub = 10 + n_good = 50 + n_bad = 40 + n = n_hub + n_good + n_bad # 100 + + hub_ids = torch.arange(n_hub) + good_ids = torch.arange(n_hub, n_hub + n_good) + bad_ids = torch.arange(n_hub + n_good, n) + + # Hub → Good: weight=1 + hub_good_src = hub_ids.repeat_interleave(n_good) + hub_good_dst = good_ids.repeat(n_hub) + hub_good_w = torch.ones(n_hub * n_good) + + # Hub → Bad: weight=0 + hub_bad_src = hub_ids.repeat_interleave(n_bad) + hub_bad_dst = bad_ids.repeat(n_hub) + hub_bad_w = torch.zeros(n_hub * n_bad) + + # Good → Good: ring with 5 outgoing edges per node, weight=1 (2nd-hop targets) + connections_per_good = 5 + good_src = good_ids.repeat_interleave(connections_per_good) + # Row i of [connections_per_good, n_good].T gives neighbors of good_ids[i] + good_dst = torch.stack( + [torch.roll(good_ids, -j) for j in range(1, connections_per_good + 1)] + ).T.reshape(-1) + good_w = torch.ones(n_good * connections_per_good) + + edge_src = torch.cat([hub_good_src, hub_bad_src, good_src]) + edge_dst = torch.cat([hub_good_dst, hub_bad_dst, good_dst]) + weights = torch.cat([hub_good_w, hub_bad_w, good_w]) + edge_index = torch.stack([edge_src, edge_dst]) + n_edges = edge_src.shape[0] + + # Feature encodes node type: hub=2.0, good=1.0, bad=0.0 + node_feats = torch.cat( + [ + torch.full((n_hub, 1), 2.0), + torch.full((n_good, 1), 1.0), + torch.full((n_bad, 1), 0.0), + ] + ) + + partition_output = PartitionOutput( + node_partition_book=torch.zeros(n), + edge_partition_book=torch.zeros(n_edges), + partitioned_edge_index=GraphPartitionData( + edge_index=edge_index, + edge_ids=None, + weights=weights, + ), + partitioned_node_features=FeaturePartitionData( + feats=node_feats, + ids=torch.arange(n), + ), + partitioned_edge_features=None, + partitioned_positive_labels=None, + partitioned_negative_labels=None, + partitioned_node_labels=None, + ) + return partition_output, n_hub, n_good, n_bad + + +def _build_heterogeneous_bipartite_weight_graph() -> tuple[ + PartitionOutput, int, int, int +]: + """Build a heterogeneous (user/item) graph with good and bad item nodes. + + Graph structure: + - 10 user nodes (0..9): seed nodes; user feature = 2.0 + - 60 item nodes total: + - Items 0..39: good, reachable from users via weight=1 edges; feature = 1.0 + - Items 40..59: bad, reachable from users via weight=0 edges; feature = 0.0 + - Good items also have weight=1 edges back to all users (for 2nd-hop). + + With weighted sampling only good item nodes should ever appear as sampled + item neighbors. + + Returns: + (partition_output, n_user, n_good_item, n_bad_item) + """ + n_user = 10 + n_good_item = 40 + n_bad_item = 20 + n_item = n_good_item + n_bad_item # 60 + + user_ids = torch.arange(n_user) + good_item_ids = torch.arange(n_good_item) + bad_item_ids = torch.arange(n_good_item, n_item) + + # User → Good Item: weight=1 + u2gi_src = user_ids.repeat_interleave(n_good_item) + u2gi_dst = good_item_ids.repeat(n_user) + u2gi_w = torch.ones(n_user * n_good_item) + + # User → Bad Item: weight=0 + u2bi_src = user_ids.repeat_interleave(n_bad_item) + u2bi_dst = bad_item_ids.repeat(n_user) + u2bi_w = torch.zeros(n_user * n_bad_item) + + # Good Item → User: weight=1 (2nd-hop back to users) + gi2u_src = good_item_ids.repeat_interleave(n_user) + gi2u_dst = user_ids.repeat(n_good_item) + gi2u_w = torch.ones(n_good_item * n_user) + + u2i_src = torch.cat([u2gi_src, u2bi_src]) + u2i_dst = torch.cat([u2gi_dst, u2bi_dst]) + u2i_w = torch.cat([u2gi_w, u2bi_w]) + n_u2i_edges = u2i_src.shape[0] + + user_feats = torch.full((n_user, 1), 2.0) + # Item feature encodes type: good=1.0, bad=0.0 + item_feats = torch.cat( + [ + torch.full((n_good_item, 1), 1.0), + torch.full((n_bad_item, 1), 0.0), + ] + ) + + partition_output = PartitionOutput( + node_partition_book={ + _USER: torch.zeros(n_user), + _ITEM: torch.zeros(n_item), + }, + edge_partition_book={ + _USER_TO_ITEM: torch.zeros(n_u2i_edges), + _ITEM_TO_USER: torch.zeros(gi2u_src.shape[0]), + }, + partitioned_edge_index={ + _USER_TO_ITEM: GraphPartitionData( + edge_index=torch.stack([u2i_src, u2i_dst]), + edge_ids=None, + weights=u2i_w, + ), + _ITEM_TO_USER: GraphPartitionData( + edge_index=torch.stack([gi2u_src, gi2u_dst]), + edge_ids=None, + weights=gi2u_w, + ), + }, + partitioned_node_features={ + _USER: FeaturePartitionData(feats=user_feats, ids=torch.arange(n_user)), + _ITEM: FeaturePartitionData(feats=item_feats, ids=torch.arange(n_item)), + }, + partitioned_edge_features=None, + partitioned_positive_labels=None, + partitioned_negative_labels=None, + partitioned_node_labels=None, + ) + return partition_output, n_user, n_good_item, n_bad_item # --------------------------------------------------------------------------- @@ -51,54 +229,73 @@ # --------------------------------------------------------------------------- -def _run_distributed_weighted_neighbor_loader_homogeneous( +def _run_weighted_sampling_correctness_homogeneous( _: int, dataset: DistDataset, - expected_data_count: int, + n_hub: int, ) -> None: - """Subprocess: iterates a weighted homogeneous loader and checks batch count and type.""" + """Subprocess: verifies weight=0 edges are never traversed in homogeneous graph. + + Seeds are the hub nodes only. Node features encode the type: + hub=2.0, good=1.0, bad=0.0. Any subgraph batch containing a bad node + (feature==0.0) means a weight=0 edge was sampled — a test failure. + """ create_test_process_group() loader = DistNeighborLoader( dataset=dataset, - num_neighbors=[2, 2], + input_nodes=torch.arange(n_hub), + num_neighbors=[10, 5], with_weight=True, pin_memory_device=torch.device("cpu"), ) count = 0 for datum in loader: - assert isinstance(datum, Data), ( - f"Subgraph should be Data for homogeneous datasets, got {type(datum)}" + assert isinstance(datum, Data), f"Expected Data, got {type(datum)}" + assert datum.x is not None, "Node features missing from sampled subgraph" + # Bad nodes have feature 0.0; hub and good nodes have feature > 0. + bad_mask = datum.x[:, 0] == 0.0 + assert not bad_mask.any(), ( + f"weight=0 edge was sampled: bad node(s) found in subgraph. " + f"Features of bad nodes: {datum.x[bad_mask].squeeze().tolist()}" ) count += 1 - assert count == expected_data_count, ( - f"Expected {expected_data_count} batches, got {count}" - ) + assert count == n_hub, f"Expected {n_hub} batches (one per hub seed), got {count}" shutdown_rpc() -def _run_distributed_weighted_neighbor_loader_heterogeneous( +def _run_weighted_sampling_correctness_heterogeneous( _: int, dataset: DistDataset, - expected_data_count: int, + n_user: int, ) -> None: - """Subprocess: iterates a weighted heterogeneous loader and checks batch count and type.""" + """Subprocess: verifies weight=0 edges are never traversed in heterogeneous graph. + + Seeds are all user nodes. Item features encode type: good=1.0, bad=0.0. + Any batch containing a bad item node means a weight=0 edge was sampled. + """ create_test_process_group() assert isinstance(dataset.node_ids, Mapping) loader = DistNeighborLoader( dataset=dataset, input_nodes=(_USER, dataset.node_ids[_USER]), - num_neighbors=[2, 2], + num_neighbors=[10, 5], with_weight=True, pin_memory_device=torch.device("cpu"), ) count = 0 for datum in loader: - assert isinstance(datum, HeteroData), ( - f"Subgraph should be HeteroData for heterogeneous datasets, got {type(datum)}" - ) + assert isinstance(datum, HeteroData), f"Expected HeteroData, got {type(datum)}" + if _ITEM in datum.node_types: + item_x = datum[_ITEM].x + assert item_x is not None, "Item features missing from sampled subgraph" + bad_mask = item_x[:, 0] == 0.0 + assert not bad_mask.any(), ( + f"weight=0 edge was sampled: bad item node(s) found. " + f"Features of bad items: {item_x[bad_mask].squeeze().tolist()}" + ) count += 1 - assert count == expected_data_count, ( - f"Expected {expected_data_count} batches, got {count}" + assert count == n_user, ( + f"Expected {n_user} batches (one per user seed), got {count}" ) shutdown_rpc() @@ -182,9 +379,265 @@ def test_homogeneous_weights_partitioned_correctly(self) -> None: msg=f"Rank {rank}: partitioned weights do not match expected src_node_id / 10.0", ) + def test_features_only_weights_are_none(self) -> None: + """Features only (no weights): weights must be None, edge_ids and features present.""" + master_port = get_free_port() + manager = Manager() + output_dict: MutableMapping[int, PartitionOutput] = manager.dict() + + mp.spawn( + run_distributed_partitioner, + args=( + output_dict, + False, + RANK_TO_MOCKED_GRAPH, + True, + self._master_ip_address, + master_port, + InputDataStrategy.REGISTER_ALL_ENTITIES_SEPARATELY, + DistPartitioner, + None, + ), + nprocs=MOCKED_NUM_PARTITIONS, + join=True, + ) + + for rank, partition_output in output_dict.items(): + gpd = partition_output.partitioned_edge_index + self.assertIsInstance(gpd, GraphPartitionData) + assert isinstance(gpd, GraphPartitionData) + self.assertIsNone( + gpd.weights, + msg=f"Rank {rank}: weights must be None when not registered", + ) + self.assertIsNotNone( + gpd.edge_ids, + msg=f"Rank {rank}: edge_ids must be present when features are registered", + ) + self.assertIsNotNone( + partition_output.partitioned_edge_features, + msg=f"Rank {rank}: expected features to be partitioned", + ) + + def test_weights_only_no_features_partitioned_correctly(self) -> None: + """Weights without features: feature part is None, weights have correct values.""" + master_port = get_free_port() + manager = Manager() + output_dict: MutableMapping[int, PartitionOutput] = manager.dict() + + rank_to_edge_weights = { + 0: MOCKED_U2U_EDGE_INDEX_ON_RANK_ZERO[0].float() / 10.0, + 1: MOCKED_U2U_EDGE_INDEX_ON_RANK_ONE[0].float() / 10.0, + } + + mp.spawn( + run_distributed_partitioner, + args=( + output_dict, + False, + RANK_TO_MOCKED_GRAPH, + True, + self._master_ip_address, + master_port, + InputDataStrategy.REGISTER_EDGE_WEIGHTS_WITHOUT_EDGE_FEATURES, + DistPartitioner, + rank_to_edge_weights, + ), + nprocs=MOCKED_NUM_PARTITIONS, + join=True, + ) + + for rank, partition_output in output_dict.items(): + gpd = partition_output.partitioned_edge_index + self.assertIsInstance(gpd, GraphPartitionData) + assert isinstance(gpd, GraphPartitionData) + self.assertIsNone( + partition_output.partitioned_edge_features, + msg=f"Rank {rank}: features must be None when not registered", + ) + + weights = gpd.weights + self.assertIsNotNone( + weights, msg=f"Rank {rank}: expected weights to be partitioned" + ) + assert weights is not None + + edge_ids = gpd.edge_ids + self.assertIsNotNone( + edge_ids, + msg=f"Rank {rank}: edge_ids must be present when weights are registered", + ) + assert edge_ids is not None + + self.assertEqual(weights.shape, edge_ids.shape) + expected_weights = edge_ids.float() * 0.1 + torch.testing.assert_close( + weights.sort().values, + expected_weights.sort().values, + msg=f"Rank {rank}: weights do not match src_node_id / 10.0", + ) + + def test_neither_features_nor_weights_gives_none_edge_ids(self) -> None: + """No features, no weights: edge_ids and weights must both be None.""" + master_port = get_free_port() + manager = Manager() + output_dict: MutableMapping[int, PartitionOutput] = manager.dict() + + mp.spawn( + run_distributed_partitioner, + args=( + output_dict, + False, + RANK_TO_MOCKED_GRAPH, + True, + self._master_ip_address, + master_port, + InputDataStrategy.REGISTER_MINIMAL_ENTITIES_SEPARATELY, + DistPartitioner, + ), + nprocs=MOCKED_NUM_PARTITIONS, + join=True, + ) + + for rank, partition_output in output_dict.items(): + gpd = partition_output.partitioned_edge_index + self.assertIsInstance(gpd, GraphPartitionData) + assert isinstance(gpd, GraphPartitionData) + self.assertIsNone( + gpd.weights, + msg=f"Rank {rank}: weights must be None when not registered", + ) + self.assertIsNone( + gpd.edge_ids, + msg=f"Rank {rank}: edge_ids must be None when neither features nor weights are registered", + ) + self.assertIsNone( + partition_output.partitioned_edge_features, + msg=f"Rank {rank}: features must be None", + ) + + def test_features_and_weights_produce_consistent_edge_ids(self) -> None: + """Both registered: GraphPartitionData.edge_ids must equal FeaturePartitionData.ids.""" + master_port = get_free_port() + manager = Manager() + output_dict: MutableMapping[int, PartitionOutput] = manager.dict() + + rank_to_edge_weights = { + 0: MOCKED_U2U_EDGE_INDEX_ON_RANK_ZERO[0].float() / 10.0, + 1: MOCKED_U2U_EDGE_INDEX_ON_RANK_ONE[0].float() / 10.0, + } + + mp.spawn( + run_distributed_partitioner, + args=( + output_dict, + False, + RANK_TO_MOCKED_GRAPH, + True, + self._master_ip_address, + master_port, + InputDataStrategy.REGISTER_ALL_ENTITIES_SEPARATELY, + DistPartitioner, + rank_to_edge_weights, + ), + nprocs=MOCKED_NUM_PARTITIONS, + join=True, + ) + + for rank, partition_output in output_dict.items(): + gpd = partition_output.partitioned_edge_index + assert isinstance(gpd, GraphPartitionData) + feat_part = partition_output.partitioned_edge_features + self.assertIsInstance(feat_part, FeaturePartitionData) + assert isinstance(feat_part, FeaturePartitionData) + + self.assertIsNotNone(gpd.edge_ids) + self.assertIsNotNone(gpd.weights) + assert gpd.edge_ids is not None + + torch.testing.assert_close( + gpd.edge_ids, + feat_part.ids, + msg=f"Rank {rank}: GraphPartitionData.edge_ids must equal FeaturePartitionData.ids", + ) + + def test_heterogeneous_partial_weights_by_edge_type(self) -> None: + """Heterogeneous: edge type with weights has them; edge type without weights has None.""" + master_port = get_free_port() + manager = Manager() + output_dict: MutableMapping[int, PartitionOutput] = manager.dict() + + rank_to_edge_weights = { + 0: { + USER_TO_USER_EDGE_TYPE: MOCKED_U2U_EDGE_INDEX_ON_RANK_ZERO[0].float() + / 10.0 + }, + 1: { + USER_TO_USER_EDGE_TYPE: MOCKED_U2U_EDGE_INDEX_ON_RANK_ONE[0].float() + / 10.0 + }, + } + + mp.spawn( + run_distributed_partitioner, + args=( + output_dict, + True, # is_heterogeneous + RANK_TO_MOCKED_GRAPH, + True, + self._master_ip_address, + master_port, + InputDataStrategy.REGISTER_ALL_ENTITIES_SEPARATELY, + DistPartitioner, + rank_to_edge_weights, + ), + nprocs=MOCKED_NUM_PARTITIONS, + join=True, + ) + + for rank, partition_output in output_dict.items(): + partitioned_edge_index = partition_output.partitioned_edge_index + self.assertIsInstance(partitioned_edge_index, dict) + assert isinstance(partitioned_edge_index, dict) + + u2u_gpd = partitioned_edge_index[USER_TO_USER_EDGE_TYPE] + u2i_gpd = partitioned_edge_index[USER_TO_ITEM_EDGE_TYPE] + + self.assertIsNotNone( + u2u_gpd.weights, + msg=f"Rank {rank}: U2U edge type should have weights", + ) + self.assertIsNone( + u2i_gpd.weights, + msg=f"Rank {rank}: U2I edge type should not have weights", + ) + # U2U edge_ids present (has both features and weights). + self.assertIsNotNone( + u2u_gpd.edge_ids, + msg=f"Rank {rank}: U2U edge_ids must be present", + ) + # U2I has neither features nor weights, so edge_ids is None. + self.assertIsNone( + u2i_gpd.edge_ids, + msg=f"Rank {rank}: U2I edge_ids must be None", + ) + class DistributedWeightedSamplingTest(TestCase): - """End-to-end dataloading tests for DistNeighborLoader with with_weight=True.""" + """End-to-end correctness tests for DistNeighborLoader with with_weight=True. + + Each test builds a graph with two classes of neighbors: + - "good" neighbors connected by weight=1 edges + - "bad" neighbors connected by weight=0 edges + + Node features encode the class (good=1.0, bad=0.0). After weighted sampling, + no bad node should ever appear in a sampled subgraph — if it does, a weight=0 + edge was traversed, indicating a bug. + + Graph sizes are chosen so that the fanout is strictly smaller than the number + of available good neighbors, ensuring the sampler actively makes choices + (rather than returning all neighbors) and the test is non-trivial. + """ def setUp(self) -> None: super().setUp() @@ -195,76 +648,42 @@ def tearDown(self) -> None: torch.distributed.destroy_process_group() super().tearDown() - def test_distributed_neighbor_loader_with_weights_homogeneous(self) -> None: - """Homogeneous loader with with_weight=True iterates all nodes without error.""" - n = 5 - partition_output = PartitionOutput( - node_partition_book=torch.zeros(n), - edge_partition_book=torch.zeros(n), - partitioned_edge_index=GraphPartitionData( - edge_index=torch.tensor([[0, 1, 2, 3, 4], [1, 2, 3, 4, 0]]), - edge_ids=None, - weights=torch.ones(n, dtype=torch.float32), - ), - partitioned_node_features=FeaturePartitionData( - feats=torch.zeros(n, 2), ids=torch.arange(n) - ), - partitioned_edge_features=None, - partitioned_positive_labels=None, - partitioned_negative_labels=None, - partitioned_node_labels=None, - ) + def test_weighted_sampling_never_traverses_zero_weight_edges_homogeneous( + self, + ) -> None: + """Homogeneous: weight=0 edges to bad nodes are never traversed. + + Graph: 10 hub seeds, each with 50 good neighbors (weight=1) and 40 bad + neighbors (weight=0). Good nodes have 5 further weight=1 edges for 2nd-hop + sampling. Fanout [10, 5] samples fewer neighbors than available good ones, + so the weighted sampler actively selects from the pool each hop. + """ + partition_output, n_hub, _, _ = _build_homogeneous_bipartite_weight_graph() dataset = DistDataset(rank=0, world_size=1, edge_dir="out") dataset.build(partition_output=partition_output) mp.spawn( - fn=_run_distributed_weighted_neighbor_loader_homogeneous, - args=(dataset, n), + fn=_run_weighted_sampling_correctness_homogeneous, + args=(dataset, n_hub), ) - def test_distributed_neighbor_loader_with_weights_heterogeneous(self) -> None: - """Heterogeneous loader with with_weight=True iterates all seed nodes without error.""" - n = 5 - partition_output = PartitionOutput( - node_partition_book={ - _USER: torch.zeros(n), - _STORY: torch.zeros(n), - }, - edge_partition_book={ - _USER_TO_STORY: torch.zeros(n), - _STORY_TO_USER: torch.zeros(n), - }, - partitioned_edge_index={ - _USER_TO_STORY: GraphPartitionData( - edge_index=torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]), - edge_ids=None, - weights=torch.ones(n, dtype=torch.float32), - ), - _STORY_TO_USER: GraphPartitionData( - edge_index=torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]), - edge_ids=None, - weights=torch.ones(n, dtype=torch.float32), - ), - }, - partitioned_node_features={ - _USER: FeaturePartitionData( - feats=torch.zeros(n, 2), ids=torch.arange(n) - ), - _STORY: FeaturePartitionData( - feats=torch.zeros(n, 2), ids=torch.arange(n) - ), - }, - partitioned_edge_features=None, - partitioned_positive_labels=None, - partitioned_negative_labels=None, - partitioned_node_labels=None, - ) + def test_weighted_sampling_never_traverses_zero_weight_edges_heterogeneous( + self, + ) -> None: + """Heterogeneous: weight=0 user→item edges to bad items are never traversed. + + Graph: 10 user seeds, each with 40 good items (weight=1) and 20 bad items + (weight=0). Good items have weight=1 back-edges to all users for 2nd-hop. + Fanout [10, 5] is smaller than the 40 available good items, so the sampler + actively selects. + """ + partition_output, n_user, _, _ = _build_heterogeneous_bipartite_weight_graph() dataset = DistDataset(rank=0, world_size=1, edge_dir="out") dataset.build(partition_output=partition_output) mp.spawn( - fn=_run_distributed_weighted_neighbor_loader_heterogeneous, - args=(dataset, n), + fn=_run_weighted_sampling_correctness_heterogeneous, + args=(dataset, n_user), ) From f541ae842fb09d23fcb856b93a57969ffd919448 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 13 May 2026 23:36:00 +0000 Subject: [PATCH 05/13] Address smells --- gigl/distributed/base_dist_loader.py | 11 +-- gigl/distributed/dataset_factory.py | 11 ++- gigl/distributed/dist_dataset.py | 1 + gigl/distributed/dist_partitioner.py | 27 +++++-- gigl/distributed/graph_store/dist_server.py | 8 +++ .../graph_store/remote_dist_dataset.py | 11 +++ .../run_distributed_partitioner.py | 71 ++----------------- 7 files changed, 64 insertions(+), 76 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 3931b4175..52fe9cc9f 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -350,11 +350,12 @@ def validate_with_weight( ValueError: If ``with_weight=True`` but no edge weights are registered. NotImplementedError: If ``with_weight=True`` and a PPR sampler is requested. """ - if ( - with_weight - and isinstance(dataset, DistDataset) - and not dataset.has_edge_weights - ): + has_edge_weights = ( + dataset.has_edge_weights + if isinstance(dataset, DistDataset) + else dataset.fetch_has_edge_weights() + ) + if with_weight and not has_edge_weights: raise ValueError( "with_weight=True requires edge weights to be registered in the dataset. " "Pass weight_edge_feat_name to build_dataset() to register edge weights." diff --git a/gigl/distributed/dataset_factory.py b/gigl/distributed/dataset_factory.py index 258c9da1f..ec1eae101 100644 --- a/gigl/distributed/dataset_factory.py +++ b/gigl/distributed/dataset_factory.py @@ -63,7 +63,7 @@ def _extract_weight_column( edge_entity_info: Union[ SerializedTFRecordInfo, dict[EdgeType, SerializedTFRecordInfo] ], -) -> Tuple[ +) -> tuple[ dict[EdgeType, torch.Tensor], Union[torch.Tensor, dict[EdgeType, torch.Tensor]], ]: @@ -120,7 +120,13 @@ def _extract_weight_column( continue feat_name = weight_name_dict[edge_type] - info = entity_info_dict[edge_type] + info = entity_info_dict.get(edge_type) + if info is None: + raise ValueError( + f"weight_edge_feat_name specifies edge type {edge_type} but no " + f"SerializedTFRecordInfo is available for that type. " + f"Available edge types: {list(entity_info_dict.keys())}" + ) feature_keys = list(info.feature_keys) if feat_name not in feature_keys: @@ -177,6 +183,7 @@ def _load_and_build_partitioned_dataset( sampling weights. The column is extracted from the feature tensor and registered separately via ``DistPartitioner.register_edge_weights()``; it is removed from the feature matrix to avoid duplication. Supply a single string to use the same column name for all edge types, or a per-edge-type dict. + Returns: DistDataset: Initialized dataset with partitioned graph information diff --git a/gigl/distributed/dist_dataset.py b/gigl/distributed/dist_dataset.py index dd2682e74..46017a079 100644 --- a/gigl/distributed/dist_dataset.py +++ b/gigl/distributed/dist_dataset.py @@ -938,6 +938,7 @@ def share_ipc( Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]: Edge feature dim and its data type, will be a dict if heterogeneous Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Degree tensors, will be a dict if heterogeneous Optional[int]: Optional per-anchor label cap for ABLP label fetching + bool: Whether edge weights were registered during dataset construction """ # TODO (mkolodner-sc): Investigate moving share_memory calls to the build() function diff --git a/gigl/distributed/dist_partitioner.py b/gigl/distributed/dist_partitioner.py index d9da0caa5..ec095202b 100644 --- a/gigl/distributed/dist_partitioner.py +++ b/gigl/distributed/dist_partitioner.py @@ -655,9 +655,25 @@ def register_edge_weights( input_edge_entity=edge_weights ) - assert input_edge_weights, ( - "Edge weights is an empty dictionary. Please provide edge weights to register." - ) + if not input_edge_weights: + raise ValueError( + "Edge weights is an empty dictionary. Please provide edge weights to register." + ) + + for edge_type, weight_tensor in input_edge_weights.items(): + if weight_tensor.ndim != 1: + raise ValueError( + f"Edge weights for edge type {edge_type} must be a 1-D tensor of shape " + f"[num_edges], got shape {tuple(weight_tensor.shape)}." + ) + if self._edge_index is not None and edge_type in self._edge_index: + local_num_edges = self._edge_index[edge_type].shape[1] + if weight_tensor.shape[0] != local_num_edges: + raise ValueError( + f"Edge weights for edge type {edge_type} have length " + f"{weight_tensor.shape[0]} but the registered edge index has " + f"{local_num_edges} edges on this rank." + ) self._edge_weights = convert_to_tensor(input_edge_weights, dtype=torch.float32) @@ -1161,7 +1177,8 @@ def _edge_pfn(_, chunk_range): gc.collect() - if len(edge_res_list) == 0: + had_zero_edges = len(edge_res_list) == 0 + if had_zero_edges: partitioned_edge_index = torch.empty((2, 0)) else: partitioned_edge_index = torch.stack( @@ -1189,6 +1206,8 @@ def _edge_pfn(_, chunk_range): logger.info( f"No edge features detected for edge type {edge_type}, will only partition edge indices for this edge type." ) + if had_zero_edges: + partitioned_edge_ids = torch.empty(0) del edge_ids del self._edge_ids[edge_type] if len(self._edge_ids) == 0: diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index 9a2ca23dc..6e17b8b0e 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -425,6 +425,14 @@ def get_edge_dir(self) -> Literal["in", "out"]: """ return self.dataset.edge_dir + def get_has_edge_weights(self) -> bool: + """Return whether edge weights were registered in the dataset. + + Returns: + True if edge weights were registered via ``DistPartitioner.register_edge_weights()``. + """ + return self.dataset.has_edge_weights + def get_node_ids( self, request: FetchNodesRequest, diff --git a/gigl/distributed/graph_store/remote_dist_dataset.py b/gigl/distributed/graph_store/remote_dist_dataset.py index cb03934af..c4e4a7b02 100644 --- a/gigl/distributed/graph_store/remote_dist_dataset.py +++ b/gigl/distributed/graph_store/remote_dist_dataset.py @@ -606,3 +606,14 @@ def fetch_node_types(self) -> Optional[list[NodeType]]: 0, DistServer.get_node_types, ) + + def fetch_has_edge_weights(self) -> bool: + """Fetch whether edge weights were registered in the remote dataset. + + Returns: + True if edge weights were registered via ``DistPartitioner.register_edge_weights()``. + """ + return request_server( + 0, + DistServer.get_has_edge_weights, + ) diff --git a/tests/test_assets/distributed/run_distributed_partitioner.py b/tests/test_assets/distributed/run_distributed_partitioner.py index b06177fec..3bbc3406c 100644 --- a/tests/test_assets/distributed/run_distributed_partitioner.py +++ b/tests/test_assets/distributed/run_distributed_partitioner.py @@ -85,7 +85,10 @@ def run_distributed_partitioner( init_rpc(master_addr=master_addr, master_port=master_port, num_rpc_threads=4) dist_partitioner: DistPartitioner - if input_data_strategy == InputDataStrategy.REGISTER_ALL_ENTITIES_SEPARATELY: + if input_data_strategy in ( + InputDataStrategy.REGISTER_ALL_ENTITIES_SEPARATELY, + InputDataStrategy.REGISTER_EDGE_WEIGHTS_WITHOUT_EDGE_FEATURES, + ): dist_partitioner = partitioner_class( should_assign_edges_by_src_node=should_assign_edges_by_src_node, ) @@ -95,8 +98,9 @@ def run_distributed_partitioner( output_node_partition_book = dist_partitioner.partition_node() dist_partitioner.register_edge_index(edge_index=edge_index) - dist_partitioner.register_edge_features(edge_features=edge_features) del edge_index + if input_data_strategy == InputDataStrategy.REGISTER_ALL_ENTITIES_SEPARATELY: + dist_partitioner.register_edge_features(edge_features=edge_features) del edge_features if rank_to_edge_weights is not None and rank in rank_to_edge_weights: dist_partitioner.register_edge_weights( @@ -137,69 +141,6 @@ def run_distributed_partitioner( node_partition_book=output_node_partition_book, is_positive=False ) - partition_output = PartitionOutput( - node_partition_book=output_node_partition_book, - edge_partition_book=output_edge_partition_book, - partitioned_edge_index=output_edge_index, - partitioned_node_features=output_node_features, - partitioned_node_labels=output_node_labels, - partitioned_edge_features=output_edge_features, - partitioned_positive_labels=output_positive_labels, - partitioned_negative_labels=output_negative_labels, - ) - elif ( - input_data_strategy - == InputDataStrategy.REGISTER_EDGE_WEIGHTS_WITHOUT_EDGE_FEATURES - ): - # Same as REGISTER_ALL_ENTITIES_SEPARATELY but skips edge feature registration, - # so weights can be tested independently of features. - dist_partitioner = partitioner_class( - should_assign_edges_by_src_node=should_assign_edges_by_src_node, - ) - dist_partitioner.register_node_ids(node_ids=node_ids) - del node_ids - output_node_partition_book = dist_partitioner.partition_node() - - dist_partitioner.register_edge_index(edge_index=edge_index) - del edge_index, edge_features - if rank_to_edge_weights is not None and rank in rank_to_edge_weights: - dist_partitioner.register_edge_weights( - edge_weights=rank_to_edge_weights[rank] - ) - ( - output_edge_index, - output_edge_features, - output_edge_partition_book, - ) = dist_partitioner.partition_edge_index_and_edge_features( - node_partition_book=output_node_partition_book - ) - - dist_partitioner.register_node_features(node_features=node_features) - dist_partitioner.register_node_labels(node_labels=node_labels) - del node_labels, node_features - ( - output_node_features, - output_node_labels, - ) = dist_partitioner.partition_node_features_and_labels( - node_partition_book=output_node_partition_book - ) - - dist_partitioner.register_labels( - label_edge_index=positive_labels, is_positive=True - ) - del positive_labels - output_positive_labels = dist_partitioner.partition_labels( - node_partition_book=output_node_partition_book, is_positive=True - ) - - dist_partitioner.register_labels( - label_edge_index=negative_labels, is_positive=False - ) - del negative_labels - output_negative_labels = dist_partitioner.partition_labels( - node_partition_book=output_node_partition_book, is_positive=False - ) - partition_output = PartitionOutput( node_partition_book=output_node_partition_book, edge_partition_book=output_edge_partition_book, From 090363a98bac95c5408b8baf83b424f62199556e Mon Sep 17 00:00:00 2001 From: mkolodner Date: Thu, 14 May 2026 17:45:37 +0000 Subject: [PATCH 06/13] small fixes --- gigl/distributed/base_dist_loader.py | 3 ++- gigl/distributed/dataset_factory.py | 17 ++++++++++++++--- gigl/distributed/graph_store/dist_server.py | 2 +- .../graph_store/remote_dist_dataset.py | 4 ++-- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 52fe9cc9f..483fc08f2 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -353,13 +353,14 @@ def validate_with_weight( has_edge_weights = ( dataset.has_edge_weights if isinstance(dataset, DistDataset) - else dataset.fetch_has_edge_weights() + else dataset.fetch_edge_weights_registered() ) if with_weight and not has_edge_weights: raise ValueError( "with_weight=True requires edge weights to be registered in the dataset. " "Pass weight_edge_feat_name to build_dataset() to register edge weights." ) + # TODO(mkolodner-sc): Implement weight-proportional residual propagation for PPR. if with_weight and isinstance(sampler_options, PPRSamplerOptions): raise NotImplementedError( "Weighted sampling is not yet supported with PPRSamplerOptions. " diff --git a/gigl/distributed/dataset_factory.py b/gigl/distributed/dataset_factory.py index ec1eae101..1c7a62c06 100644 --- a/gigl/distributed/dataset_factory.py +++ b/gigl/distributed/dataset_factory.py @@ -74,8 +74,10 @@ def _extract_weight_column( Args: edge_features: Edge feature tensor(s). - weight_edge_feat_name: Name of the weight feature column, either a single string - (applied to all edge types) or a per-edge-type mapping. + weight_edge_feat_name: Name of the weight feature column. A single string is + only valid for single-edge-type (homogeneous) graphs; heterogeneous graphs + must supply a ``dict[EdgeType, str]`` to be explicit about which edge + type(s) carry the weight column. edge_entity_info: SerializedTFRecordInfo carrying ordered ``feature_keys`` used to resolve the column name to an index. @@ -103,8 +105,17 @@ def _extract_weight_column( else: entity_info_dict = edge_entity_info - # Normalise weight_edge_feat_name to per-edge-type mapping + # Normalise weight_edge_feat_name to per-edge-type mapping. + # A bare string is only unambiguous for single-edge-type graphs; for + # heterogeneous graphs a dict[EdgeType, str] must be provided so that + # the caller is explicit about which edge type(s) carry the weight column. if isinstance(weight_edge_feat_name, str): + if len(edge_features_dict) > 1: + raise ValueError( + f"weight_edge_feat_name must be a dict[EdgeType, str] for heterogeneous graphs " + f"with multiple edge types ({sorted(edge_features_dict)}). " + "Provide an explicit per-edge-type mapping instead of a single string." + ) weight_name_dict: dict[EdgeType, str] = { et: weight_edge_feat_name for et in edge_features_dict } diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index 6e17b8b0e..636e8c332 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -425,7 +425,7 @@ def get_edge_dir(self) -> Literal["in", "out"]: """ return self.dataset.edge_dir - def get_has_edge_weights(self) -> bool: + def get_edge_weights_registered(self) -> bool: """Return whether edge weights were registered in the dataset. Returns: diff --git a/gigl/distributed/graph_store/remote_dist_dataset.py b/gigl/distributed/graph_store/remote_dist_dataset.py index c4e4a7b02..e58784dc0 100644 --- a/gigl/distributed/graph_store/remote_dist_dataset.py +++ b/gigl/distributed/graph_store/remote_dist_dataset.py @@ -607,7 +607,7 @@ def fetch_node_types(self) -> Optional[list[NodeType]]: DistServer.get_node_types, ) - def fetch_has_edge_weights(self) -> bool: + def fetch_edge_weights_registered(self) -> bool: """Fetch whether edge weights were registered in the remote dataset. Returns: @@ -615,5 +615,5 @@ def fetch_has_edge_weights(self) -> bool: """ return request_server( 0, - DistServer.get_has_edge_weights, + DistServer.get_edge_weights_registered, ) From 2bc11b17164ae28bf3bb4ccebf8d3c1dfb1e5ef9 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Thu, 14 May 2026 18:40:45 +0000 Subject: [PATCH 07/13] Update --- gigl/common/data/load_torch_tensors.py | 66 ++++++++++++- gigl/distributed/dataset_factory.py | 124 ++----------------------- gigl/types/graph.py | 2 + 3 files changed, 70 insertions(+), 122 deletions(-) diff --git a/gigl/common/data/load_torch_tensors.py b/gigl/common/data/load_torch_tensors.py index 9b311fd6d..dbda3da53 100644 --- a/gigl/common/data/load_torch_tensors.py +++ b/gigl/common/data/load_torch_tensors.py @@ -27,6 +27,7 @@ _ID_FMT = "{entity}_ids" _FEATURE_FMT = "{entity}_features" _LABEL_FMT = "{entity}_labels" +_WEIGHT_FMT = "{entity}_weights" _NODE_KEY = "node" _EDGE_KEY = "edge" _POSITIVE_LABEL_KEY = "positive_label" @@ -72,6 +73,7 @@ def _data_loading_process( ], rank: int, tf_dataset_options: TFDatasetOptions = TFDatasetOptions(), + weight_edge_feat_name: Optional[Union[str, dict[EdgeType, str]]] = None, ) -> None: """ Spawned multiprocessing.Process which loads homogeneous or heterogeneous information for a specific entity type [node, edge, positive_label, negative_label] @@ -117,6 +119,7 @@ def _data_loading_process( ids: dict[Union[NodeType, EdgeType], torch.Tensor] = {} features: dict[Union[NodeType, EdgeType], torch.Tensor] = {} labels: dict[Union[NodeType, EdgeType], torch.Tensor] = {} + weights: dict[Union[NodeType, EdgeType], torch.Tensor] = {} for ( graph_type, serialized_entity_tf_record_info, @@ -129,14 +132,13 @@ def _data_loading_process( raise NotImplementedError( "Label keys are not supported for edge entities" ) - ( - entity_ids, - entity_features, - entity_labels, - ) = tf_record_dataloader.load_as_torch_tensors( + loaded_entity = tf_record_dataloader.load_as_torch_tensors( serialized_tf_record_info=serialized_entity_tf_record_info, tf_dataset_options=tf_dataset_options, ) + entity_ids = loaded_entity.ids + entity_features = loaded_entity.features + entity_labels = loaded_entity.labels ids[graph_type] = entity_ids logger.info( f"Rank {rank} finished loading {entity_type} ids of shape {entity_ids.shape} for graph type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" @@ -161,6 +163,42 @@ def _data_loading_process( f"Rank {rank} did not detect {entity_type} labels for graph type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" ) + # Extract weight column from edge features when weight_edge_feat_name is set. + # The weight column is sliced out of each edge type's feature tensor and stored + # separately so it is not duplicated in the feature matrix. + if weight_edge_feat_name is not None and entity_type == _EDGE_KEY: + if isinstance(weight_edge_feat_name, str): + if len(serialized_tf_record_info) > 1: + raise ValueError( + f"weight_edge_feat_name must be a dict[EdgeType, str] for heterogeneous " + f"graphs with multiple edge types ({sorted(serialized_tf_record_info)}). " + "Provide an explicit per-edge-type mapping instead of a single string." + ) + weight_name_dict: dict[Union[NodeType, EdgeType], str] = { + et: weight_edge_feat_name for et in serialized_tf_record_info + } + else: + weight_name_dict = weight_edge_feat_name # type: ignore[assignment] + + for graph_type, feat_tensor in list(features.items()): + if graph_type not in weight_name_dict: + continue + col_name = weight_name_dict[graph_type] + feature_keys = list(serialized_tf_record_info[graph_type].feature_keys) + if col_name not in feature_keys: + raise ValueError( + f"weight_edge_feat_name '{col_name}' not found in edge feature keys " + f"for edge type {graph_type}: {feature_keys}" + ) + col_idx = feature_keys.index(col_name) + weights[graph_type] = feat_tensor[:, col_idx] + keep_cols = [i for i in range(feat_tensor.shape[1]) if i != col_idx] + features[graph_type] = feat_tensor[:, keep_cols] + logger.info( + f"Rank {rank} extracted weight column '{col_name}' (col {col_idx}) " + f"from {entity_type} features for graph type {graph_type}" + ) + logger.info( f"Rank {rank} is attempting to share {entity_type} id memory for tfrecord directories: {all_tf_record_uris}" ) @@ -180,6 +218,12 @@ def _data_loading_process( ) share_memory(labels) + if weights: + logger.info( + f"Rank {rank} is attempting to share {entity_type} weight memory for tfrecord directories: {all_tf_record_uris}" + ) + share_memory(weights) + output_dict[_ID_FMT.format(entity=entity_type)] = ( list(ids.values())[0] if is_input_homogeneous else ids ) @@ -191,6 +235,10 @@ def _data_loading_process( output_dict[_LABEL_FMT.format(entity=entity_type)] = ( list(labels.values())[0] if is_input_homogeneous else labels ) + if weights: + output_dict[_WEIGHT_FMT.format(entity=entity_type)] = ( + list(weights.values())[0] if is_input_homogeneous else weights + ) logger.info( f"Rank {rank} has finished loading {entity_type} data from tfrecord directories: {all_tf_record_uris}, elapsed time: {time.time() - start_time:.2f} seconds" @@ -207,6 +255,7 @@ def load_torch_tensors_from_tf_record( rank: int = 0, node_tf_dataset_options: TFDatasetOptions = TFDatasetOptions(), edge_tf_dataset_options: TFDatasetOptions = TFDatasetOptions(), + weight_edge_feat_name: Optional[Union[str, dict[EdgeType, str]]] = None, ) -> LoadedGraphTensors: """ Loads all torch tensors from a SerializedGraphMetadata object for all entity [node, edge, positive_label, negative_label] and edge / node types. @@ -222,6 +271,10 @@ def load_torch_tensors_from_tf_record( rank (int): Rank on current machine node_tf_dataset_options (TFDatasetOptions): The options to use for nodes when building the dataset. edge_tf_dataset_options (TFDatasetOptions): The options to use for edges when building the dataset. + weight_edge_feat_name (Optional[Union[str, dict[EdgeType, str]]]): Name of the edge feature column to extract + as sampling weights. The column is removed from the edge feature matrix and returned separately via + ``LoadedGraphTensors.edge_weights``. Supply a single string for homogeneous graphs or a per-edge-type + dict for heterogeneous graphs. Returns: loaded_graph_tensors (LoadedGraphTensors): Unpartitioned Graph Tensors """ @@ -269,6 +322,7 @@ def load_torch_tensors_from_tf_record( "serialized_tf_record_info": serialized_graph_metadata.edge_entity_info, "rank": rank, "tf_dataset_options": edge_tf_dataset_options, + "weight_edge_feat_name": weight_edge_feat_name, }, ) @@ -351,6 +405,7 @@ def load_torch_tensors_from_tf_record( edge_index = edge_output_dict[_ID_FMT.format(entity=_EDGE_KEY)] edge_features = edge_output_dict.get(_FEATURE_FMT.format(entity=_EDGE_KEY), None) + edge_weights = edge_output_dict.get(_WEIGHT_FMT.format(entity=_EDGE_KEY), None) positive_labels = edge_output_dict.get( _ID_FMT.format(entity=_POSITIVE_LABEL_KEY), None @@ -378,4 +433,5 @@ def load_torch_tensors_from_tf_record( edge_features=edge_features, positive_label=positive_labels, negative_label=negative_labels, + edge_weights=edge_weights, ) diff --git a/gigl/distributed/dataset_factory.py b/gigl/distributed/dataset_factory.py index 1c7a62c06..7f74cc996 100644 --- a/gigl/distributed/dataset_factory.py +++ b/gigl/distributed/dataset_factory.py @@ -20,7 +20,7 @@ ) from gigl.common import Uri, UriFactory -from gigl.common.data.dataloaders import SerializedTFRecordInfo, TFRecordDataLoader +from gigl.common.data.dataloaders import TFRecordDataLoader from gigl.common.data.load_torch_tensors import ( SerializedGraphMetadata, TFDatasetOptions, @@ -44,7 +44,6 @@ from gigl.src.common.types.graph_data import EdgeType from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.common.types.pb_wrappers.task_metadata import TaskMetadataType -from gigl.types.graph import DEFAULT_HOMOGENEOUS_EDGE_TYPE from gigl.utils.data_splitters import ( DistNodeAnchorLinkSplitter, DistNodeSplitter, @@ -57,112 +56,6 @@ logger = Logger() -def _extract_weight_column( - edge_features: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], - weight_edge_feat_name: Union[str, dict[EdgeType, str]], - edge_entity_info: Union[ - SerializedTFRecordInfo, dict[EdgeType, SerializedTFRecordInfo] - ], -) -> tuple[ - dict[EdgeType, torch.Tensor], - Union[torch.Tensor, dict[EdgeType, torch.Tensor]], -]: - """Extracts a weight column from edge features, removing it from the feature tensor. - - Returns a tuple of (edge_weights_by_type, trimmed_edge_features). The weight column - is removed from the feature tensor so it is not duplicated in memory. - - Args: - edge_features: Edge feature tensor(s). - weight_edge_feat_name: Name of the weight feature column. A single string is - only valid for single-edge-type (homogeneous) graphs; heterogeneous graphs - must supply a ``dict[EdgeType, str]`` to be explicit about which edge - type(s) carry the weight column. - edge_entity_info: SerializedTFRecordInfo carrying ordered ``feature_keys`` used - to resolve the column name to an index. - - Returns: - Tuple of (weights_by_type, trimmed_features) where ``weights_by_type`` maps each - edge type to its extracted 1-D weight tensor and ``trimmed_features`` has the - weight column removed. - """ - # Normalise edge_features to heterogeneous format for uniform processing - is_homogeneous_input = isinstance(edge_features, torch.Tensor) - if is_homogeneous_input: - assert isinstance(edge_features, torch.Tensor) - edge_features_dict: dict[EdgeType, torch.Tensor] = { - DEFAULT_HOMOGENEOUS_EDGE_TYPE: edge_features - } - else: - assert isinstance(edge_features, dict) - edge_features_dict = edge_features - - # Normalise edge_entity_info to heterogeneous format - if isinstance(edge_entity_info, SerializedTFRecordInfo): - entity_info_dict: dict[EdgeType, SerializedTFRecordInfo] = { - DEFAULT_HOMOGENEOUS_EDGE_TYPE: edge_entity_info - } - else: - entity_info_dict = edge_entity_info - - # Normalise weight_edge_feat_name to per-edge-type mapping. - # A bare string is only unambiguous for single-edge-type graphs; for - # heterogeneous graphs a dict[EdgeType, str] must be provided so that - # the caller is explicit about which edge type(s) carry the weight column. - if isinstance(weight_edge_feat_name, str): - if len(edge_features_dict) > 1: - raise ValueError( - f"weight_edge_feat_name must be a dict[EdgeType, str] for heterogeneous graphs " - f"with multiple edge types ({sorted(edge_features_dict)}). " - "Provide an explicit per-edge-type mapping instead of a single string." - ) - weight_name_dict: dict[EdgeType, str] = { - et: weight_edge_feat_name for et in edge_features_dict - } - else: - weight_name_dict = weight_edge_feat_name - - edge_weights_by_type: dict[EdgeType, torch.Tensor] = {} - trimmed_features_dict: dict[EdgeType, torch.Tensor] = {} - - for edge_type, feat_tensor in edge_features_dict.items(): - if edge_type not in weight_name_dict: - trimmed_features_dict[edge_type] = feat_tensor - continue - - feat_name = weight_name_dict[edge_type] - info = entity_info_dict.get(edge_type) - if info is None: - raise ValueError( - f"weight_edge_feat_name specifies edge type {edge_type} but no " - f"SerializedTFRecordInfo is available for that type. " - f"Available edge types: {list(entity_info_dict.keys())}" - ) - feature_keys = list(info.feature_keys) - - if feat_name not in feature_keys: - raise ValueError( - f"weight_edge_feat_name '{feat_name}' not found in edge feature keys " - f"for edge type {edge_type}: {feature_keys}" - ) - - col_idx = feature_keys.index(feat_name) - edge_weights_by_type[edge_type] = feat_tensor[:, col_idx] - - # Remove the weight column from the feature tensor - keep_cols = [i for i in range(feat_tensor.shape[1]) if i != col_idx] - trimmed_features_dict[edge_type] = feat_tensor[:, keep_cols] - - if is_homogeneous_input: - trimmed_features: Union[torch.Tensor, dict[EdgeType, torch.Tensor]] = ( - trimmed_features_dict[DEFAULT_HOMOGENEOUS_EDGE_TYPE] - ) - else: - trimmed_features = trimmed_features_dict - - return edge_weights_by_type, trimmed_features - - @tf_on_cpu def _load_and_build_partitioned_dataset( serialized_graph_metadata: SerializedGraphMetadata, @@ -217,6 +110,7 @@ def _load_and_build_partitioned_dataset( rank=rank, node_tf_dataset_options=node_tf_dataset_options, edge_tf_dataset_options=edge_tf_dataset_options, + weight_edge_feat_name=weight_edge_feat_name, ) # TODO (mkolodner-sc): Move this code block (from here up to start of partitioning) to transductive splitter once that is ready @@ -288,16 +182,11 @@ def _load_and_build_partitioned_dataset( ) if loaded_graph_tensors.node_labels is not None: partitioner.register_node_labels(node_labels=loaded_graph_tensors.node_labels) + if loaded_graph_tensors.edge_weights is not None: + partitioner.register_edge_weights( + edge_weights=loaded_graph_tensors.edge_weights + ) if loaded_graph_tensors.edge_features is not None: - if weight_edge_feat_name is not None: - edge_weights_by_type, trimmed_edge_features = _extract_weight_column( - edge_features=loaded_graph_tensors.edge_features, - weight_edge_feat_name=weight_edge_feat_name, - edge_entity_info=serialized_graph_metadata.edge_entity_info, - ) - loaded_graph_tensors.edge_features = trimmed_edge_features - if edge_weights_by_type: - partitioner.register_edge_weights(edge_weights=edge_weights_by_type) partitioner.register_edge_features( edge_features=loaded_graph_tensors.edge_features ) @@ -318,6 +207,7 @@ def _load_and_build_partitioned_dataset( loaded_graph_tensors.node_features, loaded_graph_tensors.edge_index, loaded_graph_tensors.edge_features, + loaded_graph_tensors.edge_weights, loaded_graph_tensors.positive_label, loaded_graph_tensors.negative_label, loaded_graph_tensors.node_labels, diff --git a/gigl/types/graph.py b/gigl/types/graph.py index 00ee69e33..a45770e86 100644 --- a/gigl/types/graph.py +++ b/gigl/types/graph.py @@ -149,6 +149,8 @@ class LoadedGraphTensors: positive_label: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]] # Unpartitioned Negative Edge Label negative_label: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]] + # Unpartitioned Edge Weights (per-edge sampling weights, one scalar per edge) + edge_weights: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]] = None def treat_labels_as_edges(self, edge_dir: Literal["in", "out"]) -> None: """ From 505277b1c8a9da8788167acf3d88d5194595f2a1 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Thu, 14 May 2026 18:51:02 +0000 Subject: [PATCH 08/13] Update --- gigl/common/data/load_torch_tensors.py | 79 ++++++++++++++++---------- 1 file changed, 48 insertions(+), 31 deletions(-) diff --git a/gigl/common/data/load_torch_tensors.py b/gigl/common/data/load_torch_tensors.py index dbda3da53..9a6dc084d 100644 --- a/gigl/common/data/load_torch_tensors.py +++ b/gigl/common/data/load_torch_tensors.py @@ -121,7 +121,7 @@ def _data_loading_process( labels: dict[Union[NodeType, EdgeType], torch.Tensor] = {} weights: dict[Union[NodeType, EdgeType], torch.Tensor] = {} for ( - graph_type, + edge_type, serialized_entity_tf_record_info, ) in serialized_tf_record_info.items(): # We currently do not support training with labels for edge entities @@ -139,28 +139,28 @@ def _data_loading_process( entity_ids = loaded_entity.ids entity_features = loaded_entity.features entity_labels = loaded_entity.labels - ids[graph_type] = entity_ids + ids[edge_type] = entity_ids logger.info( - f"Rank {rank} finished loading {entity_type} ids of shape {entity_ids.shape} for graph type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" + f"Rank {rank} finished loading {entity_type} ids of shape {entity_ids.shape} for type {edge_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" ) if entity_features is not None: - features[graph_type] = entity_features + features[edge_type] = entity_features logger.info( - f"Rank {rank} finished loading {entity_type} features of shape {entity_features.shape} for graph type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" + f"Rank {rank} finished loading {entity_type} features of shape {entity_features.shape} for type {edge_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" ) else: logger.info( - f"Rank {rank} did not detect {entity_type} features for graph type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" + f"Rank {rank} did not detect {entity_type} features for type {edge_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" ) if entity_labels is not None: - labels[graph_type] = entity_labels + labels[edge_type] = entity_labels logger.info( - f"Rank {rank} finished loading {entity_type} labels of shape {entity_labels.shape} for graph type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" + f"Rank {rank} finished loading {entity_type} labels of shape {entity_labels.shape} for type {edge_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" ) else: logger.info( - f"Rank {rank} did not detect {entity_type} labels for graph type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" + f"Rank {rank} did not detect {entity_type} labels for type {edge_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" ) # Extract weight column from edge features when weight_edge_feat_name is set. @@ -174,30 +174,47 @@ def _data_loading_process( f"graphs with multiple edge types ({sorted(serialized_tf_record_info)}). " "Provide an explicit per-edge-type mapping instead of a single string." ) - weight_name_dict: dict[Union[NodeType, EdgeType], str] = { - et: weight_edge_feat_name for et in serialized_tf_record_info - } + # Single column name applies to all loaded edge types. + for edge_type, feat_tensor in list(features.items()): + col_name = weight_edge_feat_name + feature_keys = list( + serialized_tf_record_info[edge_type].feature_keys + ) + if col_name not in feature_keys: + raise ValueError( + f"weight_edge_feat_name '{col_name}' not found in edge feature keys " + f"for edge type {edge_type}: {feature_keys}" + ) + col_idx = feature_keys.index(col_name) + weights[edge_type] = feat_tensor[:, col_idx] + keep_cols = [i for i in range(feat_tensor.shape[1]) if i != col_idx] + features[edge_type] = feat_tensor[:, keep_cols] + logger.info( + f"Rank {rank} extracted weight column '{col_name}' (col {col_idx}) " + f"from {entity_type} features for type {edge_type}" + ) else: - weight_name_dict = weight_edge_feat_name # type: ignore[assignment] - - for graph_type, feat_tensor in list(features.items()): - if graph_type not in weight_name_dict: - continue - col_name = weight_name_dict[graph_type] - feature_keys = list(serialized_tf_record_info[graph_type].feature_keys) - if col_name not in feature_keys: - raise ValueError( - f"weight_edge_feat_name '{col_name}' not found in edge feature keys " - f"for edge type {graph_type}: {feature_keys}" + # Iterate the EdgeType-keyed dict directly to stay within EdgeType. + for edge_type, col_name in weight_edge_feat_name.items(): + if edge_type not in features: + continue + feat_tensor = features[edge_type] + feature_keys = list( + serialized_tf_record_info[edge_type].feature_keys + ) + if col_name not in feature_keys: + raise ValueError( + f"weight_edge_feat_name '{col_name}' not found in edge feature keys " + f"for edge type {edge_type}: {feature_keys}" + ) + col_idx = feature_keys.index(col_name) + weights[edge_type] = feat_tensor[:, col_idx] + keep_cols = [i for i in range(feat_tensor.shape[1]) if i != col_idx] + features[edge_type] = feat_tensor[:, keep_cols] + logger.info( + f"Rank {rank} extracted weight column '{col_name}' (col {col_idx}) " + f"from {entity_type} features for type {edge_type}" ) - col_idx = feature_keys.index(col_name) - weights[graph_type] = feat_tensor[:, col_idx] - keep_cols = [i for i in range(feat_tensor.shape[1]) if i != col_idx] - features[graph_type] = feat_tensor[:, keep_cols] - logger.info( - f"Rank {rank} extracted weight column '{col_name}' (col {col_idx}) " - f"from {entity_type} features for graph type {graph_type}" - ) logger.info( f"Rank {rank} is attempting to share {entity_type} id memory for tfrecord directories: {all_tf_record_uris}" From daf54dba7e6519a0226dddc408a2a84ddfad603f Mon Sep 17 00:00:00 2001 From: mkolodner Date: Thu, 14 May 2026 18:52:55 +0000 Subject: [PATCH 09/13] Change back --- gigl/common/data/load_torch_tensors.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/gigl/common/data/load_torch_tensors.py b/gigl/common/data/load_torch_tensors.py index 9a6dc084d..222814b79 100644 --- a/gigl/common/data/load_torch_tensors.py +++ b/gigl/common/data/load_torch_tensors.py @@ -121,7 +121,7 @@ def _data_loading_process( labels: dict[Union[NodeType, EdgeType], torch.Tensor] = {} weights: dict[Union[NodeType, EdgeType], torch.Tensor] = {} for ( - edge_type, + graph_type, serialized_entity_tf_record_info, ) in serialized_tf_record_info.items(): # We currently do not support training with labels for edge entities @@ -139,28 +139,28 @@ def _data_loading_process( entity_ids = loaded_entity.ids entity_features = loaded_entity.features entity_labels = loaded_entity.labels - ids[edge_type] = entity_ids + ids[graph_type] = entity_ids logger.info( - f"Rank {rank} finished loading {entity_type} ids of shape {entity_ids.shape} for type {edge_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" + f"Rank {rank} finished loading {entity_type} ids of shape {entity_ids.shape} for type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" ) if entity_features is not None: - features[edge_type] = entity_features + features[graph_type] = entity_features logger.info( - f"Rank {rank} finished loading {entity_type} features of shape {entity_features.shape} for type {edge_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" + f"Rank {rank} finished loading {entity_type} features of shape {entity_features.shape} for type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" ) else: logger.info( - f"Rank {rank} did not detect {entity_type} features for type {edge_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" + f"Rank {rank} did not detect {entity_type} features for type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" ) if entity_labels is not None: - labels[edge_type] = entity_labels + labels[graph_type] = entity_labels logger.info( - f"Rank {rank} finished loading {entity_type} labels of shape {entity_labels.shape} for type {edge_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" + f"Rank {rank} finished loading {entity_type} labels of shape {entity_labels.shape} for type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" ) else: logger.info( - f"Rank {rank} did not detect {entity_type} labels for type {edge_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" + f"Rank {rank} did not detect {entity_type} labels for type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" ) # Extract weight column from edge features when weight_edge_feat_name is set. From 783648464c73cabd5f910fa410917559e32784d5 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Thu, 14 May 2026 18:58:34 +0000 Subject: [PATCH 10/13] Fixes --- gigl/common/data/load_torch_tensors.py | 51 ++++++++++++-------------- 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/gigl/common/data/load_torch_tensors.py b/gigl/common/data/load_torch_tensors.py index 222814b79..a2122a6d9 100644 --- a/gigl/common/data/load_torch_tensors.py +++ b/gigl/common/data/load_torch_tensors.py @@ -27,7 +27,7 @@ _ID_FMT = "{entity}_ids" _FEATURE_FMT = "{entity}_features" _LABEL_FMT = "{entity}_labels" -_WEIGHT_FMT = "{entity}_weights" +_EDGE_WEIGHTS_KEY = "edge_weights" _NODE_KEY = "node" _EDGE_KEY = "edge" _POSITIVE_LABEL_KEY = "positive_label" @@ -141,26 +141,26 @@ def _data_loading_process( entity_labels = loaded_entity.labels ids[graph_type] = entity_ids logger.info( - f"Rank {rank} finished loading {entity_type} ids of shape {entity_ids.shape} for type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" + f"Rank {rank} finished loading {entity_type} ids of shape {entity_ids.shape} for graph type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" ) if entity_features is not None: features[graph_type] = entity_features logger.info( - f"Rank {rank} finished loading {entity_type} features of shape {entity_features.shape} for type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" + f"Rank {rank} finished loading {entity_type} features of shape {entity_features.shape} for graph type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" ) else: logger.info( - f"Rank {rank} did not detect {entity_type} features for type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" + f"Rank {rank} did not detect {entity_type} features for graph type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" ) if entity_labels is not None: labels[graph_type] = entity_labels logger.info( - f"Rank {rank} finished loading {entity_type} labels of shape {entity_labels.shape} for type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" + f"Rank {rank} finished loading {entity_type} labels of shape {entity_labels.shape} for graph type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" ) else: logger.info( - f"Rank {rank} did not detect {entity_type} labels for type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" + f"Rank {rank} did not detect {entity_type} labels for graph type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" ) # Extract weight column from edge features when weight_edge_feat_name is set. @@ -168,31 +168,28 @@ def _data_loading_process( # separately so it is not duplicated in the feature matrix. if weight_edge_feat_name is not None and entity_type == _EDGE_KEY: if isinstance(weight_edge_feat_name, str): - if len(serialized_tf_record_info) > 1: + if len(serialized_tf_record_info) != 1 or len(features) != 1: raise ValueError( f"weight_edge_feat_name must be a dict[EdgeType, str] for heterogeneous " f"graphs with multiple edge types ({sorted(serialized_tf_record_info)}). " "Provide an explicit per-edge-type mapping instead of a single string." ) - # Single column name applies to all loaded edge types. - for edge_type, feat_tensor in list(features.items()): - col_name = weight_edge_feat_name - feature_keys = list( - serialized_tf_record_info[edge_type].feature_keys - ) - if col_name not in feature_keys: - raise ValueError( - f"weight_edge_feat_name '{col_name}' not found in edge feature keys " - f"for edge type {edge_type}: {feature_keys}" - ) - col_idx = feature_keys.index(col_name) - weights[edge_type] = feat_tensor[:, col_idx] - keep_cols = [i for i in range(feat_tensor.shape[1]) if i != col_idx] - features[edge_type] = feat_tensor[:, keep_cols] - logger.info( - f"Rank {rank} extracted weight column '{col_name}' (col {col_idx}) " - f"from {entity_type} features for type {edge_type}" + col_name = weight_edge_feat_name + edge_type, feat_tensor = next(iter(features.items())) + feature_keys = list(serialized_tf_record_info[edge_type].feature_keys) + if col_name not in feature_keys: + raise ValueError( + f"weight_edge_feat_name '{col_name}' not found in edge feature keys " + f"for edge type {edge_type}: {feature_keys}" ) + col_idx = feature_keys.index(col_name) + weights[edge_type] = feat_tensor[:, col_idx] + keep_cols = [i for i in range(feat_tensor.shape[1]) if i != col_idx] + features[edge_type] = feat_tensor[:, keep_cols] + logger.info( + f"Rank {rank} extracted weight column '{col_name}' (col {col_idx}) " + f"from {entity_type} features for type {edge_type}" + ) else: # Iterate the EdgeType-keyed dict directly to stay within EdgeType. for edge_type, col_name in weight_edge_feat_name.items(): @@ -253,7 +250,7 @@ def _data_loading_process( list(labels.values())[0] if is_input_homogeneous else labels ) if weights: - output_dict[_WEIGHT_FMT.format(entity=entity_type)] = ( + output_dict[_EDGE_WEIGHTS_KEY] = ( list(weights.values())[0] if is_input_homogeneous else weights ) @@ -422,7 +419,7 @@ def load_torch_tensors_from_tf_record( edge_index = edge_output_dict[_ID_FMT.format(entity=_EDGE_KEY)] edge_features = edge_output_dict.get(_FEATURE_FMT.format(entity=_EDGE_KEY), None) - edge_weights = edge_output_dict.get(_WEIGHT_FMT.format(entity=_EDGE_KEY), None) + edge_weights = edge_output_dict.get(_EDGE_WEIGHTS_KEY, None) positive_labels = edge_output_dict.get( _ID_FMT.format(entity=_POSITIVE_LABEL_KEY), None From c9d717dc57aab4ec7a6695803b1b4612412a85f4 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Thu, 14 May 2026 19:59:27 +0000 Subject: [PATCH 11/13] Add test for changes --- tests/unit/common/data/dataloaders_test.py | 111 +++++++++++++++++++++ 1 file changed, 111 insertions(+) diff --git a/tests/unit/common/data/dataloaders_test.py b/tests/unit/common/data/dataloaders_test.py index 569ee800a..e27ee1381 100644 --- a/tests/unit/common/data/dataloaders_test.py +++ b/tests/unit/common/data/dataloaders_test.py @@ -16,6 +16,10 @@ TFRecordDataLoader, _get_labels_from_features, ) +from gigl.common.data.load_torch_tensors import ( + SerializedGraphMetadata, + load_torch_tensors_from_tf_record, +) from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.data_preprocessor.lib.types import FeatureSpecDict from gigl.src.mocking.lib.versioning import ( @@ -487,6 +491,113 @@ def test_load_labels_from_pb(self): self.assertEqual(feature_tensor.size(1), node_metadata.feature_dim) self.assertEqual(label_tensor.size(1), len(node_metadata.label_keys)) + def test_load_edge_weights_from_tf_record(self): + """Edge weight column is extracted from edge features and returned separately. + + Mirrors test_load_labels_from_pb for edge sampling weights. + """ + n_edges = 5 + edge_feature_vals = [float(i * 10) for i in range(n_edges)] + edge_weight_vals = [float(i + 1) for i in range(n_edges)] + + node_dir = tempfile.TemporaryDirectory() + edge_dir = tempfile.TemporaryDirectory() + self.addCleanup(node_dir.cleanup) + self.addCleanup(edge_dir.cleanup) + + with tf.io.TFRecordWriter( + str(Path(node_dir.name) / "nodes.tfrecord") + ) as writer: + for i in range(3): + writer.write( + tf.train.Example( + features=tf.train.Features( + feature={ + "node_id": tf.train.Feature( + int64_list=tf.train.Int64List(value=[i]) + ), + } + ) + ).SerializeToString() + ) + + with tf.io.TFRecordWriter( + str(Path(edge_dir.name) / "edges.tfrecord") + ) as writer: + for i in range(n_edges): + writer.write( + tf.train.Example( + features=tf.train.Features( + feature={ + "src_id": tf.train.Feature( + int64_list=tf.train.Int64List(value=[i % 3]) + ), + "dst_id": tf.train.Feature( + int64_list=tf.train.Int64List(value=[(i + 1) % 3]) + ), + "edge_feature": tf.train.Feature( + float_list=tf.train.FloatList( + value=[edge_feature_vals[i]] + ) + ), + "edge_weight": tf.train.Feature( + float_list=tf.train.FloatList( + value=[edge_weight_vals[i]] + ) + ), + } + ) + ).SerializeToString() + ) + + loader = TFRecordDataLoader(rank=0, world_size=1) + loaded = load_torch_tensors_from_tf_record( + tf_record_dataloader=loader, + serialized_graph_metadata=SerializedGraphMetadata( + node_entity_info=SerializedTFRecordInfo( + tfrecord_uri_prefix=UriFactory.create_uri(node_dir.name), + feature_spec={"node_id": tf.io.FixedLenFeature([], tf.int64)}, + feature_keys=[], + feature_dim=0, + entity_key="node_id", + tfrecord_uri_pattern="nodes.tfrecord", + ), + edge_entity_info=SerializedTFRecordInfo( + tfrecord_uri_prefix=UriFactory.create_uri(edge_dir.name), + feature_spec={ + "src_id": tf.io.FixedLenFeature([], tf.int64), + "dst_id": tf.io.FixedLenFeature([], tf.int64), + "edge_feature": tf.io.FixedLenFeature([], tf.float32), + "edge_weight": tf.io.FixedLenFeature([], tf.float32), + }, + feature_keys=["edge_feature", "edge_weight"], + feature_dim=2, + entity_key=("src_id", "dst_id"), + tfrecord_uri_pattern="edges.tfrecord", + ), + ), + should_load_tensors_in_parallel=False, + edge_tf_dataset_options=TFDatasetOptions(deterministic=True), + weight_edge_feat_name="edge_weight", + ) + + self.assertIsNotNone(loaded.edge_weights) + self.assertIsNotNone(loaded.edge_features) + assert isinstance(loaded.edge_weights, torch.Tensor) + assert isinstance(loaded.edge_features, torch.Tensor) + # Weight column extracted as 1D tensor; feature matrix has one column remaining + self.assertEqual(loaded.edge_weights.shape, torch.Size([n_edges])) + self.assertEqual(loaded.edge_features.shape, torch.Size([n_edges, 1])) + # Values match the written data (sort since loading order is not guaranteed) + assert_close( + loaded.edge_weights.sort().values, + torch.tensor(sorted(edge_weight_vals), dtype=torch.float32), + ) + assert_close( + loaded.edge_features[:, 0].sort().values, + torch.tensor(sorted(edge_feature_vals), dtype=torch.float32), + ) + @parameterized.expand( [ param( From ff960b3c8b2a385f481657a62b8dba5358d61ccf Mon Sep 17 00:00:00 2001 From: mkolodner Date: Thu, 14 May 2026 23:02:32 +0000 Subject: [PATCH 12/13] address comments --- gigl/common/data/load_torch_tensors.py | 105 ++++++++-- gigl/distributed/dist_ablp_neighborloader.py | 4 +- gigl/distributed/dist_dataset.py | 9 +- gigl/distributed/dist_partitioner.py | 30 ++- gigl/distributed/dist_range_partitioner.py | 76 +++++-- .../distributed/distributed_neighborloader.py | 4 +- gigl/types/graph.py | 1 + tests/unit/common/data/dataloaders_test.py | 188 ++++++++++++++++++ .../distributed_weighted_sampling_test.py | 161 +++++++++++++++ tests/unit/types_tests/graph_test.py | 31 +++ 10 files changed, 545 insertions(+), 64 deletions(-) diff --git a/gigl/common/data/load_torch_tensors.py b/gigl/common/data/load_torch_tensors.py index a2122a6d9..1594bff30 100644 --- a/gigl/common/data/load_torch_tensors.py +++ b/gigl/common/data/load_torch_tensors.py @@ -29,6 +29,61 @@ _LABEL_FMT = "{entity}_labels" _EDGE_WEIGHTS_KEY = "edge_weights" _NODE_KEY = "node" + + +def _extract_weight_col( + feat_tensor: torch.Tensor, + feature_keys: list[str], + feature_spec: dict, + col_name: str, + edge_type: EdgeType, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Slice a named weight column out of a feature tensor. + + Accounts for multi-dim features: each feature key may contribute more than one column + to ``feat_tensor`` (e.g. ``FixedLenFeature(shape=[16])`` contributes 16 columns). + The weight feature must be a scalar (width 1). + + Args: + feat_tensor: Edge feature tensor of shape ``[num_edges, total_feature_cols]``. + feature_keys: Ordered list of feature names matching the columns of ``feat_tensor``. + feature_spec: Feature spec dict mapping feature name to its TF feature spec (used to + determine per-key column widths). + col_name: Name of the column to extract as weights. + edge_type: Edge type (used only in error messages). + + Returns: + A tuple ``(weights, trimmed_features)`` where ``weights`` is a 1-D tensor of shape + ``[num_edges]`` and ``trimmed_features`` is ``feat_tensor`` with the weight column + removed. + + Raises: + ValueError: If ``col_name`` is not in ``feature_keys`` or the weight feature is not + width 1. + """ + if col_name not in feature_keys: + raise ValueError( + f"weight_edge_feat_name '{col_name}' not found in edge feature keys " + f"for edge type {edge_type}: {feature_keys}" + ) + key_idx = feature_keys.index(col_name) + col_widths = [ + (spec.shape[-1] if spec.shape else 1) + for spec in (feature_spec[k] for k in feature_keys) + ] + weight_width = col_widths[key_idx] + if weight_width != 1: + raise ValueError( + f"weight_edge_feat_name '{col_name}' for edge type {edge_type} must be a scalar " + f"feature (width 1), but has width {weight_width}." + ) + col_offset = sum(col_widths[:key_idx]) + weights = feat_tensor[:, col_offset] + keep_cols = [i for i in range(feat_tensor.shape[1]) if i != col_offset] + trimmed = feat_tensor[:, keep_cols] if keep_cols else None + return weights, trimmed + + _EDGE_KEY = "edge" _POSITIVE_LABEL_KEY = "positive_label" _NEGATIVE_LABEL_KEY = "negative_label" @@ -91,6 +146,11 @@ def _data_loading_process( Serialized information for current entity rank (int): Rank of the current machine tf_dataset_options (TFDatasetOptions): The options to use when building the dataset. + weight_edge_feat_name (Optional[Union[str, dict[EdgeType, str]]]): Only used when + ``entity_type == _EDGE_KEY``. Name of the edge feature column to extract as + sampling weights. Ignored for node, positive_label, and negative_label entities. + Supply a single string for homogeneous graphs or a per-edge-type dict for + heterogeneous graphs. """ # We add a try - except clause here to ensure that exceptions are properly circulated back to the parent process try: @@ -176,18 +236,21 @@ def _data_loading_process( ) col_name = weight_edge_feat_name edge_type, feat_tensor = next(iter(features.items())) + assert isinstance(edge_type, EdgeType) feature_keys = list(serialized_tf_record_info[edge_type].feature_keys) - if col_name not in feature_keys: - raise ValueError( - f"weight_edge_feat_name '{col_name}' not found in edge feature keys " - f"for edge type {edge_type}: {feature_keys}" - ) - col_idx = feature_keys.index(col_name) - weights[edge_type] = feat_tensor[:, col_idx] - keep_cols = [i for i in range(feat_tensor.shape[1]) if i != col_idx] - features[edge_type] = feat_tensor[:, keep_cols] + weights[edge_type], trimmed = _extract_weight_col( + feat_tensor, + feature_keys, + serialized_tf_record_info[edge_type].feature_spec, + col_name, + edge_type, + ) + if trimmed is not None: + features[edge_type] = trimmed + else: + del features[edge_type] logger.info( - f"Rank {rank} extracted weight column '{col_name}' (col {col_idx}) " + f"Rank {rank} extracted weight column '{col_name}' " f"from {entity_type} features for type {edge_type}" ) else: @@ -199,17 +262,19 @@ def _data_loading_process( feature_keys = list( serialized_tf_record_info[edge_type].feature_keys ) - if col_name not in feature_keys: - raise ValueError( - f"weight_edge_feat_name '{col_name}' not found in edge feature keys " - f"for edge type {edge_type}: {feature_keys}" - ) - col_idx = feature_keys.index(col_name) - weights[edge_type] = feat_tensor[:, col_idx] - keep_cols = [i for i in range(feat_tensor.shape[1]) if i != col_idx] - features[edge_type] = feat_tensor[:, keep_cols] + weights[edge_type], trimmed = _extract_weight_col( + feat_tensor, + feature_keys, + serialized_tf_record_info[edge_type].feature_spec, + col_name, + edge_type, + ) + if trimmed is not None: + features[edge_type] = trimmed + else: + del features[edge_type] logger.info( - f"Rank {rank} extracted weight column '{col_name}' (col {col_idx}) " + f"Rank {rank} extracted weight column '{col_name}' " f"from {entity_type} features for type {edge_type}" ) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 5e9f6a3df..c6bc6f2f7 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -203,8 +203,8 @@ def __init__( shuffle (bool): Whether to shuffle the input nodes. (default: ``False``). drop_last (bool): Whether to drop the last incomplete batch. (default: ``False``). with_weight (bool): Whether to use edge weights for neighbor sampling. - Requires edge weights to be registered via - ``DistPartitioner.register_edge_weights()`` during dataset construction. + Requires edge weights to have been provided via + ``build_dataset(weight_edge_feat_name=...)`` during dataset construction. Defaults to ``False``. sampler_options (Optional[SamplerOptions]): Controls which sampler class is instantiated. Defaults to `KHopNeighborSamplerOptions`, which will use the num_neighbors argument diff --git a/gigl/distributed/dist_dataset.py b/gigl/distributed/dist_dataset.py index 46017a079..4db0b02fc 100644 --- a/gigl/distributed/dist_dataset.py +++ b/gigl/distributed/dist_dataset.py @@ -608,7 +608,6 @@ def _initialize_graph( for edge_type, graph_partition_data in partitioned_edge_index.items() if graph_partition_data.weights is not None } - edge_weights = weights_by_type if weights_by_type else None if weights_by_type: missing = set(partitioned_edge_index.keys()) - set( weights_by_type.keys() @@ -616,9 +615,13 @@ def _initialize_graph( if missing: logger.warning( f"Edge weights are registered for {set(weights_by_type.keys())} but " - f"not for {missing}. When with_weight=True, edge types without weights " - f"will fall back to uniform sampling." + f"not for {missing}. Filling missing edge types with uniform weights " + f"(all 1s) so GLT does not segfault on partial-weight heterogeneous graphs." ) + for edge_type in missing: + n_edges = partitioned_edge_index[edge_type].edge_index.size(1) + weights_by_type[edge_type] = torch.ones(n_edges) + edge_weights = weights_by_type if weights_by_type else None self._has_edge_weights = edge_weights is not None diff --git a/gigl/distributed/dist_partitioner.py b/gigl/distributed/dist_partitioner.py index ec095202b..c887cd54b 100644 --- a/gigl/distributed/dist_partitioner.py +++ b/gigl/distributed/dist_partitioner.py @@ -630,7 +630,9 @@ def register_edge_weights( """Registers per-edge sampling weights to the partitioner. Weights must be a 1-D float tensor of shape ``[num_edges]``, one scalar per edge. - Must be called after ``register_edge_index()`` and before ``partition()``. + Must be called before ``partition()``. May be called before or after + ``register_edge_index()``; shape validation against the edge index is performed + only when ``register_edge_index()`` has already been called. Weights are kept separate from edge features — do not include the weight column in edge features passed to ``register_edge_features()``. @@ -666,6 +668,8 @@ def register_edge_weights( f"Edge weights for edge type {edge_type} must be a 1-D tensor of shape " f"[num_edges], got shape {tuple(weight_tensor.shape)}." ) + # Shape validation is best-effort: only fires when register_edge_index() + # has already been called for this edge type. if self._edge_index is not None and edge_type in self._edge_index: local_num_edges = self._edge_index[edge_type].shape[1] if weight_tensor.shape[0] != local_num_edges: @@ -1130,16 +1134,12 @@ def _partition_edge_index_and_edge_features( and self._num_edges is not None ), "Must have registered edges prior to partitioning them" - should_skip_edge_feats = ( - self._edge_feat is None or edge_type not in self._edge_feat - ) + has_edge_feats = self._edge_feat is not None and edge_type in self._edge_feat has_weights_for_edge_type = ( self._edge_weights is not None and edge_type in self._edge_weights ) # Need a partition book if we have features or weights to reindex. - should_generate_partition_book = ( - not should_skip_edge_feats or has_weights_for_edge_type - ) + should_generate_partition_book = has_edge_feats or has_weights_for_edge_type # Partitioning Edge Indices @@ -1219,7 +1219,7 @@ def _edge_pfn(_, chunk_range): edge_feat: Optional[torch.Tensor] = None edge_feat_dim: Optional[int] = None edge_weights_tensor: Optional[torch.Tensor] = None - if not should_skip_edge_feats: + if has_edge_feats: assert self._edge_feat is not None and edge_type in self._edge_feat assert ( self._edge_feat_dim is not None and edge_type in self._edge_feat_dim @@ -1238,11 +1238,9 @@ def _edge_pfn(_, chunk_range): input_parts.append(edge_ids) # Positional indices: features first, weights next, ids always last. - feat_idx: Optional[int] = 0 if not should_skip_edge_feats else None + feat_idx: Optional[int] = 0 if has_edge_feats else None weight_idx: Optional[int] = ( - (1 if not should_skip_edge_feats else 0) - if has_weights_for_edge_type - else None + (1 if has_edge_feats else 0) if has_weights_for_edge_type else None ) def _edge_feat_weight_pfn( @@ -1264,7 +1262,7 @@ def _edge_feat_weight_pfn( del self._edge_ids[edge_type] if len(self._edge_ids) == 0: self._edge_ids = None - if not should_skip_edge_feats: + if has_edge_feats: assert edge_feat is not None assert self._edge_feat is not None and self._edge_feat_dim is not None del edge_feat @@ -1283,7 +1281,7 @@ def _edge_feat_weight_pfn( if len(feat_weight_res_list) == 0: partitioned_edge_ids = torch.empty(0) - if not should_skip_edge_feats: + if has_edge_feats: assert edge_feat_dim is not None current_feat_part = FeaturePartitionData( feats=torch.empty(0, edge_feat_dim), @@ -1293,7 +1291,7 @@ def _edge_feat_weight_pfn( partitioned_weights = torch.empty(0) else: partitioned_edge_ids = torch.cat([r[-1] for r in feat_weight_res_list]) - if not should_skip_edge_feats: + if has_edge_feats: assert feat_idx is not None and edge_feat_dim is not None current_feat_part = FeaturePartitionData( feats=torch.cat([r[feat_idx] for r in feat_weight_res_list]), @@ -1308,7 +1306,7 @@ def _edge_feat_weight_pfn( feat_weight_res_list.clear() gc.collect() - if not should_skip_edge_feats: + if has_edge_feats: logger.info( f"Got edge tensor-based partition book for edge type {edge_type} on rank {self._rank} of shape {edge_partition_book.shape}" ) diff --git a/gigl/distributed/dist_range_partitioner.py b/gigl/distributed/dist_range_partitioner.py index 42110e983..faa1cfb57 100644 --- a/gigl/distributed/dist_range_partitioner.py +++ b/gigl/distributed/dist_range_partitioner.py @@ -217,21 +217,41 @@ def _partition_edge_index_and_edge_features( ) edge_index = self._edge_index[edge_type] + has_edge_feats = self._edge_feat is not None and edge_type in self._edge_feat + has_edge_weights = ( + self._edge_weights is not None and edge_type in self._edge_weights + ) - input_data: tuple[torch.Tensor, ...] - - if self._edge_feat is None or edge_type not in self._edge_feat: + if not has_edge_feats: logger.info( f"No edge features detected for edge type {edge_type}, will only partition edge indices for this edge type." ) - edge_feat = None - edge_feat_dim = None - input_data = (edge_index[0], edge_index[1]) - else: - assert self._edge_feat_dim is not None and edge_type in self._edge_feat_dim + + edge_feat: Optional[torch.Tensor] = None + edge_feat_dim: Optional[int] = None + edge_weights_tensor: Optional[torch.Tensor] = None + + if has_edge_feats: + assert self._edge_feat is not None and self._edge_feat_dim is not None + assert edge_type in self._edge_feat_dim edge_feat = self._edge_feat[edge_type] edge_feat_dim = self._edge_feat_dim[edge_type] - input_data = (edge_index[0], edge_index[1], edge_feat) + if has_edge_weights: + assert self._edge_weights is not None + edge_weights_tensor = self._edge_weights[edge_type] + + # Build input_data tuple: (src, dst[, feat][, weights]) + # Track the index of each optional tensor so we can unpack res_list correctly. + input_parts: list[torch.Tensor] = [edge_index[0], edge_index[1]] + feat_idx: Optional[int] = None + weight_idx: Optional[int] = None + if edge_feat is not None: + feat_idx = len(input_parts) + input_parts.append(edge_feat) + if edge_weights_tensor is not None: + weight_idx = len(input_parts) + input_parts.append(edge_weights_tensor) + input_data: tuple[torch.Tensor, ...] = tuple(input_parts) if self._should_assign_edges_by_src_node: target_node_partition_book = node_partition_book[edge_type.src_node_type] @@ -249,10 +269,13 @@ def edge_partition_fn(rank_indices, _): partition_function=edge_partition_fn, ) - del input_data, edge_index, target_indices, edge_feat + del input_data, edge_index, target_indices, edge_feat, edge_weights_tensor del self._edge_index[edge_type] if self._edge_feat is not None and edge_type in self._edge_feat: - del self._edge_feat[edge_type] + assert self._edge_feat_dim is not None + del self._edge_feat[edge_type], self._edge_feat_dim[edge_type] + if self._edge_weights is not None and edge_type in self._edge_weights: + del self._edge_weights[edge_type] # We check if edge_index or edge_feat dict is empty after deleting the tensor. If so, we set these fields to None. if not self._edge_index: @@ -260,11 +283,17 @@ def edge_partition_fn(rank_indices, _): if not self._edge_feat and not self._edge_feat_dim: self._edge_feat = None self._edge_feat_dim = None + if self._edge_weights is not None and not self._edge_weights: + self._edge_weights = None gc.collect() if len(res_list) == 0: partitioned_edge_index = torch.empty((2, 0)) + partitioned_edge_features = ( + torch.empty(0, edge_feat_dim) if edge_feat_dim is not None else None + ) + partitioned_weights = torch.empty(0) if has_edge_weights else None else: partitioned_edge_index = torch.stack( ( @@ -273,19 +302,22 @@ def edge_partition_fn(rank_indices, _): ), dim=0, ) - - if edge_feat_dim is not None: - if len(res_list) == 0: - partitioned_edge_features = torch.empty(0, edge_feat_dim) - else: - partitioned_edge_features = torch.cat([r[2] for r in res_list]) + partitioned_edge_features = ( + torch.cat([r[feat_idx] for r in res_list]) + if feat_idx is not None + else None + ) + partitioned_weights = ( + torch.cat([r[weight_idx] for r in res_list]) + if weight_idx is not None + else None + ) res_list.clear() - gc.collect() - # Generating edge partition book - + # Generate range-based edge partition book and infer edge IDs. + # Only needed when edge features are present — weights use positional IDs. num_edges_on_each_rank: list[tuple[int, int]] = sorted( all_gather((self._rank, partitioned_edge_index.size(1))).values(), key=lambda x: x[0], @@ -304,10 +336,11 @@ def edge_partition_fn(rank_indices, _): partitioned_edge_ids = get_ids_on_rank( partition_book=edge_partition_book, rank=self._rank ) - + assert partitioned_edge_features is not None current_graph_part = GraphPartitionData( edge_index=partitioned_edge_index, edge_ids=partitioned_edge_ids, + weights=partitioned_weights, ) current_feat_part = FeaturePartitionData( feats=partitioned_edge_features, ids=None @@ -320,6 +353,7 @@ def edge_partition_fn(rank_indices, _): current_graph_part = GraphPartitionData( edge_index=partitioned_edge_index, edge_ids=None, + weights=partitioned_weights, ) edge_partition_book = None diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index b99d3a528..7dd4fd73f 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -160,8 +160,8 @@ def __init__( shuffle (bool): Whether to shuffle the input nodes. (default: ``False``). drop_last (bool): Whether to drop the last incomplete batch. (default: ``False``). with_weight (bool): Whether to use edge weights for neighbor sampling. - Requires edge weights to be registered via - ``DistPartitioner.register_edge_weights()`` during dataset construction. + Requires edge weights to have been provided via + ``build_dataset(weight_edge_feat_name=...)`` during dataset construction. Defaults to ``False``. sampler_options (Optional[SamplerOptions]): Controls which sampler class is instantiated. Pass ``KHopNeighborSamplerOptions`` to use the built-in sampler, diff --git a/gigl/types/graph.py b/gigl/types/graph.py index a45770e86..3fc4958fc 100644 --- a/gigl/types/graph.py +++ b/gigl/types/graph.py @@ -251,6 +251,7 @@ def treat_labels_as_edges(self, edge_dir: Literal["in", "out"]) -> None: self.node_features = to_heterogeneous_node(self.node_features) self.edge_index = edge_index_with_labels self.edge_features = to_heterogeneous_edge(self.edge_features) + self.edge_weights = to_heterogeneous_edge(self.edge_weights) self.positive_label = None self.negative_label = None gc.collect() diff --git a/tests/unit/common/data/dataloaders_test.py b/tests/unit/common/data/dataloaders_test.py index e27ee1381..bb7bad002 100644 --- a/tests/unit/common/data/dataloaders_test.py +++ b/tests/unit/common/data/dataloaders_test.py @@ -598,6 +598,194 @@ def test_load_edge_weights_from_tf_record(self): torch.tensor(sorted(edge_feature_vals), dtype=torch.float32), ) + def test_load_edge_weights_multidim_feature(self): + """Weight column offset is correct when a preceding feature key is multi-dimensional. + + A preceding feature with shape=[2] contributes 2 columns to the concatenated tensor, + so the weight column lives at offset 2, not offset 1. + """ + n_edges = 3 + node_dir = tempfile.TemporaryDirectory() + edge_dir = tempfile.TemporaryDirectory() + self.addCleanup(node_dir.cleanup) + self.addCleanup(edge_dir.cleanup) + + with tf.io.TFRecordWriter( + str(Path(node_dir.name) / "nodes.tfrecord") + ) as writer: + for i in range(3): + writer.write( + tf.train.Example( + features=tf.train.Features( + feature={ + "node_id": tf.train.Feature( + int64_list=tf.train.Int64List(value=[i]) + ) + } + ) + ).SerializeToString() + ) + + # Each edge has: src_id, dst_id, a 2-dim embedding, and a scalar weight. + embedding_vals = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] + weight_vals = [0.1, 0.2, 0.3] + with tf.io.TFRecordWriter( + str(Path(edge_dir.name) / "edges.tfrecord") + ) as writer: + for i in range(n_edges): + writer.write( + tf.train.Example( + features=tf.train.Features( + feature={ + "src_id": tf.train.Feature( + int64_list=tf.train.Int64List(value=[i % 3]) + ), + "dst_id": tf.train.Feature( + int64_list=tf.train.Int64List(value=[(i + 1) % 3]) + ), + "embedding": tf.train.Feature( + float_list=tf.train.FloatList( + value=embedding_vals[i] + ) + ), + "edge_weight": tf.train.Feature( + float_list=tf.train.FloatList( + value=[weight_vals[i]] + ) + ), + } + ) + ).SerializeToString() + ) + + loader = TFRecordDataLoader(rank=0, world_size=1) + loaded = load_torch_tensors_from_tf_record( + tf_record_dataloader=loader, + serialized_graph_metadata=SerializedGraphMetadata( + node_entity_info=SerializedTFRecordInfo( + tfrecord_uri_prefix=UriFactory.create_uri(node_dir.name), + feature_spec={"node_id": tf.io.FixedLenFeature([], tf.int64)}, + feature_keys=[], + feature_dim=0, + entity_key="node_id", + tfrecord_uri_pattern="nodes.tfrecord", + ), + edge_entity_info=SerializedTFRecordInfo( + tfrecord_uri_prefix=UriFactory.create_uri(edge_dir.name), + feature_spec={ + "src_id": tf.io.FixedLenFeature([], tf.int64), + "dst_id": tf.io.FixedLenFeature([], tf.int64), + "embedding": tf.io.FixedLenFeature([2], tf.float32), + "edge_weight": tf.io.FixedLenFeature([], tf.float32), + }, + feature_keys=["embedding", "edge_weight"], + feature_dim=3, + entity_key=("src_id", "dst_id"), + tfrecord_uri_pattern="edges.tfrecord", + ), + ), + should_load_tensors_in_parallel=False, + edge_tf_dataset_options=TFDatasetOptions(deterministic=True), + weight_edge_feat_name="edge_weight", + ) + + assert isinstance(loaded.edge_weights, torch.Tensor) + assert isinstance(loaded.edge_features, torch.Tensor) + # Weight is a 1-D scalar per edge; embedding columns remain + self.assertEqual(loaded.edge_weights.shape, torch.Size([n_edges])) + self.assertEqual(loaded.edge_features.shape, torch.Size([n_edges, 2])) + assert_close( + loaded.edge_weights.sort().values, + torch.tensor(sorted(weight_vals), dtype=torch.float32), + ) + + def test_load_edge_weights_only_feature(self): + """When the weight column is the only edge feature, edge_features is None after extraction.""" + n_edges = 3 + node_dir = tempfile.TemporaryDirectory() + edge_dir = tempfile.TemporaryDirectory() + self.addCleanup(node_dir.cleanup) + self.addCleanup(edge_dir.cleanup) + + with tf.io.TFRecordWriter( + str(Path(node_dir.name) / "nodes.tfrecord") + ) as writer: + for i in range(3): + writer.write( + tf.train.Example( + features=tf.train.Features( + feature={ + "node_id": tf.train.Feature( + int64_list=tf.train.Int64List(value=[i]) + ) + } + ) + ).SerializeToString() + ) + + weight_vals = [0.5, 1.0, 1.5] + with tf.io.TFRecordWriter( + str(Path(edge_dir.name) / "edges.tfrecord") + ) as writer: + for i in range(n_edges): + writer.write( + tf.train.Example( + features=tf.train.Features( + feature={ + "src_id": tf.train.Feature( + int64_list=tf.train.Int64List(value=[i % 3]) + ), + "dst_id": tf.train.Feature( + int64_list=tf.train.Int64List(value=[(i + 1) % 3]) + ), + "edge_weight": tf.train.Feature( + float_list=tf.train.FloatList( + value=[weight_vals[i]] + ) + ), + } + ) + ).SerializeToString() + ) + + loader = TFRecordDataLoader(rank=0, world_size=1) + loaded = load_torch_tensors_from_tf_record( + tf_record_dataloader=loader, + serialized_graph_metadata=SerializedGraphMetadata( + node_entity_info=SerializedTFRecordInfo( + tfrecord_uri_prefix=UriFactory.create_uri(node_dir.name), + feature_spec={"node_id": tf.io.FixedLenFeature([], tf.int64)}, + feature_keys=[], + feature_dim=0, + entity_key="node_id", + tfrecord_uri_pattern="nodes.tfrecord", + ), + edge_entity_info=SerializedTFRecordInfo( + tfrecord_uri_prefix=UriFactory.create_uri(edge_dir.name), + feature_spec={ + "src_id": tf.io.FixedLenFeature([], tf.int64), + "dst_id": tf.io.FixedLenFeature([], tf.int64), + "edge_weight": tf.io.FixedLenFeature([], tf.float32), + }, + feature_keys=["edge_weight"], + feature_dim=1, + entity_key=("src_id", "dst_id"), + tfrecord_uri_pattern="edges.tfrecord", + ), + ), + should_load_tensors_in_parallel=False, + edge_tf_dataset_options=TFDatasetOptions(deterministic=True), + weight_edge_feat_name="edge_weight", + ) + + assert isinstance(loaded.edge_weights, torch.Tensor) + self.assertIsNone(loaded.edge_features) + self.assertEqual(loaded.edge_weights.shape, torch.Size([n_edges])) + assert_close( + loaded.edge_weights.sort().values, + torch.tensor(sorted(weight_vals), dtype=torch.float32), + ) + @parameterized.expand( [ param( diff --git a/tests/unit/distributed/distributed_weighted_sampling_test.py b/tests/unit/distributed/distributed_weighted_sampling_test.py index 7d5107c47..f51f1bf53 100644 --- a/tests/unit/distributed/distributed_weighted_sampling_test.py +++ b/tests/unit/distributed/distributed_weighted_sampling_test.py @@ -20,6 +20,7 @@ from gigl.distributed import DistPartitioner from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.dist_range_partitioner import DistRangePartitioner from gigl.distributed.distributed_neighborloader import DistNeighborLoader from gigl.distributed.utils.networking import get_free_port from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation @@ -224,6 +225,81 @@ def _build_heterogeneous_bipartite_weight_graph() -> tuple[ return partition_output, n_user, n_good_item, n_bad_item +def _build_heterogeneous_bipartite_partial_weight_graph() -> tuple[ + PartitionOutput, int, int, int +]: + """Same graph as _build_heterogeneous_bipartite_weight_graph but ITEM_TO_USER is unweighted. + + This validates that partial-weight heterogeneous graphs work correctly: + the weighted U2I edge type still respects weights (bad items are unreachable) + while the unweighted I2U edge type samples uniformly without crashing. + """ + n_user = 10 + n_good_item = 40 + n_bad_item = 20 + n_item = n_good_item + n_bad_item + + user_ids = torch.arange(n_user) + good_item_ids = torch.arange(n_good_item) + bad_item_ids = torch.arange(n_good_item, n_item) + + u2gi_src = user_ids.repeat_interleave(n_good_item) + u2gi_dst = good_item_ids.repeat(n_user) + u2gi_w = torch.ones(n_user * n_good_item) + + u2bi_src = user_ids.repeat_interleave(n_bad_item) + u2bi_dst = bad_item_ids.repeat(n_user) + u2bi_w = torch.zeros(n_user * n_bad_item) + + gi2u_src = good_item_ids.repeat_interleave(n_user) + gi2u_dst = user_ids.repeat(n_good_item) + + u2i_src = torch.cat([u2gi_src, u2bi_src]) + u2i_dst = torch.cat([u2gi_dst, u2bi_dst]) + u2i_w = torch.cat([u2gi_w, u2bi_w]) + n_u2i_edges = u2i_src.shape[0] + + user_feats = torch.full((n_user, 1), 2.0) + item_feats = torch.cat( + [ + torch.full((n_good_item, 1), 1.0), + torch.full((n_bad_item, 1), 0.0), + ] + ) + + partition_output = PartitionOutput( + node_partition_book={ + _USER: torch.zeros(n_user), + _ITEM: torch.zeros(n_item), + }, + edge_partition_book={ + _USER_TO_ITEM: torch.zeros(n_u2i_edges), + _ITEM_TO_USER: torch.zeros(gi2u_src.shape[0]), + }, + partitioned_edge_index={ + _USER_TO_ITEM: GraphPartitionData( + edge_index=torch.stack([u2i_src, u2i_dst]), + edge_ids=None, + weights=u2i_w, + ), + _ITEM_TO_USER: GraphPartitionData( + edge_index=torch.stack([gi2u_src, gi2u_dst]), + edge_ids=None, + weights=None, # unweighted — samples uniformly + ), + }, + partitioned_node_features={ + _USER: FeaturePartitionData(feats=user_feats, ids=torch.arange(n_user)), + _ITEM: FeaturePartitionData(feats=item_feats, ids=torch.arange(n_item)), + }, + partitioned_edge_features=None, + partitioned_positive_labels=None, + partitioned_negative_labels=None, + partitioned_node_labels=None, + ) + return partition_output, n_user, n_good_item, n_bad_item + + # --------------------------------------------------------------------------- # Subprocess functions — must accept local_rank as first arg (mp.spawn) # --------------------------------------------------------------------------- @@ -622,6 +698,72 @@ def test_heterogeneous_partial_weights_by_edge_type(self) -> None: msg=f"Rank {rank}: U2I edge_ids must be None", ) + def test_range_partitioner_homogeneous_weights_partitioned_correctly(self) -> None: + """DistRangePartitioner: edge weights land on the correct rank after range-based partitioning. + + Mirrors test_homogeneous_weights_partitioned_correctly but uses DistRangePartitioner. + With range-based partitioning, edge_ids are sequential per-rank (0..3 on rank 0, + 4..7 on rank 1), and the registered weights (src_node_id / 10.0) equal + edge_id * 0.1 for this test graph. + """ + master_port = get_free_port() + manager = Manager() + output_dict: MutableMapping[int, PartitionOutput] = manager.dict() + + rank_to_edge_weights = { + 0: MOCKED_U2U_EDGE_INDEX_ON_RANK_ZERO[0].float() / 10.0, + 1: MOCKED_U2U_EDGE_INDEX_ON_RANK_ONE[0].float() / 10.0, + } + + mp.spawn( + run_distributed_partitioner, + args=( + output_dict, + False, # is_heterogeneous + RANK_TO_MOCKED_GRAPH, + True, # should_assign_edges_by_src_node + self._master_ip_address, + master_port, + InputDataStrategy.REGISTER_ALL_ENTITIES_SEPARATELY, + DistRangePartitioner, + rank_to_edge_weights, + ), + nprocs=MOCKED_NUM_PARTITIONS, + join=True, + ) + + for rank, partition_output in output_dict.items(): + partitioned_edge_index = partition_output.partitioned_edge_index + self.assertIsInstance(partitioned_edge_index, GraphPartitionData) + assert isinstance(partitioned_edge_index, GraphPartitionData) + + weights = partitioned_edge_index.weights + self.assertIsNotNone( + weights, + msg=f"Rank {rank}: expected weights in GraphPartitionData, got None", + ) + assert weights is not None + + edge_ids = partitioned_edge_index.edge_ids + self.assertIsNotNone( + edge_ids, + msg=f"Rank {rank}: edge_ids must be present when features are registered", + ) + assert edge_ids is not None + + self.assertEqual( + weights.shape, + edge_ids.shape, + msg=f"Rank {rank}: weights and edge_ids must have the same length", + ) + + expected_weights = edge_ids.float() * 0.1 + torch.testing.assert_close( + weights.sort().values, + expected_weights.sort().values, + msg=f"Rank {rank}: partitioned weights do not match expected src_node_id / 10.0", + ) + class DistributedWeightedSamplingTest(TestCase): """End-to-end correctness tests for DistNeighborLoader with with_weight=True. @@ -686,6 +828,25 @@ def test_weighted_sampling_never_traverses_zero_weight_edges_heterogeneous( args=(dataset, n_user), ) + def test_weighted_sampling_partial_weights_heterogeneous(self) -> None: + """Partial weights: weighted U2I respects weights; unweighted I2U samples uniformly. + + U2I is weighted (good items weight=1, bad items weight=0) — bad items must + never appear. I2U has no weights registered, so it uses uniform sampling. + Verifies that mixing weighted and unweighted edge types in one heterogeneous + graph does not crash and that weighted edges still behave correctly. + """ + partition_output, n_user, _, _ = ( + _build_heterogeneous_bipartite_partial_weight_graph() + ) + dataset = DistDataset(rank=0, world_size=1, edge_dir="out") + dataset.build(partition_output=partition_output) + + mp.spawn( + fn=_run_weighted_sampling_correctness_heterogeneous, + args=(dataset, n_user), + ) + if __name__ == "__main__": absltest.main() diff --git a/tests/unit/types_tests/graph_test.py b/tests/unit/types_tests/graph_test.py index fa68667f0..0c1e701f3 100644 --- a/tests/unit/types_tests/graph_test.py +++ b/tests/unit/types_tests/graph_test.py @@ -306,6 +306,37 @@ def test_treat_supervision_edges_as_graph_edges( graph_tensors.edge_index[edge_type], expected_tensor ) + def test_treat_labels_as_edges_converts_homogeneous_edge_weights(self) -> None: + """treat_labels_as_edges() must convert homogeneous edge_weights to heterogeneous form. + + Regression: edge_weights was previously left as a raw tensor while all other fields + were converted to heterogeneous dicts. DistPartitioner rejects that mixed-mode input, + so weighted homogeneous ABLP dataset construction would fail. + """ + weights = torch.tensor([0.5, 1.0]) + graph_tensors = LoadedGraphTensors( + node_ids=torch.tensor([0, 1, 2]), + node_features=torch.tensor([[1.0], [2.0], [3.0]]), + node_labels=None, + edge_index=torch.tensor([[0, 1], [1, 2]]), + edge_features=None, + positive_label=torch.tensor([[0], [2]]), + negative_label=None, + edge_weights=weights, + ) + graph_tensors.treat_labels_as_edges(edge_dir="out") + + self.assertIsInstance( + graph_tensors.edge_weights, + dict, + "edge_weights must be a heterogeneous dict after treat_labels_as_edges()", + ) + assert isinstance(graph_tensors.edge_weights, dict) + self.assertIn(DEFAULT_HOMOGENEOUS_EDGE_TYPE, graph_tensors.edge_weights) + torch.testing.assert_close( + graph_tensors.edge_weights[DEFAULT_HOMOGENEOUS_EDGE_TYPE], weights + ) + def test_select_label_edge_types(self): message_passing_edge_type = DEFAULT_HOMOGENEOUS_EDGE_TYPE edge_types = [ From a11154ccf0a19431925d7905436988a01a10e491 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 15 May 2026 06:08:46 +0000 Subject: [PATCH 13/13] Update --- gigl/distributed/base_dist_loader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 483fc08f2..f286a1481 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -350,12 +350,14 @@ def validate_with_weight( ValueError: If ``with_weight=True`` but no edge weights are registered. NotImplementedError: If ``with_weight=True`` and a PPR sampler is requested. """ + if not with_weight: + return has_edge_weights = ( dataset.has_edge_weights if isinstance(dataset, DistDataset) else dataset.fetch_edge_weights_registered() ) - if with_weight and not has_edge_weights: + if not has_edge_weights: raise ValueError( "with_weight=True requires edge weights to be registered in the dataset. " "Pass weight_edge_feat_name to build_dataset() to register edge weights."