Advanced Usage

Multiple functions in one script

We could also launch multiple functions (e.g. train on many GPUs, test on one GPU):

import torchrunx as trx

trained_model = trx.launch(
    func=train,
    hostnames=["node1", "node2"],
    workers_per_host=8
).rank(0)

accuracy = trx.launch(
    func=test,
    func_args=(trained_model,),
    hostnames=["localhost"],
    workers_per_host=1
).rank(0)

print(f'Accuracy: {accuracy}')

torchrunx.launch is self-cleaning: all processes are terminated (and the used memory is completely released) before the subsequent invocation.

CLI integration

We can use torchrunx.Launcher to populate arguments from the CLI (e.g. with tyro):

import torchrunx as trx
import tyro

def distributed_function():
    pass

if __name__ == "__main__":
    launcher = tyro.cli(trx.Launcher)
    launcher.run(distributed_function)

python ... --help then results in:

╭─ options ─────────────────────────────────────────────╮
│ -h, --help           show this help message and exit  │
│ --hostnames {[STR [STR ...]]}|{auto,slurm}            │
│                      (default: auto)                  │
│ --workers-per-host INT|{[INT [INT ...]]}|{auto,slurm} │
│                      (default: auto)                  │
│ --ssh-config-file {None}|STR|PATH                     │
│                      (default: None)                  │
│ --backend {None,nccl,gloo,mpi,ucc,auto}               │
│                      (default: auto)                  │
│ --timeout INT        (default: 600)                   │
│ --default-env-vars [STR [STR ...]]                    │
│                      (default: PATH LD_LIBRARY ...)   │
│ --extra-env-vars [STR [STR ...]]                      │
│                      (default: )                      │
│ --env-file {None}|STR|PATH                            │
│                      (default: None)                  │
╰───────────────────────────────────────────────────────╯

SLURM integration

By default, the hostnames or workers_per_host arguments are populated from the current SLURM allocation. If no allocation is detected, we assume 1 machine (localhost) with N workers (num. GPUs or CPUs). Raises a RuntimeError if hostnames="slurm" or workers_per_host="slurm" but no allocation is detected.

Propagating exceptions

Exceptions that are raised in workers will be raised by the launcher process.

A torchrunx.AgentFailedError or torchrunx.WorkerFailedError will be raised if any agent or worker dies unexpectedly (e.g. if sent a signal from the OS, due to segmentation faults or OOM).

Environment variables

Environment variables in the launcher process that match the default_env_vars argument are automatically copied to agents and workers. We set useful defaults for Python and PyTorch. Environment variables are pattern-matched with this list using fnmatch.

default_env_vars can be overriden if desired. This list can be augmented using extra_env_vars. Additional environment variables (and more custom bash logic) can be included via the env_file argument. Our agents source this file.

We also set the following environment variables in each worker: LOCAL_RANK, RANK, LOCAL_WORLD_SIZE, WORLD_SIZE, MASTER_ADDR, and MASTER_PORT.

Custom logging

We forward all logs (i.e. from logging and sys.stdout/sys.stderr) from workers and agents to the launcher. By default, the logs from the first agent and its first worker are printed into the launcher’s stdout stream. Logs from all agents and workers are written to files in $TORCHRUNX_LOG_DIR (default: ./torchrunx_logs) and are named by timestamp, hostname, and local_rank.

logging.Handler objects can be provided via the handler_factory argument to provide further customization (mapping specific agents/workers to custom output streams). You must pass a function that returns a list of logging.Handlers to handler_factory.

We provide some utilities to help:

torchrunx.file_handler(
hostname: str,
local_rank: int | None,
file_path: str | os.PathLike,
log_level: int = 0,
) Handler[source]

Handler builder function for writing logs from specified hostname/rank to a file.

torchrunx.stream_handler(hostname: str, local_rank: int | None, log_level: int = 0) Handler[source]

Handler builder function for writing logs from specified hostname/rank to stdout.

torchrunx.add_filter_to_handler(
handler: Handler,
hostname: str,
local_rank: int | None,
log_level: int = 0,
) None[source]

Apply a filter to logging.Handler so only specific worker logs are handled.

Parameters:
  • handler – Handler to be modified.

  • hostname – Name of specified host.

  • local_rank – Rank of specified worker on host (or None for agent itself).

  • log_level – Minimum log level to capture.