Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/codegate/providers/copilot/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import structlog
from litellm.types.llms.openai import ChatCompletionRequest

from codegate.pipeline.base import PipelineContext
from codegate.providers.normalizer.completion import CompletionNormalizer

logger = structlog.get_logger("codegate")
Expand Down Expand Up @@ -62,7 +63,7 @@ def _get_copilot_headers(headers: Dict[str, str]) -> Dict[str, str]:

return copilot_headers

async def process_body(self, headers: list[str], body: bytes) -> bytes:
async def process_body(self, headers: list[str], body: bytes) -> (bytes, PipelineContext):
"""Common processing logic for all strategies"""
try:
normalized_body = self.normalizer.normalize(body)
Expand Down Expand Up @@ -92,10 +93,10 @@ async def process_body(self, headers: list[str], body: bytes) -> bytes:
body = self.normalizer.denormalize(result.request)
logger.info(f"Pipeline processed request: {body}")

return body
return body, result.context
except Exception as e:
logger.error(f"Pipeline processing error: {e}")
return body
return body, None


class CopilotFimNormalizer:
Expand Down
72 changes: 67 additions & 5 deletions src/codegate/providers/copilot/provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
import re
import ssl
from dataclasses import dataclass
Expand All @@ -11,13 +12,15 @@
from codegate.config import Config
from codegate.pipeline.base import PipelineContext
from codegate.pipeline.factory import PipelineFactory
from codegate.pipeline.output import OutputPipelineInstance
from codegate.pipeline.secrets.manager import SecretsManager
from codegate.providers.copilot.mapping import VALIDATED_ROUTES
from codegate.providers.copilot.pipeline import (
CopilotChatPipeline,
CopilotFimPipeline,
CopilotPipeline,
)
from codegate.providers.copilot.streaming import SSEProcessor

logger = structlog.get_logger("codegate")

Expand Down Expand Up @@ -139,13 +142,13 @@ async def _body_through_pipeline(
path: str,
headers: list[str],
body: bytes,
) -> bytes:
) -> (bytes, PipelineContext):
logger.debug(f"Processing body through pipeline: {len(body)} bytes")
strategy = self._select_pipeline(method, path)
if strategy is None:
# if we didn't select any strategy that would change the request
# let's just pass through the body as-is
return body
return body, None
return await strategy.process_body(headers, body)

async def _request_to_target(self, headers: list[str], body: bytes):
Expand All @@ -154,13 +157,16 @@ async def _request_to_target(self, headers: list[str], body: bytes):
).encode()
logger.debug(f"Request Line: {request_line}")

body = await self._body_through_pipeline(
body, context = await self._body_through_pipeline(
self.request.method,
self.request.path,
headers,
body,
)

if context:
self.context_tracking = context

for header in headers:
if header.lower().startswith("content-length:"):
headers.remove(header)
Expand Down Expand Up @@ -243,12 +249,13 @@ async def _forward_data_through_pipeline(self, data: bytes) -> bytes:
# we couldn't parse this into an HTTP request, so we just pass through
return data

http_request.body = await self._body_through_pipeline(
http_request.body, context = await self._body_through_pipeline(
http_request.method,
http_request.path,
http_request.headers,
http_request.body,
)
self.context_tracking = context

for header in http_request.headers:
if header.lower().startswith("content-length:"):
Expand Down Expand Up @@ -549,15 +556,68 @@ def __init__(self, proxy: CopilotProvider):
self.proxy = proxy
self.transport: Optional[asyncio.Transport] = None

self.headers_sent = False
self.sse_processor: Optional[SSEProcessor] = None
self.output_pipeline_instance: Optional[OutputPipelineInstance] = None

def connection_made(self, transport: asyncio.Transport) -> None:
"""Handle successful connection to target"""
self.transport = transport
self.proxy.target_transport = transport

def _process_chunk(self, chunk: bytes):
records = self.sse_processor.process_chunk(chunk)

for record in records:
if record["type"] == "done":
sse_data = b"data: [DONE]\n\n"
# Add chunk size for DONE message too
chunk_size = hex(len(sse_data))[2:] + "\r\n"
self._proxy_transport_write(chunk_size.encode())
self._proxy_transport_write(sse_data)
self._proxy_transport_write(b"\r\n")
# Now send the final zero chunk
self._proxy_transport_write(b"0\r\n\r\n")
else:
sse_data = f"data: {json.dumps(record['content'])}\n\n".encode("utf-8")
chunk_size = hex(len(sse_data))[2:] + "\r\n"
self._proxy_transport_write(chunk_size.encode())
self._proxy_transport_write(sse_data)
self._proxy_transport_write(b"\r\n")

def _proxy_transport_write(self, data: bytes):
self.proxy.transport.write(data)

def data_received(self, data: bytes) -> None:
"""Handle data received from target"""
if self.proxy.context_tracking is not None and self.sse_processor is None:
logger.debug("Tracking context for pipeline processing")
self.sse_processor = SSEProcessor()
out_pipeline_processor = self.proxy.pipeline_factory.create_output_pipeline()
self.output_pipeline_instance = OutputPipelineInstance(
pipeline_steps=out_pipeline_processor.pipeline_steps,
input_context=self.proxy.context_tracking,
)

if self.proxy.transport and not self.proxy.transport.is_closing():
self.proxy.transport.write(data)
if not self.sse_processor:
# Pass through non-SSE data unchanged
self.proxy.transport.write(data)
return

# Check if this is the first chunk with headers
if not self.headers_sent:
header_end = data.find(b"\r\n\r\n")
if header_end != -1:
self.headers_sent = True
# Send headers first
headers = data[: header_end + 4]
self._proxy_transport_write(headers)
logger.debug(f"Headers sent: {headers}")

data = data[header_end + 4 :]

self._process_chunk(data)

def connection_lost(self, exc: Optional[Exception]) -> None:
"""Handle connection loss to target"""
Expand All @@ -570,3 +630,5 @@ def connection_lost(self, exc: Optional[Exception]) -> None:
self.proxy.transport.close()
except Exception as e:
logger.error(f"Error closing proxy transport: {e}")

# todo: clear the context to erase the sensitive data
55 changes: 55 additions & 0 deletions src/codegate/providers/copilot/streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import json

import structlog

logger = structlog.get_logger("codegate")


class SSEProcessor:
def __init__(self):
self.buffer = ""
self.initial_chunk = True
self.chunk_size = None # Store the original chunk size
self.size_written = False

def process_chunk(self, chunk: bytes) -> list:
print("BUFFER AT START")
print(self.buffer)
print("BUFFER AT START - END")
# Skip any chunk size lines (hex number followed by \r\n)
try:
chunk_str = chunk.decode("utf-8")
lines = chunk_str.split("\r\n")
for line in lines:
if all(c in "0123456789abcdefABCDEF" for c in line.strip()):
continue
self.buffer += line
except UnicodeDecodeError:
print("Failed to decode chunk")

records = []
while True:
record_end = self.buffer.find("\n\n")
if record_end == -1:
print(f"REMAINING BUFFER {self.buffer}")
break

record = self.buffer[:record_end]
self.buffer = self.buffer[record_end + 2 :]

if record.startswith("data: "):
data_content = record[6:]
if data_content.strip() == "[DONE]":
records.append({"type": "done"})
else:
try:
data = json.loads(data_content)
records.append({"type": "data", "content": data})
except json.JSONDecodeError:
print(f"Failed to parse JSON: {data_content}")

return records

def get_pending(self):
"""Return any pending incomplete data in the buffer"""
return self.buffer