-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #16 from mirsazzathossain/dev
feat(utils): add bulk image download from catalog
- Loading branch information
Showing
3 changed files
with
178 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from unittest.mock import MagicMock, patch | ||
|
||
import pandas as pd | ||
|
||
from rgc.utils.data import celestial_capture_bulk | ||
|
||
|
||
@patch("rgc.utils.data.celestial_tag") | ||
@patch("rgc.utils.data.SkyCoord") | ||
@patch("rgc.utils.data.celestial_capture") | ||
def test_celestial_capture_bulk(mock_celestial_capture, mock_SkyCoord, mock_celestial_tag): | ||
# Mock data | ||
mock_celestial_tag.return_value = "10h00m00s +10d00m00s" | ||
mock_SkyCoord.return_value = MagicMock(ra=MagicMock(deg=10), dec=MagicMock(deg=20)) | ||
|
||
catalog = pd.DataFrame({"label": ["WAT"], "object_name": ["test"]}) | ||
classes = {"WAT": 100, "NAT": 200} | ||
img_dir = "/path/to/images" | ||
|
||
# Run the function | ||
celestial_capture_bulk(catalog, "VLA FIRST (1.4 GHz)", img_dir, classes, "label") | ||
|
||
# Check that celestial_capture was called with the expected arguments | ||
mock_celestial_capture.assert_called_once_with( | ||
"VLA FIRST (1.4 GHz)", 10, 20, "/path/to/images/100_10h00m00s +10d00m00s.fits" | ||
) | ||
|
||
# Test failure handling | ||
mock_celestial_capture.reset_mock() | ||
mock_celestial_tag.side_effect = Exception("Test exception") | ||
|
||
with patch("builtins.print") as mock_print: | ||
celestial_capture_bulk(catalog, "VLA FIRST (1.4 GHz)", img_dir, classes, "object_name") | ||
mock_print.assert_called_once_with("Failed to capture image. Test exception") | ||
|
||
|
||
@patch("rgc.utils.data.celestial_tag") | ||
@patch("rgc.utils.data.SkyCoord") | ||
@patch("rgc.utils.data.celestial_capture") | ||
def test_celestial_capture_bulk_with_filename(mock_celestial_capture, mock_SkyCoord, mock_celestial_tag): | ||
# Mock data | ||
mock_celestial_tag.return_value = "10h00m00s +10d00m00s" | ||
mock_SkyCoord.return_value = MagicMock(ra=MagicMock(deg=10), dec=MagicMock(deg=20)) | ||
|
||
# Catalog with filename column | ||
catalog = pd.DataFrame({"label": ["WAT"], "filename": ["image1"], "object_name": ["test"]}) | ||
classes = {"WAT": 100, "NAT": 200} | ||
img_dir = "/path/to/images" | ||
|
||
# Run the function | ||
celestial_capture_bulk(catalog, "VLA FIRST (1.4 GHz)", img_dir, classes, "label") | ||
|
||
# Check that celestial_capture was called with the expected filename | ||
mock_celestial_capture.assert_called_once_with("VLA FIRST (1.4 GHz)", 10, 20, "/path/to/images/100_image1.fits") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import pandas as pd | ||
import pytest | ||
|
||
from rgc.utils.data import _ColumnNotFoundError, _get_class_labels | ||
|
||
|
||
def test_get_class_labels(): | ||
# Sample data | ||
catalog = pd.Series({"object_name": "Object1", "class_col": "Galaxy"}) | ||
classes = {"Galaxy": "Galactic", "Star": "Stellar"} | ||
|
||
# Test with valid column and key | ||
result = _get_class_labels(catalog, classes, "class_col") | ||
assert result == "Galactic", "Should return 'Galactic' for 'Galaxy'" | ||
|
||
# Test with invalid column | ||
with pytest.raises(_ColumnNotFoundError): | ||
_get_class_labels(catalog, classes, "invalid_col") | ||
|
||
# Test with no matching key | ||
result = _get_class_labels(catalog, classes, "object_name") | ||
assert result == "", "Should return '' if no matching key is found" |