Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

A hotfix for the FIM pipeline #353

Merged
merged 1 commit into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/codegate/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def __init__(
self.secret_manager = secret_manager
self.is_fim = is_fim
self.context = PipelineContext()
self.context.metadata["is_fim"] = is_fim

async def process_request(
self,
Expand Down
38 changes: 30 additions & 8 deletions src/codegate/providers/copilot/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,35 @@ def connection_made(self, transport: asyncio.Transport) -> None:
self.transport = transport
self.proxy.target_transport = transport

def _ensure_output_processor(self) -> None:
if self.proxy.context_tracking is None:
# No context tracking, no need to process pipeline
return

if self.sse_processor is not None:
# Already initialized, no need to reinitialize
return

# this is a hotfix - we shortcut before selecting the output pipeline for FIM
# because our FIM output pipeline is actually empty as of now. We should fix this
# but don't have any immediate need.
is_fim = self.proxy.context_tracking.metadata.get("is_fim", False)
if is_fim:
return

logger.debug("Tracking context for pipeline processing")
self.sse_processor = SSEProcessor()
is_fim = self.proxy.context_tracking.metadata.get("is_fim", False)
if is_fim:
out_pipeline_processor = self.proxy.pipeline_factory.create_fim_output_pipeline()
else:
out_pipeline_processor = self.proxy.pipeline_factory.create_output_pipeline()

self.output_pipeline_instance = OutputPipelineInstance(
pipeline_steps=out_pipeline_processor.pipeline_steps,
input_context=self.proxy.context_tracking,
)

async def _process_stream(self):
try:

Expand Down Expand Up @@ -633,14 +662,7 @@ def _proxy_transport_write(self, data: bytes):

def data_received(self, data: bytes) -> None:
"""Handle data received from target"""
if self.proxy.context_tracking is not None and self.sse_processor is None:
logger.debug("Tracking context for pipeline processing")
self.sse_processor = SSEProcessor()
out_pipeline_processor = self.proxy.pipeline_factory.create_output_pipeline()
self.output_pipeline_instance = OutputPipelineInstance(
pipeline_steps=out_pipeline_processor.pipeline_steps,
input_context=self.proxy.context_tracking,
)
self._ensure_output_processor()

if self.proxy.transport and not self.proxy.transport.is_closing():
if not self.sse_processor:
Expand Down