diff --git a/pal/src/host/vm-common/kernel_virtio.h b/pal/src/host/vm-common/kernel_virtio.h index cc3e0961..dfacfa83 100644 --- a/pal/src/host/vm-common/kernel_virtio.h +++ b/pal/src/host/vm-common/kernel_virtio.h @@ -362,6 +362,7 @@ int virtio_vsock_connect(int sockfd, const void* addr, size_t addrlen, uint64_t int virtio_vsock_shutdown(int sockfd, enum virtio_vsock_shutdown shutdown); int virtio_vsock_close(int sockfd, uint64_t timeout_us); long virtio_vsock_peek(int sockfd); +bool virtio_vsock_can_write(int sockfd); long virtio_vsock_read(int sockfd, void* buf, size_t count); long virtio_vsock_write(int sockfd, const void* buf, size_t count); int virtio_vsock_getsockname(int sockfd, const void* addr, size_t* addrlen); @@ -369,7 +370,6 @@ int virtio_vsock_set_socket_options(int sockfd, bool ipv6_v6only, bool reuseport int virtio_vsock_isr(void); int virtio_vsock_bottomhalf(void); -bool virtio_vsock_can_write(void); int virtio_vsock_init(struct virtio_pci_regs* pci_regs, struct virtio_vsock_config* pci_config, uint64_t notify_off_addr, uint32_t notify_off_multiplier, uint32_t* interrupt_status_reg); diff --git a/pal/src/host/vm-common/kernel_virtio_vsock.c b/pal/src/host/vm-common/kernel_virtio_vsock.c index cff91084..c7d41507 100644 --- a/pal/src/host/vm-common/kernel_virtio_vsock.c +++ b/pal/src/host/vm-common/kernel_virtio_vsock.c @@ -6,10 +6,6 @@ * * Reference: https://docs.oasis-open.org/virtio/virtio/v1.1/csprd01/virtio-v1.1-csprd01.pdf * - * TODO: - * - Implement guest-side buffer space management via peer_buf_alloc, peer_fwd_cnt, tx_cnt - * (see section 5.10.6.3 in spec) - * * Diagram with flows: * * Bottomhalves thread (CPU0) + App threads (CPU0-CPUn) @@ -418,13 +414,6 @@ static int send_pending_tq_control_packets(void) { return ret; } -bool virtio_vsock_can_write(void) { - spinlock_lock(&g_vsock_transmit_lock); - bool can_write = (g_vsock && g_vsock->tq->free_desc != g_vsock->tq->queue_size); - spinlock_unlock(&g_vsock_transmit_lock); - return can_write; -} - /* called from the bottomhalf thread in normal context (not interrupt context) */ int virtio_vsock_bottomhalf(void) { int handle_rq_ret = handle_rq_with_disabled_notifications(); @@ -574,6 +563,10 @@ static struct virtio_vsock_connection* create_connection(uint64_t host_port, uin conn->host_port = host_port; conn->guest_port = guest_port; + conn->tx_cnt = 0; + conn->peer_fwd_cnt = 0; + conn->peer_buf_alloc = 0; + conn->fwd_cnt = 0; conn->buf_alloc = VSOCK_MAX_PACKETS * VSOCK_MAX_PAYLOAD_SIZE; @@ -742,7 +735,12 @@ static int send_rw_packet(struct virtio_vsock_connection* conn, const char* payl if (!packet) return -PAL_ERROR_NOMEM; - return copy_into_tq_and_free(packet); + int ret = copy_into_tq_and_free(packet); + if (!ret) { + /* successfully queued the packet to be sent, only then update "bytes transmitted" */ + conn->tx_cnt += payload_size; + } + return ret; } /* takes ownership of the packet */ @@ -834,6 +832,15 @@ static int process_packet(struct virtio_vsock_packet* packet) { goto out; } + /* + * Even if packets have malicious fwd_cnt and buf_alloc values, this is benign because it only + * affects buffer space management: (1) Gramine may conclude that the host side can receive + * bytes even though it can't, or (2) Gramine may conclude that the host side can't receive + * bytes even if it can. Former case results in a packet dropped by the host, latter case + * results in Gramine app not sending the packet. Both cases are Denial of Service. + * + * See also virtio_vsock_can_write(). + */ conn->peer_fwd_cnt = packet->header.fwd_cnt; conn->peer_buf_alloc = packet->header.buf_alloc; @@ -1561,6 +1568,36 @@ long virtio_vsock_peek(int sockfd) { return ret; } +bool virtio_vsock_can_write(int sockfd) { + spinlock_lock(&g_vsock_transmit_lock); + bool can_write = (g_vsock && g_vsock->tq->free_desc != g_vsock->tq->queue_size); + spinlock_unlock(&g_vsock_transmit_lock); + + if (can_write) { + /* we can send; additionally check that host can receive (buffer space management) */ + spinlock_lock(&g_vsock_connections_lock); + + struct virtio_vsock_connection* conn = get_connection(sockfd); + if (conn) { + uint32_t bytes_in_flight = conn->tx_cnt - conn->peer_fwd_cnt; + int64_t bytes_avail_in_peer_buf = (int64_t)conn->peer_buf_alloc - bytes_in_flight; + if (bytes_avail_in_peer_buf < 0) { + /* play it safe (peer_fwd_cnt and peer_buf_alloc are potentially malicious) */ + bytes_avail_in_peer_buf = 0; + } + + can_write = (bytes_avail_in_peer_buf > 0); + } else { + /* maybe connection was shutdown, just return that we can't write in this case */ + can_write = false; + } + + spinlock_unlock(&g_vsock_connections_lock); + } + + return can_write; +} + long virtio_vsock_read(int sockfd, void* buf, size_t count) { long ret; diff --git a/pal/src/host/vm-common/pal_common_object.c b/pal/src/host/vm-common/pal_common_object.c index c7ee6a21..bf27bf05 100644 --- a/pal/src/host/vm-common/pal_common_object.c +++ b/pal/src/host/vm-common/pal_common_object.c @@ -112,7 +112,7 @@ static int check_socket_handle(struct pal_handle* handle, pal_wait_flags_t event if ((events & PAL_WAIT_READ) && peeked) revents |= PAL_WAIT_READ; - if ((events & PAL_WAIT_WRITE) && virtio_vsock_can_write()) + if ((events & PAL_WAIT_WRITE) && virtio_vsock_can_write(handle->sock.fd)) revents |= PAL_WAIT_WRITE; *out_events = revents;