Skip to content

Commit 6eb80c6

Browse files
committed
add tests for save and load
Signed-off-by: Sukriti-Sharma4 <[email protected]>
1 parent 16d2118 commit 6eb80c6

File tree

2 files changed

+54
-14
lines changed

2 files changed

+54
-14
lines changed

caikit_nlp/modules/text_classification/classification_prompt_tuning.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,21 @@ def train(
158158
@wip_decorator.work_in_progress(
159159
category=wip_decorator.WipCategory.WIP, action=wip_decorator.Action.WARNING
160160
)
161-
def save(self, model_path):
161+
def save(self, model_path: str, save_base_model: bool = False):
162162
"""Save classification model
163163
164164
Args:
165165
model_path: str
166166
Folder to save classification prompt tuning model
167+
save_base_model: bool
168+
Save base model along with the prompts in the model_path provided.
169+
Default: False
167170
"""
168171
saver = ModuleSaver(self, model_path=model_path)
169172
with saver:
170-
saver.save_module(self.classifier, "artifacts")
173+
saver.save_module(
174+
self.classifier, "artifacts", save_base_model=save_base_model
175+
)
171176
saver.update_config(
172177
{
173178
"unique_class_labels": self.unique_class_labels,

tests/modules/text_classification/test_classification_prompt_tuning.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import tempfile
66

77
# Third Party
8+
import pytest
89
import torch
910

1011
# First Party
@@ -20,6 +21,7 @@
2021
)
2122
from caikit_nlp.modules.text_generation.peft_prompt_tuning import PeftPromptTuning
2223
from tests.fixtures import causal_lm_dummy_model, causal_lm_train_kwargs
24+
import caikit_nlp
2325

2426
####################
2527
## train/run ##
@@ -62,6 +64,7 @@ def test_run_classification_model(causal_lm_dummy_model):
6264
# Returns supported class labels or None
6365
classifier_model.unique_class_labels.append(None)
6466
assert output.results[0].label in classifier_model.unique_class_labels
67+
assert output.results[0].score == None
6568

6669

6770
def test_train_run_model_classification_record(causal_lm_train_kwargs):
@@ -87,6 +90,13 @@ def test_train_run_model_classification_record(causal_lm_train_kwargs):
8790
# Test fallback to float32 behavior if this machine doesn't support bfloat16
8891
assert model.classifier.model.dtype is torch.float32
8992
assert isinstance(model, ClassificationPeftPromptTuning)
93+
output = model.run("Text does not matter")
94+
assert isinstance(output, ClassificationResults)
95+
assert model.unique_class_labels == ["complaint", "no complaint"]
96+
# Returns supported class labels or None
97+
model.unique_class_labels.append(None)
98+
assert output.results[0].label in model.unique_class_labels
99+
assert output.results[0].score == None
90100

91101

92102
####################
@@ -104,20 +114,45 @@ def test_save(causal_lm_dummy_model):
104114
assert os.path.exists(os.path.join(model_dir, "artifacts", "config.yml"))
105115

106116

107-
# TODO: Enable test when saving of base model is enabled in module_saver
108-
# def test_save_and_load(causal_lm_dummy_model):
109-
# classifier_model = ClassificationPeftPromptTuning(
110-
# classifier=causal_lm_dummy_model, unique_class_labels=["label1", "label2"]
111-
# )
112-
# with tempfile.TemporaryDirectory() as model_dir:
113-
# classifier_model.save(model_dir)
114-
# model_load = caikit_nlp.load(model_dir)
115-
# assert isinstance(model_load, ClassificationPeftPromptTuning)
116-
# assert isinstance(model_load.classifier, PeftPromptTuning)
117-
# assert model_load.unique_class_labels == ["label1", "label2"]
117+
def test_save_and_reload_with_base_model(causal_lm_dummy_model):
118+
classifier_model = ClassificationPeftPromptTuning(
119+
classifier=causal_lm_dummy_model, unique_class_labels=["label1", "label2"]
120+
)
121+
with tempfile.TemporaryDirectory() as model_dir:
122+
classifier_model.save(model_dir, save_base_model=True)
123+
model_load = caikit_nlp.load(model_dir)
124+
assert isinstance(model_load, ClassificationPeftPromptTuning)
125+
assert isinstance(model_load.classifier, PeftPromptTuning)
126+
assert model_load.unique_class_labels == ["label1", "label2"]
127+
128+
129+
def test_save_and_reload_without_base_model(causal_lm_dummy_model):
130+
"""Ensure that if we don't save the base model, we get the expected behavior."""
131+
with tempfile.TemporaryDirectory() as model_dir:
132+
causal_lm_dummy_model.save(model_dir, save_base_model=False)
133+
# For now, if we are missing the base model at load time, we throw ValueError
134+
with pytest.raises(ValueError):
135+
caikit_nlp.load(model_dir)
136+
118137

119138
####################
120139
## save/load/run ##
121140
####################
122141

123-
# TODO after load is fixed
142+
143+
def test_save_reload_and_run_with_base_model(causal_lm_dummy_model):
144+
classifier_model = ClassificationPeftPromptTuning(
145+
classifier=causal_lm_dummy_model, unique_class_labels=["label1", "label2"]
146+
)
147+
with tempfile.TemporaryDirectory() as model_dir:
148+
classifier_model.save(model_dir, save_base_model=True)
149+
model_load = caikit_nlp.load(model_dir)
150+
assert isinstance(model_load, ClassificationPeftPromptTuning)
151+
assert isinstance(model_load.classifier, PeftPromptTuning)
152+
assert model_load.unique_class_labels == ["label1", "label2"]
153+
output = model_load.run("Text does not matter")
154+
assert isinstance(output, ClassificationResults)
155+
# Returns supported class labels or None
156+
model_load.unique_class_labels.append(None)
157+
assert output.results[0].label in model_load.unique_class_labels
158+
assert output.results[0].score == None

0 commit comments

Comments
 (0)