Skip to content

Commit

Permalink
🔀 Merge pull request Pincer-org#160 from mwath/main
Browse files Browse the repository at this point in the history
✨ Support Discord Gateway Event Compression
  • Loading branch information
Arthurdw committed Oct 25, 2021
2 parents 82d873f + 9894aac commit c79b305
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 9 deletions.
23 changes: 16 additions & 7 deletions pincer/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Full MIT License can be found in `LICENSE` at the project root.

from dataclasses import dataclass
from typing import Optional


@dataclass
Expand All @@ -13,16 +14,24 @@ class GatewayConfig:
socket_base_url: str = "wss://gateway.discord.gg/"
version: int = 9
encoding: str = "json"
compression: str = "zlib-stream"
compression: Optional[str] = "zlib-stream"

@staticmethod
def uri() -> str:
@classmethod
def uri(cls) -> str:
"""
:return uri:
The GatewayConfig's uri.
"""
return (
f"{GatewayConfig.socket_base_url}"
f"?v={GatewayConfig.version}"
f"&encoding={GatewayConfig.encoding}"
)
f"{cls.socket_base_url}"
f"?v={cls.version}"
f"&encoding={cls.encoding}"
) + f"&compress={cls.compression}" * cls.compressed()

@classmethod
def compressed(cls) -> bool:
"""
:return compressed:
Whether the Gateway should compress payloads or not.
"""
return cls.compression in ["zlib-stream", "zlib-payload"]
26 changes: 24 additions & 2 deletions pincer/core/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import logging
import zlib
from asyncio import get_event_loop, AbstractEventLoop, ensure_future
from platform import system
from typing import Dict, Callable, Awaitable, Optional
Expand All @@ -23,6 +24,8 @@
)
from ..objects import Intents

ZLIB_SUFFIX = b'\x00\x00\xff\xff'

Handler = Callable[[WebSocketClientProtocol, GatewayDispatch], Awaitable[None]]
_log = logging.getLogger(__package__)

Expand Down Expand Up @@ -133,7 +136,8 @@ def __hello_socket(self) -> str:
"$os": system(),
"$browser": __package__,
"$device": __package__
}
},
"compress": GatewayConfig.compressed()
}
)
)
Expand Down Expand Up @@ -216,12 +220,30 @@ async def __dispatcher(self, loop: AbstractEventLoop):
GatewayConfig.uri()
)

if GatewayConfig.compression == "zlib-stream":
# Create an inflator for compressed data as defined in
# https://discord.com/developers/docs/topics/gateway
inflator = zlib.decompressobj()

while self.__keep_alive:
try:
_log.debug("Waiting for new event.")
msg = await socket.recv()

if isinstance(msg, bytes):
if GatewayConfig.compression == "zlib-payload":
msg = zlib.decompress(msg)
else:
buffer = bytearray(msg)

while not buffer.endswith(ZLIB_SUFFIX):
buffer.extend(await socket.recv())

msg = inflator.decompress(buffer).decode('utf-8')

await self.__handler_manager(
socket,
GatewayDispatch.from_string(await socket.recv()),
GatewayDispatch.from_string(msg),
loop
)

Expand Down

0 comments on commit c79b305

Please sign in to comment.