|
| 1 | +import asyncio |
1 | 2 | import os |
2 | 3 | import threading |
3 | 4 | import time |
@@ -116,6 +117,13 @@ def __init__(self): |
116 | 117 | self.config = LettaConfig.load() |
117 | 118 | self.logger = get_logger(__name__) |
118 | 119 |
|
| 120 | + if settings.db_max_concurrent_sessions: |
| 121 | + self._db_semaphore = asyncio.Semaphore(settings.db_max_concurrent_sessions) |
| 122 | + self.logger.info(f"Initialized database throttling with max {settings.db_max_concurrent_sessions} concurrent sessions") |
| 123 | + else: |
| 124 | + self.logger.info("Database throttling is disabled") |
| 125 | + self._db_semaphore = None |
| 126 | + |
119 | 127 | def initialize_sync(self, force: bool = False) -> None: |
120 | 128 | """Initialize the synchronous database engine if not already initialized.""" |
121 | 129 | with self._lock: |
@@ -364,16 +372,33 @@ def session(self, name: str = "default") -> Generator[Any, None, None]: |
364 | 372 | @trace_method |
365 | 373 | @asynccontextmanager |
366 | 374 | async def async_session(self, name: str = "default") -> AsyncGenerator[AsyncSession, None]: |
367 | | - """Async context manager for database sessions.""" |
368 | | - session_factory = self.get_async_session_factory(name) |
369 | | - if not session_factory: |
370 | | - raise ValueError(f"No async session factory found for '{name}' or async database is not configured") |
| 375 | + """Async context manager for database sessions with throttling.""" |
| 376 | + if self._db_semaphore: |
| 377 | + async with self._db_semaphore: |
| 378 | + session_factory = self.get_async_session_factory(name) |
| 379 | + if not session_factory: |
| 380 | + raise ValueError(f"No async session factory found for '{name}' or async database is not configured") |
| 381 | + |
| 382 | + session = session_factory() |
| 383 | + try: |
| 384 | + yield session |
| 385 | + finally: |
| 386 | + await session.close() |
| 387 | + else: |
| 388 | + session_factory = self.get_async_session_factory(name) |
| 389 | + if not session_factory: |
| 390 | + raise ValueError(f"No async session factory found for '{name}' or async database is not configured") |
371 | 391 |
|
372 | | - session = session_factory() |
373 | | - try: |
374 | | - yield session |
375 | | - finally: |
376 | | - await session.close() |
| 392 | + session = session_factory() |
| 393 | + try: |
| 394 | + yield session |
| 395 | + finally: |
| 396 | + await session.close() |
| 397 | + |
| 398 | + @trace_method |
| 399 | + def session_caller_trace(self, caller_info: str): |
| 400 | + """Trace sync db caller information for debugging purposes.""" |
| 401 | + pass # wrapper used for otel tracing only |
377 | 402 |
|
378 | 403 | @trace_method |
379 | 404 | def session_caller_trace(self, caller_info: str): |
|
0 commit comments