Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 36e7fb3

Browse files
committed
Deconstruct the streaming reply into chunks and send them back individually
This will allow us to use an output pipeline in Copilot
1 parent 51e42b5 commit 36e7fb3

File tree

3 files changed

+126
-8
lines changed

3 files changed

+126
-8
lines changed

src/codegate/providers/copilot/pipeline.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import structlog
66
from litellm.types.llms.openai import ChatCompletionRequest
77

8+
from codegate.pipeline.base import PipelineContext
89
from codegate.providers.normalizer.completion import CompletionNormalizer
910

1011
logger = structlog.get_logger("codegate")
@@ -62,7 +63,7 @@ def _get_copilot_headers(headers: Dict[str, str]) -> Dict[str, str]:
6263

6364
return copilot_headers
6465

65-
async def process_body(self, headers: list[str], body: bytes) -> bytes:
66+
async def process_body(self, headers: list[str], body: bytes) -> (bytes, PipelineContext):
6667
"""Common processing logic for all strategies"""
6768
try:
6869
normalized_body = self.normalizer.normalize(body)
@@ -92,10 +93,10 @@ async def process_body(self, headers: list[str], body: bytes) -> bytes:
9293
body = self.normalizer.denormalize(result.request)
9394
logger.info(f"Pipeline processed request: {body}")
9495

95-
return body
96+
return body, result.context
9697
except Exception as e:
9798
logger.error(f"Pipeline processing error: {e}")
98-
return body
99+
return body, None
99100

100101

101102
class CopilotFimNormalizer:

src/codegate/providers/copilot/provider.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import json
23
import re
34
import ssl
45
from dataclasses import dataclass
@@ -11,13 +12,15 @@
1112
from codegate.config import Config
1213
from codegate.pipeline.base import PipelineContext
1314
from codegate.pipeline.factory import PipelineFactory
15+
from codegate.pipeline.output import OutputPipelineInstance
1416
from codegate.pipeline.secrets.manager import SecretsManager
1517
from codegate.providers.copilot.mapping import VALIDATED_ROUTES
1618
from codegate.providers.copilot.pipeline import (
1719
CopilotChatPipeline,
1820
CopilotFimPipeline,
1921
CopilotPipeline,
2022
)
23+
from codegate.providers.copilot.streaming import SSEProcessor
2124

2225
logger = structlog.get_logger("codegate")
2326

@@ -139,13 +142,13 @@ async def _body_through_pipeline(
139142
path: str,
140143
headers: list[str],
141144
body: bytes,
142-
) -> bytes:
145+
) -> (bytes, PipelineContext):
143146
logger.debug(f"Processing body through pipeline: {len(body)} bytes")
144147
strategy = self._select_pipeline(method, path)
145148
if strategy is None:
146149
# if we didn't select any strategy that would change the request
147150
# let's just pass through the body as-is
148-
return body
151+
return body, None
149152
return await strategy.process_body(headers, body)
150153

151154
async def _request_to_target(self, headers: list[str], body: bytes):
@@ -154,13 +157,16 @@ async def _request_to_target(self, headers: list[str], body: bytes):
154157
).encode()
155158
logger.debug(f"Request Line: {request_line}")
156159

157-
body = await self._body_through_pipeline(
160+
body, context = await self._body_through_pipeline(
158161
self.request.method,
159162
self.request.path,
160163
headers,
161164
body,
162165
)
163166

167+
if context:
168+
self.context_tracking = context
169+
164170
for header in headers:
165171
if header.lower().startswith("content-length:"):
166172
headers.remove(header)
@@ -243,12 +249,13 @@ async def _forward_data_through_pipeline(self, data: bytes) -> bytes:
243249
# we couldn't parse this into an HTTP request, so we just pass through
244250
return data
245251

246-
http_request.body = await self._body_through_pipeline(
252+
http_request.body, context = await self._body_through_pipeline(
247253
http_request.method,
248254
http_request.path,
249255
http_request.headers,
250256
http_request.body,
251257
)
258+
self.context_tracking = context
252259

253260
for header in http_request.headers:
254261
if header.lower().startswith("content-length:"):
@@ -549,15 +556,68 @@ def __init__(self, proxy: CopilotProvider):
549556
self.proxy = proxy
550557
self.transport: Optional[asyncio.Transport] = None
551558

559+
self.headers_sent = False
560+
self.sse_processor: Optional[SSEProcessor] = None
561+
self.output_pipeline_instance: Optional[OutputPipelineInstance] = None
562+
552563
def connection_made(self, transport: asyncio.Transport) -> None:
553564
"""Handle successful connection to target"""
554565
self.transport = transport
555566
self.proxy.target_transport = transport
556567

568+
def _process_chunk(self, chunk: bytes):
569+
records = self.sse_processor.process_chunk(chunk)
570+
571+
for record in records:
572+
if record["type"] == "done":
573+
sse_data = b"data: [DONE]\n\n"
574+
# Add chunk size for DONE message too
575+
chunk_size = hex(len(sse_data))[2:] + "\r\n"
576+
self._proxy_transport_write(chunk_size.encode())
577+
self._proxy_transport_write(sse_data)
578+
self._proxy_transport_write(b"\r\n")
579+
# Now send the final zero chunk
580+
self._proxy_transport_write(b"0\r\n\r\n")
581+
else:
582+
sse_data = f"data: {json.dumps(record['content'])}\n\n".encode("utf-8")
583+
chunk_size = hex(len(sse_data))[2:] + "\r\n"
584+
self._proxy_transport_write(chunk_size.encode())
585+
self._proxy_transport_write(sse_data)
586+
self._proxy_transport_write(b"\r\n")
587+
588+
def _proxy_transport_write(self, data: bytes):
589+
self.proxy.transport.write(data)
590+
557591
def data_received(self, data: bytes) -> None:
558592
"""Handle data received from target"""
593+
if self.proxy.context_tracking is not None and self.sse_processor is None:
594+
logger.debug("Tracking context for pipeline processing")
595+
self.sse_processor = SSEProcessor()
596+
out_pipeline_processor = self.proxy.pipeline_factory.create_output_pipeline()
597+
self.output_pipeline_instance = OutputPipelineInstance(
598+
pipeline_steps=out_pipeline_processor.pipeline_steps,
599+
input_context=self.proxy.context_tracking,
600+
)
601+
559602
if self.proxy.transport and not self.proxy.transport.is_closing():
560-
self.proxy.transport.write(data)
603+
if not self.sse_processor:
604+
# Pass through non-SSE data unchanged
605+
self.proxy.transport.write(data)
606+
return
607+
608+
# Check if this is the first chunk with headers
609+
if not self.headers_sent:
610+
header_end = data.find(b"\r\n\r\n")
611+
if header_end != -1:
612+
self.headers_sent = True
613+
# Send headers first
614+
headers = data[: header_end + 4]
615+
self._proxy_transport_write(headers)
616+
logger.debug(f"Headers sent: {headers}")
617+
618+
data = data[header_end + 4 :]
619+
620+
self._process_chunk(data)
561621

562622
def connection_lost(self, exc: Optional[Exception]) -> None:
563623
"""Handle connection loss to target"""
@@ -570,3 +630,5 @@ def connection_lost(self, exc: Optional[Exception]) -> None:
570630
self.proxy.transport.close()
571631
except Exception as e:
572632
logger.error(f"Error closing proxy transport: {e}")
633+
634+
# todo: clear the context to erase the sensitive data
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import json
2+
3+
import structlog
4+
5+
logger = structlog.get_logger("codegate")
6+
7+
8+
class SSEProcessor:
9+
def __init__(self):
10+
self.buffer = ""
11+
self.initial_chunk = True
12+
self.chunk_size = None # Store the original chunk size
13+
self.size_written = False
14+
15+
def process_chunk(self, chunk: bytes) -> list:
16+
print("BUFFER AT START")
17+
print(self.buffer)
18+
print("BUFFER AT START - END")
19+
# Skip any chunk size lines (hex number followed by \r\n)
20+
try:
21+
chunk_str = chunk.decode("utf-8")
22+
lines = chunk_str.split("\r\n")
23+
for line in lines:
24+
if all(c in "0123456789abcdefABCDEF" for c in line.strip()):
25+
continue
26+
self.buffer += line
27+
except UnicodeDecodeError:
28+
print("Failed to decode chunk")
29+
30+
records = []
31+
while True:
32+
record_end = self.buffer.find("\n\n")
33+
if record_end == -1:
34+
print(f"REMAINING BUFFER {self.buffer}")
35+
break
36+
37+
record = self.buffer[:record_end]
38+
self.buffer = self.buffer[record_end + 2 :]
39+
40+
if record.startswith("data: "):
41+
data_content = record[6:]
42+
if data_content.strip() == "[DONE]":
43+
records.append({"type": "done"})
44+
else:
45+
try:
46+
data = json.loads(data_content)
47+
records.append({"type": "data", "content": data})
48+
except json.JSONDecodeError:
49+
print(f"Failed to parse JSON: {data_content}")
50+
51+
return records
52+
53+
def get_pending(self):
54+
"""Return any pending incomplete data in the buffer"""
55+
return self.buffer

0 commit comments

Comments
 (0)