11"""
22混合检索模块
33基于双层检索范式:实体级 + 主题级检索
4- 结合图结构检索和向量检索,使用Round-robin轮询策略
4+ 结合 BM25(jieba 分词)、向量检索与图键值索引,使用 RRF 融合
55"""
66
77import json
88import logging
9- from typing import List , Dict , Tuple , Any
9+ from typing import List , Dict , Tuple , Any , Optional
1010from dataclasses import dataclass
1111
12+ import jieba
13+ from rank_bm25 import BM25Okapi
1214from langchain_core .documents import Document
13- from langchain_community .retrievers import BM25Retriever
1415from neo4j import GraphDatabase
1516from .graph_indexing import GraphIndexingModule
1617
1718logger = logging .getLogger (__name__ )
1819
20+ # 中文停用词表:助词 / 连词 / 疑问词 / 人称 / 语气词 / 动词修饰
21+ # 不引第三方停用词包,按烹饪问答场景手挑(覆盖 testset 高频虚词)
22+ _CHINESE_STOPWORDS = set ("""
23+ 的 了 和 是 在 我 有 就 不 也 都 还 这 那 一 个 与 及 等 上 下 中 为 以 于 从 把 被 让 使 又 而 但 或
24+ 什么 怎么 如何 哪些 哪个 哪里 谁 多少 几 你 他 她 它 我们 他们 她们 它们
25+ 请问 请 想 要 需要 能 可以 应该 会 啊 呢 吧 嘛 吗 哦 呀 哈
26+ 之 其 此 该 即 各 每 些 种 类 时 后 前 里 外 内 间 已经 正在 一些 一下
27+ """ .split ())
28+
29+ # RRF 融合的常数 k:Cormack et al. 2009 默认值
30+ _RRF_K = 60
31+
1932@dataclass
2033class RetrievalResult :
2134 """检索结果数据结构"""
@@ -30,42 +43,61 @@ class HybridRetrievalModule:
3043 """
3144 混合检索模块
3245 核心特点:
33- 1. 双层检索范式(实体级 + 主题级)
34- 2. 关键词提取和匹配
35- 3. 图结构+向量检索结合
36- 4. 一跳邻居扩展
37- 5. Round-robin轮询合并策略
46+ 1. 双层检索范式(实体级 + 主题级,基于图键值索引)
47+ 2. BM25 关键词检索(jieba 分词 + 停用词过滤)
48+ 3. 向量检索(Milvus)+ 一跳邻居扩展
49+ 4. RRF (Reciprocal Rank Fusion) 融合三路结果
3850 """
39-
51+
4052 def __init__ (self , config , milvus_module , data_module , llm_client ):
4153 self .config = config
4254 self .milvus_module = milvus_module
4355 self .data_module = data_module
4456 self .llm_client = llm_client
4557 self .driver = None
46- self .bm25_retriever = None
47-
58+
59+ # BM25 索引 + 原始文档(按索引位置对齐)
60+ self .bm25 : Optional [BM25Okapi ] = None
61+ self .bm25_corpus_docs : List [Document ] = []
62+
4863 # 图索引模块
4964 self .graph_indexing = GraphIndexingModule (config , llm_client )
5065 self .graph_indexed = False
51-
66+
5267 def initialize (self , chunks : List [Document ]):
5368 """初始化检索系统"""
5469 logger .info ("初始化混合检索模块..." )
55-
70+
5671 # 连接Neo4j
5772 self .driver = GraphDatabase .driver (
58- self .config .neo4j_uri ,
73+ self .config .neo4j_uri ,
5974 auth = (self .config .neo4j_user , self .config .neo4j_password )
6075 )
61-
62- # 初始化BM25检索器
76+
77+ # 初始化 BM25(jieba 分词 + 中文停用词过滤)
6378 if chunks :
64- self .bm25_retriever = BM25Retriever .from_documents (chunks )
65- logger .info (f"BM25检索器初始化完成,文档数量: { len (chunks )} " )
66-
79+ self .bm25_corpus_docs = list (chunks )
80+ tokenized_corpus = [self ._tokenize_chinese (d .page_content ) for d in chunks ]
81+ self .bm25 = BM25Okapi (tokenized_corpus )
82+ avg_tokens = sum (len (t ) for t in tokenized_corpus ) / max (1 , len (tokenized_corpus ))
83+ logger .info (
84+ f"BM25(jieba+stopwords) 索引构建完成,文档数: { len (chunks )} ,"
85+ f"平均 token 数: { avg_tokens :.1f} "
86+ )
87+
6788 # 初始化图索引
6889 self ._build_graph_index ()
90+
91+ @staticmethod
92+ def _tokenize_chinese (text : str ) -> List [str ]:
93+ """jieba 精确分词 + 停用词 / 空白 / 单字符过滤"""
94+ if not text :
95+ return []
96+ tokens = jieba .lcut (text )
97+ return [
98+ t for t in tokens
99+ if t .strip () and t not in _CHINESE_STOPWORDS and not t .isspace ()
100+ ]
69101
70102 def _build_graph_index (self ):
71103 """构建图索引"""
@@ -542,60 +574,175 @@ def _get_node_neighbors(self, node_id: str, max_neighbors: int = 3) -> List[str]
542574 logger .error (f"获取邻居节点失败: { e } " )
543575 return []
544576
577+ def bm25_search (self , query : str , top_k : int = 5 ) -> List [Document ]:
578+ """
579+ BM25 检索:jieba 分词后查 BM25Okapi 索引,按分数降序返回 top_k。
580+ 分数写入 metadata["bm25_score"],供调试与未来潜在的分数级融合使用。
581+ """
582+ if self .bm25 is None or not self .bm25_corpus_docs :
583+ logger .warning ("BM25 索引未初始化,bm25_search 返回空" )
584+ return []
585+
586+ tokenized_query = self ._tokenize_chinese (query )
587+ if not tokenized_query :
588+ logger .debug (f"BM25 query 分词为空,跳过: { query } " )
589+ return []
590+
591+ scores = self .bm25 .get_scores (tokenized_query )
592+ # 按分数降序取 top_k 索引
593+ top_indices = sorted (
594+ range (len (scores )), key = lambda i : scores [i ], reverse = True
595+ )[:top_k ]
596+
597+ docs : List [Document ] = []
598+ for idx in top_indices :
599+ score = float (scores [idx ])
600+ if score <= 0 :
601+ # BM25 分数 ≤ 0 视为无关(IDF/TF 全无贡献),不进结果
602+ continue
603+ src = self .bm25_corpus_docs [idx ]
604+ recipe_name = (
605+ src .metadata .get ("recipe_name" )
606+ or src .metadata .get ("name" )
607+ or "未知菜品"
608+ )
609+ doc = Document (
610+ page_content = src .page_content ,
611+ metadata = {
612+ ** src .metadata ,
613+ "recipe_name" : recipe_name ,
614+ "search_method" : "bm25" ,
615+ "search_type" : "bm25" ,
616+ "bm25_score" : score ,
617+ }
618+ )
619+ docs .append (doc )
620+
621+ logger .info (f"BM25 检索完成,返回 { len (docs )} 个文档(query tokens={ tokenized_query } )" )
622+ return docs
623+
624+ @staticmethod
625+ def _rrf_merge (
626+ ranked_lists : List [Tuple [str , List [Document ]]],
627+ top_k : int ,
628+ k : int = _RRF_K ,
629+ ) -> List [Document ]:
630+ """
631+ Reciprocal Rank Fusion: score(d) = Σ_i 1 / (k + best_rank_i(d))
632+
633+ Args:
634+ ranked_lists: 多路 (source_name, ranked_docs) — docs 按相关度降序
635+ top_k: 最终返回个数
636+ k: RRF 平滑常数,默认 60(Cormack et al. 2009)
637+
638+ 去重 key:node_id 优先,page_content[:200] hash 兜底。
639+
640+ 同 source 内同 doc_id 多次命中(如一道菜的多个 chunk 共享 recipe.nodeId):
641+ - 算分只取该 source 内最佳 rank(最小 rank)一次,避免重复加分
642+ - 命中 chunk 数另存到 rrf_chunk_hits,供后续分析
643+
644+ canonical doc(最终展示给 LLM 的 page_content):
645+ 选全局最小 rank 那个 chunk;rank 相同时按 ranked_lists 顺序优先。
646+
647+ 返回的 Document 是新对象,不会 mutate 输入 list 里的 Document。
648+ """
649+ # doc_id -> source_name -> 该 source 内最小 rank(用于算分)
650+ best_rank_per_source : Dict [str , Dict [str , int ]] = {}
651+ # doc_id -> source_name -> 该 source 内命中 chunk 次数(信息存档)
652+ chunk_hits_per_source : Dict [str , Dict [str , int ]] = {}
653+ # doc_id -> (global_best_rank, source_priority, doc) — 选 canonical doc
654+ best_doc_info : Dict [str , Tuple [int , int , Document ]] = {}
655+
656+ for source_priority , (source_name , ranked_docs ) in enumerate (ranked_lists ):
657+ for rank , doc in enumerate (ranked_docs , start = 1 ):
658+ node_id = doc .metadata .get ("node_id" )
659+ doc_id = (
660+ str (node_id ) if node_id is not None
661+ else f"hash::{ hash (doc .page_content [:200 ])} "
662+ )
663+
664+ if doc_id not in best_rank_per_source :
665+ best_rank_per_source [doc_id ] = {}
666+ chunk_hits_per_source [doc_id ] = {}
667+
668+ curr_best = best_rank_per_source [doc_id ].get (source_name )
669+ # 如果是第一次出现或者当前rank比记录的更小,则更新
670+ if curr_best is None or rank < curr_best :
671+ best_rank_per_source [doc_id ][source_name ] = rank
672+
673+ chunk_hits_per_source [doc_id ][source_name ] = (
674+ chunk_hits_per_source [doc_id ].get (source_name , 0 ) + 1
675+ )
676+
677+ new_key = (rank , source_priority )
678+ if (
679+ doc_id not in best_doc_info
680+ or new_key < (best_doc_info [doc_id ][0 ], best_doc_info [doc_id ][1 ])
681+ ):
682+ best_doc_info [doc_id ] = (rank , source_priority , doc )
683+
684+ # 每个 source 只用 best rank 算一次贡献
685+ rrf_scores : Dict [str , float ] = {
686+ doc_id : sum (1.0 / (k + r ) for r in source_ranks .values ())
687+ for doc_id , source_ranks in best_rank_per_source .items ()
688+ }
689+
690+ sorted_ids = sorted (
691+ rrf_scores .keys (), key = lambda d : rrf_scores [d ], reverse = True
692+ )
693+
694+ merged : List [Document ] = []
695+ for doc_id in sorted_ids [:top_k ]:
696+ _ , _ , source_doc = best_doc_info [doc_id ]
697+ # 浅 copy metadata,避免 mutate 上游 Document
698+ new_metadata = dict (source_doc .metadata )
699+ new_metadata ["rrf_score" ] = rrf_scores [doc_id ]
700+ new_metadata ["rrf_sources" ] = list (best_rank_per_source [doc_id ].keys ())
701+ new_metadata ["rrf_ranks" ] = dict (best_rank_per_source [doc_id ])
702+ new_metadata ["rrf_chunk_hits" ] = dict (chunk_hits_per_source [doc_id ])
703+ new_metadata ["final_score" ] = rrf_scores [doc_id ]
704+ merged .append (Document (
705+ page_content = source_doc .page_content ,
706+ metadata = new_metadata ,
707+ ))
708+
709+ return merged
710+
545711 def hybrid_search (self , query : str , top_k : int = 5 ) -> List [Document ]:
546712 """
547- 混合检索:使用Round-robin轮询合并策略
548- 公平轮询合并不同检索结果,不使用权重配置
713+ 混合检索:三路召回(图键值双层 + 向量 + BM25)→ RRF 融合
549714 """
550- logger .info (f"开始混合检索: { query } " )
551-
552- # 1. 双层检索(实体+主题检索)
553- dual_docs = self .dual_level_retrieval (query , top_k )
554-
555- # 2. 增强向量检索
556- vector_docs = self .vector_search_enhanced (query , top_k )
557-
558- # 3. Round-robin轮询合并
559- merged_docs = []
560- seen_doc_ids = set ()
561- max_len = max (len (dual_docs ), len (vector_docs ))
562- origin_len = len (dual_docs ) + len (vector_docs )
563-
564- for i in range (max_len ):
565- # 先添加双层检索结果
566- if i < len (dual_docs ):
567- doc = dual_docs [i ]
568- doc_id = doc .metadata .get ("node_id" , hash (doc .page_content ))
569- if doc_id not in seen_doc_ids :
570- seen_doc_ids .add (doc_id )
571- doc .metadata ["search_method" ] = "dual_level"
572- doc .metadata ["round_robin_order" ] = len (merged_docs )
573- # 设置统一的final_score字段
574- doc .metadata ["final_score" ] = doc .metadata .get ("relevance_score" , 0.0 )
575- merged_docs .append (doc )
576-
577- # 再添加向量检索结果
578- if i < len (vector_docs ):
579- doc = vector_docs [i ]
580- doc_id = doc .metadata .get ("node_id" , hash (doc .page_content ))
581- if doc_id not in seen_doc_ids :
582- seen_doc_ids .add (doc_id )
583- doc .metadata ["search_method" ] = "vector_enhanced"
584- doc .metadata ["round_robin_order" ] = len (merged_docs )
585- # 设置统一的final_score字段(向量得分需要转换)
586- vector_score = doc .metadata .get ("score" , 0.0 )
587- # COSINE距离转换为相似度:distance越小,相似度越高
588- similarity_score = max (0.0 , 1.0 - vector_score ) if vector_score <= 1.0 else 0.0
589- doc .metadata ["final_score" ] = similarity_score
590- merged_docs .append (doc )
591-
592- # 取前top_k个结果
593- final_docs = merged_docs [:top_k ]
594-
595- logger .info (f"Round-robin合并:从总共{ origin_len } 个结果合并为{ len (final_docs )} 个文档" )
596- logger .info (f"混合检索完成,返回 { len (final_docs )} 个文档" )
715+ logger .info (f"开始混合检索(dual + vector + bm25, RRF k={ _RRF_K } ): { query } " )
716+
717+ # 每路给 RRF 留够候选空间,否则三路各自前 top_k 容易没交集,融合退化
718+ candidate_k = max (top_k * 2 , 10 )
719+
720+ dual_docs = self .dual_level_retrieval (query , candidate_k )
721+ vector_docs = self .vector_search_enhanced (query , candidate_k )
722+ bm25_docs = self .bm25_search (query , candidate_k )
723+
724+ # 标记每路来源(dual_level 内部会写 search_type 但不一定写 search_method)
725+ for d in dual_docs :
726+ d .metadata .setdefault ("search_method" , "dual_level" )
727+ for d in vector_docs :
728+ d .metadata ["search_method" ] = "vector"
729+ # bm25_search 内部已写 search_method=bm25
730+
731+ final_docs = self ._rrf_merge (
732+ ranked_lists = [
733+ ("dual_level" , dual_docs ),
734+ ("vector" , vector_docs ),
735+ ("bm25" , bm25_docs ),
736+ ],
737+ top_k = top_k ,
738+ )
739+
740+ logger .info (
741+ f"RRF 融合完成:dual={ len (dual_docs )} vector={ len (vector_docs )} "
742+ f"bm25={ len (bm25_docs )} → 最终 { len (final_docs )} 个文档"
743+ )
597744 return final_docs
598-
745+
599746 def close (self ):
600747 """关闭资源连接"""
601748 if self .driver :
0 commit comments