Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog/3684.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- `AICFilter` now shares read-only AIC models via a singleton `AICModelManager` in `aic_filter.py`.
- Multiple filters using the same `model path` or `(model_id, model_download_dir)` share one loaded model, with reference counting and concurrent load deduplication.
196 changes: 184 additions & 12 deletions src/pipecat/audio/filters/aic_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@

Classes:
AICFilter: For aic-sdk (uses 'aic_sdk' module)
AICModelManager: Singleton manager for read-only AIC Model instances.
"""

import asyncio
from pathlib import Path
from typing import List, Optional
from threading import Lock
from typing import List, Optional, Tuple

import numpy as np
from aic_sdk import (
Expand All @@ -33,6 +36,174 @@
from pipecat.frames.frames import FilterControlFrame, FilterEnableFrame


class AICModelManager:
"""Singleton manager for read-only AIC Model instances with reference counting.

Caches Model instances by path or (model_id + download_dir). Multiple
AICFilter instances using the same model share one Model; the manager
acquires on first use and releases when the last reference is dropped.
"""

_cache: dict[str, Tuple[Model, int]] = {} # key -> (model, ref_count)
_lock = Lock()
_loading: dict[
str, asyncio.Task[Model]
] = {} # key -> load task (deduplicates concurrent loads)

@classmethod
def _increment_reference(cls, cache_key: str, entry: Tuple[Model, int]) -> Tuple[Model, str]:
"""Increment reference count for cached entry. Caller must hold _lock."""
cached_model, ref_count = entry
cls._cache[cache_key] = (cached_model, ref_count + 1)
logger.debug(f"AIC model cache key={cache_key!r} ref_count={ref_count + 1}")
return cached_model, cache_key

@classmethod
def _store_new_reference(cls, cache_key: str, model: Model) -> Tuple[Model, str]:
"""Store new model in cache with ref count 1. Caller must hold _lock."""
cls._cache[cache_key] = (model, 1)
logger.debug(f"AIC model cached key={cache_key!r} ref_count=1")
return model, cache_key

@classmethod
async def _load_model_from_file(
cls,
cache_key: str,
*,
model_path: Optional[Path] = None,
model_id: Optional[str] = None,
model_download_dir: Optional[Path] = None,
) -> Model:
"""Run the actual load (file or download). Separate to allow create_task and deduplication."""
if model_path is not None:
logger.debug(f"Loading AIC model from file: {model_path}")
return Model.from_file(str(model_path))
Comment thread
gkmngrgn marked this conversation as resolved.
Outdated

if model_id is not None and model_download_dir is not None:
logger.debug(f"Downloading AIC model: {model_id}")
model_download_dir.mkdir(parents=True, exist_ok=True)
path = await Model.download_async(model_id, str(model_download_dir))
logger.debug(f"Model downloaded to: {path}")
return Model.from_file(path)

raise ValueError("Unexpected model_path or (model_id and model_download_dir) state.")

@staticmethod
def _get_cache_key(
*,
model_path: Optional[Path] = None,
model_id: Optional[str] = None,
model_download_dir: Optional[Path] = None,
) -> str:
"""Build a stable cache key for the model.

Args:
model_path: Path to a local .aicmodel file.
model_id: Model identifier (See https://artifacts.ai-coustics.io/ for available models).
model_download_dir: Directory used for downloading models.

Returns:
A string key unique per (path) or (model_id + download_dir).
"""
if model_path is not None:
return f"path:{model_path.resolve()}"

if model_id is not None and model_download_dir is not None:
return f"id:{model_id}:{model_download_dir.resolve()}"

raise ValueError("Either model_path or (model_id and model_download_dir) must be set.")

@classmethod
async def acquire(
cls,
*,
model_path: Optional[Path] = None,
model_id: Optional[str] = None,
model_download_dir: Optional[Path] = None,
) -> Tuple[Model, str]:
"""Get or load a Model and increment its reference count.

Call this when starting a filter. Store the returned key and pass it
to release() when stopping the filter.

Args:
model_path: Path to a local .aicmodel file. If set, model_id is ignored.
model_id: Model identifier to download from CDN.
model_download_dir: Directory for downloading models. Required if
model_id is used.

Returns:
Tuple of (shared Model instance, cache key for release).

Raises:
ValueError: If neither model_path nor (model_id + model_download_dir)
is provided, or if model_id is set without model_download_dir.
"""
cache_key = cls._get_cache_key(
model_path=model_path,
model_id=model_id,
model_download_dir=model_download_dir,
)

with cls._lock:
entry = cls._cache.get(cache_key)
if entry is not None:
return cls._increment_reference(cache_key, entry)

# Deduplicate concurrent loads for the same key
load_task = cls._loading.get(cache_key)
if load_task is None:
load_task = asyncio.create_task(
cls._load_model_from_file(
cache_key,
model_path=model_path,
model_id=model_id,
model_download_dir=model_download_dir,
)
)
cls._loading[cache_key] = load_task

try:
model = await load_task
finally:
with cls._lock:
cls._loading.pop(cache_key, None)

with cls._lock:
entry = cls._cache.get(cache_key)
if entry is not None:
return cls._increment_reference(cache_key, entry)
return cls._store_new_reference(cache_key, model)

@classmethod
def release(cls, key: str) -> None:
"""Release a reference to a cached model.

Call this when stopping a filter, with the key returned from
get_model(). When the last reference is released, the model
is removed from the cache.

Args:
key: Cache key returned by get_model().
"""
with cls._lock:
entry = cls._cache.get(key)

if entry is None:
logger.warning(f"AIC model release unknown key={key!r}")
return

model, ref_count = entry
ref_count -= 1

if ref_count <= 0:
del cls._cache[key]
logger.debug(f"AIC model evicted key={key!r}")
else:
cls._cache[key] = (model, ref_count)
logger.debug(f"AIC model key={key!r} ref_count={ref_count}")


class AICFilter(BaseAudioFilter):
"""Audio filter using ai-coustics' AIC SDK for real-time enhancement.

Expand Down Expand Up @@ -91,7 +262,8 @@ def __init__(
32768.0 # 2^15, for normalizing int16 (-32768 to 32767) to float32 (-1.0 to 1.0)
)

# AIC SDK objects
# AIC SDK objects; model is shared via AICModelManager
self._model_cache_key: Optional[str] = None
self._model = None
self._processor = None
self._processor_ctx = None
Expand Down Expand Up @@ -162,16 +334,12 @@ async def start(self, sample_rate: int):
"""
self._sample_rate = sample_rate

# Load or download model
if self._model_path:
logger.debug(f"Loading AIC model from: {self._model_path}")
self._model = Model.from_file(str(self._model_path))
else:
logger.debug(f"Downloading AIC model: {self._model_id}")
self._model_download_dir.mkdir(parents=True, exist_ok=True)
model_path = await Model.download_async(self._model_id, str(self._model_download_dir))
logger.debug(f"Model downloaded to: {model_path}")
self._model = Model.from_file(model_path)
# Acquire shared read-only model from singleton manager
self._model, self._model_cache_key = await AICModelManager.acquire(
model_path=self._model_path,
model_id=self._model_id,
model_download_dir=self._model_download_dir,
)

# Get optimal frames for this sample rate
self._frames_per_block = self._model.get_optimal_num_frames(self._sample_rate)
Expand Down Expand Up @@ -242,6 +410,10 @@ async def stop(self):
self._aic_ready = False
self._audio_buffer.clear()

if self._model_cache_key is not None:
AICModelManager.release(self._model_cache_key)
self._model_cache_key = None

async def process_frame(self, frame: FilterControlFrame):
"""Process control frames to enable/disable filtering.

Expand Down
Loading