|
2 | 2 | from unittest.mock import Mock |
3 | 3 |
|
4 | 4 | # Third Party |
5 | | -from peft import PromptTuningConfig |
| 5 | +from peft import LoraConfig, PromptTuningConfig |
6 | 6 | import pytest |
7 | 7 |
|
8 | 8 | # Local |
9 | | -from caikit_nlp.data_model import TuningConfig |
| 9 | +from caikit_nlp.data_model import LoraTuningConfig, TuningConfig |
10 | 10 | from caikit_nlp.modules.text_generation import TextGeneration |
11 | 11 | from caikit_nlp.modules.text_generation.peft_config import ( |
12 | 12 | TuningType, |
| 13 | + get_lora_config, |
13 | 14 | get_peft_config, |
14 | 15 | resolve_base_model, |
15 | 16 | ) |
@@ -74,6 +75,46 @@ def test_get_peft_config(train_kwargs, dummy_model, request): |
74 | 75 | assert peft_config.prompt_tuning_init_text == tuning_config.prompt_tuning_init_text |
75 | 76 |
|
76 | 77 |
|
| 78 | +@pytest.mark.parametrize( |
| 79 | + "train_kwargs,dummy_model", |
| 80 | + [ |
| 81 | + ( |
| 82 | + "seq2seq_lm_train_kwargs", |
| 83 | + "seq2seq_lm_dummy_model", |
| 84 | + ), |
| 85 | + ("causal_lm_train_kwargs", "causal_lm_dummy_model"), |
| 86 | + ], |
| 87 | +) |
| 88 | +def test_get_lora_config(train_kwargs, dummy_model, request): |
| 89 | + # Fixtures can't be called directly or passed to mark parametrize; |
| 90 | + # Currently, passing the fixture by name and retrieving it through |
| 91 | + # the request is the 'right' way to do this. |
| 92 | + train_kwargs = request.getfixturevalue(train_kwargs) |
| 93 | + dummy_model = request.getfixturevalue(dummy_model) |
| 94 | + |
| 95 | + # Define some sample values for testing |
| 96 | + tuning_type = TuningType.LORA |
| 97 | + tuning_config = LoraTuningConfig(r=8, lora_alpha=8, lora_dropout=0.0) |
| 98 | + dummy_resource = train_kwargs["base_model"] |
| 99 | + |
| 100 | + # Call the function being tested |
| 101 | + task_type, output_model_types, lora_config, tuning_type = get_lora_config( |
| 102 | + tuning_type, tuning_config, dummy_resource |
| 103 | + ) |
| 104 | + |
| 105 | + # Add assertions to validate the behavior of the function |
| 106 | + assert task_type == dummy_resource.TASK_TYPE |
| 107 | + assert output_model_types == dummy_resource.PROMPT_OUTPUT_TYPES |
| 108 | + assert tuning_type == TuningType.LORA |
| 109 | + |
| 110 | + # Validation for type & important fields in the peft config |
| 111 | + assert isinstance(lora_config, LoraConfig) |
| 112 | + assert lora_config.task_type == dummy_resource.TASK_TYPE |
| 113 | + assert lora_config.r == tuning_config.r |
| 114 | + assert lora_config.lora_alpha == tuning_config.lora_alpha |
| 115 | + assert lora_config.lora_dropout == tuning_config.lora_dropout |
| 116 | + |
| 117 | + |
77 | 118 | def test_resolve_model_with_invalid_path_raises(): |
78 | 119 | """Test passing invalid path to resolve_model function raises""" |
79 | 120 |
|
|
0 commit comments