Skip to content

AgentKilledError #77

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ API

.. autoclass:: torchrunx.LaunchResult
:members:

.. autoclass:: torchrunx.AgentKilledError
3 changes: 2 additions & 1 deletion src/torchrunx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .launcher import Launcher, LaunchResult, launch
from .launcher import AgentKilledError, Launcher, LaunchResult, launch
from .logging_utils import add_filter_to_handler, file_handler, stream_handler

__all__ = [
"AgentKilledError",
"Launcher",
"launch",
"LaunchResult",
Expand Down
262 changes: 135 additions & 127 deletions src/torchrunx/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,130 +25,8 @@
from .utils import AgentStatus, LauncherAgentGroup, LauncherPayload, WorkerException, get_open_port


def resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]:
if hostnames == "auto":
return auto_hosts()
if hostnames == "slurm":
return slurm_hosts()
return hostnames


def resolve_workers_per_host(
workers_per_host: int | list[int] | Literal["auto", "slurm"],
num_hosts: int,
) -> list[int]:
if workers_per_host == "auto":
workers_per_host = auto_workers()
elif workers_per_host == "slurm":
workers_per_host = slurm_workers()

if isinstance(workers_per_host, int):
workers_per_host = [workers_per_host] * num_hosts
elif len(workers_per_host) != num_hosts:
msg = "len(workers_per_host) != len(hostnames)"
raise ValueError(msg)

return workers_per_host


def build_logging_server(
log_handlers: list[Handler] | Literal["auto"] | None,
launcher_hostname: str,
hostnames: list[str],
workers_per_host: list[int],
log_dir: str | os.PathLike,
log_level: int,
) -> LogRecordSocketReceiver:
if log_handlers is None:
log_handlers = []
elif log_handlers == "auto":
log_handlers = default_handlers(
hostnames=hostnames,
workers_per_host=workers_per_host,
log_dir=log_dir,
log_level=log_level,
)

return LogRecordSocketReceiver(
host=launcher_hostname,
port=get_open_port(),
handlers=log_handlers,
)


def build_launch_command(
launcher_hostname: str,
launcher_port: int,
logger_port: int,
world_size: int,
rank: int,
env_vars: tuple[str, ...],
env_file: str | os.PathLike | None,
) -> str:
# shlex.quote prevents shell injection here (resolves S602 in execute_command)

commands = []

current_dir = shlex.quote(str(Path.cwd()))
commands.append("cd " + current_dir)

env_exports = []
for k, v in os.environ.items():
if any(fnmatch.fnmatch(k, e) for e in env_vars):
env_exports.append(shlex.quote(f"{k}={v}"))

if len(env_exports) > 0:
commands.append("export " + " ".join(env_exports))

if env_file is not None:
commands.append("source " + shlex.quote(str(env_file)))

python = shlex.quote(sys.executable)
launcher_hostname = shlex.quote(launcher_hostname)

commands.append(
f"{python} -u -m torchrunx "
f"--launcher-hostname {launcher_hostname} "
f"--launcher-port {launcher_port} "
f"--logger-port {logger_port} "
f"--world-size {world_size} "
f"--rank {rank}",
)

return " && ".join(commands)


def execute_command(
command: str,
hostname: str,
ssh_config_file: str | os.PathLike | None = None,
) -> None:
is_localhost = True
_hostname_or_ip = hostname
try:
_ip = ipaddress.ip_address(_hostname_or_ip)
except ValueError:
_ip = ipaddress.ip_address(socket.gethostbyname(_hostname_or_ip))
if not _ip.is_loopback:
# compare local interface addresses between host and localhost
_host_addrs = [addr[4][0] for addr in socket.getaddrinfo(str(_ip), None)]
_localhost_addrs = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)]
is_localhost = len(set(_host_addrs) & set(_localhost_addrs)) > 0

if is_localhost:
# S602: subprocess.Popen is called with shell=True (https://docs.python.org/3.9/library/subprocess.html#security-considerations)
# Made sure to shlex.quote arguments in build_command to prevent shell injection
subprocess.Popen(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # noqa: S602
else:
runtime_ssh_path = ssh_config_file
if isinstance(ssh_config_file, os.PathLike):
runtime_ssh_path = str(ssh_config_file)

with fabric.Connection(
host=hostname,
config=fabric.Config(runtime_ssh_path=runtime_ssh_path),
) as conn:
conn.run(f"{command} >> /dev/null 2>&1 &", asynchronous=True)
class AgentKilledError(Exception):
pass


@dataclass
Expand Down Expand Up @@ -263,8 +141,11 @@ def run( # noqa: C901, PLR0912
# loop to monitor agent statuses (until failed or done)

while True:
# raises RuntimeError if communication timeout due to death of any agent
agent_statuses = launcher_agent_group.sync_agent_statuses(status=None)
try:
agent_statuses = launcher_agent_group.sync_agent_statuses(status=None)
except RuntimeError as e:
# occurs if any agent dies and communication times out
raise AgentKilledError from e

# raises specific exception if any agent fails
for s in agent_statuses:
Expand Down Expand Up @@ -334,7 +215,8 @@ def launch(
:param default_env_vars: A list of environmental variables to be copied from the launcher process to workers. Allows for bash pattern matching syntax.
:param extra_env_vars: Additional, user-specified variables to copy.
:param env_file: A file (like ``.env``) with additional environment variables to copy.
:raises RuntimeError: May fail if ``torch.distributed`` not available or communication timeout between nodes
:raises RuntimeError: If ``torch.distributed`` not available
:raises AgentKilledError: If any agent is killed
:raises Exception: Propagates exceptions raised in worker processes
""" # noqa: E501
return Launcher(
Expand Down Expand Up @@ -409,3 +291,129 @@ def value(self, rank: int) -> Any:

msg = f"Rank {rank} larger than world_size"
raise ValueError(msg)


def resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]:
if hostnames == "auto":
return auto_hosts()
if hostnames == "slurm":
return slurm_hosts()
return hostnames


def resolve_workers_per_host(
workers_per_host: int | list[int] | Literal["auto", "slurm"],
num_hosts: int,
) -> list[int]:
if workers_per_host == "auto":
workers_per_host = auto_workers()
elif workers_per_host == "slurm":
workers_per_host = slurm_workers()

if isinstance(workers_per_host, int):
workers_per_host = [workers_per_host] * num_hosts
elif len(workers_per_host) != num_hosts:
msg = "len(workers_per_host) != len(hostnames)"
raise ValueError(msg)

return workers_per_host


def build_logging_server(
log_handlers: list[Handler] | Literal["auto"] | None,
launcher_hostname: str,
hostnames: list[str],
workers_per_host: list[int],
log_dir: str | os.PathLike,
log_level: int,
) -> LogRecordSocketReceiver:
if log_handlers is None:
log_handlers = []
elif log_handlers == "auto":
log_handlers = default_handlers(
hostnames=hostnames,
workers_per_host=workers_per_host,
log_dir=log_dir,
log_level=log_level,
)

return LogRecordSocketReceiver(
host=launcher_hostname,
port=get_open_port(),
handlers=log_handlers,
)


def build_launch_command(
launcher_hostname: str,
launcher_port: int,
logger_port: int,
world_size: int,
rank: int,
env_vars: tuple[str, ...],
env_file: str | os.PathLike | None,
) -> str:
# shlex.quote prevents shell injection here (resolves S602 in execute_command)

commands = []

current_dir = shlex.quote(str(Path.cwd()))
commands.append("cd " + current_dir)

env_exports = []
for k, v in os.environ.items():
if any(fnmatch.fnmatch(k, e) for e in env_vars):
env_exports.append(shlex.quote(f"{k}={v}"))

if len(env_exports) > 0:
commands.append("export " + " ".join(env_exports))

if env_file is not None:
commands.append("source " + shlex.quote(str(env_file)))

python = shlex.quote(sys.executable)
launcher_hostname = shlex.quote(launcher_hostname)

commands.append(
f"{python} -u -m torchrunx "
f"--launcher-hostname {launcher_hostname} "
f"--launcher-port {launcher_port} "
f"--logger-port {logger_port} "
f"--world-size {world_size} "
f"--rank {rank}",
)

return " && ".join(commands)


def execute_command(
command: str,
hostname: str,
ssh_config_file: str | os.PathLike | None = None,
) -> None:
is_localhost = True
_hostname_or_ip = hostname
try:
_ip = ipaddress.ip_address(_hostname_or_ip)
except ValueError:
_ip = ipaddress.ip_address(socket.gethostbyname(_hostname_or_ip))
if not _ip.is_loopback:
# compare local interface addresses between host and localhost
_host_addrs = [addr[4][0] for addr in socket.getaddrinfo(str(_ip), None)]
_localhost_addrs = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)]
is_localhost = len(set(_host_addrs) & set(_localhost_addrs)) > 0

if is_localhost:
# S602: subprocess.Popen is called with shell=True (https://docs.python.org/3.9/library/subprocess.html#security-considerations)
# Made sure to shlex.quote arguments in build_command to prevent shell injection
subprocess.Popen(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # noqa: S602
else:
runtime_ssh_path = ssh_config_file
if isinstance(ssh_config_file, os.PathLike):
runtime_ssh_path = str(ssh_config_file)

with fabric.Connection(
host=hostname,
config=fabric.Config(runtime_ssh_path=runtime_ssh_path),
) as conn:
conn.run(f"{command} >> /dev/null 2>&1 &", asynchronous=True)