Skip to content

Commit 8514d88

Browse files
authored
Merge pull request #6 from aws-samples/feature/gemini-sdk-playground
Feature/gemini sdk playground
2 parents 7affcfa + 99b4e42 commit 8514d88

36 files changed

Lines changed: 2071 additions & 771 deletions

backend/.env.example

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ KBR_INITIAL_USER_BALANCE_USD=0.0
5454
# KBR_EMAIL_WHITELIST_ENABLED=false
5555

5656
# Uvicorn server settings
57-
KBR_UVICORN_TIMEOUT_KEEP_ALIVE=120
57+
KBR_UVICORN_TIMEOUT_KEEP_ALIVE=3700
5858
KBR_UVICORN_LIMIT_CONCURRENCY=100
59-
KBR_UVICORN_LIMIT_MAX_REQUESTS=10000
59+
KBR_UVICORN_LIMIT_MAX_REQUESTS=0
6060

6161
# Streaming settings
6262
KBR_STREAM_HEARTBEAT_INTERVAL=15

backend/app/api/admin/endpoints/models.py

Lines changed: 30 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import Dict, List, Optional
88
from uuid import UUID
99

10-
import boto3
1110
from fastapi import APIRouter, Depends, HTTPException
1211
from pydantic import BaseModel
1312
from sqlalchemy import select
@@ -102,155 +101,58 @@ async def list_aws_available_models(
102101
"""
103102
Get list of available Bedrock models from AWS (for selection).
104103
105-
This endpoint queries AWS Bedrock to get all Claude and Nova models
106-
available for the current AWS account/region.
107-
Results are cached for 12 hours to reduce API calls.
104+
Returns models that the proxy can actually invoke:
105+
* Inference profiles available in the deployment region
106+
* Foundation models available in the deployment region
107+
* Foundation models available in the fallback region (routed automatically)
108108
109-
This is used when user wants to add a new model.
109+
Results are cached for 12 hours in memory. The underlying profile cache
110+
is refreshed at startup and daily at 03:00 UTC.
110111
"""
111-
# Check cache first
112+
# Check in-memory cache first
112113
cached_models = _get_aws_models_cache()
113114
if cached_models is not None:
114115
return {"models": cached_models}
115116

116117
try:
117-
# Use us-east-1 for model discovery — it has the most complete model catalog.
118-
# The deployment region (settings.AWS_REGION) is used for actual inference,
119-
# but listing available models should always query the primary catalog region.
120-
catalog_region = "us-east-1"
121-
bedrock_client = boto3.client(
122-
service_name="bedrock", region_name=catalog_region
123-
)
124-
logger.info(
125-
f"Bedrock client initialized for model catalog region: {catalog_region} (deployment region: {settings.AWS_REGION})"
126-
)
118+
bc = BedrockClient.get_instance()
127119

128-
# Allowed providers
129-
allowed_providers = {
130-
"ai21",
131-
"amazon",
132-
"anthropic",
133-
"cohere",
134-
"deepseek",
135-
"google",
136-
"luma",
137-
"meta",
138-
"minimax",
139-
"mistral",
140-
"moonshot",
141-
"moonshotai",
142-
"nvidia",
143-
"openai",
144-
"qwen",
145-
"stability",
146-
"twelvelabs",
147-
"writer",
148-
"zai",
149-
}
120+
# Refresh profile cache if stale or empty
121+
if bc._profile_cache.is_stale or bc._profile_cache.is_empty:
122+
await bc.refresh_profile_cache()
123+
124+
# Build model list from the profile cache
125+
raw_models = bc._profile_cache.get_all_available_models()
150126

151-
# ── Step 1: Get inference profiles (cross-region models) ──
152-
# These are the real, callable model IDs with prefix (e.g. "us.anthropic.claude-sonnet-4-6")
153-
profiles_response = bedrock_client.list_inference_profiles()
154-
profile_ids: set[str] = set()
155127
models = []
128+
for m in raw_models:
129+
model_id = m["model_id"]
130+
base_id = m["base_model_id"]
156131

157-
for profile in profiles_response.get("inferenceProfileSummaries", []):
158-
profile_id = profile.get("inferenceProfileId", "")
159-
profile_name = profile.get("inferenceProfileName", "")
160-
161-
# Extract base model ID by stripping known prefixes
162-
base_id = profile_id
163-
prefix = None
164-
for pfx in BedrockClient.INFERENCE_PROFILE_PREFIXES:
165-
if profile_id.startswith(pfx):
166-
base_id = profile_id[len(pfx) :]
167-
prefix = pfx.rstrip(".")
168-
break
169-
170-
# Filter by provider
171-
provider = base_id.split(".")[0] if "." in base_id else ""
172-
if provider not in allowed_providers:
173-
continue
174-
175-
# Skip embedding and rerank models
176-
if any(kw in base_id for kw in ("embed-", "rerank-")):
177-
continue
178-
179-
friendly_name = profile_name
180-
# Clean up common prefixes from profile name
181-
for strip_prefix in ("US ", "EU ", "APAC ", "GLOBAL ", "Global ", "JP "):
182-
if friendly_name.startswith(strip_prefix):
183-
friendly_name = friendly_name[len(strip_prefix) :]
184-
break
185-
186-
streaming = True # Inference profiles generally support streaming
132+
# Derive friendly name from base model ID
133+
# e.g. "anthropic.claude-sonnet-4-6" → "Claude Sonnet 4.6"
134+
friendly_name = base_id
135+
if m["cross_region_type"]:
136+
# For profiles, use prefix + base for display
137+
friendly_name = base_id
187138

188139
models.append(
189140
{
190-
"model_id": profile_id,
141+
"model_id": model_id,
191142
"model_name": friendly_name,
192143
"friendly_name": friendly_name,
193144
"provider": "bedrock-converse",
194-
"is_cross_region": True,
195-
"cross_region_type": prefix,
196-
"streaming_supported": streaming,
197-
}
198-
)
199-
profile_ids.add(profile_id)
200-
201-
# ── Step 2: Get foundation models (standard on-demand) ──
202-
# Only add models that don't already have an inference profile entry
203-
fm_response = bedrock_client.list_foundation_models()
204-
205-
for model_summary in fm_response.get("modelSummaries", []):
206-
base_model_id = model_summary.get("modelId", "")
207-
model_name = model_summary.get("modelName", "")
208-
209-
# Filter by provider
210-
provider_prefix = (
211-
base_model_id.split(".")[0] if "." in base_model_id else ""
212-
)
213-
if provider_prefix not in allowed_providers:
214-
continue
215-
216-
# Skip embedding and rerank models
217-
if any(kw in base_model_id for kw in ("embed-", "rerank-")):
218-
continue
219-
220-
# Skip if already covered by an inference profile
221-
already_covered = any(
222-
pid.endswith(f".{base_model_id}") or pid == base_model_id
223-
for pid in profile_ids
224-
)
225-
if already_covered:
226-
continue
227-
228-
# Skip throughput-variant IDs (e.g. "amazon.nova-premier-v1:0:1000k")
229-
if base_model_id.count(":") > 1:
230-
continue
231-
232-
friendly_name = model_name or base_model_id
233-
streaming = model_summary.get("responseStreamingSupported", False)
234-
235-
models.append(
236-
{
237-
"model_id": base_model_id,
238-
"model_name": model_name,
239-
"friendly_name": friendly_name,
240-
"provider": "bedrock-converse",
241-
"is_cross_region": False,
242-
"cross_region_type": None,
243-
"streaming_supported": streaming,
145+
"is_cross_region": m["is_cross_region"],
146+
"cross_region_type": m["cross_region_type"],
147+
"streaming_supported": True,
148+
"is_fallback": m.get("is_fallback", False),
244149
}
245150
)
246151

247152
# Sort: cross-region first, then by model_id
248153
models.sort(key=lambda m: (not m["is_cross_region"], m["model_id"]))
249154

250-
logger.info(
251-
f"Retrieved {len(models)} Bedrock models from AWS "
252-
f"({len(profile_ids)} inference profiles + foundation models)"
253-
)
155+
logger.info(f"Built model list from profile cache: {len(models)} models")
254156

255157
# Append Gemini models dynamically (only if API key is configured)
256158
if settings.GEMINI_API_KEY:
@@ -263,7 +165,7 @@ async def list_aws_available_models(
263165
except Exception as e:
264166
logger.warning(f"Failed to fetch Gemini models (non-fatal): {e}")
265167

266-
# Cache the results (includes Gemini models)
168+
# Cache the results
267169
_set_aws_models_cache(models)
268170

269171
return {"models": models}

0 commit comments

Comments
 (0)