Skip to content

Commit 2a5071f

Browse files
feat: implement Framed::map_codec (#4427)
1 parent 621790e commit 2a5071f

File tree

6 files changed

+196
-4
lines changed

6 files changed

+196
-4
lines changed

tokio-util/src/codec/framed.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,26 @@ impl<T, U> Framed<T, U> {
204204
&mut self.inner.codec
205205
}
206206

207+
/// Maps the codec `U` to `C`, preserving the read and write buffers
208+
/// wrapped by `Framed`.
209+
///
210+
/// Note that care should be taken to not tamper with the underlying codec
211+
/// as it may corrupt the stream of frames otherwise being worked with.
212+
pub fn map_codec<C, F>(self, map: F) -> Framed<T, C>
213+
where
214+
F: FnOnce(U) -> C,
215+
{
216+
// This could be potentially simplified once rust-lang/rust#86555 hits stable
217+
let parts = self.into_parts();
218+
Framed::from_parts(FramedParts {
219+
io: parts.io,
220+
codec: map(parts.codec),
221+
read_buf: parts.read_buf,
222+
write_buf: parts.write_buf,
223+
_priv: (),
224+
})
225+
}
226+
207227
/// Returns a mutable reference to the underlying codec wrapped by
208228
/// `Framed`.
209229
///

tokio-util/src/codec/framed_read.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,27 @@ impl<T, D> FramedRead<T, D> {
108108
&mut self.inner.codec
109109
}
110110

111+
/// Maps the decoder `D` to `C`, preserving the read buffer
112+
/// wrapped by `Framed`.
113+
pub fn map_decoder<C, F>(self, map: F) -> FramedRead<T, C>
114+
where
115+
F: FnOnce(D) -> C,
116+
{
117+
// This could be potentially simplified once rust-lang/rust#86555 hits stable
118+
let FramedImpl {
119+
inner,
120+
state,
121+
codec,
122+
} = self.inner;
123+
FramedRead {
124+
inner: FramedImpl {
125+
inner,
126+
state,
127+
codec: map(codec),
128+
},
129+
}
130+
}
131+
111132
/// Returns a mutable reference to the underlying decoder.
112133
pub fn decoder_pin_mut(self: Pin<&mut Self>) -> &mut D {
113134
self.project().inner.project().codec

tokio-util/src/codec/framed_write.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,27 @@ impl<T, E> FramedWrite<T, E> {
8888
&mut self.inner.codec
8989
}
9090

91+
/// Maps the encoder `E` to `C`, preserving the write buffer
92+
/// wrapped by `Framed`.
93+
pub fn map_encoder<C, F>(self, map: F) -> FramedWrite<T, C>
94+
where
95+
F: FnOnce(E) -> C,
96+
{
97+
// This could be potentially simplified once rust-lang/rust#86555 hits stable
98+
let FramedImpl {
99+
inner,
100+
state,
101+
codec,
102+
} = self.inner;
103+
FramedWrite {
104+
inner: FramedImpl {
105+
inner,
106+
state,
107+
codec: map(codec),
108+
},
109+
}
110+
}
111+
91112
/// Returns a mutable reference to the underlying encoder.
92113
pub fn encoder_pin_mut(self: Pin<&mut Self>) -> &mut E {
93114
self.project().inner.project().codec

tokio-util/tests/framed.rs

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ use std::task::{Context, Poll};
1212
const INITIAL_CAPACITY: usize = 8 * 1024;
1313

1414
/// Encode and decode u32 values.
15-
struct U32Codec;
15+
#[derive(Default)]
16+
struct U32Codec {
17+
read_bytes: usize,
18+
}
1619

1720
impl Decoder for U32Codec {
1821
type Item = u32;
@@ -24,6 +27,7 @@ impl Decoder for U32Codec {
2427
}
2528

2629
let n = buf.split_to(4).get_u32();
30+
self.read_bytes += 4;
2731
Ok(Some(n))
2832
}
2933
}
@@ -39,6 +43,38 @@ impl Encoder<u32> for U32Codec {
3943
}
4044
}
4145

46+
/// Encode and decode u64 values.
47+
#[derive(Default)]
48+
struct U64Codec {
49+
read_bytes: usize,
50+
}
51+
52+
impl Decoder for U64Codec {
53+
type Item = u64;
54+
type Error = io::Error;
55+
56+
fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u64>> {
57+
if buf.len() < 8 {
58+
return Ok(None);
59+
}
60+
61+
let n = buf.split_to(8).get_u64();
62+
self.read_bytes += 8;
63+
Ok(Some(n))
64+
}
65+
}
66+
67+
impl Encoder<u64> for U64Codec {
68+
type Error = io::Error;
69+
70+
fn encode(&mut self, item: u64, dst: &mut BytesMut) -> io::Result<()> {
71+
// Reserve space
72+
dst.reserve(8);
73+
dst.put_u64(item);
74+
Ok(())
75+
}
76+
}
77+
4278
/// This value should never be used
4379
struct DontReadIntoThis;
4480

@@ -63,18 +99,39 @@ impl tokio::io::AsyncRead for DontReadIntoThis {
6399

64100
#[tokio::test]
65101
async fn can_read_from_existing_buf() {
66-
let mut parts = FramedParts::new(DontReadIntoThis, U32Codec);
102+
let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default());
67103
parts.read_buf = BytesMut::from(&[0, 0, 0, 42][..]);
68104

69105
let mut framed = Framed::from_parts(parts);
70106
let num = assert_ok!(framed.next().await.unwrap());
71107

72108
assert_eq!(num, 42);
109+
assert_eq!(framed.codec().read_bytes, 4);
110+
}
111+
112+
#[tokio::test]
113+
async fn can_read_from_existing_buf_after_codec_changed() {
114+
let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default());
115+
parts.read_buf = BytesMut::from(&[0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, 84][..]);
116+
117+
let mut framed = Framed::from_parts(parts);
118+
let num = assert_ok!(framed.next().await.unwrap());
119+
120+
assert_eq!(num, 42);
121+
assert_eq!(framed.codec().read_bytes, 4);
122+
123+
let mut framed = framed.map_codec(|codec| U64Codec {
124+
read_bytes: codec.read_bytes,
125+
});
126+
let num = assert_ok!(framed.next().await.unwrap());
127+
128+
assert_eq!(num, 84);
129+
assert_eq!(framed.codec().read_bytes, 12);
73130
}
74131

75132
#[test]
76133
fn external_buf_grows_to_init() {
77-
let mut parts = FramedParts::new(DontReadIntoThis, U32Codec);
134+
let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default());
78135
parts.read_buf = BytesMut::from(&[0, 0, 0, 42][..]);
79136

80137
let framed = Framed::from_parts(parts);
@@ -85,7 +142,7 @@ fn external_buf_grows_to_init() {
85142

86143
#[test]
87144
fn external_buf_does_not_shrink() {
88-
let mut parts = FramedParts::new(DontReadIntoThis, U32Codec);
145+
let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default());
89146
parts.read_buf = BytesMut::from(&vec![0; INITIAL_CAPACITY * 2][..]);
90147

91148
let framed = Framed::from_parts(parts);

tokio-util/tests/framed_read.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,22 @@ impl Decoder for U32Decoder {
5050
}
5151
}
5252

53+
struct U64Decoder;
54+
55+
impl Decoder for U64Decoder {
56+
type Item = u64;
57+
type Error = io::Error;
58+
59+
fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u64>> {
60+
if buf.len() < 8 {
61+
return Ok(None);
62+
}
63+
64+
let n = buf.split_to(8).get_u64();
65+
Ok(Some(n))
66+
}
67+
}
68+
5369
#[test]
5470
fn read_multi_frame_in_packet() {
5571
let mut task = task::spawn(());
@@ -84,6 +100,24 @@ fn read_multi_frame_across_packets() {
84100
});
85101
}
86102

103+
#[test]
104+
fn read_multi_frame_in_packet_after_codec_changed() {
105+
let mut task = task::spawn(());
106+
let mock = mock! {
107+
Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08".to_vec()),
108+
};
109+
let mut framed = FramedRead::new(mock, U32Decoder);
110+
111+
task.enter(|cx, _| {
112+
assert_read!(pin!(framed).poll_next(cx), 0x04);
113+
114+
let mut framed = framed.map_decoder(|_| U64Decoder);
115+
assert_read!(pin!(framed).poll_next(cx), 0x08);
116+
117+
assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none());
118+
});
119+
}
120+
87121
#[test]
88122
fn read_not_ready() {
89123
let mut task = task::spawn(());

tokio-util/tests/framed_write.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,19 @@ impl Encoder<u32> for U32Encoder {
3939
}
4040
}
4141

42+
struct U64Encoder;
43+
44+
impl Encoder<u64> for U64Encoder {
45+
type Error = io::Error;
46+
47+
fn encode(&mut self, item: u64, dst: &mut BytesMut) -> io::Result<()> {
48+
// Reserve space
49+
dst.reserve(8);
50+
dst.put_u64(item);
51+
Ok(())
52+
}
53+
}
54+
4255
#[test]
4356
fn write_multi_frame_in_packet() {
4457
let mut task = task::spawn(());
@@ -65,6 +78,32 @@ fn write_multi_frame_in_packet() {
6578
});
6679
}
6780

81+
#[test]
82+
fn write_multi_frame_after_codec_changed() {
83+
let mut task = task::spawn(());
84+
let mock = mock! {
85+
Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08".to_vec()),
86+
};
87+
let mut framed = FramedWrite::new(mock, U32Encoder);
88+
89+
task.enter(|cx, _| {
90+
assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
91+
assert!(pin!(framed).start_send(0x04).is_ok());
92+
93+
let mut framed = framed.map_encoder(|_| U64Encoder);
94+
assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
95+
assert!(pin!(framed).start_send(0x08).is_ok());
96+
97+
// Nothing written yet
98+
assert_eq!(1, framed.get_ref().calls.len());
99+
100+
// Flush the writes
101+
assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok());
102+
103+
assert_eq!(0, framed.get_ref().calls.len());
104+
});
105+
}
106+
68107
#[test]
69108
fn write_hits_backpressure() {
70109
const ITER: usize = 2 * 1024;

0 commit comments

Comments
 (0)