PyTorch Lightning¶
Here’s an example script that uses torchrunx
with PyTorch Lightning to fine-tune any causal language model (from transformers
) on any text dataset (from datasets
) with any number of GPUs or nodes.
https://torchrun.xyz/lightning_train.py
python lightning_train.py --help
(expand)
python lightning_train.py --help
usage: lightning_train.py [-h] [OPTIONS]
╭─ options ──────────────────────────────────────────────────────────────────╮
│ -h, --help │
│ show this help message and exit │
╰────────────────────────────────────────────────────────────────────────────╯
╭─ launcher options ─────────────────────────────────────────────────────────╮
│ For configuring the function launch environment. │
│ ────────────────────────────────────────────────────────────────────────── │
│ --launcher.hostnames {[STR [STR ...]]}|{auto,slurm} │
│ Nodes to launch the function on. By default, infer from SLURM, else │
│ ``["localhost"]``. (default: auto) │
│ --launcher.workers-per-host INT|{[INT [INT ...]]}|{cpu,gpu} │
│ Number of processes to run per node. By default, number of GPUs per │
│ host. (default: gpu) │
│ --launcher.ssh-config-file {None}|STR|PATHLIKE │
│ For connecting to nodes. By default, ``"~/.ssh/config"`` or │
│ ``"/etc/ssh/ssh_config"``. (default: None) │
│ --launcher.backend {None,nccl,gloo,mpi,ucc} │
│ `Backend │
│ <https://pytorch.org/docs/stable/distributed.html#torch.distributed.B… │
│ for worker process group. By default, NCCL (GPU backend). │
│ Use GLOO for CPU backend. ``None`` for no process group. │
│ (default: nccl) │
│ --launcher.timeout INT │
│ Worker process group timeout (seconds). (default: 600) │
│ --launcher.copy-env-vars [STR [STR ...]] │
│ Environment variables to copy from the launcher process to workers. │
│ Supports Unix pattern matching syntax. (default: PATH LD_LIBRARY │
│ LIBRARY_PATH 'PYTHON*' 'CUDA*' 'TORCH*' 'PYTORCH*' 'NCCL*') │
│ --launcher.extra-env-vars {None}|{[STR STR [STR STR ...]]} │
│ Additional environment variables to load onto workers. (default: None) │
│ --launcher.env-file {None}|STR|PATHLIKE │
│ Path to a ``.env`` file, containing environment variables to load onto │
│ workers. (default: None) │
╰────────────────────────────────────────────────────────────────────────────╯
╭─ model options ────────────────────────────────────────────────────────────╮
│ --model.name STR │
│ (required) │
╰────────────────────────────────────────────────────────────────────────────╯
╭─ dataset options ──────────────────────────────────────────────────────────╮
│ --dataset.path STR │
│ (required) │
│ --dataset.name {None}|STR │
│ (default: None) │
│ --dataset.split {None}|STR │
│ (default: None) │
│ --dataset.text-column STR │
│ (default: text) │
│ --dataset.num-samples {None}|INT │
│ (default: None) │
╰────────────────────────────────────────────────────────────────────────────╯
Training GPT-2 on WikiText in One Line¶
The following command runs our script end-to-end: installing all dependencies, downloading model and data, training, etc.
Pre-requisite: uv
uv run --python "3.12" https://torchrun.xyz/lightning_train.py \
--model.name gpt2 \
--dataset.path "Salesforce/wikitext" --dataset.name "wikitext-2-v1" --dataset.split "train" --dataset.num-samples 80
For multi-node training (+ if not using SLURM), you should also pass e.g. --launcher.hostnames node1 node2
.
Script¶
from __future__ import annotations
import functools
import logging
import os
from dataclasses import dataclass
from typing import Annotated
import lightning as L
import torch
import tyro
from datasets import load_dataset
from torch.utils.data import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
import torchrunx
from torchrunx.integrations.lightning import TorchrunxClusterEnvironment
logging.basicConfig(level=logging.INFO)
@dataclass
class ModelConfig:
name: str
@dataclass
class DatasetConfig:
path: str
name: str | None = None
split: str | None = None
text_column: str = "text"
num_samples: int | None = None
def load_training_data(
tokenizer_name: str,
dataset_config: DatasetConfig,
) -> Dataset:
# Load dataset
dataset = load_dataset(
dataset_config.path, name=dataset_config.name, split=dataset_config.split
)
if dataset_config.num_samples is not None:
dataset = dataset.select(range(dataset_config.num_samples))
# Build tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false" # to suppress warnings
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenize_fn = functools.partial(
tokenizer,
max_length=tokenizer.model_max_length,
truncation=True,
padding="max_length",
)
# Tokenize dataset
return dataset.map(
tokenize_fn,
batched=True,
input_columns=[dataset_config.text_column],
remove_columns=[dataset_config.text_column],
).map(lambda x: {"labels": x["input_ids"]})
class CausalLMLightningWrapper(L.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
def training_step(self, batch, *args): # pyright: ignore
device_batch = {k: torch.stack(v, dim=0).to(self.model.device) for k, v in batch.items()}
loss = self.model(**device_batch).loss
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
return optimizer
def train(model: PreTrainedModel, train_dataset: Dataset) -> str:
lightning_model = CausalLMLightningWrapper(model)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8)
trainer = L.Trainer(
accelerator="gpu",
max_epochs=1,
strategy="ddp",
plugins=[TorchrunxClusterEnvironment()],
enable_checkpointing=False,
)
trainer.fit(model=lightning_model, train_dataloaders=train_loader)
checkpoint = f"{trainer.log_dir}/final.ckpt"
trainer.save_checkpoint(checkpoint)
return checkpoint
def main(
launcher: torchrunx.Launcher,
model_config: Annotated[ModelConfig, tyro.conf.arg(name="model")],
dataset_config: Annotated[DatasetConfig, tyro.conf.arg(name="dataset")],
):
model = AutoModelForCausalLM.from_pretrained(model_config.name)
train_dataset = load_training_data(
tokenizer_name=model_config.name, dataset_config=dataset_config
)
# Launch training
results = launcher.run(train, model, train_dataset)
# Loading trained model from checkpoint
checkpoint_path = results.rank(0)
dummy_model = AutoModelForCausalLM.from_pretrained(model_config.name)
trained_model = CausalLMLightningWrapper(dummy_model)
trained_model.load_state_dict(torch.load(checkpoint_path)["state_dict"])
trained_model = trained_model.model
if __name__ == "__main__":
tyro.cli(main)