Skip to content

Commit 83a968f

Browse files
authored
Merge pull request #58 from grace-sng7/augmented_suggester
2 parents 0d1c2b5 + 00b61a2 commit 83a968f

File tree

6 files changed

+63
-61
lines changed

6 files changed

+63
-61
lines changed

docs/notebooks/augmented_model_suggester_examples.ipynb

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@
5353
"execution_count": null,
5454
"outputs": []
5555
},
56+
{
57+
"cell_type": "markdown",
58+
"source": [
59+
"Here we introduce the AugmentedModelSuggester class. Creating an instance of it enables the chosen LLM to utilize Retrieval Augmented Generation (RAG) to determine causality. It currently does this by searching the CauseNet dataset for a relevant causal pair and augmenting the LLM with the corresponding evidence/information stored in CauseNet."
60+
],
61+
"metadata": {
62+
"id": "DjYECuX84vbN"
63+
}
64+
},
5665
{
5766
"cell_type": "code",
5867
"source": [
@@ -66,6 +75,15 @@
6675
"execution_count": null,
6776
"outputs": []
6877
},
78+
{
79+
"cell_type": "markdown",
80+
"source": [
81+
"AugmentedModelSuggester can suggest the pairwise relationship given two variables. If a relevant causal pair is found in CauseNet, the LLM is augmented with the aforementioned information in CauseNet. If not found, by default, the LLM will rely on its own knowledge."
82+
],
83+
"metadata": {
84+
"id": "dES0LwHV57eX"
85+
}
86+
},
6987
{
7088
"cell_type": "code",
7189
"source": [

poetry.lock

Lines changed: 1 addition & 18 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ langchain-chroma = ">=0.2.4"
6060
langchain-community = ">=0.3.24"
6161
langchain-core = ">=0.3.60"
6262
langchain-huggingface = ">=0.2.0"
63-
langchain-openai = ">=0.3.17"
6463
rank-bm25 = ">=0.2.2"
6564
sentence-transformers = ">=4.1.0"
6665

pywhyllm/suggesters/augmented_model_suggester.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,29 @@
22
import re
33

44
from .simple_model_suggester import SimpleModelSuggester
5-
from pywhyllm.utils.data_loader import *
5+
from pywhyllm.utils.data_loader import download_causenet, load_causenet_json, create_causenet_dict
66
from pywhyllm.utils.augmented_model_suggester_utils import *
77

88

99
class AugmentedModelSuggester(SimpleModelSuggester):
10+
"""
11+
A class that extends SimpleModelSuggester and currently provides methods for suggesting causal relationships between variables by leveraging the CauseNet dataset for Retrieval Augmented Generation (RAG).
12+
13+
Methods:
14+
- suggest_pairwise_relationship(variable1: str, variable2: str) -> List[str]:
15+
Suggests the causal relationship between two variables and returns a list containing the cause, effect, and a description of the relationship.
16+
"""
17+
1018
def __init__(self, llm, file_path: str = 'data/causenet-precision.jsonl.bz2'):
19+
"""
20+
Initialize the AugmentedModelSuggester with a language model and download CauseNet data.
21+
22+
Args:
23+
llm: The language model instance to be used for querying.
24+
file_path (str, optional): Path to save the downloaded CauseNet JSONL file.
25+
Defaults to 'data/causenet-precision.jsonl.bz2'.
26+
"""
27+
1128
super().__init__(llm)
1229
self.file_path = file_path
1330

@@ -23,13 +40,26 @@ def __init__(self, llm, file_path: str = 'data/causenet-precision.jsonl.bz2'):
2340
print("Download failed")
2441

2542
def suggest_pairwise_relationship(self, variable1: str, variable2: str):
43+
"""
44+
Suggests a cause-and-effect relationship between two variables, leveraging the CauseNet dataset for Retrieval Augmented Generation (RAG).
45+
If a relevant causal pair is found in CauseNet, the LLM is augmented with corresponding information regarding the relationship stored
46+
in CauseNet. If not found, by default, the LLM will rely on its own knowledge.
47+
48+
Args:
49+
variable1 (str): The name of the first variable.
50+
variable2 (str): The name of the second variable.
51+
52+
Returns:
53+
list: A list containing the suggested cause variable, the suggested effect variable, and a description of the reasoning behind the suggestion. If there is no relationship between the two variables, the first two elements will be None.
54+
"""
55+
2656
result = find_top_match_in_causenet(self.causenet_dict, variable1, variable2)
2757
if result:
2858
source_text = get_source_text(result)
2959
retriever = split_data_and_create_vectorstore_retriever(source_text)
30-
response = query_llm(variable1, variable2, source_text, retriever)
60+
response = query_llm(self.llm, variable1, variable2, source_text, retriever)
3161
else:
32-
response = query_llm(variable1, variable2)
62+
response = query_llm(self.llm, variable1, variable2)
3363

3464
answer = re.findall(r'<answer>(.*?)</answer>', response)
3565
answer = [ans.strip() for ans in answer]

pywhyllm/utils/augmented_model_suggester_utils.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from langchain_core.documents import Document
44
from langchain_chroma import Chroma
55
from langchain_huggingface import HuggingFaceEmbeddings
6-
from langchain_openai import ChatOpenAI
76
from langchain_core.prompts import ChatPromptTemplate
87
from langchain.chains import create_retrieval_chain
98
from langchain.chains.combine_documents import create_stuff_documents_chain
@@ -13,49 +12,39 @@
1312

1413

1514
def find_top_match_in_causenet(causenet_dict, variable1, variable2, threshold=0.7):
16-
# Sample dictionary
1715
pair_strings = [
1816
f"{causenet_dict[key]['causal_relation']['cause']}-{causenet_dict[key]['causal_relation']['effect']}"
1917
for key in causenet_dict]
2018

21-
# Tokenize for BM25
2219
tokenized_pairs = [text.split() for text in pair_strings]
2320
bm25 = BM25Okapi(tokenized_pairs)
2421

25-
# Original and reverse queries
2622
query = variable1 + "-" + variable2
2723
reverse_query = variable2 + "-" + variable1
2824
tokenized_query = query.split()
2925
tokenized_reverse_query = reverse_query.split()
3026

31-
# Combine tokens from both queries (remove duplicates)
3227
combined_query = list(set(tokenized_query + tokenized_reverse_query))
3328

34-
# Get top-k candidates using BM25 with combined query
3529
k = 5
3630
scores = bm25.get_scores(combined_query)
3731
top_k_indices = np.argsort(scores)[::-1][:k]
3832
candidate_pairs = [pair_strings[i] for i in top_k_indices]
3933

40-
# Apply SBERT to candidates
4134
model = SentenceTransformer('all-MiniLM-L6-v2')
4235
query_embedding = model.encode(query, convert_to_tensor=True)
4336
reverse_query_embedding = model.encode(reverse_query, convert_to_tensor=True)
4437
candidate_embeddings = model.encode(candidate_pairs, convert_to_tensor=True)
4538

46-
# Compute similarities for both original and reverse queries
4739
similarities = util.cos_sim(query_embedding, candidate_embeddings).flatten()
4840
reverse_similarities = util.cos_sim(reverse_query_embedding, candidate_embeddings).flatten()
4941

50-
# Take the maximum similarity for each candidate (original or reverse)
5142
max_similarities = np.maximum(similarities, reverse_similarities)
5243

53-
# Get the top match and its similarity score
5444
top_idx = np.argmax(max_similarities)
5545
top_similarity = max_similarities[top_idx]
5646
top_pair = candidate_pairs[top_idx]
5747

58-
# Check if the top similarity meets the threshold
5948
if top_similarity >= threshold:
6049
print(f"Best match: {top_pair} (Similarity: {top_similarity:.4f})")
6150
return causenet_dict[top_pair]
@@ -77,36 +66,29 @@ def get_source_text(causenet_query_result):
7766
def split_data_and_create_vectorstore_retriever(source_text):
7867
document = Document(page_content=source_text)
7968

80-
# Initialize the text splitter
8169
text_splitter = RecursiveCharacterTextSplitter(
82-
chunk_size=100, # Adjust chunk size as needed
83-
chunk_overlap=20 # Overlap for context
70+
chunk_size=100,
71+
chunk_overlap=20
8472
)
85-
# Split the documents
8673
splits = text_splitter.split_documents([document])
8774

8875
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
8976

90-
# Create a vector store from the document splits
9177
vectorstore = Chroma.from_documents(
9278
documents=splits,
9379
embedding=embeddings,
9480
persist_directory="./chroma_db" # Optional: Save to disk for reuse
9581
)
9682

97-
# Create a retriever from the vector store
9883
retriever = vectorstore.as_retriever(
9984
search_type="similarity",
100-
search_kwargs={"k": 5} # Retrieve top 5 relevant chunks
85+
search_kwargs={"k": 5}
10186
)
10287

10388
return retriever
10489

10590

106-
def query_llm(variable1, variable2, source_text=None, retriever=None):
107-
# Initialize the language model
108-
llm = ChatOpenAI(model="gpt-4")
109-
91+
def query_llm(llm, variable1, variable2, source_text=None, retriever=None):
11092
if source_text:
11193
system_prompt = """You are a helpful assistant for causal reasoning.
11294
@@ -116,7 +98,6 @@ def query_llm(variable1, variable2, source_text=None, retriever=None):
11698
system_prompt = """You are a helpful assistant for causal reasoning.
11799
"""
118100

119-
# prompt template
120101
prompt = ChatPromptTemplate.from_messages([
121102
("system", system_prompt),
122103
("human", "{input}")
@@ -125,12 +106,8 @@ def query_llm(variable1, variable2, source_text=None, retriever=None):
125106
query = f"""Which cause-and-effect-relationship is more likely? Provide reasoning and you must give your final answer (A, B, or C) in <answer> </answer> tags with the letter only.
126107
A. {variable1} causes {variable2} B. {variable2} causes {variable1} C. neither {variable1} nor {variable2} cause each other."""
127108

128-
# Define the system prompt
129109
if source_text:
130-
# Create a document chain to combine retrieved documents
131110
question_answer_chain = create_stuff_documents_chain(llm, prompt)
132-
133-
# Create the RAG chain
134111
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
135112

136113
response = rag_chain.invoke({"input": query})

pywhyllm/utils/data_loader.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ def download_causenet(url: str, file_path: str) -> bool:
2020
International Conference on Information &amp; Knowledge Management (CIKM '20). Association for
2121
Computing Machinery, New York, NY, USA, 3023–3030. https://doi.org/10.1145/3340531.3412763
2222
23-
TODO: Add license
23+
License:
24+
CauseNet data is licensed under the Creative Commons Attribution (CC BY) license.
25+
For full license details, see: https://creativecommons.org/licenses/by/4.0/
2426
2527
Args:
2628
url (str): The URL of the file to download.
@@ -30,21 +32,16 @@ def download_causenet(url: str, file_path: str) -> bool:
3032
bool: True if the download was successful, False otherwise.
3133
"""
3234
try:
33-
# Ensure the output directory exists
3435
os.makedirs(os.path.dirname(file_path), exist_ok=True)
3536

36-
# Send a GET request to the URL
3737
response = requests.get(url, stream=True)
3838

39-
# Check if the request was successful
4039
if response.status_code != 200:
4140
logging.error(f"Failed to download file from {url}. Status code: {response.status_code}")
4241
return False
4342

44-
# Get the total file size for progress bar (if available)
4543
total_size = int(response.headers.get("content-length", 0))
4644

47-
# Download and save the file with a progress bar
4845
with open(file_path, "wb") as file, tqdm(
4946
desc="Downloading",
5047
total=total_size,
@@ -73,12 +70,11 @@ def load_causenet_json(file_path):
7370
print("Loading CauseNet using json")
7471
with bz2.open(file_path, 'rt',
7572
encoding='utf-8') as file:
76-
# Read each line and parse as JSON
7773
for line in file:
78-
line = line.strip() # Remove trailing newlines
79-
if line: # Skip empty lines
80-
json_obj = json.loads(line) # Parse the line as JSON
81-
json_data.append(json_obj) # Add to list
74+
line = line.strip()
75+
if line:
76+
json_obj = json.loads(line)
77+
json_data.append(json_obj)
8278
print("Done loading CauseNet using json")
8379
return json_data
8480

@@ -97,7 +93,6 @@ def create_causenet_dict(json_data):
9793
'sources': item['sources']
9894
}
9995
else:
100-
# Append sources to existing list
10196
causenet_dict[key]['sources'].extend(item['sources'])
10297
print("Done creating dictionary from CauseNet json data")
10398
return causenet_dict

0 commit comments

Comments
 (0)