Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 39 additions & 16 deletions caikit_nlp/modules/text_generation/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os

# Third Party
from peft import MultitaskPromptTuningInit
from peft import LoraConfig, MultitaskPromptTuningInit
from transformers import AutoConfig

# First Party
Expand Down Expand Up @@ -51,7 +51,7 @@ class TuningType(str, Enum):
# MULTITASK_PREFIX_TUNING = "MULTITASK_PREFIX_TUNING"
# P_TUNING = "P_TUNING"
# PREFIX_TUNING = "PREFIX_TUNING"
# LORA = "LORA"
LORA = "LORA"


def resolve_base_model(base_model, cls, torch_dtype):
Expand Down Expand Up @@ -79,7 +79,19 @@ def resolve_base_model(base_model, cls, torch_dtype):


def get_peft_config(
tuning_type, tuning_config, base_model, cls, torch_dtype, verbalizer
tuning_type,
tuning_config,
base_model,
cls,
torch_dtype,
verbalizer,
lora_r=8,
lora_target_modules=None,
lora_alpha=8,
lora_dropout=0.0,
lora_fan_in_fan_out=False,
lora_bias="none",
lora_modules_to_save=None,
):

if tuning_type not in TuningType._member_names_:
Expand Down Expand Up @@ -186,18 +198,29 @@ def get_peft_config(
# Take tokenizer name/path from the model
tokenizer_name_or_path = base_model.model.config._name_or_path

# Build the peft config; this is how we determine that we want a sequence classifier.
# If we want more types, we will likely need to map this to data model outputs etc.

# NOTE: We currently only support TEXT as init type, this is to later only easily
# switch to MPT
peft_config = cls.create_hf_tuning_config(
base_model=base_model,
tuning_type=tuning_type,
task_type=task_type,
tokenizer_name_or_path=tokenizer_name_or_path,
tuning_config=tuning_config,
output_model_types=output_model_types,
)
if tuning_type == TuningType.LORA:
peft_config = LoraConfig(
r=lora_r,
target_modules=lora_target_modules,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
fan_in_fan_out=lora_fan_in_fan_out,
bias=lora_bias,
modules_to_save=lora_modules_to_save,
)
else:
# Build the peft config; this is how we determine that we want a sequence classifier.
# If we want more types, we will likely need to map this to data model outputs etc.

# NOTE: We currently only support TEXT as init type, this is to later only easily
# switch to MPT
peft_config = cls.create_hf_tuning_config(
base_model=base_model,
tuning_type=tuning_type,
task_type=task_type,
tokenizer_name_or_path=tokenizer_name_or_path,
tuning_config=tuning_config,
output_model_types=output_model_types,
)

return task_type, output_model_types, peft_config, tuning_type
21 changes: 21 additions & 0 deletions caikit_nlp/modules/text_generation/text_generation_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# Third Party
from datasets import Dataset
from datasets import IterableDataset as TransformersIterableDataset
from peft import LoraConfig, get_peft_model
from transformers import AutoConfig, AutoTokenizer
import torch

Expand Down Expand Up @@ -182,6 +183,7 @@ def train(
random_seed: int = RANDOM_SEED,
lr: float = 2e-5,
use_iterable_dataset: bool = True,
lora_config: LoraConfig = None,
**kwargs,
) -> "TextGeneration":
"""
Expand Down Expand Up @@ -214,6 +216,9 @@ def train(
Indicates whether or not we should load the full dataset into memory
NOTE: use True for this option if you are fine tuning a causal LM with
a large target sequence length unless your dataset is VERY small!
lora_config: LoraConfig
If defined, the LoRA technique for fine tuning will be applied (as
opposed to a 'full' fine tuning).
**kwargs:
Arguments supported by HF Training Arguments.
TrainingArguments:
Expand Down Expand Up @@ -430,6 +435,22 @@ def train(
torch_dtype=torch_dtype,
)

if lora_config is not None:

# Merge Model Here so it is returned
model_to_merge = get_peft_model(model, lora_config)
merged_model = model_to_merge.merge_and_unload()

# return that
return cls(
model_name=base_model._model_name,
model=merged_model,
bos_token=model.tokenizer.bos_token or None,
sep_token=model.tokenizer.sep_token or None,
eos_token=model.tokenizer.eos_token or None,
pad_token=model.tokenizer.pad_token or None,
training_metadata={"loss": training_loss_history},
)
return cls(
model_name=base_model._model_name,
model=model,
Expand Down
11 changes: 11 additions & 0 deletions caikit_nlp/resources/pretrained_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ class PretrainedModelBase(ABC, ModuleBase):
_MODEL_ARTIFACTS_CONFIG_KEY = "model_artifacts"
_LEFT_PAD_MODEL_TYPES = ("gpt", "opt", "bloom")

def named_modules(self):
return self._model.named_modules()

def get_submodule(self, target: str):
return self._model.get_submodule(target)

def named_parameters(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
):
return self._model.named_parameters(prefix, recurse, remove_duplicate)

@classmethod
@property
def REQUIRES_TOKEN_UNWRAPPING(cls) -> str:
Expand Down
Loading