diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java index 5439d73619..9dbffa918e 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.common.FunctionName.REMOTE; import static org.opensearch.ml.common.FunctionName.TEXT_EMBEDDING; import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage; +import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG; import java.util.ArrayList; import java.util.Arrays; @@ -51,6 +52,7 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -68,6 +70,7 @@ public class CreateControllerTransportAction extends HandledTransportAction wrappedListener = ActionListener.runBefore(actionListener, context::restore); mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { FunctionName functionName = mlModel.getAlgorithm(); diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java index dab8410ad0..552be4c342 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.common.FunctionName.REMOTE; import static org.opensearch.ml.common.FunctionName.TEXT_EMBEDDING; import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage; +import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG; import java.util.ArrayList; import java.util.Arrays; @@ -46,6 +47,7 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -62,6 +64,7 @@ public class UpdateControllerTransportAction extends HandledTransportAction wrappedListener = ActionListener.runBefore(actionListener, context::restore); mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { FunctionName functionName = mlModel.getAlgorithm(); diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index df9b01b9a9..809029a9d1 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -207,7 +207,6 @@ import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.processor.MLInferenceIngestProcessor; -import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; import org.opensearch.ml.rest.RestMLCreateConnectorAction; import org.opensearch.ml.rest.RestMLCreateControllerAction; import org.opensearch.ml.rest.RestMLDeleteAgentAction; @@ -370,7 +369,7 @@ public MachineLearningPlugin(Settings settings) { @Override public List> getActions() { - return ImmutableList + return List .of( new ActionHandler<>(MLStatsNodesAction.INSTANCE, MLStatsNodesTransportAction.class), new ActionHandler<>(MLExecuteTaskAction.INSTANCE, TransportExecuteTaskAction.class), @@ -654,7 +653,7 @@ public Collection createComponents( .getClusterSettings() .addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, it -> ragSearchPipelineEnabled = it); - return ImmutableList + return List .of( encryptor, mlEngine, @@ -736,9 +735,9 @@ public List getRestHandlers( RestMemorySearchInteractionsAction restSearchInteractionsAction = new RestMemorySearchInteractionsAction(); RestMemoryGetConversationAction restGetConversationAction = new RestMemoryGetConversationAction(); RestMemoryGetInteractionAction restGetInteractionAction = new RestMemoryGetInteractionAction(); - RestMLCreateControllerAction restMLCreateControllerAction = new RestMLCreateControllerAction(); + RestMLCreateControllerAction restMLCreateControllerAction = new RestMLCreateControllerAction(mlFeatureEnabledSetting); RestMLGetControllerAction restMLGetControllerAction = new RestMLGetControllerAction(); - RestMLUpdateControllerAction restMLUpdateControllerAction = new RestMLUpdateControllerAction(); + RestMLUpdateControllerAction restMLUpdateControllerAction = new RestMLUpdateControllerAction(mlFeatureEnabledSetting); RestMLDeleteControllerAction restMLDeleteControllerAction = new RestMLDeleteControllerAction(); RestMLGetAgentAction restMLGetAgentAction = new RestMLGetAgentAction(mlFeatureEnabledSetting); RestMLDeleteAgentAction restMLDeleteAgentAction = new RestMLDeleteAgentAction(mlFeatureEnabledSetting); @@ -749,7 +748,7 @@ public List getRestHandlers( RestMLListToolsAction restMLListToolsAction = new RestMLListToolsAction(toolFactories); RestMLGetToolAction restMLGetToolAction = new RestMLGetToolAction(toolFactories); RestMLGetConfigAction restMLGetConfigAction = new RestMLGetConfigAction(); - return ImmutableList + return List .of( restMLStatsAction, restMLTrainingAction, @@ -864,7 +863,7 @@ public List> getExecutorBuilders(Settings settings) { false ); - return ImmutableList + return List .of( generalThreadPool, registerModelThreadPool, @@ -878,7 +877,7 @@ public List> getExecutorBuilders(Settings settings) { @Override public List getNamedXContent() { - return ImmutableList + return List .of( KMeansParams.XCONTENT_REGISTRY, LinearRegressionParams.XCONTENT_REGISTRY, @@ -898,7 +897,7 @@ public List getNamedXContent() { @Override public List> getSettings() { - List> settings = ImmutableList + List> settings = List .of( MLCommonsSettings.ML_COMMONS_TASK_DISPATCH_POLICY, MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE, @@ -932,7 +931,8 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED, MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE, - MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED + MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED, + MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED ); return settings; } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateControllerAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateControllerAction.java index 6eb0041edd..8144080e1e 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateControllerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateControllerAction.java @@ -7,6 +7,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import static org.opensearch.ml.utils.RestActionUtils.getParameterId; @@ -20,6 +21,7 @@ import org.opensearch.ml.common.controller.MLController; import org.opensearch.ml.common.transport.controller.MLCreateControllerAction; import org.opensearch.ml.common.transport.controller.MLCreateControllerRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -29,11 +31,14 @@ public class RestMLCreateControllerAction extends BaseRestHandler { public final static String ML_CREATE_CONTROLLER_ACTION = "ml_create_controller_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; /** * Constructor */ - public RestMLCreateControllerAction() {} + public RestMLCreateControllerAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } @Override public String getName() { @@ -61,6 +66,10 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client * @return MLCreateControllerRequest */ private MLCreateControllerRequest getRequest(RestRequest request) throws IOException { + if (!mlFeatureEnabledSetting.isControllerEnabled()) { + throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG); + } + if (!request.hasContent()) { throw new OpenSearchParseException("Create model controller request has empty body"); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateControllerAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateControllerAction.java index 07fa1cc8a9..fd7966f31b 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateControllerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateControllerAction.java @@ -7,6 +7,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import static org.opensearch.ml.utils.RestActionUtils.getParameterId; @@ -20,6 +21,7 @@ import org.opensearch.ml.common.controller.MLController; import org.opensearch.ml.common.transport.controller.MLUpdateControllerAction; import org.opensearch.ml.common.transport.controller.MLUpdateControllerRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -29,11 +31,14 @@ public class RestMLUpdateControllerAction extends BaseRestHandler { public final static String ML_UPDATE_CONTROLLER_ACTION = "ml_update_controller_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; /** * Constructor */ - public RestMLUpdateControllerAction() {} + public RestMLUpdateControllerAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } @Override public String getName() { @@ -62,6 +67,10 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client * @throws IOException if an error occurs while parsing the request */ private MLUpdateControllerRequest getRequest(RestRequest request) throws IOException { + if (!mlFeatureEnabledSetting.isControllerEnabled()) { + throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG); + } + if (!request.hasContent()) { throw new OpenSearchParseException("Update model controller request has empty body"); } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 1e7a569a09..b3cfdf3bc6 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -197,4 +197,7 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED = Setting .boolSetting("plugins.ml_commons.connector.private_ip_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); + + public static final Setting ML_COMMONS_CONTROLLER_ENABLED = Setting + .boolSetting("plugins.ml_commons.controller_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic); } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java b/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java index e393b97d24..90eeb69543 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED; @@ -25,11 +26,15 @@ public class MLFeatureEnabledSetting { private volatile Boolean isLocalModelEnabled; private volatile AtomicBoolean isConnectorPrivateIpEnabled; + private volatile Boolean isModelRateLimiterEnabled; + private volatile Boolean isControllerEnabled; + public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) { isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings); isAgentFrameworkEnabled = ML_COMMONS_AGENT_FRAMEWORK_ENABLED.get(settings); isLocalModelEnabled = ML_COMMONS_LOCAL_MODEL_ENABLED.get(settings); isConnectorPrivateIpEnabled = new AtomicBoolean(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED.get(settings)); + isControllerEnabled = ML_COMMONS_CONTROLLER_ENABLED.get(settings); clusterService .getClusterSettings() @@ -41,6 +46,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) clusterService .getClusterSettings() .addSettingsUpdateConsumer(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED, it -> isConnectorPrivateIpEnabled.set(it)); + clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_CONTROLLER_ENABLED, it -> isControllerEnabled = it); } /** @@ -71,4 +77,12 @@ public AtomicBoolean isConnectorPrivateIpEnabled() { return isConnectorPrivateIpEnabled; } + /** + * Whether the controller feature is enabled. If disabled, APIs in ml-commons will block controller. + * @return whether the controller is enabled. + */ + public Boolean isControllerEnabled() { + return isControllerEnabled; + } + } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java index 68fee24fba..5340edba0f 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java @@ -24,6 +24,8 @@ public class MLExceptionUtils { "Local Model is currently disabled. To enable it, update the setting \"plugins.ml_commons.local_model.enabled\" to true."; public static final String AGENT_FRAMEWORK_DISABLED_ERR_MSG = "Agent Framework is currently disabled. To enable it, update the setting \"plugins.ml_commons.agent_framework_enabled\" to true."; + public static final String CONTROLLER_DISABLED_ERR_MSG = + "Controller is currently disabled. To enable it, update the setting \"plugins.ml_commons.controller_enabled\" to true."; public static String getRootCauseMessage(final Throwable throwable) { String message = ExceptionUtils.getRootCauseMessage(throwable); diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java index c9a4a1a6d5..f05d0df4cf 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java @@ -57,6 +57,8 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.rest.RestMLUpdateControllerAction; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -104,6 +106,9 @@ public class CreateControllerTransportActionTests extends OpenSearchTestCase { @Mock MLDeployControllerNodesResponse mlDeployControllerNodesResponse; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -138,7 +143,7 @@ public void setup() throws IOException { DiscoveryNodes nodes = DiscoveryNodes.builder().add(node1).add(node2).build(); String[] targetNodeIds = new String[] { node1.getId(), node2.getId() }; - + when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(true); createControllerTransportAction = spy( new CreateControllerTransportAction( transportService, @@ -148,7 +153,8 @@ public void setup() throws IOException { clusterService, modelAccessControlHelper, mlModelCacheHelper, - mlModelManager + mlModelManager, + mlFeatureEnabledSetting ) ); diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java index fd378647e9..ada473d085 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java @@ -58,6 +58,7 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -105,6 +106,9 @@ public class UpdateControllerTransportActionTests extends OpenSearchTestCase { @Mock MLDeployControllerNodesResponse mlDeployControllerNodesResponse; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -141,7 +145,7 @@ public void setup() throws IOException { DiscoveryNodes nodes = DiscoveryNodes.builder().add(node1).add(node2).build(); String[] targetNodeIds = new String[] { node1.getId(), node2.getId() }; - + when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(true); updateControllerTransportAction = spy( new UpdateControllerTransportAction( transportService, @@ -150,7 +154,8 @@ public void setup() throws IOException { clusterService, modelAccessControlHelper, mlModelCacheHelper, - mlModelManager + mlModelManager, + mlFeatureEnabledSetting ) ); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateControllerActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateControllerActionTests.java index 42b4cbe92c..94328ed1fe 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateControllerActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateControllerActionTests.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.util.HashMap; import java.util.List; @@ -34,6 +35,7 @@ import org.opensearch.ml.common.controller.MLController; import org.opensearch.ml.common.transport.controller.MLCreateControllerAction; import org.opensearch.ml.common.transport.controller.MLCreateControllerRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -50,6 +52,9 @@ public class RestMLCreateControllerActionTests extends OpenSearchTestCase { private NodeClient client; private ThreadPool threadPool; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock RestChannel channel; @@ -58,7 +63,8 @@ public void setup() { MockitoAnnotations.openMocks(this); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); - restMLCreateControllerAction = new RestMLCreateControllerAction(); + when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(true); + restMLCreateControllerAction = new RestMLCreateControllerAction(mlFeatureEnabledSetting); doAnswer(invocation -> { invocation.getArgument(2); return null; @@ -74,7 +80,7 @@ public void tearDown() throws Exception { @Test public void testConstructor() { - RestMLCreateControllerAction CreateModelAction = new RestMLCreateControllerAction(); + RestMLCreateControllerAction CreateModelAction = new RestMLCreateControllerAction(mlFeatureEnabledSetting); assertNotNull(CreateModelAction); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateControllerActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateControllerActionTests.java index 98ab0f1e73..ffc551f153 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateControllerActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateControllerActionTests.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.util.HashMap; import java.util.List; @@ -34,6 +35,7 @@ import org.opensearch.ml.common.controller.MLController; import org.opensearch.ml.common.transport.controller.MLUpdateControllerAction; import org.opensearch.ml.common.transport.controller.MLUpdateControllerRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -53,12 +55,16 @@ public class RestMLUpdateControllerActionTests extends OpenSearchTestCase { @Mock RestChannel channel; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Before public void setup() { MockitoAnnotations.openMocks(this); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); - restMLUpdateControllerAction = new RestMLUpdateControllerAction(); + when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(true); + restMLUpdateControllerAction = new RestMLUpdateControllerAction(mlFeatureEnabledSetting); doAnswer(invocation -> { invocation.getArgument(2); return null; @@ -74,7 +80,7 @@ public void tearDown() throws Exception { @Test public void testConstructor() { - RestMLUpdateControllerAction UpdateModelAction = new RestMLUpdateControllerAction(); + RestMLUpdateControllerAction UpdateModelAction = new RestMLUpdateControllerAction(mlFeatureEnabledSetting); assertNotNull(UpdateModelAction); }