@@ -505,3 +505,135 @@ def test_train_with_no_limit_for_module(causal_lm_train_kwargs, set_cpu_device):
505505
506506 model = module .train (** causal_lm_train_kwargs )
507507 assert model
508+
509+
510+ def test_train_module_level_data_validation_raises (
511+ causal_lm_train_kwargs , set_cpu_device
512+ ):
513+ """Check if train raises with module level default configuration
514+ if training data is within limits and model config is not provided
515+ """
516+ patch_kwargs = {
517+ "num_epochs" : 1 ,
518+ "verbalizer" : "Tweet text : {{input}} Label : " ,
519+ "train_stream" : caikit .core .data_model .DataStream .from_iterable (
520+ [
521+ ClassificationTrainRecord (
522+ text = "@foo what a cute dog!" , labels = ["no complaint" ]
523+ ),
524+ ClassificationTrainRecord (
525+ text = "@bar this is the worst idea ever." , labels = ["complaint" ]
526+ ),
527+ ]
528+ ),
529+ "torch_dtype" : torch .bfloat16 ,
530+ "device" : "cpu" ,
531+ }
532+ causal_lm_train_kwargs .update (patch_kwargs )
533+
534+ module = caikit_nlp .modules .text_generation .PeftPromptTuning
535+ with temp_config (
536+ training_data_limit = {module .MODULE_ID : {"__default__" : 1 , "foo" : 2 }}
537+ ):
538+ with pytest .raises (ValueError ):
539+ module .train (** causal_lm_train_kwargs )
540+
541+
542+ def test_train_module_level_data_validation_success (
543+ causal_lm_train_kwargs , set_cpu_device
544+ ):
545+ """Check if we are able to train successfully with module level default configuration
546+ if training data is within limits and model config present
547+ """
548+ patch_kwargs = {
549+ "num_epochs" : 1 ,
550+ "verbalizer" : "Tweet text : {{input}} Label : " ,
551+ "train_stream" : caikit .core .data_model .DataStream .from_iterable (
552+ [
553+ ClassificationTrainRecord (
554+ text = "@foo what a cute dog!" , labels = ["no complaint" ]
555+ ),
556+ ClassificationTrainRecord (
557+ text = "@bar this is the worst idea ever." , labels = ["complaint" ]
558+ ),
559+ ]
560+ ),
561+ "torch_dtype" : torch .bfloat16 ,
562+ "device" : "cpu" ,
563+ }
564+ causal_lm_train_kwargs .update (patch_kwargs )
565+
566+ model_name = causal_lm_train_kwargs ["base_model" ]._model_name
567+ module = caikit_nlp .modules .text_generation .PeftPromptTuning
568+ with temp_config (
569+ training_data_limit = {module .MODULE_ID : {"__default__" : 1 , model_name : 2 }}
570+ ):
571+
572+ model = module .train (** causal_lm_train_kwargs )
573+ assert model
574+
575+
576+ def test_train_global_default_data_validation_raises (
577+ causal_lm_train_kwargs , set_cpu_device
578+ ):
579+ """Check if train raises with global default configuration
580+ if training data is within limits and model config is not provided
581+ """
582+ patch_kwargs = {
583+ "num_epochs" : 1 ,
584+ "verbalizer" : "Tweet text : {{input}} Label : " ,
585+ "train_stream" : caikit .core .data_model .DataStream .from_iterable (
586+ [
587+ ClassificationTrainRecord (
588+ text = "@foo what a cute dog!" , labels = ["no complaint" ]
589+ ),
590+ ClassificationTrainRecord (
591+ text = "@bar this is the worst idea ever." , labels = ["complaint" ]
592+ ),
593+ ]
594+ ),
595+ "torch_dtype" : torch .bfloat16 ,
596+ "device" : "cpu" ,
597+ }
598+ causal_lm_train_kwargs .update (patch_kwargs )
599+
600+ module = caikit_nlp .modules .text_generation .PeftPromptTuning
601+ with temp_config (
602+ training_data_limit = {"__default__" : 1 , module .MODULE_ID : {"foo" : 2 }}
603+ ):
604+ with pytest .raises (ValueError ):
605+ module .train (** causal_lm_train_kwargs )
606+
607+
608+ def test_train_global_default_data_validation_success (
609+ causal_lm_train_kwargs , set_cpu_device
610+ ):
611+ """Check if we are able to train successfully with global default configuration
612+ if training data is within limits and model config is present
613+ """
614+ patch_kwargs = {
615+ "num_epochs" : 1 ,
616+ "verbalizer" : "Tweet text : {{input}} Label : " ,
617+ "train_stream" : caikit .core .data_model .DataStream .from_iterable (
618+ [
619+ ClassificationTrainRecord (
620+ text = "@foo what a cute dog!" , labels = ["no complaint" ]
621+ ),
622+ ClassificationTrainRecord (
623+ text = "@bar this is the worst idea ever." , labels = ["complaint" ]
624+ ),
625+ ]
626+ ),
627+ "torch_dtype" : torch .bfloat16 ,
628+ "device" : "cpu" ,
629+ }
630+ causal_lm_train_kwargs .update (patch_kwargs )
631+
632+ model_name = causal_lm_train_kwargs ["base_model" ]._model_name
633+ module = caikit_nlp .modules .text_generation .PeftPromptTuning
634+ with temp_config (
635+ training_data_limit = {"__default__" : 1 , module .MODULE_ID : {model_name : 2 }}
636+ ):
637+
638+ model = module .train (** causal_lm_train_kwargs )
639+ assert model
0 commit comments