Skip to content

Commit 7c63ddf

Browse files
authored
feat: add code snippet for malicious packages (#1146)
When generating an alert for a malicious package, we can associate with the code snippet that brought it, and provide this info to the user Related-to: #423
1 parent 66b19fc commit 7c63ddf

File tree

1 file changed

+16
-5
lines changed
  • src/codegate/pipeline/codegate_context_retriever

1 file changed

+16
-5
lines changed

src/codegate/pipeline/codegate_context_retriever/codegate.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,25 @@ def name(self) -> str:
3838
"""
3939
return "codegate-context-retriever"
4040

41-
def generate_context_str(self, objects: list[object], context: PipelineContext) -> str:
41+
def generate_context_str(
42+
self, objects: list[object], context: PipelineContext, snippet_map: dict
43+
) -> str:
4244
context_str = ""
4345
matched_packages = []
4446
for obj in objects:
4547
# The object is already a dictionary with 'properties'
4648
package_obj = obj["properties"] # type: ignore
4749
matched_packages.append(f"{package_obj['name']} ({package_obj['type']})")
50+
51+
# Retrieve the related snippet if it exists
52+
code_snippet = snippet_map.get(package_obj["name"])
53+
4854
# Add one alert for each package found
4955
context.add_alert(
5056
self.name,
5157
trigger_string=json.dumps(package_obj),
5258
severity_category=AlertSeverity.CRITICAL,
59+
code_snippet=code_snippet,
5360
)
5461
package_str = generate_vector_string(package_obj)
5562
context_str += package_str + "\n"
@@ -80,14 +87,18 @@ async def process( # noqa: C901
8087
snippets = extractor.extract_snippets(user_message)
8188

8289
bad_snippet_packages = []
83-
if len(snippets) > 0:
90+
snippet_map = {}
91+
if snippets and len(snippets) > 0:
8492
snippet_language = snippets[0].language
8593
# Collect all packages referenced in the snippets
8694
snippet_packages = []
8795
for snippet in snippets:
88-
snippet_packages.extend(
89-
PackageExtractor.extract_packages(snippet.code, snippet.language) # type: ignore
96+
extracted_packages = PackageExtractor.extract_packages(
97+
snippet.code, snippet.language
9098
)
99+
snippet_packages.extend(extracted_packages)
100+
for package in extracted_packages:
101+
snippet_map[package] = snippet
91102

92103
logger.info(
93104
f"Found {len(snippet_packages)} packages "
@@ -127,7 +138,7 @@ async def process( # noqa: C901
127138
return PipelineResult(request=request, context=context)
128139
else:
129140
# Add context for bad packages
130-
context_str = self.generate_context_str(all_bad_packages, context)
141+
context_str = self.generate_context_str(all_bad_packages, context, snippet_map)
131142
context.bad_packages_found = True
132143

133144
# Make a copy of the request

0 commit comments

Comments
 (0)