Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 189aee9

Browse files
Merge pull request #102 from stacklok/add-fim-pipeline
Add a FIM pipeline to Providers
2 parents f005ecc + 8bb074c commit 189aee9

File tree

9 files changed

+293
-24
lines changed

9 files changed

+293
-24
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from litellm import ChatCompletionRequest
2+
3+
from codegate.codegate_logging import setup_logging
4+
from codegate.pipeline.base import PipelineContext, PipelineResponse, PipelineResult, PipelineStep
5+
6+
logger = setup_logging()
7+
8+
9+
class SecretAnalyzer(PipelineStep):
10+
"""Pipeline step that handles analyzing secrets in FIM pipeline."""
11+
12+
message_blocked = """
13+
⚠️ CodeGate Security Warning! Analysis Report ⚠️
14+
Potential leak of sensitive credentials blocked
15+
16+
Recommendations:
17+
- Use environment variables for secrets
18+
"""
19+
20+
@property
21+
def name(self) -> str:
22+
"""
23+
Returns the name of this pipeline step.
24+
25+
Returns:
26+
str: The identifier 'fim-secret-analyzer'
27+
"""
28+
return "fim-secret-analyzer"
29+
30+
async def process(
31+
self,
32+
request: ChatCompletionRequest,
33+
context: PipelineContext
34+
) -> PipelineResult:
35+
# We should call here Secrets Blocking module to see if the request messages contain secrets
36+
# messages_contain_secrets = [analyze_msg_secrets(msg) for msg in request.messages]
37+
# message_with_secrets = any(messages_contain_secretes)
38+
39+
# For the moment to test shortcutting just treat all messages as if they contain secrets
40+
message_with_secrets = False
41+
if message_with_secrets:
42+
logger.info('Blocking message with secrets.')
43+
return PipelineResult(
44+
response=PipelineResponse(
45+
step_name=self.name,
46+
content=self.message_blocked,
47+
model=request["model"],
48+
),
49+
)
50+
51+
# No messages with secrets, execute the rest of the pipeline
52+
return PipelineResult(request=request)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from typing import AsyncIterator, Optional, Union
2+
3+
from litellm import ChatCompletionRequest, ModelResponse
4+
5+
from codegate.providers.litellmshim import LiteLLmShim
6+
7+
8+
class AnthropicCompletion(LiteLLmShim):
9+
"""
10+
AnthropicCompletion used by the Anthropic provider to execute completions
11+
"""
12+
13+
async def execute_completion(
14+
self,
15+
request: ChatCompletionRequest,
16+
api_key: Optional[str],
17+
stream: bool = False,
18+
) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
19+
"""
20+
Ensures the model name is prefixed with 'anthropic/' to explicitly route to Anthropic's API.
21+
22+
LiteLLM automatically maps most model names, but prepending 'anthropic/' forces the request
23+
to Anthropic. This avoids issues with unrecognized names like 'claude-3-5-sonnet-latest',
24+
which LiteLLM doesn't accept as a valid Anthropic model. This safeguard may be unnecessary
25+
but ensures compatibility.
26+
27+
For more details, refer to the
28+
[LiteLLM Documentation](https://docs.litellm.ai/docs/providers/anthropic).
29+
"""
30+
model_in_request = request['model']
31+
if not model_in_request.startswith('anthropic/'):
32+
request['model'] = f'anthropic/{model_in_request}'
33+
return await super().execute_completion(request, api_key, stream)

src/codegate/providers/anthropic/provider.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,27 @@
11
import json
2+
from typing import Optional
23

34
from fastapi import Header, HTTPException, Request
45

56
from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer
6-
from codegate.providers.base import BaseProvider
7-
from codegate.providers.litellmshim import LiteLLmShim, anthropic_stream_generator
7+
from codegate.providers.anthropic.completion_handler import AnthropicCompletion
8+
from codegate.providers.base import BaseProvider, SequentialPipelineProcessor
9+
from codegate.providers.litellmshim import anthropic_stream_generator
810

911

1012
class AnthropicProvider(BaseProvider):
11-
def __init__(self, pipeline_processor=None):
12-
completion_handler = LiteLLmShim(stream_generator=anthropic_stream_generator)
13+
def __init__(
14+
self,
15+
pipeline_processor: Optional[SequentialPipelineProcessor] = None,
16+
fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None
17+
):
18+
completion_handler = AnthropicCompletion(stream_generator=anthropic_stream_generator)
1319
super().__init__(
1420
AnthropicInputNormalizer(),
1521
AnthropicOutputNormalizer(),
1622
completion_handler,
1723
pipeline_processor,
24+
fim_pipeline_processor
1825
)
1926

2027
@property
@@ -39,5 +46,6 @@ async def create_message(
3946
body = await request.body()
4047
data = json.loads(body)
4148

42-
stream = await self.complete(data, x_api_key)
49+
is_fim_request = self._is_fim_request(request, data)
50+
stream = await self.complete(data, x_api_key, is_fim_request)
4351
return self._completion_handler.create_streaming_response(stream)

src/codegate/providers/base.py

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
from abc import ABC, abstractmethod
22
from typing import Any, AsyncIterator, Callable, Dict, Optional, Union
33

4-
from fastapi import APIRouter
4+
from fastapi import APIRouter, Request
55
from litellm import ModelResponse
66
from litellm.types.llms.openai import ChatCompletionRequest
77

8+
from codegate.codegate_logging import setup_logging
89
from codegate.pipeline.base import PipelineResult, SequentialPipelineProcessor
910
from codegate.providers.completion.base import BaseCompletionHandler
1011
from codegate.providers.formatting.input_pipeline import PipelineResponseFormatter
1112
from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer
1213

14+
logger = setup_logging()
1315
StreamGenerator = Callable[[AsyncIterator[Any]], AsyncIterator[str]]
1416

1517

@@ -25,12 +27,14 @@ def __init__(
2527
output_normalizer: ModelOutputNormalizer,
2628
completion_handler: BaseCompletionHandler,
2729
pipeline_processor: Optional[SequentialPipelineProcessor] = None,
30+
fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None,
2831
):
2932
self.router = APIRouter()
3033
self._completion_handler = completion_handler
3134
self._input_normalizer = input_normalizer
3235
self._output_normalizer = output_normalizer
3336
self._pipeline_processor = pipeline_processor
37+
self._fim_pipelin_processor = fim_pipeline_processor
3438

3539
self._pipeline_response_formatter = PipelineResponseFormatter(output_normalizer)
3640

@@ -48,22 +52,76 @@ def provider_route_name(self) -> str:
4852
async def _run_input_pipeline(
4953
self,
5054
normalized_request: ChatCompletionRequest,
55+
is_fim_request: bool
5156
) -> PipelineResult:
52-
if self._pipeline_processor is None:
57+
# Decide which pipeline processor to use
58+
if is_fim_request:
59+
pipeline_processor = self._fim_pipelin_processor
60+
logger.info('FIM pipeline selected for execution.')
61+
else:
62+
pipeline_processor = self._pipeline_processor
63+
logger.info('Chat completion pipeline selected for execution.')
64+
if pipeline_processor is None:
5365
return PipelineResult(request=normalized_request)
5466

55-
result = await self._pipeline_processor.process_request(normalized_request)
67+
result = await pipeline_processor.process_request(normalized_request)
5668

5769
# TODO(jakub): handle this by returning a message to the client
5870
if result.error_message:
5971
raise Exception(result.error_message)
6072

6173
return result
6274

75+
def _is_fim_request_url(self, request: Request) -> bool:
76+
"""
77+
Checks the request URL to determine if a request is FIM or chat completion.
78+
Used by: llama.cpp
79+
"""
80+
request_path = request.url.path
81+
# Evaluate first a larger substring.
82+
if request_path.endswith("/chat/completions"):
83+
return False
84+
85+
if request_path.endswith("/completions"):
86+
return True
87+
88+
return False
89+
90+
def _is_fim_request_body(self, data: Dict) -> bool:
91+
"""
92+
Determine from the raw incoming data if it's a FIM request.
93+
Used by: OpenAI and Anthropic
94+
"""
95+
messages = data.get('messages', [])
96+
if not messages:
97+
return False
98+
99+
first_message_content = messages[0].get('content')
100+
if first_message_content is None:
101+
return False
102+
103+
fim_stop_sequences = ['</COMPLETION>', '<COMPLETION>', '</QUERY>', '<QUERY>']
104+
if isinstance(first_message_content, str):
105+
msg_prompt = first_message_content
106+
elif isinstance(first_message_content, list):
107+
msg_prompt = first_message_content[0].get('text', '')
108+
else:
109+
logger.warning(f'Could not determine if message was FIM from data: {data}')
110+
return False
111+
return all([stop_sequence in msg_prompt for stop_sequence in fim_stop_sequences])
112+
113+
def _is_fim_request(self, request: Request, data: Dict) -> bool:
114+
"""
115+
Determin if the request is FIM by the URL or the data of the request.
116+
"""
117+
# Avoid more expensive inspection of body by just checking the URL.
118+
if self._is_fim_request_url(request):
119+
return True
120+
121+
return self._is_fim_request_body(data)
122+
63123
async def complete(
64-
self,
65-
data: Dict,
66-
api_key: Optional[str],
124+
self, data: Dict, api_key: Optional[str], is_fim_request: bool
67125
) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
68126
"""
69127
Main completion flow with pipeline integration
@@ -78,8 +136,7 @@ async def complete(
78136
"""
79137
normalized_request = self._input_normalizer.normalize(data)
80138
streaming = data.get("stream", False)
81-
82-
input_pipeline_result = await self._run_input_pipeline(normalized_request)
139+
input_pipeline_result = await self._run_input_pipeline(normalized_request, is_fim_request)
83140
if input_pipeline_result.response:
84141
return self._pipeline_response_formatter.handle_pipeline_response(
85142
input_pipeline_result.response, streaming
@@ -93,7 +150,6 @@ async def complete(
93150
model_response = await self._completion_handler.execute_completion(
94151
provider_request, api_key=api_key, stream=streaming
95152
)
96-
97153
if not streaming:
98154
return self._output_normalizer.denormalize(model_response)
99155
return self._output_normalizer.denormalize_streaming(model_response)

src/codegate/providers/llamacpp/normalizer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
1010
"""
1111
Normalize the input data
1212
"""
13+
# When doing FIM, we receive "prompt" instead of messages. Normalizing.
14+
if "prompt" in data:
15+
data["messages"] = [{"content": data.pop("prompt"), "role": "user"}]
16+
# We can add as many parameters as we like to data. ChatCompletionRequest is not strict.
17+
data["had_prompt_before"] = True
1318
try:
1419
return ChatCompletionRequest(**data)
1520
except Exception as e:
@@ -19,6 +24,11 @@ def denormalize(self, data: ChatCompletionRequest) -> Dict:
1924
"""
2025
Denormalize the input data
2126
"""
27+
# If we receive "prompt" in FIM, we need convert it back.
28+
if data.get("had_prompt_before", False):
29+
data["prompt"] = data["messages"][0]["content"]
30+
del data["had_prompt_before"]
31+
del data["messages"]
2232
return data
2333

2434

src/codegate/providers/llamacpp/provider.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
11
import json
2+
from typing import Optional
23

34
from fastapi import Request
45

5-
from codegate.providers.base import BaseProvider
6+
from codegate.providers.base import BaseProvider, SequentialPipelineProcessor
67
from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler
78
from codegate.providers.llamacpp.normalizer import LLamaCppInputNormalizer, LLamaCppOutputNormalizer
89

910

1011
class LlamaCppProvider(BaseProvider):
11-
def __init__(self, pipeline_processor=None):
12+
def __init__(
13+
self,
14+
pipeline_processor: Optional[SequentialPipelineProcessor] = None,
15+
fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None
16+
):
1217
completion_handler = LlamaCppCompletionHandler()
1318
super().__init__(
1419
LLamaCppInputNormalizer(),
1520
LLamaCppOutputNormalizer(),
1621
completion_handler,
1722
pipeline_processor,
23+
fim_pipeline_processor
1824
)
1925

2026
@property
@@ -34,5 +40,6 @@ async def create_completion(
3440
body = await request.body()
3541
data = json.loads(body)
3642

37-
stream = await self.complete(data, api_key=None)
43+
is_fim_request = self._is_fim_request(request, data)
44+
stream = await self.complete(data, None, is_fim_request=is_fim_request)
3845
return self._completion_handler.create_streaming_response(stream)

src/codegate/providers/openai/provider.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
11
import json
2+
from typing import Optional
23

34
from fastapi import Header, HTTPException, Request
45

5-
from codegate.providers.base import BaseProvider
6+
from codegate.providers.base import BaseProvider, SequentialPipelineProcessor
67
from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator
78
from codegate.providers.openai.adapter import OpenAIInputNormalizer, OpenAIOutputNormalizer
89

910

1011
class OpenAIProvider(BaseProvider):
11-
def __init__(self, pipeline_processor=None):
12+
def __init__(
13+
self,
14+
pipeline_processor: Optional[SequentialPipelineProcessor] = None,
15+
fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None
16+
):
1217
completion_handler = LiteLLmShim(stream_generator=sse_stream_generator)
1318
super().__init__(
1419
OpenAIInputNormalizer(),
1520
OpenAIOutputNormalizer(),
1621
completion_handler,
1722
pipeline_processor,
23+
fim_pipeline_processor
1824
)
1925

2026
@property
@@ -29,6 +35,7 @@ def _setup_routes(self):
2935
"""
3036

3137
@self.router.post(f"/{self.provider_route_name}/chat/completions")
38+
@self.router.post(f"/{self.provider_route_name}/completions")
3239
async def create_completion(
3340
request: Request,
3441
authorization: str = Header(..., description="Bearer token"),
@@ -40,5 +47,6 @@ async def create_completion(
4047
body = await request.body()
4148
data = json.loads(body)
4249

43-
stream = await self.complete(data, api_key)
50+
is_fim_request = self._is_fim_request(request, data)
51+
stream = await self.complete(data, api_key, is_fim_request=is_fim_request)
4452
return self._completion_handler.create_streaming_response(stream)

src/codegate/server.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,28 @@ def init_app() -> FastAPI:
2121
steps: List[PipelineStep] = [
2222
CodegateVersion(),
2323
]
24-
24+
# Leaving the pipeline empty for now
25+
fim_steps: List[PipelineStep] = [
26+
]
2527
pipeline = SequentialPipelineProcessor(steps)
28+
fim_pipeline = SequentialPipelineProcessor(fim_steps)
29+
2630
# Create provider registry
2731
registry = ProviderRegistry(app)
2832

2933
# Register all known providers
30-
registry.add_provider("openai", OpenAIProvider(pipeline_processor=pipeline))
31-
registry.add_provider("anthropic", AnthropicProvider(pipeline_processor=pipeline))
32-
registry.add_provider("llamacpp", LlamaCppProvider(pipeline_processor=pipeline))
34+
registry.add_provider("openai", OpenAIProvider(
35+
pipeline_processor=pipeline,
36+
fim_pipeline_processor=fim_pipeline
37+
))
38+
registry.add_provider("anthropic", AnthropicProvider(
39+
pipeline_processor=pipeline,
40+
fim_pipeline_processor=fim_pipeline
41+
))
42+
registry.add_provider("llamacpp", LlamaCppProvider(
43+
pipeline_processor=pipeline,
44+
fim_pipeline_processor=fim_pipeline
45+
))
3346

3447
# Create and add system routes
3548
system_router = APIRouter(tags=["System"]) # Tags group endpoints in the docs

0 commit comments

Comments
 (0)