diff --git a/src/codegate/pipeline/codegate_context_retriever/codegate.py b/src/codegate/pipeline/codegate_context_retriever/codegate.py index 27dbcce3..e22874a6 100644 --- a/src/codegate/pipeline/codegate_context_retriever/codegate.py +++ b/src/codegate/pipeline/codegate_context_retriever/codegate.py @@ -38,18 +38,25 @@ def name(self) -> str: """ return "codegate-context-retriever" - def generate_context_str(self, objects: list[object], context: PipelineContext) -> str: + def generate_context_str( + self, objects: list[object], context: PipelineContext, snippet_map: dict + ) -> str: context_str = "" matched_packages = [] for obj in objects: # The object is already a dictionary with 'properties' package_obj = obj["properties"] # type: ignore matched_packages.append(f"{package_obj['name']} ({package_obj['type']})") + + # Retrieve the related snippet if it exists + code_snippet = snippet_map.get(package_obj["name"]) + # Add one alert for each package found context.add_alert( self.name, trigger_string=json.dumps(package_obj), severity_category=AlertSeverity.CRITICAL, + code_snippet=code_snippet, ) package_str = generate_vector_string(package_obj) context_str += package_str + "\n" @@ -80,14 +87,18 @@ async def process( # noqa: C901 snippets = extractor.extract_snippets(user_message) bad_snippet_packages = [] - if len(snippets) > 0: + snippet_map = {} + if snippets and len(snippets) > 0: snippet_language = snippets[0].language # Collect all packages referenced in the snippets snippet_packages = [] for snippet in snippets: - snippet_packages.extend( - PackageExtractor.extract_packages(snippet.code, snippet.language) # type: ignore + extracted_packages = PackageExtractor.extract_packages( + snippet.code, snippet.language ) + snippet_packages.extend(extracted_packages) + for package in extracted_packages: + snippet_map[package] = snippet logger.info( f"Found {len(snippet_packages)} packages " @@ -127,7 +138,7 @@ async def process( # noqa: C901 return PipelineResult(request=request, context=context) else: # Add context for bad packages - context_str = self.generate_context_str(all_bad_packages, context) + context_str = self.generate_context_str(all_bad_packages, context, snippet_map) context.bad_packages_found = True # Make a copy of the request