diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java index 80beb5fa9b..5340edba0f 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java @@ -62,4 +62,7 @@ public static void logException(String errorMessage, Exception e, Logger log) { } } + public static Throwable getRootCause(Throwable t) { + return ExceptionUtils.getRootCause(t); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/error/ErrorMessageFactory.java b/plugin/src/main/java/org/opensearch/ml/utils/error/ErrorMessageFactory.java index 30aace4be3..69a3c94abe 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/error/ErrorMessageFactory.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/error/ErrorMessageFactory.java @@ -6,6 +6,7 @@ package org.opensearch.ml.utils.error; import org.opensearch.OpenSearchException; +import org.opensearch.ml.utils.MLExceptionUtils; import lombok.experimental.UtilityClass; @@ -23,22 +24,9 @@ public static ErrorMessage createErrorMessage(Throwable e, int status) { int st = status; if (t instanceof OpenSearchException) { st = ((OpenSearchException) t).status().getStatus(); - } else { - t = unwrapCause(e); } + t = MLExceptionUtils.getRootCause(t); return new ErrorMessage(t, st); } - - protected static Throwable unwrapCause(Throwable t) { - Throwable result = t; - if (result instanceof OpenSearchException) { - return result; - } - if (result.getCause() == null) { - return result; - } - result = unwrapCause(result.getCause()); - return result; - } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteActionTests.java index 597ae57a8a..ac570a6a4d 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteActionTests.java @@ -30,6 +30,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; +import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; @@ -44,6 +45,7 @@ import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.RemoteTransportException; public class RestMLExecuteActionTests extends OpenSearchTestCase { @@ -206,4 +208,77 @@ public void testPrepareRequest_disabled() { when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); assertThrows(IllegalStateException.class, () -> restMLExecuteAction.handleRequest(request, channel, client)); } + + public void testPrepareRequestClientException() throws Exception { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new IllegalArgumentException("Illegal Argument Exception")); + return null; + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); + doNothing().when(channel).sendResponse(any()); + RestRequest request = getLocalSampleCalculatorRestRequest(); + restMLExecuteAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class); + verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any()); + Input input = argumentCaptor.getValue().getInput(); + assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName()); + ArgumentCaptor restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class); + verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture()); + BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue(); + assertEquals(RestStatus.BAD_REQUEST, response.status()); + String content = response.content().utf8ToString(); + String expectedError = + "{\"error\":{\"reason\":\"Invalid Request\",\"details\":\"Illegal Argument Exception\",\"type\":\"IllegalArgumentException\"},\"status\":400}"; + assertEquals(expectedError, response.content().utf8ToString()); + } + + public void testPrepareRequestClientWrappedException() throws Exception { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener + .onFailure( + new RemoteTransportException("Remote Transport Exception", new IllegalArgumentException("Illegal Argument Exception")) + ); + return null; + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); + doNothing().when(channel).sendResponse(any()); + RestRequest request = getLocalSampleCalculatorRestRequest(); + restMLExecuteAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class); + verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any()); + Input input = argumentCaptor.getValue().getInput(); + assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName()); + ArgumentCaptor restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class); + verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture()); + BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue(); + assertEquals(RestStatus.BAD_REQUEST, response.status()); + String expectedError = + "{\"error\":{\"reason\":\"Invalid Request\",\"details\":\"Illegal Argument Exception\",\"type\":\"IllegalArgumentException\"},\"status\":400}"; + assertEquals(expectedError, response.content().utf8ToString()); + } + + public void testPrepareRequestSystemException() throws Exception { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException("System Exception")); + return null; + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); + doNothing().when(channel).sendResponse(any()); + RestRequest request = getLocalSampleCalculatorRestRequest(); + restMLExecuteAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class); + verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any()); + Input input = argumentCaptor.getValue().getInput(); + assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName()); + ArgumentCaptor restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class); + verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture()); + BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue(); + assertEquals(RestStatus.INTERNAL_SERVER_ERROR, response.status()); + String expectedError = + "{\"error\":{\"reason\":\"System Error\",\"details\":\"System Exception\",\"type\":\"RuntimeException\"},\"status\":500}"; + assertEquals(expectedError, response.content().utf8ToString()); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/utils/error/ErrorMessageFactoryTests.java b/plugin/src/test/java/org/opensearch/ml/utils/error/ErrorMessageFactoryTests.java index 00f3da1b01..5acdb847be 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/error/ErrorMessageFactoryTests.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/error/ErrorMessageFactoryTests.java @@ -5,43 +5,41 @@ package org.opensearch.ml.utils.error; -import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import org.junit.Test; import org.opensearch.OpenSearchException; import org.opensearch.core.rest.RestStatus; +import org.opensearch.transport.RemoteTransportException; public class ErrorMessageFactoryTests { - private Throwable nonOpenSearchThrowable = new Throwable(); - private Throwable openSearchThrowable = new OpenSearchException(nonOpenSearchThrowable); - - @Test - public void openSearchExceptionShouldCreateEsErrorMessage() { - Exception exception = new OpenSearchException(nonOpenSearchThrowable); - ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus()); - assertTrue(msg.exception instanceof OpenSearchException); - } - @Test - public void nonOpenSearchExceptionShouldCreateGenericErrorMessage() { - Exception exception = new Exception(nonOpenSearchThrowable); - ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus()); - assertFalse(msg.exception instanceof OpenSearchException); + public void openSearchExceptionWithoutNestedException() { + Throwable openSearchThrowable = new OpenSearchException("OpenSearch Exception"); + ErrorMessage errorMessage = ErrorMessageFactory.createErrorMessage(openSearchThrowable, RestStatus.BAD_REQUEST.getStatus()); + assertTrue(errorMessage.exception instanceof OpenSearchException); + assertEquals(RestStatus.INTERNAL_SERVER_ERROR.getStatus(), errorMessage.getStatus()); } @Test - public void nonOpenSearchExceptionWithWrappedEsExceptionCauseShouldCreateEsErrorMessage() { - Exception exception = (Exception) openSearchThrowable; - ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus()); - assertTrue(msg.exception instanceof OpenSearchException); + public void openSearchExceptionWithNestedException() { + Throwable nestedThrowable = new IllegalArgumentException("Illegal Argument Exception"); + Throwable openSearchThrowable = new RemoteTransportException("Remote Transport Exception", nestedThrowable); + ErrorMessage errorMessage = ErrorMessageFactory + .createErrorMessage(openSearchThrowable, RestStatus.INTERNAL_SERVER_ERROR.getStatus()); + assertTrue(errorMessage.exception instanceof IllegalArgumentException); + assertEquals(RestStatus.BAD_REQUEST.getStatus(), errorMessage.getStatus()); } @Test - public void nonOpenSearchExceptionWithMultiLayerWrappedEsExceptionCauseShouldCreateEsErrorMessage() { - Exception exception = new Exception(new Throwable(new Throwable(openSearchThrowable))); - ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus()); - assertTrue(msg.exception instanceof OpenSearchException); + public void nonOpenSearchExceptionWithNestedException() { + Throwable nestedThrowable = new IllegalArgumentException("Illegal Argument Exception"); + Throwable nonOpenSearchThrowable = new Exception("Remote Transport Exception", nestedThrowable); + ErrorMessage errorMessage = ErrorMessageFactory + .createErrorMessage(nonOpenSearchThrowable, RestStatus.INTERNAL_SERVER_ERROR.getStatus()); + assertTrue(errorMessage.exception instanceof IllegalArgumentException); + assertEquals(RestStatus.INTERNAL_SERVER_ERROR.getStatus(), errorMessage.getStatus()); } }