Skip to content

Commit 25d37e2

Browse files
authored
fix: change to pure rank-based RRF for relevance ordering (#4411)
* Fix RRF * Fix turbopuffer tests
1 parent 17c2783 commit 25d37e2

File tree

5 files changed

+174
-120
lines changed

5 files changed

+174
-120
lines changed

letta/functions/function_sets/base.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,33 +55,31 @@ def conversation_search(
5555
str: Query result string containing matching messages with timestamps and content.
5656
"""
5757

58-
import math
59-
6058
from letta.constants import RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
6159
from letta.helpers.json_helpers import json_dumps
6260

63-
if page is None or (isinstance(page, str) and page.lower().strip() == "none"):
64-
page = 0
65-
try:
66-
page = int(page)
67-
except:
68-
raise ValueError("'page' argument must be an integer")
69-
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
70-
# TODO: add paging by page number. currently cursor only works with strings.
71-
# original: start=page * count
61+
# Use provided limit or default
62+
if limit is None:
63+
limit = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
64+
7265
messages = self.message_manager.list_messages_for_agent(
7366
agent_id=self.agent_state.id,
7467
actor=self.user,
7568
query_text=query,
76-
limit=count,
69+
roles=roles,
70+
limit=limit,
7771
)
78-
total = len(messages)
79-
num_pages = math.ceil(total / count) - 1 # 0 index
72+
8073
if len(messages) == 0:
8174
results_str = "No results found."
8275
else:
83-
results_pref = f"Showing {len(messages)} of {total} results (page {page}/{num_pages}):"
84-
results_formatted = [message.content[0].text for message in messages]
76+
results_pref = f"Found {len(messages)} results:"
77+
results_formatted = []
78+
for message in messages:
79+
# Extract text content from message
80+
text_content = message.content[0].text if message.content else ""
81+
result_entry = {"role": message.role, "content": text_content}
82+
results_formatted.append(result_entry)
8583
results_str = f"{results_pref} {json_dumps(results_formatted)}"
8684
return results_str
8785

letta/helpers/tpuf_client.py

Lines changed: 83 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -474,8 +474,16 @@ async def query_passages(
474474
# for hybrid mode, we get a multi-query response
475475
vector_results = self._process_single_query_results(result.results[0], archive_id, tags)
476476
fts_results = self._process_single_query_results(result.results[1], archive_id, tags, is_fts=True)
477-
# use backwards-compatible wrapper which calls generic RRF
478-
return self._reciprocal_rank_fusion(vector_results, fts_results, vector_weight, fts_weight, top_k)
477+
# use RRF and return only (passage, score) for backwards compatibility
478+
results_with_metadata = self._reciprocal_rank_fusion(
479+
vector_results=[passage for passage, _ in vector_results],
480+
fts_results=[passage for passage, _ in fts_results],
481+
get_id_func=lambda p: p.id,
482+
vector_weight=vector_weight,
483+
fts_weight=fts_weight,
484+
top_k=top_k,
485+
)
486+
return [(passage, rrf_score) for passage, rrf_score, metadata in results_with_metadata]
479487
else:
480488
# for single queries (vector, fts, timestamp)
481489
is_fts = search_mode == "fts"
@@ -499,7 +507,7 @@ async def query_messages(
499507
fts_weight: float = 0.5,
500508
start_date: Optional[datetime] = None,
501509
end_date: Optional[datetime] = None,
502-
) -> List[Tuple[dict, float]]:
510+
) -> List[Tuple[dict, float, dict]]:
503511
"""Query messages from Turbopuffer using vector search, full-text search, or hybrid search.
504512
505513
Args:
@@ -516,7 +524,10 @@ async def query_messages(
516524
end_date: Optional datetime to filter messages created before this date
517525
518526
Returns:
519-
List of (message_dict, score) tuples where message_dict contains id, text, role, created_at
527+
List of (message_dict, score, metadata) tuples where:
528+
- message_dict contains id, text, role, created_at
529+
- score is the final relevance score
530+
- metadata contains individual scores and ranking information
520531
"""
521532
# Check if we should fallback to timestamp-based retrieval
522533
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
@@ -576,28 +587,42 @@ async def query_messages(
576587
if search_mode == "hybrid":
577588
# for hybrid mode, we get a multi-query response
578589
vector_results = self._process_message_query_results(result.results[0])
579-
fts_results = self._process_message_query_results(result.results[1], is_fts=True)
580-
# use generic RRF with lambda to extract ID from dict
581-
return self._generic_reciprocal_rank_fusion(
590+
fts_results = self._process_message_query_results(result.results[1])
591+
# use RRF with lambda to extract ID from dict - returns metadata
592+
results_with_metadata = self._reciprocal_rank_fusion(
582593
vector_results=vector_results,
583594
fts_results=fts_results,
584595
get_id_func=lambda msg_dict: msg_dict["id"],
585596
vector_weight=vector_weight,
586597
fts_weight=fts_weight,
587598
top_k=top_k,
588599
)
600+
# return results with metadata
601+
return results_with_metadata
589602
else:
590603
# for single queries (vector, fts, timestamp)
591-
is_fts = search_mode == "fts"
592-
return self._process_message_query_results(result, is_fts=is_fts)
604+
results = self._process_message_query_results(result)
605+
# add simple metadata for single search modes
606+
results_with_metadata = []
607+
for idx, msg_dict in enumerate(results):
608+
metadata = {
609+
"combined_score": 1.0 / (idx + 1), # Use rank-based score for single mode
610+
"search_mode": search_mode,
611+
f"{search_mode}_rank": idx + 1, # Add the rank for this search mode
612+
}
613+
results_with_metadata.append((msg_dict, metadata["combined_score"], metadata))
614+
return results_with_metadata
593615

594616
except Exception as e:
595617
logger.error(f"Failed to query messages from Turbopuffer: {e}")
596618
raise
597619

598-
def _process_message_query_results(self, result, is_fts: bool = False) -> List[Tuple[dict, float]]:
599-
"""Process results from a message query into message dicts with scores."""
600-
messages_with_scores = []
620+
def _process_message_query_results(self, result) -> List[dict]:
621+
"""Process results from a message query into message dicts.
622+
623+
For RRF, we only need the rank order - scores are not used.
624+
"""
625+
messages = []
601626

602627
for row in result.rows:
603628
# Build message dict with key fields
@@ -609,19 +634,9 @@ def _process_message_query_results(self, result, is_fts: bool = False) -> List[T
609634
"role": getattr(row, "role", None),
610635
"created_at": getattr(row, "created_at", None),
611636
}
637+
messages.append(message_dict)
612638

613-
# handle score based on search type
614-
if is_fts:
615-
# for FTS, use the BM25 score directly (higher is better)
616-
score = getattr(row, "$score", 0.0)
617-
else:
618-
# for vector search, convert distance to similarity score
619-
distance = getattr(row, "$dist", 0.0)
620-
score = 1.0 - distance
621-
622-
messages_with_scores.append((message_dict, score))
623-
624-
return messages_with_scores
639+
return messages
625640

626641
def _process_single_query_results(
627642
self, result, archive_id: str, tags: Optional[List[str]], is_fts: bool = False
@@ -663,74 +678,78 @@ def _process_single_query_results(
663678

664679
return passages_with_scores
665680

666-
def _generic_reciprocal_rank_fusion(
681+
def _reciprocal_rank_fusion(
667682
self,
668-
vector_results: List[Tuple[Any, float]],
669-
fts_results: List[Tuple[Any, float]],
683+
vector_results: List[Any],
684+
fts_results: List[Any],
670685
get_id_func: Callable[[Any], str],
671686
vector_weight: float,
672687
fts_weight: float,
673688
top_k: int,
674-
) -> List[Tuple[Any, float]]:
675-
"""Generic RRF implementation that works with any object type.
689+
) -> List[Tuple[Any, float, dict]]:
690+
"""RRF implementation that works with any object type.
676691
677-
RRF score = vector_weight * (1/(k + vector_rank)) + fts_weight * (1/(k + fts_rank))
692+
RRF score = vector_weight * (1/(k + rank)) + fts_weight * (1/(k + rank))
678693
where k is a constant (typically 60) to avoid division by zero
679694
695+
This is a pure rank-based fusion following the standard RRF algorithm.
696+
680697
Args:
681-
vector_results: List of (item, score) tuples from vector search
682-
fts_results: List of (item, score) tuples from FTS
698+
vector_results: List of items from vector search (ordered by relevance)
699+
fts_results: List of items from FTS (ordered by relevance)
683700
get_id_func: Function to extract ID from an item
684701
vector_weight: Weight for vector search results
685702
fts_weight: Weight for FTS results
686703
top_k: Number of results to return
687704
688705
Returns:
689-
List of (item, score) tuples sorted by RRF score
706+
List of (item, score, metadata) tuples sorted by RRF score
707+
metadata contains ranks from each result list
690708
"""
691-
k = 60 # standard RRF constant
709+
k = 60 # standard RRF constant from Cormack et al. (2009)
692710

693-
# create rank mappings using the get_id_func
694-
vector_ranks = {get_id_func(item): rank + 1 for rank, (item, _) in enumerate(vector_results)}
695-
fts_ranks = {get_id_func(item): rank + 1 for rank, (item, _) in enumerate(fts_results)}
711+
# create rank mappings based on position in result lists
712+
# rank starts at 1, not 0
713+
vector_ranks = {get_id_func(item): rank + 1 for rank, item in enumerate(vector_results)}
714+
fts_ranks = {get_id_func(item): rank + 1 for rank, item in enumerate(fts_results)}
696715

697-
# combine all unique items
716+
# combine all unique items from both result sets
698717
all_items = {}
699-
for item, _ in vector_results:
718+
for item in vector_results:
700719
all_items[get_id_func(item)] = item
701-
for item, _ in fts_results:
720+
for item in fts_results:
702721
all_items[get_id_func(item)] = item
703722

704-
# calculate RRF scores
723+
# calculate RRF scores based purely on ranks
705724
rrf_scores = {}
725+
score_metadata = {}
706726
for item_id in all_items:
707-
vector_score = vector_weight / (k + vector_ranks.get(item_id, k + top_k))
708-
fts_score = fts_weight / (k + fts_ranks.get(item_id, k + top_k))
709-
rrf_scores[item_id] = vector_score + fts_score
727+
# RRF formula: sum of 1/(k + rank) across result lists
728+
# If item not in a list, we don't add anything (equivalent to rank = infinity)
729+
vector_rrf_score = 0.0
730+
fts_rrf_score = 0.0
731+
732+
if item_id in vector_ranks:
733+
vector_rrf_score = vector_weight / (k + vector_ranks[item_id])
734+
if item_id in fts_ranks:
735+
fts_rrf_score = fts_weight / (k + fts_ranks[item_id])
736+
737+
combined_score = vector_rrf_score + fts_rrf_score
738+
739+
rrf_scores[item_id] = combined_score
740+
score_metadata[item_id] = {
741+
"combined_score": combined_score, # Final RRF score
742+
"vector_rank": vector_ranks.get(item_id),
743+
"fts_rank": fts_ranks.get(item_id),
744+
}
710745

711-
# sort by RRF score and return top_k
712-
sorted_results = sorted([(all_items[iid], score) for iid, score in rrf_scores.items()], key=lambda x: x[1], reverse=True)
746+
# sort by RRF score and return with metadata
747+
sorted_results = sorted(
748+
[(all_items[iid], score, score_metadata[iid]) for iid, score in rrf_scores.items()], key=lambda x: x[1], reverse=True
749+
)
713750

714751
return sorted_results[:top_k]
715752

716-
def _reciprocal_rank_fusion(
717-
self,
718-
vector_results: List[Tuple[PydanticPassage, float]],
719-
fts_results: List[Tuple[PydanticPassage, float]],
720-
vector_weight: float,
721-
fts_weight: float,
722-
top_k: int,
723-
) -> List[Tuple[PydanticPassage, float]]:
724-
"""Wrapper for backwards compatibility - uses generic RRF for passages."""
725-
return self._generic_reciprocal_rank_fusion(
726-
vector_results=vector_results,
727-
fts_results=fts_results,
728-
get_id_func=lambda p: p.id,
729-
vector_weight=vector_weight,
730-
fts_weight=fts_weight,
731-
top_k=top_k,
732-
)
733-
734753
@trace_method
735754
async def delete_passage(self, archive_id: str, passage_id: str) -> bool:
736755
"""Delete a passage from Turbopuffer."""

letta/services/message_manager.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import uuid
33
from datetime import datetime
4-
from typing import List, Optional, Sequence
4+
from typing import List, Optional, Sequence, Tuple
55

66
from sqlalchemy import delete, exists, func, select, text
77

@@ -1065,7 +1065,7 @@ async def search_messages_async(
10651065
start_date: Optional[datetime] = None,
10661066
end_date: Optional[datetime] = None,
10671067
embedding_config: Optional[EmbeddingConfig] = None,
1068-
) -> List[PydanticMessage]:
1068+
) -> List[Tuple[PydanticMessage, dict]]:
10691069
"""
10701070
Search messages using Turbopuffer if enabled, otherwise fall back to SQL search.
10711071
@@ -1082,7 +1082,7 @@ async def search_messages_async(
10821082
embedding_config: Optional embedding configuration for generating query embedding
10831083
10841084
Returns:
1085-
List of matching messages
1085+
List of tuples (message, metadata) where metadata contains relevance scores
10861086
"""
10871087
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
10881088

@@ -1133,8 +1133,8 @@ async def search_messages_async(
11331133
from letta.schemas.letta_message_content import TextContent
11341134
from letta.schemas.message import Message as PydanticMessage
11351135

1136-
turbopuffer_messages = []
1137-
for msg_dict, score in results:
1136+
message_tuples = []
1137+
for msg_dict, score, metadata in results:
11381138
# create a message object with the properly extracted text from turbopuffer
11391139
message = PydanticMessage(
11401140
id=msg_dict["id"],
@@ -1146,9 +1146,10 @@ async def search_messages_async(
11461146
created_by_id=actor.id,
11471147
last_updated_by_id=actor.id,
11481148
)
1149-
turbopuffer_messages.append(message)
1149+
# Return tuple of (message, metadata)
1150+
message_tuples.append((message, metadata))
11501151

1151-
return turbopuffer_messages
1152+
return message_tuples
11521153
else:
11531154
return []
11541155

@@ -1163,7 +1164,16 @@ async def search_messages_async(
11631164
limit=limit,
11641165
ascending=False,
11651166
)
1166-
return self._combine_assistant_tool_messages(messages)
1167+
combined_messages = self._combine_assistant_tool_messages(messages)
1168+
# Add basic metadata for SQL fallback
1169+
message_tuples = []
1170+
for message in combined_messages:
1171+
metadata = {
1172+
"search_mode": "sql_fallback",
1173+
"combined_score": None, # SQL doesn't provide scores
1174+
}
1175+
message_tuples.append((message, metadata))
1176+
return message_tuples
11671177
else:
11681178
# use sql-based search
11691179
messages = await self.list_messages_for_agent_async(
@@ -1174,4 +1184,13 @@ async def search_messages_async(
11741184
limit=limit,
11751185
ascending=False,
11761186
)
1177-
return self._combine_assistant_tool_messages(messages)
1187+
combined_messages = self._combine_assistant_tool_messages(messages)
1188+
# Add basic metadata for SQL search
1189+
message_tuples = []
1190+
for message in combined_messages:
1191+
metadata = {
1192+
"search_mode": "sql",
1193+
"combined_score": None, # SQL doesn't provide scores
1194+
}
1195+
message_tuples.append((message, metadata))
1196+
return message_tuples

0 commit comments

Comments
 (0)