Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Add FIM processing for Copilot #311

Merged
merged 3 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/codegate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from codegate.codegate_logging import LogFormat, LogLevel, setup_logging
from codegate.config import Config, ConfigurationError
from codegate.db.connection import init_db_sync
from codegate.pipeline.factory import PipelineFactory
from codegate.pipeline.secrets.manager import SecretsManager
from codegate.providers.copilot.provider import CopilotProvider
from codegate.server import init_app
from codegate.storage.utils import restore_storage_backup
Expand Down Expand Up @@ -307,7 +309,11 @@ def serve(
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

app = init_app()
# Initialize secrets manager and pipeline factory
secrets_manager = SecretsManager()
pipeline_factory = PipelineFactory(secrets_manager)

app = init_app(pipeline_factory)

# Run the server
try:
Expand Down
56 changes: 56 additions & 0 deletions src/codegate/pipeline/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import List

from codegate.config import Config
from codegate.pipeline.base import PipelineStep, SequentialPipelineProcessor
from codegate.pipeline.codegate_context_retriever.codegate import CodegateContextRetriever
from codegate.pipeline.extract_snippets.extract_snippets import CodeSnippetExtractor
from codegate.pipeline.extract_snippets.output import CodeCommentStep
from codegate.pipeline.output import OutputPipelineProcessor, OutputPipelineStep
from codegate.pipeline.secrets.manager import SecretsManager
from codegate.pipeline.secrets.secrets import (
CodegateSecrets,
SecretRedactionNotifier,
SecretUnredactionStep,
)
from codegate.pipeline.system_prompt.codegate import SystemPrompt
from codegate.pipeline.version.version import CodegateVersion


class PipelineFactory:
def __init__(self, secrets_manager: SecretsManager):
self.secrets_manager = secrets_manager

def create_input_pipeline(self) -> SequentialPipelineProcessor:
input_steps: List[PipelineStep] = [
# make sure that this step is always first in the pipeline
# the other steps might send the request to a LLM for it to be analyzed
# and without obfuscating the secrets, we'd leak the secrets during those
# later steps
CodegateSecrets(),
CodegateVersion(),
CodeSnippetExtractor(),
SystemPrompt(Config.get_config().prompts.default_chat),
CodegateContextRetriever(),
]
return SequentialPipelineProcessor(input_steps, self.secrets_manager)

def create_fim_pipeline(self) -> SequentialPipelineProcessor:
fim_steps: List[PipelineStep] = [
CodegateSecrets(),
]
return SequentialPipelineProcessor(fim_steps, self.secrets_manager)

def create_output_pipeline(self) -> OutputPipelineProcessor:
output_steps: List[OutputPipelineStep] = [
SecretRedactionNotifier(),
SecretUnredactionStep(),
CodeCommentStep(),
]
return OutputPipelineProcessor(output_steps)

def create_fim_output_pipeline(self) -> OutputPipelineProcessor:
fim_output_steps: List[OutputPipelineStep] = [
# temporarily disabled
# SecretUnredactionStep(),
]
return OutputPipelineProcessor(fim_output_steps)
99 changes: 99 additions & 0 deletions src/codegate/providers/copilot/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import json
from abc import ABC, abstractmethod

import structlog
from litellm.types.llms.openai import ChatCompletionRequest

from codegate.providers.normalizer.completion import CompletionNormalizer

logger = structlog.get_logger("codegate")


class CopilotPipeline(ABC):
"""
A CopilotPipeline puts together a normalizer to be able to pass
a request to the pipeline in a normalized format, and a pipeline
factory to create the pipeline itself and run the request
"""

def __init__(self, pipeline_factory):
self.pipeline_factory = pipeline_factory
self.normalizer = self._create_normalizer()
self.provider_name = "copilot"

@abstractmethod
def _create_normalizer(self):
"""Each strategy defines which normalizer to use"""
pass

@abstractmethod
def create_pipeline(self):
"""Each strategy defines which pipeline to create"""
pass

@staticmethod
def _request_id(headers: list[str]) -> str:
"""Extracts the request ID from the headers"""
for header in headers:
if header.startswith("x-request-id"):
print(f"Request ID found in headers: {header}")
return header.split(":")[1].strip()
print("No request ID found in headers")
return ""

async def process_body(self, headers: list[str], body: bytes) -> bytes:
"""Common processing logic for all strategies"""
try:
normalized_body = self.normalizer.normalize(body)

pipeline = self.create_pipeline()
result = await pipeline.process_request(
request=normalized_body,
provider=self.provider_name,
prompt_id=self._request_id(headers),
model=normalized_body.get("model", ""),
api_key=None,
)

if result.request:
# the pipeline did modify the request, return to the user
# in the original LLM format
body = self.normalizer.denormalize(result.request)
logger.info(f"Pipeline processed request: {body}")

return body
except Exception as e:
logger.error(f"Pipeline processing error: {e}")
return body


class CopilotFimNormalizer:
"""
A custom normalizer for the FIM format used by Copilot
We reuse the normalizer for "prompt" format, but we need to
load the body first and then encode on the way back.
"""

def __init__(self):
self._completion_normalizer = CompletionNormalizer()

def normalize(self, body: bytes) -> ChatCompletionRequest:
json_body = json.loads(body)
return self._completion_normalizer.normalize(json_body)

def denormalize(self, request_from_pipeline: ChatCompletionRequest) -> bytes:
normalized_json_body = self._completion_normalizer.denormalize(request_from_pipeline)
return json.dumps(normalized_json_body).encode()


class CopilotFimPipeline(CopilotPipeline):
"""
A pipeline for the FIM format used by Copilot. Combines the normalizer for the FIM
format and the FIM pipeline used by all providers.
"""

def _create_normalizer(self):
return CopilotFimNormalizer() # Uses your custom normalizer

def create_pipeline(self):
return self.pipeline_factory.create_fim_pipeline()
66 changes: 51 additions & 15 deletions src/codegate/providers/copilot/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@

from codegate.ca.codegate_ca import CertificateAuthority
from codegate.config import Config
from codegate.pipeline.base import PipelineContext
from codegate.pipeline.factory import PipelineFactory
from codegate.pipeline.secrets.manager import SecretsManager
from codegate.providers.copilot.mapping import VALIDATED_ROUTES
from codegate.providers.copilot.pipeline import CopilotFimPipeline

logger = structlog.get_logger("codegate")

Expand Down Expand Up @@ -56,6 +60,52 @@ def __init__(self, loop: asyncio.AbstractEventLoop):
self.proxy_ep: Optional[str] = None
self.ca = CertificateAuthority.get_instance()
self._closing = False
self.pipeline_factory = PipelineFactory(SecretsManager())
self.context_tracking: Optional[PipelineContext] = None

def _select_pipeline(self):
if (
self.request.method == "POST"
and self.request.path == "v1/engines/copilot-codex/completions"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

request.path or request.original_path?

):
logger.debug("Selected CopilotFimStrategy")
return CopilotFimPipeline(self.pipeline_factory)

logger.debug("No pipeline strategy selected")
return None

async def _body_through_pipeline(self, headers: list[str], body: bytes) -> bytes:
logger.debug(f"Processing body through pipeline: {len(body)} bytes")
strategy = self._select_pipeline()
if strategy is None:
# if we didn't select any strategy that would change the request
# let's just pass through the body as-is
return body
return await strategy.process_body(headers, body)

async def _request_to_target(self, headers: list[str], body: bytes):
request_line = (
f"{self.request.method} /{self.request.path} {self.request.version}\r\n"
).encode()
logger.debug(f"Request Line: {request_line}")

body = await self._body_through_pipeline(headers, body)

for header in headers:
if header.lower().startswith("content-length:"):
headers.remove(header)
break
headers.append(f"Content-Length: {len(body)}")

header_block = "\r\n".join(headers).encode()
headers_request_block = request_line + header_block + b"\r\n\r\n"
logger.debug("=" * 40)
self.target_transport.write(headers_request_block)
logger.debug("=" * 40)

for i in range(0, len(body), CHUNK_SIZE):
chunk = body[i : i + CHUNK_SIZE]
self.target_transport.write(chunk)

def connection_made(self, transport: asyncio.Transport) -> None:
"""Handle new client connection"""
Expand Down Expand Up @@ -192,24 +242,10 @@ async def handle_http_request(self) -> None:
if not has_host:
new_headers.append(f"Host: {self.target_host}")

request_line = (
f"{self.request.method} /{self.request.path} {self.request.version}\r\n"
).encode()
logger.debug(f"Request Line: {request_line}")
header_block = "\r\n".join(new_headers).encode()
headers = request_line + header_block + b"\r\n\r\n"

if self.target_transport:
logger.debug("=" * 40)
self.target_transport.write(headers)
logger.debug("=" * 40)

body_start = self.buffer.index(b"\r\n\r\n") + 4
body = self.buffer[body_start:]

for i in range(0, len(body), CHUNK_SIZE):
chunk = body[i : i + CHUNK_SIZE]
self.target_transport.write(chunk)
await self._request_to_target(new_headers, body)
else:
logger.debug("=" * 40)
logger.error("Target transport not available")
Expand Down
Loading