diff --git a/src/pyslice/multislice/calculators.py b/src/pyslice/multislice/calculators.py index 661e3c3..ee51669 100644 --- a/src/pyslice/multislice/calculators.py +++ b/src/pyslice/multislice/calculators.py @@ -125,6 +125,7 @@ def setup( cleanup_temp_files: bool = False, slice_axis: int = 2, cache_levels: list = ["exitwaves"], # options include: exitwaves, slices, potentials (this replaces store_all_slices) + cache_layer_indices: Optional[List[int]] = None, # NEW: subset of slice indices to store; None = store all layers max_kx = np.inf, max_ky = np.inf, use_memmap = False, @@ -151,6 +152,11 @@ def setup( save_path: Optional path to save wave function data cleanup_temp_files: Whether to delete temp files after loading store_all_slices: If True, store wavefunction at each slice for 3D visualization + cache_layer_indices: Optional list of slice-layer indices (0-based) to record when + cache_levels includes "slices". If None (default), all nz layers are stored. + Specifying a small subset (e.g. the 6 depths needed for EELS thickness series) + can reduce disk usage by >98% without affecting propagation accuracy. + Example: cache_layer_indices=[44, 88, 176, 264, 352, 440] """ self.trajectory = trajectory @@ -166,6 +172,7 @@ def setup( self.cleanup_temp_files = cleanup_temp_files self.slice_axis = slice_axis self.cache_levels = cache_levels + self.cache_layer_indices = cache_layer_indices # NEW: store for use in run() self.max_kx = max_kx self.max_ky = max_ky self.use_memmap = use_memmap # bool: frame_data (p,x,y,l,1) and wavefunction_data (p,t,x,y,l) will be memmapped instead of held in RAM @@ -293,6 +300,41 @@ def run(self) -> WFData: self.output_dir = Path("psi_data/" + ("torch" if TORCH_AVAILABLE else "numpy") + "_"+cache_key) self.output_dir.mkdir(parents=True, exist_ok=True) + # ── Resolve which layers to store ── + # NEW: if cache_layer_indices is set, only those layers are FFT'd and + # written to disk; the propagation itself still runs through all nz + # slices (physically required). cache_layer_indices=None keeps the + # original behaviour of storing every layer. + if "slices" in self.cache_levels and self.cache_layer_indices is not None: + # Validate and clip indices to [0, nz-1] + _requested = sorted(set(int(i) for i in self.cache_layer_indices)) + _dropped = [i for i in _requested if not (0 <= i < self.nz)] + _active_layers = [i for i in _requested if 0 <= i < self.nz] + if _dropped: + logger.warning( + f"cache_layer_indices: dropped out-of-range indices {_dropped} " + f"(nz={self.nz})" + ) + if not _active_layers: + raise ValueError( + "cache_layer_indices produced no valid layer indices after " + f"clipping to [0, {self.nz-1}]." + ) + logger.info( + f"Selective layer storage: recording {len(_active_layers)}/{self.nz} " + f"layers -> {_active_layers}" + ) + print( + f"[MultisliceCalculator] cache_layer_indices: storing " + f"{len(_active_layers)}/{self.nz} layers: {_active_layers}", + flush=True + ) + else: + # Default: store every layer (original behaviour) + _active_layers = list(range(self.nz)) if "slices" in self.cache_levels else [0] + + self._active_layers = _active_layers # expose for inspection / post-processing + # if probes are over vacuum (e.g. nanoparticles), we don't need to propagate them? self.probe_indices = xp.arange(len(self.probe_positions)) @@ -315,7 +357,8 @@ def run(self) -> WFData: nc,npt,nx,ny = self.base_probe._array.shape self.n_probes = nc*len(self.probe_positions) # Storage: [probe, frame, x, y, layer] - matches WFData expected format - self.n_layers = self.nz if "slices" in self.cache_levels else 1 + # CHANGED: n_layers is now len(_active_layers) instead of always self.nz + self.n_layers = len(_active_layers) if self.store_full: fd_nx = self.nx ; fd_ny = self.ny ; fd_npt = self.n_probes #if self.base_probe.cropping: @@ -406,6 +449,7 @@ def run(self) -> WFData: #print("create frame_data") ; start = time.time() # frame_data is always: p,x,y,l,1 (self.wavefunction_data expects p,t,x,y,l, since we loop time. recall Propagate gave l,p,x,y) + # CHANGED: last dim is self.n_layers = len(_active_layers), not nz if self.store_full or self.prism: fd_nx = self.nx ; fd_ny = self.ny ; fd_npt = self.n_probes #if self.base_probe.cropping: @@ -454,8 +498,13 @@ def run(self) -> WFData: exit_waves_single = expand_dims(exit_waves_single,0) if len(exit_waves_single.shape)==3 else exit_waves_single # FFT and load into frame_data kwarg = {"dim":(-2,-1)} if TORCH_AVAILABLE else {"axes":(-2,-1)} - for layer_idx in range(self.n_layers): - exit_waves_k = xp.fft.fft2(exit_waves_single[layer_idx,:,:,:], **kwarg) # l,p,x,y --> p,x,y + + # CHANGED: iterate over (out_idx, real_layer_idx) pairs instead of + # range(self.n_layers). When cache_layer_indices=None, _active_layers + # is list(range(nz)) so out_idx == real_layer_idx and behaviour is + # identical to the original code. + for out_idx, real_layer_idx in enumerate(_active_layers): + exit_waves_k = xp.fft.fft2(exit_waves_single[real_layer_idx,:,:,:], **kwarg) # l,p,x,y --> p,x,y diffraction_patterns = xp.fft.fftshift(exit_waves_k, **kwarg) #if not self.prism: diffraction_patterns = diffraction_patterns[:,self.keep_kxs_indices,:][:,:,self.keep_kys_indices]*self.kth**2 @@ -463,7 +512,8 @@ def run(self) -> WFData: diffraction_patterns = to_cpu(diffraction_patterns) selected = to_cpu(selected) if self.store_full or self.prism: - frame_data[selected,:,:,layer_idx,0] = diffraction_patterns # load p,x,y --> p,x,y,l,1 indices + # CHANGED: write to compact slot out_idx, not real_layer_idx + frame_data[selected,:,:,out_idx,0] = diffraction_patterns # load p,x,y --> p,x,y,l,1 indices if self.ADF and not self.prism: #print(self.ADF._wf_array[0,:,:,0,0,0,0]) intensities = einsum('pxy,xy->p',absolute(diffraction_patterns[:,:,:])**2,self.ADFmask) @@ -550,7 +600,11 @@ def run(self) -> WFData: #kxs = xp.fft.fftshift(xp.fft.fftfreq(self.nx, self.sampling)) # k-space in 1/Å MOVING TO INIT SO WE CAN CROP ON-THE-FLY #kys = xp.fft.fftshift(xp.fft.fftfreq(self.ny, self.sampling)) # k-space in 1/Å time_array = np.arange(self.n_frames) * self.trajectory.timestep # Time array in ps - layer_array = np.arange(self.nz) if "slices" in self.cache_levels else np.array([0]) # Layer indices + + # CHANGED: layer_array now reflects the actual stored layer indices. + # When cache_layer_indices=None, _active_layers == list(range(nz)) so + # layer_array == np.arange(nz), identical to the original behaviour. + layer_array = np.array(_active_layers) if "slices" in self.cache_levels else np.array([0]) # Layer indices # Package results array = zeros((self.n_probes,1,1,1,1),dtype=self.complex_dtype) @@ -667,4 +721,3 @@ def plot(self,w,filename=None): # TODO MAYBE "RUN" SHOULD RETURN A TACAW OBJECT else: plt.show() -