Skip to content

Commit e512728

Browse files
authored
Merge pull request #255 from gkumbhat/add_data_limitation
Add data limitation
2 parents e8d176e + 7834fae commit e512728

File tree

4 files changed

+128
-2
lines changed

4 files changed

+128
-2
lines changed

caikit_nlp/config/config.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,10 @@ unload_tgis_prompt_artifacts: false
3030
master_addr: localhost
3131
master_port: 29550
3232

33+
training_data_limit:
34+
# Configuration for PeftPromptTuning module
35+
6655831b-960a-4dc5-8df4-867026e2cd41:
36+
add_model_name_here: 10000
37+
3338
runtime:
3439
library: caikit_nlp

caikit_nlp/modules/text_generation/peft_prompt_tuning.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import transformers
4242

4343
# First Party
44+
from caikit import get_config
4445
from caikit.core.data_model import DataStream
4546
from caikit.core.exceptions import error_handler
4647
from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module
@@ -366,6 +367,20 @@ def train(
366367
verbalizer,
367368
)
368369

370+
# Check if data is within limit allowed for this module and model
371+
max_num_examples = (
372+
get_config()
373+
.training_data_limit.get(cls.MODULE_ID, {})
374+
.get(base_model_name, -1)
375+
)
376+
377+
if max_num_examples > 0:
378+
error.value_check(
379+
"<NLP77627434E>",
380+
len(train_stream) <= max_num_examples,
381+
"Number of examples larger than maximum number of examples allowed for this model",
382+
)
383+
369384
# Coerce the passed model into a resource; if we have one, this is a noop
370385
# TODO: When splitting up this mono-module, use the configured resource
371386
# type of the concrete class to bootstrap

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ classifiers=[
1414
"License :: OSI Approved :: Apache Software License"
1515
]
1616
dependencies = [
17-
"caikit[runtime-grpc,runtime-http]>=0.22.0,<0.23.0",
17+
"caikit[runtime-grpc,runtime-http]>=0.23.2,<0.25.0",
1818
"caikit-tgis-backend>=0.1.17,<0.2.0",
1919
# TODO: loosen dependencies
2020
"accelerate>=0.22.0",
@@ -32,7 +32,7 @@ dependencies = [
3232
# which broke caikit-nlp build. peft hasn't released newer version yet, so to get
3333
# the build fix, we pulling peft from main branch commit. In future, we will pull PEFT from
3434
# pypi
35-
"peft@git+https://github.com/huggingface/peft.git#8c17d556a8fe9522e10d73d7bd3fad46a6ecae14"
35+
"peft@git+https://github.com/huggingface/peft.git@8c17d556a8fe9522e10d73d7bd3fad46a6ecae14"
3636
]
3737

3838
[tool.setuptools.packages.find]

tests/modules/text_generation/test_peft_prompt_tuning.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
seq2seq_lm_dummy_model,
3434
seq2seq_lm_train_kwargs,
3535
set_cpu_device,
36+
temp_config,
3637
)
3738
import caikit_nlp
3839

@@ -399,3 +400,108 @@ def test_run_exponential_decay_len_penatly_object(causal_lm_dummy_model):
399400
exponential_decay_length_penalty=penalty,
400401
)
401402
assert isinstance(pred, GeneratedTextResult)
403+
404+
405+
def test_train_with_data_validation_raises(causal_lm_train_kwargs, set_cpu_device):
406+
"""Check if we are able to throw error for when number of examples are more than configured limit"""
407+
patch_kwargs = {
408+
"num_epochs": 1,
409+
"verbalizer": "Tweet text : {{input}} Label : ",
410+
"train_stream": caikit.core.data_model.DataStream.from_iterable(
411+
[
412+
ClassificationTrainRecord(
413+
text="@foo what a cute dog!", labels=["no complaint"]
414+
),
415+
ClassificationTrainRecord(
416+
text="@bar this is the worst idea ever.", labels=["complaint"]
417+
),
418+
]
419+
),
420+
"torch_dtype": torch.bfloat16,
421+
"device": "cpu",
422+
}
423+
causal_lm_train_kwargs.update(patch_kwargs)
424+
425+
model_name = causal_lm_train_kwargs["base_model"]._model_name
426+
module = caikit_nlp.modules.text_generation.PeftPromptTuning
427+
with temp_config(training_data_limit={module.MODULE_ID: {model_name: 1}}):
428+
with pytest.raises(ValueError):
429+
module.train(**causal_lm_train_kwargs)
430+
431+
432+
def test_train_with_data_validation_success(causal_lm_train_kwargs, set_cpu_device):
433+
"""Check if we are able to train successfully if training data is within limits"""
434+
patch_kwargs = {
435+
"num_epochs": 1,
436+
"verbalizer": "Tweet text : {{input}} Label : ",
437+
"train_stream": caikit.core.data_model.DataStream.from_iterable(
438+
[
439+
ClassificationTrainRecord(
440+
text="@foo what a cute dog!", labels=["no complaint"]
441+
),
442+
ClassificationTrainRecord(
443+
text="@bar this is the worst idea ever.", labels=["complaint"]
444+
),
445+
]
446+
),
447+
"torch_dtype": torch.bfloat16,
448+
"device": "cpu",
449+
}
450+
causal_lm_train_kwargs.update(patch_kwargs)
451+
452+
model_name = causal_lm_train_kwargs["base_model"]._model_name
453+
module = caikit_nlp.modules.text_generation.PeftPromptTuning
454+
with temp_config(training_data_limit={module.MODULE_ID: {model_name: 2}}):
455+
456+
model = module.train(**causal_lm_train_kwargs)
457+
assert model
458+
459+
460+
def test_train_with_non_existent_limit_success(causal_lm_train_kwargs, set_cpu_device):
461+
"""Check if we are able to train successfully if training data limit doesn't exist for particular model"""
462+
patch_kwargs = {
463+
"num_epochs": 1,
464+
"verbalizer": "Tweet text : {{input}} Label : ",
465+
"train_stream": caikit.core.data_model.DataStream.from_iterable(
466+
[
467+
ClassificationTrainRecord(
468+
text="@foo what a cute dog!", labels=["no complaint"]
469+
)
470+
]
471+
),
472+
"torch_dtype": torch.bfloat16,
473+
"device": "cpu",
474+
}
475+
causal_lm_train_kwargs.update(patch_kwargs)
476+
477+
model_name = causal_lm_train_kwargs["base_model"]._model_name
478+
module = caikit_nlp.modules.text_generation.PeftPromptTuning
479+
with temp_config(training_data_limit={module.MODULE_ID: {"foo": 2}}):
480+
481+
model = module.train(**causal_lm_train_kwargs)
482+
assert model
483+
484+
485+
def test_train_with_no_limit_for_module(causal_lm_train_kwargs, set_cpu_device):
486+
"""Check if we are able to train successfully if training data limit doesn't exist prompt tuning module"""
487+
patch_kwargs = {
488+
"num_epochs": 1,
489+
"verbalizer": "Tweet text : {{input}} Label : ",
490+
"train_stream": caikit.core.data_model.DataStream.from_iterable(
491+
[
492+
ClassificationTrainRecord(
493+
text="@foo what a cute dog!", labels=["no complaint"]
494+
)
495+
]
496+
),
497+
"torch_dtype": torch.bfloat16,
498+
"device": "cpu",
499+
}
500+
causal_lm_train_kwargs.update(patch_kwargs)
501+
502+
model_name = causal_lm_train_kwargs["base_model"]._model_name
503+
module = caikit_nlp.modules.text_generation.PeftPromptTuning
504+
with temp_config(training_data_limit={}):
505+
506+
model = module.train(**causal_lm_train_kwargs)
507+
assert model

0 commit comments

Comments
 (0)