Skip to content

Commit ff056e6

Browse files
committed
Fix concurrency issues across storage layer and core loop
- Snapshot active_subagents dict before iteration in cancel_all to prevent RuntimeError - Add proper shutdown event handling with fallback force exit in main loop - Replace connection caching with per-operation connections in memory store for thread safety - Add row-level locking and retry logic to issue_response_manager and task_type_manager - Rewrite work_queue with WAL mode, busy timeout, and immediate transactions - Add comprehensive concurrency test suite (16 tests) and benchmarks - Update pytest markers with clearer descriptions
1 parent f4ccf7a commit ff056e6

File tree

10 files changed

+1664
-302
lines changed

10 files changed

+1664
-302
lines changed

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,9 @@ testpaths = ["tests"]
161161
python_files = ["test_*.py", "*_test.py"]
162162
addopts = "-v --cov=sugar --cov-branch --cov-report=term-missing --cov-report=xml"
163163
markers = [
164-
"unit: Unit tests",
165-
"integration: Integration tests",
166-
"slow: Slow running tests"
164+
"unit: Unit tests (no I/O, no database)",
165+
"integration: Integration tests (real database, aiosqlite)",
166+
"slow: Slow running tests (throughput, load)"
167167
]
168168

169169
# MCP Registry identification

sugar/agent/subagent_manager.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,10 +316,17 @@ async def cancel_all(self) -> None:
316316
317317
Note: This will attempt graceful shutdown but may not
318318
interrupt tasks that are already executing.
319+
320+
Takes a snapshot of the dict before iterating to avoid
321+
RuntimeError if spawn()'s finally block mutates the dict
322+
concurrently (e.g., during shutdown while tasks are in-flight).
319323
"""
320324
logger.warning(f"Cancelling {len(self._active_subagents)} active sub-agents")
321325

322-
for task_id, subagent in self._active_subagents.items():
326+
# Snapshot to avoid RuntimeError: dictionary changed size during iteration
327+
active_snapshot = dict(self._active_subagents)
328+
329+
for task_id, subagent in active_snapshot.items():
323330
try:
324331
await subagent.end_session()
325332
logger.debug(f"Cancelled sub-agent task: {task_id}")

sugar/main.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,10 @@ def signal_handler(signum, frame):
210210
shutdown_event.set()
211211
logger.info("🔔 Shutdown event triggered")
212212
else:
213-
logger.warning("⚠️ Shutdown event not available")
213+
# Fallback: if shutdown_event isn't ready yet (shouldn't happen
214+
# now that we create it before registering handlers), exit cleanly.
215+
logger.warning("⚠️ Shutdown event not available, forcing exit")
216+
sys.exit(128 + signum)
214217

215218

216219
@click.group(invoke_without_command=True)
@@ -2035,6 +2038,11 @@ def run(ctx, dry_run, once, validate):
20352038
asyncio.run(validate_config(sugar_loop))
20362039
return
20372040

2041+
# Create shutdown event BEFORE registering signal handlers to avoid
2042+
# a window where a signal arrives but the event doesn't exist yet.
2043+
global shutdown_event
2044+
shutdown_event = asyncio.Event()
2045+
20382046
# Set up signal handlers for graceful shutdown
20392047
signal.signal(signal.SIGINT, signal_handler)
20402048
signal.signal(signal.SIGTERM, signal_handler)
@@ -2117,9 +2125,10 @@ async def run_once(sugar_loop):
21172125

21182126

21192127
async def run_continuous(sugar_loop):
2120-
"""Run Sugar continuously"""
2128+
"""Run Sugar continuously (shutdown_event created before signal handlers)"""
21212129
global shutdown_event
2122-
shutdown_event = asyncio.Event()
2130+
if shutdown_event is None:
2131+
shutdown_event = asyncio.Event()
21232132

21242133
# Create PID file for stop command
21252134
import os

sugar/memory/store.py

Lines changed: 86 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
import logging
99
import sqlite3
10+
import threading
1011
import uuid
1112
from datetime import datetime, timezone
1213
from pathlib import Path
@@ -59,6 +60,7 @@ def __init__(
5960
self.embedder = embedder or create_embedder()
6061
self._has_vec = self._check_sqlite_vec()
6162
self._conn: Optional[sqlite3.Connection] = None
63+
self._lock = threading.Lock()
6264

6365
self._init_db()
6466

@@ -73,9 +75,16 @@ def _check_sqlite_vec(self) -> bool:
7375
return False
7476

7577
def _get_connection(self) -> sqlite3.Connection:
76-
"""Get or create database connection."""
78+
"""Get or create database connection.
79+
80+
Uses check_same_thread=False to allow safe cross-thread use when
81+
called from asyncio's run_in_executor. Thread safety is ensured
82+
by self._lock around all public methods that access the connection.
83+
"""
7784
if self._conn is None:
78-
self._conn = sqlite3.connect(str(self.db_path))
85+
self._conn = sqlite3.connect(
86+
str(self.db_path), check_same_thread=False
87+
)
7988
self._conn.row_factory = sqlite3.Row
8089

8190
if self._has_vec:
@@ -209,96 +218,100 @@ def store(self, entry: MemoryEntry) -> str:
209218
if entry.created_at is None:
210219
entry.created_at = datetime.now(timezone.utc)
211220

212-
conn = self._get_connection()
213-
cursor = conn.cursor()
221+
with self._lock:
222+
conn = self._get_connection()
223+
cursor = conn.cursor()
214224

215-
# Store main entry
216-
cursor.execute(
217-
"""
218-
INSERT OR REPLACE INTO memory_entries
219-
(id, memory_type, source_id, content, summary, metadata,
220-
importance, created_at, last_accessed_at, access_count, expires_at)
221-
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
222-
""",
223-
(
224-
entry.id,
225+
# Store main entry
226+
cursor.execute(
227+
"""
228+
INSERT OR REPLACE INTO memory_entries
229+
(id, memory_type, source_id, content, summary, metadata,
230+
importance, created_at, last_accessed_at, access_count, expires_at)
231+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
232+
""",
225233
(
226-
entry.memory_type.value
227-
if isinstance(entry.memory_type, MemoryType)
228-
else entry.memory_type
234+
entry.id,
235+
(
236+
entry.memory_type.value
237+
if isinstance(entry.memory_type, MemoryType)
238+
else entry.memory_type
239+
),
240+
entry.source_id,
241+
entry.content,
242+
entry.summary,
243+
json.dumps(entry.metadata) if entry.metadata else None,
244+
entry.importance,
245+
entry.created_at.isoformat() if entry.created_at else None,
246+
entry.last_accessed_at.isoformat() if entry.last_accessed_at else None,
247+
entry.access_count,
248+
entry.expires_at.isoformat() if entry.expires_at else None,
229249
),
230-
entry.source_id,
231-
entry.content,
232-
entry.summary,
233-
json.dumps(entry.metadata) if entry.metadata else None,
234-
entry.importance,
235-
entry.created_at.isoformat() if entry.created_at else None,
236-
entry.last_accessed_at.isoformat() if entry.last_accessed_at else None,
237-
entry.access_count,
238-
entry.expires_at.isoformat() if entry.expires_at else None,
239-
),
240-
)
250+
)
241251

242-
# Generate and store embedding if we have semantic search
243-
if self._has_vec and not isinstance(self.embedder, FallbackEmbedder):
244-
try:
245-
embedding = self.embedder.embed(entry.content)
246-
if embedding:
247-
cursor.execute(
248-
"""
249-
INSERT OR REPLACE INTO memory_vectors (id, embedding)
250-
VALUES (?, ?)
251-
""",
252-
(entry.id, _serialize_embedding(embedding)),
253-
)
254-
except Exception as e:
255-
logger.warning(f"Failed to store embedding: {e}")
252+
# Generate and store embedding if we have semantic search
253+
if self._has_vec and not isinstance(self.embedder, FallbackEmbedder):
254+
try:
255+
embedding = self.embedder.embed(entry.content)
256+
if embedding:
257+
cursor.execute(
258+
"""
259+
INSERT OR REPLACE INTO memory_vectors (id, embedding)
260+
VALUES (?, ?)
261+
""",
262+
(entry.id, _serialize_embedding(embedding)),
263+
)
264+
except Exception as e:
265+
logger.warning(f"Failed to store embedding: {e}")
256266

257-
conn.commit()
258-
return entry.id
267+
conn.commit()
268+
return entry.id
259269

260270
def get(self, entry_id: str) -> Optional[MemoryEntry]:
261271
"""Get a memory entry by ID."""
262-
conn = self._get_connection()
263-
cursor = conn.cursor()
272+
with self._lock:
273+
conn = self._get_connection()
274+
cursor = conn.cursor()
264275

265-
cursor.execute(
266-
"""
267-
SELECT * FROM memory_entries WHERE id = ?
268-
""",
269-
(entry_id,),
270-
)
276+
cursor.execute(
277+
"""
278+
SELECT * FROM memory_entries WHERE id = ?
279+
""",
280+
(entry_id,),
281+
)
271282

272-
row = cursor.fetchone()
273-
if row:
274-
return self._row_to_entry(row)
275-
return None
283+
row = cursor.fetchone()
284+
if row:
285+
return self._row_to_entry(row)
286+
return None
276287

277288
def delete(self, entry_id: str) -> bool:
278289
"""Delete a memory entry."""
279-
conn = self._get_connection()
280-
cursor = conn.cursor()
290+
with self._lock:
291+
conn = self._get_connection()
292+
cursor = conn.cursor()
281293

282-
cursor.execute("DELETE FROM memory_entries WHERE id = ?", (entry_id,))
294+
cursor.execute("DELETE FROM memory_entries WHERE id = ?", (entry_id,))
283295

284-
if self._has_vec:
285-
try:
286-
cursor.execute("DELETE FROM memory_vectors WHERE id = ?", (entry_id,))
287-
except Exception:
288-
pass
296+
if self._has_vec:
297+
try:
298+
cursor.execute("DELETE FROM memory_vectors WHERE id = ?", (entry_id,))
299+
except Exception:
300+
pass
289301

290-
conn.commit()
291-
return cursor.rowcount > 0
302+
conn.commit()
303+
return cursor.rowcount > 0
292304

293305
def search(self, query: MemoryQuery) -> List[MemorySearchResult]:
294306
"""
295307
Search memories.
296308
297309
Uses vector similarity if available, falls back to FTS5.
298310
"""
299-
if self._has_vec and not isinstance(self.embedder, FallbackEmbedder):
300-
return self._search_semantic(query)
301-
return self._search_keyword(query)
311+
with self._lock:
312+
if self._has_vec and not isinstance(self.embedder, FallbackEmbedder):
313+
return self._search_semantic(query)
314+
return self._search_keyword(query)
302315

303316
def _search_semantic(self, query: MemoryQuery) -> List[MemorySearchResult]:
304317
"""Search using vector similarity."""
@@ -638,6 +651,7 @@ def prune_expired(self) -> int:
638651

639652
def close(self):
640653
"""Close database connection."""
641-
if self._conn:
642-
self._conn.close()
643-
self._conn = None
654+
with self._lock:
655+
if self._conn:
656+
self._conn.close()
657+
self._conn = None

sugar/storage/issue_response_manager.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Issue Response Manager - Track GitHub issue responses
33
"""
44

5+
import asyncio
56
import json
67
import logging
78
import uuid
@@ -19,41 +20,46 @@ class IssueResponseManager:
1920
def __init__(self, db_path: str = ".sugar/sugar.db"):
2021
self.db_path = db_path
2122
self._initialized = False
23+
self._init_lock = asyncio.Lock()
2224

2325
async def initialize(self) -> None:
2426
"""Create table if not exists"""
2527
if self._initialized:
2628
return
2729

28-
async with aiosqlite.connect(self.db_path) as db:
29-
await db.execute(
30+
async with self._init_lock:
31+
if self._initialized:
32+
return
33+
34+
async with aiosqlite.connect(self.db_path) as db:
35+
await db.execute(
36+
"""
37+
CREATE TABLE IF NOT EXISTS issue_responses (
38+
id TEXT PRIMARY KEY,
39+
repo TEXT NOT NULL,
40+
issue_number INTEGER NOT NULL,
41+
response_type TEXT NOT NULL,
42+
work_item_id TEXT,
43+
confidence REAL,
44+
posted_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
45+
response_content TEXT,
46+
labels_applied TEXT,
47+
was_auto_posted BOOLEAN DEFAULT 0,
48+
UNIQUE(repo, issue_number, response_type)
49+
)
3050
"""
31-
CREATE TABLE IF NOT EXISTS issue_responses (
32-
id TEXT PRIMARY KEY,
33-
repo TEXT NOT NULL,
34-
issue_number INTEGER NOT NULL,
35-
response_type TEXT NOT NULL,
36-
work_item_id TEXT,
37-
confidence REAL,
38-
posted_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
39-
response_content TEXT,
40-
labels_applied TEXT,
41-
was_auto_posted BOOLEAN DEFAULT 0,
42-
UNIQUE(repo, issue_number, response_type)
4351
)
44-
"""
45-
)
4652

47-
await db.execute(
53+
await db.execute(
54+
"""
55+
CREATE INDEX IF NOT EXISTS idx_issue_responses_repo_number
56+
ON issue_responses (repo, issue_number)
4857
"""
49-
CREATE INDEX IF NOT EXISTS idx_issue_responses_repo_number
50-
ON issue_responses (repo, issue_number)
51-
"""
52-
)
58+
)
5359

54-
await db.commit()
60+
await db.commit()
5561

56-
self._initialized = True
62+
self._initialized = True
5763
logger.debug(f"Issue response manager initialized: {self.db_path}")
5864

5965
async def has_responded(

0 commit comments

Comments
 (0)