Skip to content
145 changes: 140 additions & 5 deletions gigl/common/data/load_torch_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,63 @@
_ID_FMT = "{entity}_ids"
_FEATURE_FMT = "{entity}_features"
_LABEL_FMT = "{entity}_labels"
_EDGE_WEIGHTS_KEY = "edge_weights"
_NODE_KEY = "node"


def _extract_weight_col(
feat_tensor: torch.Tensor,
feature_keys: list[str],
feature_spec: dict,
col_name: str,
edge_type: EdgeType,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Slice a named weight column out of a feature tensor.

Accounts for multi-dim features: each feature key may contribute more than one column
to ``feat_tensor`` (e.g. ``FixedLenFeature(shape=[16])`` contributes 16 columns).
The weight feature must be a scalar (width 1).

Args:
feat_tensor: Edge feature tensor of shape ``[num_edges, total_feature_cols]``.
feature_keys: Ordered list of feature names matching the columns of ``feat_tensor``.
feature_spec: Feature spec dict mapping feature name to its TF feature spec (used to
determine per-key column widths).
col_name: Name of the column to extract as weights.
edge_type: Edge type (used only in error messages).

Returns:
A tuple ``(weights, trimmed_features)`` where ``weights`` is a 1-D tensor of shape
``[num_edges]`` and ``trimmed_features`` is ``feat_tensor`` with the weight column
removed.

Raises:
ValueError: If ``col_name`` is not in ``feature_keys`` or the weight feature is not
width 1.
"""
if col_name not in feature_keys:
raise ValueError(
f"weight_edge_feat_name '{col_name}' not found in edge feature keys "
f"for edge type {edge_type}: {feature_keys}"
)
key_idx = feature_keys.index(col_name)
col_widths = [
(spec.shape[-1] if spec.shape else 1)
for spec in (feature_spec[k] for k in feature_keys)
]
weight_width = col_widths[key_idx]
if weight_width != 1:
raise ValueError(
f"weight_edge_feat_name '{col_name}' for edge type {edge_type} must be a scalar "
f"feature (width 1), but has width {weight_width}."
)
col_offset = sum(col_widths[:key_idx])
weights = feat_tensor[:, col_offset]
keep_cols = [i for i in range(feat_tensor.shape[1]) if i != col_offset]
trimmed = feat_tensor[:, keep_cols] if keep_cols else None
return weights, trimmed


_EDGE_KEY = "edge"
_POSITIVE_LABEL_KEY = "positive_label"
_NEGATIVE_LABEL_KEY = "negative_label"
Expand Down Expand Up @@ -72,6 +128,7 @@ def _data_loading_process(
],
rank: int,
tf_dataset_options: TFDatasetOptions = TFDatasetOptions(),
weight_edge_feat_name: Optional[Union[str, dict[EdgeType, str]]] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we update the doc string for the new arg?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also it seems a little specific to put the weight_edge_feat_name here when _data_loading_process is kind of a generic function? Or do you think that's fine for now? I'm not sure how else we'd address this.

) -> None:
"""
Spawned multiprocessing.Process which loads homogeneous or heterogeneous information for a specific entity type [node, edge, positive_label, negative_label]
Expand All @@ -89,6 +146,11 @@ def _data_loading_process(
Serialized information for current entity
rank (int): Rank of the current machine
tf_dataset_options (TFDatasetOptions): The options to use when building the dataset.
weight_edge_feat_name (Optional[Union[str, dict[EdgeType, str]]]): Only used when
``entity_type == _EDGE_KEY``. Name of the edge feature column to extract as
sampling weights. Ignored for node, positive_label, and negative_label entities.
Supply a single string for homogeneous graphs or a per-edge-type dict for
heterogeneous graphs.
"""
# We add a try - except clause here to ensure that exceptions are properly circulated back to the parent process
try:
Expand Down Expand Up @@ -117,6 +179,7 @@ def _data_loading_process(
ids: dict[Union[NodeType, EdgeType], torch.Tensor] = {}
features: dict[Union[NodeType, EdgeType], torch.Tensor] = {}
labels: dict[Union[NodeType, EdgeType], torch.Tensor] = {}
weights: dict[Union[NodeType, EdgeType], torch.Tensor] = {}
for (
graph_type,
serialized_entity_tf_record_info,
Expand All @@ -129,14 +192,13 @@ def _data_loading_process(
raise NotImplementedError(
"Label keys are not supported for edge entities"
)
(
entity_ids,
entity_features,
entity_labels,
) = tf_record_dataloader.load_as_torch_tensors(
loaded_entity = tf_record_dataloader.load_as_torch_tensors(
serialized_tf_record_info=serialized_entity_tf_record_info,
tf_dataset_options=tf_dataset_options,
)
entity_ids = loaded_entity.ids
entity_features = loaded_entity.features
entity_labels = loaded_entity.labels
ids[graph_type] = entity_ids
logger.info(
f"Rank {rank} finished loading {entity_type} ids of shape {entity_ids.shape} for graph type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}"
Expand All @@ -161,6 +223,61 @@ def _data_loading_process(
f"Rank {rank} did not detect {entity_type} labels for graph type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}"
)

# Extract weight column from edge features when weight_edge_feat_name is set.
# The weight column is sliced out of each edge type's feature tensor and stored
# separately so it is not duplicated in the feature matrix.
if weight_edge_feat_name is not None and entity_type == _EDGE_KEY:
if isinstance(weight_edge_feat_name, str):
if len(serialized_tf_record_info) != 1 or len(features) != 1:
raise ValueError(
f"weight_edge_feat_name must be a dict[EdgeType, str] for heterogeneous "
f"graphs with multiple edge types ({sorted(serialized_tf_record_info)}). "
"Provide an explicit per-edge-type mapping instead of a single string."
)
col_name = weight_edge_feat_name
edge_type, feat_tensor = next(iter(features.items()))
assert isinstance(edge_type, EdgeType)
feature_keys = list(serialized_tf_record_info[edge_type].feature_keys)
weights[edge_type], trimmed = _extract_weight_col(
feat_tensor,
feature_keys,
serialized_tf_record_info[edge_type].feature_spec,
col_name,
edge_type,
)
if trimmed is not None:
features[edge_type] = trimmed
else:
del features[edge_type]
logger.info(
f"Rank {rank} extracted weight column '{col_name}' "
f"from {entity_type} features for type {edge_type}"
)
else:
# Iterate the EdgeType-keyed dict directly to stay within EdgeType.
for edge_type, col_name in weight_edge_feat_name.items():
if edge_type not in features:
continue
feat_tensor = features[edge_type]
feature_keys = list(
serialized_tf_record_info[edge_type].feature_keys
)
weights[edge_type], trimmed = _extract_weight_col(
feat_tensor,
feature_keys,
serialized_tf_record_info[edge_type].feature_spec,
col_name,
edge_type,
)
if trimmed is not None:
features[edge_type] = trimmed
else:
del features[edge_type]
logger.info(
f"Rank {rank} extracted weight column '{col_name}' "
f"from {entity_type} features for type {edge_type}"
)

logger.info(
f"Rank {rank} is attempting to share {entity_type} id memory for tfrecord directories: {all_tf_record_uris}"
)
Expand All @@ -180,6 +297,12 @@ def _data_loading_process(
)
share_memory(labels)

if weights:
logger.info(
f"Rank {rank} is attempting to share {entity_type} weight memory for tfrecord directories: {all_tf_record_uris}"
)
share_memory(weights)

output_dict[_ID_FMT.format(entity=entity_type)] = (
list(ids.values())[0] if is_input_homogeneous else ids
)
Expand All @@ -191,6 +314,10 @@ def _data_loading_process(
output_dict[_LABEL_FMT.format(entity=entity_type)] = (
list(labels.values())[0] if is_input_homogeneous else labels
)
if weights:
output_dict[_EDGE_WEIGHTS_KEY] = (
list(weights.values())[0] if is_input_homogeneous else weights
)

logger.info(
f"Rank {rank} has finished loading {entity_type} data from tfrecord directories: {all_tf_record_uris}, elapsed time: {time.time() - start_time:.2f} seconds"
Expand All @@ -207,6 +334,7 @@ def load_torch_tensors_from_tf_record(
rank: int = 0,
node_tf_dataset_options: TFDatasetOptions = TFDatasetOptions(),
edge_tf_dataset_options: TFDatasetOptions = TFDatasetOptions(),
weight_edge_feat_name: Optional[Union[str, dict[EdgeType, str]]] = None,
) -> LoadedGraphTensors:
"""
Loads all torch tensors from a SerializedGraphMetadata object for all entity [node, edge, positive_label, negative_label] and edge / node types.
Expand All @@ -222,6 +350,10 @@ def load_torch_tensors_from_tf_record(
rank (int): Rank on current machine
node_tf_dataset_options (TFDatasetOptions): The options to use for nodes when building the dataset.
edge_tf_dataset_options (TFDatasetOptions): The options to use for edges when building the dataset.
weight_edge_feat_name (Optional[Union[str, dict[EdgeType, str]]]): Name of the edge feature column to extract
as sampling weights. The column is removed from the edge feature matrix and returned separately via
``LoadedGraphTensors.edge_weights``. Supply a single string for homogeneous graphs or a per-edge-type
dict for heterogeneous graphs.
Returns:
loaded_graph_tensors (LoadedGraphTensors): Unpartitioned Graph Tensors
"""
Expand Down Expand Up @@ -269,6 +401,7 @@ def load_torch_tensors_from_tf_record(
"serialized_tf_record_info": serialized_graph_metadata.edge_entity_info,
"rank": rank,
"tf_dataset_options": edge_tf_dataset_options,
"weight_edge_feat_name": weight_edge_feat_name,
},
)

Expand Down Expand Up @@ -351,6 +484,7 @@ def load_torch_tensors_from_tf_record(

edge_index = edge_output_dict[_ID_FMT.format(entity=_EDGE_KEY)]
edge_features = edge_output_dict.get(_FEATURE_FMT.format(entity=_EDGE_KEY), None)
edge_weights = edge_output_dict.get(_EDGE_WEIGHTS_KEY, None)

positive_labels = edge_output_dict.get(
_ID_FMT.format(entity=_POSITIVE_LABEL_KEY), None
Expand Down Expand Up @@ -378,4 +512,5 @@ def load_torch_tensors_from_tf_record(
edge_features=edge_features,
positive_label=positive_labels,
negative_label=negative_labels,
edge_weights=edge_weights,
)
42 changes: 41 additions & 1 deletion gigl/distributed/base_dist_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,13 +333,50 @@ def __init__(
"for graph-store mode."
)

@staticmethod
def validate_with_weight(
with_weight: bool,
dataset: Union[DistDataset, RemoteDistDataset],
sampler_options: SamplerOptions,
) -> None:
"""Validates the ``with_weight`` parameter against the dataset and sampler.

Args:
with_weight: Whether weighted sampling was requested.
dataset: The dataset being sampled from.
sampler_options: The sampler to be used.

Raises:
ValueError: If ``with_weight=True`` but no edge weights are registered.
NotImplementedError: If ``with_weight=True`` and a PPR sampler is requested.
"""
if not with_weight:
return
has_edge_weights = (
dataset.has_edge_weights
if isinstance(dataset, DistDataset)
else dataset.fetch_edge_weights_registered()
)
if not has_edge_weights:
raise ValueError(
"with_weight=True requires edge weights to be registered in the dataset. "
"Pass weight_edge_feat_name to build_dataset() to register edge weights."
)
# TODO(mkolodner-sc): Implement weight-proportional residual propagation for PPR.
if with_weight and isinstance(sampler_options, PPRSamplerOptions):
raise NotImplementedError(
"Weighted sampling is not yet supported with PPRSamplerOptions. "
"Weight-proportional residual propagation for PPR is planned but not implemented."
)

@staticmethod
def create_sampling_config(
num_neighbors: Union[list[int], dict[EdgeType, list[int]]],
dataset_schema: DatasetSchema,
batch_size: int = 1,
shuffle: bool = False,
drop_last: bool = False,
with_weight: bool = False,
) -> SamplingConfig:
"""Creates a SamplingConfig with patched fanout.

Expand All @@ -352,6 +389,9 @@ def create_sampling_config(
batch_size: How many samples per batch.
shuffle: Whether to shuffle input nodes.
drop_last: Whether to drop the last incomplete batch.
with_weight: Whether to use edge weights for sampling. Requires that
edge weights were registered during dataset construction via
``DistPartitioner.register_edge_weights()``.

Returns:
A fully configured SamplingConfig.
Expand All @@ -369,7 +409,7 @@ def create_sampling_config(
with_edge=True,
collect_features=True,
with_neg=False,
with_weight=False,
with_weight=with_weight,
edge_dir=dataset_schema.edge_dir,
seed=None,
)
Expand Down
20 changes: 20 additions & 0 deletions gigl/distributed/dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def _load_and_build_partitioned_dataset(
edge_tf_dataset_options: TFDatasetOptions,
splitter: Optional[Union[NodeSplitter, NodeAnchorLinkSplitter]] = None,
_ssl_positive_label_percentage: Optional[float] = None,
weight_edge_feat_name: Optional[Union[str, dict[EdgeType, str]]] = None,
) -> DistDataset:
"""
Given some information about serialized TFRecords, loads and builds a partitioned dataset into a DistDataset class.
Expand All @@ -82,6 +83,11 @@ def _load_and_build_partitioned_dataset(
splitter (Optional[Union[NodeSplitter, NodeAnchorLinkSplitter]]): Optional splitter to use for splitting the graph data into train, val, and test sets. If not provided (None), no splitting will be performed.
_ssl_positive_label_percentage (Optional[float]): Percentage of edges to select as self-supervised labels. Must be None if supervised edge labels are provided in advance.
Slotted for refactor once this functionality is available in the transductive `splitter` directly
weight_edge_feat_name (Optional[Union[str, dict[EdgeType, str]]]): Name of the edge feature column to use as
sampling weights. The column is extracted from the feature tensor and registered separately via
``DistPartitioner.register_edge_weights()``; it is removed from the feature matrix to avoid duplication.
Supply a single string to use the same column name for all edge types, or a per-edge-type dict.

Returns:
DistDataset: Initialized dataset with partitioned graph information

Expand All @@ -104,6 +110,7 @@ def _load_and_build_partitioned_dataset(
rank=rank,
node_tf_dataset_options=node_tf_dataset_options,
edge_tf_dataset_options=edge_tf_dataset_options,
weight_edge_feat_name=weight_edge_feat_name,
)

# TODO (mkolodner-sc): Move this code block (from here up to start of partitioning) to transductive splitter once that is ready
Expand Down Expand Up @@ -175,6 +182,10 @@ def _load_and_build_partitioned_dataset(
)
if loaded_graph_tensors.node_labels is not None:
partitioner.register_node_labels(node_labels=loaded_graph_tensors.node_labels)
if loaded_graph_tensors.edge_weights is not None:
partitioner.register_edge_weights(
edge_weights=loaded_graph_tensors.edge_weights
)
if loaded_graph_tensors.edge_features is not None:
partitioner.register_edge_features(
edge_features=loaded_graph_tensors.edge_features
Expand All @@ -196,6 +207,7 @@ def _load_and_build_partitioned_dataset(
loaded_graph_tensors.node_features,
loaded_graph_tensors.edge_index,
loaded_graph_tensors.edge_features,
loaded_graph_tensors.edge_weights,
loaded_graph_tensors.positive_label,
loaded_graph_tensors.negative_label,
loaded_graph_tensors.node_labels,
Expand Down Expand Up @@ -230,6 +242,7 @@ def _build_dataset_process(
edge_tf_dataset_options: TFDatasetOptions,
splitter: Optional[Union[NodeSplitter, NodeAnchorLinkSplitter]] = None,
_ssl_positive_label_percentage: Optional[float] = None,
weight_edge_feat_name: Optional[Union[str, dict[EdgeType, str]]] = None,
) -> None:
"""
This function is spawned by a single process per machine and is responsible for:
Expand Down Expand Up @@ -311,6 +324,7 @@ def _build_dataset_process(
edge_tf_dataset_options=edge_tf_dataset_options,
splitter=splitter,
_ssl_positive_label_percentage=_ssl_positive_label_percentage,
weight_edge_feat_name=weight_edge_feat_name,
)

output_dict["dataset"] = output_dataset
Expand All @@ -336,6 +350,7 @@ def build_dataset(
_dataset_building_port: Optional[
int
] = None, # WARNING: This field will be deprecated in the future
weight_edge_feat_name: Optional[Union[str, dict[EdgeType, str]]] = None,
) -> DistDataset:
"""
Launches a spawned process for building and returning a DistDataset instance provided some
Expand Down Expand Up @@ -368,6 +383,10 @@ def build_dataset(
Slotted for refactor once this functionality is available in the transductive `splitter` directly
_dataset_building_port (deprecated field - will be removed soon) (Optional[int]): Contains information about master port. Defaults to None, in which case it will
be initialized from the current torch.distributed context.
weight_edge_feat_name (Optional[Union[str, dict[EdgeType, str]]]): Name of the edge feature column to use
as sampling weights. The column is extracted from the feature tensor and registered separately; it is
removed from the feature matrix to avoid memory duplication. Supply a single string to apply to all
edge types, or a per-edge-type dict. (default: ``None``)

Returns:
DistDataset: Built GraphLearn-for-PyTorch Dataset class
Expand Down Expand Up @@ -463,6 +482,7 @@ def build_dataset(
edge_tf_dataset_options,
splitter,
_ssl_positive_label_percentage,
weight_edge_feat_name,
),
)

Expand Down
Loading