Skip to content

Commit

Permalink
feat: add sctp
Browse files Browse the repository at this point in the history
  • Loading branch information
lchenut committed Oct 18, 2024
1 parent 7204cf7 commit 4659924
Show file tree
Hide file tree
Showing 4 changed files with 669 additions and 0 deletions.
329 changes: 329 additions & 0 deletions webrtc/sctp/sctp_connection.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,329 @@
# Nim-WebRTC
# Copyright (c) 2024 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.

import nativesockets, bitops, sequtils
import usrsctp, chronos, chronicles, stew/[ptrops, endians2, byteutils]
import ./sctp_utils, ../errors, ../dtls/dtls_connection

logScope:
topics = "webrtc sctp_connection"

const
SctpConnTracker* = "webrtc.sctp.conn"
IPPROTO_SCTP = 132 # Official IANA number

type
SctpConnOnClose* = proc() {.raises: [], gcsafe.}

SctpState* = enum
SctpConnecting
SctpConnected
SctpClosed

SctpMessageParameters* = object
# This object is used to help manage messages exchanged over SCTP
# within the DataChannel stack.
protocolId*: uint32
# protocolId is used to distinguish different protocols within
# SCTP stream. In WebRTC, this is used to define the type of application
# data being transferred (text data, binary data...).
streamId*: uint16
# streamId identifies the specific SCTP stream. In WebRTC, each
# DataChannel corresponds to a different stream, so the streamId is
# used to map the message to the appropriate DataChannel.
endOfRecord*: bool
# endOfRecord indicates whether the current SCTP message is the
# final part of a record or not. This is related to the
# fragmentation and reassembly of messages.
unordered*: bool
# The unordered flag determines whether the message should be
# delivered in order or not. SCTP allows for both ordered and
# unordered delivery of messages.

SctpMessage* = ref object
data*: seq[byte]
info*: sctp_recvv_rn
params*: SctpMessageParameters

SctpConn* = ref object
conn: DtlsConn # Underlying DTLS Connection
sctpSocket*: ptr socket # Current usrsctp socket

state*: SctpState # Current Sctp State
onClose: seq[SctpConnOnClose] # List of procedure to run while closing a connection

connectEvent*: AsyncEvent # Event fired when the connection is connected
acceptEvent*: AsyncEvent # Event fired when the connection is accepted

# Infinite loop reading on the underlying DTLS Connection.
readLoop: Future[void].Raising([CancelledError, WebRtcError])

dataRecv: AsyncQueue[SctpMessage] # Queue of messages to be read
sendQueue: seq[byte] # Queue of messages to be sent

proc remoteAddress*(self: SctpConn): TransportAddress =
if self.conn.isNil():
raise newException(WebRtcError, "SCTP - Connection not set")
return self.conn.remoteAddress()

template usrsctpAwait(self: SctpConn, body: untyped): untyped =
# usrsctpAwait is template which set `sendQueue` to @[] then calls
# an usrsctp function. If during the synchronous run of the usrsctp function
# `sendQueue` is set, it is sent at the end of the function.
proc trySend(conn: SctpConn) {.async: (raises: [CancelledError]).} =
try:
trace "Send To", address = conn.remoteAddress()
await conn.conn.write(self.sendQueue)
except CatchableError as exc:
trace "Send Failed", exceptionMsg = exc.msg

self.sendQueue = @[]
when type(body) is void:
(body)
if self.sendQueue.len() > 0:
await self.trySend()
else:
let res = (body)
if self.sendQueue.len() > 0:
await self.trySend()
res

# -- usrsctp send and receive callback --

proc recvCallback*(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
# Callback procedure called when we receive data after a connection
# has been established.
let
conn = cast[SctpConn](data)
events = usrsctp_get_events(sock)

trace "Receive callback", events
if bitand(events, SCTP_EVENT_READ) != 0:
var
message = SctpMessage(data: newSeq[byte](4096))
address: Sockaddr_storage
rn: sctp_recvv_rn
addressLen = sizeof(Sockaddr_storage).SockLen
rnLen = sizeof(sctp_recvv_rn).SockLen
infotype: uint
flags: int
let n = sock.usrsctp_recvv(
cast[pointer](addr message.data[0]),
message.data.len.uint,
cast[ptr SockAddr](addr address),
cast[ptr SockLen](addr addressLen),
cast[pointer](addr message.info),
cast[ptr SockLen](addr rnLen),
cast[ptr cuint](addr infotype),
cast[ptr cint](addr flags),
)
if n < 0:
warn "usrsctp_recvv", error = sctpStrerror()
return
elif n > 0:
message.data.delete(n ..< message.data.len())
trace "message info from handle upcall", msginfo = message.info
message.params = SctpMessageParameters(
protocolId: message.info.recvv_rcvinfo.rcv_ppid.swapBytes(),
streamId: message.info.recvv_rcvinfo.rcv_sid,
)
if bitand(flags, MSG_NOTIFICATION) != 0:
let notif = cast[ptr sctp_notification](data)
trace "Notification received", notifType = notif.sn_header.sn_type
else:
try:
conn.dataRecv.addLastNoWait(message)
except AsyncQueueFullError:
trace "Queue full, dropping packet"
elif bitand(events, SCTP_EVENT_WRITE) != 0:
trace "sctp event write in the upcall"
else:
warn "Handle Upcall unexpected event", events

proc sendCallback*(
ctx: pointer, buffer: pointer, length: uint, tos: uint8, set_df: uint8
): cint {.cdecl.} =
# This proc is called by usrsctp everytime usrsctp tries to send data.
let
conn = cast[SctpConn](ctx)
buf = @(buffer.makeOpenArray(byte, int(length)))
trace "sendCallback", sctpPacket = $(buf.getSctpPacket())
proc testSend() {.async: (raises: [CancelledError]).} =
try:
trace "Send To", address = conn.remoteAddress()
await conn.conn.write(buf)
except CatchableError as exc:
trace "Send Failed", message = exc.msg

conn.sendQueue = buf

proc addOnClose*(self: SctpConn, onCloseProc: SctpConnOnClose) =
## Adds a proc to be called when SctpConn is closed
##
self.onClose.add(onCloseProc)

proc readLoopProc(self: SctpConn) {.async: (raises: [CancelledError, WebRtcError]).} =
while true:
var msg = await self.conn.read()
if msg == @[]:
trace "Sctp read loop stopped, DTLS connection closed"
return
trace "Receive data",
remoteAddress = self.conn.remoteAddress(), sctPacket = $(msg.getSctpPacket())
self.usrsctpAwait:
usrsctp_conninput(cast[pointer](self), addr msg[0], uint(msg.len), 0)

proc new*(T: typedesc[SctpConn], conn: DtlsConn): T =
result = T(
conn: conn,
state: SctpConnecting,
connectEvent: AsyncEvent(),
acceptEvent: AsyncEvent(),
dataRecv: newAsyncQueue[SctpMessage](),
)
result.readLoop = result.readLoopProc()
usrsctp_register_address(cast[pointer](result))

proc connect*(self: SctpConn, sctpPort: uint16) {.async: (raises: [CancelledError, WebRtcError]).} =
var sconn: Sockaddr_conn
when compiles(sconn.sconn_len):
sconn.sconn_len = sizeof(sconn).uint8
sconn.sconn_family = AF_CONN
sconn.sconn_port = htons(sctpPort)
sconn.sconn_addr = cast[pointer](self)
let connErr = self.usrsctpAwait: self.sctpSocket.usrsctp_connect(
cast[ptr SockAddr](addr sconn), SockLen(sizeof(sconn))
)
if connErr != 0 and errno != SctpEINPROGRESS:
raise
newException(WebRtcError, "SCTP - Connection failed: " & sctpStrerror())

proc read*(self: SctpConn): Future[SctpMessage] {.async: (raises: [CancelledError, WebRtcError]).} =
# Used by DataChannel, returns SctpMessage in order to get the stream
# and protocol ids
if self.state == SctpClosed:
raise newException(WebRtcError, "Try to read on an already closed SctpConn")
return await self.dataRecv.popFirst()

proc toFlags(params: SctpMessageParameters): uint16 =
if params.endOfRecord:
result = result or SCTP_EOR
if params.unordered:
result = result or SCTP_UNORDERED

proc write*(
self: SctpConn, buf: seq[byte], sendParams = default(SctpMessageParameters)
) {.async: (raises: [CancelledError, WebRtcError]).} =
# Used by DataChannel, writes buf on the Dtls connection.
if self.state == SctpClosed:
raise newException(WebRtcError, "Try to write on an already closed SctpConn")
var cpy = buf
let sendvErr =
if sendParams == default(SctpMessageParameters):
# If writes is called by DataChannel, sendParams should never
# be the default value. This split is useful for testing.
self.usrsctpAwait:
self.sctpSocket.usrsctp_sendv(
cast[pointer](addr cpy[0]),
cpy.len().uint,
nil,
0,
nil,
0,
SCTP_SENDV_NOINFO.cuint,
0,
)
else:
var sendInfo = sctp_sndinfo(
snd_sid: sendParams.streamId,
snd_ppid: sendParams.protocolId.swapBytes(),
snd_flags: sendParams.toFlags(),
)
self.usrsctpAwait:
self.sctpSocket.usrsctp_sendv(
cast[pointer](addr cpy[0]),
cpy.len().uint,
nil,
0,
cast[pointer](addr sendInfo),
sizeof(sendInfo).SockLen,
SCTP_SENDV_SNDINFO.cuint,
0,
)
if sendvErr < 0:
raise newException(WebRtcError, "SCTP - " & sctpStrerror())

proc write*(
self: SctpConn, s: string
) {.async: (raises: [CancelledError, WebRtcError]).} =
await self.write(s.toBytes())

type
# This object is a workaround, srs_stream_list in usrsctp is an
# UncheckedArray, and they're not assignable.
sctp_reset_streams_workaround = object
srs_assoc_id: sctp_assoc_t
srs_flags: uint16
srs_number_streams: uint16
srs_stream_list: array[1, uint16]

proc closeChannel*(self: SctpConn, streamId: uint16) =
## Resets a specific outgoing SCTP stream identified by
## streamId to close the associated DataChannel.
var srs: sctp_reset_streams_workaround
let len = sizeof(srs)

srs.srs_flags = SCTP_STREAM_RESET_OUTGOING
srs.srs_number_streams = 1
srs.srs_stream_list[0] = streamId
let ret = usrsctp_setsockopt(
self.sctpSocket,
IPPROTO_SCTP,
SCTP_RESET_STREAMS,
addr srs,
len.SockLen
)
if ret < 0:
raise newException(WebRtcError, "SCTP - Close channel failed: " & sctpStrerror())

proc closeAllChannels*(self: SctpConn) =
## Resets all outgoing SCTP streams, effectively closing all
## open DataChannels for the current SCTP connection.
var srs: sctp_reset_streams_workaround
let len = sizeof(srs) - sizeof(srs.srs_stream_list)

srs.srs_flags = SCTP_STREAM_RESET_OUTGOING
srs.srs_number_streams = 0 # 0 means all channels
let ret = usrsctp_setsockopt(
self.sctpSocket,
IPPROTO_SCTP,
SCTP_RESET_STREAMS,
addr srs,
len.SockLen
)
if ret < 0:
raise newException(WebRtcError, "SCTP - Close all channels failed: " & sctpStrerror())

proc close*(self: SctpConn) {.async: (raises: [CancelledError, WebRtcError]).} =
## Closes the entire SCTP connection by resetting all channels,
## deregistering the address, stopping the read loop, and cleaning up resources.
if self.state == SctpClosed:
debug "Try to close SctpConn twice"
return
self.closeAllChannels()
usrsctp_deregister_address(cast[pointer](self))
self.usrsctpAwait:
self.sctpSocket.usrsctp_close()
await self.readLoop.cancelAndWait()
self.state = SctpClosed
untrackCounter(SctpConnTracker)
await self.conn.close()
for onCloseProc in self.onClose:
onCloseProc()
self.onClose = @[]
69 changes: 69 additions & 0 deletions webrtc/sctp/sctp_logutils.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Nim-WebRTC
# Copyright (c) 2024 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.

import binary_serialization

# This file defines custom objects and procedures to improve the
# readability and accuracy of logging SCTP messages. The default
# usrsctp logger may not provide sufficient detail or clarity for
# SCTP message analysis, so this implementation creates more structured
# and informative logs. By parsing and formatting SCTP packet headers
# and chunks into human-readable strings, it provides clearer insights
# into the data being transmitted. This aids debugging by offering a
# more descriptive view of SCTP traffic than what is available
# by default.

type
SctpChunk* = object
chunkType*: uint8
flag*: uint8
length* {.bin_value: it.data.len() + 4.}: uint16
data* {.bin_len: it.length - 4.}: seq[byte]

SctpPacketHeader* = object
srcPort*: uint16
dstPort*: uint16
verifTag*: uint32
checksum*: uint32

SctpPacketStructure* = object
header*: SctpPacketHeader
chunks*: seq[SctpChunk]

proc dataToString(data: seq[byte]): string =
if data.len() < 8:
return $data
result = "@["
result &= $data[0] & ", " & $data[1] & ", " & $data[2] & ", " & $data[3] & " ... "
result &= $data[^4] & ", " & $data[^3] & ", " & $data[^2] & ", " & $data[^1] & "]"

proc `$`*(packet: SctpPacketStructure): string =
result = "{header: {srcPort: "
result &= $(packet.header.srcPort) & ", dstPort: "
result &= $(packet.header.dstPort) & "}, chunks: @["
var counter = 0
for chunk in packet.chunks:
result &= "{type: " & $(chunk.chunkType) & ", len: "
result &= $(chunk.length) & ", data: "
result &= chunk.data.dataToString()
counter += 1
if counter < packet.chunks.len():
result &= ", "
result &= "]}"

proc getSctpPacket*(buffer: seq[byte]): SctpPacketStructure =
result.header = Binary.decode(buffer, SctpPacketHeader)
var size = sizeof(SctpPacketHeader)
while size < buffer.len:
let chunk = Binary.decode(buffer[size ..^ 1], SctpChunk)
result.chunks.add(chunk)
size.inc(chunk.length.int)
while size mod 4 != 0:
# padding; could use `size.inc(-size %% 4)` instead but it lacks clarity
size.inc(1)
Loading

0 comments on commit 4659924

Please sign in to comment.