Skip to content

Commit 55f168e

Browse files
committed
add unit tests
Signed-off-by: Sukriti-Sharma4 <[email protected]>
1 parent 1a5dde8 commit 55f168e

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

caikit_nlp/modules/text_generation/peft_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,5 +268,6 @@ def get_lora_config(tuning_type, tuning_config, base_model) -> LoraConfig:
268268
log.info("<NLP61012781I>", f"Parameters used: {config_kwargs}")
269269
config_params = _filter_params_for_prompt_config(tuning_config, config_kwargs)
270270
output_model_types = _get_output_types(tuning_config, base_model)
271+
del config_params["output_model_types"]
271272
lora_config = LoraConfig(task_type=task_type, **config_params)
272273
return task_type, output_model_types, lora_config, tuning_type

tests/modules/text_generation/test_peft_config.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
from unittest.mock import Mock
33

44
# Third Party
5-
from peft import PromptTuningConfig
5+
from peft import LoraConfig, PromptTuningConfig
66
import pytest
77

88
# Local
9-
from caikit_nlp.data_model import TuningConfig
9+
from caikit_nlp.data_model import LoraTuningConfig, TuningConfig
1010
from caikit_nlp.modules.text_generation import TextGeneration
1111
from caikit_nlp.modules.text_generation.peft_config import (
1212
TuningType,
13+
get_lora_config,
1314
get_peft_config,
1415
resolve_base_model,
1516
)
@@ -74,6 +75,46 @@ def test_get_peft_config(train_kwargs, dummy_model, request):
7475
assert peft_config.prompt_tuning_init_text == tuning_config.prompt_tuning_init_text
7576

7677

78+
@pytest.mark.parametrize(
79+
"train_kwargs,dummy_model",
80+
[
81+
(
82+
"seq2seq_lm_train_kwargs",
83+
"seq2seq_lm_dummy_model",
84+
),
85+
("causal_lm_train_kwargs", "causal_lm_dummy_model"),
86+
],
87+
)
88+
def test_get_lora_config(train_kwargs, dummy_model, request):
89+
# Fixtures can't be called directly or passed to mark parametrize;
90+
# Currently, passing the fixture by name and retrieving it through
91+
# the request is the 'right' way to do this.
92+
train_kwargs = request.getfixturevalue(train_kwargs)
93+
dummy_model = request.getfixturevalue(dummy_model)
94+
95+
# Define some sample values for testing
96+
tuning_type = TuningType.LORA
97+
tuning_config = LoraTuningConfig(r=8, lora_alpha=8, lora_dropout=0.0)
98+
dummy_resource = train_kwargs["base_model"]
99+
100+
# Call the function being tested
101+
task_type, output_model_types, lora_config, tuning_type = get_lora_config(
102+
tuning_type, tuning_config, dummy_resource
103+
)
104+
105+
# Add assertions to validate the behavior of the function
106+
assert task_type == dummy_resource.TASK_TYPE
107+
assert output_model_types == dummy_resource.PROMPT_OUTPUT_TYPES
108+
assert tuning_type == TuningType.LORA
109+
110+
# Validation for type & important fields in the peft config
111+
assert isinstance(lora_config, LoraConfig)
112+
assert lora_config.task_type == dummy_resource.TASK_TYPE
113+
assert lora_config.r == tuning_config.r
114+
assert lora_config.lora_alpha == tuning_config.lora_alpha
115+
assert lora_config.lora_dropout == tuning_config.lora_dropout
116+
117+
77118
def test_resolve_model_with_invalid_path_raises():
78119
"""Test passing invalid path to resolve_model function raises"""
79120

0 commit comments

Comments
 (0)