-
Notifications
You must be signed in to change notification settings - Fork 6
feat: snapshot agent #109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
feat: snapshot agent #109
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| # Snapshot Agent | ||
|
|
||
| The snapshot agent is a small process-level GPU residency primitive. | ||
|
|
||
| It exposes four commands over a Unix socket: | ||
|
|
||
| - `REGISTER(run_id, pid)` records the process that owns a run. | ||
| - `ACQUIRE(run_id)` grants that process the right to touch CUDA. | ||
| - `RELEASE(run_id)` checkpoints that process before another run can acquire CUDA. | ||
| - `UNREGISTER(run_id)` removes the process registration. | ||
|
|
||
| Today every successful `RELEASE` checkpoints the process. This is simple and | ||
| conservative, but it is slow because even a single run pays checkpoint cost after | ||
| each acquire window. | ||
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| import logging | ||
| import os | ||
| import shlex | ||
| import subprocess | ||
| import time | ||
| from typing import Protocol | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class CheckpointRestorer(Protocol): | ||
| def checkpoint(self, pid: int) -> None: | ||
| pass | ||
|
|
||
| def restore(self, pid: int) -> None: | ||
| pass | ||
|
|
||
|
|
||
| class CudaCheckpointRestorer: | ||
| def __init__(self, cuda_checkpoint_bin: str | None = None, timeout_ms: int | None = None): | ||
| self.cuda_checkpoint_bin = cuda_checkpoint_bin or os.getenv("CUDA_CHECKPOINT_BIN", "cuda-checkpoint") | ||
| self.timeout_ms = timeout_ms | ||
|
|
||
| def checkpoint(self, pid: int) -> None: | ||
| start = time.perf_counter() | ||
| logger.info("checkpoint pid=%s", pid) | ||
| lock_args = ["--action", "lock", "--pid", str(pid)] | ||
| if self.timeout_ms is not None: | ||
| lock_args.extend(["--timeout", str(self.timeout_ms)]) | ||
|
|
||
| self.run_cuda_checkpoint(lock_args) | ||
| self.run_cuda_checkpoint(["--action", "checkpoint", "--pid", str(pid)]) | ||
| logger.info("checkpoint pid=%s took %.0f ms", pid, (time.perf_counter() - start) * 1000) | ||
|
|
||
| def restore(self, pid: int) -> None: | ||
| start = time.perf_counter() | ||
| logger.info("restore pid=%s", pid) | ||
| self.run_cuda_checkpoint(["--action", "restore", "--pid", str(pid)]) | ||
| self.run_cuda_checkpoint(["--action", "unlock", "--pid", str(pid)]) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in our internal code we used |
||
| logger.info("restore pid=%s took %.0f ms", pid, (time.perf_counter() - start) * 1000) | ||
|
|
||
| def run_cuda_checkpoint(self, args: list[str]) -> None: | ||
| full_argv = [self.cuda_checkpoint_bin, *args] | ||
| result = subprocess.run(full_argv, capture_output=True, check=False, text=True) | ||
| if result.returncode != 0: | ||
| stderr = result.stderr.strip() | ||
| stdout = result.stdout.strip() | ||
| detail = stderr or stdout or f"exit code {result.returncode}" | ||
| rendered_argv = " ".join(shlex.quote(arg) for arg in full_argv) | ||
| raise RuntimeError(f"{rendered_argv} failed: {detail}") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| import asyncio | ||
| import json | ||
| from collections.abc import AsyncIterator | ||
| from contextlib import asynccontextmanager | ||
| from typing import Any | ||
|
|
||
|
|
||
| class SnapshotAgentClient: | ||
| def __init__(self, socket_path: str): | ||
| self.socket_path = socket_path | ||
| self.reader: asyncio.StreamReader | None = None | ||
| self.writer: asyncio.StreamWriter | None = None | ||
|
|
||
| async def connect(self) -> None: | ||
| if self.writer is not None and not self.writer.is_closing(): | ||
| return | ||
| self.reader, self.writer = await asyncio.open_unix_connection(self.socket_path) | ||
|
|
||
| async def close(self) -> None: | ||
| if self.writer is None: | ||
| return | ||
| self.writer.close() | ||
| await self.writer.wait_closed() | ||
| self.reader = None | ||
| self.writer = None | ||
|
|
||
| async def register(self, run_id: str, pid: int) -> dict[str, Any]: | ||
| return await self.request({"command": "REGISTER", "run_id": run_id, "pid": pid}) | ||
|
|
||
| async def unregister(self, run_id: str) -> dict[str, Any]: | ||
| return await self.request({"command": "UNREGISTER", "run_id": run_id}) | ||
|
|
||
| @asynccontextmanager | ||
| async def acquire(self, run_id: str) -> AsyncIterator[None]: | ||
| await self.request({"command": "ACQUIRE", "run_id": run_id}) | ||
| try: | ||
| yield | ||
| finally: | ||
| await self.request({"command": "RELEASE", "run_id": run_id}) | ||
|
|
||
| async def request(self, payload: dict[str, Any]) -> dict[str, Any]: | ||
| await self.connect() | ||
| assert self.reader is not None | ||
| assert self.writer is not None | ||
|
|
||
| self.writer.write(json.dumps(payload).encode("utf-8") + b"\n") | ||
| await self.writer.drain() | ||
| line = await self.reader.readline() | ||
| if not line: | ||
| raise RuntimeError("snapshot agent connection closed") | ||
|
|
||
| response = json.loads(line.decode("utf-8")) | ||
| if not response.get("ok"): | ||
| raise RuntimeError(response.get("error", "snapshot agent command failed")) | ||
| return response |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,199 @@ | ||
| import argparse | ||
| import asyncio | ||
| import json | ||
| import logging | ||
| import os | ||
| from collections import deque | ||
| from dataclasses import dataclass | ||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| from .checkpoint import CheckpointRestorer, CudaCheckpointRestorer | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @dataclass | ||
| class ProcessRegistration: | ||
| pid: int | ||
| connection_id: int | None | ||
| checkpointed: bool = False | ||
| failed: bool = False | ||
|
|
||
|
|
||
| class SnapshotAgent: | ||
| def __init__(self, restorer: CheckpointRestorer): | ||
| self.restorer = restorer | ||
| self.processes: dict[str, ProcessRegistration] = {} | ||
| self.waiting_run_ids: deque[str] = deque() | ||
| self.active_run_id: str | None = None | ||
| self.condition = asyncio.Condition() | ||
|
|
||
| def clear_run(self, run_id: str) -> None: | ||
| if run_id in self.waiting_run_ids: | ||
| self.waiting_run_ids.remove(run_id) | ||
| if self.active_run_id == run_id: | ||
| self.active_run_id = None | ||
|
|
||
| async def register(self, run_id: str, pid: int, connection_id: int | None = None) -> dict[str, Any]: | ||
| async with self.condition: | ||
| process = self.processes.get(run_id) | ||
|
|
||
| if process is not None: | ||
| return {"ok": False, "error": f"run '{run_id}' is already registered"} | ||
|
|
||
| self.processes[run_id] = ProcessRegistration(pid=pid, connection_id=connection_id) | ||
| self.condition.notify_all() | ||
| return {"ok": True} | ||
|
|
||
| async def acquire(self, run_id: str) -> dict[str, Any]: | ||
| async with self.condition: | ||
| process = self.processes.get(run_id) | ||
| if process is None: | ||
| return {"ok": False, "error": f"run '{run_id}' is not registered"} | ||
| if process.failed: | ||
| return {"ok": False, "error": f"run '{run_id}' is failed"} | ||
| if run_id in self.waiting_run_ids or self.active_run_id == run_id: | ||
| return {"ok": False, "error": f"run '{run_id}' already has a pending or active acquire"} | ||
|
|
||
| self.waiting_run_ids.append(run_id) | ||
| try: | ||
| while self.active_run_id is not None or (run_id in self.waiting_run_ids and self.waiting_run_ids[0] != run_id): | ||
| await self.condition.wait() | ||
| except BaseException: | ||
| if run_id in self.waiting_run_ids: | ||
| self.waiting_run_ids.remove(run_id) | ||
| self.condition.notify_all() | ||
| raise | ||
|
|
||
| process = self.processes.get(run_id) | ||
| if process is None or process.failed or run_id not in self.waiting_run_ids: | ||
| self.clear_run(run_id) | ||
| self.condition.notify_all() | ||
| return {"ok": False, "error": f"run '{run_id}' is not available"} | ||
|
|
||
| self.waiting_run_ids.popleft() | ||
| self.active_run_id = run_id | ||
| if process.checkpointed: | ||
| await self.run_restore(process.pid) | ||
| process.checkpointed = False | ||
|
|
||
| self.condition.notify_all() | ||
| return {"ok": True} | ||
|
|
||
| async def release(self, run_id: str) -> dict[str, Any]: | ||
| async with self.condition: | ||
| process = self.processes.get(run_id) | ||
| if process is None: | ||
| return {"ok": False, "error": f"run '{run_id}' is not registered"} | ||
| if self.active_run_id != run_id: | ||
| return {"ok": False, "error": f"run '{run_id}' does not hold an active acquire"} | ||
|
|
||
| await self.run_checkpoint(process.pid) | ||
| process.checkpointed = True | ||
| self.clear_run(run_id) | ||
| self.condition.notify_all() | ||
| return {"ok": True} | ||
|
|
||
| async def unregister(self, run_id: str) -> dict[str, Any]: | ||
| async with self.condition: | ||
| if run_id not in self.processes: | ||
| return {"ok": False, "error": f"run '{run_id}' is not registered"} | ||
|
|
||
| self.clear_run(run_id) | ||
| del self.processes[run_id] | ||
| self.condition.notify_all() | ||
| return {"ok": True} | ||
|
|
||
| async def connection_closed(self, connection_id: int) -> None: | ||
| async with self.condition: | ||
| for run_id, process in self.processes.items(): | ||
| if process.connection_id != connection_id: | ||
| continue | ||
| self.clear_run(run_id) | ||
| process.failed = True | ||
| process.checkpointed = False | ||
| process.connection_id = None | ||
| self.condition.notify_all() | ||
|
|
||
| async def run_checkpoint(self, pid: int) -> None: | ||
| try: | ||
| await asyncio.to_thread(self.restorer.checkpoint, pid) | ||
| except Exception: | ||
| logger.critical("checkpoint failed for pid %s; GPU state is unknown, killing snapshot agent", pid, exc_info=True) | ||
| os._exit(1) | ||
|
|
||
| async def run_restore(self, pid: int) -> None: | ||
| try: | ||
| await asyncio.to_thread(self.restorer.restore, pid) | ||
| except Exception: | ||
| logger.critical("restore failed for pid %s; GPU state is unknown, killing snapshot agent", pid, exc_info=True) | ||
| os._exit(1) | ||
|
|
||
|
|
||
| async def start_snapshot_agent(agent: SnapshotAgent, socket_path: str) -> asyncio.Server: | ||
| socket = Path(socket_path) | ||
| socket.parent.mkdir(parents=True, exist_ok=True) | ||
| socket.unlink(missing_ok=True) | ||
|
|
||
| async def handle_connection(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: | ||
| connection_id = id(writer) | ||
| try: | ||
| while line := await reader.readline(): | ||
| response = await dispatch(agent, line, connection_id) | ||
| writer.write(json.dumps(response).encode("utf-8") + b"\n") | ||
| await writer.drain() | ||
| finally: | ||
| await agent.connection_closed(connection_id) | ||
| writer.close() | ||
| await writer.wait_closed() | ||
|
|
||
| return await asyncio.start_unix_server(handle_connection, path=socket_path) | ||
|
|
||
|
|
||
| async def dispatch(agent: SnapshotAgent, line: bytes, connection_id: int) -> dict[str, Any]: | ||
| payload = json.loads(line.decode("utf-8")) | ||
|
|
||
| command = payload.get("command", "").upper() | ||
| run_id = payload.get("run_id") | ||
|
|
||
| assert run_id is not None, "run_id is required" | ||
|
|
||
| match command: | ||
| case "REGISTER": | ||
| return await agent.register(run_id, payload["pid"], connection_id=connection_id) | ||
| case "ACQUIRE": | ||
| return await agent.acquire(run_id) | ||
| case "RELEASE": | ||
| return await agent.release(run_id) | ||
| case "UNREGISTER": | ||
| return await agent.unregister(run_id) | ||
| case _: | ||
| return {"ok": False, "error": f"unknown command '{command}'"} | ||
|
|
||
|
|
||
| def parse_args() -> argparse.Namespace: | ||
| parser = argparse.ArgumentParser(description="Run the OpenRL snapshot agent.") | ||
| parser.add_argument("--socket", default=os.getenv("OPEN_RL_SNAPSHOT_AGENT_SOCKET", "/tmp/open-rl/snapshot-agent.sock")) | ||
| parser.add_argument("--cuda-checkpoint-bin", default=os.getenv("CUDA_CHECKPOINT_BIN", "cuda-checkpoint")) | ||
| parser.add_argument("--cuda-checkpoint-timeout-ms", type=int, default=None) | ||
| return parser.parse_args() | ||
|
|
||
|
|
||
| async def main_async() -> None: | ||
| args = parse_args() | ||
| restorer = CudaCheckpointRestorer(args.cuda_checkpoint_bin, args.cuda_checkpoint_timeout_ms) | ||
| agent = SnapshotAgent(restorer) | ||
| server = await start_snapshot_agent(agent, args.socket) | ||
| logger.info("listening on %s", args.socket) | ||
| async with server: | ||
| await server.serve_forever() | ||
|
|
||
|
|
||
| def main() -> None: | ||
| logging.basicConfig(level=logging.INFO, format="[SNAPSHOT_AGENT] %(levelname)s %(message)s") | ||
| asyncio.run(main_async()) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is run_id ?