|
1 | | -"""Circuit breaker for resilient API calls.""" |
| 1 | +"""Circuit breaker for resilient API calls. |
| 2 | +
|
| 3 | +This module adapts the resilience module's circuit breaker implementations |
| 4 | +for use in the APIs module, maintaining backward compatibility. |
| 5 | +""" |
2 | 6 |
|
3 | 7 | from __future__ import annotations |
4 | 8 |
|
5 | | -import asyncio |
6 | | -import time |
7 | 9 | from collections.abc import Callable |
8 | 10 | from typing import Any |
9 | 11 |
|
10 | 12 | from agentle.agents.apis.circuit_breaker_error import CircuitBreakerError |
11 | | -from agentle.agents.apis.circuit_breaker_state import CircuitBreakerState |
12 | 13 | from agentle.agents.apis.request_config import RequestConfig |
| 14 | +from agentle.resilience.circuit_breaker.in_memory_circuit_breaker import ( |
| 15 | + InMemoryCircuitBreaker, |
| 16 | +) |
13 | 17 |
|
14 | 18 |
|
15 | 19 | class CircuitBreaker: |
16 | | - """Circuit breaker implementation for resilient API calls.""" |
| 20 | + """ |
| 21 | + Circuit breaker implementation for resilient API calls. |
| 22 | +
|
| 23 | + This wraps the resilience module's InMemoryCircuitBreaker to provide |
| 24 | + a simpler call-based API for endpoint usage. |
| 25 | + """ |
17 | 26 |
|
18 | 27 | def __init__(self, config: RequestConfig): |
19 | 28 | self.config = config |
20 | | - self.state = CircuitBreakerState.CLOSED |
21 | | - self.failure_count = 0 |
22 | | - self.success_count = 0 |
23 | | - self.last_failure_time: float | None = None |
24 | | - self._lock = asyncio.Lock() |
| 29 | + self._circuit_id = "default" # Single circuit per endpoint |
| 30 | + # Initialize the underlying circuit breaker from resilience module |
| 31 | + self._impl = InMemoryCircuitBreaker( |
| 32 | + failure_threshold=config.circuit_breaker_failure_threshold, |
| 33 | + recovery_timeout=config.circuit_breaker_recovery_timeout, |
| 34 | + half_open_success_threshold=config.circuit_breaker_success_threshold, |
| 35 | + enable_metrics=config.enable_metrics, |
| 36 | + ) |
25 | 37 |
|
26 | 38 | async def call(self, func: Callable[[], Any]) -> Any: |
27 | | - """Execute function with circuit breaker protection.""" |
28 | | - async with self._lock: |
29 | | - # Check if circuit is open |
30 | | - if self.state == CircuitBreakerState.OPEN: |
31 | | - # Check if we should transition to half-open |
32 | | - if self.last_failure_time: |
33 | | - elapsed = time.time() - self.last_failure_time |
34 | | - if elapsed >= self.config.circuit_breaker_recovery_timeout: |
35 | | - self.state = CircuitBreakerState.HALF_OPEN |
36 | | - self.success_count = 0 |
37 | | - else: |
38 | | - raise CircuitBreakerError( |
39 | | - f"Circuit breaker is OPEN. Retry after {self.config.circuit_breaker_recovery_timeout - elapsed:.1f}s" |
40 | | - ) |
| 39 | + """ |
| 40 | + Execute function with circuit breaker protection. |
| 41 | +
|
| 42 | + Args: |
| 43 | + func: Async function to execute |
| 44 | +
|
| 45 | + Returns: |
| 46 | + Result of func call |
| 47 | +
|
| 48 | + Raises: |
| 49 | + CircuitBreakerError: If circuit is open |
| 50 | + """ |
| 51 | + # Check if circuit is open |
| 52 | + if await self._impl.is_open(self._circuit_id): |
| 53 | + # Get circuit state for more details |
| 54 | + state = await self._impl.get_circuit_state(self._circuit_id) |
| 55 | + next_retry_seconds = state.get("next_recovery_attempt_in_seconds", 0) |
| 56 | + |
| 57 | + if next_retry_seconds > 0: |
| 58 | + raise CircuitBreakerError( |
| 59 | + f"Circuit breaker is OPEN. Retry after {next_retry_seconds:.1f}s" |
| 60 | + ) |
41 | 61 |
|
42 | 62 | # Execute the function |
43 | 63 | try: |
44 | 64 | result = await func() |
45 | | - await self._on_success() |
| 65 | + await self._impl.record_success(self._circuit_id) |
46 | 66 | return result |
47 | 67 | except Exception: |
48 | | - await self._on_failure() |
| 68 | + await self._impl.record_failure(self._circuit_id) |
49 | 69 | raise |
50 | | - |
51 | | - async def _on_success(self) -> None: |
52 | | - """Handle successful call.""" |
53 | | - async with self._lock: |
54 | | - if self.state == CircuitBreakerState.HALF_OPEN: |
55 | | - self.success_count += 1 |
56 | | - if self.success_count >= self.config.circuit_breaker_success_threshold: |
57 | | - self.state = CircuitBreakerState.CLOSED |
58 | | - self.failure_count = 0 |
59 | | - elif self.state == CircuitBreakerState.CLOSED: |
60 | | - self.failure_count = 0 |
61 | | - |
62 | | - async def _on_failure(self) -> None: |
63 | | - """Handle failed call.""" |
64 | | - async with self._lock: |
65 | | - self.failure_count += 1 |
66 | | - self.last_failure_time = time.time() |
67 | | - |
68 | | - if self.state == CircuitBreakerState.HALF_OPEN: |
69 | | - # Failure in half-open state reopens circuit |
70 | | - self.state = CircuitBreakerState.OPEN |
71 | | - elif self.state == CircuitBreakerState.CLOSED: |
72 | | - # Check if we've hit the failure threshold |
73 | | - if self.failure_count >= self.config.circuit_breaker_failure_threshold: |
74 | | - self.state = CircuitBreakerState.OPEN |
0 commit comments