Skip to content

Commit

Permalink
add transport interceptor to populate queryGroupId in task headers
Browse files Browse the repository at this point in the history
Signed-off-by: Kaushal Kumar <[email protected]>
  • Loading branch information
kaushalmahi12 committed Jul 23, 2024
1 parent 2e13b79 commit c926996
Show file tree
Hide file tree
Showing 9 changed files with 248 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
import org.opensearch.transport.RemoteTransportException;
import org.opensearch.transport.Transport;
import org.opensearch.transport.TransportService;
import org.opensearch.wlm.QueryGroupConstants;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -444,7 +445,11 @@ private void executeRequest(

// At this point either the QUERY_GROUP_ID header will be present in ThreadContext either via ActionFilter
// or HTTP header (HTTP header will be deprecated once ActionFilter is implemented)
task.addQueryGroupHeaders(threadPool.getThreadContext());
task.addHeader(
QueryGroupConstants.QUERY_GROUP_ID_HEADER,
threadPool.getThreadContext(),
QueryGroupConstants.DEFAULT_QUERY_GROUP_ID_SUPPLIER
);

PipelinedRequest searchRequest;
ActionListener<SearchResponse> listener;
Expand Down
6 changes: 0 additions & 6 deletions server/src/main/java/org/opensearch/search/SearchService.java
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,6 @@ public void executeDfsPhase(
ActionListener<SearchPhaseResult> listener
) {
final IndexShard shard = getShard(request);
task.addQueryGroupHeaders(threadPool.getThreadContext());
rewriteAndFetchShardRequest(shard, request, new ActionListener<ShardSearchRequest>() {
@Override
public void onResponse(ShardSearchRequest rewritten) {
Expand Down Expand Up @@ -611,7 +610,6 @@ public void executeQueryPhase(
) {
assert request.canReturnNullResponseIfMatchNoDocs() == false || request.numberOfShards() > 1
: "empty responses require more than one shard";
task.addQueryGroupHeaders(threadPool.getThreadContext());
final IndexShard shard = getShard(request);
rewriteAndFetchShardRequest(shard, request, new ActionListener<ShardSearchRequest>() {
@Override
Expand Down Expand Up @@ -721,7 +719,6 @@ public void executeQueryPhase(
freeReaderContext(readerContext.id());
throw e;
}
task.addQueryGroupHeaders(threadPool.getThreadContext());
runAsync(getExecutor(readerContext.indexShard()), () -> {
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(null);
try (
Expand All @@ -748,7 +745,6 @@ public void executeQueryPhase(QuerySearchRequest request, SearchShardTask task,
final ReaderContext readerContext = findReaderContext(request.contextId(), request.shardSearchRequest());
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.shardSearchRequest());
final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest));
task.addQueryGroupHeaders(threadPool.getThreadContext());
runAsync(getExecutor(readerContext.indexShard()), () -> {
readerContext.setAggregatedDfs(request.dfs());
try (
Expand Down Expand Up @@ -799,7 +795,6 @@ public void executeFetchPhase(
) {
final LegacyReaderContext readerContext = (LegacyReaderContext) findReaderContext(request.contextId(), request);
final Releasable markAsUsed;
task.addQueryGroupHeaders(threadPool.getThreadContext());
try {
markAsUsed = readerContext.markAsUsed(getScrollKeepAlive(request.scroll()));
} catch (Exception e) {
Expand Down Expand Up @@ -835,7 +830,6 @@ public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, A
final ReaderContext readerContext = findReaderContext(request.contextId(), request);
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest());
final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest));
task.addQueryGroupHeaders(threadPool.getThreadContext());
runAsync(getExecutor(readerContext.indexShard()), () -> {
try (SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, false)) {
if (request.lastEmittedDoc() != null) {
Expand Down
13 changes: 7 additions & 6 deletions server/src/main/java/org/opensearch/tasks/Task.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;

/**
* Current task information
Expand Down Expand Up @@ -529,20 +530,20 @@ public String getHeader(String header) {
* hence it is not possible to copy this header from request headers. This header is required to group the tasks into queryGroups to account for the QueryGroup level resource footprint
* @param threadContext current thread context
*/
public void addQueryGroupHeaders(final ThreadContext threadContext) {
public void addHeader(final String headerName, final ThreadContext threadContext, final Supplier<String> defaultValueSupplier) {
// For now this header will be coming from HTTP headers but in second phase this header

// We will use this constant from QueryGroup Service once the framework changes are done
final String QUERY_GROUP_ID_HEADER = "queryGroupId";
String requestQueryGroupId = threadContext.getHeader(QUERY_GROUP_ID_HEADER);

if (requestQueryGroupId == null) {
requestQueryGroupId = "DEFAULT_QUERY_GROUP_ID"; // TODO: move this constant either to QueryGroupService or Tracking equivalent
String headerValue = threadContext.getHeader(headerName);

if (headerValue == null) {
headerValue = defaultValueSupplier.get();
}

final Map<String, String> newHeaders = new HashMap<>(headers);

newHeaders.put(QUERY_GROUP_ID_HEADER, requestQueryGroupId);
newHeaders.put(headerName, headerValue);

this.headers = newHeaders;
}
Expand Down
19 changes: 19 additions & 0 deletions server/src/main/java/org/opensearch/wlm/QueryGroupConstants.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.wlm;

import java.util.function.Supplier;

/**
* This class will hold all the QueryGroup related constants
*/
public class QueryGroupConstants {
public static final String QUERY_GROUP_ID_HEADER = "queryGroupId";
public static final Supplier<String> DEFAULT_QUERY_GROUP_ID_SUPPLIER = () -> "DEFAULT_QUERY_GROUP";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.wlm;

import org.opensearch.search.fetch.ShardFetchRequest;
import org.opensearch.search.internal.InternalScrollSearchRequest;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.query.QuerySearchRequest;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportChannel;
import org.opensearch.transport.TransportRequest;
import org.opensearch.transport.TransportRequestHandler;

/**
* This class is mainly used to populate the queryGroupId header
* @param <T> T is Search related request
*/
public class SearchWorkloadTransportHandler<T extends TransportRequest> implements TransportRequestHandler<T> {

private final ThreadPool threadPool;
TransportRequestHandler<T> actualHandler;

public SearchWorkloadTransportHandler(ThreadPool threadPool, TransportRequestHandler<T> actualHandler) {
this.threadPool = threadPool;
this.actualHandler = actualHandler;
}

@Override
public void messageReceived(T request, TransportChannel channel, Task task) throws Exception {
if (isSearchWorkloadRequest(request)) {
task.addHeader(
QueryGroupConstants.QUERY_GROUP_ID_HEADER,
threadPool.getThreadContext(),
QueryGroupConstants.DEFAULT_QUERY_GROUP_ID_SUPPLIER
);
}
actualHandler.messageReceived(request, channel, task);
}

private boolean isSearchWorkloadRequest(TransportRequest request) {
return (request instanceof ShardSearchRequest)
|| (request instanceof ShardFetchRequest)
|| (request instanceof InternalScrollSearchRequest)
|| (request instanceof QuerySearchRequest);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.wlm;

import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportInterceptor;
import org.opensearch.transport.TransportRequest;
import org.opensearch.transport.TransportRequestHandler;

/**
* This class is used to intercept search traffic requests and populate the queryGroupId header in task headers
* TODO: We still need to add this interceptor in {@link org.opensearch.node.Node} class to enable,
* leaving it until the feature is tested and done.
*/
public class SearchWorkloadTransportInterceptor implements TransportInterceptor {
private final ThreadPool threadPool;

public SearchWorkloadTransportInterceptor(ThreadPool threadPool) {
this.threadPool = threadPool;
}

@Override
public <T extends TransportRequest> TransportRequestHandler<T> interceptHandler(
String action,
String executor,
boolean forceExecution,
TransportRequestHandler<T> actualHandler
) {
return new SearchWorkloadTransportHandler<T>(threadPool, actualHandler);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.wlm.QueryGroupConstants;

import java.nio.charset.StandardCharsets;
import java.util.Collections;
Expand Down Expand Up @@ -253,9 +254,9 @@ public void testAddQueryGroupHeaders() {

threadPool.getThreadContext().putHeader("queryGroupId", "afakgkagj09532059");

task.addQueryGroupHeaders(threadPool.getThreadContext());
task.addHeader(QueryGroupConstants.QUERY_GROUP_ID_HEADER, threadPool.getThreadContext(), () -> "default_val");

String queryGroupId = task.getHeader("queryGroupId");
String queryGroupId = task.getHeader(QueryGroupConstants.QUERY_GROUP_ID_HEADER);

assertEquals("afakgkagj09532059", queryGroupId);
} finally {
Expand All @@ -275,11 +276,11 @@ public void testAddQueryGroupHeadersWhenHeaderIsNotPresentInThreadContext() {
Collections.emptyMap()
);

task.addQueryGroupHeaders(threadPool.getThreadContext());
task.addHeader(QueryGroupConstants.QUERY_GROUP_ID_HEADER, threadPool.getThreadContext(), () -> "default_val");

String queryGroupId = task.getHeader("queryGroupId");

assertEquals("DEFAULT_QUERY_GROUP_ID", queryGroupId);
assertEquals("default_val", queryGroupId);
} finally {
threadPool.shutdown();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.wlm;

import org.opensearch.action.index.IndexRequest;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportChannel;
import org.opensearch.transport.TransportRequest;
import org.opensearch.transport.TransportRequestHandler;

import java.util.Collections;

import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

public class SearchWorkloadTransportHandlerTests extends OpenSearchTestCase {
private SearchWorkloadTransportHandler<TransportRequest> sut;
private ThreadPool threadPool;

private TransportRequestHandler<TransportRequest> actualHandler;

public void setUp() throws Exception {
super.setUp();
threadPool = new TestThreadPool(getTestName());
actualHandler = new TestTransportRequestHandler<>();

sut = new SearchWorkloadTransportHandler<>(threadPool, actualHandler);
}

public void tearDown() throws Exception {
super.tearDown();
threadPool.shutdown();
}

public void testMessageReceivedForSearchWorkload() throws Exception {
ShardSearchRequest request = mock(ShardSearchRequest.class);
Task spyTask = getSpyTask();

sut.messageReceived(request, mock(TransportChannel.class), spyTask);

verify(spyTask, times(1)).addHeader(
QueryGroupConstants.QUERY_GROUP_ID_HEADER,
threadPool.getThreadContext(),
QueryGroupConstants.DEFAULT_QUERY_GROUP_ID_SUPPLIER
);
}

public void testMessageReceivedForNonSearchWorkload() throws Exception {
IndexRequest indexRequest = mock(IndexRequest.class);
Task spyTask = getSpyTask();
sut.messageReceived(indexRequest, mock(TransportChannel.class), spyTask);

verify(spyTask, times(0)).addHeader(any(), any(), any());
}

private static Task getSpyTask() {
final Task task = new Task(123, "transport", "Search", "test task", null, Collections.emptyMap());

return spy(task);
}

private static class TestTransportRequestHandler<T extends TransportRequest> implements TransportRequestHandler<T> {
int invokeCount = 0;

@Override
public void messageReceived(TransportRequest request, TransportChannel channel, Task task) throws Exception {
invokeCount += 1;
}

};
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.wlm;

import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportRequest;
import org.opensearch.transport.TransportRequestHandler;

import static org.opensearch.threadpool.ThreadPool.Names.SAME;

public class SearchWorkloadTransportInterceptorTests extends OpenSearchTestCase {

private ThreadPool threadPool;
private SearchWorkloadTransportInterceptor sut;

public void setUp() throws Exception {
threadPool = new TestThreadPool(getTestName());
sut = new SearchWorkloadTransportInterceptor(threadPool);
}

public void tearDown() throws Exception {
threadPool.shutdown();
}

public void testInterceptHandler() {
TransportRequestHandler<TransportRequest> requestHandler = sut.interceptHandler("Search", SAME, false, null);
assertTrue(requestHandler instanceof SearchWorkloadTransportHandler);
}
}

0 comments on commit c926996

Please sign in to comment.