Skip to content

TXT2KG refactor v2 config + vanillaRAG #10252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 7, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 76 additions & 43 deletions examples/llm/txt2kg_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from glob import glob
from itertools import chain

import yaml

try:
import wandb
wandb_available = True
Expand Down Expand Up @@ -62,7 +64,6 @@
EVAL_BATCH_SIZE_DEFAULT = 2
LLM_GEN_MODE_DEFAULT = "full"
DEFAULT_ENDPOINT_URL = "https://integrate.api.nvidia.com/v1"
DOC_CHUNK_SIZE_DEFAULT = 8192


def parse_args():
Expand Down Expand Up @@ -98,21 +99,16 @@ def parse_args():
parser.add_argument('--llm_generator_name', type=str,
default=LLM_GENERATOR_NAME_DEFAULT,
help="The LLM to use for Generation")
parser.add_argument(
'--doc_chunk_size', type=int, default=DOC_CHUNK_SIZE_DEFAULT,
help="The chunk size to use VectorRAG (document retrieval)")
parser.add_argument(
'--llm_generator_mode', type=str, default=LLM_GEN_MODE_DEFAULT,
choices=["frozen", "lora",
"full"], help="Whether to freeze the Generator LLM,\
use LORA, or fully finetune")
parser.add_argument('--dont_save_model', action="store_true",
help="Whether to skip model saving.")
parser.add_argument('--k_for_docs', type=int, default=2,
help="Number of docs to retrieve for each question.")
parser.add_argument('--log_steps', type=int, default=30,
help="Log to wandb every N steps")
parser.add_argument('--wandb_project', type=str, default="tech-qa",
parser.add_argument('--wandb_project', type=str, default="hotpotqa",
help="Weights & Biases project name")
parser.add_argument('--wandb', action="store_true",
help="Enable wandb logging")
Expand All @@ -123,12 +119,41 @@ def parse_args():
parser.add_argument('--regenerate_dataset', action="store_true",
help="Regenerate the dataset")
parser.add_argument(
'--doc_parsing_mode', type=str, default="file",
'--doc_parsing_mode', type=str, default=None,
choices=["paragraph",
"file"], help="How to parse documents: 'paragraph' splits "
"files by paragraphs, 'file' treats each file as"
"one document")
return parser.parse_args()
parser.add_argument('--k_for_docs', type=int, default=None,
help="Number of docs to retrieve for each question.")
parser.add_argument(
'--doc_chunk_size', type=int, default=None,
help="The chunk size to use VectorRAG (document retrieval)")
parser.add_argument(
'--dataset', type=str, default="hotpotqa", help="Dataset folder name, "
"should contain corpus and train.json files."
"extracted triples, processed dataset, "
"document retriever, and model checkpoints "
"will be saved in the dataset folder")
args = parser.parse_args()

config_path = f"{args.dataset}/config.yaml"
if os.path.exists(config_path):
print(f"Loading config from {config_path}...")
with open(config_path) as config_file:
config = yaml.safe_load(config_file)

if config is not None:
# Use a loop to check and apply config values for each parameter
config_params = [
'doc_parsing_mode', 'doc_chunk_size', 'k_for_docs'
]
for param in config_params:
if param in config and getattr(args, param) is None:
setattr(args, param, config[param])
print(f"Using config value for {param}: {config[param]}")

return args


# Answer this question based on retrieved contexts. Just give the answer without explanation.
Expand Down Expand Up @@ -166,7 +191,7 @@ def _process_and_chunk_text(text, chunk_size, doc_parsing_mode):

def get_data(args):
# need a JSON dict of Questions and answers, see below for how its used
with open('train.json') as file:
with open(f"{args.dataset}/train.json") as file:
json_obj = json.load(file)
text_contexts = []

Expand All @@ -175,7 +200,7 @@ def get_data(args):
# TODO: add support for additional corpus file formats: PDF, CSV, XML,
# HTML, possibly others.
# corpus folder is simply a folder with context documents in it.
file_paths = glob(f"corpus/*.json")
file_paths = glob(f"{args.dataset}/corpus/*.json")
if len(file_paths) > 0:
for file_path in file_paths:
with open(file_path, "r+") as f:
Expand All @@ -189,7 +214,7 @@ def get_data(args):
args.doc_chunk_size,
args.doc_parsing_mode))
else:
for file_path in glob(f"corpus/*"):
for file_path in glob(f"{args.dataset}/corpus/*"):
with open(file_path, "r+") as f:
text_context = f.read()
text_contexts.extend(
Expand All @@ -212,9 +237,9 @@ def index_kg(args, context_docs):
"w/ --ENDPOINT_URL flag.") # noqa
total_tqdm_count = len(context_docs)
initial_tqdm_count = 0
if os.path.exists("checkpoint_kg.pt"):
if os.path.exists(f"{args.dataset}/checkpoint_kg.pt"):
print("Restoring KG from checkpoint...")
saved_relevant_triples = torch.load("checkpoint_kg.pt",
saved_relevant_triples = torch.load(f"{args.dataset}/checkpoint_kg.pt",
weights_only=False)
kg_maker.relevant_triples = saved_relevant_triples
kg_maker.doc_id_counter = len(saved_relevant_triples)
Expand All @@ -230,17 +255,17 @@ def index_kg(args, context_docs):
chkpt_count += 1
if chkpt_count == chkpt_interval:
chkpt_count = 0
kg_maker.save_kg("checkpoint_kg.pt")
kg_maker.save_kg(f"{args.dataset}/checkpoint_kg.pt")
relevant_triples = kg_maker.relevant_triples

triples.extend(
list(
chain.from_iterable(triple_set
for triple_set in relevant_triples.values())))
triples = list(dict.fromkeys(triples))
torch.save(triples, "tech_qa_just_triples.pt")
if os.path.exists("checkpoint_kg.pt"):
os.remove("checkpoint_kg.pt")
torch.save(triples, f"{args.dataset}/raw_triples.pt")
if os.path.exists(f"{args.dataset}/checkpoint_kg.pt"):
os.remove(f"{args.dataset}/checkpoint_kg.pt")
return triples


Expand All @@ -254,11 +279,11 @@ def update_data_lists(args, data_lists):
"output_device": device,
"batch_size": int(sent_trans_batch_size / 4),
}
if os.path.exists("document_retriever.pt"):
if os.path.exists(f"{args.dataset}/document_retriever.pt"):
print("Loading document retriever from checkpoint...")
vector_retriever = DocumentRetriever.load("document_retriever.pt",
model=model.encode,
model_kwargs=model_kwargs)
vector_retriever = DocumentRetriever.load(
f"{args.dataset}/document_retriever.pt", model=model.encode,
model_kwargs=model_kwargs)
if args.k_for_docs != vector_retriever.k_for_docs:
vector_retriever.k_for_docs = args.k_for_docs
else:
Expand All @@ -280,11 +305,13 @@ def update_data_lists(args, data_lists):

progress_bar.close()

vector_retriever.save(f"{args.dataset}/document_retriever.pt")

del vector_retriever
gc.collect()
torch.cuda.empty_cache()

torch.save(data_lists, "tech_qa.pt")
torch.save(data_lists, f"{args.dataset}/{args.dataset}.pt")
return data_lists


Expand All @@ -294,8 +321,9 @@ def make_dataset(args):
data_lists = {"train": [], "validation": [], "test": []}

triples = []
if os.path.exists("tech_qa_just_triples.pt"):
triples = torch.load("tech_qa_just_triples.pt", weights_only=False)
if os.path.exists(f"{args.dataset}/raw_triples.pt"):
triples = torch.load(f"{args.dataset}/raw_triples.pt",
weights_only=False)
else:
triples = index_kg(args, context_docs)

Expand Down Expand Up @@ -334,11 +362,11 @@ def make_dataset(args):
"verbose": True
}

if os.path.exists("document_retriever.pt"):
if os.path.exists(f"{args.dataset}/document_retriever.pt"):
print("Loading document retriever from checkpoint...")
vector_retriever = DocumentRetriever.load("document_retriever.pt",
model=model.encode,
model_kwargs=model_kwargs)
vector_retriever = DocumentRetriever.load(
f"{args.dataset}/document_retriever.pt", model=model.encode,
model_kwargs=model_kwargs)
if args.k_for_docs != vector_retriever.k_for_docs:
vector_retriever.k_for_docs = args.k_for_docs
else:
Expand All @@ -347,7 +375,7 @@ def make_dataset(args):
k_for_docs=args.k_for_docs,
model=model.encode,
model_kwargs=model_kwargs)
vector_retriever.save("document_retriever.pt")
vector_retriever.save(f"{args.dataset}/document_retriever.pt")

subgraph_filter = make_pcst_filter(
triples,
Expand Down Expand Up @@ -406,7 +434,7 @@ def make_dataset(args):
len(total_data_list))]
data_lists["test"] = total_data_list[int(.8 * len(total_data_list)):]

torch.save(data_lists, "tech_qa.pt")
torch.save(data_lists, f"{args.dataset}/{args.dataset}.pt")
del model
gc.collect()
torch.cuda.empty_cache()
Expand All @@ -430,26 +458,29 @@ def train(args, data_lists):
pin_memory=True, shuffle=False)
test_loader = DataLoader(data_lists["test"], batch_size=eval_batch_size,
drop_last=False, pin_memory=True, shuffle=False)
gnn = GAT(in_channels=768, hidden_channels=hidden_channels,
out_channels=1024, num_layers=num_gnn_layers, heads=4)
if args.num_gnn_layers > 0:
gnn = GAT(in_channels=768, hidden_channels=hidden_channels,
out_channels=1024, num_layers=num_gnn_layers, heads=4)
else:
gnn = None

if args.llm_generator_mode == "full":
llm = LLM(model_name=args.llm_generator_name, sys_prompt=sys_prompt,
n_gpus=args.num_gpus)
model = GRetriever(llm=llm, gnn=gnn)
elif args.llm_generator_mode == "lora":
llm = LLM(model_name=args.llm_generator_name, sys_prompt=sys_prompt,
dtype=torch.float32, n_gpus=args.num_gpus)
model = GRetriever(llm=llm, gnn=gnn, use_lora=True)
else:
# frozen
llm = LLM(model_name=args.llm_generator_name, sys_prompt=sys_prompt,
dtype=torch.float32, n_gpus=args.num_gpus).eval()

for _, p in llm.named_parameters():
p.requires_grad = False
model = GRetriever(llm=llm, gnn=gnn)

save_name = "tech-qa-model.pt"
model = GRetriever(llm=llm, gnn=gnn,
use_lora=args.llm_generator_mode == "lora")

save_name = f"{args.dataset}/model.pt"
if os.path.exists(save_name) and not args.regenerate_dataset:
print("Re-using saved G-retriever model for testing...")
model = load_params_dict(model, save_name)
Expand Down Expand Up @@ -579,11 +610,13 @@ def eval(question: str, pred: str, correct_answer: str):
print("Please install wandb and rerun the script.")
sys.exit(1)

print("Starting TechQA training with args: ", args)
if os.path.exists("tech_qa.pt") and not args.regenerate_dataset:
print("Re-using Saved TechQA KG-RAG Dataset...")
data_lists = torch.load("tech_qa.pt", weights_only=False)
if os.path.exists("document_retriever.pt"):
print(f"Starting {args.dataset} training with args: ", args)
if os.path.exists(f"{args.dataset}/{args.dataset}.pt"
) and not args.regenerate_dataset: # noqa
print(f"Re-using Saved {args.dataset} KG-RAG Dataset...")
data_lists = torch.load(f"{args.dataset}/{args.dataset}.pt",
weights_only=False)
if os.path.exists(f"{args.dataset}/document_retriever.pt"):
print("Updating data lists with document retriever...")
data_lists = update_data_lists(args, data_lists)
else:
Expand Down
86 changes: 46 additions & 40 deletions torch_geometric/nn/models/g_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@
def __init__(
self,
llm: LLM,
gnn: torch.nn.Module,
gnn: torch.nn.Module = None,
use_lora: bool = False,
mlp_out_tokens: int = 1,
) -> None:
super().__init__()

self.llm = llm
self.gnn = gnn.to(self.llm.device)
self.gnn = gnn.to(self.llm.device) if gnn is not None else None

Check warning on line 50 in torch_geometric/nn/models/g_retriever.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/g_retriever.py#L50

Added line #L50 was not covered by tests

self.word_embedding = self.llm.word_embedding
self.llm_generator = self.llm.llm
Expand All @@ -73,15 +73,16 @@
)
self.llm_generator = get_peft_model(self.llm_generator, config)

mlp_out_channels = llm.word_embedding.embedding_dim
mlp_hidden_channels = self.gnn.out_channels
self.projector = torch.nn.Sequential(
torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels),
torch.nn.Sigmoid(),
torch.nn.Linear(mlp_hidden_channels,
mlp_out_channels * mlp_out_tokens),
torch.nn.Unflatten(-1, (mlp_out_tokens, mlp_out_channels)),
).to(self.llm.device)
if self.gnn is not None:
mlp_out_channels = llm.word_embedding.embedding_dim
mlp_hidden_channels = self.gnn.out_channels
self.projector = torch.nn.Sequential(

Check warning on line 79 in torch_geometric/nn/models/g_retriever.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/g_retriever.py#L76-L79

Added lines #L76 - L79 were not covered by tests
torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels),
torch.nn.Sigmoid(),
torch.nn.Linear(mlp_hidden_channels,
mlp_out_channels * mlp_out_tokens),
torch.nn.Unflatten(-1, (mlp_out_tokens, mlp_out_channels)),
).to(self.llm.device)

self.seq_length_stats = []

Expand Down Expand Up @@ -127,21 +128,23 @@
to give to the LLM, such as textified knowledge graphs.
(default: :obj:`None`)
"""
x = self.encode(x, edge_index, batch, edge_attr)
x = self.projector(x)
xs = x.split(1, dim=0)

# Handle case where theres more than one embedding for each sample
xs = [x.squeeze(0) for x in xs]

# Handle questions without node features:
batch_unique = batch.unique()
batch_size = len(question)
if len(batch_unique) < batch_size:
xs = [
xs[i] if i in batch_unique else None for i in range(batch_size)
]

xs = None
if self.gnn is not None:
x = self.encode(x, edge_index, batch, edge_attr)
x = self.projector(x)
xs = x.split(1, dim=0)

Check warning on line 135 in torch_geometric/nn/models/g_retriever.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/g_retriever.py#L131-L135

Added lines #L131 - L135 were not covered by tests

# Handle case where theres more than one embedding for each sample
xs = [x.squeeze(0) for x in xs]

Check warning on line 138 in torch_geometric/nn/models/g_retriever.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/g_retriever.py#L138

Added line #L138 was not covered by tests

# Handle questions without node features:
batch_unique = batch.unique()
batch_size = len(question)
if len(batch_unique) < batch_size:
xs = [

Check warning on line 144 in torch_geometric/nn/models/g_retriever.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/g_retriever.py#L141-L144

Added lines #L141 - L144 were not covered by tests
xs[i] if i in batch_unique else None
for i in range(batch_size)
]
(
inputs_embeds,
attention_mask,
Expand Down Expand Up @@ -189,20 +192,23 @@
max_out_tokens (int, optional): How many tokens for the LLM to
generate. (default: :obj:`32`)
"""
x = self.encode(x, edge_index, batch, edge_attr)
x = self.projector(x)
xs = x.split(1, dim=0)

# Handle case where theres more than one embedding for each sample
xs = [x.squeeze(0) for x in xs]

# Handle questions without node features:
batch_unique = batch.unique()
batch_size = len(question)
if len(batch_unique) < batch_size:
xs = [
xs[i] if i in batch_unique else None for i in range(batch_size)
]
xs = None
if self.gnn is not None:
x = self.encode(x, edge_index, batch, edge_attr)
x = self.projector(x)
xs = x.split(1, dim=0)

Check warning on line 199 in torch_geometric/nn/models/g_retriever.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/g_retriever.py#L195-L199

Added lines #L195 - L199 were not covered by tests

# Handle case where theres more than one embedding for each sample
xs = [x.squeeze(0) for x in xs]

Check warning on line 202 in torch_geometric/nn/models/g_retriever.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/g_retriever.py#L202

Added line #L202 was not covered by tests

# Handle questions without node features:
batch_unique = batch.unique()
batch_size = len(question)
if len(batch_unique) < batch_size:
xs = [

Check warning on line 208 in torch_geometric/nn/models/g_retriever.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/models/g_retriever.py#L205-L208

Added lines #L205 - L208 were not covered by tests
xs[i] if i in batch_unique else None
for i in range(batch_size)
]

inputs_embeds, attention_mask, _ = self.llm._get_embeds(
question, additional_text_context, xs)
Expand Down
Loading
Loading