DeepSpeed

Here’s an example script that uses torchrunx with DeepSpeed to fine-tune any causal language model (from transformers) on any text dataset (from datasets) with any number of GPUs or nodes.

https://torchrun.xyz/deepspeed_train.py

python deepspeed_train.py --help

(expand)
[2025-02-23 16:02:38,914] [WARNING] [real_accelerator.py:181:get_accelerator] Setting accelerator to CPU. If you have GPU or other accelerator, we were unable to detect it.
[2025-02-23 16:02:38,930] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cpu (auto detect)
usage: deepspeed_train.py [-h] [OPTIONS]

╭─ options ──────────────────────────────────────────────────────────────────╮
│ -h, --help                                                                 │
│     show this help message and exit                                        │
│ --model-name STR                                                           │
│     (required)                                                             │
│ --deepspeed-config PATH                                                    │
│     (required)                                                             │
│ --checkpoint-dir PATH                                                      │
│     (required)                                                             │
╰────────────────────────────────────────────────────────────────────────────╯
╭─ dataset options ──────────────────────────────────────────────────────────╮
│ --dataset.path STR                                                         │
│     (required)                                                             │
│ --dataset.name {None}|STR                                                  │
│     (default: None)                                                        │
│ --dataset.split {None}|STR                                                 │
│     (default: None)                                                        │
│ --dataset.text-column STR                                                  │
│     (default: text)                                                        │
│ --dataset.num-samples {None}|INT                                           │
│     (default: None)                                                        │
╰────────────────────────────────────────────────────────────────────────────╯
╭─ launcher options ─────────────────────────────────────────────────────────╮
│ For configuring the function launch environment.                           │
│ ────────────────────────────────────────────────────────────────────────── │
│ --launcher.hostnames {[STR [STR ...]]}|{auto,slurm}                        │
│     Nodes to launch the function on. By default, infer from SLURM, else    │
│     ``["localhost"]``. (default: auto)                                     │
│ --launcher.workers-per-host INT|{[INT [INT ...]]}|{cpu,gpu}                │
│     Number of processes to run per node. By default, number of GPUs per    │
│     host. (default: gpu)                                                   │
│ --launcher.ssh-config-file {None}|STR|PATHLIKE                             │
│     For connecting to nodes. By default, ``"~/.ssh/config"`` or            │
│     ``"/etc/ssh/ssh_config"``. (default: None)                             │
│ --launcher.backend {None,nccl,gloo,mpi,ucc}                                │
│     `Backend                                                               │
│     <https://pytorch.org/docs/stable/distributed.html#torch.distributed.B… │
│             for worker process group. By default, NCCL (GPU backend).      │
│             Use GLOO for CPU backend. ``None`` for no process group.       │
│     (default: nccl)                                                        │
│ --launcher.timeout INT                                                     │
│     Worker process group timeout (seconds). (default: 600)                 │
│ --launcher.copy-env-vars [STR [STR ...]]                                   │
│     Environment variables to copy from the launcher process to workers.    │
│     Supports Unix pattern matching syntax. (default: PATH LD_LIBRARY       │
│     LIBRARY_PATH 'PYTHON*' 'CUDA*' 'TORCH*' 'PYTORCH*' 'NCCL*')            │
│ --launcher.extra-env-vars {None}|{[STR STR [STR STR ...]]}                 │
│     Additional environment variables to load onto workers. (default: None) │
│ --launcher.env-file {None}|STR|PATHLIKE                                    │
│     Path to a ``.env`` file, containing environment variables to load onto │
│     workers. (default: None)                                               │
╰────────────────────────────────────────────────────────────────────────────╯

Training GPT-2 on WikiText

Deepspeed requires additional (non-Python) dependencies. Use the following commands to set up a project. [source: Apoorv’s Blog — Managing Project Dependencies]

Pre-requisite: pixi

pixi init my-project --format pyproject
cd my-project

# Install dependencies
pixi project channel add "conda-forge" "nvidia/label/cuda-12.4.0"
pixi add "python=3.12.7" "cuda=12.4.0" "gcc=11.4.0" "gxx=11.4.0"
pixi add --pypi "torch==2.5.1" "deepspeed" "datasets" "tensorboard" "torch" "torchrunx" "transformers" "tyro"

cat <<EOF > .env
export PYTHONNOUSERSITE="1"
export LIBRARY_PATH="\$CONDA_PREFIX/lib"
export LD_LIBRARY_PATH="\$CONDA_PREFIX/lib"
export CUDA_HOME="\$CONDA_PREFIX"
EOF

# Activate environment
pixi shell
source .env

Download deepspeed_train.py and create deepspeed_config.json with:

{
    "train_batch_size": 8,
    "gradient_accumulation_steps": 1,
    "optimizer": {
        "type": "Adam",
        "params": { "lr": 0.00015 }
    },
    "fp16": { "enabled": true },
    "zero_optimization": true,
    "tensorboard": {
        "enabled": true,
        "output_path": "output/tensorboard/",
        "job_name": "gpt2_wikitext"
    }
}
python deepspeed_train.py --model-name gpt2 --deepspeed-config deepspeed_config.json --checkpoint-dir output \
       --dataset.path "Salesforce/wikitext" --dataset.name "wikitext-2-v1" --dataset.split "train" --dataset.num-samples 80

For multi-node training (+ if not using SLURM), you should also pass e.g. --launcher.hostnames node1 node2.

You can visualize the logs with:

tensorboard --logdir output/tensorboard/gpt2_wikitext

Script

# /// script
# requires-python = ">=3.9"
# dependencies = [
#     "datasets",
#     "deepspeed",
#     "tensorboard",
#     "torch",
#     "torchrunx",
#     "transformers",
#     "tyro",
# ]
# ///

# [docs:start-after]
from __future__ import annotations

import functools
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Annotated

import deepspeed
import torch
import tyro
from datasets import load_dataset
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
from torch.utils.data import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel

import torchrunx

logging.basicConfig(level=logging.INFO)


@dataclass
class DatasetConfig:
    path: str
    name: str | None = None
    split: str | None = None
    text_column: str = "text"
    num_samples: int | None = None


def load_training_data(
    tokenizer_name: str,
    dataset_config: DatasetConfig,
) -> Dataset:
    # Load dataset

    dataset = load_dataset(
        dataset_config.path, name=dataset_config.name, split=dataset_config.split
    )
    if dataset_config.num_samples is not None:
        dataset = dataset.select(range(dataset_config.num_samples))

    # Build tokenizer

    os.environ["TOKENIZERS_PARALLELISM"] = "false"  # to suppress warnings
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenize_fn = functools.partial(
        tokenizer,
        max_length=tokenizer.model_max_length,
        truncation=True,
        padding="max_length",
    )

    # Tokenize dataset

    return dataset.map(
        tokenize_fn,
        batched=True,
        input_columns=[dataset_config.text_column],
        remove_columns=[dataset_config.text_column],
    ).map(lambda x: {"labels": x["input_ids"]})


def train(
    model: PreTrainedModel,
    train_dataset: Dataset,
    deepspeed_config: str | dict,
    checkpoint_dir: str,
) -> None:
    model_engine, _, data_loader, _ = deepspeed.initialize(
        model=model,
        model_parameters=model.parameters(),
        training_data=train_dataset,
        config=deepspeed_config,
    )

    model_engine.train()

    for step, batch in enumerate(data_loader):
        input_batch = {k: torch.stack(v).T.to(model_engine.device) for k, v in batch.items()}
        loss = model_engine(**input_batch).loss
        model_engine.backward(loss)
        model_engine.step()
        print(f"Step {step}, loss: {loss.item()}", flush=True, end="")

    model_engine.save_checkpoint(checkpoint_dir)


def main(
    model_name: str,
    deepspeed_config: Path,
    checkpoint_dir: Path,
    dataset_config: Annotated[DatasetConfig, tyro.conf.arg(name="dataset")],
    launcher: torchrunx.Launcher,
):
    model = AutoModelForCausalLM.from_pretrained(model_name)
    train_dataset = load_training_data(tokenizer_name=model_name, dataset_config=dataset_config)

    # Launch training
    launcher.run(train, model, train_dataset, str(deepspeed_config), str(checkpoint_dir))

    # Loading trained model from checkpoint
    state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir)
    trained_model = AutoModelForCausalLM.from_pretrained(model_name)
    trained_model.load_state_dict(state_dict)


if __name__ == "__main__":
    tyro.cli(main)