11"""
22混合检索模块
33基于双层检索范式:实体级 + 主题级检索
4- 结合图结构检索和向量检索,使用Round-robin轮询策略
4+ 结合 BM25(jieba 分词)、向量检索与图键值索引,使用 RRF 融合
55"""
66
77import json
88import logging
9- from typing import List , Dict , Tuple , Any
9+ from typing import List , Dict , Tuple , Any , Optional
1010from dataclasses import dataclass
1111
12+ import jieba
13+ from rank_bm25 import BM25Okapi
1214from langchain_core .documents import Document
13- from langchain_community .retrievers import BM25Retriever
1415from neo4j import GraphDatabase
1516from .graph_indexing import GraphIndexingModule
1617
1718logger = logging .getLogger (__name__ )
1819
20+ # 中文停用词表:助词 / 连词 / 疑问词 / 人称 / 语气词 / 动词修饰
21+ # 不引第三方停用词包,按烹饪问答场景手挑(覆盖 testset 高频虚词)
22+ _CHINESE_STOPWORDS = set ("""
23+ 的 了 和 是 在 我 有 就 不 也 都 还 这 那 一 个 与 及 等 上 下 中 为 以 于 从 把 被 让 使 又 而 但 或
24+ 什么 怎么 如何 哪些 哪个 哪里 谁 多少 几 你 他 她 它 我们 他们 她们 它们
25+ 请问 请 想 要 需要 能 可以 应该 会 啊 呢 吧 嘛 吗 哦 呀 哈
26+ 之 其 此 该 即 各 每 些 种 类 时 后 前 里 外 内 间 已经 正在 一些 一下
27+ """ .split ())
28+
29+ # RRF 融合的常数 k:Cormack et al. 2009 默认值
30+ _RRF_K = 60
31+
1932@dataclass
2033class RetrievalResult :
2134 """检索结果数据结构"""
@@ -30,42 +43,61 @@ class HybridRetrievalModule:
3043 """
3144 混合检索模块
3245 核心特点:
33- 1. 双层检索范式(实体级 + 主题级)
34- 2. 关键词提取和匹配
35- 3. 图结构+向量检索结合
36- 4. 一跳邻居扩展
37- 5. Round-robin轮询合并策略
46+ 1. 双层检索范式(实体级 + 主题级,基于图键值索引)
47+ 2. BM25 关键词检索(jieba 分词 + 停用词过滤)
48+ 3. 向量检索(Milvus)+ 一跳邻居扩展
49+ 4. RRF (Reciprocal Rank Fusion) 融合三路结果
3850 """
39-
51+
4052 def __init__ (self , config , milvus_module , data_module , llm_client ):
4153 self .config = config
4254 self .milvus_module = milvus_module
4355 self .data_module = data_module
4456 self .llm_client = llm_client
4557 self .driver = None
46- self .bm25_retriever = None
47-
58+
59+ # BM25 索引 + 原始文档(按索引位置对齐)
60+ self .bm25 : Optional [BM25Okapi ] = None
61+ self .bm25_corpus_docs : List [Document ] = []
62+
4863 # 图索引模块
4964 self .graph_indexing = GraphIndexingModule (config , llm_client )
5065 self .graph_indexed = False
51-
66+
5267 def initialize (self , chunks : List [Document ]):
5368 """初始化检索系统"""
5469 logger .info ("初始化混合检索模块..." )
55-
70+
5671 # 连接Neo4j
5772 self .driver = GraphDatabase .driver (
58- self .config .neo4j_uri ,
73+ self .config .neo4j_uri ,
5974 auth = (self .config .neo4j_user , self .config .neo4j_password )
6075 )
61-
62- # 初始化BM25检索器
76+
77+ # 初始化 BM25(jieba 分词 + 中文停用词过滤)
6378 if chunks :
64- self .bm25_retriever = BM25Retriever .from_documents (chunks )
65- logger .info (f"BM25检索器初始化完成,文档数量: { len (chunks )} " )
66-
79+ self .bm25_corpus_docs = list (chunks )
80+ tokenized_corpus = [self ._tokenize_chinese (d .page_content ) for d in chunks ]
81+ self .bm25 = BM25Okapi (tokenized_corpus )
82+ avg_tokens = sum (len (t ) for t in tokenized_corpus ) / max (1 , len (tokenized_corpus ))
83+ logger .info (
84+ f"BM25(jieba+stopwords) 索引构建完成,文档数: { len (chunks )} ,"
85+ f"平均 token 数: { avg_tokens :.1f} "
86+ )
87+
6788 # 初始化图索引
6889 self ._build_graph_index ()
90+
91+ @staticmethod
92+ def _tokenize_chinese (text : str ) -> List [str ]:
93+ """jieba 精确分词 + 停用词 / 空白 / 单字符过滤"""
94+ if not text :
95+ return []
96+ tokens = jieba .lcut (text )
97+ return [
98+ t for t in tokens
99+ if t .strip () and t not in _CHINESE_STOPWORDS and not t .isspace ()
100+ ]
69101
70102 def _build_graph_index (self ):
71103 """构建图索引"""
@@ -542,60 +574,147 @@ def _get_node_neighbors(self, node_id: str, max_neighbors: int = 3) -> List[str]
542574 logger .error (f"获取邻居节点失败: { e } " )
543575 return []
544576
577+ def bm25_search (self , query : str , top_k : int = 5 ) -> List [Document ]:
578+ """
579+ BM25 检索:jieba 分词后查 BM25Okapi 索引,按分数降序返回 top_k。
580+ 分数写入 metadata["bm25_score"],供调试与未来潜在的分数级融合使用。
581+ """
582+ if self .bm25 is None or not self .bm25_corpus_docs :
583+ logger .warning ("BM25 索引未初始化,bm25_search 返回空" )
584+ return []
585+
586+ tokenized_query = self ._tokenize_chinese (query )
587+ if not tokenized_query :
588+ logger .debug (f"BM25 query 分词为空,跳过: { query } " )
589+ return []
590+
591+ scores = self .bm25 .get_scores (tokenized_query )
592+ # 按分数降序取 top_k 索引
593+ top_indices = sorted (
594+ range (len (scores )), key = lambda i : scores [i ], reverse = True
595+ )[:top_k ]
596+
597+ docs : List [Document ] = []
598+ for idx in top_indices :
599+ score = float (scores [idx ])
600+ if score <= 0 :
601+ # BM25 分数 ≤ 0 视为无关(IDF/TF 全无贡献),不进结果
602+ continue
603+ src = self .bm25_corpus_docs [idx ]
604+ recipe_name = (
605+ src .metadata .get ("recipe_name" )
606+ or src .metadata .get ("name" )
607+ or "未知菜品"
608+ )
609+ doc = Document (
610+ page_content = src .page_content ,
611+ metadata = {
612+ ** src .metadata ,
613+ "recipe_name" : recipe_name ,
614+ "search_method" : "bm25" ,
615+ "search_type" : "bm25" ,
616+ "bm25_score" : score ,
617+ }
618+ )
619+ docs .append (doc )
620+
621+ logger .info (f"BM25 检索完成,返回 { len (docs )} 个文档(query tokens={ tokenized_query } )" )
622+ return docs
623+
624+ @staticmethod
625+ def _rrf_merge (
626+ ranked_lists : List [Tuple [str , List [Document ]]],
627+ top_k : int ,
628+ k : int = _RRF_K ,
629+ ) -> List [Document ]:
630+ """
631+ Reciprocal Rank Fusion: score(d) = Σ_i 1 / (k + rank_i(d))
632+
633+ Args:
634+ ranked_lists: 多路 (source_name, ranked_docs) — docs 按相关度降序
635+ top_k: 最终返回个数
636+ k: RRF 平滑常数,默认 60(Cormack et al. 2009)
637+
638+ 去重 key:node_id 优先,page_content[:200] hash 兜底。
639+ 合并后 metadata 写入 rrf_score / rrf_sources / final_score。
640+ """
641+ rrf_scores : Dict [str , float ] = {}
642+ doc_index : Dict [str , Document ] = {}
643+ sources : Dict [str , List [str ]] = {}
644+ ranks_by_source : Dict [str , Dict [str , int ]] = {}
645+
646+ for source_name , ranked_docs in ranked_lists :
647+ for rank , doc in enumerate (ranked_docs , start = 1 ):
648+ node_id = doc .metadata .get ("node_id" )
649+ doc_id = (
650+ str (node_id ) if node_id is not None
651+ else f"hash::{ hash (doc .page_content [:200 ])} "
652+ )
653+
654+ contribution = 1.0 / (k + rank )
655+ rrf_scores [doc_id ] = rrf_scores .get (doc_id , 0.0 ) + contribution
656+
657+ # 第一次见到这个 doc 时记录为 canonical(通常是 rank 较高的那路)
658+ if doc_id not in doc_index :
659+ doc_index [doc_id ] = doc
660+ sources [doc_id ] = []
661+ ranks_by_source [doc_id ] = {}
662+
663+ if source_name not in sources [doc_id ]:
664+ sources [doc_id ].append (source_name )
665+ ranks_by_source [doc_id ][source_name ] = rank
666+
667+ # 按 RRF score 降序
668+ sorted_ids = sorted (
669+ rrf_scores .keys (), key = lambda d : rrf_scores [d ], reverse = True
670+ )
671+
672+ merged : List [Document ] = []
673+ for doc_id in sorted_ids [:top_k ]:
674+ doc = doc_index [doc_id ]
675+ doc .metadata ["rrf_score" ] = rrf_scores [doc_id ]
676+ doc .metadata ["rrf_sources" ] = list (sources [doc_id ])
677+ doc .metadata ["rrf_ranks" ] = dict (ranks_by_source [doc_id ])
678+ doc .metadata ["final_score" ] = rrf_scores [doc_id ]
679+ merged .append (doc )
680+
681+ return merged
682+
545683 def hybrid_search (self , query : str , top_k : int = 5 ) -> List [Document ]:
546684 """
547- 混合检索:使用Round-robin轮询合并策略
548- 公平轮询合并不同检索结果,不使用权重配置
685+ 混合检索:三路召回(图键值双层 + 向量 + BM25)→ RRF 融合
549686 """
550- logger .info (f"开始混合检索: { query } " )
551-
552- # 1. 双层检索(实体+主题检索)
553- dual_docs = self .dual_level_retrieval (query , top_k )
554-
555- # 2. 增强向量检索
556- vector_docs = self .vector_search_enhanced (query , top_k )
557-
558- # 3. Round-robin轮询合并
559- merged_docs = []
560- seen_doc_ids = set ()
561- max_len = max (len (dual_docs ), len (vector_docs ))
562- origin_len = len (dual_docs ) + len (vector_docs )
563-
564- for i in range (max_len ):
565- # 先添加双层检索结果
566- if i < len (dual_docs ):
567- doc = dual_docs [i ]
568- doc_id = doc .metadata .get ("node_id" , hash (doc .page_content ))
569- if doc_id not in seen_doc_ids :
570- seen_doc_ids .add (doc_id )
571- doc .metadata ["search_method" ] = "dual_level"
572- doc .metadata ["round_robin_order" ] = len (merged_docs )
573- # 设置统一的final_score字段
574- doc .metadata ["final_score" ] = doc .metadata .get ("relevance_score" , 0.0 )
575- merged_docs .append (doc )
576-
577- # 再添加向量检索结果
578- if i < len (vector_docs ):
579- doc = vector_docs [i ]
580- doc_id = doc .metadata .get ("node_id" , hash (doc .page_content ))
581- if doc_id not in seen_doc_ids :
582- seen_doc_ids .add (doc_id )
583- doc .metadata ["search_method" ] = "vector_enhanced"
584- doc .metadata ["round_robin_order" ] = len (merged_docs )
585- # 设置统一的final_score字段(向量得分需要转换)
586- vector_score = doc .metadata .get ("score" , 0.0 )
587- # COSINE距离转换为相似度:distance越小,相似度越高
588- similarity_score = max (0.0 , 1.0 - vector_score ) if vector_score <= 1.0 else 0.0
589- doc .metadata ["final_score" ] = similarity_score
590- merged_docs .append (doc )
591-
592- # 取前top_k个结果
593- final_docs = merged_docs [:top_k ]
594-
595- logger .info (f"Round-robin合并:从总共{ origin_len } 个结果合并为{ len (final_docs )} 个文档" )
596- logger .info (f"混合检索完成,返回 { len (final_docs )} 个文档" )
687+ logger .info (f"开始混合检索(dual + vector + bm25, RRF k={ _RRF_K } ): { query } " )
688+
689+ # 每路给 RRF 留够候选空间,否则三路各自前 top_k 容易没交集,融合退化
690+ candidate_k = max (top_k * 2 , 10 )
691+
692+ dual_docs = self .dual_level_retrieval (query , candidate_k )
693+ vector_docs = self .vector_search_enhanced (query , candidate_k )
694+ bm25_docs = self .bm25_search (query , candidate_k )
695+
696+ # 标记每路来源(dual_level 内部会写 search_type 但不一定写 search_method)
697+ for d in dual_docs :
698+ d .metadata .setdefault ("search_method" , "dual_level" )
699+ for d in vector_docs :
700+ d .metadata ["search_method" ] = "vector"
701+ # bm25_search 内部已写 search_method=bm25
702+
703+ final_docs = self ._rrf_merge (
704+ ranked_lists = [
705+ ("dual_level" , dual_docs ),
706+ ("vector" , vector_docs ),
707+ ("bm25" , bm25_docs ),
708+ ],
709+ top_k = top_k ,
710+ )
711+
712+ logger .info (
713+ f"RRF 融合完成:dual={ len (dual_docs )} vector={ len (vector_docs )} "
714+ f"bm25={ len (bm25_docs )} → 最终 { len (final_docs )} 个文档"
715+ )
597716 return final_docs
598-
717+
599718 def close (self ):
600719 """关闭资源连接"""
601720 if self .driver :
0 commit comments