torchrunx 🔥

Python Version PyTorch Version PyPI - Version Documentation Tests GitHub License

By Apoorv Khandelwal and Peter Curtin

The easiest way to run PyTorch on multiple GPUs or machines.


torchrunx is a functional utility for distributing PyTorch code across devices. This is a more convenient, robust, and featureful alternative to CLI-based launchers, like torchrun, accelerate launch, and deepspeed.

pip install torchrunx

Requires: Linux (+ SSH & shared filesystem if using multiple machines)


Vanilla Example: Training a model on 2 machines with 2 GPUs each

Dummy distributed training function:

import os
import torch
import torch.nn as nn

def train(model: nn.Module, num_steps: int = 5) -> nn.Module | None:
    rank = int(os.environ['RANK'])
    local_rank = int(os.environ['LOCAL_RANK'])

    model.to(local_rank)
    ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
    optimizer = torch.optim.AdamW(ddp_model.parameters())

    for step in range(10):
        optimizer.zero_grad()

        inputs = torch.randn(5, 10).to(local_rank)
        labels = torch.randn(5, 10).to(local_rank)
        outputs = ddp_model(inputs)

        torch.nn.functional.mse_loss(outputs, labels).backward()
        optimizer.step()

    if rank == 0:
        return model.cpu()

Launching training with torchrunx:

import torchrunx

results = torchrunx.launch(
    func = train,
    func_kwargs = dict(
        model = nn.Linear(10, 10),
        num_steps = 10
    ),
    hostnames = ["localhost", "second_machine"],
    workers_per_host = 2
)

trained_model: nn.Module = results.rank(0)
torch.save(trained_model.state_dict(), "output/model.pth")

See examples where we fine-tune LLMs (e.g. GPT-2 on WikiText) using:

Refer to our API and Advanced Usage Guide for many more capabilities!


torchrunx uniquely offers

  1. An automatic launcher that “just works” for everyone 🚀

torchrunx is an SSH-based, pure-Python library that is universally easy to install.
No system-specific dependencies and orchestration for automatic multi-node distribution.

  1. Conventional CLI commands 🖥️

Run familiar commands, like python my_script.py ..., and customize arguments as you wish.

Other launchers override python in a cumbersome way: e.g. torchrun --nproc_per_node=2 --nnodes=2 --node_rank=0 --master_addr=100.43.331.111 --master_port=1234 my_script.py ....

  1. Support for more complex workflows in a single script 🎛️

Your workflow may have independent steps that need different parallelizations (e.g. training on 8 GPUs, testing on 1 GPU; comparing throughput on 4, then 8 GPUs; and so forth). CLI-based launchers naively parallelize the entire script for exactly N GPUs. In contrast, our library treats these steps in a modular way and permits degrees of parallelism in a single script.

We clean memory leaks as we go, so previous steps won’t crash or adversely affect future steps.

  1. Better handling of system failures. No more zombies! 🧟

With torchrun, your “work” is inherently coupled to your main Python process. If the system kills one of your workers (e.g. due to RAM OOM or segmentation faults), there is no way to fail gracefully in Python. Your processes might hang for 10 minutes (the NCCL timeout) or become perpetual zombies.

torchrunx decouples “launcher” and “worker” processes. If the system kills a worker, our launcher immediately raises a WorkerFailure exception, which users can handle as they wish. We always clean up all nodes, so no more zombies!

  1. Bonus features 🎁

  • Fine-grained, custom handling of logging, environment variables, and exception propagation. We have nice defaults too: no more interleaved logs and irrelevant exceptions!

  • No need to manually set up a dist.init_process_group

  • Automatic detection of SLURM environments.

  • Start multi-node training from Python notebooks!

On our roadmap: higher-order parallelism, support for debuggers, fuller typing, and more!