From 328da220fe77274b2baf630e5ca24928959203bd Mon Sep 17 00:00:00 2001 From: kamalahasiniburra Date: Mon, 30 Mar 2026 21:28:04 +0530 Subject: [PATCH] Add data preprocessing utilities for image normalization and augmentation --- preprocessing_utils.py | 263 ++++++++++++++++++++++++++++++++++++ test_preprocessing_utils.py | 131 ++++++++++++++++++ 2 files changed, 394 insertions(+) create mode 100644 preprocessing_utils.py create mode 100644 test_preprocessing_utils.py diff --git a/preprocessing_utils.py b/preprocessing_utils.py new file mode 100644 index 00000000..6f1fe0e1 --- /dev/null +++ b/preprocessing_utils.py @@ -0,0 +1,263 @@ +""" +Data Preprocessing Utilities for DeepLense. + +This module provides reusable data preprocessing functions for +gravitational lensing image datasets. It includes image normalization, +augmentation helpers, and dataset splitting utilities that can be used +across different DeepLense sub-projects. + +Author: Kamala Hasini Burra +""" + +import numpy as np +from typing import Dict, List, Optional, Tuple + + +def normalize_images( + images: np.ndarray, + method: str = "minmax", +) -> np.ndarray: + """Normalize a batch of images using the specified method. + + Parameters + ---------- + images : np.ndarray + Batch of images with shape (N, H, W) or (N, H, W, C). + method : str + Normalization method. Options: + - 'minmax': Scale to [0, 1] range + - 'zscore': Zero mean, unit variance + - 'robust': Median-centered with IQR scaling + + Returns + ------- + np.ndarray + Normalized images with the same shape. + + Raises + ------ + ValueError + If an unknown normalization method is specified. + + Examples + -------- + >>> imgs = np.random.randint(0, 255, (10, 64, 64)).astype(float) + >>> normed = normalize_images(imgs, method='minmax') + >>> assert normed.min() >= 0.0 and normed.max() <= 1.0 + """ + images = images.astype(np.float64) + + if method == "minmax": + img_min = images.min() + img_max = images.max() + if img_max - img_min > 0: + return (images - img_min) / (img_max - img_min) + return np.zeros_like(images) + + elif method == "zscore": + mean = images.mean() + std = images.std() + if std > 0: + return (images - mean) / std + return np.zeros_like(images) + + elif method == "robust": + median = np.median(images) + q75 = np.percentile(images, 75) + q25 = np.percentile(images, 25) + iqr = q75 - q25 + if iqr > 0: + return (images - median) / iqr + return np.zeros_like(images) + + else: + raise ValueError( + f"Unknown normalization method '{method}'. " + f"Choose from: 'minmax', 'zscore', 'robust'." + ) + + +def augment_image( + image: np.ndarray, + flip_horizontal: bool = False, + flip_vertical: bool = False, + rotate_90: int = 0, + add_noise: bool = False, + noise_std: float = 0.01, +) -> np.ndarray: + """Apply augmentation transforms to a single image. + + Parameters + ---------- + image : np.ndarray + Input image of shape (H, W) or (H, W, C). + flip_horizontal : bool + Whether to flip horizontally. + flip_vertical : bool + Whether to flip vertically. + rotate_90 : int + Number of 90-degree rotations (0, 1, 2, or 3). + add_noise : bool + Whether to add Gaussian noise. + noise_std : float + Standard deviation of Gaussian noise (only used if add_noise=True). + + Returns + ------- + np.ndarray + Augmented image with the same dtype. + """ + result = image.copy() + + if flip_horizontal: + result = np.flip(result, axis=1) + + if flip_vertical: + result = np.flip(result, axis=0) + + if rotate_90 > 0: + axes = (0, 1) + result = np.rot90(result, k=rotate_90, axes=axes) + + if add_noise: + noise = np.random.normal(0, noise_std, result.shape) + result = result + noise + + return np.ascontiguousarray(result) + + +def stratified_split( + labels: np.ndarray, + train_ratio: float = 0.7, + val_ratio: float = 0.15, + test_ratio: float = 0.15, + random_seed: int = 42, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Split dataset indices into train/val/test with stratification. + + Ensures each class is proportionally represented in all splits. + + Parameters + ---------- + labels : np.ndarray + Array of integer class labels. + train_ratio : float + Proportion for training set. + val_ratio : float + Proportion for validation set. + test_ratio : float + Proportion for test set. + random_seed : int + Random seed for reproducibility. + + Returns + ------- + tuple of np.ndarray + (train_indices, val_indices, test_indices) + + Raises + ------ + ValueError + If ratios don't sum to approximately 1.0. + """ + ratio_sum = train_ratio + val_ratio + test_ratio + if abs(ratio_sum - 1.0) > 0.01: + raise ValueError( + f"Ratios must sum to 1.0, got {ratio_sum:.2f}" + ) + + rng = np.random.RandomState(random_seed) + classes = np.unique(labels) + + train_idx: List[int] = [] + val_idx: List[int] = [] + test_idx: List[int] = [] + + for cls in classes: + cls_indices = np.where(labels == cls)[0] + rng.shuffle(cls_indices) + + n = len(cls_indices) + n_train = int(n * train_ratio) + n_val = int(n * val_ratio) + + train_idx.extend(cls_indices[:n_train]) + val_idx.extend(cls_indices[n_train:n_train + n_val]) + test_idx.extend(cls_indices[n_train + n_val:]) + + return ( + np.array(train_idx), + np.array(val_idx), + np.array(test_idx), + ) + + +def compute_class_weights( + labels: np.ndarray, + method: str = "balanced", +) -> Dict[int, float]: + """Compute class weights for handling imbalanced datasets. + + Parameters + ---------- + labels : np.ndarray + Array of integer class labels. + method : str + Weighting method: + - 'balanced': Inversely proportional to class frequency + - 'sqrt': Square root of balanced weights (softer rebalancing) + + Returns + ------- + dict + Mapping from class label to weight. + """ + classes, counts = np.unique(labels, return_counts=True) + n_samples = len(labels) + n_classes = len(classes) + + weights: Dict[int, float] = {} + + for cls, count in zip(classes, counts): + if method == "balanced": + weights[int(cls)] = n_samples / (n_classes * count) + elif method == "sqrt": + weights[int(cls)] = np.sqrt(n_samples / (n_classes * count)) + else: + raise ValueError(f"Unknown method '{method}'") + + return weights + + +def create_image_patches( + image: np.ndarray, + patch_size: int = 32, + stride: int = 16, +) -> np.ndarray: + """Extract overlapping patches from an image. + + Useful for patch-based training of gravitational lens models. + + Parameters + ---------- + image : np.ndarray + Input image of shape (H, W) or (H, W, C). + patch_size : int + Size of each square patch. + stride : int + Step size between consecutive patches. + + Returns + ------- + np.ndarray + Array of patches with shape (N, patch_size, patch_size, ...). + """ + h, w = image.shape[:2] + patches = [] + + for y in range(0, h - patch_size + 1, stride): + for x in range(0, w - patch_size + 1, stride): + patch = image[y:y + patch_size, x:x + patch_size] + patches.append(patch) + + return np.array(patches) diff --git a/test_preprocessing_utils.py b/test_preprocessing_utils.py new file mode 100644 index 00000000..9f03aa94 --- /dev/null +++ b/test_preprocessing_utils.py @@ -0,0 +1,131 @@ +"""Tests for the preprocessing_utils module.""" + +import numpy as np +import sys +import os + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from preprocessing_utils import ( + normalize_images, + augment_image, + stratified_split, + compute_class_weights, + create_image_patches, +) + + +def test_minmax_normalization(): + imgs = np.random.randint(0, 255, (5, 64, 64)).astype(float) + normed = normalize_images(imgs, method="minmax") + assert abs(normed.min()) < 1e-10, f"Min should be ~0, got {normed.min()}" + assert abs(normed.max() - 1.0) < 1e-10, f"Max should be ~1, got {normed.max()}" + print("[PASS] test_minmax_normalization") + + +def test_zscore_normalization(): + imgs = np.random.randn(5, 64, 64) * 50 + 100 + normed = normalize_images(imgs, method="zscore") + assert abs(normed.mean()) < 1e-10, f"Mean should be ~0, got {normed.mean()}" + assert abs(normed.std() - 1.0) < 1e-10, f"Std should be ~1, got {normed.std()}" + print("[PASS] test_zscore_normalization") + + +def test_robust_normalization(): + imgs = np.random.randn(5, 64, 64) * 10 + normed = normalize_images(imgs, method="robust") + assert normed.shape == imgs.shape + print("[PASS] test_robust_normalization") + + +def test_invalid_normalization(): + imgs = np.random.randn(5, 64, 64) + try: + normalize_images(imgs, method="invalid") + assert False, "Should have raised ValueError" + except ValueError: + pass + print("[PASS] test_invalid_normalization") + + +def test_augment_flip(): + img = np.arange(16).reshape(4, 4).astype(float) + flipped_h = augment_image(img, flip_horizontal=True) + assert flipped_h[0, 0] == img[0, 3] + flipped_v = augment_image(img, flip_vertical=True) + assert flipped_v[0, 0] == img[3, 0] + print("[PASS] test_augment_flip") + + +def test_augment_rotate(): + img = np.arange(16).reshape(4, 4).astype(float) + rotated = augment_image(img, rotate_90=1) + assert rotated.shape == (4, 4) + print("[PASS] test_augment_rotate") + + +def test_augment_noise(): + img = np.ones((64, 64)) + noisy = augment_image(img, add_noise=True, noise_std=0.1) + assert not np.allclose(img, noisy), "Noisy image should differ" + print("[PASS] test_augment_noise") + + +def test_stratified_split(): + labels = np.array([0]*100 + [1]*100 + [2]*100) + train, val, test = stratified_split(labels, 0.7, 0.15, 0.15) + assert len(train) + len(val) + len(test) == 300 + assert len(train) == 210 # 70% of 300 + print("[PASS] test_stratified_split") + + +def test_stratified_split_invalid_ratios(): + labels = np.array([0, 1, 2]) + try: + stratified_split(labels, 0.5, 0.5, 0.5) + assert False, "Should have raised ValueError" + except ValueError: + pass + print("[PASS] test_stratified_split_invalid_ratios") + + +def test_class_weights_balanced(): + labels = np.array([0]*10 + [1]*90) + weights = compute_class_weights(labels, method="balanced") + assert weights[0] > weights[1], "Minority class should have higher weight" + print("[PASS] test_class_weights_balanced") + + +def test_class_weights_sqrt(): + labels = np.array([0]*10 + [1]*90) + weights = compute_class_weights(labels, method="sqrt") + assert weights[0] > weights[1] + balanced = compute_class_weights(labels, method="balanced") + assert weights[0] < balanced[0], "sqrt weights should be softer" + print("[PASS] test_class_weights_sqrt") + + +def test_create_patches(): + img = np.random.randn(64, 64) + patches = create_image_patches(img, patch_size=32, stride=16) + assert patches.shape[1] == 32 + assert patches.shape[2] == 32 + expected_n = ((64 - 32) // 16 + 1) ** 2 # 3*3 = 9 + assert patches.shape[0] == expected_n, f"Expected {expected_n}, got {patches.shape[0]}" + print("[PASS] test_create_patches") + + +if __name__ == "__main__": + test_minmax_normalization() + test_zscore_normalization() + test_robust_normalization() + test_invalid_normalization() + test_augment_flip() + test_augment_rotate() + test_augment_noise() + test_stratified_split() + test_stratified_split_invalid_ratios() + test_class_weights_balanced() + test_class_weights_sqrt() + test_create_patches() + print("\n=== All 12 tests passed! ===")