@@ -59,44 +59,53 @@ async def process(
59
59
"""
60
60
Use RAG DB to add context to the user request
61
61
"""
62
- # Get the latest user messages
63
- user_messages = self .get_latest_user_messages (request )
64
-
65
- # Nothing to do if the user_messages string is empty
66
- if len (user_messages ) == 0 :
62
+ # Get the latest user message
63
+ user_message = self .get_last_user_message_block (request )
64
+ if not user_message :
67
65
return PipelineResult (request = request )
68
66
69
67
# Create storage engine object
70
68
storage_engine = StorageEngine ()
71
69
72
70
# Extract any code snippets
73
- snippets = extract_snippets (user_messages )
71
+ snippets = extract_snippets (user_message )
74
72
75
73
bad_snippet_packages = []
76
74
if len (snippets ) > 0 :
75
+ snippet_language = snippets [0 ].language
77
76
# Collect all packages referenced in the snippets
78
77
snippet_packages = []
79
78
for snippet in snippets :
80
79
snippet_packages .extend (
81
- PackageExtractor .extract_packages (snippet .code , snippet .language )
80
+ PackageExtractor .extract_packages (snippet .code , snippet .language ) # type: ignore
82
81
)
83
- logger .info (f"Found { len (snippet_packages )} packages in code snippets." )
84
82
83
+ logger .info (
84
+ f"Found { len (snippet_packages )} packages "
85
+ f"for language { snippet_language } in code snippets."
86
+ )
85
87
# Find bad packages in the snippets
86
88
bad_snippet_packages = await storage_engine .search (
87
- language = snippets [ 0 ]. language , packages = snippet_packages
88
- )
89
+ language = snippet_language , packages = snippet_packages
90
+ ) # type: ignore
89
91
logger .info (f"Found { len (bad_snippet_packages )} bad packages in code snippets." )
90
92
91
93
# Remove code snippets from the user messages and search for bad packages
92
94
# in the rest of the user query/messsages
93
- user_messages = re .sub (r"```.*?```" , "" , user_messages , flags = re .DOTALL )
94
-
95
- # Vector search to find bad packages
96
- bad_packages = await storage_engine .search (query = user_messages , distance = 0.5 , limit = 100 )
95
+ user_messages = re .sub (r"```.*?```" , "" , user_message , flags = re .DOTALL )
96
+ user_messages = re .sub (r"⋮...*?⋮...\n\n" , "" , user_messages , flags = re .DOTALL )
97
+
98
+ # split messages into double newlines, to avoid passing so many content in the search
99
+ split_messages = user_messages .split ("\n \n " )
100
+ collected_bad_packages = []
101
+ for item_message in split_messages :
102
+ # Vector search to find bad packages
103
+ bad_packages = await storage_engine .search (query = item_message , distance = 0.5 , limit = 100 )
104
+ if bad_packages and len (bad_packages ) > 0 :
105
+ collected_bad_packages .extend (bad_packages )
97
106
98
107
# All bad packages
99
- all_bad_packages = bad_snippet_packages + bad_packages
108
+ all_bad_packages = bad_snippet_packages + collected_bad_packages
100
109
101
110
logger .info (f"Adding { len (all_bad_packages )} bad packages to the context." )
102
111
@@ -119,7 +128,7 @@ async def process(
119
128
# Add the context to the last user message
120
129
# Format: "Context: {context_str} \n Query: {last user message content}"
121
130
message = new_request ["messages" ][last_user_idx ]
122
- context_msg = f'Context: { context_str } \n \n Query: { message ["content" ]} '
131
+ context_msg = f'Context: { context_str } \n \n Query: { message ["content" ]} ' # type: ignore
123
132
message ["content" ] = context_msg
124
133
125
134
logger .debug ("Final context message" , context_message = context_msg )
0 commit comments