1818import re
1919
2020# Third Party
21- from peft import MultitaskPromptTuningInit
21+ from peft import LoraConfig , MultitaskPromptTuningInit
2222from transformers import AutoConfig
2323
2424# First Party
5151class TuningType (str , Enum ):
5252 PROMPT_TUNING = "PROMPT_TUNING"
5353 MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING"
54+ LORA = "LORA"
5455 # MULTITASK_PREFIX_TUNING = "MULTITASK_PREFIX_TUNING"
5556 # P_TUNING = "P_TUNING"
5657 # PREFIX_TUNING = "PREFIX_TUNING"
57- # LORA = "LORA"
5858
5959
6060def resolve_base_model (base_model , cls , torch_dtype ):
@@ -99,7 +99,15 @@ def get_peft_config(
9999 tuning_type , tuning_config , base_model , cls , torch_dtype , verbalizer
100100):
101101
102- if tuning_type not in TuningType ._member_names_ :
102+ if isinstance (tuning_type , str ):
103+ tuning_type = TuningType (tuning_type )
104+
105+ error .type_check ("<NLP65714993E>" , TuningType , tuning_type = tuning_type )
106+
107+ if tuning_type not in [
108+ TuningType .PROMPT_TUNING ,
109+ TuningType .MULTITASK_PROMPT_TUNING ,
110+ ]:
103111 raise NotImplementedError ("{} tuning type not supported!" .format (tuning_type ))
104112
105113 if tuning_config .prompt_tuning_init_method :
@@ -147,27 +155,7 @@ def get_peft_config(
147155 error .type_check ("<NLP65714919E>" , PretrainedModelBase , base_model = base_model )
148156
149157 # Validate if tuned output model type is compatible with base model or not
150- if not tuning_config .output_model_types :
151- output_model_types = base_model .PROMPT_OUTPUT_TYPES
152- else :
153- # If the first element is not PromptOutputModelType, assume the entire list
154- # isn't and convert
155- if not isinstance (tuning_config .output_model_types [0 ], PromptOutputModelType ):
156- output_model_types = []
157- for output_type in tuning_config .output_model_types :
158- output_model_types .append (PromptOutputModelType (output_type ))
159- else :
160- output_model_types = tuning_config .output_model_types
161- error .value_check (
162- "<NLP36947542E>" ,
163- all (
164- output_type in base_model .PROMPT_OUTPUT_TYPES
165- for output_type in output_model_types
166- ),
167- "{} not supported for base model type {}" .format (
168- output_model_types , base_model .MODEL_TYPE
169- ),
170- )
158+ output_model_types = _get_output_types (tuning_config , base_model )
171159
172160 error .value_check (
173161 "<NLP30542004E>" ,
@@ -185,16 +173,6 @@ def get_peft_config(
185173 # NOTE: Base model is a resource at this point
186174 task_type = base_model .TASK_TYPE
187175
188- if isinstance (tuning_type , str ):
189- error .value_check (
190- "<NLP65714994E>" ,
191- tuning_type in TuningType ._member_names_ ,
192- f"Invalid tuning type [{ tuning_type } ]. Allowed types: "
193- f"[{ TuningType ._member_names_ } ]" ,
194- )
195- tuning_type = TuningType (tuning_type )
196- error .type_check ("<NLP65714993E>" , TuningType , tuning_type = tuning_type )
197-
198176 # Coerce the passed model into a resource; if we have one, this is a noop
199177 # TODO: When splitting up this mono-module, use the configured resource
200178 # type of the concrete class to bootstrap
@@ -218,3 +196,77 @@ def get_peft_config(
218196 )
219197
220198 return task_type , output_model_types , peft_config , tuning_type
199+
200+
201+ def _get_output_types (tuning_config , base_model ):
202+ "Validate and return output_model_types"
203+ # Validate if tuned output model type is compatible with base model or not
204+ if not tuning_config .output_model_types :
205+ output_model_types = base_model .PROMPT_OUTPUT_TYPES
206+ else :
207+ # If the first element is not PromptOutputModelType, assume the entire list
208+ # isn't and convert
209+ if not isinstance (tuning_config .output_model_types [0 ], PromptOutputModelType ):
210+ output_model_types = []
211+ for output_type in tuning_config .output_model_types :
212+ output_model_types .append (PromptOutputModelType (output_type ))
213+ else :
214+ output_model_types = tuning_config .output_model_types
215+ error .value_check (
216+ "<NLP36947542E>" ,
217+ all (
218+ output_type in base_model .PROMPT_OUTPUT_TYPES
219+ for output_type in output_model_types
220+ ),
221+ "{} not supported for base model type {}" .format (
222+ output_model_types , base_model .MODEL_TYPE
223+ ),
224+ )
225+ return output_model_types
226+
227+
228+ def _filter_params_for_prompt_config (prompt_config , params ):
229+ """Utility function to filter out required parameters for prompt_config
230+ from `params`
231+
232+ Args:
233+ prompt_config: PromptTuningConfig
234+ Tuning config type, eg:, PromptTuningConfig
235+ params: dict
236+ Dictionary containing all the input training params
237+
238+ Returns:
239+ dict:
240+ Dictionary containing required params for prompt_config
241+ """
242+ # Inspect the underlying dataclass fileds; we do this because the common super class
243+ # used for multi/vanilla prompt/prefix tuning is a DataClass; we can't use __dict__
244+ # because the dataclass fields are omitted.
245+ allowed_keys = list (prompt_config .__dataclass_fields__ .keys ())
246+ allowed_params = dict (filter (lambda x : x [0 ] in allowed_keys , params .items ()))
247+ log .info (
248+ "<NLP18184771I>" ,
249+ "[{}] config params not supported by provided tuning type!" .format (
250+ params .keys () - allowed_params .keys ()
251+ ),
252+ )
253+ return allowed_params
254+
255+
256+ def get_lora_config (tuning_type , tuning_config , base_model ) -> LoraConfig :
257+ """Creates Huggingface LoraConfig from Caikit tuning configuration."""
258+ if isinstance (tuning_type , str ):
259+ tuning_type = TuningType (tuning_type )
260+
261+ if tuning_type != TuningType .LORA :
262+ raise NotImplementedError ("{} tuning type not supported!" .format (tuning_type ))
263+
264+ error .type_check ("<NLP65714919E>" , PretrainedModelBase , base_model = base_model )
265+ # NOTE: Base model is a resource at this point
266+ task_type = base_model .TASK_TYPE
267+ config_kwargs = tuning_config .to_dict ()
268+ log .info ("<NLP61012781I>" , f"Parameters used: { config_kwargs } " )
269+ config_params = _filter_params_for_prompt_config (tuning_config , config_kwargs )
270+ output_model_types = _get_output_types (tuning_config , base_model )
271+ lora_config = LoraConfig (task_type = task_type , ** config_params )
272+ return task_type , output_model_types , lora_config , tuning_type
0 commit comments