From 6daa4f2987e275a9651d919f64f940c8d16b98ee Mon Sep 17 00:00:00 2001 From: Michael Froh Date: Fri, 24 Jan 2025 15:25:43 -0800 Subject: [PATCH] Stop processing search requests when _msearch is canceled (#17005) Prior to this fix, the _msearch API would keep running search requests even after being canceled. With this change, we explicitly check if the task has been canceled before kicking off subsequent requests. --------- Signed-off-by: Michael Froh --- CHANGELOG.md | 1 + .../gradle/precommit/JarHellTask.java | 2 +- .../gradle/precommit/ThirdPartyAuditTask.java | 2 +- .../gradle/test/TestClasspathUtils.java | 2 +- .../plugins/InstallPluginCommand.java | 2 +- .../{ => common}/bootstrap/JarHell.java | 2 +- .../bootstrap/JdkJarHellCheck.java | 2 +- .../{ => common}/bootstrap/package-info.java | 2 +- .../{ => common}/bootstrap/JarHellTests.java | 2 +- .../ingest/attachment/TikaImpl.java | 2 +- .../search/TransportMultiSearchAction.java | 24 ++++ .../org/opensearch/bootstrap/Bootstrap.java | 1 + .../org/opensearch/bootstrap/Security.java | 1 + .../org/opensearch/plugins/PluginInfo.java | 2 +- .../opensearch/plugins/PluginsService.java | 2 +- .../TransportMultiSearchActionTests.java | 118 ++++++++++++++++++ .../plugins/PluginsServiceTests.java | 2 +- .../bootstrap/BootstrapForTesting.java | 1 + 18 files changed, 158 insertions(+), 12 deletions(-) rename libs/common/src/main/java/org/opensearch/{ => common}/bootstrap/JarHell.java (99%) rename libs/common/src/main/java/org/opensearch/{ => common}/bootstrap/JdkJarHellCheck.java (98%) rename libs/common/src/main/java/org/opensearch/{ => common}/bootstrap/package-info.java (85%) rename libs/common/src/test/java/org/opensearch/{ => common}/bootstrap/JarHellTests.java (99%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 17036473e054d..ec7742b8563bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -111,6 +111,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Fix Shallow copy snapshot failures on closed index ([#16868](https://github.com/opensearch-project/OpenSearch/pull/16868)) - Fix multi-value sort for unsigned long ([#16732](https://github.com/opensearch-project/OpenSearch/pull/16732)) - The `phone-search` analyzer no longer emits the tel/sip prefix, international calling code, extension numbers and unformatted input as a token ([#16993](https://github.com/opensearch-project/OpenSearch/pull/16993)) +- Stop processing search requests when _msearch request is cancelled ([#17005](https://github.com/opensearch-project/OpenSearch/pull/17005)) - Fix GRPC AUX_TRANSPORT_PORT and SETTING_GRPC_PORT settings and remove lingering HTTP terminology ([#17037](https://github.com/opensearch-project/OpenSearch/pull/17037)) - Fix exists queries on nested flat_object fields throws exception ([#16803](https://github.com/opensearch-project/OpenSearch/pull/16803)) diff --git a/buildSrc/src/main/java/org/opensearch/gradle/precommit/JarHellTask.java b/buildSrc/src/main/java/org/opensearch/gradle/precommit/JarHellTask.java index ebe0b25a3a685..47ad8cc524a3b 100644 --- a/buildSrc/src/main/java/org/opensearch/gradle/precommit/JarHellTask.java +++ b/buildSrc/src/main/java/org/opensearch/gradle/precommit/JarHellTask.java @@ -63,7 +63,7 @@ public JarHellTask(Project project) { public void runJarHellCheck() { LoggedExec.javaexec(project, spec -> { spec.environment("CLASSPATH", getClasspath().getAsPath()); - spec.getMainClass().set("org.opensearch.bootstrap.JarHell"); + spec.getMainClass().set("org.opensearch.common.bootstrap.JarHell"); }); } diff --git a/buildSrc/src/main/java/org/opensearch/gradle/precommit/ThirdPartyAuditTask.java b/buildSrc/src/main/java/org/opensearch/gradle/precommit/ThirdPartyAuditTask.java index 2ed801b7fb9c6..70a1ed478ff63 100644 --- a/buildSrc/src/main/java/org/opensearch/gradle/precommit/ThirdPartyAuditTask.java +++ b/buildSrc/src/main/java/org/opensearch/gradle/precommit/ThirdPartyAuditTask.java @@ -94,7 +94,7 @@ public class ThirdPartyAuditTask extends DefaultTask { CliMain.EXIT_VIOLATION, CliMain.EXIT_UNSUPPORTED_JDK ); - private static final String JDK_JAR_HELL_MAIN_CLASS = "org.opensearch.bootstrap.JdkJarHellCheck"; + private static final String JDK_JAR_HELL_MAIN_CLASS = "org.opensearch.common.bootstrap.JdkJarHellCheck"; private Set missingClassExcludes = new TreeSet<>(); diff --git a/buildSrc/src/testFixtures/java/org/opensearch/gradle/test/TestClasspathUtils.java b/buildSrc/src/testFixtures/java/org/opensearch/gradle/test/TestClasspathUtils.java index ec9a5fb157ccc..84362966d7300 100644 --- a/buildSrc/src/testFixtures/java/org/opensearch/gradle/test/TestClasspathUtils.java +++ b/buildSrc/src/testFixtures/java/org/opensearch/gradle/test/TestClasspathUtils.java @@ -48,7 +48,7 @@ public class TestClasspathUtils { public static void setupJarJdkClasspath(File projectRoot) { try { URL originLocation = TestClasspathUtils.class.getClassLoader() - .loadClass("org.opensearch.bootstrap.JdkJarHellCheck") + .loadClass("org.opensearch.common.bootstrap.JdkJarHellCheck") .getProtectionDomain() .getCodeSource() .getLocation(); diff --git a/distribution/tools/plugin-cli/src/main/java/org/opensearch/plugins/InstallPluginCommand.java b/distribution/tools/plugin-cli/src/main/java/org/opensearch/plugins/InstallPluginCommand.java index 511d6974085aa..d5a0102ba86af 100644 --- a/distribution/tools/plugin-cli/src/main/java/org/opensearch/plugins/InstallPluginCommand.java +++ b/distribution/tools/plugin-cli/src/main/java/org/opensearch/plugins/InstallPluginCommand.java @@ -52,12 +52,12 @@ import org.bouncycastle.openpgp.operator.jcajce.JcaPGPContentVerifierBuilderProvider; import org.opensearch.Build; import org.opensearch.Version; -import org.opensearch.bootstrap.JarHell; import org.opensearch.cli.EnvironmentAwareCommand; import org.opensearch.cli.ExitCodes; import org.opensearch.cli.Terminal; import org.opensearch.cli.UserException; import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.bootstrap.JarHell; import org.opensearch.common.collect.Tuple; import org.opensearch.common.hash.MessageDigests; import org.opensearch.common.util.io.IOUtils; diff --git a/libs/common/src/main/java/org/opensearch/bootstrap/JarHell.java b/libs/common/src/main/java/org/opensearch/common/bootstrap/JarHell.java similarity index 99% rename from libs/common/src/main/java/org/opensearch/bootstrap/JarHell.java rename to libs/common/src/main/java/org/opensearch/common/bootstrap/JarHell.java index fc5e364241d12..470b92aaa2fab 100644 --- a/libs/common/src/main/java/org/opensearch/bootstrap/JarHell.java +++ b/libs/common/src/main/java/org/opensearch/common/bootstrap/JarHell.java @@ -30,7 +30,7 @@ * GitHub history for details. */ -package org.opensearch.bootstrap; +package org.opensearch.common.bootstrap; import org.opensearch.common.SuppressForbidden; import org.opensearch.common.io.PathUtils; diff --git a/libs/common/src/main/java/org/opensearch/bootstrap/JdkJarHellCheck.java b/libs/common/src/main/java/org/opensearch/common/bootstrap/JdkJarHellCheck.java similarity index 98% rename from libs/common/src/main/java/org/opensearch/bootstrap/JdkJarHellCheck.java rename to libs/common/src/main/java/org/opensearch/common/bootstrap/JdkJarHellCheck.java index 97b323975db0a..2a25f32b363c6 100644 --- a/libs/common/src/main/java/org/opensearch/bootstrap/JdkJarHellCheck.java +++ b/libs/common/src/main/java/org/opensearch/common/bootstrap/JdkJarHellCheck.java @@ -29,7 +29,7 @@ * GitHub history for details. */ -package org.opensearch.bootstrap; +package org.opensearch.common.bootstrap; import org.opensearch.common.SuppressForbidden; diff --git a/libs/common/src/main/java/org/opensearch/bootstrap/package-info.java b/libs/common/src/main/java/org/opensearch/common/bootstrap/package-info.java similarity index 85% rename from libs/common/src/main/java/org/opensearch/bootstrap/package-info.java rename to libs/common/src/main/java/org/opensearch/common/bootstrap/package-info.java index f522b1bb91444..8d05b614b7f38 100644 --- a/libs/common/src/main/java/org/opensearch/bootstrap/package-info.java +++ b/libs/common/src/main/java/org/opensearch/common/bootstrap/package-info.java @@ -7,4 +7,4 @@ */ /** Contains JarHell Classes */ -package org.opensearch.bootstrap; +package org.opensearch.common.bootstrap; diff --git a/libs/common/src/test/java/org/opensearch/bootstrap/JarHellTests.java b/libs/common/src/test/java/org/opensearch/common/bootstrap/JarHellTests.java similarity index 99% rename from libs/common/src/test/java/org/opensearch/bootstrap/JarHellTests.java rename to libs/common/src/test/java/org/opensearch/common/bootstrap/JarHellTests.java index d1851850e78e1..549c4bd652e2f 100644 --- a/libs/common/src/test/java/org/opensearch/bootstrap/JarHellTests.java +++ b/libs/common/src/test/java/org/opensearch/common/bootstrap/JarHellTests.java @@ -30,7 +30,7 @@ * GitHub history for details. */ -package org.opensearch.bootstrap; +package org.opensearch.common.bootstrap; import org.opensearch.common.io.PathUtils; import org.opensearch.core.common.Strings; diff --git a/plugins/ingest-attachment/src/main/java/org/opensearch/ingest/attachment/TikaImpl.java b/plugins/ingest-attachment/src/main/java/org/opensearch/ingest/attachment/TikaImpl.java index fe783e5ddb675..d999d20537485 100644 --- a/plugins/ingest-attachment/src/main/java/org/opensearch/ingest/attachment/TikaImpl.java +++ b/plugins/ingest-attachment/src/main/java/org/opensearch/ingest/attachment/TikaImpl.java @@ -41,8 +41,8 @@ import org.apache.tika.parser.ParserDecorator; import org.opensearch.SpecialPermission; import org.opensearch.bootstrap.FilePermissionUtils; -import org.opensearch.bootstrap.JarHell; import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.bootstrap.JarHell; import org.opensearch.common.io.PathUtils; import java.io.ByteArrayInputStream; diff --git a/server/src/main/java/org/opensearch/action/search/TransportMultiSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportMultiSearchAction.java index 146b4010af4b3..dcb2ce6eb88da 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportMultiSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportMultiSearchAction.java @@ -44,6 +44,9 @@ import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.tasks.TaskCancelledException; +import org.opensearch.core.tasks.TaskId; +import org.opensearch.tasks.CancellableTask; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -193,6 +196,19 @@ private void handleResponse(final int responseSlot, final MultiSearchResponse.It if (responseCounter.decrementAndGet() == 0) { assert requests.isEmpty(); finish(); + } else if (isCancelled(request.request.getParentTask())) { + // Drain the rest of the queue + SearchRequestSlot request; + while ((request = requests.poll()) != null) { + responses.set( + request.responseSlot, + new MultiSearchResponse.Item(null, new TaskCancelledException("Parent task was cancelled")) + ); + if (responseCounter.decrementAndGet() == 0) { + assert requests.isEmpty(); + finish(); + } + } } else { if (thread == Thread.currentThread()) { // we are on the same thread, we need to fork to another thread to avoid recursive stack overflow on a single thread @@ -220,6 +236,14 @@ private long buildTookInMillis() { }); } + private boolean isCancelled(TaskId taskId) { + if (taskId.isSet()) { + CancellableTask task = taskManager.getCancellableTask(taskId.getId()); + return task != null && task.isCancelled(); + } + return false; + } + /** * Slots a search request * diff --git a/server/src/main/java/org/opensearch/bootstrap/Bootstrap.java b/server/src/main/java/org/opensearch/bootstrap/Bootstrap.java index 4e167d10b99fa..95498f2bcbcd1 100644 --- a/server/src/main/java/org/opensearch/bootstrap/Bootstrap.java +++ b/server/src/main/java/org/opensearch/bootstrap/Bootstrap.java @@ -47,6 +47,7 @@ import org.opensearch.cli.UserException; import org.opensearch.common.PidFile; import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.bootstrap.JarHell; import org.opensearch.common.inject.CreationException; import org.opensearch.common.logging.LogConfigurator; import org.opensearch.common.logging.Loggers; diff --git a/server/src/main/java/org/opensearch/bootstrap/Security.java b/server/src/main/java/org/opensearch/bootstrap/Security.java index 563a026109059..acf2d7ec6a5ac 100644 --- a/server/src/main/java/org/opensearch/bootstrap/Security.java +++ b/server/src/main/java/org/opensearch/bootstrap/Security.java @@ -34,6 +34,7 @@ import org.opensearch.cli.Command; import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.bootstrap.JarHell; import org.opensearch.common.io.PathUtils; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; diff --git a/server/src/main/java/org/opensearch/plugins/PluginInfo.java b/server/src/main/java/org/opensearch/plugins/PluginInfo.java index 4ff699e8017ba..323e061aea567 100644 --- a/server/src/main/java/org/opensearch/plugins/PluginInfo.java +++ b/server/src/main/java/org/opensearch/plugins/PluginInfo.java @@ -36,8 +36,8 @@ import com.fasterxml.jackson.core.json.JsonReadFeature; import org.opensearch.Version; -import org.opensearch.bootstrap.JarHell; import org.opensearch.common.annotation.PublicApi; +import org.opensearch.common.bootstrap.JarHell; import org.opensearch.common.xcontent.json.JsonXContentParser; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; diff --git a/server/src/main/java/org/opensearch/plugins/PluginsService.java b/server/src/main/java/org/opensearch/plugins/PluginsService.java index 9bc1f1334122e..72b8ada94a0d1 100644 --- a/server/src/main/java/org/opensearch/plugins/PluginsService.java +++ b/server/src/main/java/org/opensearch/plugins/PluginsService.java @@ -43,7 +43,7 @@ import org.opensearch.OpenSearchException; import org.opensearch.Version; import org.opensearch.action.admin.cluster.node.info.PluginsAndModules; -import org.opensearch.bootstrap.JarHell; +import org.opensearch.common.bootstrap.JarHell; import org.opensearch.common.collect.Tuple; import org.opensearch.common.inject.Module; import org.opensearch.common.lifecycle.LifecycleComponent; diff --git a/server/src/test/java/org/opensearch/action/search/TransportMultiSearchActionTests.java b/server/src/test/java/org/opensearch/action/search/TransportMultiSearchActionTests.java index 48970e2b96add..45980e7137ce4 100644 --- a/server/src/test/java/org/opensearch/action/search/TransportMultiSearchActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/TransportMultiSearchActionTests.java @@ -49,7 +49,9 @@ import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.tasks.CancellableTask; import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskListener; import org.opensearch.tasks.TaskManager; import org.opensearch.telemetry.tracing.noop.NoopTracer; import org.opensearch.test.OpenSearchTestCase; @@ -62,7 +64,9 @@ import java.util.IdentityHashMap; import java.util.List; import java.util.Set; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -289,4 +293,118 @@ public void testDefaultMaxConcurrentSearches() { assertThat(result, equalTo(1)); } + public void testCancellation() { + // Initialize dependencies of TransportMultiSearchAction + Settings settings = Settings.builder().put("node.name", TransportMultiSearchActionTests.class.getSimpleName()).build(); + ActionFilters actionFilters = mock(ActionFilters.class); + when(actionFilters.filters()).thenReturn(new ActionFilter[0]); + ThreadPool threadPool = new ThreadPool(settings); + TransportService transportService = new TransportService( + Settings.EMPTY, + mock(Transport.class), + threadPool, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + boundAddress -> DiscoveryNode.createLocal(settings, boundAddress.publishAddress(), UUIDs.randomBase64UUID()), + null, + Collections.emptySet(), + NoopTracer.INSTANCE + ) { + @Override + public TaskManager getTaskManager() { + return taskManager; + } + }; + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("test")).build()); + + // Keep track of the number of concurrent searches started by multi search api, + // and if there are more searches than is allowed create an error and remember that. + int maxAllowedConcurrentSearches = 1; // Allow 1 search at a time. + AtomicInteger counter = new AtomicInteger(); + AtomicReference errorHolder = new AtomicReference<>(); + // randomize whether or not requests are executed asynchronously + ExecutorService executorService = threadPool.executor(ThreadPool.Names.GENERIC); + final Set requests = Collections.newSetFromMap(Collections.synchronizedMap(new IdentityHashMap<>())); + CountDownLatch countDownLatch = new CountDownLatch(1); + CancellableTask[] parentTask = new CancellableTask[1]; + NodeClient client = new NodeClient(settings, threadPool) { + @Override + public void search(final SearchRequest request, final ActionListener listener) { + if (parentTask[0] != null && parentTask[0].isCancelled()) { + fail("Should not execute search after parent task is cancelled"); + } + try { + countDownLatch.await(10, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + + requests.add(request); + executorService.execute(() -> { + counter.decrementAndGet(); + listener.onResponse( + new SearchResponse( + InternalSearchResponse.empty(), + null, + 0, + 0, + 0, + 0L, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ) + ); + }); + } + + @Override + public String getLocalNodeId() { + return "local_node_id"; + } + }; + + TransportMultiSearchAction action = new TransportMultiSearchAction( + threadPool, + actionFilters, + transportService, + clusterService, + 10, + System::nanoTime, + client + ); + + // Execute the multi search api and fail if we find an error after executing: + try { + /* + * Allow for a large number of search requests in a single batch as previous implementations could stack overflow if the number + * of requests in a single batch was large + */ + int numSearchRequests = scaledRandomIntBetween(1024, 8192); + MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); + multiSearchRequest.maxConcurrentSearchRequests(maxAllowedConcurrentSearches); + for (int i = 0; i < numSearchRequests; i++) { + multiSearchRequest.add(new SearchRequest()); + } + MultiSearchResponse[] responses = new MultiSearchResponse[1]; + Exception[] exceptions = new Exception[1]; + parentTask[0] = (CancellableTask) action.execute(multiSearchRequest, new TaskListener<>() { + @Override + public void onResponse(Task task, MultiSearchResponse items) { + responses[0] = items; + } + + @Override + public void onFailure(Task task, Exception e) { + exceptions[0] = e; + } + }); + parentTask[0].cancel("Giving up"); + countDownLatch.countDown(); + + assertNull(responses[0]); + assertNull(exceptions[0]); + } finally { + assertTrue(OpenSearchTestCase.terminate(threadPool)); + } + } } diff --git a/server/src/test/java/org/opensearch/plugins/PluginsServiceTests.java b/server/src/test/java/org/opensearch/plugins/PluginsServiceTests.java index f5702fa1a7ade..cb549eafc0d21 100644 --- a/server/src/test/java/org/opensearch/plugins/PluginsServiceTests.java +++ b/server/src/test/java/org/opensearch/plugins/PluginsServiceTests.java @@ -38,7 +38,7 @@ import org.apache.lucene.util.Constants; import org.opensearch.LegacyESVersion; import org.opensearch.Version; -import org.opensearch.bootstrap.JarHell; +import org.opensearch.common.bootstrap.JarHell; import org.opensearch.common.collect.Tuple; import org.opensearch.common.io.PathUtils; import org.opensearch.common.settings.Settings; diff --git a/test/framework/src/main/java/org/opensearch/bootstrap/BootstrapForTesting.java b/test/framework/src/main/java/org/opensearch/bootstrap/BootstrapForTesting.java index 933385dedcf49..76c7ce0628aac 100644 --- a/test/framework/src/main/java/org/opensearch/bootstrap/BootstrapForTesting.java +++ b/test/framework/src/main/java/org/opensearch/bootstrap/BootstrapForTesting.java @@ -39,6 +39,7 @@ import org.apache.lucene.tests.util.LuceneTestCase; import org.opensearch.common.Booleans; import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.bootstrap.JarHell; import org.opensearch.common.io.PathUtils; import org.opensearch.common.network.IfConfig; import org.opensearch.common.network.NetworkAddress;