Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #11307 - Explicit demand control in WebSocket endpoints with only onWebSocketFrame #12342

Open
wants to merge 1 commit into
base: jetty-12.1.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,34 @@ public void fail(Throwable x)
};
}

/**
* Creates a nested callback that runs completed after
* completing the nested callback.
*
* @param callback The nested callback
* @param completed The completion to run after the nested callback is completed
* @return a new callback.
*/
static Callback from(Callback callback, Runnable completed)
{
return new Callback()
{
@Override
public void succeed()
{
callback.succeed();
completed.run();
}

@Override
public void fail(Throwable x)
{
callback.fail(x);
completed.run();
}
};
}

/**
* <p>Method to invoke to succeed the callback.</p>
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,111 @@ public String toString()
boolean isRsv2();

boolean isRsv3();

class Wrapper implements Frame
{
private final Frame _frame;

public Wrapper(Frame frame)
{
_frame = frame;
}

@Override
public byte[] getMask()
{
return _frame.getMask();
}

@Override
public byte getOpCode()
{
return _frame.getOpCode();
}

@Override
public ByteBuffer getPayload()
{
return _frame.getPayload();
}

@Override
public int getPayloadLength()
{
return _frame.getPayloadLength();
}

@Override
public Type getType()
{
return _frame.getType();
}

@Override
public boolean hasPayload()
{
return _frame.hasPayload();
}

@Override
public boolean isFin()
{
return _frame.isFin();
}

@Override
public boolean isMasked()
{
return _frame.isMasked();
}

@Override
public boolean isRsv1()
{
return _frame.isRsv1();
}

@Override
public boolean isRsv2()
{
return _frame.isRsv2();
}

@Override
public boolean isRsv3()
{
return _frame.isRsv3();
}
}

static Frame copy(Frame frame)
{
ByteBuffer payloadCopy = copy(frame.getPayload());
return new Frame.Wrapper(frame)
{
@Override
public ByteBuffer getPayload()
{
return payloadCopy;
}

@Override
public int getPayloadLength()
{
return payloadCopy == null ? 0 : payloadCopy.remaining();
}
};
}

private static ByteBuffer copy(ByteBuffer buffer)
{
if (buffer == null)
return null;
int p = buffer.position();
ByteBuffer clone = buffer.isDirect() ? ByteBuffer.allocateDirect(buffer.remaining()) : ByteBuffer.allocate(buffer.remaining());
clone.put(buffer);
clone.flip();
buffer.position(p);
return clone;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -206,18 +206,60 @@ public void onFrame(Frame frame, Callback coreCallback)
coreCallback.failed(new WebSocketException(endpointInstance.getClass().getSimpleName() + " FRAME method error: " + cause.getMessage(), cause));
return;
}

switch (frame.getOpCode())
{
case OpCode.TEXT ->
{
if (textHandle == null)
autoDemand();
}
case OpCode.BINARY ->
{
if (binaryHandle == null)
autoDemand();
}
case OpCode.CONTINUATION ->
{
if (activeMessageSink == null)
autoDemand();
}
case OpCode.PING ->
{
if (pingHandle == null)
autoDemand();
}
case OpCode.PONG ->
{
if (pongHandle == null)
autoDemand();
}
case OpCode.CLOSE ->
{
// Do nothing.
}
default ->
{
coreCallback.failed(new IllegalStateException());
return;
}
};
Comment on lines +210 to +246
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like putting this logic here.... I think similar behaviour can be implemented in the individual methods which can do something like:

        if (xyzHandle != null)
        {
            xyzHandler.invoke(...);
            callback.succeeded();
            autoDemand();
        }
        else 
        {
            callback.succeeded();
            if (frameHandle == null)
                internalDemand();
            else
                autoDemand();
        }

}

Callback.Completable eventCallback = new Callback.Completable();
switch (frame.getOpCode())
{
case OpCode.CLOSE -> onCloseFrame(frame, eventCallback);
case OpCode.PING -> onPingFrame(frame, eventCallback);
case OpCode.PONG -> onPongFrame(frame, eventCallback);
case OpCode.TEXT -> onTextFrame(frame, eventCallback);
case OpCode.BINARY -> onBinaryFrame(frame, eventCallback);
case OpCode.CONTINUATION -> onContinuationFrame(frame, eventCallback);
default -> coreCallback.failed(new IllegalStateException());
case OpCode.PING -> onPingFrame(frame, eventCallback);
case OpCode.PONG -> onPongFrame(frame, eventCallback);
case OpCode.CLOSE -> onCloseFrame(frame, eventCallback);
default ->
{
coreCallback.failed(new IllegalStateException());
return;
}
};

// Combine the callback from the frame handler and the event handler.
Expand Down Expand Up @@ -315,6 +357,13 @@ private void onPingFrame(Frame frame, Callback callback)
}
else
{
// If we have a frameHandler it takes responsibility for handling the ping and demanding.
if (frameHandle != null)
{
callback.succeeded();
return;
}

// Automatically respond.
getSession().sendPong(frame.getPayload(), new org.eclipse.jetty.websocket.api.Callback()
{
Expand Down Expand Up @@ -358,7 +407,10 @@ private void onPongFrame(Frame frame, Callback callback)
}
else
{
internalDemand();
// If we have a frameHandler it takes responsibility for handling the pong and demanding.
callback.succeeded();
if (frameHandle == null)
internalDemand();
}
}

Expand Down Expand Up @@ -387,7 +439,8 @@ private void acceptFrame(Frame frame, Callback callback)
if (activeMessageSink == null)
{
callback.succeeded();
internalDemand();
if (frameHandle == null)
internalDemand();
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,29 @@ public void onMessage(String message) throws IOException
public static class ListenerSocket implements Session.Listener
{
final List<Frame> frames = new CopyOnWriteArrayList<>();
Session session;

@Override
public void onWebSocketOpen(Session session)
{
this.session = session;
session.demand();
}

@Override
public void onWebSocketFrame(Frame frame, Callback callback)
{
frames.add(frame);
frames.add(Frame.copy(frame));

// Because no pingListener is registered, the frameListener is responsible for handling pings.
if (frame.getOpCode() == OpCode.PING)
{
session.sendPong(frame.getPayload(), Callback.from(callback, session::demand));
return;
}

callback.succeed();
session.demand();
}
}

Expand Down Expand Up @@ -109,27 +126,19 @@ public void onWebSocketFrame(Frame frame, Callback callback)
if (frame.getOpCode() == OpCode.TEXT)
textMessages.add(BufferUtil.toString(frame.getPayload()));
callback.succeed();
session.demand();
}
}

@WebSocket(autoDemand = false)
public static class PingSocket extends ListenerSocket
{
Session session;

@Override
public void onWebSocketOpen(Session session)
{
this.session = session;
session.demand();
}

@Override
public void onWebSocketFrame(Frame frame, Callback callback)
{
super.onWebSocketFrame(frame, callback);
if (frame.getType() == Frame.Type.TEXT)
session.sendPing(ByteBuffer.wrap("server-ping".getBytes(StandardCharsets.UTF_8)), Callback.NOOP);
super.onWebSocketFrame(frame, callback);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,18 @@ public void onWebSocketFrame(Frame frame, Callback callback)
{
switch (frame.getOpCode())
{
case OpCode.PING -> pingMessages.add(BufferUtil.copy(frame.getPayload()));
case OpCode.PONG -> pongMessages.add(BufferUtil.copy(frame.getPayload()));
case OpCode.PING ->
{
pingMessages.add(BufferUtil.copy(frame.getPayload()));
session.sendPong(frame.getPayload(), callback);
}
case OpCode.PONG ->
{
pongMessages.add(BufferUtil.copy(frame.getPayload()));
callback.succeed();
}
default -> callback.succeed();
}
callback.succeed();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,12 @@ public static class FrameEndpoint implements Session.Listener
{
public CountDownLatch closeLatch = new CountDownLatch(1);
public LinkedBlockingQueue<String> frameEvents = new LinkedBlockingQueue<>();
public Session session;

@Override
public void onWebSocketOpen(Session session)
{
this.session = session;
session.demand();
}

Expand All @@ -147,6 +149,7 @@ public void onWebSocketFrame(Frame frame, Callback callback)
BufferUtil.toUTF8String(frame.getPayload()),
frame.getPayloadLength()));
callback.succeed();
session.demand();
}

@Override
Expand Down
Loading