diff --git a/src/main/java/org/opensearch/agent/tools/SearchIndexTool.java b/src/main/java/org/opensearch/agent/tools/SearchIndexTool.java index ff34718e..c7577501 100644 --- a/src/main/java/org/opensearch/agent/tools/SearchIndexTool.java +++ b/src/main/java/org/opensearch/agent/tools/SearchIndexTool.java @@ -12,6 +12,7 @@ import java.security.PrivilegedExceptionAction; import java.util.Map; import java.util.Objects; +import java.util.Optional; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; @@ -30,6 +31,7 @@ import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; +import com.google.gson.JsonElement; import com.google.gson.JsonObject; import lombok.Getter; @@ -48,7 +50,7 @@ public class SearchIndexTool implements Tool { public static final String TYPE = "SearchIndexTool"; private static final String DEFAULT_DESCRIPTION = - "Use this tool to search an index by providing two parameters: 'index' for the index name, and 'query' for the OpenSearch DSL formatted query."; + "Use this tool to search an index by providing two parameters: 'index' for the index name, and 'query' for the OpenSearch DSL formatted query. Only use this tool when a DSL query is available."; private String name = TYPE; @@ -90,23 +92,13 @@ public void run(Map parameters, ActionListener listener) try { String input = parameters.get(INPUT_FIELD); JsonObject jsonObject = StringUtils.gson.fromJson(input, JsonObject.class); - String index = jsonObject.get(INDEX_FIELD).getAsString(); - String query = jsonObject.get(QUERY_FIELD).toString(); - - SearchRequest searchRequest; - try { - searchRequest = getSearchRequest(index, query); - } catch (Exception e1) { - try { - // try different json parsing method - query = jsonObject.get(QUERY_FIELD).getAsString(); - searchRequest = getSearchRequest(index, query); - } catch (Exception e2) { - // try wrapped query - query = "{\"query\": " + query + "}"; - searchRequest = getSearchRequest(index, query); - } + String index = Optional.ofNullable(jsonObject).map(x -> x.get(INDEX_FIELD)).map(JsonElement::getAsString).orElse(null); + String query = Optional.ofNullable(jsonObject).map(x -> x.get(QUERY_FIELD)).map(JsonElement::toString).orElse(null); + if (index == null || query == null) { + listener.onFailure(new IllegalArgumentException("SearchIndexTool's two parameter: index and query are required!")); + return; } + SearchRequest searchRequest = getSearchRequest(index, query); ActionListener actionListener = ActionListener.wrap(r -> { SearchHit[] hits = r.getHits().getHits(); diff --git a/src/test/java/org/opensearch/agent/tools/SearchIndexToolTests.java b/src/test/java/org/opensearch/agent/tools/SearchIndexToolTests.java index 6de9fbfc..d228c0cb 100644 --- a/src/test/java/org/opensearch/agent/tools/SearchIndexToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/SearchIndexToolTests.java @@ -94,7 +94,7 @@ public void testRunWithNormalIndex() { @Test public void testRunWithConnectorIndex() { - String inputString = "{\"index\": \".plugins-ml-connector\", \"query\": {\"match_all\": {}}}"; + String inputString = "{\"index\": \".plugins-ml-connector\", \"query\": {\"query\": {\"match_all\": {}}}}"; Map parameters = Map.of("input", inputString); mockedSearchIndexTool.run(parameters, null); Mockito.verify(client, never()).search(any(), any()); @@ -103,7 +103,7 @@ public void testRunWithConnectorIndex() { @Test public void testRunWithModelIndex() { - String inputString = "{\"index\": \".plugins-ml-model\", \"query\": {\"match_all\": {}}}"; + String inputString = "{\"index\": \".plugins-ml-model\", \"query\": {\"query\": {\"match_all\": {}}}}"; Map parameters = Map.of("input", inputString); mockedSearchIndexTool.run(parameters, null); Mockito.verify(client, never()).search(any(), any()); @@ -112,7 +112,7 @@ public void testRunWithModelIndex() { @Test public void testRunWithModelGroupIndex() { - String inputString = "{\"index\": \".plugins-ml-model-group\", \"query\": {\"match_all\": {}}}"; + String inputString = "{\"index\": \".plugins-ml-model-group\", \"query\": {\"query\": {\"match_all\": {}}}}"; Map parameters = Map.of("input", inputString); mockedSearchIndexTool.run(parameters, null); Mockito.verify(client, never()).search(any(), any()); @@ -133,7 +133,7 @@ public void testRunWithSearchResults() { return null; }).when(client).search(any(), any()); - String inputString = "{\"index\": \"test-index\", \"query\": {\"match_all\": {}}}"; + String inputString = "{\"index\": \"test-index\", \"query\": {\"query\": {\"match_all\": {}}}}"; final CompletableFuture future = new CompletableFuture<>(); ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); Map parameters = Map.of("input", inputString); @@ -168,24 +168,13 @@ public void testRunWithInvalidQuery() { @Test public void testRunWithEmptyQueryBody() { - // this empty query should be parsed with jsonObject.get(QUERY_FIELD).getAsString(); - String inputString = "{\"index\": \"test-index\", \"query\": \"{}\"}"; + String inputString = "{\"index\": \"test-index\", \"query\": {}}"; Map parameters = Map.of("input", inputString); mockedSearchIndexTool.run(parameters, null); Mockito.verify(client, times(1)).search(any(), any()); Mockito.verify(client, Mockito.never()).execute(any(), any(), any()); } - @Test - public void testRunWithWrappedQuery() { - // this query should be wrapped liked "{\"query\": " + query + "}" - String inputString = "{\"index\": \".plugins-ml-model\", \"query\": {\"match_all\": {}}}"; - Map parameters = Map.of("input", inputString); - mockedSearchIndexTool.run(parameters, null); - Mockito.verify(client, never()).search(any(), any()); - Mockito.verify(client, times(1)).execute(eq(MLModelSearchAction.INSTANCE), any(), any()); - } - @Test public void testFactory() { SearchIndexTool searchIndexTool = SearchIndexTool.Factory.getInstance().create(Collections.emptyMap());