|
2 | 2 | import re
|
3 | 3 | import ssl
|
4 | 4 | from dataclasses import dataclass
|
5 |
| -from typing import Dict, List, Optional, Tuple |
| 5 | +from typing import Dict, List, Optional, Tuple, Union |
6 | 6 | from urllib.parse import unquote, urljoin, urlparse
|
7 | 7 |
|
8 | 8 | import structlog
|
@@ -61,6 +61,30 @@ def reconstruct(self) -> bytes:
|
61 | 61 | return result
|
62 | 62 |
|
63 | 63 |
|
| 64 | +@dataclass |
| 65 | +class HttpResponse: |
| 66 | + """Data class to store HTTP response details""" |
| 67 | + |
| 68 | + version: str |
| 69 | + status_code: int |
| 70 | + reason: str |
| 71 | + headers: List[str] |
| 72 | + body: Optional[bytes] = None |
| 73 | + |
| 74 | + def reconstruct(self) -> bytes: |
| 75 | + """Reconstruct HTTP response from stored details""" |
| 76 | + headers = "\r\n".join(self.headers) |
| 77 | + status_line = f"{self.version} {self.status_code} {self.reason}\r\n" |
| 78 | + header_block = f"{status_line}{headers}\r\n\r\n" |
| 79 | + |
| 80 | + # Convert header block to bytes and combine with body |
| 81 | + result = header_block.encode("utf-8") |
| 82 | + if self.body: |
| 83 | + result += self.body |
| 84 | + |
| 85 | + return result |
| 86 | + |
| 87 | + |
64 | 88 | def extract_path(full_path: str) -> str:
|
65 | 89 | """Extract clean path from full URL or path string"""
|
66 | 90 | logger.debug(f"Extracting path from {full_path}")
|
@@ -145,7 +169,7 @@ async def _body_through_pipeline(
|
145 | 169 | ) -> Tuple[bytes, PipelineContext]:
|
146 | 170 | logger.debug(f"Processing body through pipeline: {len(body)} bytes")
|
147 | 171 | strategy = self._select_pipeline(method, path)
|
148 |
| - if strategy is None: |
| 172 | + if len(body) == 0 or strategy is None: |
149 | 173 | # if we didn't select any strategy that would change the request
|
150 | 174 | # let's just pass through the body as-is
|
151 | 175 | return body, None
|
@@ -243,35 +267,87 @@ def _check_buffer_size(self, new_data: bytes) -> bool:
|
243 | 267 | """Check if adding new data would exceed buffer size limit"""
|
244 | 268 | return len(self.buffer) + len(new_data) <= MAX_BUFFER_SIZE
|
245 | 269 |
|
246 |
| - async def _forward_data_through_pipeline(self, data: bytes) -> bytes: |
| 270 | + async def _forward_data_through_pipeline( |
| 271 | + self, data: bytes |
| 272 | + ) -> Union[HttpRequest, List[HttpResponse]]: |
247 | 273 | http_request = http_request_from_bytes(data)
|
248 | 274 | if not http_request:
|
249 | 275 | # we couldn't parse this into an HTTP request, so we just pass through
|
250 | 276 | return data
|
251 | 277 |
|
252 |
| - http_request.body, context = await self._body_through_pipeline( |
| 278 | + body, context = await self._body_through_pipeline( |
253 | 279 | http_request.method,
|
254 | 280 | http_request.path,
|
255 | 281 | http_request.headers,
|
256 | 282 | http_request.body,
|
257 | 283 | )
|
258 | 284 | self.context_tracking = context
|
259 | 285 |
|
260 |
| - for header in http_request.headers: |
261 |
| - if header.lower().startswith("content-length:"): |
262 |
| - http_request.headers.remove(header) |
263 |
| - break |
264 |
| - http_request.headers.append(f"Content-Length: {len(http_request.body)}") |
| 286 | + if context and context.shortcut_response: |
| 287 | + # Send shortcut response |
| 288 | + data_prefix = b'data:' |
| 289 | + http_response = HttpResponse( |
| 290 | + http_request.version, |
| 291 | + 200, |
| 292 | + "OK", |
| 293 | + [ |
| 294 | + "server: uvicorn", |
| 295 | + "cache-control: no-cache", |
| 296 | + "connection: keep-alive", |
| 297 | + "Content-Type: application/json", |
| 298 | + "Transfer-Encoding: chunked", |
| 299 | + ], |
| 300 | + data_prefix + body |
| 301 | + ) |
| 302 | + return http_response |
265 | 303 |
|
266 |
| - pipeline_data = http_request.reconstruct() |
| 304 | + else: |
| 305 | + # Forward request to target |
| 306 | + http_request.body = body |
| 307 | + |
| 308 | + for header in http_request.headers: |
| 309 | + if header.lower().startswith("content-length:"): |
| 310 | + http_request.headers.remove(header) |
| 311 | + break |
| 312 | + http_request.headers.append(f"Content-Length: {len(http_request.body)}") |
267 | 313 |
|
268 |
| - return pipeline_data |
| 314 | + return http_request |
269 | 315 |
|
270 | 316 | async def _forward_data_to_target(self, data: bytes) -> None:
|
271 |
| - """Forward data to target if connection is established""" |
272 |
| - if self.target_transport and not self.target_transport.is_closing(): |
273 |
| - data = await self._forward_data_through_pipeline(data) |
274 |
| - self.target_transport.write(data) |
| 317 | + """ |
| 318 | + Forward data to target if connection is established. In case of shortcut |
| 319 | + response, send a response to the client |
| 320 | + """ |
| 321 | + pipeline_output = await self._forward_data_through_pipeline(data) |
| 322 | + |
| 323 | + if isinstance(pipeline_output, HttpResponse): |
| 324 | + # We need to send shortcut response |
| 325 | + if self.transport and not self.transport.is_closing(): |
| 326 | + # First, close target_transport since we don't need to send any |
| 327 | + # request to the target |
| 328 | + self.target_transport.close() |
| 329 | + |
| 330 | + # Send the shortcut response data in a chunk |
| 331 | + chunk = pipeline_output.reconstruct() |
| 332 | + chunk_size = hex(len(chunk))[2:] + "\r\n" |
| 333 | + self.transport.write(chunk_size.encode()) |
| 334 | + self.transport.write(chunk) |
| 335 | + self.transport.write(b"\r\n") |
| 336 | + |
| 337 | + # Send data done chunk |
| 338 | + chunk = b"data: [DONE]\n\n" |
| 339 | + # Add chunk size for DONE message |
| 340 | + chunk_size = hex(len(chunk))[2:] + "\r\n" |
| 341 | + self.transport.write(chunk_size.encode()) |
| 342 | + self.transport.write(chunk) |
| 343 | + self.transport.write(b"\r\n") |
| 344 | + # Now send the final chunk with 0 |
| 345 | + self.transport.write(b"0\r\n\r\n") |
| 346 | + else: |
| 347 | + if self.target_transport and not self.target_transport.is_closing(): |
| 348 | + if isinstance(pipeline_output, HttpRequest): |
| 349 | + pipeline_output = pipeline_output.reconstruct() |
| 350 | + self.target_transport.write(pipeline_output) |
275 | 351 |
|
276 | 352 | def data_received(self, data: bytes) -> None:
|
277 | 353 | """Handle received data from client"""
|
|
0 commit comments