AccelerateΒΆ

# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "accelerate",
#     "datasets",
#     "tensorboard",
#     "torch",
#     "torchrunx",
#     "transformers",
#     "tyro",
# ]
# ///

import functools
import os
from dataclasses import dataclass
from typing import Annotated

import torch
from accelerate import Accelerator
from datasets import load_dataset
from torch.utils.data import Dataset
from transformers import AutoModelForCausalLM, PreTrainedModel, AutoTokenizer

import torchrunx
import tyro

@dataclass
class ModelConfig:
    name: str


@dataclass
class DatasetConfig:
    path: str
    name: str | None = None
    split: str | None = None
    text_column: str = "text"
    num_samples: int | None = None

def load_training_data(
    tokenizer_name: str,
    dataset_config: DatasetConfig,
) -> Dataset:
    # Load dataset

    dataset = load_dataset(dataset_config.path, name=dataset_config.name, split=dataset_config.split)
    if dataset_config.num_samples is not None:
        dataset = dataset.select(range(dataset_config.num_samples))

    # Build tokenizer

    os.environ["TOKENIZERS_PARALLELISM"] = "false"  # to suppress warnings
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenize_fn = functools.partial(
        tokenizer,
        max_length=tokenizer.model_max_length,
        truncation=True,
        padding="max_length",
    )

    # Tokenize dataset

    return dataset.map(
        tokenize_fn,
        batched=True,
        input_columns=[dataset_config.text_column],
        remove_columns=[dataset_config.text_column],
    ).map(lambda x: {"labels": x["input_ids"]})


def train(
    model: PreTrainedModel,
    train_dataset: Dataset,
) -> str:

    accelerator = Accelerator()

    optimizer = torch.optim.Adam(model.parameters())
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8)

    model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)

    model.train()
    for batch_idx, batch in enumerate(train_dataloader):
        if batch_idx == 10:
            break
        device_batch = {k: torch.stack(v, dim=0).to(accelerator.device) for k, v in batch.items()}
        optimizer.zero_grad()

        loss = model(**device_batch).loss
        print(f"Step {batch_idx}, loss: {loss.item()}", flush=True, end="")
        accelerator.backward(loss)

        optimizer.step()

    accelerator.unwrap_model(model).save_pretrained("output/")

    return "output/"

def main(
    launcher: torchrunx.Launcher,
    model_config: Annotated[ModelConfig, tyro.conf.arg(name="model")],
    dataset_config: Annotated[DatasetConfig, tyro.conf.arg(name="dataset")],
    # training_args: Annotated[TrainingArguments, tyro.conf.arg(name="trainer", help="")],
):
    model = AutoModelForCausalLM.from_pretrained(model_config.name)
    train_dataset = load_training_data(tokenizer_name=model_config.name, dataset_config=dataset_config)

    # Launch training
    results = launcher.run(train, (model, train_dataset))

    # Loading trained model from checkpoint
    checkpoint_path = results.rank(0)
    trained_model = AutoModelForCausalLM.from_pretrained(checkpoint_path)


if __name__ == "__main__":
    tyro.cli(main)