Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Dataset configuration for Physics-Informed Super Resolution
# Externalizes parameters that are otherwise hardcoded in training notebooks.
# Reference: config.json holds physics/model params; this file holds dataset params.

# Each model corresponds to a simulated telescope (see README Section 4).
models:
model_1:
data_dir: "../Simulations/data_model_1/"
eval_dir: "../Simulations/data_model_1/"
resolution: 0.05
alpha_t_scaling: 0.0025
model_2:
data_dir: "../Simulations/data_model_2/"
eval_dir: "../Simulations/data_model_2/"
resolution: 0.101
alpha_t_scaling: 0.0025
model_3:
data_dir: "../Simulations/data_model_3/"
eval_dir: "../Simulations/data_model_3/"
resolution: 0.08
alpha_t_scaling: 0.0025

# Dark matter substructure classes
classes:
train: ["no_sub", "axion", "cdm"]
eval_lr: ["no_sub", "axion", "cdm"]
eval_hr: ["no_sub_HR", "axion_HR", "cdm_HR"]

# Dataset sizing
num_samples: 5000
train_val_split: 0.8

# Training defaults
batch_size: 5
epochs: 150
learning_rate: 1.0e-5
num_workers: 15

# Image properties (should stay consistent with config.json)
image_shape: 75
magnification: 2
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""
Validates dataset integrity before training.

Checks that .npy files exist for every expected class and index,
that LR and HR image counts match, and that array shapes are
consistent with the configured image dimensions and scale factor.

Run from the module root:
python validate_dataset.py --config dataset_config.yaml

Exits with code 0 if all checks pass, 1 otherwise.
"""

import argparse
import os
import sys

import numpy as np
import yaml


def load_config(path):
with open(path) as f:
return yaml.safe_load(f)


def check_directory(directory, label):
"""Return True if directory exists, print error otherwise."""
if not os.path.isdir(directory):
print(f"[FAIL] {label}: directory not found: {directory}")
return False
return True


def check_class_files(data_dir, class_name, num_samples, aux="_sim_"):
"""
Verify that {class_name}/{class_name}{aux}{i}.npy exists for
i in [0, num_samples) and that all arrays share the same shape.
Returns (missing_count, shape_mismatches, expected_shape).
"""
missing = []
shapes = {}
class_dir = os.path.join(data_dir, class_name)

if not os.path.isdir(class_dir):
print(f" [FAIL] class directory missing: {class_dir}")
return num_samples, 0, None

for i in range(num_samples):
filename = f"{class_name}{aux}{i}.npy"
filepath = os.path.join(class_dir, filename)
if not os.path.isfile(filepath):
missing.append(i)
continue
try:
arr = np.load(filepath)
shapes[i] = arr.shape
except Exception as e:
print(f" [WARN] could not load {filepath}: {e}")
missing.append(i)

if not shapes:
return len(missing), 0, None

# determine the most common shape as the expected one
shape_counts = {}
for s in shapes.values():
shape_counts[s] = shape_counts.get(s, 0) + 1
expected_shape = max(shape_counts, key=shape_counts.get)

mismatches = sum(1 for s in shapes.values() if s != expected_shape)
return len(missing), mismatches, expected_shape


def check_hr_lr_consistency(
data_dir, lr_classes, hr_classes, num_samples
):
"""
Verify that HR and LR evaluation sets have matching file counts
per class pair (e.g. no_sub <-> no_sub_HR).
"""
errors = 0
for lr_cls, hr_cls in zip(lr_classes, hr_classes):
lr_dir = os.path.join(data_dir, lr_cls)
hr_dir = os.path.join(data_dir, hr_cls)

lr_count = 0
hr_count = 0
if os.path.isdir(lr_dir):
lr_count = len(
[f for f in os.listdir(lr_dir) if f.endswith(".npy")]
)
if os.path.isdir(hr_dir):
hr_count = len(
[f for f in os.listdir(hr_dir) if f.endswith(".npy")]
)

if lr_count != hr_count:
print(
f" [FAIL] count mismatch: {lr_cls}={lr_count} vs "
f"{hr_cls}={hr_count}"
)
errors += 1
else:
print(f" [OK] {lr_cls}={lr_count}, {hr_cls}={hr_count}")
return errors


def check_shape_vs_config(shape, expected_dim, label):
"""Check that array spatial dims match the configured image_shape."""
if shape is None:
return 0 # already reported as missing
spatial = shape[-1] # last dim is the image width for square images
if spatial != expected_dim:
print(
f" [FAIL] {label}: shape {shape}, expected "
f"spatial dim {expected_dim}"
)
return 1
return 0


def main():
parser = argparse.ArgumentParser(
description="Validate dataset before training"
)
parser.add_argument(
"--config",
default="dataset_config.yaml",
help="path to dataset config file",
)
args = parser.parse_args()

cfg = load_config(args.config)
num_samples = cfg["num_samples"]
image_shape = cfg["image_shape"]
magnification = cfg["magnification"]
hr_shape = image_shape * magnification

total_errors = 0

for model_name, model_cfg in cfg["models"].items():
data_dir = model_cfg["data_dir"]
print(f"\n{'='*50}")
print(f"Validating {model_name}: {data_dir}")
print(f"{'='*50}")

if not check_directory(data_dir, model_name):
total_errors += 1
continue

# check training classes (LR simulations)
print(f"\n Training classes (simulations):")
for cls in cfg["classes"]["train"]:
missing, mismatches, shape = check_class_files(
data_dir, cls, num_samples, aux="_sim_"
)
status = "OK" if (missing == 0 and mismatches == 0) else "FAIL"
print(
f" [{status}] {cls}: {num_samples - missing}/{num_samples} "
f"files, {mismatches} shape mismatches, shape={shape}"
)
total_errors += int(missing > 0) + int(mismatches > 0)

# validate spatial dimensions against config
total_errors += check_shape_vs_config(
shape, image_shape, f"{cls} LR"
)

# check deflection angle files
print(f"\n Deflection angles:")
for cls in cfg["classes"]["train"]:
missing, mismatches, shape = check_class_files(
data_dir, cls, num_samples, aux="_alpha_"
)
status = "OK" if (missing == 0 and mismatches == 0) else "FAIL"
print(
f" [{status}] {cls}: {num_samples - missing}/{num_samples} "
f"files, {mismatches} shape mismatches, shape={shape}"
)
total_errors += int(missing > 0) + int(mismatches > 0)

# check eval HR classes and LR/HR count consistency
eval_dir = model_cfg.get("eval_dir", data_dir)
print(f"\n Evaluation HR/LR consistency ({eval_dir}):")
total_errors += check_hr_lr_consistency(
eval_dir,
cfg["classes"]["eval_lr"],
cfg["classes"]["eval_hr"],
num_samples,
)

# check HR image shapes match expected upscaled dimensions
print(f"\n HR shape check (expected spatial dim: {hr_shape}):")
for cls in cfg["classes"]["eval_hr"]:
_, _, shape = check_class_files(
eval_dir, cls, num_samples, aux="_sim_"
)
total_errors += check_shape_vs_config(
shape, hr_shape, f"{cls} HR"
)

print(f"\n{'='*50}")
if total_errors == 0:
print("All checks passed.")
else:
print(f"{total_errors} issue(s) found.")
print(f"{'='*50}")
sys.exit(0 if total_errors == 0 else 1)


if __name__ == "__main__":
main()