Skip to content

Commit

Permalink
Improve docs for 3D astra X-ray transform (#544)
Browse files Browse the repository at this point in the history
* Initial work on improved astra 3d docs

* Typo fix

* Improve 3d astra docs

* Reduce space around figures

* Improve figure caption style

* Improve 3d astra docs

* Typo fix

* Move matplotlib figures from submodule

* Add test

* Address PR review comments
  • Loading branch information
bwohlberg authored Jul 22, 2024
1 parent d38ecea commit 3d3d849
Show file tree
Hide file tree
Showing 14 changed files with 565 additions and 11 deletions.
2 changes: 1 addition & 1 deletion data
12 changes: 12 additions & 0 deletions docs/source/_static/scico.css
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,15 @@ div.doctest.highlight-default {
[data-theme=light] dl.py.property > dt {
border-radius: 4px;
}


/* Style for figure captions */

div.figure p.caption span.caption-text,
figcaption span.caption-text {
font-size: var(--font-size--small);
margin-left: 5%;
margin-right: 5%;
display: inline-block;
text-align: justify;
}
46 changes: 46 additions & 0 deletions docs/source/pyfigures/cylindgrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import numpy as np

import scico.linop as scl
from scico import plot

input_shape = (7, 7, 7)
centre = (np.array(input_shape) - 1) / 2
end = np.array(input_shape) - centre
g0, g1, g2 = np.mgrid[-centre[0] : end[0], -centre[1] : end[1], -centre[2] : end[2]]

cg = scl.CylindricalGradient(input_shape=input_shape)

ang = cg.coord[0]
rad = cg.coord[1]
axi = cg.coord[2]

theta = np.arctan2(g0, g1)
clr = theta
# See https://stackoverflow.com/a/49888126
clr = (clr.ravel() - clr.min()) / np.ptp(clr)
clr = np.concatenate((clr, np.repeat(clr, 2)))
clr = plot.plt.cm.plasma(clr)

plot.plt.rcParams["savefig.transparent"] = True

fig = plot.plt.figure(figsize=(20, 6))
ax = fig.add_subplot(1, 3, 1, projection="3d")
ax.quiver(g0, g1, g2, ang[0], ang[1], ang[2], colors=clr, length=0.9)
ax.set_title("Angular local coordinate axis")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.set_zlabel("$z$")
ax = fig.add_subplot(1, 3, 2, projection="3d")
ax.quiver(g0, g1, g2, rad[0], rad[1], rad[2], colors=clr, length=0.9)
ax.set_title("Radial local coordinate axis")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.set_zlabel("$z$")
ax = fig.add_subplot(1, 3, 3, projection="3d")
ax.quiver(g0, g1, g2, axi[0], axi[1], axi[2], colors=clr[0], length=0.9)
ax.set_title("Axial local coordinate axis")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.set_zlabel("$z$")
fig.tight_layout()
fig.show()
35 changes: 35 additions & 0 deletions docs/source/pyfigures/polargrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np

import scico.linop as scl
from scico import plot

input_shape = (21, 21)
centre = (np.array(input_shape) - 1) / 2
end = np.array(input_shape) - centre
g0, g1 = np.mgrid[-centre[0] : end[0], -centre[1] : end[1]]

pg = scl.PolarGradient(input_shape=input_shape)

ang = pg.coord[0]
rad = pg.coord[1]

clr = (np.arctan2(ang[1], ang[0]) + np.pi) / (2 * np.pi)

plot.plt.rcParams["image.cmap"] = "plasma"
plot.plt.rcParams["savefig.transparent"] = True

fig, ax = plot.plt.subplots(nrows=1, ncols=2, figsize=(13, 6))
ax[0].quiver(g0, g1, ang[0], ang[1], clr)
ax[0].set_title("Angular local coordinate axis")
ax[0].set_xlabel("$x$")
ax[0].set_ylabel("$y$")
ax[0].xaxis.set_ticks((-10, -5, 0, 5, 10))
ax[0].yaxis.set_ticks((-10, -5, 0, 5, 10))
ax[1].quiver(g0, g1, rad[0], rad[1], clr)
ax[1].set_title("Radial local coordinate axis")
ax[1].set_xlabel("$x$")
ax[1].set_ylabel("$y$")
ax[1].xaxis.set_ticks((-10, -5, 0, 5, 10))
ax[1].yaxis.set_ticks((-10, -5, 0, 5, 10))
fig.tight_layout()
fig.show()
47 changes: 47 additions & 0 deletions docs/source/pyfigures/spheregrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np

import scico.linop as scl
from scico import plot

input_shape = (7, 7, 7)
centre = (np.array(input_shape) - 1) / 2
end = np.array(input_shape) - centre
g0, g1, g2 = np.mgrid[-centre[0] : end[0], -centre[1] : end[1], -centre[2] : end[2]]

sg = scl.SphericalGradient(input_shape=input_shape)

azi = sg.coord[0]
pol = sg.coord[1]
rad = sg.coord[2]

theta = np.arctan2(g0, g1)
phi = np.arctan2(np.sqrt(g0**2 + g1**2), g2)
clr = theta * phi
# See https://stackoverflow.com/a/49888126
clr = (clr.ravel() - clr.min()) / np.ptp(clr)
clr = np.concatenate((clr, np.repeat(clr, 2)))
clr = plot.plt.cm.plasma(clr)

plot.plt.rcParams["savefig.transparent"] = True

fig = plot.plt.figure(figsize=(20, 6))
ax = fig.add_subplot(1, 3, 1, projection="3d")
ax.quiver(g0, g1, g2, azi[0], azi[1], azi[2], colors=clr, length=0.9)
ax.set_title("Azimuthal local coordinate axis")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.set_zlabel("$z$")
ax = fig.add_subplot(1, 3, 2, projection="3d")
ax.quiver(g0, g1, g2, pol[0], pol[1], pol[2], colors=clr, length=0.9)
ax.set_title("Polar local coordinate axis")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.set_zlabel("$z$")
ax = fig.add_subplot(1, 3, 3, projection="3d")
ax.quiver(g0, g1, g2, rad[0], rad[1], rad[2], colors=clr, length=0.9)
ax.set_title("Radial local coordinate axis")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.set_zlabel("$z$")
fig.tight_layout()
fig.show()
83 changes: 83 additions & 0 deletions docs/source/pyfigures/xray_2d_geom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import numpy as np

import matplotlib.patches as patches
import matplotlib.pyplot as plt

c = 1.0 / np.sqrt(2.0)
e = 1e-2
style = "Simple, tail_width=0.5, head_width=4, head_length=8"
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(21, 7))
for n in range(3):
ax[n].set_aspect(1.0)
ax[n].set_xlim(-1.1, 1.1)
ax[n].set_ylim(-1.1, 1.1)
ax[n].set_xticks(np.linspace(-1.0, 1.0, 5))
ax[n].set_yticks(np.linspace(-1.0, 1.0, 5))
ax[n].tick_params(axis="x", labelsize=12)
ax[n].tick_params(axis="y", labelsize=12)
ax[n].set_xlabel("axis 1", fontsize=14)
ax[n].set_ylabel("axis 0", fontsize=14)

# scico
plist = [
patches.FancyArrowPatch((-1.0, 0.0), (-0.5, 0.0), arrowstyle=style, color="r"),
patches.FancyArrowPatch((-c, -c), (-c / 2.0, -c / 2.0), arrowstyle=style, color="r"),
patches.FancyArrowPatch(
(
0.0,
-1.0,
),
(0.0, -0.5),
arrowstyle=style,
color="r",
),
patches.Arc((0.0, 0.0), 2.0, 2.0, theta1=180, theta2=-45.0, color="b", ls="dotted"),
patches.FancyArrowPatch((c - e, -c - e), (c + e, -c + e), arrowstyle=style, color="b"),
]
for p in plist:
ax[0].add_patch(p)
ax[0].text(-0.88, 0.02, r"$\theta=0$", color="r", fontsize=14)
ax[0].text(-3 * c / 4 - 0.01, -3 * c / 4 - 0.1, r"$\theta=\frac{\pi}{4}$", color="r", fontsize=14)
ax[0].text(0.03, -0.8, r"$\theta=\frac{\pi}{2}$", color="r", fontsize=14)
ax[0].set_title("scico", fontsize=14)

# astra
plist = [
patches.FancyArrowPatch((0.0, -1.0), (0.0, -0.5), arrowstyle=style, color="r"),
patches.FancyArrowPatch((c, -c), (c / 2.0, -c / 2.0), arrowstyle=style, color="r"),
patches.FancyArrowPatch((1.0, 0.0), (0.5, 0.0), arrowstyle=style, color="r"),
patches.Arc((0.0, 0.0), 2.0, 2.0, theta1=-90, theta2=45.0, color="b", ls="dotted"),
patches.FancyArrowPatch((c + e, c - e), (c - e, c + e), arrowstyle=style, color="b"),
]
for p in plist:
ax[1].add_patch(p)
ax[1].text(0.02, -0.75, r"$\theta=0$", color="r", fontsize=14)
ax[1].text(3 * c / 4 + 0.01, -3 * c / 4 + 0.01, r"$\theta=\frac{\pi}{4}$", color="r", fontsize=14)
ax[1].text(0.65, 0.05, r"$\theta=\frac{\pi}{2}$", color="r", fontsize=14)
ax[1].set_title("astra", fontsize=14)

# svmbir
plist = [
patches.FancyArrowPatch((-1.0, 0.0), (-0.5, 0.0), arrowstyle=style, color="r"),
patches.FancyArrowPatch((-c, c), (-c / 2.0, c / 2.0), arrowstyle=style, color="r"),
patches.FancyArrowPatch(
(
0.0,
1.0,
),
(0.0, 0.5),
arrowstyle=style,
color="r",
),
patches.Arc((0.0, 0.0), 2.0, 2.0, theta1=45, theta2=180, color="b", ls="dotted"),
patches.FancyArrowPatch((c - e, c + e), (c + e, c - e), arrowstyle=style, color="b"),
]
for p in plist:
ax[2].add_patch(p)
ax[2].text(-0.88, 0.02, r"$\theta=0$", color="r", fontsize=14)
ax[2].text(-3 * c / 4 + 0.01, 3 * c / 4 + 0.01, r"$\theta=\frac{\pi}{4}$", color="r", fontsize=14)
ax[2].text(0.03, 0.75, r"$\theta=\frac{\pi}{2}$", color="r", fontsize=14)
ax[2].set_title("svmbir", fontsize=14)

fig.tight_layout()
fig.show()
72 changes: 72 additions & 0 deletions docs/source/pyfigures/xray_3d_ang.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import numpy as np

import matplotlib as mpl
import matplotlib.patches as patches
import matplotlib.pyplot as plt

mpl.rcParams["savefig.transparent"] = True


c = 1.0 / np.sqrt(2.0)
e = 1e-2
style = "Simple, tail_width=0.5, head_width=4, head_length=8"
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(5, 5))
ax.set_aspect(1.0)
ax.set_xlim(-1.1, 1.1)
ax.set_ylim(-1.1, 1.1)
ax.set_xticks(np.linspace(-1.0, 1.0, 5))
ax.set_yticks(np.linspace(-1.0, 1.0, 5))
ax.tick_params(axis="x", labelsize=12)
ax.tick_params(axis="y", labelsize=12)
ax.set_xlabel("$x$", fontsize=14)
ax.set_ylabel("$y$", fontsize=14)

plist = [
patches.FancyArrowPatch((0.0, -1.0), (0.0, -0.5), arrowstyle=style, color="r"),
patches.FancyArrowPatch((c, -c), (c / 2.0, -c / 2.0), arrowstyle=style, color="r"),
patches.FancyArrowPatch((1.0, 0.0), (0.5, 0.0), arrowstyle=style, color="r"),
patches.Arc((0.0, 0.0), 2.0, 2.0, theta1=-90, theta2=45.0, color="b", ls="dotted"),
patches.FancyArrowPatch((c + e, c - e), (c - e, c + e), arrowstyle=style, color="b"),
]
for p in plist:
ax.add_patch(p)
ax.text(0.02, -0.75, r"$\theta=0$", color="r", fontsize=14)
ax.text(
3 * c / 4 + 0.01,
-3 * c / 4 + 0.01,
r"$\theta=\frac{\pi}{4}$",
color="r",
fontsize=14,
)
ax.text(0.65, 0.05, r"$\theta=\frac{\pi}{2}$", color="r", fontsize=14)

ax.plot((-0.375, 0.375), (1.0, 1.0), color="orange", lw=2)
ax.arrow(
-0.375,
0.94,
0.75,
0.0,
color="orange",
lw=0.5,
ls="--",
head_width=0.03,
length_includes_head=True,
)
ax.text(0.0, 0.82, r"$\theta=0$", color="orange", ha="center", fontsize=14)

ax.plot((-1.0, -1.0), (-0.375, 0.375), color="orange", lw=2)
ax.arrow(
-0.94,
-0.375,
0.0,
0.75,
color="orange",
lw=0.5,
ls="--",
head_width=0.03,
length_includes_head=True,
)
ax.text(-0.9, 0.0, r"$\theta=\frac{\pi}{2}$", color="orange", ha="left", fontsize=14)

fig.tight_layout()
fig.show()
81 changes: 81 additions & 0 deletions docs/source/pyfigures/xray_3d_vec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import numpy as np

import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d

mpl.rcParams["savefig.transparent"] = True


# See https://github.com/matplotlib/matplotlib/issues/21688
class Arrow3D(FancyArrowPatch):
def __init__(self, xs, ys, zs, *args, **kwargs):
FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
self._verts3d = xs, ys, zs

def do_3d_projection(self, renderer=None):
xs3d, ys3d, zs3d = self._verts3d
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))

return np.min(zs)


# Define vector components
πœƒ = 10 * np.pi / 180.0 # angle in x-y plane (azimuth angle)
𝛼 = 70 * np.pi / 180.0 # angle with z axis (zenith angle)
π›₯p, π›₯d = 0.3, 1.0
d = (-π›₯d * np.sin(𝛼) * np.sin(πœƒ), π›₯d * np.sin(𝛼) * np.cos(πœƒ), π›₯d * np.cos(𝛼))
u = (π›₯p * np.cos(πœƒ), π›₯p * np.sin(πœƒ), 0.0)
v = (π›₯p * np.cos(𝛼) * np.sin(πœƒ), -π›₯p * np.cos(𝛼) * np.cos(πœƒ), π›₯p * np.sin(𝛼))

# Location of text labels
d_txtpos = np.array(d) + np.array([0, 0, -0.12])
u_txtpos = np.array(d) + np.array(u) + np.array([0, 0, -0.1])
v_txtpos = np.array(d) + np.array(v) + np.array([0, 0, 0.03])


arrowstyle = "-|>,head_width=2.5,head_length=9"

fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

# Set view
ax.set_aspect("equal")
ax.elev = 15
ax.azim = -50
ax.set_box_aspect(None, zoom=2)
ax.set_xlim((-1.1, 1.1))
ax.set_ylim((-1.1, 1.1))
ax.set_zlim((-1.1, 1.1))

# Disable shaded 3d axis grids
ax.set_axis_off()

# Draw central x,y,z axes and labels
axis_crds = np.array([[-1, 1], [0, 0], [0, 0]])
axis_lbls = ("$x$", "$y$", "$z$")
for k in range(3):
crd = np.roll(axis_crds, k, axis=0)
ax.add_artist(
Arrow3D(
*crd.tolist(),
lw=1.5,
ls="--",
arrowstyle=arrowstyle,
color="black",
)
)
ax.text(*(1.05 * crd[:, 1]).tolist(), axis_lbls[k], fontsize=12)

# Draw d, u, v and labels
ax.quiver(0, 0, 0, *d, arrow_length_ratio=0.08, lw=2, color="blue")
ax.quiver(*d, *u, arrow_length_ratio=0.08 / π›₯p, lw=2, color="blue")
ax.quiver(*d, *v, arrow_length_ratio=0.08 / π›₯p, lw=2, color="blue")
ax.text(*d_txtpos, r"$\mathbf{d}$", fontsize=12)
ax.text(*u_txtpos, r"$\mathbf{u}$", fontsize=12)
ax.text(*v_txtpos, r"$\mathbf{v}$", fontsize=12)

fig.tight_layout()
fig.subplots_adjust(-0.1, -0.06, 1, 1)
fig.show()
Loading

0 comments on commit 3d3d849

Please sign in to comment.