Skip to content

Commit 2d15daa

Browse files
committed
Use uninitialized buffer for recvfrom
`smoltcp::socket::tcp::Socket::recv_slice` takes a `&[u8]`, so inline it and refactor. As a bonus, this avoids a copy and clear in the error case.
1 parent b8897bf commit 2d15daa

File tree

3 files changed

+27
-32
lines changed

3 files changed

+27
-32
lines changed

src/fd/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug {
231231

232232
/// receive a message from a socket
233233
#[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))]
234-
async fn recvfrom(&self, _buffer: &mut [u8]) -> io::Result<(usize, Endpoint)> {
234+
async fn recvfrom(&self, _buffer: &mut [MaybeUninit<u8>]) -> io::Result<(usize, Endpoint)> {
235235
Err(io::Error::ENOSYS)
236236
}
237237

src/fd/socket/udp.rs

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -142,24 +142,23 @@ impl Socket {
142142
}
143143
}
144144

145-
async fn recvfrom(&self, buffer: &mut [u8]) -> io::Result<(usize, Endpoint)> {
145+
async fn recvfrom(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<(usize, Endpoint)> {
146146
future::poll_fn(|cx| {
147147
self.with(|socket| {
148148
if socket.is_open() {
149149
if socket.can_recv() {
150-
match socket.recv_slice(buffer) {
151-
Ok((len, meta)) => match self.endpoint {
152-
Some(ep) => {
153-
if meta.endpoint == ep {
154-
Poll::Ready(Ok((len, meta.endpoint)))
155-
} else {
156-
buffer[..len].iter_mut().for_each(|x| *x = 0);
157-
socket.register_recv_waker(cx.waker());
158-
Poll::Pending
159-
}
150+
match socket.recv() {
151+
// Drop the packet when the provided buffer cannot
152+
// fit the payload.
153+
Ok((data, meta)) if data.len() <= buffer.len() => {
154+
if self.endpoint.is_none_or(|ep| meta.endpoint == ep) {
155+
buffer[..data.len()].write_copy_of_slice(data);
156+
Poll::Ready(Ok((data.len(), meta.endpoint)))
157+
} else {
158+
socket.register_recv_waker(cx.waker());
159+
Poll::Pending
160160
}
161-
None => Poll::Ready(Ok((len, meta.endpoint))),
162-
},
161+
}
163162
_ => Poll::Ready(Err(io::Error::EIO)),
164163
}
165164
} else {
@@ -175,24 +174,23 @@ impl Socket {
175174
.map(|(len, endpoint)| (len, Endpoint::Ip(endpoint)))
176175
}
177176

178-
async fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
177+
async fn read(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
179178
future::poll_fn(|cx| {
180179
self.with(|socket| {
181180
if socket.is_open() {
182181
if socket.can_recv() {
183-
match socket.recv_slice(buffer) {
184-
Ok((len, meta)) => match self.endpoint {
185-
Some(ep) => {
186-
if meta.endpoint == ep {
187-
Poll::Ready(Ok(len))
188-
} else {
189-
buffer[..len].iter_mut().for_each(|x| *x = 0);
190-
socket.register_recv_waker(cx.waker());
191-
Poll::Pending
192-
}
182+
match socket.recv() {
183+
// Drop the packet when the provided buffer cannot
184+
// fit the payload.
185+
Ok((data, meta)) if data.len() <= buffer.len() => {
186+
if self.endpoint.is_none_or(|ep| meta.endpoint == ep) {
187+
buffer[..data.len()].write_copy_of_slice(data);
188+
Poll::Ready(Ok(data.len()))
189+
} else {
190+
socket.register_recv_waker(cx.waker());
191+
Poll::Pending
193192
}
194-
None => Poll::Ready(Ok(len)),
195-
},
193+
}
196194
_ => Poll::Ready(Err(io::Error::EIO)),
197195
}
198196
} else {
@@ -258,14 +256,11 @@ impl ObjectInterface for async_lock::RwLock<Socket> {
258256
self.read().await.sendto(buffer, endpoint).await
259257
}
260258

261-
async fn recvfrom(&self, buffer: &mut [u8]) -> io::Result<(usize, Endpoint)> {
259+
async fn recvfrom(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<(usize, Endpoint)> {
262260
self.read().await.recvfrom(buffer).await
263261
}
264262

265263
async fn read(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
266-
// FIXME
267-
let buffer =
268-
unsafe { core::slice::from_raw_parts_mut(buffer.as_mut_ptr().cast(), buffer.len()) };
269264
self.read().await.read(buffer).await
270265
}
271266

src/syscalls/socket.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,7 @@ pub unsafe extern "C" fn sys_recvfrom(
962962
addr: *mut sockaddr,
963963
addrlen: *mut socklen_t,
964964
) -> isize {
965-
let slice = unsafe { core::slice::from_raw_parts_mut(buf, len) };
965+
let slice = unsafe { core::slice::from_raw_parts_mut(buf.cast(), len) };
966966
let obj = get_object(fd);
967967
obj.map_or_else(
968968
|e| -num::ToPrimitive::to_isize(&e).unwrap(),

0 commit comments

Comments
 (0)