Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backend/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ KBR_INITIAL_USER_BALANCE_USD=0.0
# KBR_EMAIL_WHITELIST_ENABLED=false

# Uvicorn server settings
KBR_UVICORN_TIMEOUT_KEEP_ALIVE=120
KBR_UVICORN_TIMEOUT_KEEP_ALIVE=3700
KBR_UVICORN_LIMIT_CONCURRENCY=100
KBR_UVICORN_LIMIT_MAX_REQUESTS=10000
KBR_UVICORN_LIMIT_MAX_REQUESTS=0

# Streaming settings
KBR_STREAM_HEARTBEAT_INTERVAL=15
Expand Down
158 changes: 30 additions & 128 deletions backend/app/api/admin/endpoints/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Dict, List, Optional
from uuid import UUID

import boto3
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy import select
Expand Down Expand Up @@ -102,155 +101,58 @@ async def list_aws_available_models(
"""
Get list of available Bedrock models from AWS (for selection).

This endpoint queries AWS Bedrock to get all Claude and Nova models
available for the current AWS account/region.
Results are cached for 12 hours to reduce API calls.
Returns models that the proxy can actually invoke:
* Inference profiles available in the deployment region
* Foundation models available in the deployment region
* Foundation models available in the fallback region (routed automatically)

This is used when user wants to add a new model.
Results are cached for 12 hours in memory. The underlying profile cache
is refreshed at startup and daily at 03:00 UTC.
"""
# Check cache first
# Check in-memory cache first
cached_models = _get_aws_models_cache()
if cached_models is not None:
return {"models": cached_models}

try:
# Use us-east-1 for model discovery — it has the most complete model catalog.
# The deployment region (settings.AWS_REGION) is used for actual inference,
# but listing available models should always query the primary catalog region.
catalog_region = "us-east-1"
bedrock_client = boto3.client(
service_name="bedrock", region_name=catalog_region
)
logger.info(
f"Bedrock client initialized for model catalog region: {catalog_region} (deployment region: {settings.AWS_REGION})"
)
bc = BedrockClient.get_instance()

# Allowed providers
allowed_providers = {
"ai21",
"amazon",
"anthropic",
"cohere",
"deepseek",
"google",
"luma",
"meta",
"minimax",
"mistral",
"moonshot",
"moonshotai",
"nvidia",
"openai",
"qwen",
"stability",
"twelvelabs",
"writer",
"zai",
}
# Refresh profile cache if stale or empty
if bc._profile_cache.is_stale or bc._profile_cache.is_empty:
await bc.refresh_profile_cache()

# Build model list from the profile cache
raw_models = bc._profile_cache.get_all_available_models()

# ── Step 1: Get inference profiles (cross-region models) ──
# These are the real, callable model IDs with prefix (e.g. "us.anthropic.claude-sonnet-4-6")
profiles_response = bedrock_client.list_inference_profiles()
profile_ids: set[str] = set()
models = []
for m in raw_models:
model_id = m["model_id"]
base_id = m["base_model_id"]

for profile in profiles_response.get("inferenceProfileSummaries", []):
profile_id = profile.get("inferenceProfileId", "")
profile_name = profile.get("inferenceProfileName", "")

# Extract base model ID by stripping known prefixes
base_id = profile_id
prefix = None
for pfx in BedrockClient.INFERENCE_PROFILE_PREFIXES:
if profile_id.startswith(pfx):
base_id = profile_id[len(pfx) :]
prefix = pfx.rstrip(".")
break

# Filter by provider
provider = base_id.split(".")[0] if "." in base_id else ""
if provider not in allowed_providers:
continue

# Skip embedding and rerank models
if any(kw in base_id for kw in ("embed-", "rerank-")):
continue

friendly_name = profile_name
# Clean up common prefixes from profile name
for strip_prefix in ("US ", "EU ", "APAC ", "GLOBAL ", "Global ", "JP "):
if friendly_name.startswith(strip_prefix):
friendly_name = friendly_name[len(strip_prefix) :]
break

streaming = True # Inference profiles generally support streaming
# Derive friendly name from base model ID
# e.g. "anthropic.claude-sonnet-4-6" → "Claude Sonnet 4.6"
friendly_name = base_id
if m["cross_region_type"]:
# For profiles, use prefix + base for display
friendly_name = base_id

models.append(
{
"model_id": profile_id,
"model_id": model_id,
"model_name": friendly_name,
"friendly_name": friendly_name,
"provider": "bedrock-converse",
"is_cross_region": True,
"cross_region_type": prefix,
"streaming_supported": streaming,
}
)
profile_ids.add(profile_id)

# ── Step 2: Get foundation models (standard on-demand) ──
# Only add models that don't already have an inference profile entry
fm_response = bedrock_client.list_foundation_models()

for model_summary in fm_response.get("modelSummaries", []):
base_model_id = model_summary.get("modelId", "")
model_name = model_summary.get("modelName", "")

# Filter by provider
provider_prefix = (
base_model_id.split(".")[0] if "." in base_model_id else ""
)
if provider_prefix not in allowed_providers:
continue

# Skip embedding and rerank models
if any(kw in base_model_id for kw in ("embed-", "rerank-")):
continue

# Skip if already covered by an inference profile
already_covered = any(
pid.endswith(f".{base_model_id}") or pid == base_model_id
for pid in profile_ids
)
if already_covered:
continue

# Skip throughput-variant IDs (e.g. "amazon.nova-premier-v1:0:1000k")
if base_model_id.count(":") > 1:
continue

friendly_name = model_name or base_model_id
streaming = model_summary.get("responseStreamingSupported", False)

models.append(
{
"model_id": base_model_id,
"model_name": model_name,
"friendly_name": friendly_name,
"provider": "bedrock-converse",
"is_cross_region": False,
"cross_region_type": None,
"streaming_supported": streaming,
"is_cross_region": m["is_cross_region"],
"cross_region_type": m["cross_region_type"],
"streaming_supported": True,
"is_fallback": m.get("is_fallback", False),
}
)

# Sort: cross-region first, then by model_id
models.sort(key=lambda m: (not m["is_cross_region"], m["model_id"]))

logger.info(
f"Retrieved {len(models)} Bedrock models from AWS "
f"({len(profile_ids)} inference profiles + foundation models)"
)
logger.info(f"Built model list from profile cache: {len(models)} models")

# Append Gemini models dynamically (only if API key is configured)
if settings.GEMINI_API_KEY:
Expand All @@ -263,7 +165,7 @@ async def list_aws_available_models(
except Exception as e:
logger.warning(f"Failed to fetch Gemini models (non-fatal): {e}")

# Cache the results (includes Gemini models)
# Cache the results
_set_aws_models_cache(models)

return {"models": models}
Expand Down
Loading
Loading