General¶
Multiple functions in one script¶
Consider multiple stages of training: pre-training, supervised fine-tuning, RLHF, etc.
Normally, this kind of work is delegated to multiple scripts. Why? Each stage is complicated (prone to memory leaks) and we don’t want them to interfere with each other. They may even require different degrees of parallelism.
torchrunx
solves these problems — even within a single script — by modularizing workloads into isolated, self-cleaning processes.
# 2 nodes x 8 GPUs
train_launcher = torchrunx.Launcher(hostnames=["node1", "node2"], workers_per_host=8)
# 1 GPU
eval_launcher = torchrunx.Launcher(hostnames=["node1"], workers_per_host=1)
# Training & testing
pretrained_model = train_launcher.run(train).rank(0)
pretrained_acc = eval_launcher.run(evaluation, model=pretrained_model).rank(0)
print(f"Pre-trained model accuracy: {pretrained_acc}")
finetuned_model = train_launcher.run(finetuning, model=pretrained_model).rank(0)
finetuned_acc = eval_launcher.run(evaluation, model=finetuned_model).rank(0)
print(f"Fine-tuned model accuracy: {finetuned_acc}")
Exceptions¶
Exceptions that are raised in workers will be raised by the launcher process. A torchrunx.AgentFailedError
or torchrunx.WorkerFailedError
will be raised if any agent or worker dies unexpectedly (e.g. if sent a signal from the OS, due to segmentation faults or OOM).
You can catch these errors and handle them as you wish!
for config in configs: # e.g. hyper-parameter sweep
try:
torchrunx.Launcher().run(train, config)
except torch.cuda.OutOfMemoryError:
print(f"{config} results in OOM... continuing...")
If you are expecting intermittent failures, you can catch errors and invoke retries:
for retry in range(3):
try:
torchrunx.Launcher().run(train, resume_from_checkpoint=True)
except torchrunx.WorkerFailedError as e:
print(f"Error occurred: {e}")
print(f"Retrying ({retry}) ...")
else: # if run() is successful
break
Environment variables¶
Environment variables in the launcher process that pattern match the copy_env_vars
argument are automatically copied to agents and workers. We set useful defaults for Python and PyTorch. You could replace these. Or extend these like:
torchrunx.Launcher(copy_env_vars=(
torchrunx.DEFAULT_ENV_VARS_FOR_COPY + ("HF_HOME", "WANDB_*",)
))
You can also pass (1) specific environment variables and values via extra_env_vars
or (2) a .env
-style file via env_file
. Our agents source {env_file}
.
Finally, we set the following environment variables in each worker: LOCAL_RANK
, RANK
, LOCAL_WORLD_SIZE
, WORLD_SIZE
, MASTER_ADDR
, and MASTER_PORT
.