diff --git a/src/snapshot_agent/README.md b/src/snapshot_agent/README.md new file mode 100644 index 00000000..0cb68d91 --- /dev/null +++ b/src/snapshot_agent/README.md @@ -0,0 +1,14 @@ +# Snapshot Agent + +The snapshot agent is a small process-level GPU residency primitive. + +It exposes four commands over a Unix socket: + +- `REGISTER(run_id, pid)` records the process that owns a run. +- `ACQUIRE(run_id)` grants that process the right to touch CUDA. +- `RELEASE(run_id)` checkpoints that process before another run can acquire CUDA. +- `UNREGISTER(run_id)` removes the process registration. + +Today every successful `RELEASE` checkpoints the process. This is simple and +conservative, but it is slow because even a single run pays checkpoint cost after +each acquire window. diff --git a/src/snapshot_agent/__init__.py b/src/snapshot_agent/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/snapshot_agent/checkpoint.py b/src/snapshot_agent/checkpoint.py new file mode 100644 index 00000000..52eb6393 --- /dev/null +++ b/src/snapshot_agent/checkpoint.py @@ -0,0 +1,50 @@ +import logging +import os +import shlex +import subprocess +import time +from typing import Protocol + +logger = logging.getLogger(__name__) + + +class CheckpointRestorer(Protocol): + def checkpoint(self, pid: int) -> None: + pass + + def restore(self, pid: int) -> None: + pass + + +class CudaCheckpointRestorer: + def __init__(self, cuda_checkpoint_bin: str | None = None, timeout_ms: int | None = None): + self.cuda_checkpoint_bin = cuda_checkpoint_bin or os.getenv("CUDA_CHECKPOINT_BIN", "cuda-checkpoint") + self.timeout_ms = timeout_ms + + def checkpoint(self, pid: int) -> None: + start = time.perf_counter() + logger.info("checkpoint pid=%s", pid) + lock_args = ["--action", "lock", "--pid", str(pid)] + if self.timeout_ms is not None: + lock_args.extend(["--timeout", str(self.timeout_ms)]) + + self.run_cuda_checkpoint(lock_args) + self.run_cuda_checkpoint(["--action", "checkpoint", "--pid", str(pid)]) + logger.info("checkpoint pid=%s took %.0f ms", pid, (time.perf_counter() - start) * 1000) + + def restore(self, pid: int) -> None: + start = time.perf_counter() + logger.info("restore pid=%s", pid) + self.run_cuda_checkpoint(["--action", "restore", "--pid", str(pid)]) + self.run_cuda_checkpoint(["--action", "unlock", "--pid", str(pid)]) + logger.info("restore pid=%s took %.0f ms", pid, (time.perf_counter() - start) * 1000) + + def run_cuda_checkpoint(self, args: list[str]) -> None: + full_argv = [self.cuda_checkpoint_bin, *args] + result = subprocess.run(full_argv, capture_output=True, check=False, text=True) + if result.returncode != 0: + stderr = result.stderr.strip() + stdout = result.stdout.strip() + detail = stderr or stdout or f"exit code {result.returncode}" + rendered_argv = " ".join(shlex.quote(arg) for arg in full_argv) + raise RuntimeError(f"{rendered_argv} failed: {detail}") diff --git a/src/snapshot_agent/client.py b/src/snapshot_agent/client.py new file mode 100644 index 00000000..7b2fb294 --- /dev/null +++ b/src/snapshot_agent/client.py @@ -0,0 +1,55 @@ +import asyncio +import json +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + + +class SnapshotAgentClient: + def __init__(self, socket_path: str): + self.socket_path = socket_path + self.reader: asyncio.StreamReader | None = None + self.writer: asyncio.StreamWriter | None = None + + async def connect(self) -> None: + if self.writer is not None and not self.writer.is_closing(): + return + self.reader, self.writer = await asyncio.open_unix_connection(self.socket_path) + + async def close(self) -> None: + if self.writer is None: + return + self.writer.close() + await self.writer.wait_closed() + self.reader = None + self.writer = None + + async def register(self, run_id: str, pid: int) -> dict[str, Any]: + return await self.request({"command": "REGISTER", "run_id": run_id, "pid": pid}) + + async def unregister(self, run_id: str) -> dict[str, Any]: + return await self.request({"command": "UNREGISTER", "run_id": run_id}) + + @asynccontextmanager + async def acquire(self, run_id: str) -> AsyncIterator[None]: + await self.request({"command": "ACQUIRE", "run_id": run_id}) + try: + yield + finally: + await self.request({"command": "RELEASE", "run_id": run_id}) + + async def request(self, payload: dict[str, Any]) -> dict[str, Any]: + await self.connect() + assert self.reader is not None + assert self.writer is not None + + self.writer.write(json.dumps(payload).encode("utf-8") + b"\n") + await self.writer.drain() + line = await self.reader.readline() + if not line: + raise RuntimeError("snapshot agent connection closed") + + response = json.loads(line.decode("utf-8")) + if not response.get("ok"): + raise RuntimeError(response.get("error", "snapshot agent command failed")) + return response diff --git a/src/snapshot_agent/serve.py b/src/snapshot_agent/serve.py new file mode 100644 index 00000000..81267df7 --- /dev/null +++ b/src/snapshot_agent/serve.py @@ -0,0 +1,199 @@ +import argparse +import asyncio +import json +import logging +import os +from collections import deque +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from .checkpoint import CheckpointRestorer, CudaCheckpointRestorer + +logger = logging.getLogger(__name__) + + +@dataclass +class ProcessRegistration: + pid: int + connection_id: int | None + checkpointed: bool = False + failed: bool = False + + +class SnapshotAgent: + def __init__(self, restorer: CheckpointRestorer): + self.restorer = restorer + self.processes: dict[str, ProcessRegistration] = {} + self.waiting_run_ids: deque[str] = deque() + self.active_run_id: str | None = None + self.condition = asyncio.Condition() + + def clear_run(self, run_id: str) -> None: + if run_id in self.waiting_run_ids: + self.waiting_run_ids.remove(run_id) + if self.active_run_id == run_id: + self.active_run_id = None + + async def register(self, run_id: str, pid: int, connection_id: int | None = None) -> dict[str, Any]: + async with self.condition: + process = self.processes.get(run_id) + + if process is not None: + return {"ok": False, "error": f"run '{run_id}' is already registered"} + + self.processes[run_id] = ProcessRegistration(pid=pid, connection_id=connection_id) + self.condition.notify_all() + return {"ok": True} + + async def acquire(self, run_id: str) -> dict[str, Any]: + async with self.condition: + process = self.processes.get(run_id) + if process is None: + return {"ok": False, "error": f"run '{run_id}' is not registered"} + if process.failed: + return {"ok": False, "error": f"run '{run_id}' is failed"} + if run_id in self.waiting_run_ids or self.active_run_id == run_id: + return {"ok": False, "error": f"run '{run_id}' already has a pending or active acquire"} + + self.waiting_run_ids.append(run_id) + try: + while self.active_run_id is not None or (run_id in self.waiting_run_ids and self.waiting_run_ids[0] != run_id): + await self.condition.wait() + except BaseException: + if run_id in self.waiting_run_ids: + self.waiting_run_ids.remove(run_id) + self.condition.notify_all() + raise + + process = self.processes.get(run_id) + if process is None or process.failed or run_id not in self.waiting_run_ids: + self.clear_run(run_id) + self.condition.notify_all() + return {"ok": False, "error": f"run '{run_id}' is not available"} + + self.waiting_run_ids.popleft() + self.active_run_id = run_id + if process.checkpointed: + await self.run_restore(process.pid) + process.checkpointed = False + + self.condition.notify_all() + return {"ok": True} + + async def release(self, run_id: str) -> dict[str, Any]: + async with self.condition: + process = self.processes.get(run_id) + if process is None: + return {"ok": False, "error": f"run '{run_id}' is not registered"} + if self.active_run_id != run_id: + return {"ok": False, "error": f"run '{run_id}' does not hold an active acquire"} + + await self.run_checkpoint(process.pid) + process.checkpointed = True + self.clear_run(run_id) + self.condition.notify_all() + return {"ok": True} + + async def unregister(self, run_id: str) -> dict[str, Any]: + async with self.condition: + if run_id not in self.processes: + return {"ok": False, "error": f"run '{run_id}' is not registered"} + + self.clear_run(run_id) + del self.processes[run_id] + self.condition.notify_all() + return {"ok": True} + + async def connection_closed(self, connection_id: int) -> None: + async with self.condition: + for run_id, process in self.processes.items(): + if process.connection_id != connection_id: + continue + self.clear_run(run_id) + process.failed = True + process.checkpointed = False + process.connection_id = None + self.condition.notify_all() + + async def run_checkpoint(self, pid: int) -> None: + try: + await asyncio.to_thread(self.restorer.checkpoint, pid) + except Exception: + logger.critical("checkpoint failed for pid %s; GPU state is unknown, killing snapshot agent", pid, exc_info=True) + os._exit(1) + + async def run_restore(self, pid: int) -> None: + try: + await asyncio.to_thread(self.restorer.restore, pid) + except Exception: + logger.critical("restore failed for pid %s; GPU state is unknown, killing snapshot agent", pid, exc_info=True) + os._exit(1) + + +async def start_snapshot_agent(agent: SnapshotAgent, socket_path: str) -> asyncio.Server: + socket = Path(socket_path) + socket.parent.mkdir(parents=True, exist_ok=True) + socket.unlink(missing_ok=True) + + async def handle_connection(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: + connection_id = id(writer) + try: + while line := await reader.readline(): + response = await dispatch(agent, line, connection_id) + writer.write(json.dumps(response).encode("utf-8") + b"\n") + await writer.drain() + finally: + await agent.connection_closed(connection_id) + writer.close() + await writer.wait_closed() + + return await asyncio.start_unix_server(handle_connection, path=socket_path) + + +async def dispatch(agent: SnapshotAgent, line: bytes, connection_id: int) -> dict[str, Any]: + payload = json.loads(line.decode("utf-8")) + + command = payload.get("command", "").upper() + run_id = payload.get("run_id") + + assert run_id is not None, "run_id is required" + + match command: + case "REGISTER": + return await agent.register(run_id, payload["pid"], connection_id=connection_id) + case "ACQUIRE": + return await agent.acquire(run_id) + case "RELEASE": + return await agent.release(run_id) + case "UNREGISTER": + return await agent.unregister(run_id) + case _: + return {"ok": False, "error": f"unknown command '{command}'"} + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run the OpenRL snapshot agent.") + parser.add_argument("--socket", default=os.getenv("OPEN_RL_SNAPSHOT_AGENT_SOCKET", "/tmp/open-rl/snapshot-agent.sock")) + parser.add_argument("--cuda-checkpoint-bin", default=os.getenv("CUDA_CHECKPOINT_BIN", "cuda-checkpoint")) + parser.add_argument("--cuda-checkpoint-timeout-ms", type=int, default=None) + return parser.parse_args() + + +async def main_async() -> None: + args = parse_args() + restorer = CudaCheckpointRestorer(args.cuda_checkpoint_bin, args.cuda_checkpoint_timeout_ms) + agent = SnapshotAgent(restorer) + server = await start_snapshot_agent(agent, args.socket) + logger.info("listening on %s", args.socket) + async with server: + await server.serve_forever() + + +def main() -> None: + logging.basicConfig(level=logging.INFO, format="[SNAPSHOT_AGENT] %(levelname)s %(message)s") + asyncio.run(main_async()) + + +if __name__ == "__main__": + main() diff --git a/tests/test_snapshot_agent.py b/tests/test_snapshot_agent.py new file mode 100644 index 00000000..8ad1bade --- /dev/null +++ b/tests/test_snapshot_agent.py @@ -0,0 +1,249 @@ +import asyncio +import sys +import tempfile +import threading +import unittest +from pathlib import Path + +from tests._server_fixture import REPO_ROOT + +sys.path.insert(0, str(REPO_ROOT / "src")) + +from snapshot_agent.client import SnapshotAgentClient # noqa: E402 +from snapshot_agent.serve import SnapshotAgent, start_snapshot_agent # noqa: E402 + + +class RecordingRestorer: + def __init__(self): + self.calls: list[tuple[str, int]] = [] + + def checkpoint(self, pid: int) -> None: + self.calls.append(("checkpoint", pid)) + + def restore(self, pid: int) -> None: + self.calls.append(("restore", pid)) + + +class BlockingRestorer(RecordingRestorer): + def __init__(self): + super().__init__() + self.checkpoint_started = threading.Event() + self.finish_checkpoint = threading.Event() + self.restore_started = threading.Event() + self.finish_restore = threading.Event() + self.block_checkpoint = False + self.block_restore = False + + def checkpoint(self, pid: int) -> None: + super().checkpoint(pid) + if self.block_checkpoint: + self.checkpoint_started.set() + self.finish_checkpoint.wait(timeout=5.0) + + def restore(self, pid: int) -> None: + super().restore(pid) + if self.block_restore: + self.restore_started.set() + self.finish_restore.wait(timeout=5.0) + + +class SnapshotAgentTest(unittest.IsolatedAsyncioTestCase): + async def test_agent_grants_only_one_active_process_at_a_time(self) -> None: + restorer = RecordingRestorer() + agent = SnapshotAgent(restorer) + await agent.register("run-a", 101) + await agent.register("run-b", 202) + + self.assertTrue((await agent.acquire("run-a"))["ok"]) + blocked = asyncio.create_task(agent.acquire("run-b")) + await asyncio.sleep(0.05) + self.assertFalse(blocked.done()) + + release = await agent.release("run-a") + self.assertTrue(release["ok"]) + granted_b = await asyncio.wait_for(blocked, timeout=1.0) + self.assertTrue(granted_b["ok"]) + self.assertEqual(restorer.calls, [("checkpoint", 101)]) + self.assertEqual(agent.active_run_id, "run-b") + + async def test_first_acquire_is_cold_and_later_acquire_restores_after_checkpoint(self) -> None: + restorer = RecordingRestorer() + agent = SnapshotAgent(restorer) + await agent.register("run-a", 101) + await agent.register("run-b", 202) + + self.assertTrue((await agent.acquire("run-a"))["ok"]) + self.assertTrue((await agent.release("run-a"))["ok"]) + self.assertEqual(restorer.calls, [("checkpoint", 101)]) + + self.assertTrue((await agent.acquire("run-b"))["ok"]) + self.assertTrue((await agent.release("run-b"))["ok"]) + self.assertTrue((await agent.acquire("run-a"))["ok"]) + + self.assertEqual(restorer.calls, [("checkpoint", 101), ("checkpoint", 202), ("restore", 101)]) + + async def test_release_with_no_waiters_checkpoints_process(self) -> None: + restorer = RecordingRestorer() + agent = SnapshotAgent(restorer) + await agent.register("run-a", 101) + + self.assertTrue((await agent.acquire("run-a"))["ok"]) + release = await agent.release("run-a") + + self.assertTrue(release["ok"]) + self.assertIsNone(agent.active_run_id) + self.assertTrue(agent.processes["run-a"].checkpointed) + self.assertFalse(agent.processes["run-a"].failed) + self.assertEqual(restorer.calls, [("checkpoint", 101)]) + + async def test_waiting_acquire_is_not_granted_until_release_checkpoint_finishes(self) -> None: + restorer = BlockingRestorer() + restorer.block_checkpoint = True + agent = SnapshotAgent(restorer) + await agent.register("run-a", 101) + await agent.register("run-b", 202) + + self.assertTrue((await agent.acquire("run-a"))["ok"]) + release_a = asyncio.create_task(agent.release("run-a")) + + checkpoint_started = await asyncio.to_thread(restorer.checkpoint_started.wait, 1.0) + self.assertTrue(checkpoint_started) + + acquire_b = asyncio.create_task(agent.acquire("run-b")) + await asyncio.sleep(0.05) + self.assertFalse(release_a.done()) + self.assertFalse(acquire_b.done()) + + restorer.finish_checkpoint.set() + + self.assertTrue((await asyncio.wait_for(release_a, timeout=1.0))["ok"]) + self.assertTrue((await asyncio.wait_for(acquire_b, timeout=1.0))["ok"]) + self.assertEqual(restorer.calls, [("checkpoint", 101)]) + + async def test_checkpointed_process_is_not_granted_until_restore_finishes(self) -> None: + restorer = BlockingRestorer() + agent = SnapshotAgent(restorer) + await agent.register("run-a", 101) + + self.assertTrue((await agent.acquire("run-a"))["ok"]) + self.assertTrue((await agent.release("run-a"))["ok"]) + + restorer.block_restore = True + acquire_a = asyncio.create_task(agent.acquire("run-a")) + + restore_started = await asyncio.to_thread(restorer.restore_started.wait, 1.0) + self.assertTrue(restore_started) + self.assertFalse(acquire_a.done()) + + restorer.finish_restore.set() + + self.assertTrue((await asyncio.wait_for(acquire_a, timeout=1.0))["ok"]) + self.assertFalse(agent.processes["run-a"].checkpointed) + self.assertEqual(restorer.calls, [("checkpoint", 101), ("restore", 101)]) + + async def test_unregister_waiting_process_prevents_later_grant(self) -> None: + agent = SnapshotAgent(RecordingRestorer()) + await agent.register("run-a", 101) + await agent.register("run-b", 202) + + self.assertTrue((await agent.acquire("run-a"))["ok"]) + acquire_b = asyncio.create_task(agent.acquire("run-b")) + await asyncio.sleep(0.05) + self.assertFalse(acquire_b.done()) + + self.assertTrue((await agent.unregister("run-b"))["ok"]) + self.assertTrue((await agent.release("run-a"))["ok"]) + + result = await asyncio.wait_for(acquire_b, timeout=1.0) + self.assertFalse(result["ok"]) + self.assertIsNone(agent.active_run_id) + + async def test_duplicate_commands_return_explicit_errors(self) -> None: + agent = SnapshotAgent(RecordingRestorer()) + await agent.register("run-a", 101) + + self.assertFalse((await agent.register("run-a", 999))["ok"]) + self.assertTrue((await agent.acquire("run-a"))["ok"]) + self.assertFalse((await agent.acquire("run-a"))["ok"]) + self.assertTrue((await agent.release("run-a"))["ok"]) + self.assertFalse((await agent.release("run-a"))["ok"]) + self.assertTrue((await agent.unregister("run-a"))["ok"]) + self.assertFalse((await agent.unregister("run-a"))["ok"]) + + async def test_waiters_are_granted_in_fifo_order(self) -> None: + agent = SnapshotAgent(RecordingRestorer()) + for run_id, pid in [("run-a", 101), ("run-b", 202), ("run-c", 303), ("run-d", 404)]: + await agent.register(run_id, pid) + + self.assertTrue((await agent.acquire("run-a"))["ok"]) + + grant_order: list[str] = [] + + async def acquire_then_release(run_id: str) -> None: + await agent.acquire(run_id) + grant_order.append(run_id) + await agent.release(run_id) + + waiters = [] + for run_id in ["run-c", "run-b", "run-d"]: + waiters.append(asyncio.create_task(acquire_then_release(run_id))) + await asyncio.sleep(0.01) + + self.assertTrue((await agent.release("run-a"))["ok"]) + await asyncio.wait_for(asyncio.gather(*waiters), timeout=1.0) + + self.assertEqual(grant_order, ["run-c", "run-b", "run-d"]) + + +class SnapshotAgentSocketTest(unittest.IsolatedAsyncioTestCase): + async def test_persistent_socket_clients_alternate(self) -> None: + restorer = RecordingRestorer() + agent = SnapshotAgent(restorer) + with tempfile.TemporaryDirectory() as tmp: + socket_path = str(Path(tmp) / "snapshot-agent.sock") + server = await start_snapshot_agent(agent, socket_path) + client_a = SnapshotAgentClient(socket_path) + client_b = SnapshotAgentClient(socket_path) + try: + await client_a.register("run-a", 101) + await client_b.register("run-b", 202) + + async with client_a.acquire("run-a"): + blocked = asyncio.create_task(acquire_once(client_b, "run-b")) + await asyncio.sleep(0.05) + self.assertFalse(blocked.done()) + + self.assertEqual(await asyncio.wait_for(blocked, timeout=1.0), "run-b") + self.assertEqual(restorer.calls, [("checkpoint", 101), ("checkpoint", 202)]) + finally: + await client_a.close() + await client_b.close() + server.close() + await server.wait_closed() + + async def test_closing_active_socket_marks_run_failed(self) -> None: + agent = SnapshotAgent(RecordingRestorer()) + with tempfile.TemporaryDirectory() as tmp: + socket_path = str(Path(tmp) / "snapshot-agent.sock") + server = await start_snapshot_agent(agent, socket_path) + client = SnapshotAgentClient(socket_path) + try: + await client.register("run-a", 101) + await client.request({"command": "ACQUIRE", "run_id": "run-a"}) + await client.close() + await asyncio.sleep(0.05) + + self.assertIsNone(agent.active_run_id) + self.assertTrue(agent.processes["run-a"].failed) + finally: + server.close() + await server.wait_closed() + + +async def acquire_once(client: SnapshotAgentClient, run_id: str) -> str: + async with client.acquire(run_id): + return run_id + + +if __name__ == "__main__": + unittest.main()