Accelerate¶
Here’s an example script that uses torchrunx
with Accelerate to fine-tune any causal language model (from transformers
) on any text dataset (from datasets
) with any number of GPUs or nodes.
https://torchrun.xyz/accelerate_train.py
python accelerate_train.py --help
(expand)
python accelerate_train.py --help
usage: accelerate_train.py [-h] [OPTIONS]
╭─ options ──────────────────────────────────────────────────────────────────╮
│ -h, --help │
│ show this help message and exit │
│ --batch-size INT │
│ (required) │
│ --output-dir PATH │
│ (required) │
╰────────────────────────────────────────────────────────────────────────────╯
╭─ launcher options ─────────────────────────────────────────────────────────╮
│ For configuring the function launch environment. │
│ ────────────────────────────────────────────────────────────────────────── │
│ --launcher.hostnames {[STR [STR ...]]}|{auto,slurm} │
│ Nodes to launch the function on. By default, infer from SLURM, else │
│ ``["localhost"]``. (default: auto) │
│ --launcher.workers-per-host INT|{[INT [INT ...]]}|{cpu,gpu} │
│ Number of processes to run per node. By default, number of GPUs per │
│ host. (default: gpu) │
│ --launcher.ssh-config-file {None}|STR|PATHLIKE │
│ For connecting to nodes. By default, ``"~/.ssh/config"`` or │
│ ``"/etc/ssh/ssh_config"``. (default: None) │
│ --launcher.backend {None,nccl,gloo,mpi,ucc} │
│ `Backend │
│ <https://pytorch.org/docs/stable/distributed.html#torch.distributed.B… │
│ for worker process group. By default, NCCL (GPU backend). │
│ Use GLOO for CPU backend. ``None`` for no process group. │
│ (default: nccl) │
│ --launcher.timeout INT │
│ Worker process group timeout (seconds). (default: 600) │
│ --launcher.copy-env-vars [STR [STR ...]] │
│ Environment variables to copy from the launcher process to workers. │
│ Supports Unix pattern matching syntax. (default: PATH LD_LIBRARY │
│ LIBRARY_PATH 'PYTHON*' 'CUDA*' 'TORCH*' 'PYTORCH*' 'NCCL*') │
│ --launcher.extra-env-vars {None}|{[STR STR [STR STR ...]]} │
│ Additional environment variables to load onto workers. (default: None) │
│ --launcher.env-file {None}|STR|PATHLIKE │
│ Path to a ``.env`` file, containing environment variables to load onto │
│ workers. (default: None) │
╰────────────────────────────────────────────────────────────────────────────╯
╭─ model options ────────────────────────────────────────────────────────────╮
│ --model.name STR │
│ (required) │
╰────────────────────────────────────────────────────────────────────────────╯
╭─ dataset options ──────────────────────────────────────────────────────────╮
│ --dataset.path STR │
│ (required) │
│ --dataset.name {None}|STR │
│ (default: None) │
│ --dataset.split {None}|STR │
│ (default: None) │
│ --dataset.text-column STR │
│ (default: text) │
│ --dataset.num-samples {None}|INT │
│ (default: None) │
╰────────────────────────────────────────────────────────────────────────────╯
Training GPT-2 on WikiText in One Line¶
The following command installs dependencies and runs our script (for example, with GPT-2
on WikiText
). For multi-node training (+ if not using SLURM), you should also pass e.g. --launcher.hostnames node1 node2
.
Pre-requisite: uv
uv run --python "3.12" https://torchrun.xyz/accelerate_train.py \
--batch-size 8 --output-dir output \
--model.name gpt2 \
--dataset.path "Salesforce/wikitext" --dataset.name "wikitext-2-v1" --dataset.split "train" --dataset.num-samples 80
Script¶
from __future__ import annotations
import functools
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Annotated
import torch
import tyro
from accelerate import Accelerator
from datasets import load_dataset
from torch.utils.data import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
import torchrunx
logging.basicConfig(level=logging.INFO)
@dataclass
class ModelConfig:
name: str
@dataclass
class DatasetConfig:
path: str
name: str | None = None
split: str | None = None
text_column: str = "text"
num_samples: int | None = None
def load_training_data(
tokenizer_name: str,
dataset_config: DatasetConfig,
) -> Dataset:
# Load dataset
dataset = load_dataset(
dataset_config.path, name=dataset_config.name, split=dataset_config.split
)
if dataset_config.num_samples is not None:
dataset = dataset.select(range(dataset_config.num_samples))
# Build tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false" # to suppress warnings
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenize_fn = functools.partial(
tokenizer,
max_length=tokenizer.model_max_length,
truncation=True,
padding="max_length",
)
# Tokenize dataset
return dataset.map(
tokenize_fn,
batched=True,
input_columns=[dataset_config.text_column],
remove_columns=[dataset_config.text_column],
).map(lambda x: {"labels": x["input_ids"]})
def train(
model: PreTrainedModel,
train_dataset: Dataset,
batch_size: int,
output_dir: Path,
) -> Path:
accelerator = Accelerator()
optimizer = torch.optim.Adam(model.parameters())
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)
model.train()
for batch_idx, batch in enumerate(train_dataloader):
device_batch = {k: torch.stack(v, dim=0).to(accelerator.device) for k, v in batch.items()}
optimizer.zero_grad()
loss = model(**device_batch).loss
print(f"Step {batch_idx}, loss: {loss.item()}", flush=True, end="")
accelerator.backward(loss)
optimizer.step()
accelerator.wait_for_everyone()
accelerator.save_state(output_dir=output_dir, safe_serialization=False)
return output_dir / "pytorch_model.bin"
def main(
launcher: torchrunx.Launcher,
model_config: Annotated[ModelConfig, tyro.conf.arg(name="model")],
dataset_config: Annotated[DatasetConfig, tyro.conf.arg(name="dataset")],
batch_size: int,
output_dir: Path,
):
model = AutoModelForCausalLM.from_pretrained(model_config.name)
train_dataset = load_training_data(
tokenizer_name=model_config.name, dataset_config=dataset_config
)
# Launch training
results = launcher.run(train, model, train_dataset, batch_size, output_dir)
# Loading trained model from checkpoint
checkpoint_path = results.rank(0)
trained_model = AutoModelForCausalLM.from_pretrained(
model_config.name, state_dict=torch.load(checkpoint_path)
)
if __name__ == "__main__":
tyro.cli(main)