Skip to content

Commit

Permalink
Merge pull request #16 from mirsazzathossain/dev
Browse files Browse the repository at this point in the history
feat(utils): add bulk image download from catalog
  • Loading branch information
mirsazzathossain authored Sep 15, 2024
2 parents ab48994 + cb436ba commit b022bea
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 0 deletions.
102 changes: 102 additions & 0 deletions rgc/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import numpy as np
import pandas as pd
from astropy import units as u
from astropy.coordinates import SkyCoord
from astropy.io import fits
from astroquery.skyview import SkyView
from astroquery.vizier import Vizier
Expand Down Expand Up @@ -203,6 +205,8 @@ def mask_image(image: Image.Image, mask: Image.Image) -> Image.Image:
:return: A PIL Image object containing the masked image.
:rtype: Image.Image
:raises _ImageMaskDimensionError: If the dimensions of the image and mask do not match.
"""
image_array = np.array(image)
mask_array = np.array(mask)
Expand Down Expand Up @@ -235,6 +239,21 @@ def __init__(self, message: str = "Number of images and masks must match and be


def mask_image_bulk(image_dir: str, mask_dir: str, masked_dir: str) -> None:
"""
Mask a directory of images with a directory of mask images.
:param image_dir: The path to the directory containing the images.
:type image_dir: str
:param mask_dir: The path to the directory containing the mask images.
:type mask_dir: str
:param masked_dir: The path to the directory to save the masked images.
:type masked_dir: str
:raises _FileNotFoundError: If no images or masks are found in the directories.
:raises _ImageMaskCountMismatchError: If the number of images and masks do not match.
"""
image_paths = sorted(Path(image_dir).glob("*.png"))
mask_paths = sorted(Path(mask_dir).glob("*.png"))

Expand Down Expand Up @@ -263,3 +282,86 @@ def mask_image_bulk(image_dir: str, mask_dir: str, masked_dir: str) -> None:
masked_image = mask_image(image, mask)

masked_image.save(Path(masked_dir) / image_path.name)


class _ColumnNotFoundError(Exception):
"""
An exception to be raised when a specified column is not found in the catalog.
"""

def __init__(self, column: str) -> None:
super().__init__(f"Column {column} not found in the catalog.")


def _get_class_labels(catalog: pd.Series, classes: dict, cls_col: str) -> str:
"""
Get the class labels for the celestial objects in the catalog.
:param catalog: A pandas Series representing a row in the catalog of celestial objects.
:type catalog: pd.Series
:param classes: A dictionary containing the classes of the celestial objects.
:type classes: dict
:param cls_col: The name of the column containing the class labels.
:type cls_col: str
:return: Class labels for the celestial objects in the catalog.
:rtype: str
:raises _ColumnNotFoundError: If the specified column is not found in the catalog.
"""
if cls_col not in catalog.index:
raise _ColumnNotFoundError(cls_col)

value = catalog[cls_col]
for key, label in classes.items():
if key in value:
return str(label)

return ""


def celestial_capture_bulk(
catalog: pd.DataFrame, survey: str, img_dir: str, classes: Optional[dict] = None, cls_col: Optional[str] = None
) -> None:
"""
Capture celestial images for a catalog of celestial objects.
:param catalog: A pandas DataFrame containing the catalog of celestial objects.
:type catalog: pd.DataFrame
:param survey: The name of the survey to be used e.g. 'VLA FIRST (1.4 GHz)'.
:type survey: str
:param img_dir: The path to the directory to save the images.
:type img_dir: str
:param classes: A dictionary containing the classes of the celestial objects.
:type classes: dict
:param cls_col: The name of the column containing the class labels.
:raises _InvalidCoordinatesError: If coordinates are invalid.
"""
failed = pd.DataFrame(columns=catalog.columns)
for _, entry in catalog.iterrows():
try:
tag = celestial_tag(entry)
coordinate = SkyCoord(tag, unit=(u.hourangle, u.deg))

right_ascension = coordinate.ra.deg
declination = coordinate.dec.deg

label = _get_class_labels(entry, classes, cls_col) if classes is not None and cls_col is not None else ""

if "filename" in catalog.columns:
filename = f'{img_dir}/{label}_{entry["filename"]}.fits'
else:
filename = f"{img_dir}/{label}_{tag}.fits"

celestial_capture(survey, right_ascension, declination, filename)
except Exception as err:
series = entry.to_frame().T
failed = pd.concat([failed, series], ignore_index=True)
print(f"Failed to capture image. {err}")
54 changes: 54 additions & 0 deletions tests/test_celestial_capture_bulk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from unittest.mock import MagicMock, patch

import pandas as pd

from rgc.utils.data import celestial_capture_bulk


@patch("rgc.utils.data.celestial_tag")
@patch("rgc.utils.data.SkyCoord")
@patch("rgc.utils.data.celestial_capture")
def test_celestial_capture_bulk(mock_celestial_capture, mock_SkyCoord, mock_celestial_tag):
# Mock data
mock_celestial_tag.return_value = "10h00m00s +10d00m00s"
mock_SkyCoord.return_value = MagicMock(ra=MagicMock(deg=10), dec=MagicMock(deg=20))

catalog = pd.DataFrame({"label": ["WAT"], "object_name": ["test"]})
classes = {"WAT": 100, "NAT": 200}
img_dir = "/path/to/images"

# Run the function
celestial_capture_bulk(catalog, "VLA FIRST (1.4 GHz)", img_dir, classes, "label")

# Check that celestial_capture was called with the expected arguments
mock_celestial_capture.assert_called_once_with(
"VLA FIRST (1.4 GHz)", 10, 20, "/path/to/images/100_10h00m00s +10d00m00s.fits"
)

# Test failure handling
mock_celestial_capture.reset_mock()
mock_celestial_tag.side_effect = Exception("Test exception")

with patch("builtins.print") as mock_print:
celestial_capture_bulk(catalog, "VLA FIRST (1.4 GHz)", img_dir, classes, "object_name")
mock_print.assert_called_once_with("Failed to capture image. Test exception")


@patch("rgc.utils.data.celestial_tag")
@patch("rgc.utils.data.SkyCoord")
@patch("rgc.utils.data.celestial_capture")
def test_celestial_capture_bulk_with_filename(mock_celestial_capture, mock_SkyCoord, mock_celestial_tag):
# Mock data
mock_celestial_tag.return_value = "10h00m00s +10d00m00s"
mock_SkyCoord.return_value = MagicMock(ra=MagicMock(deg=10), dec=MagicMock(deg=20))

# Catalog with filename column
catalog = pd.DataFrame({"label": ["WAT"], "filename": ["image1"], "object_name": ["test"]})
classes = {"WAT": 100, "NAT": 200}
img_dir = "/path/to/images"

# Run the function
celestial_capture_bulk(catalog, "VLA FIRST (1.4 GHz)", img_dir, classes, "label")

# Check that celestial_capture was called with the expected filename
mock_celestial_capture.assert_called_once_with("VLA FIRST (1.4 GHz)", 10, 20, "/path/to/images/100_image1.fits")
22 changes: 22 additions & 0 deletions tests/test_get_class_label.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pandas as pd
import pytest

from rgc.utils.data import _ColumnNotFoundError, _get_class_labels


def test_get_class_labels():
# Sample data
catalog = pd.Series({"object_name": "Object1", "class_col": "Galaxy"})
classes = {"Galaxy": "Galactic", "Star": "Stellar"}

# Test with valid column and key
result = _get_class_labels(catalog, classes, "class_col")
assert result == "Galactic", "Should return 'Galactic' for 'Galaxy'"

# Test with invalid column
with pytest.raises(_ColumnNotFoundError):
_get_class_labels(catalog, classes, "invalid_col")

# Test with no matching key
result = _get_class_labels(catalog, classes, "object_name")
assert result == "", "Should return '' if no matching key is found"

0 comments on commit b022bea

Please sign in to comment.