Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import numpy as np

import os
class LensingDataset(torch.utils.data.Dataset):
def __init__(self, directory, classes, num_samples):
"""
Expand All @@ -20,15 +20,26 @@ def __len__(self):
"""
return self.num_samples*len(self.classes)





def __getitem__(self, index):
"""
Supplies LR images

:param index: Index in the dataset to look for
:return: LR image, min-max normalized
"""
selected_class = self.classes[index//self.num_samples]
class_index = index%self.num_samples
image = torch.tensor(np.array([np.load(self.directory+selected_class+'/sim_%d.npy'%(class_index))]))
image = (image - torch.min(image))/(torch.max(image)-torch.min(image))
selected_class = self.classes[index // self.num_samples]
class_index = index % self.num_samples

# USE os.path.join for portability
file_path = os.path.join(self.directory, selected_class, 'sim_%d.npy' % class_index)

image = torch.tensor(np.array([np.load(file_path)]))

# Small safety check: handle division by zero if image is blank
img_min = torch.min(image)
img_max = torch.max(image)
if img_max > img_min:
image = (image - img_min) / (img_max - img_min)

return image
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def parse_args():
help='weight of the vdl loss')

# Performance / architecture options
parser.add_argument('--resolution', type=float,
help='arcsecond per pixel resolution the images are captured in')
parser.add_argument('--resolution', type=float, default=0.1,
help='arcsecond per pixel resolution the images are captured in')
parser.add_argument('--magnification', type=int, default=2,
help='magnification value achieved by the SR network')
parser.add_argument('--n-mag', type=int, default=1,
Expand Down Expand Up @@ -137,18 +137,22 @@ def parse_args():
'hyperparameters',
'|param|value|\n|-|-|\n%s'%('\n'.join([f'|{key}|{value}' for key, value in vars(args).items()])),
)

# --- load precomputed sparse mappings and maps --------------------------
# These files must exist in working dir. They are moved to args.device below.
cross_grid_to_log = torch.load('scatter_to_log_128.pt').to(args.device)
cross_grid_forward_from_log = torch.load('forward_from_log_128.pt').to(args.device)
cross_grid_from_log = torch.load('scatter_from_log_128.pt').to(args.device)
cross_grid_backward = torch.load('sparse_grid_fracs_euclid_backward.pt').to(args.device)

# convergence maps
source_convergence_map = torch.load('source_convergence_map.pt').to(args.device)
image_convergence_map = torch.load('image_convergence_map.pt').to(args.device)

# Get the script directory
script_dir = os.path.dirname(os.path.abspath(__file__))

# Define the folder where the grids are kept
grid_dir = os.path.join(script_dir, 'grid_matrices')

# Load the files from that specific folder
cross_grid_to_log = torch.load(os.path.join(grid_dir, 'scatter_to_log_128.pt')).to(args.device)
cross_grid_forward_from_log = torch.load(os.path.join(grid_dir, 'forward_from_log_128.pt')).to(args.device)
cross_grid_from_log = torch.load(os.path.join(grid_dir, 'scatter_from_log_128.pt')).to(args.device)
cross_grid_backward = torch.load(os.path.join(grid_dir, 'sparse_grid_fracs_euclid_backward.pt')).to(args.device)

# Check if these maps are also in grid_matrices or the root
source_convergence_map = torch.load(os.path.join(grid_dir, 'source_convergence_map.pt')).to(args.device)
image_convergence_map = torch.load(os.path.join(grid_dir, 'image_convergence_map.pt')).to(args.device)

# --- PSF kernel setup ---------------------------------------------------
# gaussian_kernel returns (Z, X, Y) as numpy arrays in the previous file.
# Converting to torch.tensor is fine but be mindful of dtype/device.
Expand Down