|
| 1 | +/* |
| 2 | + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 3 | + * SPDX-License-Identifier: Apache-2.0 |
| 4 | + */ |
| 5 | + |
| 6 | +//! Adapters to use http-body 1.0 bodies with SdkBody & ByteStream |
| 7 | +
|
| 8 | +use std::pin::Pin; |
| 9 | +use std::task::{ready, Context, Poll}; |
| 10 | + |
| 11 | +use bytes::Bytes; |
| 12 | +use http_body_util::BodyExt; |
| 13 | +use pin_project_lite::pin_project; |
| 14 | + |
| 15 | +use crate::body::{Error, SdkBody}; |
| 16 | + |
| 17 | +impl SdkBody { |
| 18 | + /// Construct an `SdkBody` from a type that implements [`http_body_1_0::Body<Data = Bytes>`](http_body_1_0::Body). |
| 19 | + pub fn from_body_1_x<T, E>(body: T) -> Self |
| 20 | + where |
| 21 | + T: http_body_1_0::Body<Data = Bytes, Error = E> + Send + Sync + 'static, |
| 22 | + E: Into<Error> + 'static, |
| 23 | + { |
| 24 | + SdkBody::from_body_0_4_internal(Http1toHttp04::new(body.map_err(Into::into))) |
| 25 | + } |
| 26 | +} |
| 27 | + |
| 28 | +pin_project! { |
| 29 | + struct Http1toHttp04<B> { |
| 30 | + #[pin] |
| 31 | + inner: B, |
| 32 | + trailers: Option<http_1x::HeaderMap>, |
| 33 | + } |
| 34 | +} |
| 35 | + |
| 36 | +impl<B> Http1toHttp04<B> { |
| 37 | + fn new(inner: B) -> Self { |
| 38 | + Self { |
| 39 | + inner, |
| 40 | + trailers: None, |
| 41 | + } |
| 42 | + } |
| 43 | +} |
| 44 | + |
| 45 | +impl<B> http_body_0_4::Body for Http1toHttp04<B> |
| 46 | +where |
| 47 | + B: http_body_1_0::Body, |
| 48 | +{ |
| 49 | + type Data = B::Data; |
| 50 | + type Error = B::Error; |
| 51 | + |
| 52 | + fn poll_data( |
| 53 | + mut self: Pin<&mut Self>, |
| 54 | + cx: &mut Context<'_>, |
| 55 | + ) -> Poll<Option<Result<Self::Data, Self::Error>>> { |
| 56 | + loop { |
| 57 | + let this = self.as_mut().project(); |
| 58 | + match ready!(this.inner.poll_frame(cx)) { |
| 59 | + Some(Ok(frame)) => { |
| 60 | + let frame = match frame.into_data() { |
| 61 | + Ok(data) => return Poll::Ready(Some(Ok(data))), |
| 62 | + Err(frame) => frame, |
| 63 | + }; |
| 64 | + // when we get a trailers frame, store the trailers for the next poll |
| 65 | + if let Ok(trailers) = frame.into_trailers() { |
| 66 | + this.trailers.replace(trailers); |
| 67 | + return Poll::Ready(None); |
| 68 | + }; |
| 69 | + // if the frame type was unknown, discard it. the next one might be something |
| 70 | + // useful |
| 71 | + } |
| 72 | + Some(Err(e)) => return Poll::Ready(Some(Err(e))), |
| 73 | + None => return Poll::Ready(None), |
| 74 | + } |
| 75 | + } |
| 76 | + } |
| 77 | + |
| 78 | + fn poll_trailers( |
| 79 | + self: Pin<&mut Self>, |
| 80 | + _cx: &mut Context<'_>, |
| 81 | + ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> { |
| 82 | + // all of the polling happens in poll_data, once we get to the trailers we've actually |
| 83 | + // already read everything |
| 84 | + let this = self.project(); |
| 85 | + match this.trailers.take() { |
| 86 | + Some(headers) => Poll::Ready(Ok(Some(convert_header_map(headers)))), |
| 87 | + None => Poll::Ready(Ok(None)), |
| 88 | + } |
| 89 | + } |
| 90 | + |
| 91 | + fn is_end_stream(&self) -> bool { |
| 92 | + self.inner.is_end_stream() |
| 93 | + } |
| 94 | + |
| 95 | + fn size_hint(&self) -> http_body_0_4::SizeHint { |
| 96 | + let mut size_hint = http_body_0_4::SizeHint::new(); |
| 97 | + let inner_hint = self.inner.size_hint(); |
| 98 | + if let Some(exact) = inner_hint.exact() { |
| 99 | + size_hint.set_exact(exact); |
| 100 | + } else { |
| 101 | + size_hint.set_lower(inner_hint.lower()); |
| 102 | + if let Some(upper) = inner_hint.upper() { |
| 103 | + size_hint.set_upper(upper); |
| 104 | + } |
| 105 | + } |
| 106 | + size_hint |
| 107 | + } |
| 108 | +} |
| 109 | + |
| 110 | +fn convert_header_map(input: http_1x::HeaderMap) -> http::HeaderMap { |
| 111 | + let mut map = http::HeaderMap::with_capacity(input.capacity()); |
| 112 | + let mut mem: Option<http_1x::HeaderName> = None; |
| 113 | + for (k, v) in input.into_iter() { |
| 114 | + let name = k.or_else(|| mem.clone()).unwrap(); |
| 115 | + map.append( |
| 116 | + http::HeaderName::from_bytes(name.as_str().as_bytes()).expect("already validated"), |
| 117 | + http::HeaderValue::from_bytes(v.as_bytes()).expect("already validated"), |
| 118 | + ); |
| 119 | + mem = Some(name); |
| 120 | + } |
| 121 | + map |
| 122 | +} |
| 123 | + |
| 124 | +#[cfg(test)] |
| 125 | +mod test { |
| 126 | + use std::collections::VecDeque; |
| 127 | + use std::pin::Pin; |
| 128 | + use std::task::{Context, Poll}; |
| 129 | + |
| 130 | + use bytes::Bytes; |
| 131 | + use http::header::{CONTENT_LENGTH as CL0, CONTENT_TYPE as CT0}; |
| 132 | + use http_1x::header::{CONTENT_LENGTH as CL1, CONTENT_TYPE as CT1}; |
| 133 | + use http_1x::{HeaderMap, HeaderName, HeaderValue}; |
| 134 | + use http_body_1_0::Frame; |
| 135 | + |
| 136 | + use crate::body::http_body_1_x::convert_header_map; |
| 137 | + use crate::body::{Error, SdkBody}; |
| 138 | + use crate::byte_stream::ByteStream; |
| 139 | + |
| 140 | + struct TestBody { |
| 141 | + chunks: VecDeque<Chunk>, |
| 142 | + } |
| 143 | + |
| 144 | + enum Chunk { |
| 145 | + Data(&'static str), |
| 146 | + Error(&'static str), |
| 147 | + Trailers(HeaderMap), |
| 148 | + } |
| 149 | + |
| 150 | + impl http_body_1_0::Body for TestBody { |
| 151 | + type Data = Bytes; |
| 152 | + type Error = Error; |
| 153 | + |
| 154 | + fn poll_frame( |
| 155 | + mut self: Pin<&mut Self>, |
| 156 | + _cx: &mut Context<'_>, |
| 157 | + ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> { |
| 158 | + let next = self.chunks.pop_front(); |
| 159 | + let mk = |v: Frame<Bytes>| Poll::Ready(Some(Ok(v))); |
| 160 | + |
| 161 | + match next { |
| 162 | + Some(Chunk::Data(s)) => mk(Frame::data(Bytes::from_static(s.as_bytes()))), |
| 163 | + Some(Chunk::Trailers(headers)) => mk(Frame::trailers(headers)), |
| 164 | + Some(Chunk::Error(err)) => Poll::Ready(Some(Err(err.into()))), |
| 165 | + None => Poll::Ready(None), |
| 166 | + } |
| 167 | + } |
| 168 | + } |
| 169 | + |
| 170 | + fn trailers() -> HeaderMap { |
| 171 | + let mut map = HeaderMap::new(); |
| 172 | + map.insert( |
| 173 | + HeaderName::from_static("x-test"), |
| 174 | + HeaderValue::from_static("x-test-value"), |
| 175 | + ); |
| 176 | + map.append( |
| 177 | + HeaderName::from_static("x-test"), |
| 178 | + HeaderValue::from_static("x-test-value-2"), |
| 179 | + ); |
| 180 | + map.append( |
| 181 | + HeaderName::from_static("y-test"), |
| 182 | + HeaderValue::from_static("y-test-value-2"), |
| 183 | + ); |
| 184 | + map |
| 185 | + } |
| 186 | + |
| 187 | + #[tokio::test] |
| 188 | + async fn test_body_with_trailers() { |
| 189 | + let body = TestBody { |
| 190 | + chunks: vec![ |
| 191 | + Chunk::Data("123"), |
| 192 | + Chunk::Data("456"), |
| 193 | + Chunk::Data("789"), |
| 194 | + Chunk::Trailers(trailers()), |
| 195 | + ] |
| 196 | + .into(), |
| 197 | + }; |
| 198 | + let body = SdkBody::from_body_1_x(body); |
| 199 | + let data = ByteStream::new(body); |
| 200 | + assert_eq!(data.collect().await.unwrap().to_vec(), b"123456789"); |
| 201 | + } |
| 202 | + |
| 203 | + #[tokio::test] |
| 204 | + async fn test_read_trailers() { |
| 205 | + let body = TestBody { |
| 206 | + chunks: vec![ |
| 207 | + Chunk::Data("123"), |
| 208 | + Chunk::Data("456"), |
| 209 | + Chunk::Data("789"), |
| 210 | + Chunk::Trailers(trailers()), |
| 211 | + ] |
| 212 | + .into(), |
| 213 | + }; |
| 214 | + let mut body = SdkBody::from_body_1_x(body); |
| 215 | + while let Some(_data) = http_body_0_4::Body::data(&mut body).await {} |
| 216 | + assert_eq!( |
| 217 | + http_body_0_4::Body::trailers(&mut body).await.unwrap(), |
| 218 | + Some(convert_header_map(trailers())) |
| 219 | + ); |
| 220 | + } |
| 221 | + |
| 222 | + #[tokio::test] |
| 223 | + async fn test_errors() { |
| 224 | + let body = TestBody { |
| 225 | + chunks: vec![ |
| 226 | + Chunk::Data("123"), |
| 227 | + Chunk::Data("456"), |
| 228 | + Chunk::Data("789"), |
| 229 | + Chunk::Error("errors!"), |
| 230 | + ] |
| 231 | + .into(), |
| 232 | + }; |
| 233 | + |
| 234 | + let body = SdkBody::from_body_1_x(body); |
| 235 | + let body = ByteStream::new(body); |
| 236 | + body.collect().await.expect_err("body returned an error"); |
| 237 | + } |
| 238 | + #[tokio::test] |
| 239 | + async fn test_no_trailers() { |
| 240 | + let body = TestBody { |
| 241 | + chunks: vec![Chunk::Data("123"), Chunk::Data("456"), Chunk::Data("789")].into(), |
| 242 | + }; |
| 243 | + |
| 244 | + let body = SdkBody::from_body_1_x(body); |
| 245 | + let body = ByteStream::new(body); |
| 246 | + assert_eq!(body.collect().await.unwrap().to_vec(), b"123456789"); |
| 247 | + } |
| 248 | + |
| 249 | + #[test] |
| 250 | + fn test_convert_headers() { |
| 251 | + let mut http1_headermap = http_1x::HeaderMap::new(); |
| 252 | + http1_headermap.append(CT1, HeaderValue::from_static("a")); |
| 253 | + http1_headermap.append(CT1, HeaderValue::from_static("b")); |
| 254 | + http1_headermap.append(CT1, HeaderValue::from_static("c")); |
| 255 | + |
| 256 | + http1_headermap.insert(CL1, HeaderValue::from_static("1234")); |
| 257 | + |
| 258 | + let mut expect = http::HeaderMap::new(); |
| 259 | + expect.append(CT0, http::HeaderValue::from_static("a")); |
| 260 | + expect.append(CT0, http::HeaderValue::from_static("b")); |
| 261 | + expect.append(CT0, http::HeaderValue::from_static("c")); |
| 262 | + |
| 263 | + expect.insert(CL0, http::HeaderValue::from_static("1234")); |
| 264 | + |
| 265 | + assert_eq!(convert_header_map(http1_headermap), expect); |
| 266 | + } |
| 267 | +} |
0 commit comments