@@ -474,8 +474,16 @@ async def query_passages(
474474 # for hybrid mode, we get a multi-query response
475475 vector_results = self ._process_single_query_results (result .results [0 ], archive_id , tags )
476476 fts_results = self ._process_single_query_results (result .results [1 ], archive_id , tags , is_fts = True )
477- # use backwards-compatible wrapper which calls generic RRF
478- return self ._reciprocal_rank_fusion (vector_results , fts_results , vector_weight , fts_weight , top_k )
477+ # use RRF and return only (passage, score) for backwards compatibility
478+ results_with_metadata = self ._reciprocal_rank_fusion (
479+ vector_results = [passage for passage , _ in vector_results ],
480+ fts_results = [passage for passage , _ in fts_results ],
481+ get_id_func = lambda p : p .id ,
482+ vector_weight = vector_weight ,
483+ fts_weight = fts_weight ,
484+ top_k = top_k ,
485+ )
486+ return [(passage , rrf_score ) for passage , rrf_score , metadata in results_with_metadata ]
479487 else :
480488 # for single queries (vector, fts, timestamp)
481489 is_fts = search_mode == "fts"
@@ -499,7 +507,7 @@ async def query_messages(
499507 fts_weight : float = 0.5 ,
500508 start_date : Optional [datetime ] = None ,
501509 end_date : Optional [datetime ] = None ,
502- ) -> List [Tuple [dict , float ]]:
510+ ) -> List [Tuple [dict , float , dict ]]:
503511 """Query messages from Turbopuffer using vector search, full-text search, or hybrid search.
504512
505513 Args:
@@ -516,7 +524,10 @@ async def query_messages(
516524 end_date: Optional datetime to filter messages created before this date
517525
518526 Returns:
519- List of (message_dict, score) tuples where message_dict contains id, text, role, created_at
527+ List of (message_dict, score, metadata) tuples where:
528+ - message_dict contains id, text, role, created_at
529+ - score is the final relevance score
530+ - metadata contains individual scores and ranking information
520531 """
521532 # Check if we should fallback to timestamp-based retrieval
522533 if query_embedding is None and query_text is None and search_mode not in ["timestamp" ]:
@@ -576,28 +587,42 @@ async def query_messages(
576587 if search_mode == "hybrid" :
577588 # for hybrid mode, we get a multi-query response
578589 vector_results = self ._process_message_query_results (result .results [0 ])
579- fts_results = self ._process_message_query_results (result .results [1 ], is_fts = True )
580- # use generic RRF with lambda to extract ID from dict
581- return self ._generic_reciprocal_rank_fusion (
590+ fts_results = self ._process_message_query_results (result .results [1 ])
591+ # use RRF with lambda to extract ID from dict - returns metadata
592+ results_with_metadata = self ._reciprocal_rank_fusion (
582593 vector_results = vector_results ,
583594 fts_results = fts_results ,
584595 get_id_func = lambda msg_dict : msg_dict ["id" ],
585596 vector_weight = vector_weight ,
586597 fts_weight = fts_weight ,
587598 top_k = top_k ,
588599 )
600+ # return results with metadata
601+ return results_with_metadata
589602 else :
590603 # for single queries (vector, fts, timestamp)
591- is_fts = search_mode == "fts"
592- return self ._process_message_query_results (result , is_fts = is_fts )
604+ results = self ._process_message_query_results (result )
605+ # add simple metadata for single search modes
606+ results_with_metadata = []
607+ for idx , msg_dict in enumerate (results ):
608+ metadata = {
609+ "combined_score" : 1.0 / (idx + 1 ), # Use rank-based score for single mode
610+ "search_mode" : search_mode ,
611+ f"{ search_mode } _rank" : idx + 1 , # Add the rank for this search mode
612+ }
613+ results_with_metadata .append ((msg_dict , metadata ["combined_score" ], metadata ))
614+ return results_with_metadata
593615
594616 except Exception as e :
595617 logger .error (f"Failed to query messages from Turbopuffer: { e } " )
596618 raise
597619
598- def _process_message_query_results (self , result , is_fts : bool = False ) -> List [Tuple [dict , float ]]:
599- """Process results from a message query into message dicts with scores."""
600- messages_with_scores = []
620+ def _process_message_query_results (self , result ) -> List [dict ]:
621+ """Process results from a message query into message dicts.
622+
623+ For RRF, we only need the rank order - scores are not used.
624+ """
625+ messages = []
601626
602627 for row in result .rows :
603628 # Build message dict with key fields
@@ -609,19 +634,9 @@ def _process_message_query_results(self, result, is_fts: bool = False) -> List[T
609634 "role" : getattr (row , "role" , None ),
610635 "created_at" : getattr (row , "created_at" , None ),
611636 }
637+ messages .append (message_dict )
612638
613- # handle score based on search type
614- if is_fts :
615- # for FTS, use the BM25 score directly (higher is better)
616- score = getattr (row , "$score" , 0.0 )
617- else :
618- # for vector search, convert distance to similarity score
619- distance = getattr (row , "$dist" , 0.0 )
620- score = 1.0 - distance
621-
622- messages_with_scores .append ((message_dict , score ))
623-
624- return messages_with_scores
639+ return messages
625640
626641 def _process_single_query_results (
627642 self , result , archive_id : str , tags : Optional [List [str ]], is_fts : bool = False
@@ -663,74 +678,78 @@ def _process_single_query_results(
663678
664679 return passages_with_scores
665680
666- def _generic_reciprocal_rank_fusion (
681+ def _reciprocal_rank_fusion (
667682 self ,
668- vector_results : List [Tuple [ Any , float ] ],
669- fts_results : List [Tuple [ Any , float ] ],
683+ vector_results : List [Any ],
684+ fts_results : List [Any ],
670685 get_id_func : Callable [[Any ], str ],
671686 vector_weight : float ,
672687 fts_weight : float ,
673688 top_k : int ,
674- ) -> List [Tuple [Any , float ]]:
675- """Generic RRF implementation that works with any object type.
689+ ) -> List [Tuple [Any , float , dict ]]:
690+ """RRF implementation that works with any object type.
676691
677- RRF score = vector_weight * (1/(k + vector_rank )) + fts_weight * (1/(k + fts_rank ))
692+ RRF score = vector_weight * (1/(k + rank )) + fts_weight * (1/(k + rank ))
678693 where k is a constant (typically 60) to avoid division by zero
679694
695+ This is a pure rank-based fusion following the standard RRF algorithm.
696+
680697 Args:
681- vector_results: List of (item, score) tuples from vector search
682- fts_results: List of (item, score) tuples from FTS
698+ vector_results: List of items from vector search (ordered by relevance)
699+ fts_results: List of items from FTS (ordered by relevance)
683700 get_id_func: Function to extract ID from an item
684701 vector_weight: Weight for vector search results
685702 fts_weight: Weight for FTS results
686703 top_k: Number of results to return
687704
688705 Returns:
689- List of (item, score) tuples sorted by RRF score
706+ List of (item, score, metadata) tuples sorted by RRF score
707+ metadata contains ranks from each result list
690708 """
691- k = 60 # standard RRF constant
709+ k = 60 # standard RRF constant from Cormack et al. (2009)
692710
693- # create rank mappings using the get_id_func
694- vector_ranks = {get_id_func (item ): rank + 1 for rank , (item , _ ) in enumerate (vector_results )}
695- fts_ranks = {get_id_func (item ): rank + 1 for rank , (item , _ ) in enumerate (fts_results )}
711+ # create rank mappings based on position in result lists
712+ # rank starts at 1, not 0
713+ vector_ranks = {get_id_func (item ): rank + 1 for rank , item in enumerate (vector_results )}
714+ fts_ranks = {get_id_func (item ): rank + 1 for rank , item in enumerate (fts_results )}
696715
697- # combine all unique items
716+ # combine all unique items from both result sets
698717 all_items = {}
699- for item , _ in vector_results :
718+ for item in vector_results :
700719 all_items [get_id_func (item )] = item
701- for item , _ in fts_results :
720+ for item in fts_results :
702721 all_items [get_id_func (item )] = item
703722
704- # calculate RRF scores
723+ # calculate RRF scores based purely on ranks
705724 rrf_scores = {}
725+ score_metadata = {}
706726 for item_id in all_items :
707- vector_score = vector_weight / (k + vector_ranks .get (item_id , k + top_k ))
708- fts_score = fts_weight / (k + fts_ranks .get (item_id , k + top_k ))
709- rrf_scores [item_id ] = vector_score + fts_score
727+ # RRF formula: sum of 1/(k + rank) across result lists
728+ # If item not in a list, we don't add anything (equivalent to rank = infinity)
729+ vector_rrf_score = 0.0
730+ fts_rrf_score = 0.0
731+
732+ if item_id in vector_ranks :
733+ vector_rrf_score = vector_weight / (k + vector_ranks [item_id ])
734+ if item_id in fts_ranks :
735+ fts_rrf_score = fts_weight / (k + fts_ranks [item_id ])
736+
737+ combined_score = vector_rrf_score + fts_rrf_score
738+
739+ rrf_scores [item_id ] = combined_score
740+ score_metadata [item_id ] = {
741+ "combined_score" : combined_score , # Final RRF score
742+ "vector_rank" : vector_ranks .get (item_id ),
743+ "fts_rank" : fts_ranks .get (item_id ),
744+ }
710745
711- # sort by RRF score and return top_k
712- sorted_results = sorted ([(all_items [iid ], score ) for iid , score in rrf_scores .items ()], key = lambda x : x [1 ], reverse = True )
746+ # sort by RRF score and return with metadata
747+ sorted_results = sorted (
748+ [(all_items [iid ], score , score_metadata [iid ]) for iid , score in rrf_scores .items ()], key = lambda x : x [1 ], reverse = True
749+ )
713750
714751 return sorted_results [:top_k ]
715752
716- def _reciprocal_rank_fusion (
717- self ,
718- vector_results : List [Tuple [PydanticPassage , float ]],
719- fts_results : List [Tuple [PydanticPassage , float ]],
720- vector_weight : float ,
721- fts_weight : float ,
722- top_k : int ,
723- ) -> List [Tuple [PydanticPassage , float ]]:
724- """Wrapper for backwards compatibility - uses generic RRF for passages."""
725- return self ._generic_reciprocal_rank_fusion (
726- vector_results = vector_results ,
727- fts_results = fts_results ,
728- get_id_func = lambda p : p .id ,
729- vector_weight = vector_weight ,
730- fts_weight = fts_weight ,
731- top_k = top_k ,
732- )
733-
734753 @trace_method
735754 async def delete_passage (self , archive_id : str , passage_id : str ) -> bool :
736755 """Delete a passage from Turbopuffer."""
0 commit comments