diff --git a/scapy/contrib/websocket.py b/scapy/contrib/websocket.py new file mode 100644 index 00000000000..e99df3f0797 --- /dev/null +++ b/scapy/contrib/websocket.py @@ -0,0 +1,264 @@ +# SPDX-License-Identifier: GPL-2.0-or-later +# This file is part of Scapy +# See https://scapy.net/ for more information +# Copyright (C) 2024 Lucas Drufva + +# scapy.contrib.description = WebSocket +# scapy.contrib.status = loads + +# Based on rfc6455 + +import struct +import base64 +import zlib +from hashlib import sha1 +from scapy.fields import (BitFieldLenField, Field, BitField, BitEnumField, ConditionalField, XNBytesField) +from scapy.layers.http import HTTPRequest, HTTPResponse, HTTP +from scapy.layers.inet import TCP +from scapy.packet import Packet +from scapy.error import Scapy_Exception +import logging + + +class PayloadLenField(BitFieldLenField): + + def __init__(self, name, default, length_of, size=0, tot_size=0, end_tot_size=0): + # Initialize with length_of (like in BitFieldLenField) and lengthFrom (like in BitLenField) + super().__init__(name, default, size, length_of=length_of, tot_size=tot_size, end_tot_size=end_tot_size) + + def getfield(self, pkt, s): + s, _ = s + # Get the 7-bit field (first byte) + length_byte = s[0] & 0x7F + s = s[1:] + + if length_byte <= 125: + # 7-bit length + return s, length_byte + elif length_byte == 126: + # 16-bit length + length = struct.unpack("!H", s[:2])[0] # Read 2 bytes + s = s[2:] + return s, length + elif length_byte == 127: + # 64-bit length + length = struct.unpack("!Q", s[:8])[0] # Read 8 bytes + s = s[8:] + return s, length + + def addfield(self, pkt, s, val): + p_field, p_val = pkt.getfield_and_val(self.length_of) + val = p_field.i2len(pkt, p_val) + + if val <= 125: + self.size = 7 + return super().addfield(pkt, s, val) + elif val <= 0xFFFF: + self.size = 7+16 + s, _, masked = s + return s + struct.pack("!BH", 126 | masked, val) + elif val <= 0xFFFFFFFFFFFFFFFF: + self.size = 7+64 + s, _, masked = s + return s + struct.pack("!BQ", 127 | masked, val) + else: + raise Scapy_Exception("%s: Payload length too large" % + self.__class__.__name__) + + + +class PayloadField(Field): + """ + Field for handling raw byte payloads with dynamic size. + The length of the payload is described by a preceding PayloadLenField. + """ + __slots__ = ["lengthFrom"] + + def __init__(self, name, lengthFrom): + """ + :param name: Field name + :param lengthFrom: Field name that provides the length of the payload + """ + super(PayloadField, self).__init__(name, None) + self.lengthFrom = lengthFrom + + def getfield(self, pkt, s): + # Fetch the length from the field that specifies the length + length = getattr(pkt, self.lengthFrom) + payloadData = s[:length] + + if pkt.mask: + key = struct.pack("I", pkt.maskingKey)[::-1] + data_int = int.from_bytes(payloadData, 'big') + mask_repeated = key * (len(payloadData) // 4) + key[: len(payloadData) % 4] + mask_int = int.from_bytes(mask_repeated, 'big') + payloadData = (data_int ^ mask_int).to_bytes(len(payloadData), 'big') + + if("permessage-deflate" in pkt.extensions): + try: + payloadData = pkt.decoder[0](payloadData + b"\x00\x00\xff\xff") + except Exception: + logging.debug("Failed to decompress payload", payloadData) + + return s[length:], payloadData + + def addfield(self, pkt, s, val): + if pkt.mask: + key = struct.pack("I", pkt.maskingKey)[::-1] + data_int = int.from_bytes(val, 'big') + mask_repeated = key * (len(val) // 4) + key[: len(val) % 4] + mask_int = int.from_bytes(mask_repeated, 'big') + val = (data_int ^ mask_int).to_bytes(len(val), 'big') + + return s + bytes(val) + + def i2len(self, pkt, val): + # Length of the payload in bytes + return len(val) + +class WebSocket(Packet): + __slots__ = ["extensions", "decoder"] + + name = "WebSocket" + fields_desc = [ + BitField("fin", 0, 1), + BitField("rsv", 0, 3), + BitEnumField("opcode", 0, 4, + { + 0x0: "none", + 0x1: "text", + 0x2: "binary", + 0x8: "close", + 0x9: "ping", + 0xA: "pong", + }), + BitField("mask", 0, 1), + PayloadLenField("payloadLen", 0, length_of="wsPayload", size=1), + ConditionalField(XNBytesField("maskingKey", 0, sz=4), lambda pkt: pkt.mask == 1), + PayloadField("wsPayload", lengthFrom="payloadLen") + ] + + def __init__(self, pkt=None, extensions=[], decoder=None, *args, **fields): + self.extensions = extensions + self.decoder = decoder + super().__init__(_pkt=pkt, *args, **fields) + + def extract_padding(self, s): + return '', s + + @classmethod + def tcp_reassemble(cls, data, metadata, session): + # data = the reassembled data from the same request/flow + # metadata = empty dictionary, that can be used to store data + # during TCP reassembly + # session = a dictionary proper to the bidirectional TCP session, + # that can be used to store anything + # [...] + # If the packet is available, return it. Otherwise don't. + # Whenever you return a packet, the buffer will be discarded. + + + HANDSHAKE_STATE_CLIENT_OPEN = 0 + HANDSHAKE_STATE_SERVER_OPEN = 1 + HANDSHAKE_STATE_OPEN = 2 + + if "handshake-state" not in session: + session["handshake-state"] = HANDSHAKE_STATE_CLIENT_OPEN + + if "extensions" not in session: + session["extensions"] = {} + + + if session["handshake-state"] == HANDSHAKE_STATE_CLIENT_OPEN: + http_data = HTTP(data) + if HTTPRequest in http_data: + http_data = http_data[HTTPRequest] + else: + return http_data + + if http_data.Method != b"GET": + return HTTP()/http_data + + if not http_data.Upgrade or http_data.Upgrade.lower() != b"websocket": + return HTTP()/http_data + + if not http_data.Unknown_Headers or b"Sec-WebSocket-Key" not in http_data.Unknown_Headers: + return HTTP()/http_data + + session["handshake-key"] = http_data.Unknown_Headers[b"Sec-WebSocket-Key"] + + if "original" in metadata: + session["server-port"] = metadata["original"][TCP].dport + + session["handshake-state"] = HANDSHAKE_STATE_SERVER_OPEN + + return http_data + + elif session["handshake-state"] == HANDSHAKE_STATE_SERVER_OPEN: + http_data = HTTP(data) + if HTTPResponse in http_data: + http_data = http_data[HTTPResponse] + else: + return http_data + + if not http_data.Upgrade.lower() == b"websocket": + return HTTP()/http_data + + # Verify key-accept handshake: + correct_accept = base64.b64encode(sha1(session["handshake-key"] + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11".encode()).digest()) + if not http_data.Unknown_Headers or b"Sec-WebSocket-Accept" not in http_data.Unknown_Headers or http_data.Unknown_Headers[b"Sec-WebSocket-Accept"] != correct_accept: + # TODO: handle or Logg wrong accept key + pass + + if b"Sec-WebSocket-Extensions" in http_data.Unknown_Headers: + session["extensions"] = {} + for extension in http_data.Unknown_Headers[b"Sec-WebSocket-Extensions"].decode().strip().split(";"): + key_value_pair = extension.split("=", 1) + [None] + session["extensions"][key_value_pair[0].strip()] = key_value_pair[1] + + if "permessage-deflate" in session["extensions"]: + def create_decompressor(window_bits): + decoder = zlib.decompressobj(wbits=-window_bits) + def decomp(data): + nonlocal decoder + return decoder.decompress(data, 0) + + def reset(): + nonlocal decoder + nonlocal window_bits + decoder = zlib.decompressobj(wbits=-window_bits) + + return (decomp, reset) + + # Default values + client_wb = 12 + server_wb = 15 + + # Check for new values in extensions header + if "client_max_window_bits" in session["extensions"]: + client_wb = int(session["extensions"]["client_max_window_bits"]) + + if "server_max_window_bits" in session["extensions"]: + server_wb = int(session["extensions"]["server_max_window_bits"]) + + + session["server-decoder"] = create_decompressor(client_wb) + session["client-decoder"] = create_decompressor(server_wb) + + + session["handshake-state"] = HANDSHAKE_STATE_OPEN + + return HTTP()/http_data + + + # Handshake is done: + if "original" not in metadata: + return + + if "permessage-deflate" in session["extensions"]: + is_server = True if metadata["original"][TCP].sport == session["server-port"] else False + ws = WebSocket(bytes(data), extensions=session["extensions"], decoder = session["server-decoder"] if is_server else session["client-decoder"]) + return ws + else: + ws = WebSocket(bytes(data), extensions=session["extensions"]) + return ws \ No newline at end of file diff --git a/scapy/layers/http.py b/scapy/layers/http.py index 394d220e4fd..288d98bb954 100644 --- a/scapy/layers/http.py +++ b/scapy/layers/http.py @@ -529,10 +529,10 @@ def do_dissect(self, s): """From the HTTP packet string, populate the scapy object""" first_line, body = _dissect_headers(self, s) try: - Method, Path, HTTPVersion = re.split(br"\s+", first_line, maxsplit=2) - self.setfieldval('Method', Method) - self.setfieldval('Path', Path) - self.setfieldval('Http_Version', HTTPVersion) + method_path_version = re.split(br"\s+", first_line, maxsplit=2) + [None] + self.setfieldval('Method', method_path_version[0]) + self.setfieldval('Path', method_path_version[1]) + self.setfieldval('Http_Version', method_path_version[2]) except ValueError: pass if body: @@ -573,10 +573,10 @@ def do_dissect(self, s): ''' From the HTTP packet string, populate the scapy object ''' first_line, body = _dissect_headers(self, s) try: - HTTPVersion, Status, Reason = re.split(br"\s+", first_line, maxsplit=2) - self.setfieldval('Http_Version', HTTPVersion) - self.setfieldval('Status_Code', Status) - self.setfieldval('Reason_Phrase', Reason) + version_status_reason = re.split(br"\s+", first_line, maxsplit=2) + [None] + self.setfieldval('Http_Version', version_status_reason[0]) + self.setfieldval('Status_Code', version_status_reason[1]) + self.setfieldval('Reason_Phrase', version_status_reason[2]) except ValueError: pass if body: diff --git a/test/contrib/websocket.uts b/test/contrib/websocket.uts new file mode 100644 index 00000000000..c1141001b1a --- /dev/null +++ b/test/contrib/websocket.uts @@ -0,0 +1,96 @@ +# WebSocket layer unit tests +# Copyright (C) 2024 Lucas Drufva +# +# Type the following command to launch start the tests: +# $ test/run_tests -P "load_contrib('websocket')" -t test/contrib/websocket.uts + ++ Syntax check += Import the WebSocket layer +from scapy.contrib.websocket import * + ++ WebSocket protocol test += Packet instantiation +pkt = WebSocket(wsPayload=b"Hello, world!", opcode="text", mask=True, maskingKey=0x11223344) + +assert pkt.wsPayload == b"Hello, world!" +assert pkt.mask == True +assert pkt.maskingKey == 0x11223344 +assert bytes(pkt) == b'\x01\x8d\x11\x22\x33\x44\x59\x47\x5f\x28\x7e\x0e\x13\x33\x7e\x50\x5f\x20\x30' + += Packet dissection +raw = b'\x01\x0dHello, world!' +pkt = WebSocket(raw) + +assert pkt.fin == 0 +assert pkt.rsv == 0 +assert pkt.opcode == 0x1 +assert pkt.mask == False +assert pkt.payloadLen == 13 +assert pkt.wsPayload == b'Hello, world!' + += Dissect masked packet +raw = b'\x01\x8d\x11\x22\x33\x44\x59\x47\x5f\x28\x7e\x0e\x13\x33\x7e\x50\x5f\x20\x30' +pkt = WebSocket(raw) + +assert pkt.fin == 0 +assert pkt.rsv == 0 +assert pkt.opcode == 0x1 +assert pkt.mask == True +assert pkt.payloadLen == 13 +assert pkt.wsPayload == b'Hello, world!' + += Session with compression + +bind_layers(TCP, WebSocket, dport=5000) +bind_layers(TCP, WebSocket, sport=5000) + +from scapy.sessions import TCPSession + +filename = scapy_path("/test/pcaps/websocket_compressed_session.pcap") +pkts = sniff(offline=filename, session=TCPSession) + +assert len(pkts) == 13 + +assert pkts[7][WebSocket].wsPayload == b'Hello' +assert pkts[8][WebSocket].wsPayload == b'"Hello"' +assert pkts[10][WebSocket].wsPayload == b'Hello2' +assert pkts[11][WebSocket].wsPayload == b'"Hello2"' + += Create packet with long payload +pkt = WebSocket(wsPayload=b"a"*126, opcode="text") + +assert bytes(pkt) == b'\x01\x7e\x00\x7e\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61' + += Dissect packet with long payload +raw = b'\x01\x7e\x00\x7e\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61' +pkt = WebSocket(raw) + +assert pkt.payloadLen == 126 +assert pkt.wsPayload == b'a'*126 + += Create packet with very long payload +pkt = WebSocket(wsPayload=b"a"*65536, opcode="text") + +assert bytes(pkt) == b'\x01\x7f\x00\x00\x00\x00\x00\x01\x00\x00' + b'a'*65536 + += Dissect packet with very long payload +raw = b'\x01\x7f\x00\x00\x00\x00\x00\x01\x00\x00' + b'a'*65536 +pkt = WebSocket(raw) + +assert pkt.payloadLen == 65536 +assert pkt.wsPayload == b'a'*65536 + += Session with invalid parts in upgrade sequence + +bind_layers(TCP, WebSocket, dport=5000) +bind_layers(TCP, WebSocket, sport=5000) + +from scapy.sessions import TCPSession + +filename = scapy_path("/test/pcaps/websocket_invalid_handshakes.pcap") +pkts = sniff(offline=filename, session=TCPSession) + + +assert len(pkts) == 19 + +assert pkts[15][WebSocket].wsPayload == b'Hello, world!' \ No newline at end of file diff --git a/test/pcaps/websocket_compressed_session.pcap b/test/pcaps/websocket_compressed_session.pcap new file mode 100644 index 00000000000..76757e44273 Binary files /dev/null and b/test/pcaps/websocket_compressed_session.pcap differ diff --git a/test/pcaps/websocket_invalid_handshakes.pcap b/test/pcaps/websocket_invalid_handshakes.pcap new file mode 100644 index 00000000000..fe3040b32d6 Binary files /dev/null and b/test/pcaps/websocket_invalid_handshakes.pcap differ