Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/async_impl/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,11 @@ impl Future for Pending {
#[cfg(feature = "zstd")]
DecoderType::Zstd => Poll::Ready(Ok(Inner::Zstd(Box::pin(
FramedRead::new(
ZstdDecoder::new(StreamReader::new(_body)),
{
let mut d = ZstdDecoder::new(StreamReader::new(_body));
d.multiple_members(true);
d
},
BytesCodec::new(),
)
.fuse(),
Expand Down
110 changes: 110 additions & 0 deletions tests/zstd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,116 @@ async fn test_non_chunked_non_fragmented_response() {
assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT);
}

// Big response can have multiple ZSTD frames in it
#[tokio::test]
async fn test_non_chunked_non_fragmented_multiple_frames_response() {
let server = server::low_level_with_response(|_raw_request, client_socket| {
Box::new(async move {
// Split the content into two parts
let content_bytes = RESPONSE_CONTENT.as_bytes();
let mid = content_bytes.len() / 2;
// Compress each part separately to create multiple ZSTD frames
let compressed_part1 = zstd_crate::encode_all(&content_bytes[0..mid], 3).unwrap();
let compressed_part2 = zstd_crate::encode_all(&content_bytes[mid..], 3).unwrap();
// Concatenate the compressed frames
let mut zstded_content = compressed_part1;
zstded_content.extend_from_slice(&compressed_part2);
// Set Content-Length to the total length of the concatenated frames
let content_length_header =
format!("Content-Length: {}\r\n\r\n", zstded_content.len()).into_bytes();
let response = [
COMPRESSED_RESPONSE_HEADERS,
&content_length_header,
&zstded_content,
]
.concat();

client_socket
.write_all(response.as_slice())
.await
.expect("response write_all failed");
client_socket.flush().await.expect("response flush failed");
})
});

let res = reqwest::Client::new()
.get(format!("http://{}/", server.addr()))
.send()
.await
.expect("response");

assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT);
}

#[tokio::test]
async fn test_chunked_fragmented_multiple_frames_in_one_chunk() {
// Define constants for delay and timing margin
const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration =
tokio::time::Duration::from_millis(1000); // 1-second delay
const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); // Margin for timing assertions

// Set up a low-level server
let server = server::low_level_with_response(|_raw_request, client_socket| {
Box::new(async move {
// Split RESPONSE_CONTENT into two parts
let mid = RESPONSE_CONTENT.len() / 2;
let part1 = &RESPONSE_CONTENT[0..mid];
let part2 = &RESPONSE_CONTENT[mid..];

// Compress each part separately to create two ZSTD frames
let compressed_part1 = zstd_compress(part1.as_bytes());
let compressed_part2 = zstd_compress(part2.as_bytes());

// Concatenate the frames into a single chunk's data
let chunk_data = [compressed_part1.as_slice(), compressed_part2.as_slice()].concat();

// Calculate the chunk size in bytes
let chunk_size = chunk_data.len();

// Prepare the initial response part: headers + chunk size
let headers = [
COMPRESSED_RESPONSE_HEADERS, // e.g., "HTTP/1.1 200 OK\r\nContent-Encoding: zstd\r\n"
b"Transfer-Encoding: chunked\r\n\r\n", // Indicate chunked encoding
format!("{:x}\r\n", chunk_size).as_bytes(), // Chunk size in hex
]
.concat();

// Send headers + chunk size + chunk data
client_socket
.write_all([headers.as_slice(), &chunk_data].concat().as_slice())
.await
.expect("write_all failed");
client_socket.flush().await.expect("flush failed");

// Introduce a delay to simulate fragmentation
tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await;

// Send chunk terminator + final chunk
client_socket
.write_all(b"\r\n0\r\n\r\n")
.await
.expect("write_all failed");
client_socket.flush().await.expect("flush failed");
})
});

// Record the start time for delay verification
let start = tokio::time::Instant::now();

let res = reqwest::Client::new()
.get(format!("http://{}/", server.addr()))
.send()
.await
.expect("Failed to get response");

// Verify the decompressed response matches the original content
assert_eq!(
res.text().await.expect("Failed to read text"),
RESPONSE_CONTENT
);
assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN);
}

#[tokio::test]
async fn test_chunked_fragmented_response_1() {
const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration =
Expand Down
Loading