diff --git a/gigl/common/utils/local_fs.py b/gigl/common/utils/local_fs.py index 56de1c7a0..d0630aa51 100644 --- a/gigl/common/utils/local_fs.py +++ b/gigl/common/utils/local_fs.py @@ -91,7 +91,7 @@ def list_at_path( Args: local_path (LocalUri): The local path to search for files and directories. regex (Optional[str]): Optional regex to match. If not provided then all children will be returned. - entity (Optional[FileSystemEntity]): Optional entity type to filter by. If not provided then all children will be returned. + file_system_entity (Optional[FileSystemEntity]): Optional entity type to filter by. If not provided then all children will be returned. names_only (bool): If True, return only the base names of the files and directories. Defaults to False. e.g /path/to/file.txt -> file.txt Returns: diff --git a/gigl/src/common/types/pb_wrappers/preprocessed_metadata.py b/gigl/src/common/types/pb_wrappers/preprocessed_metadata.py index 5f841c77a..6dfb966b0 100644 --- a/gigl/src/common/types/pb_wrappers/preprocessed_metadata.py +++ b/gigl/src/common/types/pb_wrappers/preprocessed_metadata.py @@ -282,9 +282,10 @@ def __get_feature_to_vocab_list_map( ) -> FeatureVocabDict: if isinstance(transform_fn_assets_uri, LocalUri): list_files_fn = partial( - LocalFsUtils.list_at_path, entity=LocalFsUtils.FileSystemEntity.FILE - ) # type: ignore - read_file_fn = lambda path: open(path, "rb") # type: ignore + LocalFsUtils.list_at_path, + file_system_entity=LocalFsUtils.FileSystemEntity.FILE, + ) + read_file_fn = lambda path: open(path, "rb") elif isinstance(transform_fn_assets_uri, GcsUri): gcs_utils = GcsUtils() list_files_fn = gcs_utils.list_uris_with_gcs_path_pattern # type: ignore @@ -294,7 +295,7 @@ def __get_feature_to_vocab_list_map( f"Invalid uri: {transform_fn_assets_uri}. Must be either {GcsUri.__name__} or {LocalUri.__name__}" ) - assets_file_paths = list_files_fn(transform_fn_assets_uri) # type: ignore + assets_file_paths = list_files_fn(transform_fn_assets_uri) feature_to_vocab_list_map = {} for asset_file_path in assets_file_paths: feature_key = asset_file_path.uri.split("/")[-1] diff --git a/tests/unit/src/common/types/pb_wrappers/preprocessed_metadata_test.py b/tests/unit/src/common/types/pb_wrappers/preprocessed_metadata_test.py index 30570440c..9f2f6cc5d 100644 --- a/tests/unit/src/common/types/pb_wrappers/preprocessed_metadata_test.py +++ b/tests/unit/src/common/types/pb_wrappers/preprocessed_metadata_test.py @@ -1,11 +1,18 @@ +import os +import tempfile + from absl.testing import absltest from gigl.src.common.constants.graph_metadata import DEFAULT_CONDENSED_NODE_TYPE from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.common.types.pb_wrappers.preprocessed_metadata import ( + PreprocessedMetadataPbWrapper, +) from gigl.src.mocking.lib.versioning import get_mocked_dataset_artifact_metadata from gigl.src.mocking.mocking_assets.mocked_datasets_for_pipeline_tests import ( CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO, ) +from snapchat.research.gbml import preprocessed_metadata_pb2 from tests.test_assets.test_case import TestCase @@ -55,6 +62,37 @@ def test_feature_schema_keys_match_original_keys(self): self.assertEqual(feature_index_keys, original_feature_keys) self.assertEqual(feature_schema_keys, original_feature_keys) + def test_local_uri_transform_fn_assets_branch(self): + """Exercise the LocalUri branch of __get_feature_to_vocab_list_map. + + Locks in the kwarg name forwarded to LocalFsUtils.list_at_path; a + mismatch (e.g. entity= vs file_system_entity=) raises TypeError only + when this branch actually runs, which no production caller does. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + # An empty file is a valid (empty) TFDV Schema textproto. + schema_path = os.path.join(tmp_dir, "schema.pbtxt") + with open(schema_path, "w"): + pass + transform_fn_assets_dir = os.path.join(tmp_dir, "transform_fn_assets") + os.makedirs(transform_fn_assets_dir) + + preprocessed_metadata_pb = preprocessed_metadata_pb2.PreprocessedMetadata() + node_metadata = ( + preprocessed_metadata_pb.condensed_node_type_to_preprocessed_metadata[0] + ) + node_metadata.schema_uri = schema_path + node_metadata.transform_fn_assets_uri = transform_fn_assets_dir + + wrapper = PreprocessedMetadataPbWrapper( + preprocessed_metadata_pb=preprocessed_metadata_pb + ) + + feature_schema = wrapper.condensed_node_type_to_feature_schema_map[ + DEFAULT_CONDENSED_NODE_TYPE + ] + self.assertEqual(feature_schema.feature_vocab, {}) + if __name__ == "__main__": absltest.main()