Skip to content

Commit

Permalink
Check for conflciting inference models before starting trained model …
Browse files Browse the repository at this point in the history
…deployment
  • Loading branch information
maxhniebergall committed Dec 21, 2023
1 parent 509dfbd commit dfa97fb
Showing 1 changed file with 20 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
Expand All @@ -45,6 +46,7 @@
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
Expand Down Expand Up @@ -292,8 +294,24 @@ protected void masterOperation(

}, listener::onFailure);

GetTrainedModelsAction.Request getModelRequest = new GetTrainedModelsAction.Request(request.getModelId());
client.execute(GetTrainedModelsAction.INSTANCE, getModelRequest, getModelListener);
ActionListener<GetInferenceModelAction.Response> getInferenceModelListener = ActionListener.wrap((getInferenceModelResponse) -> {
if (getInferenceModelResponse.getModels().isEmpty() == false) {
listener.onFailure(
ExceptionsHelper.badRequestException(
"Model IDs must be unique. Requested model ID [{}] matches existing model IDs [{}], but must not.",
request.getModelId(),
getInferenceModelResponse.getModels()
)
);
return;
} else {
GetTrainedModelsAction.Request getModelRequest = new GetTrainedModelsAction.Request(request.getModelId());
client.execute(GetTrainedModelsAction.INSTANCE, getModelRequest, getModelListener);
}
}, listener::onFailure);

GetInferenceModelAction.Request getModelRequest = new GetInferenceModelAction.Request(request.getModelId(), TaskType.ANY);
client.execute(GetInferenceModelAction.INSTANCE, getModelRequest, getInferenceModelListener);
}

private void waitForDeploymentState(
Expand Down

0 comments on commit dfa97fb

Please sign in to comment.