-
Notifications
You must be signed in to change notification settings - Fork 1.7k
rust: add event file reading module #4315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,229 @@ | ||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
//! Parsing for event files containing a stream of `Event` protos. | ||
|
||
use prost::{DecodeError, Message}; | ||
use std::io::Read; | ||
|
||
use crate::proto::tensorboard::Event; | ||
use crate::tf_record::{ChecksumError, ReadRecordError, TfRecordReader}; | ||
|
||
/// A reader for a stream of `Event` protos framed as TFRecords. | ||
/// | ||
/// As with [`TfRecordReader`], an event may be read over one or more underlying reads, to support | ||
/// growing, partially flushed files. | ||
#[derive(Debug)] | ||
pub struct EventFileReader<R> { | ||
/// Wall time of the record most recently read from this event file, or `None` if no records | ||
/// have been read. Used for determining when to consider this file dead and abandon it. | ||
last_wall_time: Option<f64>, | ||
/// Underlying record reader owned by this event file. | ||
reader: TfRecordReader<R>, | ||
} | ||
|
||
/// Error returned by [`EventFileReader::read_event`]. | ||
#[derive(Debug)] | ||
pub enum ReadEventError { | ||
/// The record failed its checksum. | ||
InvalidRecord(ChecksumError), | ||
/// The record passed its checksum, but the contained protocol buffer is invalid. | ||
InvalidProto(DecodeError), | ||
/// The record is a valid `Event` proto, but its `wall_time` is `NaN`. | ||
NanWallTime(Event), | ||
/// An error occurred reading the record. May or may not be fatal. | ||
ReadRecordError(ReadRecordError), | ||
} | ||
|
||
impl From<DecodeError> for ReadEventError { | ||
fn from(e: DecodeError) -> Self { | ||
ReadEventError::InvalidProto(e) | ||
} | ||
} | ||
|
||
impl From<ChecksumError> for ReadEventError { | ||
fn from(e: ChecksumError) -> Self { | ||
ReadEventError::InvalidRecord(e) | ||
} | ||
} | ||
|
||
impl From<ReadRecordError> for ReadEventError { | ||
fn from(e: ReadRecordError) -> Self { | ||
ReadEventError::ReadRecordError(e) | ||
} | ||
} | ||
|
||
impl ReadEventError { | ||
/// Checks whether this error indicates a truncated record. This is a convenience method, since | ||
/// the end of a file always implies a truncation event. | ||
pub fn truncated(&self) -> bool { | ||
matches!( | ||
self, | ||
ReadEventError::ReadRecordError(ReadRecordError::Truncated) | ||
) | ||
} | ||
} | ||
|
||
impl<R: Read> EventFileReader<R> { | ||
/// Creates a new `EventFileReader` wrapping the given reader. | ||
pub fn new(reader: R) -> Self { | ||
Self { | ||
last_wall_time: None, | ||
reader: TfRecordReader::new(reader), | ||
} | ||
} | ||
|
||
/// Reads the next event from the file. | ||
pub fn read_event(&mut self) -> Result<Event, ReadEventError> { | ||
let record = self.reader.read_record()?; | ||
record.checksum()?; | ||
let event = Event::decode(&record.data[..])?; | ||
let wall_time = event.wall_time; | ||
if wall_time.is_nan() { | ||
return Err(ReadEventError::NanWallTime(event)); | ||
} | ||
self.last_wall_time = Some(wall_time); | ||
Ok(event) | ||
} | ||
|
||
/// Gets the wall time of the event most recently read from the event file, or `None` if no | ||
/// events have yet been read. | ||
pub fn last_wall_time(&self) -> &Option<f64> { | ||
&self.last_wall_time | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use crate::masked_crc::MaskedCrc; | ||
use crate::proto::tensorboard as pb; | ||
use crate::scripted_reader::ScriptedReader; | ||
use crate::tf_record::TfRecord; | ||
use std::io::Cursor; | ||
|
||
/// Encodes an `Event` proto to bytes. | ||
fn encode_event(e: &Event) -> Vec<u8> { | ||
let mut encoded = Vec::new(); | ||
Event::encode(&e, &mut encoded).expect("failed to encode event"); | ||
encoded | ||
} | ||
|
||
#[test] | ||
fn test() { | ||
let good_event = Event { | ||
what: Some(pb::event::What::FileVersion("good event".to_string())), | ||
wall_time: 1234.5, | ||
..Event::default() | ||
}; | ||
let mut nan_event = Event { | ||
what: Some(pb::event::What::FileVersion("bad wall time".to_string())), | ||
wall_time: f64::NAN, | ||
..Event::default() | ||
}; | ||
let records = vec![ | ||
TfRecord::from_data(encode_event(&good_event)), | ||
TfRecord::from_data(encode_event(&nan_event)), | ||
TfRecord::from_data(b"failed proto, OK record".to_vec()), | ||
TfRecord { | ||
data: b"failed proto, failed checksum, OK record structure".to_vec(), | ||
data_crc: MaskedCrc(0x12345678), | ||
}, | ||
TfRecord { | ||
data: encode_event(&good_event), | ||
data_crc: MaskedCrc(0x12345678), // OK proto, failed checksum, OK record structure | ||
}, | ||
]; | ||
let mut file = Vec::new(); | ||
for record in records { | ||
record.write(&mut file).expect("writing record"); | ||
} | ||
let mut reader = EventFileReader::new(Cursor::new(file)); | ||
|
||
assert_eq!(reader.last_wall_time(), &None); | ||
assert_eq!(reader.read_event().unwrap(), good_event); | ||
assert_eq!(reader.last_wall_time(), &Some(1234.5)); | ||
match reader.read_event() { | ||
Err(ReadEventError::NanWallTime(mut e)) => { | ||
// can't just check `e == nan_event` because `NaN != NaN` | ||
assert!(e.wall_time.is_nan()); | ||
e.wall_time = 0.0; | ||
nan_event.wall_time = 0.0; | ||
assert_eq!(e, nan_event); | ||
wchargin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
other => panic!("{:?}", other), | ||
}; | ||
assert_eq!(reader.last_wall_time(), &Some(1234.5)); | ||
match reader.read_event() { | ||
Err(ReadEventError::InvalidProto(_)) => (), | ||
other => panic!("{:?}", other), | ||
}; | ||
assert_eq!(reader.last_wall_time(), &Some(1234.5)); | ||
match reader.read_event() { | ||
Err(ReadEventError::InvalidRecord(ChecksumError { | ||
got: _, | ||
want: MaskedCrc(0x12345678), | ||
})) => (), | ||
other => panic!("{:?}", other), | ||
}; | ||
assert_eq!(reader.last_wall_time(), &Some(1234.5)); | ||
match reader.read_event() { | ||
Err(ReadEventError::InvalidRecord(ChecksumError { got, want: _ })) | ||
if got == MaskedCrc::compute(&encode_event(&good_event)) => | ||
{ | ||
() | ||
} | ||
other => panic!("{:?}", other), | ||
}; | ||
assert_eq!(reader.last_wall_time(), &Some(1234.5)); | ||
// After end of file, should get a truncation error. | ||
let last = reader.read_event(); | ||
assert!(last.as_ref().unwrap_err().truncated(), "{:?}", last); | ||
assert_eq!(reader.last_wall_time(), &Some(1234.5)); | ||
} | ||
|
||
#[test] | ||
fn test_resume() { | ||
let event = Event { | ||
what: Some(pb::event::What::FileVersion("good event".to_string())), | ||
wall_time: 1234.5, | ||
..Event::default() | ||
}; | ||
let mut file = Cursor::new(Vec::<u8>::new()); | ||
TfRecord::from_data(encode_event(&event)) | ||
.write(&mut file) | ||
.unwrap(); | ||
let record_bytes = file.into_inner(); | ||
let (beginning, end) = record_bytes.split_at(6); | ||
|
||
let sr = ScriptedReader::new(vec![beginning.to_vec(), end.to_vec()]); | ||
let mut reader = EventFileReader::new(sr); | ||
|
||
// first read should be truncated | ||
let result = reader.read_event(); | ||
assert!(result.as_ref().unwrap_err().truncated(), "{:?}", result); | ||
assert_eq!(reader.last_wall_time(), &None); | ||
|
||
// second read should be the full record | ||
let result = reader.read_event(); | ||
assert_eq!(result.unwrap(), event); | ||
assert_eq!(reader.last_wall_time(), &Some(1234.5)); | ||
|
||
// further reads should be truncated again | ||
let result = reader.read_event(); | ||
assert!(result.as_ref().unwrap_err().truncated(), "{:?}", result); | ||
assert_eq!(reader.last_wall_time(), &Some(1234.5)); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,8 @@ limitations under the License. | |
//! Resumable reading for TFRecord streams. | ||
|
||
use byteorder::{ByteOrder, LittleEndian}; | ||
use std::io::{self, Read}; | ||
use std::fmt::{self, Debug}; | ||
use std::io::{self, Read, Write}; | ||
|
||
use crate::masked_crc::MaskedCrc; | ||
|
||
|
@@ -54,11 +55,12 @@ pub struct TfRecordReader<R> { | |
|
||
/// A TFRecord with a data buffer and expected checksum. The checksum may or may not match the | ||
/// actual contents. | ||
#[derive(Debug)] | ||
#[derive(Debug, PartialEq, Eq)] | ||
pub struct TfRecord { | ||
/// The payload of the TFRecord. | ||
pub data: Vec<u8>, | ||
data_crc: MaskedCrc, | ||
/// The data CRC listed in the record, which may or not actually match the payload. | ||
pub data_crc: MaskedCrc, | ||
} | ||
|
||
/// A buffer's checksum was computed, but it did not match the expected value. | ||
|
@@ -82,6 +84,31 @@ impl TfRecord { | |
Err(ChecksumError { got, want }) | ||
} | ||
} | ||
|
||
/// Creates a TFRecord from a data vector, computing the correct data CRC. Calling `checksum()` | ||
/// on this record will always succeed. | ||
pub fn from_data(data: Vec<u8>) -> Self { | ||
let data_crc = MaskedCrc::compute(&data); | ||
TfRecord { data, data_crc } | ||
} | ||
|
||
/// Encodes the record to an output stream. The data CRC will be taken from the `TfRecord` | ||
/// value, not recomputed from the payload. This means that reading a valid record and writing | ||
/// it back out will always produce identical input. It also means that the written data CRC | ||
/// may not be valid. | ||
/// | ||
/// This may call [`Write::write`] multiple times; consider providing a buffered output stream | ||
/// if this is an issue. | ||
/// | ||
/// A record can always be serialized. This method fails only due to underlying I/O errors. | ||
pub fn write<W: Write>(&self, mut writer: W) -> io::Result<()> { | ||
let len_buf: [u8; 8] = (self.data.len() as u64).to_le_bytes(); | ||
writer.write_all(&len_buf)?; | ||
writer.write_all(&MaskedCrc::compute(&len_buf).0.to_le_bytes())?; | ||
writer.write_all(&self.data)?; | ||
writer.write_all(&self.data_crc.0.to_le_bytes())?; | ||
Ok(()) | ||
} | ||
} | ||
|
||
/// Error returned by [`TfRecordReader::read_record`]. | ||
|
@@ -112,6 +139,26 @@ impl From<io::Error> for ReadRecordError { | |
} | ||
} | ||
|
||
impl<R: Debug> Debug for TfRecordReader<R> { | ||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
f.debug_struct("TfRecordReader") | ||
.field( | ||
"header", | ||
&format_args!("{}/{}", self.header.len(), self.header.capacity()), | ||
) | ||
.field( | ||
"data_plus_footer", | ||
&format_args!( | ||
"{}/{}", | ||
self.data_plus_footer.len(), | ||
self.data_plus_footer.capacity() | ||
), | ||
) | ||
.field("reader", &self.reader) | ||
.finish() | ||
} | ||
} | ||
|
||
impl<R: Read> TfRecordReader<R> { | ||
/// Creates an empty `TfRecordReader`, ready to read a stream of TFRecords from its beginning. | ||
/// The underlying reader should be aligned to the start of a record (usually, this is just the | ||
|
@@ -128,6 +175,11 @@ impl<R: Read> TfRecordReader<R> { | |
} | ||
} | ||
|
||
/// Consumes this `TfRecordReader<R>`, returning the underlying reader `R`. | ||
pub fn into_inner(self) -> R { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No action needed, this comment is mainly to document what I learned from our offline discussion. While we don't expect this fn to be used outside of this crate's tests, it's reasonable to keep as is because
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah. You can imagine something like: // parse a file that has one TFRecord as a header and then just a bunch of data
let mut f = File::open("...")?;
let mut reader = EventFileReader::new(f);
let header = reader.read_event()?;
let mut f = reader.into_inner(); // give me the normal file back
let mut buf = Vec::new();
f.read_to_end(&mut buf)?; |
||
self.reader | ||
} | ||
|
||
/// Attempts to read a TFRecord, pausing gracefully in the face of truncations. If the record | ||
/// is truncated, the result is a `Truncated` error; call `read_record` again once more data | ||
/// may have been written to resume reading where it left off. If the record is read | ||
|
@@ -352,4 +404,46 @@ mod tests { | |
other => panic!("{:?}", other), | ||
} | ||
} | ||
|
||
#[test] | ||
fn test_from_data() { | ||
let test_cases = vec![ | ||
b"".to_vec(), | ||
b"\x00".to_vec(), | ||
b"the quick brown fox jumped over the lazy dog".to_vec(), | ||
]; | ||
for data in test_cases { | ||
TfRecord::from_data(data).checksum().unwrap(); | ||
} | ||
} | ||
|
||
fn test_write_read_roundtrip(record: &TfRecord) { | ||
let mut cursor = Cursor::new(Vec::<u8>::new()); | ||
record.write(&mut cursor).expect("failed to write record"); | ||
let written_len = cursor.position(); | ||
cursor.set_position(0); | ||
let mut reader = TfRecordReader::new(cursor); | ||
let output_record = reader.read_record().expect("read_record"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I’m not certain what the prevailing style is here, but it renders like
so I think it’s already clear enough. The fact that it’s panicking and If this were in non-test code and I were using |
||
assert_eq!(&output_record, record); | ||
assert_eq!(reader.into_inner().position(), written_len); // should have read all the bytes and not more | ||
} | ||
|
||
#[test] | ||
fn test_write_read_roundtrip_valid_data_crc() { | ||
let data = b"hello world".to_vec(); | ||
let record = TfRecord { | ||
data_crc: MaskedCrc::compute(&data), | ||
data, | ||
}; | ||
test_write_read_roundtrip(&record); | ||
} | ||
|
||
#[test] | ||
fn test_write_read_roundtrip_invalid_data_crc() { | ||
let record = TfRecord { | ||
data: b"hello world".to_vec(), | ||
data_crc: MaskedCrc(0x12345678), | ||
}; | ||
test_write_read_roundtrip(&record); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As per our offline chat, we may want a
#[derive(Debug)]
here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done; thanks.