Skip to content

chore(deps): upgrade async-io #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ features = ["os-poll", "os-ext"]

[dependencies.async-io]
optional = true
version = "1.13"
version = "2"

[features]
default = []
mio_socket = ["mio"]
tokio_socket = ["tokio", "futures"]
smol_socket = ["async-io","futures"]
smol_socket = ["async-io", "futures"]

[dev-dependencies]
netlink-packet-audit = "0.4.1"
Expand Down
20 changes: 13 additions & 7 deletions src/smol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use std::{
io,
os::unix::io::{AsRawFd, FromRawFd, RawFd},
os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, RawFd},
task::{Context, Poll},
};

Expand Down Expand Up @@ -31,6 +31,12 @@ impl AsRawFd for SmolSocket {
}
}

impl AsFd for SmolSocket {
fn as_fd(&self) -> BorrowedFd<'_> {
self.0.get_ref().as_fd()
}
}

// async_io::Async<..>::{read,write}_with[_mut] functions try IO first,
// and only register context if it would block.
// replicate this in these poll functions:
Expand Down Expand Up @@ -79,7 +85,7 @@ impl AsyncSocket for SmolSocket {

/// Mutable access to underyling [`Socket`]
fn socket_mut(&mut self) -> &mut Socket {
self.0.get_mut()
unsafe { self.0.get_mut() }
}

fn new(protocol: isize) -> io::Result<Self> {
Expand All @@ -92,7 +98,7 @@ impl AsyncSocket for SmolSocket {
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.poll_write_with(cx, |this| this.0.get_mut().send(buf, 0))
self.poll_write_with(cx, |this| this.socket_mut().send(buf, 0))
}

fn poll_send_to(
Expand All @@ -101,7 +107,7 @@ impl AsyncSocket for SmolSocket {
buf: &[u8],
addr: &SocketAddr,
) -> Poll<io::Result<usize>> {
self.poll_write_with(cx, |this| this.0.get_mut().send_to(buf, addr, 0))
self.poll_write_with(cx, |this| this.socket_mut().send_to(buf, addr, 0))
}

fn poll_recv<B>(
Expand All @@ -113,7 +119,7 @@ impl AsyncSocket for SmolSocket {
B: bytes::BufMut,
{
self.poll_read_with(cx, |this| {
this.0.get_mut().recv(buf, 0).map(|_len| ())
this.socket_mut().recv(buf, 0).map(|_len| ())
})
}

Expand All @@ -126,7 +132,7 @@ impl AsyncSocket for SmolSocket {
B: bytes::BufMut,
{
self.poll_read_with(cx, |this| {
let x = this.0.get_mut().recv_from(buf, 0);
let x = this.socket_mut().recv_from(buf, 0);
trace!("poll_recv_from: {:?}", x);
x.map(|(_len, addr)| addr)
})
Expand All @@ -136,6 +142,6 @@ impl AsyncSocket for SmolSocket {
&mut self,
cx: &mut Context<'_>,
) -> Poll<io::Result<(Vec<u8>, SocketAddr)>> {
self.poll_read_with(cx, |this| this.0.get_mut().recv_from_full())
self.poll_read_with(cx, |this| this.socket_mut().recv_from_full())
}
}
99 changes: 70 additions & 29 deletions src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
use std::{
io::{Error, Result},
mem,
os::unix::io::{AsRawFd, FromRawFd, RawFd},
os::{
fd::{AsFd, BorrowedFd, FromRawFd},
unix::io::{AsRawFd, RawFd},
},
};

use crate::SocketAddr;
Expand Down Expand Up @@ -60,6 +63,12 @@ impl AsRawFd for Socket {
}
}

impl AsFd for Socket {
fn as_fd(&self) -> BorrowedFd<'_> {
unsafe { BorrowedFd::borrow_raw(self.0) }
}
}

impl FromRawFd for Socket {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
Socket(fd)
Expand All @@ -68,7 +77,7 @@ impl FromRawFd for Socket {

impl Drop for Socket {
fn drop(&mut self) {
unsafe { libc::close(self.0) };
unsafe { libc::close(self.as_raw_fd()) };
}
}

Expand All @@ -94,7 +103,7 @@ impl Socket {
/// Bind the socket to the given address
pub fn bind(&mut self, addr: &SocketAddr) -> Result<()> {
let (addr_ptr, addr_len) = addr.as_raw();
let res = unsafe { libc::bind(self.0, addr_ptr, addr_len) };
let res = unsafe { libc::bind(self.as_raw_fd(), addr_ptr, addr_len) };
if res < 0 {
return Err(Error::last_os_error());
}
Expand All @@ -115,7 +124,9 @@ impl Socket {
let (addr_ptr, mut addr_len) = addr.as_raw_mut();
let addr_len_copy = addr_len;
let addr_len_ptr = &mut addr_len as *mut libc::socklen_t;
let res = unsafe { libc::getsockname(self.0, addr_ptr, addr_len_ptr) };
let res = unsafe {
libc::getsockname(self.as_raw_fd(), addr_ptr, addr_len_ptr)
};
if res < 0 {
return Err(Error::last_os_error());
}
Expand All @@ -128,8 +139,9 @@ impl Socket {
/// Make this socket non-blocking
pub fn set_non_blocking(&self, non_blocking: bool) -> Result<()> {
let mut non_blocking = non_blocking as libc::c_int;
let res =
unsafe { libc::ioctl(self.0, libc::FIONBIO, &mut non_blocking) };
let res = unsafe {
libc::ioctl(self.as_raw_fd(), libc::FIONBIO, &mut non_blocking)
};
if res < 0 {
return Err(Error::last_os_error());
}
Expand All @@ -152,7 +164,7 @@ impl Socket {
/// 2. connect it to the kernel with [`Socket::connect`]
/// 3. send a request to the kernel with [`Socket::send`]
/// 4. read the response (which can span over several messages)
/// [`Socket::recv`]
/// [`Socket::recv`]
///
/// ```rust
/// use netlink_sys::{protocols::NETLINK_ROUTE, Socket, SocketAddr};
Expand Down Expand Up @@ -216,7 +228,7 @@ impl Socket {
// - https://stackoverflow.com/a/14046386/1836144
// - https://lists.isc.org/pipermail/bind-users/2009-August/077527.html
let (addr, addr_len) = remote_addr.as_raw();
let res = unsafe { libc::connect(self.0, addr, addr_len) };
let res = unsafe { libc::connect(self.as_raw_fd(), addr, addr_len) };
if res < 0 {
return Err(Error::last_os_error());
}
Expand Down Expand Up @@ -293,7 +305,7 @@ impl Socket {

let res = unsafe {
libc::recvfrom(
self.0,
self.as_raw_fd(),
buf_ptr,
buf_len,
flags,
Expand Down Expand Up @@ -324,7 +336,8 @@ impl Socket {
let buf_ptr = chunk.as_mut_ptr() as *mut libc::c_void;
let buf_len = chunk.len() as libc::size_t;

let res = unsafe { libc::recv(self.0, buf_ptr, buf_len, flags) };
let res =
unsafe { libc::recv(self.as_raw_fd(), buf_ptr, buf_len, flags) };
if res < 0 {
return Err(Error::last_os_error());
} else {
Expand Down Expand Up @@ -367,7 +380,14 @@ impl Socket {
let buf_len = buf.len() as libc::size_t;

let res = unsafe {
libc::sendto(self.0, buf_ptr, buf_len, flags, addr_ptr, addr_len)
libc::sendto(
self.as_raw_fd(),
buf_ptr,
buf_len,
flags,
addr_ptr,
addr_len,
)
};
if res < 0 {
return Err(Error::last_os_error());
Expand All @@ -382,7 +402,8 @@ impl Socket {
let buf_ptr = buf.as_ptr() as *const libc::c_void;
let buf_len = buf.len() as libc::size_t;

let res = unsafe { libc::send(self.0, buf_ptr, buf_len, flags) };
let res =
unsafe { libc::send(self.as_raw_fd(), buf_ptr, buf_len, flags) };
if res < 0 {
return Err(Error::last_os_error());
}
Expand All @@ -391,12 +412,17 @@ impl Socket {

pub fn set_pktinfo(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_PKTINFO, value)
setsockopt(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_PKTINFO,
value,
)
}

pub fn get_pktinfo(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_PKTINFO,
)?;
Expand All @@ -405,7 +431,7 @@ impl Socket {

pub fn add_membership(&mut self, group: u32) -> Result<()> {
setsockopt(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_ADD_MEMBERSHIP,
group,
Expand All @@ -414,7 +440,7 @@ impl Socket {

pub fn drop_membership(&mut self, group: u32) -> Result<()> {
setsockopt(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_DROP_MEMBERSHIP,
group,
Expand All @@ -434,7 +460,7 @@ impl Socket {
pub fn set_broadcast_error(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_BROADCAST_ERROR,
value,
Expand All @@ -443,7 +469,7 @@ impl Socket {

pub fn get_broadcast_error(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_BROADCAST_ERROR,
)?;
Expand All @@ -454,12 +480,17 @@ impl Socket {
/// unicast and broadcast listeners to avoid receiving `ENOBUFS` errors.
pub fn set_no_enobufs(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_NO_ENOBUFS, value)
setsockopt(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_NO_ENOBUFS,
value,
)
}

pub fn get_no_enobufs(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_NO_ENOBUFS,
)?;
Expand All @@ -474,7 +505,7 @@ impl Socket {
pub fn set_listen_all_namespaces(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_LISTEN_ALL_NSID,
value,
Expand All @@ -483,7 +514,7 @@ impl Socket {

pub fn get_listen_all_namespaces(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_LISTEN_ALL_NSID,
)?;
Expand All @@ -498,12 +529,17 @@ impl Socket {
/// acknowledgment.
pub fn set_cap_ack(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_CAP_ACK, value)
setsockopt(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_CAP_ACK,
value,
)
}

pub fn get_cap_ack(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_CAP_ACK,
)?;
Expand All @@ -515,12 +551,17 @@ impl Socket {
/// NLMSG_ERROR and NLMSG_DONE messages.
pub fn set_ext_ack(&mut self, value: bool) -> Result<()> {
let value: libc::c_int = value.into();
setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_EXT_ACK, value)
setsockopt(
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_EXT_ACK,
value,
)
}

pub fn get_ext_ack(&self) -> Result<bool> {
let res = getsockopt::<libc::c_int>(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_EXT_ACK,
)?;
Expand All @@ -534,13 +575,13 @@ impl Socket {
/// the maximum allowed value is set by the /proc/sys/net/core/rmem_max
/// file. The minimum (doubled) value for this option is 256.
pub fn set_rx_buf_sz<T>(&self, size: T) -> Result<()> {
setsockopt(self.0, libc::SOL_SOCKET, libc::SO_RCVBUF, size)
setsockopt(self.as_raw_fd(), libc::SOL_SOCKET, libc::SO_RCVBUF, size)
}

/// Gets socket receive buffer in bytes
pub fn get_rx_buf_sz(&self) -> Result<usize> {
let res = getsockopt::<libc::c_int>(
self.0,
self.as_raw_fd(),
libc::SOL_SOCKET,
libc::SO_RCVBUF,
)?;
Expand All @@ -552,7 +593,7 @@ impl Socket {
pub fn set_netlink_get_strict_chk(&self, value: bool) -> Result<()> {
let value: u32 = value.into();
setsockopt(
self.0,
self.as_raw_fd(),
libc::SOL_NETLINK,
libc::NETLINK_GET_STRICT_CHK,
value,
Expand Down
Loading