From e3ab9e13acb8c1909ccadfacb0709a4dcfaa6be9 Mon Sep 17 00:00:00 2001 From: Ulf Lilleengen Date: Thu, 6 Jun 2024 15:55:02 +0200 Subject: [PATCH 1/2] refactor l2cap channel lifecycle - l2cap channels are now refcounted like connections - l2cap disconnects are handled in control future - rx future no longer blocked waiting for channel rx on a particular l2cap channel --- host/src/channel_manager.rs | 477 +++++++++++++++++++----------------- host/src/host.rs | 309 +++++++++++++++-------- host/src/l2cap.rs | 64 +++-- host/src/types/l2cap.rs | 6 + 4 files changed, 512 insertions(+), 344 deletions(-) diff --git a/host/src/channel_manager.rs b/host/src/channel_manager.rs index 4f4b4d5..38ad8ba 100644 --- a/host/src/channel_manager.rs +++ b/host/src/channel_manager.rs @@ -9,7 +9,7 @@ use embassy_sync::blocking_mutex::raw::NoopRawMutex; use embassy_sync::channel::Channel; use embassy_sync::waitqueue::WakerRegistration; -use crate::cursor::{ReadCursor, WriteCursor}; +use crate::cursor::WriteCursor; use crate::host::BleHost; use crate::packet_pool::{AllocId, GlobalPacketPool, Packet}; use crate::pdu::Pdu; @@ -17,7 +17,7 @@ use crate::types::l2cap::{ CommandRejectRes, DisconnectionReq, DisconnectionRes, L2capHeader, L2capSignalCode, L2capSignalHeader, LeCreditConnReq, LeCreditConnRes, LeCreditConnResultCode, LeCreditFlowInd, }; -use crate::{BleHostError, Error}; +use crate::{AclSender, BleHostError, Error}; const BASE_ID: u16 = 0x40; @@ -26,6 +26,7 @@ struct State<'d> { channels: &'d mut [ChannelStorage], accept_waker: WakerRegistration, create_waker: WakerRegistration, + disconnect_waker: WakerRegistration, } /// Channel manager for L2CAP channels used directly by clients. @@ -39,6 +40,10 @@ pub(crate) struct PacketChannel { chan: Channel, QLEN>, } +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct ChannelIndex(u8); + impl PacketChannel { pub(crate) const NEW: PacketChannel = PacketChannel { chan: Channel::new() }; @@ -50,9 +55,17 @@ impl PacketChannel { self.chan.send(Some(pdu)).await; } + pub fn try_send(&self, pdu: Pdu) -> Result<(), Error> { + self.chan.try_send(Some(pdu)).map_err(|_| Error::OutOfMemory) + } + pub async fn receive(&self) -> Option { self.chan.receive().await } + + pub fn clear(&self) { + self.chan.clear() + } } impl<'d> State<'d> { @@ -63,6 +76,15 @@ impl<'d> State<'d> { } } } + fn next_request_id(&mut self) -> u8 { + // 0 is an invalid identifier + if self.next_req_id == 0 { + self.next_req_id += 1; + } + let next = self.next_req_id; + self.next_req_id = self.next_req_id.wrapping_add(1); + next + } } impl<'d, const RXQ: usize> ChannelManager<'d, RXQ> { @@ -78,88 +100,51 @@ impl<'d, const RXQ: usize> ChannelManager<'d, RXQ> { channels, accept_waker: WakerRegistration::new(), create_waker: WakerRegistration::new(), + disconnect_waker: WakerRegistration::new(), }), inbound, } } fn next_request_id(&self) -> u8 { - let mut state = self.state.borrow_mut(); - // 0 is an invalid identifier - if state.next_req_id == 0 { - state.next_req_id += 1; - } - let next = state.next_req_id; - state.next_req_id = state.next_req_id.wrapping_add(1); - next + self.state.borrow_mut().next_request_id() } - pub(crate) fn disconnect(&self, cid: u16) -> Result { + pub(crate) fn disconnect(&self, index: ChannelIndex) { self.with_mut(|state| { - for (idx, storage) in state.channels.iter_mut().enumerate() { - match storage.state { - ChannelState::Disconnecting if cid == storage.cid => { - storage.state = ChannelState::Disconnected; - storage.cid = 0; - return Ok(ConnHandle::new(storage.conn)); - } - ChannelState::PeerConnecting(_) if cid == storage.cid => { - storage.state = ChannelState::Disconnecting; - let _ = self.inbound[idx].close(); - return Ok(ConnHandle::new(storage.conn)); - } - ChannelState::Connecting(_) if cid == storage.cid => { - storage.state = ChannelState::Disconnecting; - let _ = self.inbound[idx].close(); - return Ok(ConnHandle::new(storage.conn)); - } - ChannelState::Connected if cid == storage.cid => { - storage.state = ChannelState::Disconnecting; - let _ = self.inbound[idx].close(); - return Ok(ConnHandle::new(storage.conn)); - } - _ => {} - } + let chan = &mut state.channels[index.0 as usize]; + if chan.state == ChannelState::Connected { + chan.state = ChannelState::Disconnecting; + let _ = self.inbound[index.0 as usize].close(); + state.disconnect_waker.wake(); } - trace!("[l2cap][disconnect] channel {} not found", cid); - Err(Error::NotFound) }) } pub(crate) fn disconnected(&self, conn: ConnHandle) -> Result<(), Error> { let mut state = self.state.borrow_mut(); for (idx, storage) in state.channels.iter_mut().enumerate() { - match storage.state { - ChannelState::PeerConnecting(_) if conn.raw() == storage.conn => { - storage.state = ChannelState::Disconnecting; - let _ = self.inbound[idx].close(); - } - ChannelState::Connecting(_) if conn.raw() == storage.conn => { - storage.state = ChannelState::Disconnecting; - let _ = self.inbound[idx].close(); - } - ChannelState::Connected if conn.raw() == storage.conn => { - storage.state = ChannelState::Disconnecting; - let _ = self.inbound[idx].close(); - } - _ => {} + if conn.raw() == storage.conn { + let _ = self.inbound[idx].close(); + storage.close(); } - storage.credit_waker.wake(); } state.accept_waker.wake(); state.create_waker.wake(); Ok(()) } - fn alloc(&self, conn: ConnHandle, f: F) -> Result<(), Error> { + fn alloc(&self, conn: ConnHandle, f: F) -> Result { let mut state = self.state.borrow_mut(); for (idx, storage) in state.channels.iter_mut().enumerate() { - if let ChannelState::Disconnected = storage.state { + if ChannelState::Disconnected == storage.state && storage.refcount == 0 { + // Ensure inbound is empty. + self.inbound[idx].clear(); let cid: u16 = BASE_ID + idx as u16; storage.conn = conn.raw(); storage.cid = cid; f(storage); - return Ok(()); + return Ok(ChannelIndex(idx as u8)); } } Err(Error::NoChannelAvailable) @@ -172,12 +157,12 @@ impl<'d, const RXQ: usize> ChannelManager<'d, RXQ> { mtu: u16, credit_flow: CreditFlowPolicy, initial_credits: Option, - ble: &BleHost<'d, T>, - ) -> Result> { + ble: &BleHost<'_, T>, + ) -> Result> { // Wait until we find a channel for our connection in the connecting state matching our PSM. - let (req_id, mps, mtu, cid, credits) = poll_fn(|cx| { + let (idx, req_id, mps, mtu, cid, credits) = poll_fn(|cx| { let mut state = self.state.borrow_mut(); - for chan in state.channels.iter_mut() { + for (idx, chan) in state.channels.iter_mut().enumerate() { match chan.state { ChannelState::PeerConnecting(req_id) if chan.conn == conn.raw() && psm.contains(&chan.psm) => { chan.mps = chan.mps.min(self.pool.mtu() as u16 - 4); @@ -189,7 +174,14 @@ impl<'d, const RXQ: usize> ChannelManager<'d, RXQ> { ); chan.state = ChannelState::Connected; - return Poll::Ready((req_id, chan.mps, chan.mtu, chan.cid, chan.flow_control.available())); + return Poll::Ready(( + ChannelIndex(idx as u8), + req_id, + chan.mps, + chan.mtu, + chan.cid, + chan.flow_control.available(), + )); } _ => {} } @@ -215,15 +207,10 @@ impl<'d, const RXQ: usize> ChannelManager<'d, RXQ> { ) .await?; - // NOTE: This code is disabled as we send the credits in the response request. For some reason the nrf-softdevice doesn't do that, - // so lets keep this around in case we need it. - // Send initial credits - // let next_req_id = self.next_request_id(); - // controller - // .signal(conn, next_req_id, &LeCreditFlowInd { cid, credits }, &mut tx[..]) - // .await?; - // - Ok(cid) + self.with_mut(|state| { + state.channels[idx.0 as usize].refcount = 1; + }); + Ok(idx) } pub(crate) async fn create( @@ -233,15 +220,15 @@ impl<'d, const RXQ: usize> ChannelManager<'d, RXQ> { mtu: u16, credit_flow: CreditFlowPolicy, initial_credits: Option, - ble: &BleHost<'d, T>, - ) -> Result> { + ble: &BleHost<'_, T>, + ) -> Result> { let req_id = self.next_request_id(); let mut credits = 0; let mut cid: u16 = 0; let mps = self.pool.mtu() as u16 - 4; // Allocate space for our new channel. - self.alloc(conn, |storage| { + let idx = self.alloc(conn, |storage| { cid = storage.cid; credits = initial_credits.unwrap_or(self.pool.min_available(AllocId::from_channel(storage.cid)) as u16); storage.psm = psm; @@ -266,34 +253,26 @@ impl<'d, const RXQ: usize> ChannelManager<'d, RXQ> { // Wait until a response is accepted. poll_fn(|cx| { let mut state = self.state.borrow_mut(); - for storage in state.channels.iter_mut() { - match storage.state { - ChannelState::Disconnecting if storage.conn == conn.raw() && storage.cid == cid => { - return Poll::Ready(Err(Error::Disconnected)); - } - ChannelState::Connected if storage.conn == conn.raw() && storage.cid == cid => { - return Poll::Ready(Ok(())); - } - _ => {} + state.create_waker.register(cx.waker()); + let storage = &mut state.channels[idx.0 as usize]; + match storage.state { + ChannelState::Disconnecting | ChannelState::PeerDisconnecting => { + return Poll::Ready(Err(Error::Disconnected)); + } + ChannelState::Connected => { + storage.refcount = 1; + return Poll::Ready(Ok(())); } + _ => {} } - state.create_waker.register(cx.waker()); Poll::Pending }) .await?; - - // NOTE: This code is disabled as we send the credits in the response request. For some reason the nrf-softdevice doesn't do that, - // so lets keep this around in case we need it. - // Send initial credits - // let next_req_id = self.next_request_id(); - // let req = controller - // .signal(conn, next_req_id, &LeCreditFlowInd { cid, credits }, &mut tx[..]) - // .await?; - Ok(cid) + Ok(idx) } /// Dispatch an incoming L2CAP packet to the appropriate channel. - pub(crate) async fn dispatch(&self, header: L2capHeader, packet: Packet) -> Result<(), Error> { + pub(crate) fn dispatch(&self, header: L2capHeader, packet: Packet) -> Result<(), Error> { if header.channel < BASE_ID { return Err(Error::InvalidChannelId); } @@ -319,12 +298,11 @@ impl<'d, const RXQ: usize> ChannelManager<'d, RXQ> { Ok(()) })?; - self.inbound[chan].send(Pdu::new(packet, header.length as usize)).await; - Ok(()) + self.inbound[chan].try_send(Pdu::new(packet, header.length as usize)) } /// Handle incoming L2CAP signal - pub(crate) async fn signal(&self, conn: ConnHandle, data: &[u8]) -> Result<(), Error> { + pub(crate) fn signal(&self, conn: ConnHandle, data: &[u8]) -> Result<(), Error> { let (header, data) = L2capSignalHeader::from_hci_bytes(data)?; //trace!( // "[l2cap][conn = {:?}] received signal (req {}) code {:?}", @@ -349,19 +327,18 @@ impl<'d, const RXQ: usize> ChannelManager<'d, RXQ> { } L2capSignalCode::CommandRejectRes => { let (reject, _) = CommandRejectRes::from_hci_bytes(data)?; - warn!("Rejected: {:?}", reject); Ok(()) } L2capSignalCode::DisconnectionReq => { let req = DisconnectionReq::from_hci_bytes_complete(data)?; trace!("[l2cap][conn = {:?}, cid = {}] disconnect request", conn, req.dcid); - self.disconnect(req.dcid)?; + self.handle_disconnect_request(req.dcid)?; Ok(()) } L2capSignalCode::DisconnectionRes => { let res = DisconnectionRes::from_hci_bytes_complete(data)?; trace!("[l2cap][conn = {:?}, cid = {}] disconnect response", conn, res.scid); - self.handle_disconnect_response(&res) + self.handle_disconnect_response(res.scid) } _ => Err(Error::NotSupported), } @@ -431,19 +408,26 @@ impl<'d, const RXQ: usize> ChannelManager<'d, RXQ> { _ => {} } } - trace!("[l2cap][handle_credit_flow] peer channel {} not found", req.cid); + // trace!("[l2cap][handle_credit_flow] peer channel {} not found", req.cid); Err(Error::NotFound) } - fn handle_disconnect_response(&self, res: &DisconnectionRes) -> Result<(), Error> { - let cid = res.scid; + fn handle_disconnect_request(&self, cid: u16) -> Result<(), Error> { let mut state = self.state.borrow_mut(); for storage in state.channels.iter_mut() { if cid == storage.cid { - storage.state = ChannelState::Disconnected; - storage.cid = 0; - storage.peer_cid = 0; - storage.conn = 0; + storage.state = ChannelState::PeerDisconnecting; + break; + } + } + Ok(()) + } + + fn handle_disconnect_response(&self, cid: u16) -> Result<(), Error> { + let mut state = self.state.borrow_mut(); + for storage in state.channels.iter_mut() { + if storage.state == ChannelState::Disconnecting && cid == storage.cid { + storage.close(); break; } } @@ -455,31 +439,28 @@ impl<'d, const RXQ: usize> ChannelManager<'d, RXQ> { /// The length provided buffer slice must be equal or greater to the agreed MTU. pub(crate) async fn receive( &self, - cid: u16, + chan: ChannelIndex, buf: &mut [u8], ble: &BleHost<'d, T>, ) -> Result> { - let idx = self.connected_channel_index(cid)?; - let mut n_received = 1; - let packet = self.receive_pdu(cid, idx, ble).await?; + let packet = self.receive_pdu(chan, ble).await?; let len = packet.len; - let mut r = ReadCursor::new(packet.as_ref()); - let remaining: u16 = r.read()?; + let (first, data) = packet.as_ref().split_at(2); + let remaining: u16 = u16::from_le_bytes([first[0], first[1]]); - let data = r.remaining(); let to_copy = data.len().min(buf.len()); buf[..to_copy].copy_from_slice(&data[..to_copy]); let mut pos = to_copy; let mut remaining = remaining as usize - data.len(); - self.flow_control(cid, ble, packet.packet).await?; + self.flow_control(chan, ble, packet.packet).await?; // We have some k-frames to reassemble while remaining > 0 { - let packet = self.receive_pdu(cid, idx, ble).await?; + let packet = self.receive_pdu(chan, ble).await?; n_received += 1; let to_copy = packet.len.min(buf.len() - pos); if to_copy > 0 { @@ -487,36 +468,20 @@ impl<'d, const RXQ: usize> ChannelManager<'d, RXQ> { pos += to_copy; } remaining -= packet.len; - self.flow_control(cid, ble, packet.packet).await?; + self.flow_control(chan, ble, packet.packet).await?; } Ok(pos) } - // Return the array index for a given active channel - fn connected_channel_index(&self, cid: u16) -> Result { - let state = self.state.borrow(); - for (idx, chan) in state.channels.iter().enumerate() { - if chan.cid == cid && chan.state == ChannelState::Connected { - return Ok(idx); - } - } - trace!("[l2cap][connected_channel_index] channel {} closed", cid); - Err(Error::ChannelClosed) - } - async fn receive_pdu( &self, - cid: u16, - idx: usize, - ble: &BleHost<'d, T>, + chan: ChannelIndex, + ble: &BleHost<'_, T>, ) -> Result> { - match self.inbound[idx].receive().await { + match self.inbound[chan.0 as usize].receive().await { Some(pdu) => Ok(pdu), - None => { - self.confirm_disconnected(cid, ble).await?; - Err(Error::ChannelClosed.into()) - } + None => Err(Error::ChannelClosed.into()), } } @@ -527,16 +492,16 @@ impl<'d, const RXQ: usize> ChannelManager<'d, RXQ> { /// If the channel has been closed or the channel id is not valid, an error is returned. pub(crate) async fn send( &self, - cid: u16, + index: ChannelIndex, buf: &[u8], p_buf: &mut [u8], ble: &BleHost<'d, T>, ) -> Result<(), BleHostError> { - let (conn, mps, peer_cid) = self.connected_channel_params(cid)?; + let (conn, mps, peer_cid) = self.connected_channel_params(index)?; // The number of packets we'll need to send for this payload let n_packets = 1 + ((buf.len() as u16).saturating_sub(mps - 2)).div_ceil(mps); - let mut grant = poll_fn(|cx| self.poll_request_to_send(cid, n_packets, Some(cx))).await?; + let mut grant = poll_fn(|cx| self.poll_request_to_send(index, n_packets, Some(cx))).await?; let mut hci = ble.acl(conn, n_packets).await?; // Segment using mps @@ -564,17 +529,17 @@ impl<'d, const RXQ: usize> ChannelManager<'d, RXQ> { /// If the channel has been closed or the channel id is not valid, an error is returned. pub(crate) fn try_send( &self, - cid: u16, + index: ChannelIndex, buf: &[u8], p_buf: &mut [u8], ble: &BleHost<'d, T>, ) -> Result<(), BleHostError> { - let (conn, mps, peer_cid) = self.connected_channel_params(cid)?; + let (conn, mps, peer_cid) = self.connected_channel_params(index)?; // The number of packets we'll need to send for this payload let n_packets = ((buf.len() as u16).saturating_add(2)).div_ceil(mps); - let mut grant = match self.poll_request_to_send(cid, n_packets, None) { + let mut grant = match self.poll_request_to_send(index, n_packets, None) { Poll::Ready(res) => res?, Poll::Pending => { return Err(Error::Busy.into()); @@ -601,17 +566,16 @@ impl<'d, const RXQ: usize> ChannelManager<'d, RXQ> { Ok(()) } - fn connected_channel_params(&self, cid: u16) -> Result<(ConnHandle, u16, u16), Error> { + fn connected_channel_params(&self, index: ChannelIndex) -> Result<(ConnHandle, u16, u16), Error> { let state = self.state.borrow(); - for chan in state.channels.iter() { - match chan.state { - ChannelState::Connected if chan.cid == cid => { - return Ok((ConnHandle::new(chan.conn), chan.mps, chan.peer_cid)); - } - _ => {} + let chan = &state.channels[index.0 as usize]; + match chan.state { + ChannelState::Connected => { + return Ok((ConnHandle::new(chan.conn), chan.mps, chan.peer_cid)); } + _ => {} } - trace!("[l2cap][connected_channel_params] channel {} closed", cid); + //trace!("[l2cap][connected_channel_params] channel {} closed", index); Err(Error::ChannelClosed) } @@ -619,20 +583,16 @@ impl<'d, const RXQ: usize> ChannelManager<'d, RXQ> { // our policy says so. async fn flow_control( &self, - cid: u16, + chan: ChannelIndex, ble: &BleHost<'d, T>, mut packet: Packet, ) -> Result<(), BleHostError> { - let (conn, credits) = self.with_mut(|state| { - for storage in state.channels.iter_mut() { - match storage.state { - ChannelState::Connected if cid == storage.cid => { - return Ok((storage.conn, storage.flow_control.process())); - } - _ => {} - } + let (conn, cid, credits) = self.with_mut(|state| { + let chan = &mut state.channels[chan.0 as usize]; + if chan.state == ChannelState::Connected { + return Ok((chan.conn, chan.cid, chan.flow_control.process())); } - trace!("[l2cap][flow_control] channel {} not found", cid); + trace!("[l2cap][flow_control] channel {} not found", chan); Err(Error::NotFound) })?; @@ -652,68 +612,68 @@ impl<'d, const RXQ: usize> ChannelManager<'d, RXQ> { f(&mut state) } - async fn confirm_disconnected( - &self, - cid: u16, - ble: &BleHost<'d, T>, - ) -> Result<(), BleHostError> { - let (handle, dcid, scid) = self.with_mut(|state| { - for storage in state.channels.iter_mut() { - match storage.state { - ChannelState::Disconnecting if cid == storage.cid => { - storage.state = ChannelState::Disconnected; - let scid = storage.cid; - let dcid = storage.peer_cid; - let handle = storage.conn; - storage.cid = 0; - storage.peer_cid = 0; - storage.conn = 0; - return Ok((handle, dcid, scid)); - } - _ => {} - } - } - trace!("[l2cap][confirm_disconnected] channel {} not found", cid); - Err(Error::NotFound) - })?; - - let identifier = self.next_request_id(); - let mut tx = [0; 18]; - let mut hci = ble.acl(ConnHandle::new(handle), 1).await?; - hci.signal(identifier, &DisconnectionRes { dcid, scid }, &mut tx[..]) - .await?; - Ok(()) - } - fn poll_request_to_send( &self, - cid: u16, + index: ChannelIndex, credits: u16, cx: Option<&mut Context<'_>>, ) -> Poll, Error>> { let mut state = self.state.borrow_mut(); - for storage in state.channels.iter_mut() { + let chan = &mut state.channels[index.0 as usize]; + if chan.state == ChannelState::Connected { + if let Some(cx) = cx { + chan.credit_waker.register(cx.waker()); + } + if credits <= chan.peer_credits { + chan.peer_credits -= credits; + return Poll::Ready(Ok(CreditGrant::new(&self.state, index, credits))); + } else { + warn!( + "[l2cap][poll_request_to_send][cid = {}]: not enough credits, requested {} available {}", + chan.cid, credits, chan.peer_credits + ); + return Poll::Pending; + } + } + trace!("[l2cap][pool_request_to_send] channel index {} not found", index); + Poll::Ready(Err(Error::NotFound)) + } + + pub(crate) fn poll_disconnecting<'m>(&'m self, cx: Option<&mut Context<'_>>) -> Poll> { + let mut state = self.state.borrow_mut(); + if let Some(cx) = cx { + state.disconnect_waker.register(cx.waker()); + } + for (idx, storage) in state.channels.iter().enumerate() { match storage.state { - ChannelState::Connected if cid == storage.cid => { - if let Some(cx) = cx { - storage.credit_waker.register(cx.waker()); - } - if credits <= storage.peer_credits { - storage.peer_credits -= credits; - return Poll::Ready(Ok(CreditGrant::new(&self.state, cid, credits))); - } else { - warn!( - "[l2cap][poll_request_to_send][cid = {}]: not enough credits, requested {} available {}", - cid, credits, storage.peer_credits - ); - return Poll::Pending; - } + ChannelState::Disconnecting | ChannelState::PeerDisconnecting => { + return Poll::Ready(DisconnectRequest { + index: ChannelIndex(idx as u8), + handle: ConnHandle::new(storage.conn), + state: &self.state, + }); } _ => {} } } - trace!("[l2cap][pool_request_to_send] channel {} not found", cid); - Poll::Ready(Err(Error::NotFound)) + Poll::Pending + } + + pub(crate) fn inc_ref(&self, index: ChannelIndex) { + self.with_mut(|state| { + let state = &mut state.channels[index.0 as usize]; + state.refcount = unwrap!(state.refcount.checked_add(1), "Too many references to the same channel"); + }); + } + + pub(crate) fn dec_ref(&self, index: ChannelIndex) { + self.with_mut(|state| { + let state = &mut state.channels[index.0 as usize]; + state.refcount = unwrap!(state.refcount.checked_sub(1), "bug: dropping a channel with refcount 0"); + if state.refcount == 0 && state.state == ChannelState::Connected { + state.state = ChannelState::Disconnecting; + } + }); } pub(crate) fn log_status(&self) { @@ -722,6 +682,47 @@ impl<'d, const RXQ: usize> ChannelManager<'d, RXQ> { } } +pub struct DisconnectRequest<'a, 'd> { + index: ChannelIndex, + handle: ConnHandle, + state: &'a RefCell>, +} + +impl<'a, 'd> DisconnectRequest<'a, 'd> { + pub fn handle(&self) -> ConnHandle { + self.handle + } + + pub async fn send(&self, hci: &mut AclSender<'a, 'd, T>) -> Result<(), BleHostError> { + let (state, conn, identifier, dcid, scid) = { + let mut state = self.state.borrow_mut(); + let identifier = state.next_request_id(); + let chan = &state.channels[self.index.0 as usize]; + (chan.state.clone(), chan.conn, identifier, chan.peer_cid, chan.cid) + }; + + let mut tx = [0; 18]; + match state { + ChannelState::PeerDisconnecting => { + assert_eq!(self.handle.raw(), conn); + hci.signal(identifier, &DisconnectionRes { dcid, scid }, &mut tx[..]) + .await?; + } + ChannelState::Disconnecting => { + assert_eq!(self.handle.raw(), conn); + hci.signal(identifier, &DisconnectionReq { dcid, scid }, &mut tx[..]) + .await?; + } + _ => {} + } + Ok(()) + } + + pub fn confirm(self) { + self.state.borrow_mut().channels[self.index.0 as usize].state = ChannelState::Disconnected; + } +} + fn encode(data: &[u8], packet: &mut [u8], peer_cid: u16, header: Option) -> Result { let mut w = WriteCursor::new(packet); if header.is_some() { @@ -739,6 +740,24 @@ fn encode(data: &[u8], packet: &mut [u8], peer_cid: u16, header: Option) -> Ok(w.len()) } +pub(crate) trait DynamicChannelManager { + fn inc_ref(&self, index: ChannelIndex); + fn dec_ref(&self, index: ChannelIndex); + fn disconnect(&self, index: ChannelIndex); +} + +impl<'d, const RXQ: usize> DynamicChannelManager for ChannelManager<'d, RXQ> { + fn inc_ref(&self, index: ChannelIndex) { + ChannelManager::inc_ref(self, index) + } + fn dec_ref(&self, index: ChannelIndex) { + ChannelManager::dec_ref(self, index) + } + fn disconnect(&self, index: ChannelIndex) { + ChannelManager::disconnect(self, index) + } +} + #[derive(Debug)] pub struct ChannelStorage { state: ChannelState, @@ -748,6 +767,7 @@ pub struct ChannelStorage { mps: u16, mtu: u16, flow_control: CreditFlowControl, + refcount: u8, peer_cid: u16, peer_credits: u16, @@ -785,16 +805,29 @@ impl ChannelStorage { peer_cid: 0, peer_credits: 0, credit_waker: WakerRegistration::new(), + refcount: 0, }; + + fn close(&mut self) { + self.state = ChannelState::Disconnected; + self.cid = 0; + self.conn = 0; + self.mps = 0; + self.mtu = 0; + self.psm = 0; + self.peer_cid = 0; + self.peer_credits = 0; + } } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum ChannelState { Disconnected, Connecting(u8), PeerConnecting(u8), Connected, + PeerDisconnecting, Disconnecting, } @@ -867,13 +900,13 @@ impl CreditFlowControl { pub struct CreditGrant<'reference, 'state> { state: &'reference RefCell>, - cid: u16, + index: ChannelIndex, credits: u16, } impl<'reference, 'state> CreditGrant<'reference, 'state> { - fn new(state: &'reference RefCell>, cid: u16, credits: u16) -> Self { - Self { state, cid, credits } + fn new(state: &'reference RefCell>, index: ChannelIndex, credits: u16) -> Self { + Self { state, index, credits } } pub(crate) fn confirm(&mut self, sent: u16) { @@ -893,18 +926,14 @@ impl<'reference, 'state> Drop for CreditGrant<'reference, 'state> { fn drop(&mut self) { if self.credits > 0 { let mut state = self.state.borrow_mut(); - for storage in state.channels.iter_mut() { - match storage.state { - ChannelState::Connected if self.cid == storage.cid => { - storage.peer_credits += self.credits; - storage.credit_waker.wake(); - return; - } - _ => {} - } + let chan = &mut state.channels[self.index.0 as usize]; + if chan.state == ChannelState::Connected { + chan.peer_credits += self.credits; + chan.credit_waker.wake(); + return; } // make it an assert? - warn!("[l2cap][credit grant drop] channel {} not found", self.cid); + // warn!("[l2cap][credit grant drop] channel {} not found", self.index); } } } diff --git a/host/src/host.rs b/host/src/host.rs index b948355..35cfec2 100644 --- a/host/src/host.rs +++ b/host/src/host.rs @@ -25,10 +25,11 @@ use bt_hci::param::{ FilterDuplicates, InitiatingPhy, LeConnRole, LeEventMask, Operation, PhyParams, ScanningPhy, Status, }; use bt_hci::{ControllerToHostPacket, FromHciBytes, WriteHci}; -use embassy_futures::select::{select, select3, Either, Either3}; +use embassy_futures::select::{select, select3, select4, Either, Either3, Either4}; use embassy_sync::blocking_mutex::raw::NoopRawMutex; -use embassy_sync::channel::{Channel, TryReceiveError}; +use embassy_sync::channel::Channel; use embassy_sync::once_lock::OnceLock; +use embassy_sync::waitqueue::WakerRegistration; use futures::pin_mut; use crate::advertise::{Advertisement, AdvertisementParameters, AdvertisementSet, RawAdvertisement}; @@ -52,15 +53,23 @@ use crate::{attribute::AttributeTable, gatt::GattServer}; /// /// The l2cap packet pool is used by the host to handle inbound data, by allocating space for /// incoming packets and dispatching to the appropriate connection and channel. -pub struct BleHostResources { +pub struct BleHostResources< + const CONNS: usize, + const CHANNELS: usize, + const L2CAP_MTU: usize, + const ADV_SETS: usize = 1, +> { rx_pool: PacketPool, connections: [ConnectionStorage; CONNS], channels: [ChannelStorage; CHANNELS], channels_rx: [PacketChannel<{ config::L2CAP_RX_QUEUE_SIZE }>; CHANNELS], sar: [SarType; CONNS], + advertise_handles: [AdvHandleState; ADV_SETS], } -impl BleHostResources { +impl + BleHostResources +{ /// Create a new instance of host resources with the provided QoS requirements for packets. pub fn new(qos: Qos) -> Self { Self { @@ -69,6 +78,7 @@ impl BleHostR sar: [EMPTY_SAR; CONNS], channels: [ChannelStorage::DISCONNECTED; CHANNELS], channels_rx: [PacketChannel::NEW; CHANNELS], + advertise_handles: [AdvHandleState::None; ADV_SETS], } } } @@ -98,9 +108,102 @@ pub struct BleHost<'d, T> { outbound: Channel, pub(crate) scanner: Channel, 1>, - advertise_terminated: Channel, 1>, - advertise_state: CommandState, - connect_state: CommandState, + advertise_state: AdvState<'d>, + advertise_command_state: CommandState, + connect_command_state: CommandState, +} + +#[derive(Clone, Copy)] +pub(crate) enum AdvHandleState { + None, + Advertising(AdvHandle), + Terminated(AdvHandle), +} + +pub(crate) struct AdvInnerState<'d> { + handles: &'d mut [AdvHandleState], + waker: WakerRegistration, +} + +pub(crate) struct AdvState<'d> { + state: RefCell>, +} + +impl<'d> AdvState<'d> { + pub fn new(handles: &'d mut [AdvHandleState]) -> Self { + Self { + state: RefCell::new(AdvInnerState { + handles, + waker: WakerRegistration::new(), + }), + } + } + + pub fn reset(&self) { + let mut state = self.state.borrow_mut(); + for entry in state.handles.iter_mut() { + *entry = AdvHandleState::None; + } + state.waker.wake(); + } + + // Terminate handle + pub fn terminate(&self, handle: AdvHandle) { + let mut state = self.state.borrow_mut(); + for entry in state.handles.iter_mut() { + match entry { + AdvHandleState::Advertising(h) if *h == handle => { + *entry = AdvHandleState::Terminated(handle); + } + _ => {} + } + } + state.waker.wake(); + } + + pub fn len(&self) -> usize { + let state = self.state.borrow(); + state.handles.len() + } + + pub fn start(&self, sets: &[AdvSet]) { + let mut state = self.state.borrow_mut(); + assert_eq!(sets.len(), state.handles.len()); + for (idx, entry) in state.handles.iter_mut().enumerate() { + *entry = AdvHandleState::Advertising(sets[idx].adv_handle); + } + } + + pub async fn wait(&self, sets: &[AdvSet]) { + poll_fn(|cx| { + let mut state = self.state.borrow_mut(); + state.waker.register(cx.waker()); + + let mut terminated = 0; + for entry in state.handles.iter() { + match entry { + AdvHandleState::Terminated(handle) => { + for set in sets.iter() { + if *handle == set.adv_handle { + terminated += 1; + break; + } + } + } + AdvHandleState::None => { + terminated += 1; + } + _ => {} + } + } + if terminated == sets.len() { + Poll::Ready(()) + } else { + Poll::Pending + } + }) + .await; + } } #[derive(Default)] @@ -108,7 +211,6 @@ struct Metrics { connect_events: u32, disconnect_events: u32, rx_errors: u32, - tx_blocked: u32, } impl<'d, T> BleHost<'d, T> @@ -119,9 +221,9 @@ where /// /// The host requires a HCI driver (a particular HCI-compatible controller implementing the required traits), and /// a reference to resources that are created outside the host but which the host is the only accessor of. - pub fn new( + pub fn new( controller: T, - host_resources: &'static mut BleHostResources, + host_resources: &'static mut BleHostResources, ) -> Self { Self { address: None, @@ -138,9 +240,9 @@ where rx_pool: &host_resources.rx_pool, att_inbound: Channel::new(), scanner: Channel::new(), - advertise_terminated: Channel::new(), - advertise_state: CommandState::new(), - connect_state: CommandState::new(), + advertise_state: AdvState::new(&mut host_resources.advertise_handles[..]), + advertise_command_state: CommandState::new(), + connect_command_state: CommandState::new(), outbound: Channel::new(), } } @@ -199,9 +301,9 @@ where } let _drop = OnDrop::new(|| { - self.connect_state.cancel(true); + self.connect_command_state.cancel(true); }); - self.connect_state.request().await; + self.connect_command_state.request().await; self.set_accept_filter(config.scan_config.filter_accept_list).await?; @@ -223,13 +325,13 @@ where match select( self.connections .accept(LeConnRole::Central, config.scan_config.filter_accept_list), - self.connect_state.wait_idle(), + self.connect_command_state.wait_idle(), ) .await { Either::First(conn) => { _drop.defuse(); - self.connect_state.done(); + self.connect_command_state.done(); Ok(conn) } Either::Second(_) => Err(Error::Timeout.into()), @@ -251,9 +353,9 @@ where // Ensure no other connect ongoing. let _drop = OnDrop::new(|| { - self.connect_state.cancel(true); + self.connect_command_state.cancel(true); }); - self.connect_state.request().await; + self.connect_command_state.request().await; self.set_accept_filter(config.scan_config.filter_accept_list).await?; @@ -282,13 +384,13 @@ where match select( self.connections .accept(LeConnRole::Central, config.scan_config.filter_accept_list), - self.connect_state.wait_idle(), + self.connect_command_state.wait_idle(), ) .await { Either::First(conn) => { _drop.defuse(); - self.connect_state.done(); + self.connect_command_state.done(); Ok(conn) } Either::Second(_) => Err(Error::Timeout.into()), @@ -449,12 +551,12 @@ where { // Ensure no other advertise ongoing. let _drop = OnDrop::new(|| { - self.advertise_state.cancel(false); + self.advertise_command_state.cancel(false); }); - self.advertise_state.request().await; + self.advertise_command_state.request().await; // Clear current advertising terminations - while self.advertise_terminated.try_receive() != Err(TryReceiveError::Empty) {} + self.advertise_state.reset(); let data: RawAdvertisement = data.into(); if !data.props.legacy_adv() { @@ -498,14 +600,21 @@ where self.command(LeSetScanResponseData::new(to_copy as u8, buf)).await?; } + let advsets: [AdvSet; 1] = [AdvSet { + adv_handle: AdvHandle::new(0), + duration: bt_hci::param::Duration::from_secs(0), + max_ext_adv_events: 0, + }]; + + self.advertise_state.start(&advsets[..]); self.command(LeSetAdvEnable::new(true)).await?; match select( - self.advertise_terminated.receive(), + self.advertise_state.wait(&advsets[..]), self.connections.accept(LeConnRole::Peripheral, &[]), ) .await { - Either::First(handle) => Err(Error::Timeout.into()), + Either::First(_) => Err(Error::Timeout.into()), Either::Second(conn) => Ok(conn), } } @@ -528,19 +637,21 @@ where + for<'t> ControllerCmdSync>, { // Check host supports the required advertisement sets - let result = self.command(LeReadNumberOfSupportedAdvSets::new()).await?; - if result < N as u8 { - return Err(Error::InsufficientSpace.into()); + { + let result = self.command(LeReadNumberOfSupportedAdvSets::new()).await?; + if result < N as u8 || self.advertise_state.len() < N { + return Err(Error::InsufficientSpace.into()); + } } // Ensure no other advertise ongoing. let _drop = OnDrop::new(|| { - self.advertise_state.cancel(true); + self.advertise_command_state.cancel(true); }); - self.advertise_state.request().await; + self.advertise_command_state.request().await; // Clear current advertising terminations - while self.advertise_terminated.try_receive() != Err(TryReceiveError::Empty) {} + self.advertise_state.reset(); for set in sets { let handle = AdvHandle::new(set.handle); @@ -600,29 +711,19 @@ where }); trace!("[host] enabling advertising"); + self.advertise_state.start(&advset[..]); self.command(LeSetExtAdvEnable::new(true, &advset)).await?; - let mut terminated: [bool; N] = [false; N]; loop { match select( - self.advertise_terminated.receive(), + self.advertise_state.wait(&advset[..]), self.connections.accept(LeConnRole::Peripheral, &[]), ) .await { - Either::First(None) => { + Either::First(_) => { return Err(Error::Timeout.into()); } - Either::First(Some(handle)) => { - for (i, s) in advset.iter().enumerate() { - if s.adv_handle == handle { - terminated[i] = true; - } - } - if !terminated.contains(&false) { - return Err(Error::Timeout.into()); - } - } Either::Second(conn) => return Ok(conn), } } @@ -643,27 +744,19 @@ where } } - async fn handle_connection( + fn handle_connection( &self, status: Status, handle: ConnHandle, peer_addr_kind: AddrKind, peer_addr: BdAddr, role: LeConnRole, - ) where - T: ControllerCmdSync, - { + ) -> bool { match status.to_result() { Ok(_) => { if let Err(err) = self.connections.connect(handle, peer_addr_kind, peer_addr, role) { warn!("Error establishing connection: {:?}", err); - let _ = self - .command(Disconnect::new( - handle, - DisconnectReason::RemoteDeviceTerminatedConnLowResources, - )) - .await; - self.connect_state.canceled(); + return false; } else { trace!( "[host] connection established with handle {:?} to {:?}", @@ -675,20 +768,21 @@ where } } Err(bt_hci::param::Error::ADV_TIMEOUT) => { - self.advertise_terminated.send(None).await; + self.advertise_state.reset(); } Err(bt_hci::param::Error::UNKNOWN_CONN_IDENTIFIER) => { warn!("[host] connect cancelled"); - self.connect_state.canceled(); + self.connect_command_state.canceled(); } Err(e) => { warn!("Error connection complete event: {:?}", e); - self.connect_state.canceled(); + self.connect_command_state.canceled(); } } + true } - async fn handle_acl(&self, acl: AclPacket<'_>) -> Result<(), Error> { + fn handle_acl(&self, acl: AclPacket<'_>) -> Result<(), Error> { if !self.connections.is_handle_connected(acl.handle()) { return Err(Error::Disconnected); } @@ -707,7 +801,7 @@ where // Avoids using the packet buffer for signalling packets if header.channel == L2CAP_CID_LE_U_SIGNAL { assert!(data.len() == header.length as usize); - self.channels.signal(acl.handle(), data).await?; + self.channels.signal(acl.handle(), data)?; return Ok(()); } @@ -759,12 +853,17 @@ where let len = w.len(); w.finish(); - self.outbound.send((acl.handle(), Pdu::new(packet, len))).await; + if let Err(e) = self.outbound.try_send((acl.handle(), Pdu::new(packet, len))) { + return Err(Error::OutOfMemory); + } } else { #[cfg(feature = "gatt")] - self.att_inbound - .send((acl.handle(), Pdu::new(packet, header.length as usize))) - .await; + if let Err(e) = self + .att_inbound + .try_send((acl.handle(), Pdu::new(packet, header.length as usize))) + { + return Err(Error::OutOfMemory); + } #[cfg(not(feature = "gatt"))] return Err(Error::NotSupported); @@ -773,7 +872,7 @@ where L2CAP_CID_LE_U_SIGNAL => { panic!("le signalling channel was fragmented, impossible!"); } - other if other >= L2CAP_CID_DYN_START => match self.channels.dispatch(header, packet).await { + other if other >= L2CAP_CID_DYN_START => match self.channels.dispatch(header, packet) { Ok(_) => {} Err(e) => { warn!("Error dispatching l2cap packet to channel: {:?}", e); @@ -869,35 +968,39 @@ where let _ = self.initialized.init(()); loop { - match select3( + match select4( poll_fn(|cx| self.connections.poll_disconnecting(Some(cx))), - poll_fn(|cx| self.connect_state.poll_cancelled(cx)), - poll_fn(|cx| self.advertise_state.poll_cancelled(cx)), + poll_fn(|cx| self.channels.poll_disconnecting(Some(cx))), + poll_fn(|cx| self.connect_command_state.poll_cancelled(cx)), + poll_fn(|cx| self.advertise_command_state.poll_cancelled(cx)), ) .await { - Either3::First(request) => { - trace!("[host] disconnect request handle {:?}", request.handle()); + Either4::First(request) => { self.command(Disconnect::new(request.handle(), request.reason())) .await?; - trace!("[host] disconnect sent, confirming"); request.confirm(); } - Either3::Second(_) => { + Either4::Second(request) => { + let mut grant = self.acl(request.handle(), 1).await?; + request.send(&mut grant).await?; + request.confirm(); + } + Either4::Third(_) => { // trace!("[host] cancelling create connection"); if let Err(e) = self.command(LeCreateConnCancel::new()).await { // Signal to ensure no one is stuck - self.connect_state.canceled(); + self.connect_command_state.canceled(); } } - Either3::Third(ext) => { + Either4::Fourth(ext) => { // trace!("[host] turning off advertising"); if ext { self.command(LeSetExtAdvEnable::new(false, &[])).await? } else { self.command(LeSetAdvEnable::new(false)).await? } - self.advertise_state.canceled(); + self.advertise_command_state.canceled(); } } } @@ -929,7 +1032,7 @@ where let mut rx = [0u8; MAX_HCI_PACKET_LEN]; let result = self.controller.read(&mut rx).await; match result { - Ok(ControllerToHostPacket::Acl(acl)) => match self.handle_acl(acl).await { + Ok(ControllerToHostPacket::Acl(acl)) => match self.handle_acl(acl) { Ok(_) => {} Err(e) => { trace!("Error processing ACL packet: {:?}", e); @@ -940,28 +1043,42 @@ where Ok(ControllerToHostPacket::Event(event)) => match event { Event::Le(event) => match event { LeEvent::LeConnectionComplete(e) => { - self.handle_connection(e.status, e.handle, e.peer_addr_kind, e.peer_addr, e.role) - .await; + if !self.handle_connection(e.status, e.handle, e.peer_addr_kind, e.peer_addr, e.role) { + let _ = self + .command(Disconnect::new( + e.handle, + DisconnectReason::RemoteDeviceTerminatedConnLowResources, + )) + .await; + self.connect_command_state.canceled(); + } } LeEvent::LeEnhancedConnectionComplete(e) => { - self.handle_connection(e.status, e.handle, e.peer_addr_kind, e.peer_addr, e.role) - .await; + if !self.handle_connection(e.status, e.handle, e.peer_addr_kind, e.peer_addr, e.role) { + let _ = self + .command(Disconnect::new( + e.handle, + DisconnectReason::RemoteDeviceTerminatedConnLowResources, + )) + .await; + self.connect_command_state.canceled(); + } } LeEvent::LeScanTimeout(_) => { - self.scanner.send(None).await; + let _ = self.scanner.try_send(None); } LeEvent::LeAdvertisingSetTerminated(set) => { - self.advertise_terminated.send(Some(set.adv_handle)).await; + self.advertise_state.terminate(set.adv_handle); } LeEvent::LeExtendedAdvertisingReport(data) => { - self.scanner - .send(Some(ScanReport::new(data.reports.num_reports, &data.reports.bytes))) - .await; + let _ = self + .scanner + .try_send(Some(ScanReport::new(data.reports.num_reports, &data.reports.bytes))); } LeEvent::LeAdvertisingReport(data) => { - self.scanner - .send(Some(ScanReport::new(data.reports.num_reports, &data.reports.bytes))) - .await; + let _ = self + .scanner + .try_send(Some(ScanReport::new(data.reports.num_reports, &data.reports.bytes))); } _ => { warn!("Unknown LE event!"); @@ -1014,7 +1131,7 @@ where pub(crate) async fn acl(&self, handle: ConnHandle, n: u16) -> Result, BleHostError> { let grant = poll_fn(|cx| self.connections.poll_request_to_send(handle, n as usize, Some(cx))).await?; Ok(AclSender { - ble: self, + controller: &self.controller, handle, grant, }) @@ -1029,7 +1146,7 @@ where } }; Ok(AclSender { - ble: self, + controller: &self.controller, handle, grant, }) @@ -1040,7 +1157,6 @@ where let m = self.metrics.borrow(); debug!("[host] connect events: {}", m.connect_events); debug!("[host] disconnect events: {}", m.disconnect_events); - debug!("[host] tx blocked: {}", m.tx_blocked); debug!("[host] rx errors: {}", m.rx_errors); self.connections.log_status(); self.channels.log_status(); @@ -1048,7 +1164,7 @@ where } pub struct AclSender<'a, 'd, T: Controller> { - pub(crate) ble: &'a BleHost<'d, T>, + pub(crate) controller: &'a T, pub(crate) handle: ConnHandle, pub(crate) grant: PacketGrant<'a, 'd>, } @@ -1065,14 +1181,12 @@ impl<'a, 'd, T: Controller> AclSender<'a, 'd, T> { pdu, ); // info!("Sent ACL {:?}", acl); - match self.ble.controller.try_write_acl_data(&acl) { + match self.controller.try_write_acl_data(&acl) { Ok(result) => { self.grant.confirm(1); Ok(result) } Err(blocking::TryError::Busy) => { - let mut m = self.ble.metrics.borrow_mut(); - m.tx_blocked = m.tx_blocked.wrapping_add(1); warn!("hci: acl data send busy"); Err(Error::Busy.into()) } @@ -1087,8 +1201,7 @@ impl<'a, 'd, T: Controller> AclSender<'a, 'd, T> { AclBroadcastFlag::PointToPoint, pdu, ); - self.ble - .controller + self.controller .write_acl_data(&acl) .await .map_err(BleHostError::Controller)?; diff --git a/host/src/l2cap.rs b/host/src/l2cap.rs index bb17a14..350ba68 100644 --- a/host/src/l2cap.rs +++ b/host/src/l2cap.rs @@ -2,6 +2,7 @@ use bt_hci::controller::{blocking, Controller}; pub use crate::channel_manager::CreditFlowPolicy; +use crate::channel_manager::{ChannelIndex, DynamicChannelManager}; use crate::connection::Connection; use crate::host::BleHost; use crate::BleHostError; @@ -9,9 +10,25 @@ use crate::BleHostError; pub(crate) mod sar; /// Handle representing an L2CAP channel. -#[derive(Clone)] -pub struct L2capChannel { - cid: u16, +pub struct L2capChannel<'d, const TX_MTU: usize> { + index: ChannelIndex, + manager: &'d dyn DynamicChannelManager, +} + +impl<'d, const TX_MTU: usize> Clone for L2capChannel<'d, TX_MTU> { + fn clone(&self) -> Self { + self.manager.inc_ref(self.index); + Self { + index: self.index, + manager: self.manager, + } + } +} + +impl<'d, const TX_MTU: usize> Drop for L2capChannel<'d, TX_MTU> { + fn drop(&mut self) { + self.manager.dec_ref(self.index); + } } /// Configuration for an L2CAP channel. @@ -34,7 +51,12 @@ impl Default for L2capChannelConfig { } } -impl L2capChannel { +impl<'d, const TX_MTU: usize> L2capChannel<'d, TX_MTU> { + /// Disconnect this channel. + pub fn disconnect(&mut self) { + self.manager.disconnect(self.index); + } + /// Send the provided buffer over this l2cap channel. /// /// The buffer will be segmented to the maximum payload size agreed in the opening handshake. @@ -47,7 +69,7 @@ impl L2capChannel { buf: &[u8], ) -> Result<(), BleHostError> { let mut p_buf = [0u8; TX_MTU]; - ble.channels.send(self.cid, buf, &mut p_buf[..], ble).await + ble.channels.send(self.index, buf, &mut p_buf[..], ble).await } /// Send the provided buffer over this l2cap channel. @@ -62,7 +84,7 @@ impl L2capChannel { buf: &[u8], ) -> Result<(), BleHostError> { let mut p_buf = [0u8; TX_MTU]; - ble.channels.try_send(self.cid, buf, &mut p_buf[..], ble) + ble.channels.try_send(self.index, buf, &mut p_buf[..], ble) } /// Receive data on this channel and copy it into the buffer. @@ -73,41 +95,37 @@ impl L2capChannel { ble: &BleHost<'_, T>, buf: &mut [u8], ) -> Result> { - ble.channels.receive(self.cid, buf, ble).await + ble.channels.receive(self.index, buf, ble).await } /// Await an incoming connection request matching the list of PSM. pub async fn accept( - ble: &BleHost<'_, T>, + ble: &'d BleHost<'_, T>, connection: &Connection<'_>, psm: &[u16], config: &L2capChannelConfig, - ) -> Result, BleHostError> { + ) -> Result> { let handle = connection.handle(); - let cid = ble + let index = ble .channels .accept(handle, psm, config.mtu, config.flow_policy, config.initial_credits, ble) .await?; - - Ok(Self { cid }) - } - - /// Disconnect this channel. - pub fn disconnect(&mut self, ble: &BleHost<'_, T>) -> Result<(), BleHostError> { - ble.channels.disconnect(self.cid)?; - Ok(()) + Ok(Self { + index, + manager: &ble.channels, + }) } /// Create a new connection request with the provided PSM. pub async fn create( - ble: &BleHost<'_, T>, + ble: &'d BleHost<'_, T>, connection: &Connection<'_>, psm: u16, config: &L2capChannelConfig, ) -> Result> where { let handle = connection.handle(); - let cid = ble + let index = ble .channels .create( connection.handle(), @@ -118,7 +136,9 @@ where { ble, ) .await?; - - Ok(Self { cid }) + Ok(Self { + index, + manager: &ble.channels, + }) } } diff --git a/host/src/types/l2cap.rs b/host/src/types/l2cap.rs index 84c94c4..5bf70b3 100644 --- a/host/src/types/l2cap.rs +++ b/host/src/types/l2cap.rs @@ -215,6 +215,12 @@ unsafe impl FixedSizeValue for DisconnectionReq { } } +impl L2capSignal for DisconnectionReq { + fn code() -> L2capSignalCode { + L2capSignalCode::DisconnectionReq + } +} + #[cfg_attr(feature = "defmt", derive(defmt::Format))] #[repr(C)] #[derive(Debug, Clone, Copy)] From 6c797ea45a9ccf6bd87b40db6888c154f270074d Mon Sep 17 00:00:00 2001 From: Ulf Lilleengen Date: Thu, 6 Jun 2024 18:20:29 +0200 Subject: [PATCH 2/2] fix clippy warnings --- host/src/channel_manager.rs | 8 ++------ host/src/host.rs | 20 ++++++++------------ 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/host/src/channel_manager.rs b/host/src/channel_manager.rs index 38ad8ba..e92dd9d 100644 --- a/host/src/channel_manager.rs +++ b/host/src/channel_manager.rs @@ -569,11 +569,8 @@ impl<'d, const RXQ: usize> ChannelManager<'d, RXQ> { fn connected_channel_params(&self, index: ChannelIndex) -> Result<(ConnHandle, u16, u16), Error> { let state = self.state.borrow(); let chan = &state.channels[index.0 as usize]; - match chan.state { - ChannelState::Connected => { - return Ok((ConnHandle::new(chan.conn), chan.mps, chan.peer_cid)); - } - _ => {} + if chan.state == ChannelState::Connected { + return Ok((ConnHandle::new(chan.conn), chan.mps, chan.peer_cid)); } //trace!("[l2cap][connected_channel_params] channel {} closed", index); Err(Error::ChannelClosed) @@ -930,7 +927,6 @@ impl<'reference, 'state> Drop for CreditGrant<'reference, 'state> { if chan.state == ChannelState::Connected { chan.peer_credits += self.credits; chan.credit_waker.wake(); - return; } // make it an assert? // warn!("[l2cap][credit grant drop] channel {} not found", self.index); diff --git a/host/src/host.rs b/host/src/host.rs index 35cfec2..509e4cc 100644 --- a/host/src/host.rs +++ b/host/src/host.rs @@ -714,18 +714,14 @@ where self.advertise_state.start(&advset[..]); self.command(LeSetExtAdvEnable::new(true, &advset)).await?; - loop { - match select( - self.advertise_state.wait(&advset[..]), - self.connections.accept(LeConnRole::Peripheral, &[]), - ) - .await - { - Either::First(_) => { - return Err(Error::Timeout.into()); - } - Either::Second(conn) => return Ok(conn), - } + match select( + self.advertise_state.wait(&advset[..]), + self.connections.accept(LeConnRole::Peripheral, &[]), + ) + .await + { + Either::First(_) => Err(Error::Timeout.into()), + Either::Second(conn) => Ok(conn), } }