Skip to content
5 changes: 5 additions & 0 deletions src/pyslice/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ def fftfreq(n, d, dtype=DEFAULT_FLOAT_DTYPE, device=DEFAULT_DEVICE):
else:
return xp.fft.fftfreq(n, d, dtype=dtype)

def expand_dims(ary,d):
if xp != np:
return xp.unsqueeze(ary,dim=d)
else:
return np.expand_dims(ary,d)

def exp(x):
return xp.exp(x)
Expand Down
208 changes: 55 additions & 153 deletions src/pyslice/multislice/calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .trajectory import Trajectory
from ..postprocessing.wf_data import WFData
from .sed import SED
from ..backend import zeros,expand_dims

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -208,13 +209,9 @@ def setup(
self.float_dtype = np.float64

# Storage: [probe, frame, x, y, layer] - matches WFData expected format
n_layers = nz if "slices" in cache_levels else 1
if TORCH_AVAILABLE and self.device is not None:
self.wavefunction_data = torch.zeros((self.n_probes, self.n_frames, nx, ny, n_layers),
self.n_layers = nz if "slices" in cache_levels else 1
self.wavefunction_data = zeros((self.n_probes, self.n_frames, nx, ny, self.n_layers),
dtype=self.complex_dtype, device=self.device)
else:
self.wavefunction_data = np.zeros((self.n_probes, self.n_frames, nx, ny, n_layers),
dtype=self.complex_dtype)

def run(self) -> WFData:

Expand All @@ -231,35 +228,60 @@ def run(self) -> WFData:
with tqdm(total=self.n_frames, desc="Processing frames", unit="frame") as pbar:
for frame_idx in range(self.n_frames):
cache_file = self.output_dir / f"frame_{frame_idx}.npy"
# Show detailed progress for single-frame runs
show_progress = (frame_idx == 0 and self.n_frames == 1)

positions = self.trajectory.positions[frame_idx]
atom_types = self.trajectory.atom_types

args = [ frame_idx, positions, atom_types, self.xs, self.ys, self.zs,
self.aperture, self.voltage_eV, self.base_probe, self.probe_positions, self.element_map,
cache_file, self.cache_levels, self.slice_axis, self.device ]

# Process frame
if frame_idx == 0 and self.n_frames == 1:
args[0] = -1

frame_idx_result, frame_data, was_cached = _process_frame_worker_torch(args)

# crop frame's diffraction image
frame_data = frame_data[:,self.i1:self.i2,self.j1:self.j2,:,:]

# Store result
for probe_idx in range(self.n_probes):
if "slices" in self.cache_levels:
# frame_data shape: (n_probes, nx, ny, n_slices, 1)
self.wavefunction_data[probe_idx, frame_idx, :, :, :] = frame_data[probe_idx, :, :, :, 0]
atom_type_names = []
for atom_type in atom_types:
if atom_type in self.element_map:
atom_type_names.append(self.element_map[atom_type])
else:
self.wavefunction_data[probe_idx, frame_idx, :, :, 0] = frame_data[probe_idx, :, :, 0, 0]

if was_cached:
atom_type_names.append(atom_type)

# frame_data should always be shaped: n_probes,nkx,nky,n_layers,1 (idk why there's a trailing 1)
cache_exists,frame_data = checkCache(cache_file,self.cache_levels)

if cache_exists:
frames_cached += 1
else:
potential = Potential(self.xs, self.ys, self.zs, positions, atom_type_names, kind="kirkland", device=self.device, slice_axis=self.slice_axis, progress=show_progress, cache_dir=cache_file.parent if "potentials" in self.cache_levels else None, frame_idx = frame_idx)

n_probes = len(self.probe_positions)
nx, ny = len(self.xs), len(self.ys)
n_slices = len(self.zs)

batched_probes = create_batched_probes(self.base_probe, self.probe_positions, self.device)
# Propagate returns: [l,p,x,y] where l,p are both optional (if store_all_slices=True, and if n_probes>1)
exit_waves_batch = Propagate(batched_probes, potential, self.device, progress=show_progress, onthefly=True, store_all_slices = ("slices" in self.cache_levels) )
#print(exit_waves_batch.shape)
if n_probes==1 and "slices" not in self.cache_levels:
exit_waves_batch = expand_dims(exit_waves_batch,0)
if "slices" not in self.cache_levels:
exit_waves_batch = expand_dims(exit_waves_batch,0)
#print(exit_waves_batch.shape)
# frame_data is always: p,x,y,l,1 (self.wavefunction_data expects p,t,x,y,l, since we loop time. recall Propagate gave l,p,x,y)
frame_data = zeros((n_probes, nx, ny, self.n_layers,1), dtype=self.complex_dtype, device=self.device)
#print(frame_data.shape)
for layer_idx in range(self.n_layers):
kwarg = {"dim":(-2,-1)} if TORCH_AVAILABLE else {"axes":(-2,-1)}
exit_waves_k = xp.fft.fft2(exit_waves_batch[layer_idx,:,:,:], **kwarg) # l,p,x,y --> p,x,y
diffraction_patterns = xp.fft.fftshift(exit_waves_k, **kwarg)
cropped = diffraction_patterns[:,self.i1:self.i2,self.j1:self.j2]
frame_data[:,:,:,layer_idx,0] = cropped # load p,x,y --> p,x,y,l,1 indices

# Convert to CPU numpy array for saving
if TORCH_AVAILABLE and hasattr(frame_data, 'cpu'):
frame_data_cpu = frame_data.cpu().numpy()
else:
frame_data_cpu = frame_data

if "exitwaves" in self.cache_levels or "slices" in self.cache_levels:
np.save(cache_file, frame_data_cpu)
frames_computed += 1


self.wavefunction_data[:, frame_idx, :, :, :] = frame_data[:, :, :, :, 0] # load p,x,y,l,1 --> p,t,x,y,l indices
# Update progress bar for this frame
pbar.update(1)

Expand Down Expand Up @@ -323,137 +345,17 @@ def run(self) -> WFData:
# Save if requested - psi files already saved during processing

return wf_data


logging_tracker=[]
def _process_frame_worker_torch(args):
frame_idx, positions, atom_types, xs, ys, zs, aperture, eV, probe, probe_positions, element_map, cache_file, cache_levels , slice_axis, device = args

def checkCache(cache_file,cache_levels):
global logging_tracker
if cache_file.exists() and ( "exitwaves" in cache_levels or "slices" in cache_levels ):
global logging_tracker
parent = str(cache_file.parent)
if "cache_exists-"+parent not in logging_tracker:
logging_tracker.append("cache_exists-"+parent)
logging.warning("One or more frames reloaded from cache: "+str(cache_file.parent))
return frame_idx, xp.asarray(np.load(cache_file)), True # if always saving as numpy, then must cast to torch array if re-reading cache file back in

# Use the device passed from the calculator, or auto-detect if None
if TORCH_AVAILABLE:
if device is not None:
worker_device = device
else:
worker_device = torch.device('cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))

# Set dtype based on worker device
if worker_device.type == 'mps':
worker_complex_dtype = torch.complex64
worker_float_dtype = torch.float32
else:
worker_complex_dtype = torch.complex128
worker_float_dtype = torch.float64
else:
worker_device = None
worker_complex_dtype = np.complex128
worker_float_dtype = np.float64

atom_type_names = []
for atom_type in atom_types:
if atom_type in element_map:
atom_type_names.append(element_map[atom_type])
else:
atom_type_names.append(atom_type)

#try:
potential = Potential(xs, ys, zs, positions, atom_type_names, kind="kirkland", device=worker_device, slice_axis=slice_axis, progress=(frame_idx==-1), cache_dir=cache_file.parent if "potentials" in cache_levels else None, frame_idx = frame_idx)

n_probes = len(probe_positions)
nx, ny = len(xs), len(ys)
n_slices = len(zs)

batched_probes = create_batched_probes(probe, probe_positions, worker_device)
exit_waves_batch = Propagate(batched_probes, potential, worker_device, progress=(frame_idx==-1), onthefly=True, store_all_slices = ("slices" in cache_levels) )

if "slices" in cache_levels:
# exit_waves_batch shape: (n_slices, n_probes, nx, ny)
if TORCH_AVAILABLE and worker_device is not None:
frame_data = torch.zeros((n_probes, nx, ny, n_slices, 1), dtype=worker_complex_dtype, device=worker_device)
else:
frame_data = np.zeros((n_probes, nx, ny, n_slices, 1), dtype=worker_complex_dtype)

# Convert all slices to k-space
for slice_idx in range(n_slices):
slice_waves = exit_waves_batch[slice_idx, :, :, :] # (n_probes, nx, ny)
kwarg = {"dim":(-2,-1)} if TORCH_AVAILABLE else {"axes":(-2,-1)}
waves_k = xp.fft.fft2(slice_waves, **kwarg)
diffraction_patterns = xp.fft.fftshift(waves_k, **kwarg)

# Store in frame_data
for i in range(n_probes):
frame_data[i, :, :, slice_idx, 0] = diffraction_patterns[i, :, :]
else:
# exit_waves_batch shape: (n_probes, nx, ny)
if TORCH_AVAILABLE and worker_device is not None:
frame_data = torch.zeros((n_probes, nx, ny, 1, 1), dtype=worker_complex_dtype, device=worker_device)
else:
frame_data = np.zeros((n_probes, nx, ny, 1, 1), dtype=worker_complex_dtype)

# Convert all exit waves to k-space
kwarg = {"dim":(-2,-1)} if TORCH_AVAILABLE else {"axes":(-2,-1)}
exit_waves_k = xp.fft.fft2(exit_waves_batch, **kwarg)
diffraction_patterns = xp.fft.fftshift(exit_waves_k, **kwarg)

# Store results
frame_data[:, :, :, 0, 0] = diffraction_patterns #.cpu().numpy()
#else:
# # Fallback to individual processing
# for probe_idx, (px, py) in enumerate(probe_positions):
# shifted_probe = probe.copy()
#
# probe_k = torch.fft.fft2(shifted_probe.array)
#
# kx_shift = torch.exp(2j * torch.pi * shifted_probe.kxs[:, None] * px)
# ky_shift = torch.exp(2j * torch.pi * shifted_probe.kys[None, :] * py)
# probe_k_shifted = probe_k * kx_shift * ky_shift
#
# shifted_probe.array = torch.fft.ifft2(probe_k_shifted)
#
# exit_wave_torch = PropagateTorch(shifted_probe, potential, worker_device)
#
# exit_wave_k = torch.fft.fft2(exit_wave_torch)
# diffraction_pattern = torch.fft.fftshift(exit_wave_k)
#
# frame_data[probe_idx, :, :, 0, 0] = diffraction_pattern.cpu().numpy()

# Convert to CPU numpy array for saving
if TORCH_AVAILABLE and hasattr(frame_data, 'cpu'):
frame_data_cpu = frame_data.cpu().numpy()
else:
frame_data_cpu = frame_data

if "exitwaves" in cache_levels or "slices" in cache_levels:
np.save(cache_file, frame_data_cpu)

return frame_idx, frame_data, False

#except Exception as e:
# logger.error(f"Error processing frame {frame_idx} with PyTorch: {e}")
# from .potential import Potential
# from .multislice_npy import Probe, Propagate
#
# potential = Potential(xs, ys, zs, positions, atom_type_names, kind="kirkland")
# probe = Probe(xs, ys, aperture, eV)
#
# n_probes = len(probe_positions)
# nx, ny = len(xs), len(ys)
## frame_data = np.zeros((n_probes, nx, ny, 1, 1), dtype=complex)
#
# for probe_idx, (px, py) in enumerate(probe_positions):
# exit_wave = Propagate(probe, potential)
# diffraction_pattern = np.fft.fftshift(np.fft.fft2(exit_wave))
# frame_data[probe_idx, :, :, 0, 0] = diffraction_pattern
#
# np.save(cache_file, frame_data)
# return frame_idx, frame_data, False
return True,xp.asarray(np.load(cache_file)) # if always saving as numpy, then must cast to torch array if re-reading cache file back in
return False,0


class SEDCalculator:
Expand Down
12 changes: 10 additions & 2 deletions src/pyslice/multislice/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
import logging,os
from tqdm import tqdm
from ..backend import zeros

try:
import torch ; xp = torch
Expand Down Expand Up @@ -286,7 +287,10 @@ def calculateSlice(slice_idx):
if self.cache_dir is not None:
cache_file = self.cache_dir / ("potential_"+str(frame_idx)+"_"+str(slice_idx)+".npy")
if cache_file is not None and os.path.exists(cache_file):
return np.load(cache_file)
Z = np.load(cache_file)
if TORCH_AVAILABLE:
return xp.from_numpy(Z).to(device)
return Z

# Initialize slice of potential array using xp with conditional device
device_kwargs = {'device': self.device } if self.use_torch else {}
Expand Down Expand Up @@ -356,7 +360,11 @@ def calculateSlice(slice_idx):
dy = self.ys[1] - self.ys[0]
Z = real / (dx**2 * dy**2)
if cache_file is not None:
np.save(cache_file,Z)
if TORCH_AVAILABLE and hasattr(Z, 'cpu'):
Z_cpu = Z.cpu().numpy()
else:
Z_cpu = Z
np.save(cache_file,Z_cpu)
return Z

self.calculateSlice = calculateSlice
Expand Down
Loading