Skip to content

Commit

Permalink
net: refactor UDP socket implementation
Browse files Browse the repository at this point in the history
- remove mutable reference to the UDP socket on bind/connect
  • Loading branch information
equation314 committed Jul 19, 2023
1 parent 606e741 commit 88a90a0
Show file tree
Hide file tree
Showing 9 changed files with 197 additions and 137 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions modules/axnet/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ default = ["smoltcp"]
[dependencies]
log = "0.4"
cfg-if = "1.0"
spin = "0.9"
driver_net = { path = "../../crates/driver_net" }
lazy_init = { path = "../../crates/lazy_init" }
axerrno = { path = "../../crates/axerrno" }
Expand Down
7 changes: 6 additions & 1 deletion modules/axnet/src/smoltcp_impl/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ impl TcpSocket {
if self.is_connecting() {
return Err(AxError::WouldBlock);
} else if !self.is_connected() {
return ax_err!(NotConnected, "socket recv() failed");
return ax_err!(NotConnected, "socket send() failed");
}

// SAFETY: `self.handle` should be initialized in a connected socket.
Expand Down Expand Up @@ -427,6 +427,11 @@ impl TcpSocket {
State::SynSent => false, // wait for connection
State::Established => {
self.set_state(STATE_CONNECTED); // connected
debug!(
"TCP socket {}: connected to {}",
handle,
socket.remote_endpoint().unwrap(),
);
true
}
_ => {
Expand Down
251 changes: 144 additions & 107 deletions modules/axnet/src/smoltcp_impl/udp.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
use core::sync::atomic::{AtomicBool, Ordering};

use axerrno::{ax_err, ax_err_type, AxError, AxResult};
use axio::PollState;
use axsync::Mutex;
use spin::RwLock;

use smoltcp::iface::SocketHandle;
use smoltcp::socket::udp::{self, BindError, SendError};
use smoltcp::wire::{IpAddress, IpListenEndpoint};

use super::{SocketSetWrapper, ETH0, SOCKET_SET};
use super::{SocketSetWrapper, SOCKET_SET};
use crate::SocketAddr;

const UNSPECIFIED_IP: IpAddress = IpAddress::v4(0, 0, 0, 0);
const UNSPECIFIED: SocketAddr = SocketAddr::new(UNSPECIFIED_IP, 0);

/// A UDP socket that provides POSIX-like APIs.
pub struct UdpSocket {
handle: SocketHandle,
local_addr: Option<SocketAddr>,
peer_addr: Option<SocketAddr>,
nonblock: bool,
local_addr: RwLock<Option<SocketAddr>>,
peer_addr: RwLock<Option<SocketAddr>>,
nonblock: AtomicBool,
}

impl UdpSocket {
Expand All @@ -27,22 +31,34 @@ impl UdpSocket {
let handle = SOCKET_SET.add(socket);
Self {
handle,
local_addr: None,
peer_addr: None,
nonblock: false,
local_addr: RwLock::new(None),
peer_addr: RwLock::new(None),
nonblock: AtomicBool::new(false),
}
}

/// Returns the local address and port, or
/// [`Err(NotConnected)`](AxError::NotConnected) if not connected.
pub fn local_addr(&self) -> AxResult<SocketAddr> {
self.local_addr.ok_or(AxError::NotConnected)
match self.local_addr.try_read() {
Some(addr) => addr.ok_or(AxError::NotConnected),
None => Err(AxError::NotConnected),
}
}

/// Returns the remote address and port, or
/// [`Err(NotConnected)`](AxError::NotConnected) if not connected.
pub fn peer_addr(&self) -> AxResult<SocketAddr> {
self.peer_addr.ok_or(AxError::NotConnected)
match self.peer_addr.try_read() {
Some(addr) => addr.ok_or(AxError::NotConnected),
None => Err(AxError::NotConnected),
}
}

/// Returns whether this socket is in nonblocking mode.
#[inline]
pub fn is_nonblocking(&self) -> bool {
self.nonblock.load(Ordering::Acquire)
}

/// Moves this UDP socket into or out of nonblocking mode.
Expand All @@ -53,22 +69,25 @@ impl UdpSocket {
/// further action is required. If the IO operation could not be completed
/// and needs to be retried, an error with kind
/// [`Err(WouldBlock)`](AxError::WouldBlock) is returned.
pub fn set_nonblocking(&mut self, nonblocking: bool) {
self.nonblock = nonblocking;
#[inline]
pub fn set_nonblocking(&self, nonblocking: bool) {
self.nonblock.store(nonblocking, Ordering::Release);
}

/// Binds an unbound socket to the given address and port.
///
/// It's must be called before [`send_to`](Self::send_to) and
/// [`recv_from`](Self::recv_from).
pub fn bind(&mut self, mut local_addr: SocketAddr) -> AxResult {
pub fn bind(&self, mut local_addr: SocketAddr) -> AxResult {
let mut self_local_addr = self.local_addr.write();

if local_addr.addr.is_unspecified() && local_addr.addr != UNSPECIFIED_IP {
return ax_err!(InvalidInput, "socket bind() failed: invalid addr");
}
if local_addr.port == 0 {
local_addr.port = get_ephemeral_port()?;
}
if self.local_addr.is_some() {
if self_local_addr.is_some() {
return ax_err!(InvalidInput, "socket bind() failed: already bound");
}

Expand All @@ -82,108 +101,84 @@ impl UdpSocket {
BindError::Unaddressable => ax_err!(InvalidInput, "socket bind() failed"),
})
})?;
self.local_addr = Some(local_addr);
debug!("UDP socket bound on {}", endpoint);

*self_local_addr = Some(local_addr);
debug!("UDP socket {}: bound on {}", self.handle, endpoint);
Ok(())
}

/// Transmits data in the given buffer to the given address.
pub fn send_to(&self, buf: &[u8], addr: SocketAddr) -> AxResult<usize> {
self.block_on(|| {
SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
if !socket.is_open() {
// not bound
ax_err!(NotConnected, "socket send() failed")
} else if socket.can_send() {
// TODO: size
socket.send_slice(buf, addr).map_err(|e| match e {
SendError::BufferFull => AxError::WouldBlock,
SendError::Unaddressable => {
ax_err_type!(ConnectionRefused, "socket send() failed")
}
})?;
Ok(buf.len())
} else {
// tx buffer is full
Err(AxError::WouldBlock)
}
})
})
/// Sends data on the socket to the given address. On success, returns the
/// number of bytes written.
pub fn send_to(&self, buf: &[u8], remote_addr: SocketAddr) -> AxResult<usize> {
if remote_addr.port == 0 || remote_addr.addr.is_unspecified() {
return ax_err!(InvalidInput, "socket send_to() failed: invalid address");
}
self.send_impl(buf, remote_addr)
}

fn recv_impl<F, T>(&self, mut op: F, err: &str) -> AxResult<T>
where
F: FnMut(&mut udp::Socket) -> AxResult<T>,
{
self.block_on(|| {
SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
if !socket.is_open() {
// not connected
ax_err!(NotConnected, err)
} else if socket.can_recv() {
// data available
op(socket)
} else {
// no more data
Err(AxError::WouldBlock)
}
})
/// Receives a single datagram message on the socket. On success, returns
/// the number of bytes read and the origin.
pub fn recv_from(&self, buf: &mut [u8]) -> AxResult<(usize, SocketAddr)> {
self.recv_impl(|socket| match socket.recv_slice(buf) {
Ok((len, meta)) => Ok((len, meta.endpoint)),
Err(_) => ax_err!(BadState, "socket recv_from() failed"),
})
}

/// Receives data from the socket, stores it in the given buffer.
pub fn recv_from(&self, buf: &mut [u8]) -> AxResult<(usize, SocketAddr)> {
self.recv_impl(
|socket| match socket.recv_slice(buf) {
Ok((len, meta)) => Ok((len, meta.endpoint)),
Err(_) => Err(AxError::WouldBlock),
},
"socket recv_from() failed",
)
/// Receives a single datagram message on the socket, without removing it from
/// the queue. On success, returns the number of bytes read and the origin.
pub fn peek_from(&self, buf: &mut [u8]) -> AxResult<(usize, SocketAddr)> {
self.recv_impl(|socket| match socket.peek_slice(buf) {
Ok((len, meta)) => Ok((len, meta.endpoint)),
Err(_) => ax_err!(BadState, "socket recv_from() failed"),
})
}

/// Connects to the given address and port.
/// Connects this UDP socket to a remote address, allowing the `send` and
/// `recv` to be used to send data and also applies filters to only receive
/// data from the specified address.
///
/// The local port will be generated automatically if the socket is not bound.
/// It's must be called before [`send`](Self::send) and
/// [`recv`](Self::recv).
pub fn connect(&mut self, addr: SocketAddr) -> AxResult {
if self.local_addr.is_none() {
self.bind(SocketAddr::new(
ETH0.iface
.lock()
.ipv4_addr()
.ok_or_else(|| ax_err_type!(BadAddress, "No IPv4 address"))?
.into(),
0,
))?;
pub fn connect(&self, addr: SocketAddr) -> AxResult {
let mut self_peer_addr = self.peer_addr.write();

if addr.addr.is_unspecified() && addr.addr != UNSPECIFIED_IP {
return ax_err!(InvalidInput, "socket connect() failed: invalid addr");
}
self.peer_addr = Some(addr);

if self.local_addr.read().is_none() {
self.bind(UNSPECIFIED)?;
}

*self_peer_addr = Some(addr);
debug!("UDP socket {}: connected to {}", self.handle, addr);
Ok(())
}

/// Transmits data in the given buffer to the remote address to which it is connected.
/// Sends data on the socket to the remote address to which it is connected.
pub fn send(&self, buf: &[u8]) -> AxResult<usize> {
self.send_to(buf, self.peer_addr()?)
let remote_addr = self.peer_addr()?;
self.send_impl(buf, remote_addr)
}

/// Recv data in the given buffer from the remote address to which it is connected.
/// Receives a single datagram message on the socket from the remote address
/// to which it is connected. On success, returns the number of bytes read.
pub fn recv(&self, buf: &mut [u8]) -> AxResult<usize> {
let peeraddr = self.peer_addr()?;
self.recv_impl(
|socket| match socket.recv_slice(buf) {
Ok((len, meta)) => {
if meta.endpoint == peeraddr {
// filter data from the remote address to which it is connected.
Ok(len)
} else {
Err(AxError::WouldBlock)
}
}
Err(_) => Err(AxError::WouldBlock),
},
"socket recv() failed",
)
let remote_addr = self.peer_addr()?;
self.recv_impl(|socket| {
let (len, meta) = socket
.recv_slice(buf)
.map_err(|_| ax_err_type!(BadState, "socket recv() failed"))?;
if !remote_addr.addr.is_unspecified() && remote_addr.addr != meta.endpoint.addr {
return Err(AxError::WouldBlock);
}
if remote_addr.port != 0 && remote_addr.port != meta.endpoint.port {
return Err(AxError::WouldBlock);
}
Ok(len)
})
}

/// Close the socket.
Expand All @@ -196,23 +191,65 @@ impl UdpSocket {
Ok(())
}

/// Receives data from the socket, stores it in the given buffer, without removing it from the queue.
pub fn peek_from(&self, buf: &mut [u8]) -> AxResult<(usize, SocketAddr)> {
self.recv_impl(
|socket| match socket.peek_slice(buf) {
Ok((len, meta)) => Ok((len, meta.endpoint)),
Err(_) => Err(AxError::WouldBlock),
},
"socket peek_from() failed",
)
}

/// Whether the socket is readable or writable.
pub fn poll(&self) -> AxResult<PollState> {
if self.local_addr.read().is_none() {
return Ok(PollState {
readable: false,
writable: false,
});
}
SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
Ok(PollState {
readable: socket.is_open() && socket.can_recv(),
writable: socket.is_open() && socket.can_send(),
readable: socket.can_recv(),
writable: socket.can_send(),
})
})
}
}

/// Private methods
impl UdpSocket {
fn send_impl(&self, buf: &[u8], remote_addr: SocketAddr) -> AxResult<usize> {
if self.local_addr.read().is_none() {
return ax_err!(NotConnected, "socket send() failed");
}

self.block_on(|| {
SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
if socket.can_send() {
socket.send_slice(buf, remote_addr).map_err(|e| match e {
SendError::BufferFull => AxError::WouldBlock,
SendError::Unaddressable => {
ax_err_type!(ConnectionRefused, "socket send() failed")
}
})?;
Ok(buf.len())
} else {
// tx buffer is full
Err(AxError::WouldBlock)
}
})
})
}

fn recv_impl<F, T>(&self, mut op: F) -> AxResult<T>
where
F: FnMut(&mut udp::Socket) -> AxResult<T>,
{
if self.local_addr.read().is_none() {
return ax_err!(NotConnected, "socket send() failed");
}

self.block_on(|| {
SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
if socket.can_recv() {
// data available
op(socket)
} else {
// no more data
Err(AxError::WouldBlock)
}
})
})
}
Expand All @@ -221,7 +258,7 @@ impl UdpSocket {
where
F: FnMut() -> AxResult<T>,
{
if self.nonblock {
if self.is_nonblocking() {
f()
} else {
loop {
Expand Down
2 changes: 1 addition & 1 deletion ulib/c_libax/src/fcntl.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ int posix_fadvise(int __fd, unsigned long __offset, unsigned long __len, int __a
}

// TODO
int sync_file_range(int, off_t, off_t, unsigned)
int sync_file_range(int fd, off_t pos, off_t len, unsigned flags)
{
unimplemented();
return 0;
Expand Down
Loading

0 comments on commit 88a90a0

Please sign in to comment.