diff --git a/Cargo.toml b/Cargo.toml index f8295e7..5ae039e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,13 +33,13 @@ features = ["os-poll", "os-ext"] [dependencies.async-io] optional = true -version = "1.13" +version = "2" [features] default = [] mio_socket = ["mio"] tokio_socket = ["tokio", "futures"] -smol_socket = ["async-io","futures"] +smol_socket = ["async-io", "futures"] [dev-dependencies] netlink-packet-audit = "0.4.1" diff --git a/src/smol.rs b/src/smol.rs index ca30203..8a767aa 100644 --- a/src/smol.rs +++ b/src/smol.rs @@ -2,7 +2,7 @@ use std::{ io, - os::unix::io::{AsRawFd, FromRawFd, RawFd}, + os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, RawFd}, task::{Context, Poll}, }; @@ -31,6 +31,12 @@ impl AsRawFd for SmolSocket { } } +impl AsFd for SmolSocket { + fn as_fd(&self) -> BorrowedFd<'_> { + self.0.get_ref().as_fd() + } +} + // async_io::Async<..>::{read,write}_with[_mut] functions try IO first, // and only register context if it would block. // replicate this in these poll functions: @@ -79,7 +85,7 @@ impl AsyncSocket for SmolSocket { /// Mutable access to underyling [`Socket`] fn socket_mut(&mut self) -> &mut Socket { - self.0.get_mut() + unsafe { self.0.get_mut() } } fn new(protocol: isize) -> io::Result { @@ -92,7 +98,7 @@ impl AsyncSocket for SmolSocket { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - self.poll_write_with(cx, |this| this.0.get_mut().send(buf, 0)) + self.poll_write_with(cx, |this| this.socket_mut().send(buf, 0)) } fn poll_send_to( @@ -101,7 +107,7 @@ impl AsyncSocket for SmolSocket { buf: &[u8], addr: &SocketAddr, ) -> Poll> { - self.poll_write_with(cx, |this| this.0.get_mut().send_to(buf, addr, 0)) + self.poll_write_with(cx, |this| this.socket_mut().send_to(buf, addr, 0)) } fn poll_recv( @@ -113,7 +119,7 @@ impl AsyncSocket for SmolSocket { B: bytes::BufMut, { self.poll_read_with(cx, |this| { - this.0.get_mut().recv(buf, 0).map(|_len| ()) + this.socket_mut().recv(buf, 0).map(|_len| ()) }) } @@ -126,7 +132,7 @@ impl AsyncSocket for SmolSocket { B: bytes::BufMut, { self.poll_read_with(cx, |this| { - let x = this.0.get_mut().recv_from(buf, 0); + let x = this.socket_mut().recv_from(buf, 0); trace!("poll_recv_from: {:?}", x); x.map(|(_len, addr)| addr) }) @@ -136,6 +142,6 @@ impl AsyncSocket for SmolSocket { &mut self, cx: &mut Context<'_>, ) -> Poll, SocketAddr)>> { - self.poll_read_with(cx, |this| this.0.get_mut().recv_from_full()) + self.poll_read_with(cx, |this| this.socket_mut().recv_from_full()) } } diff --git a/src/socket.rs b/src/socket.rs index a846f03..9b63959 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -3,7 +3,10 @@ use std::{ io::{Error, Result}, mem, - os::unix::io::{AsRawFd, FromRawFd, RawFd}, + os::{ + fd::{AsFd, BorrowedFd, FromRawFd}, + unix::io::{AsRawFd, RawFd}, + }, }; use crate::SocketAddr; @@ -60,6 +63,12 @@ impl AsRawFd for Socket { } } +impl AsFd for Socket { + fn as_fd(&self) -> BorrowedFd<'_> { + unsafe { BorrowedFd::borrow_raw(self.0) } + } +} + impl FromRawFd for Socket { unsafe fn from_raw_fd(fd: RawFd) -> Self { Socket(fd) @@ -68,7 +77,7 @@ impl FromRawFd for Socket { impl Drop for Socket { fn drop(&mut self) { - unsafe { libc::close(self.0) }; + unsafe { libc::close(self.as_raw_fd()) }; } } @@ -94,7 +103,7 @@ impl Socket { /// Bind the socket to the given address pub fn bind(&mut self, addr: &SocketAddr) -> Result<()> { let (addr_ptr, addr_len) = addr.as_raw(); - let res = unsafe { libc::bind(self.0, addr_ptr, addr_len) }; + let res = unsafe { libc::bind(self.as_raw_fd(), addr_ptr, addr_len) }; if res < 0 { return Err(Error::last_os_error()); } @@ -115,7 +124,9 @@ impl Socket { let (addr_ptr, mut addr_len) = addr.as_raw_mut(); let addr_len_copy = addr_len; let addr_len_ptr = &mut addr_len as *mut libc::socklen_t; - let res = unsafe { libc::getsockname(self.0, addr_ptr, addr_len_ptr) }; + let res = unsafe { + libc::getsockname(self.as_raw_fd(), addr_ptr, addr_len_ptr) + }; if res < 0 { return Err(Error::last_os_error()); } @@ -128,8 +139,9 @@ impl Socket { /// Make this socket non-blocking pub fn set_non_blocking(&self, non_blocking: bool) -> Result<()> { let mut non_blocking = non_blocking as libc::c_int; - let res = - unsafe { libc::ioctl(self.0, libc::FIONBIO, &mut non_blocking) }; + let res = unsafe { + libc::ioctl(self.as_raw_fd(), libc::FIONBIO, &mut non_blocking) + }; if res < 0 { return Err(Error::last_os_error()); } @@ -152,7 +164,7 @@ impl Socket { /// 2. connect it to the kernel with [`Socket::connect`] /// 3. send a request to the kernel with [`Socket::send`] /// 4. read the response (which can span over several messages) - /// [`Socket::recv`] + /// [`Socket::recv`] /// /// ```rust /// use netlink_sys::{protocols::NETLINK_ROUTE, Socket, SocketAddr}; @@ -216,7 +228,7 @@ impl Socket { // - https://stackoverflow.com/a/14046386/1836144 // - https://lists.isc.org/pipermail/bind-users/2009-August/077527.html let (addr, addr_len) = remote_addr.as_raw(); - let res = unsafe { libc::connect(self.0, addr, addr_len) }; + let res = unsafe { libc::connect(self.as_raw_fd(), addr, addr_len) }; if res < 0 { return Err(Error::last_os_error()); } @@ -293,7 +305,7 @@ impl Socket { let res = unsafe { libc::recvfrom( - self.0, + self.as_raw_fd(), buf_ptr, buf_len, flags, @@ -324,7 +336,8 @@ impl Socket { let buf_ptr = chunk.as_mut_ptr() as *mut libc::c_void; let buf_len = chunk.len() as libc::size_t; - let res = unsafe { libc::recv(self.0, buf_ptr, buf_len, flags) }; + let res = + unsafe { libc::recv(self.as_raw_fd(), buf_ptr, buf_len, flags) }; if res < 0 { return Err(Error::last_os_error()); } else { @@ -367,7 +380,14 @@ impl Socket { let buf_len = buf.len() as libc::size_t; let res = unsafe { - libc::sendto(self.0, buf_ptr, buf_len, flags, addr_ptr, addr_len) + libc::sendto( + self.as_raw_fd(), + buf_ptr, + buf_len, + flags, + addr_ptr, + addr_len, + ) }; if res < 0 { return Err(Error::last_os_error()); @@ -382,7 +402,8 @@ impl Socket { let buf_ptr = buf.as_ptr() as *const libc::c_void; let buf_len = buf.len() as libc::size_t; - let res = unsafe { libc::send(self.0, buf_ptr, buf_len, flags) }; + let res = + unsafe { libc::send(self.as_raw_fd(), buf_ptr, buf_len, flags) }; if res < 0 { return Err(Error::last_os_error()); } @@ -391,12 +412,17 @@ impl Socket { pub fn set_pktinfo(&mut self, value: bool) -> Result<()> { let value: libc::c_int = value.into(); - setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_PKTINFO, value) + setsockopt( + self.as_raw_fd(), + libc::SOL_NETLINK, + libc::NETLINK_PKTINFO, + value, + ) } pub fn get_pktinfo(&self) -> Result { let res = getsockopt::( - self.0, + self.as_raw_fd(), libc::SOL_NETLINK, libc::NETLINK_PKTINFO, )?; @@ -405,7 +431,7 @@ impl Socket { pub fn add_membership(&mut self, group: u32) -> Result<()> { setsockopt( - self.0, + self.as_raw_fd(), libc::SOL_NETLINK, libc::NETLINK_ADD_MEMBERSHIP, group, @@ -414,7 +440,7 @@ impl Socket { pub fn drop_membership(&mut self, group: u32) -> Result<()> { setsockopt( - self.0, + self.as_raw_fd(), libc::SOL_NETLINK, libc::NETLINK_DROP_MEMBERSHIP, group, @@ -434,7 +460,7 @@ impl Socket { pub fn set_broadcast_error(&mut self, value: bool) -> Result<()> { let value: libc::c_int = value.into(); setsockopt( - self.0, + self.as_raw_fd(), libc::SOL_NETLINK, libc::NETLINK_BROADCAST_ERROR, value, @@ -443,7 +469,7 @@ impl Socket { pub fn get_broadcast_error(&self) -> Result { let res = getsockopt::( - self.0, + self.as_raw_fd(), libc::SOL_NETLINK, libc::NETLINK_BROADCAST_ERROR, )?; @@ -454,12 +480,17 @@ impl Socket { /// unicast and broadcast listeners to avoid receiving `ENOBUFS` errors. pub fn set_no_enobufs(&mut self, value: bool) -> Result<()> { let value: libc::c_int = value.into(); - setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_NO_ENOBUFS, value) + setsockopt( + self.as_raw_fd(), + libc::SOL_NETLINK, + libc::NETLINK_NO_ENOBUFS, + value, + ) } pub fn get_no_enobufs(&self) -> Result { let res = getsockopt::( - self.0, + self.as_raw_fd(), libc::SOL_NETLINK, libc::NETLINK_NO_ENOBUFS, )?; @@ -474,7 +505,7 @@ impl Socket { pub fn set_listen_all_namespaces(&mut self, value: bool) -> Result<()> { let value: libc::c_int = value.into(); setsockopt( - self.0, + self.as_raw_fd(), libc::SOL_NETLINK, libc::NETLINK_LISTEN_ALL_NSID, value, @@ -483,7 +514,7 @@ impl Socket { pub fn get_listen_all_namespaces(&self) -> Result { let res = getsockopt::( - self.0, + self.as_raw_fd(), libc::SOL_NETLINK, libc::NETLINK_LISTEN_ALL_NSID, )?; @@ -498,12 +529,17 @@ impl Socket { /// acknowledgment. pub fn set_cap_ack(&mut self, value: bool) -> Result<()> { let value: libc::c_int = value.into(); - setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_CAP_ACK, value) + setsockopt( + self.as_raw_fd(), + libc::SOL_NETLINK, + libc::NETLINK_CAP_ACK, + value, + ) } pub fn get_cap_ack(&self) -> Result { let res = getsockopt::( - self.0, + self.as_raw_fd(), libc::SOL_NETLINK, libc::NETLINK_CAP_ACK, )?; @@ -515,12 +551,17 @@ impl Socket { /// NLMSG_ERROR and NLMSG_DONE messages. pub fn set_ext_ack(&mut self, value: bool) -> Result<()> { let value: libc::c_int = value.into(); - setsockopt(self.0, libc::SOL_NETLINK, libc::NETLINK_EXT_ACK, value) + setsockopt( + self.as_raw_fd(), + libc::SOL_NETLINK, + libc::NETLINK_EXT_ACK, + value, + ) } pub fn get_ext_ack(&self) -> Result { let res = getsockopt::( - self.0, + self.as_raw_fd(), libc::SOL_NETLINK, libc::NETLINK_EXT_ACK, )?; @@ -534,13 +575,13 @@ impl Socket { /// the maximum allowed value is set by the /proc/sys/net/core/rmem_max /// file. The minimum (doubled) value for this option is 256. pub fn set_rx_buf_sz(&self, size: T) -> Result<()> { - setsockopt(self.0, libc::SOL_SOCKET, libc::SO_RCVBUF, size) + setsockopt(self.as_raw_fd(), libc::SOL_SOCKET, libc::SO_RCVBUF, size) } /// Gets socket receive buffer in bytes pub fn get_rx_buf_sz(&self) -> Result { let res = getsockopt::( - self.0, + self.as_raw_fd(), libc::SOL_SOCKET, libc::SO_RCVBUF, )?; @@ -552,7 +593,7 @@ impl Socket { pub fn set_netlink_get_strict_chk(&self, value: bool) -> Result<()> { let value: u32 = value.into(); setsockopt( - self.0, + self.as_raw_fd(), libc::SOL_NETLINK, libc::NETLINK_GET_STRICT_CHK, value,