diff --git a/src/pyslice/backend.py b/src/pyslice/backend.py index 74ba19e..c657b5c 100644 --- a/src/pyslice/backend.py +++ b/src/pyslice/backend.py @@ -118,6 +118,11 @@ def fftfreq(n, d, dtype=DEFAULT_FLOAT_DTYPE, device=DEFAULT_DEVICE): else: return xp.fft.fftfreq(n, d, dtype=dtype) +def expand_dims(ary,d): + if xp != np: + return xp.unsqueeze(ary,dim=d) + else: + return np.expand_dims(ary,d) def exp(x): return xp.exp(x) diff --git a/src/pyslice/multislice/calculators.py b/src/pyslice/multislice/calculators.py index dc3952a..2a35b1c 100644 --- a/src/pyslice/multislice/calculators.py +++ b/src/pyslice/multislice/calculators.py @@ -34,6 +34,7 @@ from .trajectory import Trajectory from ..postprocessing.wf_data import WFData from .sed import SED +from ..backend import zeros,expand_dims logger = logging.getLogger(__name__) @@ -208,13 +209,9 @@ def setup( self.float_dtype = np.float64 # Storage: [probe, frame, x, y, layer] - matches WFData expected format - n_layers = nz if "slices" in cache_levels else 1 - if TORCH_AVAILABLE and self.device is not None: - self.wavefunction_data = torch.zeros((self.n_probes, self.n_frames, nx, ny, n_layers), + self.n_layers = nz if "slices" in cache_levels else 1 + self.wavefunction_data = zeros((self.n_probes, self.n_frames, nx, ny, self.n_layers), dtype=self.complex_dtype, device=self.device) - else: - self.wavefunction_data = np.zeros((self.n_probes, self.n_frames, nx, ny, n_layers), - dtype=self.complex_dtype) def run(self) -> WFData: @@ -231,35 +228,60 @@ def run(self) -> WFData: with tqdm(total=self.n_frames, desc="Processing frames", unit="frame") as pbar: for frame_idx in range(self.n_frames): cache_file = self.output_dir / f"frame_{frame_idx}.npy" + # Show detailed progress for single-frame runs + show_progress = (frame_idx == 0 and self.n_frames == 1) + positions = self.trajectory.positions[frame_idx] atom_types = self.trajectory.atom_types - - args = [ frame_idx, positions, atom_types, self.xs, self.ys, self.zs, - self.aperture, self.voltage_eV, self.base_probe, self.probe_positions, self.element_map, - cache_file, self.cache_levels, self.slice_axis, self.device ] - - # Process frame - if frame_idx == 0 and self.n_frames == 1: - args[0] = -1 - - frame_idx_result, frame_data, was_cached = _process_frame_worker_torch(args) - - # crop frame's diffraction image - frame_data = frame_data[:,self.i1:self.i2,self.j1:self.j2,:,:] - - # Store result - for probe_idx in range(self.n_probes): - if "slices" in self.cache_levels: - # frame_data shape: (n_probes, nx, ny, n_slices, 1) - self.wavefunction_data[probe_idx, frame_idx, :, :, :] = frame_data[probe_idx, :, :, :, 0] + atom_type_names = [] + for atom_type in atom_types: + if atom_type in self.element_map: + atom_type_names.append(self.element_map[atom_type]) else: - self.wavefunction_data[probe_idx, frame_idx, :, :, 0] = frame_data[probe_idx, :, :, 0, 0] - - if was_cached: + atom_type_names.append(atom_type) + + # frame_data should always be shaped: n_probes,nkx,nky,n_layers,1 (idk why there's a trailing 1) + cache_exists,frame_data = checkCache(cache_file,self.cache_levels) + + if cache_exists: frames_cached += 1 else: + potential = Potential(self.xs, self.ys, self.zs, positions, atom_type_names, kind="kirkland", device=self.device, slice_axis=self.slice_axis, progress=show_progress, cache_dir=cache_file.parent if "potentials" in self.cache_levels else None, frame_idx = frame_idx) + + n_probes = len(self.probe_positions) + nx, ny = len(self.xs), len(self.ys) + n_slices = len(self.zs) + + batched_probes = create_batched_probes(self.base_probe, self.probe_positions, self.device) + # Propagate returns: [l,p,x,y] where l,p are both optional (if store_all_slices=True, and if n_probes>1) + exit_waves_batch = Propagate(batched_probes, potential, self.device, progress=show_progress, onthefly=True, store_all_slices = ("slices" in self.cache_levels) ) + #print(exit_waves_batch.shape) + if n_probes==1 and "slices" not in self.cache_levels: + exit_waves_batch = expand_dims(exit_waves_batch,0) + if "slices" not in self.cache_levels: + exit_waves_batch = expand_dims(exit_waves_batch,0) + #print(exit_waves_batch.shape) + # frame_data is always: p,x,y,l,1 (self.wavefunction_data expects p,t,x,y,l, since we loop time. recall Propagate gave l,p,x,y) + frame_data = zeros((n_probes, nx, ny, self.n_layers,1), dtype=self.complex_dtype, device=self.device) + #print(frame_data.shape) + for layer_idx in range(self.n_layers): + kwarg = {"dim":(-2,-1)} if TORCH_AVAILABLE else {"axes":(-2,-1)} + exit_waves_k = xp.fft.fft2(exit_waves_batch[layer_idx,:,:,:], **kwarg) # l,p,x,y --> p,x,y + diffraction_patterns = xp.fft.fftshift(exit_waves_k, **kwarg) + cropped = diffraction_patterns[:,self.i1:self.i2,self.j1:self.j2] + frame_data[:,:,:,layer_idx,0] = cropped # load p,x,y --> p,x,y,l,1 indices + + # Convert to CPU numpy array for saving + if TORCH_AVAILABLE and hasattr(frame_data, 'cpu'): + frame_data_cpu = frame_data.cpu().numpy() + else: + frame_data_cpu = frame_data + + if "exitwaves" in self.cache_levels or "slices" in self.cache_levels: + np.save(cache_file, frame_data_cpu) frames_computed += 1 - + + self.wavefunction_data[:, frame_idx, :, :, :] = frame_data[:, :, :, :, 0] # load p,x,y,l,1 --> p,t,x,y,l indices # Update progress bar for this frame pbar.update(1) @@ -323,137 +345,17 @@ def run(self) -> WFData: # Save if requested - psi files already saved during processing return wf_data - logging_tracker=[] -def _process_frame_worker_torch(args): - frame_idx, positions, atom_types, xs, ys, zs, aperture, eV, probe, probe_positions, element_map, cache_file, cache_levels , slice_axis, device = args - +def checkCache(cache_file,cache_levels): + global logging_tracker if cache_file.exists() and ( "exitwaves" in cache_levels or "slices" in cache_levels ): - global logging_tracker parent = str(cache_file.parent) if "cache_exists-"+parent not in logging_tracker: logging_tracker.append("cache_exists-"+parent) logging.warning("One or more frames reloaded from cache: "+str(cache_file.parent)) - return frame_idx, xp.asarray(np.load(cache_file)), True # if always saving as numpy, then must cast to torch array if re-reading cache file back in - - # Use the device passed from the calculator, or auto-detect if None - if TORCH_AVAILABLE: - if device is not None: - worker_device = device - else: - worker_device = torch.device('cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')) - - # Set dtype based on worker device - if worker_device.type == 'mps': - worker_complex_dtype = torch.complex64 - worker_float_dtype = torch.float32 - else: - worker_complex_dtype = torch.complex128 - worker_float_dtype = torch.float64 - else: - worker_device = None - worker_complex_dtype = np.complex128 - worker_float_dtype = np.float64 - - atom_type_names = [] - for atom_type in atom_types: - if atom_type in element_map: - atom_type_names.append(element_map[atom_type]) - else: - atom_type_names.append(atom_type) - - #try: - potential = Potential(xs, ys, zs, positions, atom_type_names, kind="kirkland", device=worker_device, slice_axis=slice_axis, progress=(frame_idx==-1), cache_dir=cache_file.parent if "potentials" in cache_levels else None, frame_idx = frame_idx) - - n_probes = len(probe_positions) - nx, ny = len(xs), len(ys) - n_slices = len(zs) - - batched_probes = create_batched_probes(probe, probe_positions, worker_device) - exit_waves_batch = Propagate(batched_probes, potential, worker_device, progress=(frame_idx==-1), onthefly=True, store_all_slices = ("slices" in cache_levels) ) - - if "slices" in cache_levels: - # exit_waves_batch shape: (n_slices, n_probes, nx, ny) - if TORCH_AVAILABLE and worker_device is not None: - frame_data = torch.zeros((n_probes, nx, ny, n_slices, 1), dtype=worker_complex_dtype, device=worker_device) - else: - frame_data = np.zeros((n_probes, nx, ny, n_slices, 1), dtype=worker_complex_dtype) - - # Convert all slices to k-space - for slice_idx in range(n_slices): - slice_waves = exit_waves_batch[slice_idx, :, :, :] # (n_probes, nx, ny) - kwarg = {"dim":(-2,-1)} if TORCH_AVAILABLE else {"axes":(-2,-1)} - waves_k = xp.fft.fft2(slice_waves, **kwarg) - diffraction_patterns = xp.fft.fftshift(waves_k, **kwarg) - - # Store in frame_data - for i in range(n_probes): - frame_data[i, :, :, slice_idx, 0] = diffraction_patterns[i, :, :] - else: - # exit_waves_batch shape: (n_probes, nx, ny) - if TORCH_AVAILABLE and worker_device is not None: - frame_data = torch.zeros((n_probes, nx, ny, 1, 1), dtype=worker_complex_dtype, device=worker_device) - else: - frame_data = np.zeros((n_probes, nx, ny, 1, 1), dtype=worker_complex_dtype) - - # Convert all exit waves to k-space - kwarg = {"dim":(-2,-1)} if TORCH_AVAILABLE else {"axes":(-2,-1)} - exit_waves_k = xp.fft.fft2(exit_waves_batch, **kwarg) - diffraction_patterns = xp.fft.fftshift(exit_waves_k, **kwarg) - - # Store results - frame_data[:, :, :, 0, 0] = diffraction_patterns #.cpu().numpy() - #else: - # # Fallback to individual processing - # for probe_idx, (px, py) in enumerate(probe_positions): - # shifted_probe = probe.copy() - # - # probe_k = torch.fft.fft2(shifted_probe.array) - # - # kx_shift = torch.exp(2j * torch.pi * shifted_probe.kxs[:, None] * px) - # ky_shift = torch.exp(2j * torch.pi * shifted_probe.kys[None, :] * py) - # probe_k_shifted = probe_k * kx_shift * ky_shift - # - # shifted_probe.array = torch.fft.ifft2(probe_k_shifted) - # - # exit_wave_torch = PropagateTorch(shifted_probe, potential, worker_device) - # - # exit_wave_k = torch.fft.fft2(exit_wave_torch) - # diffraction_pattern = torch.fft.fftshift(exit_wave_k) - # - # frame_data[probe_idx, :, :, 0, 0] = diffraction_pattern.cpu().numpy() - - # Convert to CPU numpy array for saving - if TORCH_AVAILABLE and hasattr(frame_data, 'cpu'): - frame_data_cpu = frame_data.cpu().numpy() - else: - frame_data_cpu = frame_data - - if "exitwaves" in cache_levels or "slices" in cache_levels: - np.save(cache_file, frame_data_cpu) - - return frame_idx, frame_data, False - - #except Exception as e: - # logger.error(f"Error processing frame {frame_idx} with PyTorch: {e}") - # from .potential import Potential - # from .multislice_npy import Probe, Propagate - # - # potential = Potential(xs, ys, zs, positions, atom_type_names, kind="kirkland") - # probe = Probe(xs, ys, aperture, eV) - # - # n_probes = len(probe_positions) - # nx, ny = len(xs), len(ys) - ## frame_data = np.zeros((n_probes, nx, ny, 1, 1), dtype=complex) - # - # for probe_idx, (px, py) in enumerate(probe_positions): - # exit_wave = Propagate(probe, potential) - # diffraction_pattern = np.fft.fftshift(np.fft.fft2(exit_wave)) - # frame_data[probe_idx, :, :, 0, 0] = diffraction_pattern - # - # np.save(cache_file, frame_data) - # return frame_idx, frame_data, False + return True,xp.asarray(np.load(cache_file)) # if always saving as numpy, then must cast to torch array if re-reading cache file back in + return False,0 class SEDCalculator: diff --git a/src/pyslice/multislice/potentials.py b/src/pyslice/multislice/potentials.py index 683b553..ac3171b 100644 --- a/src/pyslice/multislice/potentials.py +++ b/src/pyslice/multislice/potentials.py @@ -2,6 +2,7 @@ from pathlib import Path import logging,os from tqdm import tqdm +from ..backend import zeros try: import torch ; xp = torch @@ -286,7 +287,10 @@ def calculateSlice(slice_idx): if self.cache_dir is not None: cache_file = self.cache_dir / ("potential_"+str(frame_idx)+"_"+str(slice_idx)+".npy") if cache_file is not None and os.path.exists(cache_file): - return np.load(cache_file) + Z = np.load(cache_file) + if TORCH_AVAILABLE: + return xp.from_numpy(Z).to(device) + return Z # Initialize slice of potential array using xp with conditional device device_kwargs = {'device': self.device } if self.use_torch else {} @@ -356,7 +360,11 @@ def calculateSlice(slice_idx): dy = self.ys[1] - self.ys[0] Z = real / (dx**2 * dy**2) if cache_file is not None: - np.save(cache_file,Z) + if TORCH_AVAILABLE and hasattr(Z, 'cpu'): + Z_cpu = Z.cpu().numpy() + else: + Z_cpu = Z + np.save(cache_file,Z_cpu) return Z self.calculateSlice = calculateSlice diff --git a/tests/18_caching.py b/tests/18_caching.py index 2f00f51..efeb46b 100644 --- a/tests/18_caching.py +++ b/tests/18_caching.py @@ -19,6 +19,8 @@ a,b=2.4907733333333337,2.1570729817355123 # cache_level options include: ["exitwaves","slices","potentials"] +tests = [1,2,3,4,5,6,7,8,9,10] +#tests = [9,10] # LOAD TRAJECTORY trajectory=Loader(dump,timestep=dt,atom_mapping=types).load() @@ -26,92 +28,102 @@ trajectory=trajectory.slice_positions([0,10*a],[0,10*b]) # ONE TIMESTEPS, ONE PROBE: -print("1. one timestep, one probe, normal caching") -traj1=trajectory.get_random_timesteps(11,seed=1) -calculator=MultisliceCalculator() -calculator.setup(traj1,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/01-test.npy","01") # p,t,x,y,l indices +if 1 in tests: + print("1. one timestep, one probe, normal caching") + traj1=trajectory.get_random_timesteps(11,seed=1) + calculator=MultisliceCalculator() + calculator.setup(traj1,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/01-test.npy","01") # p,t,x,y,l indices # ONE TIMESTEPS, ONE PROBE: -print("2. one timestep, one probe, cache potentials only") -traj2=trajectory.get_random_timesteps(11,seed=2) -calculator=MultisliceCalculator() -calculator.setup(traj2,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["potentials"]) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/02-test.npy","02") # p,t,x,y,l indices +if 2 in tests: + print("2. one timestep, one probe, cache potentials only") + traj2=trajectory.get_random_timesteps(11,seed=2) + calculator=MultisliceCalculator() + calculator.setup(traj2,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["potentials"]) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/02-test.npy","02") # p,t,x,y,l indices # ONE TIMESTEP, MANY PROBES: -print("3. one timestep, many probes, normal caching") -traj3=trajectory.get_random_timesteps(1,seed=3) -calculator=MultisliceCalculator() -probe_xs = np.linspace(a,3*a,14) -probe_ys = np.linspace(b,3*b,16) -calculator.setup(traj3,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys) -exitwaves = calculator.run() -differ(exitwaves.array[::5,:,::5,::5,:],"outputs/caching/03-test.npy","03") # p,t,x,y,l indices +if 3 in tests: + print("3. one timestep, many probes, normal caching") + traj3=trajectory.get_random_timesteps(1,seed=3) + calculator=MultisliceCalculator() + probe_xs = np.linspace(a,3*a,14) + probe_ys = np.linspace(b,3*b,16) + calculator.setup(traj3,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys) + exitwaves = calculator.run() + differ(exitwaves.array[::5,:,::5,::5,:],"outputs/caching/03-test.npy","03") # p,t,x,y,l indices # MANY TIMESTEPS, ONE PROBE: -print("4. many timesteps, one probe, normal caching") -traj4=trajectory.get_random_timesteps(10,seed=4) -calculator=MultisliceCalculator() -calculator.setup(traj4,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/04-test.npy","04") # p,t,x,y,l indices +if 4 in tests: + print("4. many timesteps, one probe, normal caching") + traj4=trajectory.get_random_timesteps(10,seed=4) + calculator=MultisliceCalculator() + calculator.setup(traj4,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/04-test.npy","04") # p,t,x,y,l indices # MANY TIMESTEPS, MANY PROBES: -print("5. many timesteps, many probes, normal caching") -traj5=trajectory.get_random_timesteps(5,seed=5) -calculator=MultisliceCalculator() -probe_xs = np.linspace(a,3*a,9) -probe_ys = np.linspace(b,3*b,10) -calculator.setup(traj5,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::10,::10,:],"outputs/caching/05-test.npy","05") # p,t,x,y,l indices +if 5 in tests: + print("5. many timesteps, many probes, normal caching") + traj5=trajectory.get_random_timesteps(5,seed=5) + calculator=MultisliceCalculator() + probe_xs = np.linspace(a,3*a,9) + probe_ys = np.linspace(b,3*b,10) + calculator.setup(traj5,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::10,::10,:],"outputs/caching/05-test.npy","05") # p,t,x,y,l indices # CACHING TURNED OFF: -print("6. many timesteps, many probes, no caching") -traj6=trajectory.get_random_timesteps(5,seed=6) -calculator=MultisliceCalculator() -probe_xs = np.linspace(a,3*a,9) -probe_ys = np.linspace(b,3*b,10) -calculator.setup(traj6,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=[]) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::10,::10,:],"outputs/caching/06-test.npy","06") # p,t,x,y,l indices +if 6 in tests: + print("6. many timesteps, many probes, no caching") + traj6=trajectory.get_random_timesteps(5,seed=6) + calculator=MultisliceCalculator() + probe_xs = np.linspace(a,3*a,9) + probe_ys = np.linspace(b,3*b,10) + calculator.setup(traj6,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=[]) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::10,::10,:],"outputs/caching/06-test.npy","06") # p,t,x,y,l indices # OR WITH THE POTENTIAL SAVED OFF ONLY -print("7. many timesteps, one probe, caching potentials only") -traj7=trajectory.get_random_timesteps(10,seed=7) -calculator=MultisliceCalculator() -calculator.setup(traj7,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["potentials"]) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/07-test.npy","07") # p,t,x,y,l indices +if 7 in tests: + print("7. many timesteps, one probe, caching potentials only") + traj7=trajectory.get_random_timesteps(10,seed=7) + calculator=MultisliceCalculator() + calculator.setup(traj7,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["potentials"]) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/07-test.npy","07") # p,t,x,y,l indices # LAYERWISE CACHING -print("8. many timesteps, many probes, layerwise caching") -traj8=trajectory.get_random_timesteps(5,seed=8) -calculator=MultisliceCalculator() -probe_xs = np.linspace(a,3*a,6) -probe_ys = np.linspace(b,3*b,7) -calculator.setup(traj8,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=["slices"]) -exitwaves = calculator.run() -differ(exitwaves.array[:,::3,::20,::20,::5],"outputs/caching/08-test.npy","08") # p,t,x,y,l indices +if 8 in tests: + print("8. many timesteps, many probes, layerwise caching") + traj8=trajectory.get_random_timesteps(5,seed=8) + calculator=MultisliceCalculator() + probe_xs = np.linspace(a,3*a,6) + probe_ys = np.linspace(b,3*b,7) + calculator.setup(traj8,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=["slices"]) + exitwaves = calculator.run() + differ(exitwaves.array[:,::3,::20,::20,::5],"outputs/caching/08-test.npy","08") # p,t,x,y,l indices # LAYERWISE CACHING, WITH ONE PROBE -print("9. many timesteps, one probe, layerwise caching") -traj9=trajectory.get_random_timesteps(5,seed=9) -calculator=MultisliceCalculator() -calculator.setup(traj9,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["slices"]) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::5,::5,::5],"outputs/caching/09-test.npy","09") # p,t,x,y,l indices +if 9 in tests: + print("9. many timesteps, one probe, layerwise caching") + traj9=trajectory.get_random_timesteps(5,seed=9) + calculator=MultisliceCalculator() + calculator.setup(traj9,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["slices"]) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::5,::5,::5],"outputs/caching/09-test.npy","09") # p,t,x,y,l indices # LAYERWISE CACHING OR WITH ONE TIMESTEP -print("10. one timestep, many probes, layerwise caching") -traj10=trajectory.get_random_timesteps(1,seed=10) -calculator=MultisliceCalculator() -probe_xs = np.linspace(a,3*a,9) -probe_ys = np.linspace(b,3*b,10) -calculator.setup(traj10,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=["slices"]) -exitwaves = calculator.run() -differ(exitwaves.array[:,:,::10,::10,::5],"outputs/caching/10-test.npy","10") # p,t,x,y,l indices +if 10 in tests: + print("10. one timestep, many probes, layerwise caching") + traj10=trajectory.get_random_timesteps(1,seed=10) + calculator=MultisliceCalculator() + probe_xs = np.linspace(a,3*a,9) + probe_ys = np.linspace(b,3*b,10) + calculator.setup(traj10,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=["slices"]) + exitwaves = calculator.run() + differ(exitwaves.array[:,:,::10,::10,::5],"outputs/caching/10-test.npy","10") # p,t,x,y,l indices