Skip to content

Commit

Permalink
Add java client connection layer. (#670)
Browse files Browse the repository at this point in the history
* Add java client connection layer.

Signed-off-by: Yury-Fridlyand <[email protected]>
  • Loading branch information
Yury-Fridlyand authored Dec 19, 2023
1 parent 825c97c commit 524c084
Show file tree
Hide file tree
Showing 9 changed files with 383 additions and 16 deletions.
16 changes: 13 additions & 3 deletions java/client/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,25 @@ dependencies {
implementation group: 'io.netty', name: 'netty-transport-native-epoll', version: '4.1.100.Final', classifier: 'linux-x86_64'
implementation group: 'io.netty', name: 'netty-transport-native-kqueue', version: '4.1.100.Final', classifier: 'osx-x86_64'
implementation group: 'io.netty', name: 'netty-transport-native-kqueue', version: '4.1.100.Final', classifier: 'osx-aarch_64'

//lombok
compileOnly 'org.projectlombok:lombok:1.18.30'
annotationProcessor 'org.projectlombok:lombok:1.18.30'
testCompileOnly 'org.projectlombok:lombok:1.18.30'
testAnnotationProcessor 'org.projectlombok:lombok:1.18.30'

// junit
testImplementation('org.junit.jupiter:junit-jupiter:5.6.2')
testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '3.12.4'
}

tasks.register('protobuf', Exec) {
doFirst {
project.mkdir(Paths.get(project.projectDir.path, 'src/main/java/babushka/protobuf').toString())
project.mkdir(Paths.get(project.projectDir.path, 'src/main/java/babushka/models/protobuf').toString())
}
commandLine 'protoc',
'-Iprotobuf=babushka-core/src/protobuf/',
'--java_out=java/client/src/main/java/babushka/protobuf',
'--java_out=java/client/src/main/java/babushka/models/protobuf',
'babushka-core/src/protobuf/connection_request.proto',
'babushka-core/src/protobuf/redis_request.proto',
'babushka-core/src/protobuf/response.proto'
Expand All @@ -35,7 +45,7 @@ tasks.register('protobuf', Exec) {

tasks.register('cleanProtobuf') {
doFirst {
project.delete(Paths.get(project.projectDir.path, 'src/main/java/babushka/protobuf').toString())
project.delete(Paths.get(project.projectDir.path, 'src/main/java/babushka/models/protobuf').toString())
}
}

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package babushka.connectors.handlers;

import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.lang3.tuple.Pair;
import response.ResponseOuterClass.Response;

/** Holder for resources required to dispatch responses and used by {@link ReadHandler}. */
public class CallbackDispatcher {
/** Unique request ID (callback ID). Thread-safe. */
private final AtomicInteger requestId = new AtomicInteger(0);

/**
* Storage of Futures to handle responses. Map key is callback id, which starts from 1.<br>
* Each future is a promise for every submitted by user request.
*/
private final Map<Integer, CompletableFuture<Response>> responses = new ConcurrentHashMap<>();

/**
* Storage for connection request similar to {@link #responses}. Unfortunately, connection
* requests can't be stored in the same storage, because callback ID = 0 is hardcoded for
* connection requests.
*/
private final CompletableFuture<Response> connectionPromise = new CompletableFuture<>();

/**
* Register a new request to be sent. Once response received, the given future completes with it.
*
* @return A pair of unique callback ID which should set into request and a client promise for
* response.
*/
public Pair<Integer, CompletableFuture<Response>> registerRequest() {
int callbackId = requestId.incrementAndGet();
var future = new CompletableFuture<Response>();
responses.put(callbackId, future);
return Pair.of(callbackId, future);
}

public CompletableFuture<Response> registerConnection() {
return connectionPromise;
}

/**
* Complete the corresponding client promise and free resources.
*
* @param response A response received
*/
public void completeRequest(Response response) {
int callbackId = response.getCallbackIdx();
if (callbackId == 0) {
connectionPromise.completeAsync(() -> response);
} else {
responses.get(callbackId).completeAsync(() -> response);
responses.remove(callbackId);
}
}

public void shutdownGracefully() {
connectionPromise.cancel(false);
responses.values().forEach(future -> future.cancel(false));
responses.clear();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package babushka.connectors.handlers;

import babushka.connectors.resources.Platform;
import connection_request.ConnectionRequestOuterClass.ConnectionRequest;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.unix.DomainSocketAddress;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import redis_request.RedisRequestOuterClass.RedisRequest;
import response.ResponseOuterClass.Response;

/**
* Class responsible for handling calls to/from a netty.io {@link Channel}.<br>
* Uses a {@link CallbackDispatcher} to record callbacks of every request sent.
*/
public class ChannelHandler {

private static final String THREAD_POOL_NAME = "babushka-channel";

private final Channel channel;
private final CallbackDispatcher callbackDispatcher;

/** Open a new channel for a new client. */
public ChannelHandler(CallbackDispatcher callbackDispatcher, String socketPath) {
channel =
new Bootstrap()
// TODO let user specify the thread pool or pool size as an option
.group(Platform.createNettyThreadPool(THREAD_POOL_NAME, Optional.empty()))
.channel(Platform.getClientUdsNettyChannelType())
.handler(new ProtobufSocketChannelInitializer(callbackDispatcher))
.connect(new DomainSocketAddress(socketPath))
// TODO call here .sync() if needed or remove this comment
.channel();
this.callbackDispatcher = callbackDispatcher;
}

/**
* Complete a protobuf message and write it to the channel (to UDS).
*
* @param request Incomplete request, function completes it by setting callback ID
* @param flush True to flush immediately
* @return A response promise
*/
public CompletableFuture<Response> write(RedisRequest.Builder request, boolean flush) {
var commandId = callbackDispatcher.registerRequest();
request.setCallbackIdx(commandId.getKey());

if (flush) {
channel.writeAndFlush(request.build());
} else {
channel.write(request.build());
}
return commandId.getValue();
}

/**
* Write a protobuf message to the channel (to UDS).
*
* @param request A connection request
* @return A connection promise
*/
public CompletableFuture<Response> connect(ConnectionRequest request) {
channel.writeAndFlush(request);
return callbackDispatcher.registerConnection();
}

private final AtomicBoolean closed = new AtomicBoolean(false);

/** Closes the UDS connection and frees corresponding resources. */
public void close() {
if (closed.compareAndSet(false, true)) {
channel.close();
callbackDispatcher.shutdownGracefully();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package babushka.connectors.handlers;

import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.unix.UnixChannel;
import io.netty.handler.codec.protobuf.ProtobufDecoder;
import io.netty.handler.codec.protobuf.ProtobufEncoder;
import io.netty.handler.codec.protobuf.ProtobufVarint32FrameDecoder;
import io.netty.handler.codec.protobuf.ProtobufVarint32LengthFieldPrepender;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import response.ResponseOuterClass.Response;

/** Builder for the channel used by {@link ChannelHandler}. */
@RequiredArgsConstructor
public class ProtobufSocketChannelInitializer extends ChannelInitializer<UnixChannel> {

private final CallbackDispatcher callbackDispatcher;

@Override
public void initChannel(@NonNull UnixChannel ch) {
ch.pipeline()
// https://netty.io/4.1/api/io/netty/handler/codec/protobuf/ProtobufEncoder.html
.addLast("frameDecoder", new ProtobufVarint32FrameDecoder())
.addLast("frameEncoder", new ProtobufVarint32LengthFieldPrepender())
.addLast("protobufDecoder", new ProtobufDecoder(Response.getDefaultInstance()))
.addLast("protobufEncoder", new ProtobufEncoder())
.addLast(new ReadHandler(callbackDispatcher))
.addLast(new ChannelOutboundHandlerAdapter());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package babushka.connectors.handlers;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import response.ResponseOuterClass.Response;

/** Handler for inbound traffic though UDS. Used by Netty. */
@RequiredArgsConstructor
public class ReadHandler extends ChannelInboundHandlerAdapter {

private final CallbackDispatcher callbackDispatcher;

/** Submit responses from babushka to an instance {@link CallbackDispatcher} to handle them. */
@Override
public void channelRead(@NonNull ChannelHandlerContext ctx, @NonNull Object msg) {
callbackDispatcher.completeRequest((Response) msg);
}

/** Handles uncaught exceptions from {@link #channelRead(ChannelHandlerContext, Object)}. */
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
System.out.printf("=== exceptionCaught %s %s %n", ctx, cause);
cause.printStackTrace(System.err);
super.exceptionCaught(ctx, cause);
}
}
139 changes: 139 additions & 0 deletions java/client/src/main/java/babushka/connectors/resources/Platform.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package babushka.connectors.resources;

import io.netty.channel.EventLoopGroup;
import io.netty.channel.epoll.Epoll;
import io.netty.channel.epoll.EpollDomainSocketChannel;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.kqueue.KQueue;
import io.netty.channel.kqueue.KQueueDomainSocketChannel;
import io.netty.channel.kqueue.KQueueEventLoopGroup;
import io.netty.channel.unix.DomainSocketChannel;
import io.netty.util.concurrent.DefaultThreadFactory;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.UtilityClass;

/**
* An auxiliary class purposed to detect platform (OS + JVM) {@link Capabilities} and allocate
* corresponding resources.
*/
@UtilityClass
public class Platform {

@Getter
@AllArgsConstructor(access = AccessLevel.PRIVATE)
@ToString
private static class Capabilities {
private final boolean isKQueueAvailable;
private final boolean isEPollAvailable;
// TODO support IO-Uring and NIO
private final boolean isIOUringAvailable;
// At the moment, Windows is not supported
// Probably we should use NIO (NioEventLoopGroup) for Windows.
private final boolean isNIOAvailable;
}

/** Detected platform (OS + JVM) capabilities. Not supposed to be changed in runtime. */
@Getter
private static final Capabilities capabilities =
new Capabilities(isKQueueAvailable(), isEPollAvailable(), false, false);

/**
* Thread pools supplied to <em>Netty</em> to perform all async IO.<br>
* Map key is supposed to be pool name + thread count as a string concat product.
*/
private static final Map<String, EventLoopGroup> groups = new ConcurrentHashMap<>();

/** Detect <em>kqueue</em> availability. */
private static boolean isKQueueAvailable() {
try {
Class.forName("io.netty.channel.kqueue.KQueue");
return KQueue.isAvailable();
} catch (ClassNotFoundException e) {
return false;
}
}

/** Detect <em>epoll</em> availability. */
private static boolean isEPollAvailable() {
try {
Class.forName("io.netty.channel.epoll.Epoll");
return Epoll.isAvailable();
} catch (ClassNotFoundException e) {
return false;
}
}

/**
* Allocate Netty thread pool required to manage connection. A thread pool could be shared across
* multiple connections.
*
* @return A new thread pool.
*/
public static EventLoopGroup createNettyThreadPool(String prefix, Optional<Integer> threadLimit) {
int threadCount = threadLimit.orElse(Runtime.getRuntime().availableProcessors());
if (capabilities.isKQueueAvailable()) {
var name = prefix + "-kqueue-elg";
return getOrCreate(
name + threadCount,
() -> new KQueueEventLoopGroup(threadCount, new DefaultThreadFactory(name, true)));
} else if (capabilities.isEPollAvailable()) {
var name = prefix + "-epoll-elg";
return getOrCreate(
name + threadCount,
() -> new EpollEventLoopGroup(threadCount, new DefaultThreadFactory(name, true)));
}
// TODO support IO-Uring and NIO

throw new RuntimeException("Current platform supports no known thread pool types");
}

/**
* Get a cached thread pool from {@link #groups} or create a new one by given lambda and cache.
*/
private static EventLoopGroup getOrCreate(String name, Supplier<EventLoopGroup> supplier) {
if (groups.containsKey(name)) {
return groups.get(name);
}
EventLoopGroup group = supplier.get();
groups.put(name, group);
return group;
}

/**
* Get a channel class required by Netty to open a client UDS channel.
*
* @return Return a class supported by the current platform.
*/
public static Class<? extends DomainSocketChannel> getClientUdsNettyChannelType() {
if (capabilities.isKQueueAvailable()) {
return KQueueDomainSocketChannel.class;
}
if (capabilities.isEPollAvailable()) {
return EpollDomainSocketChannel.class;
}
throw new RuntimeException("Current platform supports no known socket types");
}

/**
* A JVM shutdown hook to be registered. It is responsible for closing connection and freeing
* resources. It is recommended to use a class instead of lambda to ensure that it is called.<br>
* See {@link Runtime#addShutdownHook}.
*/
private static class ShutdownHook implements Runnable {
@Override
public void run() {
groups.values().forEach(EventLoopGroup::shutdownGracefully);
}
}

static {
Runtime.getRuntime().addShutdownHook(new Thread(new ShutdownHook(), "Babushka-shutdown-hook"));
}
}
Loading

0 comments on commit 524c084

Please sign in to comment.