Skip to content

Commit b3d0d7d

Browse files
authored
Merge pull request #106 from yzh3434/pr/bm25-jieba-rrf
fix(C9): 启用 hybrid_search 中真实的 BM25(jieba) 检索 + RRF 替代 round-robin
2 parents a2e0f24 + 9e4526a commit b3d0d7d

2 files changed

Lines changed: 217 additions & 69 deletions

File tree

code/C9/rag_modules/hybrid_retrieval.py

Lines changed: 216 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,34 @@
11
"""
22
混合检索模块
33
基于双层检索范式:实体级 + 主题级检索
4-
结合图结构检索和向量检索,使用Round-robin轮询策略
4+
结合 BM25(jieba 分词)、向量检索与图键值索引,使用 RRF 融合
55
"""
66

77
import json
88
import logging
9-
from typing import List, Dict, Tuple, Any
9+
from typing import List, Dict, Tuple, Any, Optional
1010
from dataclasses import dataclass
1111

12+
import jieba
13+
from rank_bm25 import BM25Okapi
1214
from langchain_core.documents import Document
13-
from langchain_community.retrievers import BM25Retriever
1415
from neo4j import GraphDatabase
1516
from .graph_indexing import GraphIndexingModule
1617

1718
logger = logging.getLogger(__name__)
1819

20+
# 中文停用词表:助词 / 连词 / 疑问词 / 人称 / 语气词 / 动词修饰
21+
# 不引第三方停用词包,按烹饪问答场景手挑(覆盖 testset 高频虚词)
22+
_CHINESE_STOPWORDS = set("""
23+
的 了 和 是 在 我 有 就 不 也 都 还 这 那 一 个 与 及 等 上 下 中 为 以 于 从 把 被 让 使 又 而 但 或
24+
什么 怎么 如何 哪些 哪个 哪里 谁 多少 几 你 他 她 它 我们 他们 她们 它们
25+
请问 请 想 要 需要 能 可以 应该 会 啊 呢 吧 嘛 吗 哦 呀 哈
26+
之 其 此 该 即 各 每 些 种 类 时 后 前 里 外 内 间 已经 正在 一些 一下
27+
""".split())
28+
29+
# RRF 融合的常数 k:Cormack et al. 2009 默认值
30+
_RRF_K = 60
31+
1932
@dataclass
2033
class RetrievalResult:
2134
"""检索结果数据结构"""
@@ -30,42 +43,61 @@ class HybridRetrievalModule:
3043
"""
3144
混合检索模块
3245
核心特点:
33-
1. 双层检索范式(实体级 + 主题级)
34-
2. 关键词提取和匹配
35-
3. 图结构+向量检索结合
36-
4. 一跳邻居扩展
37-
5. Round-robin轮询合并策略
46+
1. 双层检索范式(实体级 + 主题级,基于图键值索引)
47+
2. BM25 关键词检索(jieba 分词 + 停用词过滤)
48+
3. 向量检索(Milvus)+ 一跳邻居扩展
49+
4. RRF (Reciprocal Rank Fusion) 融合三路结果
3850
"""
39-
51+
4052
def __init__(self, config, milvus_module, data_module, llm_client):
4153
self.config = config
4254
self.milvus_module = milvus_module
4355
self.data_module = data_module
4456
self.llm_client = llm_client
4557
self.driver = None
46-
self.bm25_retriever = None
47-
58+
59+
# BM25 索引 + 原始文档(按索引位置对齐)
60+
self.bm25: Optional[BM25Okapi] = None
61+
self.bm25_corpus_docs: List[Document] = []
62+
4863
# 图索引模块
4964
self.graph_indexing = GraphIndexingModule(config, llm_client)
5065
self.graph_indexed = False
51-
66+
5267
def initialize(self, chunks: List[Document]):
5368
"""初始化检索系统"""
5469
logger.info("初始化混合检索模块...")
55-
70+
5671
# 连接Neo4j
5772
self.driver = GraphDatabase.driver(
58-
self.config.neo4j_uri,
73+
self.config.neo4j_uri,
5974
auth=(self.config.neo4j_user, self.config.neo4j_password)
6075
)
61-
62-
# 初始化BM25检索器
76+
77+
# 初始化 BM25(jieba 分词 + 中文停用词过滤)
6378
if chunks:
64-
self.bm25_retriever = BM25Retriever.from_documents(chunks)
65-
logger.info(f"BM25检索器初始化完成,文档数量: {len(chunks)}")
66-
79+
self.bm25_corpus_docs = list(chunks)
80+
tokenized_corpus = [self._tokenize_chinese(d.page_content) for d in chunks]
81+
self.bm25 = BM25Okapi(tokenized_corpus)
82+
avg_tokens = sum(len(t) for t in tokenized_corpus) / max(1, len(tokenized_corpus))
83+
logger.info(
84+
f"BM25(jieba+stopwords) 索引构建完成,文档数: {len(chunks)},"
85+
f"平均 token 数: {avg_tokens:.1f}"
86+
)
87+
6788
# 初始化图索引
6889
self._build_graph_index()
90+
91+
@staticmethod
92+
def _tokenize_chinese(text: str) -> List[str]:
93+
"""jieba 精确分词 + 停用词 / 空白 / 单字符过滤"""
94+
if not text:
95+
return []
96+
tokens = jieba.lcut(text)
97+
return [
98+
t for t in tokens
99+
if t.strip() and t not in _CHINESE_STOPWORDS and not t.isspace()
100+
]
69101

70102
def _build_graph_index(self):
71103
"""构建图索引"""
@@ -542,60 +574,175 @@ def _get_node_neighbors(self, node_id: str, max_neighbors: int = 3) -> List[str]
542574
logger.error(f"获取邻居节点失败: {e}")
543575
return []
544576

577+
def bm25_search(self, query: str, top_k: int = 5) -> List[Document]:
578+
"""
579+
BM25 检索:jieba 分词后查 BM25Okapi 索引,按分数降序返回 top_k。
580+
分数写入 metadata["bm25_score"],供调试与未来潜在的分数级融合使用。
581+
"""
582+
if self.bm25 is None or not self.bm25_corpus_docs:
583+
logger.warning("BM25 索引未初始化,bm25_search 返回空")
584+
return []
585+
586+
tokenized_query = self._tokenize_chinese(query)
587+
if not tokenized_query:
588+
logger.debug(f"BM25 query 分词为空,跳过: {query}")
589+
return []
590+
591+
scores = self.bm25.get_scores(tokenized_query)
592+
# 按分数降序取 top_k 索引
593+
top_indices = sorted(
594+
range(len(scores)), key=lambda i: scores[i], reverse=True
595+
)[:top_k]
596+
597+
docs: List[Document] = []
598+
for idx in top_indices:
599+
score = float(scores[idx])
600+
if score <= 0:
601+
# BM25 分数 ≤ 0 视为无关(IDF/TF 全无贡献),不进结果
602+
continue
603+
src = self.bm25_corpus_docs[idx]
604+
recipe_name = (
605+
src.metadata.get("recipe_name")
606+
or src.metadata.get("name")
607+
or "未知菜品"
608+
)
609+
doc = Document(
610+
page_content=src.page_content,
611+
metadata={
612+
**src.metadata,
613+
"recipe_name": recipe_name,
614+
"search_method": "bm25",
615+
"search_type": "bm25",
616+
"bm25_score": score,
617+
}
618+
)
619+
docs.append(doc)
620+
621+
logger.info(f"BM25 检索完成,返回 {len(docs)} 个文档(query tokens={tokenized_query})")
622+
return docs
623+
624+
@staticmethod
625+
def _rrf_merge(
626+
ranked_lists: List[Tuple[str, List[Document]]],
627+
top_k: int,
628+
k: int = _RRF_K,
629+
) -> List[Document]:
630+
"""
631+
Reciprocal Rank Fusion: score(d) = Σ_i 1 / (k + best_rank_i(d))
632+
633+
Args:
634+
ranked_lists: 多路 (source_name, ranked_docs) — docs 按相关度降序
635+
top_k: 最终返回个数
636+
k: RRF 平滑常数,默认 60(Cormack et al. 2009)
637+
638+
去重 key:node_id 优先,page_content[:200] hash 兜底。
639+
640+
同 source 内同 doc_id 多次命中(如一道菜的多个 chunk 共享 recipe.nodeId):
641+
- 算分只取该 source 内最佳 rank(最小 rank)一次,避免重复加分
642+
- 命中 chunk 数另存到 rrf_chunk_hits,供后续分析
643+
644+
canonical doc(最终展示给 LLM 的 page_content):
645+
选全局最小 rank 那个 chunk;rank 相同时按 ranked_lists 顺序优先。
646+
647+
返回的 Document 是新对象,不会 mutate 输入 list 里的 Document。
648+
"""
649+
# doc_id -> source_name -> 该 source 内最小 rank(用于算分)
650+
best_rank_per_source: Dict[str, Dict[str, int]] = {}
651+
# doc_id -> source_name -> 该 source 内命中 chunk 次数(信息存档)
652+
chunk_hits_per_source: Dict[str, Dict[str, int]] = {}
653+
# doc_id -> (global_best_rank, source_priority, doc) — 选 canonical doc
654+
best_doc_info: Dict[str, Tuple[int, int, Document]] = {}
655+
656+
for source_priority, (source_name, ranked_docs) in enumerate(ranked_lists):
657+
for rank, doc in enumerate(ranked_docs, start=1):
658+
node_id = doc.metadata.get("node_id")
659+
doc_id = (
660+
str(node_id) if node_id is not None
661+
else f"hash::{hash(doc.page_content[:200])}"
662+
)
663+
664+
if doc_id not in best_rank_per_source:
665+
best_rank_per_source[doc_id] = {}
666+
chunk_hits_per_source[doc_id] = {}
667+
668+
curr_best = best_rank_per_source[doc_id].get(source_name)
669+
# 如果是第一次出现或者当前rank比记录的更小,则更新
670+
if curr_best is None or rank < curr_best:
671+
best_rank_per_source[doc_id][source_name] = rank
672+
673+
chunk_hits_per_source[doc_id][source_name] = (
674+
chunk_hits_per_source[doc_id].get(source_name, 0) + 1
675+
)
676+
677+
new_key = (rank, source_priority)
678+
if (
679+
doc_id not in best_doc_info
680+
or new_key < (best_doc_info[doc_id][0], best_doc_info[doc_id][1])
681+
):
682+
best_doc_info[doc_id] = (rank, source_priority, doc)
683+
684+
# 每个 source 只用 best rank 算一次贡献
685+
rrf_scores: Dict[str, float] = {
686+
doc_id: sum(1.0 / (k + r) for r in source_ranks.values())
687+
for doc_id, source_ranks in best_rank_per_source.items()
688+
}
689+
690+
sorted_ids = sorted(
691+
rrf_scores.keys(), key=lambda d: rrf_scores[d], reverse=True
692+
)
693+
694+
merged: List[Document] = []
695+
for doc_id in sorted_ids[:top_k]:
696+
_, _, source_doc = best_doc_info[doc_id]
697+
# 浅 copy metadata,避免 mutate 上游 Document
698+
new_metadata = dict(source_doc.metadata)
699+
new_metadata["rrf_score"] = rrf_scores[doc_id]
700+
new_metadata["rrf_sources"] = list(best_rank_per_source[doc_id].keys())
701+
new_metadata["rrf_ranks"] = dict(best_rank_per_source[doc_id])
702+
new_metadata["rrf_chunk_hits"] = dict(chunk_hits_per_source[doc_id])
703+
new_metadata["final_score"] = rrf_scores[doc_id]
704+
merged.append(Document(
705+
page_content=source_doc.page_content,
706+
metadata=new_metadata,
707+
))
708+
709+
return merged
710+
545711
def hybrid_search(self, query: str, top_k: int = 5) -> List[Document]:
546712
"""
547-
混合检索:使用Round-robin轮询合并策略
548-
公平轮询合并不同检索结果,不使用权重配置
713+
混合检索:三路召回(图键值双层 + 向量 + BM25)→ RRF 融合
549714
"""
550-
logger.info(f"开始混合检索: {query}")
551-
552-
# 1. 双层检索(实体+主题检索)
553-
dual_docs = self.dual_level_retrieval(query, top_k)
554-
555-
# 2. 增强向量检索
556-
vector_docs = self.vector_search_enhanced(query, top_k)
557-
558-
# 3. Round-robin轮询合并
559-
merged_docs = []
560-
seen_doc_ids = set()
561-
max_len = max(len(dual_docs), len(vector_docs))
562-
origin_len = len(dual_docs) + len(vector_docs)
563-
564-
for i in range(max_len):
565-
# 先添加双层检索结果
566-
if i < len(dual_docs):
567-
doc = dual_docs[i]
568-
doc_id = doc.metadata.get("node_id", hash(doc.page_content))
569-
if doc_id not in seen_doc_ids:
570-
seen_doc_ids.add(doc_id)
571-
doc.metadata["search_method"] = "dual_level"
572-
doc.metadata["round_robin_order"] = len(merged_docs)
573-
# 设置统一的final_score字段
574-
doc.metadata["final_score"] = doc.metadata.get("relevance_score", 0.0)
575-
merged_docs.append(doc)
576-
577-
# 再添加向量检索结果
578-
if i < len(vector_docs):
579-
doc = vector_docs[i]
580-
doc_id = doc.metadata.get("node_id", hash(doc.page_content))
581-
if doc_id not in seen_doc_ids:
582-
seen_doc_ids.add(doc_id)
583-
doc.metadata["search_method"] = "vector_enhanced"
584-
doc.metadata["round_robin_order"] = len(merged_docs)
585-
# 设置统一的final_score字段(向量得分需要转换)
586-
vector_score = doc.metadata.get("score", 0.0)
587-
# COSINE距离转换为相似度:distance越小,相似度越高
588-
similarity_score = max(0.0, 1.0 - vector_score) if vector_score <= 1.0 else 0.0
589-
doc.metadata["final_score"] = similarity_score
590-
merged_docs.append(doc)
591-
592-
# 取前top_k个结果
593-
final_docs = merged_docs[:top_k]
594-
595-
logger.info(f"Round-robin合并:从总共{origin_len}个结果合并为{len(final_docs)}个文档")
596-
logger.info(f"混合检索完成,返回 {len(final_docs)} 个文档")
715+
logger.info(f"开始混合检索(dual + vector + bm25, RRF k={_RRF_K}): {query}")
716+
717+
# 每路给 RRF 留够候选空间,否则三路各自前 top_k 容易没交集,融合退化
718+
candidate_k = max(top_k * 2, 10)
719+
720+
dual_docs = self.dual_level_retrieval(query, candidate_k)
721+
vector_docs = self.vector_search_enhanced(query, candidate_k)
722+
bm25_docs = self.bm25_search(query, candidate_k)
723+
724+
# 标记每路来源(dual_level 内部会写 search_type 但不一定写 search_method)
725+
for d in dual_docs:
726+
d.metadata.setdefault("search_method", "dual_level")
727+
for d in vector_docs:
728+
d.metadata["search_method"] = "vector"
729+
# bm25_search 内部已写 search_method=bm25
730+
731+
final_docs = self._rrf_merge(
732+
ranked_lists=[
733+
("dual_level", dual_docs),
734+
("vector", vector_docs),
735+
("bm25", bm25_docs),
736+
],
737+
top_k=top_k,
738+
)
739+
740+
logger.info(
741+
f"RRF 融合完成:dual={len(dual_docs)} vector={len(vector_docs)} "
742+
f"bm25={len(bm25_docs)} → 最终 {len(final_docs)} 个文档"
743+
)
597744
return final_docs
598-
745+
599746
def close(self):
600747
"""关闭资源连接"""
601748
if self.driver:

code/C9/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ neo4j>=5.0.0
1313

1414
pymilvus==2.5.11
1515
rank-bm25>=0.2.2
16+
jieba>=0.42.1
1617

1718
lazy_loader==0.4
1819
huggingface-hub>=0.33.4

0 commit comments

Comments
 (0)