diff --git a/tensorboard/data/server/BUILD b/tensorboard/data/server/BUILD index 6c4d8a4647..0a58091bc5 100644 --- a/tensorboard/data/server/BUILD +++ b/tensorboard/data/server/BUILD @@ -26,6 +26,7 @@ rust_library( name = "rustboard_core", srcs = [ "lib.rs", + "event_file.rs", "masked_crc.rs", "reservoir.rs", "scripted_reader.rs", diff --git a/tensorboard/data/server/event_file.rs b/tensorboard/data/server/event_file.rs new file mode 100644 index 0000000000..0c3acc7bb1 --- /dev/null +++ b/tensorboard/data/server/event_file.rs @@ -0,0 +1,229 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +//! Parsing for event files containing a stream of `Event` protos. + +use prost::{DecodeError, Message}; +use std::io::Read; + +use crate::proto::tensorboard::Event; +use crate::tf_record::{ChecksumError, ReadRecordError, TfRecordReader}; + +/// A reader for a stream of `Event` protos framed as TFRecords. +/// +/// As with [`TfRecordReader`], an event may be read over one or more underlying reads, to support +/// growing, partially flushed files. +#[derive(Debug)] +pub struct EventFileReader { + /// Wall time of the record most recently read from this event file, or `None` if no records + /// have been read. Used for determining when to consider this file dead and abandon it. + last_wall_time: Option, + /// Underlying record reader owned by this event file. + reader: TfRecordReader, +} + +/// Error returned by [`EventFileReader::read_event`]. +#[derive(Debug)] +pub enum ReadEventError { + /// The record failed its checksum. + InvalidRecord(ChecksumError), + /// The record passed its checksum, but the contained protocol buffer is invalid. + InvalidProto(DecodeError), + /// The record is a valid `Event` proto, but its `wall_time` is `NaN`. + NanWallTime(Event), + /// An error occurred reading the record. May or may not be fatal. + ReadRecordError(ReadRecordError), +} + +impl From for ReadEventError { + fn from(e: DecodeError) -> Self { + ReadEventError::InvalidProto(e) + } +} + +impl From for ReadEventError { + fn from(e: ChecksumError) -> Self { + ReadEventError::InvalidRecord(e) + } +} + +impl From for ReadEventError { + fn from(e: ReadRecordError) -> Self { + ReadEventError::ReadRecordError(e) + } +} + +impl ReadEventError { + /// Checks whether this error indicates a truncated record. This is a convenience method, since + /// the end of a file always implies a truncation event. + pub fn truncated(&self) -> bool { + matches!( + self, + ReadEventError::ReadRecordError(ReadRecordError::Truncated) + ) + } +} + +impl EventFileReader { + /// Creates a new `EventFileReader` wrapping the given reader. + pub fn new(reader: R) -> Self { + Self { + last_wall_time: None, + reader: TfRecordReader::new(reader), + } + } + + /// Reads the next event from the file. + pub fn read_event(&mut self) -> Result { + let record = self.reader.read_record()?; + record.checksum()?; + let event = Event::decode(&record.data[..])?; + let wall_time = event.wall_time; + if wall_time.is_nan() { + return Err(ReadEventError::NanWallTime(event)); + } + self.last_wall_time = Some(wall_time); + Ok(event) + } + + /// Gets the wall time of the event most recently read from the event file, or `None` if no + /// events have yet been read. + pub fn last_wall_time(&self) -> &Option { + &self.last_wall_time + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::masked_crc::MaskedCrc; + use crate::proto::tensorboard as pb; + use crate::scripted_reader::ScriptedReader; + use crate::tf_record::TfRecord; + use std::io::Cursor; + + /// Encodes an `Event` proto to bytes. + fn encode_event(e: &Event) -> Vec { + let mut encoded = Vec::new(); + Event::encode(&e, &mut encoded).expect("failed to encode event"); + encoded + } + + #[test] + fn test() { + let good_event = Event { + what: Some(pb::event::What::FileVersion("good event".to_string())), + wall_time: 1234.5, + ..Event::default() + }; + let mut nan_event = Event { + what: Some(pb::event::What::FileVersion("bad wall time".to_string())), + wall_time: f64::NAN, + ..Event::default() + }; + let records = vec![ + TfRecord::from_data(encode_event(&good_event)), + TfRecord::from_data(encode_event(&nan_event)), + TfRecord::from_data(b"failed proto, OK record".to_vec()), + TfRecord { + data: b"failed proto, failed checksum, OK record structure".to_vec(), + data_crc: MaskedCrc(0x12345678), + }, + TfRecord { + data: encode_event(&good_event), + data_crc: MaskedCrc(0x12345678), // OK proto, failed checksum, OK record structure + }, + ]; + let mut file = Vec::new(); + for record in records { + record.write(&mut file).expect("writing record"); + } + let mut reader = EventFileReader::new(Cursor::new(file)); + + assert_eq!(reader.last_wall_time(), &None); + assert_eq!(reader.read_event().unwrap(), good_event); + assert_eq!(reader.last_wall_time(), &Some(1234.5)); + match reader.read_event() { + Err(ReadEventError::NanWallTime(mut e)) => { + // can't just check `e == nan_event` because `NaN != NaN` + assert!(e.wall_time.is_nan()); + e.wall_time = 0.0; + nan_event.wall_time = 0.0; + assert_eq!(e, nan_event); + } + other => panic!("{:?}", other), + }; + assert_eq!(reader.last_wall_time(), &Some(1234.5)); + match reader.read_event() { + Err(ReadEventError::InvalidProto(_)) => (), + other => panic!("{:?}", other), + }; + assert_eq!(reader.last_wall_time(), &Some(1234.5)); + match reader.read_event() { + Err(ReadEventError::InvalidRecord(ChecksumError { + got: _, + want: MaskedCrc(0x12345678), + })) => (), + other => panic!("{:?}", other), + }; + assert_eq!(reader.last_wall_time(), &Some(1234.5)); + match reader.read_event() { + Err(ReadEventError::InvalidRecord(ChecksumError { got, want: _ })) + if got == MaskedCrc::compute(&encode_event(&good_event)) => + { + () + } + other => panic!("{:?}", other), + }; + assert_eq!(reader.last_wall_time(), &Some(1234.5)); + // After end of file, should get a truncation error. + let last = reader.read_event(); + assert!(last.as_ref().unwrap_err().truncated(), "{:?}", last); + assert_eq!(reader.last_wall_time(), &Some(1234.5)); + } + + #[test] + fn test_resume() { + let event = Event { + what: Some(pb::event::What::FileVersion("good event".to_string())), + wall_time: 1234.5, + ..Event::default() + }; + let mut file = Cursor::new(Vec::::new()); + TfRecord::from_data(encode_event(&event)) + .write(&mut file) + .unwrap(); + let record_bytes = file.into_inner(); + let (beginning, end) = record_bytes.split_at(6); + + let sr = ScriptedReader::new(vec![beginning.to_vec(), end.to_vec()]); + let mut reader = EventFileReader::new(sr); + + // first read should be truncated + let result = reader.read_event(); + assert!(result.as_ref().unwrap_err().truncated(), "{:?}", result); + assert_eq!(reader.last_wall_time(), &None); + + // second read should be the full record + let result = reader.read_event(); + assert_eq!(result.unwrap(), event); + assert_eq!(reader.last_wall_time(), &Some(1234.5)); + + // further reads should be truncated again + let result = reader.read_event(); + assert!(result.as_ref().unwrap_err().truncated(), "{:?}", result); + assert_eq!(reader.last_wall_time(), &Some(1234.5)); + } +} diff --git a/tensorboard/data/server/lib.rs b/tensorboard/data/server/lib.rs index ffd56b4c6c..555cd72289 100644 --- a/tensorboard/data/server/lib.rs +++ b/tensorboard/data/server/lib.rs @@ -15,6 +15,7 @@ limitations under the License. //! Core functionality for TensorBoard data loading. +pub mod event_file; pub mod masked_crc; pub mod reservoir; pub mod tf_record; diff --git a/tensorboard/data/server/scripted_reader.rs b/tensorboard/data/server/scripted_reader.rs index 80c0a464f3..e77165d367 100644 --- a/tensorboard/data/server/scripted_reader.rs +++ b/tensorboard/data/server/scripted_reader.rs @@ -20,6 +20,7 @@ use std::io::{self, Cursor, Read}; /// A reader that delegates to a sequence of cursors, reading from each in turn and simulating /// EOF after each one. +#[derive(Debug)] pub struct ScriptedReader(VecDeque>>); impl ScriptedReader { diff --git a/tensorboard/data/server/tf_record.rs b/tensorboard/data/server/tf_record.rs index d4d27c5ff1..aa2fe82ed8 100644 --- a/tensorboard/data/server/tf_record.rs +++ b/tensorboard/data/server/tf_record.rs @@ -16,7 +16,8 @@ limitations under the License. //! Resumable reading for TFRecord streams. use byteorder::{ByteOrder, LittleEndian}; -use std::io::{self, Read}; +use std::fmt::{self, Debug}; +use std::io::{self, Read, Write}; use crate::masked_crc::MaskedCrc; @@ -54,11 +55,12 @@ pub struct TfRecordReader { /// A TFRecord with a data buffer and expected checksum. The checksum may or may not match the /// actual contents. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub struct TfRecord { /// The payload of the TFRecord. pub data: Vec, - data_crc: MaskedCrc, + /// The data CRC listed in the record, which may or not actually match the payload. + pub data_crc: MaskedCrc, } /// A buffer's checksum was computed, but it did not match the expected value. @@ -82,6 +84,31 @@ impl TfRecord { Err(ChecksumError { got, want }) } } + + /// Creates a TFRecord from a data vector, computing the correct data CRC. Calling `checksum()` + /// on this record will always succeed. + pub fn from_data(data: Vec) -> Self { + let data_crc = MaskedCrc::compute(&data); + TfRecord { data, data_crc } + } + + /// Encodes the record to an output stream. The data CRC will be taken from the `TfRecord` + /// value, not recomputed from the payload. This means that reading a valid record and writing + /// it back out will always produce identical input. It also means that the written data CRC + /// may not be valid. + /// + /// This may call [`Write::write`] multiple times; consider providing a buffered output stream + /// if this is an issue. + /// + /// A record can always be serialized. This method fails only due to underlying I/O errors. + pub fn write(&self, mut writer: W) -> io::Result<()> { + let len_buf: [u8; 8] = (self.data.len() as u64).to_le_bytes(); + writer.write_all(&len_buf)?; + writer.write_all(&MaskedCrc::compute(&len_buf).0.to_le_bytes())?; + writer.write_all(&self.data)?; + writer.write_all(&self.data_crc.0.to_le_bytes())?; + Ok(()) + } } /// Error returned by [`TfRecordReader::read_record`]. @@ -112,6 +139,26 @@ impl From for ReadRecordError { } } +impl Debug for TfRecordReader { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TfRecordReader") + .field( + "header", + &format_args!("{}/{}", self.header.len(), self.header.capacity()), + ) + .field( + "data_plus_footer", + &format_args!( + "{}/{}", + self.data_plus_footer.len(), + self.data_plus_footer.capacity() + ), + ) + .field("reader", &self.reader) + .finish() + } +} + impl TfRecordReader { /// Creates an empty `TfRecordReader`, ready to read a stream of TFRecords from its beginning. /// The underlying reader should be aligned to the start of a record (usually, this is just the @@ -128,6 +175,11 @@ impl TfRecordReader { } } + /// Consumes this `TfRecordReader`, returning the underlying reader `R`. + pub fn into_inner(self) -> R { + self.reader + } + /// Attempts to read a TFRecord, pausing gracefully in the face of truncations. If the record /// is truncated, the result is a `Truncated` error; call `read_record` again once more data /// may have been written to resume reading where it left off. If the record is read @@ -352,4 +404,46 @@ mod tests { other => panic!("{:?}", other), } } + + #[test] + fn test_from_data() { + let test_cases = vec![ + b"".to_vec(), + b"\x00".to_vec(), + b"the quick brown fox jumped over the lazy dog".to_vec(), + ]; + for data in test_cases { + TfRecord::from_data(data).checksum().unwrap(); + } + } + + fn test_write_read_roundtrip(record: &TfRecord) { + let mut cursor = Cursor::new(Vec::::new()); + record.write(&mut cursor).expect("failed to write record"); + let written_len = cursor.position(); + cursor.set_position(0); + let mut reader = TfRecordReader::new(cursor); + let output_record = reader.read_record().expect("read_record"); + assert_eq!(&output_record, record); + assert_eq!(reader.into_inner().position(), written_len); // should have read all the bytes and not more + } + + #[test] + fn test_write_read_roundtrip_valid_data_crc() { + let data = b"hello world".to_vec(); + let record = TfRecord { + data_crc: MaskedCrc::compute(&data), + data, + }; + test_write_read_roundtrip(&record); + } + + #[test] + fn test_write_read_roundtrip_invalid_data_crc() { + let record = TfRecord { + data: b"hello world".to_vec(), + data_crc: MaskedCrc(0x12345678), + }; + test_write_read_roundtrip(&record); + } }