Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

improve TLS handling with SNI support and cert caching #432

Merged
merged 4 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
735 changes: 479 additions & 256 deletions src/codegate/ca/codegate_ca.py

Large diffs are not rendered by default.

29 changes: 20 additions & 9 deletions src/codegate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,12 @@ def serve(
# Check certificates and create CA if necessary
logger.info("Checking certificates and creating CA if needed")
ca = CertificateAuthority.get_instance()
ca.ensure_certificates_exist()

certs_check = ca.check_and_ensure_certificates()
if certs_check:
click.echo("New Certificates generated successfully.")
else:
click.echo("Existing Certificates are already present.")

# Initialize secrets manager and pipeline factory
secrets_manager = SecretsManager()
Expand Down Expand Up @@ -452,7 +457,10 @@ def restore_backup(backup_path: Path, backup_name: str) -> None:
"--force-certs",
is_flag=True,
default=False,
help="Force the generation of certificates even if they already exist.",
help=(
"Force the generation of certificates even if they already exist. "
"Warning: this will overwrite existing certificates."
),
)
@click.option(
"--log-level",
Expand Down Expand Up @@ -488,17 +496,20 @@ def generate_certs(
cli_log_format=log_format,
)
setup_logging(cfg.log_level, cfg.log_format)
logger = structlog.get_logger("codegate").bind(origin="cli")

ca = CertificateAuthority.get_instance()
should_generate = force_certs or not ca.check_certificates_exist()

if should_generate:
ca.generate_certificates()
click.echo("Certificates generated successfully.")
click.echo(f"Certificates saved to: {cfg.certs_dir}")
click.echo("Make sure to add the new CA certificate to the operating system trust store.")
# Remove and regenerate certificates if forced; otherwise, just ensure they exist
logger.info("Checking certificates and creating certs if needed")
if force_certs:
ca.remove_certificates()

certs_check = ca.check_and_ensure_certificates()
if certs_check:
logger.info("New Certificates generated successfully.")
else:
click.echo("Certificates already exist. Skipping generation...")
logger.info("Existing Certificates are already present.")


def main() -> None:
Expand Down
6 changes: 3 additions & 3 deletions src/codegate/codegate_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def _missing_(cls, value: str) -> Optional["LogFormat"]:

def add_origin(logger, log_method, event_dict):
# Add 'origin' if it's bound to the logger but not explicitly in the event dict
if 'origin' not in event_dict and hasattr(logger, '_context'):
origin = logger._context.get('origin')
if "origin" not in event_dict and hasattr(logger, "_context"):
origin = logger._context.get("origin")
if origin:
event_dict['origin'] = origin
event_dict["origin"] = origin
return event_dict


Expand Down
3 changes: 2 additions & 1 deletion src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ async def record_request(self, prompt_params: Optional[Prompt] = None) -> Option
"""
)
recorded_request = await self._insert_pydantic_model(prompt_params, sql)
logger.debug(f"Recorded request: {recorded_request}")
# Uncomment to debug the recorded request
# logger.debug(f"Recorded request: {recorded_request}")
return recorded_request

async def record_outputs(self, outputs: List[Output]) -> Optional[Output]:
Expand Down
3 changes: 2 additions & 1 deletion src/codegate/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def add_input_request(
type="fim" if is_fim_request else "chat",
request=request_str,
)
logger.debug(f"Added input request to context: {self.input_request}")
# Uncomment the below to debug the input
# logger.debug(f"Added input request to context: {self.input_request}")
except Exception as e:
logger.warning(f"Failed to serialize input request: {normalized_request}", error=str(e))

Expand Down
3 changes: 2 additions & 1 deletion src/codegate/providers/copilot/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ async def process_body(self, headers: list[str], body: bytes) -> Tuple[bytes, Pi
# the pipeline did modify the request, return to the user
# in the original LLM format
body = self.normalizer.denormalize(result.request)
logger.debug(f"Pipeline processed request: {body}")
# Uncomment the below to debug the request
# logger.debug(f"Pipeline processed request: {body}")

return body, result.context
except Exception as e:
Expand Down
34 changes: 13 additions & 21 deletions src/codegate/providers/copilot/provider.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import asyncio
import re
import ssl
from src.codegate.codegate_logging import setup_logging
import structlog
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
from urllib.parse import unquote, urljoin, urlparse

import structlog
from litellm.types.utils import Delta, ModelResponse, StreamingChoices

from codegate.ca.codegate_ca import CertificateAuthority
from codegate.ca.codegate_ca import CertificateAuthority, TLSCertDomainManager
from codegate.config import Config
from codegate.pipeline.base import PipelineContext
from codegate.pipeline.factory import PipelineFactory
Expand All @@ -22,6 +21,7 @@
CopilotPipeline,
)
from codegate.providers.copilot.streaming import SSEProcessor
from src.codegate.codegate_logging import setup_logging

setup_logging()
logger = structlog.get_logger("codegate").bind(origin="copilot_proxy")
Expand Down Expand Up @@ -147,6 +147,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop):
self.ssl_context: Optional[ssl.SSLContext] = None
self.proxy_ep: Optional[str] = None
self.ca = CertificateAuthority.get_instance()
self.cert_manager = TLSCertDomainManager(self.ca)
self._closing = False
self.pipeline_factory = PipelineFactory(SecretsManager())
self.context_tracking: Optional[PipelineContext] = None
Expand Down Expand Up @@ -206,7 +207,7 @@ async def _request_to_target(self, headers: list[str], body: bytes):
logger.debug("=" * 40)

for i in range(0, len(body), CHUNK_SIZE):
chunk = body[i: i + CHUNK_SIZE]
chunk = body[i : i + CHUNK_SIZE]
self.target_transport.write(chunk)

def connection_made(self, transport: asyncio.Transport) -> None:
Expand Down Expand Up @@ -269,9 +270,7 @@ def _check_buffer_size(self, new_data: bytes) -> bool:
"""Check if adding new data would exceed buffer size limit"""
return len(self.buffer) + len(new_data) <= MAX_BUFFER_SIZE

async def _forward_data_through_pipeline(
self, data: bytes
) -> Union[HttpRequest, HttpResponse]:
async def _forward_data_through_pipeline(self, data: bytes) -> Union[HttpRequest, HttpResponse]:
http_request = http_request_from_bytes(data)
if not http_request:
# we couldn't parse this into an HTTP request, so we just pass through
Expand All @@ -287,7 +286,7 @@ async def _forward_data_through_pipeline(

if context and context.shortcut_response:
# Send shortcut response
data_prefix = b'data:'
data_prefix = b"data:"
http_response = HttpResponse(
http_request.version,
200,
Expand All @@ -299,7 +298,7 @@ async def _forward_data_through_pipeline(
"Content-Type: application/json",
"Transfer-Encoding: chunked",
],
data_prefix + body
data_prefix + body,
)
return http_response

Expand Down Expand Up @@ -496,8 +495,8 @@ def handle_connect(self) -> None:
self.target_host, port = path.split(":")
self.target_port = int(port)

cert_path, key_path = self.ca.get_domain_certificate(self.target_host)
self.ssl_context = self._create_ssl_context(cert_path, key_path)
# Get SSL context through the TLS handler
self.ssl_context = self.cert_manager.get_domain_context(self.target_host)

self.is_connect = True
asyncio.create_task(self.connect_to_target())
Expand All @@ -507,13 +506,6 @@ def handle_connect(self) -> None:
logger.error(f"Error handling CONNECT: {e}")
self.send_error_response(502, str(e).encode())

def _create_ssl_context(self, cert_path: str, key_path: str) -> ssl.SSLContext:
"""Create SSL context for CONNECT tunneling"""
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(cert_path, key_path)
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
return ssl_context

async def connect_to_target(self) -> None:
"""Establish connection to target for CONNECT requests"""
try:
Expand Down Expand Up @@ -618,7 +610,7 @@ async def run_proxy_server(cls) -> None:
"""Run the proxy server"""
try:
ca = CertificateAuthority.get_instance()
ssl_context = ca.create_ssl_context()
ssl_context = ca.create_server_ssl_context()
config = Config.get_config()
server = await cls.create_proxy_server(config.host, config.proxy_port, ssl_context)

Expand All @@ -639,7 +631,7 @@ async def get_target_url(path: str) -> Optional[str]:
# Check for prefix match
for route in VALIDATED_ROUTES:
# For prefix matches, keep the rest of the path
remaining_path = path[len(route.path):]
remaining_path = path[len(route.path) :]
logger.debug(f"Remaining path: {remaining_path}")
# Make sure we don't end up with double slashes
if remaining_path and remaining_path.startswith("/"):
Expand Down Expand Up @@ -793,7 +785,7 @@ def data_received(self, data: bytes) -> None:
self._proxy_transport_write(headers)
logger.debug(f"Headers sent: {headers}")

data = data[header_end + 4:]
data = data[header_end + 4 :]

self._process_chunk(data)

Expand Down