Skip to content

Commit 4a5b2f8

Browse files
authored
Merge pull request #247 from dtrifiro/fix-uncaught-exception
fix uncaught exception in http client
2 parents e512728 + 9c0330c commit 4a5b2f8

File tree

2 files changed

+49
-3
lines changed

2 files changed

+49
-3
lines changed

caikit_nlp/toolkit/text_generation/model_run_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from caikit.core.data_model.producer import ProducerId
2828
from caikit.core.exceptions import error_handler
2929
from caikit.interfaces.nlp.data_model import (
30+
FinishReason,
3031
GeneratedTextResult,
3132
GeneratedTextStreamResult,
3233
TokenStreamDetails,
@@ -237,16 +238,16 @@ def generate_text_func(
237238
if (eos_token and tokenizer.decode(generate_ids[0, -1].item()) == eos_token) or (
238239
generate_ids[0, -1] == tokenizer.eos_token_id
239240
):
240-
finish_reason = "EOS_TOKEN"
241+
finish_reason = FinishReason.EOS_TOKEN
241242
elif ("stopping_criteria" in gen_optional_params) and (
242243
gen_optional_params["stopping_criteria"](
243244
generate_ids,
244245
None, # scores, unused by SequenceStoppingCriteria
245246
)
246247
):
247-
finish_reason = "STOP_SEQUENCE"
248+
finish_reason = FinishReason.STOP_SEQUENCE
248249
else:
249-
finish_reason = "MAX_TOKENS"
250+
finish_reason = FinishReason.MAX_TOKENS
250251

251252
return GeneratedTextResult(
252253
generated_tokens=token_count,
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Third Party
2+
import pytest
3+
4+
# First Party
5+
from caikit.core.data_model.producer import ProducerId
6+
from caikit.interfaces.nlp.data_model import GeneratedTextResult
7+
8+
# Local
9+
from caikit_nlp.toolkit.text_generation.model_run_utils import generate_text_func
10+
from tests.fixtures import (
11+
causal_lm_dummy_model,
12+
causal_lm_train_kwargs,
13+
seq2seq_lm_dummy_model,
14+
seq2seq_lm_train_kwargs,
15+
)
16+
17+
18+
@pytest.mark.parametrize(
19+
"model_fixture", ["seq2seq_lm_dummy_model", "causal_lm_dummy_model"]
20+
)
21+
@pytest.mark.parametrize(
22+
"serialization_method,expected_type",
23+
[
24+
("to_dict", dict),
25+
("to_json", str),
26+
("to_proto", GeneratedTextResult._proto_class),
27+
],
28+
)
29+
def test_generate_text_func_serialization_json(
30+
request,
31+
model_fixture,
32+
serialization_method,
33+
expected_type,
34+
):
35+
model = request.getfixturevalue(model_fixture)
36+
generated_text = generate_text_func(
37+
model=model.model,
38+
tokenizer=model.tokenizer,
39+
producer_id=ProducerId("TextGeneration", "0.1.0"),
40+
eos_token="<\n>",
41+
text="What is the boiling point of liquid Nitrogen?",
42+
)
43+
44+
serialized = getattr(generated_text, serialization_method)()
45+
assert isinstance(serialized, expected_type)

0 commit comments

Comments
 (0)