Skip to content

Commit

Permalink
Stop batch job monitoring if a batch job is canceled (#537)
Browse files Browse the repository at this point in the history
* check for canceled job

* update existing tests

* add test for catching canceled job

* remove unnecessary assert

* don't raise for status in monitoring job

* update logger trigger for processing status

* fix test due to logging change

* remove unnecessary test

---------

Co-authored-by: Matic Lubej <[email protected]>
  • Loading branch information
mlubej and Matic Lubej authored Aug 2, 2024
1 parent bc9abf1 commit 6634bdd
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
7 changes: 5 additions & 2 deletions sentinelhub/api/batch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,12 @@ def monitor_batch_job(

progress_bar = tqdm(total=batch_request.tile_count, initial=finished_count, desc="Progress rate")
success_bar = tqdm(total=finished_count, initial=success_count, desc="Success rate")

monitoring_status = [BatchRequestStatus.ANALYSIS_DONE, BatchRequestStatus.PROCESSING]
with progress_bar, success_bar:
while finished_count < batch_request.tile_count:
while finished_count < batch_request.tile_count and batch_request.status in monitoring_status:
time.sleep(sleep_time)
batch_request = batch_client.get_request(batch_request)

tiles_per_status = _get_batch_tiles_per_status(batch_request, batch_client)
new_success_count = len(tiles_per_status[BatchTileStatus.PROCESSED])
Expand All @@ -94,8 +97,8 @@ def monitor_batch_job(
if failed_tiles_num:
LOGGER.info("Batch job failed for %d tiles", failed_tiles_num)

LOGGER.info("Waiting on batch job status update.")
while batch_request.status is BatchRequestStatus.PROCESSING:
LOGGER.info("Waiting on batch job status update.")
time.sleep(sleep_time)
batch_request = batch_client.get_request(batch_request)

Expand Down
20 changes: 10 additions & 10 deletions tests/api/batch/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,17 @@ def test_monitor_batch_process_job(
len(tiles) == tile_count for tiles in tiles_sequence
), "There should be the same number of tiles in each step. Fix tile_status_sequence parameter of this test."

batch_request = BatchRequest(
request_id="mocked-request", process_request={}, tile_count=tile_count, status=batch_status
)
batch_kwargs = dict(request_id="mocked-request", process_request={}, tile_count=tile_count)
batch_request = BatchRequest(**batch_kwargs, status=batch_status)
updated_batch_request = BatchRequest(**batch_kwargs, status=BatchRequestStatus.DONE)

monitor_analysis_mock = mocker.patch("sentinelhub.api.batch.utils.monitor_batch_analysis")
monitor_analysis_mock.return_value = batch_request

updated_batch_request = BatchRequest.from_dict({**batch_request.to_dict(), "status": BatchRequestStatus.DONE})
# keep returning the same batch request until the last step (simulate batch job status update)
progress_loop_counts = len(tile_status_sequence) - 1
batch_request_update_mock = mocker.patch("sentinelhub.SentinelHubBatch.get_request")
batch_request_update_mock.return_value = updated_batch_request
batch_request_update_mock.side_effect = progress_loop_counts * [batch_request] + [updated_batch_request]

batch_tiles_mock = mocker.patch("sentinelhub.SentinelHubBatch.iter_tiles")
batch_tiles_mock.side_effect = tiles_sequence
Expand All @@ -92,18 +94,16 @@ def test_monitor_batch_process_job(
assert len(results[tile_status]) == tile_status_sequence[-1].get(tile_status, 0)

assert monitor_analysis_mock.call_count == 1

progress_loop_counts = len(tile_status_sequence) - 1

assert batch_tiles_mock.call_count == progress_loop_counts + 1
assert all(call.args == (batch_request,) and call.kwargs == {} for call in batch_tiles_mock.mock_calls)

assert sleep_mock.call_count == progress_loop_counts + batch_request_update_mock.call_count
additional_calls = 1 if batch_status is BatchRequestStatus.PROCESSING else 0
assert sleep_mock.call_count == progress_loop_counts + additional_calls
assert all(call.args == (sleep_time,) and call.kwargs == {} for call in sleep_mock.mock_calls)

is_processing_logged = batch_status is BatchRequestStatus.PROCESSING
is_failure_logged = BatchTileStatus.FAILED in tile_status_sequence[-1]
assert logging_mock.call_count == int(is_processing_logged) + int(is_failure_logged) + 2
assert logging_mock.call_count == int(is_processing_logged) + int(is_failure_logged) + additional_calls + 1


def _tile_status_counts_to_tiles(tile_status_counts: dict[BatchTileStatus, int]) -> list[dict[str, str]]:
Expand Down

0 comments on commit 6634bdd

Please sign in to comment.