From bc2ee10ba85634831d5a9bf1ed839f317bfba5f8 Mon Sep 17 00:00:00 2001 From: kamalahasiniburra Date: Mon, 30 Mar 2026 21:45:20 +0530 Subject: [PATCH] Add robust dataset loader for inconsistent .npy formats (fixes #178) --- dataset_loader.py | 335 +++++++++++++++++++++++++++++++++++++++++ test_dataset_loader.py | 181 ++++++++++++++++++++++ 2 files changed, 516 insertions(+) create mode 100644 dataset_loader.py create mode 100644 test_dataset_loader.py diff --git a/dataset_loader.py b/dataset_loader.py new file mode 100644 index 00000000..f6312751 --- /dev/null +++ b/dataset_loader.py @@ -0,0 +1,335 @@ +""" +Robust Dataset Loader for DeepLense. + +Addresses Issue #178: Add robust dataset loader utility for inconsistent +.npy formats. This module provides a flexible dataset loader that handles +different .npy file formats, shapes, and dtypes commonly found across the +various DeepLense sub-projects. + +Author: Kamala Hasini Burra +""" + +import os +import glob +import numpy as np +from typing import Dict, List, Optional, Tuple, Union +import warnings + + +class DatasetLoader: + """Robust loader for .npy datasets with format validation. + + Handles common inconsistencies found in DeepLense sub-projects: + - Mixed dtypes (float32, float64, uint8, int) + - Inconsistent shapes (H,W), (H,W,C), (C,H,W) + - Missing or corrupted files + - Different normalization ranges ([0,1], [0,255], [-1,1]) + + Parameters + ---------- + data_dir : str + Root directory containing .npy files. + target_shape : tuple, optional + Target image shape (H, W). Set to None to keep original shape. + target_dtype : np.dtype, optional + Target dtype. Default is np.float32. + channel_format : str, optional + 'channels_last' (H,W,C) or 'channels_first' (C,H,W). + Default is 'channels_last'. + + Examples + -------- + >>> loader = DatasetLoader("path/to/data", target_shape=(64, 64)) + >>> images, labels = loader.load_dataset() + >>> print(images.shape, images.dtype) + """ + + def __init__( + self, + data_dir: str, + target_shape: Optional[Tuple[int, int]] = None, + target_dtype: np.dtype = np.float32, + channel_format: str = "channels_last", + ): + self.data_dir = data_dir + self.target_shape = target_shape + self.target_dtype = target_dtype + self.channel_format = channel_format + + if channel_format not in ("channels_last", "channels_first"): + raise ValueError( + f"channel_format must be 'channels_last' or 'channels_first', " + f"got '{channel_format}'" + ) + + def load_npy_safe(self, filepath: str) -> Optional[np.ndarray]: + """Safely load a .npy file, handling common errors. + + Parameters + ---------- + filepath : str + Path to the .npy file. + + Returns + ------- + np.ndarray or None + Loaded array, or None if loading failed. + """ + try: + data = np.load(filepath, allow_pickle=True) + + # Handle object arrays (sometimes created with allow_pickle) + if data.dtype == object: + try: + data = np.array(data.tolist(), dtype=self.target_dtype) + except (ValueError, TypeError): + warnings.warn( + f"Could not convert object array in {filepath}. Skipping." + ) + return None + + return data.astype(self.target_dtype) + + except Exception as e: + warnings.warn(f"Failed to load {filepath}: {e}") + return None + + def validate_image_shape( + self, image: np.ndarray + ) -> Optional[np.ndarray]: + """Validate and standardize image shape. + + Parameters + ---------- + image : np.ndarray + Input image array. + + Returns + ------- + np.ndarray or None + Standardized image, or None if shape is invalid. + """ + ndim = image.ndim + + if ndim == 2: + # Grayscale (H, W) -> (H, W, 1) + if self.channel_format == "channels_last": + image = image[:, :, np.newaxis] + else: + image = image[np.newaxis, :, :] + + elif ndim == 3: + # Check if channels_first: (C, H, W) where C is small + if image.shape[0] <= 4 and image.shape[1] > 4 and image.shape[2] > 4: + if self.channel_format == "channels_last": + image = np.transpose(image, (1, 2, 0)) + elif image.shape[2] <= 4 and image.shape[0] > 4 and image.shape[1] > 4: + if self.channel_format == "channels_first": + image = np.transpose(image, (2, 0, 1)) + else: + warnings.warn(f"Unexpected image ndim={ndim}, shape={image.shape}") + return None + + # Resize if target shape specified + if self.target_shape is not None: + image = self._resize_image(image) + + return image + + def _resize_image(self, image: np.ndarray) -> np.ndarray: + """Resize image using simple nearest-neighbor interpolation. + + Uses numpy only (no PIL/cv2 dependency required). + + Parameters + ---------- + image : np.ndarray + Input image. + + Returns + ------- + np.ndarray + Resized image. + """ + target_h, target_w = self.target_shape + + if self.channel_format == "channels_last": + h, w = image.shape[0], image.shape[1] + channels = image.shape[2] if image.ndim == 3 else 1 + else: + h, w = image.shape[1], image.shape[2] + channels = image.shape[0] if image.ndim == 3 else 1 + + if h == target_h and w == target_w: + return image + + # Nearest-neighbor resize + row_indices = (np.arange(target_h) * h / target_h).astype(int) + col_indices = (np.arange(target_w) * w / target_w).astype(int) + row_indices = np.clip(row_indices, 0, h - 1) + col_indices = np.clip(col_indices, 0, w - 1) + + if self.channel_format == "channels_last": + return image[np.ix_(row_indices, col_indices)] + else: + return image[:, np.ix_(row_indices, col_indices)].reshape( + channels, target_h, target_w + ) + + def detect_normalization( + self, images: np.ndarray + ) -> str: + """Detect the normalization range of a batch of images. + + Parameters + ---------- + images : np.ndarray + Batch of images. + + Returns + ------- + str + One of 'uint8' ([0,255]), 'normalized' ([0,1]), + 'centered' ([-1,1]), or 'unknown'. + """ + vmin, vmax = float(images.min()), float(images.max()) + + if vmin >= 0 and vmax > 1 and vmax <= 255: + return "uint8" + elif vmin >= 0 and vmax <= 1.0: + return "normalized" + elif vmin >= -1.0 and vmax <= 1.0: + return "centered" + else: + return "unknown" + + def normalize_to_range( + self, + images: np.ndarray, + target_range: str = "normalized", + ) -> np.ndarray: + """Normalize images to a consistent range. + + Parameters + ---------- + images : np.ndarray + Input images. + target_range : str + Target range: 'normalized' ([0,1]) or 'centered' ([-1,1]). + + Returns + ------- + np.ndarray + Normalized images. + """ + detected = self.detect_normalization(images) + + if detected == target_range: + return images + + # First normalize to [0, 1] + if detected == "uint8": + images = images / 255.0 + elif detected == "centered": + images = (images + 1.0) / 2.0 + elif detected == "unknown": + vmin, vmax = images.min(), images.max() + if vmax - vmin > 0: + images = (images - vmin) / (vmax - vmin) + + # Convert to target + if target_range == "centered": + images = images * 2.0 - 1.0 + + return images.astype(self.target_dtype) + + def load_dataset( + self, + image_pattern: str = "*.npy", + label_file: Optional[str] = None, + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """Load a complete dataset from directory. + + Parameters + ---------- + image_pattern : str + Glob pattern for image files. + label_file : str, optional + Path to labels .npy file. + + Returns + ------- + tuple + (images, labels) where labels may be None. + """ + pattern = os.path.join(self.data_dir, image_pattern) + files = sorted(glob.glob(pattern)) + + if not files: + raise FileNotFoundError( + f"No files matching '{image_pattern}' in {self.data_dir}" + ) + + images = [] + skipped = 0 + + for f in files: + img = self.load_npy_safe(f) + if img is None: + skipped += 1 + continue + + # Handle batch files (multiple images in one .npy) + if img.ndim >= 3 and img.shape[0] > 4: + # Likely a batch of images + for j in range(img.shape[0]): + validated = self.validate_image_shape(img[j]) + if validated is not None: + images.append(validated) + else: + validated = self.validate_image_shape(img) + if validated is not None: + images.append(validated) + + if skipped > 0: + warnings.warn(f"Skipped {skipped}/{len(files)} files due to errors.") + + if not images: + raise ValueError("No valid images loaded from the dataset.") + + images_array = np.array(images, dtype=self.target_dtype) + + # Load labels if provided + labels = None + if label_file: + label_path = os.path.join(self.data_dir, label_file) + labels = self.load_npy_safe(label_path) + + return images_array, labels + + def get_dataset_info(self) -> Dict[str, Union[str, int]]: + """Get summary information about the dataset directory. + + Returns + ------- + dict + Dictionary with dataset statistics. + """ + npy_files = glob.glob(os.path.join(self.data_dir, "*.npy")) + total_size = sum(os.path.getsize(f) for f in npy_files) + + info = { + "data_dir": self.data_dir, + "num_npy_files": len(npy_files), + "total_size_mb": round(total_size / (1024 * 1024), 2), + } + + # Sample first file for shape/dtype info + if npy_files: + sample = self.load_npy_safe(npy_files[0]) + if sample is not None: + info["sample_shape"] = str(sample.shape) + info["sample_dtype"] = str(sample.dtype) + info["normalization"] = self.detect_normalization(sample) + + return info diff --git a/test_dataset_loader.py b/test_dataset_loader.py new file mode 100644 index 00000000..72704270 --- /dev/null +++ b/test_dataset_loader.py @@ -0,0 +1,181 @@ +"""Tests for the dataset_loader module.""" + +import numpy as np +import os +import sys +import tempfile +import shutil + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from dataset_loader import DatasetLoader + + +class TestDatasetLoader: + """Test suite for DatasetLoader.""" + + def setup_temp_dir(self): + self.temp_dir = tempfile.mkdtemp() + return self.temp_dir + + def cleanup(self): + shutil.rmtree(self.temp_dir, ignore_errors=True) + + +def test_load_npy_safe_valid(): + t = TestDatasetLoader() + d = t.setup_temp_dir() + loader = DatasetLoader(d) + arr = np.random.rand(64, 64).astype(np.float32) + path = os.path.join(d, "test.npy") + np.save(path, arr) + loaded = loader.load_npy_safe(path) + assert loaded is not None + assert loaded.shape == (64, 64) + t.cleanup() + print("[PASS] test_load_npy_safe_valid") + + +def test_load_npy_safe_invalid(): + t = TestDatasetLoader() + d = t.setup_temp_dir() + loader = DatasetLoader(d) + path = os.path.join(d, "bad.npy") + with open(path, "wb") as f: + f.write(b"\x00\x01\x02\x03invalid") + loaded = loader.load_npy_safe(path) + assert loaded is None + t.cleanup() + print("[PASS] test_load_npy_safe_invalid") + + +def test_validate_grayscale(): + t = TestDatasetLoader() + d = t.setup_temp_dir() + loader = DatasetLoader(d, channel_format="channels_last") + img = np.random.rand(64, 64).astype(np.float32) + result = loader.validate_image_shape(img) + assert result is not None + assert result.shape == (64, 64, 1) + t.cleanup() + print("[PASS] test_validate_grayscale") + + +def test_validate_channels_first_to_last(): + t = TestDatasetLoader() + d = t.setup_temp_dir() + loader = DatasetLoader(d, channel_format="channels_last") + img = np.random.rand(3, 64, 64).astype(np.float32) + result = loader.validate_image_shape(img) + assert result is not None + assert result.shape == (64, 64, 3) + t.cleanup() + print("[PASS] test_validate_channels_first_to_last") + + +def test_detect_uint8(): + t = TestDatasetLoader() + d = t.setup_temp_dir() + loader = DatasetLoader(d) + imgs = np.random.randint(0, 255, (10, 64, 64)).astype(np.float32) + assert loader.detect_normalization(imgs) == "uint8" + t.cleanup() + print("[PASS] test_detect_uint8") + + +def test_detect_normalized(): + t = TestDatasetLoader() + d = t.setup_temp_dir() + loader = DatasetLoader(d) + imgs = np.random.rand(10, 64, 64).astype(np.float32) + assert loader.detect_normalization(imgs) == "normalized" + t.cleanup() + print("[PASS] test_detect_normalized") + + +def test_detect_centered(): + t = TestDatasetLoader() + d = t.setup_temp_dir() + loader = DatasetLoader(d) + imgs = np.random.rand(10, 64, 64).astype(np.float32) * 2 - 1 + assert loader.detect_normalization(imgs) == "centered" + t.cleanup() + print("[PASS] test_detect_centered") + + +def test_normalize_uint8_to_01(): + t = TestDatasetLoader() + d = t.setup_temp_dir() + loader = DatasetLoader(d) + imgs = np.array([0, 128, 255], dtype=np.float32) + normed = loader.normalize_to_range(imgs, "normalized") + assert abs(normed[0]) < 1e-5 + assert abs(normed[-1] - 1.0) < 1e-5 + t.cleanup() + print("[PASS] test_normalize_uint8_to_01") + + +def test_load_dataset(): + t = TestDatasetLoader() + d = t.setup_temp_dir() + for i in range(5): + np.save(os.path.join(d, f"img_{i}.npy"), + np.random.rand(64, 64).astype(np.float32)) + loader = DatasetLoader(d) + images, labels = loader.load_dataset() + assert images.shape[0] == 5 + assert labels is None + t.cleanup() + print("[PASS] test_load_dataset") + + +def test_load_dataset_with_labels(): + t = TestDatasetLoader() + d = t.setup_temp_dir() + for i in range(5): + np.save(os.path.join(d, f"img_{i}.npy"), + np.random.rand(64, 64).astype(np.float32)) + np.save(os.path.join(d, "labels.npy"), np.array([0, 1, 2, 0, 1])) + loader = DatasetLoader(d) + images, labels = loader.load_dataset(label_file="labels.npy") + assert images.shape[0] == 5 + assert labels is not None + assert len(labels) == 5 + t.cleanup() + print("[PASS] test_load_dataset_with_labels") + + +def test_get_dataset_info(): + t = TestDatasetLoader() + d = t.setup_temp_dir() + np.save(os.path.join(d, "test.npy"), np.random.rand(32, 32)) + loader = DatasetLoader(d) + info = loader.get_dataset_info() + assert info["num_npy_files"] == 1 + assert "sample_shape" in info + t.cleanup() + print("[PASS] test_get_dataset_info") + + +def test_invalid_channel_format(): + try: + DatasetLoader(".", channel_format="invalid") + assert False, "Should raise ValueError" + except ValueError: + pass + print("[PASS] test_invalid_channel_format") + + +if __name__ == "__main__": + test_load_npy_safe_valid() + test_load_npy_safe_invalid() + test_validate_grayscale() + test_validate_channels_first_to_last() + test_detect_uint8() + test_detect_normalized() + test_detect_centered() + test_normalize_uint8_to_01() + test_load_dataset() + test_load_dataset_with_labels() + test_get_dataset_info() + test_invalid_channel_format() + print("\n=== All 12 tests passed! ===")