diff --git a/common/src/main/java/org/opensearch/ml/common/transport/load/LoadModelResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/load/LoadModelResponse.java index 852bcc22f7..38ca6b09cd 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/load/LoadModelResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/load/LoadModelResponse.java @@ -5,6 +5,7 @@ package org.opensearch.ml.common.transport.load; +import lombok.Getter; import org.opensearch.action.ActionResponse; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; @@ -13,6 +14,7 @@ import java.io.IOException; +@Getter public class LoadModelResponse extends ActionResponse implements ToXContentObject { public static final String TASK_ID_FIELD = "task_id"; public static final String STATUS_FIELD = "status"; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload/MLUploadInput.java b/common/src/main/java/org/opensearch/ml/common/transport/upload/MLUploadInput.java index 2f952d351d..dd640bee60 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload/MLUploadInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload/MLUploadInput.java @@ -54,6 +54,8 @@ public class MLUploadInput implements ToXContentObject, Writeable { public MLUploadInput(FunctionName functionName, String modelName, String version, String url, MLModelFormat modelFormat, MLModelConfig modelConfig, boolean loadModel, String[] modelNodeIds) { if (functionName == null) { this.functionName = FunctionName.TEXT_EMBEDDING; + } else { + this.functionName = functionName; } if (modelName == null) { throw new IllegalArgumentException("model name is null"); @@ -61,6 +63,15 @@ public MLUploadInput(FunctionName functionName, String modelName, String version if (version == null) { throw new IllegalArgumentException("model version is null"); } + if (modelFormat == null) { + throw new IllegalArgumentException("model format is null"); + } + if (modelConfig == null) { + throw new IllegalArgumentException("model config is null"); + } + if (url == null) { + throw new IllegalArgumentException("model file url is null"); + } this.modelName = modelName; this.version = version; this.url = url; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload/UploadModelResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/upload/UploadModelResponse.java index 169dc6450d..50057c2d9c 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload/UploadModelResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload/UploadModelResponse.java @@ -5,6 +5,7 @@ package org.opensearch.ml.common.transport.upload; +import lombok.Getter; import org.opensearch.action.ActionResponse; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; @@ -13,6 +14,7 @@ import java.io.IOException; +@Getter public class UploadModelResponse extends ActionResponse implements ToXContentObject { public static final String TASK_ID_FIELD = "task_id"; public static final String STATUS_FIELD = "status"; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java index 11036385f9..d109dbbef8 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java @@ -91,9 +91,6 @@ private void verifyModelZipFile(String modelZipFilePath) throws IOException { hasModelFile = true; } if (fileName.equals(TOKENIZER_FILE_NAME)) { - if (hasTokenizerFile) { - throw new IllegalArgumentException("Find multiple tokenizer files"); - } hasTokenizerFile = true; } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/KMeans.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/KMeans.java index 87c230c9ce..8eaa4d347c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/KMeans.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/KMeans.java @@ -9,7 +9,6 @@ import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DataFrameBuilder; import org.opensearch.ml.common.dataset.DataFrameInputDataset; -import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; import org.opensearch.ml.common.FunctionName; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/FileUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/FileUtils.java index 6a2fec8bc0..fdcac1dc61 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/FileUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/FileUtils.java @@ -23,6 +23,7 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Queue; import java.util.Set; /** @@ -92,13 +93,13 @@ public static void write(byte[] data, File destinationFile, boolean append) thro /** * Merge files into one big file. - * @param files array of files + * @param files chunk files * @param mergedFile merged file */ - public static void mergeFiles(File[] files, File mergedFile) { + public static void mergeFiles(Queue files, File mergedFile) { boolean failed = false; - for (int i = 0; i< files.length ; i++) { - File f = files[i]; + while (!files.isEmpty()) { + File f = files.poll(); try (InputStream inStream = new BufferedInputStream(new FileInputStream(f))) { if (!failed) { int fileLength = (int) f.length(); @@ -108,11 +109,11 @@ public static void mergeFiles(File[] files, File mergedFile) { write(fileContent, mergedFile, true); } org.apache.commons.io.FileUtils.deleteQuietly(f); - if (i == files.length - 1) { + if (files.isEmpty()) { org.apache.commons.io.FileUtils.deleteQuietly(f.getParentFile()); } } catch (IOException e) { - log.error("Failed to merge file " + f.getAbsolutePath() + " to " + mergedFile.getAbsolutePath()); + log.error("Failed to merge file " + f.getAbsolutePath() + " to " + mergedFile.getAbsolutePath(), e); failed = true; } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java b/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java index 1defb734a7..64d697a93b 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java @@ -7,7 +7,7 @@ import java.time.Instant; import java.util.Arrays; -import java.util.List; +import java.util.Set; import lombok.extern.log4j.Log4j2; @@ -105,7 +105,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener workNodes = mlTaskManager.getWorkNodes(taskId); + Set workNodes = mlTaskManager.getWorkNodes(taskId); if (workNodes != null) { workNodes.remove(workerNodeId); } @@ -114,6 +114,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener 1 && workerNodes.length > 0) { - log.debug("sync model routing to other nodes. model loaded on nodes: {}", Arrays.toString(workerNodes)); - MLSyncUpInput syncUpInput = MLSyncUpInput - .builder() - .addedWorkerNodes(ImmutableMap.of(modelId, workerNodes)) - .build(); - MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(allNodes, syncUpInput); - client - .execute( - MLSyncUpAction.INSTANCE, - syncUpRequest, - ActionListener - .wrap(r -> { log.debug("Sync up successfully"); }, e -> { log.error("Failed to sync up", e); }) - ); - } + syncModelWorkerNodes(modelId); } ImmutableMap.Builder builder = ImmutableMap.builder(); builder.put(MLTask.STATE_FIELD, taskState); @@ -181,4 +166,20 @@ protected void doExecute(Task task, ActionRequest request, ActionListener 1 && workerNodes.length > 0) { + log.debug("Sync to other nodes about worker nodes of model {}: {}", modelId, Arrays.toString(workerNodes)); + MLSyncUpInput syncUpInput = MLSyncUpInput.builder().addedWorkerNodes(ImmutableMap.of(modelId, workerNodes)).build(); + MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(allNodes, syncUpInput); + client + .execute( + MLSyncUpAction.INSTANCE, + syncUpRequest, + ActionListener.wrap(r -> log.debug("Sync up successfully"), e -> log.error("Failed to sync up", e)) + ); + } + } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/load/TransportLoadModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/load/TransportLoadModelAction.java index a28f7a2f54..5fd35db7b7 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/load/TransportLoadModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/load/TransportLoadModelAction.java @@ -5,7 +5,7 @@ package org.opensearch.ml.action.load; -import static org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.LOAD_THREAD_POOL; import java.time.Instant; import java.util.ArrayList; @@ -163,7 +163,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + threadPool.executor(LOAD_THREAD_POOL).execute(() -> { LoadModelInput loadModelInput = new LoadModelInput( modelId, taskId, diff --git a/plugin/src/main/java/org/opensearch/ml/action/profile/MLProfileTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/profile/MLProfileTransportAction.java index 82a71cc65d..7defefed1a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/profile/MLProfileTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/profile/MLProfileTransportAction.java @@ -117,7 +117,10 @@ private MLProfileNodeResponse createMLProfileNodeResponse(MLProfileRequest mlPro Arrays.stream(mlModelManager.getAllModelIds()).forEach(modelId -> { if (mlProfileInput.isReturnAllModels() || (!mlProfileInput.emptyModels() && targetModelIds.contains(modelId))) { log.debug("Runtime model profile is found for model {}", modelId); - mlLocalModels.put(modelId, mlModelManager.getModelProfile(modelId)); + MLModelProfile modelProfile = mlModelManager.getModelProfile(modelId); + if (modelProfile != null) { + mlLocalModels.put(modelId, modelProfile); + } } }); diff --git a/plugin/src/main/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.java b/plugin/src/main/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.java index 251ee3ea30..a4e2210634 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.java @@ -141,7 +141,7 @@ private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest loadM for (Map.Entry> entry : modelRoutingTable.entrySet()) { log.debug("latest routing table for model: {}: {}", entry.getKey(), entry.getValue().toArray(new String[0])); } - mlModelManager.syncModelRouting(modelRoutingTable); + mlModelManager.syncModelWorkerNodes(modelRoutingTable); } if (syncUpInput.isSyncRunningLoadModelTasks()) { @@ -166,7 +166,9 @@ private void cleanUpLocalCacheFiles() { Arrays.toString(modelsInCacheFolder.toArray(new String[0])) ); for (String modelId : modelsInCacheFolder) { - if (!mlTaskManager.contains(modelId) && !mlTaskManager.containsModel(modelId) && !mlModelManager.containsModel(modelId)) { + if (!mlTaskManager.contains(modelId) + && !mlTaskManager.containsModel(modelId) + && !mlModelManager.isModelRunningOnNode(modelId)) { log.info("ML model not in cache. Remove all of its cache files. model id: {}", modelId); deleteFileCache(modelId); } diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java index 72821f7a72..75dafd94a5 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java @@ -5,7 +5,7 @@ package org.opensearch.ml.cluster; -import static org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS; import lombok.extern.log4j.Log4j2; @@ -62,7 +62,7 @@ public void onClusterManager() { private void startSyncModelRoutingCron() { if (jobInterval > 0) { syncModelRoutingCron = threadPool - .scheduleWithFixedDelay(new MLSyncUpCron(client, nodeHelper), TimeValue.timeValueSeconds(jobInterval), TASK_THREAD_POOL); + .scheduleWithFixedDelay(new MLSyncUpCron(client, nodeHelper), TimeValue.timeValueSeconds(jobInterval), GENERAL_THREAD_POOL); } else { log.debug("Stop ML syncup job as its interval is: {}", jobInterval); } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java index 6b0896b698..964d7e4f33 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java @@ -5,190 +5,84 @@ package org.opensearch.ml.model; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT; - import java.util.DoubleSummaryStatistics; -import java.util.HashSet; -import java.util.Map; import java.util.Queue; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.stream.DoubleStream; +import lombok.AccessLevel; +import lombok.Getter; +import lombok.Setter; import lombok.extern.log4j.Log4j2; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Settings; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.engine.Predictable; -import org.opensearch.ml.profile.MLModelProfile; import org.opensearch.ml.profile.MLPredictRequestStats; import com.google.common.math.Quantiles; @Log4j2 public class MLModelCache { - private final ClusterService clusterService; - - /** - * All model state. Contains both build-in algo model and custom model. - */ - private final Map modelStates; - private final Map predictors; - private final Map> modelRoutingTable;// routingTable - private final Map> modelInferenceDuration; - private final Map modelFunctionNames; - private volatile Long maxRequestCount; - - public MLModelCache(ClusterService clusterService, Settings settings) { - this.clusterService = clusterService; - this.modelStates = new ConcurrentHashMap<>(); - this.predictors = new ConcurrentHashMap<>(); - this.modelRoutingTable = new ConcurrentHashMap<>(); - this.modelInferenceDuration = new ConcurrentHashMap<>(); - this.modelFunctionNames = new ConcurrentHashMap<>(); - - maxRequestCount = ML_COMMONS_MONITORING_REQUEST_COUNT.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MONITORING_REQUEST_COUNT, it -> maxRequestCount = it); - } + private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) MLModelState modelState; + private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) FunctionName functionName; + private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Predictable predictor; + private final Set workerNodes; + private final Queue inferenceDurationQueue; - public synchronized boolean hasModel(String modelId) { - return predictors.containsKey(modelId); + public MLModelCache() { + workerNodes = ConcurrentHashMap.newKeySet(); + inferenceDurationQueue = new ConcurrentLinkedQueue<>(); } - public synchronized boolean isModelLoaded(String modelId) { - MLModelState mlModelState = modelStates.get(modelId); - if (mlModelState == MLModelState.LOADED) { - return true; - } - return false; + public void removeWorkerNode(String nodeId) { + workerNodes.remove(nodeId); } - public synchronized void initModelState(String modelId, MLModelState state, FunctionName functionName) { - if (modelStates.containsKey(modelId)) { - throw new IllegalArgumentException("Duplicate model task"); - } - modelStates.put(modelId, state); - modelFunctionNames.put(modelId, functionName); + public void removeWorkerNodes(Set removedNodes) { + workerNodes.removeAll(removedNodes); } - public synchronized void setModelState(String modelId, MLModelState state) { - if (!modelStates.containsKey(modelId)) { - throw new IllegalArgumentException("Model not found in cache"); - } - modelStates.put(modelId, state); + public void addWorkerNode(String nodeId) { + workerNodes.add(nodeId); } - public void removeModelState(String modelId) { - modelStates.remove(modelId); - modelFunctionNames.remove(modelId); + public String[] getWorkerNodes() { + return workerNodes.toArray(new String[0]); } - public void removeWorkNodes(Set removedNodes) { - for (Map.Entry> entry : modelRoutingTable.entrySet()) { - Set nodes = entry.getValue(); - nodes.removeAll(removedNodes); - } + public void syncWorkerNode(Set workerNodes) { + this.workerNodes.clear(); + this.workerNodes.addAll(workerNodes); } - public synchronized void addPredictable(String modelId, Predictable predictable) { - this.predictors.put(modelId, predictable); + public void clearWorkerNodes() { + workerNodes.clear(); } - public synchronized void addNodeToModelRoutingTable(String modelId, String nodeId) { - if (!modelRoutingTable.containsKey(modelId)) { - ConcurrentHashMap map = new ConcurrentHashMap<>(); - Set set = map.newKeySet(); - modelRoutingTable.put(modelId, set); + public void clear() { + modelState = null; + functionName = null; + workerNodes.clear(); + inferenceDurationQueue.clear(); + if (predictor != null) { + predictor.close(); } - log.debug("add node {} to model routing table for model: {}", nodeId, modelId); - modelRoutingTable.get(modelId).add(nodeId); } - public synchronized void removeNodeFromModelRoutingTable(String modelId, String nodeId) { - if (!modelRoutingTable.containsKey(modelId)) { - log.debug("model {} not found in cache", modelId); - return; - } - log.debug("remove node {} from model routing table of model {}", nodeId, modelId); - modelRoutingTable.get(modelId).remove(nodeId); - if (modelRoutingTable.get(modelId).size() == 0) { - log.debug("remove model {} from model routing table as no node running it", modelId); - modelRoutingTable.remove(modelId); + public void addInferenceDuration(double duration, long maxRequestCount) { + while (inferenceDurationQueue.size() >= maxRequestCount) { + inferenceDurationQueue.poll(); } + this.inferenceDurationQueue.add(duration); } - public void removeModel(String modelId) { - this.modelStates.remove(modelId); - this.modelFunctionNames.remove(modelId); - modelInferenceDuration.remove(modelId); - Predictable predictable = this.predictors.remove(modelId); - if (predictable != null) { - predictable.close(); - } - log.debug("remove model state and predictable model {}", modelId); - removeNodeFromModelRoutingTable(modelId, clusterService.localNode().getId()); - } - - public String[] getWorkerNodes(String modelId) { - Set nodes = modelRoutingTable.get(modelId); - if (nodes == null) { - return null; - } - return nodes.toArray(new String[0]); - } - - public Predictable getPredictable(String modelId) { - return predictors.get(modelId); - } - - public synchronized int modelCount() { - return modelStates.size(); - } - - public String[] getLoadedModels() { - return predictors.keySet().toArray(new String[0]); - } - - public void syncModelRouting(Map> modelRoutingTable) { - log.debug("sync model routing for model"); - Set currentModels = new HashSet(this.modelRoutingTable.keySet()); - this.modelRoutingTable.putAll(modelRoutingTable); - currentModels.removeAll(modelRoutingTable.keySet()); - if (currentModels.size() > 0) { - currentModels.forEach(k -> this.modelRoutingTable.remove(k)); - } - } - - public void clearRoutingTable() { - log.debug("clear routing table"); - this.modelRoutingTable.clear(); - } - - public String[] getAllModelIds() { - Set modelIds = new HashSet<>(); - modelIds.addAll(this.modelStates.keySet()); - modelIds.addAll(this.predictors.keySet()); - modelIds.addAll(this.modelRoutingTable.keySet()); - return modelIds.toArray(new String[0]); - } - - public MLModelProfile getModelProfile(String modelId) { - MLModelProfile.MLModelProfileBuilder builder = MLModelProfile.builder().modelState(modelStates.get(modelId)); - Predictable predictable = predictors.get(modelId); - if (predictable != null) { - builder.predictor(predictable.toString()); - } - Set nodes = modelRoutingTable.get(modelId); - if (nodes != null && nodes.size() > 0) { - builder.workerNodes(nodes.toArray(new String[0])); - } - Queue queue = modelInferenceDuration.get(modelId); - if (queue != null && queue.size() > 0) { + public MLPredictRequestStats getInferenceStats() { + if (inferenceDurationQueue.size() > 0) { MLPredictRequestStats.MLPredictRequestStatsBuilder statsBuilder = MLPredictRequestStats.builder(); - DoubleStream doubleStream = queue.stream().mapToDouble(v -> v); + DoubleStream doubleStream = inferenceDurationQueue.stream().mapToDouble(v -> v); DoubleSummaryStatistics doubleSummaryStatistics = doubleStream.summaryStatistics(); statsBuilder.count(doubleSummaryStatistics.getCount()); statsBuilder.max(doubleSummaryStatistics.getMax()); @@ -196,29 +90,16 @@ public MLModelProfile getModelProfile(String modelId) { statsBuilder.average(doubleSummaryStatistics.getAverage()); Quantiles.Scale percentiles = Quantiles.percentiles(); - statsBuilder.p50(percentiles.index(50).compute(queue)); - statsBuilder.p90(percentiles.index(90).compute(queue)); - statsBuilder.p99(percentiles.index(99).compute(queue)); - - builder.predictStats(statsBuilder.build()); - } - return builder.build(); - } + statsBuilder.p50(percentiles.index(50).compute(inferenceDurationQueue)); + statsBuilder.p90(percentiles.index(90).compute(inferenceDurationQueue)); + statsBuilder.p99(percentiles.index(99).compute(inferenceDurationQueue)); - public void addInferenceDuration(String modelId, double duration) { - log.debug("add duration of model {}: {}ms", modelId, duration); - Queue queue = modelInferenceDuration.computeIfAbsent(modelId, it -> new ConcurrentLinkedQueue<>()); - while (queue.size() >= maxRequestCount) { - queue.poll(); + return statsBuilder.build(); } - queue.add(duration); - } - - public FunctionName getModelFunctionName(String modelId) { - return modelFunctionNames.get(modelId); + return null; } - public boolean containsModel(String modelId) { - return modelStates.containsKey(modelId); + public boolean isValidCache() { + return modelState != null || workerNodes.size() > 0; } } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java new file mode 100644 index 0000000000..e7006256ee --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -0,0 +1,298 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.model; + +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT; + +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; + +import lombok.extern.log4j.Log4j2; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.engine.Predictable; +import org.opensearch.ml.profile.MLModelProfile; +import org.opensearch.ml.profile.MLPredictRequestStats; + +@Log4j2 +public class MLModelCacheHelper { + private final Map modelCaches; + private volatile Long maxRequestCount; + + public MLModelCacheHelper(ClusterService clusterService, Settings settings) { + this.modelCaches = new ConcurrentHashMap<>(); + + maxRequestCount = ML_COMMONS_MONITORING_REQUEST_COUNT.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MONITORING_REQUEST_COUNT, it -> maxRequestCount = it); + } + + /** + * Initialize model state. + * @param modelId model id + * @param state model state + * @param functionName function name + */ + public synchronized void initModelState(String modelId, MLModelState state, FunctionName functionName) { + if (modelCaches.containsKey(modelId)) { + throw new IllegalArgumentException("Duplicate model task"); + } + log.debug("init model state for model {}, state: {}", modelId, state); + MLModelCache modelCache = new MLModelCache(); + modelCache.setModelState(state); + modelCache.setFunctionName(functionName); + modelCaches.put(modelId, modelCache); + } + + /** + * Set model state + * @param modelId model id + * @param state model state + */ + public synchronized void setModelState(String modelId, MLModelState state) { + log.debug("Updating State of Model {} to state {}", modelId, state); + getExistingModelCache(modelId).setModelState(state); + } + + /** + * Check if model loaded on node. + * @param modelId model id + * @return true if model loaded + */ + public synchronized boolean isModelLoaded(String modelId) { + MLModelCache modelCache = modelCaches.get(modelId); + return modelCache != null && modelCache.getModelState() == MLModelState.LOADED; + } + + /** + * Get loaded models on node. + * @return array of model id + */ + public String[] getLoadedModels() { + return modelCaches + .entrySet() + .stream() + .filter(entry -> entry.getValue().getModelState() == MLModelState.LOADED) + .map(entry -> entry.getKey()) + .collect(Collectors.toList()) + .toArray(new String[0]); + } + + /** + * Check if model is running on node. + * @param modelId model id + * @return true if model is running on node. + */ + public boolean isModelRunningOnNode(String modelId) { + MLModelCache modelCache = modelCaches.get(modelId); + return modelCache != null && modelCache.getModelState() != null; + } + + /** + * Set predictor of model. + * @param modelId model id + * @param predictor predictor + */ + public synchronized void setPredictor(String modelId, Predictable predictor) { + MLModelCache modelCache = getExistingModelCache(modelId); + modelCache.setPredictor(predictor); + } + + /** + * Get predictor of model. + * @param modelId model id + * @return predictor + */ + public Predictable getPredictor(String modelId) { + MLModelCache modelCache = modelCaches.get(modelId); + if (modelCache == null) { + return null; + } + return modelCache.getPredictor(); + } + + /** + * Remove model. + * @param modelId model id + */ + public void removeModel(String modelId) { + MLModelCache modelCache = modelCaches.get(modelId); + if (modelCache != null) { + log.debug("removing model {} from cache", modelId); + modelCache.clear(); + modelCaches.remove(modelId); + } + } + + /** + * Get all model IDs in model cache. + * @return array of model id + */ + public String[] getAllModels() { + return modelCaches.keySet().toArray(new String[0]); + } + + /** + * Get worker nodes of model. + * @param modelId model id + * @return array of node id; return null if model not exists in cache + */ + public String[] getWorkerNodes(String modelId) { + MLModelCache modelCache = modelCaches.get(modelId); + if (modelCache == null) { + return null; + } + return modelCache.getWorkerNodes(); + } + + /** + * Add worker node of model. + * @param modelId model id + * @param nodeId node id + */ + public synchronized void addWorkerNode(String modelId, String nodeId) { + log.debug("add node {} to model routing table for model: {}", nodeId, modelId); + MLModelCache modelCache = getOrCreateModelCache(modelId); + modelCache.addWorkerNode(nodeId); + } + + /** + * Remove worker nodes for all models. + * @param removedNodes removed nodes + */ + public void removeWorkerNodes(Set removedNodes) { + Set modelIds = modelCaches.keySet(); + for (String modelId : modelIds) { + MLModelCache modelCache = modelCaches.get(modelId); + log.debug("remove worker nodes of model {} : {}", modelId, removedNodes.toArray(new String[0])); + modelCache.removeWorkerNodes(removedNodes); + if (!modelCache.isValidCache()) { + log.debug("remove model cache {}", modelId); + modelCaches.remove(modelId); + } + } + } + + /** + * Remove worker node of model. + * @param modelId model id + * @param nodeId node id + */ + public void removeWorkerNode(String modelId, String nodeId) { + MLModelCache modelCache = modelCaches.get(modelId); + if (modelCache != null) { + log.debug("remove worker node {} of model {} from cache", nodeId, modelId); + modelCache.removeWorkerNode(nodeId); + if (!modelCache.isValidCache()) { + log.debug("remove model {} from cache as no node running it", modelId); + modelCaches.remove(modelId); + } + } + } + + /** + * Sync worker nodes for all models. + * @param modelWorkerNodes worker nodes of all models + */ + public void syncWorkerNodes(Map> modelWorkerNodes) { + log.debug("sync model worker nodes"); + Set currentModels = new HashSet(this.modelCaches.keySet()); + currentModels.removeAll(modelWorkerNodes.keySet()); + if (currentModels.size() > 0) { + currentModels.forEach(modelId -> clearWorkerNodes(modelId)); + } + modelWorkerNodes.entrySet().forEach(entry -> { + MLModelCache modelCache = getOrCreateModelCache(entry.getKey()); + modelCache.syncWorkerNode(entry.getValue()); + }); + } + + /** + * Clear worker nodes for all models. + */ + public void clearWorkerNodes() { + log.debug("clear all model worker nodes"); + modelCaches.entrySet().forEach(entry -> clearWorkerNodes(entry.getKey())); + } + + /** + * Clear worker node of model. + * @param modelId model id + */ + public void clearWorkerNodes(String modelId) { + MLModelCache modelCache = modelCaches.get(modelId); + if (modelCache != null) { + log.debug("clear worker nodes of model {}", modelId); + modelCache.clearWorkerNodes(); + if (!modelCache.isValidCache()) { + modelCaches.remove(modelId); + } + } + } + + /** + * Get model profile. + * @param modelId model id + * @return model profile + */ + public MLModelProfile getModelProfile(String modelId) { + MLModelCache modelCache = modelCaches.get(modelId); + if (modelCache == null) { + return null; + } + + MLModelProfile.MLModelProfileBuilder builder = MLModelProfile.builder(); + builder.modelState(modelCache.getModelState()); + if (modelCache.getPredictor() != null) { + builder.predictor(modelCache.getPredictor().toString()); + } + String[] workerNodes = modelCache.getWorkerNodes(); + if (workerNodes.length > 0) { + builder.workerNodes(workerNodes); + } + MLPredictRequestStats stats = modelCache.getInferenceStats(); + builder.predictStats(stats); + return builder.build(); + } + + /** + * Add model inference duration. + * @param modelId model id + * @param duration time in milliseconds used to run inference. + */ + public void addInferenceDuration(String modelId, double duration) { + MLModelCache modelCache = getOrCreateModelCache(modelId); + modelCache.addInferenceDuration(duration, maxRequestCount); + } + + /** + * Get function name of model + * @param modelId model id + * @return function name + */ + public FunctionName getFunctionName(String modelId) { + MLModelCache modelCache = getExistingModelCache(modelId); + return modelCache.getFunctionName(); + } + + private MLModelCache getExistingModelCache(String modelId) { + MLModelCache modelCache = modelCaches.get(modelId); + if (modelCache == null) { + throw new IllegalArgumentException("Model not found in cache"); + } + return modelCache; + } + + private MLModelCache getOrCreateModelCache(String modelId) { + return modelCaches.computeIfAbsent(modelId, it -> new MLModelCache()); + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 734861f439..e67d939300 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -5,10 +5,17 @@ package org.opensearch.ml.model; +import static org.opensearch.common.xcontent.ToXContent.EMPTY_PARAMS; import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.common.xcontent.XContentType.JSON; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.CommonValue.NOT_FOUND; import static org.opensearch.ml.common.CommonValue.UNLOADED; +import static org.opensearch.ml.common.MLTask.ERROR_FIELD; +import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; +import static org.opensearch.ml.common.MLTask.STATE_FIELD; +import static org.opensearch.ml.common.MLTaskState.COMPLETED; +import static org.opensearch.ml.common.MLTaskState.FAILED; import static org.opensearch.ml.engine.MLEngine.getLoadModelChunkPath; import static org.opensearch.ml.engine.MLEngine.getLoadModelZipPath; import static org.opensearch.ml.engine.MLEngine.getUploadModelPath; @@ -19,9 +26,12 @@ import static org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingModel.MODEL_ZIP_FILE; import static org.opensearch.ml.engine.utils.FileUtils.calculateFileHash; import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly; -import static org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.LOAD_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.UPLOAD_THREAD_POOL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_UPLOAD_TASKS_PER_NODE; +import static org.opensearch.ml.stats.ActionName.UPLOAD; +import static org.opensearch.ml.stats.MLActionLevelStat.ML_ACTION_REQUEST_COUNT; import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import java.io.File; @@ -33,6 +43,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -46,31 +57,32 @@ import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.get.GetRequest; import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.IndicesOptions; +import org.opensearch.action.support.ThreadedActionListener; import org.opensearch.action.support.WriteRequest; import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.NamedXContentRegistry; -import org.opensearch.common.xcontent.ToXContent; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentParser; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.index.reindex.DeleteByQueryRequest; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; -import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.breaker.MLCircuitBreakerService; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.transport.load.LoadModelResponse; import org.opensearch.ml.common.transport.load.MLLoadModelAction; import org.opensearch.ml.common.transport.load.MLLoadModelRequest; import org.opensearch.ml.common.transport.upload.MLUploadInput; @@ -101,11 +113,12 @@ public class MLModelManager { public static final int TIMEOUT_IN_MILLIS = 5000; private final Client client; + private final ClusterService clusterService; private ThreadPool threadPool; private NamedXContentRegistry xContentRegistry; private ModelHelper modelHelper; - private final MLModelCache modelCache; + private final MLModelCacheHelper modelCacheHelper; private final MLStats mlStats; private final MLCircuitBreakerService mlCircuitBreakerService; private final MLIndicesHandler mlIndicesHandler; @@ -130,7 +143,8 @@ public MLModelManager( this.threadPool = threadPool; this.xContentRegistry = xContentRegistry; this.modelHelper = modelHelper; - this.modelCache = new MLModelCache(clusterService, settings); + this.clusterService = clusterService; + this.modelCacheHelper = new MLModelCacheHelper(clusterService, settings); this.mlStats = mlStats; this.mlCircuitBreakerService = mlCircuitBreakerService; this.mlIndicesHandler = mlIndicesHandler; @@ -155,79 +169,48 @@ public void uploadMLModel(MLUploadInput uploadInput, MLTask mlTask) { mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); String errorMsg = checkAndAddRunningTask(mlTask, maxUploadTasksPerNode); if (errorMsg != null) { - mlTaskManager - .updateMLTaskDirectly( - mlTask.getTaskId(), - ImmutableMap.of(MLTask.STATE_FIELD, MLTaskState.FAILED, MLTask.ERROR_FIELD, errorMsg) - ); + mlTaskManager.updateMLTaskDirectly(mlTask.getTaskId(), ImmutableMap.of(STATE_FIELD, FAILED, ERROR_FIELD, errorMsg)); throw new MLLimitExceededException(errorMsg); } mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); - mlStats - .createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.UPLOAD, MLActionLevelStat.ML_ACTION_REQUEST_COUNT) - .increment(); try { - if (uploadInput.getUrl() != null) { - uploadModel(uploadInput, mlTask); - } else { - throw new IllegalArgumentException("wrong model file url"); - } - } catch (Exception e) { - mlStats - .createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.UPLOAD, MLActionLevelStat.ML_ACTION_FAILURE_COUNT) - .increment(); - throw new MLException("Failed to upload model", e); - } finally { - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); - } - } - - public String checkAndAddRunningTask(MLTask mlTask, Integer limit) { - String error = mlTaskManager.checkLimitAndAddRunningTask(mlTask, limit); - if (error != null) { - return error; - } - if (mlCircuitBreakerService.isOpen()) { - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT).increment(); - return "Circuit breaker is open, please check your memory and disk usage!"; - } - return null; - } - - private void uploadModel(MLUploadInput mlUploadInput, MLTask mlTask) { - Semaphore semaphore = new Semaphore(1); - String taskId = mlTask.getTaskId(); - - AtomicInteger uploaded = new AtomicInteger(0); - threadPool.executor(TASK_THREAD_POOL).execute(() -> { - String modelName = mlUploadInput.getModelName(); - String version = mlUploadInput.getVersion(); + mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), UPLOAD, ML_ACTION_REQUEST_COUNT).increment(); + String taskId = mlTask.getTaskId(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + String modelName = uploadInput.getModelName(); + String version = uploadInput.getVersion(); mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(res -> { MLModel mlModelMeta = MLModel .builder() .name(modelName) .algorithm(mlTask.getFunctionName()) .version(version) - .modelFormat(mlUploadInput.getModelFormat()) + .modelFormat(uploadInput.getModelFormat()) .modelState(MLModelState.UPLOADING) - .modelConfig(mlUploadInput.getModelConfig()) + .modelConfig(uploadInput.getModelConfig()) .createdTime(Instant.now()) .build(); IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX); - indexModelMetaRequest - .source(mlModelMeta.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS)); + indexModelMetaRequest.source(mlModelMeta.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS)); indexModelMetaRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.index(indexModelMetaRequest, ActionListener.wrap(modelMetaRes -> { + // create model meta doc + ActionListener listener = ActionListener.wrap(modelMetaRes -> { String modelId = modelMetaRes.getId(); mlTask.setModelId(modelId); log.info("create new model meta doc {} for upload task {}", modelId, taskId); - modelHelper.downloadAndSplit(modelId, modelName, version, mlUploadInput.getUrl(), ActionListener.wrap(result -> { + modelHelper.downloadAndSplit(modelId, modelName, version, uploadInput.getUrl(), ActionListener.wrap(result -> { Long modelSizeInBytes = (Long) result.get(MODEL_SIZE_IN_BYTES); List chunkFiles = (List) result.get(CHUNK_FILES); String hashValue = (String) result.get(MODEL_FILE_HASH); + Semaphore semaphore = new Semaphore(1); + AtomicInteger uploaded = new AtomicInteger(0); + AtomicBoolean failedToUploadChunk = new AtomicBoolean(false); + // upload chunks for (String name : chunkFiles) { + if (failedToUploadChunk.get()) { + throw new MLException("Failed to save model chunk"); + } semaphore.tryAcquire(10, TimeUnit.SECONDS); File file = new File(name); byte[] bytes = Files.toByteArray(file); @@ -238,84 +221,28 @@ private void uploadModel(MLUploadInput mlUploadInput, MLTask mlTask) { .name(modelName) .algorithm(mlTask.getFunctionName()) .version(version) - .modelFormat(mlUploadInput.getModelFormat()) + .modelFormat(uploadInput.getModelFormat()) .chunkNumber(chunkNum) .totalChunks(chunkFiles.size()) .content(Base64.getEncoder().encodeToString(bytes)) .createdTime(Instant.now()) .build(); IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX); - indexRequest.id(getModelChunkId(modelId, chunkNum)); - indexRequest - .source( - mlModel.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS) - ); + String chunkId = getModelChunkId(modelId, chunkNum); + indexRequest.id(chunkId); + indexRequest.source(mlModel.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS)); indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); client.index(indexRequest, ActionListener.wrap(r -> { uploaded.getAndIncrement(); if (uploaded.get() == chunkFiles.size()) { - deleteFileQuietly(getUploadModelPath(modelId)); - updateModel( - modelId, - ImmutableMap - .of( - MLModel.MODEL_STATE_FIELD, - MLModelState.UPLOADED, - MLModel.LAST_UPLOADED_TIME_FIELD, - Instant.now().toEpochMilli(), - MLModel.TOTAL_CHUNKS_FIELD, - chunkFiles.size(), - MLModel.MODEL_CONTENT_HASH_VALUE_FIELD, - hashValue, - MLModel.MODEL_CONTENT_SIZE_IN_BYTES_FIELD, - modelSizeInBytes - ), - ActionListener.wrap(updateResponse -> { - mlTaskManager - .updateMLTask( - taskId, - ImmutableMap - .of(MLTask.STATE_FIELD, MLTaskState.COMPLETED, MLTask.MODEL_ID_FIELD, modelId), - TIMEOUT_IN_MILLIS - ); - mlTaskManager.remove(taskId); - if (mlUploadInput.isLoadModel()) { - String[] modelNodeIds = mlUploadInput.getModelNodeIds(); - log - .debug( - "uploading model done, start loading model {} on nodes: {}", - modelId, - Arrays.toString(modelNodeIds) - ); - MLLoadModelRequest mlLoadModelRequest = new MLLoadModelRequest( - modelId, - modelNodeIds, - false, - true - ); - client - .execute( - MLLoadModelAction.INSTANCE, - mlLoadModelRequest, - ActionListener - .wrap( - response -> { log.info(response); }, - exc -> { exc.printStackTrace(); } - ) - ); - } - }, e -> { - log.error("Failed to index model chunk", e); - handleException(taskId, e); - deleteModel(modelId); - }) - ); + updateModelUpdateStateAsDone(uploadInput, taskId, modelId, modelSizeInBytes, chunkFiles, hashValue); } else { file.delete(); } semaphore.release(); }, e -> { - log.error("Failed to index model chunk", e); + log.error("Failed to index model chunk " + chunkId, e); + failedToUploadChunk.set(true); handleException(taskId, e); file.delete(); // remove model doc as failed to upload model @@ -333,7 +260,9 @@ private void uploadModel(MLUploadInput mlUploadInput, MLTask mlTask) { }, e -> { log.error("Failed to index model meta doc", e); handleException(taskId, e); - })); + }); + + client.index(indexModelMetaRequest, threadedActionListener(UPLOAD_THREAD_POOL, listener)); }, e -> { log.error("Failed to init model index", e); handleException(taskId, e); @@ -342,7 +271,77 @@ private void uploadModel(MLUploadInput mlUploadInput, MLTask mlTask) { log.error("Failed to upload model", e); handleException(taskId, e); } - }); + } catch (Exception e) { + mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), UPLOAD, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment(); + throw new MLException("Failed to upload model", e); + } finally { + mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); + } + } + + private ThreadedActionListener threadedActionListener(String threadPoolName, ActionListener listener) { + return new ThreadedActionListener<>(log, threadPool, threadPoolName, listener, false); + } + + /** + * Check if exceed running task limit and if circuit breaker is open. + * @param mlTask ML task + * @param runningTaskLimit limit + * @return error message if limit exceeds; otherwise return null + */ + public String checkAndAddRunningTask(MLTask mlTask, Integer runningTaskLimit) { + String error = mlTaskManager.checkLimitAndAddRunningTask(mlTask, runningTaskLimit); + if (error != null) { + return error; + } + if (mlCircuitBreakerService.isOpen()) { + mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT).increment(); + return "Circuit breaker is open, please check your memory and disk usage!"; + } + return null; + } + + private void updateModelUpdateStateAsDone( + MLUploadInput uploadInput, + String taskId, + String modelId, + Long modelSizeInBytes, + List chunkFiles, + String hashValue + ) { + deleteFileQuietly(getUploadModelPath(modelId)); + Map updatedFields = ImmutableMap + .of( + MLModel.MODEL_STATE_FIELD, + MLModelState.UPLOADED, + MLModel.LAST_UPLOADED_TIME_FIELD, + Instant.now().toEpochMilli(), + MLModel.TOTAL_CHUNKS_FIELD, + chunkFiles.size(), + MLModel.MODEL_CONTENT_HASH_VALUE_FIELD, + hashValue, + MLModel.MODEL_CONTENT_SIZE_IN_BYTES_FIELD, + modelSizeInBytes + ); + updateModel(modelId, updatedFields, ActionListener.wrap(updateResponse -> { + mlTaskManager.updateMLTask(taskId, ImmutableMap.of(STATE_FIELD, COMPLETED, MODEL_ID_FIELD, modelId), TIMEOUT_IN_MILLIS); + mlTaskManager.remove(taskId); + if (uploadInput.isLoadModel()) { + loadModelAfterUploading(uploadInput, modelId); + } + }, e -> { + log.error("Failed to update model", e); + handleException(taskId, e); + deleteModel(modelId); + })); + } + + private void loadModelAfterUploading(MLUploadInput uploadInput, String modelId) { + String[] modelNodeIds = uploadInput.getModelNodeIds(); + log.debug("start loading model after uploading {} on nodes: {}", modelId, Arrays.toString(modelNodeIds)); + MLLoadModelRequest request = new MLLoadModelRequest(modelId, modelNodeIds, false, true); + ActionListener listener = ActionListener.wrap(r -> log.info(r), e -> log.error("Failed to load model", e)); + client.execute(MLLoadModelAction.INSTANCE, request, listener); } private void deleteModel(String modelId) { @@ -357,21 +356,9 @@ private void deleteModel(String modelId) { } private void handleException(String taskId, Exception e) { - mlTaskManager - .updateMLTask( - taskId, - ImmutableMap.of(MLTask.ERROR_FIELD, ExceptionUtils.getStackTrace(e), MLTask.STATE_FIELD, MLTaskState.FAILED), - ActionListener - .runAfter( - ActionListener - .wrap( - r -> { log.debug("updated task successfully {}", taskId); }, - ex -> { log.error("failed to update ML task " + taskId, ex); } - ), - () -> mlTaskManager.remove(taskId) - ), - TIMEOUT_IN_MILLIS - ); + mlTaskManager.remove(taskId); + Map updated = ImmutableMap.of(ERROR_FIELD, ExceptionUtils.getStackTrace(e), STATE_FIELD, FAILED); + mlTaskManager.updateMLTask(taskId, updated, TIMEOUT_IN_MILLIS); } /** @@ -384,74 +371,71 @@ private void handleException(String taskId, Exception e) { * @param listener action listener */ public void loadModel(String modelId, String modelContentHash, FunctionName functionName, ActionListener listener) { - mlStats.createCounterStatIfAbsent(functionName, ActionName.LOAD, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment(); - if (modelCache.isModelLoaded(modelId)) { + mlStats.createCounterStatIfAbsent(functionName, ActionName.LOAD, ML_ACTION_REQUEST_COUNT).increment(); + if (modelCacheHelper.isModelLoaded(modelId)) { listener.onResponse("successful"); return; } - if (modelCache.modelCount() >= maxModelPerNode) { + if (modelCacheHelper.getLoadedModels().length >= maxModelPerNode) { listener.onFailure(new IllegalArgumentException("Exceed max model per node limit")); return; } - modelCache.initModelState(modelId, MLModelState.LOADING, functionName); + modelCacheHelper.initModelState(modelId, MLModelState.LOADING, functionName); + DiscoveryNode node = clusterService.localNode(); try { - threadPool.executor(TASK_THREAD_POOL).execute(() -> { - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - this.getModel(modelId, ActionListener.wrap(mlModel -> { - if (mlModel.getAlgorithm() != FunctionName.TEXT_EMBEDDING) {// load model trained by built-in algorithm like kmeans - Predictable predictable = MLEngine.load(mlModel, null); - modelCache.addPredictable(modelId, predictable); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment(); - modelCache.setModelState(modelId, MLModelState.LOADED); - listener.onResponse("successful"); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + this.getModel(modelId, threadedActionListener(LOAD_THREAD_POOL, ActionListener.wrap(mlModel -> { + if (mlModel.getAlgorithm() != FunctionName.TEXT_EMBEDDING) {// load model trained by built-in algorithm like kmeans + Predictable predictable = MLEngine.load(mlModel, null); + modelCacheHelper.setPredictor(modelId, predictable); + mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment(); + modelCacheHelper.setModelState(modelId, MLModelState.LOADED); + listener.onResponse("successful"); + return; + } + // check circuit breaker before loading custom model chunks + if (mlCircuitBreakerService.isOpen()) { + mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT).increment(); + throw new MLLimitExceededException("Circuit breaker is open, please check your memory and disk usage!"); + } + retrieveModelChunks(mlModel, ActionListener.wrap(modelZipFile -> {// load model trunks + String hash = calculateFileHash(modelZipFile); + if (modelContentHash != null && !modelContentHash.equals(hash)) { + log.error("Model content hash can't match original hash value"); + modelCacheHelper.removeModel(modelId); + modelHelper.deleteFileCache(modelId); + listener.onFailure(new IllegalArgumentException("model content changed")); return; } - // check circuit breaker before loading custom model chunks - if (mlCircuitBreakerService.isOpen()) { - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT).increment(); - throw new MLLimitExceededException("Circuit breaker is open, please check your memory and disk usage!"); - } - retrieveModelChunks(mlModel, ActionListener.wrap(modelZipFile -> {// load model trunks - String hash = calculateFileHash(modelZipFile); - if (modelContentHash != null && !modelContentHash.equals(hash)) { - log.error("Model content hash can't match original hash value"); - modelCache.removeModelState(modelId); - modelHelper.deleteFileCache(modelId); - listener.onFailure(new IllegalArgumentException("model content changed")); - return; - } - log.debug("Model content matches original hash value, continue loading"); - Predictable predictable = MLEngine - .load(mlModel, ImmutableMap.of(MODEL_ZIP_FILE, modelZipFile, MODEL_HELPER, modelHelper)); - modelCache.addPredictable(modelId, predictable); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment(); - modelCache.setModelState(modelId, MLModelState.LOADED); - listener.onResponse("successful"); - }, e -> { - mlStats - .createCounterStatIfAbsent(functionName, ActionName.LOAD, MLActionLevelStat.ML_ACTION_FAILURE_COUNT) - .increment(); - log.error("Failed to retrieve model " + modelId, e); - modelCache.removeModelState(modelId); - listener.onFailure(e); - })); + log.debug("Model content matches original hash value, continue loading"); + Predictable predictable = MLEngine + .load(mlModel, ImmutableMap.of(MODEL_ZIP_FILE, modelZipFile, MODEL_HELPER, modelHelper)); + modelCacheHelper.setPredictor(modelId, predictable); + mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment(); + modelCacheHelper.setModelState(modelId, MLModelState.LOADED); + listener.onResponse("successful"); }, e -> { - log.error("Failed to load model " + modelId, e); mlStats .createCounterStatIfAbsent(functionName, ActionName.LOAD, MLActionLevelStat.ML_ACTION_FAILURE_COUNT) .increment(); - modelCache.removeModelState(modelId); - listener.onFailure(new MLException("Failed to load model " + modelId, e)); + log.error("Failed to retrieve model " + modelId, e); + modelCacheHelper.removeModel(modelId); + listener.onFailure(e); })); - } catch (Exception e) { + }, e -> { + log.error("Failed to load model " + modelId, e); mlStats.createCounterStatIfAbsent(functionName, ActionName.LOAD, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment(); - modelCache.removeModelState(modelId); - listener.onFailure(e); - } - }); + modelCacheHelper.removeModel(modelId); + listener.onFailure(new MLException("Failed to load model " + modelId, e)); + }))); + } catch (Exception e) { + mlStats.createCounterStatIfAbsent(functionName, ActionName.LOAD, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment(); + modelCacheHelper.removeModel(modelId); + listener.onFailure(e); + } } catch (Exception e) { mlStats.createCounterStatIfAbsent(functionName, ActionName.LOAD, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment(); - modelCache.removeModelState(modelId); + modelCacheHelper.removeModel(modelId); listener.onFailure(e); } } @@ -505,34 +489,34 @@ private void retrieveModelChunks(MLModel mlModelMeta, ActionListener liste Semaphore semaphore = new Semaphore(1); AtomicBoolean stopNow = new AtomicBoolean(false); String modelZip = getLoadModelZipPath(modelId, modelName); - File[] chunkFiles = new File[totalChunks]; + ConcurrentLinkedDeque chunkFiles = new ConcurrentLinkedDeque(); AtomicInteger retrievedChunks = new AtomicInteger(0); for (int i = 0; i < totalChunks; i++) { if (stopNow.get()) { listener.onFailure(new MLException("Failed to load model")); return; } - semaphore.tryAcquire(10, TimeUnit.SECONDS); - + semaphore.acquire(); String modelChunkId = this.getModelChunkId(modelId, i); int currentChunk = i; - this.getModel(modelChunkId, ActionListener.wrap(model -> { + this.getModel(modelChunkId, threadedActionListener(LOAD_THREAD_POOL, ActionListener.wrap(model -> { Path chunkPath = getLoadModelChunkPath(modelId, currentChunk); FileUtils.write(Base64.getDecoder().decode(model.getContent()), chunkPath.toString()); - chunkFiles[currentChunk] = new File(chunkPath.toUri()); - semaphore.release(); + chunkFiles.add(new File(chunkPath.toUri())); retrievedChunks.getAndIncrement(); if (retrievedChunks.get() == totalChunks) { File modelZipFile = new File(modelZip); FileUtils.mergeFiles(chunkFiles, modelZipFile); listener.onResponse(modelZipFile); } + semaphore.release(); }, e -> { stopNow.set(true); semaphore.release(); + log.error("Failed to model and chunks", e); listener.onFailure(new MLResourceNotFoundException("Fail to find model chunk " + modelChunkId)); return; - })); + }))); } } @@ -559,7 +543,7 @@ public void updateModel(String modelId, ImmutableMap updatedFiel * @param updatedFields updated fields * @param listener action listener */ - public void updateModel(String modelId, ImmutableMap updatedFields, ActionListener listener) { + public void updateModel(String modelId, Map updatedFields, ActionListener listener) { try { if (updatedFields == null || updatedFields.size() == 0) { listener.onFailure(new IllegalArgumentException("Updated fields is null or empty")); @@ -595,7 +579,7 @@ public String getModelChunkId(String modelId, Integer chunkNumber) { public void addModelWorkerNode(String modelId, String... nodeIds) { if (nodeIds != null) { for (String nodeId : nodeIds) { - modelCache.addNodeToModelRoutingTable(modelId, nodeId); + modelCacheHelper.addWorkerNode(modelId, nodeId); } } } @@ -609,7 +593,7 @@ public void addModelWorkerNode(String modelId, String... nodeIds) { public void removeModelWorkerNode(String modelId, String... nodeIds) { if (nodeIds != null) { for (String nodeId : nodeIds) { - modelCache.removeNodeFromModelRoutingTable(modelId, nodeId); + modelCacheHelper.removeWorkerNode(modelId, nodeId); } } } @@ -620,7 +604,7 @@ public void removeModelWorkerNode(String modelId, String... nodeIds) { * @param removedNodes removed node ids */ public void removeWorkerNodes(Set removedNodes) { - modelCache.removeWorkNodes(removedNodes); + modelCacheHelper.removeWorkerNodes(removedNodes); } /** @@ -634,26 +618,22 @@ public synchronized Map unloadModel(String[] modelIds) { if (modelIds != null && modelIds.length > 0) { log.debug("unload models {}", Arrays.toString(modelIds)); for (String modelId : modelIds) { - if (modelCache.hasModel(modelId)) { + if (modelCacheHelper.isModelLoaded(modelId)) { modelUnloadStatus.put(modelId, UNLOADED); mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).decrement(); } else { modelUnloadStatus.put(modelId, NOT_FOUND); } - mlStats - .createCounterStatIfAbsent(getModelFunctionName(modelId), ActionName.UNLOAD, MLActionLevelStat.ML_ACTION_REQUEST_COUNT) - .increment(); - modelCache.removeModel(modelId); + mlStats.createCounterStatIfAbsent(getModelFunctionName(modelId), ActionName.UNLOAD, ML_ACTION_REQUEST_COUNT).increment(); + modelCacheHelper.removeModel(modelId); } } else { log.debug("unload all models {}", Arrays.toString(getLocalLoadedModels())); for (String modelId : getLocalLoadedModels()) { modelUnloadStatus.put(modelId, UNLOADED); mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).decrement(); - mlStats - .createCounterStatIfAbsent(getModelFunctionName(modelId), ActionName.UNLOAD, MLActionLevelStat.ML_ACTION_REQUEST_COUNT) - .increment(); - modelCache.removeModel(modelId); + mlStats.createCounterStatIfAbsent(getModelFunctionName(modelId), ActionName.UNLOAD, ML_ACTION_REQUEST_COUNT).increment(); + modelCacheHelper.removeModel(modelId); } } return modelUnloadStatus; @@ -666,7 +646,7 @@ public synchronized Map unloadModel(String[] modelIds) { * @return list of worker node ids */ public String[] getWorkerNodes(String modelId) { - return modelCache.getWorkerNodes(modelId); + return modelCacheHelper.getWorkerNodes(modelId); } /** @@ -676,7 +656,7 @@ public String[] getWorkerNodes(String modelId) { * @param predictable predictable instance */ public void addPredictable(String modelId, Predictable predictable) { - modelCache.addPredictable(modelId, predictable); + modelCacheHelper.setPredictor(modelId, predictable); } /** @@ -685,8 +665,8 @@ public void addPredictable(String modelId, Predictable predictable) { * @param modelId * @return */ - public Predictable getPredictable(String modelId) { - return modelCache.getPredictable(modelId); + public Predictable getPredictor(String modelId) { + return modelCacheHelper.getPredictor(modelId); } /** @@ -695,7 +675,7 @@ public Predictable getPredictable(String modelId) { * @return */ public String[] getAllModelIds() { - return modelCache.getAllModelIds(); + return modelCacheHelper.getAllModels(); } /** @@ -704,27 +684,27 @@ public String[] getAllModelIds() { * @return */ public String[] getLocalLoadedModels() { - return modelCache.getLoadedModels(); + return modelCacheHelper.getLoadedModels(); } /** * Sync model routing table. * - * @param modelRoutingTable + * @param modelWorkerNodes */ - public synchronized void syncModelRouting(Map> modelRoutingTable) { - modelCache.syncModelRouting(modelRoutingTable); + public synchronized void syncModelWorkerNodes(Map> modelWorkerNodes) { + modelCacheHelper.syncWorkerNodes(modelWorkerNodes); } /** * */ public void clearRoutingTable() { - modelCache.clearRoutingTable(); + modelCacheHelper.clearWorkerNodes(); } public MLModelProfile getModelProfile(String modelId) { - return modelCache.getModelProfile(modelId); + return modelCacheHelper.getModelProfile(modelId); } public T trackPredictDuration(String modelId, Supplier supplier) { @@ -732,19 +712,15 @@ public T trackPredictDuration(String modelId, Supplier supplier) { T t = supplier.get(); long end = System.nanoTime(); double durationInMs = (end - start) / 1e6; - modelCache.addInferenceDuration(modelId, durationInMs); + modelCacheHelper.addInferenceDuration(modelId, durationInMs); return t; } public FunctionName getModelFunctionName(String modelId) { - return modelCache.getModelFunctionName(modelId); - } - - public void initModelState(String modelId, MLModelState state, FunctionName functionName) { - modelCache.initModelState(modelId, state, functionName); + return modelCacheHelper.getFunctionName(modelId); } - public boolean containsModel(String modelId) { - return modelCache.containsModel(modelId); + public boolean isModelRunningOnNode(String modelId) { + return modelCacheHelper.isModelRunningOnNode(modelId); } } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 916f7a8dc1..bec22b646d 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -9,7 +9,6 @@ import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; import java.util.Collection; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; @@ -28,6 +27,7 @@ import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.settings.SettingsFilter; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; @@ -136,7 +136,13 @@ import com.google.common.collect.ImmutableList; public class MachineLearningPlugin extends Plugin implements ActionPlugin { - public static final String TASK_THREAD_POOL = "OPENSEARCH_ML_TASK_THREAD_POOL"; + public static final String ML_THREAD_POOL_PREFIX = "thread_pool.ml_commons."; + public static final String GENERAL_THREAD_POOL = "opensearch_ml_general"; + public static final String EXECUTE_THREAD_POOL = "opensearch_ml_execute"; + public static final String TRAIN_THREAD_POOL = "opensearch_ml_train"; + public static final String PREDICT_THREAD_POOL = "opensearch_ml_predict"; + public static final String UPLOAD_THREAD_POOL = "opensearch_ml_upload"; + public static final String LOAD_THREAD_POOL = "opensearch_ml_load"; public static final String ML_BASE_URI = "/_plugins/_ml"; private MLStats mlStats; @@ -381,9 +387,56 @@ public List getRestHandlers( @Override public List> getExecutorBuilders(Settings settings) { - FixedExecutorBuilder ml = new FixedExecutorBuilder(settings, TASK_THREAD_POOL, 4, 4, "ml.task_thread_pool", false); + FixedExecutorBuilder generalThreadPool = new FixedExecutorBuilder( + settings, + GENERAL_THREAD_POOL, + Math.max(1, OpenSearchExecutors.allocatedProcessors(settings) - 1), + 10, + ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL, + false + ); + FixedExecutorBuilder uploadThreadPool = new FixedExecutorBuilder( + settings, + UPLOAD_THREAD_POOL, + Math.max(4, OpenSearchExecutors.allocatedProcessors(settings) - 1), + 10, + ML_THREAD_POOL_PREFIX + UPLOAD_THREAD_POOL, + false + ); + FixedExecutorBuilder loadThreadPool = new FixedExecutorBuilder( + settings, + LOAD_THREAD_POOL, + Math.max(4, OpenSearchExecutors.allocatedProcessors(settings) - 1), + 10, + ML_THREAD_POOL_PREFIX + LOAD_THREAD_POOL, + false + ); + FixedExecutorBuilder executeThreadPool = new FixedExecutorBuilder( + settings, + EXECUTE_THREAD_POOL, + Math.max(1, OpenSearchExecutors.allocatedProcessors(settings) - 1), + 10, + ML_THREAD_POOL_PREFIX + EXECUTE_THREAD_POOL, + false + ); + FixedExecutorBuilder trainThreadPool = new FixedExecutorBuilder( + settings, + TRAIN_THREAD_POOL, + Math.max(1, OpenSearchExecutors.allocatedProcessors(settings) - 1), + 10, + ML_THREAD_POOL_PREFIX + TRAIN_THREAD_POOL, + false + ); + FixedExecutorBuilder predictThreadPool = new FixedExecutorBuilder( + settings, + PREDICT_THREAD_POOL, + Math.max(1, OpenSearchExecutors.allocatedProcessors(settings) - 1), + 10, + ML_THREAD_POOL_PREFIX + PREDICT_THREAD_POOL, + false + ); - return Collections.singletonList(ml); + return ImmutableList.of(generalThreadPool, uploadThreadPool, loadThreadPool, executeThreadPool, trainThreadPool, predictThreadPool); } @Override diff --git a/plugin/src/main/java/org/opensearch/ml/profile/MLModelProfile.java b/plugin/src/main/java/org/opensearch/ml/profile/MLModelProfile.java index aa267a3711..f4853f3872 100644 --- a/plugin/src/main/java/org/opensearch/ml/profile/MLModelProfile.java +++ b/plugin/src/main/java/org/opensearch/ml/profile/MLModelProfile.java @@ -8,6 +8,7 @@ import java.io.IOException; import lombok.Builder; +import lombok.Getter; import lombok.extern.log4j.Log4j2; import org.opensearch.common.io.stream.StreamInput; @@ -17,6 +18,7 @@ import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.ml.common.model.MLModelState; +@Getter @Log4j2 public class MLModelProfile implements ToXContentFragment, Writeable { diff --git a/plugin/src/main/java/org/opensearch/ml/profile/MLPredictRequestStats.java b/plugin/src/main/java/org/opensearch/ml/profile/MLPredictRequestStats.java index 02adac9323..4a62862ee6 100644 --- a/plugin/src/main/java/org/opensearch/ml/profile/MLPredictRequestStats.java +++ b/plugin/src/main/java/org/opensearch/ml/profile/MLPredictRequestStats.java @@ -8,6 +8,7 @@ import java.io.IOException; import lombok.Builder; +import lombok.Getter; import lombok.extern.log4j.Log4j2; import org.opensearch.common.io.stream.StreamInput; @@ -16,6 +17,7 @@ import org.opensearch.common.xcontent.ToXContentFragment; import org.opensearch.common.xcontent.XContentBuilder; +@Getter @Log4j2 public class MLPredictRequestStats implements ToXContentFragment, Writeable { diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java index 897e5881c8..50a9cddd21 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java @@ -5,7 +5,7 @@ package org.opensearch.ml.task; -import static org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.EXECUTE_THREAD_POOL; import lombok.extern.log4j.Log4j2; @@ -77,7 +77,7 @@ protected TransportResponseHandler getResponseHandler(Act */ @Override protected void executeTask(MLExecuteTaskRequest request, ActionListener listener) { - threadPool.executor(TASK_THREAD_POOL).execute(() -> { + threadPool.executor(EXECUTE_THREAD_POOL).execute(() -> { try { mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index ec2bf4d1ec..26928bb885 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -9,7 +9,7 @@ import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.permission.AccessController.checkUserPermissions; import static org.opensearch.ml.permission.AccessController.getUserContext; -import static org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.PREDICT_THREAD_POOL; import java.time.Instant; import java.util.UUID; @@ -172,16 +172,12 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false) - ); + mlInputDatasetHandler.parseSearchQueryInput(mlInput.getInputDataset(), threadedActionListener(dataFrameActionListener)); break; case DATA_FRAME: case TEXT_DOCS: default: - threadPool.executor(TASK_THREAD_POOL).execute(() -> { predict(modelId, mlTask, mlInput, listener); }); + threadPool.executor(PREDICT_THREAD_POOL).execute(() -> { predict(modelId, mlTask, mlInput, listener); }); break; } } @@ -201,9 +197,9 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe // run predict if (modelId != null) { try { - Predictable predictable = mlModelManager.getPredictable(modelId); - if (predictable != null) { - MLOutput output = mlModelManager.trackPredictDuration(modelId, () -> predictable.predict(mlInput)); + Predictable predictor = mlModelManager.getPredictor(modelId); + if (predictor != null) { + MLOutput output = mlModelManager.trackPredictDuration(modelId, () -> predictor.predict(mlInput)); if (output instanceof MLPredictionOutput) { ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name()); } @@ -222,7 +218,7 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe // search model by model id. try (ThreadContext.StoredContext context = threadPool.getThreadContext().stashContext()) { - ActionListener getResponseListener = ActionListener.wrap(r -> { + ActionListener getModelListener = ActionListener.wrap(r -> { if (r == null || !r.isExists()) { internalListener.onFailure(new ResourceNotFoundException("No model found, please check the modelId.")); return; @@ -265,7 +261,7 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe handlePredictFailure(mlTask, internalListener, e, true); }); GetRequest getRequest = new GetRequest(ML_MODEL_INDEX, mlTask.getModelId()); - client.get(getRequest, ActionListener.runBefore(getResponseListener, () -> context.restore())); + client.get(getRequest, threadedActionListener(ActionListener.runBefore(getModelListener, () -> context.restore()))); } catch (Exception e) { log.error("Failed to get model " + mlTask.getModelId(), e); handlePredictFailure(mlTask, internalListener, e, true); @@ -277,6 +273,10 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe } } + private ThreadedActionListener threadedActionListener(ActionListener listener) { + return new ThreadedActionListener<>(log, threadPool, PREDICT_THREAD_POOL, listener, false); + } + private void handlePredictFailure(MLTask mlTask, ActionListener listener, Exception e, boolean trackFailure) { if (trackFailure) { mlStats diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskCache.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskCache.java index 5a8fe2556d..b99d8f5ded 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskCache.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskCache.java @@ -7,6 +7,7 @@ import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Semaphore; @@ -22,7 +23,7 @@ public class MLTaskCache { // List of worker nodes. // For example when load model on ML nodes, these ML nodes are worker nodes. When model // loaded/failed on some node, the node will be removed from worker nodes. - List workerNodes; + Set workerNodes; Map errors; // This is the original worker node count. It may not equal to size of workerNodes as // worker node may be removed later. @@ -34,8 +35,9 @@ public MLTaskCache(MLTask mlTask, List workerNodes) { if (mlTask.isAsync()) { updateTaskIndexSemaphore = new Semaphore(1); } - this.workerNodes = workerNodes; + this.workerNodes = ConcurrentHashMap.newKeySet(); if (workerNodes != null) { + this.workerNodes.addAll(workerNodes); workerNodeSize = workerNodes.size(); } this.errors = new ConcurrentHashMap<>(); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java index b5a2c05749..d660e5036f 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java @@ -204,7 +204,7 @@ public MLTaskCache getMLTaskCache(String taskId) { return null; } - public List getWorkNodes(String taskId) { + public Set getWorkNodes(String taskId) { if (taskCaches.containsKey(taskId)) { return taskCaches.get(taskId).getWorkerNodes(); } @@ -402,7 +402,7 @@ public String[] getLocalRunningLoadModelTasks() { } public void syncRunningLoadModelTasks(Map> runningLoadModelTasks) { - Instant ttlEndTime = Instant.now().minus(120, ChronoUnit.SECONDS); + Instant ttlEndTime = Instant.now().minus(10, ChronoUnit.MINUTES); Set staleTasks = new HashSet<>(); boolean noRunningTask = runningLoadModelTasks == null || runningLoadModelTasks.size() == 0; diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java index 427ad11d8a..effe58002d 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java @@ -5,7 +5,7 @@ package org.opensearch.ml.task; -import static org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.TRAIN_THREAD_POOL; import java.time.Instant; import java.util.UUID; @@ -113,10 +113,10 @@ protected void executeTask(MLTrainingTaskRequest request, ActionListener(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false) + new ThreadedActionListener<>(log, threadPool, TRAIN_THREAD_POOL, dataFrameActionListener, false) ); } else { - threadPool.executor(TASK_THREAD_POOL).execute(() -> { trainAndPredict(mlTask, mlInput, listener); }); + threadPool.executor(TRAIN_THREAD_POOL).execute(() -> { trainAndPredict(mlTask, mlInput, listener); }); } } diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java index 89804959a6..cf374e791e 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java @@ -6,7 +6,7 @@ package org.opensearch.ml.task; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; -import static org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.TRAIN_THREAD_POOL; import java.time.Instant; import java.util.UUID; @@ -158,10 +158,10 @@ private void startTrainingTask(MLTask mlTask, MLInput mlInput, ActionListener(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false) + new ThreadedActionListener<>(log, threadPool, TRAIN_THREAD_POOL, dataFrameActionListener, false) ); } else { - threadPool.executor(TASK_THREAD_POOL).execute(() -> { train(mlTask, mlInput, internalListener); }); + threadPool.executor(TRAIN_THREAD_POOL).execute(() -> { train(mlTask, mlInput, internalListener); }); } } catch (Exception e) { log.error("Failed to train " + mlInput.getAlgorithm(), e); diff --git a/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java b/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java index 55b509dd9e..ced7ef6fce 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java @@ -6,6 +6,7 @@ package org.opensearch.ml.action; import static org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams.ObjectiveType.LOGMULTICLASS; +import static org.opensearch.ml.utils.RestActionUtils.getAllNodes; import static org.opensearch.ml.utils.TestData.TARGET_FIELD; import static org.opensearch.ml.utils.TestData.TIME_FIELD; @@ -16,9 +17,17 @@ import org.opensearch.action.ActionFuture; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.WriteRequest; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.ml.action.profile.MLProfileAction; +import org.opensearch.ml.action.profile.MLProfileRequest; +import org.opensearch.ml.action.profile.MLProfileResponse; +import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; @@ -36,12 +45,19 @@ import org.opensearch.ml.common.input.parameter.rcf.FitRCFParams; import org.opensearch.ml.common.input.parameter.regression.LinearRegressionParams; import org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.MLPredictionOutput; import org.opensearch.ml.common.output.MLTrainingOutput; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.load.LoadModelResponse; +import org.opensearch.ml.common.transport.load.MLLoadModelAction; +import org.opensearch.ml.common.transport.load.MLLoadModelRequest; import org.opensearch.ml.common.transport.model.MLModelGetAction; import org.opensearch.ml.common.transport.model.MLModelGetRequest; import org.opensearch.ml.common.transport.model.MLModelGetResponse; +import org.opensearch.ml.common.transport.model.MLModelSearchAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.common.transport.task.MLTaskGetAction; @@ -50,13 +66,22 @@ import org.opensearch.ml.common.transport.training.MLTrainingTaskAction; import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest; import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction; +import org.opensearch.ml.common.transport.unload.MLUnloadModelAction; +import org.opensearch.ml.common.transport.unload.UnloadModelNodesRequest; +import org.opensearch.ml.common.transport.unload.UnloadModelNodesResponse; +import org.opensearch.ml.common.transport.upload.MLUploadInput; +import org.opensearch.ml.common.transport.upload.MLUploadModelAction; +import org.opensearch.ml.common.transport.upload.MLUploadModelRequest; +import org.opensearch.ml.common.transport.upload.UploadModelResponse; import org.opensearch.ml.plugin.MachineLearningPlugin; +import org.opensearch.ml.profile.MLProfileInput; import org.opensearch.ml.utils.TestData; import org.opensearch.plugins.Plugin; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchIntegTestCase; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.gson.Gson; public class MLCommonsIntegTestCase extends OpenSearchIntegTestCase { @@ -218,6 +243,68 @@ public String trainModel(FunctionName functionName, MLAlgoParams params, MLInput return id; } + public String uploadModel( + FunctionName functionName, + String modelName, + String version, + MLModelFormat modelFormat, + String modelType, + TextEmbeddingModelConfig.FrameworkType frameworkType, + int dimension, + String allConfig, + String url, + boolean loadModel + ) { + MLModelConfig modelConfig = TextEmbeddingModelConfig + .builder() + .modelType(modelType) + .frameworkType(frameworkType) + .embeddingDimension(dimension) + .allConfig(allConfig) + .build(); + MLUploadInput input = MLUploadInput + .builder() + .functionName(functionName) + .modelName(modelName) + .version(version) + .modelFormat(modelFormat) + .modelConfig(modelConfig) + .url(url) + .loadModel(loadModel) + .build(); + MLUploadModelRequest uploadRequest = MLUploadModelRequest.builder().mlUploadInput(input).build(); + ActionFuture actionFuture = client().execute(MLUploadModelAction.INSTANCE, uploadRequest); + UploadModelResponse uploadModelResponse = actionFuture.actionGet(); + String taskId = uploadModelResponse.getTaskId(); + assertNotNull(taskId); + assertFalse(taskId.isEmpty()); + return taskId; + } + + public String loadModel(String modelId) { + MLLoadModelRequest loadRequest = MLLoadModelRequest.builder().modelId(modelId).async(true).dispatchTask(true).build(); + ActionFuture actionFuture = client().execute(MLLoadModelAction.INSTANCE, loadRequest); + LoadModelResponse loadModelResponse = actionFuture.actionGet(); + String taskId = loadModelResponse.getTaskId(); + assertNotNull(taskId); + assertFalse(taskId.isEmpty()); + return taskId; + } + + public MLProfileResponse getModelProfile(String modelId) { + String[] allNodes = getAllNodes(clusterService()); + MLProfileInput profileInput = MLProfileInput + .builder() + .modelIds(ImmutableSet.of(modelId)) + .returnAllModels(true) + .returnAllTasks(true) + .build(); + MLProfileRequest profileRequest = new MLProfileRequest(allNodes, profileInput); + ActionFuture actionFuture = client().execute(MLProfileAction.INSTANCE, profileRequest); + MLProfileResponse response = actionFuture.actionGet(); + return response; + } + public DataFrame predictAndVerify( String modelId, MLInputDataset inputDataset, @@ -235,6 +322,21 @@ public DataFrame predictAndVerify( return predictionResult; } + public MLTaskResponse predict(String modelId, FunctionName functionName, MLInputDataset inputDataset, MLAlgoParams parameters) { + MLInput mlInput = MLInput.builder().algorithm(functionName).inputDataset(inputDataset).parameters(parameters).build(); + MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest(modelId, mlInput); + ActionFuture predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest); + MLTaskResponse predictionResponse = predictionFuture.actionGet(); + return predictionResponse; + } + + public UnloadModelNodesResponse unloadModel(String modelId) { + String[] allNodes = getAllNodes(clusterService()); + UnloadModelNodesRequest unloadRequest = new UnloadModelNodesRequest(allNodes, new String[] { modelId }); + UnloadModelNodesResponse response = client().execute(MLUnloadModelAction.INSTANCE, unloadRequest).actionGet(); + return response; + } + public MLTask getTask(String taskId) { MLTaskGetRequest getRequest = new MLTaskGetRequest(taskId); MLTaskGetResponse response = client().execute(MLTaskGetAction.INSTANCE, getRequest).actionGet(5000); @@ -246,4 +348,14 @@ public MLModel getModel(String modelId) { MLModelGetResponse response = client().execute(MLModelGetAction.INSTANCE, getRequest).actionGet(5000); return response.getMlModel(); } + + public SearchResponse searchModelChunks(String modelId) { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.fetchSource(null, new String[] { MLModel.OLD_MODEL_CONTENT_FIELD, MLModel.MODEL_CONTENT_FIELD }); + QueryBuilder queryBuilder = new TermQueryBuilder(MLModel.MODEL_ID_FIELD, modelId); + searchSourceBuilder.query(queryBuilder); + SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(CommonValue.ML_MODEL_INDEX); + SearchResponse searchResponse = client().execute(MLModelSearchAction.INSTANCE, searchRequest).actionGet(5000); + return searchResponse; + } } diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java new file mode 100644 index 0000000000..4ce3ee0391 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java @@ -0,0 +1,276 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.model; + +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT; +import static org.opensearch.ml.utils.TestHelper.clusterSetting; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.MockitoAnnotations; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingModel; +import org.opensearch.ml.profile.MLModelProfile; +import org.opensearch.ml.profile.MLPredictRequestStats; +import org.opensearch.test.OpenSearchTestCase; + +import com.google.common.collect.ImmutableSet; + +public class MLModelCacheHelperTests extends OpenSearchTestCase { + @Rule + public ExpectedException expectedEx = ExpectedException.none(); + + private ClusterService clusterService; + private Settings settings; + + private MLModelCacheHelper cacheHelper; + + private String modelId; + private String nodeId; + private TextEmbeddingModel predictor; + private int maxMonitoringRequests; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + maxMonitoringRequests = 10; + settings = Settings.builder().put(ML_COMMONS_MONITORING_REQUEST_COUNT.getKey(), maxMonitoringRequests).build(); + ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_MONITORING_REQUEST_COUNT); + clusterService = spy(new ClusterService(settings, clusterSettings, null)); + + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + cacheHelper = new MLModelCacheHelper(clusterService, settings); + + modelId = "model_id1"; + nodeId = "node_id1"; + predictor = spy(new TextEmbeddingModel()); + } + + public void testModelState() { + assertFalse(cacheHelper.isModelLoaded(modelId)); + cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING); + assertFalse(cacheHelper.isModelLoaded(modelId)); + cacheHelper.setModelState(modelId, MLModelState.LOADED); + assertTrue(cacheHelper.isModelLoaded(modelId)); + assertEquals(FunctionName.TEXT_EMBEDDING, cacheHelper.getFunctionName(modelId)); + } + + public void testModelState_DuplicateError() { + expectedEx.expect(IllegalArgumentException.class); + expectedEx.expectMessage("Duplicate model task"); + cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING); + cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING); + } + + public void testPredictor_NotFoundException() { + expectedEx.expect(IllegalArgumentException.class); + expectedEx.expectMessage("Model not found in cache"); + cacheHelper.setPredictor("modelId1", predictor); + } + + public void testPredictor() { + cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING); + assertNull(cacheHelper.getPredictor(modelId)); + cacheHelper.setPredictor(modelId, predictor); + assertEquals(predictor, cacheHelper.getPredictor(modelId)); + } + + public void testGetAndRemoveModel() { + assertFalse(cacheHelper.isModelRunningOnNode(modelId)); + cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING); + String[] loadedModels = cacheHelper.getLoadedModels(); + assertEquals(0, loadedModels.length); + + assertTrue(cacheHelper.isModelRunningOnNode(modelId)); + + cacheHelper.setModelState(modelId, MLModelState.LOADED); + loadedModels = cacheHelper.getLoadedModels(); + assertArrayEquals(new String[] { modelId }, loadedModels); + + cacheHelper.removeModel(modelId); + loadedModels = cacheHelper.getLoadedModels(); + assertEquals(0, loadedModels.length); + } + + public void testRemoveModel_WrongModelId() { + cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING); + cacheHelper.removeModel("wrong_model_id"); + assertArrayEquals(new String[] { modelId }, cacheHelper.getAllModels()); + } + + public void testModelLoaded() { + cacheHelper.addWorkerNode(modelId, nodeId); + String[] loadedModels = cacheHelper.getLoadedModels(); + assertEquals(0, loadedModels.length); + + String[] allModels = cacheHelper.getAllModels(); + assertArrayEquals(new String[] { modelId }, allModels); + } + + public void testGetWorkerNode() { + String[] workerNodes = cacheHelper.getWorkerNodes(modelId); + assertNull(workerNodes); + cacheHelper.addWorkerNode(modelId, nodeId); + workerNodes = cacheHelper.getWorkerNodes(modelId); + assertArrayEquals(new String[] { nodeId }, workerNodes); + } + + public void testRemoveWorkerNode_NullModelState() { + String nodeId2 = "node_id2"; + cacheHelper.addWorkerNode(modelId, nodeId); + cacheHelper.addWorkerNode(modelId, nodeId2); + assertEquals(2, cacheHelper.getWorkerNodes(modelId).length); + + cacheHelper.removeWorkerNode("wrong_model_id", nodeId); + cacheHelper.removeWorkerNode(modelId, nodeId2); + assertArrayEquals(new String[] { nodeId }, cacheHelper.getWorkerNodes(modelId)); + + cacheHelper.removeWorkerNodes(ImmutableSet.of(nodeId)); + assertNull(cacheHelper.getWorkerNodes(modelId)); + + cacheHelper.addWorkerNode(modelId, nodeId); + assertArrayEquals(new String[] { nodeId }, cacheHelper.getWorkerNodes(modelId)); + cacheHelper.removeWorkerNode(modelId, nodeId); + assertEquals(0, cacheHelper.getAllModels().length); + } + + public void testRemoveWorkerNode_ModelState() { + cacheHelper.addWorkerNode(modelId, nodeId); + cacheHelper.setModelState(modelId, MLModelState.LOADING); + cacheHelper.removeWorkerNodes(ImmutableSet.of(nodeId)); + assertEquals(0, cacheHelper.getWorkerNodes(modelId).length); + assertTrue(cacheHelper.isModelRunningOnNode(modelId)); + + cacheHelper.removeModel(modelId); + assertFalse(cacheHelper.isModelRunningOnNode(modelId)); + } + + public void testRemoveModel_Loaded() { + cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING); + cacheHelper.setModelState(modelId, MLModelState.LOADED); + cacheHelper.setPredictor(modelId, predictor); + cacheHelper.removeModel(modelId); + verify(predictor, times(1)).close(); + } + + public void testClearWorkerNodes_NullModelState() { + String modelId2 = "model_id2"; + cacheHelper.addWorkerNode(modelId, nodeId); + cacheHelper.addWorkerNode(modelId2, nodeId); + cacheHelper.clearWorkerNodes(); + assertEquals(0, cacheHelper.getAllModels().length); + } + + public void testClearWorkerNodes_ModelState() { + cacheHelper.initModelState(modelId, MLModelState.LOADED, FunctionName.TEXT_EMBEDDING); + cacheHelper.addWorkerNode(modelId, nodeId); + cacheHelper.clearWorkerNodes(); + assertArrayEquals(new String[] { modelId }, cacheHelper.getAllModels()); + } + + public void testClearWorkerNodes_WrongModelId() { + cacheHelper.addWorkerNode(modelId, nodeId); + cacheHelper.clearWorkerNodes("wrong_model_id"); + assertArrayEquals(new String[] { modelId }, cacheHelper.getAllModels()); + } + + public void testSyncWorkerNodes_NullModelState() { + String modelId2 = "model_id2"; + cacheHelper.addWorkerNode(modelId, nodeId); + cacheHelper.addWorkerNode(modelId2, nodeId); + + String newNodeId = "new_node_id"; + Map> modelWorkerNodes = new HashMap<>(); + modelWorkerNodes.put(modelId, ImmutableSet.of(newNodeId)); + cacheHelper.syncWorkerNodes(modelWorkerNodes); + assertArrayEquals(new String[] { modelId }, cacheHelper.getAllModels()); + assertArrayEquals(new String[] { newNodeId }, cacheHelper.getWorkerNodes(modelId)); + } + + public void testSyncWorkerNodes_ModelState() { + String modelId2 = "model_id2"; + cacheHelper.initModelState(modelId2, MLModelState.LOADED, FunctionName.TEXT_EMBEDDING); + cacheHelper.addWorkerNode(modelId, nodeId); + cacheHelper.addWorkerNode(modelId2, nodeId); + + String newNodeId = "new_node_id"; + Map> modelWorkerNodes = new HashMap<>(); + modelWorkerNodes.put(modelId, ImmutableSet.of(newNodeId)); + cacheHelper.syncWorkerNodes(modelWorkerNodes); + assertEquals(2, cacheHelper.getAllModels().length); + assertEquals(0, cacheHelper.getWorkerNodes(modelId2).length); + assertArrayEquals(new String[] { newNodeId }, cacheHelper.getWorkerNodes(modelId)); + } + + public void testSyncWorkerNodes_ModelState_NoModelLoaded() { + cacheHelper.addWorkerNode(modelId, nodeId); + + String newModelId = "new_model_id"; + String newNodeId = "new_node_id"; + Map> modelWorkerNodes = new HashMap<>(); + modelWorkerNodes.put(newModelId, ImmutableSet.of(newNodeId)); + cacheHelper.syncWorkerNodes(modelWorkerNodes); + assertArrayEquals(new String[] { newModelId }, cacheHelper.getAllModels()); + assertArrayEquals(new String[] { newNodeId }, cacheHelper.getWorkerNodes(newModelId)); + assertNull(cacheHelper.getWorkerNodes(modelId)); + + cacheHelper.syncWorkerNodes(modelWorkerNodes); + assertArrayEquals(new String[] { newModelId }, cacheHelper.getAllModels()); + assertArrayEquals(new String[] { newNodeId }, cacheHelper.getWorkerNodes(newModelId)); + assertNull(cacheHelper.getWorkerNodes(modelId)); + } + + public void testGetModelProfile_WrongModelId() { + MLModelProfile modelProfile = cacheHelper.getModelProfile(modelId); + assertNull(modelProfile); + } + + public void testGetModelProfile() { + cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING); + cacheHelper.setModelState(modelId, MLModelState.LOADED); + cacheHelper.setPredictor(modelId, predictor); + cacheHelper.addWorkerNode(modelId, nodeId); + MLModelProfile modelProfile = cacheHelper.getModelProfile(modelId); + assertNotNull(modelProfile); + assertTrue(modelProfile.getPredictor().contains("TextEmbeddingModel")); + assertEquals(MLModelState.LOADED, modelProfile.getModelState()); + assertArrayEquals(new String[] { nodeId }, modelProfile.getWorkerNodes()); + assertNull(modelProfile.getPredictStats()); + + for (int i = 1; i <= maxMonitoringRequests * 2; i++) { + cacheHelper.addInferenceDuration(modelId, i); + } + MLPredictRequestStats predictStats = cacheHelper.getModelProfile(modelId).getPredictStats(); + assertNotNull(predictStats); + assertEquals(maxMonitoringRequests + 1, predictStats.getMin(), 1e-5); + assertEquals(maxMonitoringRequests * 2, predictStats.getMax(), 1e-5); + assertEquals((maxMonitoringRequests + 1 + maxMonitoringRequests * 2) / 2.0, predictStats.getAverage(), 1e-5); + assertEquals(maxMonitoringRequests, predictStats.getCount().longValue()); + } + + public void testGetModelProfile_Loading() { + cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING); + MLModelProfile modelProfile = cacheHelper.getModelProfile(modelId); + assertNotNull(modelProfile); + assertEquals(MLModelState.LOADING, modelProfile.getModelState()); + assertNull(modelProfile.getPredictor()); + assertNull(modelProfile.getWorkerNodes()); + assertNull(modelProfile.getPredictStats()); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java new file mode 100644 index 0000000000..c1c881d028 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -0,0 +1,204 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.model; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.plugin.MachineLearningPlugin.UPLOAD_THREAD_POOL; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_UPLOAD_TASKS_PER_NODE; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT; +import static org.opensearch.ml.utils.TestHelper.clusterSetting; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionListener; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.breaker.MLCircuitBreakerService; +import org.opensearch.ml.common.dataset.MLInputDataType; +import org.opensearch.ml.common.exception.MLLimitExceededException; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.transport.upload.MLUploadInput; +import org.opensearch.ml.engine.ModelHelper; +import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.stats.MLNodeLevelStat; +import org.opensearch.ml.stats.MLStat; +import org.opensearch.ml.stats.MLStats; +import org.opensearch.ml.stats.suppliers.CounterSupplier; +import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; + +public class MLModelManagerTests extends OpenSearchTestCase { + + @Rule + public ExpectedException expectedEx = ExpectedException.none(); + + private ClusterService clusterService; + @Mock + private Client client; + @Mock + private ThreadPool threadPool; + private NamedXContentRegistry xContentRegistry; + @Mock + private ModelHelper modelHelper; + private Settings settings; + private MLStats mlStats; + @Mock + private MLCircuitBreakerService mlCircuitBreakerService; + @Mock + private MLIndicesHandler mlIndicesHandler; + @Mock + private MLTaskManager mlTaskManager; + + private MLModelManager modelManager; + + private String modelName; + private String version; + private MLUploadInput uploadInput; + private MLTask mlTask; + @Mock + private ExecutorService taskExecutorService; + private ThreadContext threadContext; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + settings = Settings.builder().put(ML_COMMONS_MAX_MODELS_PER_NODE.getKey(), 10).build(); + settings = Settings.builder().put(ML_COMMONS_MAX_UPLOAD_TASKS_PER_NODE.getKey(), 10).build(); + settings = Settings.builder().put(ML_COMMONS_MONITORING_REQUEST_COUNT.getKey(), 10).build(); + ClusterSettings clusterSettings = clusterSetting( + settings, + ML_COMMONS_MAX_MODELS_PER_NODE, + ML_COMMONS_MAX_UPLOAD_TASKS_PER_NODE, + ML_COMMONS_MONITORING_REQUEST_COUNT + ); + clusterService = spy(new ClusterService(settings, clusterSettings, null)); + + modelName = "model_name1"; + version = "1.0.0"; + MLModelConfig modelConfig = TextEmbeddingModelConfig + .builder() + .modelType("bert") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(384) + .build(); + uploadInput = MLUploadInput + .builder() + .modelName(modelName) + .version(version) + .functionName(FunctionName.TEXT_EMBEDDING) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(modelConfig) + .url("test_url") + .build(); + + Map> stats = new ConcurrentHashMap<>(); + // node level stats + stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); + this.mlStats = new MLStats(stats); + + mlTask = MLTask + .builder() + .taskId("taskId1") + .modelId("modelId1") + .taskType(MLTaskType.UPLOAD_MODEL) + .functionName(FunctionName.TEXT_EMBEDDING) + .state(MLTaskState.CREATED) + .inputType(MLInputDataType.TEXT_DOCS) + .build(); + + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(taskExecutorService).execute(any()); + + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + modelManager = new MLModelManager( + clusterService, + client, + threadPool, + xContentRegistry, + modelHelper, + settings, + mlStats, + mlCircuitBreakerService, + mlIndicesHandler, + mlTaskManager + ); + } + + public void testUploadMLModel_ExceedMaxRunningTask() { + String error = "exceed max running task limit"; + expectedEx.expect(MLLimitExceededException.class); + expectedEx.expectMessage(error); + when(mlTaskManager.checkLimitAndAddRunningTask(any(), any())).thenReturn(error); + modelManager.uploadMLModel(uploadInput, mlTask); + verify(mlTaskManager, never()).updateMLTaskDirectly(eq(mlTask.getTaskId()), any()); + } + + public void testUploadMLModel_CircuitBreakerOpen() { + expectedEx.expect(MLLimitExceededException.class); + expectedEx.expectMessage("Circuit breaker is open, please check your memory and disk usage!"); + when(mlTaskManager.checkLimitAndAddRunningTask(any(), any())).thenReturn(null); + when(mlCircuitBreakerService.isOpen()).thenReturn(true); + modelManager.uploadMLModel(uploadInput, mlTask); + verify(mlTaskManager, never()).updateMLTaskDirectly(eq(mlTask.getTaskId()), any()); + } + + public void testUploadMLModel_InitModelIndexFailure() { + when(mlTaskManager.checkLimitAndAddRunningTask(any(), any())).thenReturn(null); + when(mlCircuitBreakerService.isOpen()).thenReturn(false); + when(threadPool.executor(UPLOAD_THREAD_POOL)).thenReturn(taskExecutorService); + setUpMock_InitModelIndexFailure(); + + modelManager.uploadMLModel(uploadInput, mlTask); + verify(mlTaskManager, times(1)).remove(any()); + verify(modelHelper, never()).downloadAndSplit(any(), any(), any(), any(), any()); + verify(client, never()).index(any(), any()); + } + + private void setUpMock_InitModelIndexFailure() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new RuntimeException("test failure")); + return null; + }).when(mlIndicesHandler).initModelIndexIfAbsent(any()); + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java index b3900504e4..89299610a8 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java @@ -20,8 +20,8 @@ import java.time.Instant; import java.util.Arrays; import java.util.HashMap; -import java.util.List; import java.util.Map; +import java.util.Set; import org.junit.Assert; import org.junit.Before; @@ -339,10 +339,10 @@ public void testMLTaskCache() { assertEquals(task, mlTaskCache.getMlTask()); assertFalse(mlTaskCache.hasError()); - List workNodes = mlTaskManager.getWorkNodes(task.getTaskId()); + Set workNodes = mlTaskManager.getWorkNodes(task.getTaskId()); assertEquals(2, workNodes.size()); - assertEquals(node1, workNodes.get(0)); - assertEquals(node2, workNodes.get(1)); + assertTrue(workNodes.contains(node1)); + assertTrue(workNodes.contains(node2)); String wrongTaskId = "wrong_task_id"; assertNull(mlTaskManager.getWorkNodes(wrongTaskId)); diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java index fce55a7c0f..0b32754928 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java @@ -25,6 +25,8 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.apache.hc.core5.http.Header; import org.apache.hc.core5.http.HttpEntity; @@ -46,6 +48,7 @@ import org.opensearch.common.bytes.BytesArray; import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.collect.ImmutableOpenMap; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.transport.TransportAddress; @@ -71,6 +74,7 @@ import org.opensearch.test.rest.FakeRestRequest; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Sets; public class TestHelper { @@ -371,4 +375,12 @@ public static ClusterState setupTestClusterState() { false ); } + + public static ClusterSettings clusterSetting(Settings settings, Setting... setting) { + final Set> settingsSet = Stream + .concat(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), Sets.newHashSet(setting).stream()) + .collect(Collectors.toSet()); + ClusterSettings clusterSettings = new ClusterSettings(settings, settingsSet); + return clusterSettings; + } }