diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java index 7203dbccb8..cd1acd4507 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java @@ -270,11 +270,7 @@ public void updateMLTask(String taskId, Map updatedFields, long log.error("Failed to update ML task {}, status: {}", taskId, response.status()); } }, e -> log.error("Failed to update ML task: " + taskId, e)); - updateMLTask(taskId, updatedFields, ActionListener.runAfter(internalListener, () -> { - if (removeFromCache) { - remove(taskId); - } - }), timeoutInMillis); + updateMLTask(taskId, updatedFields, internalListener, timeoutInMillis, removeFromCache); } /** @@ -283,14 +279,19 @@ public void updateMLTask(String taskId, Map updatedFields, long * @param updatedFields updated field and values * @param listener action listener * @param timeoutInMillis time out waiting for updating task semaphore, zero or negative means don't wait at all + * @param removeFromCache remove ML task from cache */ public void updateMLTask( String taskId, Map updatedFields, ActionListener listener, - long timeoutInMillis + long timeoutInMillis, + boolean removeFromCache ) { MLTaskCache taskCache = taskCaches.get(taskId); + if (removeFromCache) { + taskCaches.remove(taskId); + } if (taskCache == null) { listener.onFailure(new MLResourceNotFoundException("Can't find task")); return; diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java index ff57dc410c..b9626dc985 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java @@ -95,11 +95,11 @@ public void testAdd() { public void testUpdateMLTaskWithNullOrEmptyMap() { mlTaskManager.add(mlTask); ActionListener listener = mock(ActionListener.class); - mlTaskManager.updateMLTask(mlTask.getTaskId(), null, listener, 0); + mlTaskManager.updateMLTask(mlTask.getTaskId(), null, listener, 0, false); verify(client, never()).update(any(), any()); verify(listener, times(1)).onFailure(any()); - mlTaskManager.updateMLTask(mlTask.getTaskId(), new HashMap<>(), listener, 0); + mlTaskManager.updateMLTask(mlTask.getTaskId(), new HashMap<>(), listener, 0, false); verify(client, never()).update(any(), any()); verify(listener, times(2)).onFailure(any()); } @@ -107,7 +107,7 @@ public void testUpdateMLTaskWithNullOrEmptyMap() { public void testUpdateMLTask_NonExistingTask() { ActionListener listener = mock(ActionListener.class); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - mlTaskManager.updateMLTask(mlTask.getTaskId(), null, listener, 0); + mlTaskManager.updateMLTask(mlTask.getTaskId(), null, listener, 0, false); verify(client, never()).update(any(), any()); verify(listener, times(1)).onFailure(argumentCaptor.capture()); assertEquals("Can't find task", argumentCaptor.getValue().getMessage()); @@ -128,11 +128,11 @@ public void testUpdateMLTask_NoSemaphore() { ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); mlTaskManager.updateMLTask(asyncMlTask.getTaskId(), ImmutableMap.of(MLTask.ERROR_FIELD, "test error"), ActionListener.wrap(r -> { ActionListener listener = mock(ActionListener.class); - mlTaskManager.updateMLTask(asyncMlTask.getTaskId(), null, listener, 0); + mlTaskManager.updateMLTask(asyncMlTask.getTaskId(), null, listener, 0, false); verify(client, times(1)).update(any(), any()); verify(listener, times(1)).onFailure(argumentCaptor.capture()); assertEquals("Other updating request not finished yet", argumentCaptor.getValue().getMessage()); - }, e -> { assertNull(e); }), 0); + }, e -> { assertNull(e); }), 0, false); } public void testUpdateMLTask_FailedToUpdate() { @@ -148,7 +148,7 @@ public void testUpdateMLTask_FailedToUpdate() { ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); ActionListener listener = mock(ActionListener.class); - mlTaskManager.updateMLTask(asyncMlTask.getTaskId(), ImmutableMap.of(MLTask.ERROR_FIELD, "test error"), listener, 0); + mlTaskManager.updateMLTask(asyncMlTask.getTaskId(), ImmutableMap.of(MLTask.ERROR_FIELD, "test error"), listener, 0, false); verify(client, times(1)).update(any(), any()); verify(listener, times(1)).onFailure(argumentCaptor.capture()); assertEquals(errorMessage, argumentCaptor.getValue().getMessage()); @@ -163,7 +163,7 @@ public void testUpdateMLTask_ThrowException() { ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); ActionListener listener = mock(ActionListener.class); - mlTaskManager.updateMLTask(asyncMlTask.getTaskId(), ImmutableMap.of(MLTask.ERROR_FIELD, "test error"), listener, 0); + mlTaskManager.updateMLTask(asyncMlTask.getTaskId(), ImmutableMap.of(MLTask.ERROR_FIELD, "test error"), listener, 0, true); verify(client, times(1)).update(any(), any()); verify(listener, times(1)).onFailure(argumentCaptor.capture()); assertEquals(errorMessage, argumentCaptor.getValue().getMessage());