Skip to content

Commit

Permalink
[ML] Downloaded and write model parts using multiple streams (elastic…
Browse files Browse the repository at this point in the history
…#111684) (elastic#112859)

Uses the range header to split the model download into multiple streams
using a separate thread for each stream
  • Loading branch information
davidkyle authored Sep 13, 2024
1 parent 53ff0ac commit 4fe2851
Show file tree
Hide file tree
Showing 11 changed files with 847 additions and 166 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/111684.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 111684
summary: Write downloaded model parts async
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ExecutorBuilder;
import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.xpack.core.ml.packageloader.action.GetTrainedModelPackageConfigAction;
import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction;
import org.elasticsearch.xpack.ml.packageloader.action.ModelDownloadTask;
import org.elasticsearch.xpack.ml.packageloader.action.ModelImporter;
import org.elasticsearch.xpack.ml.packageloader.action.TransportGetTrainedModelPackageConfigAction;
import org.elasticsearch.xpack.ml.packageloader.action.TransportLoadTrainedModelPackage;

Expand All @@ -44,16 +49,15 @@ public class MachineLearningPackageLoader extends Plugin implements ActionPlugin
Setting.Property.Dynamic
);

// re-using thread pool setup by the ml plugin
public static final String UTILITY_THREAD_POOL_NAME = "ml_utility";

// This link will be invalid for serverless, but serverless will never be
// air-gapped, so this message should never be needed.
private static final String MODEL_REPOSITORY_DOCUMENTATION_LINK = format(
"https://www.elastic.co/guide/en/machine-learning/%s/ml-nlp-elser.html#air-gapped-install",
Build.current().version().replaceFirst("^(\\d+\\.\\d+).*", "$1")
);

public static final String MODEL_DOWNLOAD_THREADPOOL_NAME = "model_download";

public MachineLearningPackageLoader() {}

@Override
Expand Down Expand Up @@ -81,6 +85,24 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
);
}

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
return List.of(modelDownloadExecutor(settings));
}

public static FixedExecutorBuilder modelDownloadExecutor(Settings settings) {
// Threadpool with a fixed number of threads for
// downloading the model definition files
return new FixedExecutorBuilder(
settings,
MODEL_DOWNLOAD_THREADPOOL_NAME,
ModelImporter.NUMBER_OF_STREAMS,
-1, // unbounded queue size
"xpack.ml.model_download_thread_pool",
EsExecutors.TaskTrackingConfig.DO_NOT_TRACK
);
}

@Override
public List<BootstrapCheck> getBootstrapChecks() {
return List.of(new BootstrapCheck() {
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.XContentParser;
Expand All @@ -34,16 +35,20 @@
import java.security.AccessController;
import java.security.MessageDigest;
import java.security.PrivilegedAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;

import static java.net.HttpURLConnection.HTTP_MOVED_PERM;
import static java.net.HttpURLConnection.HTTP_MOVED_TEMP;
import static java.net.HttpURLConnection.HTTP_NOT_FOUND;
import static java.net.HttpURLConnection.HTTP_OK;
import static java.net.HttpURLConnection.HTTP_PARTIAL;
import static java.net.HttpURLConnection.HTTP_SEE_OTHER;

/**
Expand All @@ -61,6 +66,73 @@ final class ModelLoaderUtils {

record VocabularyParts(List<String> vocab, List<String> merges, List<Double> scores) {}

// Range in bytes
record RequestRange(long rangeStart, long rangeEnd, int startPart, int numParts) {
public String bytesRange() {
return "bytes=" + rangeStart + "-" + rangeEnd;
}
}

static class HttpStreamChunker {

record BytesAndPartIndex(BytesArray bytes, int partIndex) {}

private final InputStream inputStream;
private final int chunkSize;
private final AtomicLong totalBytesRead = new AtomicLong();
private final AtomicInteger currentPart;
private final int lastPartNumber;

HttpStreamChunker(URI uri, RequestRange range, int chunkSize) {
var inputStream = getHttpOrHttpsInputStream(uri, range);
this.inputStream = inputStream;
this.chunkSize = chunkSize;
this.lastPartNumber = range.startPart() + range.numParts();
this.currentPart = new AtomicInteger(range.startPart());
}

// This ctor exists for testing purposes only.
HttpStreamChunker(InputStream inputStream, RequestRange range, int chunkSize) {
this.inputStream = inputStream;
this.chunkSize = chunkSize;
this.lastPartNumber = range.startPart() + range.numParts();
this.currentPart = new AtomicInteger(range.startPart());
}

public boolean hasNext() {
return currentPart.get() < lastPartNumber;
}

public BytesAndPartIndex next() throws IOException {
int bytesRead = 0;
byte[] buf = new byte[chunkSize];

while (bytesRead < chunkSize) {
int read = inputStream.read(buf, bytesRead, chunkSize - bytesRead);
// EOF??
if (read == -1) {
break;
}
bytesRead += read;
}

if (bytesRead > 0) {
totalBytesRead.addAndGet(bytesRead);
return new BytesAndPartIndex(new BytesArray(buf, 0, bytesRead), currentPart.getAndIncrement());
} else {
return new BytesAndPartIndex(BytesArray.EMPTY, currentPart.get());
}
}

public long getTotalBytesRead() {
return totalBytesRead.get();
}

public int getCurrentPart() {
return currentPart.get();
}
}

static class InputStreamChunker {

private final InputStream inputStream;
Expand Down Expand Up @@ -101,21 +173,26 @@ public int getTotalBytesRead() {
}
}

static InputStream getInputStreamFromModelRepository(URI uri) throws IOException {
static InputStream getInputStreamFromModelRepository(URI uri) {
String scheme = uri.getScheme().toLowerCase(Locale.ROOT);

// if you add a scheme here, also add it to the bootstrap check in {@link MachineLearningPackageLoader#validateModelRepository}
switch (scheme) {
case "http":
case "https":
return getHttpOrHttpsInputStream(uri);
return getHttpOrHttpsInputStream(uri, null);
case "file":
return getFileInputStream(uri);
default:
throw new IllegalArgumentException("unsupported scheme");
}
}

static boolean uriIsFile(URI uri) {
String scheme = uri.getScheme().toLowerCase(Locale.ROOT);
return "file".equals(scheme);
}

static VocabularyParts loadVocabulary(URI uri) {
if (uri.getPath().endsWith(".json")) {
try (InputStream vocabInputStream = getInputStreamFromModelRepository(uri)) {
Expand Down Expand Up @@ -174,7 +251,7 @@ private ModelLoaderUtils() {}

@SuppressWarnings("'java.lang.SecurityManager' is deprecated and marked for removal ")
@SuppressForbidden(reason = "we need socket connection to download")
private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException {
private static InputStream getHttpOrHttpsInputStream(URI uri, @Nullable RequestRange range) {

assert uri.getUserInfo() == null : "URI's with credentials are not supported";

Expand All @@ -186,18 +263,30 @@ private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException
PrivilegedAction<InputStream> privilegedHttpReader = () -> {
try {
HttpURLConnection conn = (HttpURLConnection) uri.toURL().openConnection();
if (range != null) {
conn.setRequestProperty("Range", range.bytesRange());
}
switch (conn.getResponseCode()) {
case HTTP_OK:
case HTTP_PARTIAL:
return conn.getInputStream();

case HTTP_MOVED_PERM:
case HTTP_MOVED_TEMP:
case HTTP_SEE_OTHER:
throw new IllegalStateException("redirects aren't supported yet");
case HTTP_NOT_FOUND:
throw new ResourceNotFoundException("{} not found", uri);
case 416: // Range not satisfiable, for some reason not in the list of constants
throw new IllegalStateException("Invalid request range [" + range.bytesRange() + "]");
default:
int responseCode = conn.getResponseCode();
throw new ElasticsearchStatusException("error during downloading {}", RestStatus.fromCode(responseCode), uri);
throw new ElasticsearchStatusException(
"error during downloading {}. Got response code {}",
RestStatus.fromCode(responseCode),
uri,
responseCode
);
}
} catch (IOException e) {
throw new UncheckedIOException(e);
Expand All @@ -209,7 +298,7 @@ private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException

@SuppressWarnings("'java.lang.SecurityManager' is deprecated and marked for removal ")
@SuppressForbidden(reason = "we need load model data from a file")
private static InputStream getFileInputStream(URI uri) {
static InputStream getFileInputStream(URI uri) {

SecurityManager sm = System.getSecurityManager();
if (sm != null) {
Expand All @@ -232,4 +321,53 @@ private static InputStream getFileInputStream(URI uri) {
return AccessController.doPrivileged(privilegedFileReader);
}

/**
* Split a stream of size {@code sizeInBytes} into {@code numberOfStreams} +1
* ranges aligned on {@code chunkSizeBytes} boundaries. Each range contains a
* whole number of chunks.
* The first {@code numberOfStreams} ranges will be split evenly (in terms of
* number of chunks not the byte size), the final range split
* is for the single final chunk and will be no more than {@code chunkSizeBytes}
* in size. The separate range for the final chunk is because when streaming and
* uploading a large model definition, writing the last part has to handled
* as a special case.
* @param sizeInBytes The total size of the stream
* @param numberOfStreams Divide the bulk of the size into this many streams.
* @param chunkSizeBytes The size of each chunk
* @return List of {@code numberOfStreams} + 1 ranges.
*/
static List<RequestRange> split(long sizeInBytes, int numberOfStreams, long chunkSizeBytes) {
int numberOfChunks = (int) ((sizeInBytes + chunkSizeBytes - 1) / chunkSizeBytes);

var ranges = new ArrayList<RequestRange>();

int baseChunksPerStream = numberOfChunks / numberOfStreams;
int remainder = numberOfChunks % numberOfStreams;
long startOffset = 0;
int startChunkIndex = 0;

for (int i = 0; i < numberOfStreams - 1; i++) {
int numChunksInStream = (i < remainder) ? baseChunksPerStream + 1 : baseChunksPerStream;
long rangeEnd = startOffset + (numChunksInStream * chunkSizeBytes) - 1; // range index is 0 based
ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksInStream));
startOffset = rangeEnd + 1; // range is inclusive start and end
startChunkIndex += numChunksInStream;
}

// Want the final range request to be a single chunk
if (baseChunksPerStream > 1) {
int numChunksExcludingFinal = baseChunksPerStream - 1;
long rangeEnd = startOffset + (numChunksExcludingFinal * chunkSizeBytes) - 1;
ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksExcludingFinal));

startOffset = rangeEnd + 1;
startChunkIndex += numChunksExcludingFinal;
}

// The final range is a single chunk the end of which should not exceed sizeInBytes
long rangeEnd = Math.min(sizeInBytes, startOffset + (baseChunksPerStream * chunkSizeBytes)) - 1;
ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, 1));

return ranges;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ protected void masterOperation(Task task, Request request, ClusterState state, A
String packagedModelId = request.getPackagedModelId();
logger.debug(() -> format("Fetch package manifest for [%s] from [%s]", packagedModelId, repository));

threadPool.executor(MachineLearningPackageLoader.UTILITY_THREAD_POOL_NAME).execute(() -> {
threadPool.executor(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME).execute(() -> {
try {
URI uri = ModelLoaderUtils.resolvePackageLocation(repository, packagedModelId + ModelLoaderUtils.METADATA_FILE_EXTENSION);
InputStream inputStream = ModelLoaderUtils.getInputStreamFromModelRepository(uri);
Expand Down
Loading

0 comments on commit 4fe2851

Please sign in to comment.