Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/snapshot_agent/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Snapshot Agent

The snapshot agent is a small process-level GPU residency primitive.

It exposes four commands over a Unix socket:

- `REGISTER(run_id, pid)` records the process that owns a run.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is run_id ?

- `ACQUIRE(run_id)` grants that process the right to touch CUDA.
- `RELEASE(run_id)` checkpoints that process before another run can acquire CUDA.
- `UNREGISTER(run_id)` removes the process registration.

Today every successful `RELEASE` checkpoints the process. This is simple and
conservative, but it is slow because even a single run pays checkpoint cost after
each acquire window.
Empty file added src/snapshot_agent/__init__.py
Empty file.
50 changes: 50 additions & 0 deletions src/snapshot_agent/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import logging
import os
import shlex
import subprocess
import time
from typing import Protocol

logger = logging.getLogger(__name__)


class CheckpointRestorer(Protocol):
def checkpoint(self, pid: int) -> None:
pass

def restore(self, pid: int) -> None:
pass


class CudaCheckpointRestorer:
def __init__(self, cuda_checkpoint_bin: str | None = None, timeout_ms: int | None = None):
self.cuda_checkpoint_bin = cuda_checkpoint_bin or os.getenv("CUDA_CHECKPOINT_BIN", "cuda-checkpoint")
self.timeout_ms = timeout_ms

def checkpoint(self, pid: int) -> None:
start = time.perf_counter()
logger.info("checkpoint pid=%s", pid)
lock_args = ["--action", "lock", "--pid", str(pid)]
if self.timeout_ms is not None:
lock_args.extend(["--timeout", str(self.timeout_ms)])

self.run_cuda_checkpoint(lock_args)
self.run_cuda_checkpoint(["--action", "checkpoint", "--pid", str(pid)])
logger.info("checkpoint pid=%s took %.0f ms", pid, (time.perf_counter() - start) * 1000)

def restore(self, pid: int) -> None:
start = time.perf_counter()
logger.info("restore pid=%s", pid)
self.run_cuda_checkpoint(["--action", "restore", "--pid", str(pid)])
self.run_cuda_checkpoint(["--action", "unlock", "--pid", str(pid)])

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in our internal code we used cuda-checkpoint --toggle to restore and I am not sure if it matters.

logger.info("restore pid=%s took %.0f ms", pid, (time.perf_counter() - start) * 1000)

def run_cuda_checkpoint(self, args: list[str]) -> None:
full_argv = [self.cuda_checkpoint_bin, *args]
result = subprocess.run(full_argv, capture_output=True, check=False, text=True)
if result.returncode != 0:
stderr = result.stderr.strip()
stdout = result.stdout.strip()
detail = stderr or stdout or f"exit code {result.returncode}"
rendered_argv = " ".join(shlex.quote(arg) for arg in full_argv)
raise RuntimeError(f"{rendered_argv} failed: {detail}")
55 changes: 55 additions & 0 deletions src/snapshot_agent/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import asyncio
import json
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any


class SnapshotAgentClient:
def __init__(self, socket_path: str):
self.socket_path = socket_path
self.reader: asyncio.StreamReader | None = None
self.writer: asyncio.StreamWriter | None = None

async def connect(self) -> None:
if self.writer is not None and not self.writer.is_closing():
return
self.reader, self.writer = await asyncio.open_unix_connection(self.socket_path)

async def close(self) -> None:
if self.writer is None:
return
self.writer.close()
await self.writer.wait_closed()
self.reader = None
self.writer = None

async def register(self, run_id: str, pid: int) -> dict[str, Any]:
return await self.request({"command": "REGISTER", "run_id": run_id, "pid": pid})

async def unregister(self, run_id: str) -> dict[str, Any]:
return await self.request({"command": "UNREGISTER", "run_id": run_id})

@asynccontextmanager
async def acquire(self, run_id: str) -> AsyncIterator[None]:
await self.request({"command": "ACQUIRE", "run_id": run_id})
try:
yield
finally:
await self.request({"command": "RELEASE", "run_id": run_id})

async def request(self, payload: dict[str, Any]) -> dict[str, Any]:
await self.connect()
assert self.reader is not None
assert self.writer is not None

self.writer.write(json.dumps(payload).encode("utf-8") + b"\n")
await self.writer.drain()
line = await self.reader.readline()
if not line:
raise RuntimeError("snapshot agent connection closed")

response = json.loads(line.decode("utf-8"))
if not response.get("ok"):
raise RuntimeError(response.get("error", "snapshot agent command failed"))
return response
199 changes: 199 additions & 0 deletions src/snapshot_agent/serve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import argparse
import asyncio
import json
import logging
import os
from collections import deque
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from .checkpoint import CheckpointRestorer, CudaCheckpointRestorer

logger = logging.getLogger(__name__)


@dataclass
class ProcessRegistration:
pid: int
connection_id: int | None
checkpointed: bool = False
failed: bool = False


class SnapshotAgent:
def __init__(self, restorer: CheckpointRestorer):
self.restorer = restorer
self.processes: dict[str, ProcessRegistration] = {}
self.waiting_run_ids: deque[str] = deque()
self.active_run_id: str | None = None
self.condition = asyncio.Condition()

def clear_run(self, run_id: str) -> None:
if run_id in self.waiting_run_ids:
self.waiting_run_ids.remove(run_id)
if self.active_run_id == run_id:
self.active_run_id = None

async def register(self, run_id: str, pid: int, connection_id: int | None = None) -> dict[str, Any]:
async with self.condition:
process = self.processes.get(run_id)

if process is not None:
return {"ok": False, "error": f"run '{run_id}' is already registered"}

self.processes[run_id] = ProcessRegistration(pid=pid, connection_id=connection_id)
self.condition.notify_all()
return {"ok": True}

async def acquire(self, run_id: str) -> dict[str, Any]:
async with self.condition:
process = self.processes.get(run_id)
if process is None:
return {"ok": False, "error": f"run '{run_id}' is not registered"}
if process.failed:
return {"ok": False, "error": f"run '{run_id}' is failed"}
if run_id in self.waiting_run_ids or self.active_run_id == run_id:
return {"ok": False, "error": f"run '{run_id}' already has a pending or active acquire"}

self.waiting_run_ids.append(run_id)
try:
while self.active_run_id is not None or (run_id in self.waiting_run_ids and self.waiting_run_ids[0] != run_id):
await self.condition.wait()
except BaseException:
if run_id in self.waiting_run_ids:
self.waiting_run_ids.remove(run_id)
self.condition.notify_all()
raise

process = self.processes.get(run_id)
if process is None or process.failed or run_id not in self.waiting_run_ids:
self.clear_run(run_id)
self.condition.notify_all()
return {"ok": False, "error": f"run '{run_id}' is not available"}

self.waiting_run_ids.popleft()
self.active_run_id = run_id
if process.checkpointed:
await self.run_restore(process.pid)
process.checkpointed = False

self.condition.notify_all()
return {"ok": True}

async def release(self, run_id: str) -> dict[str, Any]:
async with self.condition:
process = self.processes.get(run_id)
if process is None:
return {"ok": False, "error": f"run '{run_id}' is not registered"}
if self.active_run_id != run_id:
return {"ok": False, "error": f"run '{run_id}' does not hold an active acquire"}

await self.run_checkpoint(process.pid)
process.checkpointed = True
self.clear_run(run_id)
self.condition.notify_all()
return {"ok": True}

async def unregister(self, run_id: str) -> dict[str, Any]:
async with self.condition:
if run_id not in self.processes:
return {"ok": False, "error": f"run '{run_id}' is not registered"}

self.clear_run(run_id)
del self.processes[run_id]
self.condition.notify_all()
return {"ok": True}

async def connection_closed(self, connection_id: int) -> None:
async with self.condition:
for run_id, process in self.processes.items():
if process.connection_id != connection_id:
continue
self.clear_run(run_id)
process.failed = True
process.checkpointed = False
process.connection_id = None
self.condition.notify_all()

async def run_checkpoint(self, pid: int) -> None:
try:
await asyncio.to_thread(self.restorer.checkpoint, pid)
except Exception:
logger.critical("checkpoint failed for pid %s; GPU state is unknown, killing snapshot agent", pid, exc_info=True)
os._exit(1)

async def run_restore(self, pid: int) -> None:
try:
await asyncio.to_thread(self.restorer.restore, pid)
except Exception:
logger.critical("restore failed for pid %s; GPU state is unknown, killing snapshot agent", pid, exc_info=True)
os._exit(1)


async def start_snapshot_agent(agent: SnapshotAgent, socket_path: str) -> asyncio.Server:
socket = Path(socket_path)
socket.parent.mkdir(parents=True, exist_ok=True)
socket.unlink(missing_ok=True)

async def handle_connection(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
connection_id = id(writer)
try:
while line := await reader.readline():
response = await dispatch(agent, line, connection_id)
writer.write(json.dumps(response).encode("utf-8") + b"\n")
await writer.drain()
finally:
await agent.connection_closed(connection_id)
writer.close()
await writer.wait_closed()

return await asyncio.start_unix_server(handle_connection, path=socket_path)


async def dispatch(agent: SnapshotAgent, line: bytes, connection_id: int) -> dict[str, Any]:
payload = json.loads(line.decode("utf-8"))

command = payload.get("command", "").upper()
run_id = payload.get("run_id")

assert run_id is not None, "run_id is required"

match command:
case "REGISTER":
return await agent.register(run_id, payload["pid"], connection_id=connection_id)
case "ACQUIRE":
return await agent.acquire(run_id)
case "RELEASE":
return await agent.release(run_id)
case "UNREGISTER":
return await agent.unregister(run_id)
case _:
return {"ok": False, "error": f"unknown command '{command}'"}


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run the OpenRL snapshot agent.")
parser.add_argument("--socket", default=os.getenv("OPEN_RL_SNAPSHOT_AGENT_SOCKET", "/tmp/open-rl/snapshot-agent.sock"))
parser.add_argument("--cuda-checkpoint-bin", default=os.getenv("CUDA_CHECKPOINT_BIN", "cuda-checkpoint"))
parser.add_argument("--cuda-checkpoint-timeout-ms", type=int, default=None)
return parser.parse_args()


async def main_async() -> None:
args = parse_args()
restorer = CudaCheckpointRestorer(args.cuda_checkpoint_bin, args.cuda_checkpoint_timeout_ms)
agent = SnapshotAgent(restorer)
server = await start_snapshot_agent(agent, args.socket)
logger.info("listening on %s", args.socket)
async with server:
await server.serve_forever()


def main() -> None:
logging.basicConfig(level=logging.INFO, format="[SNAPSHOT_AGENT] %(levelname)s %(message)s")
asyncio.run(main_async())


if __name__ == "__main__":
main()
Loading
Loading