import os
from pathlib import Path
import lightning as L
import torch
from datasets import load_dataset
from torch import nn
from torch.utils.data import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import torchrunx
from torchrunx.integrations.lightning import TorchrunxClusterEnvironment
class GPT2CausalLMDataset(Dataset):
def __init__(self, text_dataset):
self.dataset = text_dataset
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
self.tokenizer.pad_token = self.tokenizer.eos_token
self.max_length = 1024
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
encoded = self.tokenizer(
self.dataset[idx]["text"],
max_length=self.max_length,
truncation=True,
padding="max_length",
return_tensors="pt",
)
input_ids = encoded.input_ids.squeeze()
attention_mask = encoded.attention_mask.squeeze()
labels = input_ids.clone()
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
class GPT2LightningWrapper(L.LightningModule):
def __init__(self):
super().__init__()
self.model = AutoModelForCausalLM.from_pretrained("gpt2")
def training_step(self, batch, *args): # pyright: ignore
device_batch = {k: v.to(self.model.device) for k, v in batch.items()}
loss = self.model(**device_batch).loss
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
return optimizer
def train():
lightning_model = GPT2LightningWrapper()
wikitext_train = load_dataset("Salesforce/wikitext", "wikitext-2-v1", split="train")
train_dataset = GPT2CausalLMDataset(text_dataset=wikitext_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8)
trainer = L.Trainer(
accelerator="gpu",
limit_train_batches=10,
max_epochs=1,
devices=2,
num_nodes=1,
strategy="ddp",
plugins=[TorchrunxClusterEnvironment()],
enable_checkpointing=False
)
trainer.fit(model=lightning_model, train_dataloaders=train_loader)
checkpoint = f"{trainer.log_dir}/final.ckpt"
trainer.save_checkpoint(checkpoint)
return checkpoint
if __name__ == "__main__":
results = torchrunx.launch(
func=train,
hostnames=["localhost"],
workers_per_host=2,
)
checkpoint_path = results.rank(0)
print(f"Checkpoint at: {checkpoint_path}")