Skip to content

Commit

Permalink
fix error message with unwrapping the root cause (#2458)
Browse files Browse the repository at this point in the history
Signed-off-by: Jing Zhang <[email protected]>
  • Loading branch information
jngz-es authored and dhrubo-os committed Oct 1, 2024
1 parent 77752e2 commit 708bd78
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,7 @@ public static void logException(String errorMessage, Exception e, Logger log) {
}
}

public static Throwable getRootCause(Throwable t) {
return ExceptionUtils.getRootCause(t);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.utils.error;

import org.opensearch.OpenSearchException;
import org.opensearch.ml.utils.MLExceptionUtils;

import lombok.experimental.UtilityClass;

Expand All @@ -23,22 +24,9 @@ public static ErrorMessage createErrorMessage(Throwable e, int status) {
int st = status;
if (t instanceof OpenSearchException) {
st = ((OpenSearchException) t).status().getStatus();
} else {
t = unwrapCause(e);
}
t = MLExceptionUtils.getRootCause(t);

return new ErrorMessage(t, st);
}

protected static Throwable unwrapCause(Throwable t) {
Throwable result = t;
if (result instanceof OpenSearchException) {
return result;
}
if (result.getCause() == null) {
return result;
}
result = unwrapCause(result.getCause());
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
Expand All @@ -44,6 +45,7 @@
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.RemoteTransportException;

public class RestMLExecuteActionTests extends OpenSearchTestCase {

Expand Down Expand Up @@ -206,4 +208,77 @@ public void testPrepareRequest_disabled() {
when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false);
assertThrows(IllegalStateException.class, () -> restMLExecuteAction.handleRequest(request, channel, client));
}

public void testPrepareRequestClientException() throws Exception {
doAnswer(invocation -> {
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new IllegalArgumentException("Illegal Argument Exception"));
return null;
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
doNothing().when(channel).sendResponse(any());
RestRequest request = getLocalSampleCalculatorRestRequest();
restMLExecuteAction.handleRequest(request, channel, client);

ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class);
verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any());
Input input = argumentCaptor.getValue().getInput();
assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName());
ArgumentCaptor<RestResponse> restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class);
verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture());
BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue();
assertEquals(RestStatus.BAD_REQUEST, response.status());
String content = response.content().utf8ToString();
String expectedError =
"{\"error\":{\"reason\":\"Invalid Request\",\"details\":\"Illegal Argument Exception\",\"type\":\"IllegalArgumentException\"},\"status\":400}";
assertEquals(expectedError, response.content().utf8ToString());
}

public void testPrepareRequestClientWrappedException() throws Exception {
doAnswer(invocation -> {
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
actionListener
.onFailure(
new RemoteTransportException("Remote Transport Exception", new IllegalArgumentException("Illegal Argument Exception"))
);
return null;
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
doNothing().when(channel).sendResponse(any());
RestRequest request = getLocalSampleCalculatorRestRequest();
restMLExecuteAction.handleRequest(request, channel, client);

ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class);
verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any());
Input input = argumentCaptor.getValue().getInput();
assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName());
ArgumentCaptor<RestResponse> restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class);
verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture());
BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue();
assertEquals(RestStatus.BAD_REQUEST, response.status());
String expectedError =
"{\"error\":{\"reason\":\"Invalid Request\",\"details\":\"Illegal Argument Exception\",\"type\":\"IllegalArgumentException\"},\"status\":400}";
assertEquals(expectedError, response.content().utf8ToString());
}

public void testPrepareRequestSystemException() throws Exception {
doAnswer(invocation -> {
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new RuntimeException("System Exception"));
return null;
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
doNothing().when(channel).sendResponse(any());
RestRequest request = getLocalSampleCalculatorRestRequest();
restMLExecuteAction.handleRequest(request, channel, client);

ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class);
verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any());
Input input = argumentCaptor.getValue().getInput();
assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName());
ArgumentCaptor<RestResponse> restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class);
verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture());
BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue();
assertEquals(RestStatus.INTERNAL_SERVER_ERROR, response.status());
String expectedError =
"{\"error\":{\"reason\":\"System Error\",\"details\":\"System Exception\",\"type\":\"RuntimeException\"},\"status\":500}";
assertEquals(expectedError, response.content().utf8ToString());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,41 @@

package org.opensearch.ml.utils.error;

import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import org.junit.Test;
import org.opensearch.OpenSearchException;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.transport.RemoteTransportException;

public class ErrorMessageFactoryTests {

private Throwable nonOpenSearchThrowable = new Throwable();
private Throwable openSearchThrowable = new OpenSearchException(nonOpenSearchThrowable);

@Test
public void openSearchExceptionShouldCreateEsErrorMessage() {
Exception exception = new OpenSearchException(nonOpenSearchThrowable);
ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus());
assertTrue(msg.exception instanceof OpenSearchException);
}

@Test
public void nonOpenSearchExceptionShouldCreateGenericErrorMessage() {
Exception exception = new Exception(nonOpenSearchThrowable);
ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus());
assertFalse(msg.exception instanceof OpenSearchException);
public void openSearchExceptionWithoutNestedException() {
Throwable openSearchThrowable = new OpenSearchException("OpenSearch Exception");
ErrorMessage errorMessage = ErrorMessageFactory.createErrorMessage(openSearchThrowable, RestStatus.BAD_REQUEST.getStatus());
assertTrue(errorMessage.exception instanceof OpenSearchException);
assertEquals(RestStatus.INTERNAL_SERVER_ERROR.getStatus(), errorMessage.getStatus());
}

@Test
public void nonOpenSearchExceptionWithWrappedEsExceptionCauseShouldCreateEsErrorMessage() {
Exception exception = (Exception) openSearchThrowable;
ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus());
assertTrue(msg.exception instanceof OpenSearchException);
public void openSearchExceptionWithNestedException() {
Throwable nestedThrowable = new IllegalArgumentException("Illegal Argument Exception");
Throwable openSearchThrowable = new RemoteTransportException("Remote Transport Exception", nestedThrowable);
ErrorMessage errorMessage = ErrorMessageFactory
.createErrorMessage(openSearchThrowable, RestStatus.INTERNAL_SERVER_ERROR.getStatus());
assertTrue(errorMessage.exception instanceof IllegalArgumentException);
assertEquals(RestStatus.BAD_REQUEST.getStatus(), errorMessage.getStatus());
}

@Test
public void nonOpenSearchExceptionWithMultiLayerWrappedEsExceptionCauseShouldCreateEsErrorMessage() {
Exception exception = new Exception(new Throwable(new Throwable(openSearchThrowable)));
ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus());
assertTrue(msg.exception instanceof OpenSearchException);
public void nonOpenSearchExceptionWithNestedException() {
Throwable nestedThrowable = new IllegalArgumentException("Illegal Argument Exception");
Throwable nonOpenSearchThrowable = new Exception("Remote Transport Exception", nestedThrowable);
ErrorMessage errorMessage = ErrorMessageFactory
.createErrorMessage(nonOpenSearchThrowable, RestStatus.INTERNAL_SERVER_ERROR.getStatus());
assertTrue(errorMessage.exception instanceof IllegalArgumentException);
assertEquals(RestStatus.INTERNAL_SERVER_ERROR.getStatus(), errorMessage.getStatus());
}
}

0 comments on commit 708bd78

Please sign in to comment.