Skip to content

Commit

Permalink
Merge pull request #182 from ahoust17/main
Browse files Browse the repository at this point in the history
fixed atom_tools
  • Loading branch information
gduscher authored Dec 15, 2024
2 parents 496417a + ce48099 commit 180b44b
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 23 deletions.
158 changes: 158 additions & 0 deletions notebooks/4Dstem_File_Reader.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Basics of reading an MRC file with 4D STEM data from the Spectra300 at UTK\n",
"## By Austin Houston\n",
"### Last updated 2024-09-14"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"\n",
"%matplotlib ipympl\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"sys.path.insert(0, '/Users/austin/Documents/GitHub/SciFiReaders/')\n",
"import SciFiReaders\n",
"\n",
"sys.path.insert(0, '/Users/austin/Documents/GitHub/pyTEMlib/')\n",
"import pyTEMlib\n",
"import pyTEMlib.file_tools as ft\n",
"\n",
"print(\"SciFiReaders version: \", SciFiReaders.__version__)\n",
"print(\"pyTEMlib version: \", pyTEMlib.__version__)\n",
"\n",
"# for beginning analysis\n",
"from sklearn.cluster import KMeans\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.cluster import KMeans\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mrc_filepath = '/Users/austin/Dropbox/GaTech_colabs/SnSe_MgO/2024_06_19_data/4D_STEM/'\n",
"\n",
"files = os.listdir(mrc_filepath)\n",
"files = [f for f in files if f.endswith('.mrc')]\n",
"\n",
"# Load the first file\n",
"dset = ft.open_file(mrc_filepath + files[1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = dset['Channel_000']\n",
"\n",
"view = data.plot()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mrc_array = np.array(data)\n",
"N, M, height, width = data.shape\n",
"datacube_flat = mrc_array.reshape(N * M, -1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Perform KMeans clustering\n",
"clusters = 3 \n",
"kmeans = KMeans(n_clusters=clusters, random_state=0).fit(datacube_flat)\n",
"labels = kmeans.labels_\n",
"cluster_centers = kmeans.cluster_centers_\n",
"\n",
"# Reduce the data to 3D using PCA\n",
"pca = PCA(n_components=3)\n",
"datacube_reduced = pca.fit_transform(datacube_flat)\n",
"cluster_centers_reduced = pca.transform(cluster_centers)\n",
"\n",
"# Create a 3D plot\n",
"fig = plt.figure()\n",
"ax = fig.add_subplot(111, projection='3d')\n",
"scatter = ax.scatter(datacube_reduced[:, 0], datacube_reduced[:, 1], datacube_reduced[:, 2], c=labels, cmap='viridis', marker='o')\n",
"ax.set_xlabel('PCA Component 1')\n",
"ax.set_ylabel('PCA Component 2')\n",
"ax.set_zlabel('PCA Component 3')\n",
"ax.set_xticks([])\n",
"ax.set_yticks([])\n",
"ax.set_zticks([])\n",
"plt.show()\n",
"\n",
"\n",
"label_image = labels.reshape((M, N))\n",
"\n",
"plt.figure()\n",
"plt.imshow(label_image, cmap='viridis')\n",
"plt.colorbar()\n",
"plt.show()\n",
"\n",
"# Reshape cluster centers back to original image dimensions\n",
"cluster_center_images = cluster_centers.reshape((kmeans.n_clusters, height, width))\n",
"\n",
"# Plot the average images\n",
"fig, axes = plt.subplots(1, kmeans.n_clusters, figsize=(15, 5))\n",
"\n",
"for i, ax in enumerate(axes):\n",
" ax.imshow(cluster_center_images[i], cmap='viridis')\n",
" ax.set_title(f'Cluster Center {i+1}')\n",
" ax.axis('off')\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pytemlib",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion pyTEMlib/atom_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def find_atoms(image, atom_size=0.1, threshold=0.):
if not isinstance(threshold, float):
raise TypeError('threshold parameter has to be a float number')

scale_x = ft.get_slope(image.dim_0)
scale_x = np.unique(np.gradient(image.dim_0.values))[0]
im = np.array(image-image.min())
im = im/im.max()
if threshold <= 0.:
Expand Down
12 changes: 8 additions & 4 deletions pyTEMlib/file_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@

Dimension = sidpy.Dimension

get_slope = sidpy.base.num_utils.get_slope
__version__ = '2022.3.3'
# Austin commented the line below - it is not used anywhere in the code, and it gives import errors 9-14-2024
# get_slope = sidpy.base.num_utils.get_slopes
__version__ = '2024.9.14'

from traitlets import Unicode, Bool, validate, TraitError
import ipywidgets
Expand Down Expand Up @@ -787,7 +788,7 @@ def h5_group_to_dict(group, group_dict={}):


def open_file(filename=None, h5_group=None, write_hdf_file=False, sum_frames=False): # save_file=False,
"""Opens a file if the extension is .hf5, .ndata, .dm3 or .dm4
"""Opens a file if the extension is .emd, .mrc, .hf5, .ndata, .dm3 or .dm4
If no filename is provided the QT open_file windows opens (if QT_available==True)
Everything will be stored in a NSID style hf5 file.
Expand Down Expand Up @@ -850,7 +851,7 @@ def open_file(filename=None, h5_group=None, write_hdf_file=False, sum_frames=Fa
if not write_hdf_file:
file.close()
return dataset_dict
elif extension in ['.dm3', '.dm4', '.ndata', '.ndata1', '.h5', '.emd', '.emi', '.edaxh5']:
elif extension in ['.dm3', '.dm4', '.ndata', '.ndata1', '.h5', '.emd', '.emi', '.edaxh5', '.mrc']:
# tags = open_file(filename)
if extension in ['.dm3', '.dm4']:
reader = SciFiReaders.DMReader(filename)
Expand Down Expand Up @@ -886,6 +887,9 @@ def open_file(filename=None, h5_group=None, write_hdf_file=False, sum_frames=Fa
elif extension in ['.ndata', '.h5']:
reader = SciFiReaders.NionReader(filename)

elif extension in ['.mrc']:
reader = SciFiReaders.MRCReader(filename)

else:
raise NotImplementedError('extension not supported')

Expand Down
103 changes: 90 additions & 13 deletions pyTEMlib/image_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@
from scipy.optimize import leastsq
from sklearn.cluster import DBSCAN

from ase.build import fcc110
from pyTEMlib import probe_tools

from scipy.ndimage import rotate
from scipy.interpolate import RegularGridInterpolator
from scipy.signal import fftconvolve


_SimpleITK_present = True
try:
Expand All @@ -68,6 +75,72 @@
'install with: conda install -c simpleitk simpleitk ')


def get_atomic_pseudo_potential(fov, atoms, size=512, rotation=0):
# Big assumption: the atoms are not near the edge of the unit cell
# If any atoms are close to the edge (ex. [0,0]) then the potential will be clipped
# before calling the function, shift the atoms to the center of the unit cell

pixel_size = fov / size
max_size = int(size * np.sqrt(2) + 1) # Maximum size to accommodate rotation

# Create unit cell potential
positions = atoms.get_positions()[:, :2]
atomic_numbers = atoms.get_atomic_numbers()
unit_cell_size = atoms.cell.cellpar()[:2]

unit_cell_potential = np.zeros((max_size, max_size))
for pos, atomic_number in zip(positions, atomic_numbers):
x = pos[0] / pixel_size
y = pos[1] / pixel_size
atom_width = 0.5 # Angstrom
gauss_width = atom_width/pixel_size # important for images at various fov. Room for improvement with theory
gauss = probe_tools.make_gauss(max_size, max_size, width = gauss_width, x0=x, y0=y)
unit_cell_potential += gauss * atomic_number # gauss is already normalized to 1

# Create interpolation function for unit cell potential
x_grid = np.linspace(0, fov * max_size / size, max_size)
y_grid = np.linspace(0, fov * max_size / size, max_size)
interpolator = RegularGridInterpolator((x_grid, y_grid), unit_cell_potential, bounds_error=False, fill_value=0)

# Vectorized computation of the full potential map with max_size
x_coords, y_coords = np.meshgrid(np.linspace(0, fov, max_size), np.linspace(0, fov, max_size), indexing="ij")
xtal_x = x_coords % unit_cell_size[0]
xtal_y = y_coords % unit_cell_size[1]
potential_map = interpolator((xtal_x.ravel(), xtal_y.ravel())).reshape(max_size, max_size)

# Rotate and crop the potential map
potential_map = rotate(potential_map, rotation, reshape=False)
center = potential_map.shape[0] // 2
potential_map = potential_map[center - size // 2:center + size // 2, center - size // 2:center + size // 2]

potential_map = scipy.ndimage.gaussian_filter(potential_map,3)

return potential_map

def convolve_probe(ab, potential):
# the pixel sizes should be the exact same as the potential
final_sizes = potential.shape

# Perform FFT-based convolution
pad_height = pad_width = potential.shape[0] // 2
potential = np.pad(potential, ((pad_height, pad_height), (pad_width, pad_width)), mode='constant')

probe, A_k, chi = probe_tools.get_probe(ab, potential.shape[0], potential.shape[1], scale = 'mrad', verbose= False)


convolved = fftconvolve(potential, probe, mode='same')

# Crop to original potential size
start_row = pad_height
start_col = pad_width
end_row = start_row + final_sizes[0]
end_col = start_col + final_sizes[1]

image = convolved[start_row:end_row, start_col:end_col]

return probe, image


# Wavelength in 1/nm
def get_wavelength(e0):
"""
Expand Down Expand Up @@ -280,20 +353,21 @@ def diffractogram_spots(dset, spot_threshold, return_center=True, eps=0.1):
return spots, center


def center_diffractogram(dset, return_plot = True, histogram_factor = None, smoothing = 1, min_samples = 100):
def center_diffractogram(dset, return_plot = True, smoothing = 1, min_samples = 10, beamstop_size = 0.1):
try:
diff = np.array(dset).T.astype(np.float16)
diff[diff < 0] = 0

if histogram_factor is not None:
hist, bins = np.histogram(np.ravel(diff), bins=256, range=(0, 1), density=True)
threshold = threshold_otsu(diff, hist = hist * histogram_factor)
else:
threshold = threshold_otsu(diff)
threshold = threshold_otsu(diff)
binary = (diff > threshold).astype(float)
smoothed_image = ndimage.gaussian_filter(binary, sigma=smoothing) # Smooth before edge detection
smooth_threshold = threshold_otsu(smoothed_image)
smooth_binary = (smoothed_image > smooth_threshold).astype(float)

# add a circle to mask the beamstop
x, y = np.meshgrid(np.arange(dset.shape[0]), np.arange(dset.shape[1]))
circle = (x - dset.shape[0] / 2) ** 2 + (y - dset.shape[1] / 2) ** 2 < (beamstop_size * dset.shape[0] / 2) ** 2
smooth_binary[circle] = 1

# Find the edges using the Sobel operator
edges = sobel(smooth_binary)
edge_points = np.argwhere(edges)
Expand Down Expand Up @@ -322,18 +396,21 @@ def calc_distance(c, x, y):

finally:
if return_plot:
fig, ax = plt.subplots(1, 4, figsize=(10, 4))
fig, ax = plt.subplots(1, 5, figsize=(14, 4), sharex=True, sharey=True)
ax[0].set_title('Diffractogram')
ax[0].imshow(dset.T, cmap='viridis')
ax[1].set_title('Otsu Binary Image')
ax[1].imshow(binary, cmap='gray')
ax[2].set_title('Smoothed Binary Image')
ax[2].imshow(smooth_binary, cmap='gray')
ax[3].set_title('Edge Detection and Fitting')
ax[3].imshow(edges, cmap='gray')
ax[3].scatter(center[0], center[1], c='r', s=10)
ax[2].imshow(smoothed_image, cmap='gray')

ax[3].set_title('Smoothed Binary Image')
ax[3].imshow(smooth_binary, cmap='gray')
ax[4].set_title('Edge Detection and Fitting')
ax[4].imshow(edges, cmap='gray')
ax[4].scatter(center[0], center[1], c='r', s=10)
circle = plt.Circle(center, mean_radius, color='red', fill=False)
ax[3].add_artist(circle)
ax[4].add_artist(circle)
for axis in ax:
axis.axis('off')
fig.tight_layout()
Expand Down
Loading

0 comments on commit 180b44b

Please sign in to comment.