-
Notifications
You must be signed in to change notification settings - Fork 55
Lora config data model and utilities #270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1a5dde8
55f168e
c65aa72
1016a76
4eeb934
0e2a418
45420cf
a1670a4
47e32a2
ae2a2a9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,7 @@ | |
| import re | ||
|
|
||
| # Third Party | ||
| from peft import MultitaskPromptTuningInit | ||
| from peft import LoraConfig, MultitaskPromptTuningInit, PromptTuningConfig | ||
| from transformers import AutoConfig | ||
|
|
||
| # First Party | ||
|
|
@@ -55,10 +55,10 @@ | |
| class TuningType(str, Enum): | ||
| PROMPT_TUNING = "PROMPT_TUNING" | ||
| MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING" | ||
| LORA = "LORA" | ||
| # MULTITASK_PREFIX_TUNING = "MULTITASK_PREFIX_TUNING" | ||
| # P_TUNING = "P_TUNING" | ||
| # PREFIX_TUNING = "PREFIX_TUNING" | ||
| # LORA = "LORA" | ||
|
|
||
|
|
||
| def resolve_base_model(base_model, cls, torch_dtype): | ||
|
|
@@ -101,12 +101,142 @@ def resolve_base_model(base_model, cls, torch_dtype): | |
|
|
||
|
|
||
| def get_peft_config( | ||
| tuning_type, tuning_config, base_model, cls, torch_dtype, verbalizer | ||
| tuning_type, | ||
| tuning_config, | ||
| base_model, | ||
| cls=None, | ||
| torch_dtype=None, | ||
| verbalizer="{{input}}", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a reason these defaults are now being set here? I.e., the calling module is always going to pass them in positionally and override these defaults, right?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we want to deprecate all these 3 parameters (once I move create_hf_tuning function to this file , we wont need cls either. Other 2 are unused). So I just set them here , so we can stop setting them from calling module and remove them in next major release. The 3 will not be needed for Lora either and dont have to be set from calling module. |
||
| ): | ||
|
|
||
| if isinstance(tuning_type, str): | ||
| error.value_check( | ||
| "<NLP65714994E>", | ||
| tuning_type in TuningType._member_names_, | ||
| f"Invalid tuning type [{tuning_type}]. Allowed types: " | ||
| f"[{TuningType._member_names_}]", | ||
| ) | ||
| tuning_type = TuningType(tuning_type) | ||
|
|
||
| error.type_check("<NLP65714993E>", TuningType, tuning_type=tuning_type) | ||
|
|
||
| if tuning_type not in TuningType._member_names_: | ||
| raise NotImplementedError("{} tuning type not supported!".format(tuning_type)) | ||
|
|
||
| error.type_check("<NLP65714919E>", PretrainedModelBase, base_model=base_model) | ||
|
|
||
| # Validate if tuned output model type is compatible with base model or not | ||
| output_model_types = _get_output_types(tuning_config, base_model) | ||
|
|
||
| # NOTE: Base model is a resource at this point | ||
| task_type = base_model.TASK_TYPE | ||
|
|
||
| error.value_check( | ||
| "<NLP30542004E>", | ||
| len(output_model_types) <= base_model.MAX_NUM_TRANSFORMERS, | ||
| f"Too many output model types. Got {len(output_model_types)}, " | ||
| f"maximum {base_model.MAX_NUM_TRANSFORMERS}", | ||
| ) | ||
|
|
||
| if verbalizer: | ||
| log.warning( | ||
| "<NLP21323085W>", | ||
| "verbalizer parameter is DEPRECATED for this function \ | ||
| and will be removed in future. \ | ||
| This parameter is also not getting used in creation of peft config", | ||
| ) | ||
| # Ensure that our verbalizer is a string and | ||
| # will not render to a hardcoded string | ||
| # TODO: This check should happen in prompt tuning module and not here | ||
| error.value_check( | ||
| "<NLP83837412E>", | ||
| is_valid_verbalizer(verbalizer), | ||
| "Provided verbalizer is an invalid type or has no renderable placeholders", | ||
| ) | ||
|
|
||
| if torch_dtype: | ||
| log.warning( | ||
| "<NLP16173085W>", | ||
| "torch_dtype parameter is DEPRECATED for this function \ | ||
| and will be removed in future. \ | ||
| This parameter is also not getting used in creation of peft config", | ||
| ) | ||
| torch_dtype = get_torch_dtype(torch_dtype) | ||
| print("type", tuning_type) | ||
| if tuning_type in [ | ||
| TuningType.PROMPT_TUNING, | ||
| TuningType.MULTITASK_PROMPT_TUNING, | ||
| ]: | ||
| peft_config = _create_prompt_tuning_config( | ||
| tuning_type, tuning_config, cls, base_model, task_type, output_model_types | ||
| ) | ||
| else: | ||
| # we only have Lora besides other two for now | ||
| peft_config = _create_lora_config(tuning_config, task_type) | ||
|
|
||
| return task_type, output_model_types, peft_config, tuning_type | ||
|
|
||
|
|
||
| def _get_output_types(tuning_config, base_model): | ||
| "Validate and return output_model_types" | ||
| # Validate if tuned output model type is compatible with base model or not | ||
| if not tuning_config.output_model_types: | ||
| output_model_types = base_model.PROMPT_OUTPUT_TYPES | ||
| else: | ||
| # If the first element is not PromptOutputModelType, assume the entire list | ||
| # isn't and convert | ||
| if not isinstance(tuning_config.output_model_types[0], PromptOutputModelType): | ||
| output_model_types = [] | ||
| for output_type in tuning_config.output_model_types: | ||
| output_model_types.append(PromptOutputModelType(output_type)) | ||
| else: | ||
| output_model_types = tuning_config.output_model_types | ||
| error.value_check( | ||
| "<NLP36947542E>", | ||
| all( | ||
| output_type in base_model.PROMPT_OUTPUT_TYPES | ||
| for output_type in output_model_types | ||
| ), | ||
| "{} not supported for base model type {}".format( | ||
| output_model_types, base_model.MODEL_TYPE | ||
| ), | ||
| ) | ||
| return output_model_types | ||
|
|
||
|
|
||
| def _filter_params_for_prompt_config(prompt_config, params): | ||
| """Utility function to filter out required parameters for prompt_config | ||
| from `params` | ||
|
|
||
| Args: | ||
| prompt_config: PromptTuningConfig | ||
| Tuning config type, eg:, PromptTuningConfig | ||
| params: dict | ||
| Dictionary containing all the input training params | ||
|
|
||
| Returns: | ||
| dict: | ||
| Dictionary containing required params for prompt_config | ||
| """ | ||
| # Inspect the underlying dataclass fileds; we do this because the common super class | ||
| # used for multi/vanilla prompt/prefix tuning is a DataClass; we can't use __dict__ | ||
| # because the dataclass fields are omitted. | ||
| allowed_keys = list(prompt_config.__dataclass_fields__.keys()) | ||
| allowed_params = dict(filter(lambda x: x[0] in allowed_keys, params.items())) | ||
| log.info( | ||
| "<NLP18184771I>", | ||
| "[{}] config params not supported by provided tuning type!".format( | ||
| params.keys() - allowed_params.keys() | ||
| ), | ||
| ) | ||
| return allowed_params | ||
|
|
||
|
|
||
| def _create_prompt_tuning_config( | ||
| tuning_type, tuning_config, cls, base_model, task_type, output_model_types | ||
| ) -> PromptTuningConfig: | ||
| """Creates Huggingface PromptTuningConfig from Caikit tuning configuration.""" | ||
|
|
||
| if tuning_config.prompt_tuning_init_method: | ||
| # NOTE: GK-APR-5-2023 | ||
| # MultitaskPromptTuningInit and MultitaskPrefixTuningInit are same at the | ||
|
|
@@ -149,62 +279,6 @@ def get_peft_config( | |
| tuning_config.prompt_tuning_init_source_model, | ||
| ) | ||
|
|
||
| error.type_check("<NLP65714919E>", PretrainedModelBase, base_model=base_model) | ||
|
|
||
| # Validate if tuned output model type is compatible with base model or not | ||
| if not tuning_config.output_model_types: | ||
| output_model_types = base_model.PROMPT_OUTPUT_TYPES | ||
| else: | ||
| # If the first element is not PromptOutputModelType, assume the entire list | ||
| # isn't and convert | ||
| if not isinstance(tuning_config.output_model_types[0], PromptOutputModelType): | ||
| output_model_types = [] | ||
| for output_type in tuning_config.output_model_types: | ||
| output_model_types.append(PromptOutputModelType(output_type)) | ||
| else: | ||
| output_model_types = tuning_config.output_model_types | ||
| error.value_check( | ||
| "<NLP36947542E>", | ||
| all( | ||
| output_type in base_model.PROMPT_OUTPUT_TYPES | ||
| for output_type in output_model_types | ||
| ), | ||
| "{} not supported for base model type {}".format( | ||
| output_model_types, base_model.MODEL_TYPE | ||
| ), | ||
| ) | ||
|
|
||
| error.value_check( | ||
| "<NLP30542004E>", | ||
| len(output_model_types) <= base_model.MAX_NUM_TRANSFORMERS, | ||
| f"Too many output model types. Got {len(output_model_types)}, " | ||
| f"maximum {base_model.MAX_NUM_TRANSFORMERS}", | ||
| ) | ||
| # Ensure that our verbalizer is a string and will not render to a hardcoded string | ||
| error.value_check( | ||
| "<NLP83837412E>", | ||
| is_valid_verbalizer(verbalizer), | ||
| "Provided verbalizer is an invalid type or has no renderable placeholders", | ||
| ) | ||
|
|
||
| # NOTE: Base model is a resource at this point | ||
| task_type = base_model.TASK_TYPE | ||
|
|
||
| if isinstance(tuning_type, str): | ||
| error.value_check( | ||
| "<NLP65714994E>", | ||
| tuning_type in TuningType._member_names_, | ||
| f"Invalid tuning type [{tuning_type}]. Allowed types: " | ||
| f"[{TuningType._member_names_}]", | ||
| ) | ||
| tuning_type = TuningType(tuning_type) | ||
| error.type_check("<NLP65714993E>", TuningType, tuning_type=tuning_type) | ||
|
|
||
| # Coerce the passed model into a resource; if we have one, this is a noop | ||
| # TODO: When splitting up this mono-module, use the configured resource | ||
| # type of the concrete class to bootstrap | ||
| torch_dtype = get_torch_dtype(torch_dtype) | ||
|
|
||
| # Take tokenizer name/path from the model | ||
| tokenizer_name_or_path = base_model.model.config._name_or_path | ||
|
|
||
|
|
@@ -213,13 +287,24 @@ def get_peft_config( | |
|
|
||
| # NOTE: We currently only support TEXT as init type, this is to later only easily | ||
| # switch to MPT | ||
| peft_config = cls.create_hf_tuning_config( | ||
| prompt_tuning_config = cls.create_hf_tuning_config( | ||
| base_model=base_model, | ||
| tuning_type=tuning_type, | ||
| task_type=task_type, | ||
| tokenizer_name_or_path=tokenizer_name_or_path, | ||
| tuning_config=tuning_config, | ||
| output_model_types=output_model_types, | ||
| ) | ||
| return prompt_tuning_config | ||
|
|
||
| return task_type, output_model_types, peft_config, tuning_type | ||
|
|
||
| def _create_lora_config(tuning_config, task_type) -> LoraConfig: | ||
| """Creates Huggingface LoraConfig from Caikit tuning configuration.""" | ||
|
|
||
| config_kwargs = tuning_config.to_dict() | ||
| log.info("<NLP61012781I>", f"Parameters used: {config_kwargs}") | ||
| config_params = _filter_params_for_prompt_config(tuning_config, config_kwargs) | ||
| if "output_model_types" in config_params: | ||
| del config_params["output_model_types"] | ||
| lora_config = LoraConfig(task_type=task_type, **config_params) | ||
| return lora_config | ||
Uh oh!
There was an error while loading. Please reload this page.