Skip to content

Commit 297c74c

Browse files
Safer poll timeout
1 parent aad5511 commit 297c74c

File tree

3 files changed

+259
-5
lines changed

3 files changed

+259
-5
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ This project adheres to [Semantic Versioning](https://semver.org/).
4747
([#1870](https://github.com/nix-rust/nix/pull/1870))
4848
- The `length` argument of `sys::mman::mmap` is now of type `NonZeroUsize`.
4949
([#1873](https://github.com/nix-rust/nix/pull/1873))
50+
- The `timeout` argument of `poll::poll` is now of type `poll::PollTimeout`.
51+
([#1876](https://github.com/nix-rust/nix/pull/1876))
5052

5153
### Fixed
5254

src/poll.rs

Lines changed: 254 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
//! Wait for events to trigger on specific file descriptors
2+
use std::convert::TryFrom;
3+
use std::fmt;
24
use std::os::unix::io::{AsRawFd, RawFd};
5+
use std::time::Duration;
36

47
use crate::errno::Errno;
58
use crate::Result;
@@ -112,6 +115,255 @@ libc_bitflags! {
112115
}
113116
}
114117

118+
/// Timeout argument for [`poll`].
119+
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
120+
pub struct PollTimeout(i32);
121+
122+
/// Error type for [`PollTimeout::try_from::<i128>::()`].
123+
#[derive(Debug, Clone, Copy)]
124+
pub enum TryFromI128Error {
125+
/// Value is less than -1.
126+
Underflow(crate::Errno),
127+
/// Value is greater than [`i32::MAX`].
128+
Overflow(<i32 as TryFrom<i128>>::Error),
129+
}
130+
impl fmt::Display for TryFromI128Error {
131+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
132+
match self {
133+
Self::Underflow(err) => write!(f, "Underflow: {}", err),
134+
Self::Overflow(err) => write!(f, "Overflow: {}", err),
135+
}
136+
}
137+
}
138+
impl std::error::Error for TryFromI128Error {}
139+
140+
/// Error type for [`PollTimeout::try_from::<i68>()`].
141+
#[derive(Debug, Clone, Copy)]
142+
pub enum TryFromI64Error {
143+
/// Value is less than -1.
144+
Underflow(crate::Errno),
145+
/// Value is greater than [`i32::MAX`].
146+
Overflow(<i32 as TryFrom<i64>>::Error),
147+
}
148+
impl fmt::Display for TryFromI64Error {
149+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
150+
match self {
151+
Self::Underflow(err) => write!(f, "Underflow: {}", err),
152+
Self::Overflow(err) => write!(f, "Overflow: {}", err),
153+
}
154+
}
155+
}
156+
impl std::error::Error for TryFromI64Error {}
157+
158+
// These cases implement slightly different conversions that make using generics impossible without
159+
// specialization.
160+
impl PollTimeout {
161+
/// Blocks indefinitely.
162+
pub const NONE: Self = Self(1 << 31);
163+
/// Returns immediately.
164+
pub const ZERO: Self = Self(0);
165+
/// Blocks for at most [`std::i32::MAX`] milliseconds.
166+
pub const MAX: Self = Self(i32::MAX);
167+
/// Returns if `self` equals [`PollTimeout::NONE`].
168+
pub fn is_none(&self) -> bool {
169+
*self == Self::NONE
170+
}
171+
/// Returns if `self` does not equal [`PollTimeout::NONE`].
172+
pub fn is_some(&self) -> bool {
173+
!self.is_none()
174+
}
175+
/// Returns the timeout in milliseconds if there is some, otherwise returns `None`.
176+
pub fn timeout(&self) -> Option<i32> {
177+
self.is_some().then(|| self.0)
178+
}
179+
}
180+
impl TryFrom<Duration> for PollTimeout {
181+
type Error = <i32 as TryFrom<u128>>::Error;
182+
fn try_from(x: Duration) -> std::result::Result<Self, Self::Error> {
183+
Ok(Self(i32::try_from(x.as_millis())?))
184+
}
185+
}
186+
impl TryFrom<u128> for PollTimeout {
187+
type Error = <i32 as TryFrom<u128>>::Error;
188+
fn try_from(x: u128) -> std::result::Result<Self, Self::Error> {
189+
Ok(Self(i32::try_from(x)?))
190+
}
191+
}
192+
impl TryFrom<u64> for PollTimeout {
193+
type Error = <i32 as TryFrom<u64>>::Error;
194+
fn try_from(x: u64) -> std::result::Result<Self, Self::Error> {
195+
Ok(Self(i32::try_from(x)?))
196+
}
197+
}
198+
impl TryFrom<u32> for PollTimeout {
199+
type Error = <i32 as TryFrom<u32>>::Error;
200+
fn try_from(x: u32) -> std::result::Result<Self, Self::Error> {
201+
Ok(Self(i32::try_from(x)?))
202+
}
203+
}
204+
impl From<u16> for PollTimeout {
205+
fn from(x: u16) -> Self {
206+
Self(i32::from(x))
207+
}
208+
}
209+
impl From<u8> for PollTimeout {
210+
fn from(x: u8) -> Self {
211+
Self(i32::from(x))
212+
}
213+
}
214+
impl TryFrom<i128> for PollTimeout {
215+
type Error = TryFromI128Error;
216+
fn try_from(x: i128) -> std::result::Result<Self, Self::Error> {
217+
match x {
218+
-1 => Ok(Self::NONE),
219+
millis @ 0.. => Ok(Self(
220+
i32::try_from(millis).map_err(TryFromI128Error::Overflow)?,
221+
)),
222+
// EINVAL (ppoll()) The timeout value expressed in *ip is invalid (negative).
223+
_ => Err(TryFromI128Error::Underflow(Errno::EINVAL)),
224+
}
225+
}
226+
}
227+
impl TryFrom<i64> for PollTimeout {
228+
type Error = TryFromI64Error;
229+
fn try_from(x: i64) -> std::result::Result<Self, Self::Error> {
230+
match x {
231+
-1 => Ok(Self::NONE),
232+
millis @ 0.. => Ok(Self(
233+
i32::try_from(millis).map_err(TryFromI64Error::Overflow)?,
234+
)),
235+
// EINVAL (ppoll()) The timeout value expressed in *ip is invalid (negative).
236+
_ => Err(TryFromI64Error::Underflow(Errno::EINVAL)),
237+
}
238+
}
239+
}
240+
impl TryFrom<i32> for PollTimeout {
241+
type Error = Errno;
242+
fn try_from(x: i32) -> Result<Self> {
243+
match x {
244+
-1 => Ok(Self::NONE),
245+
millis @ 0.. => Ok(Self(millis)),
246+
// EINVAL (ppoll()) The timeout value expressed in *ip is invalid (negative).
247+
_ => Err(Errno::EINVAL),
248+
}
249+
}
250+
}
251+
impl TryFrom<i16> for PollTimeout {
252+
type Error = Errno;
253+
fn try_from(x: i16) -> Result<Self> {
254+
match x {
255+
-1 => Ok(Self::NONE),
256+
millis @ 0.. => Ok(Self(millis.into())),
257+
// EINVAL (ppoll()) The timeout value expressed in *ip is invalid (negative).
258+
_ => Err(Errno::EINVAL),
259+
}
260+
}
261+
}
262+
impl TryFrom<i8> for PollTimeout {
263+
type Error = Errno;
264+
fn try_from(x: i8) -> Result<Self> {
265+
match x {
266+
-1 => Ok(Self::NONE),
267+
millis @ 0.. => Ok(Self(millis.into())),
268+
// EINVAL (ppoll()) The timeout value expressed in *ip is invalid (negative).
269+
_ => Err(Errno::EINVAL),
270+
}
271+
}
272+
}
273+
impl TryFrom<PollTimeout> for Duration {
274+
type Error = ();
275+
fn try_from(x: PollTimeout) -> std::result::Result<Self, ()> {
276+
match x.timeout() {
277+
// SAFETY: `x.0` is always positive.
278+
Some(millis) => Ok(Duration::from_millis(unsafe {
279+
u64::try_from(millis).unwrap_unchecked()
280+
})),
281+
None => Err(()),
282+
}
283+
}
284+
}
285+
impl TryFrom<PollTimeout> for u128 {
286+
type Error = ();
287+
fn try_from(x: PollTimeout) -> std::result::Result<Self, ()> {
288+
match x.timeout() {
289+
// SAFETY: When `x.timeout()` returns `Some(a)`, `a` is always positive.
290+
Some(millis) => {
291+
Ok(unsafe { Self::try_from(millis).unwrap_unchecked() })
292+
}
293+
None => Err(()),
294+
}
295+
}
296+
}
297+
impl TryFrom<PollTimeout> for u64 {
298+
type Error = ();
299+
fn try_from(x: PollTimeout) -> std::result::Result<Self, ()> {
300+
match x.timeout() {
301+
// SAFETY: When `x.timeout()` returns `Some(a)`, `a` is always positive.
302+
Some(millis) => {
303+
Ok(unsafe { Self::try_from(millis).unwrap_unchecked() })
304+
}
305+
None => Err(()),
306+
}
307+
}
308+
}
309+
impl TryFrom<PollTimeout> for u32 {
310+
type Error = ();
311+
fn try_from(x: PollTimeout) -> std::result::Result<Self, ()> {
312+
match x.timeout() {
313+
// SAFETY: When `x.timeout()` returns `Some(a)`, `a` is always positive.
314+
Some(millis) => {
315+
Ok(unsafe { Self::try_from(millis).unwrap_unchecked() })
316+
}
317+
None => Err(()),
318+
}
319+
}
320+
}
321+
impl TryFrom<PollTimeout> for u16 {
322+
type Error = Option<<Self as TryFrom<i32>>::Error>;
323+
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
324+
match x.timeout() {
325+
Some(millis) => Ok(Self::try_from(millis).map_err(Some)?),
326+
None => Err(None),
327+
}
328+
}
329+
}
330+
impl TryFrom<PollTimeout> for u8 {
331+
type Error = Option<<Self as TryFrom<i32>>::Error>;
332+
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
333+
match x.timeout() {
334+
Some(millis) => Ok(Self::try_from(millis).map_err(Some)?),
335+
None => Err(None),
336+
}
337+
}
338+
}
339+
impl From<PollTimeout> for i128 {
340+
fn from(x: PollTimeout) -> Self {
341+
x.timeout().unwrap_or(-1).into()
342+
}
343+
}
344+
impl From<PollTimeout> for i64 {
345+
fn from(x: PollTimeout) -> Self {
346+
x.timeout().unwrap_or(-1).into()
347+
}
348+
}
349+
impl From<PollTimeout> for i32 {
350+
fn from(x: PollTimeout) -> Self {
351+
x.timeout().unwrap_or(-1)
352+
}
353+
}
354+
impl TryFrom<PollTimeout> for i16 {
355+
type Error = <Self as TryFrom<i32>>::Error;
356+
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
357+
Self::try_from(x.timeout().unwrap_or(-1))
358+
}
359+
}
360+
impl TryFrom<PollTimeout> for i8 {
361+
type Error = <Self as TryFrom<i32>>::Error;
362+
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
363+
Self::try_from(x.timeout().unwrap_or(-1))
364+
}
365+
}
366+
115367
/// `poll` waits for one of a set of file descriptors to become ready to perform I/O.
116368
/// ([`poll(2)`](https://pubs.opengroup.org/onlinepubs/9699919799/functions/poll.html))
117369
///
@@ -132,12 +384,12 @@ libc_bitflags! {
132384
/// in timeout means an infinite timeout. Specifying a timeout of zero
133385
/// causes `poll()` to return immediately, even if no file descriptors are
134386
/// ready.
135-
pub fn poll(fds: &mut [PollFd], timeout: libc::c_int) -> Result<libc::c_int> {
387+
pub fn poll(fds: &mut [PollFd], timeout: PollTimeout) -> Result<libc::c_int> {
136388
let res = unsafe {
137389
libc::poll(
138390
fds.as_mut_ptr() as *mut libc::pollfd,
139391
fds.len() as libc::nfds_t,
140-
timeout,
392+
timeout.into(),
141393
)
142394
};
143395

test/test_poll.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use nix::{
22
errno::Errno,
3-
poll::{poll, PollFd, PollFlags},
3+
poll::{poll, PollFd, PollFlags, PollTimeout},
44
unistd::{pipe, write},
55
};
66

@@ -22,14 +22,14 @@ fn test_poll() {
2222
let mut fds = [PollFd::new(r, PollFlags::POLLIN)];
2323

2424
// Poll an idle pipe. Should timeout
25-
let nfds = loop_while_eintr!(poll(&mut fds, 100));
25+
let nfds = loop_while_eintr!(poll(&mut fds, PollTimeout::from(100u8)));
2626
assert_eq!(nfds, 0);
2727
assert!(!fds[0].revents().unwrap().contains(PollFlags::POLLIN));
2828

2929
write(w, b".").unwrap();
3030

3131
// Poll a readable pipe. Should return an event.
32-
let nfds = poll(&mut fds, 100).unwrap();
32+
let nfds = poll(&mut fds, PollTimeout::from(100u8)).unwrap();
3333
assert_eq!(nfds, 1);
3434
assert!(fds[0].revents().unwrap().contains(PollFlags::POLLIN));
3535
}

0 commit comments

Comments
 (0)