diff --git a/src/pyslice/backend.py b/src/pyslice/backend.py index 74ba19e..61b1500 100644 --- a/src/pyslice/backend.py +++ b/src/pyslice/backend.py @@ -112,12 +112,23 @@ def zeros(dims, dtype=DEFAULT_FLOAT_DTYPE, device=DEFAULT_DEVICE): array = xp.zeros(dims, dtype=dtype) return array +def ones(dims, dtype=DEFAULT_FLOAT_DTYPE, device=DEFAULT_DEVICE): + if xp != np: + return xp.ones(dims, dtype=dtype, device=device) + else: + return xp.ones(dims, dtype=dtype) + def fftfreq(n, d, dtype=DEFAULT_FLOAT_DTYPE, device=DEFAULT_DEVICE): if xp != np: return xp.fft.fftfreq(n, d, dtype=dtype, device=device) else: return xp.fft.fftfreq(n, d, dtype=dtype) +def expand_dims(ary,d): + if xp != np: + return xp.unsqueeze(ary,dim=d) + else: + return np.expand_dims(ary,d) def exp(x): return xp.exp(x) diff --git a/src/pyslice/multislice/calculators.py b/src/pyslice/multislice/calculators.py index 47dcec9..cc4e63a 100644 --- a/src/pyslice/multislice/calculators.py +++ b/src/pyslice/multislice/calculators.py @@ -34,6 +34,7 @@ from .trajectory import Trajectory from ..postprocessing.wf_data import WFData from .sed import SED +from ..backend import zeros,expand_dims logger = logging.getLogger(__name__) @@ -79,7 +80,8 @@ def __init__(self, device=None, force_cpu=False): } def _generate_cache_key(self, trajectory, aperture, voltage_eV, - slice_thickness, sampling, probe_positions): + slice_thickness, sampling, probe_positions, + spatial_decoherence, temporal_decoherence): """Generate unique cache key for simulation parameters.""" firstNAtoms = [ str(np.round(v,4)) for v in trajectory.positions[0,:100,0] ] # first timestep's first 10 atom's x positions params = { @@ -95,6 +97,10 @@ def _generate_cache_key(self, trajectory, aperture, voltage_eV, 'probe_positions': probe_positions, 'backend': 'pytorch' if TORCH_AVAILABLE else 'numpy', } + if spatial_decoherence is not None: + params['spatial_decoherence'] = spatial_decoherence + if temporal_decoherence is not None: + params['temporal_decoherence'] = temporal_decoherence param_str = str(sorted(params.items())) return hashlib.md5(param_str.encode()).hexdigest()[:12] @@ -149,13 +155,6 @@ def setup( self.cache_levels = cache_levels self.max_kx = max_kx self.max_ky = max_ky - - # Generate cache key and setup output directory - cache_key = self._generate_cache_key(trajectory, aperture, voltage_eV, - slice_thickness, sampling, probe_positions) - #print(cache_key) - self.output_dir = Path("psi_data/" + ("torch" if TORCH_AVAILABLE else "numpy") + "_"+cache_key) - self.output_dir.mkdir(parents=True, exist_ok=True) # Set up spatial grids xs,ys,zs,lx,ly,lz=gridFromTrajectory(trajectory,sampling=sampling,slice_thickness=slice_thickness) @@ -176,23 +175,13 @@ def setup( self.kys = self.kys[self.j1:self.j2] self.nx = self.i2 - self.i1 ; self.ny = self.j2 - self.j1 ; nx = self.nx ; ny = self.ny - # Preferred to pass probe_xs and probe_ys from which we will define a grid - if self.probe_xs is not None and self.probe_ys is not None: - x,y = np.meshgrid(self.probe_xs,self.probe_ys,indexing='ij') - self.probe_positions = np.asarray(list(zip(x.flat,y.flat))) - lx = len(self.probe_xs) ; ly = len(self.probe_ys) - - # Set up default probe position if not provided - if self.probe_positions is None: - self.probe_positions = [(lx/2, ly/2)] # Center probe - self.probe_xs = [lx/2] ; self.probe_ys = [ly/2] - # Create probe on the correct device from the start - self.base_probe = Probe(xs, ys, self.aperture, self.voltage_eV, device=self.device) + self.base_probe = Probe(xs, ys, self.aperture, self.voltage_eV, device=self.device, probe_xs=self.probe_xs, probe_ys=self.probe_ys, probe_positions=self.probe_positions) + self.base_probe.applyShifts() # Initialize storage for results self.n_frames = trajectory.n_frames - self.n_probes = len(self.probe_positions) + #self.n_probes = len(self.base_probe.probe_positions) # Set dtype based on the actual device we're using if TORCH_AVAILABLE and self.device is not None: @@ -206,17 +195,26 @@ def setup( self.complex_dtype = np.complex128 self.float_dtype = np.float64 - # Storage: [probe, frame, x, y, layer] - matches WFData expected format - n_layers = nz if "slices" in cache_levels else 1 - if TORCH_AVAILABLE and self.device is not None: - self.wavefunction_data = torch.zeros((self.n_probes, self.n_frames, nx, ny, n_layers), - dtype=self.complex_dtype, device=self.device) - else: - self.wavefunction_data = np.zeros((self.n_probes, self.n_frames, nx, ny, n_layers), - dtype=self.complex_dtype) def run(self) -> WFData: + + # Generate cache key and setup output directory + cache_key = self._generate_cache_key(self.trajectory, self.aperture, self.voltage_eV, + self.slice_thickness, self.sampling, self.probe_positions, + self.base_probe.spatial_decoherence, self.base_probe.temporal_decoherence ) + #print(cache_key) + self.output_dir = Path("psi_data/" + ("torch" if TORCH_AVAILABLE else "numpy") + "_"+cache_key) + self.output_dir.mkdir(parents=True, exist_ok=True) + + + nc,npt,nx,ny = self.base_probe._array.shape + self.n_probes = nc*npt + # Storage: [probe, frame, x, y, layer] - matches WFData expected format + self.n_layers = self.nz if "slices" in self.cache_levels else 1 + self.wavefunction_data = zeros((self.n_probes, self.n_frames, nx, ny, self.n_layers), + dtype=self.complex_dtype, device=self.device) + # Process frames with caching and multiprocessing total_start_time = time.time() frames_computed = 0 @@ -230,35 +228,61 @@ def run(self) -> WFData: with tqdm(total=self.n_frames, desc="Processing frames", unit="frame") as pbar: for frame_idx in range(self.n_frames): cache_file = self.output_dir / f"frame_{frame_idx}.npy" + # Show detailed progress for single-frame runs + show_progress = (frame_idx == 0 and self.n_frames == 1) + positions = self.trajectory.positions[frame_idx] atom_types = self.trajectory.atom_types - - args = [ frame_idx, positions, atom_types, self.xs, self.ys, self.zs, - self.aperture, self.voltage_eV, self.base_probe, self.probe_positions, self.element_map, - cache_file, self.cache_levels, self.slice_axis, self.device ] - - # Process frame - if frame_idx == 0 and self.n_frames == 1: - args[0] = -1 - - frame_idx_result, frame_data, was_cached = _process_frame_worker_torch(args) - - # crop frame's diffraction image - frame_data = frame_data[:,self.i1:self.i2,self.j1:self.j2,:,:] - - # Store result - for probe_idx in range(self.n_probes): - if "slices" in self.cache_levels: - # frame_data shape: (n_probes, nx, ny, n_slices, 1) - self.wavefunction_data[probe_idx, frame_idx, :, :, :] = frame_data[probe_idx, :, :, :, 0] + atom_type_names = [] + for atom_type in atom_types: + if atom_type in self.element_map: + atom_type_names.append(self.element_map[atom_type]) else: - self.wavefunction_data[probe_idx, frame_idx, :, :, 0] = frame_data[probe_idx, :, :, 0, 0] - - if was_cached: + atom_type_names.append(atom_type) + + # frame_data should always be shaped: n_probes,nkx,nky,n_layers,1 (idk why there's a trailing 1) + cache_exists,frame_data = checkCache(cache_file,self.cache_levels) + + if cache_exists: frames_cached += 1 else: + potential = Potential(self.xs, self.ys, self.zs, positions, atom_type_names, kind="kirkland", device=self.device, slice_axis=self.slice_axis, progress=show_progress, cache_dir=cache_file.parent if "potentials" in self.cache_levels else None, frame_idx = frame_idx) + + #n_probes = nc*npt + nc,npt,nx,ny = self.base_probe._array.shape + n_slices = len(self.zs) + + #batched_probes = create_batched_probes(self.base_probe, self.probe_positions, self.device) + # Propagate returns: [l,p,x,y] where l,p are both optional (if store_all_slices=True, and if n_probes>1) + exit_waves_batch = Propagate(self.base_probe, potential, self.device, progress=show_progress, onthefly=True, store_all_slices = ("slices" in self.cache_levels) ) + #print(exit_waves_batch.shape) + #print(exit_waves_batch.shape) + #if n_probes==1 and "slices" not in self.cache_levels: + # exit_waves_batch = expand_dims(exit_waves_batch,0) + if "slices" not in self.cache_levels: + exit_waves_batch = expand_dims(exit_waves_batch,0) + #print(exit_waves_batch.shape) + # frame_data is always: p,x,y,l,1 (self.wavefunction_data expects p,t,x,y,l, since we loop time. recall Propagate gave l,p,x,y) + frame_data = zeros((self.n_probes, nx, ny, self.n_layers,1), dtype=self.complex_dtype, device=self.device) + #print(frame_data.shape) + for layer_idx in range(self.n_layers): + kwarg = {"dim":(-2,-1)} if TORCH_AVAILABLE else {"axes":(-2,-1)} + exit_waves_k = xp.fft.fft2(exit_waves_batch[layer_idx,:,:,:], **kwarg) # l,p,x,y --> p,x,y + diffraction_patterns = xp.fft.fftshift(exit_waves_k, **kwarg) + cropped = diffraction_patterns[:,self.i1:self.i2,self.j1:self.j2] + frame_data[:,:,:,layer_idx,0] = cropped # load p,x,y --> p,x,y,l,1 indices + + # Convert to CPU numpy array for saving + if TORCH_AVAILABLE and hasattr(frame_data, 'cpu'): + frame_data_cpu = frame_data.cpu().numpy() + else: + frame_data_cpu = frame_data + + if "exitwaves" in self.cache_levels or "slices" in self.cache_levels: + np.save(cache_file, frame_data_cpu) frames_computed += 1 - + + self.wavefunction_data[:, frame_idx, :, :, :] = frame_data[:, :, :, :, 0] # load p,x,y,l,1 --> p,t,x,y,l indices # Update progress bar for this frame pbar.update(1) @@ -291,7 +315,7 @@ def run(self) -> WFData: # Package results wf_data = WFData( - probe_positions=self.probe_positions, + probe_positions=self.base_probe.probe_positions, probe_xs=self.probe_xs, probe_ys=self.probe_ys, time=time_array, @@ -322,137 +346,17 @@ def run(self) -> WFData: # Save if requested - psi files already saved during processing return wf_data - logging_tracker=[] -def _process_frame_worker_torch(args): - frame_idx, positions, atom_types, xs, ys, zs, aperture, eV, probe, probe_positions, element_map, cache_file, cache_levels , slice_axis, device = args - +def checkCache(cache_file,cache_levels): + global logging_tracker if cache_file.exists() and ( "exitwaves" in cache_levels or "slices" in cache_levels ): - global logging_tracker parent = str(cache_file.parent) if "cache_exists-"+parent not in logging_tracker: logging_tracker.append("cache_exists-"+parent) logging.warning("One or more frames reloaded from cache: "+str(cache_file.parent)) - return frame_idx, xp.asarray(np.load(cache_file)), True # if always saving as numpy, then must cast to torch array if re-reading cache file back in - - # Use the device passed from the calculator, or auto-detect if None - if TORCH_AVAILABLE: - if device is not None: - worker_device = device - else: - worker_device = torch.device('cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')) - - # Set dtype based on worker device - if worker_device.type == 'mps': - worker_complex_dtype = torch.complex64 - worker_float_dtype = torch.float32 - else: - worker_complex_dtype = torch.complex128 - worker_float_dtype = torch.float64 - else: - worker_device = None - worker_complex_dtype = np.complex128 - worker_float_dtype = np.float64 - - atom_type_names = [] - for atom_type in atom_types: - if atom_type in element_map: - atom_type_names.append(element_map[atom_type]) - else: - atom_type_names.append(atom_type) - - #try: - potential = Potential(xs, ys, zs, positions, atom_type_names, kind="kirkland", device=worker_device, slice_axis=slice_axis, progress=(frame_idx==-1), cache_dir=cache_file.parent if "potentials" in cache_levels else None, frame_idx = frame_idx) - - n_probes = len(probe_positions) - nx, ny = len(xs), len(ys) - n_slices = len(zs) - - batched_probes = create_batched_probes(probe, probe_positions, worker_device) - exit_waves_batch = Propagate(batched_probes, potential, worker_device, progress=(frame_idx==-1), onthefly=True, store_all_slices = ("slices" in cache_levels) ) - - if "slices" in cache_levels: - # exit_waves_batch shape: (n_slices, n_probes, nx, ny) - if TORCH_AVAILABLE and worker_device is not None: - frame_data = torch.zeros((n_probes, nx, ny, n_slices, 1), dtype=worker_complex_dtype, device=worker_device) - else: - frame_data = np.zeros((n_probes, nx, ny, n_slices, 1), dtype=worker_complex_dtype) - - # Convert all slices to k-space - for slice_idx in range(n_slices): - slice_waves = exit_waves_batch[slice_idx, :, :, :] # (n_probes, nx, ny) - kwarg = {"dim":(-2,-1)} if TORCH_AVAILABLE else {"axes":(-2,-1)} - waves_k = xp.fft.fft2(slice_waves, **kwarg) - diffraction_patterns = xp.fft.fftshift(waves_k, **kwarg) - - # Store in frame_data - for i in range(n_probes): - frame_data[i, :, :, slice_idx, 0] = diffraction_patterns[i, :, :] - else: - # exit_waves_batch shape: (n_probes, nx, ny) - if TORCH_AVAILABLE and worker_device is not None: - frame_data = torch.zeros((n_probes, nx, ny, 1, 1), dtype=worker_complex_dtype, device=worker_device) - else: - frame_data = np.zeros((n_probes, nx, ny, 1, 1), dtype=worker_complex_dtype) - - # Convert all exit waves to k-space - kwarg = {"dim":(-2,-1)} if TORCH_AVAILABLE else {"axes":(-2,-1)} - exit_waves_k = xp.fft.fft2(exit_waves_batch, **kwarg) - diffraction_patterns = xp.fft.fftshift(exit_waves_k, **kwarg) - - # Store results - frame_data[:, :, :, 0, 0] = diffraction_patterns #.cpu().numpy() - #else: - # # Fallback to individual processing - # for probe_idx, (px, py) in enumerate(probe_positions): - # shifted_probe = probe.copy() - # - # probe_k = torch.fft.fft2(shifted_probe.array) - # - # kx_shift = torch.exp(2j * torch.pi * shifted_probe.kxs[:, None] * px) - # ky_shift = torch.exp(2j * torch.pi * shifted_probe.kys[None, :] * py) - # probe_k_shifted = probe_k * kx_shift * ky_shift - # - # shifted_probe.array = torch.fft.ifft2(probe_k_shifted) - # - # exit_wave_torch = PropagateTorch(shifted_probe, potential, worker_device) - # - # exit_wave_k = torch.fft.fft2(exit_wave_torch) - # diffraction_pattern = torch.fft.fftshift(exit_wave_k) - # - # frame_data[probe_idx, :, :, 0, 0] = diffraction_pattern.cpu().numpy() - - # Convert to CPU numpy array for saving - if TORCH_AVAILABLE and hasattr(frame_data, 'cpu'): - frame_data_cpu = frame_data.cpu().numpy() - else: - frame_data_cpu = frame_data - - if "exitwaves" in cache_levels or "slices" in cache_levels: - np.save(cache_file, frame_data_cpu) - - return frame_idx, frame_data, False - - #except Exception as e: - # logger.error(f"Error processing frame {frame_idx} with PyTorch: {e}") - # from .potential import Potential - # from .multislice_npy import Probe, Propagate - # - # potential = Potential(xs, ys, zs, positions, atom_type_names, kind="kirkland") - # probe = Probe(xs, ys, aperture, eV) - # - # n_probes = len(probe_positions) - # nx, ny = len(xs), len(ys) - ## frame_data = np.zeros((n_probes, nx, ny, 1, 1), dtype=complex) - # - # for probe_idx, (px, py) in enumerate(probe_positions): - # exit_wave = Propagate(probe, potential) - # diffraction_pattern = np.fft.fftshift(np.fft.fft2(exit_wave)) - # frame_data[probe_idx, :, :, 0, 0] = diffraction_pattern - # - # np.save(cache_file, frame_data) - # return frame_idx, frame_data, False + return True,xp.asarray(np.load(cache_file)) # if always saving as numpy, then must cast to torch array if re-reading cache file back in + return False,0 class SEDCalculator: diff --git a/src/pyslice/multislice/multislice.py b/src/pyslice/multislice/multislice.py index 8409d1a..145b06a 100644 --- a/src/pyslice/multislice/multislice.py +++ b/src/pyslice/multislice/multislice.py @@ -1,6 +1,7 @@ import numpy as np from tqdm import tqdm import logging +from ..backend import zeros,mean,ones try: import torch ; xp = torch @@ -49,7 +50,7 @@ class Probe: Significant speedup for large grid sizes through GPU-accelerated FFT operations. """ - def __init__(self, xs, ys, mrad, eV, array=None, device=None, gaussianVOA=0, preview=False): + def __init__(self, xs, ys, mrad, eV, array=None, device=None, gaussianVOA=0, preview=False, probe_xs=None, probe_ys=None, probe_positions=None): """ Initialize GPU-accelerated probe wavefunction. @@ -59,6 +60,7 @@ def __init__(self, xs, ys, mrad, eV, array=None, device=None, gaussianVOA=0, pre eV: Electron energy in eV device: PyTorch device (None for auto-detection) """ + # TORCH DEVICES AND DTYPES if TORCH_AVAILABLE: # Auto-detect device if not specified (same logic as Potential class) if device is None: @@ -84,10 +86,7 @@ def __init__(self, xs, ys, mrad, eV, array=None, device=None, gaussianVOA=0, pre self.dtype = np.float64 self.complex_dtype = np.complex128 - self.mrad = mrad - self.eV = eV - self.wavelength = wavelength(eV) - + # SET UP SPATIAL GRIDS # Convert coordinate arrays to tensors if using torch (same as Potential class) if self.use_torch: # Use as_tensor to avoid copy warning when input is already a tensor @@ -97,10 +96,40 @@ def __init__(self, xs, ys, mrad, eV, array=None, device=None, gaussianVOA=0, pre self.xs = xs self.ys = ys - nx = len(xs) - ny = len(ys) - dx = xs[1] - xs[0] - dy = ys[1] - ys[0] + nx = len(xs) ; ny = len(ys) + dx = xs[1] - xs[0] ; dy = ys[1] - ys[0] + lx = nx*dx ; ly = ny*dy + self.nx = nx ; self.dx = dx ; self.lx = lx + self.ny = ny ; self.dy = dy ; self.ly = ly + + # HANDLE PROBE POSTIONS + self.probe_xs = probe_xs + self.probe_ys = probe_ys + self.probe_positions = probe_positions + + # Preferred to pass probe_xs and probe_ys from which we will define a grid + if self.probe_xs is not None and self.probe_ys is not None: + x,y = np.meshgrid(self.probe_xs,self.probe_ys,indexing='ij') + self.probe_positions = np.asarray(list(zip(x.flat,y.flat))) + + # Set up default probe position if not provided + if self.probe_positions is None: + self.probe_positions = [(lx/2, ly/2)] # Center probe + self.probe_xs = [lx/2] ; self.probe_ys = [ly/2] + + # HANDLE BEAM PARAMS + self.mrad = mrad + #if isinstance(eV,(float,int)): + # n = 1 if array is None else len(array) + # eV = [ eV ]*n + self.eV = eV ; self.wavelength=wavelength(eV) + self.eVs = np.asarray([eV]) + if self.use_torch: + self.eVs = torch.as_tensor(self.eVs, dtype=self.dtype, device=self.device) + self.wavelengths = wavelength(self.eVs) + self.temporal_decoherence = None + self.spatial_decoherence = None + self.gaussianVOA = gaussianVOA # Set up device kwargs for unified xp interface (same as Potential class) device_kwargs = {'device': self.device, 'dtype': self.dtype} if self.use_torch else {} @@ -113,36 +142,40 @@ def __init__(self, xs, ys, mrad, eV, array=None, device=None, gaussianVOA=0, pre self._array = array.to(device=self.device, dtype=self.complex_dtype) else: self._array = xp.asarray(array) - return - - device_kwargs = {'device': self.device, 'dtype': self.dtype} if self.use_torch else {} + else: + #self._array = zeros((len(self.eV),1,nx,ny)) + #for i,w in enumerate(self.wavelength): + # self._array[i,0,:,:] = self.generate_single_probe(mrad,w,gaussianVOA,preview=preview) + self._array = zeros((1,1,nx,ny),dtype=complex_dtype) + self._array[0,0,:,:] = self.generate_single_probe(mrad,self.wavelength,self.gaussianVOA,preview=preview) + + def generate_single_probe(self,mrad,wavelength,gaussianVOA,preview=False): + nx,ny = len(self.kxs) , len(self.kys) if mrad == 0: - self._array = xp.ones((nx, ny), **device_kwargs) + return zeros((nx, ny))+1 + + reciprocal = zeros((nx, ny)) + radius = (mrad * 1e-3) / wavelength # Convert mrad to reciprocal space units + kx_grid, ky_grid = xp.meshgrid(self.kxs, self.kys, indexing='ij') + radii = xp.sqrt(kx_grid**2 + ky_grid**2) + + if gaussianVOA == 0: + mask = radii < radius + reciprocal[mask] = 1.0 else: - reciprocal = xp.zeros((nx, ny), **device_kwargs) - radius = (mrad * 1e-3) / self.wavelength # Convert mrad to reciprocal space units - - kx_grid, ky_grid = xp.meshgrid(self.kxs, self.kys, indexing='ij') - radii = xp.sqrt(kx_grid**2 + ky_grid**2) - - if gaussianVOA == 0: - mask = radii < radius - reciprocal[mask] = 1.0 - else: - from scipy.special import erf - reciprocal = 1-erf((radii-radius)/(gaussianVOA*radius)) - - if preview: - import matplotlib.pyplot as plt - fig, ax = plt.subplots() ; print(radius) - extent = (xp.min(self.kxs), xp.max(self.kxs), xp.min(self.kys), xp.max(self.kys)) - ax.imshow(xp.fft.fftshift(reciprocal.T), cmap="inferno",extent=extent) - ax.set_xlabel("kx ($\\AA^{-1}$)") - ax.set_ylabel("ky ($\\AA^{-1}$)") - plt.show() - - self._array = xp.fft.ifftshift(xp.fft.ifft2(reciprocal)) - + from scipy.special import erf + reciprocal = 1-erf((radii-radius)/(gaussianVOA*radius)) + + if preview: + import matplotlib.pyplot as plt + fig, ax = plt.subplots() ; print(radius) + extent = (xp.min(self.kxs), xp.max(self.kxs), xp.min(self.kys), xp.max(self.kys)) + ax.imshow(xp.fft.fftshift(reciprocal.T), cmap="inferno",extent=extent) + ax.set_xlabel("kx ($\\AA^{-1}$)") + ax.set_ylabel("ky ($\\AA^{-1}$)") + plt.show() + + return xp.fft.ifftshift(xp.fft.ifft2(reciprocal)) #self.array_numpy = self.array.cpu().numpy() def copy(self): @@ -192,15 +225,13 @@ def to_device(self, device): self.complex_dtype = complex_dtype return self - def plot(self,filename=None): + def plot(self,filename=None,title=None): import matplotlib.pyplot as plt fig, ax = plt.subplots() - array = self.array.T # imshow convention: y,x. our convention: x,y - - # Convert array to CPU if on GPU/MPS device + # calling self.array should convert to CPU/numpy + array = np.mean(np.absolute(self.array[:,:,:,:]),axis=0)[0,:,:] # summable,positional,x,y indices + array=array.T # imshow convention: y,x. our convention: x,y plot_array = np.absolute(array)**.25 - #if hasattr(plot_array, 'cpu'): - # plot_array = plot_array.cpu() # Convert extent values to CPU if needed (use xp for torch/numpy compatibility) xs_min = xp.amin(self.xs) @@ -218,6 +249,8 @@ def plot(self,filename=None): ax.imshow(plot_array, cmap="inferno",extent=extent) ax.set_xlabel("x ($\\AA$)") ax.set_ylabel("y ($\\AA$)") + if title is not None: + ax.set_title(title) if filename is not None: plt.savefig(filename) @@ -225,13 +258,82 @@ def plot(self,filename=None): plt.show() def defocus(self,dz): # POSITIVE DEFOCUS PUTS BEAM WAIST ABOVE SAMPLE, UNITS OF ANGSTROM + if isinstance(dz,(int,float)): + dz = zeros(len(self._array))+dz kx_grid, ky_grid = xp.meshgrid(self.kxs, self.kys, indexing='ij') k_squared = kx_grid**2 + ky_grid**2 - P = xp.exp(-1j * xp.pi * self.wavelength * dz * k_squared) - #if dz>0: - self._array = xp.fft.ifft2( P * xp.fft.fft2( self._array ) ) - #if dz<0: - # self.array = xp.fft.ifft2( xp.fft.fft2( self.array ) / P ) + P = xp.exp(-1j * xp.pi * self.wavelength * dz[:,None,None] * k_squared[None,:,:]) + nz = len(dz) ; nc,npt,nx,ny = self._array.shape #; print("nc,npt,nx,ny",nc,npt,nx,ny) + self._array = xp.fft.ifft2( P[:,None,None,:,:] * xp.fft.fft2( self._array )[None,:,:,:,:] ) + self._array = self._array.reshape((nz*nc,npt,nx,ny)) + #print("defocus",dz,"new shape",self._array.shape) + + # ORDER OF OPERATIONS IS IMPORTANT. MUST DO: addTemporalDecoherence, addSpatialDecoherence, create_batched_probes + # addTemporalDecoherence - creates new standard probes (must come first) + # addSpatialDecoherence - applies defocus (applies to existing probe(s)) + # create_batched_probes - applied shift to each probe + def addTemporalDecoherence(self,sigma_eV,N): + nc,npt,nx,ny = self._array.shape #; print("addTemporalDecoherence shape was",nc,npt,nx,ny) + if self.temporal_decoherence is not None: + print("WARNING: calling addTemporalDecoherence twice will overwrite previous") + self.temporal_decoherence = (sigma_eV,N) + eV = self.eV + self.eVs = np.linspace(eV-2*sigma_eV,eV+2*sigma_eV,N) + if self.use_torch: + self.eVs = torch.as_tensor(self.eVs, dtype=self.dtype, device=self.device) + self.wavelengths = wavelength(self.eVs) + amplitudes = np.exp(-(eV-self.eVs)**2/sigma_eV**2) + self._array = zeros((N,1,nx,ny)) + for n,eV in enumerate(self.eVs): + self._array[n,0,:,:] = amplitudes[n] * self.generate_single_probe(self.mrad,wavelength(eV),self.gaussianVOA) + nc,npt,nx,ny = self._array.shape #; print("addTemporalDecoherence expands to",nc,npt,nx,ny) + if self.spatial_decoherence is not None: + self.addSpatialDecoherence(*self.spatial_decoherence) + nc,npt,nx,ny = self._array.shape + if npt==1: + self.applyShifts() + + def addSpatialDecoherence(self,sigma_dz,N): + nc,npt,nx,ny = self._array.shape #; print("addSpatialDecoherence shape was",nc,npt,nx,ny) + if self.temporal_decoherence is not None: + print("WARNING: calling addSpatialDecoherence twice will overwrite previous") + self.spatial_decoherence = (sigma_dz,N) + dzs = np.linspace(-2*sigma_dz,2*sigma_dz,N) # suppose N=25 + amplitudes = np.exp(-dzs**2/sigma_dz**2) + nc,npt,nx,ny = self._array.shape # suppose nc=10 (addTemporalDecoherence created 10 wavelengths) + if self.use_torch: + dzs = torch.as_tensor(dzs, dtype=self.dtype, device=self.device) + self.defocus(dzs) # defocus starts with 25,10,npt,nx,ny --reshapes--> 250,npt,nx,ny + for i in range(N): # reshape to flatten loops first index last: [[0,1],[2,3]] --> [0,1,2,3] + for j in range(nc): + self._array[i*nc+j] *= amplitudes[i] + nc,npt,nx,ny = self._array.shape #; print("addSpatialDecoherence expands to",nc,npt,nx,ny) + self.eVs = ones(N)[:,None]*self.eVs[None,:] # defocus expands into nz,nc then flattens to nz*nc + self.eVs = self.eVs.reshape(nc) + self.wavelengths = ones(N)[:,None]*self.wavelengths[None,:] + self.wavelengths = self.wavelengths.reshape(nc) + if npt==1: + self.applyShifts() + + def applyShifts(self): + nc,npt,nx,ny = self._array.shape #; print("applyShifts shape was",nc,npt,nx,ny) + if npt>1: # TODO ALSO NEED SOMETHING TO DETERMINE IF SHIFTS HAVE ALREADY BEEN APPLIED. EG A LIST WHICH IS ALWAYS UPDATED WHEN ARRAY IS RESET? + return + self._array = self._array[:,0,None,:,:] * ones(len(self.probe_positions))[None,:,None,None] + for i, (px,py) in enumerate(self.probe_positions): + if px-self.lx/2 == 0 and py-self.ly/2 == 0: + continue + # Create shifted probe using phase ramp in k-space + probe_k = xp.fft.fft2(self._array[:,i,:,:]) # summable,positional,x,y + + # Apply phase ramp for spatial shift + kx_shift = xp.exp(2j * xp.pi * self.kxs[None,:, None] * (px-self.lx/2) ) + ky_shift = xp.exp(2j * xp.pi * self.kys[None,None, :] * (py-self.ly/2) ) + probe_k_shifted = probe_k * kx_shift * ky_shift + + # Convert back to real space + self._array[:,i,:,:] = xp.fft.ifft2(probe_k_shifted) + nc,npt,nx,ny = self._array.shape #; print("applyShifts expands to",nc,npt,nx,ny) def aberrate(self,aberrations): dP = aberrationFunction(self.kxs,self.kys,self.wavelength,aberrations) @@ -368,14 +470,19 @@ def Propagate(probe, potential, device=None, progress=False, onthefly=True, stor if device is not None and not TORCH_AVAILABLE: raise ImportError("PyTorch not available. Please install PyTorch.") + # Initialize wavefunction with probe(s) - shape: (n_probes, nx, ny) + nc,npt,nx,ny = probe._array.shape #; print("nc,npt,nx,ny",nc,npt,nx,ny) + array = probe._array.reshape((nc*npt,nx,ny)) # "flatten" first two indices + probe_wavelengths = probe.wavelengths[:,None]*ones(npt)[None,:] # also expand wavelengths and eVs arrays to cover all probe positions npt + probe_wavelengths = probe_wavelengths.reshape(nc*npt) + probe_eVs = probe.eVs[:,None]*ones(npt)[None,:] + probe_eVs = probe_eVs.reshape(nc*npt) - if len(probe._array.shape) == 2: - probe._array = probe._array[None,:,:] - # Calculate interaction parameter (Kirkland Eq 5.6) E0_eV = m_electron * c_light**2 / q_electron - sigma = (2 * np.pi) / (probe.wavelength * probe.eV) * \ - (E0_eV + probe.eV) / (2 * E0_eV + probe.eV) + sigma = (2 * np.pi) / (probe_wavelengths * probe_eVs) * \ + (E0_eV + probe_eVs) / (2 * E0_eV + probe_eVs) # wavelength and eVs now have length of n_probes + if TORCH_AVAILABLE: #sigma_dtype = torch.float32 if device.type == 'mps' else torch.float64 sigma = torch.tensor(sigma, dtype=float_dtype, device=device) @@ -383,14 +490,11 @@ def Propagate(probe, potential, device=None, progress=False, onthefly=True, stor # Get slice thickness dz = potential.zs[1] - potential.zs[0] if len(potential.zs) > 1 else 0.5 - # Initialize wavefunction with probe(s) - shape: (n_probes, nx, ny) - array = probe._array #.clone() - # Pre-compute propagation operator in k-space (Fresnel propagation) # All tensors should already be on the correct device from creation kx_grid, ky_grid = xp.meshgrid(potential.kxs, potential.kys, indexing='ij') k_squared = kx_grid**2 + ky_grid**2 - P = xp.exp(-1j * xp.pi * probe.wavelength * dz * k_squared) + P = xp.exp(-1j * xp.pi * probe_wavelengths[:,None,None] * dz * k_squared[None,:,:]) if progress: localtqdm = tqdm @@ -413,11 +517,11 @@ def localtqdm(iterator): potential_slice = potential.calculateSlice(z) else: potential_slice = potential._array[:, :, z] - t = xp.exp(1j * sigma * potential_slice) + t = xp.exp(1j * sigma[:,None,None] * potential_slice[None,:,:]) # Apply transmission to all probes: ψ' = t × ψ # Broadcasting: t[nx,ny] * array[n_probes,nx,ny] = array[n_probes,nx,ny] - array = t[None, :, :] * array + array = t * array # Store wavefunction at this slice if requested (after transmission) if store_all_slices: @@ -432,7 +536,7 @@ def localtqdm(iterator): # Vectorized FFT over spatial dimensions for all probes kwarg = {"dim":(-2,-1)} if TORCH_AVAILABLE else {"axes":(-2,-1)} fft_array = xp.fft.fft2(array, **kwarg) - propagated_fft = P[None, :, :] * fft_array + propagated_fft = P * fft_array array = xp.fft.ifft2(propagated_fft, **kwarg) # Return results based on what was requested @@ -444,8 +548,10 @@ def localtqdm(iterator): else: return xp.stack(slice_wavefunctions, axis=0) + #array = array.reshape((nc,npt,nx,ny)) + # Return single probe result if input was single, otherwise return batch - if array.shape[0] == 1: - return array.squeeze(0) + #if array.shape[0] == 1: + # return array.squeeze(0) return array # okay for Propagate to return a Tensor. we probably don't want to move things off-gpu yet diff --git a/src/pyslice/multislice/potentials.py b/src/pyslice/multislice/potentials.py index 683b553..48e8134 100644 --- a/src/pyslice/multislice/potentials.py +++ b/src/pyslice/multislice/potentials.py @@ -286,7 +286,10 @@ def calculateSlice(slice_idx): if self.cache_dir is not None: cache_file = self.cache_dir / ("potential_"+str(frame_idx)+"_"+str(slice_idx)+".npy") if cache_file is not None and os.path.exists(cache_file): - return np.load(cache_file) + Z = np.load(cache_file) + if TORCH_AVAILABLE: + return xp.from_numpy(Z).to(device) + return Z # Initialize slice of potential array using xp with conditional device device_kwargs = {'device': self.device } if self.use_torch else {} @@ -356,7 +359,11 @@ def calculateSlice(slice_idx): dy = self.ys[1] - self.ys[0] Z = real / (dx**2 * dy**2) if cache_file is not None: - np.save(cache_file,Z) + if TORCH_AVAILABLE and hasattr(Z, 'cpu'): + Z_cpu = Z.cpu().numpy() + else: + Z_cpu = Z + np.save(cache_file, Z_cpu) return Z self.calculateSlice = calculateSlice diff --git a/src/pyslice/postprocessing/haadf_data.py b/src/pyslice/postprocessing/haadf_data.py index 1b91364..671f544 100644 --- a/src/pyslice/postprocessing/haadf_data.py +++ b/src/pyslice/postprocessing/haadf_data.py @@ -64,7 +64,7 @@ def __init__(self, wf_data: WFData) -> None: self.cache_dir = wf_data.cache_dir # Store reference to source WFData array for ADF calculation - self._wf_array = wf_data.reshaped # x,y,t,kx,ky,l indices + self._wf_array = wf_data.reshaped # nprobes,x,y,t,kx,ky,l indices # Initialize ADF as None, will be computed by calculateADF self._array = None @@ -186,16 +186,16 @@ def calculateADF(self, inner_mrad: float = 45, outer_mrad: float = 150, preview: # self._array[i, j] = collected - # recall self._wf_array is reshaped: p,t,kx,ky,l --> x,y,t,kx,ky,l + # recall self._wf_array is reshaped: p,t,kx,ky,l --> c,x,y,t,kx,ky,l if preview: import matplotlib.pyplot as plt fig, ax = plt.subplots() - preview_data = xp.mean(xp.absolute(self._wf_array),axis=(0,1,2,5))**.2 * (1 - mask) + preview_data = xp.mean(xp.absolute(self._wf_array),axis=(0,1,2,3,6))**.2 * (1 - mask) ax.imshow(np.asarray(preview_data), cmap="inferno") plt.show() - collected = self._wf_array * mask[None,None,None,:,:,None] # x,y,t,kx,ky,l indices, mask is kx,ky - self._array = xp.mean(xp.sum(xp.absolute(collected),axis=(3,4)),axis=(2,3)) # sum over kx,ky, mean over t,l + collected = self._wf_array * mask[None,None,None,:,:,None] # c,x,y,t,kx,ky,l indices, mask is kx,ky + self._array = xp.mean(xp.sum(xp.absolute(collected),axis=(4,5)),axis=(0,3,4)) # sum over kx,ky, mean over c,t,l # Update dimensions with computed xs, ys def to_numpy(x): diff --git a/src/pyslice/postprocessing/wf_data.py b/src/pyslice/postprocessing/wf_data.py index 33ed0ef..65365e0 100644 --- a/src/pyslice/postprocessing/wf_data.py +++ b/src/pyslice/postprocessing/wf_data.py @@ -6,7 +6,7 @@ from ..multislice.multislice import Probe,aberrationFunction from ..data import Signal, Dimensions, Dimension, GeneralMetadata from pathlib import Path -from ..backend import mean +from ..backend import mean,ones try: import torch ; xp = torch @@ -154,10 +154,11 @@ def array(self): @property def reshaped(self): # where self._array is indices probe,time,kx,ky,layer, we reshape to probe_x,probe_y,time,kx,ky,layer - npt,nt,nkx,nky,nl = self._array.shape - nx=len(self.probe_xs) - ny=len(self.probe_ys) - return xp.reshape(self._array,(nx,ny,nt,nkx,nky,nl)) + nc,nptp,nx,ny = self.probe._array.shape # recall: decoherence creates duplicate probes: num_copies,num_positions,x,y indices + npta,nt,nkx,nky,nl = self._array.shape # recall, Propagate flattens the first two, and adds time,layers: nc*npt,num_frames,x,y,nl indice + intermediate = xp.reshape(self._array,(nc,nptp,nt,nkx,nky,nl)) + nx,ny = len(self.probe.probe_xs),len(self.probe.probe_ys) + return xp.reshape(intermediate,(nc,nx,ny,nt,nkx,nky,nl)) @array.setter def array(self, value): @@ -234,7 +235,7 @@ def plot_reciprocal(self,filename=None,whichProbe="mean",whichTimestep="mean",po else: array = array[whichProbe] - if isinstance(whichTimestep,str) and whichProbe=="mean": + if isinstance(whichTimestep,str) and whichTimestep=="mean": array = mean(array,axis=0) # t,kx,ky --> kx,ky else: array = array[whichTimestep] @@ -358,22 +359,25 @@ def plot_phase(self,filename=None,whichProbe=0,whichTimestep=0,extent=None,avg=F else: plt.show() - def plot_realspace(self,whichProbe=0,whichTimestep=0,extent=None,avg=False,filename=None): + def plot_realspace(self,whichProbe="mean",whichTimestep="mean",extent=None,filename=None): import matplotlib.pyplot as plt fig, ax = plt.subplots() - # Get array (with or without averaging) - if avg: - array = self._array[whichProbe,:,:,:,-1] # Shape: (time, kx, ky) - if hasattr(array, 'mean'): # torch tensor - array = array.mean(dim=0) # Average over time dimension - else: # numpy array - array = np.mean(array, axis=0) + array = xp.fft.ifft2(self._array[:,:,:,:,-1]) + + array = xp.absolute(array) # probe, time, kx, ky, layer --> p,t,kx,ky + + if isinstance(whichProbe,str) and whichProbe=="mean": + array = mean(abs(array),axis=0) # p,t,kx,ky --> t,kx,ky else: - array = self._array[whichProbe,whichTimestep,:,:,-1] + array = array[whichProbe] + + if isinstance(whichTimestep,str) and whichTimestep=="mean": + array = mean(array,axis=0) # t,kx,ky --> kx,ky + else: + array = array[whichTimestep] array = array.T # imshow convention: y,x. our convention: x,y - array = xp.fft.ifft2(array) # Use provided extent or calculate from data if extent is None: @@ -403,6 +407,27 @@ def propagate_free_space(self,dz): # UNITS OF ANGSTROM #if dz>0: self._array = P[None,None,:,:,None] * self._array + def addSpatialDecoherence(self,sigma_dz,N): + dzs = np.linspace(-2*sigma_dz,2*sigma_dz,N) # suppose N=25 + amplitudes = np.exp(-dzs**2/sigma_dz**2) + self._array = self._array[:,None,:,:,:,:] * ones(N)[None,:,None,None,None,None] # n_probes,nt,nx,ny,nl --> + nc,npt,nt,nx,ny,nl = self._array.shape # suppose nc=10 (addTemporalDecoherence created 10 wavelengths) + kx_grid, ky_grid = xp.meshgrid(self._kxs, self._kys, indexing='ij') + k_squared = kx_grid**2 + ky_grid**2 + for i in range(N): + inner = xp.pi * self.probe.wavelength * dzs[i] * k_squared + P = xp.exp( -1j * inner ) # not sure why, but combining this and previous line triggers a "ComplexWarning: Casting complex values to real discards the imaginary part" in python 2.9.1 but not 2.2.2 + self._array[:,i,:,:,:,:] *= amplitudes[i]*P[None,None,:,:,None] + self._array = self._array.reshape(nc*npt,nt,nx,ny,nl) + #self.defocus(dzs) # defocus starts with 25,10,npt,nx,ny --reshapes--> 250,npt,nx,ny + #for i in range(N): # reshape to flatten loops first index last: [[0,1],[2,3]] --> [0,1,2,3] + # for j in range(nc): + # self._array[i*nc+j] *= amplitudes[i] + #nc,npt,nx,ny = self._array.shape + #if npt==1: + # self.applyShifts() + + def applyMask(self, radius, realOrReciprocal="reciprocal"): if realOrReciprocal == "reciprocal": radii = xp.sqrt( self._kxs[:,None]**2 + self._kys[None,:]**2 ) diff --git a/tests/18_caching.py b/tests/18_caching.py index 2f00f51..9d0a0a5 100644 --- a/tests/18_caching.py +++ b/tests/18_caching.py @@ -19,6 +19,9 @@ a,b=2.4907733333333337,2.1570729817355123 # cache_level options include: ["exitwaves","slices","potentials"] +tests = [1,2,3,4,5,5.5,6,7,8,9,10] +#tests = [9,10] +#tests=[5,5.5] # LOAD TRAJECTORY trajectory=Loader(dump,timestep=dt,atom_mapping=types).load() @@ -26,92 +29,115 @@ trajectory=trajectory.slice_positions([0,10*a],[0,10*b]) # ONE TIMESTEPS, ONE PROBE: -print("1. one timestep, one probe, normal caching") -traj1=trajectory.get_random_timesteps(11,seed=1) -calculator=MultisliceCalculator() -calculator.setup(traj1,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/01-test.npy","01") # p,t,x,y,l indices +if 1 in tests: + print("1. one timestep, one probe, normal caching") + traj1=trajectory.get_random_timesteps(11,seed=1) + calculator=MultisliceCalculator() + calculator.setup(traj1,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/01-test.npy","01") # p,t,x,y,l indices # ONE TIMESTEPS, ONE PROBE: -print("2. one timestep, one probe, cache potentials only") -traj2=trajectory.get_random_timesteps(11,seed=2) -calculator=MultisliceCalculator() -calculator.setup(traj2,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["potentials"]) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/02-test.npy","02") # p,t,x,y,l indices +if 2 in tests: + print("2. one timestep, one probe, cache potentials only") + traj2=trajectory.get_random_timesteps(11,seed=2) + calculator=MultisliceCalculator() + calculator.setup(traj2,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["potentials"]) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/02-test.npy","02") # p,t,x,y,l indices # ONE TIMESTEP, MANY PROBES: -print("3. one timestep, many probes, normal caching") -traj3=trajectory.get_random_timesteps(1,seed=3) -calculator=MultisliceCalculator() -probe_xs = np.linspace(a,3*a,14) -probe_ys = np.linspace(b,3*b,16) -calculator.setup(traj3,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys) -exitwaves = calculator.run() -differ(exitwaves.array[::5,:,::5,::5,:],"outputs/caching/03-test.npy","03") # p,t,x,y,l indices +if 3 in tests: + print("3. one timestep, many probes, normal caching") + traj3=trajectory.get_random_timesteps(1,seed=3) + calculator=MultisliceCalculator() + probe_xs = np.linspace(a,3*a,14) + probe_ys = np.linspace(b,3*b,16) + calculator.setup(traj3,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys) + exitwaves = calculator.run() + differ(exitwaves.array[::5,:,::5,::5,:],"outputs/caching/03-test.npy","03") # p,t,x,y,l indices # MANY TIMESTEPS, ONE PROBE: -print("4. many timesteps, one probe, normal caching") -traj4=trajectory.get_random_timesteps(10,seed=4) -calculator=MultisliceCalculator() -calculator.setup(traj4,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/04-test.npy","04") # p,t,x,y,l indices +if 4 in tests: + print("4. many timesteps, one probe, normal caching") + traj4=trajectory.get_random_timesteps(10,seed=4) + calculator=MultisliceCalculator() + calculator.setup(traj4,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/04-test.npy","04") # p,t,x,y,l indices # MANY TIMESTEPS, MANY PROBES: -print("5. many timesteps, many probes, normal caching") -traj5=trajectory.get_random_timesteps(5,seed=5) -calculator=MultisliceCalculator() -probe_xs = np.linspace(a,3*a,9) -probe_ys = np.linspace(b,3*b,10) -calculator.setup(traj5,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::10,::10,:],"outputs/caching/05-test.npy","05") # p,t,x,y,l indices +if 5 in tests: + print("5. many timesteps, many probes, normal caching") + traj5=trajectory.get_random_timesteps(5,seed=5) + calculator=MultisliceCalculator() + probe_xs = np.linspace(a,3*a,9) + probe_ys = np.linspace(b,3*b,10) + calculator.setup(traj5,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::10,::10,:],"outputs/caching/05-test.npy","05") # p,t,x,y,l indices + +# MANY TIMESTEPS, MANY PROBES: +if 5.5 in tests: + print("5.5 many timesteps, many probes, normal caching, with decoherence added") + traj55=trajectory.get_random_timesteps(5,seed=5) + calculator=MultisliceCalculator() + probe_xs = np.linspace(a,3*a,9) + probe_ys = np.linspace(b,3*b,10) + calculator.setup(traj55,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys) + calculator.base_probe.addSpatialDecoherence(1000,7) + exitwaves = calculator.run() + #differ(exitwaves.array[:,:,::10,::10,:],"outputs/caching/05-test.npy","05") # p,t,x,y,l indices + # CACHING TURNED OFF: -print("6. many timesteps, many probes, no caching") -traj6=trajectory.get_random_timesteps(5,seed=6) -calculator=MultisliceCalculator() -probe_xs = np.linspace(a,3*a,9) -probe_ys = np.linspace(b,3*b,10) -calculator.setup(traj6,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=[]) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::10,::10,:],"outputs/caching/06-test.npy","06") # p,t,x,y,l indices +if 6 in tests: + print("6. many timesteps, many probes, no caching") + traj6=trajectory.get_random_timesteps(5,seed=6) + calculator=MultisliceCalculator() + probe_xs = np.linspace(a,3*a,9) + probe_ys = np.linspace(b,3*b,10) + calculator.setup(traj6,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=[]) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::10,::10,:],"outputs/caching/06-test.npy","06") # p,t,x,y,l indices # OR WITH THE POTENTIAL SAVED OFF ONLY -print("7. many timesteps, one probe, caching potentials only") -traj7=trajectory.get_random_timesteps(10,seed=7) -calculator=MultisliceCalculator() -calculator.setup(traj7,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["potentials"]) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/07-test.npy","07") # p,t,x,y,l indices +if 7 in tests: + print("7. many timesteps, one probe, caching potentials only") + traj7=trajectory.get_random_timesteps(10,seed=7) + calculator=MultisliceCalculator() + calculator.setup(traj7,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["potentials"]) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/07-test.npy","07") # p,t,x,y,l indices # LAYERWISE CACHING -print("8. many timesteps, many probes, layerwise caching") -traj8=trajectory.get_random_timesteps(5,seed=8) -calculator=MultisliceCalculator() -probe_xs = np.linspace(a,3*a,6) -probe_ys = np.linspace(b,3*b,7) -calculator.setup(traj8,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=["slices"]) -exitwaves = calculator.run() -differ(exitwaves.array[:,::3,::20,::20,::5],"outputs/caching/08-test.npy","08") # p,t,x,y,l indices +if 8 in tests: + print("8. many timesteps, many probes, layerwise caching") + traj8=trajectory.get_random_timesteps(5,seed=8) + calculator=MultisliceCalculator() + probe_xs = np.linspace(a,3*a,6) + probe_ys = np.linspace(b,3*b,7) + calculator.setup(traj8,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=["slices"]) + exitwaves = calculator.run() + differ(exitwaves.array[:,::3,::20,::20,::5],"outputs/caching/08-test.npy","08") # p,t,x,y,l indices # LAYERWISE CACHING, WITH ONE PROBE -print("9. many timesteps, one probe, layerwise caching") -traj9=trajectory.get_random_timesteps(5,seed=9) -calculator=MultisliceCalculator() -calculator.setup(traj9,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["slices"]) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::5,::5,::5],"outputs/caching/09-test.npy","09") # p,t,x,y,l indices +if 9 in tests: + print("9. many timesteps, one probe, layerwise caching") + traj9=trajectory.get_random_timesteps(5,seed=9) + calculator=MultisliceCalculator() + calculator.setup(traj9,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["slices"]) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::5,::5,::5],"outputs/caching/09-test.npy","09") # p,t,x,y,l indices # LAYERWISE CACHING OR WITH ONE TIMESTEP -print("10. one timestep, many probes, layerwise caching") -traj10=trajectory.get_random_timesteps(1,seed=10) -calculator=MultisliceCalculator() -probe_xs = np.linspace(a,3*a,9) -probe_ys = np.linspace(b,3*b,10) -calculator.setup(traj10,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=["slices"]) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::10,::10,::5],"outputs/caching/10-test.npy","10") # p,t,x,y,l indices +if 10 in tests: + print("10. one timestep, many probes, layerwise caching") + traj10=trajectory.get_random_timesteps(1,seed=10) + calculator=MultisliceCalculator() + probe_xs = np.linspace(a,3*a,9) + probe_ys = np.linspace(b,3*b,10) + calculator.setup(traj10,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=["slices"]) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::10,::10,::5],"outputs/caching/10-test.npy","10") # p,t,x,y,l indices diff --git a/tests/19_coherence.py b/tests/19_coherence.py new file mode 100644 index 0000000..98f8a6a --- /dev/null +++ b/tests/19_coherence.py @@ -0,0 +1,115 @@ +import sys,os,time +try: + import pyslice +except ModuleNotFoundError: + print("import failed, falling back to relative paths") + sys.path.insert(0, '../src') +start=time.time() +from pyslice import Probe,Loader,MultisliceCalculator,HAADFData +import numpy as np +import matplotlib.pyplot as plt + +run = "TEM" + +dump="inputs/hBN_truncated.lammpstrj" +dt=.005 +types={1:"B",2:"N"} +a,b=2.4907733333333337,2.1570729817355123 + +# CHANGES SO FAR: +# probe.array is now fixed at summable,positional,x,y indices. previously, it was positional,x,y with positional optional +# calling Probe creates a base_probe +# probe_xs,probe_ys or probe_positions now passable to Probe (instead of this being set up inside MultisliceCalculator) +# Probe.__init__ calls applyShifts (which is like create_batched_probes from before) which expands the positional axis +# additional Probe internal variables Probe.eVs and Probe.wavelengths (plural) are arrays, with length matching the summable axis +# new function addTemporalDecoherence: expands summable axis by creating fresh base_probes of varying eV and varying wavelength +# Propagate is updated to handle varying eV and varying wavelength across multiple types of probes (summable axis) +# new function addSpatialDecoherence: expands summable axis (N*M for N samples spatial and M samples temporal) to defocus each existing probe +# Propagate also collapses probe._array from summable,positions,x,y --> summable*positions,x,y so all the same math, caching, etc, still works +# p.s., i decided to make eV NOT passable as a list. the user should use addTemporalDecoherence. this simplifies the logic, since addTemporalDecoherence is creating fresh probes and would then need to re-call addSpatialDecoherence if it was called previously +# TODO: +# caching excludes decoherence effects (since the calculator cache is calculated based on base_probe's n_probes) + +if run == "probes": + # Generate a few dummy probes + xs=np.linspace(0,50,501) + ys=np.linspace(0,49,491) + + probe = Probe(xs,ys,mrad=30,eV=100e3,gaussianVOA=.1,preview=True) + probe.plot(title="gauss, 30mrad") + + probe = Probe(xs,ys,mrad=30,eV=100e3) + + # temporal decoherence: a range of energies + eVs = np.linspace(80e3,120e3,25) ; amplitudes = np.exp(-(100e3-eVs)**2/10e3**2) + #plt.plot(eVs,amplitudes) ; plt.show() + # manually stack a list of probes' arrays + probes = [ Probe(xs,ys,mrad=30,eV=eV) for eV in eVs ] + probe._array = np.mean([ np.absolute(a*p._array) for a,p in zip(amplitudes,probes)],axis=0)[None,:,:,:] + probe.plot(title="manual stack eV") + # or, do it automatically: + probe = Probe(xs,ys,mrad=30,eV=100e3) + probe.addTemporalDecoherence(10e3,25) ; print(probe.array.shape) + probe.plot(title="auto stack eV") + + # spatial decoherence: a range of defocuses? + dZ = np.linspace(-400,400,27) ; amplitudes = np.exp(-(dZ)**2/200**2) + #plt.plot(dZ,amplitudes) ; plt.show() + probes = [ Probe(xs,ys,mrad=30,eV=100e3) for i in range(50) ] + #[ p.aberrate({"C10":d}) for p,d in zip(probes,dZ) ] + [ p.defocus(d) for p,d in zip(probes,dZ) ] + probe._array = np.mean([np.absolute(a*p._array) for a,p in zip(amplitudes,probes)],axis=0) + probe.plot(title="manual stack dZ") + # or, do it automatically: + probe = Probe(xs,ys,mrad=30,eV=100e3) + probe.addSpatialDecoherence(200,10) ; print(probe.array.shape) + probe.plot(title="auto stack dZ") + + # Or both, and let's check that order doesn't matter + probe = Probe(xs,ys,mrad=30,eV=100e3) + probe.addSpatialDecoherence(100,11) + probe.addTemporalDecoherence(10e3,9) + probe.plot(title="auto decohere eV and dZ") + +if run == "STEM": + trajectory=Loader(dump,timestep=dt,atom_mapping=types).load() + # TRIM TO 10x10 UC + trajectory=trajectory.slice_positions([0,10*a],[0,10*b]) + # SELECT 10 "RANDOM" TIMESTEPS (use seed for reproducibility) + trajectory=trajectory.get_random_timesteps(3,seed=5) + # CREATE CALCULATOR OBJECT + calculator=MultisliceCalculator() + # SET UP GRID OF HAADF SCAN POINTS + probe_xs = np.linspace(a,3*a,10) + probe_ys = np.linspace(b,3*b,9) + calculator.setup(trajectory,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys) + #calculator.base_probe.addTemporalDecoherence(10e3,9) + calculator.base_probe.addSpatialDecoherence(200,10) + # RUN MULTISLICE + exitwaves = calculator.run() + haadf=HAADFData(exitwaves) + ary=haadf.calculateADF(preview=False) # use preview=True to view the collection angles of the ADF detector in reciprocal space + haadf.plot() + +if run == "TEM": + # LOAD TRAJECTORY + trajectory=Loader(dump,timestep=dt,atom_mapping=types).load() + # TRIM TO 10x10 UC + trajectory=trajectory.slice_positions([0,10*a],[0,10*b]) + # SELECT 10 "RANDOM" TIMESTEPS (use seed for reproducibility) + trajectory=trajectory.get_random_timesteps(3,seed=5) + # CREATE CALCULATOR OBJECT + calculator=MultisliceCalculator() + calculator.setup(trajectory,aperture=0,voltage_eV=100e3,sampling=.1,slice_thickness=.5) + #calculator.base_probe.addTemporalDecoherence(30e3,25) + print(calculator.base_probe.eV) + #calculator.base_probe.addSpatialDecoherence(100,27) + # RUN MULTISLICE + exitwaves = calculator.run() + print(exitwaves.array.shape) + #exitwaves.propagate_free_space(10) + exitwaves.addSpatialDecoherence(10,27) + print(exitwaves.array.shape) + + exitwaves.plot_realspace()#filename="outputs/figs/04_haadf_cbed.png") +