Skip to content

Commit aae1e72

Browse files
committed
rust: add event file reading module
Summary: We implement a module for reading a TensorFlow event file. This wraps the TFRecord reading module added in #4307 and parses records as protos. The reader also stores the last-read event wall time, so that we can stop reading from event files after a while (cf. `--reload_multifile`). Along the way, we add a couple utility methods to `TfRecord` that make it easier to use records in test code. Test Plan: Unit tests included. wchargin-branch: rust-event-file-reading wchargin-source: 18183bce70f95cdef820853b67f5a0930076ba15
1 parent 84b9254 commit aae1e72

File tree

4 files changed

+301
-3
lines changed

4 files changed

+301
-3
lines changed

tensorboard/data/server/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ rust_library(
2323
name = "rustboard_core",
2424
srcs = [
2525
"lib.rs",
26+
"event_file.rs",
2627
"masked_crc.rs",
2728
"scripted_reader.rs",
2829
"tf_record.rs",

tensorboard/data/server/event_file.rs

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
//! Parsing for event files containing a stream of `Event` protos.
17+
18+
use prost::{DecodeError, Message};
19+
use std::io::Read;
20+
21+
use crate::proto::tensorboard::Event;
22+
use crate::tf_record::{ChecksumError, ReadRecordError, TfRecordReader};
23+
24+
/// A reader for a stream of `Event` protos framed as TFRecords.
25+
///
26+
/// As with [`TfRecordReader`], an event may be read over one or more underlying reads, to support
27+
/// growing, partially flushed files.
28+
pub struct EventFileReader<R> {
29+
/// Wall time of the record most recently read from this event file, or `None` if no records
30+
/// have been read. Used for determining when to consider this file dead and abandon it.
31+
last_wall_time: Option<f64>,
32+
/// Underlying record reader owned by this event file.
33+
reader: TfRecordReader<R>,
34+
}
35+
36+
/// Error returned by [`EventFileReader::read_event`].
37+
#[derive(Debug)]
38+
pub enum ReadEventError {
39+
/// The record failed its checksum. This may only be detected if the protocol buffer fails to
40+
/// decode: records with bad checksum that still happen to parse as valid protos may be
41+
/// returned silently.
42+
InvalidRecord(ChecksumError),
43+
/// The record passed its checksum, but the contained protocol buffer is invalid.
44+
InvalidProto(DecodeError),
45+
/// The record is a valid `Event` proto, but its `wall_time` is `NaN`.
46+
NanWallTime(Event),
47+
/// An error occurred reading the record. May or may not be fatal.
48+
ReadRecordError(ReadRecordError),
49+
}
50+
51+
impl From<DecodeError> for ReadEventError {
52+
fn from(e: DecodeError) -> Self {
53+
ReadEventError::InvalidProto(e)
54+
}
55+
}
56+
57+
impl From<ChecksumError> for ReadEventError {
58+
fn from(e: ChecksumError) -> Self {
59+
ReadEventError::InvalidRecord(e)
60+
}
61+
}
62+
63+
impl From<ReadRecordError> for ReadEventError {
64+
fn from(e: ReadRecordError) -> Self {
65+
ReadEventError::ReadRecordError(e)
66+
}
67+
}
68+
69+
impl ReadEventError {
70+
/// Checks whether this error indicates a truncated record. This is a convenience method, since
71+
/// the end of a file always implies a truncation event.
72+
pub fn truncated(&self) -> bool {
73+
matches!(
74+
self,
75+
ReadEventError::ReadRecordError(ReadRecordError::Truncated)
76+
)
77+
}
78+
}
79+
80+
impl<R: Read> EventFileReader<R> {
81+
/// Creates a new `EventFileReader` wrapping the given reader.
82+
pub fn new(reader: R) -> Self {
83+
Self {
84+
last_wall_time: None,
85+
reader: TfRecordReader::new(reader),
86+
}
87+
}
88+
89+
/// Reads the next event from the file.
90+
pub fn read_event(&mut self) -> Result<Event, ReadEventError> {
91+
let record = self.reader.read_record()?;
92+
let event = match Event::decode(&record.data[..]) {
93+
Ok(ev) => ev,
94+
Err(err) => {
95+
// On proto decoding failure, check the record checksum first.
96+
record.checksum()?;
97+
return Err(err.into());
98+
}
99+
};
100+
let wall_time = event.wall_time;
101+
if wall_time.is_nan() {
102+
return Err(ReadEventError::NanWallTime(event));
103+
}
104+
self.last_wall_time = Some(wall_time);
105+
Ok(event)
106+
}
107+
108+
/// Gets the wall time of the event most recently read from the event file, or `None` if no
109+
/// events have yet been read.
110+
pub fn last_wall_time(&self) -> &Option<f64> {
111+
&self.last_wall_time
112+
}
113+
}
114+
115+
#[cfg(test)]
116+
mod tests {
117+
use super::*;
118+
use crate::masked_crc::MaskedCrc;
119+
use crate::proto::tensorboard as pb;
120+
use crate::scripted_reader::ScriptedReader;
121+
use crate::tf_record::TfRecord;
122+
use std::io::Cursor;
123+
124+
/// Encodes an `Event` proto to bytes.
125+
fn encode_event(e: &Event) -> Vec<u8> {
126+
let mut encoded = Vec::new();
127+
Event::encode(&e, &mut encoded).expect("failed to encode event");
128+
encoded
129+
}
130+
131+
#[test]
132+
fn test() {
133+
let good_event = Event {
134+
what: Some(pb::event::What::FileVersion("good event".to_string())),
135+
wall_time: 1234.5,
136+
..Event::default()
137+
};
138+
let mut nan_event = Event {
139+
what: Some(pb::event::What::FileVersion("bad wall time".to_string())),
140+
wall_time: f64::NAN,
141+
..Event::default()
142+
};
143+
let records = vec![
144+
TfRecord::from_data(encode_event(&good_event)),
145+
TfRecord::from_data(encode_event(&nan_event)),
146+
TfRecord::from_data(b"failed proto, OK record".to_vec()),
147+
TfRecord {
148+
data: b"failed proto, failed checksum, OK record structure".to_vec(),
149+
data_crc: MaskedCrc(0x12345678),
150+
},
151+
];
152+
let mut file = Vec::new();
153+
for record in records {
154+
record.write(&mut file).expect("writing record");
155+
}
156+
let mut reader = EventFileReader::new(Cursor::new(file));
157+
158+
assert_eq!(reader.last_wall_time(), &None);
159+
assert_eq!(reader.read_event().unwrap(), good_event);
160+
assert_eq!(reader.last_wall_time(), &Some(1234.5));
161+
match reader.read_event() {
162+
Err(ReadEventError::NanWallTime(mut e)) => {
163+
// can't just check `e == nan_event` because `NaN != NaN`
164+
assert!(e.wall_time.is_nan());
165+
e.wall_time = 0.0;
166+
nan_event.wall_time = 0.0;
167+
assert_eq!(e, nan_event);
168+
}
169+
other => panic!("{:?}", other),
170+
};
171+
assert_eq!(reader.last_wall_time(), &Some(1234.5));
172+
match reader.read_event() {
173+
Err(ReadEventError::InvalidProto(_)) => (),
174+
other => panic!("{:?}", other),
175+
};
176+
assert_eq!(reader.last_wall_time(), &Some(1234.5));
177+
match reader.read_event() {
178+
Err(ReadEventError::InvalidRecord(ChecksumError {
179+
got: _,
180+
want: MaskedCrc(0x12345678),
181+
})) => (),
182+
other => panic!("{:?}", other),
183+
};
184+
assert_eq!(reader.last_wall_time(), &Some(1234.5));
185+
// After end of file, should get a truncation error.
186+
let last = reader.read_event();
187+
assert!(last.as_ref().unwrap_err().truncated(), "{:?}", last);
188+
assert_eq!(reader.last_wall_time(), &Some(1234.5));
189+
}
190+
191+
#[test]
192+
fn test_resume() {
193+
let event = Event {
194+
what: Some(pb::event::What::FileVersion("good event".to_string())),
195+
wall_time: 1234.5,
196+
..Event::default()
197+
};
198+
let mut file = Cursor::new(Vec::<u8>::new());
199+
TfRecord::from_data(encode_event(&event))
200+
.write(&mut file)
201+
.unwrap();
202+
let record_bytes = file.into_inner();
203+
let (beginning, end) = record_bytes.split_at(6);
204+
205+
let sr = ScriptedReader::new(vec![beginning.to_vec(), end.to_vec()]);
206+
let mut reader = EventFileReader::new(sr);
207+
208+
// first read should be truncated
209+
let result = reader.read_event();
210+
assert!(result.as_ref().unwrap_err().truncated(), "{:?}", result);
211+
assert_eq!(reader.last_wall_time(), &None);
212+
213+
// second read should be the full record
214+
let result = reader.read_event();
215+
assert_eq!(result.unwrap(), event);
216+
assert_eq!(reader.last_wall_time(), &Some(1234.5));
217+
218+
// further reads should be truncated again
219+
let result = reader.read_event();
220+
assert!(result.as_ref().unwrap_err().truncated(), "{:?}", result);
221+
assert_eq!(reader.last_wall_time(), &Some(1234.5));
222+
}
223+
}

tensorboard/data/server/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
//! Core functionality for TensorBoard data loading.
1717
18+
pub mod event_file;
1819
pub mod masked_crc;
1920
pub mod tf_record;
2021

tensorboard/data/server/tf_record.rs

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ limitations under the License.
1616
//! Resumable reading for TFRecord streams.
1717
1818
use byteorder::{ByteOrder, LittleEndian};
19-
use std::io::{self, Read};
19+
use std::io::{self, Read, Write};
2020

2121
use crate::masked_crc::MaskedCrc;
2222

@@ -54,11 +54,12 @@ pub struct TfRecordReader<R> {
5454

5555
/// A TFRecord with a data buffer and expected checksum. The checksum may or may not match the
5656
/// actual contents.
57-
#[derive(Debug)]
57+
#[derive(Debug, PartialEq, Eq)]
5858
pub struct TfRecord {
5959
/// The payload of the TFRecord.
6060
pub data: Vec<u8>,
61-
data_crc: MaskedCrc,
61+
/// The data CRC listed in the record, which may or not actually match the payload.
62+
pub data_crc: MaskedCrc,
6263
}
6364

6465
/// A buffer's checksum was computed, but it did not match the expected value.
@@ -82,6 +83,31 @@ impl TfRecord {
8283
Err(ChecksumError { got, want })
8384
}
8485
}
86+
87+
/// Creates a TFRecord from a data vector, computing the correct data CRC. Calling `checksum()`
88+
/// on this record will always succeed.
89+
pub fn from_data(data: Vec<u8>) -> Self {
90+
let data_crc = MaskedCrc::compute(&data);
91+
TfRecord { data, data_crc }
92+
}
93+
94+
/// Encodes the record to an output stream. The data CRC will be taken from the `TfRecord`
95+
/// value, not recomputed from the payload. This means that reading a valid record and writing
96+
/// it back out will always produce identical input. It also means that the written data CRC
97+
/// may not be valid.
98+
///
99+
/// This may call [`Write::write`] multiple times; consider providing a buffered output stream
100+
/// if this is an issue.
101+
///
102+
/// A record can always be serialized. This method fails only due to underlying I/O errors.
103+
pub fn write<W: Write>(&self, mut writer: W) -> io::Result<()> {
104+
let len_buf: [u8; 8] = (self.data.len() as u64).to_le_bytes();
105+
writer.write_all(&len_buf)?;
106+
writer.write_all(&MaskedCrc::compute(&len_buf).0.to_le_bytes())?;
107+
writer.write_all(&self.data)?;
108+
writer.write_all(&self.data_crc.0.to_le_bytes())?;
109+
Ok(())
110+
}
85111
}
86112

87113
/// Error returned by [`TfRecordReader::read_record`].
@@ -128,6 +154,11 @@ impl<R: Read> TfRecordReader<R> {
128154
}
129155
}
130156

157+
/// Consumes this `TfRecordReader<R>`, returning the underlying reader `R`.
158+
pub fn into_inner(self) -> R {
159+
self.reader
160+
}
161+
131162
/// Attempts to read a TFRecord, pausing gracefully in the face of truncations. If the record
132163
/// is truncated, the result is a `Truncated` error; call `read_record` again once more data
133164
/// may have been written to resume reading where it left off. If the record is read
@@ -352,4 +383,46 @@ mod tests {
352383
other => panic!("{:?}", other),
353384
}
354385
}
386+
387+
#[test]
388+
fn test_from_data() {
389+
let test_cases = vec![
390+
b"".to_vec(),
391+
b"\x00".to_vec(),
392+
b"the quick brown fox jumped over the lazy dog".to_vec(),
393+
];
394+
for data in test_cases {
395+
TfRecord::from_data(data).checksum().unwrap();
396+
}
397+
}
398+
399+
fn test_write_read_roundtrip(record: &TfRecord) {
400+
let mut cursor = Cursor::new(Vec::<u8>::new());
401+
record.write(&mut cursor).expect("failed to write record");
402+
let written_len = cursor.position();
403+
cursor.set_position(0);
404+
let mut reader = TfRecordReader::new(cursor);
405+
let output_record = reader.read_record().expect("read_record");
406+
assert_eq!(&output_record, record);
407+
assert_eq!(reader.into_inner().position(), written_len); // should have read all the bytes and not more
408+
}
409+
410+
#[test]
411+
fn test_write_read_roundtrip_valid_data_crc() {
412+
let data = b"hello world".to_vec();
413+
let record = TfRecord {
414+
data_crc: MaskedCrc::compute(&data),
415+
data,
416+
};
417+
test_write_read_roundtrip(&record);
418+
}
419+
420+
#[test]
421+
fn test_write_read_roundtrip_invalid_data_crc() {
422+
let record = TfRecord {
423+
data: b"hello world".to_vec(),
424+
data_crc: MaskedCrc(0x12345678),
425+
};
426+
test_write_read_roundtrip(&record);
427+
}
355428
}

0 commit comments

Comments
 (0)