Skip to content

Commit 2682ef3

Browse files
authored
Merge pull request #262 from dtrifiro/improve-test-times
tests: make models fixtures session-scoped
2 parents 4be54cf + 84a7924 commit 2682ef3

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

tests/fixtures/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def models_cache_dir(request):
108108

109109
### Fixtures for grabbing a randomly initialized model to test interfaces against
110110
## Causal LM
111-
@pytest.fixture
111+
@pytest.fixture(scope="session")
112112
def causal_lm_train_kwargs():
113113
"""Get the kwargs for a valid train call to a Causal LM."""
114114
model_kwargs = {
@@ -124,15 +124,15 @@ def causal_lm_train_kwargs():
124124
return model_kwargs
125125

126126

127-
@pytest.fixture
127+
@pytest.fixture(scope="session")
128128
def causal_lm_dummy_model(causal_lm_train_kwargs):
129129
"""Train a Causal LM dummy model."""
130130
return caikit_nlp.modules.text_generation.PeftPromptTuning.train(
131131
**causal_lm_train_kwargs
132132
)
133133

134134

135-
@pytest.fixture
135+
@pytest.fixture(scope="session")
136136
def saved_causal_lm_dummy_model(causal_lm_dummy_model):
137137
"""Give a path to a saved dummy model that can be loaded"""
138138
with tempfile.TemporaryDirectory() as workdir:
@@ -142,7 +142,7 @@ def saved_causal_lm_dummy_model(causal_lm_dummy_model):
142142

143143

144144
## Seq2seq
145-
@pytest.fixture
145+
@pytest.fixture(scope="session")
146146
def seq2seq_lm_train_kwargs():
147147
"""Get the kwargs for a valid train call to a Causal LM."""
148148
model_kwargs = {
@@ -158,7 +158,7 @@ def seq2seq_lm_train_kwargs():
158158
return model_kwargs
159159

160160

161-
@pytest.fixture
161+
@pytest.fixture(scope="session")
162162
def seq2seq_lm_dummy_model(seq2seq_lm_train_kwargs):
163163
"""Train a Seq2Seq LM dummy model."""
164164
return caikit_nlp.modules.text_generation.PeftPromptTuning.train(

tests/modules/text_generation/test_peft_prompt_tuning.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,17 @@ def test_run_stream_out_model(causal_lm_dummy_model):
9393
assert isinstance(pred, GeneratedTextStreamResult)
9494

9595

96-
def test_verbalizer_rendering(causal_lm_dummy_model):
96+
def test_verbalizer_rendering(causal_lm_dummy_model, monkeypatch):
9797
"""Ensure that our model renders its verbalizer text correctly before calling tokenizer."""
9898
# Mock the tokenizer; we want to make sure its inputs are rendered properly
99-
causal_lm_dummy_model.tokenizer = mock.Mock(
100-
side_effect=RuntimeError("Tokenizer is a mock!"),
101-
# Set eos token property to be attribute of tokenizer
102-
eos_token="</s>",
99+
monkeypatch.setattr(
100+
causal_lm_dummy_model,
101+
"tokenizer",
102+
mock.Mock(
103+
side_effect=RuntimeError("Tokenizer is a mock!"),
104+
# Set eos token property to be attribute of tokenizer
105+
eos_token="</s>",
106+
),
103107
)
104108
input_text = "This text doesn't matter"
105109
causal_lm_dummy_model.verbalizer = " | {{input}} |"

0 commit comments

Comments
 (0)