Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit e890340

Browse files
authored
Merge pull request #318 from stacklok/inspect-all-user-messages-for-malicious-pkgs
Inspect all user messages for malicious packages
2 parents 36e7fb3 + 16f240b commit e890340

File tree

4 files changed

+46
-46
lines changed

4 files changed

+46
-46
lines changed

prompts/default.yaml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,20 @@ default_chat: |
1919
if no malicious, deprecated or archived package is detected.
2020
2121
If no malicious, deprecated or archived packages are detected, respond to the user request using your prior
22-
knowledge. You always begin your responses with:
23-
"**CodeGate Security Analysis**"
24-
Structure your responses to emphasize security considerations, focusing on:
22+
knowledge.
23+
24+
If the user requests security review, perform the review and structure your responses to emphasize
25+
security considerations, focusing on:
2526
1. Security vulnerabilities and potential risks
2627
2. Best practices for secure implementation
2728
3. Recommendations for security improvements
2829
4. References to relevant security standards or guidelines
2930
31+
You always begin your responses for security review with:
32+
"**CodeGate Security Analysis**"
33+
34+
If you see a string that begins with REDACTED word, DO NOT MODIFY THAT STRING while responding back.
35+
3036
lookup_packages: |
3137
You are a software expert with knowledge of packages from various ecosystems.
3238
Your job is to extract any software packages from user's request.

src/codegate/pipeline/base.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -178,29 +178,30 @@ def get_last_user_message(
178178
for i in reversed(range(len(request["messages"]))):
179179
if request["messages"][i]["role"] == "user":
180180
content = request["messages"][i]["content"]
181+
return content, i
181182

182-
# This is really another LiteLLM weirdness. Depending on the
183-
# provider inside the ChatCompletionRequest you might either
184-
# have a string or a list of Union, one of which is a
185-
# ChatCompletionTextObject. We'll handle this better by
186-
# either dumping litellm completely or converting to a more sane
187-
# format # in our own adapter
183+
return None
188184

189-
# Handle string content
190-
if isinstance(content, str):
191-
return content, i
185+
@staticmethod
186+
def get_last_user_message_idx(request: ChatCompletionRequest) -> int:
187+
if request.get("messages") is None:
188+
return -1
192189

193-
# Handle iterable of ChatCompletionTextObject
194-
if isinstance(content, (list, tuple)):
195-
# Find first text content
196-
for item in content:
197-
if isinstance(item, dict) and item.get("type") == "text":
198-
return item["text"], i
190+
for idx, message in reversed(list(enumerate(request['messages']))):
191+
if message.get("role", "") == "user":
192+
return idx
199193

200-
# If no text content found, return None
201-
return None
194+
return -1
202195

203-
return None
196+
@staticmethod
197+
def get_all_user_messages(request: ChatCompletionRequest) -> str:
198+
all_user_messages = ""
199+
200+
for message in request.get("messages", []):
201+
if message["role"] == "user":
202+
all_user_messages += "\n" + message["content"]
203+
204+
return all_user_messages
204205

205206
@abstractmethod
206207
async def process(

src/codegate/pipeline/codegate_context_retriever/codegate.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,16 @@ async def process(
9393
Use RAG DB to add context to the user request
9494
"""
9595

96-
# Get the last user message
97-
last_user_message = self.get_last_user_message(request)
96+
# Get all user messages
97+
user_messages = self.get_all_user_messages(request)
9898

99-
# Nothing to do if the last user message is none
100-
if last_user_message is None:
99+
# Nothing to do if the user_messages string is empty
100+
if len(user_messages) == 0:
101101
return PipelineResult(request=request)
102102

103103
# Extract packages from the user message
104-
last_user_message_str, last_user_idx = last_user_message
105-
ecosystem = await self.__lookup_ecosystem(last_user_message_str, context)
106-
packages = await self.__lookup_packages(last_user_message_str, context)
104+
ecosystem = await self.__lookup_ecosystem(user_messages, context)
105+
packages = await self.__lookup_packages(user_messages, context)
107106
packages = [pkg.lower() for pkg in packages]
108107

109108
# If user message does not reference any packages, then just return
@@ -112,7 +111,7 @@ async def process(
112111

113112
# Look for matches in vector DB using list of packages as filter
114113
searched_objects = await self.get_objects_from_search(
115-
last_user_message_str, ecosystem, packages
114+
user_messages, ecosystem, packages
116115
)
117116

118117
logger.info(
@@ -136,24 +135,18 @@ async def process(
136135
else:
137136
context_str = "Codegate did not find any malicious or archived packages."
138137

138+
last_user_idx = self.get_last_user_message_idx(request)
139+
if last_user_idx == -1:
140+
return PipelineResult(request=request, context=context)
141+
139142
# Make a copy of the request
140143
new_request = request.copy()
141144

142145
# Add the context to the last user message
143-
# Format: "Context: {context_str} \n Query: {last user message conent}"
144-
# Handle the two cases: (a) message content is str, (b)message content
145-
# is list
146+
# Format: "Context: {context_str} \n Query: {last user message content}"
146147
message = new_request["messages"][last_user_idx]
147-
if isinstance(message["content"], str):
148-
context_msg = f'Context: {context_str} \n\n Query: {message["content"]}'
149-
message["content"] = context_msg
150-
elif isinstance(message["content"], (list, tuple)):
151-
for item in message["content"]:
152-
if isinstance(item, dict) and item.get("type") == "text":
153-
context_msg = f'Context: {context_str} \n\n Query: {item["text"]}'
154-
item["text"] = context_msg
155-
156-
return PipelineResult(request=new_request, context=context)
157-
158-
# Fall through
159-
return PipelineResult(request=request, context=context)
148+
context_msg = f'Context: {context_str} \n\n Query: {message["content"]}'
149+
message["content"] = context_msg
150+
151+
return PipelineResult(request=new_request, context=context)
152+

src/codegate/providers/copilot/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class CopilotPipeline(ABC):
2121
def __init__(self, pipeline_factory):
2222
self.pipeline_factory = pipeline_factory
2323
self.normalizer = self._create_normalizer()
24-
self.provider_name = "copilot"
24+
self.provider_name = "openai"
2525

2626
@abstractmethod
2727
def _create_normalizer(self):

0 commit comments

Comments
 (0)