diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/Dart.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/Dart.java new file mode 100644 index 000000000000..33e239161ffe --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/Dart.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart; + +import com.google.inject.BindingAnnotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Binding annotation for implements of interfaces that are Dart (MSQ-on-Broker-and-Historicals) focused. + */ +@Target({ElementType.FIELD, ElementType.PARAMETER, ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +@BindingAnnotation +public @interface Dart +{ +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/DartResourcePermissionMapper.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/DartResourcePermissionMapper.java new file mode 100644 index 000000000000..038d1b56c72b --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/DartResourcePermissionMapper.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart; + +import com.google.common.collect.ImmutableList; +import org.apache.druid.msq.dart.controller.http.DartSqlResource; +import org.apache.druid.msq.dart.worker.http.DartWorkerResource; +import org.apache.druid.msq.rpc.ResourcePermissionMapper; +import org.apache.druid.msq.rpc.WorkerResource; +import org.apache.druid.server.security.Action; +import org.apache.druid.server.security.Resource; +import org.apache.druid.server.security.ResourceAction; + +import java.util.List; + +public class DartResourcePermissionMapper implements ResourcePermissionMapper +{ + /** + * Permissions for admin APIs in {@link DartWorkerResource} and {@link WorkerResource}. Note that queries from + * end users go through {@link DartSqlResource}, which wouldn't use these mappings. + */ + @Override + public List getAdminPermissions() + { + return ImmutableList.of( + new ResourceAction(Resource.STATE_RESOURCE, Action.READ), + new ResourceAction(Resource.STATE_RESOURCE, Action.WRITE) + ); + } + + /** + * Permissions for per-query APIs in {@link DartWorkerResource} and {@link WorkerResource}. Note that queries from + * end users go through {@link DartSqlResource}, which wouldn't use these mappings. + */ + @Override + public List getQueryPermissions(String queryId) + { + return getAdminPermissions(); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/ControllerHolder.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/ControllerHolder.java new file mode 100644 index 000000000000..9644444dad24 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/ControllerHolder.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import com.google.common.base.Preconditions; +import org.apache.druid.msq.dart.worker.DartWorkerClient; +import org.apache.druid.msq.dart.worker.WorkerId; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.exec.ControllerContext; +import org.apache.druid.msq.exec.QueryListener; +import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.msq.indexing.error.WorkerFailedFault; +import org.apache.druid.server.security.AuthenticationResult; +import org.joda.time.DateTime; + +import java.util.concurrent.atomic.AtomicReference; + +/** + * Holder for {@link Controller}, stored in {@link DartControllerRegistry}. + */ +public class ControllerHolder +{ + public enum State + { + /** + * Query has been accepted, but not yet {@link Controller#run(QueryListener)}. + */ + ACCEPTED, + + /** + * Query has had {@link Controller#run(QueryListener)} called. + */ + RUNNING, + + /** + * Query has been canceled. + */ + CANCELED + } + + private final Controller controller; + private final ControllerContext controllerContext; + private final String sqlQueryId; + private final String sql; + private final AuthenticationResult authenticationResult; + private final DateTime startTime; + private final AtomicReference state = new AtomicReference<>(State.ACCEPTED); + + public ControllerHolder( + final Controller controller, + final ControllerContext controllerContext, + final String sqlQueryId, + final String sql, + final AuthenticationResult authenticationResult, + final DateTime startTime + ) + { + this.controller = Preconditions.checkNotNull(controller, "controller"); + this.controllerContext = controllerContext; + this.sqlQueryId = Preconditions.checkNotNull(sqlQueryId, "sqlQueryId"); + this.sql = sql; + this.authenticationResult = authenticationResult; + this.startTime = Preconditions.checkNotNull(startTime, "startTime"); + } + + public Controller getController() + { + return controller; + } + + public String getSqlQueryId() + { + return sqlQueryId; + } + + public String getSql() + { + return sql; + } + + public AuthenticationResult getAuthenticationResult() + { + return authenticationResult; + } + + public DateTime getStartTime() + { + return startTime; + } + + public State getState() + { + return state.get(); + } + + /** + * Call when a worker has gone offline. Closes its client and sends a {@link Controller#workerError} + * to the controller. + */ + public void workerOffline(final WorkerId workerId) + { + final String workerIdString = workerId.toString(); + + if (controllerContext instanceof DartControllerContext) { + // For DartControllerContext, newWorkerClient() returns the same instance every time. + // This will always be DartControllerContext in production; the instanceof check is here because certain + // tests use a different context class. + ((DartWorkerClient) controllerContext.newWorkerClient()).closeClient(workerId.getHostAndPort()); + } + + if (controller.hasWorker(workerIdString)) { + controller.workerError( + MSQErrorReport.fromFault( + workerIdString, + workerId.getHostAndPort(), + null, + new WorkerFailedFault(workerIdString, "Worker went offline") + ) + ); + } + } + + /** + * Places this holder into {@link State#CANCELED}. Calls {@link Controller#stop()} if it was previously in + * state {@link State#RUNNING}. + */ + public void cancel() + { + if (state.getAndSet(State.CANCELED) == State.RUNNING) { + controller.stop(); + } + } + + /** + * Calls {@link Controller#run(QueryListener)}, and returns true, if this holder was previously in state + * {@link State#ACCEPTED}. Otherwise returns false. + * + * @return whether {@link Controller#run(QueryListener)} was called. + */ + public boolean run(final QueryListener listener) throws Exception + { + if (state.compareAndSet(State.ACCEPTED, State.RUNNING)) { + controller.run(listener); + return true; + } else { + return false; + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/ControllerMessageListener.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/ControllerMessageListener.java new file mode 100644 index 000000000000..5cedd13baf0d --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/ControllerMessageListener.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import com.google.inject.Inject; +import org.apache.druid.messages.client.MessageListener; +import org.apache.druid.msq.dart.controller.messages.ControllerMessage; +import org.apache.druid.msq.dart.worker.WorkerId; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.server.DruidNode; + +/** + * Listener for worker-to-controller messages. + * Also responsible for calling {@link Controller#workerError(MSQErrorReport)} when a worker server goes away. + */ +public class ControllerMessageListener implements MessageListener +{ + private final DartControllerRegistry controllerRegistry; + + @Inject + public ControllerMessageListener(final DartControllerRegistry controllerRegistry) + { + this.controllerRegistry = controllerRegistry; + } + + @Override + public void messageReceived(ControllerMessage message) + { + final ControllerHolder holder = controllerRegistry.get(message.getQueryId()); + if (holder != null) { + message.handle(holder.getController()); + } + } + + @Override + public void serverAdded(DruidNode node) + { + // Nothing to do. + } + + @Override + public void serverRemoved(DruidNode node) + { + for (final ControllerHolder holder : controllerRegistry.getAllHolders()) { + final Controller controller = holder.getController(); + final WorkerId workerId = WorkerId.fromDruidNode(node, controller.queryId()); + holder.workerOffline(workerId); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContext.java new file mode 100644 index 000000000000..0248e66fd221 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContext.java @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.inject.Injector; +import org.apache.druid.client.BrokerServerView; +import org.apache.druid.error.DruidException; +import org.apache.druid.indexing.common.TaskLockType; +import org.apache.druid.indexing.common.actions.TaskActionClient; +import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.java.util.emitter.service.ServiceEmitter; +import org.apache.druid.java.util.emitter.service.ServiceMetricEvent; +import org.apache.druid.msq.dart.worker.DartWorkerClient; +import org.apache.druid.msq.dart.worker.WorkerId; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.exec.ControllerContext; +import org.apache.druid.msq.exec.ControllerMemoryParameters; +import org.apache.druid.msq.exec.MemoryIntrospector; +import org.apache.druid.msq.exec.WorkerFailureListener; +import org.apache.druid.msq.exec.WorkerManager; +import org.apache.druid.msq.indexing.IndexerControllerContext; +import org.apache.druid.msq.indexing.MSQSpec; +import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination; +import org.apache.druid.msq.input.InputSpecSlicer; +import org.apache.druid.msq.kernel.controller.ControllerQueryKernelConfig; +import org.apache.druid.msq.querykit.QueryKit; +import org.apache.druid.msq.querykit.QueryKitSpec; +import org.apache.druid.msq.util.MultiStageQueryContext; +import org.apache.druid.query.Query; +import org.apache.druid.query.QueryContext; +import org.apache.druid.server.DruidNode; +import org.apache.druid.server.coordination.DruidServerMetadata; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * Dart implementation of {@link ControllerContext}. + * Each instance is scoped to a query. + */ +public class DartControllerContext implements ControllerContext +{ + /** + * Default for {@link ControllerQueryKernelConfig#getMaxConcurrentStages()}. + */ + public static final int DEFAULT_MAX_CONCURRENT_STAGES = 2; + + /** + * Default for {@link MultiStageQueryContext#getTargetPartitionsPerWorkerWithDefault(QueryContext, int)}. + */ + public static final int DEFAULT_TARGET_PARTITIONS_PER_WORKER = 1; + + /** + * Context parameter for maximum number of nonleaf workers. + */ + public static final String CTX_MAX_NON_LEAF_WORKER_COUNT = "maxNonLeafWorkers"; + + /** + * Default to scatter/gather style: fan in to a single worker after the leaf stage(s). + */ + public static final int DEFAULT_MAX_NON_LEAF_WORKER_COUNT = 1; + + private final Injector injector; + private final ObjectMapper jsonMapper; + private final DruidNode selfNode; + private final DartWorkerClient workerClient; + private final BrokerServerView serverView; + private final MemoryIntrospector memoryIntrospector; + private final ServiceMetricEvent.Builder metricBuilder; + private final ServiceEmitter emitter; + + public DartControllerContext( + final Injector injector, + final ObjectMapper jsonMapper, + final DruidNode selfNode, + final DartWorkerClient workerClient, + final MemoryIntrospector memoryIntrospector, + final BrokerServerView serverView, + final ServiceEmitter emitter + ) + { + this.injector = injector; + this.jsonMapper = jsonMapper; + this.selfNode = selfNode; + this.workerClient = workerClient; + this.serverView = serverView; + this.memoryIntrospector = memoryIntrospector; + this.metricBuilder = new ServiceMetricEvent.Builder(); + this.emitter = emitter; + } + + @Override + public ControllerQueryKernelConfig queryKernelConfig( + final String queryId, + final MSQSpec querySpec + ) + { + final List servers = serverView.getDruidServerMetadatas(); + + // Lock in the list of workers when creating the kernel config. There is a race here: the serverView itself is + // allowed to float. If a segment moves to a new server that isn't part of our list after the WorkerManager is + // created, we won't be able to find a valid server for certain segments. This isn't expected to be a problem, + // since the serverView is referenced shortly after the worker list is created. + final List workerIds = new ArrayList<>(servers.size()); + for (final DruidServerMetadata server : servers) { + workerIds.add(WorkerId.fromDruidServerMetadata(server, queryId).toString()); + } + + // Shuffle workerIds, so we don't bias towards specific servers when running multiple queries concurrently. For any + // given query, lower-numbered workers tend to do more work, because the controller prefers using lower-numbered + // workers when maxWorkerCount for a stage is less than the total number of workers. + Collections.shuffle(workerIds); + + final ControllerMemoryParameters memoryParameters = + ControllerMemoryParameters.createProductionInstance( + memoryIntrospector, + workerIds.size() + ); + + final int maxConcurrentStages = MultiStageQueryContext.getMaxConcurrentStagesWithDefault( + querySpec.getQuery().context(), + DEFAULT_MAX_CONCURRENT_STAGES + ); + + return ControllerQueryKernelConfig + .builder() + .controllerHost(selfNode.getHostAndPortToUse()) + .workerIds(workerIds) + .pipeline(maxConcurrentStages > 1) + .destination(TaskReportMSQDestination.instance()) + .maxConcurrentStages(maxConcurrentStages) + .maxRetainedPartitionSketchBytes(memoryParameters.getPartitionStatisticsMaxRetainedBytes()) + .workerContextMap(IndexerControllerContext.makeWorkerContextMap(querySpec, false, maxConcurrentStages)) + .build(); + } + + @Override + public ObjectMapper jsonMapper() + { + return jsonMapper; + } + + @Override + public Injector injector() + { + return injector; + } + + @Override + public void emitMetric(final String metric, final Number value) + { + emitter.emit(metricBuilder.setMetric(metric, value)); + } + + @Override + public DruidNode selfNode() + { + return selfNode; + } + + @Override + public InputSpecSlicer newTableInputSpecSlicer(WorkerManager workerManager) + { + return DartTableInputSpecSlicer.createFromWorkerIds(workerManager.getWorkerIds(), serverView); + } + + @Override + public TaskActionClient taskActionClient() + { + throw new UnsupportedOperationException(); + } + + @Override + public WorkerManager newWorkerManager( + String queryId, + MSQSpec querySpec, + ControllerQueryKernelConfig queryKernelConfig, + WorkerFailureListener workerFailureListener + ) + { + // We're ignoring WorkerFailureListener. Dart worker failures are routed into the controller by + // ControllerMessageListener, which receives a notification when a worker goes offline. + return new DartWorkerManager(queryKernelConfig.getWorkerIds(), workerClient); + } + + @Override + public DartWorkerClient newWorkerClient() + { + return workerClient; + } + + @Override + public void registerController(Controller controller, Closer closer) + { + closer.register(workerClient); + } + + @Override + public QueryKitSpec makeQueryKitSpec( + final QueryKit> queryKit, + final String queryId, + final MSQSpec querySpec, + final ControllerQueryKernelConfig queryKernelConfig + ) + { + final QueryContext queryContext = querySpec.getQuery().context(); + return new QueryKitSpec( + queryKit, + queryId, + queryKernelConfig.getWorkerIds().size(), + queryContext.getInt( + CTX_MAX_NON_LEAF_WORKER_COUNT, + DEFAULT_MAX_NON_LEAF_WORKER_COUNT + ), + MultiStageQueryContext.getTargetPartitionsPerWorkerWithDefault( + queryContext, + DEFAULT_TARGET_PARTITIONS_PER_WORKER + ) + ); + } + + @Override + public TaskLockType taskLockType() + { + throw DruidException.defensive("TaskLockType is not used with class[%s]", getClass().getName()); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContextFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContextFactory.java new file mode 100644 index 000000000000..f58eb4bfa68d --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContextFactory.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import org.apache.druid.msq.dart.controller.sql.DartQueryMaker; +import org.apache.druid.msq.exec.ControllerContext; + +/** + * Class for creating {@link ControllerContext} in {@link DartQueryMaker}. + */ +public interface DartControllerContextFactory +{ + ControllerContext newContext(String queryId); +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContextFactoryImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContextFactoryImpl.java new file mode 100644 index 000000000000..8cefb6af7ece --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContextFactoryImpl.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.inject.Inject; +import com.google.inject.Injector; +import org.apache.druid.client.BrokerServerView; +import org.apache.druid.guice.annotations.EscalatedGlobal; +import org.apache.druid.guice.annotations.Json; +import org.apache.druid.guice.annotations.Self; +import org.apache.druid.guice.annotations.Smile; +import org.apache.druid.java.util.emitter.service.ServiceEmitter; +import org.apache.druid.msq.dart.worker.DartWorkerClient; +import org.apache.druid.msq.exec.ControllerContext; +import org.apache.druid.msq.exec.MemoryIntrospector; +import org.apache.druid.rpc.ServiceClientFactory; +import org.apache.druid.server.DruidNode; + +public class DartControllerContextFactoryImpl implements DartControllerContextFactory +{ + private final Injector injector; + private final ObjectMapper jsonMapper; + private final ObjectMapper smileMapper; + private final DruidNode selfNode; + private final ServiceClientFactory serviceClientFactory; + private final BrokerServerView serverView; + private final MemoryIntrospector memoryIntrospector; + private final ServiceEmitter emitter; + + @Inject + public DartControllerContextFactoryImpl( + final Injector injector, + @Json final ObjectMapper jsonMapper, + @Smile final ObjectMapper smileMapper, + @Self final DruidNode selfNode, + @EscalatedGlobal final ServiceClientFactory serviceClientFactory, + final MemoryIntrospector memoryIntrospector, + final BrokerServerView serverView, + final ServiceEmitter emitter + ) + { + this.injector = injector; + this.jsonMapper = jsonMapper; + this.smileMapper = smileMapper; + this.selfNode = selfNode; + this.serviceClientFactory = serviceClientFactory; + this.serverView = serverView; + this.memoryIntrospector = memoryIntrospector; + this.emitter = emitter; + } + + @Override + public ControllerContext newContext(final String queryId) + { + return new DartControllerContext( + injector, + jsonMapper, + selfNode, + new DartWorkerClient(queryId, serviceClientFactory, smileMapper, selfNode.getHostAndPortToUse()), + memoryIntrospector, + serverView, + emitter + ); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerRegistry.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerRegistry.java new file mode 100644 index 000000000000..847dbf759806 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerRegistry.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import org.apache.druid.error.DruidException; +import org.apache.druid.msq.exec.Controller; + +import javax.annotation.Nullable; +import java.util.Collection; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Registry for actively-running {@link Controller}. + */ +public class DartControllerRegistry +{ + private final ConcurrentHashMap controllerMap = new ConcurrentHashMap<>(); + + /** + * Add a controller. Throws {@link DruidException} if a controller with the same {@link Controller#queryId()} is + * already registered. + */ + public void register(ControllerHolder holder) + { + if (controllerMap.putIfAbsent(holder.getController().queryId(), holder) != null) { + throw DruidException.defensive("Controller[%s] already registered", holder.getController().queryId()); + } + } + + /** + * Remove a controller from the registry. + */ + public void deregister(ControllerHolder holder) + { + // Remove only if the current mapping for the queryId is this specific controller. + controllerMap.remove(holder.getController().queryId(), holder); + } + + /** + * Return a specific controller holder, or null if it doesn't exist. + */ + @Nullable + public ControllerHolder get(final String queryId) + { + return controllerMap.get(queryId); + } + + /** + * Returns all actively-running {@link Controller}. + */ + public Collection getAllHolders() + { + return controllerMap.values(); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartMessageRelayFactoryImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartMessageRelayFactoryImpl.java new file mode 100644 index 000000000000..7f16a37c9d72 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartMessageRelayFactoryImpl.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.inject.Inject; +import org.apache.druid.guice.annotations.EscalatedGlobal; +import org.apache.druid.guice.annotations.Self; +import org.apache.druid.guice.annotations.Smile; +import org.apache.druid.messages.client.MessageRelay; +import org.apache.druid.messages.client.MessageRelayClientImpl; +import org.apache.druid.messages.client.MessageRelayFactory; +import org.apache.druid.msq.dart.controller.messages.ControllerMessage; +import org.apache.druid.msq.dart.worker.http.DartWorkerResource; +import org.apache.druid.rpc.FixedServiceLocator; +import org.apache.druid.rpc.ServiceClient; +import org.apache.druid.rpc.ServiceClientFactory; +import org.apache.druid.rpc.ServiceLocation; +import org.apache.druid.rpc.StandardRetryPolicy; +import org.apache.druid.server.DruidNode; + +/** + * Production implementation of {@link MessageRelayFactory}. + */ +public class DartMessageRelayFactoryImpl implements MessageRelayFactory +{ + private final String clientHost; + private final ControllerMessageListener messageListener; + private final ServiceClientFactory clientFactory; + private final String basePath; + private final ObjectMapper smileMapper; + + @Inject + public DartMessageRelayFactoryImpl( + @Self DruidNode selfNode, + @EscalatedGlobal ServiceClientFactory clientFactory, + @Smile ObjectMapper smileMapper, + ControllerMessageListener messageListener + ) + { + this.clientHost = selfNode.getHostAndPortToUse(); + this.messageListener = messageListener; + this.clientFactory = clientFactory; + this.smileMapper = smileMapper; + this.basePath = DartWorkerResource.PATH + "/relay"; + } + + @Override + public MessageRelay newRelay(DruidNode clientNode) + { + final ServiceLocation location = ServiceLocation.fromDruidNode(clientNode).withBasePath(basePath); + final ServiceClient client = clientFactory.makeClient( + clientNode.getHostAndPortToUse(), + new FixedServiceLocator(location), + StandardRetryPolicy.unlimited() + ); + + return new MessageRelay<>( + clientHost, + clientNode, + new MessageRelayClientImpl<>(client, smileMapper, ControllerMessage.class), + messageListener + ); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartMessageRelays.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartMessageRelays.java new file mode 100644 index 000000000000..23accd35ecbe --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartMessageRelays.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import org.apache.druid.discovery.DruidNodeDiscoveryProvider; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.messages.client.MessageRelayFactory; +import org.apache.druid.messages.client.MessageRelays; +import org.apache.druid.msq.dart.controller.messages.ControllerMessage; + +/** + * Specialized {@link MessageRelays} for Dart controllers. + */ +public class DartMessageRelays extends MessageRelays +{ + public DartMessageRelays( + final DruidNodeDiscoveryProvider discoveryProvider, + final MessageRelayFactory messageRelayFactory + ) + { + super(() -> discoveryProvider.getForNodeRole(NodeRole.HISTORICAL), messageRelayFactory); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartTableInputSpecSlicer.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartTableInputSpecSlicer.java new file mode 100644 index 000000000000..52ecccbc152f --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartTableInputSpecSlicer.java @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import com.google.common.collect.FluentIterable; +import com.google.common.collect.ImmutableList; +import it.unimi.dsi.fastutil.objects.Object2IntMap; +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; +import org.apache.druid.client.TimelineServerView; +import org.apache.druid.client.selector.QueryableDruidServer; +import org.apache.druid.client.selector.ServerSelector; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.JodaUtils; +import org.apache.druid.msq.dart.worker.DartQueryableSegment; +import org.apache.druid.msq.dart.worker.WorkerId; +import org.apache.druid.msq.exec.SegmentSource; +import org.apache.druid.msq.exec.WorkerManager; +import org.apache.druid.msq.input.InputSlice; +import org.apache.druid.msq.input.InputSpec; +import org.apache.druid.msq.input.InputSpecSlicer; +import org.apache.druid.msq.input.NilInputSlice; +import org.apache.druid.msq.input.table.RichSegmentDescriptor; +import org.apache.druid.msq.input.table.SegmentsInputSlice; +import org.apache.druid.msq.input.table.TableInputSpec; +import org.apache.druid.query.TableDataSource; +import org.apache.druid.query.filter.DimFilterUtils; +import org.apache.druid.server.coordination.DruidServerMetadata; +import org.apache.druid.timeline.DataSegment; +import org.apache.druid.timeline.TimelineLookup; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Set; +import java.util.function.ToIntFunction; + +/** + * Slices {@link TableInputSpec} into {@link SegmentsInputSlice} for persistent servers using + * {@link TimelineServerView}. + */ +public class DartTableInputSpecSlicer implements InputSpecSlicer +{ + private static final int UNKNOWN = -1; + + /** + * Worker host:port -> worker number. This is the reverse of the mapping from {@link WorkerManager#getWorkerIds()}. + */ + private final Object2IntMap workerIdToNumber; + + /** + * Server view for identifying which segments exist and which servers (workers) have which segments. + */ + private final TimelineServerView serverView; + + DartTableInputSpecSlicer(final Object2IntMap workerIdToNumber, final TimelineServerView serverView) + { + this.workerIdToNumber = workerIdToNumber; + this.serverView = serverView; + } + + public static DartTableInputSpecSlicer createFromWorkerIds( + final List workerIds, + final TimelineServerView serverView + ) + { + final Object2IntMap reverseWorkers = new Object2IntOpenHashMap<>(); + reverseWorkers.defaultReturnValue(UNKNOWN); + + for (int i = 0; i < workerIds.size(); i++) { + reverseWorkers.put(WorkerId.fromString(workerIds.get(i)).getHostAndPort(), i); + } + + return new DartTableInputSpecSlicer(reverseWorkers, serverView); + } + + @Override + public boolean canSliceDynamic(final InputSpec inputSpec) + { + return false; + } + + @Override + public List sliceStatic(final InputSpec inputSpec, final int maxNumSlices) + { + final TableInputSpec tableInputSpec = (TableInputSpec) inputSpec; + final TimelineLookup timeline = + serverView.getTimeline(new TableDataSource(tableInputSpec.getDataSource()).getAnalysis()).orElse(null); + + if (timeline == null) { + return Collections.emptyList(); + } + + final Set prunedSegments = + findQueryableDataSegments( + tableInputSpec, + timeline, + serverSelector -> findWorkerForServerSelector(serverSelector, maxNumSlices) + ); + + final List> assignments = new ArrayList<>(maxNumSlices); + while (assignments.size() < maxNumSlices) { + assignments.add(null); + } + + int nextRoundRobinWorker = 0; + for (final DartQueryableSegment segment : prunedSegments) { + final int worker; + if (segment.getWorkerNumber() == UNKNOWN) { + // Segment is not available on any worker. Assign to some worker, round-robin. Today, that server will throw + // an error about the segment not being findable, but perhaps one day, it will be able to load the segment + // on demand. + worker = nextRoundRobinWorker; + nextRoundRobinWorker = (nextRoundRobinWorker + 1) % maxNumSlices; + } else { + worker = segment.getWorkerNumber(); + } + + if (assignments.get(worker) == null) { + assignments.set(worker, new ArrayList<>()); + } + + assignments.get(worker).add(segment); + } + + return makeSegmentSlices(tableInputSpec.getDataSource(), assignments); + } + + @Override + public List sliceDynamic( + final InputSpec inputSpec, + final int maxNumSlices, + final int maxFilesPerSlice, + final long maxBytesPerSlice + ) + { + throw new UnsupportedOperationException(); + } + + /** + * Return the worker ID that corresponds to a particular {@link ServerSelector}, or {@link #UNKNOWN} if none does. + * + * @param serverSelector the server selector + * @param maxNumSlices maximum number of worker IDs to use + */ + int findWorkerForServerSelector(final ServerSelector serverSelector, final int maxNumSlices) + { + final QueryableDruidServer server = serverSelector.pick(null); + + if (server == null) { + return UNKNOWN; + } + + final String serverHostAndPort = server.getServer().getHostAndPort(); + final int workerNumber = workerIdToNumber.getInt(serverHostAndPort); + + // The worker number may be UNKNOWN in a race condition, such as the set of Historicals changing while + // the query is being planned. I don't think it can be >= maxNumSlices, but if it is, treat it like UNKNOWN. + if (workerNumber != UNKNOWN && workerNumber < maxNumSlices) { + return workerNumber; + } else { + return UNKNOWN; + } + } + + /** + * Pull the list of {@link DataSegment} that we should query, along with a clipping interval for each one, and + * a worker to get it from. + */ + static Set findQueryableDataSegments( + final TableInputSpec tableInputSpec, + final TimelineLookup timeline, + final ToIntFunction toWorkersFunction + ) + { + final FluentIterable allSegments = + FluentIterable.from(JodaUtils.condenseIntervals(tableInputSpec.getIntervals())) + .transformAndConcat(timeline::lookup) + .transformAndConcat( + holder -> + FluentIterable + .from(holder.getObject()) + .filter(chunk -> shouldIncludeSegment(chunk.getObject())) + .transform(chunk -> { + final ServerSelector serverSelector = chunk.getObject(); + final DataSegment dataSegment = serverSelector.getSegment(); + final int worker = toWorkersFunction.applyAsInt(serverSelector); + return new DartQueryableSegment(dataSegment, holder.getInterval(), worker); + }) + .filter(segment -> !segment.getSegment().isTombstone()) + ); + + return DimFilterUtils.filterShards( + tableInputSpec.getFilter(), + tableInputSpec.getFilterFields(), + allSegments, + segment -> segment.getSegment().getShardSpec(), + new HashMap<>() + ); + } + + /** + * Create a list of {@link SegmentsInputSlice} and {@link NilInputSlice} assignments. + * + * @param dataSource datasource to read + * @param assignments list of assignment lists, one per slice + * + * @return a list of the same length as "assignments" + * + * @throws IllegalStateException if any provided segments do not match the provided datasource + */ + static List makeSegmentSlices( + final String dataSource, + final List> assignments + ) + { + final List retVal = new ArrayList<>(assignments.size()); + + for (final List assignment : assignments) { + if (assignment == null || assignment.isEmpty()) { + retVal.add(NilInputSlice.INSTANCE); + } else { + final List descriptors = new ArrayList<>(); + for (final DartQueryableSegment segment : assignment) { + if (!dataSource.equals(segment.getSegment().getDataSource())) { + throw new ISE("Expected dataSource[%s] but got[%s]", dataSource, segment.getSegment().getDataSource()); + } + + descriptors.add(toRichSegmentDescriptor(segment)); + } + + retVal.add(new SegmentsInputSlice(dataSource, descriptors, ImmutableList.of())); + } + } + + return retVal; + } + + /** + * Returns a {@link RichSegmentDescriptor}, which is used by {@link SegmentsInputSlice}. + */ + static RichSegmentDescriptor toRichSegmentDescriptor(final DartQueryableSegment segment) + { + return new RichSegmentDescriptor( + segment.getSegment().getInterval(), + segment.getInterval(), + segment.getSegment().getVersion(), + segment.getSegment().getShardSpec().getPartitionNum() + ); + } + + /** + * Whether to include a segment from the timeline. Segments are included if they are not tombstones, and are also not + * purely realtime segments. + */ + static boolean shouldIncludeSegment(final ServerSelector serverSelector) + { + if (serverSelector.getSegment().isTombstone()) { + return false; + } + + int numRealtimeServers = 0; + int numOtherServers = 0; + + for (final DruidServerMetadata server : serverSelector.getAllServers()) { + if (SegmentSource.REALTIME.getUsedServerTypes().contains(server.getType())) { + numRealtimeServers++; + } else { + numOtherServers++; + } + } + + return numOtherServers > 0 || (numOtherServers + numRealtimeServers == 0); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartWorkerManager.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartWorkerManager.java new file mode 100644 index 000000000000..54e163862d62 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartWorkerManager.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import it.unimi.dsi.fastutil.objects.Object2IntMap; +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.error.DruidException; +import org.apache.druid.indexer.TaskState; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.dart.worker.DartWorkerClient; +import org.apache.druid.msq.exec.ControllerContext; +import org.apache.druid.msq.exec.WorkerClient; +import org.apache.druid.msq.exec.WorkerManager; +import org.apache.druid.msq.exec.WorkerStats; +import org.apache.druid.msq.indexing.WorkerCount; +import org.apache.druid.utils.CloseableUtils; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Dart implementation of the {@link WorkerManager} returned by {@link ControllerContext#newWorkerManager}. + * + * This manager does not actually launch workers. The workers are housed on long-lived servers outside of this + * manager's control. This manager merely reports on their existence. + */ +public class DartWorkerManager implements WorkerManager +{ + private static final Logger log = new Logger(DartWorkerManager.class); + + private final List workerIds; + private final DartWorkerClient workerClient; + private final Object2IntMap workerIdToNumber; + private final AtomicReference state = new AtomicReference<>(State.NEW); + private final SettableFuture stopFuture = SettableFuture.create(); + + enum State + { + NEW, + STARTED, + STOPPED + } + + public DartWorkerManager( + final List workerIds, + final DartWorkerClient workerClient + ) + { + this.workerIds = workerIds; + this.workerClient = workerClient; + this.workerIdToNumber = new Object2IntOpenHashMap<>(); + this.workerIdToNumber.defaultReturnValue(UNKNOWN_WORKER_NUMBER); + + for (int i = 0; i < workerIds.size(); i++) { + workerIdToNumber.put(workerIds.get(i), i); + } + } + + @Override + public ListenableFuture start() + { + if (!state.compareAndSet(State.NEW, State.STARTED)) { + throw new ISE("Cannot start from state[%s]", state.get()); + } + + return stopFuture; + } + + @Override + public void launchWorkersIfNeeded(int workerCount) + { + // Nothing to do, just validate the count. + if (workerCount > workerIds.size()) { + throw DruidException.defensive( + "Desired workerCount[%s] must be less than or equal to actual workerCount[%s]", + workerCount, + workerIds.size() + ); + } + } + + @Override + public void waitForWorkers(Set workerNumbers) + { + // Nothing to wait for, just validate the numbers. + for (final int workerNumber : workerNumbers) { + if (workerNumber >= workerIds.size()) { + throw DruidException.defensive( + "Desired workerNumber[%s] must be less than workerCount[%s]", + workerNumber, + workerIds.size() + ); + } + } + } + + @Override + public List getWorkerIds() + { + return workerIds; + } + + @Override + public WorkerCount getWorkerCount() + { + return new WorkerCount(workerIds.size(), 0); + } + + @Override + public int getWorkerNumber(String workerId) + { + return workerIdToNumber.getInt(workerId); + } + + @Override + public boolean isWorkerActive(String workerId) + { + return workerIdToNumber.containsKey(workerId); + } + + @Override + public Map> getWorkerStats() + { + final Int2ObjectMap> retVal = new Int2ObjectAVLTreeMap<>(); + + for (int i = 0; i < workerIds.size(); i++) { + retVal.put(i, Collections.singletonList(new WorkerStats(workerIds.get(i), TaskState.RUNNING, -1, -1))); + } + + return retVal; + } + + /** + * Stop method. Possibly signals workers to stop, but does not actually wait for them to exit. + * + * If "interrupt" is false, does nothing special (other than setting {@link #stopFuture}). The assumption is that + * a previous call to {@link WorkerClient#postFinish} would have caused the worker to exit. + * + * If "interrupt" is true, sends {@link DartWorkerClient#stopWorker(String)} to workers to stop the current query ID. + * + * @param interrupt whether to interrupt currently-running work + */ + @Override + public void stop(boolean interrupt) + { + if (state.compareAndSet(State.STARTED, State.STOPPED)) { + final List> futures = new ArrayList<>(); + + // Send stop commands to all workers. This ensures they exit promptly, and do not get left in a zombie state. + // For this reason, the workerClient uses an unlimited retry policy. If a stop command is lost, a worker + // could get stuck in a zombie state without its controller. This state would persist until the server that + // ran the controller shuts down or restarts. At that time, the listener in DartWorkerRunner.BrokerListener calls + // "controllerFailed()" on the Worker, and the zombie worker would exit. + + for (final String workerId : workerIds) { + futures.add(workerClient.stopWorker(workerId)); + } + + // Block until messages are acknowledged, or until the worker we're communicating with has failed. + + try { + FutureUtils.getUnchecked(Futures.successfulAsList(futures), false); + } + catch (Throwable ignored) { + // Suppress errors. + } + + CloseableUtils.closeAndSuppressExceptions(workerClient, e -> log.warn(e, "Failed to close workerClient")); + stopFuture.set(null); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/DartQueryInfo.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/DartQueryInfo.java new file mode 100644 index 000000000000..e5f3abb894e1 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/DartQueryInfo.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.http; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Preconditions; +import org.apache.druid.msq.dart.controller.ControllerHolder; +import org.apache.druid.msq.util.MSQTaskQueryMakerUtils; +import org.apache.druid.query.QueryContexts; +import org.joda.time.DateTime; + +import java.util.Objects; + +/** + * Class included in {@link GetQueriesResponse}. + */ +public class DartQueryInfo +{ + private final String sqlQueryId; + private final String dartQueryId; + private final String sql; + private final String authenticator; + private final String identity; + private final DateTime startTime; + private final String state; + + @JsonCreator + public DartQueryInfo( + @JsonProperty("sqlQueryId") final String sqlQueryId, + @JsonProperty("dartQueryId") final String dartQueryId, + @JsonProperty("sql") final String sql, + @JsonProperty("authenticator") final String authenticator, + @JsonProperty("identity") final String identity, + @JsonProperty("startTime") final DateTime startTime, + @JsonProperty("state") final String state + ) + { + this.sqlQueryId = Preconditions.checkNotNull(sqlQueryId, "sqlQueryId"); + this.dartQueryId = Preconditions.checkNotNull(dartQueryId, "dartQueryId"); + this.sql = sql; + this.authenticator = authenticator; + this.identity = identity; + this.startTime = startTime; + this.state = state; + } + + public static DartQueryInfo fromControllerHolder(final ControllerHolder holder) + { + return new DartQueryInfo( + holder.getSqlQueryId(), + holder.getController().queryId(), + MSQTaskQueryMakerUtils.maskSensitiveJsonKeys(holder.getSql()), + holder.getAuthenticationResult().getAuthenticatedBy(), + holder.getAuthenticationResult().getIdentity(), + holder.getStartTime(), + holder.getState().toString() + ); + } + + /** + * The {@link QueryContexts#CTX_SQL_QUERY_ID} provided by the user, or generated by the system. + */ + @JsonProperty + public String getSqlQueryId() + { + return sqlQueryId; + } + + /** + * Dart query ID generated by the system. Globally unique. + */ + @JsonProperty + public String getDartQueryId() + { + return dartQueryId; + } + + /** + * SQL string for this query, masked using {@link MSQTaskQueryMakerUtils#maskSensitiveJsonKeys(String)}. + */ + @JsonProperty + @JsonInclude(JsonInclude.Include.NON_NULL) + public String getSql() + { + return sql; + } + + /** + * Authenticator that authenticated the identity from {@link #getIdentity()}. + */ + @JsonProperty + @JsonInclude(JsonInclude.Include.NON_NULL) + public String getAuthenticator() + { + return authenticator; + } + + /** + * User that issued this query. + */ + @JsonProperty + @JsonInclude(JsonInclude.Include.NON_NULL) + public String getIdentity() + { + return identity; + } + + /** + * Time this query was started. + */ + @JsonProperty + @JsonInclude(JsonInclude.Include.NON_NULL) + public DateTime getStartTime() + { + return startTime; + } + + @JsonProperty + public String getState() + { + return state; + } + + /** + * Returns a copy of this instance with {@link #getAuthenticator()} and {@link #getIdentity()} nulled. + */ + public DartQueryInfo withoutAuthenticationResult() + { + return new DartQueryInfo(sqlQueryId, dartQueryId, sql, null, null, startTime, state); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DartQueryInfo that = (DartQueryInfo) o; + return Objects.equals(sqlQueryId, that.sqlQueryId) + && Objects.equals(dartQueryId, that.dartQueryId) + && Objects.equals(sql, that.sql) + && Objects.equals(authenticator, that.authenticator) + && Objects.equals(identity, that.identity) + && Objects.equals(startTime, that.startTime) + && Objects.equals(state, that.state); + } + + @Override + public int hashCode() + { + return Objects.hash(sqlQueryId, dartQueryId, sql, authenticator, identity, startTime, state); + } + + @Override + public String toString() + { + return "DartQueryInfo{" + + "sqlQueryId='" + sqlQueryId + '\'' + + ", dartQueryId='" + dartQueryId + '\'' + + ", sql='" + sql + '\'' + + ", authenticator='" + authenticator + '\'' + + ", identity='" + identity + '\'' + + ", startTime=" + startTime + + ", state=" + state + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/DartSqlResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/DartSqlResource.java new file mode 100644 index 000000000000..37e9f1051318 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/DartSqlResource.java @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.http; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.common.util.concurrent.Futures; +import com.google.inject.Inject; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.guice.annotations.Self; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.dart.Dart; +import org.apache.druid.msq.dart.controller.ControllerHolder; +import org.apache.druid.msq.dart.controller.DartControllerRegistry; +import org.apache.druid.msq.dart.controller.sql.DartSqlClients; +import org.apache.druid.msq.dart.controller.sql.DartSqlEngine; +import org.apache.druid.query.DefaultQueryConfig; +import org.apache.druid.server.DruidNode; +import org.apache.druid.server.ResponseContextConfig; +import org.apache.druid.server.initialization.ServerConfig; +import org.apache.druid.server.security.Access; +import org.apache.druid.server.security.Action; +import org.apache.druid.server.security.AuthenticationResult; +import org.apache.druid.server.security.AuthorizationUtils; +import org.apache.druid.server.security.AuthorizerMapper; +import org.apache.druid.server.security.Resource; +import org.apache.druid.server.security.ResourceAction; +import org.apache.druid.sql.HttpStatement; +import org.apache.druid.sql.SqlLifecycleManager; +import org.apache.druid.sql.SqlStatementFactory; +import org.apache.druid.sql.http.SqlQuery; +import org.apache.druid.sql.http.SqlResource; + +import javax.servlet.http.HttpServletRequest; +import javax.ws.rs.Consumes; +import javax.ws.rs.DELETE; +import javax.ws.rs.GET; +import javax.ws.rs.POST; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.Produces; +import javax.ws.rs.QueryParam; +import javax.ws.rs.core.Context; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.UUID; +import java.util.stream.Collectors; + +/** + * Resource for Dart queries. API-compatible with {@link SqlResource}, so clients can be pointed from + * {@code /druid/v2/sql/} to {@code /druid/v2/sql/dart/} without code changes. + */ +@Path(DartSqlResource.PATH + '/') +public class DartSqlResource extends SqlResource +{ + public static final String PATH = "/druid/v2/sql/dart"; + + private static final Logger log = new Logger(DartSqlResource.class); + + private final DartControllerRegistry controllerRegistry; + private final SqlLifecycleManager sqlLifecycleManager; + private final DartSqlClients sqlClients; + private final AuthorizerMapper authorizerMapper; + private final DefaultQueryConfig dartQueryConfig; + + @Inject + public DartSqlResource( + final ObjectMapper jsonMapper, + final AuthorizerMapper authorizerMapper, + @Dart final SqlStatementFactory sqlStatementFactory, + final DartControllerRegistry controllerRegistry, + final SqlLifecycleManager sqlLifecycleManager, + final DartSqlClients sqlClients, + final ServerConfig serverConfig, + final ResponseContextConfig responseContextConfig, + @Self final DruidNode selfNode, + @Dart final DefaultQueryConfig dartQueryConfig + ) + { + super( + jsonMapper, + authorizerMapper, + sqlStatementFactory, + sqlLifecycleManager, + serverConfig, + responseContextConfig, + selfNode + ); + this.controllerRegistry = controllerRegistry; + this.sqlLifecycleManager = sqlLifecycleManager; + this.sqlClients = sqlClients; + this.authorizerMapper = authorizerMapper; + this.dartQueryConfig = dartQueryConfig; + } + + /** + * API that allows callers to check if this resource is installed without actually issuing a query. If installed, + * this call returns 200 OK. If not installed, callers get 404 Not Found. + */ + @GET + @Path("/enabled") + @Produces(MediaType.APPLICATION_JSON) + public Response doGetEnabled(@Context final HttpServletRequest request) + { + AuthorizationUtils.setRequestAuthorizationAttributeIfNeeded(request); + return Response.ok(ImmutableMap.of("enabled", true)).build(); + } + + /** + * API to list all running queries. + * + * @param selfOnly if true, return queries running on this server. If false, return queries running on all servers. + * @param req http request + */ + @GET + @Produces(MediaType.APPLICATION_JSON) + public GetQueriesResponse doGetRunningQueries( + @QueryParam("selfOnly") final String selfOnly, + @Context final HttpServletRequest req + ) + { + final AuthenticationResult authenticationResult = AuthorizationUtils.authenticationResultFromRequest(req); + final Access stateReadAccess = AuthorizationUtils.authorizeAllResourceActions( + authenticationResult, + Collections.singletonList(new ResourceAction(Resource.STATE_RESOURCE, Action.READ)), + authorizerMapper + ); + + final List queries = + controllerRegistry.getAllHolders() + .stream() + .map(DartQueryInfo::fromControllerHolder) + .sorted(Comparator.comparing(DartQueryInfo::getStartTime)) + .collect(Collectors.toList()); + + // Add queries from all other servers, if "selfOnly" is not set. + if (selfOnly == null) { + final List otherQueries = FutureUtils.getUnchecked( + Futures.successfulAsList( + Iterables.transform(sqlClients.getAllClients(), client -> client.getRunningQueries(true))), + true + ); + + for (final GetQueriesResponse response : otherQueries) { + if (response != null) { + queries.addAll(response.getQueries()); + } + } + } + + final GetQueriesResponse response; + if (stateReadAccess.isAllowed()) { + // User can READ STATE, so they can see all running queries, as well as authentication details. + response = new GetQueriesResponse(queries); + } else { + // User cannot READ STATE, so they can see only their own queries, without authentication details. + response = new GetQueriesResponse( + queries.stream() + .filter( + query -> + authenticationResult.getAuthenticatedBy() != null + && authenticationResult.getIdentity() != null + && Objects.equals(authenticationResult.getAuthenticatedBy(), query.getAuthenticator()) + && Objects.equals(authenticationResult.getIdentity(), query.getIdentity())) + .map(DartQueryInfo::withoutAuthenticationResult) + .collect(Collectors.toList()) + ); + } + + AuthorizationUtils.setRequestAuthorizationAttributeIfNeeded(req); + return response; + } + + /** + * API to issue a query. + */ + @POST + @Produces(MediaType.APPLICATION_JSON) + @Consumes(MediaType.APPLICATION_JSON) + @Override + public Response doPost( + final SqlQuery sqlQuery, + @Context final HttpServletRequest req + ) + { + final Map context = new HashMap<>(sqlQuery.getContext()); + + // Default context keys from dartQueryConfig. + for (Map.Entry entry : dartQueryConfig.getContext().entrySet()) { + context.putIfAbsent(entry.getKey(), entry.getValue()); + } + + // Dart queryId must be globally unique; cannot use user-provided sqlQueryId or queryId. + final String dartQueryId = UUID.randomUUID().toString(); + context.put(DartSqlEngine.CTX_DART_QUERY_ID, dartQueryId); + + return super.doPost(sqlQuery.withOverridenContext(context), req); + } + + /** + * API to cancel a query. + */ + @DELETE + @Path("{id}") + @Produces(MediaType.APPLICATION_JSON) + @Override + public Response cancelQuery( + @PathParam("id") String sqlQueryId, + @Context final HttpServletRequest req + ) + { + log.debug("Received cancel request for query[%s]", sqlQueryId); + + List cancelables = sqlLifecycleManager.getAll(sqlQueryId); + if (cancelables.isEmpty()) { + return Response.status(Response.Status.NOT_FOUND).build(); + } + + final Access access = authorizeCancellation(req, cancelables); + + if (access.isAllowed()) { + sqlLifecycleManager.removeAll(sqlQueryId, cancelables); + + // Don't call cancel() on the cancelables. That just cancels native queries, which is useless here. Instead, + // get the controller and stop it. + boolean found = false; + for (SqlLifecycleManager.Cancelable cancelable : cancelables) { + final HttpStatement stmt = (HttpStatement) cancelable; + final Object dartQueryId = stmt.context().get(DartSqlEngine.CTX_DART_QUERY_ID); + if (dartQueryId instanceof String) { + final ControllerHolder holder = controllerRegistry.get((String) dartQueryId); + if (holder != null) { + found = true; + holder.cancel(); + } + } else { + log.warn( + "%s[%s] for query[%s] is not a string, cannot cancel.", + DartSqlEngine.CTX_DART_QUERY_ID, + dartQueryId, + sqlQueryId + ); + } + } + + return Response.status(found ? Response.Status.ACCEPTED : Response.Status.NOT_FOUND).build(); + } else { + return Response.status(Response.Status.FORBIDDEN).build(); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/GetQueriesResponse.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/GetQueriesResponse.java new file mode 100644 index 000000000000..2d1f87f860c5 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/GetQueriesResponse.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.http; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; +import java.util.Objects; + +/** + * Class returned by {@link DartSqlResource#doGetRunningQueries}, the "list all queries" API. + */ +public class GetQueriesResponse +{ + private final List queries; + + @JsonCreator + public GetQueriesResponse(@JsonProperty("queries") List queries) + { + this.queries = queries; + } + + @JsonProperty + public List getQueries() + { + return queries; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + GetQueriesResponse response = (GetQueriesResponse) o; + return Objects.equals(queries, response.queries); + } + + @Override + public int hashCode() + { + return Objects.hashCode(queries); + } + + @Override + public String toString() + { + return "GetQueriesResponse{" + + "queries=" + queries + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/ControllerMessage.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/ControllerMessage.java new file mode 100644 index 000000000000..454e23bbc9c1 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/ControllerMessage.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.messages; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import org.apache.druid.msq.dart.worker.DartControllerClient; +import org.apache.druid.msq.exec.Controller; + +/** + * Messages sent from worker to controller by {@link DartControllerClient}. + */ +@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type") +@JsonSubTypes({ + @JsonSubTypes.Type(value = PartialKeyStatistics.class, name = "partialKeyStatistics"), + @JsonSubTypes.Type(value = DoneReadingInput.class, name = "doneReadingInput"), + @JsonSubTypes.Type(value = ResultsComplete.class, name = "resultsComplete"), + @JsonSubTypes.Type(value = WorkerError.class, name = "workerError"), + @JsonSubTypes.Type(value = WorkerWarning.class, name = "workerWarning") +}) +public interface ControllerMessage +{ + /** + * Query ID, to identify the controller that is being contacted. + */ + String getQueryId(); + + /** + * Handler for this message, which calls an appropriate method on {@link Controller}. + */ + void handle(Controller controller); +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/DoneReadingInput.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/DoneReadingInput.java new file mode 100644 index 000000000000..e74e5a0d1bb7 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/DoneReadingInput.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.messages; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Preconditions; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.exec.ControllerClient; +import org.apache.druid.msq.kernel.StageId; + +import java.util.Objects; + +/** + * Message for {@link ControllerClient#postDoneReadingInput}. + */ +public class DoneReadingInput implements ControllerMessage +{ + private final StageId stageId; + private final int workerNumber; + + @JsonCreator + public DoneReadingInput( + @JsonProperty("stage") final StageId stageId, + @JsonProperty("worker") final int workerNumber + ) + { + this.stageId = Preconditions.checkNotNull(stageId, "stageId"); + this.workerNumber = workerNumber; + } + + @Override + public String getQueryId() + { + return stageId.getQueryId(); + } + + @JsonProperty("stage") + public StageId getStageId() + { + return stageId; + } + + @JsonProperty("worker") + public int getWorkerNumber() + { + return workerNumber; + } + + @Override + public void handle(Controller controller) + { + controller.doneReadingInput(stageId.getStageNumber(), workerNumber); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DoneReadingInput that = (DoneReadingInput) o; + return workerNumber == that.workerNumber + && Objects.equals(stageId, that.stageId); + } + + @Override + public int hashCode() + { + return Objects.hash(stageId, workerNumber); + } + + @Override + public String toString() + { + return "DoneReadingInput{" + + "stageId=" + stageId + + ", workerNumber=" + workerNumber + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/PartialKeyStatistics.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/PartialKeyStatistics.java new file mode 100644 index 000000000000..1aa3bcb040e4 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/PartialKeyStatistics.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.messages; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Preconditions; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.exec.ControllerClient; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; + +import java.util.Objects; + +/** + * Message for {@link ControllerClient#postPartialKeyStatistics}. + */ +public class PartialKeyStatistics implements ControllerMessage +{ + private final StageId stageId; + private final int workerNumber; + private final PartialKeyStatisticsInformation payload; + + @JsonCreator + public PartialKeyStatistics( + @JsonProperty("stage") final StageId stageId, + @JsonProperty("worker") final int workerNumber, + @JsonProperty("payload") final PartialKeyStatisticsInformation payload + ) + { + this.stageId = Preconditions.checkNotNull(stageId, "stageId"); + this.workerNumber = workerNumber; + this.payload = payload; + } + + @Override + public String getQueryId() + { + return stageId.getQueryId(); + } + + @JsonProperty("stage") + public StageId getStageId() + { + return stageId; + } + + @JsonProperty("worker") + public int getWorkerNumber() + { + return workerNumber; + } + + @JsonProperty + public PartialKeyStatisticsInformation getPayload() + { + return payload; + } + + + @Override + public void handle(Controller controller) + { + controller.updatePartialKeyStatisticsInformation( + stageId.getStageNumber(), + workerNumber, + payload + ); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PartialKeyStatistics that = (PartialKeyStatistics) o; + return workerNumber == that.workerNumber + && Objects.equals(stageId, that.stageId) + && Objects.equals(payload, that.payload); + } + + @Override + public int hashCode() + { + return Objects.hash(stageId, workerNumber, payload); + } + + @Override + public String toString() + { + return "PartialKeyStatistics{" + + "stageId=" + stageId + + ", workerNumber=" + workerNumber + + ", payload=" + payload + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/ResultsComplete.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/ResultsComplete.java new file mode 100644 index 000000000000..58822a357265 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/ResultsComplete.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.messages; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Preconditions; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.exec.ControllerClient; +import org.apache.druid.msq.kernel.StageId; + +import javax.annotation.Nullable; +import java.util.Objects; + +/** + * Message for {@link ControllerClient#postResultsComplete}. + */ +public class ResultsComplete implements ControllerMessage +{ + private final StageId stageId; + private final int workerNumber; + + @Nullable + private final Object resultObject; + + @JsonCreator + public ResultsComplete( + @JsonProperty("stage") final StageId stageId, + @JsonProperty("worker") final int workerNumber, + @Nullable @JsonProperty("result") final Object resultObject + ) + { + this.stageId = Preconditions.checkNotNull(stageId, "stageId"); + this.workerNumber = workerNumber; + this.resultObject = resultObject; + } + + @Override + public String getQueryId() + { + return stageId.getQueryId(); + } + + @JsonProperty("stage") + public StageId getStageId() + { + return stageId; + } + + @JsonProperty("worker") + public int getWorkerNumber() + { + return workerNumber; + } + + @Nullable + @JsonProperty("result") + @JsonInclude(JsonInclude.Include.NON_NULL) + public Object getResultObject() + { + return resultObject; + } + + @Override + public void handle(Controller controller) + { + controller.resultsComplete(stageId.getQueryId(), stageId.getStageNumber(), workerNumber, resultObject); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ResultsComplete that = (ResultsComplete) o; + return workerNumber == that.workerNumber + && Objects.equals(stageId, that.stageId) + && Objects.equals(resultObject, that.resultObject); + } + + @Override + public int hashCode() + { + return Objects.hash(stageId, workerNumber, resultObject); + } + + @Override + public String toString() + { + return "ResultsComplete{" + + "stageId=" + stageId + + ", workerNumber=" + workerNumber + + ", resultObject=" + resultObject + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/WorkerError.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/WorkerError.java new file mode 100644 index 000000000000..b89cfb356a36 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/WorkerError.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.messages; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Preconditions; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.exec.ControllerClient; +import org.apache.druid.msq.indexing.error.MSQErrorReport; + +import java.util.Objects; + +/** + * Message for {@link ControllerClient#postWorkerError}. + */ +public class WorkerError implements ControllerMessage +{ + private final String queryId; + private final MSQErrorReport errorWrapper; + + @JsonCreator + public WorkerError( + @JsonProperty("queryId") String queryId, + @JsonProperty("error") MSQErrorReport errorWrapper + ) + { + this.queryId = Preconditions.checkNotNull(queryId, "queryId"); + this.errorWrapper = Preconditions.checkNotNull(errorWrapper, "error"); + } + + @Override + @JsonProperty + public String getQueryId() + { + return queryId; + } + + @JsonProperty("error") + public MSQErrorReport getErrorWrapper() + { + return errorWrapper; + } + + @Override + public void handle(Controller controller) + { + controller.workerError(errorWrapper); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + WorkerError that = (WorkerError) o; + return Objects.equals(queryId, that.queryId) + && Objects.equals(errorWrapper, that.errorWrapper); + } + + @Override + public int hashCode() + { + return Objects.hash(queryId, errorWrapper); + } + + @Override + public String toString() + { + return "WorkerError{" + + "queryId='" + queryId + '\'' + + ", errorWrapper=" + errorWrapper + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/WorkerWarning.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/WorkerWarning.java new file mode 100644 index 000000000000..aa2ff6643131 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/WorkerWarning.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.messages; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Preconditions; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.exec.ControllerClient; +import org.apache.druid.msq.indexing.error.MSQErrorReport; + +import java.util.List; +import java.util.Objects; + +/** + * Message for {@link ControllerClient#postWorkerWarning}. + */ +public class WorkerWarning implements ControllerMessage +{ + private final String queryId; + private final List errorWrappers; + + @JsonCreator + public WorkerWarning( + @JsonProperty("queryId") String queryId, + @JsonProperty("errors") List errorWrappers + ) + { + this.queryId = Preconditions.checkNotNull(queryId, "queryId"); + this.errorWrappers = Preconditions.checkNotNull(errorWrappers, "error"); + } + + @Override + @JsonProperty + public String getQueryId() + { + return queryId; + } + + @JsonProperty("errors") + public List getErrorWrappers() + { + return errorWrappers; + } + + @Override + public void handle(Controller controller) + { + controller.workerWarning(errorWrappers); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + WorkerWarning that = (WorkerWarning) o; + return Objects.equals(queryId, that.queryId) && Objects.equals(errorWrappers, that.errorWrappers); + } + + @Override + public int hashCode() + { + return Objects.hash(queryId, errorWrappers); + } + + @Override + public String toString() + { + return "WorkerWarning{" + + "queryId='" + queryId + '\'' + + ", errorWrappers=" + errorWrappers + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartQueryMaker.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartQueryMaker.java new file mode 100644 index 000000000000..37ed936a1173 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartQueryMaker.java @@ -0,0 +1,484 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.sql; + +import com.google.common.base.Throwables; +import com.google.common.collect.Iterators; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.io.LimitedOutputStream; +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.java.util.common.Either; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.Pair; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.guava.BaseSequence; +import org.apache.druid.java.util.common.guava.Sequence; +import org.apache.druid.java.util.common.jackson.JacksonUtils; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.dart.controller.ControllerHolder; +import org.apache.druid.msq.dart.controller.DartControllerContextFactory; +import org.apache.druid.msq.dart.controller.DartControllerRegistry; +import org.apache.druid.msq.dart.guice.DartControllerConfig; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.exec.ControllerContext; +import org.apache.druid.msq.exec.ControllerImpl; +import org.apache.druid.msq.exec.QueryListener; +import org.apache.druid.msq.exec.ResultsContext; +import org.apache.druid.msq.indexing.MSQSpec; +import org.apache.druid.msq.indexing.TaskReportQueryListener; +import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination; +import org.apache.druid.msq.indexing.error.CanceledFault; +import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.msq.indexing.report.MSQResultsReport; +import org.apache.druid.msq.indexing.report.MSQStatusReport; +import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; +import org.apache.druid.msq.sql.MSQTaskQueryMaker; +import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.server.QueryResponse; +import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.DruidQuery; +import org.apache.druid.sql.calcite.run.QueryMaker; +import org.apache.druid.sql.calcite.run.SqlResults; + +import javax.annotation.Nullable; +import java.io.ByteArrayOutputStream; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.stream.Collectors; + +/** + * SQL {@link QueryMaker}. Executes queries in two ways, depending on whether the user asked for a full report. + * + * When including a full report, the controller runs in the SQL planning thread (typically an HTTP thread) using + * the method {@link #runWithReport(ControllerHolder)}. The entire response is buffered in memory, up to + * {@link DartControllerConfig#getMaxQueryReportSize()}. + * + * When not including a full report, the controller runs in {@link #controllerExecutor} and results are streamed + * back to the user through {@link ResultIterator}. There is no limit to the size of the returned results. + */ +public class DartQueryMaker implements QueryMaker +{ + private static final Logger log = new Logger(DartQueryMaker.class); + + private final List> fieldMapping; + private final DartControllerContextFactory controllerContextFactory; + private final PlannerContext plannerContext; + + /** + * Controller registry, used to register and remove controllers as they start and finish. + */ + private final DartControllerRegistry controllerRegistry; + + /** + * Controller config. + */ + private final DartControllerConfig controllerConfig; + + /** + * Executor for {@link #runWithoutReport(ControllerHolder)}. Number of thread is equal to + * {@link DartControllerConfig#getConcurrentQueries()}, which limits the number of concurrent controllers. + */ + private final ExecutorService controllerExecutor; + + public DartQueryMaker( + List> fieldMapping, + DartControllerContextFactory controllerContextFactory, + PlannerContext plannerContext, + DartControllerRegistry controllerRegistry, + DartControllerConfig controllerConfig, + ExecutorService controllerExecutor + ) + { + this.fieldMapping = fieldMapping; + this.controllerContextFactory = controllerContextFactory; + this.plannerContext = plannerContext; + this.controllerRegistry = controllerRegistry; + this.controllerConfig = controllerConfig; + this.controllerExecutor = controllerExecutor; + } + + @Override + public QueryResponse runQuery(DruidQuery druidQuery) + { + final MSQSpec querySpec = MSQTaskQueryMaker.makeQuerySpec( + null, + druidQuery, + fieldMapping, + plannerContext, + null // Only used for DML, which this isn't + ); + final List> types = + MSQTaskQueryMaker.getTypes(druidQuery, fieldMapping, plannerContext); + + final String dartQueryId = druidQuery.getQuery().context().getString(DartSqlEngine.CTX_DART_QUERY_ID); + final ControllerContext controllerContext = controllerContextFactory.newContext(dartQueryId); + final ControllerImpl controller = new ControllerImpl( + dartQueryId, + querySpec, + new ResultsContext( + types.stream().map(p -> p.lhs).collect(Collectors.toList()), + SqlResults.Context.fromPlannerContext(plannerContext) + ), + controllerContext + ); + + final ControllerHolder controllerHolder = new ControllerHolder( + controller, + controllerContext, + plannerContext.getSqlQueryId(), + plannerContext.getSql(), + plannerContext.getAuthenticationResult(), + DateTimes.nowUtc() + ); + + final boolean fullReport = druidQuery.getQuery().context().getBoolean( + DartSqlEngine.CTX_FULL_REPORT, + DartSqlEngine.CTX_FULL_REPORT_DEFAULT + ); + + // Register controller before submitting anything to controllerExeuctor, so it shows up in + // "active controllers" lists. + controllerRegistry.register(controllerHolder); + + try { + // runWithReport, runWithoutReport are responsible for calling controllerRegistry.deregister(controllerHolder) + // when their work is done. + final Sequence results = + fullReport ? runWithReport(controllerHolder) : runWithoutReport(controllerHolder); + return QueryResponse.withEmptyContext(results); + } + catch (Throwable e) { + // Error while calling runWithReport or runWithoutReport. Deregister controller immediately. + controllerRegistry.deregister(controllerHolder); + throw e; + } + } + + /** + * Run a query and return the full report, buffered in memory up to + * {@link DartControllerConfig#getMaxQueryReportSize()}. + * + * Arranges for {@link DartControllerRegistry#deregister(ControllerHolder)} to be called upon completion (either + * success or failure). + */ + private Sequence runWithReport(final ControllerHolder controllerHolder) + { + final Future> reportFuture; + + // Run in controllerExecutor. Control doesn't really *need* to be moved to another thread, but we have to + // use the controllerExecutor anyway, to ensure we respect the concurrentQueries configuration. + reportFuture = controllerExecutor.submit(() -> { + final String threadName = Thread.currentThread().getName(); + + try { + Thread.currentThread().setName(nameThread(plannerContext)); + + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final TaskReportQueryListener queryListener = new TaskReportQueryListener( + TaskReportMSQDestination.instance(), + () -> new LimitedOutputStream( + baos, + controllerConfig.getMaxQueryReportSize(), + limit -> StringUtils.format( + "maxQueryReportSize[%,d] exceeded. " + + "Try limiting the result set for your query, or run it with %s[false]", + limit, + DartSqlEngine.CTX_FULL_REPORT + ) + ), + plannerContext.getJsonMapper(), + controllerHolder.getController().queryId(), + Collections.emptyMap() + ); + + if (controllerHolder.run(queryListener)) { + return plannerContext.getJsonMapper() + .readValue(baos.toByteArray(), JacksonUtils.TYPE_REFERENCE_MAP_STRING_OBJECT); + } else { + // Controller was canceled before it ran. + throw MSQErrorReport + .fromFault(controllerHolder.getController().queryId(), null, null, CanceledFault.INSTANCE) + .toDruidException(); + } + } + finally { + controllerRegistry.deregister(controllerHolder); + Thread.currentThread().setName(threadName); + } + }); + + // Return a sequence that reads one row (the report) from reportFuture. + return new BaseSequence<>( + new BaseSequence.IteratorMaker>() + { + @Override + public Iterator make() + { + try { + return Iterators.singletonIterator(new Object[]{reportFuture.get()}); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + catch (ExecutionException e) { + // Unwrap ExecutionExceptions, so errors such as DruidException are serialized properly. + Throwables.throwIfUnchecked(e.getCause()); + throw new RuntimeException(e.getCause()); + } + } + + @Override + public void cleanup(Iterator iterFromMake) + { + // Nothing to do. + } + } + ); + } + + /** + * Run a query and return the results only, streamed back using {@link ResultIteratorMaker}. + * + * Arranges for {@link DartControllerRegistry#deregister(ControllerHolder)} to be called upon completion (either + * success or failure). + */ + private Sequence runWithoutReport(final ControllerHolder controllerHolder) + { + return new BaseSequence<>(new ResultIteratorMaker(controllerHolder)); + } + + /** + * Generate a name for a thread in {@link #controllerExecutor}. + */ + private String nameThread(final PlannerContext plannerContext) + { + return StringUtils.format( + "%s-sqlQueryId[%s]-queryId[%s]", + Thread.currentThread().getName(), + plannerContext.getSqlQueryId(), + plannerContext.queryContext().get(DartSqlEngine.CTX_DART_QUERY_ID) + ); + } + + /** + * Helper for {@link #runWithoutReport(ControllerHolder)}. + */ + class ResultIteratorMaker implements BaseSequence.IteratorMaker + { + private final ControllerHolder controllerHolder; + private final ResultIterator resultIterator = new ResultIterator(); + private boolean made; + + public ResultIteratorMaker(ControllerHolder holder) + { + this.controllerHolder = holder; + submitController(); + } + + /** + * Submits the controller to the executor in the constructor, and remove it from the registry when the + * future resolves. + */ + private void submitController() + { + controllerExecutor.submit(() -> { + final Controller controller = controllerHolder.getController(); + final String threadName = Thread.currentThread().getName(); + + try { + Thread.currentThread().setName(nameThread(plannerContext)); + + if (!controllerHolder.run(resultIterator)) { + // Controller was canceled before it ran. Push a cancellation error to the resultIterator, so the sequence + // returned by "runWithoutReport" can resolve. + resultIterator.pushError( + MSQErrorReport.fromFault(controllerHolder.getController().queryId(), null, null, CanceledFault.INSTANCE) + .toDruidException() + ); + } + } + catch (Exception e) { + log.warn( + e, + "Controller failed for sqlQueryId[%s], controllerHost[%s]", + plannerContext.getSqlQueryId(), + controller.queryId() + ); + } + finally { + controllerRegistry.deregister(controllerHolder); + Thread.currentThread().setName(threadName); + } + }); + } + + @Override + public ResultIterator make() + { + if (made) { + throw new ISE("Cannot call make() more than once"); + } + + made = true; + return resultIterator; + } + + @Override + public void cleanup(final ResultIterator iterFromMake) + { + if (!iterFromMake.complete) { + controllerHolder.cancel(); + } + } + } + + /** + * Helper for {@link ResultIteratorMaker}, which is in turn a helper for {@link #runWithoutReport(ControllerHolder)}. + */ + static class ResultIterator implements Iterator, QueryListener + { + /** + * Number of rows to buffer from {@link #onResultRow(Object[])}. + */ + private static final int BUFFER_SIZE = 128; + + /** + * Empty optional signifies results are complete. + */ + private final BlockingQueue> rowBuffer = new ArrayBlockingQueue<>(BUFFER_SIZE); + + /** + * Only accessed by {@link Iterator} methods, so no need to be thread-safe. + */ + @Nullable + private Either current; + + private volatile boolean complete; + + @Override + public boolean hasNext() + { + return populateAndReturnCurrent().isPresent(); + } + + @Override + public Object[] next() + { + final Object[] retVal = populateAndReturnCurrent().orElseThrow(NoSuchElementException::new); + current = null; + return retVal; + } + + private Optional populateAndReturnCurrent() + { + if (current == null) { + try { + current = rowBuffer.take(); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + + if (current.isValue()) { + return Optional.ofNullable(current.valueOrThrow()); + } else { + // Don't use valueOrThrow to throw errors; here we *don't* want the wrapping in RuntimeException + // that Either.valueOrThrow does. We want the original DruidException to be propagated to the user, if + // there is one. + final Throwable e = current.error(); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); + } + } + + @Override + public boolean readResults() + { + return !complete; + } + + @Override + public void onResultsStart( + final List signature, + @Nullable final List sqlTypeNames + ) + { + // Nothing to do. + } + + @Override + public boolean onResultRow(Object[] row) + { + try { + rowBuffer.put(Either.value(row)); + return !complete; + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + + @Override + public void onResultsComplete() + { + // Nothing to do. + } + + @Override + public void onQueryComplete(MSQTaskReportPayload report) + { + try { + complete = true; + + final MSQStatusReport statusReport = report.getStatus(); + + if (statusReport.getStatus().isSuccess()) { + rowBuffer.put(Either.value(null)); + } else { + pushError(statusReport.getErrorReport().toDruidException()); + } + } + catch (InterruptedException e) { + // Can't fix this by pushing an error, because the rowBuffer isn't accepting new entries. + // Give up, allow controllerHolder.run() to fail. + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + + public void pushError(final Throwable e) throws InterruptedException + { + rowBuffer.put(Either.error(e)); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClient.java new file mode 100644 index 000000000000..447da229d05e --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClient.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.sql; + +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.msq.dart.controller.http.DartSqlResource; +import org.apache.druid.msq.dart.controller.http.GetQueriesResponse; + +import javax.servlet.http.HttpServletRequest; + +/** + * Client for the {@link DartSqlResource} resource. + */ +public interface DartSqlClient +{ + /** + * Get information about all currently-running queries on this server. + * + * @param selfOnly true if only queries from this server should be returned; false if queries from all servers + * should be returned + * + * @see DartSqlResource#doGetRunningQueries(String, HttpServletRequest) the server side + */ + ListenableFuture getRunningQueries(boolean selfOnly); +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientFactory.java new file mode 100644 index 000000000000..879cabe6945f --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientFactory.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.sql; + +import org.apache.druid.server.DruidNode; + +/** + * Generates {@link DartSqlClient} given a target Broker node. + */ +public interface DartSqlClientFactory +{ + DartSqlClient makeClient(DruidNode node); +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientFactoryImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientFactoryImpl.java new file mode 100644 index 000000000000..c2355a43e31a --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientFactoryImpl.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.sql; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.inject.Inject; +import org.apache.druid.guice.annotations.EscalatedGlobal; +import org.apache.druid.guice.annotations.Json; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.msq.dart.controller.http.DartSqlResource; +import org.apache.druid.rpc.FixedServiceLocator; +import org.apache.druid.rpc.ServiceClient; +import org.apache.druid.rpc.ServiceClientFactory; +import org.apache.druid.rpc.ServiceLocation; +import org.apache.druid.rpc.StandardRetryPolicy; +import org.apache.druid.server.DruidNode; + +/** + * Production implementation of {@link DartSqlClientFactory}. + */ +public class DartSqlClientFactoryImpl implements DartSqlClientFactory +{ + private final ServiceClientFactory clientFactory; + private final ObjectMapper jsonMapper; + + @Inject + public DartSqlClientFactoryImpl( + @EscalatedGlobal final ServiceClientFactory clientFactory, + @Json final ObjectMapper jsonMapper + ) + { + this.clientFactory = clientFactory; + this.jsonMapper = jsonMapper; + } + + @Override + public DartSqlClient makeClient(DruidNode node) + { + final ServiceClient client = clientFactory.makeClient( + StringUtils.format("%s[dart-sql]", node.getHostAndPortToUse()), + new FixedServiceLocator(ServiceLocation.fromDruidNode(node).withBasePath(DartSqlResource.PATH)), + StandardRetryPolicy.noRetries() + ); + + return new DartSqlClientImpl(client, jsonMapper); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientImpl.java new file mode 100644 index 000000000000..aebf7e4b90fa --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientImpl.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.sql; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.java.util.common.jackson.JacksonUtils; +import org.apache.druid.java.util.http.client.response.BytesFullResponseHandler; +import org.apache.druid.msq.dart.controller.http.GetQueriesResponse; +import org.apache.druid.rpc.RequestBuilder; +import org.apache.druid.rpc.ServiceClient; +import org.jboss.netty.handler.codec.http.HttpMethod; + +/** + * Production implementation of {@link DartSqlClient}. + */ +public class DartSqlClientImpl implements DartSqlClient +{ + private final ServiceClient client; + private final ObjectMapper jsonMapper; + + public DartSqlClientImpl(final ServiceClient client, final ObjectMapper jsonMapper) + { + this.client = client; + this.jsonMapper = jsonMapper; + } + + @Override + public ListenableFuture getRunningQueries(final boolean selfOnly) + { + return FutureUtils.transform( + client.asyncRequest( + new RequestBuilder(HttpMethod.GET, selfOnly ? "/?selfOnly" : "/"), + new BytesFullResponseHandler() + ), + holder -> JacksonUtils.readValue(jsonMapper, holder.getContent(), GetQueriesResponse.class) + ); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClients.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClients.java new file mode 100644 index 000000000000..733f69ee4bf9 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClients.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.sql; + +import com.google.common.collect.ImmutableList; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import com.google.inject.Inject; +import org.apache.druid.discovery.DiscoveryDruidNode; +import org.apache.druid.discovery.DruidNodeDiscovery; +import org.apache.druid.discovery.DruidNodeDiscoveryProvider; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.guice.ManageLifecycle; +import org.apache.druid.guice.annotations.Self; +import org.apache.druid.java.util.common.lifecycle.LifecycleStart; +import org.apache.druid.java.util.common.lifecycle.LifecycleStop; +import org.apache.druid.msq.dart.controller.http.DartSqlResource; +import org.apache.druid.server.DruidNode; + +import javax.servlet.http.HttpServletRequest; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Keeps {@link DartSqlClient} for all servers except ourselves. Currently the purpose of this is to power + * the "get all queries" API at {@link DartSqlResource#doGetRunningQueries(String, HttpServletRequest)}. + */ +@ManageLifecycle +public class DartSqlClients implements DruidNodeDiscovery.Listener +{ + @GuardedBy("clients") + private final Map clients = new HashMap<>(); + private final DruidNode selfNode; + private final DruidNodeDiscoveryProvider discoveryProvider; + private final DartSqlClientFactory clientFactory; + + private volatile DruidNodeDiscovery discovery; + + @Inject + public DartSqlClients( + @Self DruidNode selfNode, + DruidNodeDiscoveryProvider discoveryProvider, + DartSqlClientFactory clientFactory + ) + { + this.selfNode = selfNode; + this.discoveryProvider = discoveryProvider; + this.clientFactory = clientFactory; + } + + @LifecycleStart + public void start() + { + discovery = discoveryProvider.getForNodeRole(NodeRole.BROKER); + discovery.registerListener(this); + } + + public List getAllClients() + { + synchronized (clients) { + return ImmutableList.copyOf(clients.values()); + } + } + + @Override + public void nodesAdded(final Collection nodes) + { + synchronized (clients) { + for (final DiscoveryDruidNode node : nodes) { + final DruidNode druidNode = node.getDruidNode(); + if (!selfNode.equals(druidNode)) { + clients.computeIfAbsent(druidNode, clientFactory::makeClient); + } + } + } + } + + @Override + public void nodesRemoved(final Collection nodes) + { + synchronized (clients) { + for (final DiscoveryDruidNode node : nodes) { + clients.remove(node.getDruidNode()); + } + } + } + + @LifecycleStop + public void stop() + { + if (discovery != null) { + discovery.removeListener(this); + discovery = null; + } + + synchronized (clients) { + clients.clear(); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlEngine.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlEngine.java new file mode 100644 index 000000000000..28587e0e791a --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlEngine.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.sql; + +import com.google.common.collect.ImmutableList; +import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.error.DruidException; +import org.apache.druid.java.util.common.IAE; +import org.apache.druid.msq.dart.controller.DartControllerContextFactory; +import org.apache.druid.msq.dart.controller.DartControllerRegistry; +import org.apache.druid.msq.dart.controller.http.DartSqlResource; +import org.apache.druid.msq.dart.guice.DartControllerConfig; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.sql.MSQTaskSqlEngine; +import org.apache.druid.query.BaseQuery; +import org.apache.druid.query.QueryContext; +import org.apache.druid.query.QueryContexts; +import org.apache.druid.sql.SqlLifecycleManager; +import org.apache.druid.sql.calcite.planner.Calcites; +import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.run.EngineFeature; +import org.apache.druid.sql.calcite.run.QueryMaker; +import org.apache.druid.sql.calcite.run.SqlEngine; +import org.apache.druid.sql.calcite.run.SqlEngines; +import org.apache.druid.sql.destination.IngestDestination; + +import java.util.Map; +import java.util.concurrent.ExecutorService; + +public class DartSqlEngine implements SqlEngine +{ + private static final String NAME = "msq-dart"; + + /** + * Dart queryId must be globally unique, so we cannot use the user-provided {@link QueryContexts#CTX_SQL_QUERY_ID} + * or {@link BaseQuery#QUERY_ID}. Instead we generate a UUID in {@link DartSqlResource#doPost}, overriding whatever + * the user may have provided. This becomes the {@link Controller#queryId()}. + * + * The user-provided {@link QueryContexts#CTX_SQL_QUERY_ID} is still registered with the {@link SqlLifecycleManager} + * for purposes of query cancellation. + * + * The user-provided {@link BaseQuery#QUERY_ID} is ignored. + */ + public static final String CTX_DART_QUERY_ID = "dartQueryId"; + public static final String CTX_FULL_REPORT = "fullReport"; + public static final boolean CTX_FULL_REPORT_DEFAULT = false; + + private final DartControllerContextFactory controllerContextFactory; + private final DartControllerRegistry controllerRegistry; + private final DartControllerConfig controllerConfig; + private final ExecutorService controllerExecutor; + + public DartSqlEngine( + DartControllerContextFactory controllerContextFactory, + DartControllerRegistry controllerRegistry, + DartControllerConfig controllerConfig, + ExecutorService controllerExecutor + ) + { + this.controllerContextFactory = controllerContextFactory; + this.controllerRegistry = controllerRegistry; + this.controllerConfig = controllerConfig; + this.controllerExecutor = controllerExecutor; + } + + @Override + public String name() + { + return NAME; + } + + @Override + public boolean featureAvailable(EngineFeature feature) + { + switch (feature) { + case CAN_SELECT: + case SCAN_ORDER_BY_NON_TIME: + case SCAN_NEEDS_SIGNATURE: + case WINDOW_FUNCTIONS: + case WINDOW_LEAF_OPERATOR: + case UNNEST: + return true; + + case CAN_INSERT: + case CAN_REPLACE: + case READ_EXTERNAL_DATA: + case ALLOW_BINDABLE_PLAN: + case ALLOW_BROADCAST_RIGHTY_JOIN: + case ALLOW_TOP_LEVEL_UNION_ALL: + case TIMESERIES_QUERY: + case TOPN_QUERY: + case TIME_BOUNDARY_QUERY: + case GROUPING_SETS: + case GROUPBY_IMPLICITLY_SORTS: + return false; + + default: + throw new IAE("Unrecognized feature: %s", feature); + } + } + + @Override + public void validateContext(Map queryContext) + { + SqlEngines.validateNoSpecialContextKeys(queryContext, MSQTaskSqlEngine.SYSTEM_CONTEXT_PARAMETERS); + } + + @Override + public RelDataType resultTypeForSelect( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) + { + if (QueryContext.of(queryContext).getBoolean(CTX_FULL_REPORT, CTX_FULL_REPORT_DEFAULT)) { + return typeFactory.createStructType( + ImmutableList.of( + Calcites.createSqlType(typeFactory, SqlTypeName.VARCHAR) + ), + ImmutableList.of(CTX_FULL_REPORT) + ); + } else { + return validatedRowType; + } + } + + @Override + public RelDataType resultTypeForInsert( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) + { + // Defensive, because we expect this method will not be called without the CAN_INSERT and CAN_REPLACE features. + throw DruidException.defensive("Cannot execute DML commands with engine[%s]", name()); + } + + @Override + public QueryMaker buildQueryMakerForSelect(RelRoot relRoot, PlannerContext plannerContext) + { + return new DartQueryMaker( + relRoot.fields, + controllerContextFactory, + plannerContext, + controllerRegistry, + controllerConfig, + controllerExecutor + ); + } + + @Override + public QueryMaker buildQueryMakerForInsert( + IngestDestination destination, + RelRoot relRoot, + PlannerContext plannerContext + ) + { + // Defensive, because we expect this method will not be called without the CAN_INSERT and CAN_REPLACE features. + throw DruidException.defensive("Cannot execute DML commands with engine[%s]", name()); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartControllerConfig.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartControllerConfig.java new file mode 100644 index 000000000000..25094f44a79a --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartControllerConfig.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.guice; + +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Runtime configuration for controllers (which run on Brokers). + */ +public class DartControllerConfig +{ + @JsonProperty("concurrentQueries") + private int concurrentQueries = 1; + + @JsonProperty("maxQueryReportSize") + private int maxQueryReportSize = 100_000_000; + + public int getConcurrentQueries() + { + return concurrentQueries; + } + + public int getMaxQueryReportSize() + { + return maxQueryReportSize; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartControllerMemoryManagementModule.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartControllerMemoryManagementModule.java new file mode 100644 index 000000000000..95f110ec88be --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartControllerMemoryManagementModule.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.guice; + +import com.google.inject.Binder; +import com.google.inject.Provides; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.guice.annotations.LoadScope; +import org.apache.druid.initialization.DruidModule; +import org.apache.druid.msq.exec.MemoryIntrospector; +import org.apache.druid.msq.exec.MemoryIntrospectorImpl; +import org.apache.druid.query.DruidProcessingConfig; +import org.apache.druid.utils.JvmUtils; + +/** + * Memory management module for Brokers. + */ +@LoadScope(roles = {NodeRole.BROKER_JSON_NAME}) +public class DartControllerMemoryManagementModule implements DruidModule +{ + /** + * Allocate up to 15% of memory for the MSQ framework. This accounts for additional overhead due to native queries, + * the segment timeline, and lookups (which aren't accounted for by our {@link MemoryIntrospector}). + */ + public static final double USABLE_MEMORY_FRACTION = 0.15; + + @Override + public void configure(Binder binder) + { + // Nothing to do. + } + + @Provides + public MemoryIntrospector createMemoryIntrospector( + final DruidProcessingConfig processingConfig, + final DartControllerConfig controllerConfig + ) + { + return new MemoryIntrospectorImpl( + JvmUtils.getRuntimeInfo().getMaxHeapSizeBytes(), + USABLE_MEMORY_FRACTION, + controllerConfig.getConcurrentQueries(), + processingConfig.getNumThreads(), + null + ); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartControllerModule.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartControllerModule.java new file mode 100644 index 000000000000..8a4b73bc9b0f --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartControllerModule.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.guice; + +import com.google.inject.Binder; +import com.google.inject.Inject; +import com.google.inject.Module; +import com.google.inject.Provides; +import org.apache.druid.discovery.DruidNodeDiscoveryProvider; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.guice.Jerseys; +import org.apache.druid.guice.JsonConfigProvider; +import org.apache.druid.guice.LazySingleton; +import org.apache.druid.guice.LifecycleModule; +import org.apache.druid.guice.ManageLifecycle; +import org.apache.druid.guice.annotations.LoadScope; +import org.apache.druid.initialization.DruidModule; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.msq.dart.Dart; +import org.apache.druid.msq.dart.DartResourcePermissionMapper; +import org.apache.druid.msq.dart.controller.ControllerMessageListener; +import org.apache.druid.msq.dart.controller.DartControllerContextFactory; +import org.apache.druid.msq.dart.controller.DartControllerContextFactoryImpl; +import org.apache.druid.msq.dart.controller.DartControllerRegistry; +import org.apache.druid.msq.dart.controller.DartMessageRelayFactoryImpl; +import org.apache.druid.msq.dart.controller.DartMessageRelays; +import org.apache.druid.msq.dart.controller.http.DartSqlResource; +import org.apache.druid.msq.dart.controller.sql.DartSqlClientFactory; +import org.apache.druid.msq.dart.controller.sql.DartSqlClientFactoryImpl; +import org.apache.druid.msq.dart.controller.sql.DartSqlClients; +import org.apache.druid.msq.dart.controller.sql.DartSqlEngine; +import org.apache.druid.msq.rpc.ResourcePermissionMapper; +import org.apache.druid.query.DefaultQueryConfig; +import org.apache.druid.sql.SqlStatementFactory; +import org.apache.druid.sql.SqlToolbox; + +import java.util.Properties; + +/** + * Primary module for Brokers. Checks {@link DartModules#isDartEnabled(Properties)} before installing itself. + */ +@LoadScope(roles = NodeRole.BROKER_JSON_NAME) +public class DartControllerModule implements DruidModule +{ + @Inject + private Properties properties; + + @Override + public void configure(Binder binder) + { + if (DartModules.isDartEnabled(properties)) { + binder.install(new ActualModule()); + } + } + + public static class ActualModule implements Module + { + @Override + public void configure(Binder binder) + { + JsonConfigProvider.bind(binder, DartModules.DART_PROPERTY_BASE + ".controller", DartControllerConfig.class); + JsonConfigProvider.bind(binder, DartModules.DART_PROPERTY_BASE + ".query", DefaultQueryConfig.class, Dart.class); + + Jerseys.addResource(binder, DartSqlResource.class); + + LifecycleModule.register(binder, DartSqlClients.class); + LifecycleModule.register(binder, DartMessageRelays.class); + + binder.bind(ControllerMessageListener.class).in(LazySingleton.class); + binder.bind(DartControllerRegistry.class).in(LazySingleton.class); + binder.bind(DartMessageRelayFactoryImpl.class).in(LazySingleton.class); + binder.bind(DartControllerContextFactory.class) + .to(DartControllerContextFactoryImpl.class) + .in(LazySingleton.class); + binder.bind(DartSqlClientFactory.class) + .to(DartSqlClientFactoryImpl.class) + .in(LazySingleton.class); + binder.bind(ResourcePermissionMapper.class) + .annotatedWith(Dart.class) + .to(DartResourcePermissionMapper.class); + } + + @Provides + @Dart + @LazySingleton + public SqlStatementFactory makeSqlStatementFactory(final DartSqlEngine engine, final SqlToolbox toolbox) + { + return new SqlStatementFactory(toolbox.withEngine(engine)); + } + + @Provides + @ManageLifecycle + public DartMessageRelays makeMessageRelays( + final DruidNodeDiscoveryProvider discoveryProvider, + final DartMessageRelayFactoryImpl messageRelayFactory + ) + { + return new DartMessageRelays(discoveryProvider, messageRelayFactory); + } + + @Provides + @LazySingleton + public DartSqlEngine makeSqlEngine( + DartControllerContextFactory controllerContextFactory, + DartControllerRegistry controllerRegistry, + DartControllerConfig controllerConfig + ) + { + return new DartSqlEngine( + controllerContextFactory, + controllerRegistry, + controllerConfig, + Execs.multiThreaded(controllerConfig.getConcurrentQueries(), "dart-controller-%s") + ); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartModules.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartModules.java new file mode 100644 index 000000000000..a8e1a1b65e69 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartModules.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.guice; + +import java.util.Properties; + +/** + * Common utilities for Dart Guice modules. + */ +public class DartModules +{ + public static final String DART_PROPERTY_BASE = "druid.msq.dart"; + public static final String DART_ENABLED_PROPERTY = DART_PROPERTY_BASE + ".enabled"; + public static final String DART_ENABLED_DEFAULT = String.valueOf(false); + + public static boolean isDartEnabled(final Properties properties) + { + return Boolean.parseBoolean(properties.getProperty(DART_ENABLED_PROPERTY, DART_ENABLED_DEFAULT)); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerConfig.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerConfig.java new file mode 100644 index 000000000000..f7322a1af92c --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerConfig.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.guice; + +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.druid.msq.exec.MemoryIntrospector; + +/** + * Runtime configuration for workers (which run on Historicals). + */ +public class DartWorkerConfig +{ + /** + * By default, allocate up to 35% of memory for the MSQ framework. This accounts for additional overhead due to + * native queries, and lookups (which aren't accounted for by the Dart {@link MemoryIntrospector}). + */ + private static final double DEFAULT_HEAP_FRACTION = 0.35; + + public static final int AUTO = -1; + + @JsonProperty("concurrentQueries") + private int concurrentQueries = AUTO; + + @JsonProperty("heapFraction") + private double heapFraction = DEFAULT_HEAP_FRACTION; + + public int getConcurrentQueries() + { + return concurrentQueries; + } + + public double getHeapFraction() + { + return heapFraction; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerMemoryManagementModule.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerMemoryManagementModule.java new file mode 100644 index 000000000000..9f51a65152a1 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerMemoryManagementModule.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.guice; + +import com.google.inject.Binder; +import com.google.inject.Provides; +import org.apache.druid.collections.BlockingPool; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.error.DruidException; +import org.apache.druid.guice.LazySingleton; +import org.apache.druid.guice.annotations.LoadScope; +import org.apache.druid.guice.annotations.Merging; +import org.apache.druid.initialization.DruidModule; +import org.apache.druid.msq.dart.Dart; +import org.apache.druid.msq.dart.worker.DartProcessingBuffersProvider; +import org.apache.druid.msq.exec.MemoryIntrospector; +import org.apache.druid.msq.exec.MemoryIntrospectorImpl; +import org.apache.druid.msq.exec.ProcessingBuffersProvider; +import org.apache.druid.query.DruidProcessingConfig; +import org.apache.druid.utils.JvmUtils; + +import java.nio.ByteBuffer; + +/** + * Memory management module for Historicals. + */ +@LoadScope(roles = {NodeRole.HISTORICAL_JSON_NAME}) +public class DartWorkerMemoryManagementModule implements DruidModule +{ + @Override + public void configure(Binder binder) + { + // Nothing to do. + } + + @Provides + public MemoryIntrospector createMemoryIntrospector( + final DartWorkerConfig workerConfig, + final DruidProcessingConfig druidProcessingConfig + ) + { + return new MemoryIntrospectorImpl( + JvmUtils.getRuntimeInfo().getMaxHeapSizeBytes(), + workerConfig.getHeapFraction(), + computeConcurrentQueries(workerConfig, druidProcessingConfig), + druidProcessingConfig.getNumThreads(), + null + ); + } + + @Provides + @Dart + @LazySingleton + public ProcessingBuffersProvider createProcessingBuffersProvider( + @Merging final BlockingPool mergeBufferPool, + final DruidProcessingConfig processingConfig + ) + { + return new DartProcessingBuffersProvider(mergeBufferPool, processingConfig.getNumThreads()); + } + + private static int computeConcurrentQueries( + final DartWorkerConfig workerConfig, + final DruidProcessingConfig processingConfig + ) + { + if (workerConfig.getConcurrentQueries() == DartWorkerConfig.AUTO) { + return processingConfig.getNumMergeBuffers(); + } else if (workerConfig.getConcurrentQueries() < 0) { + throw DruidException.forPersona(DruidException.Persona.OPERATOR) + .ofCategory(DruidException.Category.RUNTIME_FAILURE) + .build("concurrentQueries[%s] must be positive or -1", workerConfig.getConcurrentQueries()); + } else if (workerConfig.getConcurrentQueries() > processingConfig.getNumMergeBuffers()) { + throw DruidException.forPersona(DruidException.Persona.OPERATOR) + .ofCategory(DruidException.Category.RUNTIME_FAILURE) + .build( + "concurrentQueries[%s] must be less than numMergeBuffers[%s]", + workerConfig.getConcurrentQueries(), + processingConfig.getNumMergeBuffers() + ); + } else { + return workerConfig.getConcurrentQueries(); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerModule.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerModule.java new file mode 100644 index 000000000000..15bc0e652994 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerModule.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.guice; + +import com.google.inject.Binder; +import com.google.inject.Inject; +import com.google.inject.Key; +import com.google.inject.Module; +import com.google.inject.Provides; +import org.apache.druid.discovery.DruidNodeDiscoveryProvider; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.guice.Jerseys; +import org.apache.druid.guice.JsonConfigProvider; +import org.apache.druid.guice.LazySingleton; +import org.apache.druid.guice.LifecycleModule; +import org.apache.druid.guice.ManageLifecycle; +import org.apache.druid.guice.ManageLifecycleAnnouncements; +import org.apache.druid.guice.annotations.LoadScope; +import org.apache.druid.guice.annotations.Self; +import org.apache.druid.initialization.DruidModule; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.messages.server.MessageRelayMonitor; +import org.apache.druid.messages.server.MessageRelayResource; +import org.apache.druid.messages.server.Outbox; +import org.apache.druid.messages.server.OutboxImpl; +import org.apache.druid.msq.dart.Dart; +import org.apache.druid.msq.dart.DartResourcePermissionMapper; +import org.apache.druid.msq.dart.controller.messages.ControllerMessage; +import org.apache.druid.msq.dart.worker.DartDataSegmentProvider; +import org.apache.druid.msq.dart.worker.DartWorkerFactory; +import org.apache.druid.msq.dart.worker.DartWorkerFactoryImpl; +import org.apache.druid.msq.dart.worker.DartWorkerRunner; +import org.apache.druid.msq.dart.worker.http.DartWorkerResource; +import org.apache.druid.msq.exec.MemoryIntrospector; +import org.apache.druid.msq.querykit.DataSegmentProvider; +import org.apache.druid.msq.rpc.ResourcePermissionMapper; +import org.apache.druid.query.DruidProcessingConfig; +import org.apache.druid.server.DruidNode; +import org.apache.druid.server.security.AuthorizerMapper; + +import java.io.File; +import java.util.Properties; +import java.util.concurrent.ExecutorService; + +/** + * Primary module for workers. Checks {@link DartModules#isDartEnabled(Properties)} before installing itself. + */ +@LoadScope(roles = NodeRole.HISTORICAL_JSON_NAME) +public class DartWorkerModule implements DruidModule +{ + @Inject + private Properties properties; + + @Override + public void configure(Binder binder) + { + if (DartModules.isDartEnabled(properties)) { + binder.install(new ActualModule()); + } + } + + public static class ActualModule implements Module + { + @Override + public void configure(Binder binder) + { + JsonConfigProvider.bind(binder, DartModules.DART_PROPERTY_BASE + ".worker", DartWorkerConfig.class); + Jerseys.addResource(binder, DartWorkerResource.class); + LifecycleModule.register(binder, DartWorkerRunner.class); + LifecycleModule.registerKey(binder, Key.get(MessageRelayMonitor.class, Dart.class)); + + binder.bind(DartWorkerFactory.class) + .to(DartWorkerFactoryImpl.class) + .in(LazySingleton.class); + + binder.bind(DataSegmentProvider.class) + .annotatedWith(Dart.class) + .to(DartDataSegmentProvider.class) + .in(LazySingleton.class); + + binder.bind(ResourcePermissionMapper.class) + .annotatedWith(Dart.class) + .to(DartResourcePermissionMapper.class); + } + + @Provides + @ManageLifecycle + public DartWorkerRunner createWorkerRunner( + @Self final DruidNode selfNode, + final DartWorkerFactory workerFactory, + final DruidNodeDiscoveryProvider discoveryProvider, + final DruidProcessingConfig processingConfig, + @Dart final ResourcePermissionMapper permissionMapper, + final MemoryIntrospector memoryIntrospector, + final AuthorizerMapper authorizerMapper + ) + { + final ExecutorService exec = Execs.multiThreaded(memoryIntrospector.numTasksInJvm(), "dart–worker-%s"); + final File baseTempDir = + new File(processingConfig.getTmpDir(), StringUtils.format("dart_%s", selfNode.getPortToUse())); + return new DartWorkerRunner( + workerFactory, + exec, + discoveryProvider, + permissionMapper, + authorizerMapper, + baseTempDir + ); + } + + @Provides + @Dart + public MessageRelayMonitor createMessageRelayMonitor( + final DruidNodeDiscoveryProvider discoveryProvider, + final Outbox outbox + ) + { + return new MessageRelayMonitor(discoveryProvider, outbox, NodeRole.BROKER); + } + + /** + * Create an {@link Outbox}. + * + * This is {@link ManageLifecycleAnnouncements} scoped so {@link OutboxImpl#stop()} gets called before attempting + * to shut down the Jetty server. If this doesn't happen, then server shutdown is delayed by however long it takes + * any currently-in-flight {@link MessageRelayResource#httpGetMessagesFromOutbox} to resolve. + */ + @Provides + @ManageLifecycleAnnouncements + public Outbox createOutbox() + { + return new OutboxImpl<>(); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartControllerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartControllerClient.java new file mode 100644 index 000000000000..23d83d005497 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartControllerClient.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import org.apache.druid.common.guava.FutureBox; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.error.DruidException; +import org.apache.druid.messages.server.Outbox; +import org.apache.druid.msq.counters.CounterSnapshotsTree; +import org.apache.druid.msq.dart.controller.messages.ControllerMessage; +import org.apache.druid.msq.dart.controller.messages.DoneReadingInput; +import org.apache.druid.msq.dart.controller.messages.PartialKeyStatistics; +import org.apache.druid.msq.dart.controller.messages.ResultsComplete; +import org.apache.druid.msq.dart.controller.messages.WorkerError; +import org.apache.druid.msq.dart.controller.messages.WorkerWarning; +import org.apache.druid.msq.exec.ControllerClient; +import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; + +import javax.annotation.Nullable; +import java.util.List; + +/** + * Implementation of {@link ControllerClient} that uses an {@link Outbox} to send {@link ControllerMessage} + * to a controller. + */ +public class DartControllerClient implements ControllerClient +{ + private final Outbox outbox; + private final String queryId; + private final String controllerHost; + + /** + * Currently-outstanding futures. These are tracked so they can be canceled in {@link #close()}. + */ + private final FutureBox futureBox = new FutureBox(); + + public DartControllerClient( + final Outbox outbox, + final String queryId, + final String controllerHost + ) + { + this.outbox = outbox; + this.queryId = queryId; + this.controllerHost = controllerHost; + } + + @Override + public void postPartialKeyStatistics( + final StageId stageId, + final int workerNumber, + final PartialKeyStatisticsInformation partialKeyStatisticsInformation + ) + { + validateStage(stageId); + sendMessage(new PartialKeyStatistics(stageId, workerNumber, partialKeyStatisticsInformation)); + } + + @Override + public void postDoneReadingInput(StageId stageId, int workerNumber) + { + validateStage(stageId); + sendMessage(new DoneReadingInput(stageId, workerNumber)); + } + + @Override + public void postResultsComplete(StageId stageId, int workerNumber, @Nullable Object resultObject) + { + validateStage(stageId); + sendMessage(new ResultsComplete(stageId, workerNumber, resultObject)); + } + + @Override + public void postWorkerError(MSQErrorReport errorWrapper) + { + sendMessage(new WorkerError(queryId, errorWrapper)); + } + + @Override + public void postWorkerWarning(List errorWrappers) + { + sendMessage(new WorkerWarning(queryId, errorWrappers)); + } + + @Override + public void postCounters(String workerId, CounterSnapshotsTree snapshotsTree) + { + // Do nothing. Live counters are not sent to the controller in this mode. + } + + @Override + public List getWorkerIds() + { + // Workers are set in advance through the WorkOrder, so this method isn't used. + throw new UnsupportedOperationException(); + } + + @Override + public void close() + { + // Cancel any pending futures. + futureBox.close(); + } + + private void sendMessage(final ControllerMessage message) + { + FutureUtils.getUnchecked(futureBox.register(outbox.sendMessage(controllerHost, message)), true); + } + + /** + * Validate that a {@link StageId} has the expected query ID. + */ + private void validateStage(final StageId stageId) + { + if (!stageId.getQueryId().equals(queryId)) { + throw DruidException.defensive( + "Expected queryId[%s] but got queryId[%s], stageNumber[%s]", + queryId, + stageId.getQueryId(), + stageId.getStageNumber() + ); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartDataSegmentProvider.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartDataSegmentProvider.java new file mode 100644 index 000000000000..0e8a38af90a3 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartDataSegmentProvider.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import com.google.inject.Inject; +import org.apache.druid.collections.ReferenceCountingResourceHolder; +import org.apache.druid.collections.ResourceHolder; +import org.apache.druid.error.DruidException; +import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.msq.counters.ChannelCounters; +import org.apache.druid.msq.querykit.DataSegmentProvider; +import org.apache.druid.query.TableDataSource; +import org.apache.druid.segment.CompleteSegment; +import org.apache.druid.segment.PhysicalSegmentInspector; +import org.apache.druid.segment.ReferenceCountingSegment; +import org.apache.druid.server.SegmentManager; +import org.apache.druid.timeline.SegmentId; +import org.apache.druid.timeline.VersionedIntervalTimeline; +import org.apache.druid.timeline.partition.PartitionChunk; + +import java.io.Closeable; +import java.util.Optional; +import java.util.function.Supplier; + +/** + * Implementation of {@link DataSegmentProvider} that uses locally-cached segments from a {@link SegmentManager}. + */ +public class DartDataSegmentProvider implements DataSegmentProvider +{ + private final SegmentManager segmentManager; + + @Inject + public DartDataSegmentProvider(SegmentManager segmentManager) + { + this.segmentManager = segmentManager; + } + + @Override + public Supplier> fetchSegment( + SegmentId segmentId, + ChannelCounters channelCounters, + boolean isReindex + ) + { + if (isReindex) { + throw DruidException.defensive("Got isReindex[%s], expected false", isReindex); + } + + return () -> { + final Optional> timeline = + segmentManager.getTimeline(new TableDataSource(segmentId.getDataSource()).getAnalysis()); + + if (!timeline.isPresent()) { + throw segmentNotFound(segmentId); + } + + final PartitionChunk chunk = + timeline.get().findChunk( + segmentId.getInterval(), + segmentId.getVersion(), + segmentId.getPartitionNum() + ); + + if (chunk == null) { + throw segmentNotFound(segmentId); + } + + final ReferenceCountingSegment segment = chunk.getObject(); + final Optional closeable = segment.acquireReferences(); + if (!closeable.isPresent()) { + // Segment has disappeared before we could acquire a reference to it. + throw segmentNotFound(segmentId); + } + + final Closer closer = Closer.create(); + closer.register(closeable.get()); + closer.register(() -> { + final PhysicalSegmentInspector inspector = segment.as(PhysicalSegmentInspector.class); + channelCounters.addFile(inspector != null ? inspector.getNumRows() : 0, 0); + }); + return new ReferenceCountingResourceHolder<>(new CompleteSegment(null, segment), closer); + }; + } + + /** + * Error to throw when a segment that was requested is not found. This can happen due to segment moves, etc. + */ + private static DruidException segmentNotFound(final SegmentId segmentId) + { + return DruidException.forPersona(DruidException.Persona.USER) + .ofCategory(DruidException.Category.RUNTIME_FAILURE) + .build("Segment[%s] not found on this server. Please retry your query.", segmentId); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartFrameContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartFrameContext.java new file mode 100644 index 000000000000..ff7d9fdc4e9f --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartFrameContext.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.druid.collections.ResourceHolder; +import org.apache.druid.error.DruidException; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.msq.exec.DataServerQueryHandlerFactory; +import org.apache.druid.msq.exec.ProcessingBuffers; +import org.apache.druid.msq.exec.WorkerContext; +import org.apache.druid.msq.exec.WorkerMemoryParameters; +import org.apache.druid.msq.exec.WorkerStorageParameters; +import org.apache.druid.msq.kernel.FrameContext; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.querykit.DataSegmentProvider; +import org.apache.druid.query.groupby.GroupingEngine; +import org.apache.druid.segment.IndexIO; +import org.apache.druid.segment.IndexMergerV9; +import org.apache.druid.segment.SegmentWrangler; +import org.apache.druid.segment.incremental.NoopRowIngestionMeters; +import org.apache.druid.segment.incremental.RowIngestionMeters; +import org.apache.druid.segment.loading.DataSegmentPusher; + +import javax.annotation.Nullable; +import java.io.File; + +/** + * Dart implementation of {@link FrameContext}. + */ +public class DartFrameContext implements FrameContext +{ + private final StageId stageId; + private final SegmentWrangler segmentWrangler; + private final GroupingEngine groupingEngine; + private final DataSegmentProvider dataSegmentProvider; + private final WorkerContext workerContext; + @Nullable + private final ResourceHolder processingBuffers; + private final WorkerMemoryParameters memoryParameters; + private final WorkerStorageParameters storageParameters; + + public DartFrameContext( + final StageId stageId, + final WorkerContext workerContext, + final SegmentWrangler segmentWrangler, + final GroupingEngine groupingEngine, + final DataSegmentProvider dataSegmentProvider, + @Nullable ResourceHolder processingBuffers, + final WorkerMemoryParameters memoryParameters, + final WorkerStorageParameters storageParameters + ) + { + this.stageId = stageId; + this.segmentWrangler = segmentWrangler; + this.groupingEngine = groupingEngine; + this.dataSegmentProvider = dataSegmentProvider; + this.workerContext = workerContext; + this.processingBuffers = processingBuffers; + this.memoryParameters = memoryParameters; + this.storageParameters = storageParameters; + } + + @Override + public SegmentWrangler segmentWrangler() + { + return segmentWrangler; + } + + @Override + public GroupingEngine groupingEngine() + { + return groupingEngine; + } + + @Override + public RowIngestionMeters rowIngestionMeters() + { + return new NoopRowIngestionMeters(); + } + + @Override + public DataSegmentProvider dataSegmentProvider() + { + return dataSegmentProvider; + } + + @Override + public File tempDir() + { + return new File(workerContext.tempDir(), stageId.toString()); + } + + @Override + public ObjectMapper jsonMapper() + { + return workerContext.jsonMapper(); + } + + @Override + public IndexIO indexIO() + { + throw new UnsupportedOperationException(); + } + + @Override + public File persistDir() + { + return new File(tempDir(), "persist"); + } + + @Override + public DataSegmentPusher segmentPusher() + { + throw DruidException.defensive("Ingestion not implemented"); + } + + @Override + public IndexMergerV9 indexMerger() + { + throw DruidException.defensive("Ingestion not implemented"); + } + + @Override + public ProcessingBuffers processingBuffers() + { + if (processingBuffers != null) { + return processingBuffers.get(); + } else { + throw new ISE("No processing buffers"); + } + } + + @Override + public WorkerMemoryParameters memoryParameters() + { + return memoryParameters; + } + + @Override + public WorkerStorageParameters storageParameters() + { + return storageParameters; + } + + @Override + public DataServerQueryHandlerFactory dataServerQueryHandlerFactory() + { + // We don't query data servers. This factory won't actually be used, because Dart doesn't allow segmentSource to be + // overridden; it always uses SegmentSource.NONE. (If it is called, some wires got crossed somewhere.) + return null; + } + + @Override + public void close() + { + if (processingBuffers != null) { + processingBuffers.close(); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProvider.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProvider.java new file mode 100644 index 000000000000..e2a7b97c4c2a --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProvider.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import org.apache.druid.collections.BlockingPool; +import org.apache.druid.collections.QueueNonBlockingPool; +import org.apache.druid.collections.ReferenceCountingResourceHolder; +import org.apache.druid.collections.ResourceHolder; +import org.apache.druid.error.DruidException; +import org.apache.druid.frame.processor.Bouncer; +import org.apache.druid.msq.exec.ProcessingBuffers; +import org.apache.druid.msq.exec.ProcessingBuffersProvider; +import org.apache.druid.msq.exec.ProcessingBuffersSet; +import org.apache.druid.utils.CloseableUtils; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; + +/** + * Production implementation of {@link ProcessingBuffersProvider} that uses the merge buffer pool. Each call + * to {@link #acquire(int)} acquires one merge buffer and slices it up. + */ +public class DartProcessingBuffersProvider implements ProcessingBuffersProvider +{ + private final BlockingPool mergeBufferPool; + private final int processingThreads; + + public DartProcessingBuffersProvider(BlockingPool mergeBufferPool, int processingThreads) + { + this.mergeBufferPool = mergeBufferPool; + this.processingThreads = processingThreads; + } + + @Override + public ResourceHolder acquire(final int poolSize) + { + if (poolSize == 0) { + return new ReferenceCountingResourceHolder<>(ProcessingBuffersSet.EMPTY, () -> {}); + } + + final List> batch = mergeBufferPool.takeBatch(1, 0); + if (batch.isEmpty()) { + throw DruidException.forPersona(DruidException.Persona.USER) + .ofCategory(DruidException.Category.RUNTIME_FAILURE) + .build("No merge buffers available, cannot execute query"); + } + + final ReferenceCountingResourceHolder bufferHolder = batch.get(0); + try { + final ByteBuffer buffer = bufferHolder.get().duplicate(); + final int sliceSize = buffer.capacity() / poolSize / processingThreads; + final List pool = new ArrayList<>(poolSize); + + for (int i = 0; i < poolSize; i++) { + final BlockingQueue queue = new ArrayBlockingQueue<>(processingThreads); + for (int j = 0; j < processingThreads; j++) { + final int sliceNum = i * processingThreads + j; + buffer.position(sliceSize * sliceNum).limit(sliceSize * (sliceNum + 1)); + queue.add(buffer.slice()); + } + final ProcessingBuffers buffers = new ProcessingBuffers( + new QueueNonBlockingPool<>(queue), + new Bouncer(processingThreads) + ); + pool.add(buffers); + } + + return new ReferenceCountingResourceHolder<>(new ProcessingBuffersSet(pool), bufferHolder); + } + catch (Throwable e) { + throw CloseableUtils.closeAndWrapInCatch(e, bufferHolder); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartQueryableSegment.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartQueryableSegment.java new file mode 100644 index 000000000000..574601517b44 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartQueryableSegment.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import com.google.common.base.Preconditions; +import org.apache.druid.timeline.DataSegment; +import org.joda.time.Interval; + +import java.util.Objects; + +/** + * Represents a segment that is queryable at a specific worker number. + */ +public class DartQueryableSegment +{ + private final DataSegment segment; + private final Interval interval; + private final int workerNumber; + + public DartQueryableSegment(final DataSegment segment, final Interval interval, final int workerNumber) + { + this.segment = Preconditions.checkNotNull(segment, "segment"); + this.interval = Preconditions.checkNotNull(interval, "interval"); + this.workerNumber = workerNumber; + } + + public DataSegment getSegment() + { + return segment; + } + + public Interval getInterval() + { + return interval; + } + + public int getWorkerNumber() + { + return workerNumber; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DartQueryableSegment that = (DartQueryableSegment) o; + return workerNumber == that.workerNumber + && Objects.equals(segment, that.segment) + && Objects.equals(interval, that.interval); + } + + @Override + public int hashCode() + { + return Objects.hash(segment, interval, workerNumber); + } + + @Override + public String toString() + { + return "QueryableDataSegment{" + + "segment=" + segment + + ", interval=" + interval + + ", workerNumber=" + workerNumber + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerClient.java new file mode 100644 index 000000000000..932300de217f --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerClient.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.jaxrs.smile.SmileMediaTypes; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import it.unimi.dsi.fastutil.Pair; +import org.apache.druid.error.DruidException; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.java.util.http.client.response.HttpResponseHandler; +import org.apache.druid.msq.dart.controller.DartWorkerManager; +import org.apache.druid.msq.dart.controller.sql.DartSqlEngine; +import org.apache.druid.msq.dart.worker.http.DartWorkerResource; +import org.apache.druid.msq.exec.WorkerClient; +import org.apache.druid.msq.rpc.BaseWorkerClientImpl; +import org.apache.druid.rpc.FixedServiceLocator; +import org.apache.druid.rpc.IgnoreHttpResponseHandler; +import org.apache.druid.rpc.RequestBuilder; +import org.apache.druid.rpc.ServiceClient; +import org.apache.druid.rpc.ServiceClientFactory; +import org.apache.druid.rpc.ServiceLocation; +import org.apache.druid.rpc.ServiceRetryPolicy; +import org.apache.druid.utils.CloseableUtils; +import org.jboss.netty.handler.codec.http.HttpMethod; + +import javax.annotation.Nullable; +import java.io.Closeable; +import java.net.URI; +import java.util.HashMap; +import java.util.Map; + +/** + * Dart implementation of {@link WorkerClient}. Uses the same {@link BaseWorkerClientImpl} as the task-based engine. + * Each instance of this class is scoped to a single query. + */ +public class DartWorkerClient extends BaseWorkerClientImpl +{ + private static final Logger log = new Logger(DartWorkerClient.class); + + private final String queryId; + private final ServiceClientFactory clientFactory; + private final ServiceRetryPolicy retryPolicy; + + @Nullable + private final String controllerHost; + + @GuardedBy("clientMap") + private final Map> clientMap = new HashMap<>(); + + /** + * Create a worker client. + * + * @param queryId dart query ID. see {@link DartSqlEngine#CTX_DART_QUERY_ID} + * @param clientFactory service client factor + * @param smileMapper Smile object mapper + * @param controllerHost Controller host (see {@link DartWorkerResource#HEADER_CONTROLLER_HOST}) if this is a + * controller-to-worker client. Null if this is a worker-to-worker client. + */ + public DartWorkerClient( + final String queryId, + final ServiceClientFactory clientFactory, + final ObjectMapper smileMapper, + @Nullable final String controllerHost + ) + { + super(smileMapper, SmileMediaTypes.APPLICATION_JACKSON_SMILE); + this.queryId = queryId; + this.clientFactory = clientFactory; + this.controllerHost = controllerHost; + + if (controllerHost == null) { + // worker -> worker client. Retry HTTP 503 in case worker A starts up before worker B, and needs to + // contact it immediately. + this.retryPolicy = new DartWorkerRetryPolicy(true); + } else { + // controller -> worker client. Do not retry any HTTP error codes. If we retry HTTP 503 for controller -> worker, + // we can get stuck trying to contact workers that have exited. + this.retryPolicy = new DartWorkerRetryPolicy(false); + } + } + + @Override + protected ServiceClient getClient(final String workerIdString) + { + final WorkerId workerId = WorkerId.fromString(workerIdString); + if (!queryId.equals(workerId.getQueryId())) { + throw DruidException.defensive("Unexpected queryId[%s]. Expected queryId[%s]", workerId.getQueryId(), queryId); + } + + synchronized (clientMap) { + return clientMap.computeIfAbsent(workerId.getHostAndPort(), ignored -> makeNewClient(workerId)).left(); + } + } + + /** + * Close a single worker's clients. Used when that worker fails, so we stop trying to contact it. + * + * @param workerHost worker host:port + */ + public void closeClient(final String workerHost) + { + synchronized (clientMap) { + final Pair clientPair = clientMap.remove(workerHost); + if (clientPair != null) { + CloseableUtils.closeAndWrapExceptions(clientPair.right()); + } + } + } + + /** + * Close all outstanding clients. + */ + @Override + public void close() + { + synchronized (clientMap) { + for (Map.Entry> entry : clientMap.entrySet()) { + CloseableUtils.closeAndSuppressExceptions( + entry.getValue().right(), + e -> log.warn(e, "Failed to close client[%s]", entry.getKey()) + ); + } + + clientMap.clear(); + } + } + + /** + * Stops a worker. Dart-only API, used by the {@link DartWorkerManager}. + */ + public ListenableFuture stopWorker(String workerId) + { + return getClient(workerId).asyncRequest( + new RequestBuilder(HttpMethod.POST, "/stop"), + IgnoreHttpResponseHandler.INSTANCE + ); + } + + /** + * Create a new client. Called by {@link #getClient(String)} if a new one is needed. + */ + private Pair makeNewClient(final WorkerId workerId) + { + final URI uri = workerId.toUri(); + final FixedServiceLocator locator = new FixedServiceLocator(ServiceLocation.fromUri(uri)); + final ServiceClient baseClient = + clientFactory.makeClient(workerId.toString(), locator, retryPolicy); + final ServiceClient client; + + if (controllerHost != null) { + client = new ControllerDecoratedClient(baseClient, controllerHost); + } else { + client = baseClient; + } + + return Pair.of(client, locator); + } + + /** + * Service client that adds the {@link DartWorkerResource#HEADER_CONTROLLER_HOST} header. + */ + private static class ControllerDecoratedClient implements ServiceClient + { + private final ServiceClient delegate; + private final String controllerHost; + + ControllerDecoratedClient(final ServiceClient delegate, final String controllerHost) + { + this.delegate = delegate; + this.controllerHost = controllerHost; + } + + @Override + public ListenableFuture asyncRequest( + final RequestBuilder requestBuilder, + final HttpResponseHandler handler + ) + { + return delegate.asyncRequest( + requestBuilder.header(DartWorkerResource.HEADER_CONTROLLER_HOST, controllerHost), + handler + ); + } + + @Override + public ServiceClient withRetryPolicy(final ServiceRetryPolicy retryPolicy) + { + return new ControllerDecoratedClient(delegate.withRetryPolicy(retryPolicy), controllerHost); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerContext.java new file mode 100644 index 000000000000..525162fd8ddd --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerContext.java @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Preconditions; +import com.google.inject.Injector; +import org.apache.druid.collections.ResourceHolder; +import org.apache.druid.java.util.common.IAE; +import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.messages.server.Outbox; +import org.apache.druid.msq.dart.controller.messages.ControllerMessage; +import org.apache.druid.msq.exec.ControllerClient; +import org.apache.druid.msq.exec.DataServerQueryHandlerFactory; +import org.apache.druid.msq.exec.MemoryIntrospector; +import org.apache.druid.msq.exec.ProcessingBuffersProvider; +import org.apache.druid.msq.exec.ProcessingBuffersSet; +import org.apache.druid.msq.exec.Worker; +import org.apache.druid.msq.exec.WorkerClient; +import org.apache.druid.msq.exec.WorkerContext; +import org.apache.druid.msq.exec.WorkerMemoryParameters; +import org.apache.druid.msq.exec.WorkerStorageParameters; +import org.apache.druid.msq.kernel.FrameContext; +import org.apache.druid.msq.kernel.WorkOrder; +import org.apache.druid.msq.querykit.DataSegmentProvider; +import org.apache.druid.msq.util.MultiStageQueryContext; +import org.apache.druid.query.DruidProcessingConfig; +import org.apache.druid.query.QueryContext; +import org.apache.druid.query.groupby.GroupingEngine; +import org.apache.druid.segment.SegmentWrangler; +import org.apache.druid.server.DruidNode; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; + +import java.io.File; + +/** + * Dart implementation of {@link WorkerContext}. + * Each instance is scoped to a query. + */ +public class DartWorkerContext implements WorkerContext +{ + private final String queryId; + private final String controllerHost; + private final String workerId; + private final DruidNode selfNode; + private final ObjectMapper jsonMapper; + private final Injector injector; + private final DartWorkerClient workerClient; + private final DruidProcessingConfig processingConfig; + private final SegmentWrangler segmentWrangler; + private final GroupingEngine groupingEngine; + private final DataSegmentProvider dataSegmentProvider; + private final MemoryIntrospector memoryIntrospector; + private final ProcessingBuffersProvider processingBuffersProvider; + private final Outbox outbox; + private final File tempDir; + private final QueryContext queryContext; + + /** + * Lazy initialized upon call to {@link #frameContext(WorkOrder)}. + */ + @MonotonicNonNull + private volatile ResourceHolder processingBuffersSet; + + DartWorkerContext( + final String queryId, + final String controllerHost, + final String workerId, + final DruidNode selfNode, + final ObjectMapper jsonMapper, + final Injector injector, + final DartWorkerClient workerClient, + final DruidProcessingConfig processingConfig, + final SegmentWrangler segmentWrangler, + final GroupingEngine groupingEngine, + final DataSegmentProvider dataSegmentProvider, + final MemoryIntrospector memoryIntrospector, + final ProcessingBuffersProvider processingBuffersProvider, + final Outbox outbox, + final File tempDir, + final QueryContext queryContext + ) + { + this.queryId = queryId; + this.controllerHost = controllerHost; + this.workerId = workerId; + this.selfNode = selfNode; + this.jsonMapper = jsonMapper; + this.injector = injector; + this.workerClient = workerClient; + this.processingConfig = processingConfig; + this.segmentWrangler = segmentWrangler; + this.groupingEngine = groupingEngine; + this.dataSegmentProvider = dataSegmentProvider; + this.memoryIntrospector = memoryIntrospector; + this.processingBuffersProvider = processingBuffersProvider; + this.outbox = outbox; + this.tempDir = tempDir; + this.queryContext = Preconditions.checkNotNull(queryContext, "queryContext"); + } + + @Override + public String queryId() + { + return queryId; + } + + @Override + public String workerId() + { + return workerId; + } + + @Override + public ObjectMapper jsonMapper() + { + return jsonMapper; + } + + @Override + public Injector injector() + { + return injector; + } + + @Override + public void registerWorker(Worker worker, Closer closer) + { + closer.register(() -> { + synchronized (this) { + if (processingBuffersSet != null) { + processingBuffersSet.close(); + processingBuffersSet = null; + } + } + + workerClient.close(); + }); + } + + @Override + public int maxConcurrentStages() + { + final int retVal = MultiStageQueryContext.getMaxConcurrentStagesWithDefault(queryContext, -1); + if (retVal <= 0) { + throw new IAE("Illegal maxConcurrentStages[%s]", retVal); + } + return retVal; + } + + @Override + public ControllerClient makeControllerClient() + { + return new DartControllerClient(outbox, queryId, controllerHost); + } + + @Override + public WorkerClient makeWorkerClient() + { + return workerClient; + } + + @Override + public File tempDir() + { + return tempDir; + } + + @Override + public FrameContext frameContext(WorkOrder workOrder) + { + if (processingBuffersSet == null) { + synchronized (this) { + if (processingBuffersSet == null) { + processingBuffersSet = processingBuffersProvider.acquire( + workOrder.getQueryDefinition(), + maxConcurrentStages() + ); + } + } + } + + final WorkerMemoryParameters memoryParameters = + WorkerMemoryParameters.createProductionInstance( + workOrder, + memoryIntrospector, + maxConcurrentStages() + ); + + final WorkerStorageParameters storageParameters = WorkerStorageParameters.createInstance(-1, false); + + return new DartFrameContext( + workOrder.getStageDefinition().getId(), + this, + segmentWrangler, + groupingEngine, + dataSegmentProvider, + processingBuffersSet.get().acquireForStage(workOrder.getStageDefinition()), + memoryParameters, + storageParameters + ); + } + + @Override + public int threadCount() + { + return processingConfig.getNumThreads(); + } + + @Override + public DataServerQueryHandlerFactory dataServerQueryHandlerFactory() + { + // We don't query data servers. Return null so this factory is ignored when the main worker code tries + // to close it. + return null; + } + + @Override + public boolean includeAllCounters() + { + // The context parameter "includeAllCounters" is meant to assist with backwards compatibility for versions prior + // to Druid 31. Dart didn't exist prior to Druid 31, so there is no need for it here. Always emit all counters. + return true; + } + + @Override + public DruidNode selfNode() + { + return selfNode; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerFactory.java new file mode 100644 index 000000000000..429579b2195e --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerFactory.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import org.apache.druid.msq.exec.Worker; +import org.apache.druid.query.QueryContext; + +import java.io.File; + +/** + * Used by {@link DartWorkerRunner} to create new {@link Worker} instances. + */ +public interface DartWorkerFactory +{ + Worker build(String queryId, String controllerHost, File tempDir, QueryContext context); +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerFactoryImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerFactoryImpl.java new file mode 100644 index 000000000000..eb2b25252f6a --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerFactoryImpl.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.inject.Inject; +import com.google.inject.Injector; +import org.apache.druid.guice.annotations.EscalatedGlobal; +import org.apache.druid.guice.annotations.Json; +import org.apache.druid.guice.annotations.Self; +import org.apache.druid.guice.annotations.Smile; +import org.apache.druid.messages.server.Outbox; +import org.apache.druid.msq.dart.Dart; +import org.apache.druid.msq.dart.controller.messages.ControllerMessage; +import org.apache.druid.msq.dart.worker.http.DartWorkerResource; +import org.apache.druid.msq.exec.MemoryIntrospector; +import org.apache.druid.msq.exec.ProcessingBuffersProvider; +import org.apache.druid.msq.exec.Worker; +import org.apache.druid.msq.exec.WorkerContext; +import org.apache.druid.msq.exec.WorkerImpl; +import org.apache.druid.msq.querykit.DataSegmentProvider; +import org.apache.druid.query.DruidProcessingConfig; +import org.apache.druid.query.QueryContext; +import org.apache.druid.query.groupby.GroupingEngine; +import org.apache.druid.rpc.ServiceClientFactory; +import org.apache.druid.segment.SegmentWrangler; +import org.apache.druid.server.DruidNode; + +import java.io.File; +import java.net.URI; +import java.net.URISyntaxException; + +/** + * Production implementation of {@link DartWorkerFactory}. + */ +public class DartWorkerFactoryImpl implements DartWorkerFactory +{ + private final String id; + private final DruidNode selfNode; + private final ObjectMapper jsonMapper; + private final ObjectMapper smileMapper; + private final Injector injector; + private final ServiceClientFactory serviceClientFactory; + private final DruidProcessingConfig processingConfig; + private final SegmentWrangler segmentWrangler; + private final GroupingEngine groupingEngine; + private final DataSegmentProvider dataSegmentProvider; + private final MemoryIntrospector memoryIntrospector; + private final ProcessingBuffersProvider processingBuffersProvider; + private final Outbox outbox; + + @Inject + public DartWorkerFactoryImpl( + @Self DruidNode selfNode, + @Json ObjectMapper jsonMapper, + @Smile ObjectMapper smileMapper, + Injector injector, + @EscalatedGlobal ServiceClientFactory serviceClientFactory, + DruidProcessingConfig processingConfig, + SegmentWrangler segmentWrangler, + GroupingEngine groupingEngine, + @Dart DataSegmentProvider dataSegmentProvider, + MemoryIntrospector memoryIntrospector, + @Dart ProcessingBuffersProvider processingBuffersProvider, + Outbox outbox + ) + { + this.id = makeWorkerId(selfNode); + this.selfNode = selfNode; + this.jsonMapper = jsonMapper; + this.smileMapper = smileMapper; + this.injector = injector; + this.serviceClientFactory = serviceClientFactory; + this.processingConfig = processingConfig; + this.segmentWrangler = segmentWrangler; + this.groupingEngine = groupingEngine; + this.dataSegmentProvider = dataSegmentProvider; + this.memoryIntrospector = memoryIntrospector; + this.processingBuffersProvider = processingBuffersProvider; + this.outbox = outbox; + } + + @Override + public Worker build(String queryId, String controllerHost, File tempDir, QueryContext queryContext) + { + final WorkerContext workerContext = new DartWorkerContext( + queryId, + controllerHost, + id, + selfNode, + jsonMapper, + injector, + new DartWorkerClient(queryId, serviceClientFactory, smileMapper, null), + processingConfig, + segmentWrangler, + groupingEngine, + dataSegmentProvider, + memoryIntrospector, + processingBuffersProvider, + outbox, + tempDir, + queryContext + ); + + return new WorkerImpl(null, workerContext); + } + + private static String makeWorkerId(final DruidNode selfNode) + { + try { + return new URI( + selfNode.getServiceScheme(), + null, + selfNode.getHost(), + selfNode.getPortToUse(), + DartWorkerResource.PATH, + null, + null + ).toString(); + } + catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerRetryPolicy.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerRetryPolicy.java new file mode 100644 index 000000000000..5dbfe98ef0c5 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerRetryPolicy.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import org.apache.druid.rpc.ServiceRetryPolicy; +import org.apache.druid.rpc.StandardRetryPolicy; +import org.jboss.netty.handler.codec.http.HttpResponse; +import org.jboss.netty.handler.codec.http.HttpResponseStatus; + +/** + * Retry policy for {@link DartWorkerClient}. This is a {@link StandardRetryPolicy#unlimited()} with + * {@link #retryHttpResponse(HttpResponse)} customized to retry fewer HTTP error codes. + */ +public class DartWorkerRetryPolicy implements ServiceRetryPolicy +{ + private final boolean retryOnWorkerUnavailable; + + /** + * Create a retry policy. + * + * @param retryOnWorkerUnavailable whether this policy should retry on {@link HttpResponseStatus#SERVICE_UNAVAILABLE} + */ + public DartWorkerRetryPolicy(boolean retryOnWorkerUnavailable) + { + this.retryOnWorkerUnavailable = retryOnWorkerUnavailable; + } + + @Override + public long maxAttempts() + { + return StandardRetryPolicy.unlimited().maxAttempts(); + } + + @Override + public long minWaitMillis() + { + return StandardRetryPolicy.unlimited().minWaitMillis(); + } + + @Override + public long maxWaitMillis() + { + return StandardRetryPolicy.unlimited().maxWaitMillis(); + } + + @Override + public boolean retryHttpResponse(HttpResponse response) + { + if (retryOnWorkerUnavailable) { + return HttpResponseStatus.SERVICE_UNAVAILABLE.equals(response.getStatus()); + } else { + return false; + } + } + + @Override + public boolean retryThrowable(Throwable t) + { + return StandardRetryPolicy.unlimited().retryThrowable(t); + } + + @Override + public boolean retryLoggable() + { + return false; + } + + @Override + public boolean retryNotAvailable() + { + return false; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerRunner.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerRunner.java new file mode 100644 index 000000000000..ae136196a0fc --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerRunner.java @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import com.google.common.annotations.VisibleForTesting; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import org.apache.druid.discovery.DiscoveryDruidNode; +import org.apache.druid.discovery.DruidNodeDiscovery; +import org.apache.druid.discovery.DruidNodeDiscoveryProvider; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.error.DruidException; +import org.apache.druid.guice.ManageLifecycle; +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.java.util.common.FileUtils; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.lifecycle.LifecycleStart; +import org.apache.druid.java.util.common.lifecycle.LifecycleStop; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.dart.worker.http.DartWorkerInfo; +import org.apache.druid.msq.dart.worker.http.GetWorkersResponse; +import org.apache.druid.msq.exec.Worker; +import org.apache.druid.msq.indexing.error.CanceledFault; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.rpc.ResourcePermissionMapper; +import org.apache.druid.msq.rpc.WorkerResource; +import org.apache.druid.query.QueryContext; +import org.apache.druid.server.security.AuthorizerMapper; +import org.joda.time.DateTime; + +import javax.annotation.Nullable; +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +@ManageLifecycle +public class DartWorkerRunner +{ + private static final Logger log = new Logger(DartWorkerRunner.class); + + /** + * Set of active controllers. Ignore requests from others. + */ + @GuardedBy("this") + private final Set activeControllerHosts = new HashSet<>(); + + /** + * Query ID -> Worker instance. + */ + @GuardedBy("this") + private final Map workerMap = new HashMap<>(); + private final DartWorkerFactory workerFactory; + private final ExecutorService workerExec; + private final DruidNodeDiscoveryProvider discoveryProvider; + private final ResourcePermissionMapper permissionMapper; + private final AuthorizerMapper authorizerMapper; + private final File baseTempDir; + + public DartWorkerRunner( + final DartWorkerFactory workerFactory, + final ExecutorService workerExec, + final DruidNodeDiscoveryProvider discoveryProvider, + final ResourcePermissionMapper permissionMapper, + final AuthorizerMapper authorizerMapper, + final File baseTempDir + ) + { + this.workerFactory = workerFactory; + this.workerExec = workerExec; + this.discoveryProvider = discoveryProvider; + this.permissionMapper = permissionMapper; + this.authorizerMapper = authorizerMapper; + this.baseTempDir = baseTempDir; + } + + /** + * Start a worker, creating a holder for it. If a worker with this query ID is already started, does nothing. + * Returns the worker. + * + * @throws DruidException if the controllerId does not correspond to a currently-active controller + */ + public Worker startWorker( + final String queryId, + final String controllerHost, + final QueryContext context + ) + { + final WorkerHolder holder; + final boolean newHolder; + + synchronized (this) { + if (!activeControllerHosts.contains(controllerHost)) { + throw DruidException.forPersona(DruidException.Persona.OPERATOR) + .ofCategory(DruidException.Category.RUNTIME_FAILURE) + .build("Received startWorker request for unknown controller[%s]", controllerHost); + } + + final WorkerHolder existingHolder = workerMap.get(queryId); + if (existingHolder != null) { + holder = existingHolder; + newHolder = false; + } else { + final Worker worker = workerFactory.build(queryId, controllerHost, baseTempDir, context); + final WorkerResource resource = new WorkerResource(worker, permissionMapper, authorizerMapper); + holder = new WorkerHolder(worker, controllerHost, resource, DateTimes.nowUtc()); + workerMap.put(queryId, holder); + this.notifyAll(); + newHolder = true; + } + } + + if (newHolder) { + workerExec.submit(() -> { + final String originalThreadName = Thread.currentThread().getName(); + try { + Thread.currentThread().setName(StringUtils.format("%s[%s]", originalThreadName, queryId)); + holder.worker.run(); + } + catch (Throwable t) { + if (Thread.interrupted() + || t instanceof MSQException && ((MSQException) t).getFault().getErrorCode().equals(CanceledFault.CODE)) { + log.debug(t, "Canceled, exiting thread."); + } else { + log.warn(t, "Worker for query[%s] failed and stopped.", queryId); + } + } + finally { + synchronized (this) { + workerMap.remove(queryId, holder); + this.notifyAll(); + } + + Thread.currentThread().setName(originalThreadName); + } + }); + } + + return holder.worker; + } + + /** + * Stops a worker. + */ + public void stopWorker(final String queryId) + { + final WorkerHolder holder; + + synchronized (this) { + holder = workerMap.get(queryId); + } + + if (holder != null) { + holder.worker.stop(); + } + } + + /** + * Get the worker resource handler for a query ID if it exists. Returns null if the worker is not running. + */ + @Nullable + public WorkerResource getWorkerResource(final String queryId) + { + synchronized (this) { + final WorkerHolder holder = workerMap.get(queryId); + if (holder != null) { + return holder.resource; + } else { + return null; + } + } + } + + /** + * Returns a {@link GetWorkersResponse} with information about all active workers. + */ + public GetWorkersResponse getWorkersResponse() + { + final List infos = new ArrayList<>(); + + synchronized (this) { + for (final Map.Entry entry : workerMap.entrySet()) { + final String queryId = entry.getKey(); + final WorkerHolder workerHolder = entry.getValue(); + infos.add( + new DartWorkerInfo( + queryId, + WorkerId.fromString(workerHolder.worker.id()), + workerHolder.controllerHost, + workerHolder.acceptTime + ) + ); + } + } + + return new GetWorkersResponse(infos); + } + + @LifecycleStart + public void start() + { + createAndCleanTempDirectory(); + + final DruidNodeDiscovery brokers = discoveryProvider.getForNodeRole(NodeRole.BROKER); + brokers.registerListener(new BrokerListener()); + } + + @LifecycleStop + public void stop() + { + synchronized (this) { + final Collection holders = workerMap.values(); + + for (final WorkerHolder holder : holders) { + holder.worker.stop(); + } + + for (final WorkerHolder holder : holders) { + holder.worker.awaitStop(); + } + } + } + + /** + * Method for testing. Waits for the set of queries to match a given predicate. + */ + @VisibleForTesting + void awaitQuerySet(Predicate> queryIdsPredicate) throws InterruptedException + { + synchronized (this) { + while (!queryIdsPredicate.test(workerMap.keySet())) { + wait(); + } + } + } + + /** + * Creates the {@link #baseTempDir}, and removes any items in it that still exist. + */ + void createAndCleanTempDirectory() + { + try { + FileUtils.mkdirp(baseTempDir); + } + catch (IOException e) { + throw new RuntimeException(e); + } + + final File[] files = baseTempDir.listFiles(); + + if (files != null) { + for (final File file : files) { + if (file.isDirectory()) { + try { + FileUtils.deleteDirectory(file); + log.info("Removed stale query directory[%s].", file); + } + catch (Exception e) { + log.noStackTrace().warn(e, "Could not remove stale query directory[%s], skipping.", file); + } + } + } + } + } + + private static class WorkerHolder + { + private final Worker worker; + private final WorkerResource resource; + private final String controllerHost; + private final DateTime acceptTime; + + public WorkerHolder( + Worker worker, + String controllerHost, + WorkerResource resource, + final DateTime acceptTime + ) + { + this.worker = worker; + this.resource = resource; + this.controllerHost = controllerHost; + this.acceptTime = acceptTime; + } + } + + /** + * Listener that cancels work associated with Brokers that have gone away. + */ + private class BrokerListener implements DruidNodeDiscovery.Listener + { + @Override + public void nodesAdded(Collection nodes) + { + synchronized (DartWorkerRunner.this) { + for (final DiscoveryDruidNode node : nodes) { + activeControllerHosts.add(node.getDruidNode().getHostAndPortToUse()); + } + } + } + + @Override + public void nodesRemoved(Collection nodes) + { + final Set hostsRemoved = + nodes.stream().map(node -> node.getDruidNode().getHostAndPortToUse()).collect(Collectors.toSet()); + + final List workersToNotify = new ArrayList<>(); + + synchronized (DartWorkerRunner.this) { + activeControllerHosts.removeAll(hostsRemoved); + + for (Map.Entry entry : workerMap.entrySet()) { + if (hostsRemoved.contains(entry.getValue().controllerHost)) { + workersToNotify.add(entry.getValue().worker); + } + } + } + + for (final Worker worker : workersToNotify) { + worker.controllerFailed(); + } + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/WorkerId.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/WorkerId.java new file mode 100644 index 000000000000..2bbff7111ca7 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/WorkerId.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; +import com.google.common.base.Joiner; +import com.google.common.base.Preconditions; +import org.apache.druid.java.util.common.IAE; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.msq.dart.worker.http.DartWorkerResource; +import org.apache.druid.msq.kernel.controller.ControllerQueryKernelConfig; +import org.apache.druid.server.DruidNode; +import org.apache.druid.server.coordination.DruidServerMetadata; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Objects; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Worker IDs, of the type returned by {@link ControllerQueryKernelConfig#getWorkerIds()}. + * + * Dart workerIds are strings of the form "scheme:host:port:queryId", like + * "https:host1.example.com:8083:2f05528c-a882-4da5-8b7d-2ecafb7f3f4e". + */ +public class WorkerId +{ + private static final Pattern PATTERN = Pattern.compile("^(\\w+):(.+:\\d+):([a-z0-9-]+)$"); + + private final String scheme; + private final String hostAndPort; + private final String queryId; + private final String fullString; + + public WorkerId(final String scheme, final String hostAndPort, final String queryId) + { + this.scheme = Preconditions.checkNotNull(scheme, "scheme"); + this.hostAndPort = Preconditions.checkNotNull(hostAndPort, "hostAndPort"); + this.queryId = Preconditions.checkNotNull(queryId, "queryId"); + this.fullString = Joiner.on(':').join(scheme, hostAndPort, queryId); + } + + @JsonCreator + public static WorkerId fromString(final String s) + { + if (s == null) { + throw new IAE("Missing workerId"); + } + + final Matcher matcher = PATTERN.matcher(s); + if (matcher.matches()) { + return new WorkerId(matcher.group(1), matcher.group(2), matcher.group(3)); + } else { + throw new IAE("Invalid workerId[%s]", s); + } + } + + /** + * Create a worker ID, which is a URL. + */ + public static WorkerId fromDruidNode(final DruidNode node, final String queryId) + { + return new WorkerId( + node.getServiceScheme(), + node.getHostAndPortToUse(), + queryId + ); + } + + /** + * Create a worker ID, which is a URL. + */ + public static WorkerId fromDruidServerMetadata(final DruidServerMetadata server, final String queryId) + { + return new WorkerId( + server.getHostAndTlsPort() != null ? "https" : "http", + server.getHost(), + queryId + ); + } + + public String getScheme() + { + return scheme; + } + + public String getHostAndPort() + { + return hostAndPort; + } + + public String getQueryId() + { + return queryId; + } + + public URI toUri() + { + try { + final String path = StringUtils.format( + "%s/workers/%s", + DartWorkerResource.PATH, + StringUtils.urlEncode(queryId) + ); + + return new URI(scheme, hostAndPort, path, null, null); + } + catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + @Override + @JsonValue + public String toString() + { + return fullString; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + WorkerId workerId = (WorkerId) o; + return Objects.equals(fullString, workerId.fullString); + } + + @Override + public int hashCode() + { + return fullString.hashCode(); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/http/DartWorkerInfo.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/http/DartWorkerInfo.java new file mode 100644 index 000000000000..3bd14993ded8 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/http/DartWorkerInfo.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker.http; + +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.druid.msq.dart.controller.http.DartQueryInfo; +import org.apache.druid.msq.dart.worker.WorkerId; +import org.joda.time.DateTime; + +import java.util.Objects; + +/** + * Class included in {@link GetWorkersResponse}. + */ +public class DartWorkerInfo +{ + private final String dartQueryId; + private final WorkerId workerId; + private final String controllerHost; + private final DateTime startTime; + + public DartWorkerInfo( + @JsonProperty("dartQueryId") final String dartQueryId, + @JsonProperty("workerId") final WorkerId workerId, + @JsonProperty("controllerHost") final String controllerHost, + @JsonProperty("startTime") final DateTime startTime + ) + { + this.dartQueryId = dartQueryId; + this.workerId = workerId; + this.controllerHost = controllerHost; + this.startTime = startTime; + } + + /** + * Dart query ID generated by the system. Globally unique. + */ + @JsonProperty + public String getDartQueryId() + { + return dartQueryId; + } + + /** + * Worker ID for this query. + */ + @JsonProperty + public WorkerId getWorkerId() + { + return workerId; + } + + /** + * Controller host:port that manages this query. + */ + @JsonProperty + public String getControllerHost() + { + return controllerHost; + } + + /** + * Time this query was accepted by this worker. May be somewhat later than the {@link DartQueryInfo#getStartTime()} + * on the controller. + */ + @JsonProperty + public DateTime getStartTime() + { + return startTime; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DartWorkerInfo that = (DartWorkerInfo) o; + return Objects.equals(dartQueryId, that.dartQueryId) + && Objects.equals(workerId, that.workerId) + && Objects.equals(controllerHost, that.controllerHost) + && Objects.equals(startTime, that.startTime); + } + + @Override + public int hashCode() + { + return Objects.hash(dartQueryId, workerId, controllerHost, startTime); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/http/DartWorkerResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/http/DartWorkerResource.java new file mode 100644 index 000000000000..03fd847cb1af --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/http/DartWorkerResource.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker.http; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.jaxrs.smile.SmileMediaTypes; +import com.google.inject.Inject; +import org.apache.druid.error.DruidException; +import org.apache.druid.guice.LazySingleton; +import org.apache.druid.guice.annotations.Smile; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.messages.server.MessageRelayResource; +import org.apache.druid.messages.server.Outbox; +import org.apache.druid.msq.dart.Dart; +import org.apache.druid.msq.dart.controller.messages.ControllerMessage; +import org.apache.druid.msq.dart.worker.DartWorkerRunner; +import org.apache.druid.msq.kernel.WorkOrder; +import org.apache.druid.msq.rpc.MSQResourceUtils; +import org.apache.druid.msq.rpc.ResourcePermissionMapper; +import org.apache.druid.msq.rpc.WorkerResource; +import org.apache.druid.server.DruidNode; +import org.apache.druid.server.initialization.jetty.ServiceUnavailableException; +import org.apache.druid.server.security.AuthorizerMapper; + +import javax.servlet.http.HttpServletRequest; +import javax.ws.rs.Consumes; +import javax.ws.rs.GET; +import javax.ws.rs.POST; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.Produces; +import javax.ws.rs.core.Context; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; + +/** + * Subclass of {@link WorkerResource} suitable for usage on a Historical. + * + * Note that this is not the same resource as used by {@link org.apache.druid.msq.indexing.MSQWorkerTask}. + * For that, see {@link org.apache.druid.msq.indexing.client.WorkerChatHandler}. + */ +@LazySingleton +@Path(DartWorkerResource.PATH + '/') +public class DartWorkerResource +{ + /** + * Root of worker APIs. + */ + public static final String PATH = "/druid/dart-worker"; + + /** + * Header containing the controller host:port, from {@link DruidNode#getHostAndPortToUse()}. + */ + public static final String HEADER_CONTROLLER_HOST = "X-Dart-Controller-Host"; + + private final DartWorkerRunner workerRunner; + private final ResourcePermissionMapper permissionMapper; + private final AuthorizerMapper authorizerMapper; + private final MessageRelayResource messageRelayResource; + + @Inject + public DartWorkerResource( + final DartWorkerRunner workerRunner, + @Dart final ResourcePermissionMapper permissionMapper, + @Smile final ObjectMapper smileMapper, + final Outbox outbox, + final AuthorizerMapper authorizerMapper + ) + { + this.workerRunner = workerRunner; + this.permissionMapper = permissionMapper; + this.authorizerMapper = authorizerMapper; + this.messageRelayResource = new MessageRelayResource<>( + outbox, + smileMapper, + ControllerMessage.class + ); + } + + /** + * API for retrieving all currently-running queries. + */ + @GET + @Produces(MediaType.APPLICATION_JSON) + @Path("/workers") + public GetWorkersResponse httpGetWorkers(@Context final HttpServletRequest req) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + return workerRunner.getWorkersResponse(); + } + + /** + * Like {@link WorkerResource#httpPostWorkOrder(WorkOrder, HttpServletRequest)}, but implicitly starts a worker + * when the work order is posted. Shadows {@link WorkerResource#httpPostWorkOrder(WorkOrder, HttpServletRequest)}. + */ + @POST + @Consumes({MediaType.APPLICATION_JSON, SmileMediaTypes.APPLICATION_JACKSON_SMILE}) + @Path("/workers/{queryId}/workOrder") + public Response httpPostWorkOrder( + final WorkOrder workOrder, + @PathParam("queryId") final String queryId, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + final String controllerHost = req.getHeader(HEADER_CONTROLLER_HOST); + if (controllerHost == null) { + throw DruidException.forPersona(DruidException.Persona.DEVELOPER) + .ofCategory(DruidException.Category.INVALID_INPUT) + .build("Missing controllerId[%s]", HEADER_CONTROLLER_HOST); + } + + workerRunner.startWorker(queryId, controllerHost, workOrder.getWorkerContext()) + .postWorkOrder(workOrder); + + return Response.status(Response.Status.ACCEPTED).build(); + } + + /** + * Stops a worker. Returns immediately; does not wait for the worker to actually finish. + */ + @POST + @Path("/workers/{queryId}/stop") + public Response httpPostStopWorker( + @PathParam("queryId") final String queryId, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + workerRunner.stopWorker(queryId); + return Response.status(Response.Status.ACCEPTED).build(); + } + + /** + * Handles all {@link WorkerResource} calls, except {@link WorkerResource#httpPostWorkOrder}, which is handled + * by {@link #httpPostWorkOrder(WorkOrder, String, HttpServletRequest)}. + */ + @Path("/workers/{queryId}") + public Object httpCallWorkerResource( + @PathParam("queryId") final String queryId, + @Context final HttpServletRequest req + ) + { + final WorkerResource resource = workerRunner.getWorkerResource(queryId); + + if (resource != null) { + return resource; + } else { + // Return HTTP 503 (Service Unavailable) so worker -> worker clients can retry. When workers are first starting + // up and contacting each other, worker A may contact worker B before worker B has started up. In the future, it + // would be better to do an async wait, with some timeout, for the worker to show up before returning 503. + // That way a retry wouldn't be necessary. + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + throw new ServiceUnavailableException(StringUtils.format("No worker running for query[%s]", queryId)); + } + } + + @Path("/relay") + public Object httpCallMessageRelayServer(@Context final HttpServletRequest req) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + return messageRelayResource; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/http/GetWorkersResponse.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/http/GetWorkersResponse.java new file mode 100644 index 000000000000..0fa28a4ef17f --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/http/GetWorkersResponse.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker.http; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import javax.servlet.http.HttpServletRequest; +import java.util.List; +import java.util.Objects; + +/** + * Response from {@link DartWorkerResource#httpGetWorkers(HttpServletRequest)}, the "get all workers" API. + */ +public class GetWorkersResponse +{ + private final List workers; + + public GetWorkersResponse(@JsonProperty("workers") final List workers) + { + this.workers = workers; + } + + @JsonProperty + public List getWorkers() + { + return workers; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + GetWorkersResponse that = (GetWorkersResponse) o; + return Objects.equals(workers, that.workers); + } + + @Override + public int hashCode() + { + return Objects.hashCode(workers); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java index d316b9b6b0b7..9842de174bb5 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java @@ -22,6 +22,8 @@ import org.apache.druid.indexer.report.TaskReport; import org.apache.druid.msq.counters.CounterSnapshots; import org.apache.druid.msq.counters.CounterSnapshotsTree; +import org.apache.druid.msq.dart.controller.http.DartSqlResource; +import org.apache.druid.msq.dart.controller.sql.DartSqlEngine; import org.apache.druid.msq.indexing.MSQControllerTask; import org.apache.druid.msq.indexing.client.ControllerChatHandler; import org.apache.druid.msq.indexing.error.MSQErrorReport; @@ -42,6 +44,7 @@ public interface Controller * Unique task/query ID for the batch query run by this controller. * * Controller IDs must be globally unique. For tasks, this is the task ID from {@link MSQControllerTask#getId()}. + * For Dart, this is {@link DartSqlEngine#CTX_DART_QUERY_ID}, set by {@link DartSqlResource}. */ String queryId(); @@ -121,6 +124,11 @@ void resultsComplete( */ List getWorkerIds(); + /** + * Returns whether this controller has a worker with the given ID. + */ + boolean hasWorker(String workerId); + @Nullable TaskReport.ReportMap liveReports(); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java index 96c46635662b..1497bd1022bf 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java @@ -1174,6 +1174,16 @@ public List getWorkerIds() return workerManager.getWorkerIds(); } + @Override + public boolean hasWorker(String workerId) + { + if (workerManager == null) { + return false; + } + + return workerManager.getWorkerNumber(workerId) != WorkerManager.UNKNOWN_WORKER_NUMBER; + } + @SuppressWarnings({"unchecked", "rawtypes"}) @Nullable private Int2ObjectMap makeWorkerFactoryInfosForStage( diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java index 7be045542bc8..60325e640e5a 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java @@ -79,6 +79,7 @@ import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryProcessingPool; import org.apache.druid.server.DruidNode; +import org.apache.druid.utils.CloseableUtils; import javax.annotation.Nullable; import java.io.Closeable; @@ -988,6 +989,11 @@ private void doCancel() controllerClient.close(); } + // Close worker client to cancel any currently in-flight calls to other workers. + if (workerClient != null) { + CloseableUtils.closeAndSuppressExceptions(workerClient, e -> log.warn("Failed to close workerClient")); + } + // Clear the main loop event queue, then throw a CanceledFault into the loop to exit it promptly. kernelManipulationQueue.clear(); kernelManipulationQueue.add( diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerManager.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerManager.java index ebce4821d591..31af0953d2f9 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerManager.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerManager.java @@ -83,8 +83,11 @@ public interface WorkerManager Map> getWorkerStats(); /** - * Blocks until all workers exit. Returns quietly, no matter whether there was an exception associated with the - * future from {@link #start()} or not. + * Stop all workers. + * + * The task-based implementation blocks until all tasks exit. Dart's implementation queues workers for stopping in + * the background, and returns immediately. Either way, this method returns quietly, no matter whether there was an + * exception associated with the future from {@link #start()} or not. * * @param interrupt whether to interrupt currently-running work */ diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/TaskReportQueryListener.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/TaskReportQueryListener.java index 4cc4678a58a7..be73a3cbfdd0 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/TaskReportQueryListener.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/TaskReportQueryListener.java @@ -71,6 +71,7 @@ public class TaskReportQueryListener implements QueryListener private JsonGenerator jg; private long numResults; private MSQStatusReport statusReport; + private boolean resultsCurrentlyOpen; public TaskReportQueryListener( final MSQDestination destination, @@ -99,6 +100,7 @@ public void onResultsStart(List signature, @Null { try { openGenerator(); + resultsCurrentlyOpen = true; jg.writeObjectFieldStart(FIELD_RESULTS); writeObjectField(FIELD_RESULTS_SIGNATURE, signature); @@ -118,15 +120,7 @@ public boolean onResultRow(Object[] row) try { JacksonUtils.writeObjectUsingSerializerProvider(jg, serializers, row); numResults++; - - if (rowsInTaskReport == MSQDestination.UNLIMITED || numResults < rowsInTaskReport) { - return true; - } else { - jg.writeEndArray(); - jg.writeBooleanField(FIELD_RESULTS_TRUNCATED, true); - jg.writeEndObject(); - return false; - } + return rowsInTaskReport == MSQDestination.UNLIMITED || numResults < rowsInTaskReport; } catch (IOException e) { throw new RuntimeException(e); @@ -137,6 +131,8 @@ public boolean onResultRow(Object[] row) public void onResultsComplete() { try { + resultsCurrentlyOpen = false; + jg.writeEndArray(); jg.writeBooleanField(FIELD_RESULTS_TRUNCATED, false); jg.writeEndObject(); @@ -150,7 +146,14 @@ public void onResultsComplete() public void onQueryComplete(MSQTaskReportPayload report) { try { - openGenerator(); + if (resultsCurrentlyOpen) { + jg.writeEndArray(); + jg.writeBooleanField(FIELD_RESULTS_TRUNCATED, true); + jg.writeEndObject(); + } else { + openGenerator(); + } + statusReport = report.getStatus(); writeObjectField(FIELD_STATUS, report.getStatus()); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/CanceledFault.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/CanceledFault.java index c81572a88165..2798a3ccfaa6 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/CanceledFault.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/CanceledFault.java @@ -21,6 +21,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonTypeName; +import org.apache.druid.error.DruidException; @JsonTypeName(CanceledFault.CODE) public class CanceledFault extends BaseMSQFault @@ -38,4 +39,13 @@ public static CanceledFault instance() { return INSTANCE; } + + @Override + public DruidException toDruidException() + { + return DruidException.forPersona(DruidException.Persona.USER) + .ofCategory(DruidException.Category.CANCELED) + .withErrorCode(getErrorCode()) + .build(MSQFaultUtils.generateMessageWithErrorCode(this)); + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/ColumnNameRestrictedFault.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/ColumnNameRestrictedFault.java index c2c4617292e0..0ad60bdb0b03 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/ColumnNameRestrictedFault.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/ColumnNameRestrictedFault.java @@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeName; import com.google.common.base.Preconditions; +import org.apache.druid.error.DruidException; import org.apache.druid.java.util.common.StringUtils; import java.util.Objects; @@ -51,6 +52,14 @@ public String getColumnName() return columnName; } + @Override + public DruidException toDruidException() + { + return DruidException.forPersona(DruidException.Persona.USER) + .ofCategory(DruidException.Category.INVALID_INPUT) + .build(MSQFaultUtils.generateMessageWithErrorCode(this)); + } + @Override public boolean equals(Object o) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/ColumnTypeNotSupportedFault.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/ColumnTypeNotSupportedFault.java index 91764b4b3988..2337837785ee 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/ColumnTypeNotSupportedFault.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/ColumnTypeNotSupportedFault.java @@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeName; import com.google.common.base.Preconditions; +import org.apache.druid.error.DruidException; import org.apache.druid.frame.write.UnsupportedColumnTypeException; import org.apache.druid.segment.column.ColumnType; @@ -65,6 +66,15 @@ public ColumnType getColumnType() return columnType; } + @Override + public DruidException toDruidException() + { + return DruidException.forPersona(DruidException.Persona.USER) + .ofCategory(DruidException.Category.INVALID_INPUT) + .withErrorCode(getErrorCode()) + .build(MSQFaultUtils.generateMessageWithErrorCode(this)); + } + @Override public boolean equals(Object o) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/MSQErrorReport.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/MSQErrorReport.java index 8d90bef32ff2..aa515c8b46dc 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/MSQErrorReport.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/MSQErrorReport.java @@ -25,6 +25,7 @@ import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import it.unimi.dsi.fastutil.ints.IntList; +import org.apache.druid.error.DruidException; import org.apache.druid.frame.processor.FrameRowTooLargeException; import org.apache.druid.frame.write.InvalidFieldException; import org.apache.druid.frame.write.InvalidNullByteException; @@ -138,6 +139,31 @@ public String getExceptionStackTrace() return exceptionStackTrace; } + /** + * Returns a {@link DruidException} "equivalent" of this instance. This is useful until such time as we can migrate + * usages of this class to {@link DruidException}. + */ + public DruidException toDruidException() + { + final DruidException druidException = + error.toDruidException() + .withContext("taskId", taskId); + + if (host != null) { + druidException.withContext("host", host); + } + + if (stageNumber != null) { + druidException.withContext("stageNumber", stageNumber); + } + + if (exceptionStackTrace != null) { + druidException.withContext("exceptionStackTrace", exceptionStackTrace); + } + + return druidException; + } + @Override public boolean equals(Object o) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/MSQFault.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/MSQFault.java index c36157e0ddca..39efce9d2044 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/MSQFault.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/MSQFault.java @@ -20,6 +20,7 @@ package org.apache.druid.msq.indexing.error; import com.fasterxml.jackson.annotation.JsonTypeInfo; +import org.apache.druid.error.DruidException; import javax.annotation.Nullable; @@ -36,4 +37,17 @@ public interface MSQFault @Nullable String getErrorMessage(); + /** + * Returns a {@link DruidException} corresponding to this fault. + * + * The default is a {@link DruidException.Category#RUNTIME_FAILURE} targeting {@link DruidException.Persona#USER}. + * Faults with different personas and categories should override this method. + */ + default DruidException toDruidException() + { + return DruidException.forPersona(DruidException.Persona.USER) + .ofCategory(DruidException.Category.RUNTIME_FAILURE) + .withErrorCode(getErrorCode()) + .build(MSQFaultUtils.generateMessageWithErrorCode(this)); + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/QueryNotSupportedFault.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/QueryNotSupportedFault.java index bba058cd5888..7356cc029092 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/QueryNotSupportedFault.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/QueryNotSupportedFault.java @@ -21,6 +21,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonTypeName; +import org.apache.druid.error.DruidException; @JsonTypeName(QueryNotSupportedFault.CODE) public class QueryNotSupportedFault extends BaseMSQFault @@ -33,6 +34,15 @@ public class QueryNotSupportedFault extends BaseMSQFault super(CODE); } + @Override + public DruidException toDruidException() + { + return DruidException.forPersona(DruidException.Persona.USER) + .ofCategory(DruidException.Category.UNSUPPORTED) + .withErrorCode(getErrorCode()) + .build(MSQFaultUtils.generateMessageWithErrorCode(this)); + } + @JsonCreator public static QueryNotSupportedFault instance() { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java index 74f3e780cfea..d3428a6c3428 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java @@ -58,6 +58,8 @@ */ public abstract class BaseWorkerClientImpl implements WorkerClient { + private static final Logger log = new Logger(BaseWorkerClientImpl.class); + private final ObjectMapper objectMapper; private final String contentType; @@ -192,8 +194,6 @@ public ListenableFuture getCounters(String workerId) ); } - private static final Logger log = new Logger(BaseWorkerClientImpl.class); - @Override public ListenableFuture fetchChannelData( String workerId, diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java index 839defa6bd9c..20758883ddba 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java @@ -56,6 +56,7 @@ import javax.ws.rs.core.StreamingOutput; import java.io.InputStream; import java.io.OutputStream; +import java.util.concurrent.atomic.AtomicBoolean; public class WorkerResource { @@ -104,6 +105,8 @@ public Response httpGetChannelData( worker.readStageOutput(new StageId(queryId, stageNumber), partitionNumber, offset); final AsyncContext asyncContext = req.startAsync(); + final AtomicBoolean responseResolved = new AtomicBoolean(); + asyncContext.setTimeout(GET_CHANNEL_DATA_TIMEOUT); asyncContext.addListener( new AsyncListener() @@ -116,6 +119,10 @@ public void onComplete(AsyncEvent event) @Override public void onTimeout(AsyncEvent event) { + if (responseResolved.compareAndSet(false, true)) { + return; + } + HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); response.setStatus(HttpServletResponse.SC_OK); event.getAsyncContext().complete(); @@ -144,7 +151,11 @@ public void onStartAsync(AsyncEvent event) @Override public void onSuccess(final InputStream inputStream) { - HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); + if (!responseResolved.compareAndSet(false, true)) { + return; + } + + final HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); try (final OutputStream outputStream = response.getOutputStream()) { if (inputStream == null) { @@ -188,7 +199,7 @@ public void onSuccess(final InputStream inputStream) @Override public void onFailure(Throwable e) { - if (!dataFuture.isCancelled()) { + if (responseResolved.compareAndSet(false, true)) { try { HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java index 202c1c591b10..7cf8201c5252 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java @@ -28,6 +28,7 @@ import org.apache.druid.error.DruidException; import org.apache.druid.error.InvalidInput; import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.granularity.Granularity; @@ -56,7 +57,6 @@ import org.apache.druid.sql.calcite.parser.DruidSqlIngest; import org.apache.druid.sql.calcite.parser.DruidSqlInsert; import org.apache.druid.sql.calcite.parser.DruidSqlReplace; -import org.apache.druid.sql.calcite.planner.ColumnMapping; import org.apache.druid.sql.calcite.planner.ColumnMappings; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.QueryUtils; @@ -96,7 +96,6 @@ public class MSQTaskQueryMaker implements QueryMaker private final List> fieldMapping; private final MSQTerminalStageSpecFactory terminalStageSpecFactory; - MSQTaskQueryMaker( @Nullable final IngestDestination targetDataSource, final OverlordClient overlordClient, @@ -122,6 +121,38 @@ public QueryResponse runQuery(final DruidQuery druidQuery) String taskId = MSQTasks.controllerTaskId(plannerContext.getSqlQueryId()); + final Map taskContext = new HashMap<>(); + taskContext.put(LookupLoadingSpec.CTX_LOOKUP_LOADING_MODE, plannerContext.getLookupLoadingSpec().getMode()); + if (plannerContext.getLookupLoadingSpec().getMode() == LookupLoadingSpec.Mode.ONLY_REQUIRED) { + taskContext.put(LookupLoadingSpec.CTX_LOOKUPS_TO_LOAD, plannerContext.getLookupLoadingSpec().getLookupsToLoad()); + } + + final List> typeList = getTypes(druidQuery, fieldMapping, plannerContext); + + final MSQControllerTask controllerTask = new MSQControllerTask( + taskId, + makeQuerySpec(targetDataSource, druidQuery, fieldMapping, plannerContext, terminalStageSpecFactory), + MSQTaskQueryMakerUtils.maskSensitiveJsonKeys(plannerContext.getSql()), + plannerContext.queryContextMap(), + SqlResults.Context.fromPlannerContext(plannerContext), + typeList.stream().map(typeInfo -> typeInfo.lhs).collect(Collectors.toList()), + typeList.stream().map(typeInfo -> typeInfo.rhs).collect(Collectors.toList()), + taskContext + ); + + FutureUtils.getUnchecked(overlordClient.runTask(taskId, controllerTask), true); + return QueryResponse.withEmptyContext(Sequences.simple(Collections.singletonList(new Object[]{taskId}))); + } + + public static MSQSpec makeQuerySpec( + @Nullable final IngestDestination targetDataSource, + final DruidQuery druidQuery, + final List> fieldMapping, + final PlannerContext plannerContext, + final MSQTerminalStageSpecFactory terminalStageSpecFactory + ) + { + // SQL query context: context provided by the user, and potentially modified by handlers during planning. // Does not directly influence task execution, but it does form the basis for the initial native query context, // which *does* influence task execution. @@ -138,23 +169,18 @@ public QueryResponse runQuery(final DruidQuery druidQuery) MSQMode.populateDefaultQueryContext(msqMode, nativeQueryContext); } - Object segmentGranularity; - try { - segmentGranularity = Optional.ofNullable(plannerContext.queryContext() - .get(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)) - .orElse(jsonMapper.writeValueAsString(DEFAULT_SEGMENT_GRANULARITY)); - } - catch (JsonProcessingException e) { - // This would only be thrown if we are unable to serialize the DEFAULT_SEGMENT_GRANULARITY, which we don't expect - // to happen - throw DruidException.defensive() - .build( - e, - "Unable to deserialize the DEFAULT_SEGMENT_GRANULARITY in MSQTaskQueryMaker. " - + "This shouldn't have happened since the DEFAULT_SEGMENT_GRANULARITY object is guaranteed to be " - + "serializable. Please raise an issue in case you are seeing this message while executing a query." - ); - } + Object segmentGranularity = + Optional.ofNullable(plannerContext.queryContext().get(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)) + .orElseGet(() -> { + try { + return plannerContext.getJsonMapper().writeValueAsString(DEFAULT_SEGMENT_GRANULARITY); + } + catch (JsonProcessingException e) { + // This would only be thrown if we are unable to serialize the DEFAULT_SEGMENT_GRANULARITY, + // which we don't expect to happen. + throw DruidException.defensive().build(e, "Unable to serialize DEFAULT_SEGMENT_GRANULARITY"); + } + }); final int maxNumTasks = MultiStageQueryContext.getMaxNumTasks(sqlQueryContext); @@ -170,7 +196,7 @@ public QueryResponse runQuery(final DruidQuery druidQuery) final int rowsPerSegment = MultiStageQueryContext.getRowsPerSegment(sqlQueryContext); final int maxRowsInMemory = MultiStageQueryContext.getRowsInMemory(sqlQueryContext); final Integer maxNumSegments = MultiStageQueryContext.getMaxNumSegments(sqlQueryContext); - final IndexSpec indexSpec = MultiStageQueryContext.getIndexSpec(sqlQueryContext, jsonMapper); + final IndexSpec indexSpec = MultiStageQueryContext.getIndexSpec(sqlQueryContext, plannerContext.getJsonMapper()); final boolean finalizeAggregations = MultiStageQueryContext.isFinalizeAggregations(sqlQueryContext); final List replaceTimeChunks = @@ -193,29 +219,6 @@ public QueryResponse runQuery(final DruidQuery druidQuery) ) .orElse(null); - // For assistance computing return types if !finalizeAggregations. - final Map aggregationIntermediateTypeMap = - finalizeAggregations ? null /* Not needed */ : buildAggregationIntermediateTypeMap(druidQuery); - - final List sqlTypeNames = new ArrayList<>(); - final List columnTypeList = new ArrayList<>(); - final List columnMappings = QueryUtils.buildColumnMappings(fieldMapping, druidQuery); - - for (final Entry entry : fieldMapping) { - final String queryColumn = druidQuery.getOutputRowSignature().getColumnName(entry.getKey()); - - final SqlTypeName sqlTypeName; - - if (!finalizeAggregations && aggregationIntermediateTypeMap.containsKey(queryColumn)) { - final ColumnType druidType = aggregationIntermediateTypeMap.get(queryColumn); - sqlTypeName = new RowSignatures.ComplexSqlType(SqlTypeName.OTHER, druidType, true).getSqlTypeName(); - } else { - sqlTypeName = druidQuery.getOutputRowType().getFieldList().get(entry.getKey()).getType().getSqlTypeName(); - } - sqlTypeNames.add(sqlTypeName); - columnTypeList.add(druidQuery.getOutputRowSignature().getColumnType(queryColumn).orElse(ColumnType.STRING)); - } - final MSQDestination destination; if (targetDataSource instanceof ExportDestination) { @@ -229,7 +232,8 @@ public QueryResponse runQuery(final DruidQuery druidQuery) } else if (targetDataSource instanceof TableDestination) { Granularity segmentGranularityObject; try { - segmentGranularityObject = jsonMapper.readValue((String) segmentGranularity, Granularity.class); + segmentGranularityObject = + plannerContext.getJsonMapper().readValue((String) segmentGranularity, Granularity.class); } catch (Exception e) { throw DruidException.defensive() @@ -288,7 +292,7 @@ public QueryResponse runQuery(final DruidQuery druidQuery) final MSQSpec querySpec = MSQSpec.builder() .query(druidQuery.getQuery().withOverriddenContext(nativeQueryContextOverrides)) - .columnMappings(new ColumnMappings(columnMappings)) + .columnMappings(new ColumnMappings(QueryUtils.buildColumnMappings(fieldMapping, druidQuery))) .destination(destination) .assignmentStrategy(MultiStageQueryContext.getAssignmentStrategy(sqlQueryContext)) .tuningConfig(new MSQTuningConfig(maxNumWorkers, maxRowsInMemory, rowsPerSegment, maxNumSegments, indexSpec)) @@ -296,25 +300,42 @@ public QueryResponse runQuery(final DruidQuery druidQuery) MSQTaskQueryMakerUtils.validateRealtimeReindex(querySpec); - final Map context = new HashMap<>(); - context.put(LookupLoadingSpec.CTX_LOOKUP_LOADING_MODE, plannerContext.getLookupLoadingSpec().getMode()); - if (plannerContext.getLookupLoadingSpec().getMode() == LookupLoadingSpec.Mode.ONLY_REQUIRED) { - context.put(LookupLoadingSpec.CTX_LOOKUPS_TO_LOAD, plannerContext.getLookupLoadingSpec().getLookupsToLoad()); - } + return querySpec.withOverriddenContext(nativeQueryContext); + } - final MSQControllerTask controllerTask = new MSQControllerTask( - taskId, - querySpec.withOverriddenContext(nativeQueryContext), - MSQTaskQueryMakerUtils.maskSensitiveJsonKeys(plannerContext.getSql()), - plannerContext.queryContextMap(), - SqlResults.Context.fromPlannerContext(plannerContext), - sqlTypeNames, - columnTypeList, - context - ); + public static List> getTypes( + final DruidQuery druidQuery, + final List> fieldMapping, + final PlannerContext plannerContext + ) + { + final boolean finalizeAggregations = MultiStageQueryContext.isFinalizeAggregations(plannerContext.queryContext()); - FutureUtils.getUnchecked(overlordClient.runTask(taskId, controllerTask), true); - return QueryResponse.withEmptyContext(Sequences.simple(Collections.singletonList(new Object[]{taskId}))); + // For assistance computing return types if !finalizeAggregations. + final Map aggregationIntermediateTypeMap = + finalizeAggregations ? null /* Not needed */ : buildAggregationIntermediateTypeMap(druidQuery); + + final List> retVal = new ArrayList<>(); + + for (final Entry entry : fieldMapping) { + final String queryColumn = druidQuery.getOutputRowSignature().getColumnName(entry.getKey()); + + final SqlTypeName sqlTypeName; + + if (!finalizeAggregations && aggregationIntermediateTypeMap.containsKey(queryColumn)) { + final ColumnType druidType = aggregationIntermediateTypeMap.get(queryColumn); + sqlTypeName = new RowSignatures.ComplexSqlType(SqlTypeName.OTHER, druidType, true).getSqlTypeName(); + } else { + sqlTypeName = druidQuery.getOutputRowType().getFieldList().get(entry.getKey()).getType().getSqlTypeName(); + } + + final ColumnType columnType = + druidQuery.getOutputRowSignature().getColumnType(queryColumn).orElse(ColumnType.STRING); + + retVal.add(Pair.of(sqlTypeName, columnType)); + } + + return retVal; } private static Map buildAggregationIntermediateTypeMap(final DruidQuery druidQuery) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java index 1964ad3de4ca..31a2f5e5e643 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java @@ -42,6 +42,7 @@ import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.granularity.Granularity; +import org.apache.druid.msq.dart.controller.sql.DartSqlEngine; import org.apache.druid.msq.indexing.destination.MSQTerminalStageSpecFactory; import org.apache.druid.msq.querykit.QueryKitUtils; import org.apache.druid.msq.util.ArrayIngestMode; @@ -73,6 +74,9 @@ public class MSQTaskSqlEngine implements SqlEngine { + /** + * Context parameters disallowed for all MSQ engines: task (this one) as well as {@link DartSqlEngine#toString()}. + */ public static final Set SYSTEM_CONTEXT_PARAMETERS = ImmutableSet.builder() .addAll(NativeSqlEngine.SYSTEM_CONTEXT_PARAMETERS) @@ -113,13 +117,21 @@ public void validateContext(Map queryContext) } @Override - public RelDataType resultTypeForSelect(RelDataTypeFactory typeFactory, RelDataType validatedRowType) + public RelDataType resultTypeForSelect( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) { return getMSQStructType(typeFactory); } @Override - public RelDataType resultTypeForInsert(RelDataTypeFactory typeFactory, RelDataType validatedRowType) + public RelDataType resultTypeForInsert( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) { return getMSQStructType(typeFactory); } @@ -387,7 +399,11 @@ private static void validateTypeChanges( final ColumnType oldDruidType = Calcites.getColumnTypeForRelDataType(oldSqlTypeField.getType()); final RelDataType newSqlType = rootRel.getRowType().getFieldList().get(columnIndex).getType(); final ColumnType newDruidType = - DimensionSchemaUtils.getDimensionType(columnName, Calcites.getColumnTypeForRelDataType(newSqlType), arrayIngestMode); + DimensionSchemaUtils.getDimensionType( + columnName, + Calcites.getColumnTypeForRelDataType(newSqlType), + arrayIngestMode + ); if (newDruidType.isArray() && oldDruidType.is(ValueType.STRING) || (newDruidType.is(ValueType.STRING) && oldDruidType.isArray())) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MSQTaskQueryMakerUtils.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MSQTaskQueryMakerUtils.java index a30c9bb0aec0..36c90a21f002 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MSQTaskQueryMakerUtils.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MSQTaskQueryMakerUtils.java @@ -28,7 +28,6 @@ import org.apache.druid.msq.indexing.MSQSpec; import org.apache.druid.msq.indexing.destination.DataSourceMSQDestination; -import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.regex.Matcher; @@ -82,10 +81,8 @@ public static void validateContextSortOrderColumnsExist( final Set allOutputColumns ) { - final Set allOutputColumnsSet = new HashSet<>(allOutputColumns); - for (final String column : contextSortOrder) { - if (!allOutputColumnsSet.contains(column)) { + if (!allOutputColumns.contains(column)) { throw InvalidSqlInput.exception( "Column[%s] from context parameter[%s] does not appear in the query output", column, diff --git a/extensions-core/multi-stage-query/src/main/resources/META-INF/services/org.apache.druid.initialization.DruidModule b/extensions-core/multi-stage-query/src/main/resources/META-INF/services/org.apache.druid.initialization.DruidModule index 92be5604cb8a..1058d5d5f99e 100644 --- a/extensions-core/multi-stage-query/src/main/resources/META-INF/services/org.apache.druid.initialization.DruidModule +++ b/extensions-core/multi-stage-query/src/main/resources/META-INF/services/org.apache.druid.initialization.DruidModule @@ -13,6 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +org.apache.druid.msq.dart.guice.DartControllerMemoryManagementModule +org.apache.druid.msq.dart.guice.DartControllerModule +org.apache.druid.msq.dart.guice.DartWorkerMemoryManagementModule +org.apache.druid.msq.dart.guice.DartWorkerModule org.apache.druid.msq.guice.IndexerMemoryManagementModule org.apache.druid.msq.guice.MSQDurableStorageModule org.apache.druid.msq.guice.MSQExternalDataSourceModule diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/DartTableInputSpecSlicerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/DartTableInputSpecSlicerTest.java new file mode 100644 index 000000000000..be67fe860abf --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/DartTableInputSpecSlicerTest.java @@ -0,0 +1,488 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Ordering; +import it.unimi.dsi.fastutil.ints.IntList; +import it.unimi.dsi.fastutil.ints.IntLists; +import org.apache.druid.client.DruidServer; +import org.apache.druid.client.TimelineServerView; +import org.apache.druid.client.selector.HighestPriorityTierSelectorStrategy; +import org.apache.druid.client.selector.QueryableDruidServer; +import org.apache.druid.client.selector.RandomServerSelectorStrategy; +import org.apache.druid.client.selector.ServerSelector; +import org.apache.druid.data.input.StringTuple; +import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.msq.dart.worker.WorkerId; +import org.apache.druid.msq.input.InputSlice; +import org.apache.druid.msq.input.NilInputSlice; +import org.apache.druid.msq.input.table.RichSegmentDescriptor; +import org.apache.druid.msq.input.table.SegmentsInputSlice; +import org.apache.druid.msq.input.table.TableInputSpec; +import org.apache.druid.query.TableDataSource; +import org.apache.druid.query.filter.EqualityFilter; +import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.server.coordination.DruidServerMetadata; +import org.apache.druid.server.coordination.ServerType; +import org.apache.druid.testing.InitializedNullHandlingTest; +import org.apache.druid.timeline.DataSegment; +import org.apache.druid.timeline.VersionedIntervalTimeline; +import org.apache.druid.timeline.partition.DimensionRangeShardSpec; +import org.apache.druid.timeline.partition.NumberedShardSpec; +import org.apache.druid.timeline.partition.TombstoneShardSpec; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +public class DartTableInputSpecSlicerTest extends InitializedNullHandlingTest +{ + private static final String QUERY_ID = "abc"; + private static final String DATASOURCE = "test-ds"; + private static final String DATASOURCE_NONEXISTENT = "nonexistent-ds"; + private static final String PARTITION_DIM = "dim"; + private static final long BYTES_PER_SEGMENT = 1000; + + /** + * List of servers, with descending priority, so earlier servers are preferred by the {@link ServerSelector}. + * This makes tests deterministic. + */ + private static final List SERVERS = ImmutableList.of( + new DruidServerMetadata("no", "localhost:1001", null, 1, ServerType.HISTORICAL, "__default", 2), + new DruidServerMetadata("no", "localhost:1002", null, 1, ServerType.HISTORICAL, "__default", 1), + new DruidServerMetadata("no", "localhost:1003", null, 1, ServerType.REALTIME, "__default", 0) + ); + + /** + * Dart {@link WorkerId} derived from {@link #SERVERS}. + */ + private static final List WORKER_IDS = + SERVERS.stream() + .map(server -> new WorkerId("http", server.getHostAndPort(), QUERY_ID).toString()) + .collect(Collectors.toList()); + + /** + * Segment that is one of two in a range-partitioned time chunk. + */ + private static final DataSegment SEGMENT1 = new DataSegment( + DATASOURCE, + Intervals.of("2000/2001"), + "1", + Collections.emptyMap(), + Collections.emptyList(), + Collections.emptyList(), + new DimensionRangeShardSpec(ImmutableList.of(PARTITION_DIM), null, new StringTuple(new String[]{"foo"}), 0, 2), + null, + null, + BYTES_PER_SEGMENT + ); + + /** + * Segment that is one of two in a range-partitioned time chunk. + */ + private static final DataSegment SEGMENT2 = new DataSegment( + DATASOURCE, + Intervals.of("2000/2001"), + "1", + Collections.emptyMap(), + Collections.emptyList(), + Collections.emptyList(), + new DimensionRangeShardSpec(ImmutableList.of("dim"), new StringTuple(new String[]{"foo"}), null, 1, 2), + null, + null, + BYTES_PER_SEGMENT + ); + + /** + * Segment that is alone in a time chunk. It is not served by any server, and such segments are assigned to the + * existing servers round-robin. Because this is the only "not served by any server" segment, it should + * be assigned to the first server. + */ + private static final DataSegment SEGMENT3 = new DataSegment( + DATASOURCE, + Intervals.of("2001/2002"), + "1", + Collections.emptyMap(), + Collections.emptyList(), + Collections.emptyList(), + new NumberedShardSpec(0, 1), + null, + null, + BYTES_PER_SEGMENT + ); + + /** + * Segment that should be ignored because it's a tombstone. + */ + private static final DataSegment SEGMENT4 = new DataSegment( + DATASOURCE, + Intervals.of("2002/2003"), + "1", + Collections.emptyMap(), + Collections.emptyList(), + Collections.emptyList(), + TombstoneShardSpec.INSTANCE, + null, + null, + BYTES_PER_SEGMENT + ); + + /** + * Segment that should be ignored (for now) because it's realtime-only. + */ + private static final DataSegment SEGMENT5 = new DataSegment( + DATASOURCE, + Intervals.of("2003/2004"), + "1", + Collections.emptyMap(), + Collections.emptyList(), + Collections.emptyList(), + new NumberedShardSpec(0, 1), + null, + null, + BYTES_PER_SEGMENT + ); + + /** + * Mapping of segment to servers (indexes in {@link #SERVERS}). + */ + private static final Map SEGMENT_SERVERS = + ImmutableMap.builder() + .put(SEGMENT1, IntList.of(0)) + .put(SEGMENT2, IntList.of(1)) + .put(SEGMENT3, IntLists.emptyList()) + .put(SEGMENT4, IntList.of(1)) + .put(SEGMENT5, IntList.of(2)) + .build(); + + private AutoCloseable mockCloser; + + /** + * Slicer under test. Built using {@link #timeline} and {@link #SERVERS}. + */ + private DartTableInputSpecSlicer slicer; + + /** + * Timeline built from {@link #SEGMENT_SERVERS} and {@link #SERVERS}. + */ + private VersionedIntervalTimeline timeline; + + /** + * Server view that uses {@link #timeline}. + */ + @Mock + private TimelineServerView serverView; + + @BeforeEach + void setUp() + { + mockCloser = MockitoAnnotations.openMocks(this); + slicer = DartTableInputSpecSlicer.createFromWorkerIds(WORKER_IDS, serverView); + + // Add all segments to the timeline, round-robin across the two servers. + timeline = new VersionedIntervalTimeline<>(Ordering.natural()); + for (Map.Entry entry : SEGMENT_SERVERS.entrySet()) { + final DataSegment dataSegment = entry.getKey(); + final IntList segmentServers = entry.getValue(); + final ServerSelector serverSelector = new ServerSelector( + dataSegment, + new HighestPriorityTierSelectorStrategy(new RandomServerSelectorStrategy()) + ); + for (int serverNumber : segmentServers) { + final DruidServerMetadata serverMetadata = SERVERS.get(serverNumber); + final DruidServer server = new DruidServer( + serverMetadata.getName(), + serverMetadata.getHostAndPort(), + serverMetadata.getHostAndTlsPort(), + serverMetadata.getMaxSize(), + serverMetadata.getType(), + serverMetadata.getTier(), + serverMetadata.getPriority() + ); + serverSelector.addServerAndUpdateSegment(new QueryableDruidServer<>(server, null), dataSegment); + } + timeline.add( + dataSegment.getInterval(), + dataSegment.getVersion(), + dataSegment.getShardSpec().createChunk(serverSelector) + ); + } + + Mockito.when(serverView.getDruidServerMetadatas()).thenReturn(SERVERS); + Mockito.when(serverView.getTimeline(new TableDataSource(DATASOURCE).getAnalysis())) + .thenReturn(Optional.of(timeline)); + Mockito.when(serverView.getTimeline(new TableDataSource(DATASOURCE_NONEXISTENT).getAnalysis())) + .thenReturn(Optional.empty()); + } + + @AfterEach + void tearDown() throws Exception + { + mockCloser.close(); + } + + @Test + public void test_sliceDynamic() + { + // This slicer cannot sliceDynamic. + + final TableInputSpec inputSpec = new TableInputSpec(DATASOURCE, null, null, null); + Assertions.assertFalse(slicer.canSliceDynamic(inputSpec)); + Assertions.assertThrows( + UnsupportedOperationException.class, + () -> slicer.sliceDynamic(inputSpec, 1, 1, 1) + ); + } + + @Test + public void test_sliceStatic_wholeTable_oneSlice() + { + // When 1 slice is requested, all segments are assigned to one server, even if that server doesn't actually + // currently serve those segments. + + final TableInputSpec inputSpec = new TableInputSpec(DATASOURCE, null, null, null); + final List inputSlices = slicer.sliceStatic(inputSpec, 1); + Assertions.assertEquals( + ImmutableList.of( + new SegmentsInputSlice( + DATASOURCE, + ImmutableList.of( + new RichSegmentDescriptor( + SEGMENT1.getInterval(), + SEGMENT1.getInterval(), + SEGMENT1.getVersion(), + SEGMENT1.getShardSpec().getPartitionNum() + ), + new RichSegmentDescriptor( + SEGMENT2.getInterval(), + SEGMENT2.getInterval(), + SEGMENT2.getVersion(), + SEGMENT2.getShardSpec().getPartitionNum() + ), + new RichSegmentDescriptor( + SEGMENT3.getInterval(), + SEGMENT3.getInterval(), + SEGMENT3.getVersion(), + SEGMENT3.getShardSpec().getPartitionNum() + ) + ), + ImmutableList.of() + ) + ), + inputSlices + ); + } + + @Test + public void test_sliceStatic_wholeTable_twoSlices() + { + // When 2 slices are requested, we assign segments to the servers that have those segments. + + final TableInputSpec inputSpec = new TableInputSpec(DATASOURCE, null, null, null); + final List inputSlices = slicer.sliceStatic(inputSpec, 2); + Assertions.assertEquals( + ImmutableList.of( + new SegmentsInputSlice( + DATASOURCE, + ImmutableList.of( + new RichSegmentDescriptor( + SEGMENT1.getInterval(), + SEGMENT1.getInterval(), + SEGMENT1.getVersion(), + SEGMENT1.getShardSpec().getPartitionNum() + ), + new RichSegmentDescriptor( + SEGMENT3.getInterval(), + SEGMENT3.getInterval(), + SEGMENT3.getVersion(), + SEGMENT3.getShardSpec().getPartitionNum() + ) + ), + ImmutableList.of() + ), + new SegmentsInputSlice( + DATASOURCE, + ImmutableList.of( + new RichSegmentDescriptor( + SEGMENT2.getInterval(), + SEGMENT2.getInterval(), + SEGMENT2.getVersion(), + SEGMENT2.getShardSpec().getPartitionNum() + ) + ), + ImmutableList.of() + ) + ), + inputSlices + ); + } + + @Test + public void test_sliceStatic_wholeTable_threeSlices() + { + // When 3 slices are requested, only 2 are returned, because we only have two workers. + + final TableInputSpec inputSpec = new TableInputSpec(DATASOURCE, null, null, null); + final List inputSlices = slicer.sliceStatic(inputSpec, 3); + Assertions.assertEquals( + ImmutableList.of( + new SegmentsInputSlice( + DATASOURCE, + ImmutableList.of( + new RichSegmentDescriptor( + SEGMENT1.getInterval(), + SEGMENT1.getInterval(), + SEGMENT1.getVersion(), + SEGMENT1.getShardSpec().getPartitionNum() + ), + new RichSegmentDescriptor( + SEGMENT3.getInterval(), + SEGMENT3.getInterval(), + SEGMENT3.getVersion(), + SEGMENT3.getShardSpec().getPartitionNum() + ) + ), + ImmutableList.of() + ), + new SegmentsInputSlice( + DATASOURCE, + ImmutableList.of( + new RichSegmentDescriptor( + SEGMENT2.getInterval(), + SEGMENT2.getInterval(), + SEGMENT2.getVersion(), + SEGMENT2.getShardSpec().getPartitionNum() + ) + ), + ImmutableList.of() + ), + NilInputSlice.INSTANCE + ), + inputSlices + ); + } + + @Test + public void test_sliceStatic_nonexistentTable() + { + final TableInputSpec inputSpec = new TableInputSpec(DATASOURCE_NONEXISTENT, null, null, null); + final List inputSlices = slicer.sliceStatic(inputSpec, 1); + Assertions.assertEquals( + Collections.emptyList(), + inputSlices + ); + } + + @Test + public void test_sliceStatic_dimensionFilter_twoSlices() + { + // Filtered on a dimension that is used for range partitioning in 2000/2001, so one segment gets pruned out. + + final TableInputSpec inputSpec = new TableInputSpec( + DATASOURCE, + null, + new EqualityFilter(PARTITION_DIM, ColumnType.STRING, "abc", null), + null + ); + + final List inputSlices = slicer.sliceStatic(inputSpec, 2); + + Assertions.assertEquals( + ImmutableList.of( + new SegmentsInputSlice( + DATASOURCE, + ImmutableList.of( + new RichSegmentDescriptor( + SEGMENT1.getInterval(), + SEGMENT1.getInterval(), + SEGMENT1.getVersion(), + SEGMENT1.getShardSpec().getPartitionNum() + ), + new RichSegmentDescriptor( + SEGMENT3.getInterval(), + SEGMENT3.getInterval(), + SEGMENT3.getVersion(), + SEGMENT3.getShardSpec().getPartitionNum() + ) + ), + ImmutableList.of() + ), + NilInputSlice.INSTANCE + ), + inputSlices + ); + } + + @Test + public void test_sliceStatic_timeFilter_twoSlices() + { + // Filtered on 2000/2001, so other segments get pruned out. + + final TableInputSpec inputSpec = new TableInputSpec( + DATASOURCE, + Collections.singletonList(Intervals.of("2000/P1Y")), + null, + null + ); + + final List inputSlices = slicer.sliceStatic(inputSpec, 2); + + Assertions.assertEquals( + ImmutableList.of( + new SegmentsInputSlice( + DATASOURCE, + ImmutableList.of( + new RichSegmentDescriptor( + SEGMENT1.getInterval(), + SEGMENT1.getInterval(), + SEGMENT1.getVersion(), + SEGMENT1.getShardSpec().getPartitionNum() + ) + ), + ImmutableList.of() + ), + new SegmentsInputSlice( + DATASOURCE, + ImmutableList.of( + new RichSegmentDescriptor( + SEGMENT2.getInterval(), + SEGMENT2.getInterval(), + SEGMENT2.getVersion(), + SEGMENT2.getShardSpec().getPartitionNum() + ) + ), + ImmutableList.of() + ) + ), + inputSlices + ); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/DartWorkerManagerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/DartWorkerManagerTest.java new file mode 100644 index 000000000000..f4441c984e70 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/DartWorkerManagerTest.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import it.unimi.dsi.fastutil.ints.IntSet; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.error.DruidException; +import org.apache.druid.indexer.TaskState; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.msq.dart.worker.DartWorkerClient; +import org.apache.druid.msq.dart.worker.WorkerId; +import org.apache.druid.msq.exec.WorkerManager; +import org.apache.druid.msq.exec.WorkerStats; +import org.junit.Assert; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public class DartWorkerManagerTest +{ + private static final List WORKERS = ImmutableList.of( + new WorkerId("http", "localhost:1001", "abc").toString(), + new WorkerId("http", "localhost:1002", "abc").toString() + ); + + private DartWorkerManager workerManager; + private AutoCloseable mockCloser; + + @Mock + private DartWorkerClient workerClient; + + @BeforeEach + public void setUp() + { + mockCloser = MockitoAnnotations.openMocks(this); + workerManager = new DartWorkerManager(WORKERS, workerClient); + } + + @AfterEach + public void tearDown() throws Exception + { + mockCloser.close(); + } + + @Test + public void test_getWorkerCount() + { + Assertions.assertEquals(0, workerManager.getWorkerCount().getPendingWorkerCount()); + Assertions.assertEquals(2, workerManager.getWorkerCount().getRunningWorkerCount()); + } + + @Test + public void test_getWorkerIds() + { + Assertions.assertEquals(WORKERS, workerManager.getWorkerIds()); + } + + @Test + public void test_getWorkerStats() + { + final Map> stats = workerManager.getWorkerStats(); + Assertions.assertEquals( + ImmutableMap.of( + 0, Collections.singletonList(new WorkerStats(WORKERS.get(0), TaskState.RUNNING, -1, -1)), + 1, Collections.singletonList(new WorkerStats(WORKERS.get(1), TaskState.RUNNING, -1, -1)) + ), + stats + ); + } + + @Test + public void test_getWorkerNumber() + { + Assertions.assertEquals(0, workerManager.getWorkerNumber(WORKERS.get(0))); + Assertions.assertEquals(1, workerManager.getWorkerNumber(WORKERS.get(1))); + Assertions.assertEquals(WorkerManager.UNKNOWN_WORKER_NUMBER, workerManager.getWorkerNumber("nonexistent")); + } + + @Test + public void test_isWorkerActive() + { + Assertions.assertTrue(workerManager.isWorkerActive(WORKERS.get(0))); + Assertions.assertTrue(workerManager.isWorkerActive(WORKERS.get(1))); + Assertions.assertFalse(workerManager.isWorkerActive("nonexistent")); + } + + @Test + public void test_launchWorkersIfNeeded() + { + workerManager.launchWorkersIfNeeded(0); // Does nothing, less than WORKERS.size() + workerManager.launchWorkersIfNeeded(1); // Does nothing, less than WORKERS.size() + workerManager.launchWorkersIfNeeded(2); // Does nothing, equal to WORKERS.size() + Assert.assertThrows( + DruidException.class, + () -> workerManager.launchWorkersIfNeeded(3) + ); + } + + @Test + public void test_waitForWorkers() + { + workerManager.launchWorkersIfNeeded(2); + workerManager.waitForWorkers(IntSet.of(0, 1)); // Returns immediately + } + + @Test + public void test_start_stop_noInterrupt() + { + Mockito.when(workerClient.stopWorker(WORKERS.get(0))) + .thenReturn(Futures.immediateFuture(null)); + Mockito.when(workerClient.stopWorker(WORKERS.get(1))) + .thenReturn(Futures.immediateFuture(null)); + + final ListenableFuture future = workerManager.start(); + workerManager.stop(false); + + // Ensure the future from start() resolves. + Assertions.assertNull(FutureUtils.getUnchecked(future, true)); + } + + @Test + public void test_start_stop_interrupt() + { + Mockito.when(workerClient.stopWorker(WORKERS.get(0))) + .thenReturn(Futures.immediateFuture(null)); + Mockito.when(workerClient.stopWorker(WORKERS.get(1))) + .thenReturn(Futures.immediateFuture(null)); + + final ListenableFuture future = workerManager.start(); + workerManager.stop(true); + + // Ensure the future from start() resolves. + Assertions.assertNull(FutureUtils.getUnchecked(future, true)); + } + + @Test + public void test_start_stop_interrupt_clientError() + { + Mockito.when(workerClient.stopWorker(WORKERS.get(0))) + .thenReturn(Futures.immediateFailedFuture(new ISE("stop failure"))); + Mockito.when(workerClient.stopWorker(WORKERS.get(1))) + .thenReturn(Futures.immediateFuture(null)); + + final ListenableFuture future = workerManager.start(); + workerManager.stop(true); + + // Ensure the future from start() resolves. + Assertions.assertNull(FutureUtils.getUnchecked(future, true)); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/DartQueryInfoTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/DartQueryInfoTest.java new file mode 100644 index 000000000000..980038723532 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/DartQueryInfoTest.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.http; + +import nl.jqno.equalsverifier.EqualsVerifier; +import org.junit.jupiter.api.Test; + +public class DartQueryInfoTest +{ + @Test + public void test_equals() + { + EqualsVerifier.forClass(DartQueryInfo.class).usingGetClass().verify(); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/DartSqlResourceTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/DartSqlResourceTest.java new file mode 100644 index 000000000000..db3479178724 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/DartSqlResourceTest.java @@ -0,0 +1,757 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.http; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.util.concurrent.Futures; +import org.apache.druid.indexer.report.TaskReport; +import org.apache.druid.indexing.common.TaskLockType; +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.java.util.common.IAE; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.java.util.common.jackson.JacksonUtils; +import org.apache.druid.msq.dart.controller.ControllerHolder; +import org.apache.druid.msq.dart.controller.DartControllerRegistry; +import org.apache.druid.msq.dart.controller.sql.DartQueryMaker; +import org.apache.druid.msq.dart.controller.sql.DartSqlClient; +import org.apache.druid.msq.dart.controller.sql.DartSqlClients; +import org.apache.druid.msq.dart.controller.sql.DartSqlEngine; +import org.apache.druid.msq.dart.guice.DartControllerConfig; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.indexing.error.CanceledFault; +import org.apache.druid.msq.indexing.error.InvalidNullByteFault; +import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.msq.indexing.error.MSQFaultUtils; +import org.apache.druid.msq.indexing.report.MSQTaskReport; +import org.apache.druid.msq.test.MSQTestBase; +import org.apache.druid.msq.test.MSQTestControllerContext; +import org.apache.druid.query.DefaultQueryConfig; +import org.apache.druid.query.QueryContext; +import org.apache.druid.query.QueryContexts; +import org.apache.druid.server.DruidNode; +import org.apache.druid.server.QueryStackTests; +import org.apache.druid.server.ResponseContextConfig; +import org.apache.druid.server.initialization.ServerConfig; +import org.apache.druid.server.log.NoopRequestLogger; +import org.apache.druid.server.metrics.NoopServiceEmitter; +import org.apache.druid.server.mocks.MockAsyncContext; +import org.apache.druid.server.mocks.MockHttpServletResponse; +import org.apache.druid.server.security.AuthConfig; +import org.apache.druid.server.security.AuthenticationResult; +import org.apache.druid.server.security.ForbiddenException; +import org.apache.druid.sql.SqlLifecycleManager; +import org.apache.druid.sql.SqlStatementFactory; +import org.apache.druid.sql.SqlToolbox; +import org.apache.druid.sql.calcite.planner.CalciteRulesManager; +import org.apache.druid.sql.calcite.planner.CatalogResolver; +import org.apache.druid.sql.calcite.planner.PlannerConfig; +import org.apache.druid.sql.calcite.planner.PlannerFactory; +import org.apache.druid.sql.calcite.schema.DruidSchemaCatalog; +import org.apache.druid.sql.calcite.schema.NoopDruidSchemaManager; +import org.apache.druid.sql.calcite.util.CalciteTests; +import org.apache.druid.sql.calcite.util.QueryFrameworkUtils; +import org.apache.druid.sql.calcite.view.NoopViewManager; +import org.apache.druid.sql.hook.DruidHookDispatcher; +import org.apache.druid.sql.http.ResultFormat; +import org.apache.druid.sql.http.SqlQuery; +import org.hamcrest.CoreMatchers; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; + +import javax.servlet.http.HttpServletRequest; +import javax.ws.rs.core.Response; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import static org.hamcrest.MatcherAssert.assertThat; + +/** + * Functional test of {@link DartSqlResource}, {@link DartSqlEngine}, and {@link DartQueryMaker}. + * Other classes are mocked when possible. + */ +public class DartSqlResourceTest extends MSQTestBase +{ + private static final DruidNode SELF_NODE = new DruidNode("none", "localhost", false, 8080, -1, true, false); + private static final String AUTHENTICATOR_NAME = "authn"; + private static final int MAX_CONTROLLERS = 1; + + /** + * A user that is not a superuser. + * See {@link CalciteTests#TEST_AUTHORIZER_MAPPER} for how this user is mapped. + */ + private static final String REGULAR_USER_NAME = "regularUser"; + + /** + * A user that is not a superuser, and is different from {@link #REGULAR_USER_NAME}. + * See {@link CalciteTests#TEST_AUTHORIZER_MAPPER} for how this user is mapped. + */ + private static final String DIFFERENT_REGULAR_USER_NAME = "differentRegularUser"; + + /** + * Latch that cancellation tests can use to determine when a query is added to the {@link DartControllerRegistry}, + * and becomes cancelable. + */ + private final CountDownLatch controllerRegistered = new CountDownLatch(1); + + // Objects created in setUp() below this line. + + private DartSqlResource sqlResource; + private DartControllerRegistry controllerRegistry; + private ExecutorService controllerExecutor; + private AutoCloseable mockCloser; + + // Mocks below this line. + + /** + * Mock for {@link DartSqlClients}, which is used in tests of {@link DartSqlResource#doGetRunningQueries}. + */ + @Mock + private DartSqlClients dartSqlClients; + + /** + * Mock for {@link DartSqlClient}, which is used in tests of {@link DartSqlResource#doGetRunningQueries}. + */ + @Mock + private DartSqlClient dartSqlClient; + + /** + * Mock http request. + */ + @Mock + private HttpServletRequest httpServletRequest; + + /** + * Mock for test cases that need to make two requests. + */ + @Mock + private HttpServletRequest httpServletRequest2; + + @BeforeEach + void setUp() + { + mockCloser = MockitoAnnotations.openMocks(this); + + final DartSqlEngine engine = new DartSqlEngine( + queryId -> new MSQTestControllerContext( + objectMapper, + injector, + null /* not used in this test */, + workerMemoryParameters, + loadedSegmentsMetadata, + TaskLockType.APPEND, + QueryContext.empty() + ), + controllerRegistry = new DartControllerRegistry() + { + @Override + public void register(ControllerHolder holder) + { + super.register(holder); + controllerRegistered.countDown(); + } + }, + objectMapper.convertValue(ImmutableMap.of(), DartControllerConfig.class), + controllerExecutor = Execs.multiThreaded( + MAX_CONTROLLERS, + StringUtils.encodeForFormat(getClass().getSimpleName() + "-controller-exec") + ) + ); + + final DruidSchemaCatalog rootSchema = QueryFrameworkUtils.createMockRootSchema( + CalciteTests.INJECTOR, + queryFramework().conglomerate(), + queryFramework().walker(), + new PlannerConfig(), + new NoopViewManager(), + new NoopDruidSchemaManager(), + CalciteTests.TEST_AUTHORIZER_MAPPER, + CatalogResolver.NULL_RESOLVER + ); + + final PlannerFactory plannerFactory = new PlannerFactory( + rootSchema, + queryFramework().operatorTable(), + queryFramework().macroTable(), + PLANNER_CONFIG_DEFAULT, + CalciteTests.TEST_AUTHORIZER_MAPPER, + objectMapper, + CalciteTests.DRUID_SCHEMA_NAME, + new CalciteRulesManager(ImmutableSet.of()), + CalciteTests.createJoinableFactoryWrapper(), + CatalogResolver.NULL_RESOLVER, + new AuthConfig(), + new DruidHookDispatcher() + ); + + final SqlLifecycleManager lifecycleManager = new SqlLifecycleManager(); + final SqlToolbox toolbox = new SqlToolbox( + engine, + plannerFactory, + new NoopServiceEmitter(), + new NoopRequestLogger(), + QueryStackTests.DEFAULT_NOOP_SCHEDULER, + new DefaultQueryConfig(ImmutableMap.of()), + lifecycleManager + ); + + sqlResource = new DartSqlResource( + objectMapper, + CalciteTests.TEST_AUTHORIZER_MAPPER, + new SqlStatementFactory(toolbox), + controllerRegistry, + lifecycleManager, + dartSqlClients, + new ServerConfig() /* currently only used for error transform strategy */, + ResponseContextConfig.newConfig(false), + SELF_NODE, + new DefaultQueryConfig(ImmutableMap.of("foo", "bar")) + ); + + // Setup mocks + Mockito.when(dartSqlClients.getAllClients()).thenReturn(Collections.singletonList(dartSqlClient)); + } + + @AfterEach + void tearDown() throws Exception + { + mockCloser.close(); + + // shutdown(), not shutdownNow(), to ensure controllers stop timely on their own. + controllerExecutor.shutdown(); + + if (!controllerExecutor.awaitTermination(1, TimeUnit.MINUTES)) { + throw new IAE("controllerExecutor.awaitTermination() timed out"); + } + + // Ensure that controllerRegistry has nothing in it at the conclusion of each test. Verifies that controllers + // are fully cleaned up. + Assertions.assertEquals(0, controllerRegistry.getAllHolders().size(), "controllerRegistry.getAllHolders().size()"); + } + + @Test + public void test_getEnabled() + { + final Response response = sqlResource.doGetEnabled(httpServletRequest); + Assertions.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()); + } + + /** + * Test where a superuser calls {@link DartSqlResource#doGetRunningQueries} with selfOnly enabled. + */ + @Test + public void test_getRunningQueries_selfOnly_superUser() + { + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(CalciteTests.TEST_SUPERUSER_NAME)); + + final ControllerHolder holder = setUpMockRunningQuery(REGULAR_USER_NAME); + + Assertions.assertEquals( + new GetQueriesResponse(Collections.singletonList(DartQueryInfo.fromControllerHolder(holder))), + sqlResource.doGetRunningQueries("", httpServletRequest) + ); + + controllerRegistry.deregister(holder); + } + + /** + * Test where {@link #REGULAR_USER_NAME} and {@link #DIFFERENT_REGULAR_USER_NAME} issue queries, and + * {@link #REGULAR_USER_NAME} calls {@link DartSqlResource#doGetRunningQueries} with selfOnly enabled. + */ + @Test + public void test_getRunningQueries_selfOnly_regularUser() + { + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(REGULAR_USER_NAME)); + + final ControllerHolder holder = setUpMockRunningQuery(REGULAR_USER_NAME); + final ControllerHolder holder2 = setUpMockRunningQuery(DIFFERENT_REGULAR_USER_NAME); + + // Regular users can see only their own queries, without authentication details. + Assertions.assertEquals(2, controllerRegistry.getAllHolders().size()); + Assertions.assertEquals( + new GetQueriesResponse( + Collections.singletonList(DartQueryInfo.fromControllerHolder(holder).withoutAuthenticationResult())), + sqlResource.doGetRunningQueries("", httpServletRequest) + ); + + controllerRegistry.deregister(holder); + controllerRegistry.deregister(holder2); + } + + /** + * Test where a superuser calls {@link DartSqlResource#doGetRunningQueries} with selfOnly disabled. + */ + @Test + public void test_getRunningQueries_global_superUser() + { + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(CalciteTests.TEST_SUPERUSER_NAME)); + + // REGULAR_USER_NAME runs a query locally. + final ControllerHolder localHolder = setUpMockRunningQuery(REGULAR_USER_NAME); + + // DIFFERENT_REGULAR_USER_NAME runs a query remotely. + final DartQueryInfo remoteQueryInfo = new DartQueryInfo( + "sid", + "did2", + "SELECT 2", + AUTHENTICATOR_NAME, + DIFFERENT_REGULAR_USER_NAME, + DateTimes.of("2000"), + ControllerHolder.State.RUNNING.toString() + ); + Mockito.when(dartSqlClient.getRunningQueries(true)) + .thenReturn(Futures.immediateFuture(new GetQueriesResponse(Collections.singletonList(remoteQueryInfo)))); + + // With selfOnly = null, the endpoint returns both queries. + Assertions.assertEquals( + new GetQueriesResponse( + ImmutableList.of( + DartQueryInfo.fromControllerHolder(localHolder), + remoteQueryInfo + ) + ), + sqlResource.doGetRunningQueries(null, httpServletRequest) + ); + + controllerRegistry.deregister(localHolder); + } + + /** + * Test where a superuser calls {@link DartSqlResource#doGetRunningQueries} with selfOnly disabled, and where the + * remote server has a problem. + */ + @Test + public void test_getRunningQueries_global_remoteError_superUser() + { + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(CalciteTests.TEST_SUPERUSER_NAME)); + + // REGULAR_USER_NAME runs a query locally. + final ControllerHolder localHolder = setUpMockRunningQuery(REGULAR_USER_NAME); + + // Remote call fails. + Mockito.when(dartSqlClient.getRunningQueries(true)) + .thenReturn(Futures.immediateFailedFuture(new IOException("something went wrong"))); + + // We only see local queries, because the remote call failed. (The entire call doesn't fail; we see what we + // were able to fetch.) + Assertions.assertEquals( + new GetQueriesResponse(ImmutableList.of(DartQueryInfo.fromControllerHolder(localHolder))), + sqlResource.doGetRunningQueries(null, httpServletRequest) + ); + + controllerRegistry.deregister(localHolder); + } + + /** + * Test where {@link #REGULAR_USER_NAME} and {@link #DIFFERENT_REGULAR_USER_NAME} issue queries, and + * {@link #REGULAR_USER_NAME} calls {@link DartSqlResource#doGetRunningQueries} with selfOnly disabled. + */ + @Test + public void test_getRunningQueries_global_regularUser() + { + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(REGULAR_USER_NAME)); + + // REGULAR_USER_NAME runs a query locally. + final ControllerHolder localHolder = setUpMockRunningQuery(REGULAR_USER_NAME); + + // DIFFERENT_REGULAR_USER_NAME runs a query remotely. + final DartQueryInfo remoteQueryInfo = new DartQueryInfo( + "sid", + "did2", + "SELECT 2", + AUTHENTICATOR_NAME, + DIFFERENT_REGULAR_USER_NAME, + DateTimes.of("2000"), + ControllerHolder.State.RUNNING.toString() + ); + Mockito.when(dartSqlClient.getRunningQueries(true)) + .thenReturn(Futures.immediateFuture(new GetQueriesResponse(Collections.singletonList(remoteQueryInfo)))); + + // The endpoint returns only the query issued by REGULAR_USER_NAME. + Assertions.assertEquals( + new GetQueriesResponse( + ImmutableList.of(DartQueryInfo.fromControllerHolder(localHolder).withoutAuthenticationResult())), + sqlResource.doGetRunningQueries(null, httpServletRequest) + ); + + controllerRegistry.deregister(localHolder); + } + + /** + * Test where {@link #REGULAR_USER_NAME} and {@link #DIFFERENT_REGULAR_USER_NAME} issue queries, and + * {@link #DIFFERENT_REGULAR_USER_NAME} calls {@link DartSqlResource#doGetRunningQueries} with selfOnly disabled. + */ + @Test + public void test_getRunningQueries_global_differentRegularUser() + { + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(DIFFERENT_REGULAR_USER_NAME)); + + // REGULAR_USER_NAME runs a query locally. + final ControllerHolder holder = setUpMockRunningQuery(REGULAR_USER_NAME); + + // DIFFERENT_REGULAR_USER_NAME runs a query remotely. + final DartQueryInfo remoteQueryInfo = new DartQueryInfo( + "sid", + "did2", + "SELECT 2", + AUTHENTICATOR_NAME, + DIFFERENT_REGULAR_USER_NAME, + DateTimes.of("2000"), + ControllerHolder.State.RUNNING.toString() + ); + Mockito.when(dartSqlClient.getRunningQueries(true)) + .thenReturn(Futures.immediateFuture(new GetQueriesResponse(Collections.singletonList(remoteQueryInfo)))); + + // The endpoint returns only the query issued by DIFFERENT_REGULAR_USER_NAME. + Assertions.assertEquals( + new GetQueriesResponse(ImmutableList.of(remoteQueryInfo.withoutAuthenticationResult())), + sqlResource.doGetRunningQueries(null, httpServletRequest) + ); + + controllerRegistry.deregister(holder); + } + + @Test + public void test_doPost_regularUser() + { + final MockAsyncContext asyncContext = new MockAsyncContext(); + final MockHttpServletResponse asyncResponse = new MockHttpServletResponse(); + asyncContext.response = asyncResponse; + + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(REGULAR_USER_NAME)); + Mockito.when(httpServletRequest.startAsync()) + .thenReturn(asyncContext); + + final SqlQuery sqlQuery = new SqlQuery( + "SELECT 1 + 1", + ResultFormat.ARRAY, + false, + false, + false, + Collections.emptyMap(), + Collections.emptyList() + ); + + Assertions.assertNull(sqlResource.doPost(sqlQuery, httpServletRequest)); + Assertions.assertEquals(Response.Status.OK.getStatusCode(), asyncResponse.getStatus()); + Assertions.assertEquals("[[2]]\n", StringUtils.fromUtf8(asyncResponse.baos.toByteArray())); + } + + @Test + public void test_doPost_regularUser_forbidden() + { + final MockAsyncContext asyncContext = new MockAsyncContext(); + final MockHttpServletResponse asyncResponse = new MockHttpServletResponse(); + asyncContext.response = asyncResponse; + + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(REGULAR_USER_NAME)); + Mockito.when(httpServletRequest.startAsync()) + .thenReturn(asyncContext); + + final SqlQuery sqlQuery = new SqlQuery( + StringUtils.format("SELECT * FROM \"%s\"", CalciteTests.FORBIDDEN_DATASOURCE), + ResultFormat.ARRAY, + false, + false, + false, + Collections.emptyMap(), + Collections.emptyList() + ); + + Assertions.assertThrows( + ForbiddenException.class, + () -> sqlResource.doPost(sqlQuery, httpServletRequest) + ); + } + + @Test + public void test_doPost_regularUser_runtimeError() throws IOException + { + final MockAsyncContext asyncContext = new MockAsyncContext(); + final MockHttpServletResponse asyncResponse = new MockHttpServletResponse(); + asyncContext.response = asyncResponse; + + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(REGULAR_USER_NAME)); + Mockito.when(httpServletRequest.startAsync()) + .thenReturn(asyncContext); + + final SqlQuery sqlQuery = new SqlQuery( + "SELECT U&'\\0000'", + ResultFormat.ARRAY, + false, + false, + false, + Collections.emptyMap(), + Collections.emptyList() + ); + + Assertions.assertNull(sqlResource.doPost(sqlQuery, httpServletRequest)); + Assertions.assertEquals(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), asyncResponse.getStatus()); + + final Map e = objectMapper.readValue( + asyncResponse.baos.toByteArray(), + JacksonUtils.TYPE_REFERENCE_MAP_STRING_OBJECT + ); + + Assertions.assertEquals("InvalidNullByte", e.get("errorCode")); + Assertions.assertEquals("RUNTIME_FAILURE", e.get("category")); + assertThat((String) e.get("errorMessage"), CoreMatchers.startsWith("InvalidNullByte: ")); + } + + @Test + public void test_doPost_regularUser_fullReport() throws Exception + { + final MockAsyncContext asyncContext = new MockAsyncContext(); + final MockHttpServletResponse asyncResponse = new MockHttpServletResponse(); + asyncContext.response = asyncResponse; + + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(REGULAR_USER_NAME)); + Mockito.when(httpServletRequest.startAsync()) + .thenReturn(asyncContext); + + final SqlQuery sqlQuery = new SqlQuery( + "SELECT 1 + 1", + ResultFormat.ARRAY, + false, + false, + false, + ImmutableMap.of(DartSqlEngine.CTX_FULL_REPORT, true), + Collections.emptyList() + ); + + Assertions.assertNull(sqlResource.doPost(sqlQuery, httpServletRequest)); + Assertions.assertEquals(Response.Status.OK.getStatusCode(), asyncResponse.getStatus()); + + final List> reportMaps = objectMapper.readValue( + asyncResponse.baos.toByteArray(), + new TypeReference>>() {} + ); + + Assertions.assertEquals(1, reportMaps.size()); + final MSQTaskReport report = + (MSQTaskReport) Iterables.getOnlyElement(Iterables.getOnlyElement(reportMaps)).get(MSQTaskReport.REPORT_KEY); + final List results = report.getPayload().getResults().getResults(); + + Assertions.assertEquals(1, results.size()); + Assertions.assertArrayEquals(new Object[]{2}, results.get(0)); + } + + @Test + public void test_doPost_regularUser_runtimeError_fullReport() throws Exception + { + final MockAsyncContext asyncContext = new MockAsyncContext(); + final MockHttpServletResponse asyncResponse = new MockHttpServletResponse(); + asyncContext.response = asyncResponse; + + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(REGULAR_USER_NAME)); + Mockito.when(httpServletRequest.startAsync()) + .thenReturn(asyncContext); + + final SqlQuery sqlQuery = new SqlQuery( + "SELECT U&'\\0000'", + ResultFormat.ARRAY, + false, + false, + false, + ImmutableMap.of(DartSqlEngine.CTX_FULL_REPORT, true), + Collections.emptyList() + ); + + Assertions.assertNull(sqlResource.doPost(sqlQuery, httpServletRequest)); + Assertions.assertEquals(Response.Status.OK.getStatusCode(), asyncResponse.getStatus()); + + final List> reportMaps = objectMapper.readValue( + asyncResponse.baos.toByteArray(), + new TypeReference>>() {} + ); + + Assertions.assertEquals(1, reportMaps.size()); + final MSQTaskReport report = + (MSQTaskReport) Iterables.getOnlyElement(Iterables.getOnlyElement(reportMaps)).get(MSQTaskReport.REPORT_KEY); + final MSQErrorReport errorReport = report.getPayload().getStatus().getErrorReport(); + Assertions.assertNotNull(errorReport); + assertThat(errorReport.getFault(), CoreMatchers.instanceOf(InvalidNullByteFault.class)); + } + + @Test + public void test_doPost_regularUser_thenCancelQuery() throws Exception + { + run_test_doPost_regularUser_fullReport_thenCancelQuery(false); + } + + @Test + public void test_doPost_regularUser_fullReport_thenCancelQuery() throws Exception + { + run_test_doPost_regularUser_fullReport_thenCancelQuery(true); + } + + /** + * Helper for {@link #test_doPost_regularUser_thenCancelQuery()} and + * {@link #test_doPost_regularUser_fullReport_thenCancelQuery()}. We need to do cancellation tests with and + * without the "fullReport" parameter, because {@link DartQueryMaker} has a separate pathway for each one. + */ + private void run_test_doPost_regularUser_fullReport_thenCancelQuery(final boolean fullReport) throws Exception + { + final MockAsyncContext asyncContext = new MockAsyncContext(); + final MockHttpServletResponse asyncResponse = new MockHttpServletResponse(); + asyncContext.response = asyncResponse; + + // POST SQL query request. + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(REGULAR_USER_NAME)); + Mockito.when(httpServletRequest.startAsync()) + .thenReturn(asyncContext); + + // Cancellation request. + Mockito.when(httpServletRequest2.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(REGULAR_USER_NAME)); + + // Block up the controllerExecutor so the controller runs long enough to cancel it. + final Future sleepFuture = controllerExecutor.submit(() -> { + try { + Thread.sleep(3_600_000); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + + final String sqlQueryId = UUID.randomUUID().toString(); + final SqlQuery sqlQuery = new SqlQuery( + "SELECT 1 + 1", + ResultFormat.ARRAY, + false, + false, + false, + ImmutableMap.of(QueryContexts.CTX_SQL_QUERY_ID, sqlQueryId, DartSqlEngine.CTX_FULL_REPORT, fullReport), + Collections.emptyList() + ); + + final ExecutorService doPostExec = Execs.singleThreaded("do-post-exec-%s"); + final Future doPostFuture; + try { + // Run doPost in a separate thread. There are now three threads: + // 1) The controllerExecutor thread, which is blocked up by sleepFuture. + // 2) The doPostExec thread, which has a doPost in there, blocking on controllerExecutor. + // 3) The current main test thread, which continues on and which will issue the cancellation request. + doPostFuture = doPostExec.submit(() -> sqlResource.doPost(sqlQuery, httpServletRequest)); + controllerRegistered.await(); + + // Issue cancellation request. + final Response cancellationResponse = sqlResource.cancelQuery(sqlQueryId, httpServletRequest2); + Assertions.assertEquals(Response.Status.ACCEPTED.getStatusCode(), cancellationResponse.getStatus()); + + // Now that the cancellation request has been accepted, we can cancel the sleepFuture and allow the + // controller to be canceled. + sleepFuture.cancel(true); + doPostExec.shutdown(); + } + catch (Throwable e) { + doPostExec.shutdownNow(); + throw e; + } + + if (!doPostExec.awaitTermination(1, TimeUnit.MINUTES)) { + throw new ISE("doPost timed out"); + } + + // Wait for the SQL POST to come back. + Assertions.assertNull(doPostFuture.get()); + Assertions.assertEquals(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), asyncResponse.getStatus()); + + // Ensure MSQ fault (CanceledFault) is properly translated to a DruidException and then properly serialized. + final Map e = objectMapper.readValue( + asyncResponse.baos.toByteArray(), + JacksonUtils.TYPE_REFERENCE_MAP_STRING_OBJECT + ); + Assertions.assertEquals("Canceled", e.get("errorCode")); + Assertions.assertEquals("CANCELED", e.get("category")); + Assertions.assertEquals( + MSQFaultUtils.generateMessageWithErrorCode(CanceledFault.instance()), + e.get("errorMessage") + ); + } + + @Test + public void test_cancelQuery_regularUser_unknownQuery() + { + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(REGULAR_USER_NAME)); + + final Response cancellationResponse = sqlResource.cancelQuery("nonexistent", httpServletRequest); + Assertions.assertEquals(Response.Status.NOT_FOUND.getStatusCode(), cancellationResponse.getStatus()); + } + + /** + * Add a mock {@link ControllerHolder} to {@link #controllerRegistry}, with a query run by the given user. + * Used by methods that test {@link DartSqlResource#doGetRunningQueries}. + * + * @return the mock holder + */ + private ControllerHolder setUpMockRunningQuery(final String identity) + { + final Controller controller = Mockito.mock(Controller.class); + Mockito.when(controller.queryId()).thenReturn("did_" + identity); + + final AuthenticationResult authenticationResult = makeAuthenticationResult(identity); + final ControllerHolder holder = + new ControllerHolder(controller, null, "sid", "SELECT 1", authenticationResult, DateTimes.of("2000")); + + controllerRegistry.register(holder); + return holder; + } + + /** + * Create an {@link AuthenticationResult} with {@link AuthenticationResult#getAuthenticatedBy()} set to + * {@link #AUTHENTICATOR_NAME}. + */ + private static AuthenticationResult makeAuthenticationResult(final String identity) + { + return new AuthenticationResult(identity, null, AUTHENTICATOR_NAME, Collections.emptyMap()); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/GetQueriesResponseTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/GetQueriesResponseTest.java new file mode 100644 index 000000000000..7b43c863c9d1 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/GetQueriesResponseTest.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.http; + +import com.fasterxml.jackson.databind.ObjectMapper; +import nl.jqno.equalsverifier.EqualsVerifier; +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.msq.dart.controller.ControllerHolder; +import org.apache.druid.segment.TestHelper; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.Collections; + +public class GetQueriesResponseTest +{ + @Test + public void test_serde() throws Exception + { + final ObjectMapper jsonMapper = TestHelper.JSON_MAPPER; + final GetQueriesResponse response = new GetQueriesResponse( + Collections.singletonList( + new DartQueryInfo( + "xyz", + "abc", + "SELECT 1", + "auth", + "anon", + DateTimes.of("2000"), + ControllerHolder.State.RUNNING.toString() + ) + ) + ); + final GetQueriesResponse response2 = + jsonMapper.readValue(jsonMapper.writeValueAsBytes(response), GetQueriesResponse.class); + Assertions.assertEquals(response, response2); + } + + @Test + public void test_equals() + { + EqualsVerifier.forClass(GetQueriesResponse.class).usingGetClass().verify(); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/messages/ControllerMessageTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/messages/ControllerMessageTest.java new file mode 100644 index 000000000000..427faf4aee6f --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/messages/ControllerMessageTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.messages; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.ObjectMapper; +import nl.jqno.equalsverifier.EqualsVerifier; +import org.apache.druid.msq.guice.MSQIndexingModule; +import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.msq.indexing.error.UnknownFault; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; +import org.apache.druid.segment.TestHelper; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.Collections; + +public class ControllerMessageTest +{ + private static final StageId STAGE_ID = StageId.fromString("xyz_2"); + private ObjectMapper objectMapper; + + @BeforeEach + public void setUp() + { + objectMapper = TestHelper.JSON_MAPPER.copy(); + objectMapper.enable(JsonParser.Feature.STRICT_DUPLICATE_DETECTION); + objectMapper.registerModules(new MSQIndexingModule().getJacksonModules()); + } + + @Test + public void testSerde() throws IOException + { + final PartialKeyStatisticsInformation partialKeyStatisticsInformation = + new PartialKeyStatisticsInformation(Collections.emptySet(), false, 0); + + assertSerde(new PartialKeyStatistics(STAGE_ID, 1, partialKeyStatisticsInformation)); + assertSerde(new DoneReadingInput(STAGE_ID, 1)); + assertSerde(new ResultsComplete(STAGE_ID, 1, "foo")); + assertSerde( + new WorkerError( + STAGE_ID.getQueryId(), + MSQErrorReport.fromFault("task", null, null, UnknownFault.forMessage("oops")) + ) + ); + assertSerde( + new WorkerWarning( + STAGE_ID.getQueryId(), + Collections.singletonList(MSQErrorReport.fromFault("task", null, null, UnknownFault.forMessage("oops"))) + ) + ); + } + + @Test + public void testEqualsAndHashCode() + { + EqualsVerifier.forClass(PartialKeyStatistics.class).usingGetClass().verify(); + EqualsVerifier.forClass(DoneReadingInput.class).usingGetClass().verify(); + EqualsVerifier.forClass(ResultsComplete.class).usingGetClass().verify(); + EqualsVerifier.forClass(WorkerError.class).usingGetClass().verify(); + EqualsVerifier.forClass(WorkerWarning.class).usingGetClass().verify(); + } + + private void assertSerde(final ControllerMessage message) throws IOException + { + final String json = objectMapper.writeValueAsString(message); + final ControllerMessage message2 = objectMapper.readValue(json, ControllerMessage.class); + Assertions.assertEquals(message, message2, json); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientImplTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientImplTest.java new file mode 100644 index 000000000000..19a4eaf0b151 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientImplTest.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.sql; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.jackson.DefaultObjectMapper; +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.msq.dart.controller.ControllerHolder; +import org.apache.druid.msq.dart.controller.http.DartQueryInfo; +import org.apache.druid.msq.dart.controller.http.GetQueriesResponse; +import org.apache.druid.rpc.MockServiceClient; +import org.apache.druid.rpc.RequestBuilder; +import org.jboss.netty.handler.codec.http.HttpMethod; +import org.jboss.netty.handler.codec.http.HttpResponseStatus; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import javax.ws.rs.core.HttpHeaders; +import javax.ws.rs.core.MediaType; + +public class DartSqlClientImplTest +{ + private ObjectMapper jsonMapper; + private MockServiceClient serviceClient; + private DartSqlClient dartSqlClient; + + @BeforeEach + public void setup() + { + jsonMapper = new DefaultObjectMapper(); + serviceClient = new MockServiceClient(); + dartSqlClient = new DartSqlClientImpl(serviceClient, jsonMapper); + } + + @AfterEach + public void tearDown() + { + serviceClient.verify(); + } + + @Test + public void test_getMessages_all() throws Exception + { + final GetQueriesResponse getQueriesResponse = new GetQueriesResponse( + ImmutableList.of( + new DartQueryInfo( + "sid", + "did", + "SELECT 1", + "", + "", + DateTimes.of("2000"), + ControllerHolder.State.RUNNING.toString() + ) + ) + ); + + serviceClient.expectAndRespond( + new RequestBuilder(HttpMethod.GET, "/"), + HttpResponseStatus.OK, + ImmutableMap.of(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON), + jsonMapper.writeValueAsBytes(getQueriesResponse) + ); + + final ListenableFuture result = dartSqlClient.getRunningQueries(false); + Assertions.assertEquals(getQueriesResponse, result.get()); + } + + @Test + public void test_getMessages_selfOnly() throws Exception + { + final GetQueriesResponse getQueriesResponse = new GetQueriesResponse( + ImmutableList.of( + new DartQueryInfo( + "sid", + "did", + "SELECT 1", + "", + "", + DateTimes.of("2000"), + ControllerHolder.State.RUNNING.toString() + ) + ) + ); + + serviceClient.expectAndRespond( + new RequestBuilder(HttpMethod.GET, "/?selfOnly"), + HttpResponseStatus.OK, + ImmutableMap.of(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON), + jsonMapper.writeValueAsBytes(getQueriesResponse) + ); + + final ListenableFuture result = dartSqlClient.getRunningQueries(true); + Assertions.assertEquals(getQueriesResponse, result.get()); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartQueryableSegmentTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartQueryableSegmentTest.java new file mode 100644 index 000000000000..b53a397dae81 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartQueryableSegmentTest.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import nl.jqno.equalsverifier.EqualsVerifier; +import org.junit.jupiter.api.Test; + +public class DartQueryableSegmentTest +{ + @Test + public void test_equals() + { + EqualsVerifier.forClass(DartQueryableSegment.class).usingGetClass().verify(); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartWorkerRunnerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartWorkerRunnerTest.java new file mode 100644 index 000000000000..1f152b74049f --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartWorkerRunnerTest.java @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import com.google.common.util.concurrent.SettableFuture; +import org.apache.druid.discovery.DiscoveryDruidNode; +import org.apache.druid.discovery.DruidNodeDiscovery; +import org.apache.druid.discovery.DruidNodeDiscoveryProvider; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.error.DruidException; +import org.apache.druid.java.util.common.FileUtils; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.msq.dart.DartResourcePermissionMapper; +import org.apache.druid.msq.dart.worker.http.GetWorkersResponse; +import org.apache.druid.msq.exec.Worker; +import org.apache.druid.msq.indexing.error.CanceledFault; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.query.QueryContext; +import org.apache.druid.server.DruidNode; +import org.apache.druid.server.security.AuthorizerMapper; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; +import org.junit.internal.matchers.ThrowableMessageMatcher; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.io.TempDir; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.Collections; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; + +/** + * Functional test of {@link DartWorkerRunner}. + */ +public class DartWorkerRunnerTest +{ + private static final int MAX_WORKERS = 1; + private static final String QUERY_ID = "abc"; + private static final WorkerId WORKER_ID = new WorkerId("http", "localhost:8282", QUERY_ID); + private static final String CONTROLLER_SERVER_HOST = "localhost:8081"; + private static final DiscoveryDruidNode CONTROLLER_DISCOVERY_NODE = + new DiscoveryDruidNode( + new DruidNode("no", "localhost", false, 8081, -1, true, false), + NodeRole.BROKER, + Collections.emptyMap() + ); + + private final SettableFuture workerRun = SettableFuture.create(); + + private ExecutorService workerExec; + private DartWorkerRunner workerRunner; + private AutoCloseable mockCloser; + + @TempDir + public Path temporaryFolder; + + @Mock + private DartWorkerFactory workerFactory; + + @Mock + private Worker worker; + + @Mock + private DruidNodeDiscoveryProvider discoveryProvider; + + @Mock + private DruidNodeDiscovery discovery; + + @Mock + private AuthorizerMapper authorizerMapper; + + @Captor + private ArgumentCaptor discoveryListener; + + @BeforeEach + public void setUp() + { + mockCloser = MockitoAnnotations.openMocks(this); + workerRunner = new DartWorkerRunner( + workerFactory, + workerExec = Execs.multiThreaded(MAX_WORKERS, "worker-exec-%s"), + discoveryProvider, + new DartResourcePermissionMapper(), + authorizerMapper, + temporaryFolder.toFile() + ); + + // "discoveryProvider" provides "discovery". + Mockito.when(discoveryProvider.getForNodeRole(NodeRole.BROKER)).thenReturn(discovery); + + // "workerFactory" builds "worker". + Mockito.when( + workerFactory.build( + QUERY_ID, + CONTROLLER_SERVER_HOST, + temporaryFolder.toFile(), + QueryContext.empty() + ) + ).thenReturn(worker); + + // "worker.run()" exits when "workerRun" resolves. + Mockito.doAnswer(invocation -> { + workerRun.get(); + return null; + }).when(worker).run(); + + // "worker.stop()" sets "workerRun" to a cancellation error. + Mockito.doAnswer(invocation -> { + workerRun.setException(new MSQException(CanceledFault.instance())); + return null; + }).when(worker).stop(); + + // "worker.controllerFailed()" sets "workerRun" to an error. + Mockito.doAnswer(invocation -> { + workerRun.setException(new ISE("Controller failed")); + return null; + }).when(worker).controllerFailed(); + + // "worker.awaitStop()" waits for "workerRun". It does not throw an exception, just like WorkerImpl.awaitStop. + Mockito.doAnswer(invocation -> { + try { + workerRun.get(); + } + catch (Throwable e) { + // Suppress + } + return null; + }).when(worker).awaitStop(); + + // "worker.id()" returns WORKER_ID. + Mockito.when(worker.id()).thenReturn(WORKER_ID.toString()); + + // Start workerRunner, capture listener in "discoveryListener". + workerRunner.start(); + Mockito.verify(discovery).registerListener(discoveryListener.capture()); + } + + @AfterEach + public void tearDown() throws Exception + { + workerExec.shutdown(); + workerRunner.stop(); + mockCloser.close(); + + if (!workerExec.awaitTermination(1, TimeUnit.MINUTES)) { + throw new ISE("workerExec did not terminate within timeout"); + } + } + + @Test + public void test_getWorkersResponse_empty() + { + final GetWorkersResponse workersResponse = workerRunner.getWorkersResponse(); + Assertions.assertEquals(new GetWorkersResponse(Collections.emptyList()), workersResponse); + } + + @Test + public void test_getWorkerResource_notFound() + { + Assertions.assertNull(workerRunner.getWorkerResource("nonexistent")); + } + + @Test + public void test_createAndCleanTempDirectory() throws IOException + { + workerRunner.stop(); + + // Create an empty directory "x". + FileUtils.mkdirp(new File(temporaryFolder.toFile(), "x")); + Assertions.assertArrayEquals( + new File[]{new File(temporaryFolder.toFile(), "x")}, + temporaryFolder.toFile().listFiles() + ); + + // Run "createAndCleanTempDirectory", which will delete it. + workerRunner.createAndCleanTempDirectory(); + Assertions.assertArrayEquals(new File[]{}, temporaryFolder.toFile().listFiles()); + } + + @Test + public void test_startWorker_controllerNotActive() + { + final DruidException e = Assertions.assertThrows( + DruidException.class, + () -> workerRunner.startWorker("abc", CONTROLLER_SERVER_HOST, QueryContext.empty()) + ); + + MatcherAssert.assertThat( + e, + ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString( + "Received startWorker request for unknown controller")) + ); + } + + @Test + public void test_stopWorker_nonexistent() + { + // Nothing happens when we do this. Just verifying an exception isn't thrown. + workerRunner.stopWorker("nonexistent"); + } + + @Test + public void test_startWorker() + { + // Activate controller. + discoveryListener.getValue().nodesAdded(Collections.singletonList(CONTROLLER_DISCOVERY_NODE)); + + // Start the worker twice (startWorker is idempotent; nothing special happens the second time). + final Worker workerFromStart = workerRunner.startWorker(QUERY_ID, CONTROLLER_SERVER_HOST, QueryContext.empty()); + final Worker workerFromStart2 = workerRunner.startWorker(QUERY_ID, CONTROLLER_SERVER_HOST, QueryContext.empty()); + Assertions.assertSame(worker, workerFromStart); + Assertions.assertSame(worker, workerFromStart2); + + // Worker should enter the GetWorkersResponse. + final GetWorkersResponse workersResponse = workerRunner.getWorkersResponse(); + Assertions.assertEquals(1, workersResponse.getWorkers().size()); + Assertions.assertEquals(QUERY_ID, workersResponse.getWorkers().get(0).getDartQueryId()); + Assertions.assertEquals(CONTROLLER_SERVER_HOST, workersResponse.getWorkers().get(0).getControllerHost()); + Assertions.assertEquals(WORKER_ID, workersResponse.getWorkers().get(0).getWorkerId()); + + // Worker should have a resource. + Assertions.assertNotNull(workerRunner.getWorkerResource(QUERY_ID)); + } + + @Test + @Timeout(value = 1, unit = TimeUnit.MINUTES) + public void test_startWorker_thenRemoveController() throws InterruptedException + { + // Activate controller. + discoveryListener.getValue().nodesAdded(Collections.singletonList(CONTROLLER_DISCOVERY_NODE)); + + // Start the worker. + final Worker workerFromStart = workerRunner.startWorker(QUERY_ID, CONTROLLER_SERVER_HOST, QueryContext.empty()); + Assertions.assertSame(worker, workerFromStart); + Assertions.assertEquals(1, workerRunner.getWorkersResponse().getWorkers().size()); + + // Deactivate controller. + discoveryListener.getValue().nodesRemoved(Collections.singletonList(CONTROLLER_DISCOVERY_NODE)); + + // Worker should go away. + workerRunner.awaitQuerySet(Set::isEmpty); + Assertions.assertEquals(0, workerRunner.getWorkersResponse().getWorkers().size()); + } + + @Test + @Timeout(value = 1, unit = TimeUnit.MINUTES) + public void test_startWorker_thenStopWorker() throws InterruptedException + { + // Activate controller. + discoveryListener.getValue().nodesAdded(Collections.singletonList(CONTROLLER_DISCOVERY_NODE)); + + // Start the worker. + final Worker workerFromStart = workerRunner.startWorker(QUERY_ID, CONTROLLER_SERVER_HOST, QueryContext.empty()); + Assertions.assertSame(worker, workerFromStart); + Assertions.assertEquals(1, workerRunner.getWorkersResponse().getWorkers().size()); + + // Stop that worker. + workerRunner.stopWorker(QUERY_ID); + + // Worker should go away. + workerRunner.awaitQuerySet(Set::isEmpty); + Assertions.assertEquals(0, workerRunner.getWorkersResponse().getWorkers().size()); + } + + @Test + @Timeout(value = 1, unit = TimeUnit.MINUTES) + public void test_startWorker_thenStopRunner() throws InterruptedException + { + // Activate controller. + discoveryListener.getValue().nodesAdded(Collections.singletonList(CONTROLLER_DISCOVERY_NODE)); + + // Start the worker. + final Worker workerFromStart = workerRunner.startWorker(QUERY_ID, CONTROLLER_SERVER_HOST, QueryContext.empty()); + Assertions.assertSame(worker, workerFromStart); + Assertions.assertEquals(1, workerRunner.getWorkersResponse().getWorkers().size()); + + // Stop runner. + workerRunner.stop(); + + // Worker should go away. + workerRunner.awaitQuerySet(Set::isEmpty); + Assertions.assertEquals(0, workerRunner.getWorkersResponse().getWorkers().size()); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/WorkerIdTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/WorkerIdTest.java new file mode 100644 index 000000000000..e4f74a0250f6 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/WorkerIdTest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import com.fasterxml.jackson.databind.ObjectMapper; +import nl.jqno.equalsverifier.EqualsVerifier; +import org.apache.druid.segment.TestHelper; +import org.apache.druid.server.DruidNode; +import org.apache.druid.server.coordination.DruidServerMetadata; +import org.apache.druid.server.coordination.ServerType; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.io.IOException; + +public class WorkerIdTest +{ + @Test + public void test_fromString() + { + Assertions.assertEquals( + new WorkerId("https", "local-host:8100", "xyz"), + WorkerId.fromString("https:local-host:8100:xyz") + ); + } + + @Test + public void test_fromDruidNode() + { + Assertions.assertEquals( + new WorkerId("https", "local-host:8100", "xyz"), + WorkerId.fromDruidNode(new DruidNode("none", "local-host", false, 8200, 8100, true, true), "xyz") + ); + } + + @Test + public void test_fromDruidServerMetadata() + { + Assertions.assertEquals( + new WorkerId("https", "local-host:8100", "xyz"), + WorkerId.fromDruidServerMetadata( + new DruidServerMetadata("none", "local-host:8200", "local-host:8100", 1, ServerType.HISTORICAL, "none", 0), + "xyz" + ) + ); + } + + @Test + public void test_toString() + { + Assertions.assertEquals( + "https:local-host:8100:xyz", + new WorkerId("https", "local-host:8100", "xyz").toString() + ); + } + + @Test + public void test_getters() + { + final WorkerId workerId = new WorkerId("https", "local-host:8100", "xyz"); + Assertions.assertEquals("https", workerId.getScheme()); + Assertions.assertEquals("local-host:8100", workerId.getHostAndPort()); + Assertions.assertEquals("xyz", workerId.getQueryId()); + Assertions.assertEquals("https://local-host:8100/druid/dart-worker/workers/xyz", workerId.toUri().toString()); + } + + @Test + public void test_serde() throws IOException + { + final ObjectMapper objectMapper = TestHelper.JSON_MAPPER; + final WorkerId workerId = new WorkerId("https", "localhost:8100", "xyz"); + final WorkerId workerId2 = objectMapper.readValue(objectMapper.writeValueAsBytes(workerId), WorkerId.class); + Assertions.assertEquals(workerId, workerId2); + } + + @Test + public void test_equals() + { + EqualsVerifier.forClass(WorkerId.class) + .usingGetClass() + .withNonnullFields("fullString") + .withIgnoredFields("scheme", "hostAndPort", "queryId") + .verify(); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/http/DartWorkerInfoTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/http/DartWorkerInfoTest.java new file mode 100644 index 000000000000..74cd8a28915a --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/http/DartWorkerInfoTest.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker.http; + +import nl.jqno.equalsverifier.EqualsVerifier; +import org.junit.jupiter.api.Test; + +public class DartWorkerInfoTest +{ + @Test + public void test_equals() + { + EqualsVerifier.forClass(DartWorkerInfo.class).usingGetClass().verify(); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/http/GetWorkersResponseTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/http/GetWorkersResponseTest.java new file mode 100644 index 000000000000..f516077a5754 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/http/GetWorkersResponseTest.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker.http; + +import com.fasterxml.jackson.databind.ObjectMapper; +import nl.jqno.equalsverifier.EqualsVerifier; +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.msq.dart.worker.WorkerId; +import org.apache.druid.segment.TestHelper; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.Collections; + +public class GetWorkersResponseTest +{ + @Test + public void test_serde() throws Exception + { + final ObjectMapper jsonMapper = TestHelper.JSON_MAPPER; + final GetWorkersResponse response = new GetWorkersResponse( + Collections.singletonList( + new DartWorkerInfo( + "xyz", + WorkerId.fromString("http:localhost:8100:xyz"), + "localhost:8101", + DateTimes.of("2000") + ) + ) + ); + final GetWorkersResponse response2 = + jsonMapper.readValue(jsonMapper.writeValueAsBytes(response), GetWorkersResponse.class); + Assertions.assertEquals(response, response2); + } + + @Test + public void test_equals() + { + EqualsVerifier.forClass(GetWorkersResponse.class).usingGetClass().verify(); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java index 761a61337ea9..cbaae8f1b8e0 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java @@ -320,6 +320,7 @@ public class MSQTestBase extends BaseCalciteQueryTest protected File localFileStorageDir; protected LocalFileStorageConnector localFileStorageConnector; private static final Logger log = new Logger(MSQTestBase.class); + protected Injector injector; protected ObjectMapper objectMapper; protected MSQTestOverlordServiceClient indexingServiceClient; protected MSQTestTaskActionClient testTaskActionClient; @@ -530,7 +531,7 @@ public String getFormatString() binder -> binder.bind(Bouncer.class).toInstance(new Bouncer(1)) ); // adding node role injection to the modules, since CliPeon would also do that through run method - Injector injector = new CoreInjectorBuilder(new StartupInjectorBuilder().build(), ImmutableSet.of(NodeRole.PEON)) + injector = new CoreInjectorBuilder(new StartupInjectorBuilder().build(), ImmutableSet.of(NodeRole.PEON)) .addAll(modules) .build(); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java index 970d873c96c8..4dadeae5bc10 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java @@ -56,7 +56,6 @@ import org.apache.druid.msq.exec.WorkerStorageParameters; import org.apache.druid.msq.indexing.IndexerControllerContext; import org.apache.druid.msq.indexing.IndexerTableInputSpecSlicer; -import org.apache.druid.msq.indexing.MSQControllerTask; import org.apache.druid.msq.indexing.MSQSpec; import org.apache.druid.msq.indexing.MSQWorkerTask; import org.apache.druid.msq.indexing.MSQWorkerTaskLauncher; @@ -108,8 +107,8 @@ public class MSQTestControllerContext implements ControllerContext private Controller controller; private final WorkerMemoryParameters workerMemoryParameters; + private final TaskLockType taskLockType; private final QueryContext queryContext; - private final MSQControllerTask controllerTask; public MSQTestControllerContext( ObjectMapper mapper, @@ -117,7 +116,8 @@ public MSQTestControllerContext( TaskActionClient taskActionClient, WorkerMemoryParameters workerMemoryParameters, List loadedSegments, - MSQControllerTask controllerTask + TaskLockType taskLockType, + QueryContext queryContext ) { this.mapper = mapper; @@ -137,8 +137,8 @@ public MSQTestControllerContext( .collect(Collectors.toList()) ); this.workerMemoryParameters = workerMemoryParameters; - this.controllerTask = controllerTask; - this.queryContext = controllerTask.getQuerySpec().getQuery().context(); + this.taskLockType = taskLockType; + this.queryContext = queryContext; } OverlordClient overlordClient = new NoopOverlordClient() @@ -329,7 +329,7 @@ public TaskActionClient taskActionClient() @Override public TaskLockType taskLockType() { - return controllerTask.getTaskLockType(); + return taskLockType; } @Override diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java index 6a7db8aa5b63..b35c074fa060 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java @@ -103,7 +103,8 @@ public ListenableFuture runTask(String taskId, Object taskObject) taskActionClient, workerMemoryParameters, loadedSegmentMetadata, - cTask + cTask.getTaskLockType(), + cTask.getQuerySpec().getQuery().context() ); inMemoryControllerTask.put(cTask.getId(), cTask); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java index ffd7c67ca2d6..4c7ccd72efd0 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java @@ -35,10 +35,12 @@ import java.io.InputStream; import java.util.Arrays; import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; public class MSQTestWorkerClient implements WorkerClient { private final Map inMemoryWorkers; + private final AtomicBoolean closed = new AtomicBoolean(); public MSQTestWorkerClient(Map inMemoryWorkers) { @@ -141,6 +143,8 @@ public ListenableFuture fetchChannelData( @Override public void close() { - inMemoryWorkers.forEach((k, v) -> v.stop()); + if (closed.compareAndSet(false, true)) { + inMemoryWorkers.forEach((k, v) -> v.stop()); + } } } diff --git a/processing/src/main/java/org/apache/druid/common/guava/FutureBox.java b/processing/src/main/java/org/apache/druid/common/guava/FutureBox.java new file mode 100644 index 000000000000..3e92706aa028 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/common/guava/FutureBox.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.common.guava; + +import com.google.common.collect.Sets; +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.java.util.common.concurrent.Execs; + +import java.io.Closeable; +import java.util.Set; + +/** + * Box for tracking pending futures. Allows cancellation of all pending futures. + */ +public class FutureBox implements Closeable +{ + /** + * Currently-outstanding futures. These are tracked so they can be canceled in {@link #close()}. + */ + private final Set> pendingFutures = Sets.newConcurrentHashSet(); + + private volatile boolean stopped; + + /** + * Returns the number of currently-pending futures. + */ + public int pendingCount() + { + return pendingFutures.size(); + } + + /** + * Adds a future to the box. + * If {@link #close()} had previously been called, the future is immediately canceled. + */ + public ListenableFuture register(final ListenableFuture future) + { + pendingFutures.add(future); + future.addListener(() -> pendingFutures.remove(future), Execs.directExecutor()); + + // If "stop" was called while we were creating this future, cancel it prior to returning it. + if (stopped) { + future.cancel(false); + } + + return future; + } + + /** + * Closes the box, canceling all currently-pending futures. + */ + @Override + public void close() + { + stopped = true; + for (ListenableFuture future : pendingFutures) { + future.cancel(false); // Ignore return value + } + } +} diff --git a/processing/src/main/java/org/apache/druid/io/LimitedOutputStream.java b/processing/src/main/java/org/apache/druid/io/LimitedOutputStream.java new file mode 100644 index 000000000000..6d27abb42739 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/io/LimitedOutputStream.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.io; + +import org.apache.druid.error.DruidException; +import org.apache.druid.java.util.common.IOE; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.function.Function; + +/** + * An {@link OutputStream} that limits how many bytes can be written. Throws {@link IOException} if the limit + * is exceeded. + */ +public class LimitedOutputStream extends OutputStream +{ + private final OutputStream out; + private final long limit; + private final Function exceptionMessageFn; + long written; + + /** + * Create a bytes-limited output stream. + * + * @param out output stream to wrap + * @param limit bytes limit + * @param exceptionMessageFn function for generating an exception message for an {@link IOException}, given the limit. + */ + public LimitedOutputStream(OutputStream out, long limit, Function exceptionMessageFn) + { + this.out = out; + this.limit = limit; + this.exceptionMessageFn = exceptionMessageFn; + + if (limit < 0) { + throw DruidException.defensive("Limit[%s] must be greater than or equal to zero", limit); + } + } + + @Override + public void write(int b) throws IOException + { + plus(1); + out.write(b); + } + + @Override + public void write(byte[] b) throws IOException + { + plus(b.length); + out.write(b); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException + { + plus(len); + out.write(b, off, len); + } + + @Override + public void flush() throws IOException + { + out.flush(); + } + + @Override + public void close() throws IOException + { + out.close(); + } + + private void plus(final int n) throws IOException + { + written += n; + if (written > limit) { + throw new IOE(exceptionMessageFn.apply(limit)); + } + } +} diff --git a/processing/src/test/java/org/apache/druid/common/guava/FutureBoxTest.java b/processing/src/test/java/org/apache/druid/common/guava/FutureBoxTest.java new file mode 100644 index 000000000000..7428f94fa71a --- /dev/null +++ b/processing/src/test/java/org/apache/druid/common/guava/FutureBoxTest.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.common.guava; + +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import org.junit.Test; +import org.junit.jupiter.api.Assertions; + +import java.util.concurrent.ExecutionException; + +public class FutureBoxTest +{ + @Test + public void test_immediateFutures() throws Exception + { + try (final FutureBox box = new FutureBox()) { + Assertions.assertEquals("a", box.register(Futures.immediateFuture("a")).get()); + Assertions.assertThrows( + ExecutionException.class, + () -> box.register(Futures.immediateFailedFuture(new RuntimeException())).get() + ); + Assertions.assertTrue(box.register(Futures.immediateCancelledFuture()).isCancelled()); + Assertions.assertEquals(0, box.pendingCount()); + } + } + + @Test + public void test_register_thenStop() + { + final FutureBox box = new FutureBox(); + final SettableFuture settableFuture = SettableFuture.create(); + + final ListenableFuture retVal = box.register(settableFuture); + Assertions.assertSame(retVal, settableFuture); + Assertions.assertEquals(1, box.pendingCount()); + + box.close(); + Assertions.assertEquals(0, box.pendingCount()); + + Assertions.assertTrue(settableFuture.isCancelled()); + } + + @Test + public void test_stop_thenRegister() + { + final FutureBox box = new FutureBox(); + final SettableFuture settableFuture = SettableFuture.create(); + + box.close(); + final ListenableFuture retVal = box.register(settableFuture); + + Assertions.assertSame(retVal, settableFuture); + Assertions.assertEquals(0, box.pendingCount()); + Assertions.assertTrue(settableFuture.isCancelled()); + } +} diff --git a/processing/src/test/java/org/apache/druid/io/LimitedOutputStreamTest.java b/processing/src/test/java/org/apache/druid/io/LimitedOutputStreamTest.java new file mode 100644 index 000000000000..a11b63149710 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/io/LimitedOutputStreamTest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.io; + +import org.apache.druid.java.util.common.StringUtils; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; +import org.junit.Assert; +import org.junit.Test; +import org.junit.internal.matchers.ThrowableMessageMatcher; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +public class LimitedOutputStreamTest +{ + @Test + public void test_limitZero() throws IOException + { + try (final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final OutputStream stream = + new LimitedOutputStream(baos, 0, LimitedOutputStreamTest::makeErrorMessage)) { + final IOException e = Assert.assertThrows( + IOException.class, + () -> stream.write('b') + ); + + MatcherAssert.assertThat(e, ThrowableMessageMatcher.hasMessage(CoreMatchers.equalTo("Limit[0] exceeded"))); + } + } + + @Test + public void test_limitThree() throws IOException + { + try (final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final OutputStream stream = + new LimitedOutputStream(baos, 3, LimitedOutputStreamTest::makeErrorMessage)) { + stream.write('a'); + stream.write(new byte[]{'b'}); + stream.write(new byte[]{'c'}, 0, 1); + final IOException e = Assert.assertThrows( + IOException.class, + () -> stream.write('d') + ); + + MatcherAssert.assertThat(e, ThrowableMessageMatcher.hasMessage(CoreMatchers.equalTo("Limit[3] exceeded"))); + } + } + + private static String makeErrorMessage(final long limit) + { + return StringUtils.format("Limit[%d] exceeded", limit); + } +} diff --git a/server/src/main/java/org/apache/druid/client/BrokerServerView.java b/server/src/main/java/org/apache/druid/client/BrokerServerView.java index 2cb2bec03b59..f2eb62db0208 100644 --- a/server/src/main/java/org/apache/druid/client/BrokerServerView.java +++ b/server/src/main/java/org/apache/druid/client/BrokerServerView.java @@ -44,6 +44,7 @@ import org.apache.druid.timeline.VersionedIntervalTimeline; import org.apache.druid.timeline.partition.PartitionChunk; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -398,6 +399,19 @@ private void runTimelineCallbacks(final Function getDruidServerMetadatas() + { + // Override default implementation for better performance. + final List retVal = new ArrayList<>(clients.size()); + + for (final QueryableDruidServer server : clients.values()) { + retVal.add(server.getServer().getMetadata()); + } + + return retVal; + } + @Override public List getDruidServers() { diff --git a/server/src/main/java/org/apache/druid/client/TimelineServerView.java b/server/src/main/java/org/apache/druid/client/TimelineServerView.java index 9a2b7b767755..9c6ee608e1f4 100644 --- a/server/src/main/java/org/apache/druid/client/TimelineServerView.java +++ b/server/src/main/java/org/apache/druid/client/TimelineServerView.java @@ -27,6 +27,7 @@ import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.TimelineLookup; +import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.concurrent.Executor; @@ -45,10 +46,23 @@ public interface TimelineServerView extends ServerView * * @throws IllegalStateException if 'analysis' does not represent a scan-based datasource of a single table */ - Optional> getTimeline(DataSourceAnalysis analysis); + > Optional getTimeline(DataSourceAnalysis analysis); /** - * Returns a list of {@link ImmutableDruidServer} + * Returns a snapshot of the current set of server metadata. + */ + default List getDruidServerMetadatas() + { + final List druidServers = getDruidServers(); + final List metadatas = new ArrayList<>(druidServers.size()); + for (final ImmutableDruidServer druidServer : druidServers) { + metadatas.add(druidServer.getMetadata()); + } + return metadatas; + } + + /** + * Returns a snapshot of the current servers, their metadata, and their inventory. */ List getDruidServers(); diff --git a/server/src/main/java/org/apache/druid/discovery/DataServerClient.java b/server/src/main/java/org/apache/druid/discovery/DataServerClient.java index ce7ac325b62b..ce3d62ca91b5 100644 --- a/server/src/main/java/org/apache/druid/discovery/DataServerClient.java +++ b/server/src/main/java/org/apache/druid/discovery/DataServerClient.java @@ -35,7 +35,7 @@ import org.apache.druid.java.util.http.client.response.StatusResponseHolder; import org.apache.druid.query.Query; import org.apache.druid.query.context.ResponseContext; -import org.apache.druid.rpc.FixedSetServiceLocator; +import org.apache.druid.rpc.FixedServiceLocator; import org.apache.druid.rpc.RequestBuilder; import org.apache.druid.rpc.ServiceClient; import org.apache.druid.rpc.ServiceClientFactory; @@ -71,7 +71,7 @@ public DataServerClient( { this.serviceClient = serviceClientFactory.makeClient( serviceLocation.getHost(), - FixedSetServiceLocator.forServiceLocation(serviceLocation), + new FixedServiceLocator(serviceLocation), StandardRetryPolicy.noRetries() ); this.serviceLocation = serviceLocation; diff --git a/server/src/main/java/org/apache/druid/messages/MessageBatch.java b/server/src/main/java/org/apache/druid/messages/MessageBatch.java new file mode 100644 index 000000000000..51209ff6d243 --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/MessageBatch.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.druid.messages.client.MessageRelay; +import org.apache.druid.messages.server.MessageRelayResource; +import org.apache.druid.messages.server.Outbox; + +import java.util.List; +import java.util.Objects; + +/** + * A batch of messages collected by {@link MessageRelay} from a remote {@link Outbox} through + * {@link MessageRelayResource#httpGetMessagesFromOutbox}. + */ +public class MessageBatch +{ + private final List messages; + private final long epoch; + private final long startWatermark; + + @JsonCreator + public MessageBatch( + @JsonProperty("messages") final List messages, + @JsonProperty("epoch") final long epoch, + @JsonProperty("watermark") final long startWatermark + ) + { + this.messages = messages; + this.epoch = epoch; + this.startWatermark = startWatermark; + } + + /** + * The messages. + */ + @JsonProperty + public List getMessages() + { + return messages; + } + + /** + * Epoch, which is associated with a specific instance of {@link Outbox}. + */ + @JsonProperty + @JsonInclude(JsonInclude.Include.NON_DEFAULT) + public long getEpoch() + { + return epoch; + } + + /** + * Watermark, an incrementing message ID that enables clients and servers to stay in sync, and enables + * acknowledging of messages. + */ + @JsonProperty("watermark") + @JsonInclude(JsonInclude.Include.NON_DEFAULT) + public long getStartWatermark() + { + return startWatermark; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + MessageBatch that = (MessageBatch) o; + return epoch == that.epoch && startWatermark == that.startWatermark && Objects.equals(messages, that.messages); + } + + @Override + public int hashCode() + { + return Objects.hash(messages, epoch, startWatermark); + } + + @Override + public String toString() + { + return "MessageBatch{" + + "messages=" + messages + + ", epoch=" + epoch + + ", startWatermark=" + startWatermark + + '}'; + } +} diff --git a/server/src/main/java/org/apache/druid/messages/client/MessageListener.java b/server/src/main/java/org/apache/druid/messages/client/MessageListener.java new file mode 100644 index 000000000000..6711c418f812 --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/client/MessageListener.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.client; + +import org.apache.druid.server.DruidNode; + +/** + * Listener for messages received by clients. + */ +public interface MessageListener +{ + /** + * Called when a server is added. + * + * @param node server node + */ + void serverAdded(DruidNode node); + + /** + * Called when a message is received. Should not throw exceptions. If this method does throw an exception, + * the exception is logged and the message is acknowledged anyway. + * + * @param message the message that was received + */ + void messageReceived(MessageType message); + + /** + * Called when a server is removed. + * + * @param node server node + */ + void serverRemoved(DruidNode node); +} diff --git a/server/src/main/java/org/apache/druid/messages/client/MessageRelay.java b/server/src/main/java/org/apache/druid/messages/client/MessageRelay.java new file mode 100644 index 000000000000..deda87c7d23d --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/client/MessageRelay.java @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.client; + +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.messages.MessageBatch; +import org.apache.druid.messages.server.MessageRelayResource; +import org.apache.druid.rpc.ServiceClosedException; +import org.apache.druid.server.DruidNode; + +import java.io.Closeable; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Relays run on clients, and receive messages from a server. + * Uses {@link MessageRelayClient} to communicate with the {@link MessageRelayResource} on a server. + * that flows upstream + */ +public class MessageRelay implements Closeable +{ + private static final Logger log = new Logger(MessageRelay.class); + + /** + * Value to provide for epoch on the initial call to {@link MessageRelayClient#getMessages(String, long, long)}. + */ + public static final long INIT = -1; + + private final String selfHost; + private final DruidNode serverNode; + private final MessageRelayClient client; + private final AtomicBoolean closed = new AtomicBoolean(false); + private final Collector collector; + + public MessageRelay( + final String selfHost, + final DruidNode serverNode, + final MessageRelayClient client, + final MessageListener listener + ) + { + this.selfHost = selfHost; + this.serverNode = serverNode; + this.client = client; + this.collector = new Collector(listener); + } + + /** + * Start the {@link Collector}. + */ + public void start() + { + collector.start(); + } + + /** + * Stop the {@link Collector}. + */ + @Override + public void close() + { + if (closed.compareAndSet(false, true)) { + collector.stop(); + } + } + + /** + * Retrieves messages that are being sent to this client and hands them to {@link #listener}. + */ + private class Collector + { + private final MessageListener listener; + private final AtomicLong epoch = new AtomicLong(INIT); + private final AtomicLong watermark = new AtomicLong(INIT); + private final AtomicReference> currentCall = new AtomicReference<>(); + + public Collector(final MessageListener listener) + { + this.listener = listener; + } + + private void start() + { + if (!watermark.compareAndSet(INIT, 0)) { + throw new ISE("Already started"); + } + + listener.serverAdded(serverNode); + issueNextGetMessagesCall(); + } + + private void issueNextGetMessagesCall() + { + if (closed.get()) { + return; + } + + final long theEpoch = epoch.get(); + final long theWatermark = watermark.get(); + + log.debug( + "Getting messages from server[%s] for client[%s] (current state: epoch[%s] watermark[%s]).", + serverNode.getHostAndPortToUse(), + selfHost, + theEpoch, + theWatermark + ); + + final ListenableFuture> future = client.getMessages(selfHost, theEpoch, theWatermark); + + if (!currentCall.compareAndSet(null, future)) { + log.error( + "Fatal error: too many outgoing calls to server[%s] for client[%s] " + + "(current state: epoch[%s] watermark[%s]). Closing collector.", + serverNode.getHostAndPortToUse(), + selfHost, + theEpoch, + theWatermark + ); + + close(); + return; + } + + Futures.addCallback( + future, + new FutureCallback>() + { + @Override + public void onSuccess(final MessageBatch result) + { + log.debug("Received message batch: %s", result); + currentCall.compareAndSet(future, null); + final long endWatermark = result.getStartWatermark() + result.getMessages().size(); + if (theEpoch == INIT) { + epoch.set(result.getEpoch()); + watermark.set(endWatermark); + } else if (epoch.get() != result.getEpoch() + || !watermark.compareAndSet(result.getStartWatermark(), endWatermark)) { + // We don't expect to see this unless there is somehow another collector running with the same + // clientHost. If the unexpected happens, log it and close the collector. It will stay, doing + // nothing, in the MessageCollectors map until it is removed by the discovery listener. + log.error( + "Incorrect epoch + watermark from server[%s] for client[%s] " + + "(expected[%s:%s] but got[%s:%s]). " + + "Closing collector.", + serverNode.getHostAndPortToUse(), + selfHost, + theEpoch, + theWatermark, + result.getEpoch(), + result.getStartWatermark() + ); + + close(); + return; + } + + for (final MessageType message : result.getMessages()) { + try { + listener.messageReceived(message); + } + catch (Throwable e) { + log.warn( + e, + "Failed to handle message[%s] from server[%s] for client[%s].", + message, + selfHost, + serverNode.getHostAndPortToUse() + ); + } + } + + issueNextGetMessagesCall(); + } + + @Override + public void onFailure(final Throwable e) + { + currentCall.compareAndSet(future, null); + if (!(e instanceof CancellationException) && !(e instanceof ServiceClosedException)) { + // We don't expect to see any other errors, since we use an unlimited retry policy for clients. If the + // unexpected happens, log it and close the collector. It will stay, doing nothing, in the + // MessageCollectors map until it is removed by the discovery listener. + log.error( + e, + "Fatal error contacting server[%s] for client[%s] " + + "(current state: epoch[%s] watermark[%s]). " + + "Closing collector.", + serverNode.getHostAndPortToUse(), + selfHost, + theEpoch, + theWatermark + ); + } + + close(); + } + }, + Execs.directExecutor() + ); + } + + public void stop() + { + final ListenableFuture future = currentCall.getAndSet(null); + if (future != null) { + future.cancel(true); + } + + try { + listener.serverRemoved(serverNode); + } + catch (Throwable e) { + log.warn(e, "Failed to close server[%s]", serverNode.getHostAndPortToUse()); + } + } + } +} diff --git a/server/src/main/java/org/apache/druid/messages/client/MessageRelayClient.java b/server/src/main/java/org/apache/druid/messages/client/MessageRelayClient.java new file mode 100644 index 000000000000..fad228f7b5f0 --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/client/MessageRelayClient.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.client; + +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.messages.MessageBatch; +import org.apache.druid.messages.server.MessageRelayResource; + +/** + * Client for {@link MessageRelayResource}. + */ +public interface MessageRelayClient +{ + /** + * Get the next batch of messages from an outbox. + * + * @param clientHost which outbox to retrieve messages from. Each clientHost has its own outbox. + * @param epoch outbox epoch, or {@link MessageRelay#INIT} if this is the first call from the collector. + * @param startWatermark outbox message watermark to retrieve from. + * + * @return future that resolves to the next batch of messages + * + * @see MessageRelayResource#httpGetMessagesFromOutbox http endpoint this method calls + */ + ListenableFuture> getMessages(String clientHost, long epoch, long startWatermark); +} diff --git a/server/src/main/java/org/apache/druid/messages/client/MessageRelayClientImpl.java b/server/src/main/java/org/apache/druid/messages/client/MessageRelayClientImpl.java new file mode 100644 index 000000000000..140bd45e1af4 --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/client/MessageRelayClientImpl.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.client; + +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.jackson.JacksonUtils; +import org.apache.druid.java.util.http.client.response.BytesFullResponseHandler; +import org.apache.druid.messages.MessageBatch; +import org.apache.druid.rpc.RequestBuilder; +import org.apache.druid.rpc.ServiceClient; +import org.eclipse.jetty.http.HttpStatus; +import org.jboss.netty.handler.codec.http.HttpMethod; + +import java.util.Collections; + +/** + * Production implementation of {@link MessageRelayClient}. + */ +public class MessageRelayClientImpl implements MessageRelayClient +{ + private final ServiceClient serviceClient; + private final ObjectMapper smileMapper; + private final JavaType inMessageBatchType; + + public MessageRelayClientImpl( + final ServiceClient serviceClient, + final ObjectMapper smileMapper, + final Class inMessageClass + ) + { + this.serviceClient = serviceClient; + this.smileMapper = smileMapper; + this.inMessageBatchType = smileMapper.getTypeFactory().constructParametricType(MessageBatch.class, inMessageClass); + } + + @Override + public ListenableFuture> getMessages( + final String clientHost, + final long epoch, + final long startWatermark + ) + { + final String path = StringUtils.format( + "/outbox/%s/messages?epoch=%d&watermark=%d", + StringUtils.urlEncode(clientHost), + epoch, + startWatermark + ); + + return FutureUtils.transform( + serviceClient.asyncRequest( + new RequestBuilder(HttpMethod.GET, path), + new BytesFullResponseHandler() + ), + holder -> { + if (holder.getResponse().getStatus().getCode() == HttpStatus.NO_CONTENT_204) { + return new MessageBatch<>(Collections.emptyList(), epoch, startWatermark); + } else { + return JacksonUtils.readValue(smileMapper, holder.getContent(), inMessageBatchType); + } + } + ); + } +} diff --git a/server/src/main/java/org/apache/druid/messages/client/MessageRelayFactory.java b/server/src/main/java/org/apache/druid/messages/client/MessageRelayFactory.java new file mode 100644 index 000000000000..b647b9e4b6a2 --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/client/MessageRelayFactory.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.client; + +import org.apache.druid.server.DruidNode; + +/** + * Factory for creating new message relays. Used by {@link MessageRelays}. + */ +public interface MessageRelayFactory +{ + MessageRelay newRelay(DruidNode druidNode); +} diff --git a/server/src/main/java/org/apache/druid/messages/client/MessageRelays.java b/server/src/main/java/org/apache/druid/messages/client/MessageRelays.java new file mode 100644 index 000000000000..e7af8fc51b55 --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/client/MessageRelays.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.client; + +import com.google.errorprone.annotations.concurrent.GuardedBy; +import org.apache.druid.discovery.DiscoveryDruidNode; +import org.apache.druid.discovery.DruidNodeDiscovery; +import org.apache.druid.discovery.DruidNodeDiscoveryProvider; +import org.apache.druid.guice.ManageLifecycle; +import org.apache.druid.java.util.common.Pair; +import org.apache.druid.java.util.common.lifecycle.LifecycleStart; +import org.apache.druid.java.util.common.lifecycle.LifecycleStop; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.server.DruidNode; +import org.apache.druid.utils.CloseableUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + +/** + * Manages a fleet of {@link MessageRelay}, one for each server discovered by a {@link DruidNodeDiscoveryProvider}. + */ +@ManageLifecycle +public class MessageRelays +{ + private static final Logger log = new Logger(MessageRelays.class); + + @GuardedBy("serverRelays") + private final Map> serverRelays = new HashMap<>(); + private final Supplier discoverySupplier; + private final MessageRelayFactory messageRelayFactory; + private final MessageRelaysListener listener; + + private volatile DruidNodeDiscovery discovery; + + public MessageRelays( + final Supplier discoverySupplier, + final MessageRelayFactory messageRelayFactory + ) + { + this.discoverySupplier = discoverySupplier; + this.messageRelayFactory = messageRelayFactory; + this.listener = new MessageRelaysListener(); + } + + @LifecycleStart + public void start() + { + discovery = discoverySupplier.get(); + discovery.registerListener(listener); + } + + @LifecycleStop + public void stop() + { + if (discovery != null) { + discovery.removeListener(listener); + discovery = null; + } + + synchronized (serverRelays) { + try { + CloseableUtils.closeAll(serverRelays.values()); + } + catch (IOException e) { + throw new RuntimeException(e); + } + finally { + serverRelays.clear(); + } + } + } + + /** + * Discovery listener. Creates and tears down individual host relays. + */ + class MessageRelaysListener implements DruidNodeDiscovery.Listener + { + @Override + public void nodesAdded(final Collection nodes) + { + synchronized (serverRelays) { + for (final DiscoveryDruidNode node : nodes) { + final DruidNode druidNode = node.getDruidNode(); + + serverRelays.computeIfAbsent(druidNode.getHostAndPortToUse(), ignored -> { + final MessageRelay relay = messageRelayFactory.newRelay(druidNode); + relay.start(); + return relay; + }); + } + } + } + + @Override + public void nodesRemoved(final Collection nodes) + { + final List>> removed = new ArrayList<>(); + + synchronized (serverRelays) { + for (final DiscoveryDruidNode node : nodes) { + final DruidNode druidNode = node.getDruidNode(); + final String druidHost = druidNode.getHostAndPortToUse(); + final MessageRelay relay = serverRelays.remove(druidHost); + if (relay != null) { + removed.add(Pair.of(druidHost, relay)); + } + } + } + + for (final Pair> relay : removed) { + try { + relay.rhs.close(); + } + catch (Throwable e) { + log.noStackTrace().warn(e, "Could not close relay for server[%s]. Dropping.", relay.lhs); + } + } + } + } +} diff --git a/server/src/main/java/org/apache/druid/messages/package-info.java b/server/src/main/java/org/apache/druid/messages/package-info.java new file mode 100644 index 000000000000..9eb36d1e181c --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/package-info.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Message relays provide a mechanism to send messages from server to client using long polling. The messages are + * sent in order, with acknowledgements from client to server when a message has been successfully delivered. + * + * This is useful when there is some need for some "downstream" servers to send low-latency messages to some + * "upstream" server, but where establishing connections from downstream servers to upstream servers would not be + * desirable. This is typically done when upstream servers want to keep state in-memory that is updated incrementally + * by downstream servers, and where there may be lots of instances of downstream servers. + * + * This structure has two main benefits. First, it prevents upstream servers from being overwhelmed by connections + * from downstream servers. Second, it allows upstream servers to drive the updates of their own state, and better + * handle events like restarts and leader changes. + * + * On the downstream (server) side, messages are placed into an {@link org.apache.druid.messages.server.Outbox} + * and served by a {@link org.apache.druid.messages.server.MessageRelayResource}. + * + * On the upstream (client) side, messages are retrieved by {@link org.apache.druid.messages.client.MessageRelays} + * using {@link org.apache.druid.messages.client.MessageRelayClient}. + * + * This is currently used by Dart (multi-stage-query engine running on Brokers and Historicals) to implement + * worker-to-controller messages. In the future it may also be used to implement + * {@link org.apache.druid.server.coordination.ChangeRequestHttpSyncer}. + */ + +package org.apache.druid.messages; diff --git a/server/src/main/java/org/apache/druid/messages/server/MessageRelayMonitor.java b/server/src/main/java/org/apache/druid/messages/server/MessageRelayMonitor.java new file mode 100644 index 000000000000..1126f273ccaa --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/server/MessageRelayMonitor.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.server; + +import org.apache.druid.discovery.DiscoveryDruidNode; +import org.apache.druid.discovery.DruidNodeDiscovery; +import org.apache.druid.discovery.DruidNodeDiscoveryProvider; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.java.util.common.lifecycle.LifecycleStart; + +import java.util.Collection; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Code that runs on message servers, to monitor their clients. When a client vanishes, its outbox is reset using + * {@link Outbox#resetOutbox(String)}. + */ +public class MessageRelayMonitor +{ + private final DruidNodeDiscoveryProvider discoveryProvider; + private final Outbox outbox; + private final NodeRole clientRole; + + public MessageRelayMonitor( + final DruidNodeDiscoveryProvider discoveryProvider, + final Outbox outbox, + final NodeRole clientRole + ) + { + this.discoveryProvider = discoveryProvider; + this.outbox = outbox; + this.clientRole = clientRole; + } + + @LifecycleStart + public void start() + { + discoveryProvider.getForNodeRole(clientRole).registerListener(new ClientListener()); + } + + /** + * Listener that cancels work associated with clients that have gone away. + */ + private class ClientListener implements DruidNodeDiscovery.Listener + { + @Override + public void nodesAdded(Collection nodes) + { + // Nothing to do. Although, perhaps it would make sense to *set up* an outbox here. (Currently, outboxes are + // created on-demand as they receive messages.) + } + + @Override + public void nodesRemoved(Collection nodes) + { + final Set hostsRemoved = + nodes.stream().map(node -> node.getDruidNode().getHostAndPortToUse()).collect(Collectors.toSet()); + + for (final String clientHost : hostsRemoved) { + outbox.resetOutbox(clientHost); + } + } + } +} diff --git a/server/src/main/java/org/apache/druid/messages/server/MessageRelayResource.java b/server/src/main/java/org/apache/druid/messages/server/MessageRelayResource.java new file mode 100644 index 000000000000..f8e771d378c7 --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/server/MessageRelayResource.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.server; + +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.jaxrs.smile.SmileMediaTypes; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.messages.MessageBatch; +import org.apache.druid.messages.client.MessageListener; +import org.apache.druid.messages.client.MessageRelayClient; + +import javax.servlet.AsyncContext; +import javax.servlet.AsyncEvent; +import javax.servlet.AsyncListener; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.ws.rs.GET; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.QueryParam; +import javax.ws.rs.core.Context; +import java.io.IOException; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Server-side resource for message relaying. Wraps an {@link Outbox} and {@link MessageListener}. + * The client for this resource is {@link MessageRelayClient}. + */ +public class MessageRelayResource +{ + private static final Logger log = new Logger(MessageRelayResource.class); + private static final long GET_MESSAGES_TIMEOUT = 30_000L; + + /** + * Outbox for messages sent from this server. + */ + private final Outbox outbox; + + /** + * Message relay protocol uses Smile. + */ + private final ObjectMapper smileMapper; + + /** + * Type of {@link MessageBatch} of {@link MessageType}. + */ + private final JavaType batchType; + + public MessageRelayResource( + final Outbox outbox, + final ObjectMapper smileMapper, + final Class messageClass + ) + { + this.outbox = outbox; + this.smileMapper = smileMapper; + this.batchType = smileMapper.getTypeFactory().constructParametricType(MessageBatch.class, messageClass); + } + + /** + * Retrieve messages from the outbox for a particular client, as a {@link MessageBatch} in Smile format. + * The messages are retrieved from {@link Outbox#getMessages(String, long, long)}. + * + * This is a long-polling async method, using {@link AsyncContext} to wait up to {@link #GET_MESSAGES_TIMEOUT} for + * messages to appear in the outbox. + * + * @return HTTP 200 with Smile response with messages on success; HTTP 204 (No Content) if no messages were put in + * the outbox before the timeout {@link #GET_MESSAGES_TIMEOUT} elapsed + * + * @see Outbox#getMessages(String, long, long) for more details on the API + */ + @GET + @Path("/outbox/{clientHost}/messages") + public Void httpGetMessagesFromOutbox( + @PathParam("clientHost") final String clientHost, + @QueryParam("epoch") final Long epoch, + @QueryParam("watermark") final Long watermark, + @Context final HttpServletRequest req + ) throws IOException + { + if (epoch == null || watermark == null || clientHost == null || clientHost.isEmpty()) { + AsyncContext asyncContext = req.startAsync(); + HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); + response.sendError(HttpServletResponse.SC_BAD_REQUEST); + asyncContext.complete(); + return null; + } + + final AtomicBoolean didRespond = new AtomicBoolean(); + final ListenableFuture> batchFuture = outbox.getMessages(clientHost, epoch, watermark); + final AsyncContext asyncContext = req.startAsync(); + asyncContext.setTimeout(GET_MESSAGES_TIMEOUT); + asyncContext.addListener( + new AsyncListener() + { + @Override + public void onComplete(AsyncEvent event) + { + } + + @Override + public void onTimeout(AsyncEvent event) + { + if (didRespond.compareAndSet(false, true)) { + HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); + response.setStatus(HttpServletResponse.SC_NO_CONTENT); + event.getAsyncContext().complete(); + batchFuture.cancel(true); + } + } + + @Override + public void onError(AsyncEvent event) + { + } + + @Override + public void onStartAsync(AsyncEvent event) + { + } + } + ); + + // Save these items, since "req" becomes inaccessible in future exception handlers. + final String remoteAddr = req.getRemoteAddr(); + final String requestURI = req.getRequestURI(); + + Futures.addCallback( + batchFuture, + new FutureCallback>() + { + @Override + public void onSuccess(MessageBatch result) + { + if (didRespond.compareAndSet(false, true)) { + log.debug("Sending message batch: %s", result); + try { + HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); + response.setStatus(HttpServletResponse.SC_OK); + response.setContentType(SmileMediaTypes.APPLICATION_JACKSON_SMILE); + smileMapper.writerFor(batchType) + .writeValue(asyncContext.getResponse().getOutputStream(), result); + response.getOutputStream().close(); + asyncContext.complete(); + } + catch (Exception e) { + log.noStackTrace().warn(e, "Could not respond to request from[%s] to[%s]", remoteAddr, requestURI); + } + } + } + + @Override + public void onFailure(Throwable e) + { + if (didRespond.compareAndSet(false, true)) { + try { + HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); + response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + asyncContext.complete(); + } + catch (Exception e2) { + e.addSuppressed(e2); + } + + log.noStackTrace().warn(e, "Request failed from[%s] to[%s]", remoteAddr, requestURI); + } + } + }, + Execs.directExecutor() + ); + + return null; + } +} diff --git a/server/src/main/java/org/apache/druid/messages/server/Outbox.java b/server/src/main/java/org/apache/druid/messages/server/Outbox.java new file mode 100644 index 000000000000..4fcf130f0a9f --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/server/Outbox.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.server; + +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.messages.MessageBatch; +import org.apache.druid.messages.client.MessageRelay; + +/** + * An outbox for messages sent from servers to clients. Messages are retrieved in the order they are sent. + * + * @see org.apache.druid.messages package-level javadoc for description of the message relay system + */ +public interface Outbox +{ + /** + * Send a message to a client, through an outbox. + * + * @param clientHost which outbox to send messages through. Each clientHost has its own outbox. + * @param message message to send + * + * @return future that resolves successfully when the client has acknowledged the message + */ + ListenableFuture sendMessage(String clientHost, MessageType message); + + /** + * Get the next batch of messages for an client, from an outbox. Messages are retrieved in the order they were sent. + * + * The provided epoch must either be {@link MessageRelay#INIT}, or must match the epoch of the outbox as indicated by + * {@link MessageBatch#getEpoch()} returned by previous calls to the same outbox. If the provided epoch does not + * match, an empty batch is returned with the correct epoch indicated in {@link MessageBatch#getEpoch()}. + * + * The provided watermark must be greater than, or equal to, the previous watermark supplied to the same outbox. + * Any messages lower than the watermark are acknowledged and removed from the outbox. + * + * @param clientHost which outbox to retrieve messages from. Each clientHost has its own outbox. + * @param epoch outbox epoch, or {@link MessageRelay#INIT} if this is the first call from the collector. + * @param startWatermark outbox message watermark to retrieve from. + * + * @return future that resolves to the next batch of messages + */ + ListenableFuture> getMessages(String clientHost, long epoch, long startWatermark); + + /** + * Reset the outbox for a particular client. This removes all messages, cancels all outstanding futures, and + * resets the epoch. + * + * @param clientHost the client host:port + */ + void resetOutbox(String clientHost); +} diff --git a/server/src/main/java/org/apache/druid/messages/server/OutboxImpl.java b/server/src/main/java/org/apache/druid/messages/server/OutboxImpl.java new file mode 100644 index 000000000000..09e19177b945 --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/server/OutboxImpl.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.server; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import org.apache.druid.common.guava.FutureBox; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.java.util.common.Pair; +import org.apache.druid.java.util.common.lifecycle.LifecycleStop; +import org.apache.druid.messages.MessageBatch; +import org.apache.druid.messages.client.MessageRelay; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ThreadLocalRandom; + +/** + * Production implementation of {@link Outbox}. Each outbox is represented by an {@link OutboxQueue}. + */ +public class OutboxImpl implements Outbox +{ + private static final int MAX_BATCH_SIZE = 8; + + // clientHost -> outgoing message queue + private final ConcurrentHashMap> queues; + private volatile boolean stopped; + + public OutboxImpl() + { + this.queues = new ConcurrentHashMap<>(); + } + + @LifecycleStop + public void stop() + { + stopped = true; + + final Iterator> it = queues.values().iterator(); + while (it.hasNext()) { + it.next().stop(); + it.remove(); + } + } + + @Override + public ListenableFuture sendMessage(String clientHost, MessageType message) + { + if (stopped) { + return Futures.immediateCancelledFuture(); + } + + return queues.computeIfAbsent(clientHost, id -> new OutboxQueue<>()) + .sendMessage(message); + } + + @Override + public ListenableFuture> getMessages(String clientHost, long epoch, long startWatermark) + { + if (stopped) { + return Futures.immediateCancelledFuture(); + } + + final OutboxQueue queue = queues.computeIfAbsent(clientHost, id -> new OutboxQueue<>()); + if (epoch != queue.epoch && epoch != MessageRelay.INIT) { + return Futures.immediateFuture(new MessageBatch<>(Collections.emptyList(), queue.epoch, 0)); + } + + return queue.getMessages(startWatermark); + } + + @Override + public void resetOutbox(final String clientHost) + { + final OutboxQueue queue = queues.remove(clientHost); + if (queue != null) { + queue.stop(); + } + } + + @VisibleForTesting + long getOutboxEpoch(final String clientHost) + { + final OutboxQueue queue = queues.get(clientHost); + return queue != null ? queue.epoch : MessageRelay.INIT; + } + + /** + * Outgoing queue for a specific client. + */ + public static class OutboxQueue + { + /** + * Epoch, set when the outbox is created. Attached to returned batches through {@link MessageBatch#getEpoch()}. + */ + private final long epoch; + + /** + * Currently-outstanding futures. + */ + private final FutureBox pendingFutures = new FutureBox(); + + @GuardedBy("this") + private long startWatermark = 0; + + @GuardedBy("this") + private final Deque, T>> queue = new ArrayDeque<>(); + + @GuardedBy("this") + private SettableFuture messageAvailableFuture = SettableFuture.create(); + + public OutboxQueue() + { + // Random positive number, to differentiate this outbox from a previous version that may have lived + // on the same host. (When the upstream relay connects, it needs to know if this is the "same" outbox + // it was previously listening to.) + this.epoch = ThreadLocalRandom.current().nextLong() & Long.MAX_VALUE; + } + + ListenableFuture sendMessage(final T message) + { + final SettableFuture future = SettableFuture.create(); + + synchronized (this) { + queue.add(Pair.of(future, message)); + if (!messageAvailableFuture.isDone()) { + messageAvailableFuture.set(null); + } + } + + return pendingFutures.register(future); + } + + ListenableFuture> getMessages(final long newStartWatermark) + { + synchronized (this) { + // Ack and drain all messages up to startWatermark. + while (!queue.isEmpty() && startWatermark < newStartWatermark) { + final Pair, T> message = queue.poll(); + startWatermark++; + message.lhs.set(null); + } + + if (queue.isEmpty()) { + // Send next batch when a message is available. + if (messageAvailableFuture.isDone()) { + messageAvailableFuture = SettableFuture.create(); + } + + return pendingFutures.register( + FutureUtils.transform( + Futures.nonCancellationPropagating(messageAvailableFuture), + ignored -> { + synchronized (this) { + return nextBatch(); + } + } + ) + ); + } else { + return pendingFutures.register(Futures.immediateFuture(nextBatch())); + } + } + } + + void stop() + { + pendingFutures.close(); + } + + @GuardedBy("this") + private MessageBatch nextBatch() + { + final List batch = new ArrayList<>(); + final Iterator, T>> it = queue.iterator(); + + while (it.hasNext() && batch.size() < MAX_BATCH_SIZE) { + batch.add(it.next().rhs); + } + + return new MessageBatch<>(batch, epoch, startWatermark); + } + } +} diff --git a/server/src/main/java/org/apache/druid/rpc/FixedServiceLocator.java b/server/src/main/java/org/apache/druid/rpc/FixedServiceLocator.java new file mode 100644 index 000000000000..06e7bd993c18 --- /dev/null +++ b/server/src/main/java/org/apache/druid/rpc/FixedServiceLocator.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.rpc; + +import com.google.common.base.Preconditions; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; + +/** + * Locator for a fixed set of {@link ServiceLocations}. + */ +public class FixedServiceLocator implements ServiceLocator +{ + private final ServiceLocations locations; + + private volatile boolean closed = false; + + public FixedServiceLocator(final ServiceLocations locations) + { + this.locations = Preconditions.checkNotNull(locations); + } + + public FixedServiceLocator(final ServiceLocation location) + { + this(ServiceLocations.forLocation(location)); + } + + @Override + public ListenableFuture locate() + { + if (closed) { + return Futures.immediateFuture(ServiceLocations.closed()); + } else { + return Futures.immediateFuture(locations); + } + } + + @Override + public void close() + { + closed = true; + } +} diff --git a/server/src/main/java/org/apache/druid/rpc/FixedSetServiceLocator.java b/server/src/main/java/org/apache/druid/rpc/FixedSetServiceLocator.java deleted file mode 100644 index d6f6eff9d7fd..000000000000 --- a/server/src/main/java/org/apache/druid/rpc/FixedSetServiceLocator.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.druid.rpc; - -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import org.apache.druid.server.coordination.DruidServerMetadata; -import org.jboss.netty.util.internal.ThreadLocalRandom; - -import javax.validation.constraints.NotNull; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * Basic implmentation of {@link ServiceLocator} that returns a service location from a static set of locations. Returns - * a random location each time one is requested. - */ -public class FixedSetServiceLocator implements ServiceLocator -{ - private ServiceLocations serviceLocations; - - private FixedSetServiceLocator(ServiceLocations serviceLocations) - { - this.serviceLocations = serviceLocations; - } - - public static FixedSetServiceLocator forServiceLocation(@NotNull ServiceLocation serviceLocation) - { - return new FixedSetServiceLocator(ServiceLocations.forLocation(serviceLocation)); - } - - public static FixedSetServiceLocator forDruidServerMetadata(Set serverMetadataSet) - { - if (serverMetadataSet == null || serverMetadataSet.isEmpty()) { - return new FixedSetServiceLocator(ServiceLocations.closed()); - } else { - Set serviceLocationSet = serverMetadataSet.stream() - .map(ServiceLocation::fromDruidServerMetadata) - .collect(Collectors.toSet()); - - return new FixedSetServiceLocator(ServiceLocations.forLocations(serviceLocationSet)); - } - } - - @Override - public ListenableFuture locate() - { - if (serviceLocations.isClosed() || serviceLocations.getLocations().isEmpty()) { - return Futures.immediateFuture(ServiceLocations.closed()); - } - - Set locationSet = serviceLocations.getLocations(); - int size = locationSet.size(); - if (size == 1) { - return Futures.immediateFuture(ServiceLocations.forLocation(locationSet.stream().findFirst().get())); - } - - return Futures.immediateFuture( - ServiceLocations.forLocation( - locationSet.stream() - .skip(ThreadLocalRandom.current().nextInt(size)) - .findFirst() - .orElse(null) - ) - ); - } - - @Override - public void close() - { - serviceLocations = ServiceLocations.closed(); - } -} diff --git a/server/src/main/java/org/apache/druid/rpc/ServiceClientImpl.java b/server/src/main/java/org/apache/druid/rpc/ServiceClientImpl.java index 3178360016ab..172f220fabad 100644 --- a/server/src/main/java/org/apache/druid/rpc/ServiceClientImpl.java +++ b/server/src/main/java/org/apache/druid/rpc/ServiceClientImpl.java @@ -41,7 +41,6 @@ import javax.annotation.Nullable; import java.net.URI; -import java.net.URISyntaxException; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; @@ -497,19 +496,7 @@ static long computeBackoffMs(final ServiceRetryPolicy retryPolicy, final long at } /** - * Sanitizes IPv6 address if it has brackets. Eg. host = "[1:2:3:4:5:6:7:8]" will be returned as "1:2:3:4:5:6:7:8" - * after this function - */ - static String sanitizeHost(String host) - { - if (host.charAt(0) == '[') { - return host.substring(1, host.length() - 1); - } - return host; - } - - /** - * Returns a {@link ServiceLocation} without a path component, based on a URI. + * Returns a {@link ServiceLocation} without a path component, based on a URI. Returns null on invalid URIs. */ @Nullable @VisibleForTesting @@ -520,24 +507,17 @@ static ServiceLocation serviceLocationNoPathFromUri(@Nullable final String uriSt } try { - final URI uri = new URI(uriString); - - if (uri.getHost() == null) { - return null; - } - - final String scheme = uri.getScheme(); - final String host = sanitizeHost(uri.getHost()); - - if ("http".equals(scheme)) { - return new ServiceLocation(host, uri.getPort() < 0 ? 80 : uri.getPort(), -1, ""); - } else if ("https".equals(scheme)) { - return new ServiceLocation(host, -1, uri.getPort() < 0 ? 443 : uri.getPort(), ""); - } else { - return null; - } + final ServiceLocation location = ServiceLocation.fromUri(URI.create(uriString)); + + // Strip path. + return new ServiceLocation( + location.getHost(), + location.getPlaintextPort(), + location.getTlsPort(), + "" + ); } - catch (URISyntaxException e) { + catch (IllegalArgumentException e) { return null; } } @@ -549,8 +529,8 @@ static ServiceLocation serviceLocationNoPathFromUri(@Nullable final String uriSt static boolean serviceLocationMatches(final ServiceLocation left, final ServiceLocation right) { return left.getHost().equals(right.getHost()) - && portMatches(left.getPlaintextPort(), right.getPlaintextPort()) - && portMatches(left.getTlsPort(), right.getTlsPort()); + && portMatches(left.getPlaintextPort(), right.getPlaintextPort()) + && portMatches(left.getTlsPort(), right.getTlsPort()); } static boolean portMatches(int left, int right) diff --git a/server/src/main/java/org/apache/druid/rpc/ServiceLocation.java b/server/src/main/java/org/apache/druid/rpc/ServiceLocation.java index aeaa24318e93..974f09fe89bb 100644 --- a/server/src/main/java/org/apache/druid/rpc/ServiceLocation.java +++ b/server/src/main/java/org/apache/druid/rpc/ServiceLocation.java @@ -22,6 +22,7 @@ import com.google.common.base.Preconditions; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; +import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.ISE; import org.apache.druid.server.DruidNode; import org.apache.druid.server.coordination.DruidServerMetadata; @@ -29,6 +30,7 @@ import javax.annotation.Nullable; import javax.validation.constraints.NotNull; import java.net.MalformedURLException; +import java.net.URI; import java.net.URL; import java.util.Iterator; import java.util.Objects; @@ -40,6 +42,8 @@ public class ServiceLocation { private static final String HTTP_SCHEME = "http"; private static final String HTTPS_SCHEME = "https"; + private static final int HTTP_DEFAULT_PORT = 80; + private static final int HTTPS_DEFAULT_PORT = 443; private static final Splitter HOST_SPLITTER = Splitter.on(":").limit(2); private final String host; @@ -72,6 +76,50 @@ public static ServiceLocation fromDruidNode(final DruidNode druidNode) return new ServiceLocation(druidNode.getHost(), druidNode.getPlaintextPort(), druidNode.getTlsPort(), ""); } + /** + * Create a service location based on a {@link URI}. + * + * @throws IllegalArgumentException if the URI cannot be mapped to a service location. + */ + public static ServiceLocation fromUri(final URI uri) + { + if (uri == null || uri.getHost() == null) { + throw new IAE("URI[%s] has no host", uri); + } + + final String scheme = uri.getScheme(); + final String host = stripBrackets(uri.getHost()); + final StringBuilder basePath = new StringBuilder(); + + if (uri.getRawPath() != null) { + if (uri.getRawQuery() == null && uri.getRawFragment() == null && uri.getRawPath().endsWith("/")) { + // Strip trailing slash if the URI has no query or fragment. By convention, this trailing slash is not + // part of the service location. + basePath.append(uri.getRawPath(), 0, uri.getRawPath().length() - 1); + } else { + basePath.append(uri.getRawPath()); + } + } + + if (uri.getRawQuery() != null) { + basePath.append('?').append(uri.getRawQuery()); + } + + if (uri.getRawFragment() != null) { + basePath.append('#').append(uri.getRawFragment()); + } + + if (HTTP_SCHEME.equals(scheme)) { + final int port = uri.getPort() < 0 ? HTTP_DEFAULT_PORT : uri.getPort(); + return new ServiceLocation(host, port, -1, basePath.toString()); + } else if (HTTPS_SCHEME.equals(scheme)) { + final int port = uri.getPort() < 0 ? HTTPS_DEFAULT_PORT : uri.getPort(); + return new ServiceLocation(host, -1, port, basePath.toString()); + } else { + throw new IAE("URI[%s] has invalid scheme[%s]", uri, scheme); + } + } + /** * Create a service location based on a {@link DruidServerMetadata}. * @@ -133,6 +181,11 @@ public String getBasePath() return basePath; } + public ServiceLocation withBasePath(final String newBasePath) + { + return new ServiceLocation(host, plaintextPort, tlsPort, newBasePath); + } + public URL toURL(@Nullable final String encodedPathAndQueryString) { final String scheme; @@ -193,4 +246,15 @@ public String toString() '}'; } + /** + * Strips brackers from the host part of a URI, so we can better handle IPv6 addresses. + * e.g. host = "[1:2:3:4:5:6:7:8]" is transformed to "1:2:3:4:5:6:7:8" by this function + */ + static String stripBrackets(String host) + { + if (host.charAt(0) == '[' && host.charAt(host.length() - 1) == ']') { + return host.substring(1, host.length() - 1); + } + return host; + } } diff --git a/server/src/test/java/org/apache/druid/client/BrokerServerViewTest.java b/server/src/test/java/org/apache/druid/client/BrokerServerViewTest.java index 798b55ed7274..fd90ff905a22 100644 --- a/server/src/test/java/org/apache/druid/client/BrokerServerViewTest.java +++ b/server/src/test/java/org/apache/druid/client/BrokerServerViewTest.java @@ -25,6 +25,7 @@ import com.google.common.base.Predicates; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Sets; @@ -65,6 +66,7 @@ import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; +import java.util.stream.Collectors; public class BrokerServerViewTest extends CuratorTestBase { @@ -290,6 +292,12 @@ public void testMultipleServerAndBroker() throws Exception ) ); + // check server metadatas + Assert.assertEquals( + druidServers.stream().map(DruidServer::getMetadata).collect(Collectors.toSet()), + ImmutableSet.copyOf(brokerServerView.getDruidServerMetadatas()) + ); + // unannounce the broker segment should do nothing to announcements unannounceSegmentForServer(druidBroker, brokerSegment, zkPathsConfig); Assert.assertTrue(timing.forWaiting().awaitLatch(segmentRemovedLatch)); @@ -593,7 +601,8 @@ private void setupViews() throws Exception setupViews(null, null, true); } - private void setupViews(Set watchedTiers, Set ignoredTiers, boolean watchRealtimeTasks) throws Exception + private void setupViews(Set watchedTiers, Set ignoredTiers, boolean watchRealtimeTasks) + throws Exception { baseView = new BatchServerInventoryView( zkPathsConfig, diff --git a/server/src/test/java/org/apache/druid/messages/MessageBatchTest.java b/server/src/test/java/org/apache/druid/messages/MessageBatchTest.java new file mode 100644 index 000000000000..bcf9fb3423d1 --- /dev/null +++ b/server/src/test/java/org/apache/druid/messages/MessageBatchTest.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import nl.jqno.equalsverifier.EqualsVerifier; +import org.apache.druid.segment.TestHelper; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; + +public class MessageBatchTest +{ + @Test + public void test_serde() throws IOException + { + final ObjectMapper objectMapper = TestHelper.JSON_MAPPER; + final MessageBatch batch = new MessageBatch<>(ImmutableList.of("foo", "bar"), 123L, 456L); + final MessageBatch batch2 = + objectMapper.readValue(objectMapper.writeValueAsBytes(batch), new TypeReference>() {}); + Assert.assertEquals(batch, batch2); + } + + @Test + public void test_equals() + { + EqualsVerifier.forClass(MessageBatch.class) + .usingGetClass() + .verify(); + } +} diff --git a/server/src/test/java/org/apache/druid/messages/client/MessageRelayClientImplTest.java b/server/src/test/java/org/apache/druid/messages/client/MessageRelayClientImplTest.java new file mode 100644 index 000000000000..7b8af75c8d41 --- /dev/null +++ b/server/src/test/java/org/apache/druid/messages/client/MessageRelayClientImplTest.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.client; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.smile.SmileFactory; +import com.fasterxml.jackson.jaxrs.smile.SmileMediaTypes; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.ListenableFuture; +import it.unimi.dsi.fastutil.bytes.ByteArrays; +import org.apache.druid.jackson.DefaultObjectMapper; +import org.apache.druid.messages.MessageBatch; +import org.apache.druid.rpc.MockServiceClient; +import org.apache.druid.rpc.RequestBuilder; +import org.jboss.netty.handler.codec.http.HttpMethod; +import org.jboss.netty.handler.codec.http.HttpResponseStatus; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import javax.ws.rs.core.HttpHeaders; +import java.util.Collections; + +public class MessageRelayClientImplTest +{ + private ObjectMapper smileMapper; + private MockServiceClient serviceClient; + private MessageRelayClient messageRelayClient; + + @Before + public void setup() + { + smileMapper = new DefaultObjectMapper(new SmileFactory(), null); + serviceClient = new MockServiceClient(); + messageRelayClient = new MessageRelayClientImpl<>(serviceClient, smileMapper, String.class); + } + + @After + public void tearDown() + { + serviceClient.verify(); + } + + @Test + public void test_getMessages_ok() throws Exception + { + final MessageBatch batch = new MessageBatch<>(ImmutableList.of("foo", "bar"), 123, 0); + + serviceClient.expectAndRespond( + new RequestBuilder(HttpMethod.GET, "/outbox/me/messages?epoch=-1&watermark=0"), + HttpResponseStatus.OK, + ImmutableMap.of(HttpHeaders.CONTENT_TYPE, SmileMediaTypes.APPLICATION_JACKSON_SMILE), + smileMapper.writeValueAsBytes(batch) + ); + + final ListenableFuture> result = messageRelayClient.getMessages("me", MessageRelay.INIT, 0); + Assert.assertEquals(batch, result.get()); + } + + @Test + public void test_getMessages_noContent() throws Exception + { + serviceClient.expectAndRespond( + new RequestBuilder(HttpMethod.GET, "/outbox/me/messages?epoch=-1&watermark=0"), + HttpResponseStatus.NO_CONTENT, + ImmutableMap.of(HttpHeaders.CONTENT_TYPE, SmileMediaTypes.APPLICATION_JACKSON_SMILE), + ByteArrays.EMPTY_ARRAY + ); + + final ListenableFuture> result = messageRelayClient.getMessages("me", MessageRelay.INIT, 0); + Assert.assertEquals(new MessageBatch<>(Collections.emptyList(), MessageRelay.INIT, 0), result.get()); + } +} diff --git a/server/src/test/java/org/apache/druid/messages/client/MessageRelaysTest.java b/server/src/test/java/org/apache/druid/messages/client/MessageRelaysTest.java new file mode 100644 index 000000000000..b2014450d81e --- /dev/null +++ b/server/src/test/java/org/apache/druid/messages/client/MessageRelaysTest.java @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.client; + +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import org.apache.druid.discovery.DiscoveryDruidNode; +import org.apache.druid.discovery.DruidNodeDiscovery; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.messages.MessageBatch; +import org.apache.druid.messages.server.Outbox; +import org.apache.druid.messages.server.OutboxImpl; +import org.apache.druid.server.DruidNode; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.function.Consumer; + +public class MessageRelaysTest +{ + private static final String MY_HOST = "me"; + private static final DruidNode OUTBOX_NODE = new DruidNode("service", "host", false, 80, -1, true, false); + private static final DiscoveryDruidNode OUTBOX_DISCO_NODE = new DiscoveryDruidNode( + new DruidNode("service", "host", false, 80, -1, true, false), + NodeRole.HISTORICAL, + Collections.emptyMap() + ); + + private Outbox outbox; + private TestMessageListener messageListener; + private TestDiscovery discovery; + private MessageRelays messageRelays; + + @Before + public void setUp() + { + outbox = new OutboxImpl<>(); + messageListener = new TestMessageListener(); + discovery = new TestDiscovery(); + messageRelays = new MessageRelays<>( + () -> discovery, + node -> { + Assert.assertEquals(OUTBOX_NODE, node); + return new MessageRelay<>( + MY_HOST, + node, + new OutboxMessageRelayClient(outbox), + messageListener + ); + } + ); + messageRelays.start(); + } + + @After + public void tearDown() + { + messageRelays.stop(); + Assert.assertEquals(Collections.emptyList(), discovery.getListeners()); + } + + @Test + public void test_serverAdded_thenRemoved() + { + discovery.fire(listener -> listener.nodesAdded(Collections.singletonList(OUTBOX_DISCO_NODE))); + discovery.fire(listener -> listener.nodesRemoved(Collections.singletonList(OUTBOX_DISCO_NODE))); + Assert.assertEquals(1, messageListener.getAdds()); + Assert.assertEquals(1, messageListener.getRemoves()); + } + + @Test + public void test_messageListener() + { + discovery.fire(listener -> listener.nodesAdded(Collections.singletonList(OUTBOX_DISCO_NODE))); + Assert.assertEquals(1, messageListener.getAdds()); + Assert.assertEquals(0, messageListener.getRemoves()); + + final ListenableFuture sendFuture = outbox.sendMessage(MY_HOST, "foo"); + Assert.assertEquals(ImmutableList.of("foo"), messageListener.getMessages()); + Assert.assertTrue(sendFuture.isDone()); + + final ListenableFuture sendFuture2 = outbox.sendMessage(MY_HOST, "bar"); + Assert.assertEquals(ImmutableList.of("foo", "bar"), messageListener.getMessages()); + Assert.assertTrue(sendFuture2.isDone()); + } + + /** + * Implementation of {@link MessageListener} that tracks all received messages. + */ + private static class TestMessageListener implements MessageListener + { + @GuardedBy("this") + private long adds; + + @GuardedBy("this") + private long removes; + + @GuardedBy("this") + private final List messages = new ArrayList<>(); + + @Override + public synchronized void serverAdded(DruidNode node) + { + adds++; + } + + @Override + public synchronized void messageReceived(String message) + { + messages.add(message); + } + + @Override + public synchronized void serverRemoved(DruidNode node) + { + removes++; + } + + public synchronized long getAdds() + { + return adds; + } + + public synchronized long getRemoves() + { + return removes; + } + + public synchronized List getMessages() + { + return ImmutableList.copyOf(messages); + } + } + + /** + * Implementation of {@link MessageRelayClient} that directly uses an {@link Outbox}, rather than contacting + * a remote outbox. + */ + private static class OutboxMessageRelayClient implements MessageRelayClient + { + private final Outbox outbox; + + public OutboxMessageRelayClient(final Outbox outbox) + { + this.outbox = outbox; + } + + @Override + public ListenableFuture> getMessages(String clientHost, long epoch, long startWatermark) + { + return outbox.getMessages(clientHost, epoch, startWatermark); + } + } + + /** + * Implementation of {@link DruidNodeDiscovery} that allows firing listeners on command. + */ + private static class TestDiscovery implements DruidNodeDiscovery + { + @GuardedBy("this") + private final List listeners; + + public TestDiscovery() + { + listeners = new ArrayList<>(); + } + + @Override + public Collection getAllNodes() + { + throw new UnsupportedOperationException(); + } + + @Override + public synchronized void registerListener(Listener listener) + { + listeners.add(listener); + } + + @Override + public synchronized void removeListener(Listener listener) + { + listeners.remove(listener); + } + + public synchronized List getListeners() + { + return ImmutableList.copyOf(listeners); + } + + public synchronized void fire(Consumer f) + { + for (final Listener listener : listeners) { + f.accept(listener); + } + } + } +} diff --git a/server/src/test/java/org/apache/druid/messages/server/OutboxImplTest.java b/server/src/test/java/org/apache/druid/messages/server/OutboxImplTest.java new file mode 100644 index 000000000000..727c1c6ee2fc --- /dev/null +++ b/server/src/test/java/org/apache/druid/messages/server/OutboxImplTest.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.server; + +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.messages.MessageBatch; +import org.apache.druid.messages.client.MessageRelay; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.Collections; +import java.util.concurrent.ExecutionException; + +public class OutboxImplTest +{ + private static final String HOST = "h1"; + + private OutboxImpl outbox; + + @Before + public void setUp() + { + outbox = new OutboxImpl<>(); + } + + @After + public void tearDown() + { + outbox.stop(); + } + + @Test + public void test_normalOperation() throws InterruptedException, ExecutionException + { + // Send first three messages. + final ListenableFuture sendFuture1 = outbox.sendMessage(HOST, "1"); + final ListenableFuture sendFuture2 = outbox.sendMessage(HOST, "2"); + final ListenableFuture sendFuture3 = outbox.sendMessage(HOST, "3"); + + final long outboxEpoch = outbox.getOutboxEpoch(HOST); + + // No messages are acknowledged. + Assert.assertFalse(sendFuture1.isDone()); + Assert.assertFalse(sendFuture2.isDone()); + Assert.assertFalse(sendFuture3.isDone()); + + // Request all three messages (startWatermark = 0). + Assert.assertEquals( + new MessageBatch<>(ImmutableList.of("1", "2", "3"), outboxEpoch, 0), + outbox.getMessages(HOST, MessageRelay.INIT, 0).get() + ); + + // No messages are acknowledged. + Assert.assertFalse(sendFuture1.isDone()); + Assert.assertFalse(sendFuture2.isDone()); + Assert.assertFalse(sendFuture3.isDone()); + + // Request two of those messages again (startWatermark = 1). + Assert.assertEquals( + new MessageBatch<>(ImmutableList.of("2", "3"), outboxEpoch, 1), + outbox.getMessages(HOST, outboxEpoch, 1).get() + ); + + // First message is acknowledged. + Assert.assertTrue(sendFuture1.isDone()); + Assert.assertFalse(sendFuture2.isDone()); + Assert.assertFalse(sendFuture3.isDone()); + + // Request the high watermark (startWatermark = 3). + final ListenableFuture> futureBatch = outbox.getMessages(HOST, outboxEpoch, 3); + + // It's not available yet. + Assert.assertFalse(futureBatch.isDone()); + + // All messages are acknowledged. + Assert.assertTrue(sendFuture1.isDone()); + Assert.assertTrue(sendFuture2.isDone()); + Assert.assertTrue(sendFuture3.isDone()); + + // Send one more message; futureBatch resolves. + final ListenableFuture sendFuture4 = outbox.sendMessage(HOST, "4"); + Assert.assertTrue(futureBatch.isDone()); + + // sendFuture4 is not resolved. + Assert.assertFalse(sendFuture4.isDone()); + } + + @Test + public void test_getMessages_wrongEpoch() throws InterruptedException, ExecutionException + { + final ListenableFuture sendFuture = outbox.sendMessage(HOST, "1"); + final long outboxEpoch = outbox.getOutboxEpoch(HOST); + + // Fetch with the wrong epoch. + final MessageBatch batch = outbox.getMessages(HOST, outboxEpoch + 1, 0).get(); + Assert.assertEquals( + new MessageBatch<>(Collections.emptyList(), outboxEpoch, 0), + batch + ); + + Assert.assertFalse(sendFuture.isDone()); + } + + @Test + public void test_getMessages_nonexistentHost() throws InterruptedException, ExecutionException + { + // Calling getMessages with a nonexistent host creates an outbox. + final String nonexistentHost = "nonexistent"; + final ListenableFuture> batchFuture = outbox.getMessages( + nonexistentHost, + MessageRelay.INIT, + 0 + ); + Assert.assertFalse(batchFuture.isDone()); + + // Check that an outbox was created (it has an epoch). + MatcherAssert.assertThat(outbox.getOutboxEpoch(nonexistentHost), Matchers.greaterThanOrEqualTo(0L)); + + // getMessages future resolves when a message is sent. + final ListenableFuture sendFuture = outbox.sendMessage(nonexistentHost, "foo"); + Assert.assertTrue(batchFuture.isDone()); + Assert.assertEquals( + new MessageBatch<>(ImmutableList.of("foo"), outbox.getOutboxEpoch(nonexistentHost), 0), + batchFuture.get() + ); + + // As usual, sendFuture resolves when the high watermark is requested. + Assert.assertFalse(sendFuture.isDone()); + final ListenableFuture> batchFuture2 = + outbox.getMessages(nonexistentHost, outbox.getOutboxEpoch(nonexistentHost), 1); + + Assert.assertTrue(sendFuture.isDone()); + + outbox.resetOutbox(nonexistentHost); + Assert.assertTrue(batchFuture2.isDone()); + } + + @Test + public void test_stop_cancelsSendMessage() + { + final ListenableFuture sendFuture = outbox.sendMessage(HOST, "1"); + outbox.stop(); + Assert.assertTrue(sendFuture.isCancelled()); + } + + @Test + public void test_stop_cancelsGetMessages() + { + final ListenableFuture> futureBatch = outbox.getMessages(HOST, MessageRelay.INIT, 0); + outbox.stop(); + Assert.assertTrue(futureBatch.isCancelled()); + } + + @Test + public void test_reset_cancelsSendMessage() + { + final ListenableFuture sendFuture = outbox.sendMessage(HOST, "1"); + outbox.resetOutbox(HOST); + Assert.assertTrue(sendFuture.isCancelled()); + } + + @Test + public void test_reset_cancelsGetMessages() + { + final ListenableFuture> futureBatch = outbox.getMessages(HOST, MessageRelay.INIT, 0); + outbox.resetOutbox(HOST); + Assert.assertTrue(futureBatch.isCancelled()); + } + + @Test + public void test_reset_nonexistentHost_doesNothing() + { + outbox.resetOutbox("nonexistent"); + } + + @Test + public void test_stop_preventsSendMessage() + { + outbox.stop(); + final ListenableFuture sendFuture = outbox.sendMessage(HOST, "1"); + Assert.assertTrue(sendFuture.isCancelled()); + } + + @Test + public void test_stop_preventsGetMessages() + { + outbox.stop(); + final ListenableFuture> futureBatch = outbox.getMessages(HOST, MessageRelay.INIT, 0); + Assert.assertTrue(futureBatch.isCancelled()); + } +} diff --git a/server/src/test/java/org/apache/druid/rpc/FixedSetServiceLocatorTest.java b/server/src/test/java/org/apache/druid/rpc/FixedServiceLocatorTest.java similarity index 56% rename from server/src/test/java/org/apache/druid/rpc/FixedSetServiceLocatorTest.java rename to server/src/test/java/org/apache/druid/rpc/FixedServiceLocatorTest.java index b0b92f5e271b..c7775bfac833 100644 --- a/server/src/test/java/org/apache/druid/rpc/FixedSetServiceLocatorTest.java +++ b/server/src/test/java/org/apache/druid/rpc/FixedServiceLocatorTest.java @@ -27,7 +27,7 @@ import java.util.concurrent.ExecutionException; -public class FixedSetServiceLocatorTest +public class FixedServiceLocatorTest { public static final DruidServerMetadata DATA_SERVER_1 = new DruidServerMetadata( "TestDataServer", @@ -50,19 +50,24 @@ public class FixedSetServiceLocatorTest ); @Test - public void testLocateNullShouldBeClosed() throws ExecutionException, InterruptedException + public void test_constructor_rejectsNull() { - FixedSetServiceLocator serviceLocator - = FixedSetServiceLocator.forDruidServerMetadata(null); + Assert.assertThrows( + NullPointerException.class, + () -> new FixedServiceLocator((ServiceLocation) null) + ); - Assert.assertTrue(serviceLocator.locate().get().isClosed()); + Assert.assertThrows( + NullPointerException.class, + () -> new FixedServiceLocator((ServiceLocations) null) + ); } @Test - public void testLocateSingleServer() throws ExecutionException, InterruptedException + public void test_locate_singleServer() throws ExecutionException, InterruptedException { - FixedSetServiceLocator serviceLocator - = FixedSetServiceLocator.forDruidServerMetadata(ImmutableSet.of(DATA_SERVER_1)); + FixedServiceLocator serviceLocator = + new FixedServiceLocator(ServiceLocation.fromDruidServerMetadata(DATA_SERVER_1)); Assert.assertEquals( ServiceLocations.forLocation(ServiceLocation.fromDruidServerMetadata(DATA_SERVER_1)), @@ -71,16 +76,30 @@ public void testLocateSingleServer() throws ExecutionException, InterruptedExcep } @Test - public void testLocateMultipleServers() throws ExecutionException, InterruptedException + public void test_locate_afterClose() throws ExecutionException, InterruptedException { - FixedSetServiceLocator serviceLocator - = FixedSetServiceLocator.forDruidServerMetadata(ImmutableSet.of(DATA_SERVER_1, DATA_SERVER_2)); + FixedServiceLocator serviceLocator = + new FixedServiceLocator(ServiceLocation.fromDruidServerMetadata(DATA_SERVER_1)); + + serviceLocator.close(); - Assert.assertTrue( + Assert.assertEquals( + ServiceLocations.closed(), + serviceLocator.locate().get() + ); + } + + @Test + public void test_locate_multipleServers() throws ExecutionException, InterruptedException + { + final ServiceLocations locations = ServiceLocations.forLocations( ImmutableSet.of( - ServiceLocations.forLocation(ServiceLocation.fromDruidServerMetadata(DATA_SERVER_1)), - ServiceLocations.forLocation(ServiceLocation.fromDruidServerMetadata(DATA_SERVER_2)) - ).contains(serviceLocator.locate().get()) + ServiceLocation.fromDruidServerMetadata(DATA_SERVER_1), + ServiceLocation.fromDruidServerMetadata(DATA_SERVER_2) + ) ); + + FixedServiceLocator serviceLocator = new FixedServiceLocator(locations); + Assert.assertEquals(locations, serviceLocator.locate().get()); } } diff --git a/server/src/test/java/org/apache/druid/rpc/ServiceClientImplTest.java b/server/src/test/java/org/apache/druid/rpc/ServiceClientImplTest.java index 69cb12e423ca..7346edd5cf6b 100644 --- a/server/src/test/java/org/apache/druid/rpc/ServiceClientImplTest.java +++ b/server/src/test/java/org/apache/druid/rpc/ServiceClientImplTest.java @@ -685,14 +685,6 @@ public void test_serviceLocationNoPathFromUri() ); } - @Test - public void test_normalizeHost() - { - Assert.assertEquals("1:2:3:4:5:6:7:8", ServiceClientImpl.sanitizeHost("[1:2:3:4:5:6:7:8]")); - Assert.assertEquals("1:2:3:4:5:6:7:8", ServiceClientImpl.sanitizeHost("1:2:3:4:5:6:7:8")); - Assert.assertEquals("1.2.3.4", ServiceClientImpl.sanitizeHost("1.2.3.4")); - } - @Test public void test_isRedirect() { diff --git a/server/src/test/java/org/apache/druid/rpc/ServiceLocationTest.java b/server/src/test/java/org/apache/druid/rpc/ServiceLocationTest.java index 6aec0e2b6060..8d95e79dd966 100644 --- a/server/src/test/java/org/apache/druid/rpc/ServiceLocationTest.java +++ b/server/src/test/java/org/apache/druid/rpc/ServiceLocationTest.java @@ -25,8 +25,48 @@ import org.junit.Assert; import org.junit.Test; +import java.net.URI; + public class ServiceLocationTest { + @Test + public void test_stripBrackets() + { + Assert.assertEquals("1:2:3:4:5:6:7:8", ServiceLocation.stripBrackets("[1:2:3:4:5:6:7:8]")); + Assert.assertEquals("1:2:3:4:5:6:7:8", ServiceLocation.stripBrackets("1:2:3:4:5:6:7:8")); + Assert.assertEquals("1.2.3.4", ServiceLocation.stripBrackets("1.2.3.4")); + } + + @Test + public void test_fromUri_http() + { + final ServiceLocation location = ServiceLocation.fromUri(URI.create("http://example.com:8100/xyz")); + Assert.assertEquals("example.com", location.getHost()); + Assert.assertEquals(-1, location.getTlsPort()); + Assert.assertEquals(8100, location.getPlaintextPort()); + Assert.assertEquals("/xyz", location.getBasePath()); + } + + @Test + public void test_fromUri_https_defaultPort() + { + final ServiceLocation location = ServiceLocation.fromUri(URI.create("https://example.com/xyz")); + Assert.assertEquals("example.com", location.getHost()); + Assert.assertEquals(443, location.getTlsPort()); + Assert.assertEquals(-1, location.getPlaintextPort()); + Assert.assertEquals("/xyz", location.getBasePath()); + } + + @Test + public void test_fromUri_https() + { + final ServiceLocation location = ServiceLocation.fromUri(URI.create("https://example.com:8100/xyz")); + Assert.assertEquals("example.com", location.getHost()); + Assert.assertEquals(8100, location.getTlsPort()); + Assert.assertEquals(-1, location.getPlaintextPort()); + Assert.assertEquals("/xyz", location.getBasePath()); + } + @Test public void test_fromDruidServerMetadata_withPort() { diff --git a/services/src/main/java/org/apache/druid/cli/CliHistorical.java b/services/src/main/java/org/apache/druid/cli/CliHistorical.java index 2e231bcdcc3b..ea8bbd994348 100644 --- a/services/src/main/java/org/apache/druid/cli/CliHistorical.java +++ b/services/src/main/java/org/apache/druid/cli/CliHistorical.java @@ -42,6 +42,7 @@ import org.apache.druid.guice.ManageLifecycle; import org.apache.druid.guice.QueryRunnerFactoryModule; import org.apache.druid.guice.QueryableModule; +import org.apache.druid.guice.SegmentWranglerModule; import org.apache.druid.guice.ServerTypeConfig; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.query.QuerySegmentWalker; @@ -99,6 +100,7 @@ protected List getModules() new DruidProcessingModule(), new QueryableModule(), new QueryRunnerFactoryModule(), + new SegmentWranglerModule(), new JoinableFactoryModule(), new HistoricalServiceModule(), binder -> { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/IngestHandler.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/IngestHandler.java index e0b5ffdb08e3..3569864cae78 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/IngestHandler.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/IngestHandler.java @@ -186,7 +186,8 @@ protected RelDataType returnedRowType() final RelDataTypeFactory typeFactory = rootQueryRel.rel.getCluster().getTypeFactory(); return handlerContext.engine().resultTypeForInsert( typeFactory, - rootQueryRel.validatedRowType + rootQueryRel.validatedRowType, + handlerContext.queryContextMap() ); } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/QueryHandler.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/QueryHandler.java index a915833cd3ff..b19e83e040e0 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/QueryHandler.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/QueryHandler.java @@ -726,7 +726,8 @@ protected RelDataType returnedRowType() final RelDataTypeFactory typeFactory = rootQueryRel.rel.getCluster().getTypeFactory(); return handlerContext.engine().resultTypeForSelect( typeFactory, - rootQueryRel.validatedRowType + rootQueryRel.validatedRowType, + handlerContext.plannerContext().queryContextMap() ); } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeSqlEngine.java b/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeSqlEngine.java index d02d302437b8..4f3d86b1b420 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeSqlEngine.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeSqlEngine.java @@ -83,13 +83,21 @@ public void validateContext(Map queryContext) } @Override - public RelDataType resultTypeForSelect(RelDataTypeFactory typeFactory, RelDataType validatedRowType) + public RelDataType resultTypeForSelect( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) { return validatedRowType; } @Override - public RelDataType resultTypeForInsert(RelDataTypeFactory typeFactory, RelDataType validatedRowType) + public RelDataType resultTypeForInsert( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) { throw new UnsupportedOperationException(); } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/run/SqlEngine.java b/sql/src/main/java/org/apache/druid/sql/calcite/run/SqlEngine.java index fec7660e44ef..1d33b019e684 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/run/SqlEngine.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/run/SqlEngine.java @@ -57,8 +57,13 @@ public interface SqlEngine * * @param typeFactory type factory * @param validatedRowType row type from Calcite's validator + * @param queryContext query context, in case that affects the result type */ - RelDataType resultTypeForSelect(RelDataTypeFactory typeFactory, RelDataType validatedRowType); + RelDataType resultTypeForSelect( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ); /** * SQL row type that would be emitted by the {@link QueryMaker} from {@link #buildQueryMakerForInsert}. @@ -66,8 +71,13 @@ public interface SqlEngine * * @param typeFactory type factory * @param validatedRowType row type from Calcite's validator + * @param queryContext query context, in case that affects the result type */ - RelDataType resultTypeForInsert(RelDataTypeFactory typeFactory, RelDataType validatedRowType); + RelDataType resultTypeForInsert( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ); /** * Create a {@link QueryMaker} for a SELECT query. diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/view/ViewSqlEngine.java b/sql/src/main/java/org/apache/druid/sql/calcite/view/ViewSqlEngine.java index 716fa50b85f1..7563b45d52bc 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/view/ViewSqlEngine.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/view/ViewSqlEngine.java @@ -93,13 +93,21 @@ public void validateContext(Map queryContext) } @Override - public RelDataType resultTypeForSelect(RelDataTypeFactory typeFactory, RelDataType validatedRowType) + public RelDataType resultTypeForSelect( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) { return validatedRowType; } @Override - public RelDataType resultTypeForInsert(RelDataTypeFactory typeFactory, RelDataType validatedRowType) + public RelDataType resultTypeForInsert( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) { // Can't have views of INSERT or REPLACE statements. throw new UnsupportedOperationException(); diff --git a/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java b/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java index 4adea5d8d84e..d957e7155b5e 100644 --- a/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java +++ b/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java @@ -82,7 +82,7 @@ public class SqlResource private final DruidNode selfNode; @Inject - SqlResource( + protected SqlResource( final ObjectMapper jsonMapper, final AuthorizerMapper authorizerMapper, final @NativeQuery SqlStatementFactory sqlStatementFactory, @@ -140,19 +140,7 @@ public Response cancelQuery( return Response.status(Status.NOT_FOUND).build(); } - // Considers only datasource and table resources; not context - // key resources when checking permissions. This means that a user's - // permission to cancel a query depends on the datasource, not the - // context variables used in the query. - Set resources = lifecycles - .stream() - .flatMap(lifecycle -> lifecycle.resources().stream()) - .collect(Collectors.toSet()); - Access access = AuthorizationUtils.authorizeAllResourceActions( - req, - resources, - authorizerMapper - ); + final Access access = authorizeCancellation(req, lifecycles); if (access.isAllowed()) { // should remove only the lifecycles in the snapshot. @@ -341,4 +329,23 @@ public void writeException(Exception ex, OutputStream out) throws IOException out.write(jsonMapper.writeValueAsBytes(ex)); } } + + /** + * Authorize a query cancellation operation. + * + * Considers only datasource and table resources; not context key resources when checking permissions. This means + * that a user's permission to cancel a query depends on the datasource, not the context variables used in the query. + */ + public Access authorizeCancellation(final HttpServletRequest req, final List cancelables) + { + Set resources = cancelables + .stream() + .flatMap(lifecycle -> lifecycle.resources().stream()) + .collect(Collectors.toSet()); + return AuthorizationUtils.authorizeAllResourceActions( + req, + resources, + authorizerMapper + ); + } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteScanSignatureTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteScanSignatureTest.java index abab053dd6bb..80a9dde9b4c9 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteScanSignatureTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteScanSignatureTest.java @@ -155,13 +155,21 @@ public void validateContext(Map queryContext) } @Override - public RelDataType resultTypeForSelect(RelDataTypeFactory typeFactory, RelDataType validatedRowType) + public RelDataType resultTypeForSelect( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) { return validatedRowType; } @Override - public RelDataType resultTypeForInsert(RelDataTypeFactory typeFactory, RelDataType validatedRowType) + public RelDataType resultTypeForInsert( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) { throw new UnsupportedOperationException(); } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/IngestionTestSqlEngine.java b/sql/src/test/java/org/apache/druid/sql/calcite/IngestionTestSqlEngine.java index 466bd0e390bd..569598af1e4e 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/IngestionTestSqlEngine.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/IngestionTestSqlEngine.java @@ -56,13 +56,21 @@ public void validateContext(Map queryContext) } @Override - public RelDataType resultTypeForSelect(RelDataTypeFactory typeFactory, RelDataType validatedRowType) + public RelDataType resultTypeForSelect( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) { throw new UnsupportedOperationException(); } @Override - public RelDataType resultTypeForInsert(RelDataTypeFactory typeFactory, RelDataType validatedRowType) + public RelDataType resultTypeForInsert( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) { // Matches the return structure of TestInsertQueryMaker. return typeFactory.createStructType( diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/util/TestTimelineServerView.java b/sql/src/test/java/org/apache/druid/sql/calcite/util/TestTimelineServerView.java index bd80aee8cdad..58990e806617 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/util/TestTimelineServerView.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/util/TestTimelineServerView.java @@ -34,7 +34,6 @@ import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.TimelineLookup; -import javax.annotation.Nullable; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -97,7 +96,6 @@ public Optional> getTimeline(Da throw new UnsupportedOperationException(); } - @Nullable @Override public List getDruidServers() {