Skip to content

Commit

Permalink
Add feature enable setting for controller index
Browse files Browse the repository at this point in the history
Signed-off-by: b4sjoo <[email protected]>
  • Loading branch information
b4sjoo committed Jul 16, 2024
1 parent b6618b2 commit fa3a54c
Show file tree
Hide file tree
Showing 12 changed files with 98 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.ml.common.FunctionName.REMOTE;
import static org.opensearch.ml.common.FunctionName.TEXT_EMBEDDING;
import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage;
import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -51,6 +52,7 @@
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand All @@ -68,6 +70,7 @@ public class CreateControllerTransportAction extends HandledTransportAction<Acti
ClusterService clusterService;
MLModelCacheHelper mlModelCacheHelper;
ModelAccessControlHelper modelAccessControlHelper;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public CreateControllerTransportAction(
Expand All @@ -78,7 +81,8 @@ public CreateControllerTransportAction(
ClusterService clusterService,
ModelAccessControlHelper modelAccessControlHelper,
MLModelCacheHelper mlModelCacheHelper,
MLModelManager mlModelManager
MLModelManager mlModelManager,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLCreateControllerAction.NAME, transportService, actionFilters, MLCreateControllerRequest::new);
this.mlIndicesHandler = mlIndicesHandler;
Expand All @@ -87,6 +91,7 @@ public CreateControllerTransportAction(
this.clusterService = clusterService;
this.mlModelCacheHelper = mlModelCacheHelper;
this.modelAccessControlHelper = modelAccessControlHelper;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
Expand All @@ -98,6 +103,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLCrea
String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD };

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
if (!mlFeatureEnabledSetting.isControllerEnabled()) {
throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG);
}
ActionListener<MLCreateControllerResponse> wrappedListener = ActionListener.runBefore(actionListener, context::restore);
mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> {
FunctionName functionName = mlModel.getAlgorithm();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.ml.common.FunctionName.REMOTE;
import static org.opensearch.ml.common.FunctionName.TEXT_EMBEDDING;
import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage;
import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -46,6 +47,7 @@
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand All @@ -62,6 +64,7 @@ public class UpdateControllerTransportAction extends HandledTransportAction<Acti
MLModelCacheHelper mlModelCacheHelper;
ClusterService clusterService;
ModelAccessControlHelper modelAccessControlHelper;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public UpdateControllerTransportAction(
Expand All @@ -71,14 +74,16 @@ public UpdateControllerTransportAction(
ClusterService clusterService,
ModelAccessControlHelper modelAccessControlHelper,
MLModelCacheHelper mlModelCacheHelper,
MLModelManager mlModelManager
MLModelManager mlModelManager,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLUpdateControllerAction.NAME, transportService, actionFilters, MLUpdateControllerRequest::new);
this.client = client;
this.mlModelManager = mlModelManager;
this.clusterService = clusterService;
this.mlModelCacheHelper = mlModelCacheHelper;
this.modelAccessControlHelper = modelAccessControlHelper;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
Expand All @@ -90,6 +95,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update
String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD };

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
if (!mlFeatureEnabledSetting.isControllerEnabled()) {
throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG);
}
ActionListener<UpdateResponse> wrappedListener = ActionListener.runBefore(actionListener, context::restore);
mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> {
FunctionName functionName = mlModel.getAlgorithm();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.processor.MLInferenceIngestProcessor;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList;
import org.opensearch.ml.rest.RestMLCreateConnectorAction;
import org.opensearch.ml.rest.RestMLCreateControllerAction;
import org.opensearch.ml.rest.RestMLDeleteAgentAction;
Expand Down Expand Up @@ -370,7 +369,7 @@ public MachineLearningPlugin(Settings settings) {

@Override
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
return ImmutableList
return List
.of(
new ActionHandler<>(MLStatsNodesAction.INSTANCE, MLStatsNodesTransportAction.class),
new ActionHandler<>(MLExecuteTaskAction.INSTANCE, TransportExecuteTaskAction.class),
Expand Down Expand Up @@ -654,7 +653,7 @@ public Collection<Object> createComponents(
.getClusterSettings()
.addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, it -> ragSearchPipelineEnabled = it);

return ImmutableList
return List
.of(
encryptor,
mlEngine,
Expand Down Expand Up @@ -736,9 +735,9 @@ public List<RestHandler> getRestHandlers(
RestMemorySearchInteractionsAction restSearchInteractionsAction = new RestMemorySearchInteractionsAction();
RestMemoryGetConversationAction restGetConversationAction = new RestMemoryGetConversationAction();
RestMemoryGetInteractionAction restGetInteractionAction = new RestMemoryGetInteractionAction();
RestMLCreateControllerAction restMLCreateControllerAction = new RestMLCreateControllerAction();
RestMLCreateControllerAction restMLCreateControllerAction = new RestMLCreateControllerAction(mlFeatureEnabledSetting);
RestMLGetControllerAction restMLGetControllerAction = new RestMLGetControllerAction();
RestMLUpdateControllerAction restMLUpdateControllerAction = new RestMLUpdateControllerAction();
RestMLUpdateControllerAction restMLUpdateControllerAction = new RestMLUpdateControllerAction(mlFeatureEnabledSetting);
RestMLDeleteControllerAction restMLDeleteControllerAction = new RestMLDeleteControllerAction();
RestMLGetAgentAction restMLGetAgentAction = new RestMLGetAgentAction(mlFeatureEnabledSetting);
RestMLDeleteAgentAction restMLDeleteAgentAction = new RestMLDeleteAgentAction(mlFeatureEnabledSetting);
Expand All @@ -749,7 +748,7 @@ public List<RestHandler> getRestHandlers(
RestMLListToolsAction restMLListToolsAction = new RestMLListToolsAction(toolFactories);
RestMLGetToolAction restMLGetToolAction = new RestMLGetToolAction(toolFactories);
RestMLGetConfigAction restMLGetConfigAction = new RestMLGetConfigAction();
return ImmutableList
return List
.of(
restMLStatsAction,
restMLTrainingAction,
Expand Down Expand Up @@ -864,7 +863,7 @@ public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
false
);

return ImmutableList
return List
.of(
generalThreadPool,
registerModelThreadPool,
Expand All @@ -878,7 +877,7 @@ public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {

@Override
public List<NamedXContentRegistry.Entry> getNamedXContent() {
return ImmutableList
return List
.of(
KMeansParams.XCONTENT_REGISTRY,
LinearRegressionParams.XCONTENT_REGISTRY,
Expand All @@ -898,7 +897,7 @@ public List<NamedXContentRegistry.Entry> getNamedXContent() {

@Override
public List<Setting<?>> getSettings() {
List<Setting<?>> settings = ImmutableList
List<Setting<?>> settings = List
.of(
MLCommonsSettings.ML_COMMONS_TASK_DISPATCH_POLICY,
MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE,
Expand Down Expand Up @@ -932,7 +931,8 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED,
MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED,
MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE,
MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED
MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED,
MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED
);
return settings;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
import static org.opensearch.ml.utils.RestActionUtils.getParameterId;

Expand All @@ -20,6 +21,7 @@
import org.opensearch.ml.common.controller.MLController;
import org.opensearch.ml.common.transport.controller.MLCreateControllerAction;
import org.opensearch.ml.common.transport.controller.MLCreateControllerRequest;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;
Expand All @@ -29,11 +31,14 @@
public class RestMLCreateControllerAction extends BaseRestHandler {

public final static String ML_CREATE_CONTROLLER_ACTION = "ml_create_controller_action";
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

/**
* Constructor
*/
public RestMLCreateControllerAction() {}
public RestMLCreateControllerAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
public String getName() {
Expand Down Expand Up @@ -61,6 +66,10 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
* @return MLCreateControllerRequest
*/
private MLCreateControllerRequest getRequest(RestRequest request) throws IOException {
if (!mlFeatureEnabledSetting.isControllerEnabled()) {
throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG);
}

if (!request.hasContent()) {
throw new OpenSearchParseException("Create model controller request has empty body");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
import static org.opensearch.ml.utils.RestActionUtils.getParameterId;

Expand All @@ -20,6 +21,7 @@
import org.opensearch.ml.common.controller.MLController;
import org.opensearch.ml.common.transport.controller.MLUpdateControllerAction;
import org.opensearch.ml.common.transport.controller.MLUpdateControllerRequest;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;
Expand All @@ -29,11 +31,14 @@
public class RestMLUpdateControllerAction extends BaseRestHandler {

public final static String ML_UPDATE_CONTROLLER_ACTION = "ml_update_controller_action";
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

/**
* Constructor
*/
public RestMLUpdateControllerAction() {}
public RestMLUpdateControllerAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
public String getName() {
Expand Down Expand Up @@ -62,6 +67,10 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
* @throws IOException if an error occurs while parsing the request
*/
private MLUpdateControllerRequest getRequest(RestRequest request) throws IOException {
if (!mlFeatureEnabledSetting.isControllerEnabled()) {
throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG);
}

if (!request.hasContent()) {
throw new OpenSearchParseException("Update model controller request has empty body");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,4 +197,7 @@ private MLCommonsSettings() {}

public static final Setting<Boolean> ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED = Setting
.boolSetting("plugins.ml_commons.connector.private_ip_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Boolean> ML_COMMONS_CONTROLLER_ENABLED = Setting
.boolSetting("plugins.ml_commons.controller_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;

Expand All @@ -25,11 +26,15 @@ public class MLFeatureEnabledSetting {
private volatile Boolean isLocalModelEnabled;
private volatile AtomicBoolean isConnectorPrivateIpEnabled;

private volatile Boolean isModelRateLimiterEnabled;
private volatile Boolean isControllerEnabled;

public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) {
isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings);
isAgentFrameworkEnabled = ML_COMMONS_AGENT_FRAMEWORK_ENABLED.get(settings);
isLocalModelEnabled = ML_COMMONS_LOCAL_MODEL_ENABLED.get(settings);
isConnectorPrivateIpEnabled = new AtomicBoolean(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED.get(settings));
isControllerEnabled = ML_COMMONS_CONTROLLER_ENABLED.get(settings);

clusterService
.getClusterSettings()
Expand All @@ -41,6 +46,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED, it -> isConnectorPrivateIpEnabled.set(it));
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_CONTROLLER_ENABLED, it -> isControllerEnabled = it);
}

/**
Expand Down Expand Up @@ -71,4 +77,12 @@ public AtomicBoolean isConnectorPrivateIpEnabled() {
return isConnectorPrivateIpEnabled;
}

/**
* Whether the controller feature is enabled. If disabled, APIs in ml-commons will block controller.
* @return whether the controller is enabled.
*/
public Boolean isControllerEnabled() {
return isControllerEnabled;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ public class MLExceptionUtils {
"Local Model is currently disabled. To enable it, update the setting \"plugins.ml_commons.local_model.enabled\" to true.";
public static final String AGENT_FRAMEWORK_DISABLED_ERR_MSG =
"Agent Framework is currently disabled. To enable it, update the setting \"plugins.ml_commons.agent_framework_enabled\" to true.";
public static final String CONTROLLER_DISABLED_ERR_MSG =
"Controller is currently disabled. To enable it, update the setting \"plugins.ml_commons.controller_enabled\" to true.";

public static String getRootCauseMessage(final Throwable throwable) {
String message = ExceptionUtils.getRootCauseMessage(throwable);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.rest.RestMLUpdateControllerAction;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
Expand Down Expand Up @@ -104,6 +106,9 @@ public class CreateControllerTransportActionTests extends OpenSearchTestCase {
@Mock
MLDeployControllerNodesResponse mlDeployControllerNodesResponse;

@Mock
MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

Expand Down Expand Up @@ -138,7 +143,7 @@ public void setup() throws IOException {

DiscoveryNodes nodes = DiscoveryNodes.builder().add(node1).add(node2).build();
String[] targetNodeIds = new String[] { node1.getId(), node2.getId() };

when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(true);
createControllerTransportAction = spy(
new CreateControllerTransportAction(
transportService,
Expand All @@ -148,7 +153,8 @@ public void setup() throws IOException {
clusterService,
modelAccessControlHelper,
mlModelCacheHelper,
mlModelManager
mlModelManager,
mlFeatureEnabledSetting
)
);

Expand Down
Loading

0 comments on commit fa3a54c

Please sign in to comment.