Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit aa4208a

Browse files
committed
Copilot chats are sent through an input pipeline
1 parent 831617b commit aa4208a

File tree

2 files changed

+135
-25
lines changed

2 files changed

+135
-25
lines changed

src/codegate/providers/copilot/pipeline.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,42 @@ def denormalize(self, request_from_pipeline: ChatCompletionRequest) -> bytes:
8686
return json.dumps(normalized_json_body).encode()
8787

8888

89+
class CopilotChatNormalizer:
90+
"""
91+
A custom normalizer for the chat format used by Copilot
92+
The requests are already in the OpenAI format, we just need
93+
to unmarshall them and marshall them back.
94+
"""
95+
96+
def normalize(self, body: bytes) -> ChatCompletionRequest:
97+
json_body = json.loads(body)
98+
return ChatCompletionRequest(**json_body)
99+
100+
def denormalize(self, request_from_pipeline: ChatCompletionRequest) -> bytes:
101+
return json.dumps(request_from_pipeline).encode()
102+
103+
89104
class CopilotFimPipeline(CopilotPipeline):
90105
"""
91106
A pipeline for the FIM format used by Copilot. Combines the normalizer for the FIM
92107
format and the FIM pipeline used by all providers.
93108
"""
94109

95110
def _create_normalizer(self):
96-
return CopilotFimNormalizer() # Uses your custom normalizer
111+
return CopilotFimNormalizer()
97112

98113
def create_pipeline(self):
99114
return self.pipeline_factory.create_fim_pipeline()
115+
116+
117+
class CopilotChatPipeline(CopilotPipeline):
118+
"""
119+
A pipeline for the Chat format used by Copilot. Combines the normalizer for the FIM
120+
format and the FIM pipeline used by all providers.
121+
"""
122+
123+
def _create_normalizer(self):
124+
return CopilotChatNormalizer()
125+
126+
def create_pipeline(self):
127+
return self.pipeline_factory.create_input_pipeline()

src/codegate/providers/copilot/provider.py

Lines changed: 106 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
from codegate.pipeline.factory import PipelineFactory
1414
from codegate.pipeline.secrets.manager import SecretsManager
1515
from codegate.providers.copilot.mapping import VALIDATED_ROUTES
16-
from codegate.providers.copilot.pipeline import CopilotFimPipeline
16+
from codegate.providers.copilot.pipeline import (
17+
CopilotChatPipeline,
18+
CopilotFimPipeline,
19+
CopilotPipeline,
20+
)
1721

1822
logger = structlog.get_logger("codegate")
1923

@@ -38,6 +42,61 @@ class HttpRequest:
3842
headers: List[str]
3943
original_path: str
4044
target: Optional[str] = None
45+
body: Optional[bytes] = None
46+
47+
def reconstruct(self) -> bytes:
48+
"""Reconstruct HTTP request from stored details"""
49+
headers = "\r\n".join(self.headers)
50+
request_line = f"{self.method} /{self.path} {self.version}\r\n"
51+
header_block = f"{request_line}{headers}\r\n\r\n"
52+
53+
# Convert header block to bytes and combine with body
54+
result = header_block.encode("utf-8")
55+
if self.body:
56+
result += self.body
57+
58+
return result
59+
60+
61+
def extract_path(full_path: str) -> str:
62+
"""Extract clean path from full URL or path string"""
63+
logger.debug(f"Extracting path from {full_path}")
64+
if full_path.startswith(("http://", "https://")):
65+
parsed = urlparse(full_path)
66+
path = parsed.path
67+
if parsed.query:
68+
path = f"{path}?{parsed.query}"
69+
return path.lstrip("/")
70+
return full_path.lstrip("/")
71+
72+
73+
def http_request_from_bytes(data: bytes) -> Optional[HttpRequest]:
74+
"""
75+
Parse HTTP request details from raw bytes data.
76+
TODO: Make safer by checking for valid HTTP request format, check
77+
if there is a method if there are headers, etc.
78+
"""
79+
if b"\r\n\r\n" not in data:
80+
return None
81+
82+
headers_end = data.index(b"\r\n\r\n")
83+
headers = data[:headers_end].split(b"\r\n")
84+
85+
request = headers[0].decode("utf-8")
86+
method, full_path, version = request.split(" ")
87+
88+
body_start = data.index(b"\r\n\r\n") + 4
89+
body = data[body_start:]
90+
91+
return HttpRequest(
92+
method=method,
93+
path=extract_path(full_path),
94+
version=version,
95+
headers=[header.decode("utf-8") for header in headers[1:]],
96+
original_path=full_path,
97+
target=full_path if method == "CONNECT" else None,
98+
body=body,
99+
)
41100

42101

43102
class CopilotProvider(asyncio.Protocol):
@@ -63,20 +122,26 @@ def __init__(self, loop: asyncio.AbstractEventLoop):
63122
self.pipeline_factory = PipelineFactory(SecretsManager())
64123
self.context_tracking: Optional[PipelineContext] = None
65124

66-
def _select_pipeline(self):
67-
if (
68-
self.request.method == "POST"
69-
and self.request.path == "v1/engines/copilot-codex/completions"
70-
):
125+
def _select_pipeline(self, method: str, path: str) -> Optional[CopilotPipeline]:
126+
if method == "POST" and path == "v1/engines/copilot-codex/completions":
71127
logger.debug("Selected CopilotFimStrategy")
72128
return CopilotFimPipeline(self.pipeline_factory)
129+
if method == "POST" and path == "chat/completions":
130+
logger.debug("Selected CopilotChatStrategy")
131+
return CopilotChatPipeline(self.pipeline_factory)
73132

74133
logger.debug("No pipeline strategy selected")
75134
return None
76135

77-
async def _body_through_pipeline(self, headers: list[str], body: bytes) -> bytes:
136+
async def _body_through_pipeline(
137+
self,
138+
method: str,
139+
path: str,
140+
headers: list[str],
141+
body: bytes,
142+
) -> bytes:
78143
logger.debug(f"Processing body through pipeline: {len(body)} bytes")
79-
strategy = self._select_pipeline()
144+
strategy = self._select_pipeline(method, path)
80145
if strategy is None:
81146
# if we didn't select any strategy that would change the request
82147
# let's just pass through the body as-is
@@ -89,7 +154,12 @@ async def _request_to_target(self, headers: list[str], body: bytes):
89154
).encode()
90155
logger.debug(f"Request Line: {request_line}")
91156

92-
body = await self._body_through_pipeline(headers, body)
157+
body = await self._body_through_pipeline(
158+
self.request.method,
159+
self.request.path,
160+
headers,
161+
body,
162+
)
93163

94164
for header in headers:
95165
if header.lower().startswith("content-length:"):
@@ -113,18 +183,6 @@ def connection_made(self, transport: asyncio.Transport) -> None:
113183
self.peername = transport.get_extra_info("peername")
114184
logger.debug(f"Client connected from {self.peername}")
115185

116-
@staticmethod
117-
def extract_path(full_path: str) -> str:
118-
"""Extract clean path from full URL or path string"""
119-
logger.debug(f"Extracting path from {full_path}")
120-
if full_path.startswith(("http://", "https://")):
121-
parsed = urlparse(full_path)
122-
path = parsed.path
123-
if parsed.query:
124-
path = f"{path}?{parsed.query}"
125-
return path.lstrip("/")
126-
return full_path.lstrip("/")
127-
128186
def get_headers_dict(self) -> Dict[str, str]:
129187
"""Convert raw headers to dictionary format"""
130188
headers_dict = {}
@@ -161,7 +219,7 @@ def parse_headers(self) -> bool:
161219

162220
self.request = HttpRequest(
163221
method=method,
164-
path=self.extract_path(full_path),
222+
path=extract_path(full_path),
165223
version=version,
166224
headers=[header.decode("utf-8") for header in headers[1:]],
167225
original_path=full_path,
@@ -179,9 +237,33 @@ def _check_buffer_size(self, new_data: bytes) -> bool:
179237
"""Check if adding new data would exceed buffer size limit"""
180238
return len(self.buffer) + len(new_data) <= MAX_BUFFER_SIZE
181239

182-
def _forward_data_to_target(self, data: bytes) -> None:
240+
async def _forward_data_through_pipeline(self, data: bytes) -> bytes:
241+
http_request = http_request_from_bytes(data)
242+
if not http_request:
243+
# we couldn't parse this into an HTTP request, so we just pass through
244+
return data
245+
246+
http_request.body = await self._body_through_pipeline(
247+
http_request.method,
248+
http_request.path,
249+
http_request.headers,
250+
http_request.body,
251+
)
252+
253+
for header in http_request.headers:
254+
if header.lower().startswith("content-length:"):
255+
http_request.headers.remove(header)
256+
break
257+
http_request.headers.append(f"Content-Length: {len(http_request.body)}")
258+
259+
pipeline_data = http_request.reconstruct()
260+
261+
return pipeline_data
262+
263+
async def _forward_data_to_target(self, data: bytes) -> None:
183264
"""Forward data to target if connection is established"""
184265
if self.target_transport and not self.target_transport.is_closing():
266+
data = await self._forward_data_through_pipeline(data)
185267
self.target_transport.write(data)
186268

187269
def data_received(self, data: bytes) -> None:
@@ -201,7 +283,7 @@ def data_received(self, data: bytes) -> None:
201283
else:
202284
asyncio.create_task(self.handle_http_request())
203285
else:
204-
self._forward_data_to_target(data)
286+
asyncio.create_task(self._forward_data_to_target(data))
205287

206288
except Exception as e:
207289
logger.error(f"Error processing received data: {e}")

0 commit comments

Comments
 (0)