Skip to content

Commit 2df1b46

Browse files
authored
rust: add event file reading module (#4315)
Summary: We implement a module for reading a TensorFlow event file. This wraps the TFRecord reading module added in #4307 and parses records as protos. The reader also stores the last-read event wall time, so that we can stop reading from event files after a while (cf. `--reload_multifile`). Along the way, we add a couple utility methods to `TfRecord` that make it easier to use records in test code. Test Plan: Unit tests included. wchargin-branch: rust-event-file-reading
1 parent 0f43809 commit 2df1b46

File tree

5 files changed

+329
-3
lines changed

5 files changed

+329
-3
lines changed

tensorboard/data/server/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ rust_library(
2626
name = "rustboard_core",
2727
srcs = [
2828
"lib.rs",
29+
"event_file.rs",
2930
"masked_crc.rs",
3031
"reservoir.rs",
3132
"scripted_reader.rs",

tensorboard/data/server/event_file.rs

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
//! Parsing for event files containing a stream of `Event` protos.
17+
18+
use prost::{DecodeError, Message};
19+
use std::io::Read;
20+
21+
use crate::proto::tensorboard::Event;
22+
use crate::tf_record::{ChecksumError, ReadRecordError, TfRecordReader};
23+
24+
/// A reader for a stream of `Event` protos framed as TFRecords.
25+
///
26+
/// As with [`TfRecordReader`], an event may be read over one or more underlying reads, to support
27+
/// growing, partially flushed files.
28+
#[derive(Debug)]
29+
pub struct EventFileReader<R> {
30+
/// Wall time of the record most recently read from this event file, or `None` if no records
31+
/// have been read. Used for determining when to consider this file dead and abandon it.
32+
last_wall_time: Option<f64>,
33+
/// Underlying record reader owned by this event file.
34+
reader: TfRecordReader<R>,
35+
}
36+
37+
/// Error returned by [`EventFileReader::read_event`].
38+
#[derive(Debug)]
39+
pub enum ReadEventError {
40+
/// The record failed its checksum.
41+
InvalidRecord(ChecksumError),
42+
/// The record passed its checksum, but the contained protocol buffer is invalid.
43+
InvalidProto(DecodeError),
44+
/// The record is a valid `Event` proto, but its `wall_time` is `NaN`.
45+
NanWallTime(Event),
46+
/// An error occurred reading the record. May or may not be fatal.
47+
ReadRecordError(ReadRecordError),
48+
}
49+
50+
impl From<DecodeError> for ReadEventError {
51+
fn from(e: DecodeError) -> Self {
52+
ReadEventError::InvalidProto(e)
53+
}
54+
}
55+
56+
impl From<ChecksumError> for ReadEventError {
57+
fn from(e: ChecksumError) -> Self {
58+
ReadEventError::InvalidRecord(e)
59+
}
60+
}
61+
62+
impl From<ReadRecordError> for ReadEventError {
63+
fn from(e: ReadRecordError) -> Self {
64+
ReadEventError::ReadRecordError(e)
65+
}
66+
}
67+
68+
impl ReadEventError {
69+
/// Checks whether this error indicates a truncated record. This is a convenience method, since
70+
/// the end of a file always implies a truncation event.
71+
pub fn truncated(&self) -> bool {
72+
matches!(
73+
self,
74+
ReadEventError::ReadRecordError(ReadRecordError::Truncated)
75+
)
76+
}
77+
}
78+
79+
impl<R: Read> EventFileReader<R> {
80+
/// Creates a new `EventFileReader` wrapping the given reader.
81+
pub fn new(reader: R) -> Self {
82+
Self {
83+
last_wall_time: None,
84+
reader: TfRecordReader::new(reader),
85+
}
86+
}
87+
88+
/// Reads the next event from the file.
89+
pub fn read_event(&mut self) -> Result<Event, ReadEventError> {
90+
let record = self.reader.read_record()?;
91+
record.checksum()?;
92+
let event = Event::decode(&record.data[..])?;
93+
let wall_time = event.wall_time;
94+
if wall_time.is_nan() {
95+
return Err(ReadEventError::NanWallTime(event));
96+
}
97+
self.last_wall_time = Some(wall_time);
98+
Ok(event)
99+
}
100+
101+
/// Gets the wall time of the event most recently read from the event file, or `None` if no
102+
/// events have yet been read.
103+
pub fn last_wall_time(&self) -> &Option<f64> {
104+
&self.last_wall_time
105+
}
106+
}
107+
108+
#[cfg(test)]
109+
mod tests {
110+
use super::*;
111+
use crate::masked_crc::MaskedCrc;
112+
use crate::proto::tensorboard as pb;
113+
use crate::scripted_reader::ScriptedReader;
114+
use crate::tf_record::TfRecord;
115+
use std::io::Cursor;
116+
117+
/// Encodes an `Event` proto to bytes.
118+
fn encode_event(e: &Event) -> Vec<u8> {
119+
let mut encoded = Vec::new();
120+
Event::encode(&e, &mut encoded).expect("failed to encode event");
121+
encoded
122+
}
123+
124+
#[test]
125+
fn test() {
126+
let good_event = Event {
127+
what: Some(pb::event::What::FileVersion("good event".to_string())),
128+
wall_time: 1234.5,
129+
..Event::default()
130+
};
131+
let mut nan_event = Event {
132+
what: Some(pb::event::What::FileVersion("bad wall time".to_string())),
133+
wall_time: f64::NAN,
134+
..Event::default()
135+
};
136+
let records = vec![
137+
TfRecord::from_data(encode_event(&good_event)),
138+
TfRecord::from_data(encode_event(&nan_event)),
139+
TfRecord::from_data(b"failed proto, OK record".to_vec()),
140+
TfRecord {
141+
data: b"failed proto, failed checksum, OK record structure".to_vec(),
142+
data_crc: MaskedCrc(0x12345678),
143+
},
144+
TfRecord {
145+
data: encode_event(&good_event),
146+
data_crc: MaskedCrc(0x12345678), // OK proto, failed checksum, OK record structure
147+
},
148+
];
149+
let mut file = Vec::new();
150+
for record in records {
151+
record.write(&mut file).expect("writing record");
152+
}
153+
let mut reader = EventFileReader::new(Cursor::new(file));
154+
155+
assert_eq!(reader.last_wall_time(), &None);
156+
assert_eq!(reader.read_event().unwrap(), good_event);
157+
assert_eq!(reader.last_wall_time(), &Some(1234.5));
158+
match reader.read_event() {
159+
Err(ReadEventError::NanWallTime(mut e)) => {
160+
// can't just check `e == nan_event` because `NaN != NaN`
161+
assert!(e.wall_time.is_nan());
162+
e.wall_time = 0.0;
163+
nan_event.wall_time = 0.0;
164+
assert_eq!(e, nan_event);
165+
}
166+
other => panic!("{:?}", other),
167+
};
168+
assert_eq!(reader.last_wall_time(), &Some(1234.5));
169+
match reader.read_event() {
170+
Err(ReadEventError::InvalidProto(_)) => (),
171+
other => panic!("{:?}", other),
172+
};
173+
assert_eq!(reader.last_wall_time(), &Some(1234.5));
174+
match reader.read_event() {
175+
Err(ReadEventError::InvalidRecord(ChecksumError {
176+
got: _,
177+
want: MaskedCrc(0x12345678),
178+
})) => (),
179+
other => panic!("{:?}", other),
180+
};
181+
assert_eq!(reader.last_wall_time(), &Some(1234.5));
182+
match reader.read_event() {
183+
Err(ReadEventError::InvalidRecord(ChecksumError { got, want: _ }))
184+
if got == MaskedCrc::compute(&encode_event(&good_event)) =>
185+
{
186+
()
187+
}
188+
other => panic!("{:?}", other),
189+
};
190+
assert_eq!(reader.last_wall_time(), &Some(1234.5));
191+
// After end of file, should get a truncation error.
192+
let last = reader.read_event();
193+
assert!(last.as_ref().unwrap_err().truncated(), "{:?}", last);
194+
assert_eq!(reader.last_wall_time(), &Some(1234.5));
195+
}
196+
197+
#[test]
198+
fn test_resume() {
199+
let event = Event {
200+
what: Some(pb::event::What::FileVersion("good event".to_string())),
201+
wall_time: 1234.5,
202+
..Event::default()
203+
};
204+
let mut file = Cursor::new(Vec::<u8>::new());
205+
TfRecord::from_data(encode_event(&event))
206+
.write(&mut file)
207+
.unwrap();
208+
let record_bytes = file.into_inner();
209+
let (beginning, end) = record_bytes.split_at(6);
210+
211+
let sr = ScriptedReader::new(vec![beginning.to_vec(), end.to_vec()]);
212+
let mut reader = EventFileReader::new(sr);
213+
214+
// first read should be truncated
215+
let result = reader.read_event();
216+
assert!(result.as_ref().unwrap_err().truncated(), "{:?}", result);
217+
assert_eq!(reader.last_wall_time(), &None);
218+
219+
// second read should be the full record
220+
let result = reader.read_event();
221+
assert_eq!(result.unwrap(), event);
222+
assert_eq!(reader.last_wall_time(), &Some(1234.5));
223+
224+
// further reads should be truncated again
225+
let result = reader.read_event();
226+
assert!(result.as_ref().unwrap_err().truncated(), "{:?}", result);
227+
assert_eq!(reader.last_wall_time(), &Some(1234.5));
228+
}
229+
}

tensorboard/data/server/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
//! Core functionality for TensorBoard data loading.
1717
18+
pub mod event_file;
1819
pub mod masked_crc;
1920
pub mod reservoir;
2021
pub mod tf_record;

tensorboard/data/server/scripted_reader.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use std::io::{self, Cursor, Read};
2020

2121
/// A reader that delegates to a sequence of cursors, reading from each in turn and simulating
2222
/// EOF after each one.
23+
#[derive(Debug)]
2324
pub struct ScriptedReader(VecDeque<Cursor<Vec<u8>>>);
2425

2526
impl ScriptedReader {

tensorboard/data/server/tf_record.rs

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ limitations under the License.
1616
//! Resumable reading for TFRecord streams.
1717
1818
use byteorder::{ByteOrder, LittleEndian};
19-
use std::io::{self, Read};
19+
use std::fmt::{self, Debug};
20+
use std::io::{self, Read, Write};
2021

2122
use crate::masked_crc::MaskedCrc;
2223

@@ -54,11 +55,12 @@ pub struct TfRecordReader<R> {
5455

5556
/// A TFRecord with a data buffer and expected checksum. The checksum may or may not match the
5657
/// actual contents.
57-
#[derive(Debug)]
58+
#[derive(Debug, PartialEq, Eq)]
5859
pub struct TfRecord {
5960
/// The payload of the TFRecord.
6061
pub data: Vec<u8>,
61-
data_crc: MaskedCrc,
62+
/// The data CRC listed in the record, which may or not actually match the payload.
63+
pub data_crc: MaskedCrc,
6264
}
6365

6466
/// A buffer's checksum was computed, but it did not match the expected value.
@@ -82,6 +84,31 @@ impl TfRecord {
8284
Err(ChecksumError { got, want })
8385
}
8486
}
87+
88+
/// Creates a TFRecord from a data vector, computing the correct data CRC. Calling `checksum()`
89+
/// on this record will always succeed.
90+
pub fn from_data(data: Vec<u8>) -> Self {
91+
let data_crc = MaskedCrc::compute(&data);
92+
TfRecord { data, data_crc }
93+
}
94+
95+
/// Encodes the record to an output stream. The data CRC will be taken from the `TfRecord`
96+
/// value, not recomputed from the payload. This means that reading a valid record and writing
97+
/// it back out will always produce identical input. It also means that the written data CRC
98+
/// may not be valid.
99+
///
100+
/// This may call [`Write::write`] multiple times; consider providing a buffered output stream
101+
/// if this is an issue.
102+
///
103+
/// A record can always be serialized. This method fails only due to underlying I/O errors.
104+
pub fn write<W: Write>(&self, mut writer: W) -> io::Result<()> {
105+
let len_buf: [u8; 8] = (self.data.len() as u64).to_le_bytes();
106+
writer.write_all(&len_buf)?;
107+
writer.write_all(&MaskedCrc::compute(&len_buf).0.to_le_bytes())?;
108+
writer.write_all(&self.data)?;
109+
writer.write_all(&self.data_crc.0.to_le_bytes())?;
110+
Ok(())
111+
}
85112
}
86113

87114
/// Error returned by [`TfRecordReader::read_record`].
@@ -112,6 +139,26 @@ impl From<io::Error> for ReadRecordError {
112139
}
113140
}
114141

142+
impl<R: Debug> Debug for TfRecordReader<R> {
143+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144+
f.debug_struct("TfRecordReader")
145+
.field(
146+
"header",
147+
&format_args!("{}/{}", self.header.len(), self.header.capacity()),
148+
)
149+
.field(
150+
"data_plus_footer",
151+
&format_args!(
152+
"{}/{}",
153+
self.data_plus_footer.len(),
154+
self.data_plus_footer.capacity()
155+
),
156+
)
157+
.field("reader", &self.reader)
158+
.finish()
159+
}
160+
}
161+
115162
impl<R: Read> TfRecordReader<R> {
116163
/// Creates an empty `TfRecordReader`, ready to read a stream of TFRecords from its beginning.
117164
/// The underlying reader should be aligned to the start of a record (usually, this is just the
@@ -128,6 +175,11 @@ impl<R: Read> TfRecordReader<R> {
128175
}
129176
}
130177

178+
/// Consumes this `TfRecordReader<R>`, returning the underlying reader `R`.
179+
pub fn into_inner(self) -> R {
180+
self.reader
181+
}
182+
131183
/// Attempts to read a TFRecord, pausing gracefully in the face of truncations. If the record
132184
/// is truncated, the result is a `Truncated` error; call `read_record` again once more data
133185
/// may have been written to resume reading where it left off. If the record is read
@@ -352,4 +404,46 @@ mod tests {
352404
other => panic!("{:?}", other),
353405
}
354406
}
407+
408+
#[test]
409+
fn test_from_data() {
410+
let test_cases = vec![
411+
b"".to_vec(),
412+
b"\x00".to_vec(),
413+
b"the quick brown fox jumped over the lazy dog".to_vec(),
414+
];
415+
for data in test_cases {
416+
TfRecord::from_data(data).checksum().unwrap();
417+
}
418+
}
419+
420+
fn test_write_read_roundtrip(record: &TfRecord) {
421+
let mut cursor = Cursor::new(Vec::<u8>::new());
422+
record.write(&mut cursor).expect("failed to write record");
423+
let written_len = cursor.position();
424+
cursor.set_position(0);
425+
let mut reader = TfRecordReader::new(cursor);
426+
let output_record = reader.read_record().expect("read_record");
427+
assert_eq!(&output_record, record);
428+
assert_eq!(reader.into_inner().position(), written_len); // should have read all the bytes and not more
429+
}
430+
431+
#[test]
432+
fn test_write_read_roundtrip_valid_data_crc() {
433+
let data = b"hello world".to_vec();
434+
let record = TfRecord {
435+
data_crc: MaskedCrc::compute(&data),
436+
data,
437+
};
438+
test_write_read_roundtrip(&record);
439+
}
440+
441+
#[test]
442+
fn test_write_read_roundtrip_invalid_data_crc() {
443+
let record = TfRecord {
444+
data: b"hello world".to_vec(),
445+
data_crc: MaskedCrc(0x12345678),
446+
};
447+
test_write_read_roundtrip(&record);
448+
}
355449
}

0 commit comments

Comments
 (0)