PyTorch Lightning

Here’s an example script that uses torchrunx with PyTorch Lightning to fine-tune any causal language model (from transformers) on any text dataset (from datasets) with any number of GPUs or nodes.

https://torchrun.xyz/lightning_train.py

python lightning_train.py --help

(expand)
usage: lightning_train.py [-h] [OPTIONS]

╭─ options ──────────────────────────────────────────────────────────────────╮
│ -h, --help                                                                 │
│     show this help message and exit                                        │
╰────────────────────────────────────────────────────────────────────────────╯
╭─ launcher options ─────────────────────────────────────────────────────────╮
│ For configuring the function launch environment.                           │
│ ────────────────────────────────────────────────────────────────────────── │
│ --launcher.hostnames {[STR [STR ...]]}|{auto,slurm}                        │
│     Nodes to launch the function on. By default, infer from SLURM, else    │
│     ``["localhost"]``. (default: auto)                                     │
│ --launcher.workers-per-host INT|{[INT [INT ...]]}|{cpu,gpu}                │
│     Number of processes to run per node. By default, number of GPUs per    │
│     host. (default: gpu)                                                   │
│ --launcher.ssh-config-file {None}|STR|PATHLIKE                             │
│     For connecting to nodes. By default, ``"~/.ssh/config"`` or            │
│     ``"/etc/ssh/ssh_config"``. (default: None)                             │
│ --launcher.backend {None,nccl,gloo,mpi,ucc}                                │
│     `Backend                                                               │
│     <https://pytorch.org/docs/stable/distributed.html#torch.distributed.B… │
│             for worker process group. By default, NCCL (GPU backend).      │
│             Use GLOO for CPU backend. ``None`` for no process group.       │
│     (default: nccl)                                                        │
│ --launcher.timeout INT                                                     │
│     Worker process group timeout (seconds). (default: 600)                 │
│ --launcher.copy-env-vars [STR [STR ...]]                                   │
│     Environment variables to copy from the launcher process to workers.    │
│     Supports Unix pattern matching syntax. (default: PATH LD_LIBRARY       │
│     LIBRARY_PATH 'PYTHON*' 'CUDA*' 'TORCH*' 'PYTORCH*' 'NCCL*')            │
│ --launcher.extra-env-vars {None}|{[STR STR [STR STR ...]]}                 │
│     Additional environment variables to load onto workers. (default: None) │
│ --launcher.env-file {None}|STR|PATHLIKE                                    │
│     Path to a ``.env`` file, containing environment variables to load onto │
│     workers. (default: None)                                               │
╰────────────────────────────────────────────────────────────────────────────╯
╭─ model options ────────────────────────────────────────────────────────────╮
│ --model.name STR                                                           │
│     (required)                                                             │
╰────────────────────────────────────────────────────────────────────────────╯
╭─ dataset options ──────────────────────────────────────────────────────────╮
│ --dataset.path STR                                                         │
│     (required)                                                             │
│ --dataset.name {None}|STR                                                  │
│     (default: None)                                                        │
│ --dataset.split {None}|STR                                                 │
│     (default: None)                                                        │
│ --dataset.text-column STR                                                  │
│     (default: text)                                                        │
│ --dataset.num-samples {None}|INT                                           │
│     (default: None)                                                        │
╰────────────────────────────────────────────────────────────────────────────╯

Training GPT-2 on WikiText in One Line

The following command runs our script end-to-end: installing all dependencies, downloading model and data, training, etc.

Pre-requisite: uv

uv run --python "3.12" https://torchrun.xyz/lightning_train.py \
   --model.name gpt2 \
   --dataset.path "Salesforce/wikitext" --dataset.name "wikitext-2-v1" --dataset.split "train" --dataset.num-samples 80

For multi-node training (+ if not using SLURM), you should also pass e.g. --launcher.hostnames node1 node2.

Script

from __future__ import annotations

import functools
import logging
import os
from dataclasses import dataclass
from typing import Annotated

import lightning as L
import torch
import tyro
from datasets import load_dataset
from torch.utils.data import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel

import torchrunx
from torchrunx.integrations.lightning import TorchrunxClusterEnvironment

logging.basicConfig(level=logging.INFO)


@dataclass
class ModelConfig:
    name: str


@dataclass
class DatasetConfig:
    path: str
    name: str | None = None
    split: str | None = None
    text_column: str = "text"
    num_samples: int | None = None


def load_training_data(
    tokenizer_name: str,
    dataset_config: DatasetConfig,
) -> Dataset:
    # Load dataset

    dataset = load_dataset(
        dataset_config.path, name=dataset_config.name, split=dataset_config.split
    )
    if dataset_config.num_samples is not None:
        dataset = dataset.select(range(dataset_config.num_samples))

    # Build tokenizer

    os.environ["TOKENIZERS_PARALLELISM"] = "false"  # to suppress warnings
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenize_fn = functools.partial(
        tokenizer,
        max_length=tokenizer.model_max_length,
        truncation=True,
        padding="max_length",
    )

    # Tokenize dataset

    return dataset.map(
        tokenize_fn,
        batched=True,
        input_columns=[dataset_config.text_column],
        remove_columns=[dataset_config.text_column],
    ).map(lambda x: {"labels": x["input_ids"]})


class CausalLMLightningWrapper(L.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def training_step(self, batch, *args):  # pyright: ignore
        device_batch = {k: torch.stack(v, dim=0).to(self.model.device) for k, v in batch.items()}
        loss = self.model(**device_batch).loss
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
        return optimizer


def train(model: PreTrainedModel, train_dataset: Dataset) -> str:
    lightning_model = CausalLMLightningWrapper(model)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8)

    trainer = L.Trainer(
        accelerator="gpu",
        max_epochs=1,
        strategy="ddp",
        plugins=[TorchrunxClusterEnvironment()],
        enable_checkpointing=False,
    )

    trainer.fit(model=lightning_model, train_dataloaders=train_loader)
    checkpoint = f"{trainer.log_dir}/final.ckpt"
    trainer.save_checkpoint(checkpoint)

    return checkpoint


def main(
    launcher: torchrunx.Launcher,
    model_config: Annotated[ModelConfig, tyro.conf.arg(name="model")],
    dataset_config: Annotated[DatasetConfig, tyro.conf.arg(name="dataset")],
):
    model = AutoModelForCausalLM.from_pretrained(model_config.name)
    train_dataset = load_training_data(
        tokenizer_name=model_config.name, dataset_config=dataset_config
    )

    # Launch training
    results = launcher.run(train, model, train_dataset)

    # Loading trained model from checkpoint
    checkpoint_path = results.rank(0)
    dummy_model = AutoModelForCausalLM.from_pretrained(model_config.name)
    trained_model = CausalLMLightningWrapper(dummy_model)
    trained_model.load_state_dict(torch.load(checkpoint_path)["state_dict"])
    trained_model = trained_model.model


if __name__ == "__main__":
    tyro.cli(main)