lollms-webui/train/train.py

252 lines
8.9 KiB
Python
Raw Normal View History

2023-06-06 20:20:29 +00:00
import os
from argparse import ArgumentParser
2024-12-19 12:48:57 +00:00
import torch
import wandb
2023-06-06 20:20:29 +00:00
from accelerate import Accelerator
2024-12-19 12:48:57 +00:00
from accelerate.utils import DummyOptim, DummyScheduler, set_seed
2023-06-06 20:20:29 +00:00
from data import load_data
2024-12-19 12:48:57 +00:00
from peft import LoraConfig, TaskType, get_peft_model
from read import read_config
from torch.optim import AdamW
2023-06-06 20:20:29 +00:00
from torchmetrics import MeanMetric
from tqdm import tqdm
2024-12-19 12:48:57 +00:00
from transformers import (AutoModelForCausalLM, AutoTokenizer,
LlamaForCausalLM, get_scheduler)
2023-06-06 20:20:29 +00:00
torch.backends.cuda.matmul.allow_tf32 = True
2024-12-19 12:48:57 +00:00
2023-06-06 20:20:29 +00:00
def format_metrics(metrics, split, prefix=""):
log = f"[{split}]" + prefix
log += " ".join([f"{key}: {value:.4f}" for key, value in metrics.items()])
return log
def evaluate(model, val_dataloader):
model.eval()
val_loss = MeanMetric(nan_strategy="error").to(model.device)
with torch.no_grad():
for batch in tqdm(val_dataloader):
loss = model(**batch).loss
loss_values = accelerator.gather_for_metrics({"loss": loss.detach()})
val_loss.update(loss_values["loss"])
return val_loss
def train(accelerator, config):
2024-12-19 12:48:57 +00:00
set_seed(config["seed"])
2023-06-06 20:20:29 +00:00
accelerator.print(config)
accelerator.print(f"Using {accelerator.num_processes} GPUs")
2024-12-19 12:48:57 +00:00
tokenizer = AutoTokenizer.from_pretrained(
config["tokenizer_name"], model_max_length=config["max_length"]
)
2023-06-06 20:20:29 +00:00
# if no pad token, set it to eos
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
with accelerator.main_process_first():
2024-12-19 12:48:57 +00:00
train_dataloader, val_dataloader = load_data(config, tokenizer)
2023-06-06 20:20:29 +00:00
checkpoint = config["gradient_checkpointing"]
2024-12-19 12:48:57 +00:00
model = AutoModelForCausalLM.from_pretrained(
config["model_name"],
use_cache=False if checkpoint else True,
trust_remote_code=True,
)
2023-06-06 20:20:29 +00:00
if checkpoint:
model.gradient_checkpointing_enable()
if config["lora"]:
peft_config = LoraConfig(
# should R be configurable?
2024-12-19 12:48:57 +00:00
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
2023-06-06 20:20:29 +00:00
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
optimizer_cls = (
AdamW
if accelerator.state.deepspeed_plugin is None
or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
else DummyOptim
)
# karpathy doesn't decay embeddding, maybe we should exclude
# https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s
2024-12-19 12:48:57 +00:00
optimizer = optimizer_cls(
model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"]
)
2023-06-06 20:20:29 +00:00
if accelerator.state.deepspeed_plugin is not None:
2024-12-19 12:48:57 +00:00
gradient_accumulation_steps = (
accelerator.state.deepspeed_plugin.deepspeed_config[
"gradient_accumulation_steps"
]
)
2023-06-06 20:20:29 +00:00
# decay to min_lr instead of 0
lr_ratio = config["min_lr"] / config["lr"]
accelerator.print(f"Len of train_dataloader: {len(train_dataloader)}")
2024-12-19 12:48:57 +00:00
total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * config[
"num_epochs"
]
2023-06-06 20:20:29 +00:00
# instead of decaying to zero, decay to ratio of min_lr / lr
total_num_steps += int(total_num_steps * lr_ratio) + config["warmup_steps"]
accelerator.print(f"Total training steps: {total_num_steps}")
# Creates Dummy Scheduler if `scheduler` was specified in the config file else creates `args.lr_scheduler_type` Scheduler
if (
accelerator.state.deepspeed_plugin is None
or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
):
scheduler = get_scheduler(
name="cosine",
optimizer=optimizer,
num_warmup_steps=config["warmup_steps"] * accelerator.num_processes,
num_training_steps=total_num_steps,
)
else:
scheduler = DummyScheduler(
2024-12-19 12:48:57 +00:00
optimizer,
total_num_steps=config["warmup_steps"],
warmup_num_steps=config["warmup_steps"],
2023-06-06 20:20:29 +00:00
)
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
2024-12-19 12:48:57 +00:00
model, optimizer, train_dataloader, val_dataloader, scheduler
2023-06-06 20:20:29 +00:00
)
# setup for saving training states in case preemption
accelerator.register_for_checkpointing(scheduler)
if config["checkpoint"]:
accelerator.load_state(config["checkpoint"])
accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}")
path = os.path.basename(config["train_args"]["resume_from_checkpoint"])
training_difference = os.path.splitext(path)[0]
resume_step = int(training_difference.replace("step_", ""))
accelerator.skip_first_batches(train_dataloader, resume_step)
accelerator.print(f"Resuming from step {resume_step}")
# log gradients
if accelerator.is_main_process and config["wandb"]:
wandb.watch(model, log_freq=config["log_grads_every"], log="all")
for epoch in range(config["num_epochs"]):
train_loss = MeanMetric(nan_strategy="error").to(model.device)
for step, batch in enumerate(tqdm(train_dataloader)):
model.train()
outputs = model(**batch)
loss = outputs.loss
# gather loss before backprop in case of gradient accumulation
2024-12-19 12:48:57 +00:00
loss_values = accelerator.gather_for_metrics(
{"loss": loss.detach().float()}
)
2023-06-06 20:20:29 +00:00
train_loss.update(loss_values["loss"])
loss = loss / gradient_accumulation_steps
accelerator.backward(loss)
# get gradient norm of all params
2024-12-19 12:48:57 +00:00
# log LR in case something weird happens
2023-06-06 20:20:29 +00:00
if step > 0 and step % (config["eval_every"] // 10) == 0:
if config["wandb"]:
curr_step = step + epoch * len(train_dataloader)
accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=curr_step)
2024-12-19 12:48:57 +00:00
if (step + 1) % gradient_accumulation_steps == 0 or step == len(
train_dataloader
) - 1:
2023-06-06 20:20:29 +00:00
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if step > 0 and step % config["save_every"] == 0:
curr_step = step + epoch * len(train_dataloader)
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
2024-12-19 12:48:57 +00:00
if step > 0 and (
step % config["eval_every"] == 0 or step == len(train_dataloader) - 1
):
2023-06-06 20:20:29 +00:00
val_loss = evaluate(model, val_dataloader)
2024-12-19 12:48:57 +00:00
log_train = {"train_loss": train_loss.compute()}
log_val = {"val_loss": val_loss.compute()}
2023-06-06 20:20:29 +00:00
if config["wandb"]:
curr_step = step + epoch * len(train_dataloader)
accelerator.log({**log_train, **log_val}, step=curr_step)
accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}")
accelerator.print(format_metrics(log_train, "train", f" step {step} "))
accelerator.print(format_metrics(log_val, "val", f" step {step} "))
train_loss.reset()
accelerator.print(f"Epoch {epoch} finished")
accelerator.print(f"Pushing to HF hub")
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
try:
if accelerator.is_main_process:
2024-12-19 12:48:57 +00:00
unwrapped_model.push_to_hub(
config["save_name"] + f"-epoch_{epoch}", private=True
)
2023-06-06 20:20:29 +00:00
except Exception as e:
accelerator.print(e)
accelerator.print(f"Failed to push to hub")
unwrapped_model.save_pretrained(
f"{config['output_dir']}/epoch_{epoch}",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
2024-12-19 12:48:57 +00:00
2023-06-06 20:20:29 +00:00
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
f"{config['output_dir']}/final",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
accelerator.end_training()
if __name__ == "__main__":
# parse arguments by reading in a config
parser = ArgumentParser()
parser.add_argument("--config", type=str, default="config.yaml")
args = parser.parse_args()
config = read_config(args.config)
if config["wandb"]:
accelerator = Accelerator(log_with="wandb")
accelerator.init_trackers(
project_name=config["wandb_project_name"],
config=config,
init_kwargs={"wandb": {"entity": config["wandb_entity"]}},
)
else:
accelerator = Accelerator()
2024-12-19 12:48:57 +00:00
train(accelerator, config=config)