Skip to content

Commit

Permalink
adding sdkClient for task manager to create task + updating snapshot …
Browse files Browse the repository at this point in the history
…+ making master key id unique for multi_tenancy (#2861)

* adding sdkClient for task manager to create task

Signed-off-by: Dhrubo Saha <[email protected]>

* addressing the comments

Signed-off-by: Dhrubo Saha <[email protected]>

---------

Signed-off-by: Dhrubo Saha <[email protected]>
  • Loading branch information
dhrubo-os authored Aug 29, 2024
1 parent 8d736fd commit 69a495a
Show file tree
Hide file tree
Showing 27 changed files with 655 additions and 503 deletions.
4 changes: 2 additions & 2 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ buildscript {
ext {
opensearch_group = "org.opensearch"
isSnapshot = "true" == System.getProperty("build.snapshot", "true")
opensearch_version = System.getProperty("opensearch.version", "2.15.0-SNAPSHOT")
opensearch_version = System.getProperty("opensearch.version", "2.16.0-SNAPSHOT")
buildVersionQualifier = System.getProperty("build.version_qualifier", "")
mlCommonsBuildVersion = "multi-tenancy-2.15.0-SNAPSHOT"
mlCommonsBuildVersion = "multi-tenancy-2.16.0-SNAPSHOT"

// 2.0.0-rc1-SNAPSHOT -> 2.0.0.0-rc1-SNAPSHOT
version_tokens = opensearch_version.tokenize('-')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public class CommonValue {
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 3;
public static final String ML_CONFIG_INDEX = ".plugins-ml-config";
public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 3;
public static final String ML_CONTROLLER_INDEX = ".plugins-ml-controller";
public static final Integer ML_CONTROLLER_INDEX_SCHEMA_VERSION = 1;
public static final String ML_MAP_RESPONSE_KEY = "response";
Expand Down Expand Up @@ -422,6 +422,9 @@ public class CommonValue {
+ MLConfig.TYPE_FIELD
+ "\" : {\"type\":\"keyword\"},\n"
+ " \""
+ TENANT_ID
+ "\" : {\"type\":\"keyword\"},\n"
+ " \""
+ MLConfig.CONFIGURATION_FIELD
+ "\" : {\"type\": \"flat_object\"},\n"
+ " \""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX;
import static org.opensearch.ml.common.CommonValue.TENANT_ID;
import static org.opensearch.ml.common.MLConfig.CREATE_TIME_FIELD;

import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.time.Instant;
import java.util.Base64;
Expand Down Expand Up @@ -43,7 +46,6 @@
import com.amazonaws.encryptionsdk.CommitmentPolicy;
import com.amazonaws.encryptionsdk.CryptoResult;
import com.amazonaws.encryptionsdk.jce.JceMasterKey;
import com.google.common.collect.ImmutableMap;

import lombok.extern.log4j.Log4j2;

Expand Down Expand Up @@ -158,10 +160,14 @@ private void handleInitMLConfigIndexSuccess(AtomicReference<Exception> exception
}

private GetDataObjectRequest createGetDataObjectRequest(String tenantId, FetchSourceContext fetchSourceContext) {
String masterKeyId = MASTER_KEY;
if (tenantId != null) {
masterKeyId = MASTER_KEY + "_" + hashString(tenantId);
}
return GetDataObjectRequest
.builder()
.index(ML_CONFIG_INDEX)
.id(MASTER_KEY)
.id(masterKeyId)
.tenantId(tenantId)
.fetchSourceContext(fetchSourceContext)
.build();
Expand Down Expand Up @@ -247,16 +253,46 @@ private void initializeNewMasterKey(
}

private PutDataObjectRequest createPutDataObjectRequest(String tenantId, String generatedMasterKey) {
String masterKeyId = MASTER_KEY;
if (tenantId != null) {
masterKeyId = MASTER_KEY + "_" + hashString(tenantId);
}
return PutDataObjectRequest
.builder()
.tenantId(tenantId)
.index(ML_CONFIG_INDEX)
.id(MASTER_KEY)
.id(masterKeyId)
.overwriteIfExists(false)
.dataObject(ImmutableMap.of(MASTER_KEY, generatedMasterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli()))
.dataObject(
Map
.of(
MASTER_KEY,
generatedMasterKey,
CREATE_TIME_FIELD,
Instant.now().toEpochMilli(),
TENANT_ID,
Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID)
)
)
.build();
}

private String hashString(String input) {
try {
// Create a MessageDigest instance for SHA-256
MessageDigest digest = MessageDigest.getInstance("SHA-256");

// Perform the hashing and get the byte array
byte[] hashBytes = digest.digest(input.getBytes());

// Convert the byte array to a Base64 encoded string
return Base64.getUrlEncoder().encodeToString(hashBytes);

} catch (NoSuchAlgorithmException e) {
throw new RuntimeException("Error: Unable to compute hash", e);
}
}

private void handlePutDataObjectResponse(
String tenantId,
ThreadContext.StoredContext context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ private void deployModel(
mlTaskManager
.updateMLTask(
taskId,
tenantId,
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)),
TASK_SEMAPHORE_TIMEOUT,
true
Expand Down Expand Up @@ -358,6 +359,7 @@ void deployRemoteModel(
ActionListener<MLDeployModelNodesResponse> actionListener = deployModelNodesResponseListener(
mlTask.getTaskId(),
mlModel.getModelId(),
mlModel.getTenantId(),
mlModel.getIsHidden(),
listener
);
Expand Down Expand Up @@ -387,19 +389,21 @@ void deployRemoteModel(
private ActionListener<MLDeployModelNodesResponse> deployModelNodesResponseListener(
String taskId,
String modelId,
String tenantId,
Boolean isHidden,
ActionListener<MLDeployModelResponse> listener
) {
return ActionListener.wrap(r -> {
if (mlTaskManager.contains(taskId)) {
mlTaskManager.updateMLTask(taskId, Map.of(STATE_FIELD, MLTaskState.RUNNING), TASK_SEMAPHORE_TIMEOUT, false);
mlTaskManager.updateMLTask(taskId, tenantId, Map.of(STATE_FIELD, MLTaskState.RUNNING), TASK_SEMAPHORE_TIMEOUT, false);
}
listener.onResponse(new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, MLTaskState.COMPLETED.name()));
}, e -> {
log.error("Failed to deploy model {}", modelId, e);
mlTaskManager
.updateMLTask(
taskId,
tenantId,
Map.of(MLTask.ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(e), STATE_FIELD, FAILED),
TASK_SEMAPHORE_TIMEOUT,
true
Expand Down Expand Up @@ -434,13 +438,15 @@ void updateModelDeployStatusAndTriggerOnNodesAction(
);
ActionListener<MLDeployModelNodesResponse> actionListener = ActionListener.wrap(r -> {
if (mlTaskManager.contains(taskId)) {
mlTaskManager.updateMLTask(taskId, Map.of(STATE_FIELD, MLTaskState.RUNNING), TASK_SEMAPHORE_TIMEOUT, false);
mlTaskManager
.updateMLTask(taskId, mlModel.getTenantId(), Map.of(STATE_FIELD, MLTaskState.RUNNING), TASK_SEMAPHORE_TIMEOUT, false);
}
}, e -> {
log.error("Failed to deploy model {}", modelId, e);
mlTaskManager
.updateMLTask(
taskId,
mlModel.getTenantId(),
Map.of(MLTask.ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(e), STATE_FIELD, FAILED),
TASK_SEMAPHORE_TIMEOUT,
true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@
@Log4j2
public class TransportForwardAction extends HandledTransportAction<ActionRequest, MLForwardResponse> {
private final ClusterService clusterService;
MLTaskManager mlTaskManager;
Client client;
MLModelManager mlModelManager;
DiscoveryNodeHelper nodeHelper;
final MLTaskManager mlTaskManager;
final Client client;
final MLModelManager mlModelManager;
final DiscoveryNodeHelper nodeHelper;

private final Settings settings;

Expand Down Expand Up @@ -131,7 +131,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
syncModelWorkerNodes(modelId, functionName);
}

if (workNodes == null || workNodes.size() == 0) {
if (workNodes == null || workNodes.isEmpty()) {
int currentWorkerNodeCount = mlTaskCache.getWorkerNodeSize();
MLTaskState taskState = mlTaskCache.hasError() ? MLTaskState.COMPLETED_WITH_ERROR : MLTaskState.COMPLETED;
if (mlTaskCache.allNodeFailed()) {
Expand All @@ -147,7 +147,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
builder.put(MLTask.ERROR_FIELD, toJsonString(mlTaskCache.getErrors()));
}
boolean clearAutoReDeployRetryTimes = triggerNextModelDeployAndCheckIfRestRetryTimes(workNodes, taskId);
mlTaskManager.updateMLTask(taskId, builder.build(), TASK_SEMAPHORE_TIMEOUT, true);
mlTaskManager.updateMLTask(taskId, null, builder.build(), TASK_SEMAPHORE_TIMEOUT, true);

MLModelState modelState;
if (!mlTaskCache.allNodeFailed()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,29 +206,34 @@ private void updateModelGroup(
source.put(MLModelGroup.DESCRIPTION_FIELD, updateModelGroupInput.getDescription());
}
if (StringUtils.isNotBlank(updateModelGroupInput.getName()) && !updateModelGroupInput.getName().equals(modelGroupName)) {
mlModelGroupManager.validateUniqueModelGroupName(updateModelGroupInput.getName(), ActionListener.wrap(modelGroups -> {
if (modelGroups != null
&& modelGroups.getHits().getTotalHits() != null
&& modelGroups.getHits().getTotalHits().value != 0) {
for (SearchHit documentFields : modelGroups.getHits()) {
String id = documentFields.getId();
listener
.onFailure(
new IllegalArgumentException(
"The name you provided is already being used by another model with ID: "
+ id
+ ". Please provide a different name"
)
);
}
} else {
source.put(MLModelGroup.MODEL_GROUP_NAME_FIELD, updateModelGroupInput.getName());
updateModelGroup(modelGroupId, source, listener);
}
}, e -> {
log.error("Failed to search model group index", e);
listener.onFailure(e);
}));
mlModelGroupManager
.validateUniqueModelGroupName(
updateModelGroupInput.getName(),
updateModelGroupInput.getTenantId(),
ActionListener.wrap(modelGroups -> {
if (modelGroups != null
&& modelGroups.getHits().getTotalHits() != null
&& modelGroups.getHits().getTotalHits().value != 0) {
for (SearchHit documentFields : modelGroups.getHits()) {
String id = documentFields.getId();
listener
.onFailure(
new IllegalArgumentException(
"The name you provided is already being used by another model with ID: "
+ id
+ ". Please provide a different name"
)
);
}
} else {
source.put(MLModelGroup.MODEL_GROUP_NAME_FIELD, updateModelGroupInput.getName());
updateModelGroup(modelGroupId, source, listener);
}
}, e -> {
log.error("Failed to search model group index", e);
listener.onFailure(e);
})
);
} else {
updateModelGroup(modelGroupId, source, listener);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public class TransportRegisterModelAction extends HandledTransportAction<ActionR

ConnectorAccessControlHelper connectorAccessControlHelper;
MLModelGroupManager mlModelGroupManager;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public TransportRegisterModelAction(
Expand Down Expand Up @@ -170,20 +170,25 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLRegi
}
registerModelInput.setIsHidden(RestActionUtils.isSuperAdminUser(clusterService, client));
if (StringUtils.isEmpty(registerModelInput.getModelGroupId())) {
mlModelGroupManager.validateUniqueModelGroupName(registerModelInput.getModelName(), ActionListener.wrap(modelGroups -> {
if (modelGroups != null
&& modelGroups.getHits().getTotalHits() != null
&& modelGroups.getHits().getTotalHits().value != 0) {
String modelGroupIdOfTheNameProvided = modelGroups.getHits().getAt(0).getId();
registerModelInput.setModelGroupId(modelGroupIdOfTheNameProvided);
checkUserAccess(registerModelInput, listener, true);
} else {
doRegister(registerModelInput, listener);
}
}, e -> {
log.error("Failed to search model group index", e);
listener.onFailure(e);
}));
mlModelGroupManager
.validateUniqueModelGroupName(
registerModelInput.getModelName(),
registerModelInput.getTenantId(),
ActionListener.wrap(modelGroups -> {
if (modelGroups != null
&& modelGroups.getHits().getTotalHits() != null
&& modelGroups.getHits().getTotalHits().value != 0) {
String modelGroupIdOfTheNameProvided = modelGroups.getHits().getAt(0).getId();
registerModelInput.setModelGroupId(modelGroupIdOfTheNameProvided);
checkUserAccess(registerModelInput, listener, true);
} else {
doRegister(registerModelInput, listener);
}
}, e -> {
log.error("Failed to search model group index", e);
listener.onFailure(e);
})
);
} else {
checkUserAccess(registerModelInput, listener, false);
}
Expand Down Expand Up @@ -377,6 +382,7 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen
mlTaskManager
.updateMLTask(
taskId,
registerModelInput.getTenantId(),
ImmutableMap.of(MLTask.ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex), STATE_FIELD, FAILED),
TASK_SEMAPHORE_TIMEOUT,
true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,12 @@ private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest syncU
// and all values in this map is false.
Map<String, Boolean> deployToAllNodes = syncUpInput.getDeployToAllNodes();

if (addedWorkerNodes != null && addedWorkerNodes.size() > 0) {
if (addedWorkerNodes != null && !addedWorkerNodes.isEmpty()) {
for (Map.Entry<String, String[]> entry : addedWorkerNodes.entrySet()) {
mlModelManager.addModelWorkerNode(entry.getKey(), entry.getValue());
}
}
if (removedWorkerNodes != null && removedWorkerNodes.size() > 0) {
if (removedWorkerNodes != null && !removedWorkerNodes.isEmpty()) {
for (Map.Entry<String, String[]> entry : removedWorkerNodes.entrySet()) {
mlModelManager
.removeModelWorkerNode(
Expand Down Expand Up @@ -222,6 +222,7 @@ void cleanUpLocalCache(Map<String, Set<String>> runningDeployModelTasks) {
mlTaskManager
.updateMLTask(
taskId,
null,
ImmutableMap
.of(MLTask.STATE_FIELD, MLTaskState.FAILED, MLTask.ERROR_FIELD, "timeout after " + mlTaskTimeout + " seconds"),
10_000,
Expand All @@ -236,7 +237,7 @@ private void cleanUpLocalCacheFiles() {
Path deployModelRootPath = mlEngine.getDeployModelRootPath();
Path modelCacheRootPath = mlEngine.getModelCacheRootPath();
Set<String> modelsInCacheFolder = FileUtils.getFileNames(registerModelRootPath, deployModelRootPath, modelCacheRootPath);
if (modelsInCacheFolder.size() > 0) {
if (!modelsInCacheFolder.isEmpty()) {
log
.debug(
"Found {} models in cache folder: {}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLRegi
MLRegisterModelMetaInput mlUploadInput = registerModelMetaRequest.getMlRegisterModelMetaInput();

if (StringUtils.isEmpty(mlUploadInput.getModelGroupId())) {
mlModelGroupManager.validateUniqueModelGroupName(mlUploadInput.getName(), ActionListener.wrap(modelGroups -> {
mlModelGroupManager.validateUniqueModelGroupName(mlUploadInput.getName(), null, ActionListener.wrap(modelGroups -> {
if (modelGroups != null
&& modelGroups.getHits().getTotalHits() != null
&& modelGroups.getHits().getTotalHits().value != 0) {
Expand Down
Loading

0 comments on commit 69a495a

Please sign in to comment.