Skip to content

Commit 01985ac

Browse files
karlmcdowallsylvestre
authored andcommitted
Head: ensure stdin input stream is correct on exit
Fix issue #7028 Head tool now ensures that stdin is set to the last character that was output by the tool. This ensures that if any subsequent tools are run from the same input stream they will start at the correct point in the stream.
1 parent b448722 commit 01985ac

File tree

6 files changed

+345
-45
lines changed

6 files changed

+345
-45
lines changed

src/uu/head/src/head.rs

Lines changed: 73 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@
77

88
use clap::{crate_version, Arg, ArgAction, ArgMatches, Command};
99
use std::ffi::OsString;
10+
#[cfg(unix)]
11+
use std::fs::File;
1012
use std::io::{self, ErrorKind, Read, Seek, SeekFrom, Write};
1113
use std::num::TryFromIntError;
14+
#[cfg(unix)]
15+
use std::os::fd::{AsRawFd, FromRawFd};
1216
use thiserror::Error;
1317
use uucore::display::Quotable;
1418
use uucore::error::{FromIo, UError, UResult};
@@ -239,7 +243,7 @@ impl HeadOptions {
239243
}
240244
}
241245

242-
fn read_n_bytes<R>(input: R, n: u64) -> std::io::Result<()>
246+
fn read_n_bytes<R>(input: R, n: u64) -> std::io::Result<u64>
243247
where
244248
R: Read,
245249
{
@@ -250,31 +254,31 @@ where
250254
let stdout = std::io::stdout();
251255
let mut stdout = stdout.lock();
252256

253-
io::copy(&mut reader, &mut stdout)?;
257+
let bytes_copied = io::copy(&mut reader, &mut stdout)?;
254258

255259
// Make sure we finish writing everything to the target before
256260
// exiting. Otherwise, when Rust is implicitly flushing, any
257261
// error will be silently ignored.
258262
stdout.flush()?;
259263

260-
Ok(())
264+
Ok(bytes_copied)
261265
}
262266

263-
fn read_n_lines(input: &mut impl std::io::BufRead, n: u64, separator: u8) -> std::io::Result<()> {
267+
fn read_n_lines(input: &mut impl std::io::BufRead, n: u64, separator: u8) -> std::io::Result<u64> {
264268
// Read the first `n` lines from the `input` reader.
265269
let mut reader = take_lines(input, n, separator);
266270

267271
// Write those bytes to `stdout`.
268272
let mut stdout = std::io::stdout();
269273

270-
io::copy(&mut reader, &mut stdout)?;
274+
let bytes_copied = io::copy(&mut reader, &mut stdout)?;
271275

272276
// Make sure we finish writing everything to the target before
273277
// exiting. Otherwise, when Rust is implicitly flushing, any
274278
// error will be silently ignored.
275279
stdout.flush()?;
276280

277-
Ok(())
281+
Ok(bytes_copied)
278282
}
279283

280284
fn catch_too_large_numbers_in_backwards_bytes_or_lines(n: u64) -> Option<usize> {
@@ -288,7 +292,7 @@ fn catch_too_large_numbers_in_backwards_bytes_or_lines(n: u64) -> Option<usize>
288292
}
289293

290294
/// Print to stdout all but the last `n` bytes from the given reader.
291-
fn read_but_last_n_bytes(input: &mut impl std::io::BufRead, n: u64) -> std::io::Result<()> {
295+
fn read_but_last_n_bytes(input: &mut impl std::io::BufRead, n: u64) -> std::io::Result<u64> {
292296
if n == 0 {
293297
//prints everything
294298
return read_n_bytes(input, u64::MAX);
@@ -302,6 +306,7 @@ fn read_but_last_n_bytes(input: &mut impl std::io::BufRead, n: u64) -> std::io::
302306

303307
let mut buffer = [0u8; BUF_SIZE];
304308
let mut total_read = 0;
309+
let mut total_written = 0;
305310

306311
loop {
307312
let read = match input.read(&mut buffer) {
@@ -322,29 +327,38 @@ fn read_but_last_n_bytes(input: &mut impl std::io::BufRead, n: u64) -> std::io::
322327
} else {
323328
// Write the ring buffer and the part of the buffer that exceeds n
324329
stdout.write_all(&ring_buffer)?;
325-
stdout.write_all(&buffer[..read - n + ring_buffer.len()])?;
330+
let buffer_bytes_to_write = read - n + ring_buffer.len();
331+
stdout.write_all(&buffer[..buffer_bytes_to_write])?;
332+
// Track our total bytes written.
333+
total_written += ring_buffer.len();
334+
total_written += buffer_bytes_to_write;
335+
326336
ring_buffer.clear();
327337
ring_buffer.extend_from_slice(&buffer[read - n + ring_buffer.len()..read]);
328338
}
329339
}
340+
return Ok(u64::try_from(total_written).unwrap());
330341
}
331342

332-
Ok(())
343+
Ok(0)
333344
}
334345

335346
fn read_but_last_n_lines(
336347
input: impl std::io::BufRead,
337348
n: u64,
338349
separator: u8,
339-
) -> std::io::Result<()> {
350+
) -> std::io::Result<u64> {
351+
let mut bytes_read: u64 = 0;
340352
if let Some(n) = catch_too_large_numbers_in_backwards_bytes_or_lines(n) {
341353
let stdout = std::io::stdout();
342354
let mut stdout = stdout.lock();
343355
for bytes in take_all_but(lines(input, separator), n) {
344-
stdout.write_all(&bytes?)?;
356+
let bytes = bytes?;
357+
bytes_read += u64::try_from(bytes.len()).unwrap();
358+
stdout.write_all(&bytes)?;
345359
}
346360
}
347-
Ok(())
361+
Ok(bytes_read)
348362
}
349363

350364
/// Return the index in `input` just after the `n`th line from the end.
@@ -425,61 +439,58 @@ fn is_seekable(input: &mut std::fs::File) -> bool {
425439
&& input.seek(SeekFrom::Start(current_pos.unwrap())).is_ok()
426440
}
427441

428-
fn head_backwards_file(input: &mut std::fs::File, options: &HeadOptions) -> std::io::Result<()> {
442+
fn head_backwards_file(input: &mut std::fs::File, options: &HeadOptions) -> std::io::Result<u64> {
429443
let st = input.metadata()?;
430444
let seekable = is_seekable(input);
431445
let blksize_limit = uucore::fs::sane_blksize::sane_blksize_from_metadata(&st);
432446
if !seekable || st.len() <= blksize_limit {
433-
return head_backwards_without_seek_file(input, options);
447+
head_backwards_without_seek_file(input, options)
448+
} else {
449+
head_backwards_on_seekable_file(input, options)
434450
}
435-
436-
head_backwards_on_seekable_file(input, options)
437451
}
438452

439453
fn head_backwards_without_seek_file(
440454
input: &mut std::fs::File,
441455
options: &HeadOptions,
442-
) -> std::io::Result<()> {
456+
) -> std::io::Result<u64> {
443457
let reader = &mut std::io::BufReader::with_capacity(BUF_SIZE, &*input);
444458

445459
match options.mode {
446-
Mode::AllButLastBytes(n) => read_but_last_n_bytes(reader, n)?,
447-
Mode::AllButLastLines(n) => read_but_last_n_lines(reader, n, options.line_ending.into())?,
460+
Mode::AllButLastBytes(n) => read_but_last_n_bytes(reader, n),
461+
Mode::AllButLastLines(n) => read_but_last_n_lines(reader, n, options.line_ending.into()),
448462
_ => unreachable!(),
449463
}
450-
451-
Ok(())
452464
}
453465

454466
fn head_backwards_on_seekable_file(
455467
input: &mut std::fs::File,
456468
options: &HeadOptions,
457-
) -> std::io::Result<()> {
469+
) -> std::io::Result<u64> {
458470
match options.mode {
459471
Mode::AllButLastBytes(n) => {
460472
let size = input.metadata()?.len();
461473
if n >= size {
462-
return Ok(());
474+
Ok(0)
463475
} else {
464476
read_n_bytes(
465477
&mut std::io::BufReader::with_capacity(BUF_SIZE, input),
466478
size - n,
467-
)?;
479+
)
468480
}
469481
}
470482
Mode::AllButLastLines(n) => {
471483
let found = find_nth_line_from_end(input, n, options.line_ending.into())?;
472484
read_n_bytes(
473485
&mut std::io::BufReader::with_capacity(BUF_SIZE, input),
474486
found,
475-
)?;
487+
)
476488
}
477489
_ => unreachable!(),
478490
}
479-
Ok(())
480491
}
481492

482-
fn head_file(input: &mut std::fs::File, options: &HeadOptions) -> std::io::Result<()> {
493+
fn head_file(input: &mut std::fs::File, options: &HeadOptions) -> std::io::Result<u64> {
483494
match options.mode {
484495
Mode::FirstBytes(n) => {
485496
read_n_bytes(&mut std::io::BufReader::with_capacity(BUF_SIZE, input), n)
@@ -506,16 +517,41 @@ fn uu_head(options: &HeadOptions) -> UResult<()> {
506517
println!("==> standard input <==");
507518
}
508519
let stdin = std::io::stdin();
509-
let mut stdin = stdin.lock();
510-
511-
match options.mode {
512-
Mode::FirstBytes(n) => read_n_bytes(&mut stdin, n),
513-
Mode::AllButLastBytes(n) => read_but_last_n_bytes(&mut stdin, n),
514-
Mode::FirstLines(n) => read_n_lines(&mut stdin, n, options.line_ending.into()),
515-
Mode::AllButLastLines(n) => {
516-
read_but_last_n_lines(&mut stdin, n, options.line_ending.into())
520+
521+
#[cfg(unix)]
522+
{
523+
let stdin_raw_fd = stdin.as_raw_fd();
524+
let mut stdin_file = unsafe { File::from_raw_fd(stdin_raw_fd) };
525+
let current_pos = stdin_file.stream_position();
526+
if let Ok(current_pos) = current_pos {
527+
// We have a seekable file. Ensure we set the input stream to the
528+
// last byte read so that any tools that parse the remainder of
529+
// the stdin stream read from the correct place.
530+
531+
let bytes_read = head_file(&mut stdin_file, options)?;
532+
stdin_file.seek(SeekFrom::Start(current_pos + bytes_read))?;
533+
} else {
534+
let _bytes_read = head_file(&mut stdin_file, options)?;
517535
}
518536
}
537+
538+
#[cfg(not(unix))]
539+
{
540+
let mut stdin = stdin.lock();
541+
542+
match options.mode {
543+
Mode::FirstBytes(n) => read_n_bytes(&mut stdin, n),
544+
Mode::AllButLastBytes(n) => read_but_last_n_bytes(&mut stdin, n),
545+
Mode::FirstLines(n) => {
546+
read_n_lines(&mut stdin, n, options.line_ending.into())
547+
}
548+
Mode::AllButLastLines(n) => {
549+
read_but_last_n_lines(&mut stdin, n, options.line_ending.into())
550+
}
551+
}?;
552+
}
553+
554+
Ok(())
519555
}
520556
(name, false) => {
521557
let mut file = match std::fs::File::open(name) {
@@ -534,7 +570,8 @@ fn uu_head(options: &HeadOptions) -> UResult<()> {
534570
}
535571
println!("==> {name} <==");
536572
}
537-
head_file(&mut file, options)
573+
head_file(&mut file, options)?;
574+
Ok(())
538575
}
539576
};
540577
if let Err(e) = res {

0 commit comments

Comments
 (0)