torchrunx 🔥¶
By Apoorv Khandelwal and Peter Curtin
The easiest way to run PyTorch on multiple GPUs or machines.
torchrunx
is a functional utility for distributing PyTorch code across devices. This is a more convenient, robust, and featureful alternative to CLI-based launchers, like torchrun
, accelerate launch
, and deepspeed
.
It enables complex workflows within a single script and has useful features even if only using 1 GPU.
pip install torchrunx
Requires: Linux. If using multiple machines: SSH & shared filesystem.
Example: simple training loop
Suppose we have some distributed training function (needs to run on every GPU):
def distributed_training(output_dir: str, num_steps: int = 10) -> str:
# returns path to model checkpoint
Click to expand (implementation)
from __future__ import annotations
import os
import torch
import torch.nn as nn
def distributed_training(output_dir: str, num_steps: int = 10) -> str | None:
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
model = nn.Linear(10, 10)
model.to(local_rank)
ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
optimizer = torch.optim.AdamW(ddp_model.parameters())
for step in range(num_steps):
optimizer.zero_grad()
inputs = torch.randn(5, 10).to(local_rank)
labels = torch.randn(5, 10).to(local_rank)
outputs = ddp_model(inputs)
torch.nn.functional.mse_loss(outputs, labels).backward()
optimizer.step()
if rank == 0:
os.makedirs(output_dir, exist_ok=True)
checkpoint_path = os.path.join(output_dir, "model.pt")
torch.save(model, checkpoint_path)
return checkpoint_path
return None
We can distribute and run this function (e.g. on 2 machines x 2 GPUs) using torchrunx
!
import logging
import torchrunx
logging.basicConfig(level=logging.INFO)
launcher = torchrunx.Launcher(
hostnames = ["localhost", "second_machine"], # or IP addresses
workers_per_host = "gpu" # default, or just: 2
)
results = launcher.run(
distributed_training,
output_dir = "outputs",
num_steps = 10,
)
Once completed, you can retrieve the results and process them as you wish.
checkpoint_path: str = results.rank(0)
# or: results.index(hostname="localhost", local_rank=0)
# and continue your script
model = torch.load(checkpoint_path, weights_only=False)
model.eval()
See more examples where we fine-tune LLMs using:
Refer to our API, Features, and Usage for many more capabilities!