55import tempfile
66
77# Third Party
8+ import pytest
89import torch
910
1011# First Party
2021)
2122from caikit_nlp .modules .text_generation .peft_prompt_tuning import PeftPromptTuning
2223from tests .fixtures import causal_lm_dummy_model , causal_lm_train_kwargs
24+ import caikit_nlp
2325
2426####################
2527## train/run ##
@@ -62,6 +64,7 @@ def test_run_classification_model(causal_lm_dummy_model):
6264 # Returns supported class labels or None
6365 classifier_model .unique_class_labels .append (None )
6466 assert output .results [0 ].label in classifier_model .unique_class_labels
67+ assert output .results [0 ].score == None
6568
6669
6770def test_train_run_model_classification_record (causal_lm_train_kwargs ):
@@ -87,6 +90,13 @@ def test_train_run_model_classification_record(causal_lm_train_kwargs):
8790 # Test fallback to float32 behavior if this machine doesn't support bfloat16
8891 assert model .classifier .model .dtype is torch .float32
8992 assert isinstance (model , ClassificationPeftPromptTuning )
93+ output = model .run ("Text does not matter" )
94+ assert isinstance (output , ClassificationResults )
95+ assert model .unique_class_labels == ["complaint" , "no complaint" ]
96+ # Returns supported class labels or None
97+ model .unique_class_labels .append (None )
98+ assert output .results [0 ].label in model .unique_class_labels
99+ assert output .results [0 ].score == None
90100
91101
92102####################
@@ -104,20 +114,45 @@ def test_save(causal_lm_dummy_model):
104114 assert os .path .exists (os .path .join (model_dir , "artifacts" , "config.yml" ))
105115
106116
107- # TODO: Enable test when saving of base model is enabled in module_saver
108- # def test_save_and_load(causal_lm_dummy_model):
109- # classifier_model = ClassificationPeftPromptTuning(
110- # classifier=causal_lm_dummy_model, unique_class_labels=["label1", "label2"]
111- # )
112- # with tempfile.TemporaryDirectory() as model_dir:
113- # classifier_model.save(model_dir)
114- # model_load = caikit_nlp.load(model_dir)
115- # assert isinstance(model_load, ClassificationPeftPromptTuning)
116- # assert isinstance(model_load.classifier, PeftPromptTuning)
117- # assert model_load.unique_class_labels == ["label1", "label2"]
117+ def test_save_and_reload_with_base_model (causal_lm_dummy_model ):
118+ classifier_model = ClassificationPeftPromptTuning (
119+ classifier = causal_lm_dummy_model , unique_class_labels = ["label1" , "label2" ]
120+ )
121+ with tempfile .TemporaryDirectory () as model_dir :
122+ classifier_model .save (model_dir , save_base_model = True )
123+ model_load = caikit_nlp .load (model_dir )
124+ assert isinstance (model_load , ClassificationPeftPromptTuning )
125+ assert isinstance (model_load .classifier , PeftPromptTuning )
126+ assert model_load .unique_class_labels == ["label1" , "label2" ]
127+
128+
129+ def test_save_and_reload_without_base_model (causal_lm_dummy_model ):
130+ """Ensure that if we don't save the base model, we get the expected behavior."""
131+ with tempfile .TemporaryDirectory () as model_dir :
132+ causal_lm_dummy_model .save (model_dir , save_base_model = False )
133+ # For now, if we are missing the base model at load time, we throw ValueError
134+ with pytest .raises (ValueError ):
135+ caikit_nlp .load (model_dir )
136+
118137
119138####################
120139## save/load/run ##
121140####################
122141
123- # TODO after load is fixed
142+
143+ def test_save_reload_and_run_with_base_model (causal_lm_dummy_model ):
144+ classifier_model = ClassificationPeftPromptTuning (
145+ classifier = causal_lm_dummy_model , unique_class_labels = ["label1" , "label2" ]
146+ )
147+ with tempfile .TemporaryDirectory () as model_dir :
148+ classifier_model .save (model_dir , save_base_model = True )
149+ model_load = caikit_nlp .load (model_dir )
150+ assert isinstance (model_load , ClassificationPeftPromptTuning )
151+ assert isinstance (model_load .classifier , PeftPromptTuning )
152+ assert model_load .unique_class_labels == ["label1" , "label2" ]
153+ output = model_load .run ("Text does not matter" )
154+ assert isinstance (output , ClassificationResults )
155+ # Returns supported class labels or None
156+ model_load .unique_class_labels .append (None )
157+ assert output .results [0 ].label in model_load .unique_class_labels
158+ assert output .results [0 ].score == None
0 commit comments