diff --git a/gigl/common/data/load_torch_tensors.py b/gigl/common/data/load_torch_tensors.py index 9b311fd6d..1594bff30 100644 --- a/gigl/common/data/load_torch_tensors.py +++ b/gigl/common/data/load_torch_tensors.py @@ -27,7 +27,63 @@ _ID_FMT = "{entity}_ids" _FEATURE_FMT = "{entity}_features" _LABEL_FMT = "{entity}_labels" +_EDGE_WEIGHTS_KEY = "edge_weights" _NODE_KEY = "node" + + +def _extract_weight_col( + feat_tensor: torch.Tensor, + feature_keys: list[str], + feature_spec: dict, + col_name: str, + edge_type: EdgeType, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Slice a named weight column out of a feature tensor. + + Accounts for multi-dim features: each feature key may contribute more than one column + to ``feat_tensor`` (e.g. ``FixedLenFeature(shape=[16])`` contributes 16 columns). + The weight feature must be a scalar (width 1). + + Args: + feat_tensor: Edge feature tensor of shape ``[num_edges, total_feature_cols]``. + feature_keys: Ordered list of feature names matching the columns of ``feat_tensor``. + feature_spec: Feature spec dict mapping feature name to its TF feature spec (used to + determine per-key column widths). + col_name: Name of the column to extract as weights. + edge_type: Edge type (used only in error messages). + + Returns: + A tuple ``(weights, trimmed_features)`` where ``weights`` is a 1-D tensor of shape + ``[num_edges]`` and ``trimmed_features`` is ``feat_tensor`` with the weight column + removed. + + Raises: + ValueError: If ``col_name`` is not in ``feature_keys`` or the weight feature is not + width 1. + """ + if col_name not in feature_keys: + raise ValueError( + f"weight_edge_feat_name '{col_name}' not found in edge feature keys " + f"for edge type {edge_type}: {feature_keys}" + ) + key_idx = feature_keys.index(col_name) + col_widths = [ + (spec.shape[-1] if spec.shape else 1) + for spec in (feature_spec[k] for k in feature_keys) + ] + weight_width = col_widths[key_idx] + if weight_width != 1: + raise ValueError( + f"weight_edge_feat_name '{col_name}' for edge type {edge_type} must be a scalar " + f"feature (width 1), but has width {weight_width}." + ) + col_offset = sum(col_widths[:key_idx]) + weights = feat_tensor[:, col_offset] + keep_cols = [i for i in range(feat_tensor.shape[1]) if i != col_offset] + trimmed = feat_tensor[:, keep_cols] if keep_cols else None + return weights, trimmed + + _EDGE_KEY = "edge" _POSITIVE_LABEL_KEY = "positive_label" _NEGATIVE_LABEL_KEY = "negative_label" @@ -72,6 +128,7 @@ def _data_loading_process( ], rank: int, tf_dataset_options: TFDatasetOptions = TFDatasetOptions(), + weight_edge_feat_name: Optional[Union[str, dict[EdgeType, str]]] = None, ) -> None: """ Spawned multiprocessing.Process which loads homogeneous or heterogeneous information for a specific entity type [node, edge, positive_label, negative_label] @@ -89,6 +146,11 @@ def _data_loading_process( Serialized information for current entity rank (int): Rank of the current machine tf_dataset_options (TFDatasetOptions): The options to use when building the dataset. + weight_edge_feat_name (Optional[Union[str, dict[EdgeType, str]]]): Only used when + ``entity_type == _EDGE_KEY``. Name of the edge feature column to extract as + sampling weights. Ignored for node, positive_label, and negative_label entities. + Supply a single string for homogeneous graphs or a per-edge-type dict for + heterogeneous graphs. """ # We add a try - except clause here to ensure that exceptions are properly circulated back to the parent process try: @@ -117,6 +179,7 @@ def _data_loading_process( ids: dict[Union[NodeType, EdgeType], torch.Tensor] = {} features: dict[Union[NodeType, EdgeType], torch.Tensor] = {} labels: dict[Union[NodeType, EdgeType], torch.Tensor] = {} + weights: dict[Union[NodeType, EdgeType], torch.Tensor] = {} for ( graph_type, serialized_entity_tf_record_info, @@ -129,14 +192,13 @@ def _data_loading_process( raise NotImplementedError( "Label keys are not supported for edge entities" ) - ( - entity_ids, - entity_features, - entity_labels, - ) = tf_record_dataloader.load_as_torch_tensors( + loaded_entity = tf_record_dataloader.load_as_torch_tensors( serialized_tf_record_info=serialized_entity_tf_record_info, tf_dataset_options=tf_dataset_options, ) + entity_ids = loaded_entity.ids + entity_features = loaded_entity.features + entity_labels = loaded_entity.labels ids[graph_type] = entity_ids logger.info( f"Rank {rank} finished loading {entity_type} ids of shape {entity_ids.shape} for graph type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" @@ -161,6 +223,61 @@ def _data_loading_process( f"Rank {rank} did not detect {entity_type} labels for graph type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}" ) + # Extract weight column from edge features when weight_edge_feat_name is set. + # The weight column is sliced out of each edge type's feature tensor and stored + # separately so it is not duplicated in the feature matrix. + if weight_edge_feat_name is not None and entity_type == _EDGE_KEY: + if isinstance(weight_edge_feat_name, str): + if len(serialized_tf_record_info) != 1 or len(features) != 1: + raise ValueError( + f"weight_edge_feat_name must be a dict[EdgeType, str] for heterogeneous " + f"graphs with multiple edge types ({sorted(serialized_tf_record_info)}). " + "Provide an explicit per-edge-type mapping instead of a single string." + ) + col_name = weight_edge_feat_name + edge_type, feat_tensor = next(iter(features.items())) + assert isinstance(edge_type, EdgeType) + feature_keys = list(serialized_tf_record_info[edge_type].feature_keys) + weights[edge_type], trimmed = _extract_weight_col( + feat_tensor, + feature_keys, + serialized_tf_record_info[edge_type].feature_spec, + col_name, + edge_type, + ) + if trimmed is not None: + features[edge_type] = trimmed + else: + del features[edge_type] + logger.info( + f"Rank {rank} extracted weight column '{col_name}' " + f"from {entity_type} features for type {edge_type}" + ) + else: + # Iterate the EdgeType-keyed dict directly to stay within EdgeType. + for edge_type, col_name in weight_edge_feat_name.items(): + if edge_type not in features: + continue + feat_tensor = features[edge_type] + feature_keys = list( + serialized_tf_record_info[edge_type].feature_keys + ) + weights[edge_type], trimmed = _extract_weight_col( + feat_tensor, + feature_keys, + serialized_tf_record_info[edge_type].feature_spec, + col_name, + edge_type, + ) + if trimmed is not None: + features[edge_type] = trimmed + else: + del features[edge_type] + logger.info( + f"Rank {rank} extracted weight column '{col_name}' " + f"from {entity_type} features for type {edge_type}" + ) + logger.info( f"Rank {rank} is attempting to share {entity_type} id memory for tfrecord directories: {all_tf_record_uris}" ) @@ -180,6 +297,12 @@ def _data_loading_process( ) share_memory(labels) + if weights: + logger.info( + f"Rank {rank} is attempting to share {entity_type} weight memory for tfrecord directories: {all_tf_record_uris}" + ) + share_memory(weights) + output_dict[_ID_FMT.format(entity=entity_type)] = ( list(ids.values())[0] if is_input_homogeneous else ids ) @@ -191,6 +314,10 @@ def _data_loading_process( output_dict[_LABEL_FMT.format(entity=entity_type)] = ( list(labels.values())[0] if is_input_homogeneous else labels ) + if weights: + output_dict[_EDGE_WEIGHTS_KEY] = ( + list(weights.values())[0] if is_input_homogeneous else weights + ) logger.info( f"Rank {rank} has finished loading {entity_type} data from tfrecord directories: {all_tf_record_uris}, elapsed time: {time.time() - start_time:.2f} seconds" @@ -207,6 +334,7 @@ def load_torch_tensors_from_tf_record( rank: int = 0, node_tf_dataset_options: TFDatasetOptions = TFDatasetOptions(), edge_tf_dataset_options: TFDatasetOptions = TFDatasetOptions(), + weight_edge_feat_name: Optional[Union[str, dict[EdgeType, str]]] = None, ) -> LoadedGraphTensors: """ Loads all torch tensors from a SerializedGraphMetadata object for all entity [node, edge, positive_label, negative_label] and edge / node types. @@ -222,6 +350,10 @@ def load_torch_tensors_from_tf_record( rank (int): Rank on current machine node_tf_dataset_options (TFDatasetOptions): The options to use for nodes when building the dataset. edge_tf_dataset_options (TFDatasetOptions): The options to use for edges when building the dataset. + weight_edge_feat_name (Optional[Union[str, dict[EdgeType, str]]]): Name of the edge feature column to extract + as sampling weights. The column is removed from the edge feature matrix and returned separately via + ``LoadedGraphTensors.edge_weights``. Supply a single string for homogeneous graphs or a per-edge-type + dict for heterogeneous graphs. Returns: loaded_graph_tensors (LoadedGraphTensors): Unpartitioned Graph Tensors """ @@ -269,6 +401,7 @@ def load_torch_tensors_from_tf_record( "serialized_tf_record_info": serialized_graph_metadata.edge_entity_info, "rank": rank, "tf_dataset_options": edge_tf_dataset_options, + "weight_edge_feat_name": weight_edge_feat_name, }, ) @@ -351,6 +484,7 @@ def load_torch_tensors_from_tf_record( edge_index = edge_output_dict[_ID_FMT.format(entity=_EDGE_KEY)] edge_features = edge_output_dict.get(_FEATURE_FMT.format(entity=_EDGE_KEY), None) + edge_weights = edge_output_dict.get(_EDGE_WEIGHTS_KEY, None) positive_labels = edge_output_dict.get( _ID_FMT.format(entity=_POSITIVE_LABEL_KEY), None @@ -378,4 +512,5 @@ def load_torch_tensors_from_tf_record( edge_features=edge_features, positive_label=positive_labels, negative_label=negative_labels, + edge_weights=edge_weights, ) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 203c8520d..f286a1481 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -333,6 +333,42 @@ def __init__( "for graph-store mode." ) + @staticmethod + def validate_with_weight( + with_weight: bool, + dataset: Union[DistDataset, RemoteDistDataset], + sampler_options: SamplerOptions, + ) -> None: + """Validates the ``with_weight`` parameter against the dataset and sampler. + + Args: + with_weight: Whether weighted sampling was requested. + dataset: The dataset being sampled from. + sampler_options: The sampler to be used. + + Raises: + ValueError: If ``with_weight=True`` but no edge weights are registered. + NotImplementedError: If ``with_weight=True`` and a PPR sampler is requested. + """ + if not with_weight: + return + has_edge_weights = ( + dataset.has_edge_weights + if isinstance(dataset, DistDataset) + else dataset.fetch_edge_weights_registered() + ) + if not has_edge_weights: + raise ValueError( + "with_weight=True requires edge weights to be registered in the dataset. " + "Pass weight_edge_feat_name to build_dataset() to register edge weights." + ) + # TODO(mkolodner-sc): Implement weight-proportional residual propagation for PPR. + if with_weight and isinstance(sampler_options, PPRSamplerOptions): + raise NotImplementedError( + "Weighted sampling is not yet supported with PPRSamplerOptions. " + "Weight-proportional residual propagation for PPR is planned but not implemented." + ) + @staticmethod def create_sampling_config( num_neighbors: Union[list[int], dict[EdgeType, list[int]]], @@ -340,6 +376,7 @@ def create_sampling_config( batch_size: int = 1, shuffle: bool = False, drop_last: bool = False, + with_weight: bool = False, ) -> SamplingConfig: """Creates a SamplingConfig with patched fanout. @@ -352,6 +389,9 @@ def create_sampling_config( batch_size: How many samples per batch. shuffle: Whether to shuffle input nodes. drop_last: Whether to drop the last incomplete batch. + with_weight: Whether to use edge weights for sampling. Requires that + edge weights were registered during dataset construction via + ``DistPartitioner.register_edge_weights()``. Returns: A fully configured SamplingConfig. @@ -369,7 +409,7 @@ def create_sampling_config( with_edge=True, collect_features=True, with_neg=False, - with_weight=False, + with_weight=with_weight, edge_dir=dataset_schema.edge_dir, seed=None, ) diff --git a/gigl/distributed/dataset_factory.py b/gigl/distributed/dataset_factory.py index ffa13fc71..7f74cc996 100644 --- a/gigl/distributed/dataset_factory.py +++ b/gigl/distributed/dataset_factory.py @@ -66,6 +66,7 @@ def _load_and_build_partitioned_dataset( edge_tf_dataset_options: TFDatasetOptions, splitter: Optional[Union[NodeSplitter, NodeAnchorLinkSplitter]] = None, _ssl_positive_label_percentage: Optional[float] = None, + weight_edge_feat_name: Optional[Union[str, dict[EdgeType, str]]] = None, ) -> DistDataset: """ Given some information about serialized TFRecords, loads and builds a partitioned dataset into a DistDataset class. @@ -82,6 +83,11 @@ def _load_and_build_partitioned_dataset( splitter (Optional[Union[NodeSplitter, NodeAnchorLinkSplitter]]): Optional splitter to use for splitting the graph data into train, val, and test sets. If not provided (None), no splitting will be performed. _ssl_positive_label_percentage (Optional[float]): Percentage of edges to select as self-supervised labels. Must be None if supervised edge labels are provided in advance. Slotted for refactor once this functionality is available in the transductive `splitter` directly + weight_edge_feat_name (Optional[Union[str, dict[EdgeType, str]]]): Name of the edge feature column to use as + sampling weights. The column is extracted from the feature tensor and registered separately via + ``DistPartitioner.register_edge_weights()``; it is removed from the feature matrix to avoid duplication. + Supply a single string to use the same column name for all edge types, or a per-edge-type dict. + Returns: DistDataset: Initialized dataset with partitioned graph information @@ -104,6 +110,7 @@ def _load_and_build_partitioned_dataset( rank=rank, node_tf_dataset_options=node_tf_dataset_options, edge_tf_dataset_options=edge_tf_dataset_options, + weight_edge_feat_name=weight_edge_feat_name, ) # TODO (mkolodner-sc): Move this code block (from here up to start of partitioning) to transductive splitter once that is ready @@ -175,6 +182,10 @@ def _load_and_build_partitioned_dataset( ) if loaded_graph_tensors.node_labels is not None: partitioner.register_node_labels(node_labels=loaded_graph_tensors.node_labels) + if loaded_graph_tensors.edge_weights is not None: + partitioner.register_edge_weights( + edge_weights=loaded_graph_tensors.edge_weights + ) if loaded_graph_tensors.edge_features is not None: partitioner.register_edge_features( edge_features=loaded_graph_tensors.edge_features @@ -196,6 +207,7 @@ def _load_and_build_partitioned_dataset( loaded_graph_tensors.node_features, loaded_graph_tensors.edge_index, loaded_graph_tensors.edge_features, + loaded_graph_tensors.edge_weights, loaded_graph_tensors.positive_label, loaded_graph_tensors.negative_label, loaded_graph_tensors.node_labels, @@ -230,6 +242,7 @@ def _build_dataset_process( edge_tf_dataset_options: TFDatasetOptions, splitter: Optional[Union[NodeSplitter, NodeAnchorLinkSplitter]] = None, _ssl_positive_label_percentage: Optional[float] = None, + weight_edge_feat_name: Optional[Union[str, dict[EdgeType, str]]] = None, ) -> None: """ This function is spawned by a single process per machine and is responsible for: @@ -311,6 +324,7 @@ def _build_dataset_process( edge_tf_dataset_options=edge_tf_dataset_options, splitter=splitter, _ssl_positive_label_percentage=_ssl_positive_label_percentage, + weight_edge_feat_name=weight_edge_feat_name, ) output_dict["dataset"] = output_dataset @@ -336,6 +350,7 @@ def build_dataset( _dataset_building_port: Optional[ int ] = None, # WARNING: This field will be deprecated in the future + weight_edge_feat_name: Optional[Union[str, dict[EdgeType, str]]] = None, ) -> DistDataset: """ Launches a spawned process for building and returning a DistDataset instance provided some @@ -368,6 +383,10 @@ def build_dataset( Slotted for refactor once this functionality is available in the transductive `splitter` directly _dataset_building_port (deprecated field - will be removed soon) (Optional[int]): Contains information about master port. Defaults to None, in which case it will be initialized from the current torch.distributed context. + weight_edge_feat_name (Optional[Union[str, dict[EdgeType, str]]]): Name of the edge feature column to use + as sampling weights. The column is extracted from the feature tensor and registered separately; it is + removed from the feature matrix to avoid memory duplication. Supply a single string to apply to all + edge types, or a per-edge-type dict. (default: ``None``) Returns: DistDataset: Built GraphLearn-for-PyTorch Dataset class @@ -463,6 +482,7 @@ def build_dataset( edge_tf_dataset_options, splitter, _ssl_positive_label_percentage, + weight_edge_feat_name, ), ) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 989907640..c6bc6f2f7 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -91,6 +91,7 @@ def __init__( num_cpu_threads: Optional[int] = None, shuffle: bool = False, drop_last: bool = False, + with_weight: bool = False, sampler_options: Optional[SamplerOptions] = None, context: Optional[DistributedContext] = None, # TODO: (svij) Deprecate this local_process_rank: Optional[int] = None, # TODO: (svij) Deprecate this @@ -201,6 +202,10 @@ def __init__( Defaults to `2` if set to `None` when using cpu training/inference. shuffle (bool): Whether to shuffle the input nodes. (default: ``False``). drop_last (bool): Whether to drop the last incomplete batch. (default: ``False``). + with_weight (bool): Whether to use edge weights for neighbor sampling. + Requires edge weights to have been provided via + ``build_dataset(weight_edge_feat_name=...)`` during dataset construction. + Defaults to ``False``. sampler_options (Optional[SamplerOptions]): Controls which sampler class is instantiated. Defaults to `KHopNeighborSamplerOptions`, which will use the num_neighbors argument to instantiate the sampler. @@ -261,6 +266,8 @@ def __init__( ) del context, local_process_rank, local_process_world_size + BaseDistLoader.validate_with_weight(with_weight, dataset, sampler_options) + device = ( pin_memory_device if pin_memory_device @@ -349,6 +356,7 @@ def __init__( batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, + with_weight=with_weight, ) producer: Optional[DistSamplingProducer] = None diff --git a/gigl/distributed/dist_dataset.py b/gigl/distributed/dist_dataset.py index b40f2969a..4db0b02fc 100644 --- a/gigl/distributed/dist_dataset.py +++ b/gigl/distributed/dist_dataset.py @@ -84,6 +84,7 @@ def __init__( Union[torch.Tensor, dict[EdgeType, torch.Tensor]] ] = None, max_labels_per_anchor_node: Optional[int] = None, + has_edge_weights: bool = False, ) -> None: """ Initializes the fields of the DistDataset class. This function is called upon each serialization of the DistDataset instance. @@ -111,6 +112,9 @@ def __init__( degree_tensor: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Pre-computed degree tensor. Lazily computed on first access via the degree_tensor property. max_labels_per_anchor_node (Optional[int]): Optional cap for how many labels to materialize per anchor node for ABLP label fetching. + has_edge_weights (bool): Whether edge weights were registered during dataset + construction via ``DistPartitioner.register_edge_weights()``. Automatically + set during ``build()``; do not pass this manually. (default: ``False``) """ self._rank: int = rank self._world_size: int = world_size @@ -150,6 +154,7 @@ def __init__( Union[torch.Tensor, dict[EdgeType, torch.Tensor]] ] = degree_tensor self._max_labels_per_anchor_node = max_labels_per_anchor_node + self._has_edge_weights: bool = has_edge_weights # TODO (mkolodner-sc): Modify so that we don't need to rely on GLT's base variable naming (i.e. partition_idx, num_partitions) in favor of more clear # naming (i.e. rank, world_size). @@ -336,6 +341,13 @@ def degree_tensor( self._degree_tensor = compute_and_broadcast_degree_tensor(self.graph) return self._degree_tensor + @property + def has_edge_weights(self) -> bool: + """True if edge weights were registered during dataset construction via + ``DistPartitioner.register_edge_weights()``. + """ + return self._has_edge_weights + @property def max_labels_per_anchor_node(self) -> Optional[int]: return self._max_labels_per_anchor_node @@ -559,11 +571,15 @@ def _initialize_graph( GraphPartitionData, dict[EdgeType, GraphPartitionData] ], ) -> None: - """ - Initializes the graph structure with edge index and edge IDs from partition output. + """Initializes the graph structure from partition output. + + Sets up the GLT graph with edge index, edge IDs, and optional edge weights. + For heterogeneous graphs with weights registered on only some edge types, a + warning is logged and unweighted edge types fall back to uniform sampling. Args: - partitioned_edge_index(Union[GraphPartitionData, dict[EdgeType, GraphPartitionData]]): The partitioned graph data + partitioned_edge_index: Partitioned graph data per edge type (heterogeneous) + or a single partition (homogeneous). """ # Edge Index refers to the [2, num_edges] tensor representing pairs of nodes connecting each edge @@ -575,6 +591,9 @@ def _initialize_graph( edge_ids: Union[ Optional[torch.Tensor], dict[EdgeType, Optional[torch.Tensor]] ] = partitioned_edge_index.edge_ids + edge_weights: Optional[ + Union[torch.Tensor, dict[EdgeType, torch.Tensor]] + ] = partitioned_edge_index.weights else: edge_index = { edge_type: graph_partition_data.edge_index @@ -584,12 +603,34 @@ def _initialize_graph( edge_type: graph_partition_data.edge_ids for edge_type, graph_partition_data in partitioned_edge_index.items() } + weights_by_type = { + edge_type: graph_partition_data.weights + for edge_type, graph_partition_data in partitioned_edge_index.items() + if graph_partition_data.weights is not None + } + if weights_by_type: + missing = set(partitioned_edge_index.keys()) - set( + weights_by_type.keys() + ) + if missing: + logger.warning( + f"Edge weights are registered for {set(weights_by_type.keys())} but " + f"not for {missing}. Filling missing edge types with uniform weights " + f"(all 1s) so GLT does not segfault on partial-weight heterogeneous graphs." + ) + for edge_type in missing: + n_edges = partitioned_edge_index[edge_type].edge_index.size(1) + weights_by_type[edge_type] = torch.ones(n_edges) + edge_weights = weights_by_type if weights_by_type else None + + self._has_edge_weights = edge_weights is not None self.init_graph( edge_index=edge_index, edge_ids=edge_ids, graph_mode="CPU", directed=True, + edge_weights=edge_weights, ) if isinstance(partitioned_edge_index, Mapping): @@ -876,6 +917,7 @@ def share_ipc( Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]], Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], Optional[int], + bool, ]: """ Serializes the member variables of the DistDatasetClass @@ -899,6 +941,7 @@ def share_ipc( Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]: Edge feature dim and its data type, will be a dict if heterogeneous Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Degree tensors, will be a dict if heterogeneous Optional[int]: Optional per-anchor label cap for ABLP label fetching + bool: Whether edge weights were registered during dataset construction """ # TODO (mkolodner-sc): Investigate moving share_memory calls to the build() function @@ -928,6 +971,7 @@ def share_ipc( self._edge_feature_info, # Additional field unique to DistDataset class self._degree_tensor, # Additional field unique to DistDataset class self._max_labels_per_anchor_node, # Additional field unique to DistDataset class + self._has_edge_weights, # Additional field unique to DistDataset class ) return ipc_handle @@ -1185,6 +1229,7 @@ def _rebuild_distributed_dataset( ], # Edge feature dim and its data type Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], # Degree tensors Optional[int], # Optional per-anchor label cap for ABLP label fetching + bool, # has_edge_weights ], ): dataset = DistDataset.from_ipc_handle(ipc_handle) diff --git a/gigl/distributed/dist_partitioner.py b/gigl/distributed/dist_partitioner.py index 514ba5193..c887cd54b 100644 --- a/gigl/distributed/dist_partitioner.py +++ b/gigl/distributed/dist_partitioner.py @@ -202,6 +202,7 @@ def __init__( self._edge_ids: Optional[dict[EdgeType, tuple[int, int]]] = None self._edge_feat: Optional[dict[EdgeType, torch.Tensor]] = None self._edge_feat_dim: Optional[dict[EdgeType, int]] = None + self._edge_weights: Optional[dict[EdgeType, torch.Tensor]] = None # TODO (mkolodner-sc): Deprecate the need for explicitly storing labels are part of this class, leveraging # heterogeneous support instead @@ -623,6 +624,63 @@ def register_edge_features( for edge_type in input_edge_features: self._edge_feat_dim[edge_type] = input_edge_features[edge_type].shape[1] + def register_edge_weights( + self, edge_weights: Union[torch.Tensor, dict[EdgeType, torch.Tensor]] + ) -> None: + """Registers per-edge sampling weights to the partitioner. + + Weights must be a 1-D float tensor of shape ``[num_edges]``, one scalar per edge. + Must be called before ``partition()``. May be called before or after + ``register_edge_index()``; shape validation against the edge index is performed + only when ``register_edge_index()`` has already been called. + + Weights are kept separate from edge features — do not include the weight column + in edge features passed to ``register_edge_features()``. + + For optimal memory management, delete the reference to ``edge_weights`` after + calling this function. + + Args: + edge_weights: Per-edge weights, either a ``[num_edges]`` tensor if homogeneous + or a ``dict[EdgeType, Tensor]`` if heterogeneous. + """ + self._assert_and_get_rpc_setup() + + if self._edge_weights is not None: + raise ValueError( + "Edge weights have already been registered. Cannot re-register edge weight data." + ) + + logger.info("Registering Edge Weights ...") + + input_edge_weights = self._convert_edge_entity_to_heterogeneous_format( + input_edge_entity=edge_weights + ) + + if not input_edge_weights: + raise ValueError( + "Edge weights is an empty dictionary. Please provide edge weights to register." + ) + + for edge_type, weight_tensor in input_edge_weights.items(): + if weight_tensor.ndim != 1: + raise ValueError( + f"Edge weights for edge type {edge_type} must be a 1-D tensor of shape " + f"[num_edges], got shape {tuple(weight_tensor.shape)}." + ) + # Shape validation is best-effort: only fires when register_edge_index() + # has already been called for this edge type. + if self._edge_index is not None and edge_type in self._edge_index: + local_num_edges = self._edge_index[edge_type].shape[1] + if weight_tensor.shape[0] != local_num_edges: + raise ValueError( + f"Edge weights for edge type {edge_type} have length " + f"{weight_tensor.shape[0]} but the registered edge index has " + f"{local_num_edges} edges on this rank." + ) + + self._edge_weights = convert_to_tensor(input_edge_weights, dtype=torch.float32) + def register_labels( self, label_edge_index: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], @@ -1076,9 +1134,12 @@ def _partition_edge_index_and_edge_features( and self._num_edges is not None ), "Must have registered edges prior to partitioning them" - should_skip_edge_feats = ( - self._edge_feat is None or edge_type not in self._edge_feat + has_edge_feats = self._edge_feat is not None and edge_type in self._edge_feat + has_weights_for_edge_type = ( + self._edge_weights is not None and edge_type in self._edge_weights ) + # Need a partition book if we have features or weights to reindex. + should_generate_partition_book = has_edge_feats or has_weights_for_edge_type # Partitioning Edge Indices @@ -1100,13 +1161,12 @@ def _edge_pfn(_, chunk_range): chunk_target_indices = index_select(target_indices, chunk_range) return target_node_partition_book[chunk_target_indices] - # TODO (mkolodner-sc): Investigate partitioning edge features as part of this input_data tuple edge_res_list, edge_partition_book = self._partition_by_chunk( input_data=(edge_index[0], edge_index[1], edge_ids), rank_indices=edge_ids, partition_function=_edge_pfn, total_val_size=num_edges, - generate_pb=not should_skip_edge_feats, + generate_pb=should_generate_partition_book, ) del edge_index, target_indices @@ -1117,9 +1177,9 @@ def _edge_pfn(_, chunk_range): gc.collect() - if len(edge_res_list) == 0: + had_zero_edges = len(edge_res_list) == 0 + if had_zero_edges: partitioned_edge_index = torch.empty((2, 0)) - partitioned_edge_ids = torch.empty(0) else: partitioned_edge_index = torch.stack( ( @@ -1128,85 +1188,138 @@ def _edge_pfn(_, chunk_range): ), dim=0, ) - if should_skip_edge_feats: - partitioned_edge_ids = None - else: - partitioned_edge_ids = torch.cat([r[2] for r in edge_res_list]) - - current_graph_part = GraphPartitionData( - edge_index=partitioned_edge_index, - edge_ids=partitioned_edge_ids, - ) edge_res_list.clear() gc.collect() - # Partitioning Edge Features + # Partition edge features and weights together in a single pass, + # mirroring how node features and labels are co-partitioned. + # Input tuple layout: (edge_feat?, edge_weights?, edge_ids) + # IDs are always at r[-1]; features at r[0]; weights at r[1] when + # features are also present, else r[0]. + current_feat_part: Optional[FeaturePartitionData] = None + partitioned_weights: Optional[torch.Tensor] = None + partitioned_edge_ids: Optional[torch.Tensor] = None - if should_skip_edge_feats: + if not should_generate_partition_book: logger.info( f"No edge features detected for edge type {edge_type}, will only partition edge indices for this edge type." ) - current_feat_part = None + if had_zero_edges: + partitioned_edge_ids = torch.empty(0) del edge_ids del self._edge_ids[edge_type] if len(self._edge_ids) == 0: self._edge_ids = None gc.collect() else: - assert self._edge_feat_dim is not None and edge_type in self._edge_feat_dim - assert self._edge_feat is not None and edge_type in self._edge_feat assert edge_partition_book is not None - edge_feat = self._edge_feat[edge_type] - edge_feat_dim = self._edge_feat_dim[edge_type] - def _edge_feature_pfn(edge_feature_ids, _): + edge_feat: Optional[torch.Tensor] = None + edge_feat_dim: Optional[int] = None + edge_weights_tensor: Optional[torch.Tensor] = None + if has_edge_feats: + assert self._edge_feat is not None and edge_type in self._edge_feat + assert ( + self._edge_feat_dim is not None and edge_type in self._edge_feat_dim + ) + edge_feat = self._edge_feat[edge_type] + edge_feat_dim = self._edge_feat_dim[edge_type] + if has_weights_for_edge_type: + assert self._edge_weights is not None + edge_weights_tensor = self._edge_weights[edge_type] + + input_parts: list[torch.Tensor] = [] + if edge_feat is not None: + input_parts.append(edge_feat) + if edge_weights_tensor is not None: + input_parts.append(edge_weights_tensor) + input_parts.append(edge_ids) + + # Positional indices: features first, weights next, ids always last. + feat_idx: Optional[int] = 0 if has_edge_feats else None + weight_idx: Optional[int] = ( + (1 if has_edge_feats else 0) if has_weights_for_edge_type else None + ) + + def _edge_feat_weight_pfn( + ids_chunk: torch.Tensor, _: object + ) -> torch.Tensor: assert edge_partition_book is not None - return edge_partition_book[edge_feature_ids] + return edge_partition_book[ids_chunk] - # partitioned_results is a list of tuples. Each tuple correpsonds - # to a chunk of data. A tuple contains edge features and edge ids. - edge_feat_res_list, _ = self._partition_by_chunk( - input_data=(edge_feat, edge_ids), + # Each result tuple contains (edge_feat?, edge_weights?, edge_ids). + feat_weight_res_list, _ = self._partition_by_chunk( + input_data=tuple(input_parts), rank_indices=edge_ids, - partition_function=_edge_feature_pfn, + partition_function=_edge_feat_weight_pfn, total_val_size=num_edges, generate_pb=False, ) - del edge_feat, edge_ids - del ( - self._edge_feat[edge_type], - self._edge_feat_dim[edge_type], - self._edge_ids[edge_type], - ) + + del edge_ids + del self._edge_ids[edge_type] if len(self._edge_ids) == 0: self._edge_ids = None - if len(self._edge_feat) == 0 and len(self._edge_feat_dim) == 0: - self._edge_feat = None - self._edge_feat_dim = None + if has_edge_feats: + assert edge_feat is not None + assert self._edge_feat is not None and self._edge_feat_dim is not None + del edge_feat + del self._edge_feat[edge_type], self._edge_feat_dim[edge_type] + if len(self._edge_feat) == 0 and len(self._edge_feat_dim) == 0: + self._edge_feat = None + self._edge_feat_dim = None + if has_weights_for_edge_type: + assert edge_weights_tensor is not None + assert self._edge_weights is not None + del edge_weights_tensor + del self._edge_weights[edge_type] + if len(self._edge_weights) == 0: + self._edge_weights = None gc.collect() - if len(edge_feat_res_list) == 0: - partitioned_edge_features = torch.empty(0, edge_feat_dim) - partitioned_edge_feat_ids = torch.empty(0) + + if len(feat_weight_res_list) == 0: + partitioned_edge_ids = torch.empty(0) + if has_edge_feats: + assert edge_feat_dim is not None + current_feat_part = FeaturePartitionData( + feats=torch.empty(0, edge_feat_dim), + ids=partitioned_edge_ids, + ) + if has_weights_for_edge_type: + partitioned_weights = torch.empty(0) else: - partitioned_edge_features = torch.cat( - [r[0] for r in edge_feat_res_list] - ) - partitioned_edge_feat_ids = torch.cat( - [r[1] for r in edge_feat_res_list] - ) + partitioned_edge_ids = torch.cat([r[-1] for r in feat_weight_res_list]) + if has_edge_feats: + assert feat_idx is not None and edge_feat_dim is not None + current_feat_part = FeaturePartitionData( + feats=torch.cat([r[feat_idx] for r in feat_weight_res_list]), + ids=partitioned_edge_ids, + ) + if has_weights_for_edge_type: + assert weight_idx is not None + partitioned_weights = torch.cat( + [r[weight_idx] for r in feat_weight_res_list] + ) + + feat_weight_res_list.clear() + gc.collect() - current_feat_part = FeaturePartitionData( - feats=partitioned_edge_features, ids=partitioned_edge_feat_ids - ) - logger.info( - f"Got edge tensor-based partition book for edge type {edge_type} on rank {self._rank} of shape {edge_partition_book.shape}" - ) + if has_edge_feats: + logger.info( + f"Got edge tensor-based partition book for edge type {edge_type} on rank {self._rank} of shape {edge_partition_book.shape}" + ) logger.info( f"Edge Index and Feature Partitioning for edge type {edge_type} finished, took {time.time() - start_time:.3f}s" ) + current_graph_part = GraphPartitionData( + edge_index=partitioned_edge_index, + edge_ids=partitioned_edge_ids, + weights=partitioned_weights, + ) + return current_graph_part, current_feat_part, edge_partition_book def _partition_label_edge_index( diff --git a/gigl/distributed/dist_range_partitioner.py b/gigl/distributed/dist_range_partitioner.py index 42110e983..faa1cfb57 100644 --- a/gigl/distributed/dist_range_partitioner.py +++ b/gigl/distributed/dist_range_partitioner.py @@ -217,21 +217,41 @@ def _partition_edge_index_and_edge_features( ) edge_index = self._edge_index[edge_type] + has_edge_feats = self._edge_feat is not None and edge_type in self._edge_feat + has_edge_weights = ( + self._edge_weights is not None and edge_type in self._edge_weights + ) - input_data: tuple[torch.Tensor, ...] - - if self._edge_feat is None or edge_type not in self._edge_feat: + if not has_edge_feats: logger.info( f"No edge features detected for edge type {edge_type}, will only partition edge indices for this edge type." ) - edge_feat = None - edge_feat_dim = None - input_data = (edge_index[0], edge_index[1]) - else: - assert self._edge_feat_dim is not None and edge_type in self._edge_feat_dim + + edge_feat: Optional[torch.Tensor] = None + edge_feat_dim: Optional[int] = None + edge_weights_tensor: Optional[torch.Tensor] = None + + if has_edge_feats: + assert self._edge_feat is not None and self._edge_feat_dim is not None + assert edge_type in self._edge_feat_dim edge_feat = self._edge_feat[edge_type] edge_feat_dim = self._edge_feat_dim[edge_type] - input_data = (edge_index[0], edge_index[1], edge_feat) + if has_edge_weights: + assert self._edge_weights is not None + edge_weights_tensor = self._edge_weights[edge_type] + + # Build input_data tuple: (src, dst[, feat][, weights]) + # Track the index of each optional tensor so we can unpack res_list correctly. + input_parts: list[torch.Tensor] = [edge_index[0], edge_index[1]] + feat_idx: Optional[int] = None + weight_idx: Optional[int] = None + if edge_feat is not None: + feat_idx = len(input_parts) + input_parts.append(edge_feat) + if edge_weights_tensor is not None: + weight_idx = len(input_parts) + input_parts.append(edge_weights_tensor) + input_data: tuple[torch.Tensor, ...] = tuple(input_parts) if self._should_assign_edges_by_src_node: target_node_partition_book = node_partition_book[edge_type.src_node_type] @@ -249,10 +269,13 @@ def edge_partition_fn(rank_indices, _): partition_function=edge_partition_fn, ) - del input_data, edge_index, target_indices, edge_feat + del input_data, edge_index, target_indices, edge_feat, edge_weights_tensor del self._edge_index[edge_type] if self._edge_feat is not None and edge_type in self._edge_feat: - del self._edge_feat[edge_type] + assert self._edge_feat_dim is not None + del self._edge_feat[edge_type], self._edge_feat_dim[edge_type] + if self._edge_weights is not None and edge_type in self._edge_weights: + del self._edge_weights[edge_type] # We check if edge_index or edge_feat dict is empty after deleting the tensor. If so, we set these fields to None. if not self._edge_index: @@ -260,11 +283,17 @@ def edge_partition_fn(rank_indices, _): if not self._edge_feat and not self._edge_feat_dim: self._edge_feat = None self._edge_feat_dim = None + if self._edge_weights is not None and not self._edge_weights: + self._edge_weights = None gc.collect() if len(res_list) == 0: partitioned_edge_index = torch.empty((2, 0)) + partitioned_edge_features = ( + torch.empty(0, edge_feat_dim) if edge_feat_dim is not None else None + ) + partitioned_weights = torch.empty(0) if has_edge_weights else None else: partitioned_edge_index = torch.stack( ( @@ -273,19 +302,22 @@ def edge_partition_fn(rank_indices, _): ), dim=0, ) - - if edge_feat_dim is not None: - if len(res_list) == 0: - partitioned_edge_features = torch.empty(0, edge_feat_dim) - else: - partitioned_edge_features = torch.cat([r[2] for r in res_list]) + partitioned_edge_features = ( + torch.cat([r[feat_idx] for r in res_list]) + if feat_idx is not None + else None + ) + partitioned_weights = ( + torch.cat([r[weight_idx] for r in res_list]) + if weight_idx is not None + else None + ) res_list.clear() - gc.collect() - # Generating edge partition book - + # Generate range-based edge partition book and infer edge IDs. + # Only needed when edge features are present — weights use positional IDs. num_edges_on_each_rank: list[tuple[int, int]] = sorted( all_gather((self._rank, partitioned_edge_index.size(1))).values(), key=lambda x: x[0], @@ -304,10 +336,11 @@ def edge_partition_fn(rank_indices, _): partitioned_edge_ids = get_ids_on_rank( partition_book=edge_partition_book, rank=self._rank ) - + assert partitioned_edge_features is not None current_graph_part = GraphPartitionData( edge_index=partitioned_edge_index, edge_ids=partitioned_edge_ids, + weights=partitioned_weights, ) current_feat_part = FeaturePartitionData( feats=partitioned_edge_features, ids=None @@ -320,6 +353,7 @@ def edge_partition_fn(rank_indices, _): current_graph_part = GraphPartitionData( edge_index=partitioned_edge_index, edge_ids=None, + weights=partitioned_weights, ) edge_partition_book = None diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 3d6d5a34b..7dd4fd73f 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -90,6 +90,7 @@ def __init__( num_cpu_threads: Optional[int] = None, shuffle: bool = False, drop_last: bool = False, + with_weight: bool = False, sampler_options: Optional[SamplerOptions] = None, non_blocking_transfers: bool = True, ): @@ -158,6 +159,10 @@ def __init__( Defaults to `2` if set to `None` when using cpu training/inference. shuffle (bool): Whether to shuffle the input nodes. (default: ``False``). drop_last (bool): Whether to drop the last incomplete batch. (default: ``False``). + with_weight (bool): Whether to use edge weights for neighbor sampling. + Requires edge weights to have been provided via + ``build_dataset(weight_edge_feat_name=...)`` during dataset construction. + Defaults to ``False``. sampler_options (Optional[SamplerOptions]): Controls which sampler class is instantiated. Pass ``KHopNeighborSamplerOptions`` to use the built-in sampler, or ``CustomSamplerOptions`` to dynamically import a custom sampler class. @@ -184,6 +189,8 @@ def __init__( ) del context, local_process_rank, local_process_world_size + BaseDistLoader.validate_with_weight(with_weight, dataset, sampler_options) + # Determine mode if isinstance(dataset, RemoteDistDataset): self._sampling_cluster_setup = SamplingClusterSetup.GRAPH_STORE @@ -263,6 +270,7 @@ def __init__( batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, + with_weight=with_weight, ) producer: Optional[DistSamplingProducer] = None diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index 9a2ca23dc..636e8c332 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -425,6 +425,14 @@ def get_edge_dir(self) -> Literal["in", "out"]: """ return self.dataset.edge_dir + def get_edge_weights_registered(self) -> bool: + """Return whether edge weights were registered in the dataset. + + Returns: + True if edge weights were registered via ``DistPartitioner.register_edge_weights()``. + """ + return self.dataset.has_edge_weights + def get_node_ids( self, request: FetchNodesRequest, diff --git a/gigl/distributed/graph_store/remote_dist_dataset.py b/gigl/distributed/graph_store/remote_dist_dataset.py index cb03934af..e58784dc0 100644 --- a/gigl/distributed/graph_store/remote_dist_dataset.py +++ b/gigl/distributed/graph_store/remote_dist_dataset.py @@ -606,3 +606,14 @@ def fetch_node_types(self) -> Optional[list[NodeType]]: 0, DistServer.get_node_types, ) + + def fetch_edge_weights_registered(self) -> bool: + """Fetch whether edge weights were registered in the remote dataset. + + Returns: + True if edge weights were registered via ``DistPartitioner.register_edge_weights()``. + """ + return request_server( + 0, + DistServer.get_edge_weights_registered, + ) diff --git a/gigl/types/graph.py b/gigl/types/graph.py index 00ee69e33..3fc4958fc 100644 --- a/gigl/types/graph.py +++ b/gigl/types/graph.py @@ -149,6 +149,8 @@ class LoadedGraphTensors: positive_label: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]] # Unpartitioned Negative Edge Label negative_label: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]] + # Unpartitioned Edge Weights (per-edge sampling weights, one scalar per edge) + edge_weights: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]] = None def treat_labels_as_edges(self, edge_dir: Literal["in", "out"]) -> None: """ @@ -249,6 +251,7 @@ def treat_labels_as_edges(self, edge_dir: Literal["in", "out"]) -> None: self.node_features = to_heterogeneous_node(self.node_features) self.edge_index = edge_index_with_labels self.edge_features = to_heterogeneous_edge(self.edge_features) + self.edge_weights = to_heterogeneous_edge(self.edge_weights) self.positive_label = None self.negative_label = None gc.collect() diff --git a/tests/test_assets/distributed/run_distributed_partitioner.py b/tests/test_assets/distributed/run_distributed_partitioner.py index 96cfb25cf..3bbc3406c 100644 --- a/tests/test_assets/distributed/run_distributed_partitioner.py +++ b/tests/test_assets/distributed/run_distributed_partitioner.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Type, Union +from typing import Optional, Type, Union import torch from graphlearn_torch.distributed import init_rpc, init_worker_group @@ -19,6 +19,9 @@ class InputDataStrategy(Enum): REGISTER_ALL_ENTITIES_SEPARATELY = "REGISTER_ALL_ENTITIES_SEPARATELY" REGISTER_ALL_ENTITIES_TOGETHER = "REGISTER_ALL_ENTITIES_TOGETHER" REGISTER_MINIMAL_ENTITIES_SEPARATELY = "REGISTER_MINIMAL_ENTITIES_SEPARATELY" + REGISTER_EDGE_WEIGHTS_WITHOUT_EDGE_FEATURES = ( + "REGISTER_EDGE_WEIGHTS_WITHOUT_EDGE_FEATURES" + ) def run_distributed_partitioner( @@ -31,6 +34,9 @@ def run_distributed_partitioner( master_port: int, input_data_strategy: InputDataStrategy, partitioner_class: Type[DistPartitioner], + rank_to_edge_weights: Optional[ + dict[int, Union[torch.Tensor, dict[EdgeType, torch.Tensor]]] + ] = None, ) -> None: """ Runs the distributed partitioner on a specific rank. @@ -44,6 +50,7 @@ def run_distributed_partitioner( master_port (int): Master port for initializing rpc for partitioning input_data_strategy (InputDataStrategy): Strategy for registering inputs to the partitioner partitioner_class (Type[DistPartitioner]): The class to use for partitioning + rank_to_edge_weights (Optional[dict[int, Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]]): Optional mapping of rank to 1D edge weight tensor (or dict per EdgeType for heterogeneous). Only supported with REGISTER_ALL_ENTITIES_SEPARATELY strategy. """ input_graph = rank_to_input_graph[rank] @@ -78,7 +85,10 @@ def run_distributed_partitioner( init_rpc(master_addr=master_addr, master_port=master_port, num_rpc_threads=4) dist_partitioner: DistPartitioner - if input_data_strategy == InputDataStrategy.REGISTER_ALL_ENTITIES_SEPARATELY: + if input_data_strategy in ( + InputDataStrategy.REGISTER_ALL_ENTITIES_SEPARATELY, + InputDataStrategy.REGISTER_EDGE_WEIGHTS_WITHOUT_EDGE_FEATURES, + ): dist_partitioner = partitioner_class( should_assign_edges_by_src_node=should_assign_edges_by_src_node, ) @@ -88,9 +98,14 @@ def run_distributed_partitioner( output_node_partition_book = dist_partitioner.partition_node() dist_partitioner.register_edge_index(edge_index=edge_index) - dist_partitioner.register_edge_features(edge_features=edge_features) del edge_index + if input_data_strategy == InputDataStrategy.REGISTER_ALL_ENTITIES_SEPARATELY: + dist_partitioner.register_edge_features(edge_features=edge_features) del edge_features + if rank_to_edge_weights is not None and rank in rank_to_edge_weights: + dist_partitioner.register_edge_weights( + edge_weights=rank_to_edge_weights[rank] + ) ( output_edge_index, output_edge_features, diff --git a/tests/test_assets/distributed/test_dataset.py b/tests/test_assets/distributed/test_dataset.py index bab8afdf4..eb0b852c5 100644 --- a/tests/test_assets/distributed/test_dataset.py +++ b/tests/test_assets/distributed/test_dataset.py @@ -67,6 +67,7 @@ def create_homogeneous_dataset( node_features: Optional[torch.Tensor] = None, edge_features: Optional[torch.Tensor] = None, node_labels: Optional[torch.Tensor] = None, + edge_weights: Optional[torch.Tensor] = None, rank: int = 0, world_size: int = 1, edge_dir: Literal["in", "out"] = "out", @@ -74,13 +75,14 @@ def create_homogeneous_dataset( """Create a homogeneous test dataset. Creates a single-partition DistDataset with the specified edge index, node features, - edge features, and node labels. + edge features, node labels, and optional per-edge sampling weights. Args: edge_index: COO format edge index [2, num_edges]. node_features: Node feature tensor [num_nodes, feature_dim], or None. edge_features: Edge feature tensor [num_edges, feature_dim], or None. node_labels: Node label tensor [num_nodes, label_dim], or None. + edge_weights: 1D per-edge sampling weight tensor [num_edges], or None. rank: Rank of the current process. Defaults to 0. world_size: Total number of processes. Defaults to 1. edge_dir: Edge direction ("in" or "out"). Defaults to "out". @@ -129,6 +131,7 @@ def create_homogeneous_dataset( partitioned_edge_index=GraphPartitionData( edge_index=edge_index, edge_ids=None, + weights=edge_weights, ), partitioned_node_features=partitioned_node_features, partitioned_edge_features=partitioned_edge_features, diff --git a/tests/unit/common/data/dataloaders_test.py b/tests/unit/common/data/dataloaders_test.py index 569ee800a..bb7bad002 100644 --- a/tests/unit/common/data/dataloaders_test.py +++ b/tests/unit/common/data/dataloaders_test.py @@ -16,6 +16,10 @@ TFRecordDataLoader, _get_labels_from_features, ) +from gigl.common.data.load_torch_tensors import ( + SerializedGraphMetadata, + load_torch_tensors_from_tf_record, +) from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.data_preprocessor.lib.types import FeatureSpecDict from gigl.src.mocking.lib.versioning import ( @@ -487,6 +491,301 @@ def test_load_labels_from_pb(self): self.assertEqual(feature_tensor.size(1), node_metadata.feature_dim) self.assertEqual(label_tensor.size(1), len(node_metadata.label_keys)) + def test_load_edge_weights_from_tf_record(self): + """Edge weight column is extracted from edge features and returned separately. + + Mirrors test_load_labels_from_pb for edge sampling weights. + """ + n_edges = 5 + edge_feature_vals = [float(i * 10) for i in range(n_edges)] + edge_weight_vals = [float(i + 1) for i in range(n_edges)] + + node_dir = tempfile.TemporaryDirectory() + edge_dir = tempfile.TemporaryDirectory() + self.addCleanup(node_dir.cleanup) + self.addCleanup(edge_dir.cleanup) + + with tf.io.TFRecordWriter( + str(Path(node_dir.name) / "nodes.tfrecord") + ) as writer: + for i in range(3): + writer.write( + tf.train.Example( + features=tf.train.Features( + feature={ + "node_id": tf.train.Feature( + int64_list=tf.train.Int64List(value=[i]) + ), + } + ) + ).SerializeToString() + ) + + with tf.io.TFRecordWriter( + str(Path(edge_dir.name) / "edges.tfrecord") + ) as writer: + for i in range(n_edges): + writer.write( + tf.train.Example( + features=tf.train.Features( + feature={ + "src_id": tf.train.Feature( + int64_list=tf.train.Int64List(value=[i % 3]) + ), + "dst_id": tf.train.Feature( + int64_list=tf.train.Int64List(value=[(i + 1) % 3]) + ), + "edge_feature": tf.train.Feature( + float_list=tf.train.FloatList( + value=[edge_feature_vals[i]] + ) + ), + "edge_weight": tf.train.Feature( + float_list=tf.train.FloatList( + value=[edge_weight_vals[i]] + ) + ), + } + ) + ).SerializeToString() + ) + + loader = TFRecordDataLoader(rank=0, world_size=1) + loaded = load_torch_tensors_from_tf_record( + tf_record_dataloader=loader, + serialized_graph_metadata=SerializedGraphMetadata( + node_entity_info=SerializedTFRecordInfo( + tfrecord_uri_prefix=UriFactory.create_uri(node_dir.name), + feature_spec={"node_id": tf.io.FixedLenFeature([], tf.int64)}, + feature_keys=[], + feature_dim=0, + entity_key="node_id", + tfrecord_uri_pattern="nodes.tfrecord", + ), + edge_entity_info=SerializedTFRecordInfo( + tfrecord_uri_prefix=UriFactory.create_uri(edge_dir.name), + feature_spec={ + "src_id": tf.io.FixedLenFeature([], tf.int64), + "dst_id": tf.io.FixedLenFeature([], tf.int64), + "edge_feature": tf.io.FixedLenFeature([], tf.float32), + "edge_weight": tf.io.FixedLenFeature([], tf.float32), + }, + feature_keys=["edge_feature", "edge_weight"], + feature_dim=2, + entity_key=("src_id", "dst_id"), + tfrecord_uri_pattern="edges.tfrecord", + ), + ), + should_load_tensors_in_parallel=False, + edge_tf_dataset_options=TFDatasetOptions(deterministic=True), + weight_edge_feat_name="edge_weight", + ) + + self.assertIsNotNone(loaded.edge_weights) + self.assertIsNotNone(loaded.edge_features) + assert isinstance(loaded.edge_weights, torch.Tensor) + assert isinstance(loaded.edge_features, torch.Tensor) + # Weight column extracted as 1D tensor; feature matrix has one column remaining + self.assertEqual(loaded.edge_weights.shape, torch.Size([n_edges])) + self.assertEqual(loaded.edge_features.shape, torch.Size([n_edges, 1])) + # Values match the written data (sort since loading order is not guaranteed) + assert_close( + loaded.edge_weights.sort().values, + torch.tensor(sorted(edge_weight_vals), dtype=torch.float32), + ) + assert_close( + loaded.edge_features[:, 0].sort().values, + torch.tensor(sorted(edge_feature_vals), dtype=torch.float32), + ) + + def test_load_edge_weights_multidim_feature(self): + """Weight column offset is correct when a preceding feature key is multi-dimensional. + + A preceding feature with shape=[2] contributes 2 columns to the concatenated tensor, + so the weight column lives at offset 2, not offset 1. + """ + n_edges = 3 + node_dir = tempfile.TemporaryDirectory() + edge_dir = tempfile.TemporaryDirectory() + self.addCleanup(node_dir.cleanup) + self.addCleanup(edge_dir.cleanup) + + with tf.io.TFRecordWriter( + str(Path(node_dir.name) / "nodes.tfrecord") + ) as writer: + for i in range(3): + writer.write( + tf.train.Example( + features=tf.train.Features( + feature={ + "node_id": tf.train.Feature( + int64_list=tf.train.Int64List(value=[i]) + ) + } + ) + ).SerializeToString() + ) + + # Each edge has: src_id, dst_id, a 2-dim embedding, and a scalar weight. + embedding_vals = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] + weight_vals = [0.1, 0.2, 0.3] + with tf.io.TFRecordWriter( + str(Path(edge_dir.name) / "edges.tfrecord") + ) as writer: + for i in range(n_edges): + writer.write( + tf.train.Example( + features=tf.train.Features( + feature={ + "src_id": tf.train.Feature( + int64_list=tf.train.Int64List(value=[i % 3]) + ), + "dst_id": tf.train.Feature( + int64_list=tf.train.Int64List(value=[(i + 1) % 3]) + ), + "embedding": tf.train.Feature( + float_list=tf.train.FloatList( + value=embedding_vals[i] + ) + ), + "edge_weight": tf.train.Feature( + float_list=tf.train.FloatList( + value=[weight_vals[i]] + ) + ), + } + ) + ).SerializeToString() + ) + + loader = TFRecordDataLoader(rank=0, world_size=1) + loaded = load_torch_tensors_from_tf_record( + tf_record_dataloader=loader, + serialized_graph_metadata=SerializedGraphMetadata( + node_entity_info=SerializedTFRecordInfo( + tfrecord_uri_prefix=UriFactory.create_uri(node_dir.name), + feature_spec={"node_id": tf.io.FixedLenFeature([], tf.int64)}, + feature_keys=[], + feature_dim=0, + entity_key="node_id", + tfrecord_uri_pattern="nodes.tfrecord", + ), + edge_entity_info=SerializedTFRecordInfo( + tfrecord_uri_prefix=UriFactory.create_uri(edge_dir.name), + feature_spec={ + "src_id": tf.io.FixedLenFeature([], tf.int64), + "dst_id": tf.io.FixedLenFeature([], tf.int64), + "embedding": tf.io.FixedLenFeature([2], tf.float32), + "edge_weight": tf.io.FixedLenFeature([], tf.float32), + }, + feature_keys=["embedding", "edge_weight"], + feature_dim=3, + entity_key=("src_id", "dst_id"), + tfrecord_uri_pattern="edges.tfrecord", + ), + ), + should_load_tensors_in_parallel=False, + edge_tf_dataset_options=TFDatasetOptions(deterministic=True), + weight_edge_feat_name="edge_weight", + ) + + assert isinstance(loaded.edge_weights, torch.Tensor) + assert isinstance(loaded.edge_features, torch.Tensor) + # Weight is a 1-D scalar per edge; embedding columns remain + self.assertEqual(loaded.edge_weights.shape, torch.Size([n_edges])) + self.assertEqual(loaded.edge_features.shape, torch.Size([n_edges, 2])) + assert_close( + loaded.edge_weights.sort().values, + torch.tensor(sorted(weight_vals), dtype=torch.float32), + ) + + def test_load_edge_weights_only_feature(self): + """When the weight column is the only edge feature, edge_features is None after extraction.""" + n_edges = 3 + node_dir = tempfile.TemporaryDirectory() + edge_dir = tempfile.TemporaryDirectory() + self.addCleanup(node_dir.cleanup) + self.addCleanup(edge_dir.cleanup) + + with tf.io.TFRecordWriter( + str(Path(node_dir.name) / "nodes.tfrecord") + ) as writer: + for i in range(3): + writer.write( + tf.train.Example( + features=tf.train.Features( + feature={ + "node_id": tf.train.Feature( + int64_list=tf.train.Int64List(value=[i]) + ) + } + ) + ).SerializeToString() + ) + + weight_vals = [0.5, 1.0, 1.5] + with tf.io.TFRecordWriter( + str(Path(edge_dir.name) / "edges.tfrecord") + ) as writer: + for i in range(n_edges): + writer.write( + tf.train.Example( + features=tf.train.Features( + feature={ + "src_id": tf.train.Feature( + int64_list=tf.train.Int64List(value=[i % 3]) + ), + "dst_id": tf.train.Feature( + int64_list=tf.train.Int64List(value=[(i + 1) % 3]) + ), + "edge_weight": tf.train.Feature( + float_list=tf.train.FloatList( + value=[weight_vals[i]] + ) + ), + } + ) + ).SerializeToString() + ) + + loader = TFRecordDataLoader(rank=0, world_size=1) + loaded = load_torch_tensors_from_tf_record( + tf_record_dataloader=loader, + serialized_graph_metadata=SerializedGraphMetadata( + node_entity_info=SerializedTFRecordInfo( + tfrecord_uri_prefix=UriFactory.create_uri(node_dir.name), + feature_spec={"node_id": tf.io.FixedLenFeature([], tf.int64)}, + feature_keys=[], + feature_dim=0, + entity_key="node_id", + tfrecord_uri_pattern="nodes.tfrecord", + ), + edge_entity_info=SerializedTFRecordInfo( + tfrecord_uri_prefix=UriFactory.create_uri(edge_dir.name), + feature_spec={ + "src_id": tf.io.FixedLenFeature([], tf.int64), + "dst_id": tf.io.FixedLenFeature([], tf.int64), + "edge_weight": tf.io.FixedLenFeature([], tf.float32), + }, + feature_keys=["edge_weight"], + feature_dim=1, + entity_key=("src_id", "dst_id"), + tfrecord_uri_pattern="edges.tfrecord", + ), + ), + should_load_tensors_in_parallel=False, + edge_tf_dataset_options=TFDatasetOptions(deterministic=True), + weight_edge_feat_name="edge_weight", + ) + + assert isinstance(loaded.edge_weights, torch.Tensor) + self.assertIsNone(loaded.edge_features) + self.assertEqual(loaded.edge_weights.shape, torch.Size([n_edges])) + assert_close( + loaded.edge_weights.sort().values, + torch.tensor(sorted(weight_vals), dtype=torch.float32), + ) + @parameterized.expand( [ param( diff --git a/tests/unit/distributed/distributed_weighted_sampling_test.py b/tests/unit/distributed/distributed_weighted_sampling_test.py new file mode 100644 index 000000000..f51f1bf53 --- /dev/null +++ b/tests/unit/distributed/distributed_weighted_sampling_test.py @@ -0,0 +1,852 @@ +"""Distributed integration tests for weighted edge sampling. + +Covers two surfaces: + 1. DistPartitioner correctly partitions registered edge weights (weights land on + the right rank and match the expected values). + 2. DistNeighborLoader with with_weight=True never traverses weight=0 edges — + verified by encoding node type into features (hub=2.0, good=1.0, bad=0.0) + and asserting no bad node appears in any sampled subgraph. +""" + +from collections.abc import Mapping +from typing import MutableMapping + +import torch +import torch.multiprocessing as mp +from absl.testing import absltest +from graphlearn_torch.distributed import shutdown_rpc +from torch.multiprocessing import Manager +from torch_geometric.data import Data, HeteroData + +from gigl.distributed import DistPartitioner +from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.dist_range_partitioner import DistRangePartitioner +from gigl.distributed.distributed_neighborloader import DistNeighborLoader +from gigl.distributed.utils.networking import get_free_port +from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation +from gigl.types.graph import ( + FeaturePartitionData, + GraphPartitionData, + PartitionOutput, +) +from tests.test_assets.distributed.constants import ( + MOCKED_NUM_PARTITIONS, + MOCKED_U2U_EDGE_INDEX_ON_RANK_ONE, + MOCKED_U2U_EDGE_INDEX_ON_RANK_ZERO, + RANK_TO_MOCKED_GRAPH, + USER_TO_ITEM_EDGE_TYPE, + USER_TO_USER_EDGE_TYPE, +) +from tests.test_assets.distributed.run_distributed_partitioner import ( + InputDataStrategy, + run_distributed_partitioner, +) +from tests.test_assets.distributed.utils import create_test_process_group +from tests.test_assets.test_case import TestCase + +_USER = NodeType("user") +_ITEM = NodeType("item") +_USER_TO_ITEM = EdgeType(_USER, Relation("to"), _ITEM) +_ITEM_TO_USER = EdgeType(_ITEM, Relation("to"), _USER) + + +# --------------------------------------------------------------------------- +# Graph builders +# --------------------------------------------------------------------------- + + +def _build_homogeneous_bipartite_weight_graph() -> tuple[ + PartitionOutput, int, int, int +]: + """Build a homogeneous graph with hub, good, and bad nodes. + + Graph structure: + - 10 hub nodes (0..9): used as seed nodes; feature value = 2.0 + - 50 good nodes (10..59): reachable from hubs via weight=1 edges; feature = 1.0 + - 40 bad nodes (60..99): reachable from hubs via weight=0 edges; feature = 0.0 + - Each good node also has 5 outgoing weight=1 edges to nearby good nodes + (ring topology, for 2nd-hop sampling). + + With weighted sampling only good nodes should ever appear as sampled + neighbors — weight=0 edges to bad nodes must never be traversed. + + Returns: + (partition_output, n_hub, n_good, n_bad) + """ + n_hub = 10 + n_good = 50 + n_bad = 40 + n = n_hub + n_good + n_bad # 100 + + hub_ids = torch.arange(n_hub) + good_ids = torch.arange(n_hub, n_hub + n_good) + bad_ids = torch.arange(n_hub + n_good, n) + + # Hub → Good: weight=1 + hub_good_src = hub_ids.repeat_interleave(n_good) + hub_good_dst = good_ids.repeat(n_hub) + hub_good_w = torch.ones(n_hub * n_good) + + # Hub → Bad: weight=0 + hub_bad_src = hub_ids.repeat_interleave(n_bad) + hub_bad_dst = bad_ids.repeat(n_hub) + hub_bad_w = torch.zeros(n_hub * n_bad) + + # Good → Good: ring with 5 outgoing edges per node, weight=1 (2nd-hop targets) + connections_per_good = 5 + good_src = good_ids.repeat_interleave(connections_per_good) + # Row i of [connections_per_good, n_good].T gives neighbors of good_ids[i] + good_dst = torch.stack( + [torch.roll(good_ids, -j) for j in range(1, connections_per_good + 1)] + ).T.reshape(-1) + good_w = torch.ones(n_good * connections_per_good) + + edge_src = torch.cat([hub_good_src, hub_bad_src, good_src]) + edge_dst = torch.cat([hub_good_dst, hub_bad_dst, good_dst]) + weights = torch.cat([hub_good_w, hub_bad_w, good_w]) + edge_index = torch.stack([edge_src, edge_dst]) + n_edges = edge_src.shape[0] + + # Feature encodes node type: hub=2.0, good=1.0, bad=0.0 + node_feats = torch.cat( + [ + torch.full((n_hub, 1), 2.0), + torch.full((n_good, 1), 1.0), + torch.full((n_bad, 1), 0.0), + ] + ) + + partition_output = PartitionOutput( + node_partition_book=torch.zeros(n), + edge_partition_book=torch.zeros(n_edges), + partitioned_edge_index=GraphPartitionData( + edge_index=edge_index, + edge_ids=None, + weights=weights, + ), + partitioned_node_features=FeaturePartitionData( + feats=node_feats, + ids=torch.arange(n), + ), + partitioned_edge_features=None, + partitioned_positive_labels=None, + partitioned_negative_labels=None, + partitioned_node_labels=None, + ) + return partition_output, n_hub, n_good, n_bad + + +def _build_heterogeneous_bipartite_weight_graph() -> tuple[ + PartitionOutput, int, int, int +]: + """Build a heterogeneous (user/item) graph with good and bad item nodes. + + Graph structure: + - 10 user nodes (0..9): seed nodes; user feature = 2.0 + - 60 item nodes total: + - Items 0..39: good, reachable from users via weight=1 edges; feature = 1.0 + - Items 40..59: bad, reachable from users via weight=0 edges; feature = 0.0 + - Good items also have weight=1 edges back to all users (for 2nd-hop). + + With weighted sampling only good item nodes should ever appear as sampled + item neighbors. + + Returns: + (partition_output, n_user, n_good_item, n_bad_item) + """ + n_user = 10 + n_good_item = 40 + n_bad_item = 20 + n_item = n_good_item + n_bad_item # 60 + + user_ids = torch.arange(n_user) + good_item_ids = torch.arange(n_good_item) + bad_item_ids = torch.arange(n_good_item, n_item) + + # User → Good Item: weight=1 + u2gi_src = user_ids.repeat_interleave(n_good_item) + u2gi_dst = good_item_ids.repeat(n_user) + u2gi_w = torch.ones(n_user * n_good_item) + + # User → Bad Item: weight=0 + u2bi_src = user_ids.repeat_interleave(n_bad_item) + u2bi_dst = bad_item_ids.repeat(n_user) + u2bi_w = torch.zeros(n_user * n_bad_item) + + # Good Item → User: weight=1 (2nd-hop back to users) + gi2u_src = good_item_ids.repeat_interleave(n_user) + gi2u_dst = user_ids.repeat(n_good_item) + gi2u_w = torch.ones(n_good_item * n_user) + + u2i_src = torch.cat([u2gi_src, u2bi_src]) + u2i_dst = torch.cat([u2gi_dst, u2bi_dst]) + u2i_w = torch.cat([u2gi_w, u2bi_w]) + n_u2i_edges = u2i_src.shape[0] + + user_feats = torch.full((n_user, 1), 2.0) + # Item feature encodes type: good=1.0, bad=0.0 + item_feats = torch.cat( + [ + torch.full((n_good_item, 1), 1.0), + torch.full((n_bad_item, 1), 0.0), + ] + ) + + partition_output = PartitionOutput( + node_partition_book={ + _USER: torch.zeros(n_user), + _ITEM: torch.zeros(n_item), + }, + edge_partition_book={ + _USER_TO_ITEM: torch.zeros(n_u2i_edges), + _ITEM_TO_USER: torch.zeros(gi2u_src.shape[0]), + }, + partitioned_edge_index={ + _USER_TO_ITEM: GraphPartitionData( + edge_index=torch.stack([u2i_src, u2i_dst]), + edge_ids=None, + weights=u2i_w, + ), + _ITEM_TO_USER: GraphPartitionData( + edge_index=torch.stack([gi2u_src, gi2u_dst]), + edge_ids=None, + weights=gi2u_w, + ), + }, + partitioned_node_features={ + _USER: FeaturePartitionData(feats=user_feats, ids=torch.arange(n_user)), + _ITEM: FeaturePartitionData(feats=item_feats, ids=torch.arange(n_item)), + }, + partitioned_edge_features=None, + partitioned_positive_labels=None, + partitioned_negative_labels=None, + partitioned_node_labels=None, + ) + return partition_output, n_user, n_good_item, n_bad_item + + +def _build_heterogeneous_bipartite_partial_weight_graph() -> tuple[ + PartitionOutput, int, int, int +]: + """Same graph as _build_heterogeneous_bipartite_weight_graph but ITEM_TO_USER is unweighted. + + This validates that partial-weight heterogeneous graphs work correctly: + the weighted U2I edge type still respects weights (bad items are unreachable) + while the unweighted I2U edge type samples uniformly without crashing. + """ + n_user = 10 + n_good_item = 40 + n_bad_item = 20 + n_item = n_good_item + n_bad_item + + user_ids = torch.arange(n_user) + good_item_ids = torch.arange(n_good_item) + bad_item_ids = torch.arange(n_good_item, n_item) + + u2gi_src = user_ids.repeat_interleave(n_good_item) + u2gi_dst = good_item_ids.repeat(n_user) + u2gi_w = torch.ones(n_user * n_good_item) + + u2bi_src = user_ids.repeat_interleave(n_bad_item) + u2bi_dst = bad_item_ids.repeat(n_user) + u2bi_w = torch.zeros(n_user * n_bad_item) + + gi2u_src = good_item_ids.repeat_interleave(n_user) + gi2u_dst = user_ids.repeat(n_good_item) + + u2i_src = torch.cat([u2gi_src, u2bi_src]) + u2i_dst = torch.cat([u2gi_dst, u2bi_dst]) + u2i_w = torch.cat([u2gi_w, u2bi_w]) + n_u2i_edges = u2i_src.shape[0] + + user_feats = torch.full((n_user, 1), 2.0) + item_feats = torch.cat( + [ + torch.full((n_good_item, 1), 1.0), + torch.full((n_bad_item, 1), 0.0), + ] + ) + + partition_output = PartitionOutput( + node_partition_book={ + _USER: torch.zeros(n_user), + _ITEM: torch.zeros(n_item), + }, + edge_partition_book={ + _USER_TO_ITEM: torch.zeros(n_u2i_edges), + _ITEM_TO_USER: torch.zeros(gi2u_src.shape[0]), + }, + partitioned_edge_index={ + _USER_TO_ITEM: GraphPartitionData( + edge_index=torch.stack([u2i_src, u2i_dst]), + edge_ids=None, + weights=u2i_w, + ), + _ITEM_TO_USER: GraphPartitionData( + edge_index=torch.stack([gi2u_src, gi2u_dst]), + edge_ids=None, + weights=None, # unweighted — samples uniformly + ), + }, + partitioned_node_features={ + _USER: FeaturePartitionData(feats=user_feats, ids=torch.arange(n_user)), + _ITEM: FeaturePartitionData(feats=item_feats, ids=torch.arange(n_item)), + }, + partitioned_edge_features=None, + partitioned_positive_labels=None, + partitioned_negative_labels=None, + partitioned_node_labels=None, + ) + return partition_output, n_user, n_good_item, n_bad_item + + +# --------------------------------------------------------------------------- +# Subprocess functions — must accept local_rank as first arg (mp.spawn) +# --------------------------------------------------------------------------- + + +def _run_weighted_sampling_correctness_homogeneous( + _: int, + dataset: DistDataset, + n_hub: int, +) -> None: + """Subprocess: verifies weight=0 edges are never traversed in homogeneous graph. + + Seeds are the hub nodes only. Node features encode the type: + hub=2.0, good=1.0, bad=0.0. Any subgraph batch containing a bad node + (feature==0.0) means a weight=0 edge was sampled — a test failure. + """ + create_test_process_group() + loader = DistNeighborLoader( + dataset=dataset, + input_nodes=torch.arange(n_hub), + num_neighbors=[10, 5], + with_weight=True, + pin_memory_device=torch.device("cpu"), + ) + count = 0 + for datum in loader: + assert isinstance(datum, Data), f"Expected Data, got {type(datum)}" + assert datum.x is not None, "Node features missing from sampled subgraph" + # Bad nodes have feature 0.0; hub and good nodes have feature > 0. + bad_mask = datum.x[:, 0] == 0.0 + assert not bad_mask.any(), ( + f"weight=0 edge was sampled: bad node(s) found in subgraph. " + f"Features of bad nodes: {datum.x[bad_mask].squeeze().tolist()}" + ) + count += 1 + assert count == n_hub, f"Expected {n_hub} batches (one per hub seed), got {count}" + shutdown_rpc() + + +def _run_weighted_sampling_correctness_heterogeneous( + _: int, + dataset: DistDataset, + n_user: int, +) -> None: + """Subprocess: verifies weight=0 edges are never traversed in heterogeneous graph. + + Seeds are all user nodes. Item features encode type: good=1.0, bad=0.0. + Any batch containing a bad item node means a weight=0 edge was sampled. + """ + create_test_process_group() + assert isinstance(dataset.node_ids, Mapping) + loader = DistNeighborLoader( + dataset=dataset, + input_nodes=(_USER, dataset.node_ids[_USER]), + num_neighbors=[10, 5], + with_weight=True, + pin_memory_device=torch.device("cpu"), + ) + count = 0 + for datum in loader: + assert isinstance(datum, HeteroData), f"Expected HeteroData, got {type(datum)}" + if _ITEM in datum.node_types: + item_x = datum[_ITEM].x + assert item_x is not None, "Item features missing from sampled subgraph" + bad_mask = item_x[:, 0] == 0.0 + assert not bad_mask.any(), ( + f"weight=0 edge was sampled: bad item node(s) found. " + f"Features of bad items: {item_x[bad_mask].squeeze().tolist()}" + ) + count += 1 + assert count == n_user, ( + f"Expected {n_user} batches (one per user seed), got {count}" + ) + shutdown_rpc() + + +# --------------------------------------------------------------------------- +# Test classes +# --------------------------------------------------------------------------- + + +class WeightedEdgePartitionerTestCase(TestCase): + """Tests that DistPartitioner correctly partitions registered edge weights.""" + + def setUp(self) -> None: + self._master_ip_address = "localhost" + + def test_homogeneous_weights_partitioned_correctly(self) -> None: + """Edge weights (= src_node_id / 10.0) land on the correct rank after partitioning. + + The mocked graph has edges with source nodes 0–3 on rank 0 and 4–7 on rank 1. + Weights are set to src_node_id / 10.0, mirroring the existing edge-feature + convention. After partitioning by source node each rank should hold only its + own weights, and each weight should equal the corresponding global edge ID * 0.1. + """ + master_port = get_free_port() + manager = Manager() + output_dict: MutableMapping[int, PartitionOutput] = manager.dict() + + rank_to_edge_weights = { + 0: MOCKED_U2U_EDGE_INDEX_ON_RANK_ZERO[0].float() / 10.0, + 1: MOCKED_U2U_EDGE_INDEX_ON_RANK_ONE[0].float() / 10.0, + } + + mp.spawn( + run_distributed_partitioner, + args=( + output_dict, + False, # is_heterogeneous + RANK_TO_MOCKED_GRAPH, + True, # should_assign_edges_by_src_node + self._master_ip_address, + master_port, + InputDataStrategy.REGISTER_ALL_ENTITIES_SEPARATELY, + DistPartitioner, + rank_to_edge_weights, + ), + nprocs=MOCKED_NUM_PARTITIONS, + join=True, + ) + + for rank, partition_output in output_dict.items(): + partitioned_edge_index = partition_output.partitioned_edge_index + self.assertIsInstance(partitioned_edge_index, GraphPartitionData) + assert isinstance(partitioned_edge_index, GraphPartitionData) + + weights = partitioned_edge_index.weights + self.assertIsNotNone( + weights, + msg=f"Rank {rank}: expected weights in GraphPartitionData, got None", + ) + assert weights is not None + + edge_ids = partitioned_edge_index.edge_ids + self.assertIsNotNone( + edge_ids, + msg=f"Rank {rank}: edge_ids must be present when weights are registered", + ) + assert edge_ids is not None + + self.assertEqual( + weights.shape, + edge_ids.shape, + msg=f"Rank {rank}: weights and edge_ids must have the same length", + ) + + # weight for each edge == its global edge_id * 0.1 (i.e. src_node_id / 10.0). + # Sort both so the comparison is order-independent. + expected_weights = edge_ids.float() * 0.1 + torch.testing.assert_close( + weights.sort().values, + expected_weights.sort().values, + msg=f"Rank {rank}: partitioned weights do not match expected src_node_id / 10.0", + ) + + def test_features_only_weights_are_none(self) -> None: + """Features only (no weights): weights must be None, edge_ids and features present.""" + master_port = get_free_port() + manager = Manager() + output_dict: MutableMapping[int, PartitionOutput] = manager.dict() + + mp.spawn( + run_distributed_partitioner, + args=( + output_dict, + False, + RANK_TO_MOCKED_GRAPH, + True, + self._master_ip_address, + master_port, + InputDataStrategy.REGISTER_ALL_ENTITIES_SEPARATELY, + DistPartitioner, + None, + ), + nprocs=MOCKED_NUM_PARTITIONS, + join=True, + ) + + for rank, partition_output in output_dict.items(): + gpd = partition_output.partitioned_edge_index + self.assertIsInstance(gpd, GraphPartitionData) + assert isinstance(gpd, GraphPartitionData) + self.assertIsNone( + gpd.weights, + msg=f"Rank {rank}: weights must be None when not registered", + ) + self.assertIsNotNone( + gpd.edge_ids, + msg=f"Rank {rank}: edge_ids must be present when features are registered", + ) + self.assertIsNotNone( + partition_output.partitioned_edge_features, + msg=f"Rank {rank}: expected features to be partitioned", + ) + + def test_weights_only_no_features_partitioned_correctly(self) -> None: + """Weights without features: feature part is None, weights have correct values.""" + master_port = get_free_port() + manager = Manager() + output_dict: MutableMapping[int, PartitionOutput] = manager.dict() + + rank_to_edge_weights = { + 0: MOCKED_U2U_EDGE_INDEX_ON_RANK_ZERO[0].float() / 10.0, + 1: MOCKED_U2U_EDGE_INDEX_ON_RANK_ONE[0].float() / 10.0, + } + + mp.spawn( + run_distributed_partitioner, + args=( + output_dict, + False, + RANK_TO_MOCKED_GRAPH, + True, + self._master_ip_address, + master_port, + InputDataStrategy.REGISTER_EDGE_WEIGHTS_WITHOUT_EDGE_FEATURES, + DistPartitioner, + rank_to_edge_weights, + ), + nprocs=MOCKED_NUM_PARTITIONS, + join=True, + ) + + for rank, partition_output in output_dict.items(): + gpd = partition_output.partitioned_edge_index + self.assertIsInstance(gpd, GraphPartitionData) + assert isinstance(gpd, GraphPartitionData) + self.assertIsNone( + partition_output.partitioned_edge_features, + msg=f"Rank {rank}: features must be None when not registered", + ) + + weights = gpd.weights + self.assertIsNotNone( + weights, msg=f"Rank {rank}: expected weights to be partitioned" + ) + assert weights is not None + + edge_ids = gpd.edge_ids + self.assertIsNotNone( + edge_ids, + msg=f"Rank {rank}: edge_ids must be present when weights are registered", + ) + assert edge_ids is not None + + self.assertEqual(weights.shape, edge_ids.shape) + expected_weights = edge_ids.float() * 0.1 + torch.testing.assert_close( + weights.sort().values, + expected_weights.sort().values, + msg=f"Rank {rank}: weights do not match src_node_id / 10.0", + ) + + def test_neither_features_nor_weights_gives_none_edge_ids(self) -> None: + """No features, no weights: edge_ids and weights must both be None.""" + master_port = get_free_port() + manager = Manager() + output_dict: MutableMapping[int, PartitionOutput] = manager.dict() + + mp.spawn( + run_distributed_partitioner, + args=( + output_dict, + False, + RANK_TO_MOCKED_GRAPH, + True, + self._master_ip_address, + master_port, + InputDataStrategy.REGISTER_MINIMAL_ENTITIES_SEPARATELY, + DistPartitioner, + ), + nprocs=MOCKED_NUM_PARTITIONS, + join=True, + ) + + for rank, partition_output in output_dict.items(): + gpd = partition_output.partitioned_edge_index + self.assertIsInstance(gpd, GraphPartitionData) + assert isinstance(gpd, GraphPartitionData) + self.assertIsNone( + gpd.weights, + msg=f"Rank {rank}: weights must be None when not registered", + ) + self.assertIsNone( + gpd.edge_ids, + msg=f"Rank {rank}: edge_ids must be None when neither features nor weights are registered", + ) + self.assertIsNone( + partition_output.partitioned_edge_features, + msg=f"Rank {rank}: features must be None", + ) + + def test_features_and_weights_produce_consistent_edge_ids(self) -> None: + """Both registered: GraphPartitionData.edge_ids must equal FeaturePartitionData.ids.""" + master_port = get_free_port() + manager = Manager() + output_dict: MutableMapping[int, PartitionOutput] = manager.dict() + + rank_to_edge_weights = { + 0: MOCKED_U2U_EDGE_INDEX_ON_RANK_ZERO[0].float() / 10.0, + 1: MOCKED_U2U_EDGE_INDEX_ON_RANK_ONE[0].float() / 10.0, + } + + mp.spawn( + run_distributed_partitioner, + args=( + output_dict, + False, + RANK_TO_MOCKED_GRAPH, + True, + self._master_ip_address, + master_port, + InputDataStrategy.REGISTER_ALL_ENTITIES_SEPARATELY, + DistPartitioner, + rank_to_edge_weights, + ), + nprocs=MOCKED_NUM_PARTITIONS, + join=True, + ) + + for rank, partition_output in output_dict.items(): + gpd = partition_output.partitioned_edge_index + assert isinstance(gpd, GraphPartitionData) + feat_part = partition_output.partitioned_edge_features + self.assertIsInstance(feat_part, FeaturePartitionData) + assert isinstance(feat_part, FeaturePartitionData) + + self.assertIsNotNone(gpd.edge_ids) + self.assertIsNotNone(gpd.weights) + assert gpd.edge_ids is not None + + torch.testing.assert_close( + gpd.edge_ids, + feat_part.ids, + msg=f"Rank {rank}: GraphPartitionData.edge_ids must equal FeaturePartitionData.ids", + ) + + def test_heterogeneous_partial_weights_by_edge_type(self) -> None: + """Heterogeneous: edge type with weights has them; edge type without weights has None.""" + master_port = get_free_port() + manager = Manager() + output_dict: MutableMapping[int, PartitionOutput] = manager.dict() + + rank_to_edge_weights = { + 0: { + USER_TO_USER_EDGE_TYPE: MOCKED_U2U_EDGE_INDEX_ON_RANK_ZERO[0].float() + / 10.0 + }, + 1: { + USER_TO_USER_EDGE_TYPE: MOCKED_U2U_EDGE_INDEX_ON_RANK_ONE[0].float() + / 10.0 + }, + } + + mp.spawn( + run_distributed_partitioner, + args=( + output_dict, + True, # is_heterogeneous + RANK_TO_MOCKED_GRAPH, + True, + self._master_ip_address, + master_port, + InputDataStrategy.REGISTER_ALL_ENTITIES_SEPARATELY, + DistPartitioner, + rank_to_edge_weights, + ), + nprocs=MOCKED_NUM_PARTITIONS, + join=True, + ) + + for rank, partition_output in output_dict.items(): + partitioned_edge_index = partition_output.partitioned_edge_index + self.assertIsInstance(partitioned_edge_index, dict) + assert isinstance(partitioned_edge_index, dict) + + u2u_gpd = partitioned_edge_index[USER_TO_USER_EDGE_TYPE] + u2i_gpd = partitioned_edge_index[USER_TO_ITEM_EDGE_TYPE] + + self.assertIsNotNone( + u2u_gpd.weights, + msg=f"Rank {rank}: U2U edge type should have weights", + ) + self.assertIsNone( + u2i_gpd.weights, + msg=f"Rank {rank}: U2I edge type should not have weights", + ) + # U2U edge_ids present (has both features and weights). + self.assertIsNotNone( + u2u_gpd.edge_ids, + msg=f"Rank {rank}: U2U edge_ids must be present", + ) + # U2I has neither features nor weights, so edge_ids is None. + self.assertIsNone( + u2i_gpd.edge_ids, + msg=f"Rank {rank}: U2I edge_ids must be None", + ) + + def test_range_partitioner_homogeneous_weights_partitioned_correctly(self) -> None: + """DistRangePartitioner: edge weights land on the correct rank after range-based partitioning. + + Mirrors test_homogeneous_weights_partitioned_correctly but uses DistRangePartitioner. + With range-based partitioning, edge_ids are sequential per-rank (0..3 on rank 0, + 4..7 on rank 1), and the registered weights (src_node_id / 10.0) equal + edge_id * 0.1 for this test graph. + """ + master_port = get_free_port() + manager = Manager() + output_dict: MutableMapping[int, PartitionOutput] = manager.dict() + + rank_to_edge_weights = { + 0: MOCKED_U2U_EDGE_INDEX_ON_RANK_ZERO[0].float() / 10.0, + 1: MOCKED_U2U_EDGE_INDEX_ON_RANK_ONE[0].float() / 10.0, + } + + mp.spawn( + run_distributed_partitioner, + args=( + output_dict, + False, # is_heterogeneous + RANK_TO_MOCKED_GRAPH, + True, # should_assign_edges_by_src_node + self._master_ip_address, + master_port, + InputDataStrategy.REGISTER_ALL_ENTITIES_SEPARATELY, + DistRangePartitioner, + rank_to_edge_weights, + ), + nprocs=MOCKED_NUM_PARTITIONS, + join=True, + ) + + for rank, partition_output in output_dict.items(): + partitioned_edge_index = partition_output.partitioned_edge_index + self.assertIsInstance(partitioned_edge_index, GraphPartitionData) + assert isinstance(partitioned_edge_index, GraphPartitionData) + + weights = partitioned_edge_index.weights + self.assertIsNotNone( + weights, + msg=f"Rank {rank}: expected weights in GraphPartitionData, got None", + ) + assert weights is not None + + edge_ids = partitioned_edge_index.edge_ids + self.assertIsNotNone( + edge_ids, + msg=f"Rank {rank}: edge_ids must be present when features are registered", + ) + assert edge_ids is not None + + self.assertEqual( + weights.shape, + edge_ids.shape, + msg=f"Rank {rank}: weights and edge_ids must have the same length", + ) + + expected_weights = edge_ids.float() * 0.1 + torch.testing.assert_close( + weights.sort().values, + expected_weights.sort().values, + msg=f"Rank {rank}: partitioned weights do not match expected src_node_id / 10.0", + ) + + +class DistributedWeightedSamplingTest(TestCase): + """End-to-end correctness tests for DistNeighborLoader with with_weight=True. + + Each test builds a graph with two classes of neighbors: + - "good" neighbors connected by weight=1 edges + - "bad" neighbors connected by weight=0 edges + + Node features encode the class (good=1.0, bad=0.0). After weighted sampling, + no bad node should ever appear in a sampled subgraph — if it does, a weight=0 + edge was traversed, indicating a bug. + + Graph sizes are chosen so that the fanout is strictly smaller than the number + of available good neighbors, ensuring the sampler actively makes choices + (rather than returning all neighbors) and the test is non-trivial. + """ + + def setUp(self) -> None: + super().setUp() + self._world_size = 1 + + def tearDown(self) -> None: + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + super().tearDown() + + def test_weighted_sampling_never_traverses_zero_weight_edges_homogeneous( + self, + ) -> None: + """Homogeneous: weight=0 edges to bad nodes are never traversed. + + Graph: 10 hub seeds, each with 50 good neighbors (weight=1) and 40 bad + neighbors (weight=0). Good nodes have 5 further weight=1 edges for 2nd-hop + sampling. Fanout [10, 5] samples fewer neighbors than available good ones, + so the weighted sampler actively selects from the pool each hop. + """ + partition_output, n_hub, _, _ = _build_homogeneous_bipartite_weight_graph() + dataset = DistDataset(rank=0, world_size=1, edge_dir="out") + dataset.build(partition_output=partition_output) + + mp.spawn( + fn=_run_weighted_sampling_correctness_homogeneous, + args=(dataset, n_hub), + ) + + def test_weighted_sampling_never_traverses_zero_weight_edges_heterogeneous( + self, + ) -> None: + """Heterogeneous: weight=0 user→item edges to bad items are never traversed. + + Graph: 10 user seeds, each with 40 good items (weight=1) and 20 bad items + (weight=0). Good items have weight=1 back-edges to all users for 2nd-hop. + Fanout [10, 5] is smaller than the 40 available good items, so the sampler + actively selects. + """ + partition_output, n_user, _, _ = _build_heterogeneous_bipartite_weight_graph() + dataset = DistDataset(rank=0, world_size=1, edge_dir="out") + dataset.build(partition_output=partition_output) + + mp.spawn( + fn=_run_weighted_sampling_correctness_heterogeneous, + args=(dataset, n_user), + ) + + def test_weighted_sampling_partial_weights_heterogeneous(self) -> None: + """Partial weights: weighted U2I respects weights; unweighted I2U samples uniformly. + + U2I is weighted (good items weight=1, bad items weight=0) — bad items must + never appear. I2U has no weights registered, so it uses uniform sampling. + Verifies that mixing weighted and unweighted edge types in one heterogeneous + graph does not crash and that weighted edges still behave correctly. + """ + partition_output, n_user, _, _ = ( + _build_heterogeneous_bipartite_partial_weight_graph() + ) + dataset = DistDataset(rank=0, world_size=1, edge_dir="out") + dataset.build(partition_output=partition_output) + + mp.spawn( + fn=_run_weighted_sampling_correctness_heterogeneous, + args=(dataset, n_user), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/unit/types_tests/graph_test.py b/tests/unit/types_tests/graph_test.py index fa68667f0..0c1e701f3 100644 --- a/tests/unit/types_tests/graph_test.py +++ b/tests/unit/types_tests/graph_test.py @@ -306,6 +306,37 @@ def test_treat_supervision_edges_as_graph_edges( graph_tensors.edge_index[edge_type], expected_tensor ) + def test_treat_labels_as_edges_converts_homogeneous_edge_weights(self) -> None: + """treat_labels_as_edges() must convert homogeneous edge_weights to heterogeneous form. + + Regression: edge_weights was previously left as a raw tensor while all other fields + were converted to heterogeneous dicts. DistPartitioner rejects that mixed-mode input, + so weighted homogeneous ABLP dataset construction would fail. + """ + weights = torch.tensor([0.5, 1.0]) + graph_tensors = LoadedGraphTensors( + node_ids=torch.tensor([0, 1, 2]), + node_features=torch.tensor([[1.0], [2.0], [3.0]]), + node_labels=None, + edge_index=torch.tensor([[0, 1], [1, 2]]), + edge_features=None, + positive_label=torch.tensor([[0], [2]]), + negative_label=None, + edge_weights=weights, + ) + graph_tensors.treat_labels_as_edges(edge_dir="out") + + self.assertIsInstance( + graph_tensors.edge_weights, + dict, + "edge_weights must be a heterogeneous dict after treat_labels_as_edges()", + ) + assert isinstance(graph_tensors.edge_weights, dict) + self.assertIn(DEFAULT_HOMOGENEOUS_EDGE_TYPE, graph_tensors.edge_weights) + torch.testing.assert_close( + graph_tensors.edge_weights[DEFAULT_HOMOGENEOUS_EDGE_TYPE], weights + ) + def test_select_label_edge_types(self): message_passing_edge_type = DEFAULT_HOMOGENEOUS_EDGE_TYPE edge_types = [