Skip to content

Commit

Permalink
add previous key for PPL tool (#131)
Browse files Browse the repository at this point in the history
* add previous key for PPL tool

Signed-off-by: xinyual <[email protected]>

* remove useless log

Signed-off-by: xinyual <[email protected]>

* add UT for new logic

Signed-off-by: xinyual <[email protected]>

* change previous key logic

Signed-off-by: xinyual <[email protected]>

* change error message

Signed-off-by: xinyual <[email protected]>

* apply spotless

Signed-off-by: xinyual <[email protected]>

* remove useless

Signed-off-by: xinyual <[email protected]>

* fix ut

Signed-off-by: xinyual <[email protected]>

* fix IT

Signed-off-by: xinyual <[email protected]>

* fix logic

Signed-off-by: xinyual <[email protected]>

* update default prompt

Signed-off-by: xinyual <[email protected]>

* remove useless logic

Signed-off-by: xinyual <[email protected]>

---------

Signed-off-by: xinyual <[email protected]>
  • Loading branch information
xinyual authored Feb 5, 2024
1 parent 9174e4c commit c2a1fed
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 9 deletions.
51 changes: 46 additions & 5 deletions src/main/java/org/opensearch/agent/tools/PPLTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ public class PPLTool implements Tool {
@Setter
private Client client;

private static final String DEFAULT_DESCRIPTION = "Use this tool to generate PPL and execute.";
private static final String DEFAULT_DESCRIPTION =
"\"Use this tool when user ask question based on the data in the cluster or parse user statement about which index to use in a conversion.\nAlso use this tool when question only contains index information.\n1. If uesr question contain both question and index name, the input parameters are {'question': UserQuestion, 'index': IndexName}.\n2. If user question contain only question, the input parameter is {'question': UserQuestion}.\n3. If uesr question contain only index name, find the original human input from the conversation histroy and formulate parameter as {'question': UserQuestion, 'index': IndexName}\nThe index name should be exactly as stated in user's input.";

@Setter
@Getter
Expand All @@ -91,6 +92,8 @@ public class PPLTool implements Tool {

private PPLModelType pplModelType;

private String previousToolKey;

private static Gson gson = new Gson();

private static Map<String, String> DEFAULT_PROMPT_DICT;
Expand Down Expand Up @@ -143,7 +146,7 @@ public static PPLModelType from(String value) {

}

public PPLTool(Client client, String modelId, String contextPrompt, String pplModelType, boolean execute) {
public PPLTool(Client client, String modelId, String contextPrompt, String pplModelType, String previousToolKey, boolean execute) {
this.client = client;
this.modelId = modelId;
this.pplModelType = PPLModelType.from(pplModelType);
Expand All @@ -152,13 +155,19 @@ public PPLTool(Client client, String modelId, String contextPrompt, String pplMo
} else {
this.contextPrompt = contextPrompt;
}
this.previousToolKey = previousToolKey;
this.execute = execute;
}

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
parameters = extractFromChatParameters(parameters);
String indexName = parameters.get("index");
String indexName = getIndexNameFromParameters(parameters);
if (StringUtils.isBlank(indexName)) {
throw new IllegalArgumentException(
"Return this final answer to human directly and do not use other tools: 'Please provide index name'. Please try to directly send this message to human to ask for index name"
);
}
String question = parameters.get("question");
if (StringUtils.isBlank(indexName) || StringUtils.isBlank(question)) {
throw new IllegalArgumentException("Parameter index and question can not be null or empty.");
Expand All @@ -181,7 +190,7 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
client.search(searchRequest, ActionListener.<SearchResponse>wrap(searchResponse -> {
SearchHit[] searchHits = searchResponse.getHits().getHits();
String tableInfo = constructTableInfo(searchHits, mappings);
String prompt = constructPrompt(tableInfo, question, indexName);
String prompt = constructPrompt(tableInfo, question.strip(), indexName);
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
.builder()
.parameters(Collections.singletonMap("prompt", prompt))
Expand Down Expand Up @@ -236,7 +245,17 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
));
}, e -> {
log.info("fail to get mapping: " + e);
listener.onFailure(e);
String errorMessage = e.getMessage();
if (errorMessage.contains("no such index")) {
listener
.onFailure(
new IllegalArgumentException(
"Return this final answer to human directly and do not use other tools: 'Please provide index name'. Please try to directly send this message to human to ask for index name"
)
);
} else {
listener.onFailure(e);
}
}));
}

Expand Down Expand Up @@ -287,6 +306,7 @@ public PPLTool create(Map<String, Object> map) {
(String) map.get("model_id"),
(String) map.getOrDefault("prompt", ""),
(String) map.getOrDefault("model_type", ""),
(String) map.getOrDefault("previous_tool_name", ""),
Boolean.valueOf((String) map.getOrDefault("execute", "true"))
);
}
Expand Down Expand Up @@ -436,6 +456,7 @@ private String parseOutput(String llmOutput, String indexName) {
ppl = matcher.group(1).replaceAll("[\\r\\n]", "").replaceAll("ISNOTNULL", "isnotnull").trim();
} else { // logic for only ppl returned
int sourceIndex = llmOutput.indexOf("source=");
int describeIndex = llmOutput.indexOf("describe ");
if (sourceIndex != -1) {
llmOutput = llmOutput.substring(sourceIndex);

Expand All @@ -447,6 +468,17 @@ private String parseOutput(String llmOutput, String indexName) {
lists[0] = "source=" + indexName;
}

// Joining the string back together
ppl = String.join("|", lists);
} else if (describeIndex != -1) {
llmOutput = llmOutput.substring(describeIndex);
String[] lists = llmOutput.split("\\|");

// Modifying the first element
if (lists.length > 0) {
lists[0] = "describe " + indexName;
}

// Joining the string back together
ppl = String.join("|", lists);
} else {
Expand All @@ -458,6 +490,15 @@ private String parseOutput(String llmOutput, String indexName) {
return ppl;
}

private String getIndexNameFromParameters(Map<String, String> parameters) {
String indexName = parameters.getOrDefault("index", "");
if (!StringUtils.isBlank(this.previousToolKey) && StringUtils.isBlank(indexName)) {
indexName = parameters.getOrDefault(this.previousToolKey + ".output", ""); // read index name from previous key
}
return indexName;

}

private static Map<String, String> loadDefaultPromptDict() throws IOException {
InputStream searchResponseIns = PPLTool.class.getResourceAsStream("PPLDefaultPrompt.json");
if (searchResponseIns != null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"CLAUDE": "\n\nHuman:You will be given a question about some metrics from a user.\nUse context provided to write a PPL query that can be used to retrieve the information.\n\nHere is a sample PPL query:\nsource=\\`<index>\\` | where \\`<field>\\` = '\\`<value>\\`'\n\nHere are some sample questions and the PPL query to retrieve the information. The format for fields is\n\\`\\`\\`\n- field_name: field_type (sample field value)\n\\`\\`\\`\n\nFor example, below is a field called \\`timestamp\\`, it has a field type of \\`date\\`, and a sample value of it could look like \\`1686000665919\\`.\n\\`\\`\\`\n- timestamp: date (1686000665919)\n\\`\\`\\`\n----------------\n\nThe following text contains fields and questions/answers for the 'accounts' index\n\nFields:\n- account_number: long (101)\n- address: text ('880 Holmes Lane')\n- age: long (32)\n- balance: long (39225)\n- city: text ('Brogan')\n- email: text ('[email protected]')\n- employer: text ('Pyrami')\n- firstname: text ('Amber')\n- gender: text ('M')\n- lastname: text ('Duke')\n- state: text ('IL')\n- registered_at: date (1686000665919)\n\nQuestion: Give me some documents in index 'accounts'\nPPL: source=\\`accounts\\` | head\n\nQuestion: Give me 5 oldest people in index 'accounts'\nPPL: source=\\`accounts\\` | sort -age | head 5\n\nQuestion: Give me first names of 5 youngest people in index 'accounts'\nPPL: source=\\`accounts\\` | sort +age | head 5 | fields \\`firstname\\`\n\nQuestion: Give me some addresses in index 'accounts'\nPPL: source=\\`accounts\\` | fields \\`address\\`\n\nQuestion: Find the documents in index 'accounts' where firstname is 'Hattie'\nPPL: source=\\`accounts\\` | where \\`firstname\\` = 'Hattie'\n\nQuestion: Find the emails where firstname is 'Hattie' or lastname is 'Frank' in index 'accounts'\nPPL: source=\\`accounts\\` | where \\`firstname\\` = 'Hattie' OR \\`lastname\\` = 'frank' | fields \\`email\\`\n\nQuestion: Find the documents in index 'accounts' where firstname is not 'Hattie' and lastname is not 'Frank'\nPPL: source=\\`accounts\\` | where \\`firstname\\` != 'Hattie' AND \\`lastname\\` != 'frank'\n\nQuestion: Find the emails that contain '.com' in index 'accounts'\nPPL: source=\\`accounts\\` | where QUERY_STRING(['email'], '.com') | fields \\`email\\`\n\nQuestion: Find the documents in index 'accounts' where there is an email\nPPL: source=\\`accounts\\` | where ISNOTNULL(\\`email\\`)\n\nQuestion: Count the number of documents in index 'accounts'\nPPL: source=\\`accounts\\` | stats COUNT() AS \\`count\\`\n\nQuestion: Count the number of people with firstnaQuestion: Count the number of people withe=\\`accounts\\` | where \\`firstname\\` ='Amber' | stats COUNT() AS \\`count\\`\n\nQuestion: How many people are older than 33? index is 'accounts'\nPPL: source=\\`accounts\\` | where \\`age\\` > 33 | stats COUNT() AS \\`count\\`\n\nQuestion: How many distinct ages? index is 'accounts'\nPPL: source=\\`accounts\\` | stats DISTINCT_COUNT(age) AS \\`distinct_count\\`\n\nQuestion: How many males and females in index 'accounts'?\nPPL: source=\\`accounts\\` | stats COUNT() AS \\`count\\` BY \\`gender\\`\n\nQuestion: What is the average, minimum, maximum age in 'accounts' index?\nPPL: source=\\`accounts\\` | stats AVG(\\`age\\`) AS \\`avg_age\\`, MIN(\\`age\\`) AS \\`min_age\\`, MAX(\\`age\\`) AS \\`max_age\\`\n\nQuestion: Show all states sorted by average balance. index is 'accounts'\nPPL: source=\\`accounts\\` | stats AVG(\\`balance\\`) AS \\`avg_balance\\` BY \\`state\\` | sort +avg_balance\n\n----------------\n\nThe following text contains fields and questions/answers for the 'ecommerce' index\n\nFields:\n- category: text ('Men's Clothing')\n- currency: keyword ('EUR')\n- customer_birth_date: date (null)\n- customer_first_name: text ('Eddie')\n- customer_full_name: text ('Eddie Underwood')\n- customer_gender: keyword ('MALE')\n- customer_id: keyword ('38')\n- customer_last_name: text ('Underwood')\n- customer_phone: keyword ('')\n- day_of_week: keyword ('Monday')\n- day_of_week_i: integer (0)\n- email: keyword ('[email protected]')\n- event.dataset: keyword ('sample_ecommerce')\n- geoip.city_name: keyword ('Cairo')\n- geoip.continent_name: keyword ('Africa')\n- geoip.country_iso_code: keyword ('EG')\n- geoip.location: geo_point ([object Object])\n- geoip.region_name: keyword ('Cairo Governorate')\n- manufacturer: text ('Elitelligence,Oceanavigations')\n- order_date: date (2023-06-05T09:28:48+00:00)\n- order_id: keyword ('584677')\n- products._id: text (null)\n- products.base_price: half_float (null)\n- products.base_unit_price: half_float (null)\n- products.category: text (null)\n- products.created_on: date (null)\n- products.discount_amount: half_float (null)\n- products.discount_percentage: half_float (null)\n- products.manufacturer: text (null)\n- products.min_price: half_float (null)\n- products.price: half_float (null)\n- products.product_id: long (null)\n- products.product_name: text (null)\n- products.quantity: integer (null)\n- products.sku: keyword (null)\n- products.tax_amount: half_float (null)\n- products.taxful_price: half_float (null)\n- products.taxless_price: half_float (null)\n- products.unit_discount_amount: half_float (null)\n- sku: keyword ('ZO0549605496,ZO0299602996')\n- taxful_total_price: half_float (36.98)\n- taxless_total_price: half_float (36.98)\n- total_quantity: integer (2)\n- total_unique_products: integer (2)\n- type: keyword ('order')\n- user: keyword ('eddie')\n\nQuestion: What is the average price of products in clothing category ordered in the last 7 days? index is 'ecommerce'\nPPL: source=\\`ecommerce\\` | where QUERY_STRING(['category'], 'clothing') AND \\`order_date\\` > DATE_SUB(NOW(), INTERVAL 7 DAY) | stats AVG(\\`taxful_total_price\\`) AS \\`avg_price\\`\n\nQuestion: What is the average price of products in each city ordered today by every 2 hours? index is 'ecommerce'\nPPL: source=\\`ecommerce\\` | where \\`order_date\\` > DATE_SUB(NOW(), INTERVAL 24 HOUR) | stats AVG(\\`taxful_total_price\\`) AS \\`avg_price\\` by SPAN(\\`order_date\\`, 2h) AS \\`span\\`, \\`geoip.city_name\\`\n\nQuestion: What is the total revenue of shoes each day in this week? index is 'ecommerce'\nPPL: source=\\`ecommerce\\` | where QUERY_STRING(['category'], 'shoes') AND \\`order_date\\` > DATE_SUB(NOW(), INTERVAL 1 WEEK) | stats SUM(\\`taxful_total_price\\`) AS \\`revenue\\` by SPAN(\\`order_date\\`, 1d) AS \\`span\\`\n\n----------------\n\nThe following text contains fields and questions/answers for the 'events' index\nFields:\n- timestamp: long (1686000665919)\n- attributes.data_stream.dataset: text ('nginx.access')\n- attributes.data_stream.namespace: text ('production')\n- attributes.data_stream.type: text ('logs')\n- body: text ('172.24.0.1 - - [02/Jun/2023:23:09:27 +0000] 'GET / HTTP/1.1' 200 4955 '-' 'Mozilla/5.0 zgrab/0.x'')\n- communication.source.address: text ('127.0.0.1')\n- communication.source.ip: text ('172.24.0.1')\n- container_id: text (null)\n- container_name: text (null)\n- event.category: text ('web')\n- event.domain: text ('nginx.access')\n- event.kind: text ('event')\n- event.name: text ('access')\n- event.result: text ('success')\n- event.type: text ('access')\n- http.flavor: text ('1.1')\n- http.request.method: text ('GET')\n- http.response.bytes: long (4955)\n- http.response.status_code: keyword ('200')\n- http.url: text ('/')\n- log: text (null)\n- observerTime: date (1686000665919)\n- source: text (null)\n- span_id: text ('abcdef1010')\n- trace_id: text ('102981ABCD2901')\n\nQuestion: What are recent logs with errors and contains word 'test'? index is 'events'\nPPL: source=\\`events\\` | where QUERY_STRING(['http.response.status_code'], '4* OR 5*') AND QUERY_STRING(['body'], 'test') AND \\`observerTime\\` > DATE_SUB(NOW(), INTERVAL 5 MINUTE)\n\nQuestion: What is the total number of log with a status code other than 200 in 2023 Feburary? index is 'events'\nPPL: source=\\`events\\` | where QUERY_STRING(['http.response.status_code'], '!200') AND \\`observerTime\\` >= '2023-03-01 00:00:00' AND \\`observerTime\\` < '2023-04-01 00:00:00' | stats COUNT() AS \\`count\\`\n\nQuestion: Count the number of business days that have web category logs last week? index is 'events'\nPPL: source=\\`events\\` | where \\`category\\` = 'web' AND \\`observerTime\\` > DATE_SUB(NOW(), INTERVAL 1 WEEK) AND DAY_OF_WEEK(\\`observerTime\\`) >= 2 AND DAY_OF_WEEK(\\`observerTime\\`) <= 6 | stats DISTINCT_COUNT(DATE_FORMAT(\\`observerTime\\`, 'yyyy-MM-dd')) AS \\`distinct_count\\`\n\nQuestion: What are the top traces with largest bytes? index is 'events'\nPPL: source=\\`events\\` | stats SUM(\\`http.response.bytes\\`) AS \\`sum_bytes\\` by \\`trace_id\\` | sort -sum_bytes | head\n\nQuestion: Give me log patterns? index is 'events'\nPPL: source=\\`events\\` | patterns \\`body\\` | stats take(\\`body\\`, 1) AS \\`sample_pattern\\` by \\`patterns_field\\` | fields \\`sample_pattern\\`\n\nQuestion: Give me log patterns for logs with errors? index is 'events'\nPPL: source=\\`events\\` | where QUERY_STRING(['http.response.status_code'], '4* OR 5*') | patterns \\`body\\` | stats take(\\`body\\`, 1) AS \\`sample_pattern\\` by \\`patterns_field\\` | fields \\`sample_pattern\\`\n\n----------------\n\nUse the following steps to generate the PPL query:\n\nStep 1. Find all field entities in the question.\n\nStep 2. Pick the fields that are relevant to the question from the provided fields list using entities. Rules:\n#01 Consider the field name, the field type, and the sample value when picking relevant fields. For example, if you need to filter flights departed from 'JFK', look for a \\`text\\` or \\`keyword\\` field with a field name such as 'departedAirport', and the sample value should be a 3 letter IATA airport code. Similarly, if you need a date field, look for a relevant field name with type \\`date\\` and not \\`long\\`.\n#02 You must pick a field with \\`date\\` type when filtering on date/time.\n#03 You must pick a field with \\`date\\` type when aggregating by time interval.\n#04 You must not use the sample value in PPL query, unless it is relevant to the question.\n#05 You must only pick fields that are relevant, and must pick the whole field name from the fields list.\n#06 You must not use fields that are not in the fields list.\n#07 You must not use the sample values unless relevant to the question.\n#08 You must pick the field that contains a log line when asked about log patterns. Usually it is one of \\`log\\`, \\`body\\`, \\`message\\`.\n\nStep 3. Use the choosen fields to write the PPL query. Rules:\n#01 Always use comparisons to filter date/time, eg. 'where \\`timestamp\\` > DATE_SUB(NOW(), INTERVAL 1 DAY)'; or by absolute time: 'where \\`timestamp\\` > 'yyyy-MM-dd HH:mm:ss'', eg. 'where \\`timestamp\\` < '2023-01-01 00:00:00''. Do not use \\`DATE_FORMAT()\\`.\n#02 Only use PPL syntax and keywords appeared in the question or in the examples.\n#03 If user asks for current or recent status, filter the time field for last 5 minutes.\n#04 The field used in 'SPAN(\\`<field>\\`, <interval>)' must have type \\`date\\`, not \\`long\\`.\n#05 When aggregating by \\`SPAN\\` and another field, put \\`SPAN\\` after \\`by\\` and before the other field, eg. 'stats COUNT() AS \\`count\\` by SPAN(\\`timestamp\\`, 1d) AS \\`span\\`, \\`category\\`'.\n#06 You must put values in quotes when filtering fields with \\`text\\` or \\`keyword\\` field type.\n#07 To find documents that contain certain phrases in string fields, use \\`QUERY_STRING\\` which supports multiple fields and wildcard, eg. 'where QUERY_STRING(['field1', 'field2'], 'prefix*')'.\n#08 To find 4xx and 5xx errors using status code, if the status code field type is numberic (eg. \\`integer\\`), then use 'where \\`status_code\\` >= 400'; if the field is a string (eg. \\`text\\` or \\`keyword\\`), then use 'where QUERY_STRING(['status_code'], '4* OR 5*')'.\n\n----------------\nPut your PPL query in <ppl> tags.\n----------------\nQuestion: ${indexInfo.question}? index is \\`${indexInfo.indexName}\\`\nFields:\n${indexInfo.mappingInfo}\n\nAssistant:",
"FINETUNE": "Below is an instruction that describes a task, paired with the index and corresponding fields that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nI have an opensearch index with fields in the following. Now I have a question: ${indexInfo.question} Can you help me generate a PPL for that?\n\n### Index:\n${indexInfo.indexName}\n\n### Fields:\n${indexInfo.mappingInfo}\n\n### Response:\n"
"FINETUNE": "Below is an instruction that describes a task, paired with the index and corresponding fields that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nI have an opensearch index with fields in the following. Now I have a question: ${indexInfo.question}. Can you help me generate a PPL for that?\n\n### Index:\n${indexInfo.indexName}\n\n### Fields:\n${indexInfo.mappingInfo}\n\n### Response:\n"
}
Loading

0 comments on commit c2a1fed

Please sign in to comment.