Skip to content

[P1] RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn #170

@Mishajain1110

Description

@Mishajain1110

Hello, I am fine-tuning the LLaMA-2 7B model on an A100 40 GB GPU. Initially, I was getting a CUDA out-of-memory error. I tried various methods, such as reducing batch size, but none worked. Then I enabled:

model.gradient_checkpointing_enable()

After doing this, the OOM issue was resolved, but now I get the following error during backpropagation:

torch.autograd.backward(
File ".../torch/autograd/init.py", line 354, in backward
_engine_run_backward(
File ".../torch/autograd/graph.py", line 829, in _engine_run_backward
return Variable._execution_engine.run_backward(
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

I also tried:

model.enable_input_require_grads()

but the error still persists. I suspect the issue is related to enabling gradient checkpointing.

In model_init():
reft_model.gradient_checkpointing_enable()
reft_model.enable_input_require_grads()

Is there something I am missing when using gradient checkpointing in this setup?

Here is the finetune and model initialization function of the train.py file -

import argparse
import datetime
import json
from contextlib import nullcontext
from copy import deepcopy

import evaluate
import numpy as np
import torch
from compute_metrics import compute_metrics
from dataset import LoReftGLUEDataset, LoReftSupervisedDataset
from peft import PeftModel
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from task_config import task_config
from torch import profiler
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorForSeq2Seq,
DataCollatorWithPadding,
TrainingArguments,
set_seed,
)
from transformers.trainer_utils import EvalPrediction

import wandb
from pyreft import (
LoreftIntervention,
MoReIntervention,
NoIntervention,
NoreftIntervention,
ReftConfig,
ReftDataCollator,
ReftModel,
ReftTrainerForCausalLM,
ReftTrainerForSequenceClassification,
TaskType,
get_reft_model,
)
repo_root = os.path.dirname(os.path.dirname(os.path.abspath(file)))
config_path = config_path

args = None
tokenizer = None
reft_model = None
best_hyperparams = None
config = None
device = "cuda" if torch.cuda.is_available() else "cpu"
classification_tasks = {"glue"}
residual_stream_component_mapping = {"robertaformaskedlm": "roberta.encoder.layer[%s].output"}
dtype_mapping = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float8": "float8",
}
#peft_config = json.load(open("/fly/task_configs/llama/peft_config.json", "r")) # removed /fly prefix
with open(config_path, "r") as f:
peft_config = json.load(f)

def model_init(hyperparams: dict = best_hyperparams):
global peft_config, args, tokenizer, reft_model, config
if hyperparams == None:
hyperparams = {}

# everything is guarded by a single seed
set_seed(args.seed)
dtype = dtype_mapping[args.dtype]
model_name = args.model
# Hyperparameter search
if args.blk_r != -1:
    hyperparams["blk_r"] = args.blk_r
if args.nblocks != -1:
    hyperparams["nblocks"] = args.nblocks
if hyperparams is not None:
    for k in peft_config.keys():
        if k in hyperparams.keys() and hyperparams[k] != peft_config[k]:
            print("Overriding the {} in best HP to {}".format(k, hyperparams[k]))
            peft_config[k] = hyperparams[k]

if wandb.run is not None:
    wandb.run.config.update(peft_config)
    wandb.run.config.update({"dtype": dtype})

if reft_model is None:
    if args.task in classification_tasks:
        config = AutoConfig.from_pretrained(
            model_path,
            num_labels=args.num_labels,
            finetuning_task=args.train_dataset,
            load_in_8bit=True if args.dtype == "float8" else False,
            device_map=device,
        )
        # full precision loading since usually for small models
        reft_model = AutoModelForSequenceClassification.from_pretrained(
            model_path,
            config=config,  # just providing the label
            torch_dtype=dtype if args.dtype != "float8" else None,
            load_in_8bit=True if args.dtype == "float8" else False,
            device_map=device,
            max_memory={0: 0.8},
            # attn_implementation="flash_attention_2"
        )

    else:
        reft_model = AutoModelForCausalLM.from_pretrained(
             model_path,
            torch_dtype=dtype if args.dtype != "float8" else None,  # save memory
            load_in_8bit=True if args.dtype == "float8" else False,
            device_map=device,
            max_memory={0: 0.8},  # device: memory
            attn_implementation="flash_attention_2",
        )
        config = reft_model.config

if not isinstance(reft_model, (ReftModel, PeftModel)):
    # Optionally apply ReFT
    if args.intervention_type == "LoreftIntervention":
        intervention_type = LoreftIntervention
    elif args.intervention_type == "NoreftIntervention":
        intervention_type = NoreftIntervention
    elif args.intervention_type == "MoReIntervention":
        intervention_type = partial(MoReIntervention, dtype=dtype, blk_r=args.blk_r, nblocks=args.nblocks)
    else:
        intervention_type = NoIntervention

    # intervention config based on model type
    intervention_dtype = torch.bfloat16 if isinstance(dtype, str) else dtype
    model_arch = config.architectures[0].lower()
    if model_arch in residual_stream_component_mapping:
        representations = [
            {
                "component": residual_stream_component_mapping[model_arch] % l,
                "intervention": intervention_type(
                    embed_dim=config.hidden_size,
                    low_rank_dimension=args.rank,
                    dropout=args.dropout,
                    dtype=intervention_dtype,
                    act_fn=args.act_fn,
                    device=device,
                    add_bias=args.add_bias,
                ),
            }
            for l in args.layers
        ]
        TaskType.SEQ_CLS
    else:
        representations = [
            {
                "layer": l,
                "component": "block_output",
                "low_rank_dimension": args.rank,
                "intervention": intervention_type(
                    embed_dim=config.hidden_size,
                    low_rank_dimension=args.rank,
                    dropout=args.dropout,
                    dtype=intervention_dtype,
                    act_fn=args.act_fn,
                    device=device,
                    add_bias=args.add_bias,
                ),
            }
            for l in args.layers
        ]
        TaskType.CAUSAL_LM
    reft_model.gradient_checkpointing_enable()
    reft_model.enable_input_require_grads()
    reft_config = ReftConfig(representations=representations)
    reft_model = get_reft_model(reft_model, reft_config, set_device=not isinstance(dtype, str))
    # for GLUE tasks, we enable gradients on the classifier head.
    # the parameter will be counted as well.

    if args.task == "glue" and args.allow_cls_grad:
        for param in reft_model.model.classifier.parameters():
            # reft_model with HF trainer will automatically pick up these params to optimize
            param.requires_grad = True

    # Monarch adaptation
    if args.mode == "monarch":
        print("###### INIT MONARCH ######")
        peft_config["dtype"] = dtype
        init_monarch(reft_model, peft_config)
    elif args.mode == "lora":
        peft_config = {"target_modules": peft_config["target_modules"]}
        peft_config["r"] = 32
        peft_config["lora_alpha"] = 64
        init_lora(reft_model, peft_config)
    elif args.mode == "boft":
        print("###### INIT BOFT ######")
        peft_config["dtype"] = dtype
        reft_model = init_boft(reft_model, peft_config)
    else:
        raise NotImplementedError()

param_stats(reft_model, training=False)
reft_model.print_trainable_parameters()
return reft_model

def finetune(
act_fn: str,
add_bias: bool,
model: str,
layers: str,
rank: int,
position: str,
epochs: int,
seed: int,
intervention_type: str,
max_n_train_example: int,
max_n_eval_example: int,
wandb_name: str,
gradient_accumulation_steps: int,
batch_size: int,
output_dir: str,
task: str,
lr: float,
schedule: str,
data_dir: str,
train_dataset: str,
eval_dataset: str,
save_model: bool,
eval_batch_size: int,
warmup_ratio: float,
weight_decay: float,
dropout: float,
test_split: str,
train_on_inputs: bool,
max_length: int,
use_normalized_template: bool,
allow_cls_grad: bool,
metric_for_best_model: str,
dtype: str,
logging_steps: int,
wandb_dir: str,
wandb_proj: str,
share_weights: bool,
greedy_decoding: bool,
temperature: float,
top_p: float,
top_k: float,
**kwargs,
):
"""
Generic Representation Finetuning.
"""
global tokenizer, reft_model, peft_config
use_wandb = kwargs.pop("wandb", True)

assert task in task_config
if data_dir is not None:
    assert os.path.exists(data_dir), f"Data directory {data_dir} does not exist."

# store/log run details
print(
    f"task: {task}, model: {model}, intervention_type: {intervention_type}, "
    f"layers: {layers}, rank: {rank}, "
    f"position: {position}, epoch: {epochs}, train_on_inputs: {train_on_inputs}, "
    f"max_length: {max_length}, allow_cls_grad: {allow_cls_grad}"
)

now = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")
model_str = args.model.split("/")[-1]
if args.train_dataset is not None:
    run_name = f"{model_str}.{args.task}.{args.train_dataset}.{args.test_split}.{now}"
else:
    run_name = f"{model_str}.{args.task}.{now}"
# load dataset splits
assert task in task_config, f"Unrecognized task: {task}"
train_datasets = task_config[task]["train_datasets"] if train_dataset is None else [train_dataset]
if task == "glue":
    eval_datasets = [train_dataset]
else:
    eval_datasets = task_config[task]["eval_datasets"] if args.eval_dataset is None else [args.eval_dataset]

# initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    '/nlsasfs/home/obfuscated/shilpyk/MISHA/local_models/llama-7b-hf',
    model_max_length=args.max_length,
    padding_side="right",
    use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
# which layers to intervene on
if isinstance(args.layers, str):
    if args.layers != "all":
        args.layers = [int(l) for l in args.layers.split(";")]
    else:
        temp_config = AutoConfig.from_pretrained('/nlsasfs/home/obfuscated/shilpyk/MISHA/local_models/llama-7b-hf')
        args.layers = [l for l in range(temp_config.num_hidden_layers)]

    # position str takes the following formats:
    # f1 -> first token; f2 -> first two tokens.
    # f1+l1 -> first and last tokens; f2+l2 -> first and last two tokens.
    # fn or ln shares the same intervention.
    if "+" in args.position and not args.share_weights:
        args.layers += args.layers

ReftDataset = LoReftGLUEDataset if task == "glue" else LoReftSupervisedDataset
path = os.path.join(data_dir, train_datasets[0]) if data_dir is not None else train_datasets[0]
if not args.do_train and not args.do_tune:
    max_n_train_example = 1
train_dataset = ReftDataset(
    task,
    train_datasets[0] if task == "glue" else path,
    tokenizer,
    data_split="train",
    seed=seed,
    max_n_example=max_n_train_example,
    **{"num_interventions": len(args.layers), "position": position, "share_weights": share_weights},
)
trigger_tokens = train_dataset.trigger_tokens
args.num_labels = train_dataset.num_labels

all_eval_datasets = {}
for eval_dataset in eval_datasets:
    test_splits = test_split.split(";")
    all_eval_datasets[eval_dataset] = {}
    for split in test_splits:
        if args.do_tune:
            split = "train"  # TODO: Ensure eval_loop doesn't throw a bug.. need to change later

        path = os.path.join(data_dir, eval_dataset) if data_dir is not None else eval_dataset
        raw_eval = ReftDataset(
            task,
            eval_dataset if task == "glue" else path,
            tokenizer,
            data_split=split,
            seed=seed,
            max_n_example=max_n_eval_example,
            **{"num_interventions": len(args.layers), "position": position, "share_weights": share_weights},
            is_eval=True,
        )
        all_eval_datasets[eval_dataset][split] = [raw_eval, raw_eval.raw_dataset]
eval_datasets = all_eval_datasets
# Initialize model
if args.all_linear:
    peft_config["target_modules"] += ["o_proj", "up_proj", "down_proj", "gate_proj"]

reft_model = model_init()
n_params = reft_model.count_parameters(include_model=False)

if task == "glue":
    # we repartition the eval_datatsets into [1] 50% validation + [2] 50% test
    # we select the best model on [1] during training
    # we test the selected model on [2] to ensure fairness
    to_split_eval_datasets = eval_datasets[args.train_dataset][test_split][0]
    if len(to_split_eval_datasets) > 5000:
        in_train_n_eval_sample = 1000
    else:
        in_train_n_eval_sample = len(to_split_eval_datasets) // 2

    new_splits = torch.utils.data.random_split(
        to_split_eval_datasets, [len(to_split_eval_datasets) - in_train_n_eval_sample, in_train_n_eval_sample]
    )

    in_test_eval_datasets, in_train_eval_datasets = new_splits[0], new_splits[1]
    eval_datasets[args.train_dataset][test_split][0] = in_test_eval_datasets
    print("GLUE validation split (in training): ", len(in_train_eval_datasets))
    print("GLUE validation split (testing): ", len(eval_datasets[args.train_dataset][test_split][0]))

    is_regression = args.train_dataset == "stsb"
    metric = evaluate.load("glue", args.train_dataset)

    # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
    # predictions and label_ids field) and has to return a dictionary string to float.
    def in_training_compute_metrics(p: EvalPrediction):
        preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
        preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
        result = metric.compute(predictions=preds, references=p.label_ids)
        if len(result) > 1:
            result["combined_score"] = np.mean(list(result.values())).item()
        return result

# select collator based on the type
if task in classification_tasks:
    data_collator_fn = DataCollatorWithPadding(tokenizer=tokenizer, padding="longest")
else:
    data_collator_fn = DataCollatorForSeq2Seq(
        tokenizer=tokenizer, model=model, label_pad_token_id=-100, padding="longest"
    )

data_collator = ReftDataCollator(data_collator=data_collator_fn)

# start wandb logging
# if task == "tune_math":
# task = "math"
# Now datasets are set up, use the actual task name
if "tune" in task:
    task = task.split("_")[-1]
task_dir = os.path.join(output_dir, task)
output_dir = os.path.join(task_dir, args.group) if args.group else task_dir

if args.notes:
    run_name = args.notes + "_" + run_name

# Must set env variables to carry to them ray tune workers!
if args.wandb == False:
    os.environ["WANDB_MODE"] = "offline"
if args.resume:
    group_path = os.path.join(output_dir, "full_group.txt")
    group = None
    os.environ["WANDB_RUN_GROUP"] = group = get_run_group(
        task, group=args.group, notes=args.notes, do_tune=args.do_tune
    )
    if os.path.exists(group_path):
        os.environ["WANDB_RUN_GROUP"] = group = open(group_path, "r").read().strip()
else:
    os.environ["WANDB_RUN_GROUP"] = group = get_run_group(
        task, group=args.group, notes=args.notes, do_tune=args.do_tune
    )

os.environ["WANDB_PROJECT"] = f"reft-monarch-{task}"
run = wandb.init(
    project=os.environ["WANDB_PROJECT"],
    name=run_name,
    dir=wandb_dir,
    group=group,
)
if not args.do_tune:
    watch_layers(reft_model.model)
run.summary.update(vars(args))
wandb.log({"train/n_params": n_params})

# # training args
training_args = TrainingArguments(
    output_dir=output_dir,
    run_name=run_name,
    num_train_epochs=epochs,
    max_steps=args.max_steps,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=eval_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    eval_strategy="steps",
    save_strategy="steps",
    save_steps=args.save_steps,
    metric_for_best_model=metric_for_best_model if task == "glue" else None,
    # load_best_model_at_end=True if task == "glue" else False,
    logging_strategy="steps",
    save_total_limit=6,
    logging_steps=logging_steps,
    lr_scheduler_type=schedule,
    learning_rate=lr,
    warmup_ratio=warmup_ratio,
    optim="adamw_torch",
    weight_decay=weight_decay,
    report_to="wandb" if use_wandb else "none",
    use_cpu=False if device == "cuda" else True,
    seed=seed,
    # until HF supports ReFT, this remains False! :)
    remove_unused_columns=False,
    do_eval=True,
    bf16=True if dtype == torch.bfloat16 else False,
)

# make trainer
trainer_class = ReftTrainerForSequenceClassification if task in classification_tasks else ReftTrainerForCausalLM
if task == "glue":
    eval_dataset = in_train_eval_datasets
# Ray Tune requires eval. Otherwise the authors didn't use eval
elif args.do_tune:
    assert len(eval_datasets) == 1, "Use only one eval set for HPO!"
    raw_eval = list(list(eval_datasets.values())[0].values())[0][0]  # TODO: beautify
    eval_dataset = raw_eval
elif task == "commonsense":
    eval_dataset = ReftDataset(
        task,
        train_datasets[0] if task == "glue" else path,
        tokenizer,
        data_split="train",
        seed=seed,
        max_n_example=max_n_eval_example,
        **{"num_interventions": len(args.layers), "position": position, "share_weights": share_weights},
        is_eval=True,
    )
else:
    eval_dataset = None
    training_args.evaluation_strategy = "no"

trainer = trainer_class(
    model=reft_model,
    model_init=model_init,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=in_training_compute_metrics if task == "glue" else None,
)
# Eval more than once per epoch
evals_per_epoch = args.evals_per_epoch
trainer.args.eval_steps = len(train_dataset) // (
    training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * evals_per_epoch
)

if args.do_tune:
    # Save full tune group name for resuming
    with open(os.path.join(training_args.output_dir, "full_group.txt"), "w") as f:
        f.write(group)

    _args = deepcopy(trainer.args)
    trainer.args.save_total_limit = 0  # Avoid flooding the disk during HPO
    trainer.args.load_best_model_at_end = False
    trainer.args.save_strategy = "no"
    trainer.args.evaluation_strategy = "steps"  # Requried for Ray Tune

    ######################### Need to change for each task #########################
    # PEFT monarch search space
    if task == "math":
        real_bs = [16, 32, 64]
    elif task == "commonsense":
        real_bs = [16, 32, 64]
    else:
        raise NotImplementedError(f"Don't forget to manually pick bs for task {task} !")
    grad_acc_steps = [i // args.batch_size for i in real_bs]

    param_space = {
        # "nblocks": tune.choice(['sqrt(n)', 4]),
        "seed": training_args.seed,
        # "num_train_epochs": tune.choice([20, 25]),
        "learning_rate": tune.quniform(1e-4, 9e-4, 1e-4),
        "gradient_accumulation_steps": tune.choice(grad_acc_steps),  # Will OOM if tune batch size
        "weight_decay": tune.choice([0]),
        "lr_scheduler_type": tune.choice(["cosine", "linear"]),  # mostly linear underperforms
        "dropout": tune.choice([0.05, 0.1]),
        "blk_r": peft_config["blk_r"],
        "nblocks": peft_config["nblocks"],
    }
    n_trials = args.n_trials

    # Set up scheduler and reporter etc.
    metric = f"eval_loss"
    direction = "min" if "loss" in metric else "max"  # minimize eval loss
    tune_unit = "iter"
    max_t = 40 * 60 if "tune_unit" == "time" else args.epochs * evals_per_epoch
    grade_period = 4 * 60 if tune_unit == "time" else 2
    time_attr = "time_total_s" if tune_unit == "time" else "training_iteration"
    ############################## End of task specific ##############################

    scheduler = ASHAScheduler(
        time_attr=time_attr,
        max_t=max_t,
        metric=metric,
        mode=direction,
        grace_period=grade_period,
    )
    # Do hyperparam optimization with Ray Tune
    best_run = trainer.hyperparameter_search(
        hp_space=lambda _: param_space,
        backend="ray",
        n_trials=n_trials,  # under the hood it calls ray.tune.run(num_samples=n_trials, ...)
        scheduler=scheduler,
        keep_checkpoints_num=None,
        resources_per_trial={"cpu": 1, "gpu": 1},
        name=group,
        local_dir="/fly/ray_results",
        max_failures=9999,  # tolerate OOM
        direction="maximize" if direction == "max" else "minimize",
        compute_objective=partial(get_hpo_metric, metric),
        resume=args.resume,
    )
    del trainer.model
    model = None
    free_memory()  # Re-init model
    trainer.args = _args
    best_hyperparams = best_run.hyperparameters
    # Save the best HP for full training
    print("Best hyperparameters: ", best_hyperparams)

    # Save hyperparams
    run_hp_path = os.path.join(training_args.output_dir, "best_hyperparams.json")
    task_hp_path = os.path.join(task_dir, "best_hyperparams.json")
    json.dump(best_hyperparams, open(run_hp_path, "w"))
    json.dump(best_hyperparams, open(task_hp_path, "w"))

last_ckpt, _ = get_last_checkpoint(training_args.output_dir)
# last_ckpt = os.path.join(last_ckpt, "intervenable_model")

if args.do_train:
    load_best_hp(training_args.output_dir, task_dir)
    # TODO:enable resume
    if args.profile:
        ctx = profiler.profile(
            schedule=profiler.schedule(wait=1, warmup=1, active=2, repeat=1),
            on_trace_ready=profiler.tensorboard_trace_handler(f"./llama_reasoning_{args.mode}_log"),
            record_shapes=True,
            profile_memory=True,
            with_stack=True,
        )
        trainer.add_callback(ProfCallback(ctx))
    else:
        ctx = nullcontext()

    with ctx:
        if args.resume:
            trainer.train(resume_from_checkpoint=last_ckpt)
        else:
            trainer.train()

    # dump config
    args_dict = vars(args)
    args_dict["n_params"] = n_params
    json_file_name = f"{output_dir}/args.json"
    with open(json_file_name, "w") as json_file:
        json.dump(args_dict, json_file, indent=4)

    # save model
    reft_model.save(output_dir)
    trainer.save_state()
    # NOTE: force load best
    trainer._load_best_model()
else:
    print("Skipping training, loading last checkpoint...")
    trainer.model.load_intervention(last_ckpt, include_model=True)

# ensure everything is in eval mode
reft_model.model.eval()
for k, v in reft_model.interventions.items():
    _ = v[0].eval()

print({"n_params": n_params})
# do eval
eval_results = {}
for dataset_name in eval_datasets:
    # split evalset into chunks
    for split, (eval_dataset, data_items) in eval_datasets[dataset_name].items():

        generations, stats = compute_metrics(
            task,
            dataset_name,
            reft_model,
            tokenizer,
            eval_dataset,
            data_items,
            trigger_tokens,
            run_name,
            eval_batch_size,
            data_collator if task in classification_tasks else None,
            split,
            greedy_decoding,
            temperature,
            top_p,
            top_k,
        )

        # log
        eval_results.update(stats)
        if use_wandb:
            wandb.log(stats)
        generations = stats if generations is None else generations
        result_json_file_name = f"{output_dir}/{dataset_name}_{split}_outputs.json"
        with open(result_json_file_name, "w") as json_file:
            json.dump(generations, json_file, indent=4)

# log final eval stats
result_json_file_name = f"{output_dir}/eval_results.json"
eval_results["n_params"] = n_params
with open(result_json_file_name, "w") as json_file:
    json.dump(eval_results, json_file, indent=4)

Metadata

Metadata

Assignees

No one assigned

    Labels

    documentationImprovements or additions to documentation

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions