Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions core/example_workflows/talk_to_file_rag_config_workflow.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"max_files": 20,
"llm_config": { "temperature": 0.3, "max_context_tokens": 20000 },
"max_history": 10,
"reranker_config":
{ "model": "rerank-v3.5", "top_n": 10, "supplier": "cohere" },
"workflow_config":
{
"name": "Standard RAG",
"nodes":
[
{
"name": "START",
"edges": ["filter_history"],
"description": "Starting workflow",
},
{
"name": "filter_history",
"edges": ["retrieve"],
"description": "Filtering history",
},
{
"name": "retrieve",
"edges": ["retrieve_full_documents_context"],
"description": "Retrieving relevant information",
},
{
"name": "retrieve_full_documents_context",
"edges": ["generate_zendesk_rag"],
"description": "Retrieving full tickets context",
},
{
"name": "generate_zendesk_rag",
"edges": ["END"],
"description": "Generating answer",
},
],
},
}
32 changes: 29 additions & 3 deletions core/quivr_core/rag/prompts.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import datetime
from pydantic import ConfigDict, create_model

from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
PromptTemplate,
SystemMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain_core.prompts.base import BasePromptTemplate
from pydantic import ConfigDict, create_model


class CustomPromptsDict(dict):
Expand Down Expand Up @@ -258,6 +258,32 @@ def _define_custom_prompts() -> CustomPromptsDict:

custom_prompts["TOOL_ROUTING_PROMPT"] = TOOL_ROUTING_PROMPT

system_message_zendesk_template = """

- You are a Zendesk Agent.
- You are answering a client query.
- You must provide a response with all the information you have. Do not write areas to be filled like [your name], [your email], etc.
- Give a the most complete answer to the client query and give relevant links if needed.
- Based on the following similar client tickets, provide a response to the client query in the same format.

------ Zendesk Similar Tickets ------
{similar_tickets}
-------------------------------------

------ Client Query ------
{client_query}
--------------------------

Agent :
"""

ZENDESK_TEMPLATE_PROMPT = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(system_message_zendesk_template),
]
)
custom_prompts["ZENDESK_TEMPLATE_PROMPT"] = ZENDESK_TEMPLATE_PROMPT

return custom_prompts


Expand Down
85 changes: 85 additions & 0 deletions core/quivr_core/rag/quivr_rag_langgraph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
from collections import OrderedDict
from typing import (
Annotated,
Any,
Expand All @@ -13,6 +14,7 @@
TypedDict,
)
from uuid import UUID, uuid4

import openai
from langchain.retrievers import ContextualCompressionRetriever
from langchain_cohere import CohereRerank
Expand Down Expand Up @@ -752,6 +754,73 @@ def _sort_docs_by_relevance(self, docs: List[Document]) -> List[Document]:
reverse=True,
)

async def retrieve_full_documents_context(self, state: AgentState) -> AgentState:
if "tasks" in state:
tasks = state["tasks"]
else:
tasks = UserTasks([state["messages"][0].content])

if not tasks.has_tasks():
return {**state}

docs = tasks.docs if tasks else []

relevant_knowledge = {}
for doc in docs:
knowledge_id = doc.metadata["knowledge_id"]
similarity_score = doc.metadata.get("similarity", 0)
if knowledge_id in relevant_knowledge:
relevant_knowledge[knowledge_id]["count"] += 1
relevant_knowledge[knowledge_id]["max_similarity_score"] = max(
relevant_knowledge[knowledge_id]["max_similarity_score"],
similarity_score,
)
relevant_knowledge[knowledge_id]["chunk_index"] = max(
doc.metadata["chunk_index"],
relevant_knowledge[knowledge_id]["chunk_index"],
)
else:
relevant_knowledge[knowledge_id] = {
"count": 1,
"max_similarity_score": similarity_score,
"chunk_index": doc.metadata["chunk_index"],
}

top_n = min(3, len(relevant_knowledge))
# FIXME: Tweak this to return the most relevant knowledges
top_knowledge_ids = OrderedDict(
sorted(
relevant_knowledge.items(),
key=lambda x: (
x[1]["max_similarity_score"],
x[1]["count"],
),
reverse=True,
)[:top_n]
)

logger.info(f"Top knowledge IDs: {top_knowledge_ids}")

_docs = []

assert hasattr(
self.vector_store, "get_vectors_by_knowledge_id"
), "Vector store must have method 'get_vectors_by_knowledge_id', this is an enterprise only feature"

for knowledge_id in top_knowledge_ids:
_docs.append(
await self.vector_store.get_vectors_by_knowledge_id( # type: ignore
knowledge_id,
end_index=relevant_knowledge[knowledge_id]["chunk_index"],
)
)

tasks.set_docs(
id=tasks.ids[0], docs=_docs
) # FIXME If multiple IDs is not handled.

return {**state, "tasks": tasks}

def get_rag_context_length(self, state: AgentState, docs: List[Document]) -> int:
final_inputs = self._build_rag_prompt_inputs(state, docs)
msg = custom_prompts.RAG_ANSWER_PROMPT.format(**final_inputs)
Expand Down Expand Up @@ -836,6 +905,22 @@ def bind_tools_to_llm(self, node_name: str):
return self.llm_endpoint._llm.bind_tools(tools, tool_choice="any")
return self.llm_endpoint._llm

def generate_zendesk_rag(self, state: AgentState) -> AgentState:
tasks = state["tasks"]
docs = tasks.docs if tasks else []
messages = state["messages"]
user_task = messages[0].content
inputs = {
"similar_tickets": docs,
"client_query": user_task,
}

msg = custom_prompts.ZENDESK_TEMPLATE_PROMPT.format(**inputs)
llm = self.bind_tools_to_llm(self.generate_zendesk_rag.__name__)
response = llm.invoke(msg)

return {**state, "messages": [response]}

def generate_rag(self, state: AgentState) -> AgentState:
tasks = state["tasks"]
docs = tasks.docs if tasks else []
Expand Down
Loading