diff --git a/plugin/guesslang/client.py b/plugin/guesslang/client.py index 261fa9e4..811a8cd1 100644 --- a/plugin/guesslang/client.py +++ b/plugin/guesslang/client.py @@ -1,7 +1,7 @@ from __future__ import annotations import threading -from typing import Protocol +from typing import Any, Protocol import sublime @@ -25,10 +25,13 @@ def on_close(self, ws: websocket.WebSocketApp, close_status_code: int, close_msg class NullTransportCallbacks: - on_open = None - on_message = None - on_error = None - on_close = None + def _(*args: Any) -> None: + pass + + on_open = _ + on_message = _ + on_error = _ + on_close = _ class GuesslangClient: diff --git a/plugin/libs/websocket/__init__.py b/plugin/libs/websocket/__init__.py index a9fa4634..c186ace8 100644 --- a/plugin/libs/websocket/__init__.py +++ b/plugin/libs/websocket/__init__.py @@ -2,7 +2,7 @@ __init__.py websocket - WebSocket client library for Python -Copyright 2021 engn33r +Copyright 2023 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,10 +17,10 @@ limitations under the License. """ from ._abnf import * -from ._app import WebSocketApp +from ._app import WebSocketApp, setReconnect from ._core import * from ._exceptions import * from ._logging import * from ._socket import * -__version__ = "1.2.1" +__version__ = "1.6.4" diff --git a/plugin/libs/websocket/_abnf.py b/plugin/libs/websocket/_abnf.py index 6a4d4907..a1c6f5a6 100644 --- a/plugin/libs/websocket/_abnf.py +++ b/plugin/libs/websocket/_abnf.py @@ -1,12 +1,19 @@ -""" +import array +import os +import struct +import sys -""" +from threading import Lock +from typing import Callable, Union + +from ._exceptions import * +from ._utils import validate_utf8 """ _abnf.py websocket - WebSocket client library for Python -Copyright 2021 engn33r +Copyright 2023 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,14 +27,6 @@ See the License for the specific language governing permissions and limitations under the License. """ -import array -import os -import struct -import sys - -from ._exceptions import * -from ._utils import validate_utf8 -from threading import Lock try: # If wsaccel is available, use compiled routines to mask data. @@ -36,18 +35,18 @@ # Note that wsaccel is unmaintained. from wsaccel.xormask import XorMaskerSimple - def _mask(_m, _d): + def _mask(_m, _d) -> bytes: return XorMaskerSimple(_m).process(_d) except ImportError: # wsaccel is not available, use websocket-client _mask() native_byteorder = sys.byteorder - def _mask(mask_value, data_value): + def _mask(mask_value: array.array, data_value: array.array) -> bytes: datalen = len(data_value) - data_value = int.from_bytes(data_value, native_byteorder) - mask_value = int.from_bytes(mask_value * (datalen // 4) + mask_value[: datalen % 4], native_byteorder) - return (data_value ^ mask_value).to_bytes(datalen, native_byteorder) + int_data_value = int.from_bytes(data_value, native_byteorder) + int_mask_value = int.from_bytes(mask_value * (datalen // 4) + mask_value[: datalen % 4], native_byteorder) + return (int_data_value ^ int_mask_value).to_bytes(datalen, native_byteorder) __all__ = [ @@ -79,6 +78,8 @@ def _mask(mask_value, data_value): STATUS_MESSAGE_TOO_BIG = 1009 STATUS_INVALID_EXTENSION = 1010 STATUS_UNEXPECTED_CONDITION = 1011 +STATUS_SERVICE_RESTART = 1012 +STATUS_TRY_AGAIN_LATER = 1013 STATUS_BAD_GATEWAY = 1014 STATUS_TLS_HANDSHAKE_ERROR = 1015 @@ -92,11 +93,13 @@ def _mask(mask_value, data_value): STATUS_MESSAGE_TOO_BIG, STATUS_INVALID_EXTENSION, STATUS_UNEXPECTED_CONDITION, + STATUS_SERVICE_RESTART, + STATUS_TRY_AGAIN_LATER, STATUS_BAD_GATEWAY, ) -class ABNF(object): +class ABNF: """ ABNF frame class. See http://tools.ietf.org/html/rfc5234 @@ -130,8 +133,8 @@ class ABNF(object): LENGTH_16 = 1 << 16 LENGTH_63 = 1 << 63 - def __init__(self, fin=0, rsv1=0, rsv2=0, rsv3=0, - opcode=OPCODE_TEXT, mask=1, data=""): + def __init__(self, fin: int = 0, rsv1: int = 0, rsv2: int = 0, rsv3: int = 0, + opcode: int = OPCODE_TEXT, mask: int = 1, data: Union[str, bytes] = "") -> None: """ Constructor for ABNF. Please check RFC for arguments. """ @@ -146,7 +149,7 @@ def __init__(self, fin=0, rsv1=0, rsv2=0, rsv3=0, self.data = data self.get_mask_key = os.urandom - def validate(self, skip_utf8_validation=False): + def validate(self, skip_utf8_validation: bool = False) -> None: """ Validate the ABNF frame. @@ -174,31 +177,31 @@ def validate(self, skip_utf8_validation=False): code = 256 * self.data[0] + self.data[1] if not self._is_valid_close_status(code): - raise WebSocketProtocolException("Invalid close opcode.") + raise WebSocketProtocolException("Invalid close opcode %r", code) @staticmethod - def _is_valid_close_status(code): + def _is_valid_close_status(code: int) -> bool: return code in VALID_CLOSE_STATUS or (3000 <= code < 5000) - def __str__(self): + def __str__(self) -> str: return "fin=" + str(self.fin) \ + " opcode=" + str(self.opcode) \ + " data=" + str(self.data) @staticmethod - def create_frame(data, opcode, fin=1): + def create_frame(data: Union[bytes, str], opcode: int, fin: int = 1) -> 'ABNF': """ Create frame to send text, binary and other data. Parameters ---------- - data: + data: str data to send. This is string value(byte array). If opcode is OPCODE_TEXT and this value is unicode, data value is converted into unicode string, automatically. - opcode: - operation code. please see OPCODE_XXX. - fin: + opcode: int + operation code. please see OPCODE_MAP. + fin: int fin flag. if set to 0, create continue fragmentation. """ if opcode == ABNF.OPCODE_TEXT and isinstance(data, str): @@ -206,7 +209,7 @@ def create_frame(data, opcode, fin=1): # mask must be set if send data from client return ABNF(fin, 0, 0, 0, opcode, 1, data) - def format(self): + def format(self) -> bytes: """ Format this object to string(byte array) to send data to server. """ @@ -236,7 +239,7 @@ def format(self): mask_key = self.get_mask_key(4) return frame_header + self._get_masked(mask_key) - def _get_masked(self, mask_key): + def _get_masked(self, mask_key: Union[str, bytes]) -> bytes: s = ABNF.mask(mask_key, self.data) if isinstance(mask_key, str): @@ -245,15 +248,15 @@ def _get_masked(self, mask_key): return mask_key + s @staticmethod - def mask(mask_key, data): + def mask(mask_key: Union[str, bytes], data: Union[str, bytes]) -> bytes: """ Mask or unmask data. Just do xor for each byte Parameters ---------- - mask_key: - 4 byte string. - data: + mask_key: bytes or str + 4 byte mask. + data: bytes or str data to mask/unmask. """ if data is None: @@ -268,11 +271,11 @@ def mask(mask_key, data): return _mask(array.array("B", mask_key), array.array("B", data)) -class frame_buffer(object): +class frame_buffer: _HEADER_MASK_INDEX = 5 _HEADER_LENGTH_INDEX = 6 - def __init__(self, recv_fn, skip_utf8_validation): + def __init__(self, recv_fn: Callable[[int], int], skip_utf8_validation: bool) -> None: self.recv = recv_fn self.skip_utf8_validation = skip_utf8_validation # Buffers over the packets from the layer beneath until desired amount @@ -281,15 +284,15 @@ def __init__(self, recv_fn, skip_utf8_validation): self.clear() self.lock = Lock() - def clear(self): + def clear(self) -> None: self.header = None self.length = None self.mask = None - def has_received_header(self): + def has_received_header(self) -> bool: return self.header is None - def recv_header(self): + def recv_header(self) -> None: header = self.recv_strict(2) b1 = header[0] fin = b1 >> 7 & 1 @@ -303,15 +306,15 @@ def recv_header(self): self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits) - def has_mask(self): + def has_mask(self) -> Union[bool, int]: if not self.header: return False return self.header[frame_buffer._HEADER_MASK_INDEX] - def has_received_length(self): + def has_received_length(self) -> bool: return self.length is None - def recv_length(self): + def recv_length(self) -> None: bits = self.header[frame_buffer._HEADER_LENGTH_INDEX] length_bits = bits & 0x7f if length_bits == 0x7e: @@ -323,13 +326,13 @@ def recv_length(self): else: self.length = length_bits - def has_received_mask(self): + def has_received_mask(self) -> bool: return self.mask is None - def recv_mask(self): + def recv_mask(self) -> None: self.mask = self.recv_strict(4) if self.has_mask() else "" - def recv_frame(self): + def recv_frame(self) -> ABNF: with self.lock: # Header @@ -360,7 +363,7 @@ def recv_frame(self): return frame - def recv_strict(self, bufsize): + def recv_strict(self, bufsize: int) -> bytes: shortage = bufsize - sum(map(len, self.recv_buffer)) while shortage > 0: # Limit buffer size that we pass to socket.recv() to avoid @@ -373,7 +376,7 @@ def recv_strict(self, bufsize): self.recv_buffer.append(bytes_) shortage -= len(bytes_) - unified = bytes("", 'utf-8').join(self.recv_buffer) + unified = b"".join(self.recv_buffer) if shortage == 0: self.recv_buffer = [] @@ -383,22 +386,22 @@ def recv_strict(self, bufsize): return unified[:bufsize] -class continuous_frame(object): +class continuous_frame: - def __init__(self, fire_cont_frame, skip_utf8_validation): + def __init__(self, fire_cont_frame: bool, skip_utf8_validation: bool) -> None: self.fire_cont_frame = fire_cont_frame self.skip_utf8_validation = skip_utf8_validation self.cont_data = None self.recving_frames = None - def validate(self, frame): + def validate(self, frame: ABNF) -> None: if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT: raise WebSocketProtocolException("Illegal frame") if self.recving_frames and \ frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): raise WebSocketProtocolException("Illegal frame") - def add(self, frame): + def add(self, frame: ABNF) -> None: if self.cont_data: self.cont_data[1] += frame.data else: @@ -409,10 +412,10 @@ def add(self, frame): if frame.fin: self.recving_frames = None - def is_fire(self, frame): + def is_fire(self, frame: ABNF) -> Union[bool, int]: return frame.fin or self.fire_cont_frame - def extract(self, frame): + def extract(self, frame: ABNF) -> list: data = self.cont_data self.cont_data = None frame.data = data[1] diff --git a/plugin/libs/websocket/_app.py b/plugin/libs/websocket/_app.py index 61925bad..13f8bd56 100644 --- a/plugin/libs/websocket/_app.py +++ b/plugin/libs/websocket/_app.py @@ -1,12 +1,24 @@ -""" +import inspect +import selectors +import socket +import sys +import threading +import time +import traceback -""" +from typing import Any, Callable, Optional, Union + +from . import _logging +from ._abnf import ABNF +from ._url import parse_url +from ._core import WebSocket, getdefaulttimeout +from ._exceptions import * """ _app.py websocket - WebSocket client library for Python -Copyright 2021 engn33r +Copyright 2023 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,84 +32,121 @@ See the License for the specific language governing permissions and limitations under the License. """ -import selectors -import sys -import threading -import time -import traceback -from ._abnf import ABNF -from ._core import WebSocket, getdefaulttimeout -from ._exceptions import * -from . import _logging - __all__ = ["WebSocketApp"] +RECONNECT = 0 + + +def setReconnect(reconnectInterval: int) -> None: + global RECONNECT + RECONNECT = reconnectInterval + -class Dispatcher: +class DispatcherBase: """ - Dispatcher + DispatcherBase """ - def __init__(self, app, ping_timeout): + def __init__(self, app: Any, ping_timeout: float) -> None: self.app = app self.ping_timeout = ping_timeout - def read(self, sock, read_callback, check_callback): - while self.app.keep_running: - sel = selectors.DefaultSelector() - sel.register(self.app.sock.sock, selectors.EVENT_READ) + def timeout(self, seconds: int, callback: Callable) -> None: + time.sleep(seconds) + callback() + + def reconnect(self, seconds: int, reconnector: Callable) -> None: + try: + _logging.info("reconnect() - retrying in {seconds_count} seconds [{frame_count} frames in stack]".format( + seconds_count=seconds, frame_count=len(inspect.stack()))) + time.sleep(seconds) + reconnector(reconnecting=True) + except KeyboardInterrupt as e: + _logging.info("User exited {err}".format(err=e)) + raise e - r = sel.select(self.ping_timeout) - if r: - if not read_callback(): - break - check_callback() + +class Dispatcher(DispatcherBase): + """ + Dispatcher + """ + def read(self, sock: socket.socket, read_callback: Callable, check_callback: Callable) -> None: + sel = selectors.DefaultSelector() + sel.register(self.app.sock.sock, selectors.EVENT_READ) + try: + while self.app.keep_running: + r = sel.select(self.ping_timeout) + if r: + if not read_callback(): + break + check_callback() + finally: sel.close() -class SSLDispatcher: +class SSLDispatcher(DispatcherBase): """ SSLDispatcher """ - def __init__(self, app, ping_timeout): - self.app = app - self.ping_timeout = ping_timeout - - def read(self, sock, read_callback, check_callback): - while self.app.keep_running: - r = self.select() - if r: - if not read_callback(): - break - check_callback() + def read(self, sock: socket.socket, read_callback: Callable, check_callback: Callable) -> None: + sock = self.app.sock.sock + sel = selectors.DefaultSelector() + sel.register(sock, selectors.EVENT_READ) + try: + while self.app.keep_running: + r = self.select(sock, sel) + if r: + if not read_callback(): + break + check_callback() + finally: + sel.close() - def select(self): + def select(self, sock, sel:selectors.DefaultSelector): sock = self.app.sock.sock if sock.pending(): return [sock,] - sel = selectors.DefaultSelector() - sel.register(sock, selectors.EVENT_READ) - r = sel.select(self.ping_timeout) - sel.close() if len(r) > 0: return r[0][0] -class WebSocketApp(object): +class WrappedDispatcher: + """ + WrappedDispatcher + """ + def __init__(self, app, ping_timeout: float, dispatcher: Dispatcher) -> None: + self.app = app + self.ping_timeout = ping_timeout + self.dispatcher = dispatcher + dispatcher.signal(2, dispatcher.abort) # keyboard interrupt + + def read(self, sock: socket.socket, read_callback: Callable, check_callback: Callable) -> None: + self.dispatcher.read(sock, read_callback) + self.ping_timeout and self.timeout(self.ping_timeout, check_callback) + + def timeout(self, seconds: int, callback: Callable) -> None: + self.dispatcher.timeout(seconds, callback) + + def reconnect(self, seconds: int, reconnector: Callable) -> None: + self.timeout(seconds, reconnector) + + +class WebSocketApp: """ Higher level of APIs are provided. The interface is like JavaScript WebSocket object. """ - def __init__(self, url, header=None, - on_open=None, on_message=None, on_error=None, - on_close=None, on_ping=None, on_pong=None, - on_cont_message=None, - keep_running=True, get_mask_key=None, cookie=None, - subprotocols=None, - on_data=None): + def __init__(self, url: str, header: Union[list, dict, Callable] = None, + on_open: Callable = None, on_message: Callable = None, on_error: Callable = None, + on_close: Callable = None, on_ping: Callable = None, on_pong: Callable = None, + on_cont_message: Callable = None, + keep_running: bool = True, get_mask_key: Callable = None, cookie: str = None, + subprotocols: list = None, + on_data: Callable = None, + socket: socket.socket = None) -> None: """ WebSocketApp initialization @@ -105,8 +154,11 @@ def __init__(self, url, header=None, ---------- url: str Websocket url. - header: list or dict + header: list or dict or Callable Custom header for websocket handshake. + If the parameter is a callable object, it is called just before the connection attempt. + The returned dict or list is used as custom header value. + This could be useful in order to properly setup timestamp dependent headers. on_open: function Callback object which is called at opening websocket. on_open has one argument. @@ -153,6 +205,8 @@ def __init__(self, url, header=None, Cookie value. subprotocols: list List of available sub protocols. Default is None. + socket: socket + Pre-initialized stream socket. """ self.url = url self.header = header if header is not None else [] @@ -171,9 +225,18 @@ def __init__(self, url, header=None, self.sock = None self.last_ping_tm = 0 self.last_pong_tm = 0 + self.ping_thread = None + self.stop_ping = None + self.ping_interval = 0 + self.ping_timeout = None + self.ping_payload = "" self.subprotocols = subprotocols + self.prepared_socket = socket + self.has_errored = False + self.has_done_teardown = False + self.has_done_teardown_lock = threading.Lock() - def send(self, data, opcode=ABNF.OPCODE_TEXT): + def send(self, data: str, opcode: int = ABNF.OPCODE_TEXT) -> None: """ send message @@ -190,7 +253,7 @@ def send(self, data, opcode=ABNF.OPCODE_TEXT): raise WebSocketConnectionClosedException( "Connection is already closed.") - def close(self, **kwargs): + def close(self, **kwargs) -> None: """ Close websocket connection. """ @@ -199,24 +262,41 @@ def close(self, **kwargs): self.sock.close(**kwargs) self.sock = None - def _send_ping(self, interval, event, payload): - while not event.wait(interval): - self.last_ping_tm = time.time() + def _start_ping_thread(self) -> None: + self.last_ping_tm = self.last_pong_tm = 0 + self.stop_ping = threading.Event() + self.ping_thread = threading.Thread(target=self._send_ping) + self.ping_thread.daemon = True + self.ping_thread.start() + + def _stop_ping_thread(self) -> None: + if self.stop_ping: + self.stop_ping.set() + if self.ping_thread and self.ping_thread.is_alive(): + self.ping_thread.join(3) + self.last_ping_tm = self.last_pong_tm = 0 + + def _send_ping(self) -> None: + if self.stop_ping.wait(self.ping_interval) or self.keep_running is False: + return + while not self.stop_ping.wait(self.ping_interval) and self.keep_running is True: if self.sock: + self.last_ping_tm = time.time() try: - self.sock.ping(payload) - except Exception as ex: - _logging.warning("send_ping routine terminated: {}".format(ex)) - break - - def run_forever(self, sockopt=None, sslopt=None, - ping_interval=0, ping_timeout=None, - ping_payload="", - http_proxy_host=None, http_proxy_port=None, - http_no_proxy=None, http_proxy_auth=None, - skip_utf8_validation=False, - host=None, origin=None, dispatcher=None, - suppress_origin=False, proxy_type=None): + _logging.debug("Sending ping") + self.sock.ping(self.ping_payload) + except Exception as e: + _logging.debug("Failed to send ping: {err}".format(err=e)) + + def run_forever(self, sockopt: tuple = None, sslopt: dict = None, + ping_interval: float = 0, ping_timeout: Optional[float] = None, + ping_payload: str = "", + http_proxy_host: str = None, http_proxy_port: Union[int, str] = None, + http_no_proxy: list = None, http_proxy_auth: tuple = None, + http_proxy_timeout: float = None, + skip_utf8_validation: bool = False, + host: str = None, origin: str = None, dispatcher: Dispatcher = None, + suppress_origin: bool = False, proxy_type: str = None, reconnect: int = None) -> bool: """ Run event loop for WebSocket framework. @@ -244,6 +324,10 @@ def run_forever(self, sockopt=None, sslopt=None, HTTP proxy port. If not set, set to 80. http_no_proxy: list Whitelisted host names that don't use the proxy. + http_proxy_timeout: int or float + HTTP proxy timeout, default is 60 sec as per python-socks. + http_proxy_auth: tuple + HTTP proxy auth information. tuple of username and password. Default is None. skip_utf8_validation: bool skip utf8 validation. host: str @@ -254,13 +338,21 @@ def run_forever(self, sockopt=None, sslopt=None, customize reading data from socket. suppress_origin: bool suppress outputting origin header. + proxy_type: str + type of proxy from: http, socks4, socks4a, socks5, socks5h + reconnect: int + delay interval when reconnecting Returns ------- teardown: bool - False if caught KeyboardInterrupt, True if other exception was raised during a loop + False if the `WebSocketApp` is closed or caught KeyboardInterrupt, + True if any other exception was raised during a loop. """ + if reconnect is None: + reconnect = RECONNECT + if ping_timeout is not None and ping_timeout <= 0: raise WebSocketException("Ensure ping_timeout > 0") if ping_interval is not None and ping_interval < 0: @@ -273,12 +365,13 @@ def run_forever(self, sockopt=None, sslopt=None, sslopt = {} if self.sock: raise WebSocketException("socket is already opened") - thread = None + + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.ping_payload = ping_payload self.keep_running = True - self.last_ping_tm = 0 - self.last_pong_tm = 0 - def teardown(close_frame=None): + def teardown(close_frame: ABNF = None): """ Tears down the connection. @@ -289,9 +382,14 @@ def teardown(close_frame=None): with the statusCode and reason from the provided frame. """ - if thread and thread.is_alive(): - event.set() - thread.join() + # teardown() is called in many code paths to ensure resources are cleaned up and on_close is fired. + # To ensure the work is only done once, we use this bool and lock. + with self.has_done_teardown_lock: + if self.has_done_teardown: + return + self.has_done_teardown = True + + self._stop_ping_thread() self.keep_running = False if self.sock: self.sock.close() @@ -302,87 +400,135 @@ def teardown(close_frame=None): # Finally call the callback AFTER all teardown is complete self._callback(self.on_close, close_status_code, close_reason) - try: + def setSock(reconnecting: bool = False) -> None: + if reconnecting and self.sock: + self.sock.shutdown() + self.sock = WebSocket( self.get_mask_key, sockopt=sockopt, sslopt=sslopt, fire_cont_frame=self.on_cont_message is not None, skip_utf8_validation=skip_utf8_validation, enable_multithread=True) + self.sock.settimeout(getdefaulttimeout()) - self.sock.connect( - self.url, header=self.header, cookie=self.cookie, - http_proxy_host=http_proxy_host, - http_proxy_port=http_proxy_port, http_no_proxy=http_no_proxy, - http_proxy_auth=http_proxy_auth, subprotocols=self.subprotocols, - host=host, origin=origin, suppress_origin=suppress_origin, - proxy_type=proxy_type) - if not dispatcher: - dispatcher = self.create_dispatcher(ping_timeout) - - self._callback(self.on_open) - - if ping_interval: - event = threading.Event() - thread = threading.Thread( - target=self._send_ping, args=(ping_interval, event, ping_payload)) - thread.daemon = True - thread.start() - - def read(): - if not self.keep_running: - return teardown() + try: + + header = self.header() if callable(self.header) else self.header + + self.sock.connect( + self.url, header=header, cookie=self.cookie, + http_proxy_host=http_proxy_host, + http_proxy_port=http_proxy_port, http_no_proxy=http_no_proxy, + http_proxy_auth=http_proxy_auth, http_proxy_timeout=http_proxy_timeout, + subprotocols=self.subprotocols, + host=host, origin=origin, suppress_origin=suppress_origin, + proxy_type=proxy_type, socket=self.prepared_socket) + + _logging.info("Websocket connected") + + if self.ping_interval: + self._start_ping_thread() + self._callback(self.on_open) + + dispatcher.read(self.sock.sock, read, check) + except (WebSocketConnectionClosedException, ConnectionRefusedError, KeyboardInterrupt, SystemExit, Exception) as e: + handleDisconnect(e, reconnecting) + + def read() -> bool: + if not self.keep_running: + return teardown() + + try: op_code, frame = self.sock.recv_data_frame(True) - if op_code == ABNF.OPCODE_CLOSE: - return teardown(frame) - elif op_code == ABNF.OPCODE_PING: - self._callback(self.on_ping, frame.data) - elif op_code == ABNF.OPCODE_PONG: - self.last_pong_tm = time.time() - self._callback(self.on_pong, frame.data) - elif op_code == ABNF.OPCODE_CONT and self.on_cont_message: - self._callback(self.on_data, frame.data, - frame.opcode, frame.fin) - self._callback(self.on_cont_message, - frame.data, frame.fin) + except (WebSocketConnectionClosedException, KeyboardInterrupt) as e: + if custom_dispatcher: + return handleDisconnect(e) else: - data = frame.data - if op_code == ABNF.OPCODE_TEXT: - data = data.decode("utf-8") - self._callback(self.on_data, data, frame.opcode, True) - self._callback(self.on_message, data) - - return True - - def check(): - if (ping_timeout): - has_timeout_expired = time.time() - self.last_ping_tm > ping_timeout - has_pong_not_arrived_after_last_ping = self.last_pong_tm - self.last_ping_tm < 0 - has_pong_arrived_too_late = self.last_pong_tm - self.last_ping_tm > ping_timeout - - if (self.last_ping_tm and - has_timeout_expired and - (has_pong_not_arrived_after_last_ping or has_pong_arrived_too_late)): - raise WebSocketTimeoutException("ping/pong timed out") - return True - - dispatcher.read(self.sock.sock, read, check) - except (Exception, KeyboardInterrupt, SystemExit) as e: - self._callback(self.on_error, e) - if isinstance(e, SystemExit): - # propagate SystemExit further + raise e + + if op_code == ABNF.OPCODE_CLOSE: + return teardown(frame) + elif op_code == ABNF.OPCODE_PING: + self._callback(self.on_ping, frame.data) + elif op_code == ABNF.OPCODE_PONG: + self.last_pong_tm = time.time() + self._callback(self.on_pong, frame.data) + elif op_code == ABNF.OPCODE_CONT and self.on_cont_message: + self._callback(self.on_data, frame.data, + frame.opcode, frame.fin) + self._callback(self.on_cont_message, + frame.data, frame.fin) + else: + data = frame.data + if op_code == ABNF.OPCODE_TEXT and not skip_utf8_validation: + data = data.decode("utf-8") + self._callback(self.on_data, data, frame.opcode, True) + self._callback(self.on_message, data) + + return True + + def check() -> bool: + if (self.ping_timeout): + has_timeout_expired = time.time() - self.last_ping_tm > self.ping_timeout + has_pong_not_arrived_after_last_ping = self.last_pong_tm - self.last_ping_tm < 0 + has_pong_arrived_too_late = self.last_pong_tm - self.last_ping_tm > self.ping_timeout + + if (self.last_ping_tm and + has_timeout_expired and + (has_pong_not_arrived_after_last_ping or has_pong_arrived_too_late)): + raise WebSocketTimeoutException("ping/pong timed out") + return True + + def handleDisconnect(e: Exception, reconnecting: bool = False) -> bool: + self.has_errored = True + self._stop_ping_thread() + if not reconnecting: + self._callback(self.on_error, e) + + if isinstance(e, (KeyboardInterrupt, SystemExit)): + teardown() + # Propagate further raise + + if reconnect: + _logging.info("{err} - reconnect".format(err=e)) + if custom_dispatcher: + _logging.debug("Calling custom dispatcher reconnect [{frame_count} frames in stack]".format(frame_count=len(inspect.stack()))) + dispatcher.reconnect(reconnect, setSock) + else: + _logging.error("{err} - goodbye".format(err=e)) + teardown() + + custom_dispatcher = bool(dispatcher) + dispatcher = self.create_dispatcher(ping_timeout, dispatcher, parse_url(self.url)[3]) + + try: + setSock() + if not custom_dispatcher and reconnect: + while self.keep_running: + _logging.debug("Calling dispatcher reconnect [{frame_count} frames in stack]".format(frame_count=len(inspect.stack()))) + dispatcher.reconnect(reconnect, setSock) + except (KeyboardInterrupt, Exception) as e: + _logging.info("tearing down on exception {err}".format(err=e)) teardown() - return not isinstance(e, KeyboardInterrupt) + finally: + if not custom_dispatcher: + # Ensure teardown was called before returning from run_forever + teardown() + + return self.has_errored - def create_dispatcher(self, ping_timeout): + def create_dispatcher(self, ping_timeout: int, dispatcher: Dispatcher = None, is_ssl: bool = False) -> DispatcherBase: + if dispatcher: # If custom dispatcher is set, use WrappedDispatcher + return WrappedDispatcher(self, ping_timeout, dispatcher) timeout = ping_timeout or 10 - if self.sock.is_ssl(): + if is_ssl: return SSLDispatcher(self, timeout) return Dispatcher(self, timeout) - def _get_close_args(self, close_frame): + def _get_close_args(self, close_frame: ABNF) -> list: """ _get_close_args extracts the close code and reason from the close body if it exists (RFC6455 says WebSocket Connection Close Code is optional) @@ -401,12 +547,12 @@ def _get_close_args(self, close_frame): # Most likely reached this because len(close_frame_data.data) < 2 return [None, None] - def _callback(self, callback, *args): + def _callback(self, callback, *args) -> None: if callback: try: callback(self, *args) except Exception as e: - _logging.error("error from callback {}: {}".format(callback, e)) + _logging.error("error from callback {callback}: {err}".format(callback=callback, err=e)) if self.on_error: self.on_error(self, e) diff --git a/plugin/libs/websocket/_cookiejar.py b/plugin/libs/websocket/_cookiejar.py index dcf5031a..bf907d6b 100644 --- a/plugin/libs/websocket/_cookiejar.py +++ b/plugin/libs/websocket/_cookiejar.py @@ -1,12 +1,12 @@ -""" +import http.cookies -""" +from typing import Optional """ _cookiejar.py websocket - WebSocket client library for Python -Copyright 2021 engn33r +Copyright 2023 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,14 +20,13 @@ See the License for the specific language governing permissions and limitations under the License. """ -import http.cookies -class SimpleCookieJar(object): - def __init__(self): +class SimpleCookieJar: + def __init__(self) -> None: self.jar = dict() - def add(self, set_cookie): + def add(self, set_cookie: Optional[str]) -> None: if set_cookie: simpleCookie = http.cookies.SimpleCookie(set_cookie) @@ -40,7 +39,7 @@ def add(self, set_cookie): cookie.update(simpleCookie) self.jar[domain.lower()] = cookie - def set(self, set_cookie): + def set(self, set_cookie: str) -> None: if set_cookie: simpleCookie = http.cookies.SimpleCookie(set_cookie) @@ -51,7 +50,7 @@ def set(self, set_cookie): domain = "." + domain self.jar[domain.lower()] = simpleCookie - def get(self, host): + def get(self, host: str) -> str: if not host: return "" diff --git a/plugin/libs/websocket/_core.py b/plugin/libs/websocket/_core.py index f92f8a60..fea2b6d4 100644 --- a/plugin/libs/websocket/_core.py +++ b/plugin/libs/websocket/_core.py @@ -1,14 +1,25 @@ -""" -_core.py -==================================== -WebSocket Python client -""" +import socket +import struct +import threading +import time + +from typing import Optional, Union + +# websocket modules +from ._abnf import * +from ._exceptions import * +from ._handshake import * +from ._http import * +from ._logging import * +from ._socket import * +from ._ssl_compat import * +from ._utils import * """ _core.py websocket - WebSocket client library for Python -Copyright 2021 engn33r +Copyright 2023 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,25 +33,11 @@ See the License for the specific language governing permissions and limitations under the License. """ -import socket -import struct -import threading -import time - -# websocket modules -from ._abnf import * -from ._exceptions import * -from ._handshake import * -from ._http import * -from ._logging import * -from ._socket import * -from ._ssl_compat import * -from ._utils import * __all__ = ['WebSocket', 'create_connection'] -class WebSocket(object): +class WebSocket: """ Low level WebSocket interface. @@ -51,8 +48,11 @@ class WebSocket(object): >>> import websocket >>> ws = websocket.WebSocket() - >>> ws.connect("ws://echo.websocket.org") + >>> ws.connect("ws://echo.websocket.events") + >>> ws.recv() + 'echo.websocket.events sponsored by Lob.com' >>> ws.send("Hello, Server") + 19 >>> ws.recv() 'Hello, Server' >>> ws.close() @@ -66,7 +66,7 @@ class WebSocket(object): Values for socket.setsockopt. sockopt must be tuple and each element is argument of sock.setsockopt. sslopt: dict - Optional dict object for ssl socket options. + Optional dict object for ssl socket options. See FAQ for details. fire_cont_frame: bool Fire recv event for each cont frame. Default is False. enable_multithread: bool @@ -76,15 +76,15 @@ class WebSocket(object): """ def __init__(self, get_mask_key=None, sockopt=None, sslopt=None, - fire_cont_frame=False, enable_multithread=True, - skip_utf8_validation=False, **_): + fire_cont_frame: bool = False, enable_multithread: bool = True, + skip_utf8_validation: bool = False, **_): """ Initialize WebSocket object. Parameters ---------- sslopt: dict - Optional dict object for ssl socket options. + Optional dict object for ssl socket options. See FAQ for details. """ self.sock_opt = sock_opt(sockopt, sslopt) self.handshake_response = None @@ -135,7 +135,7 @@ def set_mask_key(self, func): """ self.get_mask_key = func - def gettimeout(self): + def gettimeout(self) -> float: """ Get the websocket timeout (in seconds) as an int or float @@ -146,7 +146,7 @@ def gettimeout(self): """ return self.sock_opt.timeout - def settimeout(self, timeout): + def settimeout(self, timeout: Optional[float]): """ Set the timeout to the websocket. @@ -208,7 +208,7 @@ def connect(self, url, **options): If you set "header" list object, you can set your own custom header. >>> ws = WebSocket() - >>> ws.connect("ws://echo.websocket.org/", + >>> ws.connect("ws://echo.websocket.events", ... header=["User-Agent: MyProgram", ... "x-custom: header"]) @@ -238,6 +238,8 @@ def connect(self, url, **options): Whitelisted host names that don't use the proxy. http_proxy_auth: tuple HTTP proxy auth information. Tuple of username and password. Default is None. + http_proxy_timeout: int or float + HTTP proxy timeout, default is 60 sec as per python-socks. redirect_limit: int Number of redirects to follow. subprotocols: list @@ -250,14 +252,14 @@ def connect(self, url, **options): options.pop('socket', None)) try: - self.handshake_response = handshake(self.sock, *addrs, **options) + self.handshake_response = handshake(self.sock, url, *addrs, **options) for attempt in range(options.pop('redirect_limit', 3)): if self.handshake_response.status in SUPPORTED_REDIRECT_STATUSES: url = self.handshake_response.headers['location'] self.sock.close() self.sock, addrs = connect(url, self.sock_opt, proxy_info(**options), options.pop('socket', None)) - self.handshake_response = handshake(self.sock, *addrs, **options) + self.handshake_response = handshake(self.sock, url, *addrs, **options) self.connected = True except: if self.sock: @@ -265,7 +267,7 @@ def connect(self, url, **options): self.sock = None raise - def send(self, payload, opcode=ABNF.OPCODE_TEXT): + def send(self, payload: Union[bytes, str], opcode: int = ABNF.OPCODE_TEXT) -> int: """ Send the data as string. @@ -282,11 +284,11 @@ def send(self, payload, opcode=ABNF.OPCODE_TEXT): frame = ABNF.create_frame(payload, opcode) return self.send_frame(frame) - def send_frame(self, frame): + def send_frame(self, frame) -> int: """ Send the data frame. - >>> ws = create_connection("ws://echo.websocket.org/") + >>> ws = create_connection("ws://echo.websocket.events") >>> frame = ABNF.create_frame("Hello", ABNF.OPCODE_TEXT) >>> ws.send_frame(frame) >>> cont_frame = ABNF.create_frame("My name is ", ABNF.OPCODE_CONT, 0) @@ -313,10 +315,18 @@ def send_frame(self, frame): return length - def send_binary(self, payload): + def send_binary(self, payload: bytes) -> int: + """ + Send a binary message (OPCODE_BINARY). + + Parameters + ---------- + payload: bytes + payload of message to send. + """ return self.send(payload, ABNF.OPCODE_BINARY) - def ping(self, payload=""): + def ping(self, payload: Union[str, bytes] = ""): """ Send ping data. @@ -329,7 +339,7 @@ def ping(self, payload=""): payload = payload.encode("utf-8") self.send(payload, ABNF.OPCODE_PING) - def pong(self, payload=""): + def pong(self, payload: Union[str, bytes] = ""): """ Send pong data. @@ -342,7 +352,7 @@ def pong(self, payload=""): payload = payload.encode("utf-8") self.send(payload, ABNF.OPCODE_PONG) - def recv(self): + def recv(self) -> Union[str, bytes]: """ Receive string data(byte array) from the server. @@ -359,7 +369,7 @@ def recv(self): else: return '' - def recv_data(self, control_frame=False): + def recv_data(self, control_frame: bool = False) -> tuple: """ Receive data with operation code. @@ -377,10 +387,12 @@ def recv_data(self, control_frame=False): opcode, frame = self.recv_data_frame(control_frame) return opcode, frame.data - def recv_data_frame(self, control_frame=False): + def recv_data_frame(self, control_frame: bool = False): """ Receive data with operation code. + If a valid ping message is received, a pong response is sent. + Parameters ---------- control_frame: bool @@ -401,7 +413,7 @@ def recv_data_frame(self, control_frame=False): # handle error: # 'NoneType' object has no attribute 'opcode' raise WebSocketProtocolException( - "Not a valid frame %s" % frame) + "Not a valid frame {frame}".format(frame=frame)) elif frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY, ABNF.OPCODE_CONT): self.cont_frame.validate(frame) self.cont_frame.add(frame) @@ -434,7 +446,7 @@ def recv_frame(self): """ return self.frame_buffer.recv_frame() - def send_close(self, status=STATUS_NORMAL, reason=bytes('', encoding='utf-8')): + def send_close(self, status: int = STATUS_NORMAL, reason: bytes = b""): """ Send close data to the server. @@ -443,23 +455,23 @@ def send_close(self, status=STATUS_NORMAL, reason=bytes('', encoding='utf-8')): status: int Status code to send. See STATUS_XXX. reason: str or bytes - The reason to close. This must be string or bytes. + The reason to close. This must be string or UTF-8 bytes. """ if status < 0 or status >= ABNF.LENGTH_16: raise ValueError("code is invalid range") self.connected = False self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE) - def close(self, status=STATUS_NORMAL, reason=bytes('', encoding='utf-8'), timeout=3): + def close(self, status: int = STATUS_NORMAL, reason: bytes = b"", timeout: float = 3): """ Close Websocket object Parameters ---------- status: int - Status code to send. See STATUS_XXX. + Status code to send. See VALID_CLOSE_STATUS in ABNF. reason: bytes - The reason to close. + The reason to close in UTF-8. timeout: int or float Timeout until receive a close frame. If None, it will wait forever until receive a close frame. @@ -511,7 +523,7 @@ def shutdown(self): self.sock = None self.connected = False - def _send(self, data): + def _send(self, data: Union[str, bytes]): return send(self.sock, data) def _recv(self, bufsize): @@ -525,7 +537,7 @@ def _recv(self, bufsize): raise -def create_connection(url, timeout=None, class_=WebSocket, **options): +def create_connection(url: str, timeout=None, class_=WebSocket, **options): """ Connect to url and return websocket object. @@ -536,7 +548,7 @@ def create_connection(url, timeout=None, class_=WebSocket, **options): You can customize using 'options'. If you set "header" list object, you can set your own custom header. - >>> conn = create_connection("ws://echo.websocket.org/", + >>> conn = create_connection("ws://echo.websocket.events", ... header=["User-Agent: MyProgram", ... "x-custom: header"]) @@ -567,6 +579,8 @@ class to instantiate when creating the connection. It has to implement Whitelisted host names that don't use the proxy. http_proxy_auth: tuple HTTP proxy auth information. tuple of username and password. Default is None. + http_proxy_timeout: int or float + HTTP proxy timeout, default is 60 sec as per python-socks. enable_multithread: bool Enable lock for multithread. redirect_limit: int @@ -575,7 +589,7 @@ class to instantiate when creating the connection. It has to implement Values for socket.setsockopt. sockopt must be a tuple and each element is an argument of sock.setsockopt. sslopt: dict - Optional dict object for ssl socket options. + Optional dict object for ssl socket options. See FAQ for details. subprotocols: list List of available subprotocols. Default is None. skip_utf8_validation: bool diff --git a/plugin/libs/websocket/_exceptions.py b/plugin/libs/websocket/_exceptions.py index 2d5b0535..48f40a07 100644 --- a/plugin/libs/websocket/_exceptions.py +++ b/plugin/libs/websocket/_exceptions.py @@ -1,12 +1,8 @@ -""" -Define WebSocket exceptions -""" - """ _exceptions.py websocket - WebSocket client library for Python -Copyright 2021 engn33r +Copyright 2023 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -70,11 +66,11 @@ class WebSocketBadStatusException(WebSocketException): WebSocketBadStatusException will be raised when we get bad handshake status code. """ - def __init__(self, message, status_code, status_message=None, resp_headers=None): - msg = message % (status_code, status_message) - super(WebSocketBadStatusException, self).__init__(msg) + def __init__(self, message: str, status_code: int, status_message=None, resp_headers=None, resp_body=None): + super().__init__(message) self.status_code = status_code self.resp_headers = resp_headers + self.resp_body = resp_body class WebSocketAddressException(WebSocketException): diff --git a/plugin/libs/websocket/_handshake.py b/plugin/libs/websocket/_handshake.py index da1a8d44..a94d3030 100644 --- a/plugin/libs/websocket/_handshake.py +++ b/plugin/libs/websocket/_handshake.py @@ -2,7 +2,7 @@ _handshake.py websocket - WebSocket client library for Python -Copyright 2021 engn33r +Copyright 2023 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,23 +32,23 @@ # websocket supported version. VERSION = 13 -SUPPORTED_REDIRECT_STATUSES = (HTTPStatus.MOVED_PERMANENTLY, HTTPStatus.FOUND, HTTPStatus.SEE_OTHER,) +SUPPORTED_REDIRECT_STATUSES = (HTTPStatus.MOVED_PERMANENTLY, HTTPStatus.FOUND, HTTPStatus.SEE_OTHER, HTTPStatus.TEMPORARY_REDIRECT, HTTPStatus.PERMANENT_REDIRECT) SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,) CookieJar = SimpleCookieJar() -class handshake_response(object): +class handshake_response: - def __init__(self, status, headers, subprotocol): + def __init__(self, status: int, headers: dict, subprotocol): self.status = status self.headers = headers self.subprotocol = subprotocol CookieJar.add(headers.get("set-cookie")) -def handshake(sock, hostname, port, resource, **options): - headers, key = _get_handshake_headers(resource, hostname, port, options) +def handshake(sock, url: str, hostname: str, port: int, resource: str, **options): + headers, key = _get_handshake_headers(resource, url, hostname, port, options) header_str = "\r\n".join(headers) send(sock, header_str) @@ -64,7 +64,7 @@ def handshake(sock, hostname, port, resource, **options): return handshake_response(status, resp, subproto) -def _pack_hostname(hostname): +def _pack_hostname(hostname: str) -> str: # IPv6 address if ':' in hostname: return '[' + hostname + ']' @@ -72,49 +72,53 @@ def _pack_hostname(hostname): return hostname -def _get_handshake_headers(resource, host, port, options): +def _get_handshake_headers(resource: str, url: str, host: str, port: int, options: dict): headers = [ - "GET %s HTTP/1.1" % resource, + "GET {resource} HTTP/1.1".format(resource=resource), "Upgrade: websocket" ] if port == 80 or port == 443: hostport = _pack_hostname(host) else: - hostport = "%s:%d" % (_pack_hostname(host), port) - if "host" in options and options["host"] is not None: - headers.append("Host: %s" % options["host"]) + hostport = "{h}:{p}".format(h=_pack_hostname(host), p=port) + if options.get("host"): + headers.append("Host: {h}".format(h=options["host"])) else: - headers.append("Host: %s" % hostport) + headers.append("Host: {hp}".format(hp=hostport)) - if "suppress_origin" not in options or not options["suppress_origin"]: + # scheme indicates whether http or https is used in Origin + # The same approach is used in parse_url of _url.py to set default port + scheme, url = url.split(":", 1) + if not options.get("suppress_origin"): if "origin" in options and options["origin"] is not None: - headers.append("Origin: %s" % options["origin"]) + headers.append("Origin: {origin}".format(origin=options["origin"])) + elif scheme == "wss": + headers.append("Origin: https://{hp}".format(hp=hostport)) else: - headers.append("Origin: http://%s" % hostport) + headers.append("Origin: http://{hp}".format(hp=hostport)) key = _create_sec_websocket_key() # Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified - if 'header' not in options or 'Sec-WebSocket-Key' not in options['header']: - key = _create_sec_websocket_key() - headers.append("Sec-WebSocket-Key: %s" % key) + if not options.get('header') or 'Sec-WebSocket-Key' not in options['header']: + headers.append("Sec-WebSocket-Key: {key}".format(key=key)) else: key = options['header']['Sec-WebSocket-Key'] - if 'header' not in options or 'Sec-WebSocket-Version' not in options['header']: - headers.append("Sec-WebSocket-Version: %s" % VERSION) + if not options.get('header') or 'Sec-WebSocket-Version' not in options['header']: + headers.append("Sec-WebSocket-Version: {version}".format(version=VERSION)) - if 'connection' not in options or options['connection'] is None: + if not options.get('connection'): headers.append('Connection: Upgrade') else: headers.append(options['connection']) subprotocols = options.get("subprotocols") if subprotocols: - headers.append("Sec-WebSocket-Protocol: %s" % ",".join(subprotocols)) + headers.append("Sec-WebSocket-Protocol: {protocols}".format(protocols=",".join(subprotocols))) - if "header" in options: - header = options["header"] + header = options.get("header") + if header: if isinstance(header, dict): header = [ ": ".join([k, v]) @@ -129,18 +133,21 @@ def _get_handshake_headers(resource, host, port, options): cookie = "; ".join(filter(None, [server_cookie, client_cookie])) if cookie: - headers.append("Cookie: %s" % cookie) - - headers.append("") - headers.append("") + headers.append("Cookie: {cookie}".format(cookie=cookie)) + headers.extend(("", "")) return headers, key -def _get_resp_headers(sock, success_statuses=SUCCESS_STATUSES): +def _get_resp_headers(sock, success_statuses: tuple = SUCCESS_STATUSES) -> tuple: status, resp_headers, status_message = read_headers(sock) if status not in success_statuses: - raise WebSocketBadStatusException("Handshake status %d %s", status, status_message, resp_headers) + content_len = resp_headers.get('content-length') + if content_len: + response_body = sock.recv(int(content_len)) # read the body of the HTTP error message response and include it in the exception + else: + response_body = None + raise WebSocketBadStatusException("Handshake status {status} {message} -+-+- {headers} -+-+- {body}".format(status=status, message=status_message, headers=resp_headers, body=response_body), status, status_message, resp_headers, response_body) return status, resp_headers @@ -150,7 +157,7 @@ def _get_resp_headers(sock, success_statuses=SUCCESS_STATUSES): } -def _validate(headers, key, subprotocols): +def _validate(headers, key: str, subprotocols): subproto = None for k, v in _HEADERS_TO_CHECK.items(): r = headers.get(k, None) @@ -185,6 +192,6 @@ def _validate(headers, key, subprotocols): return False, None -def _create_sec_websocket_key(): +def _create_sec_websocket_key() -> str: randomness = os.urandom(16) return base64encode(randomness).decode('utf-8').strip() diff --git a/plugin/libs/websocket/_http.py b/plugin/libs/websocket/_http.py index 9ddf01d0..13183b20 100644 --- a/plugin/libs/websocket/_http.py +++ b/plugin/libs/websocket/_http.py @@ -2,7 +2,7 @@ _http.py websocket - WebSocket client library for Python -Copyright 2021 engn33r +Copyright 2023 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,11 +19,10 @@ import errno import os import socket -import sys from ._exceptions import * from ._logging import * -from ._socket import* +from ._socket import * from ._ssl_compat import * from ._url import * @@ -49,7 +48,7 @@ class ProxyConnectionError(Exception): pass -class proxy_info(object): +class proxy_info: def __init__(self, **options): self.proxy_host = options.get("http_proxy_host", None) @@ -59,7 +58,7 @@ def __init__(self, **options): self.no_proxy = options.get("http_no_proxy", None) self.proxy_protocol = options.get("proxy_type", "http") # Note: If timeout not specified, default python-socks timeout is 60 seconds - self.proxy_timeout = options.get("timeout", None) + self.proxy_timeout = options.get("http_proxy_timeout", None) if self.proxy_protocol not in ['http', 'socks4', 'socks4a', 'socks5', 'socks5h']: raise ProxyError("Only http, socks4, socks5 proxy protocols are supported") else: @@ -69,7 +68,7 @@ def __init__(self, **options): self.proxy_protocol = "http" -def _start_proxied_socket(url, options, proxy): +def _start_proxied_socket(url: str, options, proxy): if not HAVE_PYTHON_SOCKS: raise WebSocketException("Python Socks is needed for SOCKS proxying but is not available") @@ -107,29 +106,29 @@ def _start_proxied_socket(url, options, proxy): return sock, (hostname, port, resource) -def connect(url, options, proxy, socket): +def connect(url: str, options, proxy, socket): # Use _start_proxied_socket() only for socks4 or socks5 proxy # Use _tunnel() for http proxy # TODO: Use python-socks for http protocol also, to standardize flow if proxy.proxy_host and not socket and not (proxy.proxy_protocol == "http"): return _start_proxied_socket(url, options, proxy) - hostname, port, resource, is_secure = parse_url(url) + hostname, port_from_url, resource, is_secure = parse_url(url) if socket: - return socket, (hostname, port, resource) + return socket, (hostname, port_from_url, resource) addrinfo_list, need_tunnel, auth = _get_addrinfo_list( - hostname, port, is_secure, proxy) + hostname, port_from_url, is_secure, proxy) if not addrinfo_list: raise WebSocketException( - "Host not found.: " + hostname + ":" + str(port)) + "Host not found.: " + hostname + ":" + str(port_from_url)) sock = None try: sock = _open_socket(addrinfo_list, options.sockopt, options.timeout) if need_tunnel: - sock = _tunnel(sock, hostname, port, auth) + sock = _tunnel(sock, hostname, port_from_url, auth) if is_secure: if HAVE_SSL: @@ -137,7 +136,7 @@ def connect(url, options, proxy, socket): else: raise WebSocketException("SSL not available.") - return sock, (hostname, port, resource) + return sock, (hostname, port_from_url, resource) except: if sock: sock.close() @@ -184,19 +183,16 @@ def _open_socket(addrinfo_list, sockopt, timeout): try: sock.connect(address) except socket.error as error: + sock.close() error.remote_ip = str(address[0]) try: - eConnRefused = (errno.ECONNREFUSED, errno.WSAECONNREFUSED) - except: - eConnRefused = (errno.ECONNREFUSED, ) - if error.errno == errno.EINTR: - continue - elif error.errno in eConnRefused: + eConnRefused = (errno.ECONNREFUSED, errno.WSAECONNREFUSED, errno.ENETUNREACH) + except AttributeError: + eConnRefused = (errno.ECONNREFUSED, errno.ENETUNREACH) + if error.errno in eConnRefused: err = error continue else: - if sock: - sock.close() raise error else: break @@ -211,33 +207,46 @@ def _open_socket(addrinfo_list, sockopt, timeout): def _wrap_sni_socket(sock, sslopt, hostname, check_hostname): - context = ssl.SSLContext(sslopt.get('ssl_version', ssl.PROTOCOL_TLS)) - - if sslopt.get('cert_reqs', ssl.CERT_NONE) != ssl.CERT_NONE: - cafile = sslopt.get('ca_certs', None) - capath = sslopt.get('ca_cert_path', None) - if cafile or capath: - context.load_verify_locations(cafile=cafile, capath=capath) - elif hasattr(context, 'load_default_certs'): - context.load_default_certs(ssl.Purpose.SERVER_AUTH) - if sslopt.get('certfile', None): - context.load_cert_chain( - sslopt['certfile'], - sslopt.get('keyfile', None), - sslopt.get('password', None), - ) - # see - # https://github.com/liris/websocket-client/commit/b96a2e8fa765753e82eea531adb19716b52ca3ca#commitcomment-10803153 - context.verify_mode = sslopt['cert_reqs'] - if HAVE_CONTEXT_CHECK_HOSTNAME: - context.check_hostname = check_hostname - if 'ciphers' in sslopt: - context.set_ciphers(sslopt['ciphers']) - if 'cert_chain' in sslopt: - certfile, keyfile, password = sslopt['cert_chain'] - context.load_cert_chain(certfile, keyfile, password) - if 'ecdh_curve' in sslopt: - context.set_ecdh_curve(sslopt['ecdh_curve']) + context = sslopt.get('context', None) + if not context: + context = ssl.SSLContext(sslopt.get('ssl_version', ssl.PROTOCOL_TLS_CLIENT)) + # Non default context need to manually enable SSLKEYLOGFILE support by setting the keylog_filename attribute. + # For more details see also: + # * https://docs.python.org/3.8/library/ssl.html?highlight=sslkeylogfile#context-creation + # * https://docs.python.org/3.8/library/ssl.html?highlight=sslkeylogfile#ssl.SSLContext.keylog_filename + context.keylog_filename = os.environ.get("SSLKEYLOGFILE", None) + + if sslopt.get('cert_reqs', ssl.CERT_NONE) != ssl.CERT_NONE: + cafile = sslopt.get('ca_certs', None) + capath = sslopt.get('ca_cert_path', None) + if cafile or capath: + context.load_verify_locations(cafile=cafile, capath=capath) + elif hasattr(context, 'load_default_certs'): + context.load_default_certs(ssl.Purpose.SERVER_AUTH) + if sslopt.get('certfile', None): + context.load_cert_chain( + sslopt['certfile'], + sslopt.get('keyfile', None), + sslopt.get('password', None), + ) + + # Python 3.10 switch to PROTOCOL_TLS_CLIENT defaults to "cert_reqs = ssl.CERT_REQUIRED" and "check_hostname = True" + # If both disabled, set check_hostname before verify_mode + # see https://github.com/liris/websocket-client/commit/b96a2e8fa765753e82eea531adb19716b52ca3ca#commitcomment-10803153 + if sslopt.get('cert_reqs', ssl.CERT_NONE) == ssl.CERT_NONE and not sslopt.get('check_hostname', False): + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + else: + context.check_hostname = sslopt.get('check_hostname', True) + context.verify_mode = sslopt.get('cert_reqs', ssl.CERT_REQUIRED) + + if 'ciphers' in sslopt: + context.set_ciphers(sslopt['ciphers']) + if 'cert_chain' in sslopt: + certfile, keyfile, password = sslopt['cert_chain'] + context.load_cert_chain(certfile, keyfile, password) + if 'ecdh_curve' in sslopt: + context.set_ecdh_curve(sslopt['ecdh_curve']) return context.wrap_socket( sock, @@ -262,20 +271,16 @@ def _ssl_socket(sock, user_sslopt, hostname): if sslopt.get('server_hostname', None): hostname = sslopt['server_hostname'] - check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop( - 'check_hostname', True) + check_hostname = sslopt.get('check_hostname', True) sock = _wrap_sni_socket(sock, sslopt, hostname, check_hostname) - if not HAVE_CONTEXT_CHECK_HOSTNAME and check_hostname: - match_hostname(sock.getpeercert(), hostname) - return sock def _tunnel(sock, host, port, auth): debug("Connecting proxy...") - connect_header = "CONNECT %s:%d HTTP/1.1\r\n" % (host, port) - connect_header += "Host: %s:%d\r\n" % (host, port) + connect_header = "CONNECT {h}:{p} HTTP/1.1\r\n".format(h=host, p=port) + connect_header += "Host: {h}:{p}\r\n".format(h=host, p=port) # TODO: support digest auth. if auth and auth[0]: @@ -283,7 +288,7 @@ def _tunnel(sock, host, port, auth): if auth[1]: auth_str += ":" + auth[1] encoded_str = base64encode(auth_str.encode()).strip().decode().replace('\n', '') - connect_header += "Proxy-Authorization: Basic %s\r\n" % encoded_str + connect_header += "Proxy-Authorization: Basic {str}\r\n".format(str=encoded_str) connect_header += "\r\n" dump("request header", connect_header) @@ -296,7 +301,7 @@ def _tunnel(sock, host, port, auth): if status != 200: raise WebSocketProxyException( - "failed CONNECT via proxy status: %r" % status) + "failed CONNECT via proxy status: {status}".format(status=status)) return sock diff --git a/plugin/libs/websocket/_logging.py b/plugin/libs/websocket/_logging.py index 480d43b0..806de4d4 100644 --- a/plugin/libs/websocket/_logging.py +++ b/plugin/libs/websocket/_logging.py @@ -1,12 +1,10 @@ -""" - -""" +import logging """ _logging.py websocket - WebSocket client library for Python -Copyright 2021 engn33r +Copyright 2023 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,14 +18,13 @@ See the License for the specific language governing permissions and limitations under the License. """ -import logging _logger = logging.getLogger('websocket') try: from logging import NullHandler except ImportError: class NullHandler(logging.Handler): - def emit(self, record): + def emit(self, record) -> None: pass _logger.addHandler(NullHandler()) @@ -38,7 +35,9 @@ def emit(self, record): "isEnabledForError", "isEnabledForDebug", "isEnabledForTrace"] -def enableTrace(traceable, handler=logging.StreamHandler()): +def enableTrace(traceable: bool, + handler: logging.StreamHandler = logging.StreamHandler(), + level: str = "DEBUG") -> None: """ Turn on/off the traceability. @@ -51,40 +50,44 @@ def enableTrace(traceable, handler=logging.StreamHandler()): _traceEnabled = traceable if traceable: _logger.addHandler(handler) - _logger.setLevel(logging.DEBUG) + _logger.setLevel(getattr(logging, level)) -def dump(title, message): +def dump(title: str, message: str) -> None: if _traceEnabled: _logger.debug("--- " + title + " ---") _logger.debug(message) _logger.debug("-----------------------") -def error(msg): +def error(msg: str) -> None: _logger.error(msg) -def warning(msg): +def warning(msg: str) -> None: _logger.warning(msg) -def debug(msg): +def debug(msg: str) -> None: _logger.debug(msg) -def trace(msg): +def info(msg: str) -> None: + _logger.info(msg) + + +def trace(msg: str) -> None: if _traceEnabled: _logger.debug(msg) -def isEnabledForError(): +def isEnabledForError() -> bool: return _logger.isEnabledFor(logging.ERROR) -def isEnabledForDebug(): +def isEnabledForDebug() -> bool: return _logger.isEnabledFor(logging.DEBUG) -def isEnabledForTrace(): +def isEnabledForTrace() -> bool: return _traceEnabled diff --git a/plugin/libs/websocket/_socket.py b/plugin/libs/websocket/_socket.py index eb573d4e..1575a0c0 100644 --- a/plugin/libs/websocket/_socket.py +++ b/plugin/libs/websocket/_socket.py @@ -1,12 +1,18 @@ -""" +import errno +import selectors +import socket -""" +from typing import Union + +from ._exceptions import * +from ._ssl_compat import * +from ._utils import * """ _socket.py websocket - WebSocket client library for Python -Copyright 2021 engn33r +Copyright 2023 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,13 +26,6 @@ See the License for the specific language governing permissions and limitations under the License. """ -import errno -import selectors -import socket - -from ._exceptions import * -from ._ssl_compat import * -from ._utils import * DEFAULT_SOCKET_OPTION = [(socket.SOL_TCP, socket.TCP_NODELAY, 1)] if hasattr(socket, "SO_KEEPALIVE"): @@ -44,9 +43,9 @@ "recv", "recv_line", "send"] -class sock_opt(object): +class sock_opt: - def __init__(self, sockopt, sslopt): + def __init__(self, sockopt: list, sslopt: dict) -> None: if sockopt is None: sockopt = [] if sslopt is None: @@ -56,7 +55,7 @@ def __init__(self, sockopt, sslopt): self.timeout = None -def setdefaulttimeout(timeout): +def setdefaulttimeout(timeout: Union[int, float, None]) -> None: """ Set the global timeout setting to connect. @@ -69,7 +68,7 @@ def setdefaulttimeout(timeout): _default_timeout = timeout -def getdefaulttimeout(): +def getdefaulttimeout() -> Union[int, float, None]: """ Get default timeout @@ -81,7 +80,7 @@ def getdefaulttimeout(): return _default_timeout -def recv(sock, bufsize): +def recv(sock: socket.socket, bufsize: int) -> bytes: if not sock: raise WebSocketConnectionClosedException("socket is already closed.") @@ -92,9 +91,7 @@ def _recv(): pass except socket.error as exc: error_code = extract_error_code(exc) - if error_code is None: - raise - if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK: + if error_code != errno.EAGAIN and error_code != errno.EWOULDBLOCK: raise sel = selectors.DefaultSelector() @@ -111,6 +108,8 @@ def _recv(): bytes_ = sock.recv(bufsize) else: bytes_ = _recv() + except TimeoutError: + raise WebSocketTimeoutException("Connection timed out") except socket.timeout as e: message = extract_err_message(e) raise WebSocketTimeoutException(message) @@ -128,7 +127,7 @@ def _recv(): return bytes_ -def recv_line(sock): +def recv_line(sock: socket.socket) -> bytes: line = [] while True: c = recv(sock, 1) @@ -138,7 +137,7 @@ def recv_line(sock): return b''.join(line) -def send(sock, data): +def send(sock: socket.socket, data: Union[bytes, str]) -> int: if isinstance(data, str): data = data.encode('utf-8') @@ -154,7 +153,7 @@ def _send(): error_code = extract_error_code(exc) if error_code is None: raise - if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK: + if error_code != errno.EAGAIN and error_code != errno.EWOULDBLOCK: raise sel = selectors.DefaultSelector() diff --git a/plugin/libs/websocket/_ssl_compat.py b/plugin/libs/websocket/_ssl_compat.py index 9e5460c2..b2eba387 100644 --- a/plugin/libs/websocket/_ssl_compat.py +++ b/plugin/libs/websocket/_ssl_compat.py @@ -2,7 +2,7 @@ _ssl_compat.py websocket - WebSocket client library for Python -Copyright 2021 engn33r +Copyright 2023 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,11 +23,6 @@ from ssl import SSLError from ssl import SSLWantReadError from ssl import SSLWantWriteError - HAVE_CONTEXT_CHECK_HOSTNAME = False - if hasattr(ssl, 'SSLContext') and hasattr(ssl.SSLContext, 'check_hostname'): - HAVE_CONTEXT_CHECK_HOSTNAME = True - - __all__.append("HAVE_CONTEXT_CHECK_HOSTNAME") HAVE_SSL = True except ImportError: # dummy class of SSLError for environment without ssl support diff --git a/plugin/libs/websocket/_url.py b/plugin/libs/websocket/_url.py index f2a55019..a3306154 100644 --- a/plugin/libs/websocket/_url.py +++ b/plugin/libs/websocket/_url.py @@ -1,11 +1,15 @@ -""" +import os +import socket +import struct + +from typing import Optional +from urllib.parse import unquote, urlparse -""" """ _url.py websocket - WebSocket client library for Python -Copyright 2021 engn33r +Copyright 2023 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,17 +24,10 @@ limitations under the License. """ -import os -import socket -import struct - -from urllib.parse import unquote, urlparse - - __all__ = ["parse_url", "get_proxy_info"] -def parse_url(url): +def parse_url(url: str) -> tuple: """ parse url and the result is tuple of (hostname, port, resource path and the flag of secure mode) @@ -79,7 +76,7 @@ def parse_url(url): DEFAULT_NO_PROXY_HOST = ["localhost", "127.0.0.1"] -def _is_ip_address(addr): +def _is_ip_address(addr: str) -> bool: try: socket.inet_aton(addr) except socket.error: @@ -88,7 +85,7 @@ def _is_ip_address(addr): return True -def _is_subnet_address(hostname): +def _is_subnet_address(hostname: str) -> bool: try: addr, netmask = hostname.split("/") return _is_ip_address(addr) and 0 <= int(netmask) < 32 @@ -96,7 +93,7 @@ def _is_subnet_address(hostname): return False -def _is_address_in_network(ip, net): +def _is_address_in_network(ip: str, net: str) -> bool: ipaddr = struct.unpack('!I', socket.inet_aton(ip))[0] netaddr, netmask = net.split('/') netaddr = struct.unpack('!I', socket.inet_aton(netaddr))[0] @@ -105,7 +102,7 @@ def _is_address_in_network(ip, net): return ipaddr & netmask == netaddr -def _is_no_proxy_host(hostname, no_proxy): +def _is_no_proxy_host(hostname: str, no_proxy: Optional[list]) -> bool: if not no_proxy: v = os.environ.get("no_proxy", os.environ.get("NO_PROXY", "")).replace(" ", "") if v: @@ -126,8 +123,8 @@ def _is_no_proxy_host(hostname, no_proxy): def get_proxy_info( - hostname, is_secure, proxy_host=None, proxy_port=0, proxy_auth=None, - no_proxy=None, proxy_type='http'): + hostname: str, is_secure: bool, proxy_host: Optional[str] = None, proxy_port: int = 0, proxy_auth: Optional[tuple] = None, + no_proxy: Optional[list] = None, proxy_type: str = 'http') -> tuple: """ Try to retrieve proxy host and port from environment if not provided in options. @@ -141,14 +138,14 @@ def get_proxy_info( Websocket server name. is_secure: bool Is the connection secure? (wss) looks for "https_proxy" in env - before falling back to "http_proxy" + instead of "http_proxy" proxy_host: str http proxy host name. - http_proxy_port: str or int + proxy_port: str or int http proxy port. - http_no_proxy: list + no_proxy: list Whitelisted host names that don't use the proxy. - http_proxy_auth: tuple + proxy_auth: tuple HTTP proxy auth information. Tuple of username and password. Default is None. proxy_type: str Specify the proxy protocol (http, socks4, socks4a, socks5, socks5h). Default is "http". @@ -162,15 +159,11 @@ def get_proxy_info( auth = proxy_auth return proxy_host, port, auth - env_keys = ["http_proxy"] - if is_secure: - env_keys.insert(0, "https_proxy") - - for key in env_keys: - value = os.environ.get(key, os.environ.get(key.upper(), "")).replace(" ", "") - if value: - proxy = urlparse(value) - auth = (unquote(proxy.username), unquote(proxy.password)) if proxy.username else None - return proxy.hostname, proxy.port, auth + env_key = "https_proxy" if is_secure else "http_proxy" + value = os.environ.get(env_key, os.environ.get(env_key.upper(), "")).replace(" ", "") + if value: + proxy = urlparse(value) + auth = (unquote(proxy.username), unquote(proxy.password)) if proxy.username else None + return proxy.hostname, proxy.port, auth return None, 0, None diff --git a/plugin/libs/websocket/_utils.py b/plugin/libs/websocket/_utils.py index feed027e..62ba0b01 100644 --- a/plugin/libs/websocket/_utils.py +++ b/plugin/libs/websocket/_utils.py @@ -1,8 +1,10 @@ +from typing import Union + """ _url.py websocket - WebSocket client library for Python -Copyright 2021 engn33r +Copyright 2023 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,12 +21,12 @@ __all__ = ["NoLock", "validate_utf8", "extract_err_message", "extract_error_code"] -class NoLock(object): +class NoLock: - def __enter__(self): + def __enter__(self) -> None: pass - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback) -> None: pass @@ -33,7 +35,7 @@ def __exit__(self, exc_type, exc_value, traceback): # strings. from wsaccel.utf8validator import Utf8Validator - def _validate_utf8(utfbytes): + def _validate_utf8(utfbytes: bytes) -> bool: return Utf8Validator().validate(utfbytes)[0] except ImportError: @@ -63,7 +65,7 @@ def _validate_utf8(utfbytes): 12,12,12,12,12,12,12,36,12,36,12,12, 12,36,12,12,12,12,12,36,12,36,12,12, 12,36,12,12,12,12,12,12,12,12,12,12, ] - def _decode(state, codep, ch): + def _decode(state: int, codep: int, ch: int) -> tuple: tp = _UTF8D[ch] codep = (ch & 0x3f) | (codep << 6) if ( @@ -72,7 +74,7 @@ def _decode(state, codep, ch): return state, codep - def _validate_utf8(utfbytes): + def _validate_utf8(utfbytes: Union[str, bytes]) -> bool: state = _UTF8_ACCEPT codep = 0 for i in utfbytes: @@ -83,7 +85,7 @@ def _validate_utf8(utfbytes): return True -def validate_utf8(utfbytes): +def validate_utf8(utfbytes: Union[str, bytes]) -> bool: """ validate utf8 byte string. utfbytes: utf byte string to check. @@ -92,13 +94,13 @@ def validate_utf8(utfbytes): return _validate_utf8(utfbytes) -def extract_err_message(exception): +def extract_err_message(exception: Exception) -> Union[str, None]: if exception.args: return exception.args[0] else: return None -def extract_error_code(exception): +def extract_error_code(exception: Exception) -> Union[int, None]: if exception.args and len(exception.args) > 1: return exception.args[0] if isinstance(exception.args[0], int) else None diff --git a/plugin/libs/websocket/_wsdump.py b/plugin/libs/websocket/_wsdump.py new file mode 100644 index 00000000..d637ce2b --- /dev/null +++ b/plugin/libs/websocket/_wsdump.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python3 + +""" +wsdump.py +websocket - WebSocket client library for Python + +Copyright 2023 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import argparse +import code +import sys +import threading +import time +import ssl +import gzip +import zlib +from urllib.parse import urlparse + +import websocket + +try: + import readline +except ImportError: + pass + + +def get_encoding() -> str: + encoding = getattr(sys.stdin, "encoding", "") + if not encoding: + return "utf-8" + else: + return encoding.lower() + + +OPCODE_DATA = (websocket.ABNF.OPCODE_TEXT, websocket.ABNF.OPCODE_BINARY) +ENCODING = get_encoding() + + +class VAction(argparse.Action): + + def __call__(self, parser: argparse.Namespace, args: tuple, values: str, option_string: str = None) -> None: + if values is None: + values = "1" + try: + values = int(values) + except ValueError: + values = values.count("v") + 1 + setattr(args, self.dest, values) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="WebSocket Simple Dump Tool") + parser.add_argument("url", metavar="ws_url", + help="websocket url. ex. ws://echo.websocket.events/") + parser.add_argument("-p", "--proxy", + help="proxy url. ex. http://127.0.0.1:8080") + parser.add_argument("-v", "--verbose", default=0, nargs='?', action=VAction, + dest="verbose", + help="set verbose mode. If set to 1, show opcode. " + "If set to 2, enable to trace websocket module") + parser.add_argument("-n", "--nocert", action='store_true', + help="Ignore invalid SSL cert") + parser.add_argument("-r", "--raw", action="store_true", + help="raw output") + parser.add_argument("-s", "--subprotocols", nargs='*', + help="Set subprotocols") + parser.add_argument("-o", "--origin", + help="Set origin") + parser.add_argument("--eof-wait", default=0, type=int, + help="wait time(second) after 'EOF' received.") + parser.add_argument("-t", "--text", + help="Send initial text") + parser.add_argument("--timings", action="store_true", + help="Print timings in seconds") + parser.add_argument("--headers", + help="Set custom headers. Use ',' as separator") + + return parser.parse_args() + + +class RawInput: + + def raw_input(self, prompt: str = "") -> str: + line = input(prompt) + + if ENCODING and ENCODING != "utf-8" and not isinstance(line, str): + line = line.decode(ENCODING).encode("utf-8") + elif isinstance(line, str): + line = line.encode("utf-8") + + return line + + +class InteractiveConsole(RawInput, code.InteractiveConsole): + + def write(self, data: str) -> None: + sys.stdout.write("\033[2K\033[E") + # sys.stdout.write("\n") + sys.stdout.write("\033[34m< " + data + "\033[39m") + sys.stdout.write("\n> ") + sys.stdout.flush() + + def read(self) -> str: + return self.raw_input("> ") + + +class NonInteractive(RawInput): + + def write(self, data: str) -> None: + sys.stdout.write(data) + sys.stdout.write("\n") + sys.stdout.flush() + + def read(self) -> str: + return self.raw_input("") + + +def main() -> None: + start_time = time.time() + args = parse_args() + if args.verbose > 1: + websocket.enableTrace(True) + options = {} + if args.proxy: + p = urlparse(args.proxy) + options["http_proxy_host"] = p.hostname + options["http_proxy_port"] = p.port + if args.origin: + options["origin"] = args.origin + if args.subprotocols: + options["subprotocols"] = args.subprotocols + opts = {} + if args.nocert: + opts = {"cert_reqs": ssl.CERT_NONE, "check_hostname": False} + if args.headers: + options['header'] = list(map(str.strip, args.headers.split(','))) + ws = websocket.create_connection(args.url, sslopt=opts, **options) + if args.raw: + console = NonInteractive() + else: + console = InteractiveConsole() + print("Press Ctrl+C to quit") + + def recv() -> tuple: + try: + frame = ws.recv_frame() + except websocket.WebSocketException: + return websocket.ABNF.OPCODE_CLOSE, "" + if not frame: + raise websocket.WebSocketException("Not a valid frame {frame}".format(frame=frame)) + elif frame.opcode in OPCODE_DATA: + return frame.opcode, frame.data + elif frame.opcode == websocket.ABNF.OPCODE_CLOSE: + ws.send_close() + return frame.opcode, "" + elif frame.opcode == websocket.ABNF.OPCODE_PING: + ws.pong(frame.data) + return frame.opcode, frame.data + + return frame.opcode, frame.data + + def recv_ws() -> None: + while True: + opcode, data = recv() + msg = None + if opcode == websocket.ABNF.OPCODE_TEXT and isinstance(data, bytes): + data = str(data, "utf-8") + if isinstance(data, bytes) and len(data) > 2 and data[:2] == b'\037\213': # gzip magick + try: + data = "[gzip] " + str(gzip.decompress(data), "utf-8") + except: + pass + elif isinstance(data, bytes): + try: + data = "[zlib] " + str(zlib.decompress(data, -zlib.MAX_WBITS), "utf-8") + except: + pass + + if isinstance(data, bytes): + data = repr(data) + + if args.verbose: + msg = "{opcode}: {data}".format(opcode=websocket.ABNF.OPCODE_MAP.get(opcode), data=data) + else: + msg = data + + if msg is not None: + if args.timings: + console.write(str(time.time() - start_time) + ": " + msg) + else: + console.write(msg) + + if opcode == websocket.ABNF.OPCODE_CLOSE: + break + + thread = threading.Thread(target=recv_ws) + thread.daemon = True + thread.start() + + if args.text: + ws.send(args.text) + + while True: + try: + message = console.read() + ws.send(message) + except KeyboardInterrupt: + return + except EOFError: + time.sleep(args.eof_wait) + return + + +if __name__ == "__main__": + try: + main() + except Exception as e: + print(e)