From bde6ab8c33c2fadc012c5c770435ea088c41ac08 Mon Sep 17 00:00:00 2001 From: "Thomas Pfeifer (qwe)" Date: Tue, 25 Nov 2025 14:34:24 -0500 Subject: [PATCH 1/8] reimplment simplified MultisliceCalculator, ensuring old psi_data is compatible, ensuring 18_caching test still passes --- src/pyslice/backend.py | 5 + src/pyslice/multislice/calculators.py | 210 +++++++------------------- tests/18_caching.py | 152 ++++++++++--------- 3 files changed, 142 insertions(+), 225 deletions(-) diff --git a/src/pyslice/backend.py b/src/pyslice/backend.py index 74ba19e..23eb768 100644 --- a/src/pyslice/backend.py +++ b/src/pyslice/backend.py @@ -118,6 +118,11 @@ def fftfreq(n, d, dtype=DEFAULT_FLOAT_DTYPE, device=DEFAULT_DEVICE): else: return xp.fft.fftfreq(n, d, dtype=dtype) +def expand_dims(ary,d): + if xp != np: + return xp.unsqueeze(ary,dim=d) + else: + return np.expand_dims(ary,dim=d) def exp(x): return xp.exp(x) diff --git a/src/pyslice/multislice/calculators.py b/src/pyslice/multislice/calculators.py index 47dcec9..0d14606 100644 --- a/src/pyslice/multislice/calculators.py +++ b/src/pyslice/multislice/calculators.py @@ -34,6 +34,7 @@ from .trajectory import Trajectory from ..postprocessing.wf_data import WFData from .sed import SED +from ..backend import zeros,expand_dims logger = logging.getLogger(__name__) @@ -207,13 +208,9 @@ def setup( self.float_dtype = np.float64 # Storage: [probe, frame, x, y, layer] - matches WFData expected format - n_layers = nz if "slices" in cache_levels else 1 - if TORCH_AVAILABLE and self.device is not None: - self.wavefunction_data = torch.zeros((self.n_probes, self.n_frames, nx, ny, n_layers), + self.n_layers = nz if "slices" in cache_levels else 1 + self.wavefunction_data = zeros((self.n_probes, self.n_frames, nx, ny, self.n_layers), dtype=self.complex_dtype, device=self.device) - else: - self.wavefunction_data = np.zeros((self.n_probes, self.n_frames, nx, ny, n_layers), - dtype=self.complex_dtype) def run(self) -> WFData: @@ -230,35 +227,58 @@ def run(self) -> WFData: with tqdm(total=self.n_frames, desc="Processing frames", unit="frame") as pbar: for frame_idx in range(self.n_frames): cache_file = self.output_dir / f"frame_{frame_idx}.npy" + positions = self.trajectory.positions[frame_idx] atom_types = self.trajectory.atom_types - - args = [ frame_idx, positions, atom_types, self.xs, self.ys, self.zs, - self.aperture, self.voltage_eV, self.base_probe, self.probe_positions, self.element_map, - cache_file, self.cache_levels, self.slice_axis, self.device ] - - # Process frame - if frame_idx == 0 and self.n_frames == 1: - args[0] = -1 - - frame_idx_result, frame_data, was_cached = _process_frame_worker_torch(args) - - # crop frame's diffraction image - frame_data = frame_data[:,self.i1:self.i2,self.j1:self.j2,:,:] - - # Store result - for probe_idx in range(self.n_probes): - if "slices" in self.cache_levels: - # frame_data shape: (n_probes, nx, ny, n_slices, 1) - self.wavefunction_data[probe_idx, frame_idx, :, :, :] = frame_data[probe_idx, :, :, :, 0] + atom_type_names = [] + for atom_type in atom_types: + if atom_type in self.element_map: + atom_type_names.append(self.element_map[atom_type]) else: - self.wavefunction_data[probe_idx, frame_idx, :, :, 0] = frame_data[probe_idx, :, :, 0, 0] - - if was_cached: - frames_cached += 1 + atom_type_names.append(atom_type) + + # frame_data should always be shaped: n_probes,nkx,nky,n_layers,1 (idk why there's a trailing 1) + cache_exists,frame_data = checkCache(cache_file,self.cache_levels) + + if cache_exists: + #print(frame_data.shape) + pass else: - frames_computed += 1 - + potential = Potential(self.xs, self.ys, self.zs, positions, atom_type_names, kind="kirkland", device=self.device, slice_axis=self.slice_axis, progress=(frame_idx==-1), cache_dir=cache_file.parent if "potentials" in self.cache_levels else None, frame_idx = frame_idx) + + n_probes = len(self.probe_positions) + nx, ny = len(self.xs), len(self.ys) + n_slices = len(self.zs) + + batched_probes = create_batched_probes(self.base_probe, self.probe_positions, self.device) + # Propagate returns: [l,p,x,y] where l,p are both optional (if store_all_slices=True, and if n_probes>1) + exit_waves_batch = Propagate(batched_probes, potential, self.device, progress=(frame_idx==-1), onthefly=True, store_all_slices = ("slices" in self.cache_levels) ) + #print(exit_waves_batch.shape) + if n_probes==1 and "slices" not in self.cache_levels: + exit_waves_batch = expand_dims(exit_waves_batch,0) + if "slices" not in self.cache_levels: + exit_waves_batch = expand_dims(exit_waves_batch,0) + #print(exit_waves_batch.shape) + # frame_data is always: p,x,y,l,1 (self.wavefunction_data expects p,t,x,y,l, since we loop time. recall Propagate gave l,p,x,y) + frame_data = zeros((n_probes, nx, ny, self.n_layers,1), dtype=self.complex_dtype, device=self.device) + #print(frame_data.shape) + for layer_idx in range(self.n_layers): + kwarg = {"dim":(-2,-1)} if TORCH_AVAILABLE else {"axes":(-2,-1)} + exit_waves_k = xp.fft.fft2(exit_waves_batch[layer_idx,:,:,:], **kwarg) # l,p,x,y --> p,x,y + diffraction_patterns = xp.fft.fftshift(exit_waves_k, **kwarg) + cropped = diffraction_patterns[:,self.i1:self.i2,self.j1:self.j2] + frame_data[:,:,:,layer_idx,0] = cropped # load p,x,y --> p,x,y,l,1 indices + + # Convert to CPU numpy array for saving + if TORCH_AVAILABLE and hasattr(frame_data, 'cpu'): + frame_data_cpu = frame_data.cpu().numpy() + else: + frame_data_cpu = frame_data + + if "exitwaves" in self.cache_levels: + np.save(cache_file, frame_data_cpu) + + self.wavefunction_data[:, frame_idx, :, :, :] = frame_data[:, :, :, :, 0] # load p,x,y,l,1 --> p,t,x,y,l indices # Update progress bar for this frame pbar.update(1) @@ -322,137 +342,17 @@ def run(self) -> WFData: # Save if requested - psi files already saved during processing return wf_data - logging_tracker=[] -def _process_frame_worker_torch(args): - frame_idx, positions, atom_types, xs, ys, zs, aperture, eV, probe, probe_positions, element_map, cache_file, cache_levels , slice_axis, device = args - +def checkCache(cache_file,cache_levels): + global logging_tracker if cache_file.exists() and ( "exitwaves" in cache_levels or "slices" in cache_levels ): - global logging_tracker parent = str(cache_file.parent) if "cache_exists-"+parent not in logging_tracker: logging_tracker.append("cache_exists-"+parent) logging.warning("One or more frames reloaded from cache: "+str(cache_file.parent)) - return frame_idx, xp.asarray(np.load(cache_file)), True # if always saving as numpy, then must cast to torch array if re-reading cache file back in - - # Use the device passed from the calculator, or auto-detect if None - if TORCH_AVAILABLE: - if device is not None: - worker_device = device - else: - worker_device = torch.device('cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')) - - # Set dtype based on worker device - if worker_device.type == 'mps': - worker_complex_dtype = torch.complex64 - worker_float_dtype = torch.float32 - else: - worker_complex_dtype = torch.complex128 - worker_float_dtype = torch.float64 - else: - worker_device = None - worker_complex_dtype = np.complex128 - worker_float_dtype = np.float64 - - atom_type_names = [] - for atom_type in atom_types: - if atom_type in element_map: - atom_type_names.append(element_map[atom_type]) - else: - atom_type_names.append(atom_type) - - #try: - potential = Potential(xs, ys, zs, positions, atom_type_names, kind="kirkland", device=worker_device, slice_axis=slice_axis, progress=(frame_idx==-1), cache_dir=cache_file.parent if "potentials" in cache_levels else None, frame_idx = frame_idx) - - n_probes = len(probe_positions) - nx, ny = len(xs), len(ys) - n_slices = len(zs) - - batched_probes = create_batched_probes(probe, probe_positions, worker_device) - exit_waves_batch = Propagate(batched_probes, potential, worker_device, progress=(frame_idx==-1), onthefly=True, store_all_slices = ("slices" in cache_levels) ) - - if "slices" in cache_levels: - # exit_waves_batch shape: (n_slices, n_probes, nx, ny) - if TORCH_AVAILABLE and worker_device is not None: - frame_data = torch.zeros((n_probes, nx, ny, n_slices, 1), dtype=worker_complex_dtype, device=worker_device) - else: - frame_data = np.zeros((n_probes, nx, ny, n_slices, 1), dtype=worker_complex_dtype) - - # Convert all slices to k-space - for slice_idx in range(n_slices): - slice_waves = exit_waves_batch[slice_idx, :, :, :] # (n_probes, nx, ny) - kwarg = {"dim":(-2,-1)} if TORCH_AVAILABLE else {"axes":(-2,-1)} - waves_k = xp.fft.fft2(slice_waves, **kwarg) - diffraction_patterns = xp.fft.fftshift(waves_k, **kwarg) - - # Store in frame_data - for i in range(n_probes): - frame_data[i, :, :, slice_idx, 0] = diffraction_patterns[i, :, :] - else: - # exit_waves_batch shape: (n_probes, nx, ny) - if TORCH_AVAILABLE and worker_device is not None: - frame_data = torch.zeros((n_probes, nx, ny, 1, 1), dtype=worker_complex_dtype, device=worker_device) - else: - frame_data = np.zeros((n_probes, nx, ny, 1, 1), dtype=worker_complex_dtype) - - # Convert all exit waves to k-space - kwarg = {"dim":(-2,-1)} if TORCH_AVAILABLE else {"axes":(-2,-1)} - exit_waves_k = xp.fft.fft2(exit_waves_batch, **kwarg) - diffraction_patterns = xp.fft.fftshift(exit_waves_k, **kwarg) - - # Store results - frame_data[:, :, :, 0, 0] = diffraction_patterns #.cpu().numpy() - #else: - # # Fallback to individual processing - # for probe_idx, (px, py) in enumerate(probe_positions): - # shifted_probe = probe.copy() - # - # probe_k = torch.fft.fft2(shifted_probe.array) - # - # kx_shift = torch.exp(2j * torch.pi * shifted_probe.kxs[:, None] * px) - # ky_shift = torch.exp(2j * torch.pi * shifted_probe.kys[None, :] * py) - # probe_k_shifted = probe_k * kx_shift * ky_shift - # - # shifted_probe.array = torch.fft.ifft2(probe_k_shifted) - # - # exit_wave_torch = PropagateTorch(shifted_probe, potential, worker_device) - # - # exit_wave_k = torch.fft.fft2(exit_wave_torch) - # diffraction_pattern = torch.fft.fftshift(exit_wave_k) - # - # frame_data[probe_idx, :, :, 0, 0] = diffraction_pattern.cpu().numpy() - - # Convert to CPU numpy array for saving - if TORCH_AVAILABLE and hasattr(frame_data, 'cpu'): - frame_data_cpu = frame_data.cpu().numpy() - else: - frame_data_cpu = frame_data - - if "exitwaves" in cache_levels or "slices" in cache_levels: - np.save(cache_file, frame_data_cpu) - - return frame_idx, frame_data, False - - #except Exception as e: - # logger.error(f"Error processing frame {frame_idx} with PyTorch: {e}") - # from .potential import Potential - # from .multislice_npy import Probe, Propagate - # - # potential = Potential(xs, ys, zs, positions, atom_type_names, kind="kirkland") - # probe = Probe(xs, ys, aperture, eV) - # - # n_probes = len(probe_positions) - # nx, ny = len(xs), len(ys) - ## frame_data = np.zeros((n_probes, nx, ny, 1, 1), dtype=complex) - # - # for probe_idx, (px, py) in enumerate(probe_positions): - # exit_wave = Propagate(probe, potential) - # diffraction_pattern = np.fft.fftshift(np.fft.fft2(exit_wave)) - # frame_data[probe_idx, :, :, 0, 0] = diffraction_pattern - # - # np.save(cache_file, frame_data) - # return frame_idx, frame_data, False + return True,xp.asarray(np.load(cache_file)) # if always saving as numpy, then must cast to torch array if re-reading cache file back in + return False,0 class SEDCalculator: diff --git a/tests/18_caching.py b/tests/18_caching.py index 2f00f51..efeb46b 100644 --- a/tests/18_caching.py +++ b/tests/18_caching.py @@ -19,6 +19,8 @@ a,b=2.4907733333333337,2.1570729817355123 # cache_level options include: ["exitwaves","slices","potentials"] +tests = [1,2,3,4,5,6,7,8,9,10] +#tests = [9,10] # LOAD TRAJECTORY trajectory=Loader(dump,timestep=dt,atom_mapping=types).load() @@ -26,92 +28,102 @@ trajectory=trajectory.slice_positions([0,10*a],[0,10*b]) # ONE TIMESTEPS, ONE PROBE: -print("1. one timestep, one probe, normal caching") -traj1=trajectory.get_random_timesteps(11,seed=1) -calculator=MultisliceCalculator() -calculator.setup(traj1,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/01-test.npy","01") # p,t,x,y,l indices +if 1 in tests: + print("1. one timestep, one probe, normal caching") + traj1=trajectory.get_random_timesteps(11,seed=1) + calculator=MultisliceCalculator() + calculator.setup(traj1,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/01-test.npy","01") # p,t,x,y,l indices # ONE TIMESTEPS, ONE PROBE: -print("2. one timestep, one probe, cache potentials only") -traj2=trajectory.get_random_timesteps(11,seed=2) -calculator=MultisliceCalculator() -calculator.setup(traj2,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["potentials"]) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/02-test.npy","02") # p,t,x,y,l indices +if 2 in tests: + print("2. one timestep, one probe, cache potentials only") + traj2=trajectory.get_random_timesteps(11,seed=2) + calculator=MultisliceCalculator() + calculator.setup(traj2,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["potentials"]) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/02-test.npy","02") # p,t,x,y,l indices # ONE TIMESTEP, MANY PROBES: -print("3. one timestep, many probes, normal caching") -traj3=trajectory.get_random_timesteps(1,seed=3) -calculator=MultisliceCalculator() -probe_xs = np.linspace(a,3*a,14) -probe_ys = np.linspace(b,3*b,16) -calculator.setup(traj3,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys) -exitwaves = calculator.run() -differ(exitwaves.array[::5,:,::5,::5,:],"outputs/caching/03-test.npy","03") # p,t,x,y,l indices +if 3 in tests: + print("3. one timestep, many probes, normal caching") + traj3=trajectory.get_random_timesteps(1,seed=3) + calculator=MultisliceCalculator() + probe_xs = np.linspace(a,3*a,14) + probe_ys = np.linspace(b,3*b,16) + calculator.setup(traj3,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys) + exitwaves = calculator.run() + differ(exitwaves.array[::5,:,::5,::5,:],"outputs/caching/03-test.npy","03") # p,t,x,y,l indices # MANY TIMESTEPS, ONE PROBE: -print("4. many timesteps, one probe, normal caching") -traj4=trajectory.get_random_timesteps(10,seed=4) -calculator=MultisliceCalculator() -calculator.setup(traj4,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/04-test.npy","04") # p,t,x,y,l indices +if 4 in tests: + print("4. many timesteps, one probe, normal caching") + traj4=trajectory.get_random_timesteps(10,seed=4) + calculator=MultisliceCalculator() + calculator.setup(traj4,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/04-test.npy","04") # p,t,x,y,l indices # MANY TIMESTEPS, MANY PROBES: -print("5. many timesteps, many probes, normal caching") -traj5=trajectory.get_random_timesteps(5,seed=5) -calculator=MultisliceCalculator() -probe_xs = np.linspace(a,3*a,9) -probe_ys = np.linspace(b,3*b,10) -calculator.setup(traj5,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::10,::10,:],"outputs/caching/05-test.npy","05") # p,t,x,y,l indices +if 5 in tests: + print("5. many timesteps, many probes, normal caching") + traj5=trajectory.get_random_timesteps(5,seed=5) + calculator=MultisliceCalculator() + probe_xs = np.linspace(a,3*a,9) + probe_ys = np.linspace(b,3*b,10) + calculator.setup(traj5,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::10,::10,:],"outputs/caching/05-test.npy","05") # p,t,x,y,l indices # CACHING TURNED OFF: -print("6. many timesteps, many probes, no caching") -traj6=trajectory.get_random_timesteps(5,seed=6) -calculator=MultisliceCalculator() -probe_xs = np.linspace(a,3*a,9) -probe_ys = np.linspace(b,3*b,10) -calculator.setup(traj6,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=[]) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::10,::10,:],"outputs/caching/06-test.npy","06") # p,t,x,y,l indices +if 6 in tests: + print("6. many timesteps, many probes, no caching") + traj6=trajectory.get_random_timesteps(5,seed=6) + calculator=MultisliceCalculator() + probe_xs = np.linspace(a,3*a,9) + probe_ys = np.linspace(b,3*b,10) + calculator.setup(traj6,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=[]) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::10,::10,:],"outputs/caching/06-test.npy","06") # p,t,x,y,l indices # OR WITH THE POTENTIAL SAVED OFF ONLY -print("7. many timesteps, one probe, caching potentials only") -traj7=trajectory.get_random_timesteps(10,seed=7) -calculator=MultisliceCalculator() -calculator.setup(traj7,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["potentials"]) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/07-test.npy","07") # p,t,x,y,l indices +if 7 in tests: + print("7. many timesteps, one probe, caching potentials only") + traj7=trajectory.get_random_timesteps(10,seed=7) + calculator=MultisliceCalculator() + calculator.setup(traj7,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["potentials"]) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/07-test.npy","07") # p,t,x,y,l indices # LAYERWISE CACHING -print("8. many timesteps, many probes, layerwise caching") -traj8=trajectory.get_random_timesteps(5,seed=8) -calculator=MultisliceCalculator() -probe_xs = np.linspace(a,3*a,6) -probe_ys = np.linspace(b,3*b,7) -calculator.setup(traj8,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=["slices"]) -exitwaves = calculator.run() -differ(exitwaves.array[:,::3,::20,::20,::5],"outputs/caching/08-test.npy","08") # p,t,x,y,l indices +if 8 in tests: + print("8. many timesteps, many probes, layerwise caching") + traj8=trajectory.get_random_timesteps(5,seed=8) + calculator=MultisliceCalculator() + probe_xs = np.linspace(a,3*a,6) + probe_ys = np.linspace(b,3*b,7) + calculator.setup(traj8,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=["slices"]) + exitwaves = calculator.run() + differ(exitwaves.array[:,::3,::20,::20,::5],"outputs/caching/08-test.npy","08") # p,t,x,y,l indices # LAYERWISE CACHING, WITH ONE PROBE -print("9. many timesteps, one probe, layerwise caching") -traj9=trajectory.get_random_timesteps(5,seed=9) -calculator=MultisliceCalculator() -calculator.setup(traj9,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["slices"]) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::5,::5,::5],"outputs/caching/09-test.npy","09") # p,t,x,y,l indices +if 9 in tests: + print("9. many timesteps, one probe, layerwise caching") + traj9=trajectory.get_random_timesteps(5,seed=9) + calculator=MultisliceCalculator() + calculator.setup(traj9,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["slices"]) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::5,::5,::5],"outputs/caching/09-test.npy","09") # p,t,x,y,l indices # LAYERWISE CACHING OR WITH ONE TIMESTEP -print("10. one timestep, many probes, layerwise caching") -traj10=trajectory.get_random_timesteps(1,seed=10) -calculator=MultisliceCalculator() -probe_xs = np.linspace(a,3*a,9) -probe_ys = np.linspace(b,3*b,10) -calculator.setup(traj10,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=["slices"]) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::10,::10,::5],"outputs/caching/10-test.npy","10") # p,t,x,y,l indices +if 10 in tests: + print("10. one timestep, many probes, layerwise caching") + traj10=trajectory.get_random_timesteps(1,seed=10) + calculator=MultisliceCalculator() + probe_xs = np.linspace(a,3*a,9) + probe_ys = np.linspace(b,3*b,10) + calculator.setup(traj10,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=["slices"]) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::10,::10,::5],"outputs/caching/10-test.npy","10") # p,t,x,y,l indices From e19641de565b529f8d9944018444035675a80148 Mon Sep 17 00:00:00 2001 From: "Thomas Pfeifer (qwe)" Date: Tue, 25 Nov 2025 14:54:21 -0500 Subject: [PATCH 2/8] update backend expand_dims for numpy --- src/pyslice/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pyslice/backend.py b/src/pyslice/backend.py index 23eb768..c657b5c 100644 --- a/src/pyslice/backend.py +++ b/src/pyslice/backend.py @@ -122,7 +122,7 @@ def expand_dims(ary,d): if xp != np: return xp.unsqueeze(ary,dim=d) else: - return np.expand_dims(ary,dim=d) + return np.expand_dims(ary,d) def exp(x): return xp.exp(x) From 7f43b68e8227bb9e30b76a3495b52b331f0e3668 Mon Sep 17 00:00:00 2001 From: "Thomas Pfeifer (qwe)" Date: Tue, 25 Nov 2025 14:57:59 -0500 Subject: [PATCH 3/8] forgot to add back in slices to trigger exitwave caching --- src/pyslice/multislice/calculators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pyslice/multislice/calculators.py b/src/pyslice/multislice/calculators.py index 0d14606..5a2da7a 100644 --- a/src/pyslice/multislice/calculators.py +++ b/src/pyslice/multislice/calculators.py @@ -275,7 +275,7 @@ def run(self) -> WFData: else: frame_data_cpu = frame_data - if "exitwaves" in self.cache_levels: + if "exitwaves" in self.cache_levels or "slices" in self.cache_levels: np.save(cache_file, frame_data_cpu) self.wavefunction_data[:, frame_idx, :, :, :] = frame_data[:, :, :, :, 0] # load p,x,y,l,1 --> p,t,x,y,l indices From 2739b16712e09987e4ce6f400123d6eb09b35c6c Mon Sep 17 00:00:00 2001 From: "Thomas Pfeifer (qwe)" Date: Tue, 25 Nov 2025 15:55:02 -0500 Subject: [PATCH 4/8] apparently potential slice saving was broken for torch gpu --- src/pyslice/multislice/potentials.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pyslice/multislice/potentials.py b/src/pyslice/multislice/potentials.py index 683b553..a3cbc31 100644 --- a/src/pyslice/multislice/potentials.py +++ b/src/pyslice/multislice/potentials.py @@ -356,6 +356,8 @@ def calculateSlice(slice_idx): dy = self.ys[1] - self.ys[0] Z = real / (dx**2 * dy**2) if cache_file is not None: + if TORCH_AVAILABLE and hasattr(Z, 'cpu'): + Z = Z.cpu().numpy() np.save(cache_file,Z) return Z From 3055a734ab94e4a19525c3a8a317ff7ba193bdef Mon Sep 17 00:00:00 2001 From: "Thomas Pfeifer (qwe)" Date: Tue, 25 Nov 2025 16:10:45 -0500 Subject: [PATCH 5/8] potential slice to cpu --- src/pyslice/multislice/potentials.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/pyslice/multislice/potentials.py b/src/pyslice/multislice/potentials.py index a3cbc31..047a57c 100644 --- a/src/pyslice/multislice/potentials.py +++ b/src/pyslice/multislice/potentials.py @@ -357,8 +357,10 @@ def calculateSlice(slice_idx): Z = real / (dx**2 * dy**2) if cache_file is not None: if TORCH_AVAILABLE and hasattr(Z, 'cpu'): - Z = Z.cpu().numpy() - np.save(cache_file,Z) + Z_cpu = Z.cpu().numpy() + else: + Z_cpu = Z + np.save(cache_file,Z_cpu) return Z self.calculateSlice = calculateSlice From e829449bc945f6603d7a7ea938b2ea0b21aec118 Mon Sep 17 00:00:00 2001 From: "Thomas Pfeifer (qwe)" Date: Tue, 25 Nov 2025 16:16:22 -0500 Subject: [PATCH 6/8] retorch on potential reload --- src/pyslice/multislice/potentials.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/pyslice/multislice/potentials.py b/src/pyslice/multislice/potentials.py index 047a57c..5aca8b3 100644 --- a/src/pyslice/multislice/potentials.py +++ b/src/pyslice/multislice/potentials.py @@ -2,6 +2,7 @@ from pathlib import Path import logging,os from tqdm import tqdm +from ..backend import zeros try: import torch ; xp = torch @@ -286,7 +287,8 @@ def calculateSlice(slice_idx): if self.cache_dir is not None: cache_file = self.cache_dir / ("potential_"+str(frame_idx)+"_"+str(slice_idx)+".npy") if cache_file is not None and os.path.exists(cache_file): - return np.load(cache_file) + Z = np.load(cache_file) + return zeros(Z.shape) + Z # Initialize slice of potential array using xp with conditional device device_kwargs = {'device': self.device } if self.use_torch else {} From a1793c4519088ce320bdaedb9073c11d739c6c31 Mon Sep 17 00:00:00 2001 From: "Thomas Pfeifer (qwe)" Date: Tue, 25 Nov 2025 16:25:57 -0500 Subject: [PATCH 7/8] hopefully get torch to device working for potential reload from cache --- src/pyslice/multislice/potentials.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/pyslice/multislice/potentials.py b/src/pyslice/multislice/potentials.py index 5aca8b3..ac3171b 100644 --- a/src/pyslice/multislice/potentials.py +++ b/src/pyslice/multislice/potentials.py @@ -288,7 +288,9 @@ def calculateSlice(slice_idx): cache_file = self.cache_dir / ("potential_"+str(frame_idx)+"_"+str(slice_idx)+".npy") if cache_file is not None and os.path.exists(cache_file): Z = np.load(cache_file) - return zeros(Z.shape) + Z + if TORCH_AVAILABLE: + return xp.from_numpy(Z).to(device) + return Z # Initialize slice of potential array using xp with conditional device device_kwargs = {'device': self.device } if self.use_torch else {} From c8231cc0cc5d4bae022ad33c759511c7a6bfb5d7 Mon Sep 17 00:00:00 2001 From: h-walk Date: Fri, 9 Jan 2026 10:03:39 -0500 Subject: [PATCH 8/8] fix progress bar display and frame tracking --- src/pyslice/multislice/calculators.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/pyslice/multislice/calculators.py b/src/pyslice/multislice/calculators.py index bc5c04c..2a35b1c 100644 --- a/src/pyslice/multislice/calculators.py +++ b/src/pyslice/multislice/calculators.py @@ -228,6 +228,8 @@ def run(self) -> WFData: with tqdm(total=self.n_frames, desc="Processing frames", unit="frame") as pbar: for frame_idx in range(self.n_frames): cache_file = self.output_dir / f"frame_{frame_idx}.npy" + # Show detailed progress for single-frame runs + show_progress = (frame_idx == 0 and self.n_frames == 1) positions = self.trajectory.positions[frame_idx] atom_types = self.trajectory.atom_types @@ -242,10 +244,9 @@ def run(self) -> WFData: cache_exists,frame_data = checkCache(cache_file,self.cache_levels) if cache_exists: - #print(frame_data.shape) - pass + frames_cached += 1 else: - potential = Potential(self.xs, self.ys, self.zs, positions, atom_type_names, kind="kirkland", device=self.device, slice_axis=self.slice_axis, progress=(frame_idx==-1), cache_dir=cache_file.parent if "potentials" in self.cache_levels else None, frame_idx = frame_idx) + potential = Potential(self.xs, self.ys, self.zs, positions, atom_type_names, kind="kirkland", device=self.device, slice_axis=self.slice_axis, progress=show_progress, cache_dir=cache_file.parent if "potentials" in self.cache_levels else None, frame_idx = frame_idx) n_probes = len(self.probe_positions) nx, ny = len(self.xs), len(self.ys) @@ -253,7 +254,7 @@ def run(self) -> WFData: batched_probes = create_batched_probes(self.base_probe, self.probe_positions, self.device) # Propagate returns: [l,p,x,y] where l,p are both optional (if store_all_slices=True, and if n_probes>1) - exit_waves_batch = Propagate(batched_probes, potential, self.device, progress=(frame_idx==-1), onthefly=True, store_all_slices = ("slices" in self.cache_levels) ) + exit_waves_batch = Propagate(batched_probes, potential, self.device, progress=show_progress, onthefly=True, store_all_slices = ("slices" in self.cache_levels) ) #print(exit_waves_batch.shape) if n_probes==1 and "slices" not in self.cache_levels: exit_waves_batch = expand_dims(exit_waves_batch,0) @@ -278,6 +279,7 @@ def run(self) -> WFData: if "exitwaves" in self.cache_levels or "slices" in self.cache_levels: np.save(cache_file, frame_data_cpu) + frames_computed += 1 self.wavefunction_data[:, frame_idx, :, :, :] = frame_data[:, :, :, :, 0] # load p,x,y,l,1 --> p,t,x,y,l indices # Update progress bar for this frame