Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit d441951

Browse files
committed
Fix codegate-version support and pipeline error
1 parent 576c97b commit d441951

File tree

4 files changed

+122
-17
lines changed

4 files changed

+122
-17
lines changed

src/codegate/pipeline/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ class PipelineContext:
7777
prompt_id: Optional[str] = field(default_factory=lambda: None)
7878
input_request: Optional[Prompt] = field(default_factory=lambda: None)
7979
output_responses: List[Output] = field(default_factory=list)
80+
shortcut_response: bool = False
8081

8182
def add_code_snippet(self, snippet: CodeSnippet):
8283
self.code_snippets.append(snippet)

src/codegate/pipeline/version/version.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ async def process(
4343
if last_user_message is not None:
4444
last_user_message_str, _ = last_user_message
4545
if "codegate-version" in last_user_message_str.lower():
46+
context.shortcut_response = True
4647
context.add_alert(self.name, trigger_string=last_user_message_str)
4748
return PipelineResult(
4849
response=PipelineResponse(

src/codegate/providers/copilot/pipeline.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import json
2+
import time
23
from abc import ABC, abstractmethod
34
from typing import Dict, Tuple
45

56
import structlog
7+
from litellm import ModelResponse
68
from litellm.types.llms.openai import ChatCompletionRequest
9+
from litellm.types.utils import Delta, StreamingChoices
710

8-
from codegate.pipeline.base import PipelineContext, SequentialPipelineProcessor
11+
from codegate.pipeline.base import PipelineContext, PipelineResult, SequentialPipelineProcessor
912
from codegate.pipeline.factory import PipelineFactory
1013
from codegate.providers.normalizer.completion import CompletionNormalizer
1114

@@ -64,6 +67,23 @@ def _get_copilot_headers(headers: Dict[str, str]) -> Dict[str, str]:
6467

6568
return copilot_headers
6669

70+
@staticmethod
71+
def _create_shortcut_response(result: PipelineResult, model: str) -> bytes:
72+
response = ModelResponse(
73+
choices=[
74+
StreamingChoices(
75+
finish_reason="stop",
76+
index=0,
77+
delta=Delta(content=result.response.content, role="assistant"),
78+
)
79+
],
80+
created=int(time.time()),
81+
model=model,
82+
stream=True,
83+
)
84+
body = response.model_dump_json(exclude_none=True, exclude_unset=True).encode()
85+
return body
86+
6787
async def process_body(self, headers: list[str], body: bytes) -> Tuple[bytes, PipelineContext]:
6888
"""Common processing logic for all strategies"""
6989
try:
@@ -88,7 +108,14 @@ async def process_body(self, headers: list[str], body: bytes) -> Tuple[bytes, Pi
88108
is_copilot=True,
89109
)
90110

91-
if result.request:
111+
if result.context.shortcut_response:
112+
# Return shortcut response to the user
113+
body = CopilotPipeline._create_shortcut_response(
114+
result, normalized_body.get("model", "gpt-4o-mini")
115+
)
116+
logger.info(f"Pipeline created shortcut response: {body}")
117+
118+
elif result.request:
92119
# the pipeline did modify the request, return to the user
93120
# in the original LLM format
94121
body = self.normalizer.denormalize(result.request)

src/codegate/providers/copilot/provider.py

Lines changed: 91 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import re
33
import ssl
44
from dataclasses import dataclass
5-
from typing import Dict, List, Optional, Tuple
5+
from typing import Dict, List, Optional, Tuple, Union
66
from urllib.parse import unquote, urljoin, urlparse
77

88
import structlog
@@ -61,6 +61,30 @@ def reconstruct(self) -> bytes:
6161
return result
6262

6363

64+
@dataclass
65+
class HttpResponse:
66+
"""Data class to store HTTP response details"""
67+
68+
version: str
69+
status_code: int
70+
reason: str
71+
headers: List[str]
72+
body: Optional[bytes] = None
73+
74+
def reconstruct(self) -> bytes:
75+
"""Reconstruct HTTP response from stored details"""
76+
headers = "\r\n".join(self.headers)
77+
status_line = f"{self.version} {self.status_code} {self.reason}\r\n"
78+
header_block = f"{status_line}{headers}\r\n\r\n"
79+
80+
# Convert header block to bytes and combine with body
81+
result = header_block.encode("utf-8")
82+
if self.body:
83+
result += self.body
84+
85+
return result
86+
87+
6488
def extract_path(full_path: str) -> str:
6589
"""Extract clean path from full URL or path string"""
6690
logger.debug(f"Extracting path from {full_path}")
@@ -145,7 +169,7 @@ async def _body_through_pipeline(
145169
) -> Tuple[bytes, PipelineContext]:
146170
logger.debug(f"Processing body through pipeline: {len(body)} bytes")
147171
strategy = self._select_pipeline(method, path)
148-
if strategy is None:
172+
if len(body) == 0 or strategy is None:
149173
# if we didn't select any strategy that would change the request
150174
# let's just pass through the body as-is
151175
return body, None
@@ -243,35 +267,87 @@ def _check_buffer_size(self, new_data: bytes) -> bool:
243267
"""Check if adding new data would exceed buffer size limit"""
244268
return len(self.buffer) + len(new_data) <= MAX_BUFFER_SIZE
245269

246-
async def _forward_data_through_pipeline(self, data: bytes) -> bytes:
270+
async def _forward_data_through_pipeline(
271+
self, data: bytes
272+
) -> Union[HttpRequest, List[HttpResponse]]:
247273
http_request = http_request_from_bytes(data)
248274
if not http_request:
249275
# we couldn't parse this into an HTTP request, so we just pass through
250276
return data
251277

252-
http_request.body, context = await self._body_through_pipeline(
278+
body, context = await self._body_through_pipeline(
253279
http_request.method,
254280
http_request.path,
255281
http_request.headers,
256282
http_request.body,
257283
)
258284
self.context_tracking = context
259285

260-
for header in http_request.headers:
261-
if header.lower().startswith("content-length:"):
262-
http_request.headers.remove(header)
263-
break
264-
http_request.headers.append(f"Content-Length: {len(http_request.body)}")
286+
if context and context.shortcut_response:
287+
# Send shortcut response
288+
data_prefix = b'data:'
289+
http_response = HttpResponse(
290+
http_request.version,
291+
200,
292+
"OK",
293+
[
294+
"server: uvicorn",
295+
"cache-control: no-cache",
296+
"connection: keep-alive",
297+
"Content-Type: application/json",
298+
"Transfer-Encoding: chunked",
299+
],
300+
data_prefix + body
301+
)
302+
return http_response
265303

266-
pipeline_data = http_request.reconstruct()
304+
else:
305+
# Forward request to target
306+
http_request.body = body
307+
308+
for header in http_request.headers:
309+
if header.lower().startswith("content-length:"):
310+
http_request.headers.remove(header)
311+
break
312+
http_request.headers.append(f"Content-Length: {len(http_request.body)}")
267313

268-
return pipeline_data
314+
return http_request
269315

270316
async def _forward_data_to_target(self, data: bytes) -> None:
271-
"""Forward data to target if connection is established"""
272-
if self.target_transport and not self.target_transport.is_closing():
273-
data = await self._forward_data_through_pipeline(data)
274-
self.target_transport.write(data)
317+
"""
318+
Forward data to target if connection is established. In case of shortcut
319+
response, send a response to the client
320+
"""
321+
pipeline_output = await self._forward_data_through_pipeline(data)
322+
323+
if isinstance(pipeline_output, HttpResponse):
324+
# We need to send shortcut response
325+
if self.transport and not self.transport.is_closing():
326+
# First, close target_transport since we don't need to send any
327+
# request to the target
328+
self.target_transport.close()
329+
330+
# Send the shortcut response data in a chunk
331+
chunk = pipeline_output.reconstruct()
332+
chunk_size = hex(len(chunk))[2:] + "\r\n"
333+
self.transport.write(chunk_size.encode())
334+
self.transport.write(chunk)
335+
self.transport.write(b"\r\n")
336+
337+
# Send data done chunk
338+
chunk = b"data: [DONE]\n\n"
339+
# Add chunk size for DONE message
340+
chunk_size = hex(len(chunk))[2:] + "\r\n"
341+
self.transport.write(chunk_size.encode())
342+
self.transport.write(chunk)
343+
self.transport.write(b"\r\n")
344+
# Now send the final chunk with 0
345+
self.transport.write(b"0\r\n\r\n")
346+
else:
347+
if self.target_transport and not self.target_transport.is_closing():
348+
if isinstance(pipeline_output, HttpRequest):
349+
pipeline_output = pipeline_output.reconstruct()
350+
self.target_transport.write(pipeline_output)
275351

276352
def data_received(self, data: bytes) -> None:
277353
"""Handle received data from client"""

0 commit comments

Comments
 (0)