|
33 | 33 | seq2seq_lm_dummy_model, |
34 | 34 | seq2seq_lm_train_kwargs, |
35 | 35 | set_cpu_device, |
| 36 | + temp_config, |
36 | 37 | ) |
37 | 38 | import caikit_nlp |
38 | 39 |
|
@@ -399,3 +400,108 @@ def test_run_exponential_decay_len_penatly_object(causal_lm_dummy_model): |
399 | 400 | exponential_decay_length_penalty=penalty, |
400 | 401 | ) |
401 | 402 | assert isinstance(pred, GeneratedTextResult) |
| 403 | + |
| 404 | + |
| 405 | +def test_train_with_data_validation_raises(causal_lm_train_kwargs, set_cpu_device): |
| 406 | + """Check if we are able to throw error for when number of examples are more than configured limit""" |
| 407 | + patch_kwargs = { |
| 408 | + "num_epochs": 1, |
| 409 | + "verbalizer": "Tweet text : {{input}} Label : ", |
| 410 | + "train_stream": caikit.core.data_model.DataStream.from_iterable( |
| 411 | + [ |
| 412 | + ClassificationTrainRecord( |
| 413 | + text="@foo what a cute dog!", labels=["no complaint"] |
| 414 | + ), |
| 415 | + ClassificationTrainRecord( |
| 416 | + text="@bar this is the worst idea ever.", labels=["complaint"] |
| 417 | + ), |
| 418 | + ] |
| 419 | + ), |
| 420 | + "torch_dtype": torch.bfloat16, |
| 421 | + "device": "cpu", |
| 422 | + } |
| 423 | + causal_lm_train_kwargs.update(patch_kwargs) |
| 424 | + |
| 425 | + model_name = causal_lm_train_kwargs["base_model"]._model_name |
| 426 | + module = caikit_nlp.modules.text_generation.PeftPromptTuning |
| 427 | + with temp_config(training_data_limit={module.MODULE_ID: {model_name: 1}}): |
| 428 | + with pytest.raises(ValueError): |
| 429 | + module.train(**causal_lm_train_kwargs) |
| 430 | + |
| 431 | + |
| 432 | +def test_train_with_data_validation_success(causal_lm_train_kwargs, set_cpu_device): |
| 433 | + """Check if we are able to train successfully if training data is within limits""" |
| 434 | + patch_kwargs = { |
| 435 | + "num_epochs": 1, |
| 436 | + "verbalizer": "Tweet text : {{input}} Label : ", |
| 437 | + "train_stream": caikit.core.data_model.DataStream.from_iterable( |
| 438 | + [ |
| 439 | + ClassificationTrainRecord( |
| 440 | + text="@foo what a cute dog!", labels=["no complaint"] |
| 441 | + ), |
| 442 | + ClassificationTrainRecord( |
| 443 | + text="@bar this is the worst idea ever.", labels=["complaint"] |
| 444 | + ), |
| 445 | + ] |
| 446 | + ), |
| 447 | + "torch_dtype": torch.bfloat16, |
| 448 | + "device": "cpu", |
| 449 | + } |
| 450 | + causal_lm_train_kwargs.update(patch_kwargs) |
| 451 | + |
| 452 | + model_name = causal_lm_train_kwargs["base_model"]._model_name |
| 453 | + module = caikit_nlp.modules.text_generation.PeftPromptTuning |
| 454 | + with temp_config(training_data_limit={module.MODULE_ID: {model_name: 2}}): |
| 455 | + |
| 456 | + model = module.train(**causal_lm_train_kwargs) |
| 457 | + assert model |
| 458 | + |
| 459 | + |
| 460 | +def test_train_with_non_existent_limit_success(causal_lm_train_kwargs, set_cpu_device): |
| 461 | + """Check if we are able to train successfully if training data limit doesn't exist for particular model""" |
| 462 | + patch_kwargs = { |
| 463 | + "num_epochs": 1, |
| 464 | + "verbalizer": "Tweet text : {{input}} Label : ", |
| 465 | + "train_stream": caikit.core.data_model.DataStream.from_iterable( |
| 466 | + [ |
| 467 | + ClassificationTrainRecord( |
| 468 | + text="@foo what a cute dog!", labels=["no complaint"] |
| 469 | + ) |
| 470 | + ] |
| 471 | + ), |
| 472 | + "torch_dtype": torch.bfloat16, |
| 473 | + "device": "cpu", |
| 474 | + } |
| 475 | + causal_lm_train_kwargs.update(patch_kwargs) |
| 476 | + |
| 477 | + model_name = causal_lm_train_kwargs["base_model"]._model_name |
| 478 | + module = caikit_nlp.modules.text_generation.PeftPromptTuning |
| 479 | + with temp_config(training_data_limit={module.MODULE_ID: {"foo": 2}}): |
| 480 | + |
| 481 | + model = module.train(**causal_lm_train_kwargs) |
| 482 | + assert model |
| 483 | + |
| 484 | + |
| 485 | +def test_train_with_no_limit_for_module(causal_lm_train_kwargs, set_cpu_device): |
| 486 | + """Check if we are able to train successfully if training data limit doesn't exist prompt tuning module""" |
| 487 | + patch_kwargs = { |
| 488 | + "num_epochs": 1, |
| 489 | + "verbalizer": "Tweet text : {{input}} Label : ", |
| 490 | + "train_stream": caikit.core.data_model.DataStream.from_iterable( |
| 491 | + [ |
| 492 | + ClassificationTrainRecord( |
| 493 | + text="@foo what a cute dog!", labels=["no complaint"] |
| 494 | + ) |
| 495 | + ] |
| 496 | + ), |
| 497 | + "torch_dtype": torch.bfloat16, |
| 498 | + "device": "cpu", |
| 499 | + } |
| 500 | + causal_lm_train_kwargs.update(patch_kwargs) |
| 501 | + |
| 502 | + model_name = causal_lm_train_kwargs["base_model"]._model_name |
| 503 | + module = caikit_nlp.modules.text_generation.PeftPromptTuning |
| 504 | + with temp_config(training_data_limit={}): |
| 505 | + |
| 506 | + model = module.train(**causal_lm_train_kwargs) |
| 507 | + assert model |
0 commit comments