Skip to content

Commit 1a5dde8

Browse files
committed
utility to generate LoraConfig
Signed-off-by: Sukriti-Sharma4 <[email protected]>
1 parent 2682ef3 commit 1a5dde8

File tree

3 files changed

+130
-66
lines changed

3 files changed

+130
-66
lines changed

caikit_nlp/data_model/generation.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""
1616
# Standard
1717
from enum import Enum
18-
from typing import List
18+
from typing import List, Union
1919

2020
# First Party
2121
from caikit.core import DataObjectBase
@@ -73,6 +73,41 @@ class TuningConfig(DataObjectBase):
7373
# encoder_hidden_size: int # Optional - The hidden size of the prompt encoder.
7474

7575

76+
@caikit.core.dataobject(package="caikit_data_model.caikit_nlp")
77+
class LoraTuningConfig(DataObjectBase):
78+
# Lora attention dimension.
79+
r: int
80+
# The names of the modules to apply Lora to.
81+
target_modules: Union[List[str], str]
82+
# The alpha parameter for Lora scaling.
83+
lora_alpha: int
84+
# The dropout probability for Lora layers.
85+
lora_dropout: float
86+
# Set this to True if the layer to replace stores weight like (fan_in, fan_out).
87+
# For example, gpt-2 uses Conv1D which stores weights like (fan_in, fan_out)
88+
# and hence this should be set to True.
89+
fan_in_fan_out: bool
90+
# Bias type for Lora. Can be ‘none’, ‘all’ or ‘lora_only’.
91+
# If ‘all’ or ‘lora_only’, the corresponding biases will be updated during training.
92+
# Be aware that this means that, even when disabling the adapters,
93+
# the model will not produce the same output
94+
# as the base model would have without adaptation.
95+
bias: str
96+
# List of modules apart from LoRA layers to be set as trainable
97+
# and saved in the final checkpoint.
98+
modules_to_save: List[str]
99+
# The layer indexes to transform, if this argument is specified,
100+
# it will apply the LoRA transformations
101+
# on the layer indexes that are specified in this list.
102+
# If a single integer is passed,
103+
# it will apply the LoRA transformations on the layer at this index.
104+
layers_to_transform: Union[List[int], int]
105+
# The layer pattern name, used only if layers_to_transform is different from None and
106+
# if the layer pattern is not in the common layers pattern.
107+
layers_pattern: str
108+
output_model_types: List[str]
109+
110+
76111
@caikit.core.dataobject(package="caikit_data_model.caikit_nlp")
77112
class ExponentialDecayLengthPenalty(DataObjectBase):
78113
start_index: int

caikit_nlp/modules/text_generation/peft_config.py

Lines changed: 86 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import re
1919

2020
# Third Party
21-
from peft import MultitaskPromptTuningInit
21+
from peft import LoraConfig, MultitaskPromptTuningInit
2222
from transformers import AutoConfig
2323

2424
# First Party
@@ -51,10 +51,10 @@
5151
class TuningType(str, Enum):
5252
PROMPT_TUNING = "PROMPT_TUNING"
5353
MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING"
54+
LORA = "LORA"
5455
# MULTITASK_PREFIX_TUNING = "MULTITASK_PREFIX_TUNING"
5556
# P_TUNING = "P_TUNING"
5657
# PREFIX_TUNING = "PREFIX_TUNING"
57-
# LORA = "LORA"
5858

5959

6060
def resolve_base_model(base_model, cls, torch_dtype):
@@ -99,7 +99,15 @@ def get_peft_config(
9999
tuning_type, tuning_config, base_model, cls, torch_dtype, verbalizer
100100
):
101101

102-
if tuning_type not in TuningType._member_names_:
102+
if isinstance(tuning_type, str):
103+
tuning_type = TuningType(tuning_type)
104+
105+
error.type_check("<NLP65714993E>", TuningType, tuning_type=tuning_type)
106+
107+
if tuning_type not in [
108+
TuningType.PROMPT_TUNING,
109+
TuningType.MULTITASK_PROMPT_TUNING,
110+
]:
103111
raise NotImplementedError("{} tuning type not supported!".format(tuning_type))
104112

105113
if tuning_config.prompt_tuning_init_method:
@@ -147,27 +155,7 @@ def get_peft_config(
147155
error.type_check("<NLP65714919E>", PretrainedModelBase, base_model=base_model)
148156

149157
# Validate if tuned output model type is compatible with base model or not
150-
if not tuning_config.output_model_types:
151-
output_model_types = base_model.PROMPT_OUTPUT_TYPES
152-
else:
153-
# If the first element is not PromptOutputModelType, assume the entire list
154-
# isn't and convert
155-
if not isinstance(tuning_config.output_model_types[0], PromptOutputModelType):
156-
output_model_types = []
157-
for output_type in tuning_config.output_model_types:
158-
output_model_types.append(PromptOutputModelType(output_type))
159-
else:
160-
output_model_types = tuning_config.output_model_types
161-
error.value_check(
162-
"<NLP36947542E>",
163-
all(
164-
output_type in base_model.PROMPT_OUTPUT_TYPES
165-
for output_type in output_model_types
166-
),
167-
"{} not supported for base model type {}".format(
168-
output_model_types, base_model.MODEL_TYPE
169-
),
170-
)
158+
output_model_types = _get_output_types(tuning_config, base_model)
171159

172160
error.value_check(
173161
"<NLP30542004E>",
@@ -185,16 +173,6 @@ def get_peft_config(
185173
# NOTE: Base model is a resource at this point
186174
task_type = base_model.TASK_TYPE
187175

188-
if isinstance(tuning_type, str):
189-
error.value_check(
190-
"<NLP65714994E>",
191-
tuning_type in TuningType._member_names_,
192-
f"Invalid tuning type [{tuning_type}]. Allowed types: "
193-
f"[{TuningType._member_names_}]",
194-
)
195-
tuning_type = TuningType(tuning_type)
196-
error.type_check("<NLP65714993E>", TuningType, tuning_type=tuning_type)
197-
198176
# Coerce the passed model into a resource; if we have one, this is a noop
199177
# TODO: When splitting up this mono-module, use the configured resource
200178
# type of the concrete class to bootstrap
@@ -218,3 +196,77 @@ def get_peft_config(
218196
)
219197

220198
return task_type, output_model_types, peft_config, tuning_type
199+
200+
201+
def _get_output_types(tuning_config, base_model):
202+
"Validate and return output_model_types"
203+
# Validate if tuned output model type is compatible with base model or not
204+
if not tuning_config.output_model_types:
205+
output_model_types = base_model.PROMPT_OUTPUT_TYPES
206+
else:
207+
# If the first element is not PromptOutputModelType, assume the entire list
208+
# isn't and convert
209+
if not isinstance(tuning_config.output_model_types[0], PromptOutputModelType):
210+
output_model_types = []
211+
for output_type in tuning_config.output_model_types:
212+
output_model_types.append(PromptOutputModelType(output_type))
213+
else:
214+
output_model_types = tuning_config.output_model_types
215+
error.value_check(
216+
"<NLP36947542E>",
217+
all(
218+
output_type in base_model.PROMPT_OUTPUT_TYPES
219+
for output_type in output_model_types
220+
),
221+
"{} not supported for base model type {}".format(
222+
output_model_types, base_model.MODEL_TYPE
223+
),
224+
)
225+
return output_model_types
226+
227+
228+
def _filter_params_for_prompt_config(prompt_config, params):
229+
"""Utility function to filter out required parameters for prompt_config
230+
from `params`
231+
232+
Args:
233+
prompt_config: PromptTuningConfig
234+
Tuning config type, eg:, PromptTuningConfig
235+
params: dict
236+
Dictionary containing all the input training params
237+
238+
Returns:
239+
dict:
240+
Dictionary containing required params for prompt_config
241+
"""
242+
# Inspect the underlying dataclass fileds; we do this because the common super class
243+
# used for multi/vanilla prompt/prefix tuning is a DataClass; we can't use __dict__
244+
# because the dataclass fields are omitted.
245+
allowed_keys = list(prompt_config.__dataclass_fields__.keys())
246+
allowed_params = dict(filter(lambda x: x[0] in allowed_keys, params.items()))
247+
log.info(
248+
"<NLP18184771I>",
249+
"[{}] config params not supported by provided tuning type!".format(
250+
params.keys() - allowed_params.keys()
251+
),
252+
)
253+
return allowed_params
254+
255+
256+
def get_lora_config(tuning_type, tuning_config, base_model) -> LoraConfig:
257+
"""Creates Huggingface LoraConfig from Caikit tuning configuration."""
258+
if isinstance(tuning_type, str):
259+
tuning_type = TuningType(tuning_type)
260+
261+
if tuning_type != TuningType.LORA:
262+
raise NotImplementedError("{} tuning type not supported!".format(tuning_type))
263+
264+
error.type_check("<NLP65714919E>", PretrainedModelBase, base_model=base_model)
265+
# NOTE: Base model is a resource at this point
266+
task_type = base_model.TASK_TYPE
267+
config_kwargs = tuning_config.to_dict()
268+
log.info("<NLP61012781I>", f"Parameters used: {config_kwargs}")
269+
config_params = _filter_params_for_prompt_config(tuning_config, config_kwargs)
270+
output_model_types = _get_output_types(tuning_config, base_model)
271+
lora_config = LoraConfig(task_type=task_type, **config_params)
272+
return task_type, output_model_types, lora_config, tuning_type

caikit_nlp/modules/text_generation/peft_prompt_tuning.py

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,12 @@
7474
)
7575
from ...toolkit.trainer_utils import validate_training_data
7676
from ...toolkit.verbalizer_utils import render_verbalizer
77-
from .peft_config import TuningType, get_peft_config, resolve_base_model
77+
from .peft_config import (
78+
TuningType,
79+
_filter_params_for_prompt_config,
80+
get_peft_config,
81+
resolve_base_model,
82+
)
7883

7984
log = alog.use_channel("PEFT_PROMPT")
8085
error = error_handler.get(log)
@@ -99,10 +104,10 @@ class PeftPromptTuning(ModuleBase):
99104
tuning_type_to_huggingface = {
100105
TuningType.PROMPT_TUNING: PeftType.PROMPT_TUNING,
101106
TuningType.MULTITASK_PROMPT_TUNING: PeftType.MULTITASK_PROMPT_TUNING,
107+
TuningType.LORA: PeftType.LORA,
102108
# TuningType.MULTITASK_PREFIX_TUNING: PeftType.MULTITASK_PREFIX_TUNING,
103109
# TuningType.P_TUNING: PeftType.P_TUNING,
104110
# TuningType.PREFIX_TUNING: PeftType.PREFIX_TUNING,
105-
# TuningType.LORA: PeftType.LORA,
106111
}
107112

108113
RANDOM_SEED = 73
@@ -856,7 +861,7 @@ def create_hf_tuning_config(
856861
elif tuning_type == TuningType.MULTITASK_PROMPT_TUNING:
857862
tuning_config_type = MultitaskPromptTuningConfig
858863

859-
config_params = cls._filter_params_for_prompt_config(
864+
config_params = _filter_params_for_prompt_config(
860865
tuning_config_type, config_kwargs
861866
)
862867
log.info("<NLP41038481I>", f"Parameters used: {config_params}")
@@ -1150,34 +1155,6 @@ def _execute_train_loop(
11501155
)
11511156
return {"loss": training_loss_tracker}
11521157

1153-
@classmethod
1154-
def _filter_params_for_prompt_config(cls, prompt_config, params):
1155-
"""Utility function to filter out required parameters for prompt_config
1156-
from `params`
1157-
1158-
Args:
1159-
prompt_config: PromptTuningConfig
1160-
Tuning config type, eg:, PromptTuningConfig
1161-
params: dict
1162-
Dictionary containing all the input training params
1163-
1164-
Returns:
1165-
dict:
1166-
Dictionary containing required params for prompt_config
1167-
"""
1168-
# Inspect the underlying dataclass fileds; we do this because the common super class
1169-
# used for multi/vanilla prompt/prefix tuning is a DataClass; we can't use __dict__
1170-
# because the dataclass fields are omitted.
1171-
allowed_keys = list(prompt_config.__dataclass_fields__.keys())
1172-
allowed_params = dict(filter(lambda x: x[0] in allowed_keys, params.items()))
1173-
log.info(
1174-
"<NLP18184771I>",
1175-
"[{}] config params not supported by provided tuning type!".format(
1176-
params.keys() - allowed_params.keys()
1177-
),
1178-
)
1179-
return allowed_params
1180-
11811158
@staticmethod
11821159
def convert_peft_model_to_type(
11831160
device: str, peft_model: PeftModel, torch_dtype=Union[str, torch.dtype]

0 commit comments

Comments
 (0)