Skip to content

Commit

Permalink
fix running task when reload loaded model on single node cluster (#561)…
Browse files Browse the repository at this point in the history
… (#617)

Signed-off-by: Yaliang Wu <[email protected]>

Signed-off-by: Yaliang Wu <[email protected]>
Signed-off-by: Sicheng Song <[email protected]>

Signed-off-by: Yaliang Wu <[email protected]>
Signed-off-by: Sicheng Song <[email protected]>
Co-authored-by: Yaliang Wu <[email protected]>
  • Loading branch information
b4sjoo and ylwu-amzn authored Dec 5, 2022
1 parent 7b2f38f commit 11100ff
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
13 changes: 7 additions & 6 deletions plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,7 @@ public void updateMLTask(String taskId, Map<String, Object> updatedFields, long
log.error("Failed to update ML task {}, status: {}", taskId, response.status());
}
}, e -> log.error("Failed to update ML task: " + taskId, e));
updateMLTask(taskId, updatedFields, ActionListener.runAfter(internalListener, () -> {
if (removeFromCache) {
remove(taskId);
}
}), timeoutInMillis);
updateMLTask(taskId, updatedFields, internalListener, timeoutInMillis, removeFromCache);
}

/**
Expand All @@ -283,14 +279,19 @@ public void updateMLTask(String taskId, Map<String, Object> updatedFields, long
* @param updatedFields updated field and values
* @param listener action listener
* @param timeoutInMillis time out waiting for updating task semaphore, zero or negative means don't wait at all
* @param removeFromCache remove ML task from cache
*/
public void updateMLTask(
String taskId,
Map<String, Object> updatedFields,
ActionListener<UpdateResponse> listener,
long timeoutInMillis
long timeoutInMillis,
boolean removeFromCache
) {
MLTaskCache taskCache = taskCaches.get(taskId);
if (removeFromCache) {
taskCaches.remove(taskId);
}
if (taskCache == null) {
listener.onFailure(new MLResourceNotFoundException("Can't find task"));
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,19 @@ public void testAdd() {
public void testUpdateMLTaskWithNullOrEmptyMap() {
mlTaskManager.add(mlTask);
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
mlTaskManager.updateMLTask(mlTask.getTaskId(), null, listener, 0);
mlTaskManager.updateMLTask(mlTask.getTaskId(), null, listener, 0, false);
verify(client, never()).update(any(), any());
verify(listener, times(1)).onFailure(any());

mlTaskManager.updateMLTask(mlTask.getTaskId(), new HashMap<>(), listener, 0);
mlTaskManager.updateMLTask(mlTask.getTaskId(), new HashMap<>(), listener, 0, false);
verify(client, never()).update(any(), any());
verify(listener, times(2)).onFailure(any());
}

public void testUpdateMLTask_NonExistingTask() {
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
mlTaskManager.updateMLTask(mlTask.getTaskId(), null, listener, 0);
mlTaskManager.updateMLTask(mlTask.getTaskId(), null, listener, 0, false);
verify(client, never()).update(any(), any());
verify(listener, times(1)).onFailure(argumentCaptor.capture());
assertEquals("Can't find task", argumentCaptor.getValue().getMessage());
Expand All @@ -128,11 +128,11 @@ public void testUpdateMLTask_NoSemaphore() {
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
mlTaskManager.updateMLTask(asyncMlTask.getTaskId(), ImmutableMap.of(MLTask.ERROR_FIELD, "test error"), ActionListener.wrap(r -> {
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
mlTaskManager.updateMLTask(asyncMlTask.getTaskId(), null, listener, 0);
mlTaskManager.updateMLTask(asyncMlTask.getTaskId(), null, listener, 0, false);
verify(client, times(1)).update(any(), any());
verify(listener, times(1)).onFailure(argumentCaptor.capture());
assertEquals("Other updating request not finished yet", argumentCaptor.getValue().getMessage());
}, e -> { assertNull(e); }), 0);
}, e -> { assertNull(e); }), 0, false);
}

public void testUpdateMLTask_FailedToUpdate() {
Expand All @@ -148,7 +148,7 @@ public void testUpdateMLTask_FailedToUpdate() {

ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
mlTaskManager.updateMLTask(asyncMlTask.getTaskId(), ImmutableMap.of(MLTask.ERROR_FIELD, "test error"), listener, 0);
mlTaskManager.updateMLTask(asyncMlTask.getTaskId(), ImmutableMap.of(MLTask.ERROR_FIELD, "test error"), listener, 0, false);
verify(client, times(1)).update(any(), any());
verify(listener, times(1)).onFailure(argumentCaptor.capture());
assertEquals(errorMessage, argumentCaptor.getValue().getMessage());
Expand All @@ -163,7 +163,7 @@ public void testUpdateMLTask_ThrowException() {

ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
mlTaskManager.updateMLTask(asyncMlTask.getTaskId(), ImmutableMap.of(MLTask.ERROR_FIELD, "test error"), listener, 0);
mlTaskManager.updateMLTask(asyncMlTask.getTaskId(), ImmutableMap.of(MLTask.ERROR_FIELD, "test error"), listener, 0, true);
verify(client, times(1)).update(any(), any());
verify(listener, times(1)).onFailure(argumentCaptor.capture());
assertEquals(errorMessage, argumentCaptor.getValue().getMessage());
Expand Down

0 comments on commit 11100ff

Please sign in to comment.