diff --git a/tornado/websocket.py b/tornado/websocket.py index d2c6a427aa..879dc92e97 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -310,7 +310,8 @@ def max_message_size(self) -> int: ) def write_message( - self, message: Union[bytes, str, Dict[str, Any]], binary: bool = False + self, message: Union[bytes, str, Dict[str, Any]], binary: bool = False, + return_future: bool = True ) -> "Future[None]": """Sends the given message to the client of this Web Socket. @@ -337,7 +338,8 @@ def write_message( raise WebSocketClosedError() if isinstance(message, dict): message = tornado.escape.json_encode(message) - return self.ws_connection.write_message(message, binary=binary) + return self.ws_connection.write_message(message, binary=binary, + return_future=return_future) def select_subprotocol(self, subprotocols: List[str]) -> Optional[str]: """Override to implement subprotocol negotiation. @@ -680,7 +682,8 @@ async def accept_connection(self, handler: WebSocketHandler) -> None: @abc.abstractmethod def write_message( - self, message: Union[str, bytes], binary: bool = False + self, message: Union[str, bytes], binary: bool = False, + return_future: bool = True ) -> "Future[None]": raise NotImplementedError() @@ -1072,7 +1075,8 @@ def _write_frame( return self.stream.write(frame) def write_message( - self, message: Union[str, bytes], binary: bool = False + self, message: Union[str, bytes], binary: bool = False, + return_future: bool = True ) -> "Future[None]": """Sends the given message to the client of this Web Socket.""" if binary: @@ -1096,13 +1100,14 @@ def write_message( except StreamClosedError: raise WebSocketClosedError() - async def wrapper() -> None: - try: - await fut - except StreamClosedError: - raise WebSocketClosedError() + if return_future: + async def wrapper() -> None: + try: + await fut + except StreamClosedError: + raise WebSocketClosedError() - return asyncio.ensure_future(wrapper()) + return asyncio.ensure_future(wrapper()) def write_ping(self, data: bytes) -> None: """Send ping frame.""" @@ -1493,7 +1498,8 @@ async def headers_received( future_set_result_unless_cancelled(self.connect_future, self) def write_message( - self, message: Union[str, bytes], binary: bool = False + self, message: Union[str, bytes], binary: bool = False, + return_future: bool = True, ) -> "Future[None]": """Sends a message to the WebSocket server. @@ -1504,7 +1510,8 @@ def write_message( Exception raised on a closed stream changed from `.StreamClosedError` to `WebSocketClosedError`. """ - return self.protocol.write_message(message, binary=binary) + return self.protocol.write_message(message, binary=binary, + return_future=return_future) def read_message( self,