Skip to content

Commit

Permalink
use dedicated thread pool for ingestion
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Sep 4, 2024
1 parent a2dabaa commit 993023b
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,26 +64,42 @@ private double ingestSingleSource(String fileId, MLBatchIngestionInput mlBatchIn
connection.setRequestMethod("GET");
connection.setRequestProperty("Authorization", "Bearer " + apiKey);

InputStreamReader inputStreamReader = AccessController
.doPrivileged((PrivilegedExceptionAction<InputStreamReader>) () -> new InputStreamReader(connection.getInputStream()));
BufferedReader reader = new BufferedReader(inputStreamReader);

List<String> linesBuffer = new ArrayList<>();
String line;
int lineCount = 0;
// Atomic counters for tracking success and failure
AtomicInteger successfulBatches = new AtomicInteger(0);
AtomicInteger failedBatches = new AtomicInteger(0);
// List of CompletableFutures to track batch ingestion operations
List<CompletableFuture<Void>> futures = new ArrayList<>();

while ((line = reader.readLine()) != null) {
linesBuffer.add(line);
lineCount++;

// Process every 100 lines
if (lineCount == 100) {
// Create a CompletableFuture that will be completed by the bulkResponseListener
try (
InputStreamReader inputStreamReader = AccessController
.doPrivileged((PrivilegedExceptionAction<InputStreamReader>) () -> new InputStreamReader(connection.getInputStream()));
BufferedReader reader = new BufferedReader(inputStreamReader)
) {
List<String> linesBuffer = new ArrayList<>();
String line;
int lineCount = 0;
// Atomic counters for tracking success and failure
AtomicInteger successfulBatches = new AtomicInteger(0);
AtomicInteger failedBatches = new AtomicInteger(0);
// List of CompletableFutures to track batch ingestion operations
List<CompletableFuture<Void>> futures = new ArrayList<>();

while ((line = reader.readLine()) != null) {
linesBuffer.add(line);
lineCount++;

// Process every 100 lines
if (lineCount % 100 == 0) {
// Create a CompletableFuture that will be completed by the bulkResponseListener
CompletableFuture<Void> future = new CompletableFuture<>();
batchIngest(
linesBuffer,
mlBatchIngestionInput,
getBulkResponseListener(successfulBatches, failedBatches, future),
sourceIndex,
isSoleSource
);

futures.add(future);
linesBuffer.clear();
}
}
// Process any remaining lines in the buffer
if (!linesBuffer.isEmpty()) {
CompletableFuture<Void> future = new CompletableFuture<>();
batchIngest(
linesBuffer,
Expand All @@ -92,32 +108,17 @@ private double ingestSingleSource(String fileId, MLBatchIngestionInput mlBatchIn
sourceIndex,
isSoleSource
);

futures.add(future);
linesBuffer.clear();
lineCount = 0;
}
}
// Process any remaining lines in the buffer
if (!linesBuffer.isEmpty()) {
CompletableFuture<Void> future = new CompletableFuture<>();
batchIngest(
linesBuffer,
mlBatchIngestionInput,
getBulkResponseListener(successfulBatches, failedBatches, future),
sourceIndex,
isSoleSource
);
futures.add(future);
}

reader.close();
// Combine all futures and wait for completion
CompletableFuture<Void> allFutures = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]));
// Wait for all tasks to complete
allFutures.join();
int totalBatches = successfulBatches.get() + failedBatches.get();
successRate = (double) successfulBatches.get() / totalBatches * 100;
reader.close();
// Combine all futures and wait for completion
CompletableFuture<Void> allFutures = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]));
// Wait for all tasks to complete
allFutures.join();
int totalBatches = successfulBatches.get() + failedBatches.get();
successRate = (totalBatches == 0) ? 100 : (double) successfulBatches.get() / totalBatches * 100;
}
} catch (PrivilegedActionException e) {
throw new RuntimeException("Failed to read from OpenAI file API: ", e);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ public double ingestSingleSource(
GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket(bucketName).key(keyName).build();
double successRate = 0;

try {
try (
ResponseInputStream<GetObjectResponse> s3is = AccessController
.doPrivileged((PrivilegedExceptionAction<ResponseInputStream<GetObjectResponse>>) () -> s3.getObject(getObjectRequest));
BufferedReader reader = new BufferedReader(new InputStreamReader(s3is, StandardCharsets.UTF_8));

BufferedReader reader = new BufferedReader(new InputStreamReader(s3is, StandardCharsets.UTF_8))
) {
List<String> linesBuffer = new ArrayList<>();
String line;
int lineCount = 0;
Expand All @@ -100,7 +100,7 @@ public double ingestSingleSource(
lineCount++;

// Process every 100 lines
if (lineCount == 100) {
if (lineCount % 100 == 0) {
// Create a CompletableFuture that will be completed by the bulkResponseListener
CompletableFuture<Void> future = new CompletableFuture<>();
batchIngest(
Expand All @@ -113,7 +113,6 @@ public double ingestSingleSource(

futures.add(future);
linesBuffer.clear();
lineCount = 0;
}
}
// Process any remaining lines in the buffer
Expand All @@ -138,7 +137,7 @@ public double ingestSingleSource(
allFutures.join();

int totalBatches = successfulBatches.get() + failedBatches.get();
successRate = (double) successfulBatches.get() / totalBatches * 100;
successRate = (totalBatches == 0) ? 100 : (double) successfulBatches.get() / totalBatches * 100;
} catch (S3Exception e) {
log.error("Error reading from S3: " + e.awsErrorDetails().errorMessage());
throw e;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
import static org.opensearch.ml.common.MLTaskState.COMPLETED;
import static org.opensearch.ml.common.MLTaskState.FAILED;
import static org.opensearch.ml.plugin.MachineLearningPlugin.TRAIN_THREAD_POOL;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;

import java.time.Instant;
Expand Down Expand Up @@ -37,6 +38,7 @@
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.MLExceptionUtils;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import lombok.extern.log4j.Log4j2;
Expand All @@ -50,18 +52,21 @@ public class TransportBatchIngestionAction extends HandledTransportAction<Action
TransportService transportService;
MLTaskManager mlTaskManager;
private final Client client;
private ThreadPool threadPool;

@Inject
public TransportBatchIngestionAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
MLTaskManager mlTaskManager
MLTaskManager mlTaskManager,
ThreadPool threadPool
) {
super(MLBatchIngestionAction.NAME, transportService, actionFilters, MLBatchIngestionRequest::new);
this.transportService = transportService;
this.client = client;
this.mlTaskManager = mlTaskManager;
this.threadPool = threadPool;
}

@Override
Expand All @@ -87,8 +92,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatc
listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE);
Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class);
double successRate = ingestable.ingest(mlBatchIngestionInput);
handleSuccessRate(successRate, taskId);
threadPool.executor(TRAIN_THREAD_POOL).execute(() -> {
double successRate = ingestable.ingest(mlBatchIngestionInput);
handleSuccessRate(successRate, taskId);
});
} catch (Exception ex) {
log.error("Failed in batch ingestion", ex);
mlTaskManager
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
Expand All @@ -62,14 +63,16 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
private Task task;
@Mock
ActionListener<MLBatchIngestionResponse> actionListener;
@Mock
ThreadPool threadPool;

private TransportBatchIngestionAction batchAction;
private MLBatchIngestionInput batchInput;

@Before
public void setup() {
MockitoAnnotations.openMocks(this);
batchAction = new TransportBatchIngestionAction(transportService, actionFilters, client, mlTaskManager);
batchAction = new TransportBatchIngestionAction(transportService, actionFilters, client, mlTaskManager, threadPool);

Map<String, Object> fieldMap = new HashMap<>();
fieldMap.put("input", "$.content");
Expand Down

0 comments on commit 993023b

Please sign in to comment.