Skip to content

Commit 0444f46

Browse files
authored
Update deprecated chains to LCEL (#198)
* update chains * Format
1 parent 8f5b9c6 commit 0444f46

File tree

5 files changed

+56
-48
lines changed

5 files changed

+56
-48
lines changed

api.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,7 @@ def qstream(question: Question = Depends()):
128128
q = Queue()
129129

130130
def cb():
131-
output_function(
132-
{"question": question.text, "chat_history": []},
133-
callbacks=[QueueCallback(q)],
134-
)
131+
output_function.invoke(question.text, config={"callbacks": [QueueCallback(q)]})
135132

136133
def generate():
137134
yield json.dumps({"init": True, "model": llm_name})
@@ -146,9 +143,7 @@ async def ask(question: Question = Depends()):
146143
output_function = llm_chain
147144
if question.rag:
148145
output_function = rag_chain
149-
result = output_function(
150-
{"question": question.text, "chat_history": []}, callbacks=[]
151-
)
146+
result = output_function.invoke(question.text)
152147

153148
return {"result": result["answer"], "model": llm_name}
154149

bot.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,10 @@ def chat_input():
9292
with st.chat_message("assistant"):
9393
st.caption(f"RAG: {name}")
9494
stream_handler = StreamHandler(st.empty())
95-
result = output_function(
96-
{"question": user_input, "chat_history": []}, callbacks=[stream_handler]
97-
)["answer"]
98-
output = result
95+
output = output_function.invoke(
96+
user_input, config={"callbacks": [stream_handler]}
97+
)
98+
9999
st.session_state[f"user_input"].append(user_input)
100100
st.session_state[f"generated"].append(output)
101101
st.session_state[f"rag_mode"].append(name)

chains.py

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
from langchain_openai import OpenAIEmbeddings
32
from langchain_ollama import OllamaEmbeddings
43
from langchain_aws import BedrockEmbeddings
@@ -10,17 +9,17 @@
109

1110
from langchain_neo4j import Neo4jVector
1211

13-
from langchain.chains import RetrievalQAWithSourcesChain
14-
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
12+
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
13+
from langchain_core.output_parsers import StrOutputParser
1514

1615
from langchain.prompts import (
1716
ChatPromptTemplate,
1817
HumanMessagePromptTemplate,
19-
SystemMessagePromptTemplate
18+
SystemMessagePromptTemplate,
2019
)
2120

2221
from typing import List, Any
23-
from utils import BaseLogger, extract_title_and_question
22+
from utils import BaseLogger, extract_title_and_question, format_docs
2423
from langchain_google_genai import GoogleGenerativeAIEmbeddings
2524

2625
AWS_MODELS = (
@@ -32,6 +31,7 @@
3231
"mistral.mi",
3332
)
3433

34+
3535
def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config={}):
3636
if embedding_model_name == "ollama":
3737
embeddings = OllamaEmbeddings(
@@ -47,10 +47,8 @@ def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config=
4747
embeddings = BedrockEmbeddings()
4848
dimension = 1536
4949
logger.info("Embedding: Using AWS")
50-
elif embedding_model_name == "google-genai-embedding-001":
51-
embeddings = GoogleGenerativeAIEmbeddings(
52-
model="models/embedding-001"
53-
)
50+
elif embedding_model_name == "google-genai-embedding-001":
51+
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
5452
dimension = 768
5553
logger.info("Embedding: Using Google Generative AI Embeddings")
5654
else:
@@ -112,17 +110,8 @@ def configure_llm_only_chain(llm):
112110
chat_prompt = ChatPromptTemplate.from_messages(
113111
[system_message_prompt, human_message_prompt]
114112
)
115-
116-
def generate_llm_output(
117-
user_input: str, callbacks: List[Any], prompt=chat_prompt
118-
) -> str:
119-
chain = prompt | llm
120-
answer = chain.invoke(
121-
{"question": user_input}, config={"callbacks": callbacks}
122-
).content
123-
return {"answer": answer}
124-
125-
return generate_llm_output
113+
chain = chat_prompt | llm | StrOutputParser()
114+
return chain
126115

127116

128117
def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, password):
@@ -152,12 +141,6 @@ def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, pass
152141
]
153142
qa_prompt = ChatPromptTemplate.from_messages(messages)
154143

155-
qa_chain = load_qa_with_sources_chain(
156-
llm,
157-
chain_type="stuff",
158-
prompt=qa_prompt,
159-
)
160-
161144
# Vector + Knowledge Graph response
162145
kg = Neo4jVector.from_existing_index(
163146
embedding=embeddings,
@@ -183,12 +166,16 @@ def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, pass
183166
ORDER BY similarity ASC // so that best answers are the last
184167
""",
185168
)
186-
187-
kg_qa = RetrievalQAWithSourcesChain(
188-
combine_documents_chain=qa_chain,
189-
retriever=kg.as_retriever(search_kwargs={"k": 2}),
190-
reduce_k_below_max_tokens=False,
191-
max_tokens_limit=3375,
169+
kg_qa = (
170+
RunnableParallel(
171+
{
172+
"summaries": kg.as_retriever(search_kwargs={"k": 2}) | format_docs,
173+
"question": RunnablePassthrough(),
174+
}
175+
)
176+
| qa_prompt
177+
| llm
178+
| StrOutputParser()
192179
)
193180
return kg_qa
194181

pdf_bot.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
import os
22

33
import streamlit as st
4-
from langchain.chains import RetrievalQA
54
from PyPDF2 import PdfReader
65
from langchain.callbacks.base import BaseCallbackHandler
76
from langchain.text_splitter import RecursiveCharacterTextSplitter
7+
from langchain.prompts import ChatPromptTemplate
88
from langchain_neo4j import Neo4jVector
99
from streamlit.logger import get_logger
1010
from chains import (
1111
load_embedding_model,
1212
load_llm,
1313
)
14+
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
15+
from langchain_core.output_parsers import StrOutputParser
16+
from utils import format_docs
1417

1518
# load api key lib
1619
from dotenv import load_dotenv
@@ -67,6 +70,14 @@ def main():
6770
)
6871

6972
chunks = text_splitter.split_text(text=text)
73+
qa_prompt = ChatPromptTemplate.from_messages(
74+
[
75+
(
76+
"human",
77+
"Based on the provided summary: {summaries} \n Answer the following question:{question}",
78+
)
79+
]
80+
)
7081

7182
# Store the chunks part in db (vector)
7283
vectorstore = Neo4jVector.from_texts(
@@ -79,16 +90,25 @@ def main():
7990
node_label="PdfBotChunk",
8091
pre_delete_collection=True, # Delete existing PDF data
8192
)
82-
qa = RetrievalQA.from_chain_type(
83-
llm=llm, chain_type="stuff", retriever=vectorstore.as_retriever()
93+
qa = (
94+
RunnableParallel(
95+
{
96+
"summaries": vectorstore.as_retriever(search_kwargs={"k": 2})
97+
| format_docs,
98+
"question": RunnablePassthrough(),
99+
}
100+
)
101+
| qa_prompt
102+
| llm
103+
| StrOutputParser()
84104
)
85105

86106
# Accept user questions/query
87107
query = st.text_input("Ask questions about your PDF file")
88108

89109
if query:
90110
stream_handler = StreamHandler(st.empty())
91-
qa.run(query, callbacks=[stream_handler])
111+
qa.invoke(query, {"callbacks": [stream_handler]})
92112

93113

94114
if __name__ == "__main__":

utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def create_vector_index(driver) -> None:
3232
driver.query(index_query)
3333
except: # Already exists
3434
pass
35-
index_query = "CREATE VECTOR INDEX top_answers IF NOT EXISTS FOR (m:Answer) ON m.embedding"
35+
index_query = (
36+
"CREATE VECTOR INDEX top_answers IF NOT EXISTS FOR (m:Answer) ON m.embedding"
37+
)
3638
try:
3739
driver.query(index_query)
3840
except: # Already exists
@@ -52,3 +54,7 @@ def create_constraints(driver):
5254
driver.query(
5355
"CREATE CONSTRAINT tag_name IF NOT EXISTS FOR (t:Tag) REQUIRE (t.name) IS UNIQUE"
5456
)
57+
58+
59+
def format_docs(docs):
60+
return "\n\n".join(doc.page_content for doc in docs)

0 commit comments

Comments
 (0)