Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion caikit_nlp/data_model/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
# Standard
from enum import Enum
from typing import List
from typing import List, Optional, Union

# First Party
from caikit.core import DataObjectBase
Expand Down Expand Up @@ -73,6 +73,26 @@ class TuningConfig(DataObjectBase):
# encoder_hidden_size: int # Optional - The hidden size of the prompt encoder.


@caikit.core.dataobject(package="caikit_data_model.caikit_nlp")
class LoraTuningConfig(DataObjectBase):
# Lora attention dimension.
r: int
# List of module names or regex expression of the module names to replace with Lora.
# For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$
target_modules: Optional[Union[List[str], str]] = None
# The alpha parameter for Lora scaling.
lora_alpha: Optional[int] = 8
# The dropout probability for Lora layers.
lora_dropout: Optional[float] = 0.0
# Bias type for Lora. Can be ‘none’, ‘all’ or ‘lora_only’.
# If ‘all’ or ‘lora_only’, the corresponding biases will be updated during training.
# Be aware that this means that, even when disabling the adapters,
# the model will not produce the same output
# as the base model would have without adaptation.
bias: Optional[str] = "none"
output_model_types: List[str]


@caikit.core.dataobject(package="caikit_data_model.caikit_nlp")
class ExponentialDecayLengthPenalty(DataObjectBase):
start_index: int
Expand Down
207 changes: 146 additions & 61 deletions caikit_nlp/modules/text_generation/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import re

# Third Party
from peft import MultitaskPromptTuningInit
from peft import LoraConfig, MultitaskPromptTuningInit, PromptTuningConfig
from transformers import AutoConfig

# First Party
Expand Down Expand Up @@ -55,10 +55,10 @@
class TuningType(str, Enum):
PROMPT_TUNING = "PROMPT_TUNING"
MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING"
LORA = "LORA"
# MULTITASK_PREFIX_TUNING = "MULTITASK_PREFIX_TUNING"
# P_TUNING = "P_TUNING"
# PREFIX_TUNING = "PREFIX_TUNING"
# LORA = "LORA"


def resolve_base_model(base_model, cls, torch_dtype):
Expand Down Expand Up @@ -101,12 +101,142 @@ def resolve_base_model(base_model, cls, torch_dtype):


def get_peft_config(
tuning_type, tuning_config, base_model, cls, torch_dtype, verbalizer
tuning_type,
tuning_config,
base_model,
cls=None,
torch_dtype=None,
verbalizer="{{input}}",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason these defaults are now being set here? I.e., the calling module is always going to pass them in positionally and override these defaults, right?

Copy link
Collaborator Author

@Ssukriti Ssukriti Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we want to deprecate all these 3 parameters (once I move create_hf_tuning function to this file , we wont need cls either. Other 2 are unused). So I just set them here , so we can stop setting them from calling module and remove them in next major release. The 3 will not be needed for Lora either and dont have to be set from calling module.

):

if isinstance(tuning_type, str):
error.value_check(
"<NLP65714994E>",
tuning_type in TuningType._member_names_,
f"Invalid tuning type [{tuning_type}]. Allowed types: "
f"[{TuningType._member_names_}]",
)
tuning_type = TuningType(tuning_type)

error.type_check("<NLP65714993E>", TuningType, tuning_type=tuning_type)

if tuning_type not in TuningType._member_names_:
raise NotImplementedError("{} tuning type not supported!".format(tuning_type))

error.type_check("<NLP65714919E>", PretrainedModelBase, base_model=base_model)

# Validate if tuned output model type is compatible with base model or not
output_model_types = _get_output_types(tuning_config, base_model)

# NOTE: Base model is a resource at this point
task_type = base_model.TASK_TYPE

error.value_check(
"<NLP30542004E>",
len(output_model_types) <= base_model.MAX_NUM_TRANSFORMERS,
f"Too many output model types. Got {len(output_model_types)}, "
f"maximum {base_model.MAX_NUM_TRANSFORMERS}",
)

if verbalizer:
log.warning(
"<NLP21323085W>",
"verbalizer parameter is DEPRECATED for this function \
and will be removed in future. \
This parameter is also not getting used in creation of peft config",
)
# Ensure that our verbalizer is a string and
# will not render to a hardcoded string
# TODO: This check should happen in prompt tuning module and not here
error.value_check(
"<NLP83837412E>",
is_valid_verbalizer(verbalizer),
"Provided verbalizer is an invalid type or has no renderable placeholders",
)

if torch_dtype:
log.warning(
"<NLP16173085W>",
"torch_dtype parameter is DEPRECATED for this function \
and will be removed in future. \
This parameter is also not getting used in creation of peft config",
)
torch_dtype = get_torch_dtype(torch_dtype)
print("type", tuning_type)
if tuning_type in [
TuningType.PROMPT_TUNING,
TuningType.MULTITASK_PROMPT_TUNING,
]:
peft_config = _create_prompt_tuning_config(
tuning_type, tuning_config, cls, base_model, task_type, output_model_types
)
else:
# we only have Lora besides other two for now
peft_config = _create_lora_config(tuning_config, task_type)

return task_type, output_model_types, peft_config, tuning_type


def _get_output_types(tuning_config, base_model):
"Validate and return output_model_types"
# Validate if tuned output model type is compatible with base model or not
if not tuning_config.output_model_types:
output_model_types = base_model.PROMPT_OUTPUT_TYPES
else:
# If the first element is not PromptOutputModelType, assume the entire list
# isn't and convert
if not isinstance(tuning_config.output_model_types[0], PromptOutputModelType):
output_model_types = []
for output_type in tuning_config.output_model_types:
output_model_types.append(PromptOutputModelType(output_type))
else:
output_model_types = tuning_config.output_model_types
error.value_check(
"<NLP36947542E>",
all(
output_type in base_model.PROMPT_OUTPUT_TYPES
for output_type in output_model_types
),
"{} not supported for base model type {}".format(
output_model_types, base_model.MODEL_TYPE
),
)
return output_model_types


def _filter_params_for_prompt_config(prompt_config, params):
"""Utility function to filter out required parameters for prompt_config
from `params`

Args:
prompt_config: PromptTuningConfig
Tuning config type, eg:, PromptTuningConfig
params: dict
Dictionary containing all the input training params

Returns:
dict:
Dictionary containing required params for prompt_config
"""
# Inspect the underlying dataclass fileds; we do this because the common super class
# used for multi/vanilla prompt/prefix tuning is a DataClass; we can't use __dict__
# because the dataclass fields are omitted.
allowed_keys = list(prompt_config.__dataclass_fields__.keys())
allowed_params = dict(filter(lambda x: x[0] in allowed_keys, params.items()))
log.info(
"<NLP18184771I>",
"[{}] config params not supported by provided tuning type!".format(
params.keys() - allowed_params.keys()
),
)
return allowed_params


def _create_prompt_tuning_config(
tuning_type, tuning_config, cls, base_model, task_type, output_model_types
) -> PromptTuningConfig:
"""Creates Huggingface PromptTuningConfig from Caikit tuning configuration."""

if tuning_config.prompt_tuning_init_method:
# NOTE: GK-APR-5-2023
# MultitaskPromptTuningInit and MultitaskPrefixTuningInit are same at the
Expand Down Expand Up @@ -149,62 +279,6 @@ def get_peft_config(
tuning_config.prompt_tuning_init_source_model,
)

error.type_check("<NLP65714919E>", PretrainedModelBase, base_model=base_model)

# Validate if tuned output model type is compatible with base model or not
if not tuning_config.output_model_types:
output_model_types = base_model.PROMPT_OUTPUT_TYPES
else:
# If the first element is not PromptOutputModelType, assume the entire list
# isn't and convert
if not isinstance(tuning_config.output_model_types[0], PromptOutputModelType):
output_model_types = []
for output_type in tuning_config.output_model_types:
output_model_types.append(PromptOutputModelType(output_type))
else:
output_model_types = tuning_config.output_model_types
error.value_check(
"<NLP36947542E>",
all(
output_type in base_model.PROMPT_OUTPUT_TYPES
for output_type in output_model_types
),
"{} not supported for base model type {}".format(
output_model_types, base_model.MODEL_TYPE
),
)

error.value_check(
"<NLP30542004E>",
len(output_model_types) <= base_model.MAX_NUM_TRANSFORMERS,
f"Too many output model types. Got {len(output_model_types)}, "
f"maximum {base_model.MAX_NUM_TRANSFORMERS}",
)
# Ensure that our verbalizer is a string and will not render to a hardcoded string
error.value_check(
"<NLP83837412E>",
is_valid_verbalizer(verbalizer),
"Provided verbalizer is an invalid type or has no renderable placeholders",
)

# NOTE: Base model is a resource at this point
task_type = base_model.TASK_TYPE

if isinstance(tuning_type, str):
error.value_check(
"<NLP65714994E>",
tuning_type in TuningType._member_names_,
f"Invalid tuning type [{tuning_type}]. Allowed types: "
f"[{TuningType._member_names_}]",
)
tuning_type = TuningType(tuning_type)
error.type_check("<NLP65714993E>", TuningType, tuning_type=tuning_type)

# Coerce the passed model into a resource; if we have one, this is a noop
# TODO: When splitting up this mono-module, use the configured resource
# type of the concrete class to bootstrap
torch_dtype = get_torch_dtype(torch_dtype)

# Take tokenizer name/path from the model
tokenizer_name_or_path = base_model.model.config._name_or_path

Expand All @@ -213,13 +287,24 @@ def get_peft_config(

# NOTE: We currently only support TEXT as init type, this is to later only easily
# switch to MPT
peft_config = cls.create_hf_tuning_config(
prompt_tuning_config = cls.create_hf_tuning_config(
base_model=base_model,
tuning_type=tuning_type,
task_type=task_type,
tokenizer_name_or_path=tokenizer_name_or_path,
tuning_config=tuning_config,
output_model_types=output_model_types,
)
return prompt_tuning_config

return task_type, output_model_types, peft_config, tuning_type

def _create_lora_config(tuning_config, task_type) -> LoraConfig:
"""Creates Huggingface LoraConfig from Caikit tuning configuration."""

config_kwargs = tuning_config.to_dict()
log.info("<NLP61012781I>", f"Parameters used: {config_kwargs}")
config_params = _filter_params_for_prompt_config(tuning_config, config_kwargs)
if "output_model_types" in config_params:
del config_params["output_model_types"]
lora_config = LoraConfig(task_type=task_type, **config_params)
return lora_config
Loading