Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/pyslice/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,23 @@ def zeros(dims, dtype=DEFAULT_FLOAT_DTYPE, device=DEFAULT_DEVICE):
array = xp.zeros(dims, dtype=dtype)
return array

def ones(dims, dtype=DEFAULT_FLOAT_DTYPE, device=DEFAULT_DEVICE):
if xp != np:
return xp.ones(dims, dtype=dtype, device=device)
else:
return xp.ones(dims, dtype=dtype)

def fftfreq(n, d, dtype=DEFAULT_FLOAT_DTYPE, device=DEFAULT_DEVICE):
if xp != np:
return xp.fft.fftfreq(n, d, dtype=dtype, device=device)
else:
return xp.fft.fftfreq(n, d, dtype=dtype)

def expand_dims(ary,d):
if xp != np:
return xp.unsqueeze(ary,dim=d)
else:
return np.expand_dims(ary,d)

def exp(x):
return xp.exp(x)
Expand Down
258 changes: 81 additions & 177 deletions src/pyslice/multislice/calculators.py

Large diffs are not rendered by default.

230 changes: 168 additions & 62 deletions src/pyslice/multislice/multislice.py

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions src/pyslice/multislice/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,10 @@ def calculateSlice(slice_idx):
if self.cache_dir is not None:
cache_file = self.cache_dir / ("potential_"+str(frame_idx)+"_"+str(slice_idx)+".npy")
if cache_file is not None and os.path.exists(cache_file):
return np.load(cache_file)
Z = np.load(cache_file)
if TORCH_AVAILABLE:
return xp.from_numpy(Z).to(device)
return Z

# Initialize slice of potential array using xp with conditional device
device_kwargs = {'device': self.device } if self.use_torch else {}
Expand Down Expand Up @@ -356,7 +359,11 @@ def calculateSlice(slice_idx):
dy = self.ys[1] - self.ys[0]
Z = real / (dx**2 * dy**2)
if cache_file is not None:
np.save(cache_file,Z)
if TORCH_AVAILABLE and hasattr(Z, 'cpu'):
Z_cpu = Z.cpu().numpy()
else:
Z_cpu = Z
np.save(cache_file, Z_cpu)
return Z

self.calculateSlice = calculateSlice
Expand Down
10 changes: 5 additions & 5 deletions src/pyslice/postprocessing/haadf_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self, wf_data: WFData) -> None:
self.cache_dir = wf_data.cache_dir

# Store reference to source WFData array for ADF calculation
self._wf_array = wf_data.reshaped # x,y,t,kx,ky,l indices
self._wf_array = wf_data.reshaped # nprobes,x,y,t,kx,ky,l indices

# Initialize ADF as None, will be computed by calculateADF
self._array = None
Expand Down Expand Up @@ -186,16 +186,16 @@ def calculateADF(self, inner_mrad: float = 45, outer_mrad: float = 150, preview:
# self._array[i, j] = collected


# recall self._wf_array is reshaped: p,t,kx,ky,l --> x,y,t,kx,ky,l
# recall self._wf_array is reshaped: p,t,kx,ky,l --> c,x,y,t,kx,ky,l
if preview:
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
preview_data = xp.mean(xp.absolute(self._wf_array),axis=(0,1,2,5))**.2 * (1 - mask)
preview_data = xp.mean(xp.absolute(self._wf_array),axis=(0,1,2,3,6))**.2 * (1 - mask)
ax.imshow(np.asarray(preview_data), cmap="inferno")
plt.show()

collected = self._wf_array * mask[None,None,None,:,:,None] # x,y,t,kx,ky,l indices, mask is kx,ky
self._array = xp.mean(xp.sum(xp.absolute(collected),axis=(3,4)),axis=(2,3)) # sum over kx,ky, mean over t,l
collected = self._wf_array * mask[None,None,None,:,:,None] # c,x,y,t,kx,ky,l indices, mask is kx,ky
self._array = xp.mean(xp.sum(xp.absolute(collected),axis=(4,5)),axis=(0,3,4)) # sum over kx,ky, mean over c,t,l

# Update dimensions with computed xs, ys
def to_numpy(x):
Expand Down
57 changes: 41 additions & 16 deletions src/pyslice/postprocessing/wf_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..multislice.multislice import Probe,aberrationFunction
from ..data import Signal, Dimensions, Dimension, GeneralMetadata
from pathlib import Path
from ..backend import mean
from ..backend import mean,ones

try:
import torch ; xp = torch
Expand Down Expand Up @@ -154,10 +154,11 @@ def array(self):

@property
def reshaped(self): # where self._array is indices probe,time,kx,ky,layer, we reshape to probe_x,probe_y,time,kx,ky,layer
npt,nt,nkx,nky,nl = self._array.shape
nx=len(self.probe_xs)
ny=len(self.probe_ys)
return xp.reshape(self._array,(nx,ny,nt,nkx,nky,nl))
nc,nptp,nx,ny = self.probe._array.shape # recall: decoherence creates duplicate probes: num_copies,num_positions,x,y indices
npta,nt,nkx,nky,nl = self._array.shape # recall, Propagate flattens the first two, and adds time,layers: nc*npt,num_frames,x,y,nl indice
intermediate = xp.reshape(self._array,(nc,nptp,nt,nkx,nky,nl))
nx,ny = len(self.probe.probe_xs),len(self.probe.probe_ys)
return xp.reshape(intermediate,(nc,nx,ny,nt,nkx,nky,nl))

@array.setter
def array(self, value):
Expand Down Expand Up @@ -234,7 +235,7 @@ def plot_reciprocal(self,filename=None,whichProbe="mean",whichTimestep="mean",po
else:
array = array[whichProbe]

if isinstance(whichTimestep,str) and whichProbe=="mean":
if isinstance(whichTimestep,str) and whichTimestep=="mean":
array = mean(array,axis=0) # t,kx,ky --> kx,ky
else:
array = array[whichTimestep]
Expand Down Expand Up @@ -358,22 +359,25 @@ def plot_phase(self,filename=None,whichProbe=0,whichTimestep=0,extent=None,avg=F
else:
plt.show()

def plot_realspace(self,whichProbe=0,whichTimestep=0,extent=None,avg=False,filename=None):
def plot_realspace(self,whichProbe="mean",whichTimestep="mean",extent=None,filename=None):
import matplotlib.pyplot as plt
fig, ax = plt.subplots()

# Get array (with or without averaging)
if avg:
array = self._array[whichProbe,:,:,:,-1] # Shape: (time, kx, ky)
if hasattr(array, 'mean'): # torch tensor
array = array.mean(dim=0) # Average over time dimension
else: # numpy array
array = np.mean(array, axis=0)
array = xp.fft.ifft2(self._array[:,:,:,:,-1])

array = xp.absolute(array) # probe, time, kx, ky, layer --> p,t,kx,ky

if isinstance(whichProbe,str) and whichProbe=="mean":
array = mean(abs(array),axis=0) # p,t,kx,ky --> t,kx,ky
else:
array = self._array[whichProbe,whichTimestep,:,:,-1]
array = array[whichProbe]

if isinstance(whichTimestep,str) and whichTimestep=="mean":
array = mean(array,axis=0) # t,kx,ky --> kx,ky
else:
array = array[whichTimestep]

array = array.T # imshow convention: y,x. our convention: x,y
array = xp.fft.ifft2(array)

# Use provided extent or calculate from data
if extent is None:
Expand Down Expand Up @@ -403,6 +407,27 @@ def propagate_free_space(self,dz): # UNITS OF ANGSTROM
#if dz>0:
self._array = P[None,None,:,:,None] * self._array

def addSpatialDecoherence(self,sigma_dz,N):
dzs = np.linspace(-2*sigma_dz,2*sigma_dz,N) # suppose N=25
amplitudes = np.exp(-dzs**2/sigma_dz**2)
self._array = self._array[:,None,:,:,:,:] * ones(N)[None,:,None,None,None,None] # n_probes,nt,nx,ny,nl -->
nc,npt,nt,nx,ny,nl = self._array.shape # suppose nc=10 (addTemporalDecoherence created 10 wavelengths)
kx_grid, ky_grid = xp.meshgrid(self._kxs, self._kys, indexing='ij')
k_squared = kx_grid**2 + ky_grid**2
for i in range(N):
inner = xp.pi * self.probe.wavelength * dzs[i] * k_squared
P = xp.exp( -1j * inner ) # not sure why, but combining this and previous line triggers a "ComplexWarning: Casting complex values to real discards the imaginary part" in python 2.9.1 but not 2.2.2
self._array[:,i,:,:,:,:] *= amplitudes[i]*P[None,None,:,:,None]
self._array = self._array.reshape(nc*npt,nt,nx,ny,nl)
#self.defocus(dzs) # defocus starts with 25,10,npt,nx,ny --reshapes--> 250,npt,nx,ny
#for i in range(N): # reshape to flatten loops first index last: [[0,1],[2,3]] --> [0,1,2,3]
# for j in range(nc):
# self._array[i*nc+j] *= amplitudes[i]
#nc,npt,nx,ny = self._array.shape
#if npt==1:
# self.applyShifts()


def applyMask(self, radius, realOrReciprocal="reciprocal"):
if realOrReciprocal == "reciprocal":
radii = xp.sqrt( self._kxs[:,None]**2 + self._kys[None,:]**2 )
Expand Down
166 changes: 96 additions & 70 deletions tests/18_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,99 +19,125 @@
a,b=2.4907733333333337,2.1570729817355123

# cache_level options include: ["exitwaves","slices","potentials"]
tests = [1,2,3,4,5,5.5,6,7,8,9,10]
#tests = [9,10]
#tests=[5,5.5]

# LOAD TRAJECTORY
trajectory=Loader(dump,timestep=dt,atom_mapping=types).load()
# TRIM TO 10x10 UC
trajectory=trajectory.slice_positions([0,10*a],[0,10*b])

# ONE TIMESTEPS, ONE PROBE:
print("1. one timestep, one probe, normal caching")
traj1=trajectory.get_random_timesteps(11,seed=1)
calculator=MultisliceCalculator()
calculator.setup(traj1,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5)
exitwaves = calculator.run()
differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/01-test.npy","01") # p,t,x,y,l indices
if 1 in tests:
print("1. one timestep, one probe, normal caching")
traj1=trajectory.get_random_timesteps(11,seed=1)
calculator=MultisliceCalculator()
calculator.setup(traj1,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5)
exitwaves = calculator.run()
differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/01-test.npy","01") # p,t,x,y,l indices

# ONE TIMESTEPS, ONE PROBE:
print("2. one timestep, one probe, cache potentials only")
traj2=trajectory.get_random_timesteps(11,seed=2)
calculator=MultisliceCalculator()
calculator.setup(traj2,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["potentials"])
exitwaves = calculator.run()
differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/02-test.npy","02") # p,t,x,y,l indices
if 2 in tests:
print("2. one timestep, one probe, cache potentials only")
traj2=trajectory.get_random_timesteps(11,seed=2)
calculator=MultisliceCalculator()
calculator.setup(traj2,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["potentials"])
exitwaves = calculator.run()
differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/02-test.npy","02") # p,t,x,y,l indices

# ONE TIMESTEP, MANY PROBES:
print("3. one timestep, many probes, normal caching")
traj3=trajectory.get_random_timesteps(1,seed=3)
calculator=MultisliceCalculator()
probe_xs = np.linspace(a,3*a,14)
probe_ys = np.linspace(b,3*b,16)
calculator.setup(traj3,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys)
exitwaves = calculator.run()
differ(exitwaves.array[::5,:,::5,::5,:],"outputs/caching/03-test.npy","03") # p,t,x,y,l indices
if 3 in tests:
print("3. one timestep, many probes, normal caching")
traj3=trajectory.get_random_timesteps(1,seed=3)
calculator=MultisliceCalculator()
probe_xs = np.linspace(a,3*a,14)
probe_ys = np.linspace(b,3*b,16)
calculator.setup(traj3,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys)
exitwaves = calculator.run()
differ(exitwaves.array[::5,:,::5,::5,:],"outputs/caching/03-test.npy","03") # p,t,x,y,l indices

# MANY TIMESTEPS, ONE PROBE:
print("4. many timesteps, one probe, normal caching")
traj4=trajectory.get_random_timesteps(10,seed=4)
calculator=MultisliceCalculator()
calculator.setup(traj4,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5)
exitwaves = calculator.run()
differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/04-test.npy","04") # p,t,x,y,l indices
if 4 in tests:
print("4. many timesteps, one probe, normal caching")
traj4=trajectory.get_random_timesteps(10,seed=4)
calculator=MultisliceCalculator()
calculator.setup(traj4,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5)
exitwaves = calculator.run()
differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/04-test.npy","04") # p,t,x,y,l indices

# MANY TIMESTEPS, MANY PROBES:
print("5. many timesteps, many probes, normal caching")
traj5=trajectory.get_random_timesteps(5,seed=5)
calculator=MultisliceCalculator()
probe_xs = np.linspace(a,3*a,9)
probe_ys = np.linspace(b,3*b,10)
calculator.setup(traj5,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys)
exitwaves = calculator.run()
differ(exitwaves.array[:,:,::10,::10,:],"outputs/caching/05-test.npy","05") # p,t,x,y,l indices
if 5 in tests:
print("5. many timesteps, many probes, normal caching")
traj5=trajectory.get_random_timesteps(5,seed=5)
calculator=MultisliceCalculator()
probe_xs = np.linspace(a,3*a,9)
probe_ys = np.linspace(b,3*b,10)
calculator.setup(traj5,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys)
exitwaves = calculator.run()
differ(exitwaves.array[:,:,::10,::10,:],"outputs/caching/05-test.npy","05") # p,t,x,y,l indices

# MANY TIMESTEPS, MANY PROBES:
if 5.5 in tests:
print("5.5 many timesteps, many probes, normal caching, with decoherence added")
traj55=trajectory.get_random_timesteps(5,seed=5)
calculator=MultisliceCalculator()
probe_xs = np.linspace(a,3*a,9)
probe_ys = np.linspace(b,3*b,10)
calculator.setup(traj55,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys)
calculator.base_probe.addSpatialDecoherence(1000,7)
exitwaves = calculator.run()
#differ(exitwaves.array[:,:,::10,::10,:],"outputs/caching/05-test.npy","05") # p,t,x,y,l indices


# CACHING TURNED OFF:
print("6. many timesteps, many probes, no caching")
traj6=trajectory.get_random_timesteps(5,seed=6)
calculator=MultisliceCalculator()
probe_xs = np.linspace(a,3*a,9)
probe_ys = np.linspace(b,3*b,10)
calculator.setup(traj6,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=[])
exitwaves = calculator.run()
differ(exitwaves.array[:,:,::10,::10,:],"outputs/caching/06-test.npy","06") # p,t,x,y,l indices
if 6 in tests:
print("6. many timesteps, many probes, no caching")
traj6=trajectory.get_random_timesteps(5,seed=6)
calculator=MultisliceCalculator()
probe_xs = np.linspace(a,3*a,9)
probe_ys = np.linspace(b,3*b,10)
calculator.setup(traj6,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=[])
exitwaves = calculator.run()
differ(exitwaves.array[:,:,::10,::10,:],"outputs/caching/06-test.npy","06") # p,t,x,y,l indices

# OR WITH THE POTENTIAL SAVED OFF ONLY
print("7. many timesteps, one probe, caching potentials only")
traj7=trajectory.get_random_timesteps(10,seed=7)
calculator=MultisliceCalculator()
calculator.setup(traj7,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["potentials"])
exitwaves = calculator.run()
differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/07-test.npy","07") # p,t,x,y,l indices
if 7 in tests:
print("7. many timesteps, one probe, caching potentials only")
traj7=trajectory.get_random_timesteps(10,seed=7)
calculator=MultisliceCalculator()
calculator.setup(traj7,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["potentials"])
exitwaves = calculator.run()
differ(exitwaves.array[:,:,::5,::5,:],"outputs/caching/07-test.npy","07") # p,t,x,y,l indices

# LAYERWISE CACHING
print("8. many timesteps, many probes, layerwise caching")
traj8=trajectory.get_random_timesteps(5,seed=8)
calculator=MultisliceCalculator()
probe_xs = np.linspace(a,3*a,6)
probe_ys = np.linspace(b,3*b,7)
calculator.setup(traj8,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=["slices"])
exitwaves = calculator.run()
differ(exitwaves.array[:,::3,::20,::20,::5],"outputs/caching/08-test.npy","08") # p,t,x,y,l indices
if 8 in tests:
print("8. many timesteps, many probes, layerwise caching")
traj8=trajectory.get_random_timesteps(5,seed=8)
calculator=MultisliceCalculator()
probe_xs = np.linspace(a,3*a,6)
probe_ys = np.linspace(b,3*b,7)
calculator.setup(traj8,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=["slices"])
exitwaves = calculator.run()
differ(exitwaves.array[:,::3,::20,::20,::5],"outputs/caching/08-test.npy","08") # p,t,x,y,l indices

# LAYERWISE CACHING, WITH ONE PROBE
print("9. many timesteps, one probe, layerwise caching")
traj9=trajectory.get_random_timesteps(5,seed=9)
calculator=MultisliceCalculator()
calculator.setup(traj9,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["slices"])
exitwaves = calculator.run()
differ(exitwaves.array[:,:,::5,::5,::5],"outputs/caching/09-test.npy","09") # p,t,x,y,l indices
if 9 in tests:
print("9. many timesteps, one probe, layerwise caching")
traj9=trajectory.get_random_timesteps(5,seed=9)
calculator=MultisliceCalculator()
calculator.setup(traj9,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,cache_levels=["slices"])
exitwaves = calculator.run()
differ(exitwaves.array[:,:,::5,::5,::5],"outputs/caching/09-test.npy","09") # p,t,x,y,l indices

# LAYERWISE CACHING OR WITH ONE TIMESTEP
print("10. one timestep, many probes, layerwise caching")
traj10=trajectory.get_random_timesteps(1,seed=10)
calculator=MultisliceCalculator()
probe_xs = np.linspace(a,3*a,9)
probe_ys = np.linspace(b,3*b,10)
calculator.setup(traj10,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=["slices"])
exitwaves = calculator.run()
differ(exitwaves.array[:,:,::10,::10,::5],"outputs/caching/10-test.npy","10") # p,t,x,y,l indices
if 10 in tests:
print("10. one timestep, many probes, layerwise caching")
traj10=trajectory.get_random_timesteps(1,seed=10)
calculator=MultisliceCalculator()
probe_xs = np.linspace(a,3*a,9)
probe_ys = np.linspace(b,3*b,10)
calculator.setup(traj10,aperture=30,voltage_eV=100e3,sampling=.1,slice_thickness=.5,probe_xs=probe_xs,probe_ys=probe_ys,cache_levels=["slices"])
exitwaves = calculator.run()
differ(exitwaves.array[:,:,::10,::10,::5],"outputs/caching/10-test.npy","10") # p,t,x,y,l indices

Loading