Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gigl/common/utils/local_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def list_at_path(
Args:
local_path (LocalUri): The local path to search for files and directories.
regex (Optional[str]): Optional regex to match. If not provided then all children will be returned.
entity (Optional[FileSystemEntity]): Optional entity type to filter by. If not provided then all children will be returned.
file_system_entity (Optional[FileSystemEntity]): Optional entity type to filter by. If not provided then all children will be returned.
names_only (bool): If True, return only the base names of the files and directories. Defaults to False. e.g /path/to/file.txt -> file.txt
Returns:
Expand Down
9 changes: 5 additions & 4 deletions gigl/src/common/types/pb_wrappers/preprocessed_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,10 @@ def __get_feature_to_vocab_list_map(
) -> FeatureVocabDict:
if isinstance(transform_fn_assets_uri, LocalUri):
list_files_fn = partial(
LocalFsUtils.list_at_path, entity=LocalFsUtils.FileSystemEntity.FILE
) # type: ignore
read_file_fn = lambda path: open(path, "rb") # type: ignore
LocalFsUtils.list_at_path,
file_system_entity=LocalFsUtils.FileSystemEntity.FILE,
)
read_file_fn = lambda path: open(path, "rb")
elif isinstance(transform_fn_assets_uri, GcsUri):
gcs_utils = GcsUtils()
list_files_fn = gcs_utils.list_uris_with_gcs_path_pattern # type: ignore
Expand All @@ -294,7 +295,7 @@ def __get_feature_to_vocab_list_map(
f"Invalid uri: {transform_fn_assets_uri}. Must be either {GcsUri.__name__} or {LocalUri.__name__}"
)

assets_file_paths = list_files_fn(transform_fn_assets_uri) # type: ignore
assets_file_paths = list_files_fn(transform_fn_assets_uri)
feature_to_vocab_list_map = {}
for asset_file_path in assets_file_paths:
feature_key = asset_file_path.uri.split("/")[-1]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import os
import tempfile

from absl.testing import absltest

from gigl.src.common.constants.graph_metadata import DEFAULT_CONDENSED_NODE_TYPE
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.types.pb_wrappers.preprocessed_metadata import (
PreprocessedMetadataPbWrapper,
)
from gigl.src.mocking.lib.versioning import get_mocked_dataset_artifact_metadata
from gigl.src.mocking.mocking_assets.mocked_datasets_for_pipeline_tests import (
CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO,
)
from snapchat.research.gbml import preprocessed_metadata_pb2
from tests.test_assets.test_case import TestCase


Expand Down Expand Up @@ -55,6 +62,37 @@ def test_feature_schema_keys_match_original_keys(self):
self.assertEqual(feature_index_keys, original_feature_keys)
self.assertEqual(feature_schema_keys, original_feature_keys)

def test_local_uri_transform_fn_assets_branch(self):
"""Exercise the LocalUri branch of __get_feature_to_vocab_list_map.

Locks in the kwarg name forwarded to LocalFsUtils.list_at_path; a
mismatch (e.g. entity= vs file_system_entity=) raises TypeError only
when this branch actually runs, which no production caller does.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
# An empty file is a valid (empty) TFDV Schema textproto.
schema_path = os.path.join(tmp_dir, "schema.pbtxt")
with open(schema_path, "w"):
pass
transform_fn_assets_dir = os.path.join(tmp_dir, "transform_fn_assets")
os.makedirs(transform_fn_assets_dir)

preprocessed_metadata_pb = preprocessed_metadata_pb2.PreprocessedMetadata()
node_metadata = (
preprocessed_metadata_pb.condensed_node_type_to_preprocessed_metadata[0]
)
node_metadata.schema_uri = schema_path
node_metadata.transform_fn_assets_uri = transform_fn_assets_dir

wrapper = PreprocessedMetadataPbWrapper(
preprocessed_metadata_pb=preprocessed_metadata_pb
)

feature_schema = wrapper.condensed_node_type_to_feature_schema_map[
DEFAULT_CONDENSED_NODE_TYPE
]
self.assertEqual(feature_schema.feature_vocab, {})


if __name__ == "__main__":
absltest.main()