diff --git a/experiments/compression/corpus.py b/experiments/compression/corpus.py index da5661df..56e26211 100644 --- a/experiments/compression/corpus.py +++ b/experiments/compression/corpus.py @@ -47,6 +47,6 @@ def main(corpus): if __name__ == "__main__": if len(sys.argv) < 2: - print(f"Usage: {sys.argv[0]} [directory]") + print(f"Usage: {sys.argv[0]} ") sys.exit(2) main(pathlib.Path(sys.argv[1])) diff --git a/experiments/profiling/compression.py b/experiments/profiling/compression.py new file mode 100644 index 00000000..1ece1f10 --- /dev/null +++ b/experiments/profiling/compression.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python + +""" +Profile the permessage-deflate extension. + +Usage:: + $ pip install line_profiler + $ python experiments/compression/corpus.py experiments/compression/corpus + $ PYTHONPATH=src python -m kernprof \ + --line-by-line \ + --prof-mod src/websockets/extensions/permessage_deflate.py \ + --view \ + experiments/profiling/compression.py experiments/compression/corpus 12 5 6 + +""" + +import pathlib +import sys + +from websockets.extensions.permessage_deflate import PerMessageDeflate +from websockets.frames import OP_TEXT, Frame + + +def compress_and_decompress(corpus, max_window_bits, memory_level, level): + extension = PerMessageDeflate( + remote_no_context_takeover=False, + local_no_context_takeover=False, + remote_max_window_bits=max_window_bits, + local_max_window_bits=max_window_bits, + compress_settings={"memLevel": memory_level, "level": level}, + ) + for data in corpus: + frame = Frame(OP_TEXT, data) + frame = extension.encode(frame) + frame = extension.decode(frame) + + +if __name__ == "__main__": + if len(sys.argv) < 2 or not pathlib.Path(sys.argv[1]).is_dir(): + print(f"Usage: {sys.argv[0]} [] []") + corpus = [file.read_bytes() for file in pathlib.Path(sys.argv[1]).iterdir()] + max_window_bits = int(sys.argv[2]) if len(sys.argv) > 2 else 12 + memory_level = int(sys.argv[3]) if len(sys.argv) > 3 else 5 + level = int(sys.argv[4]) if len(sys.argv) > 4 else 6 + compress_and_decompress(corpus, max_window_bits, memory_level, level) diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index f962b65f..21df804f 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -1,6 +1,5 @@ from __future__ import annotations -import dataclasses import zlib from collections.abc import Sequence from typing import Any @@ -120,7 +119,6 @@ def decode( else: if not frame.rsv1: return frame - frame = dataclasses.replace(frame, rsv1=False) if not frame.fin: self.decode_cont_data = True @@ -146,7 +144,15 @@ def decode( if frame.fin and self.remote_no_context_takeover: del self.decoder - return dataclasses.replace(frame, data=data) + return frames.Frame( + frame.opcode, + data, + frame.fin, + # Unset the rsv1 flag on the first frame of a compressed message. + False, + frame.rsv2, + frame.rsv3, + ) def encode(self, frame: frames.Frame) -> frames.Frame: """ @@ -161,8 +167,6 @@ def encode(self, frame: frames.Frame) -> frames.Frame: # data" flag similar to "decode continuation data" at this time. if frame.opcode is not frames.OP_CONT: - # Set the rsv1 flag on the first frame of a compressed message. - frame = dataclasses.replace(frame, rsv1=True) # Re-initialize per-message decoder. if self.local_no_context_takeover: self.encoder = zlib.compressobj( @@ -172,14 +176,25 @@ def encode(self, frame: frames.Frame) -> frames.Frame: # Compress data. data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH) - if frame.fin and data.endswith(_EMPTY_UNCOMPRESSED_BLOCK): + if frame.fin and data[-4:] == _EMPTY_UNCOMPRESSED_BLOCK: + # Making a copy is faster than memoryview(a)[:-4] until about 2kB. + # On larger messages, it's slower but profiling shows that it's + # marginal compared to compress() and flush(). Keep it simple. data = data[:-4] # Allow garbage collection of the encoder if it won't be reused. if frame.fin and self.local_no_context_takeover: del self.encoder - return dataclasses.replace(frame, data=data) + return frames.Frame( + frame.opcode, + data, + frame.fin, + # Set the rsv1 flag on the first frame of a compressed message. + frame.opcode is not frames.OP_CONT, + frame.rsv2, + frame.rsv3, + ) def _build_parameters(