Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forwarding port changes in 2.4 to main branch (refactor model cache and thread pool) #564

Merged
merged 1 commit into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.common.transport.load;

import lombok.Getter;
import org.opensearch.action.ActionResponse;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
Expand All @@ -13,6 +14,7 @@

import java.io.IOException;

@Getter
public class LoadModelResponse extends ActionResponse implements ToXContentObject {
public static final String TASK_ID_FIELD = "task_id";
public static final String STATUS_FIELD = "status";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,24 @@ public class MLUploadInput implements ToXContentObject, Writeable {
public MLUploadInput(FunctionName functionName, String modelName, String version, String url, MLModelFormat modelFormat, MLModelConfig modelConfig, boolean loadModel, String[] modelNodeIds) {
if (functionName == null) {
this.functionName = FunctionName.TEXT_EMBEDDING;
} else {
this.functionName = functionName;
}
if (modelName == null) {
throw new IllegalArgumentException("model name is null");
}
if (version == null) {
throw new IllegalArgumentException("model version is null");
}
if (modelFormat == null) {
throw new IllegalArgumentException("model format is null");
}
if (modelConfig == null) {
throw new IllegalArgumentException("model config is null");
}
if (url == null) {
throw new IllegalArgumentException("model file url is null");
}
this.modelName = modelName;
this.version = version;
this.url = url;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.common.transport.upload;

import lombok.Getter;
import org.opensearch.action.ActionResponse;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
Expand All @@ -13,6 +14,7 @@

import java.io.IOException;

@Getter
public class UploadModelResponse extends ActionResponse implements ToXContentObject {
public static final String TASK_ID_FIELD = "task_id";
public static final String STATUS_FIELD = "status";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,6 @@ private void verifyModelZipFile(String modelZipFilePath) throws IOException {
hasModelFile = true;
}
if (fileName.equals(TOKENIZER_FILE_NAME)) {
if (hasTokenizerFile) {
throw new IllegalArgumentException("Find multiple tokenizer files");
}
hasTokenizerFile = true;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.DataFrameBuilder;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
import org.opensearch.ml.common.FunctionName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Queue;
import java.util.Set;

/**
Expand Down Expand Up @@ -92,13 +93,13 @@ public static void write(byte[] data, File destinationFile, boolean append) thro

/**
* Merge files into one big file.
* @param files array of files
* @param files chunk files
* @param mergedFile merged file
*/
public static void mergeFiles(File[] files, File mergedFile) {
public static void mergeFiles(Queue<File> files, File mergedFile) {
boolean failed = false;
for (int i = 0; i< files.length ; i++) {
File f = files[i];
while (!files.isEmpty()) {
File f = files.poll();
try (InputStream inStream = new BufferedInputStream(new FileInputStream(f))) {
if (!failed) {
int fileLength = (int) f.length();
Expand All @@ -108,11 +109,11 @@ public static void mergeFiles(File[] files, File mergedFile) {
write(fileContent, mergedFile, true);
}
org.apache.commons.io.FileUtils.deleteQuietly(f);
if (i == files.length - 1) {
if (files.isEmpty()) {
org.apache.commons.io.FileUtils.deleteQuietly(f.getParentFile());
}
} catch (IOException e) {
log.error("Failed to merge file " + f.getAbsolutePath() + " to " + mergedFile.getAbsolutePath());
log.error("Failed to merge file " + f.getAbsolutePath() + " to " + mergedFile.getAbsolutePath(), e);
failed = true;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.Set;

import lombok.extern.log4j.Log4j2;

Expand Down Expand Up @@ -105,7 +105,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
try {
switch (requestType) {
case LOAD_MODEL_DONE:
List<String> workNodes = mlTaskManager.getWorkNodes(taskId);
Set<String> workNodes = mlTaskManager.getWorkNodes(taskId);
if (workNodes != null) {
workNodes.remove(workerNodeId);
}
Expand All @@ -114,6 +114,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
mlTaskManager.addNodeError(taskId, workerNodeId, error);
} else {
mlModelManager.addModelWorkerNode(modelId, workerNodeId);
syncModelWorkerNodes(modelId);
}

if (workNodes == null || workNodes.size() == 0) {
Expand All @@ -122,23 +123,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
if (mlTaskCache.allNodeFailed()) {
taskState = MLTaskState.FAILED;
} else {
DiscoveryNode[] allNodes = nodeFilter.getAllNodes();
String[] workerNodes = mlModelManager.getWorkerNodes(modelId);
if (allNodes.length > 1 && workerNodes.length > 0) {
log.debug("sync model routing to other nodes. model loaded on nodes: {}", Arrays.toString(workerNodes));
MLSyncUpInput syncUpInput = MLSyncUpInput
.builder()
.addedWorkerNodes(ImmutableMap.of(modelId, workerNodes))
.build();
MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(allNodes, syncUpInput);
client
.execute(
MLSyncUpAction.INSTANCE,
syncUpRequest,
ActionListener
.wrap(r -> { log.debug("Sync up successfully"); }, e -> { log.error("Failed to sync up", e); })
);
}
syncModelWorkerNodes(modelId);
}
ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
builder.put(MLTask.STATE_FIELD, taskState);
Expand Down Expand Up @@ -181,4 +166,20 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
listener.onFailure(e);
}
}

private void syncModelWorkerNodes(String modelId) {
DiscoveryNode[] allNodes = nodeFilter.getAllNodes();
String[] workerNodes = mlModelManager.getWorkerNodes(modelId);
if (allNodes.length > 1 && workerNodes.length > 0) {
log.debug("Sync to other nodes about worker nodes of model {}: {}", modelId, Arrays.toString(workerNodes));
MLSyncUpInput syncUpInput = MLSyncUpInput.builder().addedWorkerNodes(ImmutableMap.of(modelId, workerNodes)).build();
MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(allNodes, syncUpInput);
client
.execute(
MLSyncUpAction.INSTANCE,
syncUpRequest,
ActionListener.wrap(r -> log.debug("Sync up successfully"), e -> log.error("Failed to sync up", e))
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.ml.action.load;

import static org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL;
import static org.opensearch.ml.plugin.MachineLearningPlugin.LOAD_THREAD_POOL;

import java.time.Instant;
import java.util.ArrayList;
Expand Down Expand Up @@ -163,7 +163,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<LoadMo
try {
mlTaskManager.add(mlTask, nodeIds);
listener.onResponse(new LoadModelResponse(taskId, MLTaskState.CREATED.name()));
threadPool.executor(TASK_THREAD_POOL).execute(() -> {
threadPool.executor(LOAD_THREAD_POOL).execute(() -> {
LoadModelInput loadModelInput = new LoadModelInput(
modelId,
taskId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,10 @@ private MLProfileNodeResponse createMLProfileNodeResponse(MLProfileRequest mlPro
Arrays.stream(mlModelManager.getAllModelIds()).forEach(modelId -> {
if (mlProfileInput.isReturnAllModels() || (!mlProfileInput.emptyModels() && targetModelIds.contains(modelId))) {
log.debug("Runtime model profile is found for model {}", modelId);
mlLocalModels.put(modelId, mlModelManager.getModelProfile(modelId));
MLModelProfile modelProfile = mlModelManager.getModelProfile(modelId);
if (modelProfile != null) {
mlLocalModels.put(modelId, modelProfile);
}
}
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest loadM
for (Map.Entry<String, Set<String>> entry : modelRoutingTable.entrySet()) {
log.debug("latest routing table for model: {}: {}", entry.getKey(), entry.getValue().toArray(new String[0]));
}
mlModelManager.syncModelRouting(modelRoutingTable);
mlModelManager.syncModelWorkerNodes(modelRoutingTable);
}

if (syncUpInput.isSyncRunningLoadModelTasks()) {
Expand All @@ -166,7 +166,9 @@ private void cleanUpLocalCacheFiles() {
Arrays.toString(modelsInCacheFolder.toArray(new String[0]))
);
for (String modelId : modelsInCacheFolder) {
if (!mlTaskManager.contains(modelId) && !mlTaskManager.containsModel(modelId) && !mlModelManager.containsModel(modelId)) {
if (!mlTaskManager.contains(modelId)
&& !mlTaskManager.containsModel(modelId)
&& !mlModelManager.isModelRunningOnNode(modelId)) {
log.info("ML model not in cache. Remove all of its cache files. model id: {}", modelId);
deleteFileCache(modelId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.ml.cluster;

import static org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL;
import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS;

import lombok.extern.log4j.Log4j2;
Expand Down Expand Up @@ -62,7 +62,7 @@ public void onClusterManager() {
private void startSyncModelRoutingCron() {
if (jobInterval > 0) {
syncModelRoutingCron = threadPool
.scheduleWithFixedDelay(new MLSyncUpCron(client, nodeHelper), TimeValue.timeValueSeconds(jobInterval), TASK_THREAD_POOL);
.scheduleWithFixedDelay(new MLSyncUpCron(client, nodeHelper), TimeValue.timeValueSeconds(jobInterval), GENERAL_THREAD_POOL);
} else {
log.debug("Stop ML syncup job as its interval is: {}", jobInterval);
}
Expand Down
Loading