-
Notifications
You must be signed in to change notification settings - Fork 130
Description
Hello, I am fine-tuning the LLaMA-2 7B model on an A100 40 GB GPU. Initially, I was getting a CUDA out-of-memory error. I tried various methods, such as reducing batch size, but none worked. Then I enabled:
model.gradient_checkpointing_enable()
After doing this, the OOM issue was resolved, but now I get the following error during backpropagation:
torch.autograd.backward(
File ".../torch/autograd/init.py", line 354, in backward
_engine_run_backward(
File ".../torch/autograd/graph.py", line 829, in _engine_run_backward
return Variable._execution_engine.run_backward(
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
I also tried:
model.enable_input_require_grads()
but the error still persists. I suspect the issue is related to enabling gradient checkpointing.
In model_init():
reft_model.gradient_checkpointing_enable()
reft_model.enable_input_require_grads()
Is there something I am missing when using gradient checkpointing in this setup?
Here is the finetune and model initialization function of the train.py file -
import argparse
import datetime
import json
from contextlib import nullcontext
from copy import deepcopy
import evaluate
import numpy as np
import torch
from compute_metrics import compute_metrics
from dataset import LoReftGLUEDataset, LoReftSupervisedDataset
from peft import PeftModel
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from task_config import task_config
from torch import profiler
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorForSeq2Seq,
DataCollatorWithPadding,
TrainingArguments,
set_seed,
)
from transformers.trainer_utils import EvalPrediction
import wandb
from pyreft import (
LoreftIntervention,
MoReIntervention,
NoIntervention,
NoreftIntervention,
ReftConfig,
ReftDataCollator,
ReftModel,
ReftTrainerForCausalLM,
ReftTrainerForSequenceClassification,
TaskType,
get_reft_model,
)
repo_root = os.path.dirname(os.path.dirname(os.path.abspath(file)))
config_path = config_path
args = None
tokenizer = None
reft_model = None
best_hyperparams = None
config = None
device = "cuda" if torch.cuda.is_available() else "cpu"
classification_tasks = {"glue"}
residual_stream_component_mapping = {"robertaformaskedlm": "roberta.encoder.layer[%s].output"}
dtype_mapping = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float8": "float8",
}
#peft_config = json.load(open("/fly/task_configs/llama/peft_config.json", "r")) # removed /fly prefix
with open(config_path, "r") as f:
peft_config = json.load(f)
def model_init(hyperparams: dict = best_hyperparams):
global peft_config, args, tokenizer, reft_model, config
if hyperparams == None:
hyperparams = {}
# everything is guarded by a single seed
set_seed(args.seed)
dtype = dtype_mapping[args.dtype]
model_name = args.model
# Hyperparameter search
if args.blk_r != -1:
hyperparams["blk_r"] = args.blk_r
if args.nblocks != -1:
hyperparams["nblocks"] = args.nblocks
if hyperparams is not None:
for k in peft_config.keys():
if k in hyperparams.keys() and hyperparams[k] != peft_config[k]:
print("Overriding the {} in best HP to {}".format(k, hyperparams[k]))
peft_config[k] = hyperparams[k]
if wandb.run is not None:
wandb.run.config.update(peft_config)
wandb.run.config.update({"dtype": dtype})
if reft_model is None:
if args.task in classification_tasks:
config = AutoConfig.from_pretrained(
model_path,
num_labels=args.num_labels,
finetuning_task=args.train_dataset,
load_in_8bit=True if args.dtype == "float8" else False,
device_map=device,
)
# full precision loading since usually for small models
reft_model = AutoModelForSequenceClassification.from_pretrained(
model_path,
config=config, # just providing the label
torch_dtype=dtype if args.dtype != "float8" else None,
load_in_8bit=True if args.dtype == "float8" else False,
device_map=device,
max_memory={0: 0.8},
# attn_implementation="flash_attention_2"
)
else:
reft_model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=dtype if args.dtype != "float8" else None, # save memory
load_in_8bit=True if args.dtype == "float8" else False,
device_map=device,
max_memory={0: 0.8}, # device: memory
attn_implementation="flash_attention_2",
)
config = reft_model.config
if not isinstance(reft_model, (ReftModel, PeftModel)):
# Optionally apply ReFT
if args.intervention_type == "LoreftIntervention":
intervention_type = LoreftIntervention
elif args.intervention_type == "NoreftIntervention":
intervention_type = NoreftIntervention
elif args.intervention_type == "MoReIntervention":
intervention_type = partial(MoReIntervention, dtype=dtype, blk_r=args.blk_r, nblocks=args.nblocks)
else:
intervention_type = NoIntervention
# intervention config based on model type
intervention_dtype = torch.bfloat16 if isinstance(dtype, str) else dtype
model_arch = config.architectures[0].lower()
if model_arch in residual_stream_component_mapping:
representations = [
{
"component": residual_stream_component_mapping[model_arch] % l,
"intervention": intervention_type(
embed_dim=config.hidden_size,
low_rank_dimension=args.rank,
dropout=args.dropout,
dtype=intervention_dtype,
act_fn=args.act_fn,
device=device,
add_bias=args.add_bias,
),
}
for l in args.layers
]
TaskType.SEQ_CLS
else:
representations = [
{
"layer": l,
"component": "block_output",
"low_rank_dimension": args.rank,
"intervention": intervention_type(
embed_dim=config.hidden_size,
low_rank_dimension=args.rank,
dropout=args.dropout,
dtype=intervention_dtype,
act_fn=args.act_fn,
device=device,
add_bias=args.add_bias,
),
}
for l in args.layers
]
TaskType.CAUSAL_LM
reft_model.gradient_checkpointing_enable()
reft_model.enable_input_require_grads()
reft_config = ReftConfig(representations=representations)
reft_model = get_reft_model(reft_model, reft_config, set_device=not isinstance(dtype, str))
# for GLUE tasks, we enable gradients on the classifier head.
# the parameter will be counted as well.
if args.task == "glue" and args.allow_cls_grad:
for param in reft_model.model.classifier.parameters():
# reft_model with HF trainer will automatically pick up these params to optimize
param.requires_grad = True
# Monarch adaptation
if args.mode == "monarch":
print("###### INIT MONARCH ######")
peft_config["dtype"] = dtype
init_monarch(reft_model, peft_config)
elif args.mode == "lora":
peft_config = {"target_modules": peft_config["target_modules"]}
peft_config["r"] = 32
peft_config["lora_alpha"] = 64
init_lora(reft_model, peft_config)
elif args.mode == "boft":
print("###### INIT BOFT ######")
peft_config["dtype"] = dtype
reft_model = init_boft(reft_model, peft_config)
else:
raise NotImplementedError()
param_stats(reft_model, training=False)
reft_model.print_trainable_parameters()
return reft_model
def finetune(
act_fn: str,
add_bias: bool,
model: str,
layers: str,
rank: int,
position: str,
epochs: int,
seed: int,
intervention_type: str,
max_n_train_example: int,
max_n_eval_example: int,
wandb_name: str,
gradient_accumulation_steps: int,
batch_size: int,
output_dir: str,
task: str,
lr: float,
schedule: str,
data_dir: str,
train_dataset: str,
eval_dataset: str,
save_model: bool,
eval_batch_size: int,
warmup_ratio: float,
weight_decay: float,
dropout: float,
test_split: str,
train_on_inputs: bool,
max_length: int,
use_normalized_template: bool,
allow_cls_grad: bool,
metric_for_best_model: str,
dtype: str,
logging_steps: int,
wandb_dir: str,
wandb_proj: str,
share_weights: bool,
greedy_decoding: bool,
temperature: float,
top_p: float,
top_k: float,
**kwargs,
):
"""
Generic Representation Finetuning.
"""
global tokenizer, reft_model, peft_config
use_wandb = kwargs.pop("wandb", True)
assert task in task_config
if data_dir is not None:
assert os.path.exists(data_dir), f"Data directory {data_dir} does not exist."
# store/log run details
print(
f"task: {task}, model: {model}, intervention_type: {intervention_type}, "
f"layers: {layers}, rank: {rank}, "
f"position: {position}, epoch: {epochs}, train_on_inputs: {train_on_inputs}, "
f"max_length: {max_length}, allow_cls_grad: {allow_cls_grad}"
)
now = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")
model_str = args.model.split("/")[-1]
if args.train_dataset is not None:
run_name = f"{model_str}.{args.task}.{args.train_dataset}.{args.test_split}.{now}"
else:
run_name = f"{model_str}.{args.task}.{now}"
# load dataset splits
assert task in task_config, f"Unrecognized task: {task}"
train_datasets = task_config[task]["train_datasets"] if train_dataset is None else [train_dataset]
if task == "glue":
eval_datasets = [train_dataset]
else:
eval_datasets = task_config[task]["eval_datasets"] if args.eval_dataset is None else [args.eval_dataset]
# initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(
'/nlsasfs/home/obfuscated/shilpyk/MISHA/local_models/llama-7b-hf',
model_max_length=args.max_length,
padding_side="right",
use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
# which layers to intervene on
if isinstance(args.layers, str):
if args.layers != "all":
args.layers = [int(l) for l in args.layers.split(";")]
else:
temp_config = AutoConfig.from_pretrained('/nlsasfs/home/obfuscated/shilpyk/MISHA/local_models/llama-7b-hf')
args.layers = [l for l in range(temp_config.num_hidden_layers)]
# position str takes the following formats:
# f1 -> first token; f2 -> first two tokens.
# f1+l1 -> first and last tokens; f2+l2 -> first and last two tokens.
# fn or ln shares the same intervention.
if "+" in args.position and not args.share_weights:
args.layers += args.layers
ReftDataset = LoReftGLUEDataset if task == "glue" else LoReftSupervisedDataset
path = os.path.join(data_dir, train_datasets[0]) if data_dir is not None else train_datasets[0]
if not args.do_train and not args.do_tune:
max_n_train_example = 1
train_dataset = ReftDataset(
task,
train_datasets[0] if task == "glue" else path,
tokenizer,
data_split="train",
seed=seed,
max_n_example=max_n_train_example,
**{"num_interventions": len(args.layers), "position": position, "share_weights": share_weights},
)
trigger_tokens = train_dataset.trigger_tokens
args.num_labels = train_dataset.num_labels
all_eval_datasets = {}
for eval_dataset in eval_datasets:
test_splits = test_split.split(";")
all_eval_datasets[eval_dataset] = {}
for split in test_splits:
if args.do_tune:
split = "train" # TODO: Ensure eval_loop doesn't throw a bug.. need to change later
path = os.path.join(data_dir, eval_dataset) if data_dir is not None else eval_dataset
raw_eval = ReftDataset(
task,
eval_dataset if task == "glue" else path,
tokenizer,
data_split=split,
seed=seed,
max_n_example=max_n_eval_example,
**{"num_interventions": len(args.layers), "position": position, "share_weights": share_weights},
is_eval=True,
)
all_eval_datasets[eval_dataset][split] = [raw_eval, raw_eval.raw_dataset]
eval_datasets = all_eval_datasets
# Initialize model
if args.all_linear:
peft_config["target_modules"] += ["o_proj", "up_proj", "down_proj", "gate_proj"]
reft_model = model_init()
n_params = reft_model.count_parameters(include_model=False)
if task == "glue":
# we repartition the eval_datatsets into [1] 50% validation + [2] 50% test
# we select the best model on [1] during training
# we test the selected model on [2] to ensure fairness
to_split_eval_datasets = eval_datasets[args.train_dataset][test_split][0]
if len(to_split_eval_datasets) > 5000:
in_train_n_eval_sample = 1000
else:
in_train_n_eval_sample = len(to_split_eval_datasets) // 2
new_splits = torch.utils.data.random_split(
to_split_eval_datasets, [len(to_split_eval_datasets) - in_train_n_eval_sample, in_train_n_eval_sample]
)
in_test_eval_datasets, in_train_eval_datasets = new_splits[0], new_splits[1]
eval_datasets[args.train_dataset][test_split][0] = in_test_eval_datasets
print("GLUE validation split (in training): ", len(in_train_eval_datasets))
print("GLUE validation split (testing): ", len(eval_datasets[args.train_dataset][test_split][0]))
is_regression = args.train_dataset == "stsb"
metric = evaluate.load("glue", args.train_dataset)
# You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
def in_training_compute_metrics(p: EvalPrediction):
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
result = metric.compute(predictions=preds, references=p.label_ids)
if len(result) > 1:
result["combined_score"] = np.mean(list(result.values())).item()
return result
# select collator based on the type
if task in classification_tasks:
data_collator_fn = DataCollatorWithPadding(tokenizer=tokenizer, padding="longest")
else:
data_collator_fn = DataCollatorForSeq2Seq(
tokenizer=tokenizer, model=model, label_pad_token_id=-100, padding="longest"
)
data_collator = ReftDataCollator(data_collator=data_collator_fn)
# start wandb logging
# if task == "tune_math":
# task = "math"
# Now datasets are set up, use the actual task name
if "tune" in task:
task = task.split("_")[-1]
task_dir = os.path.join(output_dir, task)
output_dir = os.path.join(task_dir, args.group) if args.group else task_dir
if args.notes:
run_name = args.notes + "_" + run_name
# Must set env variables to carry to them ray tune workers!
if args.wandb == False:
os.environ["WANDB_MODE"] = "offline"
if args.resume:
group_path = os.path.join(output_dir, "full_group.txt")
group = None
os.environ["WANDB_RUN_GROUP"] = group = get_run_group(
task, group=args.group, notes=args.notes, do_tune=args.do_tune
)
if os.path.exists(group_path):
os.environ["WANDB_RUN_GROUP"] = group = open(group_path, "r").read().strip()
else:
os.environ["WANDB_RUN_GROUP"] = group = get_run_group(
task, group=args.group, notes=args.notes, do_tune=args.do_tune
)
os.environ["WANDB_PROJECT"] = f"reft-monarch-{task}"
run = wandb.init(
project=os.environ["WANDB_PROJECT"],
name=run_name,
dir=wandb_dir,
group=group,
)
if not args.do_tune:
watch_layers(reft_model.model)
run.summary.update(vars(args))
wandb.log({"train/n_params": n_params})
# # training args
training_args = TrainingArguments(
output_dir=output_dir,
run_name=run_name,
num_train_epochs=epochs,
max_steps=args.max_steps,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=eval_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
eval_strategy="steps",
save_strategy="steps",
save_steps=args.save_steps,
metric_for_best_model=metric_for_best_model if task == "glue" else None,
# load_best_model_at_end=True if task == "glue" else False,
logging_strategy="steps",
save_total_limit=6,
logging_steps=logging_steps,
lr_scheduler_type=schedule,
learning_rate=lr,
warmup_ratio=warmup_ratio,
optim="adamw_torch",
weight_decay=weight_decay,
report_to="wandb" if use_wandb else "none",
use_cpu=False if device == "cuda" else True,
seed=seed,
# until HF supports ReFT, this remains False! :)
remove_unused_columns=False,
do_eval=True,
bf16=True if dtype == torch.bfloat16 else False,
)
# make trainer
trainer_class = ReftTrainerForSequenceClassification if task in classification_tasks else ReftTrainerForCausalLM
if task == "glue":
eval_dataset = in_train_eval_datasets
# Ray Tune requires eval. Otherwise the authors didn't use eval
elif args.do_tune:
assert len(eval_datasets) == 1, "Use only one eval set for HPO!"
raw_eval = list(list(eval_datasets.values())[0].values())[0][0] # TODO: beautify
eval_dataset = raw_eval
elif task == "commonsense":
eval_dataset = ReftDataset(
task,
train_datasets[0] if task == "glue" else path,
tokenizer,
data_split="train",
seed=seed,
max_n_example=max_n_eval_example,
**{"num_interventions": len(args.layers), "position": position, "share_weights": share_weights},
is_eval=True,
)
else:
eval_dataset = None
training_args.evaluation_strategy = "no"
trainer = trainer_class(
model=reft_model,
model_init=model_init,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
compute_metrics=in_training_compute_metrics if task == "glue" else None,
)
# Eval more than once per epoch
evals_per_epoch = args.evals_per_epoch
trainer.args.eval_steps = len(train_dataset) // (
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * evals_per_epoch
)
if args.do_tune:
# Save full tune group name for resuming
with open(os.path.join(training_args.output_dir, "full_group.txt"), "w") as f:
f.write(group)
_args = deepcopy(trainer.args)
trainer.args.save_total_limit = 0 # Avoid flooding the disk during HPO
trainer.args.load_best_model_at_end = False
trainer.args.save_strategy = "no"
trainer.args.evaluation_strategy = "steps" # Requried for Ray Tune
######################### Need to change for each task #########################
# PEFT monarch search space
if task == "math":
real_bs = [16, 32, 64]
elif task == "commonsense":
real_bs = [16, 32, 64]
else:
raise NotImplementedError(f"Don't forget to manually pick bs for task {task} !")
grad_acc_steps = [i // args.batch_size for i in real_bs]
param_space = {
# "nblocks": tune.choice(['sqrt(n)', 4]),
"seed": training_args.seed,
# "num_train_epochs": tune.choice([20, 25]),
"learning_rate": tune.quniform(1e-4, 9e-4, 1e-4),
"gradient_accumulation_steps": tune.choice(grad_acc_steps), # Will OOM if tune batch size
"weight_decay": tune.choice([0]),
"lr_scheduler_type": tune.choice(["cosine", "linear"]), # mostly linear underperforms
"dropout": tune.choice([0.05, 0.1]),
"blk_r": peft_config["blk_r"],
"nblocks": peft_config["nblocks"],
}
n_trials = args.n_trials
# Set up scheduler and reporter etc.
metric = f"eval_loss"
direction = "min" if "loss" in metric else "max" # minimize eval loss
tune_unit = "iter"
max_t = 40 * 60 if "tune_unit" == "time" else args.epochs * evals_per_epoch
grade_period = 4 * 60 if tune_unit == "time" else 2
time_attr = "time_total_s" if tune_unit == "time" else "training_iteration"
############################## End of task specific ##############################
scheduler = ASHAScheduler(
time_attr=time_attr,
max_t=max_t,
metric=metric,
mode=direction,
grace_period=grade_period,
)
# Do hyperparam optimization with Ray Tune
best_run = trainer.hyperparameter_search(
hp_space=lambda _: param_space,
backend="ray",
n_trials=n_trials, # under the hood it calls ray.tune.run(num_samples=n_trials, ...)
scheduler=scheduler,
keep_checkpoints_num=None,
resources_per_trial={"cpu": 1, "gpu": 1},
name=group,
local_dir="/fly/ray_results",
max_failures=9999, # tolerate OOM
direction="maximize" if direction == "max" else "minimize",
compute_objective=partial(get_hpo_metric, metric),
resume=args.resume,
)
del trainer.model
model = None
free_memory() # Re-init model
trainer.args = _args
best_hyperparams = best_run.hyperparameters
# Save the best HP for full training
print("Best hyperparameters: ", best_hyperparams)
# Save hyperparams
run_hp_path = os.path.join(training_args.output_dir, "best_hyperparams.json")
task_hp_path = os.path.join(task_dir, "best_hyperparams.json")
json.dump(best_hyperparams, open(run_hp_path, "w"))
json.dump(best_hyperparams, open(task_hp_path, "w"))
last_ckpt, _ = get_last_checkpoint(training_args.output_dir)
# last_ckpt = os.path.join(last_ckpt, "intervenable_model")
if args.do_train:
load_best_hp(training_args.output_dir, task_dir)
# TODO:enable resume
if args.profile:
ctx = profiler.profile(
schedule=profiler.schedule(wait=1, warmup=1, active=2, repeat=1),
on_trace_ready=profiler.tensorboard_trace_handler(f"./llama_reasoning_{args.mode}_log"),
record_shapes=True,
profile_memory=True,
with_stack=True,
)
trainer.add_callback(ProfCallback(ctx))
else:
ctx = nullcontext()
with ctx:
if args.resume:
trainer.train(resume_from_checkpoint=last_ckpt)
else:
trainer.train()
# dump config
args_dict = vars(args)
args_dict["n_params"] = n_params
json_file_name = f"{output_dir}/args.json"
with open(json_file_name, "w") as json_file:
json.dump(args_dict, json_file, indent=4)
# save model
reft_model.save(output_dir)
trainer.save_state()
# NOTE: force load best
trainer._load_best_model()
else:
print("Skipping training, loading last checkpoint...")
trainer.model.load_intervention(last_ckpt, include_model=True)
# ensure everything is in eval mode
reft_model.model.eval()
for k, v in reft_model.interventions.items():
_ = v[0].eval()
print({"n_params": n_params})
# do eval
eval_results = {}
for dataset_name in eval_datasets:
# split evalset into chunks
for split, (eval_dataset, data_items) in eval_datasets[dataset_name].items():
generations, stats = compute_metrics(
task,
dataset_name,
reft_model,
tokenizer,
eval_dataset,
data_items,
trigger_tokens,
run_name,
eval_batch_size,
data_collator if task in classification_tasks else None,
split,
greedy_decoding,
temperature,
top_p,
top_k,
)
# log
eval_results.update(stats)
if use_wandb:
wandb.log(stats)
generations = stats if generations is None else generations
result_json_file_name = f"{output_dir}/{dataset_name}_{split}_outputs.json"
with open(result_json_file_name, "w") as json_file:
json.dump(generations, json_file, indent=4)
# log final eval stats
result_json_file_name = f"{output_dir}/eval_results.json"
eval_results["n_params"] = n_params
with open(result_json_file_name, "w") as json_file:
json.dump(eval_results, json_file, indent=4)