Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Fix codegate-version support and pipeline error #396

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/codegate/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class PipelineContext:
prompt_id: Optional[str] = field(default_factory=lambda: None)
input_request: Optional[Prompt] = field(default_factory=lambda: None)
output_responses: List[Output] = field(default_factory=list)
shortcut_response: bool = False

def add_code_snippet(self, snippet: CodeSnippet):
self.code_snippets.append(snippet)
Expand Down
1 change: 1 addition & 0 deletions src/codegate/pipeline/version/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ async def process(
if last_user_message is not None:
last_user_message_str, _ = last_user_message
if "codegate-version" in last_user_message_str.lower():
context.shortcut_response = True
context.add_alert(self.name, trigger_string=last_user_message_str)
return PipelineResult(
response=PipelineResponse(
Expand Down
31 changes: 29 additions & 2 deletions src/codegate/providers/copilot/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import json
import time
from abc import ABC, abstractmethod
from typing import Dict, Tuple

import structlog
from litellm import ModelResponse
from litellm.types.llms.openai import ChatCompletionRequest
from litellm.types.utils import Delta, StreamingChoices

from codegate.pipeline.base import PipelineContext, SequentialPipelineProcessor
from codegate.pipeline.base import PipelineContext, PipelineResult, SequentialPipelineProcessor
from codegate.pipeline.factory import PipelineFactory
from codegate.providers.normalizer.completion import CompletionNormalizer

Expand Down Expand Up @@ -64,6 +67,23 @@ def _get_copilot_headers(headers: Dict[str, str]) -> Dict[str, str]:

return copilot_headers

@staticmethod
def _create_shortcut_response(result: PipelineResult, model: str) -> bytes:
response = ModelResponse(
choices=[
StreamingChoices(
finish_reason="stop",
index=0,
delta=Delta(content=result.response.content, role="assistant"),
)
],
created=int(time.time()),
model=model,
stream=True,
)
body = response.model_dump_json(exclude_none=True, exclude_unset=True).encode()
return body

async def process_body(self, headers: list[str], body: bytes) -> Tuple[bytes, PipelineContext]:
"""Common processing logic for all strategies"""
try:
Expand All @@ -88,7 +108,14 @@ async def process_body(self, headers: list[str], body: bytes) -> Tuple[bytes, Pi
is_copilot=True,
)

if result.request:
if result.context.shortcut_response:
# Return shortcut response to the user
body = CopilotPipeline._create_shortcut_response(
result, normalized_body.get("model", "gpt-4o-mini")
)
logger.info(f"Pipeline created shortcut response: {body}")

elif result.request:
# the pipeline did modify the request, return to the user
# in the original LLM format
body = self.normalizer.denormalize(result.request)
Expand Down
106 changes: 91 additions & 15 deletions src/codegate/providers/copilot/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
import ssl
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
from urllib.parse import unquote, urljoin, urlparse

import structlog
Expand Down Expand Up @@ -61,6 +61,30 @@ def reconstruct(self) -> bytes:
return result


@dataclass
class HttpResponse:
"""Data class to store HTTP response details"""

version: str
status_code: int
reason: str
headers: List[str]
body: Optional[bytes] = None

def reconstruct(self) -> bytes:
"""Reconstruct HTTP response from stored details"""
headers = "\r\n".join(self.headers)
status_line = f"{self.version} {self.status_code} {self.reason}\r\n"
header_block = f"{status_line}{headers}\r\n\r\n"

# Convert header block to bytes and combine with body
result = header_block.encode("utf-8")
if self.body:
result += self.body

return result


def extract_path(full_path: str) -> str:
"""Extract clean path from full URL or path string"""
logger.debug(f"Extracting path from {full_path}")
Expand Down Expand Up @@ -145,7 +169,7 @@ async def _body_through_pipeline(
) -> Tuple[bytes, PipelineContext]:
logger.debug(f"Processing body through pipeline: {len(body)} bytes")
strategy = self._select_pipeline(method, path)
if strategy is None:
if len(body) == 0 or strategy is None:
# if we didn't select any strategy that would change the request
# let's just pass through the body as-is
return body, None
Expand Down Expand Up @@ -243,35 +267,87 @@ def _check_buffer_size(self, new_data: bytes) -> bool:
"""Check if adding new data would exceed buffer size limit"""
return len(self.buffer) + len(new_data) <= MAX_BUFFER_SIZE

async def _forward_data_through_pipeline(self, data: bytes) -> bytes:
async def _forward_data_through_pipeline(
self, data: bytes
) -> Union[HttpRequest, HttpResponse]:
http_request = http_request_from_bytes(data)
if not http_request:
# we couldn't parse this into an HTTP request, so we just pass through
return data

http_request.body, context = await self._body_through_pipeline(
body, context = await self._body_through_pipeline(
http_request.method,
http_request.path,
http_request.headers,
http_request.body,
)
self.context_tracking = context

for header in http_request.headers:
if header.lower().startswith("content-length:"):
http_request.headers.remove(header)
break
http_request.headers.append(f"Content-Length: {len(http_request.body)}")
if context and context.shortcut_response:
# Send shortcut response
data_prefix = b'data:'
http_response = HttpResponse(
http_request.version,
200,
"OK",
[
"server: uvicorn",
"cache-control: no-cache",
"connection: keep-alive",
"Content-Type: application/json",
"Transfer-Encoding: chunked",
],
data_prefix + body
)
return http_response

pipeline_data = http_request.reconstruct()
else:
# Forward request to target
http_request.body = body

for header in http_request.headers:
if header.lower().startswith("content-length:"):
http_request.headers.remove(header)
break
http_request.headers.append(f"Content-Length: {len(http_request.body)}")

return pipeline_data
return http_request

async def _forward_data_to_target(self, data: bytes) -> None:
"""Forward data to target if connection is established"""
if self.target_transport and not self.target_transport.is_closing():
data = await self._forward_data_through_pipeline(data)
self.target_transport.write(data)
"""
Forward data to target if connection is established. In case of shortcut
response, send a response to the client
"""
pipeline_output = await self._forward_data_through_pipeline(data)

if isinstance(pipeline_output, HttpResponse):
# We need to send shortcut response
if self.transport and not self.transport.is_closing():
# First, close target_transport since we don't need to send any
# request to the target
self.target_transport.close()

# Send the shortcut response data in a chunk
chunk = pipeline_output.reconstruct()
chunk_size = hex(len(chunk))[2:] + "\r\n"
self.transport.write(chunk_size.encode())
self.transport.write(chunk)
self.transport.write(b"\r\n")

# Send data done chunk
chunk = b"data: [DONE]\n\n"
# Add chunk size for DONE message
chunk_size = hex(len(chunk))[2:] + "\r\n"
self.transport.write(chunk_size.encode())
self.transport.write(chunk)
self.transport.write(b"\r\n")
# Now send the final chunk with 0
self.transport.write(b"0\r\n\r\n")
else:
if self.target_transport and not self.target_transport.is_closing():
if isinstance(pipeline_output, HttpRequest):
pipeline_output = pipeline_output.reconstruct()
self.target_transport.write(pipeline_output)

def data_received(self, data: bytes) -> None:
"""Handle received data from client"""
Expand Down
Loading