diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/OpenAIDataIngestion.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/OpenAIDataIngestion.java index 70d4120d2b..8dc94894ef 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/OpenAIDataIngestion.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/OpenAIDataIngestion.java @@ -64,26 +64,42 @@ private double ingestSingleSource(String fileId, MLBatchIngestionInput mlBatchIn connection.setRequestMethod("GET"); connection.setRequestProperty("Authorization", "Bearer " + apiKey); - InputStreamReader inputStreamReader = AccessController - .doPrivileged((PrivilegedExceptionAction) () -> new InputStreamReader(connection.getInputStream())); - BufferedReader reader = new BufferedReader(inputStreamReader); - - List linesBuffer = new ArrayList<>(); - String line; - int lineCount = 0; - // Atomic counters for tracking success and failure - AtomicInteger successfulBatches = new AtomicInteger(0); - AtomicInteger failedBatches = new AtomicInteger(0); - // List of CompletableFutures to track batch ingestion operations - List> futures = new ArrayList<>(); - - while ((line = reader.readLine()) != null) { - linesBuffer.add(line); - lineCount++; - - // Process every 100 lines - if (lineCount == 100) { - // Create a CompletableFuture that will be completed by the bulkResponseListener + try ( + InputStreamReader inputStreamReader = AccessController + .doPrivileged((PrivilegedExceptionAction) () -> new InputStreamReader(connection.getInputStream())); + BufferedReader reader = new BufferedReader(inputStreamReader) + ) { + List linesBuffer = new ArrayList<>(); + String line; + int lineCount = 0; + // Atomic counters for tracking success and failure + AtomicInteger successfulBatches = new AtomicInteger(0); + AtomicInteger failedBatches = new AtomicInteger(0); + // List of CompletableFutures to track batch ingestion operations + List> futures = new ArrayList<>(); + + while ((line = reader.readLine()) != null) { + linesBuffer.add(line); + lineCount++; + + // Process every 100 lines + if (lineCount % 100 == 0) { + // Create a CompletableFuture that will be completed by the bulkResponseListener + CompletableFuture future = new CompletableFuture<>(); + batchIngest( + linesBuffer, + mlBatchIngestionInput, + getBulkResponseListener(successfulBatches, failedBatches, future), + sourceIndex, + isSoleSource + ); + + futures.add(future); + linesBuffer.clear(); + } + } + // Process any remaining lines in the buffer + if (!linesBuffer.isEmpty()) { CompletableFuture future = new CompletableFuture<>(); batchIngest( linesBuffer, @@ -92,32 +108,17 @@ private double ingestSingleSource(String fileId, MLBatchIngestionInput mlBatchIn sourceIndex, isSoleSource ); - futures.add(future); - linesBuffer.clear(); - lineCount = 0; } - } - // Process any remaining lines in the buffer - if (!linesBuffer.isEmpty()) { - CompletableFuture future = new CompletableFuture<>(); - batchIngest( - linesBuffer, - mlBatchIngestionInput, - getBulkResponseListener(successfulBatches, failedBatches, future), - sourceIndex, - isSoleSource - ); - futures.add(future); - } - reader.close(); - // Combine all futures and wait for completion - CompletableFuture allFutures = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])); - // Wait for all tasks to complete - allFutures.join(); - int totalBatches = successfulBatches.get() + failedBatches.get(); - successRate = (double) successfulBatches.get() / totalBatches * 100; + reader.close(); + // Combine all futures and wait for completion + CompletableFuture allFutures = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])); + // Wait for all tasks to complete + allFutures.join(); + int totalBatches = successfulBatches.get() + failedBatches.get(); + successRate = (totalBatches == 0) ? 100 : (double) successfulBatches.get() / totalBatches * 100; + } } catch (PrivilegedActionException e) { throw new RuntimeException("Failed to read from OpenAI file API: ", e); } catch (Exception e) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java index 4306bb1bf3..b6fb3e1226 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java @@ -81,11 +81,11 @@ public double ingestSingleSource( GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket(bucketName).key(keyName).build(); double successRate = 0; - try { + try ( ResponseInputStream s3is = AccessController .doPrivileged((PrivilegedExceptionAction>) () -> s3.getObject(getObjectRequest)); - BufferedReader reader = new BufferedReader(new InputStreamReader(s3is, StandardCharsets.UTF_8)); - + BufferedReader reader = new BufferedReader(new InputStreamReader(s3is, StandardCharsets.UTF_8)) + ) { List linesBuffer = new ArrayList<>(); String line; int lineCount = 0; @@ -100,7 +100,7 @@ public double ingestSingleSource( lineCount++; // Process every 100 lines - if (lineCount == 100) { + if (lineCount % 100 == 0) { // Create a CompletableFuture that will be completed by the bulkResponseListener CompletableFuture future = new CompletableFuture<>(); batchIngest( @@ -113,7 +113,6 @@ public double ingestSingleSource( futures.add(future); linesBuffer.clear(); - lineCount = 0; } } // Process any remaining lines in the buffer @@ -138,7 +137,7 @@ public double ingestSingleSource( allFutures.join(); int totalBatches = successfulBatches.get() + failedBatches.get(); - successRate = (double) successfulBatches.get() / totalBatches * 100; + successRate = (totalBatches == 0) ? 100 : (double) successfulBatches.get() / totalBatches * 100; } catch (S3Exception e) { log.error("Error reading from S3: " + e.awsErrorDetails().errorMessage()); throw e; diff --git a/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java b/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java index facc27e845..cf03d0f11a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.common.MLTask.STATE_FIELD; import static org.opensearch.ml.common.MLTaskState.COMPLETED; import static org.opensearch.ml.common.MLTaskState.FAILED; +import static org.opensearch.ml.plugin.MachineLearningPlugin.TRAIN_THREAD_POOL; import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; import java.time.Instant; @@ -37,6 +38,7 @@ import org.opensearch.ml.task.MLTaskManager; import org.opensearch.ml.utils.MLExceptionUtils; import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; import lombok.extern.log4j.Log4j2; @@ -50,18 +52,21 @@ public class TransportBatchIngestionAction extends HandledTransportAction { + double successRate = ingestable.ingest(mlBatchIngestionInput); + handleSuccessRate(successRate, taskId); + }); } catch (Exception ex) { log.error("Failed in batch ingestion", ex); mlTaskManager diff --git a/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java index d6d4a99e6f..7b3766dadf 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java @@ -45,6 +45,7 @@ import org.opensearch.ml.task.MLTaskManager; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; public class TransportBatchIngestionActionTests extends OpenSearchTestCase { @@ -62,6 +63,8 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase { private Task task; @Mock ActionListener actionListener; + @Mock + ThreadPool threadPool; private TransportBatchIngestionAction batchAction; private MLBatchIngestionInput batchInput; @@ -69,7 +72,7 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - batchAction = new TransportBatchIngestionAction(transportService, actionFilters, client, mlTaskManager); + batchAction = new TransportBatchIngestionAction(transportService, actionFilters, client, mlTaskManager, threadPool); Map fieldMap = new HashMap<>(); fieldMap.put("input", "$.content");