Skip to content

Commit

Permalink
Modified the wavelength argument of colorsynth.rgb() to accept va…
Browse files Browse the repository at this point in the history
…lues of `None`.
  • Loading branch information
byrdie committed Feb 6, 2024
1 parent 94326e8 commit 0be9bc5
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
32 changes: 30 additions & 2 deletions colorsynth/_colorsynth.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def transform_spd_wavelength(x: np.ndarray, w: u.Quantity):

def rgb(
spd: np.ndarray,
wavelength: u.Quantity,
wavelength: None | u.Quantity = None,
axis: int = -1,
spd_min: None | np.ndarray = None,
spd_max: None | np.ndarray = None,
Expand All @@ -678,7 +678,9 @@ def rgb(
spd
a spectral power distribution to be converted into a RGB array
wavelength
the wavelength array corresponding to the spectral power distribution
The wavelength array corresponding to the spectral power distribution.
If :obj:`None`, the wavelength is assumed to be evenly sampled across
the human visible color range.
axis
the logical axis corresponding to changing wavelength,
or the axis along which to integrate the spectral power distribution
Expand All @@ -700,7 +702,33 @@ def rgb(
wavelength_norm
an optional function to transform the wavelength values before they
are mapped into the human visible color range.
Examples
--------
Colorize a random, 3D numpy array.
.. jupyter-execute::
import numpy as np
import matplotlib.pyplot as plt
import colorsynth
# Create a uniform random 3D numpy array
a = np.random.uniform(low=0, high=1, size=(16, 16, 11))
# Colorize the 3D numpy array
rgb = colorsynth.rgb(a)
# Plot the resulting RGB image
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(rgb);
"""
if wavelength is None:
shape_wavelength = [1] * spd.ndim
shape_wavelength[axis] = -1
wavelength = np.linspace(0, 1, num=spd.shape[axis])
wavelength = wavelength.reshape(shape_wavelength)

spd, wavelength = np.broadcast_arrays(spd, wavelength, subok=True)

transform_spd_wavelength = _transform_spd_wavelength(
Expand Down
1 change: 1 addition & 0 deletions colorsynth/_tests/test_colorsynth.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def test_sRGB(
@pytest.mark.parametrize(
argnames="wavelength",
argvalues=[
None,
np.linspace(380, 780, num=101) * u.nm,
],
)
Expand Down

0 comments on commit 0be9bc5

Please sign in to comment.