Accelerate

Here’s an example script that uses torchrunx with Accelerate to fine-tune any causal language model (from transformers) on any text dataset (from datasets) with any number of GPUs or nodes.

https://torchrun.xyz/accelerate_train.py

python accelerate_train.py --help

(expand)
usage: accelerate_train.py [-h] [OPTIONS]

╭─ options ──────────────────────────────────────────────────────────────────╮
│ -h, --help                                                                 │
│     show this help message and exit                                        │
│ --batch-size INT                                                           │
│     (required)                                                             │
│ --output-dir PATH                                                          │
│     (required)                                                             │
╰────────────────────────────────────────────────────────────────────────────╯
╭─ launcher options ─────────────────────────────────────────────────────────╮
│ For configuring the function launch environment.                           │
│ ────────────────────────────────────────────────────────────────────────── │
│ --launcher.hostnames {[STR [STR ...]]}|{auto,slurm}                        │
│     Nodes to launch the function on. By default, infer from SLURM, else    │
│     ``["localhost"]``. (default: auto)                                     │
│ --launcher.workers-per-host INT|{[INT [INT ...]]}|{cpu,gpu}                │
│     Number of processes to run per node. By default, number of GPUs per    │
│     host. (default: gpu)                                                   │
│ --launcher.ssh-config-file {None}|STR|PATHLIKE                             │
│     For connecting to nodes. By default, ``"~/.ssh/config"`` or            │
│     ``"/etc/ssh/ssh_config"``. (default: None)                             │
│ --launcher.backend {None,nccl,gloo,mpi,ucc}                                │
│     `Backend                                                               │
│     <https://pytorch.org/docs/stable/distributed.html#torch.distributed.B… │
│             for worker process group. By default, NCCL (GPU backend).      │
│             Use GLOO for CPU backend. ``None`` for no process group.       │
│     (default: nccl)                                                        │
│ --launcher.timeout INT                                                     │
│     Worker process group timeout (seconds). (default: 600)                 │
│ --launcher.copy-env-vars [STR [STR ...]]                                   │
│     Environment variables to copy from the launcher process to workers.    │
│     Supports Unix pattern matching syntax. (default: PATH LD_LIBRARY       │
│     LIBRARY_PATH 'PYTHON*' 'CUDA*' 'TORCH*' 'PYTORCH*' 'NCCL*')            │
│ --launcher.extra-env-vars {None}|{[STR STR [STR STR ...]]}                 │
│     Additional environment variables to load onto workers. (default: None) │
│ --launcher.env-file {None}|STR|PATHLIKE                                    │
│     Path to a ``.env`` file, containing environment variables to load onto │
│     workers. (default: None)                                               │
╰────────────────────────────────────────────────────────────────────────────╯
╭─ model options ────────────────────────────────────────────────────────────╮
│ --model.name STR                                                           │
│     (required)                                                             │
╰────────────────────────────────────────────────────────────────────────────╯
╭─ dataset options ──────────────────────────────────────────────────────────╮
│ --dataset.path STR                                                         │
│     (required)                                                             │
│ --dataset.name {None}|STR                                                  │
│     (default: None)                                                        │
│ --dataset.split {None}|STR                                                 │
│     (default: None)                                                        │
│ --dataset.text-column STR                                                  │
│     (default: text)                                                        │
│ --dataset.num-samples {None}|INT                                           │
│     (default: None)                                                        │
╰────────────────────────────────────────────────────────────────────────────╯

Training GPT-2 on WikiText in One Line

The following command installs dependencies and runs our script (for example, with GPT-2 on WikiText). For multi-node training (+ if not using SLURM), you should also pass e.g. --launcher.hostnames node1 node2.

Pre-requisite: uv

uv run --python "3.12" https://torchrun.xyz/accelerate_train.py \
   --batch-size 8 --output-dir output \
   --model.name gpt2 \
   --dataset.path "Salesforce/wikitext" --dataset.name "wikitext-2-v1" --dataset.split "train" --dataset.num-samples 80

Script

from __future__ import annotations

import functools
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Annotated

import torch
import tyro
from accelerate import Accelerator
from datasets import load_dataset
from torch.utils.data import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel

import torchrunx

logging.basicConfig(level=logging.INFO)


@dataclass
class ModelConfig:
    name: str


@dataclass
class DatasetConfig:
    path: str
    name: str | None = None
    split: str | None = None
    text_column: str = "text"
    num_samples: int | None = None


def load_training_data(
    tokenizer_name: str,
    dataset_config: DatasetConfig,
) -> Dataset:
    # Load dataset

    dataset = load_dataset(
        dataset_config.path, name=dataset_config.name, split=dataset_config.split
    )
    if dataset_config.num_samples is not None:
        dataset = dataset.select(range(dataset_config.num_samples))

    # Build tokenizer

    os.environ["TOKENIZERS_PARALLELISM"] = "false"  # to suppress warnings
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenize_fn = functools.partial(
        tokenizer,
        max_length=tokenizer.model_max_length,
        truncation=True,
        padding="max_length",
    )

    # Tokenize dataset

    return dataset.map(
        tokenize_fn,
        batched=True,
        input_columns=[dataset_config.text_column],
        remove_columns=[dataset_config.text_column],
    ).map(lambda x: {"labels": x["input_ids"]})


def train(
    model: PreTrainedModel,
    train_dataset: Dataset,
    batch_size: int,
    output_dir: Path,
) -> Path:
    accelerator = Accelerator()

    optimizer = torch.optim.Adam(model.parameters())
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)

    model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)

    model.train()
    for batch_idx, batch in enumerate(train_dataloader):
        device_batch = {k: torch.stack(v, dim=0).to(accelerator.device) for k, v in batch.items()}
        optimizer.zero_grad()

        loss = model(**device_batch).loss
        print(f"Step {batch_idx}, loss: {loss.item()}", flush=True, end="")
        accelerator.backward(loss)

        optimizer.step()

    accelerator.wait_for_everyone()
    accelerator.save_state(output_dir=output_dir, safe_serialization=False)
    return output_dir / "pytorch_model.bin"


def main(
    launcher: torchrunx.Launcher,
    model_config: Annotated[ModelConfig, tyro.conf.arg(name="model")],
    dataset_config: Annotated[DatasetConfig, tyro.conf.arg(name="dataset")],
    batch_size: int,
    output_dir: Path,
):
    model = AutoModelForCausalLM.from_pretrained(model_config.name)
    train_dataset = load_training_data(
        tokenizer_name=model_config.name, dataset_config=dataset_config
    )

    # Launch training
    results = launcher.run(train, model, train_dataset, batch_size, output_dir)

    # Loading trained model from checkpoint
    checkpoint_path = results.rank(0)
    trained_model = AutoModelForCausalLM.from_pretrained(
        model_config.name, state_dict=torch.load(checkpoint_path)
    )


if __name__ == "__main__":
    tyro.cli(main)