Skip to content

rust: add event file reading module #4315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorboard/data/server/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ rust_library(
name = "rustboard_core",
srcs = [
"lib.rs",
"event_file.rs",
"masked_crc.rs",
"reservoir.rs",
"scripted_reader.rs",
Expand Down
229 changes: 229 additions & 0 deletions tensorboard/data/server/event_file.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

//! Parsing for event files containing a stream of `Event` protos.

use prost::{DecodeError, Message};
use std::io::Read;

use crate::proto::tensorboard::Event;
use crate::tf_record::{ChecksumError, ReadRecordError, TfRecordReader};

/// A reader for a stream of `Event` protos framed as TFRecords.
///
/// As with [`TfRecordReader`], an event may be read over one or more underlying reads, to support
/// growing, partially flushed files.
#[derive(Debug)]
pub struct EventFileReader<R> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per our offline chat, we may want a #[derive(Debug)] here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done; thanks.

/// Wall time of the record most recently read from this event file, or `None` if no records
/// have been read. Used for determining when to consider this file dead and abandon it.
last_wall_time: Option<f64>,
/// Underlying record reader owned by this event file.
reader: TfRecordReader<R>,
}

/// Error returned by [`EventFileReader::read_event`].
#[derive(Debug)]
pub enum ReadEventError {
/// The record failed its checksum.
InvalidRecord(ChecksumError),
/// The record passed its checksum, but the contained protocol buffer is invalid.
InvalidProto(DecodeError),
/// The record is a valid `Event` proto, but its `wall_time` is `NaN`.
NanWallTime(Event),
/// An error occurred reading the record. May or may not be fatal.
ReadRecordError(ReadRecordError),
}

impl From<DecodeError> for ReadEventError {
fn from(e: DecodeError) -> Self {
ReadEventError::InvalidProto(e)
}
}

impl From<ChecksumError> for ReadEventError {
fn from(e: ChecksumError) -> Self {
ReadEventError::InvalidRecord(e)
}
}

impl From<ReadRecordError> for ReadEventError {
fn from(e: ReadRecordError) -> Self {
ReadEventError::ReadRecordError(e)
}
}

impl ReadEventError {
/// Checks whether this error indicates a truncated record. This is a convenience method, since
/// the end of a file always implies a truncation event.
pub fn truncated(&self) -> bool {
matches!(
self,
ReadEventError::ReadRecordError(ReadRecordError::Truncated)
)
}
}

impl<R: Read> EventFileReader<R> {
/// Creates a new `EventFileReader` wrapping the given reader.
pub fn new(reader: R) -> Self {
Self {
last_wall_time: None,
reader: TfRecordReader::new(reader),
}
}

/// Reads the next event from the file.
pub fn read_event(&mut self) -> Result<Event, ReadEventError> {
let record = self.reader.read_record()?;
record.checksum()?;
let event = Event::decode(&record.data[..])?;
let wall_time = event.wall_time;
if wall_time.is_nan() {
return Err(ReadEventError::NanWallTime(event));
}
self.last_wall_time = Some(wall_time);
Ok(event)
}

/// Gets the wall time of the event most recently read from the event file, or `None` if no
/// events have yet been read.
pub fn last_wall_time(&self) -> &Option<f64> {
&self.last_wall_time
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::masked_crc::MaskedCrc;
use crate::proto::tensorboard as pb;
use crate::scripted_reader::ScriptedReader;
use crate::tf_record::TfRecord;
use std::io::Cursor;

/// Encodes an `Event` proto to bytes.
fn encode_event(e: &Event) -> Vec<u8> {
let mut encoded = Vec::new();
Event::encode(&e, &mut encoded).expect("failed to encode event");
encoded
}

#[test]
fn test() {
let good_event = Event {
what: Some(pb::event::What::FileVersion("good event".to_string())),
wall_time: 1234.5,
..Event::default()
};
let mut nan_event = Event {
what: Some(pb::event::What::FileVersion("bad wall time".to_string())),
wall_time: f64::NAN,
..Event::default()
};
let records = vec![
TfRecord::from_data(encode_event(&good_event)),
TfRecord::from_data(encode_event(&nan_event)),
TfRecord::from_data(b"failed proto, OK record".to_vec()),
TfRecord {
data: b"failed proto, failed checksum, OK record structure".to_vec(),
data_crc: MaskedCrc(0x12345678),
},
TfRecord {
data: encode_event(&good_event),
data_crc: MaskedCrc(0x12345678), // OK proto, failed checksum, OK record structure
},
];
let mut file = Vec::new();
for record in records {
record.write(&mut file).expect("writing record");
}
let mut reader = EventFileReader::new(Cursor::new(file));

assert_eq!(reader.last_wall_time(), &None);
assert_eq!(reader.read_event().unwrap(), good_event);
assert_eq!(reader.last_wall_time(), &Some(1234.5));
match reader.read_event() {
Err(ReadEventError::NanWallTime(mut e)) => {
// can't just check `e == nan_event` because `NaN != NaN`
assert!(e.wall_time.is_nan());
e.wall_time = 0.0;
nan_event.wall_time = 0.0;
assert_eq!(e, nan_event);
}
other => panic!("{:?}", other),
};
assert_eq!(reader.last_wall_time(), &Some(1234.5));
match reader.read_event() {
Err(ReadEventError::InvalidProto(_)) => (),
other => panic!("{:?}", other),
};
assert_eq!(reader.last_wall_time(), &Some(1234.5));
match reader.read_event() {
Err(ReadEventError::InvalidRecord(ChecksumError {
got: _,
want: MaskedCrc(0x12345678),
})) => (),
other => panic!("{:?}", other),
};
assert_eq!(reader.last_wall_time(), &Some(1234.5));
match reader.read_event() {
Err(ReadEventError::InvalidRecord(ChecksumError { got, want: _ }))
if got == MaskedCrc::compute(&encode_event(&good_event)) =>
{
()
}
other => panic!("{:?}", other),
};
assert_eq!(reader.last_wall_time(), &Some(1234.5));
// After end of file, should get a truncation error.
let last = reader.read_event();
assert!(last.as_ref().unwrap_err().truncated(), "{:?}", last);
assert_eq!(reader.last_wall_time(), &Some(1234.5));
}

#[test]
fn test_resume() {
let event = Event {
what: Some(pb::event::What::FileVersion("good event".to_string())),
wall_time: 1234.5,
..Event::default()
};
let mut file = Cursor::new(Vec::<u8>::new());
TfRecord::from_data(encode_event(&event))
.write(&mut file)
.unwrap();
let record_bytes = file.into_inner();
let (beginning, end) = record_bytes.split_at(6);

let sr = ScriptedReader::new(vec![beginning.to_vec(), end.to_vec()]);
let mut reader = EventFileReader::new(sr);

// first read should be truncated
let result = reader.read_event();
assert!(result.as_ref().unwrap_err().truncated(), "{:?}", result);
assert_eq!(reader.last_wall_time(), &None);

// second read should be the full record
let result = reader.read_event();
assert_eq!(result.unwrap(), event);
assert_eq!(reader.last_wall_time(), &Some(1234.5));

// further reads should be truncated again
let result = reader.read_event();
assert!(result.as_ref().unwrap_err().truncated(), "{:?}", result);
assert_eq!(reader.last_wall_time(), &Some(1234.5));
}
}
1 change: 1 addition & 0 deletions tensorboard/data/server/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

//! Core functionality for TensorBoard data loading.

pub mod event_file;
pub mod masked_crc;
pub mod reservoir;
pub mod tf_record;
Expand Down
1 change: 1 addition & 0 deletions tensorboard/data/server/scripted_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::io::{self, Cursor, Read};

/// A reader that delegates to a sequence of cursors, reading from each in turn and simulating
/// EOF after each one.
#[derive(Debug)]
pub struct ScriptedReader(VecDeque<Cursor<Vec<u8>>>);

impl ScriptedReader {
Expand Down
100 changes: 97 additions & 3 deletions tensorboard/data/server/tf_record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ limitations under the License.
//! Resumable reading for TFRecord streams.

use byteorder::{ByteOrder, LittleEndian};
use std::io::{self, Read};
use std::fmt::{self, Debug};
use std::io::{self, Read, Write};

use crate::masked_crc::MaskedCrc;

Expand Down Expand Up @@ -54,11 +55,12 @@ pub struct TfRecordReader<R> {

/// A TFRecord with a data buffer and expected checksum. The checksum may or may not match the
/// actual contents.
#[derive(Debug)]
#[derive(Debug, PartialEq, Eq)]
pub struct TfRecord {
/// The payload of the TFRecord.
pub data: Vec<u8>,
data_crc: MaskedCrc,
/// The data CRC listed in the record, which may or not actually match the payload.
pub data_crc: MaskedCrc,
}

/// A buffer's checksum was computed, but it did not match the expected value.
Expand All @@ -82,6 +84,31 @@ impl TfRecord {
Err(ChecksumError { got, want })
}
}

/// Creates a TFRecord from a data vector, computing the correct data CRC. Calling `checksum()`
/// on this record will always succeed.
pub fn from_data(data: Vec<u8>) -> Self {
let data_crc = MaskedCrc::compute(&data);
TfRecord { data, data_crc }
}

/// Encodes the record to an output stream. The data CRC will be taken from the `TfRecord`
/// value, not recomputed from the payload. This means that reading a valid record and writing
/// it back out will always produce identical input. It also means that the written data CRC
/// may not be valid.
///
/// This may call [`Write::write`] multiple times; consider providing a buffered output stream
/// if this is an issue.
///
/// A record can always be serialized. This method fails only due to underlying I/O errors.
pub fn write<W: Write>(&self, mut writer: W) -> io::Result<()> {
let len_buf: [u8; 8] = (self.data.len() as u64).to_le_bytes();
writer.write_all(&len_buf)?;
writer.write_all(&MaskedCrc::compute(&len_buf).0.to_le_bytes())?;
writer.write_all(&self.data)?;
writer.write_all(&self.data_crc.0.to_le_bytes())?;
Ok(())
}
}

/// Error returned by [`TfRecordReader::read_record`].
Expand Down Expand Up @@ -112,6 +139,26 @@ impl From<io::Error> for ReadRecordError {
}
}

impl<R: Debug> Debug for TfRecordReader<R> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TfRecordReader")
.field(
"header",
&format_args!("{}/{}", self.header.len(), self.header.capacity()),
)
.field(
"data_plus_footer",
&format_args!(
"{}/{}",
self.data_plus_footer.len(),
self.data_plus_footer.capacity()
),
)
.field("reader", &self.reader)
.finish()
}
}

impl<R: Read> TfRecordReader<R> {
/// Creates an empty `TfRecordReader`, ready to read a stream of TFRecords from its beginning.
/// The underlying reader should be aligned to the start of a record (usually, this is just the
Expand All @@ -128,6 +175,11 @@ impl<R: Read> TfRecordReader<R> {
}
}

/// Consumes this `TfRecordReader<R>`, returning the underlying reader `R`.
pub fn into_inner(self) -> R {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No action needed, this comment is mainly to document what I learned from our offline discussion.

While we don't expect this fn to be used outside of this crate's tests, it's reasonable to keep as is because

  • public into_inner is a fairly common pattern in Rust's std library
  • it must be called with self, which means the caller already has ownership of this TfRecordReader

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. You can imagine something like:

// parse a file that has one TFRecord as a header and then just a bunch of data
let mut f = File::open("...")?;
let mut reader = EventFileReader::new(f);
let header = reader.read_event()?;
let mut f = reader.into_inner(); // give me the normal file back
let mut buf = Vec::new();
f.read_to_end(&mut buf)?;

self.reader
}

/// Attempts to read a TFRecord, pausing gracefully in the face of truncations. If the record
/// is truncated, the result is a `Truncated` error; call `read_record` again once more data
/// may have been written to resume reading where it left off. If the record is read
Expand Down Expand Up @@ -352,4 +404,46 @@ mod tests {
other => panic!("{:?}", other),
}
}

#[test]
fn test_from_data() {
let test_cases = vec![
b"".to_vec(),
b"\x00".to_vec(),
b"the quick brown fox jumped over the lazy dog".to_vec(),
];
for data in test_cases {
TfRecord::from_data(data).checksum().unwrap();
}
}

fn test_write_read_roundtrip(record: &TfRecord) {
let mut cursor = Cursor::new(Vec::<u8>::new());
record.write(&mut cursor).expect("failed to write record");
let written_len = cursor.position();
cursor.set_position(0);
let mut reader = TfRecordReader::new(cursor);
let output_record = reader.read_record().expect("read_record");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ...expect("read_record"); --> expect("failed to read record"); ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m not certain what the prevailing style is here, but it renders like
this:

thread '<test_name>' panicked at 'read_record: BadLengthCrc(Truncated)', tensorboard/data/server/tf_record.rs:123:45

so I think it’s already clear enough. The fact that it’s panicking and
the test fails means that it’s failed.

If this were in non-test code and I were using expect for a condition
that should never happen (i.e., an assertion), I’d be more careful to
provide clear information and context. But for a test that’s completely
self-contained, I think that this is fine.

assert_eq!(&output_record, record);
assert_eq!(reader.into_inner().position(), written_len); // should have read all the bytes and not more
}

#[test]
fn test_write_read_roundtrip_valid_data_crc() {
let data = b"hello world".to_vec();
let record = TfRecord {
data_crc: MaskedCrc::compute(&data),
data,
};
test_write_read_roundtrip(&record);
}

#[test]
fn test_write_read_roundtrip_invalid_data_crc() {
let record = TfRecord {
data: b"hello world".to_vec(),
data_crc: MaskedCrc(0x12345678),
};
test_write_read_roundtrip(&record);
}
}