Skip to content

Litellm adaptive routing#26049

Open
krrish-berri-2 wants to merge 3 commits intolitellm_internal_stagingfrom
litellm_adaptive_routing
Open

Litellm adaptive routing#26049
krrish-berri-2 wants to merge 3 commits intolitellm_internal_stagingfrom
litellm_adaptive_routing

Conversation

@krrish-berri-2
Copy link
Copy Markdown
Contributor

@krrish-berri-2 krrish-berri-2 commented Apr 19, 2026

What

New beta routing strategy: adaptive_router/. Give it a list of models and a quality/cost weight — it picks per-request and learns from real traffic which model is best at which kind of request.

Config

model_list:
  - model_name: gpt-4o
    litellm_params: { model: openai/gpt-4o }
    model_info:
      adaptive_router_preferences:
        quality_tier: 3
        strengths: ["code_generation", "analytical_reasoning"]

  - model_name: gpt-4o-mini
    litellm_params: { model: openai/gpt-4o-mini }
    model_info:
      adaptive_router_preferences:
        quality_tier: 2
        strengths: ["factual_lookup"]

  - model_name: my-router
    litellm_params:
      model: adaptive_router/smart-router
      adaptive_router_config:
        available_models: ["gpt-4o", "gpt-4o-mini"]
        weights: { quality: 0.7, cost: 0.3 }

Call model: my-router like any other model. Response header x-litellm-adaptive-router-model tells you which model was actually picked. Force a floor with x-litellm-min-quality-tier: 3.

How it learns

Each request → classified into 1 of 7 types (code_generation, code_understanding, technical_design, analytical_reasoning, writing, factual_lookup, general). Per (router, request_type, model) we keep a Beta posterior, updated from session signals (satisfaction, failure, loop, disengagement, stagnation, exhaustion, misalignment) extracted from message history.

Inspect state: GET /adaptive_router/{router_name}/state.

What's in the PR

  • litellm/router_strategy/adaptive_router/ — bandit, classifier, signals, hooks, config
  • litellm/proxy/db/db_transaction_queue/adaptive_router_update_queue.py — batched async writer
  • New tables LiteLLM_AdaptiveRouterState + LiteLLM_AdaptiveRouterSession (migration + all 3 schema.prisma copies)
  • GET /adaptive_router/{router_name}/state endpoint
  • Docs: docs/my-website/docs/adaptive_router.md
  • Tests under tests/test_litellm/router_strategy/adaptive_router/ + signal fixtures
  • Demo: scripts/adaptive_router_demo/ (chat UI, dashboard, traffic gen, eval)

Known limitations

  • No latency scoring — slow model can still win
  • Signals are regex-based, English-biased
  • Hard cap 200 obs/cell, no decay
  • Once a session picks a model, other models' turns in that session don't contribute to learning

Type

🆕 New Feature

allow model routing to improve based on conversation signals

ensures router is picking best model for task
@CLAassistant
Copy link
Copy Markdown

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 19, 2026

Greptile Summary

This PR introduces an adaptive routing strategy for LiteLLM using a Thompson-sampling (Beta-distribution) multi-armed bandit. It adds new DB tables (LiteLLM_AdaptiveRouterState, LiteLLM_AdaptiveRouterSession), a per-turn signal classifier, a session-state tracker, an in-memory aggregation queue with a background flusher, and a pre/post-call hook pair to close the feedback loop.

Three P1 issues need attention before merge:

  • Two bare print() calls in adaptive_router.py (record_turn) will spam stdout on every processed turn in production.
  • load_state_from_db is never invoked at proxy startup, so all learned bandit posteriors are discarded on every restart — the router always starts cold.
  • flush_state_to_db uses a non-atomic read-modify-write (find → compute → upsert) that silently drops updates when two pods flush concurrently for the same cell, which is the common multi-pod topology.

Confidence Score: 3/5

Not safe to merge: three P1 defects (debug prints, lost state on restart, race condition in multi-pod flush) should be resolved first.

Three independent P1 findings: leftover print() statements pollute production stdout, load_state_from_db is never called so the bandit forgets everything on restart, and the non-atomic flush loses concurrent updates in multi-pod deployments. These are current defects in the changed code, not speculative risks.

litellm/router_strategy/adaptive_router/adaptive_router.py (print statements + missing load_state_from_db call), litellm/proxy/db/db_transaction_queue/adaptive_router_update_queue.py (non-atomic flush + PK fields in update block), litellm/proxy/proxy_server.py (missing load_state_from_db at startup)

Important Files Changed

Filename Overview
litellm/router_strategy/adaptive_router/adaptive_router.py Core adaptive router logic: two leftover debug print() statements will spam stdout in production on every turn, and load_state_from_db is never called at startup so all learned state is lost on restart
litellm/proxy/db/db_transaction_queue/adaptive_router_update_queue.py In-memory queue with async DB flusher: flush_state_to_db uses a non-atomic read-modify-write pattern that loses updates in multi-pod deployments; flush_session_to_db passes PK fields in the Prisma update block
litellm/router_strategy/adaptive_router/bandit.py Thompson sampling implementation: clean Beta-distribution bandit with quality/cost scoring; SAMPLE_CAP hard-drops updates once exceeded (silently, by design per D5 comment)
litellm/router_strategy/adaptive_router/signals.py O(1) per-turn signal detection (misalignment, stagnation, satisfaction, failure, loop, exhaustion) using Jaccard similarity and regex patterns; logic is straightforward and tests cover main scenarios
litellm/router_strategy/adaptive_router/hooks.py Post-call hook assembles Turn from request/response and feeds it through record_turn; session key derivation via SHA-256 over identity fields is sound; exceptions are swallowed to protect the request path
litellm/router_strategy/adaptive_router/classifier.py Rule-based prompt classifier using regex rules ordered from specific to general; deterministic and O(rules) per call with a 2000-char input cap
litellm-proxy-extras/litellm_proxy_extras/migrations/20260418000000_add_adaptive_router_tables/migration.sql Adds LiteLLM_AdaptiveRouterState and LiteLLM_AdaptiveRouterSession tables with appropriate indexes; schema is consistent with the Prisma model and application code
litellm/proxy/proxy_server.py Starts the adaptive-router flusher task at startup, but does not call load_state_from_db to restore persisted bandit cells — every proxy restart begins with cold-start priors
litellm/router.py Adds adaptive_routers dict and finalization logic; _finalize_adaptive_router_if_configured and init_adaptive_router_deployment correctly build AdaptiveRouter instances and register post-call hooks

Sequence Diagram

sequenceDiagram
    participant Client
    participant Router
    participant AdaptiveRouter
    participant PostCallHook
    participant UpdateQueue
    participant FlusherLoop
    participant Postgres

    Client->>Router: POST /v1/chat/completions (model=smart-cheap-router)
    Router->>AdaptiveRouter: async_pre_routing_hook()
    AdaptiveRouter->>AdaptiveRouter: classify_prompt() → RequestType
    AdaptiveRouter->>AdaptiveRouter: pick_model() [Thompson sample]
    AdaptiveRouter-->>Router: PreRoutingHookResponse(model=chosen)
    Router->>Router: Route to chosen upstream model
    Router-->>Client: Response + x-litellm-adaptive-router-model header

    Note over PostCallHook: Runs after response is sent
    Router->>PostCallHook: async_log_success_event(kwargs, response)
    PostCallHook->>PostCallHook: _resolve_session_key()
    PostCallHook->>AdaptiveRouter: claim_or_check_owner()
    AdaptiveRouter-->>PostCallHook: True/False (owner check)
    PostCallHook->>AdaptiveRouter: record_turn(session_id, model, turn)
    AdaptiveRouter->>AdaptiveRouter: apply_turn() → SignalDelta
    AdaptiveRouter->>AdaptiveRouter: _compute_bandit_delta() → Δα, Δβ
    AdaptiveRouter->>AdaptiveRouter: apply_delta() → update in-memory cell
    AdaptiveRouter->>UpdateQueue: add_session_state()
    AdaptiveRouter->>UpdateQueue: add_state_delta()

    loop Every FLUSH_INTERVAL seconds
        FlusherLoop->>UpdateQueue: flush_state_to_db(prisma)
        UpdateQueue->>Postgres: find_unique + upsert (non-atomic)
        FlusherLoop->>UpdateQueue: flush_session_to_db(prisma)
        UpdateQueue->>Postgres: upsert session row
    end
Loading

Reviews (1): Last reviewed commit: "docs: update docs" | Re-trigger Greptile

Comment on lines +300 to +308
print("CALLS DELTA", delta)

snapshot = asdict(state)
await self.queue.add_session_state(
session_id, self.router_name, model_name, snapshot
)

d_alpha, d_beta = self._compute_bandit_delta(delta)
print("CALLS D_ALPHA", d_alpha)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Debug print() statements left in production code

Two bare print() calls on lines 300 and 308 will emit to stdout on every turn processed by the adaptive router. In a busy deployment this creates significant noise and can impact performance since stdout is unbuffered by default in many container runtimes. These should be replaced with verbose_router_logger.debug(...).

Suggested change
print("CALLS DELTA", delta)
snapshot = asdict(state)
await self.queue.add_session_state(
session_id, self.router_name, model_name, snapshot
)
d_alpha, d_beta = self._compute_bandit_delta(delta)
print("CALLS D_ALPHA", d_alpha)
delta = apply_turn(state, turn)
verbose_router_logger.debug("CALLS DELTA %s", delta)
snapshot = asdict(state)
await self.queue.add_session_state(
session_id, self.router_name, model_name, snapshot
)
d_alpha, d_beta = self._compute_bandit_delta(delta)
verbose_router_logger.debug("CALLS D_ALPHA %s", d_alpha)

Comment on lines +97 to +128
async def load_state_from_db(self, prisma_client: Any) -> None:
"""Override cold-start cells with persisted state. Called once at startup."""
if prisma_client is None:
return
try:
rows = await prisma_client.db.litellm_adaptiverouterstate.find_many(
where={"router_name": self.router_name}
)
loaded = 0
for row in rows:
try:
rt = RequestType(row.request_type)
except ValueError:
# Unknown taxonomy entry from an older/newer version. Skip.
continue
if row.model_name not in self.config.available_models:
continue
self._cells[(rt, row.model_name)] = BanditCell(
alpha=row.alpha, beta=row.beta
)
loaded += 1
verbose_router_logger.info(
"AdaptiveRouter[%s]: loaded %d cells from DB",
self.router_name,
loaded,
)
except Exception as e:
verbose_router_logger.exception(
"AdaptiveRouter[%s]: failed to load state from DB: %s",
self.router_name,
e,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 load_state_from_db is defined but never called at proxy startup

The method correctly restores persisted bandit cells from LiteLLM_AdaptiveRouterState, but there is no call to it anywhere in proxy_server.py or router.py. Every time the proxy restarts, all learned posteriors are discarded and the router starts from cold-start priors — defeating the purpose of persisting state to Postgres.

init_adaptive_router_deployment in router.py (around line 7030) constructs the AdaptiveRouter but does not await load_state_from_db. _finalize_adaptive_router_if_configured is called from __init__ (synchronously), so state loading would need to be triggered from the async startup path in proxy_server.py, similar to how other startup tasks are scheduled.

Comment on lines +109 to +151
try:
existing = (
await prisma_client.db.litellm_adaptiverouterstate.find_unique(
where={
"router_name_request_type_model_name": {
"router_name": router,
"request_type": rt,
"model_name": model,
}
}
)
)
new_alpha = (existing.alpha if existing else 0.0) + payload[
"delta_alpha"
]
new_beta = (existing.beta if existing else 0.0) + payload["delta_beta"]
new_samples = (existing.total_samples if existing else 0) + int(
payload["samples_added"]
)
await prisma_client.db.litellm_adaptiverouterstate.upsert(
where={
"router_name_request_type_model_name": {
"router_name": router,
"request_type": rt,
"model_name": model,
}
},
data={
"create": {
"router_name": router,
"request_type": rt,
"model_name": model,
"alpha": new_alpha,
"beta": new_beta,
"total_samples": new_samples,
},
"update": {
"alpha": new_alpha,
"beta": new_beta,
"total_samples": new_samples,
},
},
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Non-atomic read-modify-write causes state loss in multi-pod deployments

flush_state_to_db reads the current alpha/beta from the DB, computes new values in Python, then upserts — a classic TOCTOU race. When two pods flush concurrently for the same (router, request_type, model) key, both read existing.alpha = N, both compute N + delta, and both upsert N + delta — one pod's update is silently discarded.

Multi-pod LiteLLM proxy is the standard production topology, so this data loss is not theoretical. A safe fix is to push the arithmetic into a single SQL statement:

await prisma_client.db.execute_raw(
    """
    INSERT INTO "LiteLLM_AdaptiveRouterState"
        (router_name, request_type, model_name, alpha, beta, total_samples)
    VALUES ($1, $2, $3, $4, $5, $6)
    ON CONFLICT (router_name, request_type, model_name)
    DO UPDATE SET
        alpha         = "LiteLLM_AdaptiveRouterState".alpha + EXCLUDED.alpha,
        beta          = "LiteLLM_AdaptiveRouterState".beta  + EXCLUDED.beta,
        total_samples = "LiteLLM_AdaptiveRouterState".total_samples + EXCLUDED.total_samples,
        last_updated_at = CURRENT_TIMESTAMP
    """,
    router, rt, model,
    payload["delta_alpha"], payload["delta_beta"], int(payload["samples_added"]),
)

This makes the merge atomic and eliminates the separate find_unique round-trip.

Comment on lines +186 to +197
}
},
data={
"create": {
"session_id": session_id,
"router_name": router,
"model_name": model,
**payload,
},
"update": payload,
},
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 PK fields included in Prisma upsert update block

payload is asdict(state) where SessionState contains session_id, router_name, and model_name as top-level fields. These end up in the update: payload block, which only accepts non-PK fields in Prisma. Depending on the Prisma version and the generated client, this may raise a validation error at runtime or silently be ignored.

Consider stripping the identity fields before passing to update:

update_payload = {
    k: v for k, v in payload.items()
    if k not in {"session_id", "router_name", "model_name"}
}
await prisma_client.db.litellm_adaptiveroutersession.upsert(
    where={"session_id_router_name_model_name": {...}},
    data={
        "create": {"session_id": session_id, "router_name": router, "model_name": model, **update_payload},
        "update": update_payload,
    },
)

Comment on lines +80 to +84
self._cells: Dict[Tuple[RequestType, str], BanditCell] = {}
self._owner_cache: Dict[str, Tuple[str, float]] = {}
self._session_states: Dict[Tuple[str, str], SessionState] = {}
self._skipped_updates_total: int = 0
self._lock = asyncio.Lock()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _owner_cache grows without bound

_owner_cache maps session_key → (owner_model, expires_at) and is never pruned. Expired entries are checked on read but left in place, so a long-running proxy handling millions of unique sessions would accumulate unbounded memory. Consider evicting stale entries periodically (e.g., inside record_turn or on a background timer) using a simple scan of entries where expires_at < now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants