From 80555db2cb82ac4ec9a70e9893bb1f756339f0ee Mon Sep 17 00:00:00 2001 From: Baojun Wang Date: Mon, 6 Jun 2022 17:38:13 -0400 Subject: [PATCH] Add Seek instance for std::io::Take --- library/std/src/io/mod.rs | 61 ++++++++++++++++++++++++++++++++-- library/std/src/io/tests.rs | 65 +++++++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 2 deletions(-) diff --git a/library/std/src/io/mod.rs b/library/std/src/io/mod.rs index 94812e3fe3b2c..407572a9718b2 100644 --- a/library/std/src/io/mod.rs +++ b/library/std/src/io/mod.rs @@ -988,7 +988,7 @@ pub trait Read { where Self: Sized, { - Take { inner: self, limit } + Take { inner: self, limit, cursor: 0, seek_once: false } } } @@ -2408,6 +2408,8 @@ impl SizeHint for Chain { pub struct Take { inner: T, limit: u64, + cursor: u64, + seek_once: bool, } impl Take { @@ -2559,6 +2561,7 @@ impl Read for Take { let max = cmp::min(buf.len() as u64, self.limit) as usize; let n = self.inner.read(&mut buf[..max])?; + self.cursor += n as u64; self.limit -= n as u64; Ok(n) } @@ -2600,11 +2603,12 @@ impl Read for Take { } buf.add_filled(filled); - + self.cursor += filled as u64; self.limit -= filled as u64; } else { self.inner.read_buf(buf)?; + self.cursor += buf.filled_len().saturating_sub(prev_filled) as u64; //inner may unfill self.limit -= buf.filled_len().saturating_sub(prev_filled) as u64; } @@ -2623,17 +2627,70 @@ impl BufRead for Take { let buf = self.inner.fill_buf()?; let cap = cmp::min(buf.len() as u64, self.limit) as usize; + self.cursor = cap as u64; Ok(&buf[..cap]) } fn consume(&mut self, amt: usize) { // Don't let callers reset the limit by passing an overlarge value let amt = cmp::min(amt as u64, self.limit) as usize; + self.cursor += amt as u64; self.limit -= amt as u64; self.inner.consume(amt); } } +#[stable(feature = "rust1", since = "1.0.0")] +impl Seek for Take { + fn seek(&mut self, pos: SeekFrom) -> Result { + if !self.seek_once { + let old_pos = self.inner.stream_position()?; + let end = self.inner.seek(SeekFrom::End(0))?; + if end != old_pos { + self.inner.seek(SeekFrom::Start(old_pos))?; + } + self.seek_once = true; + self.limit = cmp::min(self.limit, end - old_pos); + } + let stream_end = self.cursor + self.limit; + let position = match pos { + SeekFrom::Start(k) => Some(cmp::min(k, stream_end)), + SeekFrom::Current(k) if k < 0 => { + if -k as u64 > self.cursor { + None + } else { + Some(self.cursor - (-k as u64)) + } + } + SeekFrom::Current(k) => Some(cmp::min(self.cursor + k as u64, stream_end)), + SeekFrom::End(k) if k >= 0 => Some(stream_end), + SeekFrom::End(k) => { + if -k as u64 > stream_end { + None + } else { + Some(stream_end - (-k) as u64) + } + } + }; + + match position { + None => Err(ErrorKind::InvalidInput.into()), + Some(pos) => { + let rel = pos as i64 - self.cursor as i64; + self.inner.seek(SeekFrom::Current(rel))?; + if rel >= 0 { + self.cursor += rel as u64; + self.limit -= rel as u64; + } else { + self.cursor -= -rel as u64; + self.limit += -rel as u64; + } + Ok(pos) + } + } + } +} + impl SizeHint for Take { #[inline] fn lower_bound(&self) -> usize { diff --git a/library/std/src/io/tests.rs b/library/std/src/io/tests.rs index eb62634856462..1fd01237f9030 100644 --- a/library/std/src/io/tests.rs +++ b/library/std/src/io/tests.rs @@ -602,3 +602,68 @@ fn bench_take_read_buf(b: &mut test::Bencher) { [255; 128].take(64).read_buf(&mut rbuf).unwrap(); }); } + +#[test] +fn test_io_take_seek() { + let mut buf = Cursor::new(b"....0123456789abcdef"); + buf.set_position(4); + { + let mut stream = buf.by_ref().take(8); + assert_eq!(stream.seek(SeekFrom::End(0)).unwrap(), 8); + assert_eq!(stream.seek(SeekFrom::End(4)).unwrap(), 8); + assert_eq!(stream.seek(SeekFrom::End(-8)).unwrap(), 0); + assert_eq!(stream.seek(SeekFrom::End(-9)).unwrap_err().kind(), io::ErrorKind::InvalidInput); + let mut bytes: [u8; 2] = [0; 2]; + assert!(stream.read_exact(&mut bytes).is_ok()); + assert_eq!(bytes, *b"01"); + assert_eq!(stream.stream_position().unwrap(), 2); + assert_eq!(stream.seek(SeekFrom::Current(2)).unwrap(), 4); + assert_eq!( + stream.seek(SeekFrom::Current(-5)).unwrap_err().kind(), + io::ErrorKind::InvalidInput + ); + assert_eq!(stream.seek(SeekFrom::Start(1)).unwrap(), 1); + assert!(stream.read_exact(&mut bytes).is_ok()); + assert_eq!(bytes, *b"12"); + assert_eq!(stream.seek(SeekFrom::Current(3)).unwrap(), 6); + assert!(stream.read_exact(&mut bytes).is_ok()); + assert_eq!(bytes, *b"67"); + assert_eq!(stream.stream_position().unwrap(), 8); + // reached end of file. + assert!(stream.read_exact(&mut bytes).is_err()); + assert_eq!(stream.seek(SeekFrom::Current(-3)).unwrap(), 5); + let mut res = Vec::new(); + assert!(stream.read_to_end(&mut res).is_ok()); + assert_eq!(&res, b"567"); + assert_eq!(stream.stream_position().unwrap(), 8); + } + assert_eq!(buf.stream_position().unwrap(), 12); +} + +#[test] +fn test_io_take_seek_insufficient_bytes() { + let mut buf = Cursor::new(b"....0123456789abcdef"); + buf.set_position(16); + { + // only four bytes are available. + let mut stream = buf.by_ref().take(8); + assert_eq!(stream.seek(SeekFrom::Start(10)).unwrap(), 4); + assert_eq!(stream.seek(SeekFrom::End(-4)).unwrap(), 0); + assert_eq!(stream.seek(SeekFrom::End(1)).unwrap(), 4); + assert!(stream.seek(SeekFrom::Current(-5)).is_err()); + assert_eq!(stream.seek(SeekFrom::Current(-4)).unwrap(), 0); + let mut bytes: [u8; 2] = [0; 2]; + assert!(stream.read_exact(&mut bytes).is_ok()); + assert_eq!(bytes, *b"cd"); + assert_eq!(stream.stream_position().unwrap(), 2); + let mut res = Vec::new(); + assert!(stream.read_to_end(&mut res).is_ok()); + assert_eq!(&res, b"ef"); + assert_eq!(stream.stream_position().unwrap(), 4); + assert!(stream.seek(SeekFrom::Current(-1)).is_ok()); + assert_eq!(stream.read_exact(&mut bytes).unwrap_err().kind(), io::ErrorKind::UnexpectedEof); + assert_eq!(stream.read_to_end(&mut res).unwrap(), 0); + assert_eq!(stream.stream_position().unwrap(), 4); + } + assert_eq!(buf.stream_position().unwrap(), 20); +}