13
13
from codegate .pipeline .factory import PipelineFactory
14
14
from codegate .pipeline .secrets .manager import SecretsManager
15
15
from codegate .providers .copilot .mapping import VALIDATED_ROUTES
16
- from codegate .providers .copilot .pipeline import CopilotFimPipeline
16
+ from codegate .providers .copilot .pipeline import (
17
+ CopilotChatPipeline ,
18
+ CopilotFimPipeline ,
19
+ CopilotPipeline ,
20
+ )
17
21
18
22
logger = structlog .get_logger ("codegate" )
19
23
@@ -38,6 +42,61 @@ class HttpRequest:
38
42
headers : List [str ]
39
43
original_path : str
40
44
target : Optional [str ] = None
45
+ body : Optional [bytes ] = None
46
+
47
+ def reconstruct (self ) -> bytes :
48
+ """Reconstruct HTTP request from stored details"""
49
+ headers = "\r \n " .join (self .headers )
50
+ request_line = f"{ self .method } /{ self .path } { self .version } \r \n "
51
+ header_block = f"{ request_line } { headers } \r \n \r \n "
52
+
53
+ # Convert header block to bytes and combine with body
54
+ result = header_block .encode ("utf-8" )
55
+ if self .body :
56
+ result += self .body
57
+
58
+ return result
59
+
60
+
61
+ def extract_path (full_path : str ) -> str :
62
+ """Extract clean path from full URL or path string"""
63
+ logger .debug (f"Extracting path from { full_path } " )
64
+ if full_path .startswith (("http://" , "https://" )):
65
+ parsed = urlparse (full_path )
66
+ path = parsed .path
67
+ if parsed .query :
68
+ path = f"{ path } ?{ parsed .query } "
69
+ return path .lstrip ("/" )
70
+ return full_path .lstrip ("/" )
71
+
72
+
73
+ def http_request_from_bytes (data : bytes ) -> Optional [HttpRequest ]:
74
+ """
75
+ Parse HTTP request details from raw bytes data.
76
+ TODO: Make safer by checking for valid HTTP request format, check
77
+ if there is a method if there are headers, etc.
78
+ """
79
+ if b"\r \n \r \n " not in data :
80
+ return None
81
+
82
+ headers_end = data .index (b"\r \n \r \n " )
83
+ headers = data [:headers_end ].split (b"\r \n " )
84
+
85
+ request = headers [0 ].decode ("utf-8" )
86
+ method , full_path , version = request .split (" " )
87
+
88
+ body_start = data .index (b"\r \n \r \n " ) + 4
89
+ body = data [body_start :]
90
+
91
+ return HttpRequest (
92
+ method = method ,
93
+ path = extract_path (full_path ),
94
+ version = version ,
95
+ headers = [header .decode ("utf-8" ) for header in headers [1 :]],
96
+ original_path = full_path ,
97
+ target = full_path if method == "CONNECT" else None ,
98
+ body = body ,
99
+ )
41
100
42
101
43
102
class CopilotProvider (asyncio .Protocol ):
@@ -63,20 +122,26 @@ def __init__(self, loop: asyncio.AbstractEventLoop):
63
122
self .pipeline_factory = PipelineFactory (SecretsManager ())
64
123
self .context_tracking : Optional [PipelineContext ] = None
65
124
66
- def _select_pipeline (self ):
67
- if (
68
- self .request .method == "POST"
69
- and self .request .path == "v1/engines/copilot-codex/completions"
70
- ):
125
+ def _select_pipeline (self , method : str , path : str ) -> Optional [CopilotPipeline ]:
126
+ if method == "POST" and path == "v1/engines/copilot-codex/completions" :
71
127
logger .debug ("Selected CopilotFimStrategy" )
72
128
return CopilotFimPipeline (self .pipeline_factory )
129
+ if method == "POST" and path == "chat/completions" :
130
+ logger .debug ("Selected CopilotChatStrategy" )
131
+ return CopilotChatPipeline (self .pipeline_factory )
73
132
74
133
logger .debug ("No pipeline strategy selected" )
75
134
return None
76
135
77
- async def _body_through_pipeline (self , headers : list [str ], body : bytes ) -> bytes :
136
+ async def _body_through_pipeline (
137
+ self ,
138
+ method : str ,
139
+ path : str ,
140
+ headers : list [str ],
141
+ body : bytes ,
142
+ ) -> bytes :
78
143
logger .debug (f"Processing body through pipeline: { len (body )} bytes" )
79
- strategy = self ._select_pipeline ()
144
+ strategy = self ._select_pipeline (method , path )
80
145
if strategy is None :
81
146
# if we didn't select any strategy that would change the request
82
147
# let's just pass through the body as-is
@@ -89,7 +154,12 @@ async def _request_to_target(self, headers: list[str], body: bytes):
89
154
).encode ()
90
155
logger .debug (f"Request Line: { request_line } " )
91
156
92
- body = await self ._body_through_pipeline (headers , body )
157
+ body = await self ._body_through_pipeline (
158
+ self .request .method ,
159
+ self .request .path ,
160
+ headers ,
161
+ body ,
162
+ )
93
163
94
164
for header in headers :
95
165
if header .lower ().startswith ("content-length:" ):
@@ -113,18 +183,6 @@ def connection_made(self, transport: asyncio.Transport) -> None:
113
183
self .peername = transport .get_extra_info ("peername" )
114
184
logger .debug (f"Client connected from { self .peername } " )
115
185
116
- @staticmethod
117
- def extract_path (full_path : str ) -> str :
118
- """Extract clean path from full URL or path string"""
119
- logger .debug (f"Extracting path from { full_path } " )
120
- if full_path .startswith (("http://" , "https://" )):
121
- parsed = urlparse (full_path )
122
- path = parsed .path
123
- if parsed .query :
124
- path = f"{ path } ?{ parsed .query } "
125
- return path .lstrip ("/" )
126
- return full_path .lstrip ("/" )
127
-
128
186
def get_headers_dict (self ) -> Dict [str , str ]:
129
187
"""Convert raw headers to dictionary format"""
130
188
headers_dict = {}
@@ -161,7 +219,7 @@ def parse_headers(self) -> bool:
161
219
162
220
self .request = HttpRequest (
163
221
method = method ,
164
- path = self . extract_path (full_path ),
222
+ path = extract_path (full_path ),
165
223
version = version ,
166
224
headers = [header .decode ("utf-8" ) for header in headers [1 :]],
167
225
original_path = full_path ,
@@ -179,9 +237,33 @@ def _check_buffer_size(self, new_data: bytes) -> bool:
179
237
"""Check if adding new data would exceed buffer size limit"""
180
238
return len (self .buffer ) + len (new_data ) <= MAX_BUFFER_SIZE
181
239
182
- def _forward_data_to_target (self , data : bytes ) -> None :
240
+ async def _forward_data_through_pipeline (self , data : bytes ) -> bytes :
241
+ http_request = http_request_from_bytes (data )
242
+ if not http_request :
243
+ # we couldn't parse this into an HTTP request, so we just pass through
244
+ return data
245
+
246
+ http_request .body = await self ._body_through_pipeline (
247
+ http_request .method ,
248
+ http_request .path ,
249
+ http_request .headers ,
250
+ http_request .body ,
251
+ )
252
+
253
+ for header in http_request .headers :
254
+ if header .lower ().startswith ("content-length:" ):
255
+ http_request .headers .remove (header )
256
+ break
257
+ http_request .headers .append (f"Content-Length: { len (http_request .body )} " )
258
+
259
+ pipeline_data = http_request .reconstruct ()
260
+
261
+ return pipeline_data
262
+
263
+ async def _forward_data_to_target (self , data : bytes ) -> None :
183
264
"""Forward data to target if connection is established"""
184
265
if self .target_transport and not self .target_transport .is_closing ():
266
+ data = await self ._forward_data_through_pipeline (data )
185
267
self .target_transport .write (data )
186
268
187
269
def data_received (self , data : bytes ) -> None :
@@ -201,7 +283,7 @@ def data_received(self, data: bytes) -> None:
201
283
else :
202
284
asyncio .create_task (self .handle_http_request ())
203
285
else :
204
- self ._forward_data_to_target (data )
286
+ asyncio . create_task ( self ._forward_data_to_target (data ) )
205
287
206
288
except Exception as e :
207
289
logger .error (f"Error processing received data: { e } " )
0 commit comments