diff --git a/caikit_nlp/data_model/generation.py b/caikit_nlp/data_model/generation.py index 4fec2711..3821fc65 100644 --- a/caikit_nlp/data_model/generation.py +++ b/caikit_nlp/data_model/generation.py @@ -15,7 +15,7 @@ """ # Standard from enum import Enum -from typing import List +from typing import List, Optional, Union # First Party from caikit.core import DataObjectBase @@ -73,6 +73,26 @@ class TuningConfig(DataObjectBase): # encoder_hidden_size: int # Optional - The hidden size of the prompt encoder. +@caikit.core.dataobject(package="caikit_data_model.caikit_nlp") +class LoraTuningConfig(DataObjectBase): + # Lora attention dimension. + r: int + # List of module names or regex expression of the module names to replace with Lora. + # For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$ + target_modules: Optional[Union[List[str], str]] = None + # The alpha parameter for Lora scaling. + lora_alpha: Optional[int] = 8 + # The dropout probability for Lora layers. + lora_dropout: Optional[float] = 0.0 + # Bias type for Lora. Can be ‘none’, ‘all’ or ‘lora_only’. + # If ‘all’ or ‘lora_only’, the corresponding biases will be updated during training. + # Be aware that this means that, even when disabling the adapters, + # the model will not produce the same output + # as the base model would have without adaptation. + bias: Optional[str] = "none" + output_model_types: List[str] + + @caikit.core.dataobject(package="caikit_data_model.caikit_nlp") class ExponentialDecayLengthPenalty(DataObjectBase): start_index: int diff --git a/caikit_nlp/modules/text_generation/peft_config.py b/caikit_nlp/modules/text_generation/peft_config.py index af2589e5..a06a4175 100644 --- a/caikit_nlp/modules/text_generation/peft_config.py +++ b/caikit_nlp/modules/text_generation/peft_config.py @@ -18,7 +18,7 @@ import re # Third Party -from peft import MultitaskPromptTuningInit +from peft import LoraConfig, MultitaskPromptTuningInit, PromptTuningConfig from transformers import AutoConfig # First Party @@ -55,10 +55,10 @@ class TuningType(str, Enum): PROMPT_TUNING = "PROMPT_TUNING" MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING" + LORA = "LORA" # MULTITASK_PREFIX_TUNING = "MULTITASK_PREFIX_TUNING" # P_TUNING = "P_TUNING" # PREFIX_TUNING = "PREFIX_TUNING" - # LORA = "LORA" def resolve_base_model(base_model, cls, torch_dtype): @@ -101,12 +101,142 @@ def resolve_base_model(base_model, cls, torch_dtype): def get_peft_config( - tuning_type, tuning_config, base_model, cls, torch_dtype, verbalizer + tuning_type, + tuning_config, + base_model, + cls=None, + torch_dtype=None, + verbalizer="{{input}}", ): + if isinstance(tuning_type, str): + error.value_check( + "", + tuning_type in TuningType._member_names_, + f"Invalid tuning type [{tuning_type}]. Allowed types: " + f"[{TuningType._member_names_}]", + ) + tuning_type = TuningType(tuning_type) + + error.type_check("", TuningType, tuning_type=tuning_type) + if tuning_type not in TuningType._member_names_: raise NotImplementedError("{} tuning type not supported!".format(tuning_type)) + error.type_check("", PretrainedModelBase, base_model=base_model) + + # Validate if tuned output model type is compatible with base model or not + output_model_types = _get_output_types(tuning_config, base_model) + + # NOTE: Base model is a resource at this point + task_type = base_model.TASK_TYPE + + error.value_check( + "", + len(output_model_types) <= base_model.MAX_NUM_TRANSFORMERS, + f"Too many output model types. Got {len(output_model_types)}, " + f"maximum {base_model.MAX_NUM_TRANSFORMERS}", + ) + + if verbalizer: + log.warning( + "", + "verbalizer parameter is DEPRECATED for this function \ + and will be removed in future. \ + This parameter is also not getting used in creation of peft config", + ) + # Ensure that our verbalizer is a string and + # will not render to a hardcoded string + # TODO: This check should happen in prompt tuning module and not here + error.value_check( + "", + is_valid_verbalizer(verbalizer), + "Provided verbalizer is an invalid type or has no renderable placeholders", + ) + + if torch_dtype: + log.warning( + "", + "torch_dtype parameter is DEPRECATED for this function \ + and will be removed in future. \ + This parameter is also not getting used in creation of peft config", + ) + torch_dtype = get_torch_dtype(torch_dtype) + print("type", tuning_type) + if tuning_type in [ + TuningType.PROMPT_TUNING, + TuningType.MULTITASK_PROMPT_TUNING, + ]: + peft_config = _create_prompt_tuning_config( + tuning_type, tuning_config, cls, base_model, task_type, output_model_types + ) + else: + # we only have Lora besides other two for now + peft_config = _create_lora_config(tuning_config, task_type) + + return task_type, output_model_types, peft_config, tuning_type + + +def _get_output_types(tuning_config, base_model): + "Validate and return output_model_types" + # Validate if tuned output model type is compatible with base model or not + if not tuning_config.output_model_types: + output_model_types = base_model.PROMPT_OUTPUT_TYPES + else: + # If the first element is not PromptOutputModelType, assume the entire list + # isn't and convert + if not isinstance(tuning_config.output_model_types[0], PromptOutputModelType): + output_model_types = [] + for output_type in tuning_config.output_model_types: + output_model_types.append(PromptOutputModelType(output_type)) + else: + output_model_types = tuning_config.output_model_types + error.value_check( + "", + all( + output_type in base_model.PROMPT_OUTPUT_TYPES + for output_type in output_model_types + ), + "{} not supported for base model type {}".format( + output_model_types, base_model.MODEL_TYPE + ), + ) + return output_model_types + + +def _filter_params_for_prompt_config(prompt_config, params): + """Utility function to filter out required parameters for prompt_config + from `params` + + Args: + prompt_config: PromptTuningConfig + Tuning config type, eg:, PromptTuningConfig + params: dict + Dictionary containing all the input training params + + Returns: + dict: + Dictionary containing required params for prompt_config + """ + # Inspect the underlying dataclass fileds; we do this because the common super class + # used for multi/vanilla prompt/prefix tuning is a DataClass; we can't use __dict__ + # because the dataclass fields are omitted. + allowed_keys = list(prompt_config.__dataclass_fields__.keys()) + allowed_params = dict(filter(lambda x: x[0] in allowed_keys, params.items())) + log.info( + "", + "[{}] config params not supported by provided tuning type!".format( + params.keys() - allowed_params.keys() + ), + ) + return allowed_params + + +def _create_prompt_tuning_config( + tuning_type, tuning_config, cls, base_model, task_type, output_model_types +) -> PromptTuningConfig: + """Creates Huggingface PromptTuningConfig from Caikit tuning configuration.""" + if tuning_config.prompt_tuning_init_method: # NOTE: GK-APR-5-2023 # MultitaskPromptTuningInit and MultitaskPrefixTuningInit are same at the @@ -149,62 +279,6 @@ def get_peft_config( tuning_config.prompt_tuning_init_source_model, ) - error.type_check("", PretrainedModelBase, base_model=base_model) - - # Validate if tuned output model type is compatible with base model or not - if not tuning_config.output_model_types: - output_model_types = base_model.PROMPT_OUTPUT_TYPES - else: - # If the first element is not PromptOutputModelType, assume the entire list - # isn't and convert - if not isinstance(tuning_config.output_model_types[0], PromptOutputModelType): - output_model_types = [] - for output_type in tuning_config.output_model_types: - output_model_types.append(PromptOutputModelType(output_type)) - else: - output_model_types = tuning_config.output_model_types - error.value_check( - "", - all( - output_type in base_model.PROMPT_OUTPUT_TYPES - for output_type in output_model_types - ), - "{} not supported for base model type {}".format( - output_model_types, base_model.MODEL_TYPE - ), - ) - - error.value_check( - "", - len(output_model_types) <= base_model.MAX_NUM_TRANSFORMERS, - f"Too many output model types. Got {len(output_model_types)}, " - f"maximum {base_model.MAX_NUM_TRANSFORMERS}", - ) - # Ensure that our verbalizer is a string and will not render to a hardcoded string - error.value_check( - "", - is_valid_verbalizer(verbalizer), - "Provided verbalizer is an invalid type or has no renderable placeholders", - ) - - # NOTE: Base model is a resource at this point - task_type = base_model.TASK_TYPE - - if isinstance(tuning_type, str): - error.value_check( - "", - tuning_type in TuningType._member_names_, - f"Invalid tuning type [{tuning_type}]. Allowed types: " - f"[{TuningType._member_names_}]", - ) - tuning_type = TuningType(tuning_type) - error.type_check("", TuningType, tuning_type=tuning_type) - - # Coerce the passed model into a resource; if we have one, this is a noop - # TODO: When splitting up this mono-module, use the configured resource - # type of the concrete class to bootstrap - torch_dtype = get_torch_dtype(torch_dtype) - # Take tokenizer name/path from the model tokenizer_name_or_path = base_model.model.config._name_or_path @@ -213,7 +287,7 @@ def get_peft_config( # NOTE: We currently only support TEXT as init type, this is to later only easily # switch to MPT - peft_config = cls.create_hf_tuning_config( + prompt_tuning_config = cls.create_hf_tuning_config( base_model=base_model, tuning_type=tuning_type, task_type=task_type, @@ -221,5 +295,16 @@ def get_peft_config( tuning_config=tuning_config, output_model_types=output_model_types, ) + return prompt_tuning_config - return task_type, output_model_types, peft_config, tuning_type + +def _create_lora_config(tuning_config, task_type) -> LoraConfig: + """Creates Huggingface LoraConfig from Caikit tuning configuration.""" + + config_kwargs = tuning_config.to_dict() + log.info("", f"Parameters used: {config_kwargs}") + config_params = _filter_params_for_prompt_config(tuning_config, config_kwargs) + if "output_model_types" in config_params: + del config_params["output_model_types"] + lora_config = LoraConfig(task_type=task_type, **config_params) + return lora_config diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index 435e439d..2f086eb8 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -53,6 +53,7 @@ from ...data_model import ( ExponentialDecayLengthPenalty, GenerationTrainRecord, + LoraTuningConfig, PromptOutputModelType, TuningConfig, ) @@ -73,7 +74,12 @@ ) from ...toolkit.trainer_utils import validate_training_data from ...toolkit.verbalizer_utils import render_verbalizer -from .peft_config import TuningType, get_peft_config, resolve_base_model +from .peft_config import ( + TuningType, + _filter_params_for_prompt_config, + get_peft_config, + resolve_base_model, +) log = alog.use_channel("PEFT_PROMPT") error = error_handler.get(log) @@ -98,10 +104,10 @@ class PeftPromptTuning(ModuleBase): tuning_type_to_huggingface = { TuningType.PROMPT_TUNING: PeftType.PROMPT_TUNING, TuningType.MULTITASK_PROMPT_TUNING: PeftType.MULTITASK_PROMPT_TUNING, + TuningType.LORA: PeftType.LORA, # TuningType.MULTITASK_PREFIX_TUNING: PeftType.MULTITASK_PREFIX_TUNING, # TuningType.P_TUNING: PeftType.P_TUNING, # TuningType.PREFIX_TUNING: PeftType.PREFIX_TUNING, - # TuningType.LORA: PeftType.LORA, } RANDOM_SEED = 73 @@ -186,6 +192,10 @@ def run( verbalized_text = render_verbalizer(self.verbalizer, {"input": text}) + mpt = False + if self.tuning_type == TuningType.MULTITASK_PROMPT_TUNING: + mpt = True + return generate_text_func( self.model, self.tokenizer, @@ -207,6 +217,7 @@ def run( stop_sequences=stop_sequences, preserve_input_text=preserve_input_text, task_type=self.task_type, + mpt=mpt, ) # NOTE: We need to disable wip decorator here otherwise we get issues in @@ -281,7 +292,7 @@ def train( DataStream[GenerationTrainRecord], DataStream[ClassificationTrainRecord], ], - tuning_config: TuningConfig, + tuning_config: Union[TuningConfig, LoraTuningConfig], val_stream: Optional[ Union[ DataStream[GenerationTrainRecord], @@ -310,7 +321,7 @@ def train( Base resource model used for underlying generation. train_stream: DataStream[GenerationTrainRecord] or DataStream[ClassificationTrainRecord] Data to be used for training the prompt vectors of the generation model. - tuning_config: TuningConfig + tuning_config: Union[TuningConfig, LoraTuningConfig] Additional model tuning configurations to be considered for prompt vector initialization and training behavior. val_stream: Optional[DataStream[GenerationTrainRecord] @@ -853,7 +864,7 @@ def create_hf_tuning_config( elif tuning_type == TuningType.MULTITASK_PROMPT_TUNING: tuning_config_type = MultitaskPromptTuningConfig - config_params = cls._filter_params_for_prompt_config( + config_params = _filter_params_for_prompt_config( tuning_config_type, config_kwargs ) log.info("", f"Parameters used: {config_params}") @@ -902,34 +913,6 @@ def _get_collate_fn(tokenizer: AutoTokenizer, task_type: str) -> Callable: # want to set labels ourselves. TODO: centralize collator management. return default_data_collator - @classmethod - def _filter_params_for_prompt_config(cls, prompt_config, params): - """Utility function to filter out required parameters for prompt_config - from `params` - - Args: - prompt_config: PromptTuningConfig - Tuning config type, eg:, PromptTuningConfig - params: dict - Dictionary containing all the input training params - - Returns: - dict: - Dictionary containing required params for prompt_config - """ - # Inspect the underlying dataclass fileds; we do this because the common super class - # used for multi/vanilla prompt/prefix tuning is a DataClass; we can't use __dict__ - # because the dataclass fields are omitted. - allowed_keys = list(prompt_config.__dataclass_fields__.keys()) - allowed_params = dict(filter(lambda x: x[0] in allowed_keys, params.items())) - log.info( - "", - "[{}] config params not supported by provided tuning type!".format( - params.keys() - allowed_params.keys() - ), - ) - return allowed_params - @staticmethod def convert_peft_model_to_type( device: str, peft_model: PeftModel, torch_dtype=Union[str, torch.dtype] diff --git a/caikit_nlp/toolkit/text_generation/model_run_utils.py b/caikit_nlp/toolkit/text_generation/model_run_utils.py index 07f7ba29..3cadf77c 100644 --- a/caikit_nlp/toolkit/text_generation/model_run_utils.py +++ b/caikit_nlp/toolkit/text_generation/model_run_utils.py @@ -154,6 +154,7 @@ def generate_text_func( stop_sequences: Optional[List[str]] = None, preserve_input_text: Optional[bool] = True, task_type: Optional[str] = None, + mpt: bool = False, **kwargs, ): """ @@ -224,10 +225,11 @@ def generate_text_func( # NOTE: Below is required as `task_id` is a required field for generation # with MPT in PEFT. We are manually setting task id to 0 vector since # we do not allow setting task specific id anyways. - if isinstance(model, PeftModel): - gen_optional_params["task_ids"] = torch.zeros( - inputs["input_ids"].shape[0], dtype=inputs["input_ids"].dtype - ).to(model.device) + if mpt: + if isinstance(model, PeftModel): + gen_optional_params["task_ids"] = torch.zeros( + inputs["input_ids"].shape[0], dtype=inputs["input_ids"].dtype + ).to(model.device) with torch.no_grad(): generate_ids = model.generate( diff --git a/examples/run_lora_tuning.py b/examples/run_lora_tuning.py new file mode 100644 index 00000000..ca2c2362 --- /dev/null +++ b/examples/run_lora_tuning.py @@ -0,0 +1,385 @@ +"""This script illustrates how to train a prompt tuned model using prompt tuning / MPT. +Supported tuning types: +- Lora Tuning + +Supported model types: +- Causal LM +- Seq2Seq LM +""" +# Standard +from collections import namedtuple +from typing import Any, Tuple +import argparse +import os +import pathlib +import random +import shutil +import time + +# Third Party +from transformers import AutoConfig +from utils import ( + ALOG_OPTS, + SUPPORTED_DATASETS, + DatasetInfo, + configure_random_seed_and_logging, + print_colored, +) +import datasets + +# First Party +from caikit.core.data_model import DataStream +import alog +import caikit + +# Local +from caikit_nlp.data_model import GenerationTrainRecord, LoraTuningConfig, TuningConfig +from caikit_nlp.modules.text_generation.peft_prompt_tuning import ( + PeftPromptTuning, + TuningType, +) +from caikit_nlp.resources.pretrained_model import ( + HFAutoCausalLM, + HFAutoSeq2SeqLM, + PretrainedModelBase, +) + + +def subsample_stream( + train_stream: DataStream[GenerationTrainRecord], num_shots: int +) -> DataStream[GenerationTrainRecord]: + """Given a training stream of length n, randomly extract num_shots <= n samples from it + for use in few-shot learning. + + Args: + train_stream: DataStream + Full dataset to be used for training prior to few shot sampling. + num_shots: int + Number of samples to keep for use in training. + + Returns: + DataStream[GenerationTrainRecord] + Train subsampled stream of len(x) == num_shots + """ + num_samples = len(train_stream) + if num_shots > num_samples or num_shots <= 0: + raise ValueError( + "num_shots [{}] is less than 0 or exceeds train data size: [{}]".format( + num_shots, num_samples + ) + ) + # If we have the same number of samples as shots, just give the raw stream back + elif num_shots == num_samples: + return train_stream + # Otherwise subsample the stream to condense its length; shuffle using + # the whole stream as a buffer, and build a new stream from the result. + # NOTE - this is not a great idea, but for now we do this, so that the sampling + # is exactly the same as the original MPT code, since sampling the whole dataset + # with a max buffer would load everything into memory anyway. + shuffled_dataset = random.sample(list(train_stream), num_shots) + return DataStream.from_iterable(shuffled_dataset) + + +def get_resource_type(model_name: str) -> PretrainedModelBase: + """Given a model name, or a path to a model, get the resource type to be initialized. + + Args: + model_name: str + Model name or path to the model to be leveraged. + + Returns: + type + PretrainedModel subclass wrapping the loaded Transformer model, e.g., + a HFAutoCausalLM or HFAutoSeq2SeqLM. We return the type here so that + we can initialize it later, after we show a nice experiment configuration. + """ + try: + model_type = AutoConfig.from_pretrained(model_name).model_type + except OSError: + raise ValueError("Failed to load model with name: {}".format(model_name)) + if model_type in HFAutoCausalLM.SUPPORTED_MODEL_TYPES: + return HFAutoCausalLM + elif model_type in HFAutoSeq2SeqLM.SUPPORTED_MODEL_TYPES: + return HFAutoSeq2SeqLM + raise NotImplementedError( + "Provided is not supported for any supported resource type!" + ) + + +### Functions for arg parsing & validation +def parse_args() -> argparse.Namespace: + """Parse command line arguments. Here, we set up each tuning task as a subparser + to prevent the arguments from being too confusing. Common arguments, e.g., the number + of virtual tokens, are added to all parsers. + + Returns: + argparse.Namespace + Parsed arguments to be leveraged for one prompt tuning application. + """ + parser = argparse.ArgumentParser( + description="Train prompt vectors on top of a text generation model.", + ) + ### Args specific to subparsers, i.e., tuning / training arguments + subparsers = parser.add_subparsers( + help="The type of tuning to apply.", dest="tuning_type", required=True + ) + # NOTE: These keys should line up with the TuningType enum values + parser_LoRA_tuning = subparsers.add_parser( + "LORA", + help="Train prompt vectors through LoRA Tuning.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + subparsers = ( + parser_LoRA_tuning, + ) + # Register all of the common args, as well as specific tuning args for subcommands + register_common_arguments(subparsers) + register_lora_tuning_args(parser_LoRA_tuning) + args = parser.parse_args() + # Reconfigure logging level based on verbosity, while preserving filters etc. + default_level = "debug" if args.verbose else "info" + alog_settings = {**ALOG_OPTS, **{"default_level": default_level}} + alog.configure(**alog_settings) + # Validate common arg values + validate_common_args(args) + return args + + +def register_common_arguments(subparsers: Tuple[argparse.ArgumentParser]) -> None: + """Registers common arguments intended to be shared across all subparsers. + + Args: + subparsers: Tuple[argparse.ArgumentParser] + Iterable of argument subparsers that should have common args. + """ + for subparser in subparsers: + subparser.add_argument( + "--dataset", + help="Dataset to use to train prompt vectors. Options: {}".format( + list(SUPPORTED_DATASETS.keys()) + ), + default="twitter_complaints", + ) + subparser.add_argument( + "--model_name", + help="Name of base model or path to model to use to train prompt vectors", + default="bigscience/bloom-560m", + ) + subparser.add_argument( + "--output_dir", + help="Name of the directory that we want to export the model to", + default="sample_prompt", + ) + subparser.add_argument( + "--prompt_only", + help="Indicates that we do not need to export the full model, just the prompt vectors", + action="store_true", + ) + subparser.add_argument( + "--verbose", + help="If enabled, shows TQDM progress bars & debug logs", + action="store_true", + ) + subparser.add_argument( + "--num_epochs", + help="Number of epochs to use for prompt tuning", + type=int, + default=10, + ) + subparser.add_argument( + "--learning_rate", + help="Learning rate to use while training", + type=float, + default=3e-2, + ), + subparser.add_argument( + "--num_shots", + help="Number of training samples to use for few-shot learning", + type=int, + default=None, + ), + subparser.add_argument( + "--batch_size", help="Batch size to use while training", type=int, default=8 + ) + subparser.add_argument( + "--max_source_length", + help="Maximum source sequence length.", + default=256, + type=int, + ) + subparser.add_argument( + "--max_target_length", + help="Maximum target sequence length.", + default=128, + type=int, + ) + subparser.add_argument( + "--accumulate_steps", + help="Gradient accumulation steps", + default=1, + type=int, + ) + subparser.add_argument( + "--torch_dtype", + help="Torch dtype to be used for training", + default="float32", + choices=["float16", "bfloat16", "float32"], + ) + + +def register_lora_tuning_args(subparser: argparse.ArgumentParser): + """Register additional configuration options for prompt tuning subtask. + + Args: + subparser: argparser.ArgumentParser + Configuration options for prompt tuning specifically. + """ + subparser.add_argument( + "--r", + help="Rank used for LoRA", + type = int, + default= 8, + ) + subparser.add_argument( + "--lora_alpha", + help="LoRA scaling factor.", + type = int, + default = 8, + ) + subparser.add_argument( + "--lora_dropout", + help="LoRA dropout.", + type = int, + default = 0.0, + ) + subparser.add_argument( + "--bias", + help="Specifies if the bias parameters should be trained.", + type = str, + default = "none", + choices = ["none", "all", "lora_only"] + ) + + + + +def validate_common_args(args: argparse.Namespace): + """Validates common arguments to ensure values make sense; here, we only validate things that + are not (or should not) be handled within the module. + + Args: + args: argparse.Namespace + Parsed args corresponding to one tuning task. + """ + # Validate that the dataset is one of our allowed values + if args.dataset not in SUPPORTED_DATASETS: + raise KeyError( + "[{}] is not a supported dataset; see --help for options.".format( + args.dataset + ) + ) + # Purge our output directory if one already exists. + if os.path.isdir(args.output_dir): + print("Existing model directory found; purging it now.") + shutil.rmtree(args.output_dir) + +def build_lora_config(args: argparse.Namespace, dataset_info: DatasetInfo): + """Builds the tuning config for this tuning task. + + Args: + args: argparse.Namespace + Parsed args corresponding to one tuning task. + dataset_info: DatasetInfo + Dataset information, including text to be used to initialize prompt tuning if + that's the selected initialization scheme. + + Returns + LoraTuningConfig + Tuning config object to be provided at .train() time. + """ + base_kwargs = { + "r": args.r, + "lora_alpha": args.lora_alpha, + "lora_dropout": args.lora_dropout, + "bias": args.bias + } + return LoraTuningConfig(**base_kwargs) + + +def show_experiment_configuration(args, dataset_info, model_type) -> None: + """Show the complete configuration for this experiment, i.e., the model info, + the resource type we built, the training params, metadata about the dataset where + possible, and so on. + + Args: + args: argparse.Namespace + Parsed args corresponding to one tuning task. + dataset_info: DatasetInfo + Dataset information, including text to be used to initialize prompt tuning if + that's the selected initialization scheme. + model_type: type + Resource class corresponding to the base model. + """ + print_strs = [ + "Experiment Configuration", + "- Model Name: [{}]".format(args.model_name), + " |- Inferred Model Resource Type: [{}]".format(model_type), + "- Tuning Type: [{}]".format(args.tuning_type), + "- Dataset: [{}]".format(args.dataset), + "- Verbalizer: [{}]".format(dataset_info.verbalizer), + "- Number of Epochs: [{}]".format(args.num_epochs), + "- Learning Rate: [{}]".format(args.learning_rate), + "- Batch Size: [{}]".format(args.batch_size), + "- Output Directory: [{}]".format(args.output_dir), + "- Exporting prompt only: [{}]".format(args.prompt_only), + "- Number of shots: [{}]".format(args.num_shots), + "- Maximum source sequence length: [{}]".format(args.max_source_length), + "- Maximum target sequence length: [{}]".format(args.max_target_length), + "- Gradient accumulation steps: [{}]".format(args.accumulate_steps), + "- LoRA Rank: [{}]".format(args.r), + "- LoRA Alpha: [{}]".format(args.lora_alpha), + "- LoRA dropout: [{}]".format(args.lora_dropout), + "- bias: [{}]".format(args.bias) + + ] + # Log and sleep for a few seconds in case people actually want to read this... + print_colored("\n".join([print_str for print_str in print_strs if print_str])) + + +if __name__ == "__main__": + configure_random_seed_and_logging() + args = parse_args() + model_type = get_resource_type(args.model_name) + # Unpack the dataset dictionary into a loaded dataset & verbalizer + dataset_info = SUPPORTED_DATASETS[args.dataset] + show_experiment_configuration(args, dataset_info, model_type) + # Convert the loaded dataset to a stream + print_colored("[Loading the dataset...]") + # TODO - conditionally enable validation stream + train_stream = dataset_info.dataset_loader()[0] + if args.num_shots is not None: + train_stream = subsample_stream(train_stream, args.num_shots) + # Init the resource & Build the tuning config from our dataset/arg info + print_colored("[Loading the base model resource...]") + base_model = model_type.bootstrap(args.model_name, tokenizer_name=args.model_name) + tuning_config = build_lora_config(args) + # Then actually train the model & save it + print_colored("[Starting the training...]") + model = PeftPromptTuning.train( + base_model, + train_stream, + tuning_config, + val_stream=None, + max_source_length=args.max_source_length, + max_target_length=args.max_target_length, + tuning_type=args.tuning_type, + num_epochs=args.num_epochs, + lr=args.learning_rate, + batch_size=args.batch_size, + verbalizer=dataset_info.verbalizer, + silence_progress_bars=not args.verbose, + accumulate_steps=args.accumulate_steps, + torch_dtype=args.torch_dtype, + ) + model.save(args.output_dir, save_base_model=not args.prompt_only) + print_colored("[Training Complete]") diff --git a/tests/data_model/test_generation.py b/tests/data_model/test_generation.py index 64b5dc08..2535be0d 100644 --- a/tests/data_model/test_generation.py +++ b/tests/data_model/test_generation.py @@ -14,6 +14,7 @@ # Local from caikit_nlp.data_model import ExponentialDecayLengthPenalty +from caikit_nlp.data_model.generation import LoraTuningConfig ## Setup ######################################################################### @@ -43,3 +44,17 @@ def test_sampling_parameters_from_json_and_back(): ) assert new.start_index == 1 assert new.decay_factor == 0.95 + + +## Setup ######################################################################### + +LoraConfig = LoraTuningConfig(r=1) + +## Tests ######################################################################## +def test_loraconfig_fields_accessible(): + assert LoraConfig.r == 1 + + +def test_from_proto_and_back(): + new = LoraTuningConfig.from_proto(LoraConfig.to_proto()) + assert new.r == 1 diff --git a/tests/modules/text_generation/test_peft_config.py b/tests/modules/text_generation/test_peft_config.py index 7f037c6f..1f2601e7 100644 --- a/tests/modules/text_generation/test_peft_config.py +++ b/tests/modules/text_generation/test_peft_config.py @@ -2,11 +2,11 @@ from unittest.mock import Mock # Third Party -from peft import PromptTuningConfig +from peft import LoraConfig, PromptTuningConfig import pytest # Local -from caikit_nlp.data_model import TuningConfig +from caikit_nlp.data_model import LoraTuningConfig, TuningConfig from caikit_nlp.modules.text_generation import TextGeneration from caikit_nlp.modules.text_generation.peft_config import ( TuningType, @@ -74,6 +74,46 @@ def test_get_peft_config(train_kwargs, dummy_model, request): assert peft_config.prompt_tuning_init_text == tuning_config.prompt_tuning_init_text +@pytest.mark.parametrize( + "train_kwargs,dummy_model", + [ + ( + "seq2seq_lm_train_kwargs", + "seq2seq_lm_dummy_model", + ), + ("causal_lm_train_kwargs", "causal_lm_dummy_model"), + ], +) +def test_get_peft_config_with_lora(train_kwargs, dummy_model, request): + # Fixtures can't be called directly or passed to mark parametrize; + # Currently, passing the fixture by name and retrieving it through + # the request is the 'right' way to do this. + train_kwargs = request.getfixturevalue(train_kwargs) + dummy_model = request.getfixturevalue(dummy_model) + + # Define some sample values for testing + tuning_type = TuningType.LORA + tuning_config = LoraTuningConfig(r=8, lora_alpha=8, lora_dropout=0.0) + dummy_resource = train_kwargs["base_model"] + + # Call the function being tested + task_type, output_model_types, peft_config, tuning_type = get_peft_config( + tuning_type, tuning_config, dummy_resource + ) + + # Add assertions to validate the behavior of the function + assert task_type == dummy_resource.TASK_TYPE + assert output_model_types == dummy_resource.PROMPT_OUTPUT_TYPES + assert tuning_type == TuningType.LORA + + # Validation for type & important fields in the peft config + assert isinstance(peft_config, LoraConfig) + assert peft_config.task_type == dummy_resource.TASK_TYPE + assert peft_config.r == tuning_config.r + assert peft_config.lora_alpha == tuning_config.lora_alpha + assert peft_config.lora_dropout == tuning_config.lora_dropout + + def test_resolve_model_with_invalid_path_raises(): """Test passing invalid path to resolve_model function raises""" diff --git a/tests/modules/text_generation/test_peft_prompt_tuning.py b/tests/modules/text_generation/test_peft_prompt_tuning.py index 5cf82439..ddfd423c 100644 --- a/tests/modules/text_generation/test_peft_prompt_tuning.py +++ b/tests/modules/text_generation/test_peft_prompt_tuning.py @@ -162,6 +162,37 @@ def test_train_model(causal_lm_train_kwargs, set_cpu_device): assert isinstance(pred, GeneratedTextResult) +def test_train_model_lora_config(causal_lm_train_kwargs, set_cpu_device): + """Ensure that we can train a model on some toy data for 1+ steps & run inference.""" + patch_kwargs = { + "num_epochs": 1, + "verbalizer": "Tweet text : {{input}} Label : ", + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + caikit_nlp.data_model.GenerationTrainRecord( + input="@foo what a cute dog!", output="no complaint" + ), + caikit_nlp.data_model.GenerationTrainRecord( + input="@bar this is the worst idea ever.", output="complaint" + ), + ] + ), + "torch_dtype": torch.bfloat16, + "device": "cpu", + "tuning_config": caikit_nlp.data_model.LoraTuningConfig(r=8), + "tuning_type": "LORA", + } + causal_lm_train_kwargs.update(patch_kwargs) + model = caikit_nlp.modules.text_generation.PeftPromptTuning.train( + **causal_lm_train_kwargs + ) + # Test fallback to float32 behavior if this machine doesn't support bfloat16 + assert model.model.dtype is torch.float32 + # Ensure that we can get something out of it + pred = model.run("@bar what a cute cat!") + assert isinstance(pred, GeneratedTextResult) + + def test_gen_trained_mpt(causal_lm_train_kwargs, set_cpu_device): """Ensure that we are able to do generation on causal-lm model trained using MPT.""" @@ -181,6 +212,7 @@ def test_gen_trained_mpt(causal_lm_train_kwargs, set_cpu_device): "torch_dtype": torch.float32, "tuning_type": "MULTITASK_PROMPT_TUNING", "device": "cpu", + "tuning_config": caikit_nlp.data_model.TuningConfig(num_virtual_tokens=8), } causal_lm_train_kwargs.update(patch_kwargs) model = caikit_nlp.modules.text_generation.PeftPromptTuning.train(