Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

populate time fields for connectors on return #2922

Merged
merged 8 commits into from
Oct 1, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ public abstract class AbstractConnector implements Connector {
protected User owner;
@Setter
protected AccessMode access;
@Setter
brianf-aws marked this conversation as resolved.
Show resolved Hide resolved
protected Instant createdTime;
@Setter
protected Instant lastUpdateTime;
@Setter
protected ConnectorClientConfig connectorClientConfig;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -44,6 +45,10 @@ public interface Connector extends ToXContentObject, Writeable {

String getProtocol();

void setCreatedTime(Instant createdTime);

void setLastUpdateTime(Instant lastUpdateTime);

User getOwner();

void setOwner(User user);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,22 @@ public void toXContent_InternalConnector() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mlModel.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
assertEquals(
"{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\","
+ "\"algorithm\":\"REMOTE\",\"model_version\":\"1.0.0\",\"description\":\"test model\","
+ "\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\","
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
+ "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"},"
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
+ "\"headers\":{\"api_key\":\"${credential.key}\"},"
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}],"
+ "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\","
+ "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000,"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}}",
mlModelContent
);

String expectedConnectorResponse = "{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\","
+ "\"algorithm\":\"REMOTE\",\"model_version\":\"1.0.0\",\"description\":\"test model\","
+ "\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\","
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
+ "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"},"
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
+ "\"headers\":{\"api_key\":\"${credential.key}\"},"
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}],"
+ "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\","
+ "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000,"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}}";

assertEquals(expectedConnectorResponse, mlModelContent);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ public void toXContent() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
connector.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);

Assert.assertEquals(TEST_CONNECTOR_JSON_STRING, content);
}

@Test
public void constructor_Parser() throws IOException {

XContentParser parser = XContentType.JSON
.xContent()
.createParser(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,21 @@ public void toXContentTest() throws IOException {
mlConnectorGetResponse.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(builder);
String jsonStr = builder.toString();
assertEquals(
"{\"name\":\"test_connector_name\",\"version\":\"1\","
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
+ "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"},"
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
+ "\"headers\":{\"api_key\":\"${credential.key}\"},"
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}],"
+ "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\","
+ "\"client_config\":{\"max_connection\":30,"
+ "\"connection_timeout\":30000,\"read_timeout\":30000,"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}",
jsonStr
);

String expectedControllerResponse = "{\"name\":\"test_connector_name\",\"version\":\"1\","
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
+ "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"},"
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
+ "\"headers\":{\"api_key\":\"${credential.key}\"},"
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}],"
+ "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\","
+ "\"client_config\":{\"max_connection\":30,"
+ "\"connection_timeout\":30000,\"read_timeout\":30000,"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}";

assertEquals(expectedControllerResponse, jsonStr);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ public class MLUpdateModelInputTest {
+ "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":"
+ "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":"
+ "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":"
+ "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":"
+ "\"test-connector_id\",\"last_updated_time\":1}";
+ "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},"
+ "\"connector_id\":\"test-connector_id\",\"last_updated_time\":1}";

private final String expectedOutputStr = "{\"model_id\":null,\"name\":\"name\",\"description\":\"description\",\"model_group_id\":"
+ "\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":"
Expand Down Expand Up @@ -152,6 +152,7 @@ public void readInputStreamSuccessWithNullFields() throws IOException {
@Test
public void testToXContent() throws Exception {
String jsonStr = serializationWithToXContent(updateModelInput);

assertEquals(expectedOutputStrForUpdateRequestDoc, jsonStr);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,40 @@ public void testToXContent() throws Exception {
input.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(builder);
String jsonStr = builder.toString();
assertEquals(expectedInputStr, jsonStr);

String expectedFunctionInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\","
+ "\"version\":\"version\",\"model_group_id\":\"modelGroupId\",\"description\":\"test description\","
+ "\"url\":\"url\",\"model_content_hash_value\":\"hash_value_test\",\"model_format\":\"ONNX\","
+ "\"model_config\":{\"model_type\":\"testModelType\",\"embedding_dimension\":100,"
+ "\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\","
+ "\\\"field2\\\":\\\"value2\\\"}\"},\"deploy_model\":true,\"model_node_ids\":[\"modelNodeIds\"],"
+ "\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\","
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
+ "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"},"
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
+ "\"headers\":{\"api_key\":\"${credential.key}\"},"
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}],"
+ "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\","
+ "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000,"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}},\"is_hidden\":false}";

assertEquals(expectedFunctionInputStr, jsonStr);
}

@Test
public void testToXContent_Incomplete() throws Exception {
input.setUrl(null);
input.setModelConfig(null);
input.setModelFormat(null);
input.setModelNodeIds(null);
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
input.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(builder);

String jsonStr = builder.toString();

String expectedIncompleteInputStr = "{\"function_name\":\"LINEAR_REGRESSION\","
+ "\"name\":\"modelName\",\"version\":\"version\",\"model_group_id\":\"modelGroupId\","
+ "\"description\":\"test description\",\"model_content_hash_value\":\"hash_value_test\","
Expand All @@ -186,14 +215,7 @@ public void testToXContent_Incomplete() throws Exception {
+ "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\","
+ "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000,"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}},\"is_hidden\":false}";
input.setUrl(null);
input.setModelConfig(null);
input.setModelFormat(null);
input.setModelNodeIds(null);
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
input.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(builder);
String jsonStr = builder.toString();

assertEquals(expectedIncompleteInputStr, jsonStr);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;

import java.time.Instant;
import java.util.HashSet;
import java.util.List;

Expand Down Expand Up @@ -135,6 +136,11 @@ private void indexConnector(Connector connector, ActionListener<MLCreateConnecto
}, listener::onFailure);

IndexRequest indexRequest = new IndexRequest(ML_CONNECTOR_INDEX);

Instant currentTime = Instant.now();
connector.setCreatedTime(currentTime);
connector.setLastUpdateTime(currentTime);

indexRequest.source(connector.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS));
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.index(indexRequest, ActionListener.runBefore(indexResponseListener, context::restore));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -36,6 +37,7 @@
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.connector.HttpConnector;
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest;
import org.opensearch.ml.engine.MLEngine;
Expand Down Expand Up @@ -93,6 +95,11 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update
if (Boolean.TRUE.equals(hasPermission)) {
connector.update(mlUpdateConnectorAction.getUpdateContent(), mlEngine::encrypt);
connector.validateConnectorURL(trustedConnectorEndpointsRegex);

if (connector instanceof HttpConnector && ((HttpConnector) connector).getLastUpdateTime() != null) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only HttpConnector? Can't we just check connector.getLastUpdateTime() != null

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't want to bloat the connector interface with a getter as it already has two getters. Also AbstractConnector has the time fields not the connector interface

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if getLastUpdateTime is null? Don't we want to update those values also?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the concern if we just add connector.setLastUpdateTime(Instant.now()); without any checks?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the updates should be built-in (but overridable) defaults on the AbstractConnector. May not need a setter for it as it should generally done automatically.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @dhrubo-os, @dbwiddis

This method will update connectors when it does this my theory is that deserializes the connector from an index and constructs the HTTPConnector object when this happens it is very hard to control the update fields (on instantiation).

If we decide to add the time fields in the abstract Connector constructor we have a problem between old connectors and new ones. This results in a unexpected bug when connectors that didn't have these fields get brought to life; every subsequent update on the connector makes the time fields a new number every time.

Now if we wanted old connectors to not have the time fields (what our users expect when they update a old connector) they wont see the time fields because this method checks if an update should work on time only if they have the field. When a connector thats been created with the code change, they will see time fields and will update accordingly well.

Please let me know if you are still confused I would be happy to explain more in depth :)

connector.setLastUpdateTime(Instant.now());
}

UpdateRequest updateRequest = new UpdateRequest(ML_CONNECTOR_INDEX, connectorId);
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
updateRequest.doc(connector.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@

import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.time.Instant;
import java.util.*;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's generally frowned on to use star imports.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, yeah will fix that


import org.apache.lucene.search.TotalHits;
import org.junit.Before;
Expand Down Expand Up @@ -202,6 +200,116 @@ public void setup() throws IOException {
}).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class));
}

@Test
public void testUpdateConnectorDoesNotUpdateHttpConnectorTimeFields() {
HttpConnector connector = HttpConnector
.builder()
.name("test")
.protocol("http")
.version("1")
.credential(Map.of("api_key", "credential_value"))
.parameters(Map.of("param1", "value1"))
.actions(
Arrays
.asList(
ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("https://api.openai.com/v1/chat/completions")
.headers(Map.of("Authorization", "Bearer ${credential.api_key}"))
.requestBody("{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }")
.build()
)
)
.build();

assertNull(connector.getCreatedTime());
assertNull(connector.getLastUpdateTime());

doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class));

doAnswer(invocation -> {
ActionListener<Connector> listener = invocation.getArgument(2);
listener.onResponse(connector);
return null;
}).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(searchResponse);
return null;
}).when(client).search(any(SearchRequest.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(1);
listener.onResponse(updateResponse);
return null;
}).when(client).update(any(UpdateRequest.class), isA(ActionListener.class));

updateConnectorTransportAction.doExecute(task, updateRequest, actionListener);

assertNull(connector.getLastUpdateTime());
}

@Test
public void testUpdateConnectorUpdatesHttpConnectorTimeFields() {
HttpConnector connector = HttpConnector
.builder()
.name("test")
.protocol("http")
.version("1")
.credential(Map.of("api_key", "credential_value"))
.parameters(Map.of("param1", "value1"))
.actions(
Arrays
.asList(
ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("https://api.openai.com/v1/chat/completions")
.headers(Map.of("Authorization", "Bearer ${credential.api_key}"))
.requestBody("{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }")
.build()
)
)
.build();

Instant testInitialTime = Instant.now();
connector.setCreatedTime(testInitialTime);
connector.setLastUpdateTime(testInitialTime);

assert (connector.getCreatedTime().toEpochMilli() == connector.getLastUpdateTime().toEpochMilli());

doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class));

doAnswer(invocation -> {
ActionListener<Connector> listener = invocation.getArgument(2);
listener.onResponse(connector);
return null;
}).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(searchResponse);
return null;
}).when(client).search(any(SearchRequest.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(1);
listener.onResponse(updateResponse);
return null;
}).when(client).update(any(UpdateRequest.class), isA(ActionListener.class));

updateConnectorTransportAction.doExecute(task, updateRequest, actionListener);

assertTrue(
"Last update time must be bigger than the creation time",
connector.getLastUpdateTime().toEpochMilli() >= connector.getCreatedTime().toEpochMilli()
);
}

@Test
public void testExecuteConnectorAccessControlSuccess() {
doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class));
Expand Down
Loading