Skip to content

Commit

Permalink
pw_rpc: Increment call_ids for java client
Browse files Browse the repository at this point in the history
Change-Id: Ia4fd675adc4da9a62cac98e8a2d63d195cf3e750
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/258792
Commit-Queue: Kieran Cyphus <[email protected]>
Lint: Lint 🤖 <[email protected]>
Reviewed-by: Wyatt Hepler <[email protected]>
Presubmit-Verified: CQ Bot Account <[email protected]>
  • Loading branch information
kierancyphus authored and CQ Bot Account committed Jan 9, 2025
1 parent 8a5fc59 commit 15d4ae5
Show file tree
Hide file tree
Showing 12 changed files with 102 additions and 27 deletions.
4 changes: 1 addition & 3 deletions pw_rpc/design.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ In ``pw_rpc``, an RPC begins when the client sends an initial packet. The server
receives the packet, looks up the relevant service method, then calls into the
RPC function. The RPC is considered active until the server sends a status to
finish the RPC. The client may terminate an ongoing RPC by cancelling it.
Multiple concurrent RPC requests to the same method may be made simultaneously
(Note: Concurrent requests are not yet possible using the Java client. See
`Issue 237418397 <https://issues.pigweed.dev/issues/237418397>`_).
Multiple concurrent RPC requests to the same method may be made simultaneously.

Depending the type of RPC, the client and server exchange zero or more protobuf
request or response payloads. There are four RPC types:
Expand Down
3 changes: 3 additions & 0 deletions pw_rpc/java/main/dev/pigweed/pw_rpc/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ java_library(
"//third_party/google_auto:value",
"@com_google_protobuf//java/core",
artifact("com.google.code.findbugs:jsr305"),
artifact("com.google.errorprone:error_prone_annotations"),
artifact("com.google.guava:guava"),
],
)
Expand All @@ -69,6 +70,7 @@ java_library(
"//third_party/google_auto:value",
"@com_google_protobuf//java/lite",
artifact("com.google.code.findbugs:jsr305"),
artifact("com.google.errorprone:error_prone_annotations"),
artifact("com.google.guava:guava"),
],
)
Expand All @@ -84,6 +86,7 @@ android_library(
"//third_party/google_auto:value",
"@com_google_protobuf//java/lite",
artifact("com.google.code.findbugs:jsr305"),
artifact("com.google.errorprone:error_prone_annotations"),
artifact("com.google.guava:guava"),
],
)
35 changes: 28 additions & 7 deletions pw_rpc/java/main/dev/pigweed/pw_rpc/Endpoint.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package dev.pigweed.pw_rpc;

import com.google.errorprone.annotations.concurrent.GuardedBy;
import com.google.protobuf.ByteString;
import com.google.protobuf.MessageLite;
import dev.pigweed.pw_log.Logger;
Expand Down Expand Up @@ -42,12 +43,24 @@
class Endpoint {
private static final Logger logger = Logger.forClass(Endpoint.class);

// Call IDs are varint encoded. Limit the varint size to 2 bytes (14 usable bits).
private static final int MAX_CALL_ID = 1 << 14;

private final Map<Integer, Channel> channels;
private final Map<PendingRpc, AbstractCall<?, ?>> pending = new HashMap<>();
private final BlockingQueue<Runnable> callUpdates = new LinkedBlockingQueue<>();
private final int maxCallId;

@GuardedBy("this") private int nextCallId = 1;

public Endpoint(List<Channel> channels) {
this(channels, MAX_CALL_ID);
}

/** Create endpoint with {@code maxCallId} possible call_ids for testing purposes */
Endpoint(List<Channel> channels, int maxCallId) {
this.channels = channels.stream().collect(Collectors.toMap(Channel::id, c -> c));
this.maxCallId = maxCallId;
}

/**
Expand Down Expand Up @@ -99,15 +112,10 @@ public Endpoint(List<Channel> channels) {
throw InvalidRpcChannelException.unknown(channelId);
}

return createCall.apply(this, PendingRpc.create(channel, method));
return createCall.apply(this, PendingRpc.create(channel, method, getNewCallId()));
}

private void registerCall(AbstractCall<?, ?> call) {
// TODO(hepler): Use call_id to support simultaneous calls for the same RPC on one channel.
//
// Originally, only one call per service/method/channel was supported. With this restriction,
// the original call should have been aborted here, but was not. The client will be updated to
// support multiple simultaneous calls instead of aborting the call.
pending.put(call.rpc(), call);
}

Expand Down Expand Up @@ -254,7 +262,7 @@ public boolean processClientPacket(@Nullable Method method, RpcPacket packet) {
return true; // true since the packet was handled, even though it was invalid.
}

PendingRpc rpc = PendingRpc.create(channel, method);
PendingRpc rpc = PendingRpc.create(channel, method, packet.getCallId());
if (!updateCall(packet, rpc)) {
logger.atFine().log("Ignoring packet for %s, which isn't pending", rpc);
sendError(channel, packet, Status.FAILED_PRECONDITION);
Expand Down Expand Up @@ -308,4 +316,17 @@ private static Status decodeStatus(RpcPacket packet) {
}
return status;
}

/** Gets the next available call id and increments internal count for next call. */
private synchronized int getNewCallId() {
int callId = nextCallId;
nextCallId = (nextCallId + 1) % maxCallId;

// Skip call_id `0` to avoid confusion with legacy servers which use call_id `0` as
// an open call id or which do not provide call_id at all.
if (nextCallId == 0) {
nextCallId = 1;
}
return callId;
}
}
5 changes: 5 additions & 0 deletions pw_rpc/java/main/dev/pigweed/pw_rpc/Packets.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ private Packets() {}
public static byte[] request(PendingRpc rpc, MessageLite payload) {
RpcPacket.Builder builder = RpcPacket.newBuilder()
.setType(PacketType.REQUEST)
.setCallId(rpc.callId())
.setChannelId(rpc.channel().id())
.setServiceId(rpc.service().id())
.setMethodId(rpc.method().id());
Expand All @@ -37,6 +38,7 @@ public static byte[] request(PendingRpc rpc, MessageLite payload) {
public static byte[] cancel(PendingRpc rpc) {
return RpcPacket.newBuilder()
.setType(PacketType.CLIENT_ERROR)
.setCallId(rpc.callId())
.setChannelId(rpc.channel().id())
.setServiceId(rpc.service().id())
.setMethodId(rpc.method().id())
Expand All @@ -48,6 +50,7 @@ public static byte[] cancel(PendingRpc rpc) {
public static byte[] error(RpcPacket packet, Status status) {
return RpcPacket.newBuilder()
.setType(PacketType.CLIENT_ERROR)
.setCallId(packet.getCallId())
.setChannelId(packet.getChannelId())
.setServiceId(packet.getServiceId())
.setMethodId(packet.getMethodId())
Expand All @@ -59,6 +62,7 @@ public static byte[] error(RpcPacket packet, Status status) {
public static byte[] clientStream(PendingRpc rpc, MessageLite payload) {
return RpcPacket.newBuilder()
.setType(PacketType.CLIENT_STREAM)
.setCallId(rpc.callId())
.setChannelId(rpc.channel().id())
.setServiceId(rpc.service().id())
.setMethodId(rpc.method().id())
Expand All @@ -70,6 +74,7 @@ public static byte[] clientStream(PendingRpc rpc, MessageLite payload) {
public static byte[] clientStreamEnd(PendingRpc rpc) {
return RpcPacket.newBuilder()
.setType(PacketType.CLIENT_REQUEST_COMPLETION)
.setCallId(rpc.callId())
.setChannelId(rpc.channel().id())
.setServiceId(rpc.service().id())
.setMethodId(rpc.method().id())
Expand Down
18 changes: 10 additions & 8 deletions pw_rpc/java/main/dev/pigweed/pw_rpc/PendingRpc.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@
import com.google.auto.value.AutoValue;
import java.util.Locale;

/**
* Represents an active RPC invocation: channel + service + method.
*
* TODO(hepler): Use call ID to support multiple simultaneous calls to the same RPC on one channel.
*/
/** Represents an active RPC invocation: channel + service + method + call id. */
@AutoValue
abstract class PendingRpc {
static PendingRpc create(Channel channel, Method method) {
return new AutoValue_PendingRpc(channel, method);
// The default call id should always be 1 since it is the first id that is chosen by the endpoint.
static final int DEFAULT_CALL_ID = 1;

static PendingRpc create(Channel channel, Method method, int callId) {
return new AutoValue_PendingRpc(channel, method, callId);
}

public abstract Channel channel();
Expand All @@ -36,8 +35,11 @@ public final Service service() {

public abstract Method method();

public abstract int callId();

@Override
public final String toString() {
return String.format(Locale.ENGLISH, "PendingRpc[%s|channel=%d]", method(), channel().id());
return String.format(
Locale.ENGLISH, "PendingRpc[%s|channel=%d|callId=%d]", method(), channel().id(), callId());
}
}
1 change: 1 addition & 0 deletions pw_rpc/java/test/dev/pigweed/pw_rpc/ClientTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ private static byte[] serverReply(
private static RpcPacket.Builder packetBuilder(String service, String method) {
return RpcPacket.newBuilder()
.setChannelId(CHANNEL_ID)
.setCallId(PendingRpc.DEFAULT_CALL_ID)
.setServiceId(Ids.calculate(service))
.setMethodId(Ids.calculate(method));
}
Expand Down
40 changes: 38 additions & 2 deletions pw_rpc/java/test/dev/pigweed/pw_rpc/EndpointTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
import com.google.protobuf.MessageLite;
import dev.pigweed.pw_rpc.internal.Packet.PacketType;
import dev.pigweed.pw_rpc.internal.Packet.RpcPacket;
import dev.pigweed.pw_rpc.internal.Packet.RpcPacket.Builder;
import java.util.ArrayList;
import java.util.List;
import org.junit.Rule;
import org.junit.Test;
import org.mockito.Mock;
Expand All @@ -52,13 +55,21 @@ public final class EndpointTest {
private static final AnotherMessage RESPONSE_PAYLOAD =
AnotherMessage.newBuilder().setPayload("hello").build();
private static final int CHANNEL_ID = 555;
private static final int DEFAULT_CALL_ID = 1;
private static final int MAX_CALL_ID = 3;

@Mock private Channel.Output mockOutput;
@Mock private StreamObserver<MessageLite> callEvents;

private final Channel channel = new Channel(CHANNEL_ID, bytes -> mockOutput.send(bytes));
private final Endpoint endpoint = new Endpoint(ImmutableList.of(channel));

private final List<byte[]> sentPackets = new ArrayList<>();
private final Channel channelWithRecord =
new Channel(CHANNEL_ID, bytes -> sentPackets.add(bytes));
private final Endpoint endpointWithRecord =
new Endpoint(ImmutableList.of(channelWithRecord), MAX_CALL_ID);

private static byte[] request(MessageLite payload) {
return packetBuilder()
.setType(PacketType.REQUEST)
Expand All @@ -78,6 +89,7 @@ private static byte[] cancel() {
private static RpcPacket.Builder packetBuilder() {
return RpcPacket.newBuilder()
.setChannelId(CHANNEL_ID)
.setCallId(PendingRpc.DEFAULT_CALL_ID)
.setServiceId(SERVICE.id())
.setMethodId(METHOD.id());
}
Expand All @@ -95,6 +107,30 @@ public void start_succeeds_rpcIsPending() throws Exception {
assertThat(endpoint.abandon(call)).isTrue();
}

@Test
public void start_succeeds_callIdIsIncreasing() throws Exception {
AbstractCall<MessageLite, MessageLite> call1 =
endpointWithRecord.invokeRpc(CHANNEL_ID, METHOD, this::createCall, REQUEST_PAYLOAD);
AbstractCall<MessageLite, MessageLite> call2 =
endpointWithRecord.invokeRpc(CHANNEL_ID, METHOD, this::createCall, REQUEST_PAYLOAD);
RpcPacket packet1 = RpcPacket.parseFrom(sentPackets.get(0));
RpcPacket packet2 = RpcPacket.parseFrom(sentPackets.get(1));

assertThat(packet1.getCallId()).isLessThan(packet2.getCallId());
}

@Test
public void start_succeeds_callIdWraps() throws Exception {
for (int i = 0; i < MAX_CALL_ID; i++) {
endpointWithRecord.invokeRpc(CHANNEL_ID, METHOD, this::createCall, REQUEST_PAYLOAD);
}

RpcPacket firstPacket = RpcPacket.parseFrom(sentPackets.get(0));
RpcPacket lastPacket = RpcPacket.parseFrom(sentPackets.get(MAX_CALL_ID - 1));

assertThat(firstPacket.getCallId()).isEqualTo(lastPacket.getCallId());
}

@Test
public void start_sendingFails_callsHandleError() throws Exception {
doThrow(new ChannelOutputException()).when(mockOutput).send(any());
Expand Down Expand Up @@ -155,7 +191,7 @@ public void open_sendsNoPacketsButRpcIsPending() {
@Test
public void ignoresActionsIfCallIsNotPending() throws Exception {
AbstractCall<MessageLite, MessageLite> call =
createCall(endpoint, PendingRpc.create(channel, METHOD));
createCall(endpoint, PendingRpc.create(channel, METHOD, DEFAULT_CALL_ID));

assertThat(endpoint.cancel(call)).isFalse();
assertThat(endpoint.abandon(call)).isFalse();
Expand All @@ -166,7 +202,7 @@ public void ignoresActionsIfCallIsNotPending() throws Exception {
@Test
public void ignoresPacketsIfCallIsNotPending() throws Exception {
AbstractCall<MessageLite, MessageLite> call =
createCall(endpoint, PendingRpc.create(channel, METHOD));
createCall(endpoint, PendingRpc.create(channel, METHOD, DEFAULT_CALL_ID));

assertThat(endpoint.cancel(call)).isFalse();
assertThat(endpoint.abandon(call)).isFalse();
Expand Down
3 changes: 2 additions & 1 deletion pw_rpc/java/test/dev/pigweed/pw_rpc/FutureCallTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@ public final class FutureCallTest {
"SomeBidirectional", SomeMessage.parser(), AnotherMessage.parser()));
private static final Method METHOD = SERVICE.method("SomeUnary");
private static final int CHANNEL_ID = 555;
private static final int DEFAULT_CALL_ID = 1;

@Mock private Channel.Output mockOutput;

private final Channel channel = new Channel(CHANNEL_ID, packet -> mockOutput.send(packet));
private final Endpoint endpoint = new Endpoint(ImmutableList.of(channel));
private final PendingRpc rpc = PendingRpc.create(channel, METHOD);
private final PendingRpc rpc = PendingRpc.create(channel, METHOD, DEFAULT_CALL_ID);

@Test
public void unaryFuture_response_setsValue() throws Exception {
Expand Down
6 changes: 4 additions & 2 deletions pw_rpc/java/test/dev/pigweed/pw_rpc/PacketsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ public final class PacketsTest {
private static final Service SERVICE = new Service(
"Greetings", Service.unaryMethod("Hello", RpcPacket.parser(), RpcPacket.parser()));

private static final PendingRpc RPC =
PendingRpc.create(new Channel(123, null), SERVICE.method("Hello"));
private static final PendingRpc RPC = PendingRpc.create(
new Channel(123, null), SERVICE.method("Hello"), PendingRpc.DEFAULT_CALL_ID);

private static final RpcPacket PACKET = RpcPacket.newBuilder()
.setChannelId(123)
.setCallId(PendingRpc.DEFAULT_CALL_ID)
.setServiceId(RPC.service().id())
.setMethodId(RPC.method().id())
.build();
Expand Down Expand Up @@ -62,6 +63,7 @@ public void error() throws Exception {
private static RpcPacket.Builder packet() {
return RpcPacket.newBuilder()
.setChannelId(123)
.setCallId(PendingRpc.DEFAULT_CALL_ID)
.setServiceId(Ids.calculate("Greetings"))
.setMethodId(Ids.calculate("Hello"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ private static byte[] cancel() {
private static RpcPacket.Builder packetBuilder() {
return RpcPacket.newBuilder()
.setChannelId(CHANNEL_ID)
.setCallId(PendingRpc.DEFAULT_CALL_ID)
.setServiceId(SERVICE.id())
.setMethodId(METHOD.id());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ public final class StreamObserverMethodClientTest {
Service.bidirectionalStreamingMethod(
"SomeBidirectionalStreaming", SomeMessage.parser(), AnotherMessage.parser()));

private static final int DEFAULT_CALL_ID = 1;

@Rule public final MockitoRule mockito = MockitoJUnit.rule();

@Mock private StreamObserver<MessageLite> defaultObserver;
Expand All @@ -51,13 +53,14 @@ public final class StreamObserverMethodClientTest {
// Wrap Channel.Output since channelOutput will be null when the channel is initialized.
private final Channel channel = new Channel(1, bytes -> channelOutput.send(bytes));

private final PendingRpc unary_rpc = PendingRpc.create(channel, SERVICE.method("SomeUnary"));
private final PendingRpc unary_rpc =
PendingRpc.create(channel, SERVICE.method("SomeUnary"), DEFAULT_CALL_ID);
private final PendingRpc server_streaming_rpc =
PendingRpc.create(channel, SERVICE.method("SomeServerStreaming"));
PendingRpc.create(channel, SERVICE.method("SomeServerStreaming"), DEFAULT_CALL_ID);
private final PendingRpc client_streaming_rpc =
PendingRpc.create(channel, SERVICE.method("SomeClientStreaming"));
PendingRpc.create(channel, SERVICE.method("SomeClientStreaming"), DEFAULT_CALL_ID);
private final PendingRpc bidirectional_streaming_rpc =
PendingRpc.create(channel, SERVICE.method("SomeBidirectionalStreaming"));
PendingRpc.create(channel, SERVICE.method("SomeBidirectionalStreaming"), DEFAULT_CALL_ID);

private final Client client = Client.create(ImmutableList.of(channel), ImmutableList.of(SERVICE));
private MethodClient unaryMethodClient;
Expand Down Expand Up @@ -242,6 +245,7 @@ public void invalidMethod_throwsException() {
private static byte[] responsePacket(PendingRpc rpc, MessageLite payload) {
return RpcPacket.newBuilder()
.setChannelId(1)
.setCallId(PendingRpc.DEFAULT_CALL_ID)
.setServiceId(rpc.service().id())
.setMethodId(rpc.method().id())
.setType(PacketType.RESPONSE)
Expand Down
1 change: 1 addition & 0 deletions pw_rpc/java/test/dev/pigweed/pw_rpc/TestClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ private void processPacket(RpcPacket.Builder packet) {
private static RpcPacket.Builder startPacket(String service, String method, PacketType type) {
return RpcPacket.newBuilder()
.setType(type)
.setCallId(PendingRpc.DEFAULT_CALL_ID)
.setChannelId(CHANNEL_ID)
.setServiceId(Ids.calculate(service))
.setMethodId(Ids.calculate(method));
Expand Down

0 comments on commit 15d4ae5

Please sign in to comment.