From 797003373ba9baf9b3e08cdb9ef0fc509a88beae Mon Sep 17 00:00:00 2001 From: Haoran Ma <150745856+HaoranLMaoMao@users.noreply.github.com> Date: Thu, 7 May 2026 13:33:13 +0200 Subject: [PATCH 1/3] feat: add cache_layer_indices to reduce wavefunction storage Added cache_layer_indices parameter to control which slice layers are stored. Updated related documentation and logic to handle selective layer storage. --- src/pyslice/multislice/calculators.py | 151 +++++++++++--------------- 1 file changed, 65 insertions(+), 86 deletions(-) diff --git a/src/pyslice/multislice/calculators.py b/src/pyslice/multislice/calculators.py index 661e3c3..03636ca 100644 --- a/src/pyslice/multislice/calculators.py +++ b/src/pyslice/multislice/calculators.py @@ -125,6 +125,7 @@ def setup( cleanup_temp_files: bool = False, slice_axis: int = 2, cache_levels: list = ["exitwaves"], # options include: exitwaves, slices, potentials (this replaces store_all_slices) + cache_layer_indices: Optional[List[int]] = None, # NEW: subset of slice indices to store; None = store all layers max_kx = np.inf, max_ky = np.inf, use_memmap = False, @@ -151,6 +152,11 @@ def setup( save_path: Optional path to save wave function data cleanup_temp_files: Whether to delete temp files after loading store_all_slices: If True, store wavefunction at each slice for 3D visualization + cache_layer_indices: Optional list of slice-layer indices (0-based) to record when + cache_levels includes "slices". If None (default), all nz layers are stored. + Specifying a small subset (e.g. the 6 depths needed for EELS thickness series) + can reduce disk usage by >98% without affecting propagation accuracy. + Example: cache_layer_indices=[44, 88, 176, 264, 352, 440] """ self.trajectory = trajectory @@ -166,6 +172,7 @@ def setup( self.cleanup_temp_files = cleanup_temp_files self.slice_axis = slice_axis self.cache_levels = cache_levels + self.cache_layer_indices = cache_layer_indices # NEW: store for use in run() self.max_kx = max_kx self.max_ky = max_ky self.use_memmap = use_memmap # bool: frame_data (p,x,y,l,1) and wavefunction_data (p,t,x,y,l) will be memmapped instead of held in RAM @@ -199,7 +206,6 @@ def setup( self.keep_kxs_indices = xp.arange(self.nx)[kx_mask==1][::self.kth] self.keep_kys_indices = xp.arange(self.ny)[ky_mask==1][::self.kth] self.nx = len(self.keep_kxs_indices) ; self.ny = len(self.keep_kys_indices) - #print("kxs",len(self.kxs),"kys",len(self.kys),"nx",self.nx,"ny",self.ny,"keepkx",len(self.keep_kxs_indices),"keepky",len(self.keep_kys_indices)) # Preferred to pass probe_xs and probe_ys from which we will define a grid if self.probe_xs is not None and self.probe_ys is not None: @@ -230,7 +236,6 @@ def setup( # Initialize storage for results self.n_frames = trajectory.n_frames - #self.n_probes = len(self.base_probe.probe_positions) # Set dtype based on the actual device we're using if TORCH_AVAILABLE and self.device is not None: @@ -250,9 +255,7 @@ def setup( self.slice_thickness, self.sampling, self.probe_positions, self.base_probe.spatial_decoherence, self.base_probe.temporal_decoherence, self.base_probe._array) - #print(cache_key) self.output_dir = Path("psi_data/" + ("torch" if TORCH_AVAILABLE else "numpy") + "_"+self.cache_key) - #self.output_dir.mkdir(parents=True, exist_ok=True) def preview_probes(self): positions = self.trajectory.positions[0] @@ -271,7 +274,6 @@ def preview_probes(self): array = np.absolute(to_cpu(potential.array))[:,::-1,0].T # imshow convention: y,x. our convention: x,y, and flip y (0,0 upper-left) xs = to_cpu(potential.xs) ; ys = to_cpu(potential.ys) extent = (np.amin(xs),np.amax(xs),np.amin(ys),np.amax(ys)) - #print(extent) ax.imshow(array, cmap="inferno", extent=extent) ax.set_xlabel("x ($\\AA$)") ; ax.set_ylabel("y ($\\AA$)") pp = np.asarray(self.base_probe.probe_positions) @@ -289,10 +291,44 @@ def run(self) -> WFData: self.base_probe._array) if self.cache_key != cache_key: self.cache_key = cache_key - #print(cache_key) self.output_dir = Path("psi_data/" + ("torch" if TORCH_AVAILABLE else "numpy") + "_"+cache_key) self.output_dir.mkdir(parents=True, exist_ok=True) + # ── Resolve which layers to store ──────────────────────────────────── + # NEW: if cache_layer_indices is set, only those layers are FFT'd and + # written to disk; the propagation itself still runs through all nz + # slices (physically required). cache_layer_indices=None keeps the + # original behaviour of storing every layer. + if "slices" in self.cache_levels and self.cache_layer_indices is not None: + # Validate and clip indices to [0, nz-1] + _requested = sorted(set(int(i) for i in self.cache_layer_indices)) + _dropped = [i for i in _requested if not (0 <= i < self.nz)] + _active_layers = [i for i in _requested if 0 <= i < self.nz] + if _dropped: + logger.warning( + f"cache_layer_indices: dropped out-of-range indices {_dropped} " + f"(nz={self.nz})" + ) + if not _active_layers: + raise ValueError( + "cache_layer_indices produced no valid layer indices after " + f"clipping to [0, {self.nz-1}]." + ) + logger.info( + f"Selective layer storage: recording {len(_active_layers)}/{self.nz} " + f"layers -> {_active_layers}" + ) + print( + f"[MultisliceCalculator] cache_layer_indices: storing " + f"{len(_active_layers)}/{self.nz} layers: {_active_layers}", + flush=True + ) + else: + # Default: store every layer (original behaviour) + _active_layers = list(range(self.nz)) if "slices" in self.cache_levels else [0] + + self._active_layers = _active_layers # expose for inspection / post-processing + # ───────────────────────────────────────────────────────────────────── # if probes are over vacuum (e.g. nanoparticles), we don't need to propagate them? self.probe_indices = xp.arange(len(self.probe_positions)) @@ -314,19 +350,19 @@ def run(self) -> WFData: nc,npt,nx,ny = self.base_probe._array.shape self.n_probes = nc*len(self.probe_positions) + # Storage: [probe, frame, x, y, layer] - matches WFData expected format - self.n_layers = self.nz if "slices" in self.cache_levels else 1 + # CHANGED: n_layers is now len(_active_layers) instead of always self.nz + self.n_layers = len(_active_layers) + if self.store_full: fd_nx = self.nx ; fd_ny = self.ny ; fd_npt = self.n_probes - #if self.base_probe.cropping: - # fd_nx = self.base_probe.cropping ; fd_ny = fd_nx if self.use_memmap: self.wavefunction_data = memmap((fd_npt, self.n_frames, fd_nx, fd_ny, self.n_layers), dtype=self.complex_dtype, filename = self.output_dir / "wdf_memmap.npy" ) else: self.wavefunction_data = zeros((fd_npt, self.n_frames, fd_nx, fd_ny, self.n_layers), dtype=self.complex_dtype, device=self.device) - #print("wavefunction_data",self.wavefunction_data.shape) # Process frames with caching and multiprocessing total_start_time = time.time() @@ -355,8 +391,6 @@ def run(self) -> WFData: # Process frames one at a time with tqdm progress tracking with tqdm(total=self.n_frames, desc="Processing frames", unit="frame") as pbar: for frame_idx in range(self.n_frames): - #if sum(absolute(self.wavefunction_data[:,frame_idx,:,:,:]),axis=(0,1,2,3))>0: # p,t,x,y,l indices - #continue cache_file = self.output_dir / f"frame_{frame_idx}.npy" # Show detailed progress for single-frame runs show_progress = (frame_idx == 0 and self.n_frames == 1 and not self.loop_probes) @@ -393,40 +427,25 @@ def run(self) -> WFData: if cache_exists: frames_cached += 1 else: - #print("create potential") ; start = time.time() potential = Potential(self.xs, self.ys, self.zs, positions, atom_type_names, kind="kirkland", device=self.device, slice_axis=self.slice_axis, progress=show_progress, cache_dir=cache_file.parent if "potentials" in self.cache_levels else None, frame_idx = frame_idx) - #print("(done)",time.time()-start) - #n_probes = nc*npt nc,npt,nx,ny = self.base_probe._array.shape ; npt = len(self.base_probe.probe_positions) n_slices = len(self.zs) n_waves = len(self.base_probe.probe_positions) - #n_waves = - #if self.base_probe.cropping: - # nx,ny = self.base_probe.cropping,self.base_probe.cropping - #print("create frame_data") ; start = time.time() # frame_data is always: p,x,y,l,1 (self.wavefunction_data expects p,t,x,y,l, since we loop time. recall Propagate gave l,p,x,y) + # CHANGED: last dim is self.n_layers = len(_active_layers), not nz if self.store_full or self.prism: fd_nx = self.nx ; fd_ny = self.ny ; fd_npt = self.n_probes - #if self.base_probe.cropping: - # fd_nx = self.base_probe.cropping ; fd_ny = fd_nx # ceil(/self.kth) ; if self.use_memmap: frame_data = memmap((n_waves, fd_nx, fd_ny, self.n_layers,1), dtype=self.complex_dtype, filename = cache_file ) else: frame_data = zeros((n_waves, fd_nx, fd_ny, self.n_layers,1), dtype=self.complex_dtype, device=self.device) - #print("frame_data",frame_data.shape) - #print("(done)",time.time()-start) - #batched_probes = create_batched_probes(self.base_probe, self.probe_positions, self.device) - # Propagate returns: [l,p,x,y] where l,p are both optional (if store_all_slices=True, and if n_probes>1) - #print("chunking") ; start = time.time() chunks = [] if self.loop_probes: chunksize = self.loop_probes if isinstance(self.loop_probes,int) else 1 for i in range(10000000): chunk = xp.arange(i*chunksize,(i+1)*chunksize) - #chunk = chunk[chunknpt: break @@ -437,7 +456,6 @@ def run(self) -> WFData: else: chunks.append( xp.arange(npt) ) pbar2 = None - #print("(done)",time.time()-start) for selected in chunks: if len(selected)==npt: @@ -445,57 +463,36 @@ def run(self) -> WFData: else: probe = self.base_probe.copy(selected_probes=selected) probe.applyShifts() - # propagate single probe - #print("propagate") ; start = time.time() + # Propagate returns shape (nz, n_probes, nx, ny) when store_all_slices=True exit_waves_single = Propagate(probe, potential, self.device, progress=show_progress, onthefly=True, store_all_slices = ("slices" in self.cache_levels) ) # [l],p,x,y indices - #print("(done)",time.time()-start) # expand out to fixed l,p,x,y indices exit_waves_single = expand_dims(exit_waves_single,0) if len(exit_waves_single.shape)==3 else exit_waves_single + # FFT and load into frame_data kwarg = {"dim":(-2,-1)} if TORCH_AVAILABLE else {"axes":(-2,-1)} - for layer_idx in range(self.n_layers): - exit_waves_k = xp.fft.fft2(exit_waves_single[layer_idx,:,:,:], **kwarg) # l,p,x,y --> p,x,y + + # CHANGED: iterate over (out_idx, real_layer_idx) pairs instead of + # range(self.n_layers). When cache_layer_indices=None, _active_layers + # is range(nz) so out_idx == real_layer_idx and behaviour is identical + # to the original code. + for out_idx, real_layer_idx in enumerate(_active_layers): + exit_waves_k = xp.fft.fft2(exit_waves_single[real_layer_idx,:,:,:], **kwarg) # l,p,x,y --> p,x,y diffraction_patterns = xp.fft.fftshift(exit_waves_k, **kwarg) - #if not self.prism: diffraction_patterns = diffraction_patterns[:,self.keep_kxs_indices,:][:,:,self.keep_kys_indices]*self.kth**2 if self.use_memmap: diffraction_patterns = to_cpu(diffraction_patterns) selected = to_cpu(selected) if self.store_full or self.prism: - frame_data[selected,:,:,layer_idx,0] = diffraction_patterns # load p,x,y --> p,x,y,l,1 indices + # CHANGED: write to compact slot out_idx, not real_layer_idx + frame_data[selected,:,:,out_idx,0] = diffraction_patterns # load p,x,y --> p,x,y,l,1 indices if self.ADF and not self.prism: - #print(self.ADF._wf_array[0,:,:,0,0,0,0]) intensities = einsum('pxy,xy->p',absolute(diffraction_patterns[:,:,:])**2,self.ADFmask) for i,pp in zip(intensities,selected): self.ADF._array[self.ADFindex==pp] += i + if pbar2 is not None: pbar2.update(int(max(selected))-pbar2.n) -# else: -# # simultaneously propagate all probes at once, [l],p,x,y -# exit_waves_batch = Propagate(self.base_probe, potential, self.device, progress=show_progress, onthefly=True, store_all_slices = ("slices" in self.cache_levels) ) -# # expand out to fixed l,p,x,y indices -# exit_waves_batch = expand_dims(exit_waves_batch,0) if len(exit_waves_batch.shape)==3 else exit_waves_batch -# # FFT and load into frame_data -# for layer_idx in range(self.n_layers): -# kwarg = {"dim":(-2,-1)} if TORCH_AVAILABLE else {"axes":(-2,-1)} -# exit_waves_k = xp.fft.fft2(exit_waves_batch[layer_idx,:,:,:], **kwarg) # l,p,x,y --> p,x,y -# diffraction_patterns = xp.fft.fftshift(exit_waves_k, **kwarg) -# #cropped = diffraction_patterns[:,self.i1:self.i2,self.j1:self.j2] -# #if not self.prism: -# diffraction_patterns = diffraction_patterns[:,self.keep_kxs_indices,:][:,:,self.keep_kys_indices]*self.kth**2 -# if self.use_memmap: -# diffraction_patterns = to_cpu(diffraction_patterns) -# if self.store_full or self.prism: -# frame_data[:,:,:,layer_idx,0] = diffraction_patterns # load p,x,y --> p,x,y,l,1 indices -# if self.ADF and not self.prism: -# intensities = einsum('pxy,xy->p',absolute(diffraction_patterns)**2,self.ADFmask) -# #print(type(self.ADF._array),type(intensities),type(self.ADFindex)) -# if self.use_memmap: -# intensities = asarray(intensities,device=self.device) -# self.ADF._array += intensities[self.ADFindex] -# #self.ADF._array = einsum('pxyln,'frame_data - if not self.use_memmap and ( "exitwaves" in self.cache_levels or "slices" in self.cache_levels ) and (self.store_full or self.prism): # Convert to CPU numpy array for saving @@ -503,12 +500,8 @@ def run(self) -> WFData: np.save(cache_file, frame_data_cpu) frames_computed += 1 - #print(frame_data.shape,self.wavefunction_data.shape) if self.store_full or self.prism: cropped = frame_data[:,:,:,:,0] - #print(cropped.shape) - #if self.use_memmap: - # cropped = to_cpu(cropped) if self.prism: # Recall: Prism algorithm passes a series of sinusoids through the sample (fourier components shared by all real-space probes), so now for each real-space probe, we need to calculate the exitwaves from components @@ -541,22 +534,17 @@ def run(self) -> WFData: 'calculator': 'MultisliceCalculator' } - # Create coordinate arrays for output - # Note: WFData expects (probe_positions, time, kx, ky, layer) format - # Create k-space coordinates to match expected format (same as AbTem) - # TWP: If we're not going to also provide a shifted/etc reciprocal_array, we shouldn't shift the kxs - #kxs = xp.fft.fftfreq(self.nx, d=self.dx) - #kys = xp.fft.fftfreq(self.ny, d=self.dy) - #kxs = xp.fft.fftshift(xp.fft.fftfreq(self.nx, self.sampling)) # k-space in 1/Å MOVING TO INIT SO WE CAN CROP ON-THE-FLY - #kys = xp.fft.fftshift(xp.fft.fftfreq(self.ny, self.sampling)) # k-space in 1/Å time_array = np.arange(self.n_frames) * self.trajectory.timestep # Time array in ps - layer_array = np.arange(self.nz) if "slices" in self.cache_levels else np.array([0]) # Layer indices + + # CHANGED: layer_array now reflects the actual stored layer indices. + # When cache_layer_indices=None, _active_layers == list(range(nz)) so + # layer_array == np.arange(nz), identical to the original behaviour. + layer_array = np.array(_active_layers) if "slices" in self.cache_levels else np.array([0]) # Package results array = zeros((self.n_probes,1,1,1,1),dtype=self.complex_dtype) if self.store_full: array = self.wavefunction_data - #print(array.shape,self.kxs.shape,self.kys.shape) wf_data = WFData( probe_positions=self.probe_positions, probe_xs=self.probe_xs, @@ -586,8 +574,6 @@ def run(self) -> WFData: else: logger.info(f"Cache files saved in: {self.output_dir}") - # Save if requested - psi files already saved during processing - if self.ADF: self.ADF._array /= self.n_frames # haadf_data divides by nc,nt,nl (from _wf_array's c,x,y,t,kx,ky,l) return wf_data,self.ADF @@ -646,14 +632,9 @@ def run(self) -> WFData: self.ws = ws/self.trajectory.timestep - def plot(self,w,filename=None): # TODO MAYBE "RUN" SHOULD RETURN A TACAW OBJECT SO WE CAN REUSE TACAW PLOTTING/POSTPROCESSING FUNCTIONALITY?? + def plot(self,w,filename=None): import matplotlib.pyplot as plt - #fig, ax = plt.subplots() - #extent = ( np.amin(kxs) , np.amax(kxs) , np.amin(ws) , np.amax(ws) ) - #ax.imshow((Zx[::-1,:,0]+Zy[::-1,:,0]+Zz[::-1,:,0])**.25, cmap="inferno", extent=extent,aspect="auto") - #plt.show() - i=np.argmin(np.absolute(self.ws-w)) extent = ( np.amin(self.kxs) , np.amax(self.kxs) , np.amin(self.kys) , np.amax(self.kys) ) @@ -666,5 +647,3 @@ def plot(self,w,filename=None): # TODO MAYBE "RUN" SHOULD RETURN A TACAW OBJECT plt.savefig(filename) else: plt.show() - - From bd31aa298d9d57ac9ef1c8f74e2d71a576ae1449 Mon Sep 17 00:00:00 2001 From: Haoran Ma <150745856+HaoranLMaoMao@users.noreply.github.com> Date: Thu, 7 May 2026 14:11:31 +0200 Subject: [PATCH 2/3] Update calculators.py --- src/pyslice/multislice/calculators.py | 91 ++++++++++++++++++++++++--- 1 file changed, 82 insertions(+), 9 deletions(-) diff --git a/src/pyslice/multislice/calculators.py b/src/pyslice/multislice/calculators.py index 03636ca..0918419 100644 --- a/src/pyslice/multislice/calculators.py +++ b/src/pyslice/multislice/calculators.py @@ -206,6 +206,7 @@ def setup( self.keep_kxs_indices = xp.arange(self.nx)[kx_mask==1][::self.kth] self.keep_kys_indices = xp.arange(self.ny)[ky_mask==1][::self.kth] self.nx = len(self.keep_kxs_indices) ; self.ny = len(self.keep_kys_indices) + #print("kxs",len(self.kxs),"kys",len(self.kys),"nx",self.nx,"ny",self.ny,"keepkx",len(self.keep_kxs_indices),"keepky",len(self.keep_kys_indices)) # Preferred to pass probe_xs and probe_ys from which we will define a grid if self.probe_xs is not None and self.probe_ys is not None: @@ -236,6 +237,7 @@ def setup( # Initialize storage for results self.n_frames = trajectory.n_frames + #self.n_probes = len(self.base_probe.probe_positions) # Set dtype based on the actual device we're using if TORCH_AVAILABLE and self.device is not None: @@ -255,7 +257,9 @@ def setup( self.slice_thickness, self.sampling, self.probe_positions, self.base_probe.spatial_decoherence, self.base_probe.temporal_decoherence, self.base_probe._array) + #print(cache_key) self.output_dir = Path("psi_data/" + ("torch" if TORCH_AVAILABLE else "numpy") + "_"+self.cache_key) + #self.output_dir.mkdir(parents=True, exist_ok=True) def preview_probes(self): positions = self.trajectory.positions[0] @@ -274,6 +278,7 @@ def preview_probes(self): array = np.absolute(to_cpu(potential.array))[:,::-1,0].T # imshow convention: y,x. our convention: x,y, and flip y (0,0 upper-left) xs = to_cpu(potential.xs) ; ys = to_cpu(potential.ys) extent = (np.amin(xs),np.amax(xs),np.amin(ys),np.amax(ys)) + #print(extent) ax.imshow(array, cmap="inferno", extent=extent) ax.set_xlabel("x ($\\AA$)") ; ax.set_ylabel("y ($\\AA$)") pp = np.asarray(self.base_probe.probe_positions) @@ -291,6 +296,7 @@ def run(self) -> WFData: self.base_probe._array) if self.cache_key != cache_key: self.cache_key = cache_key + #print(cache_key) self.output_dir = Path("psi_data/" + ("torch" if TORCH_AVAILABLE else "numpy") + "_"+cache_key) self.output_dir.mkdir(parents=True, exist_ok=True) @@ -350,19 +356,20 @@ def run(self) -> WFData: nc,npt,nx,ny = self.base_probe._array.shape self.n_probes = nc*len(self.probe_positions) - # Storage: [probe, frame, x, y, layer] - matches WFData expected format # CHANGED: n_layers is now len(_active_layers) instead of always self.nz self.n_layers = len(_active_layers) - if self.store_full: fd_nx = self.nx ; fd_ny = self.ny ; fd_npt = self.n_probes + #if self.base_probe.cropping: + # fd_nx = self.base_probe.cropping ; fd_ny = fd_nx if self.use_memmap: self.wavefunction_data = memmap((fd_npt, self.n_frames, fd_nx, fd_ny, self.n_layers), dtype=self.complex_dtype, filename = self.output_dir / "wdf_memmap.npy" ) else: self.wavefunction_data = zeros((fd_npt, self.n_frames, fd_nx, fd_ny, self.n_layers), dtype=self.complex_dtype, device=self.device) + #print("wavefunction_data",self.wavefunction_data.shape) # Process frames with caching and multiprocessing total_start_time = time.time() @@ -391,6 +398,8 @@ def run(self) -> WFData: # Process frames one at a time with tqdm progress tracking with tqdm(total=self.n_frames, desc="Processing frames", unit="frame") as pbar: for frame_idx in range(self.n_frames): + #if sum(absolute(self.wavefunction_data[:,frame_idx,:,:,:]),axis=(0,1,2,3))>0: # p,t,x,y,l indices + #continue cache_file = self.output_dir / f"frame_{frame_idx}.npy" # Show detailed progress for single-frame runs show_progress = (frame_idx == 0 and self.n_frames == 1 and not self.loop_probes) @@ -427,25 +436,41 @@ def run(self) -> WFData: if cache_exists: frames_cached += 1 else: + #print("create potential") ; start = time.time() potential = Potential(self.xs, self.ys, self.zs, positions, atom_type_names, kind="kirkland", device=self.device, slice_axis=self.slice_axis, progress=show_progress, cache_dir=cache_file.parent if "potentials" in self.cache_levels else None, frame_idx = frame_idx) + #print("(done)",time.time()-start) + #n_probes = nc*npt nc,npt,nx,ny = self.base_probe._array.shape ; npt = len(self.base_probe.probe_positions) n_slices = len(self.zs) n_waves = len(self.base_probe.probe_positions) + #n_waves = + #if self.base_probe.cropping: + # nx,ny = self.base_probe.cropping,self.base_probe.cropping + #print("create frame_data") ; start = time.time() # frame_data is always: p,x,y,l,1 (self.wavefunction_data expects p,t,x,y,l, since we loop time. recall Propagate gave l,p,x,y) # CHANGED: last dim is self.n_layers = len(_active_layers), not nz if self.store_full or self.prism: fd_nx = self.nx ; fd_ny = self.ny ; fd_npt = self.n_probes + #if self.base_probe.cropping: + # fd_nx = self.base_probe.cropping ; fd_ny = fd_nx # ceil(/self.kth) ; if self.use_memmap: frame_data = memmap((n_waves, fd_nx, fd_ny, self.n_layers,1), dtype=self.complex_dtype, filename = cache_file ) else: frame_data = zeros((n_waves, fd_nx, fd_ny, self.n_layers,1), dtype=self.complex_dtype, device=self.device) + #print("frame_data",frame_data.shape) + #print("(done)",time.time()-start) + #batched_probes = create_batched_probes(self.base_probe, self.probe_positions, self.device) + # Propagate returns: [l,p,x,y] where l,p are both optional (if store_all_slices=True, and if n_probes>1) + #print("chunking") ; start = time.time() chunks = [] if self.loop_probes: chunksize = self.loop_probes if isinstance(self.loop_probes,int) else 1 for i in range(10000000): chunk = xp.arange(i*chunksize,(i+1)*chunksize) + #chunk = chunk[chunknpt: break @@ -456,6 +481,7 @@ def run(self) -> WFData: else: chunks.append( xp.arange(npt) ) pbar2 = None + #print("(done)",time.time()-start) for selected in chunks: if len(selected)==npt: @@ -463,22 +489,24 @@ def run(self) -> WFData: else: probe = self.base_probe.copy(selected_probes=selected) probe.applyShifts() - # Propagate returns shape (nz, n_probes, nx, ny) when store_all_slices=True + # propagate single probe + #print("propagate") ; start = time.time() exit_waves_single = Propagate(probe, potential, self.device, progress=show_progress, onthefly=True, store_all_slices = ("slices" in self.cache_levels) ) # [l],p,x,y indices + #print("(done)",time.time()-start) # expand out to fixed l,p,x,y indices exit_waves_single = expand_dims(exit_waves_single,0) if len(exit_waves_single.shape)==3 else exit_waves_single - # FFT and load into frame_data kwarg = {"dim":(-2,-1)} if TORCH_AVAILABLE else {"axes":(-2,-1)} # CHANGED: iterate over (out_idx, real_layer_idx) pairs instead of # range(self.n_layers). When cache_layer_indices=None, _active_layers - # is range(nz) so out_idx == real_layer_idx and behaviour is identical - # to the original code. + # is list(range(nz)) so out_idx == real_layer_idx and behaviour is + # identical to the original code. for out_idx, real_layer_idx in enumerate(_active_layers): exit_waves_k = xp.fft.fft2(exit_waves_single[real_layer_idx,:,:,:], **kwarg) # l,p,x,y --> p,x,y diffraction_patterns = xp.fft.fftshift(exit_waves_k, **kwarg) + #if not self.prism: diffraction_patterns = diffraction_patterns[:,self.keep_kxs_indices,:][:,:,self.keep_kys_indices]*self.kth**2 if self.use_memmap: diffraction_patterns = to_cpu(diffraction_patterns) @@ -487,12 +515,37 @@ def run(self) -> WFData: # CHANGED: write to compact slot out_idx, not real_layer_idx frame_data[selected,:,:,out_idx,0] = diffraction_patterns # load p,x,y --> p,x,y,l,1 indices if self.ADF and not self.prism: + #print(self.ADF._wf_array[0,:,:,0,0,0,0]) intensities = einsum('pxy,xy->p',absolute(diffraction_patterns[:,:,:])**2,self.ADFmask) for i,pp in zip(intensities,selected): self.ADF._array[self.ADFindex==pp] += i - if pbar2 is not None: pbar2.update(int(max(selected))-pbar2.n) +# else: +# # simultaneously propagate all probes at once, [l],p,x,y +# exit_waves_batch = Propagate(self.base_probe, potential, self.device, progress=show_progress, onthefly=True, store_all_slices = ("slices" in self.cache_levels) ) +# # expand out to fixed l,p,x,y indices +# exit_waves_batch = expand_dims(exit_waves_batch,0) if len(exit_waves_batch.shape)==3 else exit_waves_batch +# # FFT and load into frame_data +# for layer_idx in range(self.n_layers): +# kwarg = {"dim":(-2,-1)} if TORCH_AVAILABLE else {"axes":(-2,-1)} +# exit_waves_k = xp.fft.fft2(exit_waves_batch[layer_idx,:,:,:], **kwarg) # l,p,x,y --> p,x,y +# diffraction_patterns = xp.fft.fftshift(exit_waves_k, **kwarg) +# #cropped = diffraction_patterns[:,self.i1:self.i2,self.j1:self.j2] +# #if not self.prism: +# diffraction_patterns = diffraction_patterns[:,self.keep_kxs_indices,:][:,:,self.keep_kys_indices]*self.kth**2 +# if self.use_memmap: +# diffraction_patterns = to_cpu(diffraction_patterns) +# if self.store_full or self.prism: +# frame_data[:,:,:,layer_idx,0] = diffraction_patterns # load p,x,y --> p,x,y,l,1 indices +# if self.ADF and not self.prism: +# intensities = einsum('pxy,xy->p',absolute(diffraction_patterns)**2,self.ADFmask) +# #print(type(self.ADF._array),type(intensities),type(self.ADFindex)) +# if self.use_memmap: +# intensities = asarray(intensities,device=self.device) +# self.ADF._array += intensities[self.ADFindex] +# #self.ADF._array = einsum('pxyln,'frame_data + if not self.use_memmap and ( "exitwaves" in self.cache_levels or "slices" in self.cache_levels ) and (self.store_full or self.prism): # Convert to CPU numpy array for saving @@ -500,8 +553,12 @@ def run(self) -> WFData: np.save(cache_file, frame_data_cpu) frames_computed += 1 + #print(frame_data.shape,self.wavefunction_data.shape) if self.store_full or self.prism: cropped = frame_data[:,:,:,:,0] + #print(cropped.shape) + #if self.use_memmap: + # cropped = to_cpu(cropped) if self.prism: # Recall: Prism algorithm passes a series of sinusoids through the sample (fourier components shared by all real-space probes), so now for each real-space probe, we need to calculate the exitwaves from components @@ -534,17 +591,26 @@ def run(self) -> WFData: 'calculator': 'MultisliceCalculator' } + # Create coordinate arrays for output + # Note: WFData expects (probe_positions, time, kx, ky, layer) format + # Create k-space coordinates to match expected format (same as AbTem) + # TWP: If we're not going to also provide a shifted/etc reciprocal_array, we shouldn't shift the kxs + #kxs = xp.fft.fftfreq(self.nx, d=self.dx) + #kys = xp.fft.fftfreq(self.ny, d=self.dy) + #kxs = xp.fft.fftshift(xp.fft.fftfreq(self.nx, self.sampling)) # k-space in 1/Å MOVING TO INIT SO WE CAN CROP ON-THE-FLY + #kys = xp.fft.fftshift(xp.fft.fftfreq(self.ny, self.sampling)) # k-space in 1/Å time_array = np.arange(self.n_frames) * self.trajectory.timestep # Time array in ps # CHANGED: layer_array now reflects the actual stored layer indices. # When cache_layer_indices=None, _active_layers == list(range(nz)) so # layer_array == np.arange(nz), identical to the original behaviour. - layer_array = np.array(_active_layers) if "slices" in self.cache_levels else np.array([0]) + layer_array = np.array(_active_layers) if "slices" in self.cache_levels else np.array([0]) # Layer indices # Package results array = zeros((self.n_probes,1,1,1,1),dtype=self.complex_dtype) if self.store_full: array = self.wavefunction_data + #print(array.shape,self.kxs.shape,self.kys.shape) wf_data = WFData( probe_positions=self.probe_positions, probe_xs=self.probe_xs, @@ -574,6 +640,8 @@ def run(self) -> WFData: else: logger.info(f"Cache files saved in: {self.output_dir}") + # Save if requested - psi files already saved during processing + if self.ADF: self.ADF._array /= self.n_frames # haadf_data divides by nc,nt,nl (from _wf_array's c,x,y,t,kx,ky,l) return wf_data,self.ADF @@ -632,9 +700,14 @@ def run(self) -> WFData: self.ws = ws/self.trajectory.timestep - def plot(self,w,filename=None): + def plot(self,w,filename=None): # TODO MAYBE "RUN" SHOULD RETURN A TACAW OBJECT SO WE CAN REUSE TACAW PLOTTING/POSTPROCESSING FUNCTIONALITY?? import matplotlib.pyplot as plt + #fig, ax = plt.subplots() + #extent = ( np.amin(kxs) , np.amax(kxs) , np.amin(ws) , np.amax(ws) ) + #ax.imshow((Zx[::-1,:,0]+Zy[::-1,:,0]+Zz[::-1,:,0])**.25, cmap="inferno", extent=extent,aspect="auto") + #plt.show() + i=np.argmin(np.absolute(self.ws-w)) extent = ( np.amin(self.kxs) , np.amax(self.kxs) , np.amin(self.kys) , np.amax(self.kys) ) From c4aa67cb2bd3945a484402436a155b493a7a3148 Mon Sep 17 00:00:00 2001 From: Haoran Ma <150745856+HaoranLMaoMao@users.noreply.github.com> Date: Thu, 7 May 2026 14:29:26 +0200 Subject: [PATCH 3/3] Refine comments in calculators.py Updated comments for clarity and consistency. --- src/pyslice/multislice/calculators.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pyslice/multislice/calculators.py b/src/pyslice/multislice/calculators.py index 0918419..ee51669 100644 --- a/src/pyslice/multislice/calculators.py +++ b/src/pyslice/multislice/calculators.py @@ -300,7 +300,7 @@ def run(self) -> WFData: self.output_dir = Path("psi_data/" + ("torch" if TORCH_AVAILABLE else "numpy") + "_"+cache_key) self.output_dir.mkdir(parents=True, exist_ok=True) - # ── Resolve which layers to store ──────────────────────────────────── + # ── Resolve which layers to store ── # NEW: if cache_layer_indices is set, only those layers are FFT'd and # written to disk; the propagation itself still runs through all nz # slices (physically required). cache_layer_indices=None keeps the @@ -334,7 +334,7 @@ def run(self) -> WFData: _active_layers = list(range(self.nz)) if "slices" in self.cache_levels else [0] self._active_layers = _active_layers # expose for inspection / post-processing - # ───────────────────────────────────────────────────────────────────── + # if probes are over vacuum (e.g. nanoparticles), we don't need to propagate them? self.probe_indices = xp.arange(len(self.probe_positions)) @@ -720,3 +720,4 @@ def plot(self,w,filename=None): # TODO MAYBE "RUN" SHOULD RETURN A TACAW OBJECT plt.savefig(filename) else: plt.show() +