Skip to content

Use uninitialized buffers for read and recvfrom #1606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/fd/eventfd.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use alloc::boxed::Box;
use alloc::collections::vec_deque::VecDeque;
use core::future::{self, Future};
use core::mem;
use core::mem::{self, MaybeUninit};
use core::task::{Poll, Waker, ready};

use async_lock::Mutex;
Expand Down Expand Up @@ -45,7 +45,7 @@ impl EventFd {

#[async_trait]
impl ObjectInterface for EventFd {
async fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
async fn read(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
let len = mem::size_of::<u64>();

if buf.len() < len {
Expand All @@ -58,8 +58,7 @@ impl ObjectInterface for EventFd {
let mut guard = ready!(pinned.as_mut().poll(cx));
if guard.counter > 0 {
guard.counter -= 1;
let tmp = u64::to_ne_bytes(1);
buf[..len].copy_from_slice(&tmp);
buf[..len].write_copy_of_slice(&u64::to_ne_bytes(1));
if let Some(cx) = guard.write_queue.pop_front() {
cx.wake_by_ref();
}
Expand All @@ -74,7 +73,7 @@ impl ObjectInterface for EventFd {
let tmp = guard.counter;
if tmp > 0 {
guard.counter = 0;
buf[..len].copy_from_slice(&u64::to_ne_bytes(tmp));
buf[..len].write_copy_of_slice(&u64::to_ne_bytes(tmp));
if let Some(cx) = guard.read_queue.pop_front() {
cx.wake_by_ref();
}
Expand Down
7 changes: 4 additions & 3 deletions src/fd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use alloc::boxed::Box;
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::future::{self, Future};
use core::mem::MaybeUninit;
use core::task::Poll::{Pending, Ready};
use core::time::Duration;

Expand Down Expand Up @@ -152,7 +153,7 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug {

/// `async_read` attempts to read `len` bytes from the object references
/// by the descriptor
async fn read(&self, _buf: &mut [u8]) -> io::Result<usize> {
async fn read(&self, _buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
Err(io::Error::ENOSYS)
}

Expand Down Expand Up @@ -230,7 +231,7 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug {

/// receive a message from a socket
#[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))]
async fn recvfrom(&self, _buffer: &mut [u8]) -> io::Result<(usize, Endpoint)> {
async fn recvfrom(&self, _buffer: &mut [MaybeUninit<u8>]) -> io::Result<(usize, Endpoint)> {
Err(io::Error::ENOSYS)
}

Expand Down Expand Up @@ -264,7 +265,7 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug {
}
}

pub(crate) fn read(fd: FileDescriptor, buf: &mut [u8]) -> io::Result<usize> {
pub(crate) fn read(fd: FileDescriptor, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
let obj = get_object(fd)?;

if buf.is_empty() {
Expand Down
7 changes: 4 additions & 3 deletions src/fd/socket/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use alloc::boxed::Box;
use alloc::collections::BTreeSet;
use alloc::sync::Arc;
use core::future;
use core::mem::MaybeUninit;
use core::sync::atomic::{AtomicU16, Ordering};
use core::task::Poll;

Expand Down Expand Up @@ -171,7 +172,7 @@ impl Socket {
.await
}

async fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
async fn read(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
future::poll_fn(|cx| {
self.with(|socket| {
let state = socket.state();
Expand All @@ -187,7 +188,7 @@ impl Socket {
socket
.recv(|data| {
let len = core::cmp::min(buffer.len(), data.len());
buffer[..len].copy_from_slice(&data[..len]);
buffer[..len].write_copy_of_slice(&data[..len]);
(len, len)
})
.map_err(|_| io::Error::EIO),
Expand Down Expand Up @@ -468,7 +469,7 @@ impl ObjectInterface for async_lock::RwLock<Socket> {
self.read().await.poll(event).await
}

async fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
async fn read(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
self.read().await.read(buffer).await
}

Expand Down
55 changes: 27 additions & 28 deletions src/fd/socket/udp.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use alloc::boxed::Box;
use core::future;
use core::mem::MaybeUninit;
use core::task::Poll;

use async_trait::async_trait;
Expand Down Expand Up @@ -141,24 +142,23 @@ impl Socket {
}
}

async fn recvfrom(&self, buffer: &mut [u8]) -> io::Result<(usize, Endpoint)> {
async fn recvfrom(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<(usize, Endpoint)> {
future::poll_fn(|cx| {
self.with(|socket| {
if socket.is_open() {
if socket.can_recv() {
match socket.recv_slice(buffer) {
Ok((len, meta)) => match self.endpoint {
Some(ep) => {
if meta.endpoint == ep {
Poll::Ready(Ok((len, meta.endpoint)))
} else {
buffer[..len].iter_mut().for_each(|x| *x = 0);
socket.register_recv_waker(cx.waker());
Poll::Pending
}
match socket.recv() {
// Drop the packet when the provided buffer cannot
// fit the payload.
Ok((data, meta)) if data.len() <= buffer.len() => {
if self.endpoint.is_none_or(|ep| meta.endpoint == ep) {
buffer[..data.len()].write_copy_of_slice(data);
Poll::Ready(Ok((data.len(), meta.endpoint)))
} else {
socket.register_recv_waker(cx.waker());
Poll::Pending
}
None => Poll::Ready(Ok((len, meta.endpoint))),
},
}
_ => Poll::Ready(Err(io::Error::EIO)),
}
} else {
Expand All @@ -174,24 +174,23 @@ impl Socket {
.map(|(len, endpoint)| (len, Endpoint::Ip(endpoint)))
}

async fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
async fn read(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
future::poll_fn(|cx| {
self.with(|socket| {
if socket.is_open() {
if socket.can_recv() {
match socket.recv_slice(buffer) {
Ok((len, meta)) => match self.endpoint {
Some(ep) => {
if meta.endpoint == ep {
Poll::Ready(Ok(len))
} else {
buffer[..len].iter_mut().for_each(|x| *x = 0);
socket.register_recv_waker(cx.waker());
Poll::Pending
}
match socket.recv() {
// Drop the packet when the provided buffer cannot
// fit the payload.
Ok((data, meta)) if data.len() <= buffer.len() => {
if self.endpoint.is_none_or(|ep| meta.endpoint == ep) {
buffer[..data.len()].write_copy_of_slice(data);
Poll::Ready(Ok(data.len()))
} else {
socket.register_recv_waker(cx.waker());
Poll::Pending
}
None => Poll::Ready(Ok(len)),
},
}
_ => Poll::Ready(Err(io::Error::EIO)),
}
} else {
Expand Down Expand Up @@ -257,11 +256,11 @@ impl ObjectInterface for async_lock::RwLock<Socket> {
self.read().await.sendto(buffer, endpoint).await
}

async fn recvfrom(&self, buffer: &mut [u8]) -> io::Result<(usize, Endpoint)> {
async fn recvfrom(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<(usize, Endpoint)> {
self.read().await.recvfrom(buffer).await
}

async fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
async fn read(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
self.read().await.read(buffer).await
}

Expand Down
9 changes: 5 additions & 4 deletions src/fd/socket/vsock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use alloc::boxed::Box;
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::future;
use core::mem::MaybeUninit;
use core::task::Poll;

use async_trait::async_trait;
Expand Down Expand Up @@ -312,7 +313,7 @@ impl Socket {
}
}

async fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
async fn read(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
let port = self.port;
future::poll_fn(|cx| {
let mut guard = VSOCK_MAP.lock();
Expand All @@ -331,7 +332,7 @@ impl Socket {
}
} else {
let tmp: Vec<_> = raw.buffer.drain(..len).collect();
buffer[..len].copy_from_slice(tmp.as_slice());
buffer[..len].write_copy_of_slice(tmp.as_slice());

Poll::Ready(Ok(len))
}
Expand All @@ -343,7 +344,7 @@ impl Socket {
Poll::Ready(Ok(0))
} else {
let tmp: Vec<_> = raw.buffer.drain(..len).collect();
buffer[..len].copy_from_slice(tmp.as_slice());
buffer[..len].write_copy_of_slice(tmp.as_slice());

Poll::Ready(Ok(len))
}
Expand Down Expand Up @@ -424,7 +425,7 @@ impl ObjectInterface for async_lock::RwLock<Socket> {
self.read().await.poll(event).await
}

async fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
async fn read(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
self.read().await.read(buffer).await
}

Expand Down
5 changes: 3 additions & 2 deletions src/fd/stdio.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use alloc::boxed::Box;
use core::future;
use core::mem::MaybeUninit;
use core::task::Poll;

use async_trait::async_trait;
Expand Down Expand Up @@ -27,7 +28,7 @@ impl ObjectInterface for GenericStdin {
Ok(event & available)
}

async fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
async fn read(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
future::poll_fn(|cx| {
let mut read_bytes = 0;
let mut guard = CONSOLE.lock();
Expand All @@ -36,7 +37,7 @@ impl ObjectInterface for GenericStdin {
let c = unsafe { char::from_u32_unchecked(byte.into()) };
guard.write(c.as_bytes());

buf[read_bytes] = byte;
buf[read_bytes].write(byte);
read_bytes += 1;

if read_bytes >= buf.len() {
Expand Down
7 changes: 4 additions & 3 deletions src/fs/fuse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use alloc::ffi::CString;
use alloc::string::String;
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::mem::MaybeUninit;
use core::sync::atomic::{AtomicU64, Ordering};
use core::task::Poll;
use core::{future, mem};
Expand Down Expand Up @@ -629,7 +630,7 @@ impl FuseFileHandleInner {
.await
}

fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
fn read(&mut self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
let mut len = buf.len();
if len > MAX_READ_LEN {
debug!("Reading longer than max_read_len: {}", len);
Expand All @@ -651,7 +652,7 @@ impl FuseFileHandleInner {
};
self.offset += len;

buf[..len].copy_from_slice(&rsp.payload.unwrap()[..len]);
buf[..len].write_copy_of_slice(&rsp.payload.unwrap()[..len]);

Ok(len)
} else {
Expand Down Expand Up @@ -767,7 +768,7 @@ impl ObjectInterface for FuseFileHandle {
self.0.lock().await.poll(event).await
}

async fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
async fn read(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
self.0.lock().await.read(buf)
}

Expand Down
11 changes: 6 additions & 5 deletions src/fs/mem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use alloc::collections::BTreeMap;
use alloc::string::{String, ToString};
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::mem::MaybeUninit;

use async_lock::{Mutex, RwLock};
use async_trait::async_trait;
Expand Down Expand Up @@ -59,7 +60,7 @@ impl ObjectInterface for RomFileInterface {
Ok(ret)
}

async fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
async fn read(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
{
let microseconds = arch::kernel::systemtime::now_micros();
let t = timespec::from_usec(microseconds as i64);
Expand All @@ -81,7 +82,7 @@ impl ObjectInterface for RomFileInterface {
buf.len()
};

buf[0..len].clone_from_slice(&vec[pos..pos + len]);
buf[..len].write_copy_of_slice(&vec[pos..pos + len]);
*pos_guard = pos + len;

Ok(len)
Expand Down Expand Up @@ -170,7 +171,7 @@ impl ObjectInterface for RamFileInterface {
Ok(event & available)
}

async fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
async fn read(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
{
let microseconds = arch::kernel::systemtime::now_micros();
let t = timespec::from_usec(microseconds as i64);
Expand All @@ -192,7 +193,7 @@ impl ObjectInterface for RamFileInterface {
buf.len()
};

buf[0..len].clone_from_slice(&guard.data[pos..pos + len]);
buf[..len].write_copy_of_slice(&guard.data[pos..pos + len]);
*pos_guard = pos + len;

Ok(len)
Expand All @@ -214,7 +215,7 @@ impl ObjectInterface for RamFileInterface {
guard.attr.st_mtim = t;
guard.attr.st_ctim = t;

guard.data[pos..pos + buf.len()].clone_from_slice(buf);
guard.data[pos..pos + buf.len()].copy_from_slice(buf);
*pos_guard = pos + buf.len();

Ok(buf.len())
Expand Down
1 change: 1 addition & 0 deletions src/fs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ impl File {

impl crate::io::Read for File {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let buf = unsafe { core::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len()) };
fd::read(self.fd, buf)
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/fs/uhyve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use alloc::ffi::CString;
use alloc::string::{String, ToString};
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::mem::MaybeUninit;

use async_lock::Mutex;
use async_trait::async_trait;
Expand All @@ -29,7 +30,7 @@ impl UhyveFileHandleInner {
Self(fd)
}

fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
fn read(&mut self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
let mut read_params = ReadParams {
fd: self.0,
buf: GuestVirtAddr::new(buf.as_mut_ptr() as u64),
Expand Down Expand Up @@ -94,7 +95,7 @@ impl UhyveFileHandle {

#[async_trait]
impl ObjectInterface for UhyveFileHandle {
async fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
async fn read(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
self.0.lock().await.read(buf)
}

Expand Down
Loading