diff --git a/CHANGELOG.md b/CHANGELOG.md index d530dfaa1..2e73c4d97 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,15 @@ ## [Unreleased] +### Changes + +- Remote embedding, reranking, and query expansion via OpenAI-compatible API + (vLLM, Ollama, OpenAI, etc.). Set `QMD_EMBED_API_URL` / `QMD_EMBED_API_MODEL` + (and optionally `QMD_RERANK_API_*` / `QMD_EXPAND_API_*`) env vars or add + the equivalent keys to `models:` in `index.yml`. Local generation and + tokenization are preserved via a hybrid routing layer. Includes circuit + breakers, dimension validation, and batch splitting. + ### Fixes - GPU: respect explicit `QMD_LLAMA_GPU=metal|vulkan|cuda` backend overrides instead of always using auto GPU selection. #529 diff --git a/README.md b/README.md index 6f318446b..c0b421d07 100644 --- a/README.md +++ b/README.md @@ -939,6 +939,41 @@ Uses node-llama-cpp's `createRankingContext()` and `rankAndSort()` API for cross Used for generating query variations via `LlamaChatSession`. +### Remote Embedding & Reranking + +QMD can offload embedding and reranking to a remote OpenAI-compatible server (vLLM, Ollama, LM Studio, OpenAI, etc.) while keeping query expansion local. + +**Environment variables** (presence of `QMD_EMBED_API_URL` activates remote mode): + +| Variable | Required | Description | +|----------|----------|-------------| +| `QMD_EMBED_API_URL` | Yes | Base URL, e.g. `http://gpu-host:8000/v1` | +| `QMD_EMBED_API_MODEL` | Yes | Model name, e.g. `BAAI/bge-m3` | +| `QMD_EMBED_API_KEY` | No | Bearer token for auth | +| `QMD_RERANK_API_URL` | No | Rerank endpoint (defaults to embed URL) | +| `QMD_RERANK_API_MODEL` | No | Rerank model name | +| `QMD_RERANK_API_KEY` | No | Rerank auth (defaults to embed key) | + +**YAML config** (`~/.config/qmd/index.yml`): +```yaml +models: + embed_api_url: "http://gpu-host:8000/v1" + embed_api_model: "BAAI/bge-m3" + rerank_api_model: "BAAI/bge-reranker-v2-m3" +``` + +**Example with vLLM:** +```sh +# Start vLLM with an embedding model +vllm serve BAAI/bge-m3 --task embed + +# Point QMD at it +export QMD_EMBED_API_URL=http://localhost:8000/v1 +export QMD_EMBED_API_MODEL=BAAI/bge-m3 +qmd embed +qmd query "your search query" +``` + ## License MIT diff --git a/src/cli/qmd.ts b/src/cli/qmd.ts index 50ae76486..24d6b0527 100755 --- a/src/cli/qmd.ts +++ b/src/cli/qmd.ts @@ -78,7 +78,9 @@ import { type ReindexResult, type ChunkStrategy, } from "../store.js"; -import { disposeDefaultLlamaCpp, getDefaultLlamaCpp, setDefaultLlamaCpp, LlamaCpp, withLLMSession, pullModels, DEFAULT_EMBED_MODEL_URI, DEFAULT_GENERATE_MODEL_URI, DEFAULT_RERANK_MODEL_URI, DEFAULT_MODEL_CACHE_DIR } from "../llm.js"; +import { disposeDefaultLlamaCpp, getDefaultLLM, setDefaultLLM, LlamaCpp, withLLMSession, pullModels, DEFAULT_EMBED_MODEL_URI, DEFAULT_GENERATE_MODEL_URI, DEFAULT_RERANK_MODEL_URI, DEFAULT_MODEL_CACHE_DIR } from "../llm.js"; +import { RemoteLLM, remoteConfigFromEnv } from "../remote-llm.js"; +import { HybridLLM } from "../hybrid-llm.js"; import { formatSearchResults, formatDocuments, @@ -121,11 +123,28 @@ function getStore(): ReturnType { const config = loadConfig(); syncConfigToDb(store.db, config); if (config.models) { - setDefaultLlamaCpp(new LlamaCpp({ + const localLlm = new LlamaCpp({ embedModel: config.models.embed, generateModel: config.models.generate, rerankModel: config.models.rerank, - })); + }); + + // Check if remote embedding is configured (env vars take precedence over YAML) + const remoteConfig = remoteConfigFromEnv(config.models); + if (remoteConfig) { + const remoteLlm = new RemoteLLM(remoteConfig); + setDefaultLLM(new HybridLLM(remoteLlm, localLlm)); + } else { + setDefaultLLM(localLlm); + } + } else { + // No YAML models config — still check env vars for remote embedding + const remoteConfig = remoteConfigFromEnv(); + if (remoteConfig) { + const remoteLlm = new RemoteLLM(remoteConfig); + const localLlm = new LlamaCpp(); + setDefaultLLM(new HybridLLM(remoteLlm, localLlm)); + } } } catch { // Config may not exist yet — that's fine, DB works without it @@ -1681,6 +1700,9 @@ async function vectorIndex( const storeInstance = getStore(); const db = storeInstance.db; + // Use the actual model name from the configured LLM (may be remote, not the default GGUF URI) + model = getDefaultLLM().embedModelName; + if (force) { console.log(`${c.yellow}Force re-indexing: clearing all vectors...${c.reset}`); } diff --git a/src/collections.ts b/src/collections.ts index e68ff65b5..1bf4bb11a 100644 --- a/src/collections.ts +++ b/src/collections.ts @@ -40,6 +40,18 @@ export interface ModelsConfig { embed?: string; rerank?: string; generate?: string; + /** Remote embedding API base URL (e.g. http://gpu-host:8000/v1) */ + embed_api_url?: string; + /** Remote embedding model name (e.g. BAAI/bge-m3) */ + embed_api_model?: string; + /** Bearer token for remote embedding API */ + embed_api_key?: string; + /** Remote rerank API base URL (defaults to embed_api_url) */ + rerank_api_url?: string; + /** Remote rerank model name */ + rerank_api_model?: string; + /** Bearer token for remote rerank API */ + rerank_api_key?: string; } /** diff --git a/src/hybrid-llm.ts b/src/hybrid-llm.ts new file mode 100644 index 000000000..9e93e31ca --- /dev/null +++ b/src/hybrid-llm.ts @@ -0,0 +1,70 @@ +/** + * hybrid-llm.ts - Compositor that routes LLM operations between remote and local backends + * + * Embed/rerank → remote (GPU-heavy, benefits from offloading) + * Generate → local LlamaCpp + * ExpandQuery → remote when expandApiModel is configured, otherwise local LlamaCpp + * tokenize/countTokens → local LlamaCpp (CPU-cheap, needed for chunking) + */ + +import type { + LLM, + EmbedOptions, + EmbeddingResult, + GenerateOptions, + GenerateResult, + ModelInfo, + Queryable, + RerankDocument, + RerankOptions, + RerankResult, +} from "./llm.js"; +import { RemoteLLM } from "./remote-llm.js"; + +export class HybridLLM implements LLM { + constructor( + private readonly remote: LLM, + private readonly local: LLM, + ) {} + + get embedModelName(): string { + return this.remote.embedModelName; + } + + // Route to remote + embed(text: string, options?: EmbedOptions): Promise { + return this.remote.embed(text, options); + } + + embedBatch(texts: string[], options?: EmbedOptions): Promise<(EmbeddingResult | null)[]> { + return this.remote.embedBatch(texts, options); + } + + rerank(query: string, documents: RerankDocument[], options?: RerankOptions): Promise { + return this.remote.rerank(query, documents, options); + } + + // Route to local + generate(prompt: string, options?: GenerateOptions): Promise { + return this.local.generate(prompt, options); + } + + /** + * Route expandQuery to remote when the remote backend supports it + * (i.e., RemoteLLM with expandApiModel configured), otherwise fall back to local. + */ + expandQuery(query: string, options?: { context?: string; includeLexical?: boolean; intent?: string }): Promise { + if (this.remote instanceof RemoteLLM && this.remote.supportsExpand) { + return this.remote.expandQuery(query, options); + } + return this.local.expandQuery(query, options); + } + + modelExists(model: string): Promise { + return this.local.modelExists(model); + } + + async dispose(): Promise { + await Promise.all([this.remote.dispose(), this.local.dispose()]); + } +} diff --git a/src/llm.ts b/src/llm.ts index 7cccc3fa8..bb1c639c4 100644 --- a/src/llm.ts +++ b/src/llm.ts @@ -30,13 +30,24 @@ export function isQwen3EmbeddingModel(modelUri: string): boolean { return /qwen.*embed/i.test(modelUri) || /embed.*qwen/i.test(modelUri); } +/** + * Detect if a model URI refers to a remote API model (not a local GGUF model). + * Remote models handle their own prompt formatting, so no prefixes should be added. + */ +export function isRemoteModel(modelUri: string): boolean { + // Local models use hf: URIs or local file paths ending in .gguf + return !modelUri.startsWith("hf:") && !modelUri.endsWith(".gguf"); +} + /** * Format a query for embedding. * Uses nomic-style task prefix format for embeddinggemma (default). * Uses Qwen3-Embedding instruct format when a Qwen embedding model is active. + * Remote models receive raw text (they handle their own formatting). */ export function formatQueryForEmbedding(query: string, modelUri?: string): string { const uri = modelUri ?? process.env.QMD_EMBED_MODEL ?? DEFAULT_EMBED_MODEL; + if (isRemoteModel(uri)) return query; if (isQwen3EmbeddingModel(uri)) { return `Instruct: Retrieve relevant documents for the given query\nQuery: ${query}`; } @@ -47,9 +58,11 @@ export function formatQueryForEmbedding(query: string, modelUri?: string): strin * Format a document for embedding. * Uses nomic-style format with title and text fields (default). * Qwen3-Embedding encodes documents as raw text without special prefixes. + * Remote models receive raw text (they handle their own formatting). */ export function formatDocForEmbedding(text: string, title?: string, modelUri?: string): string { const uri = modelUri ?? process.env.QMD_EMBED_MODEL ?? DEFAULT_EMBED_MODEL; + if (isRemoteModel(uri)) return title ? `${title}\n${text}` : text; if (isQwen3EmbeddingModel(uri)) { // Qwen3-Embedding: documents are raw text, no task prefix return title ? `${title}\n${text}` : text; @@ -156,7 +169,7 @@ export type LLMSessionOptions = { export interface ILLMSession { embed(text: string, options?: EmbedOptions): Promise; embedBatch(texts: string[], options?: EmbedOptions): Promise<(EmbeddingResult | null)[]>; - expandQuery(query: string, options?: { context?: string; includeLexical?: boolean }): Promise; + expandQuery(query: string, options?: { context?: string; includeLexical?: boolean; intent?: string }): Promise; rerank(query: string, documents: RerankDocument[], options?: RerankOptions): Promise; /** Whether this session is still valid (not released or aborted) */ readonly isValid: boolean; @@ -372,6 +385,16 @@ export interface LLM { */ embed(text: string, options?: EmbedOptions): Promise; + /** + * Batch embed multiple texts + */ + embedBatch(texts: string[], options?: EmbedOptions): Promise<(EmbeddingResult | null)[]>; + + /** + * The embedding model name/URI + */ + readonly embedModelName: string; + /** * Generate text completion */ @@ -386,7 +409,7 @@ export interface LLM { * Expand a search query into multiple variations for different backends. * Returns a list of Queryable objects. */ - expandQuery(query: string, options?: { context?: string, includeLexical?: boolean }): Promise; + expandQuery(query: string, options?: { context?: string; includeLexical?: boolean; intent?: string }): Promise; /** * Rerank documents by relevance to a query @@ -1394,11 +1417,11 @@ export class LlamaCpp implements LLM { * Coordinates with LlamaCpp idle timeout to prevent disposal during active sessions. */ class LLMSessionManager { - private llm: LlamaCpp; + private llm: LLM; private _activeSessionCount = 0; private _inFlightOperations = 0; - constructor(llm: LlamaCpp) { + constructor(llm: LLM) { this.llm = llm; } @@ -1434,7 +1457,7 @@ class LLMSessionManager { this._inFlightOperations = Math.max(0, this._inFlightOperations - 1); } - getLlamaCpp(): LlamaCpp { + getLLM(): LLM { return this.llm; } } @@ -1537,18 +1560,18 @@ class LLMSession implements ILLMSession { } async embed(text: string, options?: EmbedOptions): Promise { - return this.withOperation(() => this.manager.getLlamaCpp().embed(text, options)); + return this.withOperation(() => this.manager.getLLM().embed(text, options)); } async embedBatch(texts: string[], options?: EmbedOptions): Promise<(EmbeddingResult | null)[]> { - return this.withOperation(() => this.manager.getLlamaCpp().embedBatch(texts, options)); + return this.withOperation(() => this.manager.getLLM().embedBatch(texts, options)); } async expandQuery( query: string, options?: { context?: string; includeLexical?: boolean } ): Promise { - return this.withOperation(() => this.manager.getLlamaCpp().expandQuery(query, options)); + return this.withOperation(() => this.manager.getLLM().expandQuery(query, options)); } async rerank( @@ -1556,19 +1579,19 @@ class LLMSession implements ILLMSession { documents: RerankDocument[], options?: RerankOptions ): Promise { - return this.withOperation(() => this.manager.getLlamaCpp().rerank(query, documents, options)); + return this.withOperation(() => this.manager.getLLM().rerank(query, documents, options)); } } -// Session manager for the default LlamaCpp instance +// Session manager for the default LLM instance let defaultSessionManager: LLMSessionManager | null = null; /** - * Get the session manager for the default LlamaCpp instance. + * Get the session manager for the default LLM instance. */ function getSessionManager(): LLMSessionManager { - const llm = getDefaultLlamaCpp(); - if (!defaultSessionManager || defaultSessionManager.getLlamaCpp() !== llm) { + const llm = getDefaultLLM(); + if (!defaultSessionManager || defaultSessionManager.getLLM() !== llm) { defaultSessionManager = new LLMSessionManager(llm); } return defaultSessionManager; @@ -1603,11 +1626,11 @@ export async function withLLMSession( } /** - * Execute a function with a scoped LLM session using a specific LlamaCpp instance. + * Execute a function with a scoped LLM session using a specific LLM instance. * Unlike withLLMSession, this does not use the global singleton. */ export async function withLLMSessionForLlm( - llm: LlamaCpp, + llm: LLM, fn: (session: ILLMSession) => Promise, options?: LLMSessionOptions ): Promise { @@ -1631,35 +1654,45 @@ export function canUnloadLLM(): boolean { } // ============================================================================= -// Singleton for default LlamaCpp instance +// Singleton for default LLM instance // ============================================================================= -let defaultLlamaCpp: LlamaCpp | null = null; +let defaultLLMInstance: LLM | null = null; /** - * Get the default LlamaCpp instance (creates one if needed) + * Get the default LLM instance (creates a LlamaCpp if none set) */ -export function getDefaultLlamaCpp(): LlamaCpp { - if (!defaultLlamaCpp) { - defaultLlamaCpp = new LlamaCpp(); +export function getDefaultLLM(): LLM { + if (!defaultLLMInstance) { + defaultLLMInstance = new LlamaCpp(); } - return defaultLlamaCpp; + return defaultLLMInstance; } /** - * Set a custom default LlamaCpp instance (useful for testing) + * Set the default LLM instance */ -export function setDefaultLlamaCpp(llm: LlamaCpp | null): void { - defaultLlamaCpp = llm; +export function setDefaultLLM(llm: LLM | null): void { + defaultLLMInstance = llm; +} + +/** @deprecated Use getDefaultLLM() */ +export function getDefaultLlamaCpp(): LLM { + return getDefaultLLM(); +} + +/** @deprecated Use setDefaultLLM() */ +export function setDefaultLlamaCpp(llm: LLM | null): void { + setDefaultLLM(llm); } /** - * Dispose the default LlamaCpp instance if it exists. + * Dispose the default LLM instance if it exists. * Call this before process exit to prevent NAPI crashes. */ export async function disposeDefaultLlamaCpp(): Promise { - if (defaultLlamaCpp) { - await defaultLlamaCpp.dispose(); - defaultLlamaCpp = null; + if (defaultLLMInstance) { + await defaultLLMInstance.dispose(); + defaultLLMInstance = null; } } diff --git a/src/remote-llm.ts b/src/remote-llm.ts new file mode 100644 index 000000000..5e0aafa8e --- /dev/null +++ b/src/remote-llm.ts @@ -0,0 +1,485 @@ +/** + * remote-llm.ts - OpenAI-compatible remote embedding, reranking & query expansion backend + * + * Implements the LLM interface by calling HTTP endpoints (vLLM, Ollama, OpenAI, etc.). + * Supports embed, rerank, and (when expandApiModel is set) query expansion via chat completions. + * generate() is not supported — use HybridLLM to pair with a local backend. + */ + +import type { + LLM, + EmbedOptions, + EmbeddingResult, + GenerateOptions, + GenerateResult, + ModelInfo, + Queryable, + QueryType, + RerankDocument, + RerankOptions, + RerankResult, +} from "./llm.js"; + +// ============================================================================= +// Configuration +// ============================================================================= + +export type RemoteLLMConfig = { + /** Base URL for embedding endpoint (e.g. http://gpu-host:8000/v1) */ + embedApiUrl: string; + /** Model name for embedding (e.g. BAAI/bge-m3) */ + embedApiModel: string; + /** Optional bearer token for embedding endpoint */ + embedApiKey?: string; + /** Base URL for rerank endpoint (defaults to embedApiUrl) */ + rerankApiUrl?: string; + /** Model name for reranking */ + rerankApiModel?: string; + /** Optional bearer token for rerank endpoint */ + rerankApiKey?: string; + /** Base URL for query expansion endpoint (defaults to embedApiUrl) */ + expandApiUrl?: string; + /** Model name for query expansion via chat completions (enables remote expansion when set) */ + expandApiModel?: string; + /** Optional bearer token for expansion endpoint (defaults to embedApiKey) */ + expandApiKey?: string; + /** Read timeout for query expansion in ms (default: 30000) */ + expandReadTimeoutMs?: number; + /** Connect timeout in ms (default: 5000) */ + connectTimeoutMs?: number; + /** Read timeout for embedding in ms (default: 30000) */ + embedReadTimeoutMs?: number; + /** Read timeout for reranking in ms (default: 60000) */ + rerankReadTimeoutMs?: number; + /** Max texts per embed HTTP request (default: 32) */ + maxBatchSize?: number; +}; + +// ============================================================================= +// Circuit Breaker +// ============================================================================= + +type CircuitState = "closed" | "open" | "half-open"; + +class CircuitBreaker { + private state: CircuitState = "closed"; + private failures = 0; + private lastFailureTime = 0; + private readonly maxFailures: number; + private readonly cooldownMs: number; + + constructor(maxFailures = 3, cooldownMs = 10 * 60 * 1000) { + this.maxFailures = maxFailures; + this.cooldownMs = cooldownMs; + } + + canAttempt(): boolean { + if (this.state === "closed") return true; + if (this.state === "open") { + if (Date.now() - this.lastFailureTime >= this.cooldownMs) { + this.state = "half-open"; + return true; + } + return false; + } + // half-open: allow one attempt + return true; + } + + onSuccess(): void { + this.state = "closed"; + this.failures = 0; + } + + onFailure(): void { + this.failures++; + this.lastFailureTime = Date.now(); + if (this.state === "half-open" || this.failures >= this.maxFailures) { + this.state = "open"; + } + } + + getState(): CircuitState { + return this.state; + } +} + +// ============================================================================= +// RemoteLLM +// ============================================================================= + +export class RemoteLLM implements LLM { + private readonly config: Required< + Pick + > & RemoteLLMConfig; + + private readonly embedBreaker = new CircuitBreaker(); + private readonly rerankBreaker = new CircuitBreaker(); + private readonly expandBreaker = new CircuitBreaker(); + private expectedDimensions: number | null = null; + + constructor(config: RemoteLLMConfig) { + this.config = { + connectTimeoutMs: 5000, + embedReadTimeoutMs: 30000, + rerankReadTimeoutMs: 60000, + maxBatchSize: 32, + ...config, + }; + } + + get embedModelName(): string { + return this.config.embedApiModel; + } + + /** True when expandApiModel is configured and remote query expansion is available. */ + get supportsExpand(): boolean { + return !!this.config.expandApiModel; + } + + // --------------------------------------------------------------------------- + // Embedding + // --------------------------------------------------------------------------- + + async embed(text: string, options?: EmbedOptions): Promise { + const results = await this.embedBatch([text], options); + return results[0] ?? null; + } + + async embedBatch(texts: string[], _options?: EmbedOptions): Promise<(EmbeddingResult | null)[]> { + if (texts.length === 0) return []; + + if (!this.embedBreaker.canAttempt()) { + throw new Error( + `Remote embedding circuit breaker is open — endpoint ${this.config.embedApiUrl} is unavailable. ` + + `Will retry after cooldown.` + ); + } + + const batchSize = this.config.maxBatchSize; + const results: (EmbeddingResult | null)[] = []; + + for (let i = 0; i < texts.length; i += batchSize) { + const batch = texts.slice(i, i + batchSize); + const batchResults = await this.embedBatchRequest(batch); + results.push(...batchResults); + } + + return results; + } + + private async embedBatchRequest(texts: string[]): Promise<(EmbeddingResult | null)[]> { + const url = normalizeUrl(this.config.embedApiUrl, "/embeddings"); + const headers: Record = { "Content-Type": "application/json" }; + if (this.config.embedApiKey) { + headers["Authorization"] = `Bearer ${this.config.embedApiKey}`; + } + + const body = JSON.stringify({ + model: this.config.embedApiModel, + input: texts, + }); + + try { + const response = await fetchWithTimeout(url, { + method: "POST", + headers, + body, + }, this.config.embedReadTimeoutMs); + + if (!response.ok) { + const errText = await response.text().catch(() => ""); + throw new Error(`Embedding API returned ${response.status}: ${errText}`); + } + + const json = await response.json() as { + data: { embedding: number[]; index: number }[]; + }; + + // Validate dimensions consistency + if (json.data.length > 0) { + const dim = json.data[0]!.embedding.length; + if (this.expectedDimensions === null) { + this.expectedDimensions = dim; + } else if (dim !== this.expectedDimensions) { + throw new Error( + `Embedding dimension mismatch: expected ${this.expectedDimensions}, got ${dim}. ` + + `This usually means the remote model changed.` + ); + } + } + + // Sort by index to match input order + const sorted = [...json.data].sort((a, b) => a.index - b.index); + const results: (EmbeddingResult | null)[] = sorted.map(item => ({ + embedding: item.embedding, + model: this.config.embedApiModel, + })); + + this.embedBreaker.onSuccess(); + return results; + } catch (err) { + this.embedBreaker.onFailure(); + throw err; + } + } + + // --------------------------------------------------------------------------- + // Reranking + // --------------------------------------------------------------------------- + + async rerank(query: string, documents: RerankDocument[], _options?: RerankOptions): Promise { + const rerankUrl = this.config.rerankApiUrl || this.config.embedApiUrl; + const rerankModel = this.config.rerankApiModel; + const rerankKey = this.config.rerankApiKey || this.config.embedApiKey; + + if (!rerankModel) { + throw new Error("Remote reranking requires rerankApiModel to be configured"); + } + + if (!this.rerankBreaker.canAttempt()) { + throw new Error( + `Remote rerank circuit breaker is open — endpoint ${rerankUrl} is unavailable. ` + + `Will retry after cooldown.` + ); + } + + const url = normalizeUrl(rerankUrl, "/rerank"); + const headers: Record = { "Content-Type": "application/json" }; + if (rerankKey) { + headers["Authorization"] = `Bearer ${rerankKey}`; + } + + const body = JSON.stringify({ + model: rerankModel, + query, + documents: documents.map(d => d.text), + }); + + try { + const response = await fetchWithTimeout(url, { + method: "POST", + headers, + body, + }, this.config.rerankReadTimeoutMs); + + if (!response.ok) { + const errText = await response.text().catch(() => ""); + throw new Error(`Rerank API returned ${response.status}: ${errText}`); + } + + const json = await response.json() as { + results: { index: number; relevance_score: number }[]; + }; + + const results = json.results.map(r => ({ + file: documents[r.index]!.file, + score: r.relevance_score, + index: r.index, + })); + + this.rerankBreaker.onSuccess(); + return { results, model: rerankModel }; + } catch (err) { + this.rerankBreaker.onFailure(); + throw err; + } + } + + // --------------------------------------------------------------------------- + // Unsupported operations (these require local models) + // --------------------------------------------------------------------------- + + async generate(_prompt: string, _options?: GenerateOptions): Promise { + throw new Error("RemoteLLM does not support text generation — use HybridLLM to route generation to a local backend"); + } + + async modelExists(_model: string): Promise { + return { name: this.config.embedApiModel, exists: true }; + } + + async expandQuery(query: string, options?: { context?: string; includeLexical?: boolean; intent?: string }): Promise { + if (!this.config.expandApiModel) { + throw new Error("RemoteLLM: expandApiModel not configured — set expandApiUrl and expandApiModel to enable remote query expansion"); + } + + if (!this.expandBreaker.canAttempt()) { + throw new Error( + `Remote query expansion circuit breaker is open — endpoint unavailable. Will retry after cooldown.` + ); + } + + const expandUrl = this.config.expandApiUrl || this.config.embedApiUrl; + const expandKey = this.config.expandApiKey || this.config.embedApiKey; + const includeLexical = options?.includeLexical ?? true; + const intent = options?.intent; + + const systemPrompt = + "You are a search query expansion assistant. " + + "Given a search query, produce expanded variants in EXACTLY this format:\n" + + "lex: \n" + + "vec: \n" + + "hyde: \n\n" + + "Output only those three lines. No explanation, no extra text."; + + const userPrompt = intent + ? `Expand this search query: ${query}\nQuery intent: ${intent}` + : `Expand this search query: ${query}`; + + const url = normalizeUrl(expandUrl, "/chat/completions"); + const headers: Record = { "Content-Type": "application/json" }; + if (expandKey) headers["Authorization"] = `Bearer ${expandKey}`; + + const body = JSON.stringify({ + model: this.config.expandApiModel, + messages: [ + { role: "system", content: systemPrompt }, + { role: "user", content: userPrompt }, + ], + temperature: 0.7, + max_tokens: 200, + }); + + try { + const response = await fetchWithTimeout(url, { + method: "POST", + headers, + body, + }, this.config.expandReadTimeoutMs ?? 30000); + + if (!response.ok) { + const errText = await response.text().catch(() => ""); + throw new Error(`Query expansion API returned ${response.status}: ${errText}`); + } + + const json = await response.json() as { + choices: { message: { content: string } }[]; + }; + const content = json.choices[0]?.message?.content ?? ""; + + this.expandBreaker.onSuccess(); + return parseExpandResponse(content, query, includeLexical); + } catch (err) { + this.expandBreaker.onFailure(); + throw err; + } + } + + async dispose(): Promise { + // Nothing to dispose for HTTP client + } +} + +// ============================================================================= +// Helpers +// ============================================================================= + +/** + * Parse the chat completion response from the expand endpoint into Queryable[]. + * Expects lines in "type: content" format where type ∈ {lex, vec, hyde}. + */ +function parseExpandResponse(content: string, originalQuery: string, includeLexical: boolean): Queryable[] { + const lines = content.trim().split("\n"); + const queryTerms = originalQuery.toLowerCase() + .replace(/[^a-z0-9\s]/g, " ") + .split(/\s+/) + .filter(Boolean); + + const hasQueryTerm = (text: string): boolean => { + if (queryTerms.length === 0) return true; + const lower = text.toLowerCase(); + return queryTerms.some(term => lower.includes(term)); + }; + + const queryables: Queryable[] = lines.flatMap(line => { + const colonIdx = line.indexOf(":"); + if (colonIdx === -1) return []; + const type = line.slice(0, colonIdx).trim() as QueryType; + if (type !== "lex" && type !== "vec" && type !== "hyde") return []; + const text = line.slice(colonIdx + 1).trim(); + if (!text || !hasQueryTerm(text)) return []; + return [{ type, text }]; + }); + + const filtered = includeLexical ? queryables : queryables.filter(q => q.type !== "lex"); + if (filtered.length > 0) return filtered; + + // Fallback when model output couldn't be parsed + const fallback: Queryable[] = [ + { type: "hyde", text: `Information about ${originalQuery}` }, + { type: "lex", text: originalQuery }, + { type: "vec", text: originalQuery }, + ]; + return includeLexical ? fallback : fallback.filter(q => q.type !== "lex"); +} + +/** + * Normalize a base URL and append a path, handling trailing slashes. + */ +function normalizeUrl(baseUrl: string, path: string): string { + const base = baseUrl.replace(/\/+$/, ""); + return `${base}${path}`; +} + +/** + * Fetch with a timeout using AbortSignal.timeout(). + */ +async function fetchWithTimeout( + url: string, + init: RequestInit, + timeoutMs: number +): Promise { + return fetch(url, { + ...init, + signal: AbortSignal.timeout(timeoutMs), + }); +} + +// ============================================================================= +// Configuration from environment +// ============================================================================= + +/** + * Create a RemoteLLMConfig from environment variables and optional YAML config. + * Returns null if remote embedding is not configured. + */ +export function remoteConfigFromEnv(yamlModels?: { + embed_api_url?: string; + embed_api_model?: string; + embed_api_key?: string; + rerank_api_url?: string; + rerank_api_model?: string; + rerank_api_key?: string; + expand_api_url?: string; + expand_api_model?: string; + expand_api_key?: string; +}): RemoteLLMConfig | null { + const embedApiUrl = process.env.QMD_EMBED_API_URL || yamlModels?.embed_api_url; + const embedApiModel = process.env.QMD_EMBED_API_MODEL || yamlModels?.embed_api_model; + + if (!embedApiUrl || !embedApiModel) return null; + + return { + embedApiUrl, + embedApiModel, + embedApiKey: process.env.QMD_EMBED_API_KEY || yamlModels?.embed_api_key, + rerankApiUrl: process.env.QMD_RERANK_API_URL || yamlModels?.rerank_api_url, + rerankApiModel: process.env.QMD_RERANK_API_MODEL || yamlModels?.rerank_api_model, + rerankApiKey: process.env.QMD_RERANK_API_KEY || yamlModels?.rerank_api_key, + expandApiUrl: process.env.QMD_EXPAND_API_URL || yamlModels?.expand_api_url, + expandApiModel: process.env.QMD_EXPAND_API_MODEL || yamlModels?.expand_api_model, + expandApiKey: process.env.QMD_EXPAND_API_KEY || yamlModels?.expand_api_key, + connectTimeoutMs: parseEnvInt("QMD_REMOTE_CONNECT_TIMEOUT", 5000), + embedReadTimeoutMs: parseEnvInt("QMD_REMOTE_READ_TIMEOUT", 30000), + rerankReadTimeoutMs: parseEnvInt("QMD_REMOTE_RERANK_TIMEOUT", 60000), + expandReadTimeoutMs: parseEnvInt("QMD_REMOTE_EXPAND_TIMEOUT", 30000), + maxBatchSize: parseEnvInt("QMD_REMOTE_BATCH_SIZE", 32), + }; +} + +function parseEnvInt(name: string, defaultValue: number): number { + const val = process.env[name]; + if (!val) return defaultValue; + const parsed = parseInt(val, 10); + return Number.isFinite(parsed) && parsed > 0 ? parsed : defaultValue; +} diff --git a/src/store.ts b/src/store.ts index 16a55b7df..bce4a6695 100644 --- a/src/store.ts +++ b/src/store.ts @@ -19,11 +19,11 @@ import { readFileSync, realpathSync, statSync, mkdirSync } from "node:fs"; // Note: node:path resolve is not imported — we export our own cross-platform resolve() import fastGlob from "fast-glob"; import { - LlamaCpp, - getDefaultLlamaCpp, + getDefaultLLM, formatQueryForEmbedding, formatDocForEmbedding, withLLMSessionForLlm, + type LLM, type RerankDocument, type ILLMSession, } from "./llm.js"; @@ -59,11 +59,11 @@ export const CHUNK_WINDOW_TOKENS = 200; export const CHUNK_WINDOW_CHARS = CHUNK_WINDOW_TOKENS * 4; // 800 chars /** - * Get the LlamaCpp instance for a store — prefers the store's own instance, + * Get the LLM instance for a store — prefers the store's own instance, * falls back to the global singleton. */ -function getLlm(store: Store): LlamaCpp { - return store.llm ?? getDefaultLlamaCpp(); +function getLlm(store: Store): LLM { + return store.llm ?? getDefaultLLM(); } // ============================================================================= @@ -1087,8 +1087,8 @@ function ensureVecTableInternal(db: Database, dimensions: number): void { export type Store = { db: Database; dbPath: string; - /** Optional LlamaCpp instance for this store (overrides the global singleton) */ - llm?: LlamaCpp; + /** Optional LLM instance for this store (overrides the global singleton) */ + llm?: LLM; close: () => void; ensureVecTable: (dimensions: number) => void; @@ -2276,7 +2276,10 @@ export async function chunkDocumentByTokens( chunkStrategy: ChunkStrategy = "regex", signal?: AbortSignal ): Promise<{ text: string; pos: number; tokens: number }[]> { - const llm = getDefaultLlamaCpp(); + const llm = getDefaultLLM(); + + // Check if the LLM supports tokenization (LlamaCpp does, RemoteLLM doesn't) + const canTokenize = typeof (llm as any).tokenize === "function"; // Use moderate chars/token estimate (prose ~4, code ~2, mixed ~3) // If chunks exceed limit, they'll be re-split with actual ratio @@ -3186,12 +3189,12 @@ export async function searchVec(db: Database, query: string, model: string, limi // Embeddings // ============================================================================= -async function getEmbedding(text: string, model: string, isQuery: boolean, session?: ILLMSession, llmOverride?: LlamaCpp): Promise { +async function getEmbedding(text: string, model: string, isQuery: boolean, session?: ILLMSession, llmOverride?: LLM): Promise { // Format text using the appropriate prompt template const formattedText = isQuery ? formatQueryForEmbedding(text, model) : formatDocForEmbedding(text, undefined, model); const result = session ? await session.embed(formattedText, { model, isQuery }) - : await (llmOverride ?? getDefaultLlamaCpp()).embed(formattedText, { model, isQuery }); + : await (llmOverride ?? getDefaultLLM()).embed(formattedText, { model, isQuery }); return result?.embedding || null; } @@ -3255,7 +3258,7 @@ export function insertEmbedding( // Query expansion // ============================================================================= -export async function expandQuery(query: string, model: string = DEFAULT_QUERY_MODEL, db: Database, intent?: string, llmOverride?: LlamaCpp): Promise { +export async function expandQuery(query: string, model: string = DEFAULT_QUERY_MODEL, db: Database, intent?: string, llmOverride?: LLM): Promise { // Check cache first — stored as JSON preserving types const cacheKey = getCacheKey("expandQuery", { query, model, ...(intent && { intent }) }); const cached = getCachedResult(db, cacheKey); @@ -3273,7 +3276,7 @@ export async function expandQuery(query: string, model: string = DEFAULT_QUERY_M } } - const llm = llmOverride ?? getDefaultLlamaCpp(); + const llm = llmOverride ?? getDefaultLLM(); // Note: LlamaCpp uses hardcoded model, model parameter is ignored const results = await llm.expandQuery(query, { intent }); @@ -3294,7 +3297,7 @@ export async function expandQuery(query: string, model: string = DEFAULT_QUERY_M // Reranking // ============================================================================= -export async function rerank(query: string, documents: { file: string; text: string }[], model: string = DEFAULT_RERANK_MODEL, db: Database, intent?: string, llmOverride?: LlamaCpp): Promise<{ file: string; score: number }[]> { +export async function rerank(query: string, documents: { file: string; text: string }[], model: string = DEFAULT_RERANK_MODEL, db: Database, intent?: string, llmOverride?: LLM): Promise<{ file: string; score: number }[]> { // Prepend intent to rerank query so the reranker scores with domain context const rerankQuery = intent ? `${intent}\n\n${query}` : query; @@ -3319,7 +3322,7 @@ export async function rerank(query: string, documents: { file: string; text: str // Rerank uncached documents using LlamaCpp if (uncachedDocsByChunk.size > 0) { - const llm = llmOverride ?? getDefaultLlamaCpp(); + const llm = llmOverride ?? getDefaultLLM(); const uncachedDocs = [...uncachedDocsByChunk.values()]; const rerankResult = await llm.rerank(rerankQuery, uncachedDocs, { model }); diff --git a/test/remote-llm-integration.test.ts b/test/remote-llm-integration.test.ts new file mode 100644 index 000000000..b168d825c --- /dev/null +++ b/test/remote-llm-integration.test.ts @@ -0,0 +1,514 @@ +/** + * Integration tests for RemoteLLM against live vLLM servers. + * + * Requires environment variables: + * VLLM_EMBED_URL - e.g. http://gpu-host:8002/v1 + * VLLM_EMBED_MODEL - e.g. Qwen/Qwen3-Embedding-0.6B + * VLLM_RERANK_URL - e.g. http://gpu-host:8001/v1 + * VLLM_RERANK_MODEL - e.g. qwen3-reranker-4b + * VLLM_EXPAND_URL - e.g. http://gpu-host:8005/v1 (optional — skips expand tests if absent) + * VLLM_EXPAND_MODEL - e.g. cyankiwi/Qwen3.5-35B-A3B-AWQ-4bit + * + * Skip these tests when no server is available (all tests guard on EMBED_URL). + */ + +import { describe, it, expect, beforeAll, beforeEach } from "vitest"; +import { RemoteLLM } from "../src/remote-llm.js"; +import { HybridLLM } from "../src/hybrid-llm.js"; +import { formatQueryForEmbedding, formatDocForEmbedding } from "../src/llm.js"; +import type { LLM } from "../src/llm.js"; + +const EMBED_URL = process.env.VLLM_EMBED_URL ?? ""; +const EMBED_MODEL = process.env.VLLM_EMBED_MODEL ?? ""; +const RERANK_URL = process.env.VLLM_RERANK_URL ?? ""; +const RERANK_MODEL = process.env.VLLM_RERANK_MODEL ?? ""; +const EXPAND_URL = process.env.VLLM_EXPAND_URL ?? ""; +const EXPAND_MODEL = process.env.VLLM_EXPAND_MODEL ?? ""; + +const SKIP = !EMBED_URL || !EMBED_MODEL; +const SKIP_EXPAND = SKIP || !EXPAND_URL || !EXPAND_MODEL; + +let remoteLlm: RemoteLLM; + +beforeAll(() => { + if (SKIP) return; + remoteLlm = new RemoteLLM({ + embedApiUrl: EMBED_URL, + embedApiModel: EMBED_MODEL, + rerankApiUrl: RERANK_URL, + rerankApiModel: RERANK_MODEL, + }); +}); + +// ============================================================================= +// Connectivity +// ============================================================================= + +describe.skipIf(SKIP)("Server connectivity", () => { + it("can reach the embedding server", async () => { + const res = await fetch(`${EMBED_URL}/models`); + expect(res.ok).toBe(true); + const json = await res.json() as any; + expect(json.data.length).toBeGreaterThan(0); + }); + + it("can reach the reranking server", async () => { + const res = await fetch(`${RERANK_URL}/models`); + expect(res.ok).toBe(true); + }); +}); + +// ============================================================================= +// Single embedding +// ============================================================================= + +describe.skipIf(SKIP)("Single embedding", () => { + it("returns a non-empty embedding vector", async () => { + const result = await remoteLlm.embed("The quick brown fox jumps over the lazy dog"); + expect(result).not.toBeNull(); + expect(result!.embedding.length).toBeGreaterThan(0); + expect(result!.model).toBe(EMBED_MODEL); + }); + + it("embedding values are finite numbers", async () => { + const result = await remoteLlm.embed("test embedding quality"); + expect(result).not.toBeNull(); + for (const val of result!.embedding) { + expect(Number.isFinite(val)).toBe(true); + } + }); + + it("embedding is normalized (L2 norm ≈ 1.0)", async () => { + const result = await remoteLlm.embed("normalization check"); + expect(result).not.toBeNull(); + const norm = Math.sqrt(result!.embedding.reduce((sum, v) => sum + v * v, 0)); + expect(norm).toBeCloseTo(1.0, 1); // within 0.1 + }); + + it("different texts produce different embeddings", async () => { + const [a, b] = await Promise.all([ + remoteLlm.embed("cats are wonderful pets"), + remoteLlm.embed("quantum computing research paper"), + ]); + expect(a).not.toBeNull(); + expect(b).not.toBeNull(); + // Cosine similarity should be < 1 (they are different) + const dot = a!.embedding.reduce((sum, v, i) => sum + v * b!.embedding[i]!, 0); + expect(dot).toBeLessThan(0.95); + }); + + it("similar texts produce similar embeddings", async () => { + const [a, b] = await Promise.all([ + remoteLlm.embed("how to train a puppy"), + remoteLlm.embed("puppy training tips"), + ]); + expect(a).not.toBeNull(); + expect(b).not.toBeNull(); + const dot = a!.embedding.reduce((sum, v, i) => sum + v * b!.embedding[i]!, 0); + expect(dot).toBeGreaterThan(0.7); + }); +}); + +// ============================================================================= +// Dimension consistency +// ============================================================================= + +describe.skipIf(SKIP)("Dimension consistency", () => { + it("all embeddings have the same dimension", async () => { + const texts = [ + "short", + "a medium length sentence about embedding dimensions", + "a much longer piece of text that goes on and on to test whether the embedding dimension stays consistent regardless of input length, which it absolutely should because the model always projects to a fixed-size output vector", + ]; + const results = await Promise.all(texts.map(t => remoteLlm.embed(t))); + const dims = results.map(r => r!.embedding.length); + expect(new Set(dims).size).toBe(1); + console.log(` Embedding dimension: ${dims[0]}`); + }); +}); + +// ============================================================================= +// Batch embedding +// ============================================================================= + +describe.skipIf(SKIP)("Batch embedding", () => { + it("embeds a batch of texts", async () => { + const texts = [ + "document one about machine learning", + "document two about cooking recipes", + "document three about space exploration", + ]; + const results = await remoteLlm.embedBatch(texts); + expect(results).toHaveLength(3); + for (const r of results) { + expect(r).not.toBeNull(); + expect(r!.embedding.length).toBeGreaterThan(0); + } + }); + + it("batch results match individual results", async () => { + const texts = ["alpha text", "beta text"]; + const [batchResults, individual1, individual2] = await Promise.all([ + remoteLlm.embedBatch(texts), + remoteLlm.embed("alpha text"), + remoteLlm.embed("beta text"), + ]); + + // Compare batch[0] with individual1 + expect(batchResults[0]!.embedding.length).toBe(individual1!.embedding.length); + // Embeddings should be very close (may not be exactly identical due to batching) + const dot = batchResults[0]!.embedding.reduce( + (sum, v, i) => sum + v * individual1!.embedding[i]!, 0 + ); + expect(dot).toBeGreaterThan(0.99); + }); + + it("handles empty batch", async () => { + const results = await remoteLlm.embedBatch([]); + expect(results).toEqual([]); + }); + + it("handles large batch (>32 texts, triggers splitting)", async () => { + const texts = Array.from({ length: 50 }, (_, i) => `document number ${i} about topic ${i % 5}`); + const results = await remoteLlm.embedBatch(texts); + expect(results).toHaveLength(50); + for (const r of results) { + expect(r).not.toBeNull(); + } + }); +}); + +// ============================================================================= +// Edge cases +// ============================================================================= + +describe.skipIf(SKIP)("Edge cases", () => { + it("handles very short text", async () => { + const result = await remoteLlm.embed("a"); + expect(result).not.toBeNull(); + expect(result!.embedding.length).toBeGreaterThan(0); + }); + + it("handles text with special characters", async () => { + const result = await remoteLlm.embed("café résumé naïve 日本語 中文 🎉 "); + expect(result).not.toBeNull(); + expect(result!.embedding.length).toBeGreaterThan(0); + }); + + it("handles multi-paragraph text", async () => { + const text = `# Introduction + +This is a long document with multiple paragraphs and markdown formatting. + +## Section 1 + +Some content here with **bold** and *italic* text. + +## Section 2 + +More content with a list: +- item one +- item two +- item three + +\`\`\`python +def hello(): + print("hello world") +\`\`\` +`; + const result = await remoteLlm.embed(text); + expect(result).not.toBeNull(); + expect(result!.embedding.length).toBeGreaterThan(0); + }); +}); + +// ============================================================================= +// Reranking +// ============================================================================= + +describe.skipIf(SKIP)("Reranking", () => { + it("reranks documents by relevance", async () => { + const query = "how to bake chocolate chip cookies"; + const documents = [ + { file: "space.md", text: "The Mars rover collected soil samples from the crater rim." }, + { file: "cookies.md", text: "Preheat oven to 375°F. Mix flour, butter, sugar and chocolate chips. Bake for 12 minutes." }, + { file: "quantum.md", text: "Quantum entanglement allows particles to be correlated over large distances." }, + { file: "baking.md", text: "Cookie recipes require precise measurements of ingredients like flour and sugar." }, + ]; + + const result = await remoteLlm.rerank(query, documents); + expect(result.model).toBe(RERANK_MODEL); + expect(result.results).toHaveLength(4); + + // The cookie/baking docs should rank higher than space/quantum + const scores = new Map(result.results.map(r => [r.file, r.score])); + console.log(" Rerank scores:", Object.fromEntries(scores)); + + expect(scores.get("cookies.md")!).toBeGreaterThan(scores.get("space.md")!); + expect(scores.get("cookies.md")!).toBeGreaterThan(scores.get("quantum.md")!); + }); + + it("scores are between 0 and 1", async () => { + const result = await remoteLlm.rerank("test query", [ + { file: "a.md", text: "relevant document about testing" }, + { file: "b.md", text: "unrelated document about gardening" }, + ]); + for (const r of result.results) { + expect(r.score).toBeGreaterThanOrEqual(0); + expect(r.score).toBeLessThanOrEqual(1); + } + }); + + it("preserves file mapping through index", async () => { + const documents = [ + { file: "first.md", text: "First document" }, + { file: "second.md", text: "Second document" }, + { file: "third.md", text: "Third document" }, + ]; + const result = await remoteLlm.rerank("query", documents); + const files = new Set(result.results.map(r => r.file)); + expect(files).toEqual(new Set(["first.md", "second.md", "third.md"])); + }); + + it("handles single document", async () => { + const result = await remoteLlm.rerank("test", [ + { file: "only.md", text: "The only document" }, + ]); + expect(result.results).toHaveLength(1); + expect(result.results[0]!.file).toBe("only.md"); + }); + + it("handles many documents", async () => { + const documents = Array.from({ length: 20 }, (_, i) => ({ + file: `doc${i}.md`, + text: `Document ${i} contains some text about topic ${i % 4}`, + })); + const result = await remoteLlm.rerank("topic about topic 2", documents); + expect(result.results).toHaveLength(20); + }); +}); + +// ============================================================================= +// Embedding format (remote models skip prefixes) +// ============================================================================= + +describe.skipIf(SKIP)("Embedding format for remote models", () => { + it("formatQueryForEmbedding returns raw text for remote model name", () => { + const formatted = formatQueryForEmbedding("search query", EMBED_MODEL); + expect(formatted).toBe("search query"); + }); + + it("formatDocForEmbedding returns raw text for remote model name", () => { + const formatted = formatDocForEmbedding("doc content", undefined, EMBED_MODEL); + expect(formatted).toBe("doc content"); + }); + + it("formatDocForEmbedding with title prepends title for remote model", () => { + const formatted = formatDocForEmbedding("doc content", "Title", EMBED_MODEL); + expect(formatted).toBe("Title\ndoc content"); + }); +}); + +// ============================================================================= +// HybridLLM integration +// ============================================================================= + +describe.skipIf(SKIP)("HybridLLM with real remote backend", () => { + // Mock local LLM for generate/expandQuery + function createMockLocal(): LLM { + return { + embedModelName: "local-embed-model", + embed: async () => ({ embedding: [0.5], model: "local" }), + embedBatch: async (texts) => texts.map(() => ({ embedding: [0.5], model: "local" })), + generate: async () => ({ text: "generated text", model: "local", done: true }), + modelExists: async (model) => ({ name: model, exists: true }), + expandQuery: async () => [{ type: "lex" as const, text: "expanded" }], + rerank: async () => ({ results: [], model: "local" }), + dispose: async () => {}, + }; + } + + it("routes embed through remote, returning real embeddings", async () => { + const hybrid = new HybridLLM(remoteLlm, createMockLocal()); + const result = await hybrid.embed("testing hybrid embedding"); + expect(result).not.toBeNull(); + // Real embedding has many dimensions, not just [0.5] + expect(result!.embedding.length).toBeGreaterThan(10); + expect(result!.model).toBe(EMBED_MODEL); + }); + + it("routes embedBatch through remote", async () => { + const hybrid = new HybridLLM(remoteLlm, createMockLocal()); + const results = await hybrid.embedBatch(["text one", "text two"]); + expect(results).toHaveLength(2); + expect(results[0]!.embedding.length).toBeGreaterThan(10); + }); + + it("routes rerank through remote", async () => { + const hybrid = new HybridLLM(remoteLlm, createMockLocal()); + const result = await hybrid.rerank("cookies", [ + { file: "a.md", text: "baking cookies at 350 degrees" }, + { file: "b.md", text: "orbiting space station" }, + ]); + expect(result.model).toBe(RERANK_MODEL); + expect(result.results).toHaveLength(2); + }); + + it("routes generate through local mock", async () => { + const hybrid = new HybridLLM(remoteLlm, createMockLocal()); + const result = await hybrid.generate("prompt"); + expect(result!.text).toBe("generated text"); + expect(result!.model).toBe("local"); + }); + + it("routes expandQuery through local mock", async () => { + const hybrid = new HybridLLM(remoteLlm, createMockLocal()); + const result = await hybrid.expandQuery("query"); + expect(result[0]!.text).toBe("expanded"); + }); + + it("embedModelName comes from remote", async () => { + const hybrid = new HybridLLM(remoteLlm, createMockLocal()); + expect(hybrid.embedModelName).toBe(EMBED_MODEL); + }); +}); + +// ============================================================================= +// End-to-end: embed → cosine similarity search +// ============================================================================= + +describe.skipIf(SKIP)("End-to-end embed + search simulation", () => { + it("finds the most relevant document via cosine similarity", async () => { + // Index some "documents" + const docs = [ + { file: "git.md", text: "Git is a distributed version control system for tracking changes in source code" }, + { file: "cooking.md", text: "To make pasta, boil water, add salt, cook noodles for 8 minutes" }, + { file: "docker.md", text: "Docker containers package applications with their dependencies for consistent deployment" }, + { file: "gardening.md", text: "Tomatoes need full sun and regular watering to produce fruit" }, + { file: "typescript.md", text: "TypeScript adds static type checking to JavaScript for safer code" }, + ]; + + // Embed all documents + const docEmbeddings = await remoteLlm.embedBatch(docs.map(d => d.text)); + + // Embed a query + const queryResult = await remoteLlm.embed("how to use version control for my code"); + expect(queryResult).not.toBeNull(); + + // Compute cosine similarities + const similarities = docEmbeddings.map((docEmb, i) => { + const dot = queryResult!.embedding.reduce((sum, v, j) => sum + v * docEmb!.embedding[j]!, 0); + return { file: docs[i]!.file, similarity: dot }; + }); + + similarities.sort((a, b) => b.similarity - a.similarity); + console.log(" Similarity ranking:"); + for (const s of similarities) { + console.log(` ${s.file}: ${s.similarity.toFixed(4)}`); + } + + // git.md should be the top result for a version control query + expect(similarities[0]!.file).toBe("git.md"); + // cooking/gardening should be near the bottom + const cookingRank = similarities.findIndex(s => s.file === "cooking.md"); + const gitRank = similarities.findIndex(s => s.file === "git.md"); + expect(gitRank).toBeLessThan(cookingRank); + }); +}); + +// ============================================================================= +// Remote query expansion +// ============================================================================= + +describe.skipIf(SKIP_EXPAND)("Remote query expansion", () => { + let expandLlm: RemoteLLM; + + beforeAll(() => { + if (SKIP_EXPAND) return; + expandLlm = new RemoteLLM({ + embedApiUrl: EMBED_URL, + embedApiModel: EMBED_MODEL, + expandApiUrl: EXPAND_URL, + expandApiModel: EXPAND_MODEL, + }); + }); + + it("supportsExpand is true when expandApiModel is configured", () => { + expect(expandLlm.supportsExpand).toBe(true); + }); + + it("returns non-empty Queryable[] for a plain query", async () => { + const result = await expandLlm.expandQuery("how to bake chocolate chip cookies"); + expect(result.length).toBeGreaterThan(0); + for (const q of result) { + expect(["lex", "vec", "hyde"]).toContain(q.type); + expect(typeof q.text).toBe("string"); + expect(q.text.length).toBeGreaterThan(0); + } + console.log(" Expansion result:", result.map(q => `${q.type}: ${q.text}`)); + }); + + it("returns all three types (lex, vec, hyde) for a well-formed query", async () => { + const result = await expandLlm.expandQuery("git rebase workflow"); + const types = new Set(result.map(q => q.type)); + // Model should return at least two distinct types + expect(types.size).toBeGreaterThanOrEqual(2); + }); + + it("excludes lex entries when includeLexical is false", async () => { + const result = await expandLlm.expandQuery("docker container networking", { includeLexical: false }); + expect(result.every(q => q.type !== "lex")).toBe(true); + expect(result.length).toBeGreaterThan(0); + }); + + it("incorporates intent into expansion when provided", async () => { + const withIntent = await expandLlm.expandQuery("python", { intent: "find beginner tutorials" }); + // Should produce results tailored to beginner/tutorial angle — at minimum non-empty + expect(withIntent.length).toBeGreaterThan(0); + }); + + it("expanded queries improve recall over original query alone", async () => { + // Use a short ambiguous query + const expanded = await expandLlm.expandQuery("bank"); + // Should get at least two different expansion texts — model interpreted the query + const texts = new Set(expanded.map(q => q.text)); + expect(texts.size).toBeGreaterThanOrEqual(2); + }); +}); + +// ============================================================================= +// HybridLLM with remote expand +// ============================================================================= + +describe.skipIf(SKIP_EXPAND)("HybridLLM with remote expand", () => { + it("routes expandQuery to remote when expandApiModel is configured", async () => { + function createMockLocal(): LLM { + return { + embedModelName: "local", + embed: async () => null, + embedBatch: async (texts) => texts.map(() => null), + generate: async () => null, + modelExists: async (m) => ({ name: m, exists: false }), + expandQuery: async () => [{ type: "lex" as const, text: "LOCAL_SENTINEL" }], + rerank: async () => ({ results: [], model: "local" }), + dispose: async () => {}, + }; + } + + const remote = new RemoteLLM({ + embedApiUrl: EMBED_URL, + embedApiModel: EMBED_MODEL, + expandApiUrl: EXPAND_URL, + expandApiModel: EXPAND_MODEL, + }); + + const hybrid = new HybridLLM(remote, createMockLocal()); + const result = await hybrid.expandQuery("version control with git"); + + // Must NOT come from local mock + expect(result.every(q => q.text !== "LOCAL_SENTINEL")).toBe(true); + // Should have real expanded results + expect(result.length).toBeGreaterThan(0); + console.log(" HybridLLM remote expand:", result.map(q => `${q.type}: ${q.text}`)); + }); +}); diff --git a/test/remote-llm.test.ts b/test/remote-llm.test.ts new file mode 100644 index 000000000..1bd450b85 --- /dev/null +++ b/test/remote-llm.test.ts @@ -0,0 +1,762 @@ +/** + * Tests for RemoteLLM and HybridLLM + * + * Uses a local HTTP server to mock OpenAI-compatible endpoints. + */ + +import { describe, it, expect, beforeAll, afterAll, beforeEach, afterEach } from "vitest"; +import { createServer, type Server, type IncomingMessage, type ServerResponse } from "http"; +import { RemoteLLM, remoteConfigFromEnv, type RemoteLLMConfig } from "../src/remote-llm.js"; +import { HybridLLM } from "../src/hybrid-llm.js"; +import { isRemoteModel, formatQueryForEmbedding, formatDocForEmbedding, getDefaultLLM, setDefaultLLM, LlamaCpp } from "../src/llm.js"; +import type { LLM, EmbeddingResult, RerankResult, Queryable, GenerateResult, ModelInfo } from "../src/llm.js"; + +// ============================================================================= +// Mock HTTP server +// ============================================================================= + +type MockHandler = (req: IncomingMessage, body: string) => { status: number; body: any }; + +let server: Server; +let serverPort: number; +let mockHandler: MockHandler; + +function setMockHandler(handler: MockHandler) { + mockHandler = handler; +} + +beforeAll(async () => { + server = createServer(async (req: IncomingMessage, res: ServerResponse) => { + const chunks: Buffer[] = []; + for await (const chunk of req) { + chunks.push(chunk as Buffer); + } + const body = Buffer.concat(chunks).toString(); + + try { + const result = mockHandler(req, body); + res.writeHead(result.status, { "Content-Type": "application/json" }); + res.end(JSON.stringify(result.body)); + } catch (err: any) { + res.writeHead(500, { "Content-Type": "application/json" }); + res.end(JSON.stringify({ error: err.message })); + } + }); + + await new Promise((resolve) => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address(); + if (typeof addr === "object" && addr) { + serverPort = addr.port; + } + resolve(); + }); + }); +}); + +afterAll(() => { + server.close(); +}); + +function baseUrl(): string { + return `http://127.0.0.1:${serverPort}/v1`; +} + +function createRemoteLLM(overrides?: Partial): RemoteLLM { + return new RemoteLLM({ + embedApiUrl: baseUrl(), + embedApiModel: "test-model", + ...overrides, + }); +} + +// ============================================================================= +// RemoteLLM Tests +// ============================================================================= + +describe("RemoteLLM", () => { + describe("embed", () => { + it("should embed a single text", async () => { + setMockHandler((req, body) => { + const parsed = JSON.parse(body); + expect(parsed.model).toBe("test-model"); + expect(parsed.input).toEqual(["hello world"]); + return { + status: 200, + body: { + data: [{ embedding: [0.1, 0.2, 0.3], index: 0 }], + }, + }; + }); + + const llm = createRemoteLLM(); + const result = await llm.embed("hello world"); + expect(result).not.toBeNull(); + expect(result!.embedding).toEqual([0.1, 0.2, 0.3]); + expect(result!.model).toBe("test-model"); + }); + + it("should embed a batch of texts", async () => { + setMockHandler((_req, body) => { + const parsed = JSON.parse(body); + return { + status: 200, + body: { + data: parsed.input.map((text: string, i: number) => ({ + embedding: [i * 0.1, i * 0.2], + index: i, + })), + }, + }; + }); + + const llm = createRemoteLLM(); + const results = await llm.embedBatch(["text1", "text2", "text3"]); + expect(results).toHaveLength(3); + expect(results[0]!.embedding).toEqual([0, 0]); + expect(results[2]!.embedding).toEqual([0.2, 0.4]); + }); + + it("should return empty array for empty input", async () => { + const llm = createRemoteLLM(); + const results = await llm.embedBatch([]); + expect(results).toEqual([]); + }); + + it("should split large batches", async () => { + const requestBodies: string[][] = []; + setMockHandler((_req, body) => { + const parsed = JSON.parse(body); + requestBodies.push(parsed.input); + return { + status: 200, + body: { + data: parsed.input.map((_: string, i: number) => ({ + embedding: [1.0], + index: i, + })), + }, + }; + }); + + const llm = createRemoteLLM({ maxBatchSize: 2 }); + const texts = ["a", "b", "c", "d", "e"]; + const results = await llm.embedBatch(texts); + + expect(results).toHaveLength(5); + // Should have made 3 requests: [a,b], [c,d], [e] + expect(requestBodies).toHaveLength(3); + expect(requestBodies[0]).toEqual(["a", "b"]); + expect(requestBodies[1]).toEqual(["c", "d"]); + expect(requestBodies[2]).toEqual(["e"]); + }); + + it("should sort response by index", async () => { + setMockHandler(() => ({ + status: 200, + body: { + // Return in reverse order + data: [ + { embedding: [0.3], index: 2 }, + { embedding: [0.1], index: 0 }, + { embedding: [0.2], index: 1 }, + ], + }, + })); + + const llm = createRemoteLLM(); + const results = await llm.embedBatch(["a", "b", "c"]); + expect(results[0]!.embedding).toEqual([0.1]); + expect(results[1]!.embedding).toEqual([0.2]); + expect(results[2]!.embedding).toEqual([0.3]); + }); + }); + + describe("auth", () => { + it("should send Authorization header when key is set", async () => { + let authHeader: string | undefined; + setMockHandler((req) => { + authHeader = req.headers["authorization"] as string; + return { + status: 200, + body: { data: [{ embedding: [1.0], index: 0 }] }, + }; + }); + + const llm = createRemoteLLM({ embedApiKey: "test-key-123" }); + await llm.embed("test"); + expect(authHeader).toBe("Bearer test-key-123"); + }); + + it("should not send Authorization header when no key", async () => { + let authHeader: string | undefined; + setMockHandler((req) => { + authHeader = req.headers["authorization"] as string; + return { + status: 200, + body: { data: [{ embedding: [1.0], index: 0 }] }, + }; + }); + + const llm = createRemoteLLM(); + await llm.embed("test"); + expect(authHeader).toBeUndefined(); + }); + }); + + describe("dimension validation", () => { + it("should reject dimension mismatch after first response", async () => { + let callCount = 0; + setMockHandler(() => { + callCount++; + const dim = callCount === 1 ? [1.0, 2.0, 3.0] : [1.0, 2.0]; + return { + status: 200, + body: { data: [{ embedding: dim, index: 0 }] }, + }; + }); + + const llm = createRemoteLLM(); + // First call succeeds and locks dimensions to 3 + await llm.embed("first"); + // Second call should fail because dimensions changed + await expect(llm.embed("second")).rejects.toThrow("dimension mismatch"); + }); + }); + + describe("error handling", () => { + it("should throw on HTTP error", async () => { + setMockHandler(() => ({ + status: 500, + body: { error: "Internal server error" }, + })); + + const llm = createRemoteLLM(); + await expect(llm.embed("test")).rejects.toThrow("500"); + }); + + it("should open circuit breaker after failures", async () => { + setMockHandler(() => ({ + status: 500, + body: { error: "down" }, + })); + + const llm = createRemoteLLM(); + // Fail 3 times to trip the breaker + for (let i = 0; i < 3; i++) { + await expect(llm.embed("test")).rejects.toThrow(); + } + // Next call should fail immediately with circuit breaker message + await expect(llm.embed("test")).rejects.toThrow("circuit breaker"); + }); + }); + + describe("rerank", () => { + it("should rerank documents", async () => { + setMockHandler((_req, body) => { + const parsed = JSON.parse(body); + expect(parsed.model).toBe("rerank-model"); + expect(parsed.query).toBe("test query"); + expect(parsed.documents).toEqual(["doc A text", "doc B text"]); + return { + status: 200, + body: { + results: [ + { index: 1, relevance_score: 0.9 }, + { index: 0, relevance_score: 0.3 }, + ], + }, + }; + }); + + const llm = createRemoteLLM({ + rerankApiModel: "rerank-model", + }); + const result = await llm.rerank( + "test query", + [ + { file: "a.md", text: "doc A text" }, + { file: "b.md", text: "doc B text" }, + ] + ); + + expect(result.model).toBe("rerank-model"); + expect(result.results).toHaveLength(2); + expect(result.results.find(r => r.file === "b.md")!.score).toBe(0.9); + expect(result.results.find(r => r.file === "a.md")!.score).toBe(0.3); + }); + + it("should throw when rerankApiModel not configured", async () => { + const llm = createRemoteLLM(); + await expect( + llm.rerank("query", [{ file: "a.md", text: "text" }]) + ).rejects.toThrow("rerankApiModel"); + }); + }); + + describe("unsupported operations", () => { + it("should throw on generate", async () => { + const llm = createRemoteLLM(); + await expect(llm.generate("prompt")).rejects.toThrow("does not support text generation"); + }); + + it("should throw on expandQuery when expandApiModel not configured", async () => { + const llm = createRemoteLLM(); + await expect(llm.expandQuery("query")).rejects.toThrow("expandApiModel not configured"); + }); + }); + + // =========================================================================== + // expandQuery + // =========================================================================== + + describe("expandQuery", () => { + function createExpandLLM(overrides?: Partial): RemoteLLM { + return createRemoteLLM({ expandApiModel: "expand-model", ...overrides }); + } + + function expandResponse(content: string) { + return { + status: 200, + body: { choices: [{ message: { content } }] }, + }; + } + + it("supportsExpand is false without expandApiModel", () => { + expect(createRemoteLLM().supportsExpand).toBe(false); + }); + + it("supportsExpand is true with expandApiModel", () => { + expect(createExpandLLM().supportsExpand).toBe(true); + }); + + it("calls /chat/completions with correct payload", async () => { + let capturedBody: any; + setMockHandler((req, body) => { + expect(req.url).toBe("/v1/chat/completions"); + capturedBody = JSON.parse(body); + return expandResponse("lex: foo keywords\nvec: foo semantic\nhyde: A document about foo"); + }); + + await createExpandLLM().expandQuery("foo"); + + expect(capturedBody.model).toBe("expand-model"); + expect(capturedBody.messages).toHaveLength(2); + expect(capturedBody.messages[0].role).toBe("system"); + expect(capturedBody.messages[1].role).toBe("user"); + expect(capturedBody.messages[1].content).toContain("foo"); + }); + + it("includes intent in user prompt when provided", async () => { + let userContent = ""; + setMockHandler((_req, body) => { + userContent = JSON.parse(body).messages[1].content; + return expandResponse("lex: foo\nvec: foo semantic\nhyde: A document about foo"); + }); + + await createExpandLLM().expandQuery("foo", { intent: "find documentation" }); + expect(userContent).toContain("find documentation"); + }); + + it("sends Authorization header for expand endpoint", async () => { + let authHeader: string | undefined; + setMockHandler((req) => { + authHeader = req.headers["authorization"] as string; + return expandResponse("lex: t\nvec: t semantic\nhyde: A document about t"); + }); + + await createExpandLLM({ expandApiKey: "expand-secret" }).expandQuery("t"); + expect(authHeader).toBe("Bearer expand-secret"); + }); + + it("falls back to embedApiKey when expandApiKey not set", async () => { + let authHeader: string | undefined; + setMockHandler((req) => { + authHeader = req.headers["authorization"] as string; + return expandResponse("lex: t\nvec: t semantic\nhyde: A document about t"); + }); + + await createExpandLLM({ embedApiKey: "embed-key" }).expandQuery("t"); + expect(authHeader).toBe("Bearer embed-key"); + }); + + it("parses lex/vec/hyde lines into Queryable[]", async () => { + setMockHandler(() => + expandResponse( + "lex: chocolate chip cookies keywords\n" + + "vec: how to bake chocolate chip cookies\n" + + "hyde: Preheat oven to 375F and mix chocolate chips into the dough before baking" + ) + ); + + const result = await createExpandLLM().expandQuery("chocolate chip cookies"); + + expect(result).toHaveLength(3); + expect(result.find(q => q.type === "lex")?.text).toBe("chocolate chip cookies keywords"); + expect(result.find(q => q.type === "vec")?.text).toBe("how to bake chocolate chip cookies"); + expect(result.find(q => q.type === "hyde")?.text).toContain("chocolate"); + }); + + it("excludes lex entries when includeLexical is false", async () => { + setMockHandler(() => + expandResponse("lex: foo keywords\nvec: foo semantic\nhyde: A document about foo") + ); + + const result = await createExpandLLM().expandQuery("foo", { includeLexical: false }); + + expect(result.every(q => q.type !== "lex")).toBe(true); + expect(result.length).toBeGreaterThan(0); + }); + + it("returns fallback Queryable[] when model output is unparseable", async () => { + setMockHandler(() => expandResponse("Sorry, I cannot help with that.")); + + const result = await createExpandLLM().expandQuery("cats"); + + expect(result.length).toBeGreaterThan(0); + expect(result.some(q => q.text === "cats")).toBe(true); + }); + + it("skips lines whose text shares no term with the original query", async () => { + setMockHandler(() => + // "completely different" shares no words with "cats" + expandResponse("lex: completely different\nvec: cats semantic paraphrase\nhyde: A document about cats") + ); + + const result = await createExpandLLM().expandQuery("cats"); + + const lexEntry = result.find(q => q.type === "lex"); + expect(lexEntry?.text).not.toBe("completely different"); + }); + + it("opens circuit breaker after repeated expansion failures", async () => { + setMockHandler(() => ({ status: 503, body: { error: "unavailable" } })); + + const llm = createExpandLLM(); + for (let i = 0; i < 3; i++) { + await expect(llm.expandQuery("query")).rejects.toThrow("503"); + } + await expect(llm.expandQuery("query")).rejects.toThrow("circuit breaker"); + }); + + it("throws on HTTP error response", async () => { + setMockHandler(() => ({ status: 500, body: { error: "server error" } })); + await expect(createExpandLLM().expandQuery("query")).rejects.toThrow("500"); + }); + }); +}); + +// ============================================================================= +// HybridLLM Tests +// ============================================================================= + +describe("HybridLLM", () => { + // Simple mock local LLM + function createMockLocalLLM(): LLM { + return { + embedModelName: "local-model", + embed: async () => ({ embedding: [0.5], model: "local-model" }), + embedBatch: async (texts) => texts.map(() => ({ embedding: [0.5], model: "local-model" })), + generate: async () => ({ text: "expanded", model: "local-model", done: true }), + modelExists: async (model) => ({ name: model, exists: true }), + expandQuery: async () => [{ type: "lex" as const, text: "expanded query" }], + rerank: async () => ({ results: [], model: "local-model" }), + dispose: async () => {}, + }; + } + + it("should route embed to remote", async () => { + setMockHandler(() => ({ + status: 200, + body: { data: [{ embedding: [0.9], index: 0 }] }, + })); + + const remote = createRemoteLLM(); + const local = createMockLocalLLM(); + const hybrid = new HybridLLM(remote, local); + + const result = await hybrid.embed("test"); + // Should come from remote (0.9), not local (0.5) + expect(result!.embedding).toEqual([0.9]); + }); + + it("should route embedBatch to remote", async () => { + setMockHandler((_req, body) => { + const parsed = JSON.parse(body); + return { + status: 200, + body: { + data: parsed.input.map((_: string, i: number) => ({ + embedding: [0.9 + i * 0.01], + index: i, + })), + }, + }; + }); + + const remote = createRemoteLLM(); + const local = createMockLocalLLM(); + const hybrid = new HybridLLM(remote, local); + + const results = await hybrid.embedBatch(["a", "b"]); + expect(results[0]!.embedding).toEqual([0.9]); + expect(results[1]!.embedding).toEqual([0.91]); + }); + + it("should route generate to local", async () => { + const local = createMockLocalLLM(); + const remote = createRemoteLLM(); + const hybrid = new HybridLLM(remote, local); + + const result = await hybrid.generate("prompt"); + expect(result!.text).toBe("expanded"); + expect(result!.model).toBe("local-model"); + }); + + it("should route expandQuery to local when expandApiModel not set", async () => { + const local = createMockLocalLLM(); + const remote = createRemoteLLM(); // no expandApiModel + const hybrid = new HybridLLM(remote, local); + + const result = await hybrid.expandQuery("test query"); + expect(result[0]!.text).toBe("expanded query"); // from mock local + }); + + it("should route expandQuery to remote when expandApiModel is set", async () => { + setMockHandler((_req, body) => { + const parsed = JSON.parse(body); + expect(parsed.model).toBe("remote-expand-model"); + return { + status: 200, + body: { + choices: [{ + message: { + content: "lex: test keywords\nvec: test semantic\nhyde: A document about test query", + }, + }], + }, + }; + }); + + const remote = createRemoteLLM({ expandApiModel: "remote-expand-model" }); + const local = createMockLocalLLM(); + const hybrid = new HybridLLM(remote, local); + + const result = await hybrid.expandQuery("test query"); + // Should come from remote, not local mock ("expanded query") + expect(result.some(q => q.text === "expanded query")).toBe(false); + expect(result.some(q => q.type === "lex")).toBe(true); + }); + + it("should use remote embedModelName", async () => { + const remote = createRemoteLLM({ embedApiModel: "BAAI/bge-m3" }); + const local = createMockLocalLLM(); + const hybrid = new HybridLLM(remote, local); + + expect(hybrid.embedModelName).toBe("BAAI/bge-m3"); + }); +}); + +// ============================================================================= +// Config Tests +// ============================================================================= + +describe("remoteConfigFromEnv", () => { + const origEnv = { ...process.env }; + + beforeEach(() => { + // Clear any QMD_ env vars + for (const key of Object.keys(process.env)) { + if (key.startsWith("QMD_") && key.includes("API")) { + delete process.env[key]; + } + } + }); + + afterAll(() => { + // Restore original env + for (const key of Object.keys(process.env)) { + if (key.startsWith("QMD_") && key.includes("API")) { + delete process.env[key]; + } + } + Object.assign(process.env, origEnv); + }); + + it("should return null when no config", () => { + expect(remoteConfigFromEnv()).toBeNull(); + }); + + it("should parse env vars", () => { + process.env.QMD_EMBED_API_URL = "http://gpu:8000/v1"; + process.env.QMD_EMBED_API_MODEL = "bge-m3"; + process.env.QMD_EMBED_API_KEY = "secret"; + + const config = remoteConfigFromEnv(); + expect(config).not.toBeNull(); + expect(config!.embedApiUrl).toBe("http://gpu:8000/v1"); + expect(config!.embedApiModel).toBe("bge-m3"); + expect(config!.embedApiKey).toBe("secret"); + }); + + it("should use YAML config as fallback", () => { + const config = remoteConfigFromEnv({ + embed_api_url: "http://yaml:8000/v1", + embed_api_model: "yaml-model", + }); + expect(config).not.toBeNull(); + expect(config!.embedApiUrl).toBe("http://yaml:8000/v1"); + }); + + it("should prefer env vars over YAML", () => { + process.env.QMD_EMBED_API_URL = "http://env:8000/v1"; + process.env.QMD_EMBED_API_MODEL = "env-model"; + + const config = remoteConfigFromEnv({ + embed_api_url: "http://yaml:8000/v1", + embed_api_model: "yaml-model", + }); + expect(config!.embedApiUrl).toBe("http://env:8000/v1"); + expect(config!.embedApiModel).toBe("env-model"); + }); + + it("should parse expand env vars", () => { + process.env.QMD_EMBED_API_URL = "http://gpu:8000/v1"; + process.env.QMD_EMBED_API_MODEL = "bge-m3"; + process.env.QMD_EXPAND_API_URL = "http://gpu:8001/v1"; + process.env.QMD_EXPAND_API_MODEL = "llama3-8b"; + process.env.QMD_EXPAND_API_KEY = "expand-key"; + + const config = remoteConfigFromEnv(); + expect(config!.expandApiUrl).toBe("http://gpu:8001/v1"); + expect(config!.expandApiModel).toBe("llama3-8b"); + expect(config!.expandApiKey).toBe("expand-key"); + }); + + it("should parse expand fields from YAML config", () => { + const config = remoteConfigFromEnv({ + embed_api_url: "http://gpu:8000/v1", + embed_api_model: "bge-m3", + expand_api_model: "yaml-expand-model", + }); + expect(config!.expandApiModel).toBe("yaml-expand-model"); + expect(config!.expandApiUrl).toBeUndefined(); + }); + + it("expandApiModel not set when not in env or YAML", () => { + const config = remoteConfigFromEnv({ + embed_api_url: "http://gpu:8000/v1", + embed_api_model: "bge-m3", + }); + expect(config!.expandApiModel).toBeUndefined(); + }); +}); + +// ============================================================================= +// Embedding format tests +// ============================================================================= + +describe("isRemoteModel", () => { + it("should detect remote models", () => { + expect(isRemoteModel("BAAI/bge-m3")).toBe(true); + expect(isRemoteModel("intfloat/multilingual-e5-large")).toBe(true); + expect(isRemoteModel("text-embedding-ada-002")).toBe(true); + }); + + it("should detect local models", () => { + expect(isRemoteModel("hf:ggml-org/embeddinggemma-300M-GGUF/embeddinggemma-300M-Q8_0.gguf")).toBe(false); + expect(isRemoteModel("/path/to/model.gguf")).toBe(false); + }); +}); + +describe("formatQueryForEmbedding with remote models", () => { + it("should return raw query for remote models", () => { + expect(formatQueryForEmbedding("test query", "BAAI/bge-m3")).toBe("test query"); + }); + + it("should add prefix for local nomic models", () => { + expect(formatQueryForEmbedding("test query", "hf:ggml-org/embeddinggemma-300M-GGUF/embeddinggemma-300M-Q8_0.gguf")).toContain("task:"); + }); +}); + +describe("formatDocForEmbedding with remote models", () => { + it("should return raw text for remote models", () => { + expect(formatDocForEmbedding("doc text", undefined, "BAAI/bge-m3")).toBe("doc text"); + }); + + it("should include title when provided for remote models", () => { + expect(formatDocForEmbedding("doc text", "My Title", "BAAI/bge-m3")).toBe("My Title\ndoc text"); + }); +}); + +// ============================================================================= +// Local-only path (no remote config) +// ============================================================================= + +describe("Local-only LlamaCpp path", () => { + afterEach(() => { + // Reset to default so other tests aren't affected + setDefaultLLM(null); + }); + + it("getDefaultLLM() returns a LlamaCpp instance when nothing is configured", () => { + setDefaultLLM(null); + const llm = getDefaultLLM(); + expect(llm).toBeInstanceOf(LlamaCpp); + }); + + it("LlamaCpp instance satisfies the LLM interface", () => { + const llm = new LlamaCpp(); + // All LLM interface methods exist + expect(typeof llm.embed).toBe("function"); + expect(typeof llm.embedBatch).toBe("function"); + expect(typeof llm.generate).toBe("function"); + expect(typeof llm.modelExists).toBe("function"); + expect(typeof llm.expandQuery).toBe("function"); + expect(typeof llm.rerank).toBe("function"); + expect(typeof llm.dispose).toBe("function"); + expect(typeof llm.embedModelName).toBe("string"); + }); + + it("LlamaCpp has tokenize method (used by chunkDocumentByTokens duck-typing)", () => { + const llm = new LlamaCpp(); + expect(typeof llm.tokenize).toBe("function"); + }); + + it("setDefaultLLM with LlamaCpp is retrievable via getDefaultLLM", () => { + const llm = new LlamaCpp(); + setDefaultLLM(llm); + expect(getDefaultLLM()).toBe(llm); + }); + + it("remoteConfigFromEnv returns null when no env vars or YAML set", () => { + // Clear any remote env vars + const saved: Record = {}; + for (const key of ["QMD_EMBED_API_URL", "QMD_EMBED_API_MODEL"]) { + saved[key] = process.env[key]; + delete process.env[key]; + } + try { + expect(remoteConfigFromEnv()).toBeNull(); + expect(remoteConfigFromEnv({})).toBeNull(); + expect(remoteConfigFromEnv({ embed_api_url: undefined })).toBeNull(); + } finally { + for (const [key, val] of Object.entries(saved)) { + if (val !== undefined) process.env[key] = val; + } + } + }); + + it("formatQueryForEmbedding adds nomic prefix for default local model", () => { + // Default model is embeddinggemma (hf: URI), should get task prefix + const formatted = formatQueryForEmbedding("hello"); + expect(formatted).toContain("task:"); + expect(formatted).toContain("hello"); + }); + + it("formatDocForEmbedding adds nomic prefix for default local model", () => { + const formatted = formatDocForEmbedding("doc content", "My Doc"); + expect(formatted).toContain("title: My Doc"); + expect(formatted).toContain("text: doc content"); + }); +});