Skip to content

Commit

Permalink
fix: update search index tool parse logic (#100) (#107)
Browse files Browse the repository at this point in the history
* update search index tool parse logic



* Update src/main/java/org/opensearch/agent/tools/SearchIndexTool.java




* update parsing logic and unit tests



* Update src/main/java/org/opensearch/agent/tools/SearchIndexTool.java




* import classes in search index tool



---------




(cherry picked from commit da31992)

Signed-off-by: yuye-aws <[email protected]>
Signed-off-by: Yuye Zhu <[email protected]>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: zane-neo <[email protected]>
  • Loading branch information
3 people authored Jan 5, 2024
1 parent f5e15ac commit c73fc7a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 33 deletions.
26 changes: 9 additions & 17 deletions src/main/java/org/opensearch/agent/tools/SearchIndexTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.security.PrivilegedExceptionAction;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
Expand All @@ -30,6 +31,7 @@
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

import com.google.gson.JsonElement;
import com.google.gson.JsonObject;

import lombok.Getter;
Expand All @@ -48,7 +50,7 @@ public class SearchIndexTool implements Tool {

public static final String TYPE = "SearchIndexTool";
private static final String DEFAULT_DESCRIPTION =
"Use this tool to search an index by providing two parameters: 'index' for the index name, and 'query' for the OpenSearch DSL formatted query.";
"Use this tool to search an index by providing two parameters: 'index' for the index name, and 'query' for the OpenSearch DSL formatted query. Only use this tool when a DSL query is available.";

private String name = TYPE;

Expand Down Expand Up @@ -90,23 +92,13 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
try {
String input = parameters.get(INPUT_FIELD);
JsonObject jsonObject = StringUtils.gson.fromJson(input, JsonObject.class);
String index = jsonObject.get(INDEX_FIELD).getAsString();
String query = jsonObject.get(QUERY_FIELD).toString();

SearchRequest searchRequest;
try {
searchRequest = getSearchRequest(index, query);
} catch (Exception e1) {
try {
// try different json parsing method
query = jsonObject.get(QUERY_FIELD).getAsString();
searchRequest = getSearchRequest(index, query);
} catch (Exception e2) {
// try wrapped query
query = "{\"query\": " + query + "}";
searchRequest = getSearchRequest(index, query);
}
String index = Optional.ofNullable(jsonObject).map(x -> x.get(INDEX_FIELD)).map(JsonElement::getAsString).orElse(null);
String query = Optional.ofNullable(jsonObject).map(x -> x.get(QUERY_FIELD)).map(JsonElement::toString).orElse(null);
if (index == null || query == null) {
listener.onFailure(new IllegalArgumentException("SearchIndexTool's two parameter: index and query are required!"));
return;
}
SearchRequest searchRequest = getSearchRequest(index, query);

ActionListener<SearchResponse> actionListener = ActionListener.<SearchResponse>wrap(r -> {
SearchHit[] hits = r.getHits().getHits();
Expand Down
21 changes: 5 additions & 16 deletions src/test/java/org/opensearch/agent/tools/SearchIndexToolTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public void testRunWithNormalIndex() {

@Test
public void testRunWithConnectorIndex() {
String inputString = "{\"index\": \".plugins-ml-connector\", \"query\": {\"match_all\": {}}}";
String inputString = "{\"index\": \".plugins-ml-connector\", \"query\": {\"query\": {\"match_all\": {}}}}";
Map<String, String> parameters = Map.of("input", inputString);
mockedSearchIndexTool.run(parameters, null);
Mockito.verify(client, never()).search(any(), any());
Expand All @@ -103,7 +103,7 @@ public void testRunWithConnectorIndex() {

@Test
public void testRunWithModelIndex() {
String inputString = "{\"index\": \".plugins-ml-model\", \"query\": {\"match_all\": {}}}";
String inputString = "{\"index\": \".plugins-ml-model\", \"query\": {\"query\": {\"match_all\": {}}}}";
Map<String, String> parameters = Map.of("input", inputString);
mockedSearchIndexTool.run(parameters, null);
Mockito.verify(client, never()).search(any(), any());
Expand All @@ -112,7 +112,7 @@ public void testRunWithModelIndex() {

@Test
public void testRunWithModelGroupIndex() {
String inputString = "{\"index\": \".plugins-ml-model-group\", \"query\": {\"match_all\": {}}}";
String inputString = "{\"index\": \".plugins-ml-model-group\", \"query\": {\"query\": {\"match_all\": {}}}}";
Map<String, String> parameters = Map.of("input", inputString);
mockedSearchIndexTool.run(parameters, null);
Mockito.verify(client, never()).search(any(), any());
Expand All @@ -133,7 +133,7 @@ public void testRunWithSearchResults() {
return null;
}).when(client).search(any(), any());

String inputString = "{\"index\": \"test-index\", \"query\": {\"match_all\": {}}}";
String inputString = "{\"index\": \"test-index\", \"query\": {\"query\": {\"match_all\": {}}}}";
final CompletableFuture<String> future = new CompletableFuture<>();
ActionListener<String> listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); });
Map<String, String> parameters = Map.of("input", inputString);
Expand Down Expand Up @@ -168,24 +168,13 @@ public void testRunWithInvalidQuery() {

@Test
public void testRunWithEmptyQueryBody() {
// this empty query should be parsed with jsonObject.get(QUERY_FIELD).getAsString();
String inputString = "{\"index\": \"test-index\", \"query\": \"{}\"}";
String inputString = "{\"index\": \"test-index\", \"query\": {}}";
Map<String, String> parameters = Map.of("input", inputString);
mockedSearchIndexTool.run(parameters, null);
Mockito.verify(client, times(1)).search(any(), any());
Mockito.verify(client, Mockito.never()).execute(any(), any(), any());
}

@Test
public void testRunWithWrappedQuery() {
// this query should be wrapped liked "{\"query\": " + query + "}"
String inputString = "{\"index\": \".plugins-ml-model\", \"query\": {\"match_all\": {}}}";
Map<String, String> parameters = Map.of("input", inputString);
mockedSearchIndexTool.run(parameters, null);
Mockito.verify(client, never()).search(any(), any());
Mockito.verify(client, times(1)).execute(eq(MLModelSearchAction.INSTANCE), any(), any());
}

@Test
public void testFactory() {
SearchIndexTool searchIndexTool = SearchIndexTool.Factory.getInstance().create(Collections.emptyMap());
Expand Down

0 comments on commit c73fc7a

Please sign in to comment.