diff --git a/ripplemapper/analyse.py b/ripplemapper/analyse.py index bb9c011..fcc2855 100644 --- a/ripplemapper/analyse.py +++ b/ripplemapper/analyse.py @@ -1,4 +1,5 @@ """Mostly a collection of functions to help instantiate a list of image classes and add some contours""" + import warnings import cv2 @@ -13,14 +14,31 @@ def add_boundary_contours(ripple_images: list[RippleImage] | RippleImage | RippleImageSeries, overwrite: bool = False, level=None, **kwargs) -> list[RippleImage]: - """Add boundary contours to a list of RippleImage objects.""" + """ + Add boundary contours to a list of RippleImage objects. + + Parameters + ---------- + ripple_images : list[RippleImage] | RippleImage | RippleImageSeries + A list of RippleImage objects or a single RippleImage or RippleImageSeries. + overwrite : bool, optional + Whether to overwrite existing boundary contours, by default False. + level : optional + Contour level parameter for the find_contours function, by default None. + **kwargs + Additional keyword arguments for the find_contours function. + + Returns + ------- + list[RippleImage] + The list of RippleImage objects with added boundary contours. + """ if isinstance(ripple_images, RippleImageSeries): ripple_images = ripple_images.images if isinstance(ripple_images, RippleImage): ripple_images = [ripple_images] for ripple_image in ripple_images: if len(ripple_image.contours) > 0: - # TODO: refactor to use new get_contour method indexes = [] for i in range(len(ripple_image.contours)): if ripple_image.contours[i].method == 'Upper Boundary': @@ -34,20 +52,40 @@ def add_boundary_contours(ripple_images: list[RippleImage] | RippleImage | Rippl ripple_image.contours.pop(indexes[0]) if len(indexes) == 2: ripple_image.contours.pop(indexes[0]) - # they have now moved by 1. - ripple_image.contours.pop(indexes[1]-1) + ripple_image.contours.pop(indexes[1] - 1) else: warnings.warn(f"Boundary contours already exist, skipping image: {ripple_image.source_file}") continue edges = detect_edges(ripple_image.image) processed_edges = process_edges(edges) contours = find_contours(processed_edges, level=level) - ripple_image.add_contour(np.array([contours[0][:,0],contours[0][:,1]]), 'Upper Boundary') - ripple_image.add_contour(np.array([contours[1][:,0],contours[1][:,1]]), 'Lower Boundary') + ripple_image.add_contour(np.array([contours[0][:, 0], contours[0][:, 1]]), 'Upper Boundary') + ripple_image.add_contour(np.array([contours[1][:, 0], contours[1][:, 1]]), 'Lower Boundary') + +def add_a_star_contours(ripple_images: list[RippleImage] | RippleImage | RippleImageSeries, contour_index: list[int] = [0, 1], overwrite: bool = False) -> list[RippleImage]: + """ + Add A* contours to a list of RippleImage objects. -def add_a_star_contours(ripple_images: list[RippleImage] | RippleImage | RippleImageSeries, contour_index: list[int]=[0,1], overwrite:bool=False) -> list[RippleImage]: - """Add A* contours to a list of RippleImage objects.""" + Parameters + ---------- + ripple_images : list[RippleImage] | RippleImage | RippleImageSeries + A list of RippleImage objects or a single RippleImage or RippleImageSeries. + contour_index : list[int], optional + List of two integers indicating which contours to use, by default [0, 1]. + overwrite : bool, optional + Whether to overwrite existing A* contours, by default False. + + Returns + ------- + list[RippleImage] + The list of RippleImage objects with added A* contours. + + Raises + ------ + ValueError + If contour_index does not have exactly two integers. + """ if len(contour_index) != 2: raise ValueError("contour_index must be a list of two integers.") if isinstance(ripple_images, RippleImageSeries): @@ -58,17 +96,13 @@ def add_a_star_contours(ripple_images: list[RippleImage] | RippleImage | RippleI if len(ripple_image.contours) < 2: warnings.warn(f"RippleImage object must have at least two contours, skipping image: {ripple_image.source_file}") continue - # TODO: refactor to use new get_contour method methods = [contour.method for contour in ripple_image.contours] if 'A* traversal' in methods: if overwrite: warnings.warn(f"Overwriting A* contour for image: {ripple_image.source_file}") - # find me the method index that matches 'A* traversal' for contour in ripple_image.contours: - print(contour.method) if contour.method == 'A* traversal': ripple_image.contours.remove(contour) - print(ripple_image.contours) else: warnings.warn(f"A* contour already exists, skipping image: {ripple_image.source_file}") continue @@ -79,22 +113,38 @@ def add_a_star_contours(ripple_images: list[RippleImage] | RippleImage | RippleI bounded_img = np.zeros(ripple_image.image.shape, dtype=np.uint8) bounded_img = cv2.drawContours(bounded_img, [contour], 0, (255, 255, 255), -1) d_map = distance_map(bounded_img) - start = (np.argmax(d_map[:,0]), 0) # Pixel with the highest value on the left. - goal = (np.argmax(d_map[:, -1]), d_map.shape[1] - 1) # Highest value on the right. + start = (np.argmax(d_map[:, 0]), 0) + goal = (np.argmax(d_map[:, -1]), d_map.shape[1] - 1) path = a_star(start, goal, d_map) - # return type from my a_star function is a list of tuples, need to convert it to a numpy array - path = np.flip(np.array(path), axis=0).T # the path output has insane shape, need to flip it + path = np.flip(np.array(path), axis=0).T ripple_image.add_contour(path, 'A* traversal') + def add_chan_vese_contours(ripple_images: list[RippleImage] | RippleImage | RippleImageSeries, overwrite: bool = False, use_gradients=False, **kwargs): - """Add Chan-Vese contours to a list of RippleImage objects.""" + """ + Add Chan-Vese contours to a list of RippleImage objects. + + Parameters + ---------- + ripple_images : list[RippleImage] | RippleImage | RippleImageSeries + A list of RippleImage objects or a single RippleImage or RippleImageSeries. + overwrite : bool, optional + Whether to overwrite existing Chan-Vese contours, by default False. + use_gradients : bool, optional + Whether to use image gradients, by default False. + **kwargs + Additional keyword arguments for the cv_segmentation function. + + Returns + ------- + None + """ if isinstance(ripple_images, RippleImageSeries): ripple_images = ripple_images.images if isinstance(ripple_images, RippleImage): ripple_images = [ripple_images] for ripple_image in ripple_images: if len(ripple_image.contours) > 0: - # TODO: refactor to use new get_contour method methods = [contour.method for contour in ripple_image.contours] if 'Chan-Vese' in methods: if overwrite: @@ -107,19 +157,49 @@ def add_chan_vese_contours(ripple_images: list[RippleImage] | RippleImage | Ripp continue if use_gradients: grad = np.sum(np.abs(np.gradient(ripple_image.image)), axis=0) - img = cv2.GaussianBlur(grad / np.max(grad), (7,7), 0)+(1-(ripple_image.image/np.max(ripple_image.image))) + img = cv2.GaussianBlur(grad / np.max(grad), (7, 7), 0) + (1 - (ripple_image.image / np.max(ripple_image.image))) cv = cv_segmentation(img, **kwargs) else: cv = cv_segmentation(ripple_image.image, **kwargs) contours = find_contours(cv) - ripple_image.add_contour(np.array([contours[0][:,0],contours[0][:,1]]), 'Chan-Vese') + ripple_image.add_contour(np.array([contours[0][:, 0], contours[0][:, 1]]), 'Chan-Vese') + def remove_small_bumps(contour: RippleContour, **kwargs) -> RippleContour: - """Remove small bumps from a RippleContour object.""" + """ + Remove small bumps from a RippleContour object. + + Parameters + ---------- + contour : RippleContour + A RippleContour object to be smoothed. + **kwargs + Additional keyword arguments for the smooth_bumps function. + + Returns + ------- + RippleContour + The smoothed RippleContour object. + """ return smooth_bumps(contour, **kwargs) + def remove_small_bumps_from_images(ripple_images: list[RippleImage] | RippleImage, **kwargs) -> list[RippleImage]: - """Remove small bumps from a list of RippleImage objects.""" + """ + Remove small bumps from a list of RippleImage objects. + + Parameters + ---------- + ripple_images : list[RippleImage] | RippleImage + A list of RippleImage objects or a single RippleImage. + **kwargs + Additional keyword arguments for the remove_small_bumps function. + + Returns + ------- + list[RippleImage] + The list of RippleImage objects with smoothed contours. + """ if isinstance(ripple_images, RippleImageSeries): ripple_images = ripple_images.images if isinstance(ripple_images, RippleImage): diff --git a/ripplemapper/classes.py b/ripplemapper/classes.py index 57780d3..1c4e53b 100644 --- a/ripplemapper/classes.py +++ b/ripplemapper/classes.py @@ -19,7 +19,23 @@ class RippleContour: """Dataclass for ripple contours.""" - def __init__(self, *args, image=None): # we do not type image to prevent crossover typing + def __init__(self, *args, image=None): + """ + Initialize a RippleContour instance. + + Parameters + ---------- + *args : tuple + Arguments to initialize the contour. Can be a file path to a contour file, + or a numpy array of contour values and a method string. + image : optional + The parent image associated with the contour, by default None. + + Raises + ------ + ValueError + If the input is not a valid file path or a values-method pair. + """ if len(args) == 1 and isinstance(args[0], str) and str(args[0]).endswith('.txt'): self._load(args[0], image) elif len(args) == 2 and isinstance(args[0], np.ndarray) and isinstance(args[1], str): @@ -30,11 +46,18 @@ def __init__(self, *args, image=None): # we do not type image to prevent crosso raise ValueError("Invalid input, expected a file path to a contour file or a values, method pair.") def to_physical(self): - """Converts the contour to physical units.""" + """Convert the contour to physical units.""" return - def save(self, fname: str=False): - """Write the contour to a file.""" + def save(self, fname: str = False): + """ + Write the contour to a file. + + Parameters + ---------- + fname : str, optional + File name to save the contour. If not provided, a default name is generated. + """ if not fname: fname = f"{self.parent_image.source_file}_{self.method}.txt" with open(fname, 'w') as f: @@ -44,7 +67,16 @@ def save(self, fname: str=False): }, f) def plot(self, *args, **kwargs): - """Plot the image with contours.""" + """ + Plot the image with contours. + + Parameters + ---------- + *args : tuple + Additional positional arguments for the plot function. + **kwargs : dict + Additional keyword arguments for the plot function. + """ plot_contours(self, *args, **kwargs) if self.parent_image: plt.title(f"{self.parent_image.source_file} - Contour: {self.method}") @@ -53,12 +85,28 @@ def plot(self, *args, **kwargs): return def smooth(self, **kwargs): - """Smooth the contour.""" + """ + Smooth the contour. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments for the smooth_bumps function. + """ smooth_bumps(self, **kwargs) return def _load(self, file: str, image): - """Load a contour from a file.""" + """ + Load a contour from a file. + + Parameters + ---------- + file : str + File path to load the contour from. + image : + The parent image associated with the contour. + """ with open(file) as f: data = json.load(f) self.values = np.array(data["values"]) @@ -70,8 +118,25 @@ class RippleImage: """Class for ripple images.""" def __init__(self, *args, roi_x: list[int] = False, roi_y: list[int] = False): + """ + Initialize a RippleImage instance. + + Parameters + ---------- + *args : tuple + Arguments to initialize the image. Can be a file path to an image file, + or a file name and image data pair. + roi_x : list[int], optional + Region of interest in the x-dimension, by default False. + roi_y : list[int], optional + Region of interest in the y-dimension, by default False. + + Raises + ------ + ValueError + If the input is not a valid file path or a file name and image data pair. + """ self.contours: list[RippleContour] = [] - # Handle loading from file if the file extension is .rimg if len(args) == 1 and isinstance(args[0], str) and str(args[0]).endswith('.rimg'): self._load(args[0]) return @@ -94,15 +159,41 @@ def __init__(self, *args, roi_x: list[int] = False, roi_y: list[int] = False): self.image = preprocess_image(self.image, roi_x=roi_x, roi_y=roi_y) def __repr__(self) -> str: + """Return a string representation of the RippleImage instance.""" return f"RippleImage: {self.source_file.split('/')[-1]}" def add_contour(self, *args): - """Add a contour to the RippleImage object.""" + """ + Add a contour to the RippleImage object. + + Parameters + ---------- + *args : tuple + Arguments to initialize the contour. Can be a file path to a contour file, + or a numpy array of contour values and a method string. + """ contour = RippleContour(*args, image=self) self.contours.append(contour) def get_contour(self, contour: str | int): - """Return a given contour for the image.""" + """ + Return a given contour for the image. + + Parameters + ---------- + contour : str | int + Contour identifier. Can be an integer index or a method string. + + Returns + ------- + RippleContour + The corresponding RippleContour object. + + Raises + ------ + ValueError + If the input is not a valid integer or method string. + """ if isinstance(contour, int): return self.contours[contour] elif isinstance(contour, str): @@ -112,24 +203,54 @@ def get_contour(self, contour: str | int): else: raise ValueError("Invalid input, expected an integer or method string") - def smooth_contours(self, **kwargs): - """Smooth all the contours in the image.""" + """ + Smooth all the contours in the image. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments for the smooth method of RippleContour. + """ self.contours = [contour.smooth(**kwargs) for contour in self.contours] return def plot(self, include_contours: bool = True, *args, **kwargs): - """Plot the image with optional contours.""" + """ + Plot the image with optional contours. + + Parameters + ---------- + include_contours : bool, optional + Whether to include contours in the plot, by default True. + *args : tuple + Additional positional arguments for the plot function. + **kwargs : dict + Additional keyword arguments for the plot function. + """ plot_image(self, include_contours=include_contours, *args, **kwargs) plt.title("RippleImage: " + self.source_file.split('/')[-1]) return def save(self, fname: str = False, save_image_data: bool = False): - """Save the image and contours to a file.""" + """ + Save the image and contours to a file. + + Parameters + ---------- + fname : str, optional + File name to save the image and contours. If not provided, a default name is generated. + save_image_data : bool, optional + Whether to save the image data, by default False. + + Returns + ------- + str + The file name the image and contours were saved to. + """ if not fname: fname = self.source_file.replace('.tif', '.rimg') - # Save the metadata and contours data = { "source_file": self.source_file, "contours": [{ @@ -143,7 +264,14 @@ def save(self, fname: str = False, save_image_data: bool = False): return fname def _load(self, file: str): - """Load the image and contours from a file.""" + """ + Load the image and contours from a file. + + Parameters + ---------- + file : str + File path to load the image and contours from. + """ with gzip.open(file, 'rb') as f: data = pickle.load(f) @@ -166,6 +294,20 @@ class RippleImageSeries: """Class for a series of ripple images.""" def __init__(self, *args): + """ + Initialize a RippleImageSeries instance. + + Parameters + ---------- + *args : tuple + Arguments to initialize the series. Can be a file path to a .rimgs file, + or a list of RippleImage objects. + + Raises + ------ + ValueError + If the input is not a valid file path or a list of RippleImage objects. + """ if len(args) == 1 and isinstance(args[0], str) and str(args[0]).endswith('.rimgs'): self._load(args[0]) elif len(args) == 1 and isinstance(args[0], list) and all(isinstance(img, RippleImage) for img in args[0]): @@ -174,26 +316,66 @@ def __init__(self, *args): raise ValueError("Invalid input, expected a file path to a .rimgs file or a list of RippleImage objects.") def __repr__(self) -> str: + """Return a string representation of the RippleImageSeries instance.""" return f"RippleImageSeries: {len(self.images)} images" def animate(self, fig=plt.figure(figsize=(12,8)), fname: str = None, **kwargs): - """Animate the images.""" + """ + Animate the images. + + Parameters + ---------- + fig : matplotlib.figure.Figure, optional + The figure to animate, by default plt.figure(figsize=(12,8)). + fname : str, optional + File name to save the animation. If not provided, the animation is not saved. + **kwargs : dict + Additional keyword arguments for the plot function. + + Returns + ------- + FuncAnimation + The created animation. + """ ani = FuncAnimation(fig, partial(self.update, **kwargs), - frames=range(len(self.images)), interval=200, repeat=False) + frames=range(len(self.images)), interval=200, repeat=False) if fname: ani.save(fname, writer='ffmpeg') return ani def update(self, frame, **kwargs): + """ + Update the plot for animation. + + Parameters + ---------- + frame : int + The frame index to update. + **kwargs : dict + Additional keyword arguments for the plot function. + """ plt.clf() self.images[frame].plot(**kwargs) def save(self, fname: str = False, save_image_data: bool = False): - """Save the image series to a file.""" + """ + Save the image series to a file. + + Parameters + ---------- + fname : str, optional + File name to save the image series. If not provided, a default name is generated. + save_image_data : bool, optional + Whether to save the image data, by default False. + + Returns + ------- + str + The file name the image series was saved to. + """ if not fname: fname = 'image_series.rimgs' - # Save the metadata and contours data = [image.source_file.replace('.tif', '.rimg') for image in self.images] with gzip.open(fname, 'wb') as f: pickle.dump(data, f) @@ -208,15 +390,30 @@ def save(self, fname: str = False, save_image_data: bool = False): return fname def timeseries(self, contour: str | int = 0, **kwargs): - """Plot a timeseries of the same contour.""" + """ + Plot a timeseries of the same contour. + + Parameters + ---------- + contour : str | int, optional + The contour identifier to plot the timeseries for, by default 0. + **kwargs : dict + Additional keyword arguments for the plot_timeseries function. + """ contours = [img.get_contour(contour) for img in self.images] labels = [img.source_file.split('/')[-1] for img in self.images] - plot_timeseries(contours, labels) - + plot_timeseries(contours, labels, **kwargs) def _load(self, file: str): - """Load the image series from a file.""" + """ + Load the image series from a file. + + Parameters + ---------- + file : str + File path to load the image series from. + """ with gzip.open(file, 'rb') as f: image_files = pickle.load(f) base_path = Path(file).parent - self.images = [RippleImage(str(base_path / image_file.split("/")[-1])) for image_file in image_files] # TODO Path ojets should be accepted by RippleImage, etc. + self.images = [RippleImage(str(base_path / image_file.split("/")[-1])) for image_file in image_files] diff --git a/ripplemapper/contour.py b/ripplemapper/contour.py index e2fe06b..8002bab 100644 --- a/ripplemapper/contour.py +++ b/ripplemapper/contour.py @@ -1,4 +1,5 @@ """Ripplemapper contours module.""" + import heapq import cv2 @@ -10,19 +11,23 @@ __all__ = ["find_contours", "compute_recursive_midpoints", "extend_contour", "combine_contours", "smooth_contour", "distance_map", "neighbors", "a_star", "get_next_node", "find_boundaries", "find_bump_limits", "smooth_bumps", "average_boundaries"] -def find_contours(edges_cleaned: np.ndarray, level: float=0.5) -> np.ndarray: +def find_contours(edges_cleaned: np.ndarray, level: float = 0.5) -> np.ndarray: """ Find contours in the edge image and approximate them to simplify. - Parameters: - edges_cleaned (numpy.ndarray): Processed edge image. - tolerance (int): Tolerance value for approximating contours. + Parameters + ---------- + edges_cleaned : np.ndarray + Processed edge image. + level : float, optional + Contour level parameter for the find_contours function, by default 0.5. - Returns: - numpy.ndarray: Approximated contour vertices. + Returns + ------- + np.ndarray + Approximated contour vertices. """ contours = measure.find_contours(edges_cleaned, level=level) - # sort contours by length contours = sorted(contours, key=lambda x: len(x), reverse=True) return contours @@ -30,95 +35,112 @@ def compute_recursive_midpoints(poly_a: np.ndarray, poly_b: np.ndarray, iteratio """ Compute midpoints between two contours recursively, addressing the shape mismatch. - Parameters: - poly_a (numpy.ndarray): Vertices of the first contour. - poly_b (numpy.ndarray): Vertices of the second contour. - iterations (int): Number of iterations for recursion. + Parameters + ---------- + poly_a : np.ndarray + Vertices of the first contour. + poly_b : np.ndarray + Vertices of the second contour. + iterations : int + Number of iterations for recursion. - Returns: - numpy.ndarray: Midpoint vertices after the final iteration. + Returns + ------- + np.ndarray + Midpoint vertices after the final iteration. """ if iterations == 0: - return poly_a, poly_b # Or return poly_b, or an average if you prefer. + return poly_a, poly_b - # Initialize KD-Trees for each set of points tree_a = cKDTree(poly_a) tree_b = cKDTree(poly_b) - # New sets for midpoints midpoints_a = np.empty_like(poly_a) midpoints_b = np.empty_like(poly_b) - # Compute midpoints from a to b for i, point in enumerate(poly_a): dist, index = tree_b.query(point) nearest_point = poly_b[index] midpoints_a[i] = (point + nearest_point) / 2 - # Compute midpoints from b to a for i, point in enumerate(poly_b): dist, index = tree_a.query(point) nearest_point = poly_a[index] midpoints_b[i] = (point + nearest_point) / 2 - # Recursively refine the midpoints return compute_recursive_midpoints(midpoints_a, midpoints_b, iterations - 1) -def extend_contour(contour, shape): - """Extends the contour to the edges of the image region defined by shape. +def extend_contour(contour: np.ndarray, shape: tuple) -> np.ndarray: + """ + Extend the contour to the edges of the image region defined by shape. Parameters ---------- - contour : _type_ - points marking the vertices of the contour. + contour : np.ndarray + Points marking the vertices of the contour. shape : tuple - len, width of the image. - """ - # make a new first point and prepend it to the array, the new first point should have x=0 and y= the same as the second point + Length and width of the image. - if contour[0][1] > shape[1]/2: + Returns + ------- + np.ndarray + Extended contour vertices. + """ + if contour[0][1] > shape[1] / 2: new_first_point = [contour[0][1], shape[1]] - new_last_point = [contour[-1][1], 0] + new_last_point = [contour[-1][1], 0] else: new_first_point = [contour[0][1], 0] - new_last_point = [contour[-1][1], shape[0]] - + new_last_point = [contour[-1][1], shape[0]] contour = np.vstack([new_first_point, contour, new_last_point]) - - - return contour -def combine_contours(contour1, contour2): - """Combines two contours into one. +def combine_contours(contour1: np.ndarray, contour2: np.ndarray) -> np.ndarray: + """ + Combine two contours into one. Parameters ---------- - contour1 : _type_ - points marking the vertices of the first contour. - contour2 : _type_ - points marking the vertices of the second contour. + contour1 : np.ndarray + Points marking the vertices of the first contour. + contour2 : np.ndarray + Points marking the vertices of the second contour. + + Returns + ------- + np.ndarray + Combined contour vertices. """ - # we need contour one to run from low to high and contour 2 to run from high to low if contour1[0][1] > contour1[0][-1]: contour1 = np.flip(contour1) if contour2[0][1] < contour2[0][-1]: contour2 = np.flip(contour2) - #stitch them together contour = np.vstack([contour1, contour2]) return contour -def smooth_contour(contour: np.ndarray, window: int=3): +def smooth_contour(contour: np.ndarray, window: int = 3) -> np.ndarray: """ Smooth a contour by convolving with a small window. + + Parameters + ---------- + contour : np.ndarray + Points marking the vertices of the contour. + window : int, optional + Size of the smoothing window, by default 3. + + Returns + ------- + np.ndarray + Smoothed contour vertices. """ - x = np.convolve(contour[:, 0], np.ones(window)/window, mode='valid') - y = np.convolve(contour[:, 1], np.ones(window)/window, mode='valid') + x = np.convolve(contour[:, 0], np.ones(window) / window, mode='valid') + y = np.convolve(contour[:, 1], np.ones(window) / window, mode='valid') return np.vstack([x, y]).T -def distance_map(binary_map): +def distance_map(binary_map: np.ndarray) -> np.ndarray: """ Compute the distance map of a binary image. @@ -126,33 +148,59 @@ def distance_map(binary_map): ---------- binary_map : np.ndarray Binary image with interiors marked as 1's and exteriors as 0's. - """ - # Assuming `binary_map` is your binary image with interiors marked as 1's and exteriors as 0's - # First, ensure the binary_map is of type uint8 + Returns + ------- + np.ndarray + Normalized distance map. + """ binary_map = binary_map.astype(np.uint8) - - # Apply the distance transform distance_map = cv2.distanceTransform(binary_map, cv2.DIST_L2, cv2.DIST_MASK_PRECISE) ** 2 - # Optionally, normalize the distance map for visualization norm_distance_map = cv2.normalize(distance_map, None, 0, 1.0, cv2.NORM_MINMAX) return norm_distance_map -def neighbors(node, grid_shape): - """Generate neighbors for a given node.""" - # 8-connected grid +def neighbors(node: tuple, grid_shape: tuple) -> tuple: + """ + Generate neighbors for a given node. + + Parameters + ---------- + node : tuple + The current node coordinates. + grid_shape : tuple + Shape of the grid. + + Yields + ------ + tuple + Neighboring node coordinates. + """ directions = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)] for dx, dy in directions: - if get_next_node(node, dx, dy) is None: - continue nx, ny = node[0] + dx, node[1] + dy if 0 <= nx < grid_shape[0] and 0 <= ny < grid_shape[1]: yield (nx, ny) -def a_star(start, goal, grid): - """A simple A* algorithm.""" +def a_star(start: tuple, goal: tuple, grid: np.ndarray) -> list: + """ + A simple A* algorithm for pathfinding. + + Parameters + ---------- + start : tuple + Starting node coordinates. + goal : tuple + Goal node coordinates. + grid : np.ndarray + Grid representing the map. + + Returns + ------- + list + Path from start to goal. + """ open_set = [] - heapq.heappush(open_set, (0, start)) # (cost, node) + heapq.heappush(open_set, (0, start)) came_from = {} g_score = {start: 0} f_score = {start: np.linalg.norm(np.array(start) - np.array(goal))} @@ -165,10 +213,10 @@ def a_star(start, goal, grid): while current in came_from: path.append(current) current = came_from[current] - return path[::-1] # Return reversed path + return path[::-1] for neighbor in neighbors(current, grid.shape): - tentative_g_score = g_score[current] + 1 / (grid[neighbor] + 0.01) # Avoid division by zero + tentative_g_score = g_score[current] + 1 / (grid[neighbor] + 0.01) if neighbor not in g_score or tentative_g_score < g_score[neighbor]: came_from[neighbor] = current g_score[neighbor] = tentative_g_score @@ -177,16 +225,43 @@ def a_star(start, goal, grid): return [] -def get_next_node(node, dx, dy): +def get_next_node(node: tuple, dx: int, dy: int) -> tuple: + """ + Get the next node coordinates by moving in the given direction. + + Parameters + ---------- + node : tuple + Current node coordinates. + dx : int + Change in x-coordinate. + dy : int + Change in y-coordinate. + + Returns + ------- + tuple + Next node coordinates, or None if out of bounds. + """ try: nx, ny = node[0] + dx, node[1] + dy except IndexError: return None return nx, ny -def find_boundaries(gray_image: np.ndarray) -> np.ndarray: +def find_boundaries(gray_image: np.ndarray) -> tuple: """ Find the upper and lower boundaries of the edge region. + + Parameters + ---------- + gray_image : np.ndarray + Grayscale image. + + Returns + ------- + tuple + Upper and lower boundaries as numpy arrays. """ edges_gray = detect_edges(gray_image) edges_cleaned = process_edges(edges_gray) @@ -196,13 +271,11 @@ def find_boundaries(gray_image: np.ndarray) -> np.ndarray: upper, lower = contours[0], contours[1] else: upper, lower = contours[1], contours[0] - return (upper, lower) + return upper, lower -def find_bump_limits(large_changes: np.array, current: int = 0, max_size: int = 10, bumps:list[tuple[int, int]] = []): +def find_bump_limits(large_changes: np.array, current: int = 0, max_size: int = 10, bumps: list[tuple[int, int]] = []) -> list[tuple[int, int]]: """ Recursive function to find the limits of "small" bumps in the data. - Small bumps are defined as those where there are multiple large changes in a row, representing a rapid increase - followed by a rapid decrease in the data. Parameters ---------- @@ -225,7 +298,7 @@ def find_bump_limits(large_changes: np.array, current: int = 0, max_size: int = start = large_changes[(large_changes > current)][0] end = False for i in np.arange(1, max_size): - if start+i > large_changes[-1]: + if start + i > large_changes[-1]: return bumps if start + i in large_changes: end = start + i @@ -240,63 +313,57 @@ def smooth_bumps(contour, max_size: int = 40, std_factor: float = 2.0): """ Function to smooth out bumps in the contour data. - If there is an area of the contour where the gradient rapidly changes and then rapidly changes again, - this can be a jump between contours rather than the continual following of one contour. - Parameters ---------- contour : RippleContour Contour object containing the data to be smoothed. max_size : int, optional - Maximum size of a bump, by default 10. + Maximum size of a bump, by default 40. std_factor : float, optional - Standard deviation factor for identifying large changes, by default 3.0. + Standard deviation factor for identifying large changes, by default 2.0. Returns ------- - None - The function modifies the contour values in-place. + RippleContour + The smoothed RippleContour object. """ - # moving average - moving_avg = np.convolve(contour.values[0, :], np.ones(100)/100, mode='valid') + moving_avg = np.convolve(contour.values[0, :], np.ones(100) / 100, mode='valid') diffs = contour.values[0, :len(moving_avg)] - moving_avg gradients = np.gradient(diffs) - # find large changes, grater than the std_factor*std of the gradients versus the moving average - large_changes = np.where(np.abs(gradients) > std_factor*np.std(gradients))[0] - # find any small bumps, i.e. those where there are multiple large changes in a row + large_changes = np.where(np.abs(gradients) > std_factor * np.std(gradients))[0] if len(large_changes) == 0: return contour bumps = find_bump_limits(large_changes, max_size=max_size, bumps=[]) - # unroll each bump into all indices contained within lims indices = [] for bump in bumps: - indices += list(np.arange(bump[0],bump[1])) + indices += list(np.arange(bump[0], bump[1])) indices = np.array(indices) print("num removed", indices.shape) contour.values = np.delete(contour.values, indices[indices < contour.values.shape[1]], axis=1) return contour - -def average_boundaries(self, contour_a = None, contour_b = None, iterations: int=3, save_both: bool=True): - """Average the two contours to get a more accurate representation of the interface. +def average_boundaries(self, contour_a = None, contour_b = None, iterations: int = 3, save_both: bool = True) -> np.ndarray: + """ + Average the two contours to get a more accurate representation of the interface. Parameters ---------- contour_a : RippleContour, optional - The first contour to average, by default self.contours[0] + The first contour to average, by default self.contours[0]. contour_b : RippleContour, optional - The second contour to average, by default self.contours[1] + The second contour to average, by default self.contours[1]. iterations : int, optional - The number of iterations to average the contours over, by default 3 + The number of iterations to average the contours over, by default 3. + save_both : bool, optional + Whether to save both averaged contours, by default True. Returns ------- np.ndarray - The averaged contour + The averaged contour. """ from ripplemapper.classes import RippleContour - # if no contours passed then we use the first two in the list if not contour_a or not contour_b: try: contour_a = self.contours[0] diff --git a/ripplemapper/image.py b/ripplemapper/image.py index cd2eb53..1a29025 100644 --- a/ripplemapper/image.py +++ b/ripplemapper/image.py @@ -1,61 +1,82 @@ """Ripplemapper images module.""" + import numpy as np from skimage import color, feature, filters, morphology from skimage.segmentation import chan_vese __all__ = ["preprocess_image", "cv_segmentation", "detect_edges", "process_edges", "threshold_image"] -def threshold_image(image: np.ndarray, level=0.8) -> np.ndarray: +def threshold_image(image: np.ndarray, level: float = 0.8) -> np.ndarray: """ Threshold the image to make any pixel above the level equal to the max. - Parameters: - image numpy.ndarray: Input image. - level float: Threshold value. - - Returns: - numpy.ndarray: Binary image. + Parameters + ---------- + image : np.ndarray + Input image. + level : float, optional + Threshold value, by default 0.8. + + Returns + ------- + np.ndarray + Binary image. """ prev_max = np.max(image) - image = image/prev_max + image = image / prev_max image[image > level] = 1 image *= prev_max return image - -def preprocess_image(image: np.ndarray, roi_x: list[int]=False, roi_y: list[int]=False, sigma: float=1) -> np.ndarray: +def preprocess_image(image: np.ndarray, roi_x: list[int] = False, roi_y: list[int] = False, sigma: float = 1) -> np.ndarray: """ Preprocess the image by converting it to grayscale and applying Gaussian blur. - Parameters: - image (numpy.ndarray): Input image. - - Returns: - numpy.ndarray: Preprocessed image. + Parameters + ---------- + image : np.ndarray + Input image. + roi_x : list[int], optional + Region of interest in the x-dimension, by default False. + roi_y : list[int], optional + Region of interest in the y-dimension, by default False. + sigma : float, optional + Sigma value for Gaussian blur, by default 1. + + Returns + ------- + np.ndarray + Preprocessed image. """ if len(image.shape) == 3 and image.shape[2] == 3: gray_image = color.rgb2gray(image) else: gray_image = image blurred_gray_image = filters.gaussian(gray_image, sigma=sigma) - # crop the image + if roi_x: blurred_gray_image = blurred_gray_image[roi_x[0]:roi_x[1], :] if roi_y: blurred_gray_image = blurred_gray_image[:, roi_y[0]:roi_y[1]] + return blurred_gray_image def cv_segmentation(image: np.ndarray, **kwargs) -> np.ndarray: """ Perform image segmentation using skimage chan-vese method. - Parameters: - image (numpy.ndarray): Input image. - - Returns: - numpy.ndarray: Segmented image. + Parameters + ---------- + image : np.ndarray + Input image. + **kwargs + Additional keyword arguments for chan-vese method. + + Returns + ------- + np.ndarray + Segmented image. """ - # Define default values for kwargs in the wrapper default_kwargs = { 'mu': 0.2, 'dt': 0.5, @@ -65,43 +86,56 @@ def cv_segmentation(image: np.ndarray, **kwargs) -> np.ndarray: 'tol': 1e-3, } - # Update kwargs with default values if they are not already set for key, value in default_kwargs.items(): kwargs.setdefault(key, value) - cv = chan_vese( - image, - **kwargs - ) + cv = chan_vese(image, **kwargs) return cv -def detect_edges(image: np.ndarray, sigma=1, low_threshold: np.float32=0.1, high_threshold: np.float32=0.5) -> np.ndarray: +def detect_edges(image: np.ndarray, sigma: float = 1, low_threshold: float = 0.1, high_threshold: float = 0.5) -> np.ndarray: """ Detect edges in the image using Canny edge detection. - Parameters: - image (numpy.ndarray): Input image. - - Returns: - numpy.ndarray: Binary edge image. + Parameters + ---------- + image : np.ndarray + Input image. + sigma : float, optional + Sigma value for Gaussian filter, by default 1. + low_threshold : float, optional + Low threshold for hysteresis, by default 0.1. + high_threshold : float, optional + High threshold for hysteresis, by default 0.5. + + Returns + ------- + np.ndarray + Binary edge image. """ edges_gray = feature.canny(image, sigma=sigma, low_threshold=low_threshold, high_threshold=high_threshold) return edges_gray -def process_edges(edges_gray: np.ndarray, sigma: float=0) -> np.ndarray: +def process_edges(edges_gray: np.ndarray, sigma: float = 0) -> np.ndarray: """ Process the binary edge image by performing morphological operations. - Parameters: - edges_gray (numpy.ndarray): Binary edge image. - - Returns: - numpy.ndarray: Processed edge image. + Parameters + ---------- + edges_gray : np.ndarray + Binary edge image. + sigma : float, optional + Sigma value for Gaussian filter, by default 0. + + Returns + ------- + np.ndarray + Processed edge image. """ edges_dilated = morphology.binary_dilation(edges_gray, footprint=np.ones((5, 5))) edges_closed = morphology.binary_closing(edges_dilated, footprint=np.ones((5, 5))) edges_cleaned = morphology.remove_small_objects(edges_closed, min_size=64) - # optionally blur the edges + if sigma > 0: edges_cleaned = filters.gaussian(edges_cleaned, sigma=sigma) + return edges_closed diff --git a/ripplemapper/io.py b/ripplemapper/io.py index 49ad36b..b4a1348 100644 --- a/ripplemapper/io.py +++ b/ripplemapper/io.py @@ -1,4 +1,5 @@ """This module is for input/output functions.""" + import os from pathlib import PosixPath, WindowsPath @@ -8,7 +9,24 @@ __all__ = ["load_image", "load_tif", "load_dir", "load_dir_to_obj", "load"] def load(file: str | PosixPath | WindowsPath): - """Load a file into a ripplemapper object based on file extension.""" + """ + Load a file into a ripplemapper object based on file extension. + + Parameters + ---------- + file : str | PosixPath | WindowsPath + File path to load. + + Returns + ------- + RippleContour, RippleImage, or RippleImageSeries + Loaded ripplemapper object based on file extension. + + Raises + ------ + ValueError + If the file type is unsupported. + """ from ripplemapper.classes import (RippleContour, RippleImage, RippleImageSeries) @@ -23,12 +41,25 @@ def load(file: str | PosixPath | WindowsPath): else: raise ValueError(f"Unsupported file type: {file}") - -# TODO (ADW): Add support for other image file types just use load_tif for now. -# should probably be looping in this function rather than the dispatched functions but... it's fine for now. def load_image(file: str | PosixPath | WindowsPath) -> np.ndarray: - """Load an image file based on file extension.""" - # TODO (ADW): this needs to be refactored to allow lists. + """ + Load an image file based on file extension. + + Parameters + ---------- + file : str | PosixPath | WindowsPath + File path to load. + + Returns + ------- + np.ndarray + Loaded image data. + + Raises + ------ + ValueError + If the file type is unsupported. + """ if isinstance(file, PosixPath) | isinstance(file, WindowsPath): file = str(file.resolve()) if file.endswith('.tif') or file.endswith('.tiff'): @@ -37,10 +68,20 @@ def load_image(file: str | PosixPath | WindowsPath) -> np.ndarray: raise ValueError(f"Unsupported file type: {file}") return img_data[0] - def load_tif(files: str | list[str]) -> list[np.ndarray]: - """Load an array of tif files and return numpy.ndarray.""" + """ + Load an array of tif files and return numpy.ndarray. + + Parameters + ---------- + files : str | list[str] + File path or list of file paths to load. + Returns + ------- + list[np.ndarray] + List of loaded image data arrays. + """ if isinstance(files, str): files = [files] @@ -51,22 +92,32 @@ def load_tif(files: str | list[str]) -> list[np.ndarray]: return files, img_data -def load_dir(directory: str | PosixPath, pattern: str | bool = False, skip: int = 1, start: int=0, end: int | bool=None) -> tuple[list[np.ndarray], list[str]]: - """Load all tif files found in directory and return the data in a list of numpy.ndarray. +def load_dir(directory: str | PosixPath, pattern: str | bool = False, skip: int = 1, start: int = 0, end: int | bool = None) -> tuple[list[np.ndarray], list[str]]: + """ + Load all tif files found in directory and return the data in a list of numpy.ndarray. Parameters ---------- - directory : str - directory path to load tif files from - pattern : str, optional - optional pattern to match file names, by default False + directory : str | PosixPath + Directory path to load tif files from. + pattern : str | bool, optional + Optional pattern to match file names, by default False. skip : int, optional - number of files to skip, by default False + Number of files to skip, by default 1. + start : int, optional + Starting index, by default 0. + end : int | bool, optional + Ending index, by default None. Returns ------- tuple[list[np.ndarray], list[str]] - list of the data arrays extracted from the tif files. + List of the data arrays extracted from the tif files and their corresponding file names. + + Raises + ------ + FileNotFoundError + If no tif files are found in the directory. """ if isinstance(directory, PosixPath): directory = str(directory.resolve()) @@ -86,19 +137,30 @@ def load_dir(directory: str | PosixPath, pattern: str | bool = False, skip: int files, img_data = load_tif([os.path.join(directory, file) for file in files]) return files, img_data -def load_dir_to_obj(directory: str | PosixPath, pattern: str | bool = False, skip: int = 1, start: int=0, end: int=None, **kwargs) -> list: - """Load all tif files found in directory and return the data in a list of Ripple Image objects. +def load_dir_to_obj(directory: str | PosixPath, pattern: str | bool = False, skip: int = 1, start: int = 0, end: int = None, **kwargs) -> list: + """ + Load all tif files found in directory and return the data in a list of RippleImage objects. Parameters ---------- - + directory : str | PosixPath + Directory path to load tif files from. + pattern : str | bool, optional + Optional pattern to match file names, by default False. + skip : int, optional + Number of files to skip, by default 1. + start : int, optional + Starting index, by default 0. + end : int, optional + Ending index, by default None. + **kwargs + Additional keyword arguments for the RippleImage initialization. Returns ------- list[RippleImage] - list of the data arrays extracted from the tif files. + List of RippleImage objects initialized from the tif files. """ - # prevent circular import from ripplemapper.classes import RippleImage files, img_data = load_dir(directory, pattern, skip=skip, start=start, end=end) return [RippleImage(file, img_data, **kwargs) for file, img_data in zip(files, img_data)] diff --git a/ripplemapper/visualisation.py b/ripplemapper/visualisation.py index 200e4bf..bbc7cba 100644 --- a/ripplemapper/visualisation.py +++ b/ripplemapper/visualisation.py @@ -6,7 +6,18 @@ __all__ = ['plot_contours', 'plot_image', 'plot_timeseries'] def plot_contours(ripple_contours, *args, **kwargs): - """Plot the contour.""" + """ + Plot the contour. + + Parameters + ---------- + ripple_contours : RippleContour or list of RippleContour + The contour or list of contours to plot. + *args : tuple + Additional positional arguments for the plot function. + **kwargs : dict + Additional keyword arguments for the plot function. + """ if not isinstance(ripple_contours, list): ripple_contours = [ripple_contours] for contour in ripple_contours: @@ -14,20 +25,25 @@ def plot_contours(ripple_contours, *args, **kwargs): label = contour.parent_image.source_file.split('/')[-1] + ' : ' + contour.method else: label = contour.method - plt.plot(contour.values[1], contour.values[0], label=label, *args, **kwargs) - # set y axis to be high to low + plt.plot(contour.values[1], contour.values[0], label=label, *args, **kwargs) + # set y-axis to match images (inverted) ax = plt.gca() ax.set_ylim((np.max(ax.get_ylim()), np.min(ax.get_ylim()))) -def plot_image(ripple_image, include_contours: bool=True, cmap: str='gray', **kwargs): - """Plot a RippleImage object. +def plot_image(ripple_image, include_contours: bool = True, cmap: str = 'gray', **kwargs): + """ + Plot a RippleImage object. Parameters ---------- ripple_image : RippleImage The RippleImage object to plot. include_contours : bool, optional - whether to include all the RippleContours on the plot, by default True + Whether to include all the RippleContours on the plot, by default True. + cmap : str, optional + Colormap to use for plotting the image, by default 'gray'. + **kwargs : dict + Additional keyword arguments for the imshow function. """ if ripple_image.image is None: if ripple_image.contours is None: @@ -58,9 +74,19 @@ def plot_image(ripple_image, include_contours: bool=True, cmap: str='gray', **k plt.plot(contour.values[:][1], contour.values[:][0], label=contour.method) plt.legend() - def plot_timeseries(contours, labels, **kwargs): - """Plot a timeseries of contours.""" + """ + Plot a timeseries of contours. + + Parameters + ---------- + contours : list of RippleContour + List of contours to plot. + labels : list of str + List of labels corresponding to each contour. + **kwargs : dict + Additional keyword arguments for the plot function. + """ for i, contour in enumerate(contours): plt.plot(contour.values[1], contour.values[0], label=labels[i], **kwargs) plt.gca().invert_yaxis()