Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 263 additions & 0 deletions preprocessing_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
"""
Data Preprocessing Utilities for DeepLense.

This module provides reusable data preprocessing functions for
gravitational lensing image datasets. It includes image normalization,
augmentation helpers, and dataset splitting utilities that can be used
across different DeepLense sub-projects.

Author: Kamala Hasini Burra
"""

import numpy as np
from typing import Dict, List, Optional, Tuple


def normalize_images(
images: np.ndarray,
method: str = "minmax",
) -> np.ndarray:
"""Normalize a batch of images using the specified method.

Parameters
----------
images : np.ndarray
Batch of images with shape (N, H, W) or (N, H, W, C).
method : str
Normalization method. Options:
- 'minmax': Scale to [0, 1] range
- 'zscore': Zero mean, unit variance
- 'robust': Median-centered with IQR scaling

Returns
-------
np.ndarray
Normalized images with the same shape.

Raises
------
ValueError
If an unknown normalization method is specified.

Examples
--------
>>> imgs = np.random.randint(0, 255, (10, 64, 64)).astype(float)
>>> normed = normalize_images(imgs, method='minmax')
>>> assert normed.min() >= 0.0 and normed.max() <= 1.0
"""
images = images.astype(np.float64)

if method == "minmax":
img_min = images.min()
img_max = images.max()
if img_max - img_min > 0:
return (images - img_min) / (img_max - img_min)
return np.zeros_like(images)

elif method == "zscore":
mean = images.mean()
std = images.std()
if std > 0:
return (images - mean) / std
return np.zeros_like(images)

elif method == "robust":
median = np.median(images)
q75 = np.percentile(images, 75)
q25 = np.percentile(images, 25)
iqr = q75 - q25
if iqr > 0:
return (images - median) / iqr
return np.zeros_like(images)

else:
raise ValueError(
f"Unknown normalization method '{method}'. "
f"Choose from: 'minmax', 'zscore', 'robust'."
)


def augment_image(
image: np.ndarray,
flip_horizontal: bool = False,
flip_vertical: bool = False,
rotate_90: int = 0,
add_noise: bool = False,
noise_std: float = 0.01,
) -> np.ndarray:
"""Apply augmentation transforms to a single image.

Parameters
----------
image : np.ndarray
Input image of shape (H, W) or (H, W, C).
flip_horizontal : bool
Whether to flip horizontally.
flip_vertical : bool
Whether to flip vertically.
rotate_90 : int
Number of 90-degree rotations (0, 1, 2, or 3).
add_noise : bool
Whether to add Gaussian noise.
noise_std : float
Standard deviation of Gaussian noise (only used if add_noise=True).

Returns
-------
np.ndarray
Augmented image with the same dtype.
"""
result = image.copy()

if flip_horizontal:
result = np.flip(result, axis=1)

if flip_vertical:
result = np.flip(result, axis=0)

if rotate_90 > 0:
axes = (0, 1)
result = np.rot90(result, k=rotate_90, axes=axes)

if add_noise:
noise = np.random.normal(0, noise_std, result.shape)
result = result + noise

return np.ascontiguousarray(result)


def stratified_split(
labels: np.ndarray,
train_ratio: float = 0.7,
val_ratio: float = 0.15,
test_ratio: float = 0.15,
random_seed: int = 42,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Split dataset indices into train/val/test with stratification.

Ensures each class is proportionally represented in all splits.

Parameters
----------
labels : np.ndarray
Array of integer class labels.
train_ratio : float
Proportion for training set.
val_ratio : float
Proportion for validation set.
test_ratio : float
Proportion for test set.
random_seed : int
Random seed for reproducibility.

Returns
-------
tuple of np.ndarray
(train_indices, val_indices, test_indices)

Raises
------
ValueError
If ratios don't sum to approximately 1.0.
"""
ratio_sum = train_ratio + val_ratio + test_ratio
if abs(ratio_sum - 1.0) > 0.01:
raise ValueError(
f"Ratios must sum to 1.0, got {ratio_sum:.2f}"
)

rng = np.random.RandomState(random_seed)
classes = np.unique(labels)

train_idx: List[int] = []
val_idx: List[int] = []
test_idx: List[int] = []

for cls in classes:
cls_indices = np.where(labels == cls)[0]
rng.shuffle(cls_indices)

n = len(cls_indices)
n_train = int(n * train_ratio)
n_val = int(n * val_ratio)

train_idx.extend(cls_indices[:n_train])
val_idx.extend(cls_indices[n_train:n_train + n_val])
test_idx.extend(cls_indices[n_train + n_val:])

return (
np.array(train_idx),
np.array(val_idx),
np.array(test_idx),
)


def compute_class_weights(
labels: np.ndarray,
method: str = "balanced",
) -> Dict[int, float]:
"""Compute class weights for handling imbalanced datasets.

Parameters
----------
labels : np.ndarray
Array of integer class labels.
method : str
Weighting method:
- 'balanced': Inversely proportional to class frequency
- 'sqrt': Square root of balanced weights (softer rebalancing)

Returns
-------
dict
Mapping from class label to weight.
"""
classes, counts = np.unique(labels, return_counts=True)
n_samples = len(labels)
n_classes = len(classes)

weights: Dict[int, float] = {}

for cls, count in zip(classes, counts):
if method == "balanced":
weights[int(cls)] = n_samples / (n_classes * count)
elif method == "sqrt":
weights[int(cls)] = np.sqrt(n_samples / (n_classes * count))
else:
raise ValueError(f"Unknown method '{method}'")

return weights


def create_image_patches(
image: np.ndarray,
patch_size: int = 32,
stride: int = 16,
) -> np.ndarray:
"""Extract overlapping patches from an image.

Useful for patch-based training of gravitational lens models.

Parameters
----------
image : np.ndarray
Input image of shape (H, W) or (H, W, C).
patch_size : int
Size of each square patch.
stride : int
Step size between consecutive patches.

Returns
-------
np.ndarray
Array of patches with shape (N, patch_size, patch_size, ...).
"""
h, w = image.shape[:2]
patches = []

for y in range(0, h - patch_size + 1, stride):
for x in range(0, w - patch_size + 1, stride):
patch = image[y:y + patch_size, x:x + patch_size]
patches.append(patch)

return np.array(patches)
131 changes: 131 additions & 0 deletions test_preprocessing_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""Tests for the preprocessing_utils module."""

import numpy as np
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from preprocessing_utils import (
normalize_images,
augment_image,
stratified_split,
compute_class_weights,
create_image_patches,
)


def test_minmax_normalization():
imgs = np.random.randint(0, 255, (5, 64, 64)).astype(float)
normed = normalize_images(imgs, method="minmax")
assert abs(normed.min()) < 1e-10, f"Min should be ~0, got {normed.min()}"
assert abs(normed.max() - 1.0) < 1e-10, f"Max should be ~1, got {normed.max()}"
print("[PASS] test_minmax_normalization")


def test_zscore_normalization():
imgs = np.random.randn(5, 64, 64) * 50 + 100
normed = normalize_images(imgs, method="zscore")
assert abs(normed.mean()) < 1e-10, f"Mean should be ~0, got {normed.mean()}"
assert abs(normed.std() - 1.0) < 1e-10, f"Std should be ~1, got {normed.std()}"
print("[PASS] test_zscore_normalization")


def test_robust_normalization():
imgs = np.random.randn(5, 64, 64) * 10
normed = normalize_images(imgs, method="robust")
assert normed.shape == imgs.shape
print("[PASS] test_robust_normalization")


def test_invalid_normalization():
imgs = np.random.randn(5, 64, 64)
try:
normalize_images(imgs, method="invalid")
assert False, "Should have raised ValueError"
except ValueError:
pass
print("[PASS] test_invalid_normalization")


def test_augment_flip():
img = np.arange(16).reshape(4, 4).astype(float)
flipped_h = augment_image(img, flip_horizontal=True)
assert flipped_h[0, 0] == img[0, 3]
flipped_v = augment_image(img, flip_vertical=True)
assert flipped_v[0, 0] == img[3, 0]
print("[PASS] test_augment_flip")


def test_augment_rotate():
img = np.arange(16).reshape(4, 4).astype(float)
rotated = augment_image(img, rotate_90=1)
assert rotated.shape == (4, 4)
print("[PASS] test_augment_rotate")


def test_augment_noise():
img = np.ones((64, 64))
noisy = augment_image(img, add_noise=True, noise_std=0.1)
assert not np.allclose(img, noisy), "Noisy image should differ"
print("[PASS] test_augment_noise")


def test_stratified_split():
labels = np.array([0]*100 + [1]*100 + [2]*100)
train, val, test = stratified_split(labels, 0.7, 0.15, 0.15)
assert len(train) + len(val) + len(test) == 300
assert len(train) == 210 # 70% of 300
print("[PASS] test_stratified_split")


def test_stratified_split_invalid_ratios():
labels = np.array([0, 1, 2])
try:
stratified_split(labels, 0.5, 0.5, 0.5)
assert False, "Should have raised ValueError"
except ValueError:
pass
print("[PASS] test_stratified_split_invalid_ratios")


def test_class_weights_balanced():
labels = np.array([0]*10 + [1]*90)
weights = compute_class_weights(labels, method="balanced")
assert weights[0] > weights[1], "Minority class should have higher weight"
print("[PASS] test_class_weights_balanced")


def test_class_weights_sqrt():
labels = np.array([0]*10 + [1]*90)
weights = compute_class_weights(labels, method="sqrt")
assert weights[0] > weights[1]
balanced = compute_class_weights(labels, method="balanced")
assert weights[0] < balanced[0], "sqrt weights should be softer"
print("[PASS] test_class_weights_sqrt")


def test_create_patches():
img = np.random.randn(64, 64)
patches = create_image_patches(img, patch_size=32, stride=16)
assert patches.shape[1] == 32
assert patches.shape[2] == 32
expected_n = ((64 - 32) // 16 + 1) ** 2 # 3*3 = 9
assert patches.shape[0] == expected_n, f"Expected {expected_n}, got {patches.shape[0]}"
print("[PASS] test_create_patches")


if __name__ == "__main__":
test_minmax_normalization()
test_zscore_normalization()
test_robust_normalization()
test_invalid_normalization()
test_augment_flip()
test_augment_rotate()
test_augment_noise()
test_stratified_split()
test_stratified_split_invalid_ratios()
test_class_weights_balanced()
test_class_weights_sqrt()
test_create_patches()
print("\n=== All 12 tests passed! ===")