API¶

torchrunx.launch(
func: Callable,
func_args: tuple | None = None,
func_kwargs: dict[str, Any] | None = None,
hostnames: list[str] | Literal['auto', 'slurm'] = 'auto',
workers_per_host: int | list[int] | Literal['auto', 'slurm'] = 'auto',
ssh_config_file: str | os.PathLike | None = None,
backend: Literal['nccl', 'gloo', 'mpi', 'ucc', 'auto'] | None = 'auto',
timeout: int = 600,
default_env_vars: tuple[str, ...] = ('PATH', 'LD_LIBRARY', 'LIBRARY_PATH', 'PYTHON*', 'CUDA*', 'TORCH*', 'PYTORCH*', 'NCCL*'),
extra_env_vars: tuple[str, ...] = (),
env_file: str | os.PathLike | None = None,
handler_factory: Callable[[], list[Handler]] | Literal['auto'] | None = 'auto',
) LaunchResult[source]¶

Launch a distributed PyTorch function on the specified nodes.

Parameters:
  • func – Function to run on each worker.

  • func_args – Positional arguments for func.

  • func_kwargs – Keyword arguments for func.

  • hostnames – Nodes on which to launch the function. Defaults to nodes inferred from a SLURM environment or localhost.

  • workers_per_host – Number of processes to run per node. Can specify different counts per node with a list.

  • ssh_config_file – Path to an SSH configuration file for connecting to nodes. Defaults to ~/.ssh/config or /etc/ssh/ssh_config.

  • backend – Backend for worker process group. Defaults to NCCL (GPU) or GLOO (CPU). Set None to disable.

  • timeout – Worker process group timeout (seconds).

  • default_env_vars – Environment variables to copy from the launcher process to workers. Supports bash pattern matching syntax.

  • extra_env_vars – Additional user-specified environment variables to copy.

  • env_file – Path to a file (e.g., .env) with additional environment variables to copy.

  • handler_factory – Function to build logging handlers that process agent and worker logs. Defaults to an automatic basic logging scheme.

Raises:
  • RuntimeError – If there are configuration issues.

  • AgentFailedError – If an agent fails, e.g. from an OS signal.

  • WorkerFailedError – If a worker fails, e.g. from a segmentation fault.

  • Exception – Any exception raised in a worker process is propagated.

We provide the torchrunx.Launcher class as an alias to torchrunx.launch.

class torchrunx.Launcher(
hostnames: list[str] | Literal['auto', 'slurm'] = 'auto',
workers_per_host: int | list[int] | Literal['auto', 'slurm'] = 'auto',
ssh_config_file: str | os.PathLike | None = None,
backend: Literal['nccl', 'gloo', 'mpi', 'ucc', 'auto'] | None = 'auto',
timeout: int = 600,
default_env_vars: tuple[str, ...] = ('PATH', 'LD_LIBRARY', 'LIBRARY_PATH', 'PYTHON*', 'CUDA*', 'TORCH*', 'PYTORCH*', 'NCCL*'),
extra_env_vars: tuple[str, ...] = (),
env_file: str | os.PathLike | None = None,
)[source]¶

Useful for sequential invocations or for specifying arguments via CLI.

run(
func: Callable,
func_args: tuple | None = None,
func_kwargs: dict[str, Any] | None = None,
handler_factory: Callable[[], list[Handler]] | Literal['auto'] | None = 'auto',
) LaunchResult[source]¶

Run a function using the torchrunx.Launcher configuration.

Results¶

class torchrunx.LaunchResult(hostnames: list[str], return_values: list[list[Any]])[source]¶

Container for objects returned from workers after successful launches.

by_hostnames() dict[str, list[Any]][source]¶

All return values from workers, indexed by host and local rank.

by_ranks() list[Any][source]¶

All return values from workers, indexed by global rank.

index(hostname: str, rank: int) Any[source]¶

Get return value from worker by host and local rank.

rank(i: int) Any[source]¶

Get return value from worker by global rank.

Exceptions¶

class torchrunx.AgentFailedError[source]¶

Raised if agent fails (e.g. if signal received).

class torchrunx.WorkerFailedError[source]¶

Raised if a worker fails (e.g. if signal recieved or segmentation fault).