Skip to content

Commit 4be54cf

Browse files
authored
Merge pull request #264 from gkumbhat/add_global_training_data_limit
Update training data validation to consider global and module level defaults
2 parents 4a5b2f8 + b387a77 commit 4be54cf

File tree

5 files changed

+166
-13
lines changed

5 files changed

+166
-13
lines changed

caikit_nlp/config/config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ master_addr: localhost
3131
master_port: 29550
3232

3333
training_data_limit:
34+
__default__: -1
3435
# Configuration for PeftPromptTuning module
3536
6655831b-960a-4dc5-8df4-867026e2cd41:
3637
add_model_name_here: 10000

caikit_nlp/modules/text_generation/peft_prompt_tuning.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
import transformers
4242

4343
# First Party
44-
from caikit import get_config
4544
from caikit.core.data_model import DataStream
4645
from caikit.core.exceptions import error_handler
4746
from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module
@@ -73,6 +72,7 @@
7372
generate_text_func,
7473
generate_text_func_stream,
7574
)
75+
from ...toolkit.trainer_utils import validate_training_data
7676
from ...toolkit.verbalizer_utils import render_verbalizer
7777
from .peft_config import TuningType, get_peft_config, resolve_base_model
7878

@@ -368,19 +368,12 @@ def train(
368368
)
369369

370370
# Check if data is within limit allowed for this module and model
371-
max_num_examples = (
372-
get_config()
373-
.training_data_limit.get(cls.MODULE_ID, {})
374-
.get(base_model_name, -1)
371+
validate_training_data(
372+
train_stream,
373+
base_model_name,
374+
cls.MODULE_ID,
375375
)
376376

377-
if max_num_examples > 0:
378-
error.value_check(
379-
"<NLP77627434E>",
380-
len(train_stream) <= max_num_examples,
381-
"Number of examples larger than maximum number of examples allowed for this model",
382-
)
383-
384377
# Coerce the passed model into a resource; if we have one, this is a noop
385378
# TODO: When splitting up this mono-module, use the configured resource
386379
# type of the concrete class to bootstrap

caikit_nlp/toolkit/trainer_utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,36 @@
1919
import torch
2020

2121
# First Party
22+
from caikit import get_config
23+
from caikit.core.data_model import DataStream
24+
from caikit.core.exceptions import error_handler
2225
import alog
2326

2427
log = alog.use_channel("TRNR_UTILS")
28+
error = error_handler.get(log)
29+
30+
31+
def validate_training_data(train_stream: DataStream, model_name: str, module_id: str):
32+
33+
global_default = get_config().training_data_limit.__default__
34+
module_default = (
35+
get_config()
36+
.training_data_limit.get(module_id, {})
37+
.get("__default__", global_default)
38+
)
39+
40+
max_num_examples = (
41+
get_config()
42+
.training_data_limit.get(module_id, {})
43+
.get(model_name, module_default)
44+
)
45+
46+
if max_num_examples > 0:
47+
error.value_check(
48+
"<NLP77627434E>",
49+
len(train_stream) <= max_num_examples,
50+
"Number of examples larger than maximum number of examples allowed for this model",
51+
)
2552

2653

2754
def log_step(state, logs):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ classifiers=[
1414
"License :: OSI Approved :: Apache Software License"
1515
]
1616
dependencies = [
17-
"caikit[runtime-grpc,runtime-http]>=0.23.2,<0.25.0",
17+
"caikit[runtime-grpc,runtime-http]>=0.24.0,<0.25.0",
1818
"caikit-tgis-backend>=0.1.17,<0.2.0",
1919
# TODO: loosen dependencies
2020
"accelerate>=0.22.0",

tests/modules/text_generation/test_peft_prompt_tuning.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,3 +505,135 @@ def test_train_with_no_limit_for_module(causal_lm_train_kwargs, set_cpu_device):
505505

506506
model = module.train(**causal_lm_train_kwargs)
507507
assert model
508+
509+
510+
def test_train_module_level_data_validation_raises(
511+
causal_lm_train_kwargs, set_cpu_device
512+
):
513+
"""Check if train raises with module level default configuration
514+
if training data is within limits and model config is not provided
515+
"""
516+
patch_kwargs = {
517+
"num_epochs": 1,
518+
"verbalizer": "Tweet text : {{input}} Label : ",
519+
"train_stream": caikit.core.data_model.DataStream.from_iterable(
520+
[
521+
ClassificationTrainRecord(
522+
text="@foo what a cute dog!", labels=["no complaint"]
523+
),
524+
ClassificationTrainRecord(
525+
text="@bar this is the worst idea ever.", labels=["complaint"]
526+
),
527+
]
528+
),
529+
"torch_dtype": torch.bfloat16,
530+
"device": "cpu",
531+
}
532+
causal_lm_train_kwargs.update(patch_kwargs)
533+
534+
module = caikit_nlp.modules.text_generation.PeftPromptTuning
535+
with temp_config(
536+
training_data_limit={module.MODULE_ID: {"__default__": 1, "foo": 2}}
537+
):
538+
with pytest.raises(ValueError):
539+
module.train(**causal_lm_train_kwargs)
540+
541+
542+
def test_train_module_level_data_validation_success(
543+
causal_lm_train_kwargs, set_cpu_device
544+
):
545+
"""Check if we are able to train successfully with module level default configuration
546+
if training data is within limits and model config present
547+
"""
548+
patch_kwargs = {
549+
"num_epochs": 1,
550+
"verbalizer": "Tweet text : {{input}} Label : ",
551+
"train_stream": caikit.core.data_model.DataStream.from_iterable(
552+
[
553+
ClassificationTrainRecord(
554+
text="@foo what a cute dog!", labels=["no complaint"]
555+
),
556+
ClassificationTrainRecord(
557+
text="@bar this is the worst idea ever.", labels=["complaint"]
558+
),
559+
]
560+
),
561+
"torch_dtype": torch.bfloat16,
562+
"device": "cpu",
563+
}
564+
causal_lm_train_kwargs.update(patch_kwargs)
565+
566+
model_name = causal_lm_train_kwargs["base_model"]._model_name
567+
module = caikit_nlp.modules.text_generation.PeftPromptTuning
568+
with temp_config(
569+
training_data_limit={module.MODULE_ID: {"__default__": 1, model_name: 2}}
570+
):
571+
572+
model = module.train(**causal_lm_train_kwargs)
573+
assert model
574+
575+
576+
def test_train_global_default_data_validation_raises(
577+
causal_lm_train_kwargs, set_cpu_device
578+
):
579+
"""Check if train raises with global default configuration
580+
if training data is within limits and model config is not provided
581+
"""
582+
patch_kwargs = {
583+
"num_epochs": 1,
584+
"verbalizer": "Tweet text : {{input}} Label : ",
585+
"train_stream": caikit.core.data_model.DataStream.from_iterable(
586+
[
587+
ClassificationTrainRecord(
588+
text="@foo what a cute dog!", labels=["no complaint"]
589+
),
590+
ClassificationTrainRecord(
591+
text="@bar this is the worst idea ever.", labels=["complaint"]
592+
),
593+
]
594+
),
595+
"torch_dtype": torch.bfloat16,
596+
"device": "cpu",
597+
}
598+
causal_lm_train_kwargs.update(patch_kwargs)
599+
600+
module = caikit_nlp.modules.text_generation.PeftPromptTuning
601+
with temp_config(
602+
training_data_limit={"__default__": 1, module.MODULE_ID: {"foo": 2}}
603+
):
604+
with pytest.raises(ValueError):
605+
module.train(**causal_lm_train_kwargs)
606+
607+
608+
def test_train_global_default_data_validation_success(
609+
causal_lm_train_kwargs, set_cpu_device
610+
):
611+
"""Check if we are able to train successfully with global default configuration
612+
if training data is within limits and model config is present
613+
"""
614+
patch_kwargs = {
615+
"num_epochs": 1,
616+
"verbalizer": "Tweet text : {{input}} Label : ",
617+
"train_stream": caikit.core.data_model.DataStream.from_iterable(
618+
[
619+
ClassificationTrainRecord(
620+
text="@foo what a cute dog!", labels=["no complaint"]
621+
),
622+
ClassificationTrainRecord(
623+
text="@bar this is the worst idea ever.", labels=["complaint"]
624+
),
625+
]
626+
),
627+
"torch_dtype": torch.bfloat16,
628+
"device": "cpu",
629+
}
630+
causal_lm_train_kwargs.update(patch_kwargs)
631+
632+
model_name = causal_lm_train_kwargs["base_model"]._model_name
633+
module = caikit_nlp.modules.text_generation.PeftPromptTuning
634+
with temp_config(
635+
training_data_limit={"__default__": 1, module.MODULE_ID: {model_name: 2}}
636+
):
637+
638+
model = module.train(**causal_lm_train_kwargs)
639+
assert model

0 commit comments

Comments
 (0)