diff --git a/caikit_nlp/modules/text_generation/peft_config.py b/caikit_nlp/modules/text_generation/peft_config.py index 57c9fc58..781aefef 100644 --- a/caikit_nlp/modules/text_generation/peft_config.py +++ b/caikit_nlp/modules/text_generation/peft_config.py @@ -17,7 +17,7 @@ import os # Third Party -from peft import MultitaskPromptTuningInit +from peft import LoraConfig, MultitaskPromptTuningInit from transformers import AutoConfig # First Party @@ -51,7 +51,7 @@ class TuningType(str, Enum): # MULTITASK_PREFIX_TUNING = "MULTITASK_PREFIX_TUNING" # P_TUNING = "P_TUNING" # PREFIX_TUNING = "PREFIX_TUNING" - # LORA = "LORA" + LORA = "LORA" def resolve_base_model(base_model, cls, torch_dtype): @@ -79,7 +79,19 @@ def resolve_base_model(base_model, cls, torch_dtype): def get_peft_config( - tuning_type, tuning_config, base_model, cls, torch_dtype, verbalizer + tuning_type, + tuning_config, + base_model, + cls, + torch_dtype, + verbalizer, + lora_r=8, + lora_target_modules=None, + lora_alpha=8, + lora_dropout=0.0, + lora_fan_in_fan_out=False, + lora_bias="none", + lora_modules_to_save=None, ): if tuning_type not in TuningType._member_names_: @@ -186,18 +198,29 @@ def get_peft_config( # Take tokenizer name/path from the model tokenizer_name_or_path = base_model.model.config._name_or_path - # Build the peft config; this is how we determine that we want a sequence classifier. - # If we want more types, we will likely need to map this to data model outputs etc. - - # NOTE: We currently only support TEXT as init type, this is to later only easily - # switch to MPT - peft_config = cls.create_hf_tuning_config( - base_model=base_model, - tuning_type=tuning_type, - task_type=task_type, - tokenizer_name_or_path=tokenizer_name_or_path, - tuning_config=tuning_config, - output_model_types=output_model_types, - ) + if tuning_type == TuningType.LORA: + peft_config = LoraConfig( + r=lora_r, + target_modules=lora_target_modules, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + fan_in_fan_out=lora_fan_in_fan_out, + bias=lora_bias, + modules_to_save=lora_modules_to_save, + ) + else: + # Build the peft config; this is how we determine that we want a sequence classifier. + # If we want more types, we will likely need to map this to data model outputs etc. + + # NOTE: We currently only support TEXT as init type, this is to later only easily + # switch to MPT + peft_config = cls.create_hf_tuning_config( + base_model=base_model, + tuning_type=tuning_type, + task_type=task_type, + tokenizer_name_or_path=tokenizer_name_or_path, + tuning_config=tuning_config, + output_model_types=output_model_types, + ) return task_type, output_model_types, peft_config, tuning_type diff --git a/caikit_nlp/modules/text_generation/text_generation_local.py b/caikit_nlp/modules/text_generation/text_generation_local.py index 63807594..a77f77d4 100644 --- a/caikit_nlp/modules/text_generation/text_generation_local.py +++ b/caikit_nlp/modules/text_generation/text_generation_local.py @@ -23,6 +23,7 @@ # Third Party from datasets import Dataset from datasets import IterableDataset as TransformersIterableDataset +from peft import LoraConfig, get_peft_model from transformers import AutoConfig, AutoTokenizer import torch @@ -182,6 +183,7 @@ def train( random_seed: int = RANDOM_SEED, lr: float = 2e-5, use_iterable_dataset: bool = True, + lora_config: LoraConfig = None, **kwargs, ) -> "TextGeneration": """ @@ -214,6 +216,9 @@ def train( Indicates whether or not we should load the full dataset into memory NOTE: use True for this option if you are fine tuning a causal LM with a large target sequence length unless your dataset is VERY small! + lora_config: LoraConfig + If defined, the LoRA technique for fine tuning will be applied (as + opposed to a 'full' fine tuning). **kwargs: Arguments supported by HF Training Arguments. TrainingArguments: @@ -430,6 +435,22 @@ def train( torch_dtype=torch_dtype, ) + if lora_config is not None: + + # Merge Model Here so it is returned + model_to_merge = get_peft_model(model, lora_config) + merged_model = model_to_merge.merge_and_unload() + + # return that + return cls( + model_name=base_model._model_name, + model=merged_model, + bos_token=model.tokenizer.bos_token or None, + sep_token=model.tokenizer.sep_token or None, + eos_token=model.tokenizer.eos_token or None, + pad_token=model.tokenizer.pad_token or None, + training_metadata={"loss": training_loss_history}, + ) return cls( model_name=base_model._model_name, model=model, diff --git a/caikit_nlp/resources/pretrained_model/base.py b/caikit_nlp/resources/pretrained_model/base.py index eba74744..4672011f 100644 --- a/caikit_nlp/resources/pretrained_model/base.py +++ b/caikit_nlp/resources/pretrained_model/base.py @@ -71,6 +71,17 @@ class PretrainedModelBase(ABC, ModuleBase): _MODEL_ARTIFACTS_CONFIG_KEY = "model_artifacts" _LEFT_PAD_MODEL_TYPES = ("gpt", "opt", "bloom") + def named_modules(self): + return self._model.named_modules() + + def get_submodule(self, target: str): + return self._model.get_submodule(target) + + def named_parameters( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ): + return self._model.named_parameters(prefix, recurse, remove_duplicate) + @classmethod @property def REQUIRES_TOKEN_UNWRAPPING(cls) -> str: diff --git a/examples/run_lora_tuning.py b/examples/run_lora_tuning.py new file mode 100644 index 00000000..0beebaac --- /dev/null +++ b/examples/run_lora_tuning.py @@ -0,0 +1,429 @@ +"""This script illustrates how to fine-tune a model. + +Supported model types: +- Seq2Seq LM +""" + +# Standard +from typing import Any, Tuple +import argparse +import json +import os +import shutil + +# Third Party +from peft import LoraConfig +from tqdm import tqdm +from transformers import AutoConfig +from utils import ( + ALOG_OPTS, + SUPPORTED_DATASETS, + SUPPORTED_METRICS, + DatasetInfo, + configure_random_seed_and_logging, + load_model, + print_colored, +) + +# First Party +from caikit.core.data_model import DataStream +from caikit.core.toolkit import wip_decorator +import alog + +# Local +from caikit_nlp.data_model import GenerationTrainRecord, TuningConfig +from caikit_nlp.modules.text_generation import TextGeneration +from caikit_nlp.resources.pretrained_model import ( + HFAutoCausalLM, + HFAutoSeq2SeqLM, + PretrainedModelBase, +) + +# TODO: Remove me once fine-tuning is out of WIP +wip_decorator.disable_wip() + + +def get_resource_type(model_name: str) -> PretrainedModelBase: + """Given a model name, or a path to a model, get the resource type to be initialized. + + Args: + model_name: str + Model name or path to the model to be leveraged. + + Returns: + type + PretrainedModel subclass wrapping the loaded Transformer model, e.g., + a HFAutoCausalLM or HFAutoSeq2SeqLM. We return the type here so that + we can initialize it later, after we show a nice experiment configuration. + """ + try: + model_type = AutoConfig.from_pretrained(model_name).model_type + except OSError: + raise ValueError("Failed to load model with name: {}".format(model_name)) + if model_type in HFAutoCausalLM.SUPPORTED_MODEL_TYPES: + return HFAutoCausalLM + elif model_type in HFAutoSeq2SeqLM.SUPPORTED_MODEL_TYPES: + return HFAutoSeq2SeqLM + raise NotImplementedError( + "Provided is not supported for any supported resource type!" + ) + + +### Functions for arg parsing & validation +def parse_args() -> argparse.Namespace: + """Parse command line arguments. + + Returns: + argparse.Namespace + Parsed arguments to be leveraged for fine tuning application. + """ + parser = argparse.ArgumentParser( + description="Fine-tuning a text generation model.", + ) + # Register all of the common args, as well as specific tuning args for subcommands + register_common_arguments(parser) + + args = parser.parse_args() + # Reconfigure logging level based on verbosity, while preserving filters etc. + + alog_settings = {**ALOG_OPTS, **{"default_level": "debug"}} + alog.configure(**alog_settings) + # Validate common arg values + validate_common_args(args) + return args + + +def register_common_arguments(subparser: argparse.ArgumentParser) -> None: + """Registers common arguments intended to be shared across all subparsers. + + Args: + subparser: argparse.ArgumentParser + Iterable of argument subparsers that should have common args. + """ + subparser.add_argument( + "--dataset", + help="Dataset to use to train prompt vectors. Options: {}".format( + list(SUPPORTED_DATASETS.keys()) + ), + default="twitter_complaints", + ) + subparser.add_argument( + "--model_name", + help="Name of base model or path to model to use to train prompt vectors", + default="t5-small", + ) + subparser.add_argument( + "--output_dir", + help="Name of the directory that we want to export the model to", + default="sample_tuned_model", + ) + subparser.add_argument( + "--num_epochs", + help="Number of epochs to use for prompt tuning", + type=int, + default=10, + ) + subparser.add_argument( + "--learning_rate", + help="Learning rate to use while training", + type=float, + default=1e-4, + ), + subparser.add_argument( + "--batch_size", help="Batch size to use while training", type=int, default=8 + ) + subparser.add_argument( + "--max_source_length", + help="Maximum source sequence length.", + default=256, + type=int, + ) + subparser.add_argument( + "--max_target_length", + help="Maximum target sequence length.", + default=128, + type=int, + ) + subparser.add_argument( + "--lora_alpha", + help="The alpha parameter for Lora scaling", + default=8, + type=int, + ) + subparser.add_argument( + "--lora_r", + help="Lora attention dimension", + default=8, + type=int, + ) + subparser.add_argument( + "--lora_bias", + help="""Bias type for Lora. Can be ‘none’, ‘all’ or ‘lora_only’. If ‘all’ or ‘lora_only’, + the corresponding biases will be updated during training. Be aware that this + means that, even when disabling the adapters, the model will not produce the + same output as the base model would have without adaptation""", + default="none", + choices=["none", "all", "lora_only"], + ) + + subparser.add_argument( + "--lora_dropout", + help="The dropout probability for Lora layers", + type=float, + default=0.0, + ) + + subparser.add_argument( + "--lora_target_modules", + help="The names of the modules to apply Lora to", + nargs="+", + type=str, + default=None, + ) + subparser.add_argument( + "--evaluate", + help="Enable evaluation on trained model", + action="store_true", + ) + subparser.add_argument( + "--preds_file", + help="JSON file to dump raw source / target texts to.", + default="model_preds.json", + ) + subparser.add_argument( + "--torch_dtype", + help="Torch dtype to use for training", + type=str, + default="float16", + ) + subparser.add_argument( + "--metrics", + help="Metrics to calculate. Options: {}".format(list(SUPPORTED_METRICS.keys())), + nargs="*", + default=["accuracy"], + ) + subparser.add_argument( + "--tgis", + help="Run inference using TGIS. NOTE: This involves saving and reloading model in TGIS container", + action="store_true", + ) + subparser.add_argument( + "--iterable_dataset", + help="Enable evaluation on trained model", + action="store_true", + ) + + +def validate_common_args(args: argparse.Namespace): + """Validates common arguments to ensure values make sense; here, we only validate things that + are not (or should not) be handled within the module. + + Args: + args: argparse.Namespace + Parsed args corresponding to one tuning task. + """ + # Validate that the dataset is one of our allowed values + if args.dataset not in SUPPORTED_DATASETS: + raise KeyError( + "[{}] is not a supported dataset; see --help for options.".format( + args.dataset + ) + ) + # Purge our output directory if one already exists. + if os.path.isdir(args.output_dir): + print("Existing model directory found; purging it now.") + shutil.rmtree(args.output_dir) + + +def show_experiment_configuration(args, dataset_info, model_type) -> None: + """Show the complete configuration for this experiment, i.e., the model info, + the resource type we built, the training params, metadata about the dataset where + possible, and so on. + + Args: + args: argparse.Namespace + Parsed args corresponding to one tuning task. + dataset_info: DatasetInfo + Dataset information. + model_type: type + Resource class corresponding to the base model. + """ + print_strs = [ + "Experiment Configuration", + "- Model Name: [{}]".format(args.model_name), + " |- Inferred Model Resource Type: [{}]".format(model_type), + "- Dataset: [{}]".format(args.dataset), + "- Number of Epochs: [{}]".format(args.num_epochs), + "- Learning Rate: [{}]".format(args.learning_rate), + "- Batch Size: [{}]".format(args.batch_size), + "- Output Directory: [{}]".format(args.output_dir), + "- Maximum source sequence length: [{}]".format(args.max_source_length), + "- Maximum target sequence length: [{}]".format(args.max_target_length), + "- Enable evaluation: [{}]".format(args.evaluate), + "- Evaluation metrics: [{}]".format(args.metrics), + "- Torch dtype to use for training: [{}]".format(args.torch_dtype), + "- Using iterable dataset: [{}]".format(args.iterable_dataset), + "- LoRA r: [{}]".format(args.lora_r), + "- LoRA Alpha: [{}]".format(args.lora_alpha), + "- LoRA Bias: [{}]".format(args.lora_bias), + "- LoRA Dropout: [{}]".format(args.lora_dropout), + "- LoRA Target Modules: [{}]".format(args.lora_target_modules), + ] + # Log and sleep for a few seconds in case people actually want to read this... + print_colored("\n".join([print_str for print_str in print_strs if print_str])) + + +def get_model_preds_and_references( + model, validation_stream, truncate_input_tokens, max_new_tokens +): + """Given a model & a validation stream, run the model against every example in the validation + stream and compare the outputs to the target/output sequence. + + Args: + model + Fine-tuned Model to be evaluated (may leverage different backends). + validation_stream: DataStream[GenerationTrainRecord] + Validation stream with labeled targets that we want to compare to our model's + predictions. + truncate_input_tokens: int + maximum number of tokens to be accepted by the model and rest will be + truncated. + Returns: + Tuple(List) + Tuple of 2 lists; the model predictions and the expected output sequences. + """ + model_preds = [] + targets = [] + + for datum in tqdm(validation_stream): + # Local .run() currently prepends the input text to the generated string; + # Ensure that we're just splitting the first predicted token & beyond. + raw_model_text = model.run( + datum.input, + truncate_input_tokens=truncate_input_tokens, + max_new_tokens=max_new_tokens, + ).generated_text + parse_pred_text = raw_model_text.split(datum.input)[-1].strip() + model_preds.append(parse_pred_text) + targets.append(datum.output) + return ( + model_preds, + targets, + ) + + +def export_model_preds(preds_file, predictions, validation_stream): + """Exports a JSON file containing a list of objects, where every object contains: + - source: str - Source string used for generation. + - target: str - Ground truth target label used for generation. + - predicted_target: str - Predicted model target. + + Args: + preds_file: str + Path on disk to JSON file to be written. + predictions: List + Model prediction list, where each predicted text excludes source text as a prefix. + validation_stream: DataStream + Datastream object of GenerationTrainRecord objects used for validation against a model + to generate predictions. + """ + pred_objs = [] + for pred, record in zip(predictions, validation_stream): + pred_objs.append( + { + "source": record.input, + "target": record.output, + "predicted_target": pred, + } + ) + with open(preds_file, "w") as jfile: + json.dump(pred_objs, jfile, indent=4, sort_keys=True) + + +if __name__ == "__main__": + configure_random_seed_and_logging() + args = parse_args() + model_type = get_resource_type(args.model_name) + # Unpack the dataset dictionary into a loaded dataset & verbalizer + dataset_info = SUPPORTED_DATASETS[args.dataset] + show_experiment_configuration(args, dataset_info, model_type) + # Convert the loaded dataset to a stream + print_colored("[Loading the dataset...]") + # TODO - conditionally enable validation stream + train_stream = dataset_info.dataset_loader()[0] + + # Init the resource & Build the tuning config from our dataset/arg info + print_colored("[Loading the base model resource...]") + base_model = model_type.bootstrap( + args.model_name, tokenizer_name=args.model_name, torch_dtype=args.torch_dtype + ) + + lora_config = LoraConfig( + lora_alpha=args.lora_alpha, + r=args.lora_r, + bias=args.lora_bias, + lora_dropout=args.lora_dropout, + target_modules=args.lora_target_modules, + ) + + # Then actually train the model & save it + print_colored("[Starting the training...]") + + model = TextGeneration.train( + base_model, + train_stream, + max_source_length=args.max_source_length, + max_target_length=args.max_target_length, + lr=args.learning_rate, + torch_dtype=args.torch_dtype, + batch_size=args.batch_size, + num_epochs=args.num_epochs, + use_iterable_dataset=args.iterable_dataset, + ) + + print_colored("[Training Complete]") + + # Prediction + # sample_text = "summarize: The Inflation Reduction Act lowers prescription drug costs, health care costs, and energy costs. It's the most aggressive action on tackling the climate crisis in American history, which will lift up American workers and create good-paying, union jobs across the country. It'll lower the deficit and ask the ultra-wealthy and corporations to pay their fair share. And no one making under $400,000 per year will pay a penny more in taxes." + # prediction_results = model.run(sample_text) + # + # print("Generated text: ", prediction_results) + + # Saving model + model.save(args.output_dir) + + if args.tgis: + + # Load model in TGIS + # HACK: export args.output_dir as MODEL_NAME for TGIS + # container to pick up automatically + os.environ["MODEL_DIR"] = os.path.dirname(args.output_dir) + os.environ["MODEL_NAME"] = os.path.join( + "models", os.path.basename(args.output_dir), "artifacts" + ) + + loaded_model = load_model(is_distributed=True, model_path=args.output_dir) + + else: + # Use trained model directly + loaded_model = model + + ## Evaluation + print_colored("[Starting Evaluation]") + + validation_stream = dataset_info.dataset_loader()[1] + + print_colored("Getting model predictions...") + truncate_input_tokens = args.max_source_length + args.max_target_length + predictions, references = get_model_preds_and_references( + loaded_model, validation_stream, truncate_input_tokens, args.max_target_length + ) + + export_model_preds(args.preds_file, predictions, validation_stream) + + metric_funcs = [SUPPORTED_METRICS[metric_name] for metric_name in args.metrics] + print_colored("Metrics to be calculated: {}".format(args.metrics)) + + for metric_func in metric_funcs: + metric_res = metric_func(predictions=predictions, references=references) + print_colored(metric_res) diff --git a/tests/modules/text_generation/test_text_generation_local.py b/tests/modules/text_generation/test_text_generation_local.py index d83ed380..13496a03 100644 --- a/tests/modules/text_generation/test_text_generation_local.py +++ b/tests/modules/text_generation/test_text_generation_local.py @@ -6,6 +6,7 @@ import tempfile # Third Party +from peft import LoraConfig import pytest import torch @@ -155,6 +156,58 @@ def test_train_model_causallm(disable_wip, set_cpu_device): assert isinstance(pred, GeneratedTextResult) +@pytest.mark.skipif(platform.processor() == "arm", reason="ARM training not supported") +def test_train_lora_model_causallm(disable_wip, set_cpu_device): + """Ensure that we can finetune a causal-lm model on some toy data for 1+ + steps & run inference.""" + train_kwargs = { + "base_model": HFAutoCausalLM.bootstrap( + model_name=CAUSAL_LM_MODEL, tokenizer_name=CAUSAL_LM_MODEL + ), + "num_epochs": 1, + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + GenerationTrainRecord( + input="@foo what a cute dog!", output="no complaint" + ), + ] + ), + "torch_dtype": torch.float32, + "lora_config": LoraConfig(target_modules=["query_key_value"]), + } + model = TextGeneration.train(**train_kwargs) + + # Ensure that we can get something out of it + pred = model.run("@bar what a cute cat!") + assert isinstance(pred, GeneratedTextResult) + + +@pytest.mark.skipif(platform.processor() == "arm", reason="ARM training not supported") +def test_train_lora_model_seq2seq(disable_wip, set_cpu_device): + """Ensure that we can finetune a seq2seq-lm model on some toy data for 1+ + steps & run inference.""" + train_kwargs = { + "base_model": HFAutoSeq2SeqLM.bootstrap( + model_name=SEQ2SEQ_LM_MODEL, tokenizer_name=SEQ2SEQ_LM_MODEL + ), + "num_epochs": 1, + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + GenerationTrainRecord( + input="@foo what a cute dog!", output="no complaint" + ), + ] + ), + "torch_dtype": torch.float32, + "lora_config": LoraConfig(target_modules=["q"]), + } + model = TextGeneration.train(**train_kwargs) + + # Ensure that we can get something out of it + pred = model.run("@bar what a cute cat!") + assert isinstance(pred, GeneratedTextResult) + + ############################## Error Cases ################################