Skip to content
Merged
18 changes: 9 additions & 9 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 18 additions & 11 deletions src/agent/cost_guard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,19 @@ impl CostGuard {
/// Record a completed LLM action: its token costs and the action timestamp.
///
/// Call this AFTER an LLM call completes so that costs are tracked.
///
/// When `cost_per_token` is `Some`, those rates are used directly (provider-
/// sourced pricing). When `None`, falls back to the static `costs::model_cost`
/// lookup table, then `costs::default_cost`.
pub async fn record_llm_call(
&self,
model: &str,
input_tokens: u32,
output_tokens: u32,
cost_per_token: Option<(Decimal, Decimal)>,
) -> Decimal {
let (input_rate, output_rate) =
costs::model_cost(model).unwrap_or_else(costs::default_cost);
let (input_rate, output_rate) = cost_per_token
.unwrap_or_else(|| costs::model_cost(model).unwrap_or_else(costs::default_cost));
let cost =
input_rate * Decimal::from(input_tokens) + output_rate * Decimal::from(output_tokens);

Expand Down Expand Up @@ -261,7 +266,9 @@ mod tests {
assert!(guard.check_allowed().await.is_ok());

// Record a big call, still allowed
guard.record_llm_call("gpt-4o", 100_000, 100_000).await;
guard
.record_llm_call("gpt-4o", 100_000, 100_000, None)
.await;
assert!(guard.check_allowed().await.is_ok());
}

Expand All @@ -278,7 +285,7 @@ mod tests {
// Record a call that costs more than $0.01
// gpt-4o: input=$0.0000025/tok, output=$0.00001/tok
// 10000 input + 10000 output = $0.025 + $0.10 = $0.125
guard.record_llm_call("gpt-4o", 10_000, 10_000).await;
guard.record_llm_call("gpt-4o", 10_000, 10_000, None).await;

// Now should be blocked
let result = guard.check_allowed().await;
Expand All @@ -301,7 +308,7 @@ mod tests {
// First 3 actions allowed
for _ in 0..3 {
assert!(guard.check_allowed().await.is_ok());
guard.record_llm_call("gpt-4o", 10, 10).await;
guard.record_llm_call("gpt-4o", 10, 10, None).await;
}

// 4th should be blocked
Expand All @@ -322,7 +329,7 @@ mod tests {

assert_eq!(guard.daily_spend().await, Decimal::ZERO);

let cost = guard.record_llm_call("gpt-4o", 1000, 500).await;
let cost = guard.record_llm_call("gpt-4o", 1000, 500, None).await;
assert!(cost > Decimal::ZERO);
assert_eq!(guard.daily_spend().await, cost);
}
Expand All @@ -333,8 +340,8 @@ mod tests {

assert_eq!(guard.actions_this_hour().await, 0);

guard.record_llm_call("gpt-4o", 10, 10).await;
guard.record_llm_call("gpt-4o", 10, 10).await;
guard.record_llm_call("gpt-4o", 10, 10, None).await;
guard.record_llm_call("gpt-4o", 10, 10, None).await;

assert_eq!(guard.actions_this_hour().await, 2);
}
Expand Down Expand Up @@ -371,10 +378,10 @@ mod tests {
assert!(guard.model_usage().await.is_empty());

// Record calls for two different models
guard.record_llm_call("gpt-4o", 1000, 500).await;
guard.record_llm_call("gpt-4o", 2000, 1000).await;
guard.record_llm_call("gpt-4o", 1000, 500, None).await;
guard.record_llm_call("gpt-4o", 2000, 1000, None).await;
guard
.record_llm_call("claude-3-5-sonnet-20241022", 500, 200)
.record_llm_call("claude-3-5-sonnet-20241022", 500, 200, None)
.await;

let usage = guard.model_usage().await;
Expand Down
1 change: 1 addition & 0 deletions src/agent/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ impl Agent {
&model_name,
output.usage.input_tokens,
output.usage.output_tokens,
Some(self.llm().cost_per_token()),
)
.await;
tracing::debug!(
Expand Down
108 changes: 59 additions & 49 deletions src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::skills::SkillRegistry;
use crate::skills::catalog::SkillCatalog;
use crate::tools::ToolRegistry;
use crate::tools::mcp::McpSessionManager;
use crate::tools::wasm::SharedCredentialRegistry;
use crate::tools::wasm::WasmToolRuntime;
use crate::workspace::{EmbeddingProvider, Workspace};

Expand All @@ -48,6 +49,8 @@ pub struct AppComponents {
pub skill_catalog: Option<Arc<SkillCatalog>>,
pub cost_guard: Arc<crate::agent::cost_guard::CostGuard>,
pub session: Arc<SessionManager>,
pub catalog_entries: Vec<crate::extensions::RegistryEntry>,
pub dev_loaded_tool_names: Vec<String>,
}

/// Options that control optional init phases.
Expand Down Expand Up @@ -313,54 +316,41 @@ impl AppBuilder {
),
anyhow::Error,
> {
use crate::workspace::{NearAiEmbeddings, OpenAiEmbeddings};

let safety = Arc::new(SafetyLayer::new(&self.config.safety));
tracing::info!("Safety layer initialized");

let tools = Arc::new(ToolRegistry::new());
tools.register_builtin_tools();

// Create embeddings provider if configured
let embeddings: Option<Arc<dyn EmbeddingProvider>> = if self.config.embeddings.enabled {
match self.config.embeddings.provider.as_str() {
"nearai" => {
tracing::info!(
"Embeddings enabled via NEAR AI (model: {})",
self.config.embeddings.model
);
Some(Arc::new(
NearAiEmbeddings::new(
&self.config.llm.nearai.base_url,
self.session.clone(),
)
.with_model(&self.config.embeddings.model, 1536),
))
}
_ => {
if let Some(api_key) = self.config.embeddings.openai_api_key() {
tracing::info!(
"Embeddings enabled via OpenAI (model: {})",
self.config.embeddings.model
);
Some(Arc::new(OpenAiEmbeddings::with_model(
api_key,
&self.config.embeddings.model,
match self.config.embeddings.model.as_str() {
"text-embedding-3-large" => 3072,
_ => 1536,
},
)))
} else {
tracing::warn!("Embeddings configured but OPENAI_API_KEY not set");
None
}
}
}
// Initialize tool registry with credential injection support
let credential_registry = Arc::new(SharedCredentialRegistry::new());
let tools = if let Some(ref ss) = self.secrets_store {
Arc::new(
ToolRegistry::new()
.with_credentials(Arc::clone(&credential_registry), Arc::clone(ss)),
)
} else {
tracing::info!("Embeddings disabled (set OPENAI_API_KEY or EMBEDDING_ENABLED=true)");
None
Arc::new(ToolRegistry::new())
};
tools.register_builtin_tools();

// Create embeddings provider using the unified method
let embeddings = self
.config
.embeddings
.create_provider(&self.config.llm.nearai.base_url, self.session.clone());

// Warn if libSQL backend is used with non-1536 embedding dimension.
if self.config.database.backend == crate::config::DatabaseBackend::LibSql
&& self.config.embeddings.enabled
&& self.config.embeddings.dimension != 1536
{
tracing::warn!(
configured_dimension = self.config.embeddings.dimension,
"Embedding dimension {} is not 1536. The libSQL schema uses \
F32_BLOB(1536) which requires exactly 1536 dimensions. \
Embedding storage will fail. Use PostgreSQL or set \
EMBEDDING_DIMENSION=1536.",
self.config.embeddings.dimension
);
}

// Register memory tools if database is available
let workspace = if let Some(ref db) = self.db {
Expand Down Expand Up @@ -402,6 +392,8 @@ impl AppBuilder {
Arc<McpSessionManager>,
Option<Arc<WasmToolRuntime>>,
Option<Arc<ExtensionManager>>,
Vec<crate::extensions::RegistryEntry>,
Vec<String>,
),
anyhow::Error,
> {
Expand Down Expand Up @@ -431,6 +423,8 @@ impl AppBuilder {
let tools = Arc::clone(tools);
let wasm_config = self.config.wasm.clone();
async move {
let mut dev_loaded_tool_names: Vec<String> = Vec::new();

if let Some(ref runtime) = wasm_tool_runtime {
let mut loader = WasmToolLoader::new(Arc::clone(runtime), Arc::clone(&tools));
if let Some(ref secrets) = secrets_store {
Expand Down Expand Up @@ -461,10 +455,11 @@ impl AppBuilder {

match load_dev_tools(&loader, &wasm_config.tools_dir).await {
Ok(results) => {
if !results.loaded.is_empty() {
dev_loaded_tool_names.extend(results.loaded.iter().cloned());
if !dev_loaded_tool_names.is_empty() {
tracing::info!(
"Loaded {} dev WASM tools from build artifacts",
results.loaded.len()
dev_loaded_tool_names.len()
);
}
}
Expand All @@ -473,6 +468,8 @@ impl AppBuilder {
}
}
}

dev_loaded_tool_names
}
};

Expand Down Expand Up @@ -577,7 +574,7 @@ impl AppBuilder {
}
};

tokio::join!(wasm_tools_future, mcp_servers_future);
let (dev_loaded_tool_names, _) = tokio::join!(wasm_tools_future, mcp_servers_future);

// Load registry catalog entries for extension discovery
let catalog_entries = match crate::registry::RegistryCatalog::load_or_embedded() {
Expand Down Expand Up @@ -640,7 +637,13 @@ impl AppBuilder {
tools.register_dev_tools();
}

Ok((mcp_session_manager, wasm_tool_runtime, extension_manager))
Ok((
mcp_session_manager,
wasm_tool_runtime,
extension_manager,
catalog_entries,
dev_loaded_tool_names,
))
}
Comment on lines +640 to 647
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The return tuple should include the list of development tool names collected during the WASM loading phase to ensure they can be passed to the hook bootstrapping process.

References
  1. When implementing mutually exclusive logic or refactoring shared state, ensure that all relevant state flags and identifiers are updated consistently to prevent incorrect behavior.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in bd5523e. init_extensions() now captures and returns dev_loaded_tool_names from the WASM loading future. The names are exposed on AppComponents and passed through to bootstrap_hooks() in main.rs.


/// Run all init phases in order and return the assembled components.
Expand All @@ -654,8 +657,13 @@ impl AppBuilder {
// Create hook registry early so runtime extension activation can register hooks.
let hooks = Arc::new(HookRegistry::new());

let (mcp_session_manager, wasm_tool_runtime, extension_manager) =
self.init_extensions(&tools, &hooks).await?;
let (
mcp_session_manager,
wasm_tool_runtime,
extension_manager,
catalog_entries,
dev_loaded_tool_names,
) = self.init_extensions(&tools, &hooks).await?;

// Seed workspace and backfill embeddings
if let Some(ref ws) = workspace {
Expand Down Expand Up @@ -730,6 +738,8 @@ impl AppBuilder {
skill_catalog,
cost_guard,
session: self.session,
catalog_entries,
dev_loaded_tool_names,
})
}
}
Loading
Loading