Skip to content

Commit

Permalink
Stop processing search requests when _msearch is canceled (opensearch…
Browse files Browse the repository at this point in the history
…-project#17005)

Prior to this fix, the _msearch API would keep running search requests
even after being canceled. With this change, we explicitly check if
the task has been canceled before kicking off subsequent requests.

---------

Signed-off-by: Michael Froh <[email protected]>
  • Loading branch information
msfroh authored and prudhvigodithi committed Jan 25, 2025
1 parent 0d7ac2c commit 6daa4f2
Show file tree
Hide file tree
Showing 18 changed files with 158 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fix Shallow copy snapshot failures on closed index ([#16868](https://github.com/opensearch-project/OpenSearch/pull/16868))
- Fix multi-value sort for unsigned long ([#16732](https://github.com/opensearch-project/OpenSearch/pull/16732))
- The `phone-search` analyzer no longer emits the tel/sip prefix, international calling code, extension numbers and unformatted input as a token ([#16993](https://github.com/opensearch-project/OpenSearch/pull/16993))
- Stop processing search requests when _msearch request is cancelled ([#17005](https://github.com/opensearch-project/OpenSearch/pull/17005))
- Fix GRPC AUX_TRANSPORT_PORT and SETTING_GRPC_PORT settings and remove lingering HTTP terminology ([#17037](https://github.com/opensearch-project/OpenSearch/pull/17037))
- Fix exists queries on nested flat_object fields throws exception ([#16803](https://github.com/opensearch-project/OpenSearch/pull/16803))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public JarHellTask(Project project) {
public void runJarHellCheck() {
LoggedExec.javaexec(project, spec -> {
spec.environment("CLASSPATH", getClasspath().getAsPath());
spec.getMainClass().set("org.opensearch.bootstrap.JarHell");
spec.getMainClass().set("org.opensearch.common.bootstrap.JarHell");
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public class ThirdPartyAuditTask extends DefaultTask {
CliMain.EXIT_VIOLATION,
CliMain.EXIT_UNSUPPORTED_JDK
);
private static final String JDK_JAR_HELL_MAIN_CLASS = "org.opensearch.bootstrap.JdkJarHellCheck";
private static final String JDK_JAR_HELL_MAIN_CLASS = "org.opensearch.common.bootstrap.JdkJarHellCheck";

private Set<String> missingClassExcludes = new TreeSet<>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public class TestClasspathUtils {
public static void setupJarJdkClasspath(File projectRoot) {
try {
URL originLocation = TestClasspathUtils.class.getClassLoader()
.loadClass("org.opensearch.bootstrap.JdkJarHellCheck")
.loadClass("org.opensearch.common.bootstrap.JdkJarHellCheck")
.getProtectionDomain()
.getCodeSource()
.getLocation();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@
import org.bouncycastle.openpgp.operator.jcajce.JcaPGPContentVerifierBuilderProvider;
import org.opensearch.Build;
import org.opensearch.Version;
import org.opensearch.bootstrap.JarHell;
import org.opensearch.cli.EnvironmentAwareCommand;
import org.opensearch.cli.ExitCodes;
import org.opensearch.cli.Terminal;
import org.opensearch.cli.UserException;
import org.opensearch.common.SuppressForbidden;
import org.opensearch.common.bootstrap.JarHell;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.hash.MessageDigests;
import org.opensearch.common.util.io.IOUtils;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
* GitHub history for details.
*/

package org.opensearch.bootstrap;
package org.opensearch.common.bootstrap;

import org.opensearch.common.SuppressForbidden;
import org.opensearch.common.io.PathUtils;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
* GitHub history for details.
*/

package org.opensearch.bootstrap;
package org.opensearch.common.bootstrap;

import org.opensearch.common.SuppressForbidden;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
*/

/** Contains JarHell Classes */
package org.opensearch.bootstrap;
package org.opensearch.common.bootstrap;
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
* GitHub history for details.
*/

package org.opensearch.bootstrap;
package org.opensearch.common.bootstrap;

import org.opensearch.common.io.PathUtils;
import org.opensearch.core.common.Strings;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
import org.apache.tika.parser.ParserDecorator;
import org.opensearch.SpecialPermission;
import org.opensearch.bootstrap.FilePermissionUtils;
import org.opensearch.bootstrap.JarHell;
import org.opensearch.common.SuppressForbidden;
import org.opensearch.common.bootstrap.JarHell;
import org.opensearch.common.io.PathUtils;

import java.io.ByteArrayInputStream;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.tasks.TaskCancelledException;
import org.opensearch.core.tasks.TaskId;
import org.opensearch.tasks.CancellableTask;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
Expand Down Expand Up @@ -193,6 +196,19 @@ private void handleResponse(final int responseSlot, final MultiSearchResponse.It
if (responseCounter.decrementAndGet() == 0) {
assert requests.isEmpty();
finish();
} else if (isCancelled(request.request.getParentTask())) {
// Drain the rest of the queue
SearchRequestSlot request;
while ((request = requests.poll()) != null) {
responses.set(
request.responseSlot,
new MultiSearchResponse.Item(null, new TaskCancelledException("Parent task was cancelled"))
);
if (responseCounter.decrementAndGet() == 0) {
assert requests.isEmpty();
finish();
}
}
} else {
if (thread == Thread.currentThread()) {
// we are on the same thread, we need to fork to another thread to avoid recursive stack overflow on a single thread
Expand Down Expand Up @@ -220,6 +236,14 @@ private long buildTookInMillis() {
});
}

private boolean isCancelled(TaskId taskId) {
if (taskId.isSet()) {
CancellableTask task = taskManager.getCancellableTask(taskId.getId());
return task != null && task.isCancelled();
}
return false;
}

/**
* Slots a search request
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.opensearch.cli.UserException;
import org.opensearch.common.PidFile;
import org.opensearch.common.SuppressForbidden;
import org.opensearch.common.bootstrap.JarHell;
import org.opensearch.common.inject.CreationException;
import org.opensearch.common.logging.LogConfigurator;
import org.opensearch.common.logging.Loggers;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import org.opensearch.cli.Command;
import org.opensearch.common.SuppressForbidden;
import org.opensearch.common.bootstrap.JarHell;
import org.opensearch.common.io.PathUtils;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
import com.fasterxml.jackson.core.json.JsonReadFeature;

import org.opensearch.Version;
import org.opensearch.bootstrap.JarHell;
import org.opensearch.common.annotation.PublicApi;
import org.opensearch.common.bootstrap.JarHell;
import org.opensearch.common.xcontent.json.JsonXContentParser;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.StreamInput;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import org.opensearch.OpenSearchException;
import org.opensearch.Version;
import org.opensearch.action.admin.cluster.node.info.PluginsAndModules;
import org.opensearch.bootstrap.JarHell;
import org.opensearch.common.bootstrap.JarHell;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.inject.Module;
import org.opensearch.common.lifecycle.LifecycleComponent;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.tasks.CancellableTask;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskListener;
import org.opensearch.tasks.TaskManager;
import org.opensearch.telemetry.tracing.noop.NoopTracer;
import org.opensearch.test.OpenSearchTestCase;
Expand All @@ -62,7 +64,9 @@
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

Expand Down Expand Up @@ -289,4 +293,118 @@ public void testDefaultMaxConcurrentSearches() {
assertThat(result, equalTo(1));
}

public void testCancellation() {
// Initialize dependencies of TransportMultiSearchAction
Settings settings = Settings.builder().put("node.name", TransportMultiSearchActionTests.class.getSimpleName()).build();
ActionFilters actionFilters = mock(ActionFilters.class);
when(actionFilters.filters()).thenReturn(new ActionFilter[0]);
ThreadPool threadPool = new ThreadPool(settings);
TransportService transportService = new TransportService(
Settings.EMPTY,
mock(Transport.class),
threadPool,
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
boundAddress -> DiscoveryNode.createLocal(settings, boundAddress.publishAddress(), UUIDs.randomBase64UUID()),
null,
Collections.emptySet(),
NoopTracer.INSTANCE
) {
@Override
public TaskManager getTaskManager() {
return taskManager;
}
};
ClusterService clusterService = mock(ClusterService.class);
when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("test")).build());

// Keep track of the number of concurrent searches started by multi search api,
// and if there are more searches than is allowed create an error and remember that.
int maxAllowedConcurrentSearches = 1; // Allow 1 search at a time.
AtomicInteger counter = new AtomicInteger();
AtomicReference<AssertionError> errorHolder = new AtomicReference<>();
// randomize whether or not requests are executed asynchronously
ExecutorService executorService = threadPool.executor(ThreadPool.Names.GENERIC);
final Set<SearchRequest> requests = Collections.newSetFromMap(Collections.synchronizedMap(new IdentityHashMap<>()));
CountDownLatch countDownLatch = new CountDownLatch(1);
CancellableTask[] parentTask = new CancellableTask[1];
NodeClient client = new NodeClient(settings, threadPool) {
@Override
public void search(final SearchRequest request, final ActionListener<SearchResponse> listener) {
if (parentTask[0] != null && parentTask[0].isCancelled()) {
fail("Should not execute search after parent task is cancelled");
}
try {
countDownLatch.await(10, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}

requests.add(request);
executorService.execute(() -> {
counter.decrementAndGet();
listener.onResponse(
new SearchResponse(
InternalSearchResponse.empty(),
null,
0,
0,
0,
0L,
ShardSearchFailure.EMPTY_ARRAY,
SearchResponse.Clusters.EMPTY
)
);
});
}

@Override
public String getLocalNodeId() {
return "local_node_id";
}
};

TransportMultiSearchAction action = new TransportMultiSearchAction(
threadPool,
actionFilters,
transportService,
clusterService,
10,
System::nanoTime,
client
);

// Execute the multi search api and fail if we find an error after executing:
try {
/*
* Allow for a large number of search requests in a single batch as previous implementations could stack overflow if the number
* of requests in a single batch was large
*/
int numSearchRequests = scaledRandomIntBetween(1024, 8192);
MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
multiSearchRequest.maxConcurrentSearchRequests(maxAllowedConcurrentSearches);
for (int i = 0; i < numSearchRequests; i++) {
multiSearchRequest.add(new SearchRequest());
}
MultiSearchResponse[] responses = new MultiSearchResponse[1];
Exception[] exceptions = new Exception[1];
parentTask[0] = (CancellableTask) action.execute(multiSearchRequest, new TaskListener<>() {
@Override
public void onResponse(Task task, MultiSearchResponse items) {
responses[0] = items;
}

@Override
public void onFailure(Task task, Exception e) {
exceptions[0] = e;
}
});
parentTask[0].cancel("Giving up");
countDownLatch.countDown();

assertNull(responses[0]);
assertNull(exceptions[0]);
} finally {
assertTrue(OpenSearchTestCase.terminate(threadPool));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import org.apache.lucene.util.Constants;
import org.opensearch.LegacyESVersion;
import org.opensearch.Version;
import org.opensearch.bootstrap.JarHell;
import org.opensearch.common.bootstrap.JarHell;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.io.PathUtils;
import org.opensearch.common.settings.Settings;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.lucene.tests.util.LuceneTestCase;
import org.opensearch.common.Booleans;
import org.opensearch.common.SuppressForbidden;
import org.opensearch.common.bootstrap.JarHell;
import org.opensearch.common.io.PathUtils;
import org.opensearch.common.network.IfConfig;
import org.opensearch.common.network.NetworkAddress;
Expand Down

0 comments on commit 6daa4f2

Please sign in to comment.