1
1
import asyncio
2
+ import json
2
3
import re
3
4
import ssl
4
5
from dataclasses import dataclass
11
12
from codegate .config import Config
12
13
from codegate .pipeline .base import PipelineContext
13
14
from codegate .pipeline .factory import PipelineFactory
15
+ from codegate .pipeline .output import OutputPipelineInstance
14
16
from codegate .pipeline .secrets .manager import SecretsManager
15
17
from codegate .providers .copilot .mapping import VALIDATED_ROUTES
16
18
from codegate .providers .copilot .pipeline import (
17
19
CopilotChatPipeline ,
18
20
CopilotFimPipeline ,
19
21
CopilotPipeline ,
20
22
)
23
+ from codegate .providers .copilot .streaming import SSEProcessor
21
24
22
25
logger = structlog .get_logger ("codegate" )
23
26
@@ -139,13 +142,13 @@ async def _body_through_pipeline(
139
142
path : str ,
140
143
headers : list [str ],
141
144
body : bytes ,
142
- ) -> bytes :
145
+ ) -> ( bytes , PipelineContext ) :
143
146
logger .debug (f"Processing body through pipeline: { len (body )} bytes" )
144
147
strategy = self ._select_pipeline (method , path )
145
148
if strategy is None :
146
149
# if we didn't select any strategy that would change the request
147
150
# let's just pass through the body as-is
148
- return body
151
+ return body , None
149
152
return await strategy .process_body (headers , body )
150
153
151
154
async def _request_to_target (self , headers : list [str ], body : bytes ):
@@ -154,13 +157,16 @@ async def _request_to_target(self, headers: list[str], body: bytes):
154
157
).encode ()
155
158
logger .debug (f"Request Line: { request_line } " )
156
159
157
- body = await self ._body_through_pipeline (
160
+ body , context = await self ._body_through_pipeline (
158
161
self .request .method ,
159
162
self .request .path ,
160
163
headers ,
161
164
body ,
162
165
)
163
166
167
+ if context :
168
+ self .context_tracking = context
169
+
164
170
for header in headers :
165
171
if header .lower ().startswith ("content-length:" ):
166
172
headers .remove (header )
@@ -243,12 +249,13 @@ async def _forward_data_through_pipeline(self, data: bytes) -> bytes:
243
249
# we couldn't parse this into an HTTP request, so we just pass through
244
250
return data
245
251
246
- http_request .body = await self ._body_through_pipeline (
252
+ http_request .body , context = await self ._body_through_pipeline (
247
253
http_request .method ,
248
254
http_request .path ,
249
255
http_request .headers ,
250
256
http_request .body ,
251
257
)
258
+ self .context_tracking = context
252
259
253
260
for header in http_request .headers :
254
261
if header .lower ().startswith ("content-length:" ):
@@ -549,15 +556,68 @@ def __init__(self, proxy: CopilotProvider):
549
556
self .proxy = proxy
550
557
self .transport : Optional [asyncio .Transport ] = None
551
558
559
+ self .headers_sent = False
560
+ self .sse_processor : Optional [SSEProcessor ] = None
561
+ self .output_pipeline_instance : Optional [OutputPipelineInstance ] = None
562
+
552
563
def connection_made (self , transport : asyncio .Transport ) -> None :
553
564
"""Handle successful connection to target"""
554
565
self .transport = transport
555
566
self .proxy .target_transport = transport
556
567
568
+ def _process_chunk (self , chunk : bytes ):
569
+ records = self .sse_processor .process_chunk (chunk )
570
+
571
+ for record in records :
572
+ if record ["type" ] == "done" :
573
+ sse_data = b"data: [DONE]\n \n "
574
+ # Add chunk size for DONE message too
575
+ chunk_size = hex (len (sse_data ))[2 :] + "\r \n "
576
+ self ._proxy_transport_write (chunk_size .encode ())
577
+ self ._proxy_transport_write (sse_data )
578
+ self ._proxy_transport_write (b"\r \n " )
579
+ # Now send the final zero chunk
580
+ self ._proxy_transport_write (b"0\r \n \r \n " )
581
+ else :
582
+ sse_data = f"data: { json .dumps (record ['content' ])} \n \n " .encode ("utf-8" )
583
+ chunk_size = hex (len (sse_data ))[2 :] + "\r \n "
584
+ self ._proxy_transport_write (chunk_size .encode ())
585
+ self ._proxy_transport_write (sse_data )
586
+ self ._proxy_transport_write (b"\r \n " )
587
+
588
+ def _proxy_transport_write (self , data : bytes ):
589
+ self .proxy .transport .write (data )
590
+
557
591
def data_received (self , data : bytes ) -> None :
558
592
"""Handle data received from target"""
593
+ if self .proxy .context_tracking is not None and self .sse_processor is None :
594
+ logger .debug ("Tracking context for pipeline processing" )
595
+ self .sse_processor = SSEProcessor ()
596
+ out_pipeline_processor = self .proxy .pipeline_factory .create_output_pipeline ()
597
+ self .output_pipeline_instance = OutputPipelineInstance (
598
+ pipeline_steps = out_pipeline_processor .pipeline_steps ,
599
+ input_context = self .proxy .context_tracking ,
600
+ )
601
+
559
602
if self .proxy .transport and not self .proxy .transport .is_closing ():
560
- self .proxy .transport .write (data )
603
+ if not self .sse_processor :
604
+ # Pass through non-SSE data unchanged
605
+ self .proxy .transport .write (data )
606
+ return
607
+
608
+ # Check if this is the first chunk with headers
609
+ if not self .headers_sent :
610
+ header_end = data .find (b"\r \n \r \n " )
611
+ if header_end != - 1 :
612
+ self .headers_sent = True
613
+ # Send headers first
614
+ headers = data [: header_end + 4 ]
615
+ self ._proxy_transport_write (headers )
616
+ logger .debug (f"Headers sent: { headers } " )
617
+
618
+ data = data [header_end + 4 :]
619
+
620
+ self ._process_chunk (data )
561
621
562
622
def connection_lost (self , exc : Optional [Exception ]) -> None :
563
623
"""Handle connection loss to target"""
@@ -570,3 +630,5 @@ def connection_lost(self, exc: Optional[Exception]) -> None:
570
630
self .proxy .transport .close ()
571
631
except Exception as e :
572
632
logger .error (f"Error closing proxy transport: { e } " )
633
+
634
+ # todo: clear the context to erase the sensitive data
0 commit comments