From fe74150815188d85b27ca3da41251bc52074da96 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Fri, 30 Aug 2024 10:56:28 -0700 Subject: [PATCH] use local_regex as default type for guardrails (#2853) (#2856) (#2860) * use local_regex as default type for guardrails * add UT for model type --------- (cherry picked from commit 7ecff1aaa17d8ee04fc0a408b49384624db0b93b) Co-authored-by: opensearch-trigger-bot[bot] <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> --- .../ml/common/model/Guardrails.java | 3 + .../ml/common/model/GuardrailsTests.java | 55 ++++++++++-- .../ml/rest/RestMLGuardrailsIT.java | 85 +++++++++++++++++++ 3 files changed, 138 insertions(+), 5 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/model/Guardrails.java b/common/src/main/java/org/opensearch/ml/common/model/Guardrails.java index db7558b7cc..c3030d98e6 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/Guardrails.java +++ b/common/src/main/java/org/opensearch/ml/common/model/Guardrails.java @@ -122,6 +122,9 @@ public static Guardrails parse(XContentParser parser) throws IOException { break; } } + if (type == null) { + type = "local_regex"; + } if (!validateType(type)) { throw new IllegalArgumentException("The type of guardrails is required, can not be null."); } diff --git a/common/src/test/java/org/opensearch/ml/common/model/GuardrailsTests.java b/common/src/test/java/org/opensearch/ml/common/model/GuardrailsTests.java index a1b589d07c..195e8353e1 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/GuardrailsTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/GuardrailsTests.java @@ -5,6 +5,11 @@ package org.opensearch.ml.common.model; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; + import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -18,15 +23,13 @@ import org.opensearch.ml.common.TestHelper; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Collections; -import java.util.List; - public class GuardrailsTests { StopWords stopWords; String[] regex; LocalRegexGuardrail inputLocalRegexGuardrail; LocalRegexGuardrail outputLocalRegexGuardrail; + ModelGuardrail inputModelGuardrail; + ModelGuardrail outputModelGuardrail; @Before public void setUp() { @@ -34,6 +37,8 @@ public void setUp() { regex = List.of("regex1").toArray(new String[0]); inputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); outputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); + inputModelGuardrail = new ModelGuardrail(Map.of("model_id", "guardrail_model_id", "response_validation_regex", "accept")); + outputModelGuardrail = new ModelGuardrail(Map.of("model_id", "guardrail_model_id", "response_validation_regex", "accept")); } @Test @@ -75,4 +80,44 @@ public void parse() throws IOException { Assert.assertEquals(guardrails.getInputGuardrail(), inputLocalRegexGuardrail); Assert.assertEquals(guardrails.getOutputGuardrail(), outputLocalRegexGuardrail); } -} \ No newline at end of file + + @Test + public void parseNonType() throws IOException { + String jsonStr = "{" + + "\"input_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}," + + "\"output_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + Guardrails guardrails = Guardrails.parse(parser); + + Assert.assertEquals(guardrails.getType(), "local_regex"); + Assert.assertEquals(guardrails.getInputGuardrail(), inputLocalRegexGuardrail); + Assert.assertEquals(guardrails.getOutputGuardrail(), outputLocalRegexGuardrail); + } + + @Test + public void parseModelType() throws IOException { + String jsonStr = "{\"type\":\"model\"," + + "\"input_guardrail\":{\"model_id\":\"guardrail_model_id\",\"response_validation_regex\":\"accept\"}," + + "\"output_guardrail\":{\"model_id\":\"guardrail_model_id\",\"response_validation_regex\":\"accept\"}}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + Guardrails guardrails = Guardrails.parse(parser); + + Assert.assertEquals(guardrails.getType(), "model"); + Assert.assertEquals(guardrails.getInputGuardrail(), inputModelGuardrail); + Assert.assertEquals(guardrails.getOutputGuardrail(), outputModelGuardrail); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java index 44533d3ae5..d2a0a7daf2 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java @@ -177,6 +177,31 @@ public void testPredictRemoteModelFailed() throws IOException, InterruptedExcept predictRemoteModel(modelId, predictInput); } + public void testPredictRemoteModelFailedNonType() throws IOException, InterruptedException { + // Skip test if key is null + if (OPENAI_KEY == null) { + return; + } + exceptionRule.expect(ResponseException.class); + exceptionRule.expectMessage("guardrails triggered for user input"); + Response response = createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelNonTypeGuardrails("openAI-GPT-3.5 completions", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test of stop word.\"\n" + " }\n" + "}"; + predictRemoteModel(modelId, predictInput); + } + public void testPredictRemoteModelSuccessWithModelGuardrail() throws IOException, InterruptedException { // Skip test if key is null if (OPENAI_KEY == null) { @@ -429,6 +454,66 @@ protected Response registerRemoteModelWithLocalRegexGuardrails(String name, Stri .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null); } + protected Response registerRemoteModelNonTypeGuardrails(String name, String connectorId) throws IOException { + String registerModelGroupEntity = "{\n" + + " \"name\": \"remote_model_group\",\n" + + " \"description\": \"This is an example description\"\n" + + "}"; + Response response = TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/model_groups/_register", + null, + TestHelper.toHttpEntity(registerModelGroupEntity), + null + ); + Map responseMap = parseResponseToMap(response); + assertEquals((String) responseMap.get("status"), "CREATED"); + String modelGroupId = (String) responseMap.get("model_group_id"); + + String registerModelEntity = "{\n" + + " \"name\": \"" + + name + + "\",\n" + + " \"function_name\": \"remote\",\n" + + " \"model_group_id\": \"" + + modelGroupId + + "\",\n" + + " \"version\": \"1.0.0\",\n" + + " \"description\": \"test model\",\n" + + " \"connector_id\": \"" + + connectorId + + "\",\n" + + " \"guardrails\": {\n" + + " \"input_guardrail\": {\n" + + " \"stop_words\": [\n" + + " {" + + " \"index_name\": \"stop_words\",\n" + + " \"source_fields\": [\"title\"]\n" + + " }" + + " ],\n" + + " \"regex\": [\"regex1\", \"regex2\"]\n" + + " },\n" + + " \"output_guardrail\": {\n" + + " \"stop_words\": [\n" + + " {" + + " \"index_name\": \"stop_words\",\n" + + " \"source_fields\": [\"title\"]\n" + + " }" + + " ],\n" + + " \"regex\": [\"regex1\", \"regex2\"]\n" + + " }\n" + + "},\n" + + " \"interface\": {\n" + + " \"input\": {},\n" + + " \"output\": {}\n" + + " }\n" + + "}"; + return TestHelper + .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null); + } + protected Response registerRemoteModelWithModelGuardrails(String name, String connectorId, String guardrailModelId) throws IOException { String registerModelGroupEntity = "{\n"