diff --git a/CODEOWNERS b/CODEOWNERS index 53b536af..6cfa142b 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -8,4 +8,4 @@ # https://help.github.com/en/articles/about-code-owners # -* @alex-jw-brooks @gkumbhat @evaline-ju +* @alex-jw-brooks @gkumbhat @evaline-ju @gabe-l-hart diff --git a/README.md b/README.md index d5303d07..1ae92382 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,36 @@ # Caikit NLP +Welcome to the README.md page of `caikit-nlp`! This repository provides NLP domain capabilities running on [caikit](https://github.com/caikit/caikit) framework. + +## Introduction + `caikit_nlp` is a [Caikit](https://github.com/caikit/caikit) library that currently provides [PEFT prompt tuning](https://github.com/huggingface/peft) and MPT (multi-task prompt tuning) functionalities. -More information on MPT can be found at: https://arxiv.org/abs/2303.02861 +### Getting Started + +To help you quickly get started with using Caikit, we have prepared a [Jupyter notebook](examples/Caikit_Getting_Started.ipynb) that can be run in Google Colab. Caikit-nlp is a powerful library that leverages prompt tuning and fine-tuning to add NLP domain capabilities to caikit. + + +### Contributing + +We welcome contributions from the community! If you would like to contribute to `caikit-nlp`, please read the guidelines in the main project's [CONTRIBUTING.md](main/CONTRIBUTING.md) file. It includes information on submitting bug reports, feature requests, and pull requests. Make sure to follow our coding standards, [code of conduct](code-of-conduct.md), [security standards](https://github.com/caikit/community/blob/main/SECURITY.md), and documentation guidelines to streamline the contribution process. + +### License + +This project is licensed under the [ASFv2 License](LICENSE). + +### Glossary + +A list of terms that either may be unfamiliar or that have nebulous definitions based on who and where you hear them, defined for how they are used/thought of in the `caikit`/`caikit-nlp` project: + +* Fine tuning - trains the base model onto new data etc; this changes the base model. +* Prompt engineering - (usually) manually crafting texts that make models do a better job that's left appended to the input text. E.g., if you wanted to do something like sentiment on movie reviews, you might come up with a prompt like The movie was: _____ and replace the _____ with the movie review you're consider to try to get something like happy/sad out of it. +* PEFT - library by Huggingface containing implementations of different tuning methods that scale well - things like prompt tuning, and MPT live there. So PEFT itself isn't an approach even though parameter efficient fine-tuning sounds like one. +Prompt tuning - learning soft prompts. This is different from prompt engineering in that you're not trying to learn tokens. Instead, you're basically trying to learn new embedded representations (sometimes called virtual tokens) that can be concatenated onto your embedded input text to improve the performance. This can work well, but also can be sensitive to initialization. +* Multitask prompt tuning (MPT) - Tries to fix some of the issues with prompt tuning by allowing you to effectively learn 'source prompts' across different tasks & leverage them to initialize your prompt tuning etc. More information on MPT can be found at: https://arxiv.org/abs/2303.02861 -Currently causal language models and sequence-to-sequence models are supported. +The important difference between fine tuning and capabilities like prompt tuning/multi-taskprompt tuning is that the latter doesn't change the base model's weights at all. So when you run inference for prompt tuned models, you can have n prompts to 1 base model, and just inject the prompt tensors you need when they're requested instead of having _n_ separate fine-tuned models. #### Notes -- The data model for text generative capabilities is baked into this repository itself at `caikit_nlp/data_model/generation.py`. +- Currently causal language models and sequence-to-sequence models are supported. diff --git a/caikit_nlp/__init__.py b/caikit_nlp/__init__.py index 6dbb0c00..0c9c020f 100644 --- a/caikit_nlp/__init__.py +++ b/caikit_nlp/__init__.py @@ -24,7 +24,7 @@ # Local # Import subpackages -from . import config, data_model +from . import config, data_model, model_management from .config import * from .data_model import * from .modules import * diff --git a/caikit_nlp/data_model/__init__.py b/caikit_nlp/data_model/__init__.py index f6defedb..6826b631 100644 --- a/caikit_nlp/data_model/__init__.py +++ b/caikit_nlp/data_model/__init__.py @@ -15,7 +15,5 @@ """ # Local -from . import classification, generation, text -from .classification import * +from . import generation from .generation import * -from .text import * diff --git a/caikit_nlp/data_model/classification.py b/caikit_nlp/data_model/classification.py deleted file mode 100644 index 9ffee4f1..00000000 --- a/caikit_nlp/data_model/classification.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""These interfaces can be promoted to caikit/caikit for wider usage -when applicable to multiple modules -""" -# Standard -from typing import List - -# First Party -from caikit.core import DataObjectBase -import caikit - - -@caikit.core.dataobject(package="caikit_data_model.caikit_nlp") -class ClassificationTrainRecord(DataObjectBase): - text: str - labels: List[str] - - -@caikit.core.dataobject(package="caikit_data_model.caikit_nlp") -class Classification(DataObjectBase): - label: str - score: float - - -@caikit.core.dataobject(package="caikit_data_model.caikit_nlp") -class ClassificationResult(DataObjectBase): - results: List[Classification] - - -# NOTE: This is meant to align with the HuggingFace token classification task: -# https://huggingface.co/docs/transformers/tasks/token_classification#inference -# The field `word` does not necessarily correspond to a single "word", -# and `entity` may not always be applicable beyond "entity" in the NER -# (named entity recognition) sense -@caikit.core.dataobject(package="caikit_data_model.caikit_nlp") -class TokenClassification(DataObjectBase): - start: int - end: int - word: str # could be thought of as text - entity: str # could be thought of as label - entity_group: str # could be thought of as aggregate label, if applicable - score: float - - -@caikit.core.dataobject(package="caikit_data_model.caikit_nlp") -class TokenClassificationResult(DataObjectBase): - results: List[TokenClassification] - - -# Streaming result that indicates up to where in stream is processed -@caikit.core.dataobject(package="caikit_data_model.caikit_nlp") -class StreamingTokenClassificationResult(TokenClassificationResult): - # Result index up to which text is processed - processed_index: int diff --git a/caikit_nlp/data_model/text.py b/caikit_nlp/data_model/text.py deleted file mode 100644 index 4bf2f9d7..00000000 --- a/caikit_nlp/data_model/text.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""These interfaces can be promoted to caikit/caikit for wider usage -when applicable to multiple modules -""" - -# Standard -from typing import List - -# First Party -from caikit.core import DataObjectBase -import caikit - - -@caikit.core.dataobject(package="caikit_data_model.caikit_nlp") -class Token(DataObjectBase): - """Tokens here are the basic units of text. Tokens can be characters, words, - sub-words, or other segments of text or code, depending on the method of - tokenization chosen or the task being implemented. - """ - - start: int - end: int - text: str - - -@caikit.core.dataobject(package="caikit_data_model.caikit_nlp") -class TokenizationResult(DataObjectBase): - results: List[Token] diff --git a/caikit_nlp/modules/tokenization/tokenization_task.py b/caikit_nlp/model_management/__init__.py similarity index 71% rename from caikit_nlp/modules/tokenization/tokenization_task.py rename to caikit_nlp/model_management/__init__.py index 9d093191..3816104d 100644 --- a/caikit_nlp/modules/tokenization/tokenization_task.py +++ b/caikit_nlp/model_management/__init__.py @@ -12,16 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -# First Party -from caikit.core import TaskBase, task - # Local -from ...data_model import TokenizationResult - - -@task( - required_parameters={"text": str}, - output_type=TokenizationResult, -) -class TokenizationTask(TaskBase): - pass +from .tgis_auto_finder import TGISAutoFinder diff --git a/caikit_nlp/model_management/tgis_auto_finder.py b/caikit_nlp/model_management/tgis_auto_finder.py new file mode 100644 index 00000000..a4c13d81 --- /dev/null +++ b/caikit_nlp/model_management/tgis_auto_finder.py @@ -0,0 +1,142 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The TGISAutoFinder implements the ModelFinder interface to provide automatic +discovery of text-generation models that can be auto-configured to run against +a remote TGIS model. +""" +# Standard +from typing import Optional + +# First Party +from caikit.core import MODEL_MANAGER, error_handler +from caikit.core.model_management import ModelFinderBase, model_finder_factory +from caikit.core.modules import ModuleConfig +from caikit_tgis_backend import TGISBackend +import aconfig +import alog + +# Local +from ..modules.text_generation import TextGenerationTGIS + +log = alog.use_channel("TGIS_FND") +error = error_handler.get(log) + + +class TGISAutoFinder(ModelFinderBase): + __doc__ = __doc__ + + name = "TGIS-AUTO" + + # Constants for the keys of the config blob + _LOCAL_INITIALIZER_NAME_KEY = "local_initializer_name" + _TGIS_BACKEND_PRIORITY_KEY = "tgis_backend_priority" + + def __init__(self, config: aconfig.Config, instance_name: str = ""): + """Initialize from the model finder factory config + + Config schema: + + local_initializer_name: + type: string + default: "default" + description: The name within the initializers config for the LOCAL + initializer that will hold the tgis backend to use + + tgis_backend_priority: + type: integer + description: Index within the backend_priority list for the TGIS + backend to use. If not set, the first TGIS backend found will be + used. + + Args: + config (aconfig.Config): The configuration blob from caikit's + model_management factory construction + instance_name (str): The name of this finder instance + """ + local_initializer_name = config.get(self._LOCAL_INITIALIZER_NAME_KEY, "default") + tgis_backend_priority = config.get(self._TGIS_BACKEND_PRIORITY_KEY) + error.type_check( + "", str, local_initializer_name=local_initializer_name + ) + error.type_check( + "", + int, + tgis_backend_priority=tgis_backend_priority, + allow_none=True, + ) + + # Extract the TGIS backend instance + local_initializer = MODEL_MANAGER.get_initializer(local_initializer_name) + backends = local_initializer.backends + if tgis_backend_priority is not None: + error.value_check( + "", + 0 <= tgis_backend_priority < len(backends), + "Invalid {}: {}", + self._TGIS_BACKEND_PRIORITY_KEY, + tgis_backend_priority, + ) + self._tgis_backend = backends[tgis_backend_priority] + error.value_check( + "", + self._tgis_backend.backend_type == TGISBackend.backend_type, + "Index {} is not a TGIS backend", + tgis_backend_priority, + ) + else: + tgis_backend = None + for backend in backends: + if backend.backend_type == TGISBackend.backend_type: + tgis_backend = backend + break + error.value_check( + "", + tgis_backend is not None, + "No TGIS backend found!", + ) + self._tgis_backend = tgis_backend + + def find_model( + self, + model_path: str, + **kwargs, + ) -> Optional[ModuleConfig]: + """Find the model if""" + + # Get a connection to this model in tgis + log.debug2("Attempting to setup TGIS client for %s", model_path) + if self._tgis_backend.get_connection(model_id=model_path) is None: + log.debug2("TGIS cannot connect to model %s", model_path) + return None + + # If connection is ok, set up the module config to point to the remote + # TGIS text generation module + cfg = ModuleConfig( + { + "module_id": TextGenerationTGIS.MODULE_ID, + "module_class": TextGenerationTGIS.MODULE_CLASS, + "name": TextGenerationTGIS.MODULE_NAME, + "version": TextGenerationTGIS.MODULE_VERSION, + "model_name": model_path, + } + ) + # Set a special indicator in the module config to use the backend that + # this finder found. This will override the backend found by the local + # initializer. + cfg.tgis_backend = self._tgis_backend + return cfg + + +model_finder_factory.register(TGISAutoFinder) diff --git a/caikit_nlp/modules/text_classification/__init__.py b/caikit_nlp/modules/text_classification/__init__.py index eb56bf6d..79ff6537 100644 --- a/caikit_nlp/modules/text_classification/__init__.py +++ b/caikit_nlp/modules/text_classification/__init__.py @@ -14,4 +14,3 @@ # Local from .sequence_classification import SequenceClassification -from .text_classification_task import TextClassificationTask diff --git a/caikit_nlp/modules/text_classification/classification_prompt_tuning.py b/caikit_nlp/modules/text_classification/classification_prompt_tuning.py new file mode 100644 index 00000000..72bff310 --- /dev/null +++ b/caikit_nlp/modules/text_classification/classification_prompt_tuning.py @@ -0,0 +1,235 @@ +# Standard +from typing import List, Optional, Union +import os + +# First Party +from caikit.core.data_model import DataStream +from caikit.core.modules import ( + ModuleBase, + ModuleConfig, + ModuleLoader, + ModuleSaver, + module, +) +from caikit.core.toolkit import error_handler, wip_decorator +from caikit.interfaces.nlp.data_model import ( + ClassificationResult, + ClassificationResults, + ClassificationTrainRecord, +) +from caikit.interfaces.nlp.tasks import TextClassificationTask +import alog + +# Local +from ...data_model import TuningConfig +from ...toolkit.task_specific_utils import get_sorted_unique_class_labels +from ..text_generation import PeftPromptTuning + +log = alog.use_channel("CLASSIFICATION_PROMPT") +error = error_handler.get(log) + +# TODO: try to refactor this into a smaller module +# pylint: disable=too-many-lines,too-many-instance-attributes +@module( + id="6713731b-160b-4sc5-8df4-167126e2cd11", + name="Classification Peft Tuning", + version="0.1.0", + task=TextClassificationTask, +) +class ClassificationPeftPromptTuning(ModuleBase): + + _DETECT_DEVICE = "__DETECT__" + + def __init__( + self, + classifier: PeftPromptTuning, + unique_class_labels: List[str], + ): + super().__init__() + error.type_check( + "", + PeftPromptTuning, + classifier=classifier, + ) + error.type_check( + "", + List, + unique_class_labels=unique_class_labels, + ) + self.classifier = classifier + self.unique_class_labels = unique_class_labels + + @classmethod + @wip_decorator.work_in_progress( + category=wip_decorator.WipCategory.WIP, action=wip_decorator.Action.WARNING + ) + def train( + cls, + base_model: str, # TODO: Union[str, PretrainedModelBase] + train_stream: DataStream[ClassificationTrainRecord], + tuning_config: TuningConfig, + val_stream: DataStream[ClassificationTrainRecord] = None, + device: str = _DETECT_DEVICE, # TODO: Union[int, str] + tuning_type: str = "PROMPT_TUNING", # TODO: Union[str, TuningType] + num_epochs: int = 20, + lr: float = 0.3, + verbalizer: str = "{{input}}", + batch_size: int = 8, + max_source_length: int = 256, + max_target_length: int = 128, + accumulate_steps: int = 32, + torch_dtype: str = None, # TODO: Optional[Union[torch.dtype, str]] + silence_progress_bars: bool = True, + **kwargs, + ) -> "ClassificationPeftPromptTuning": + """Run prompt tuning (vanilla or MPT) through PEFT on a CausalLM or Seq2seq model + to refine a text generation model. + + Args: + base_model: Union[str, caikit_nlp.resources.pretrained_model.base.PretrainedModelBase] + Base resource model used for underlying generation. + train_stream: DataStream[ClassificationTrainRecord] + Data to be used for training the prompt vectors of the generation model. + tuning_config: TuningConfig + Additional model tuning configurations to be considered for prompt vector + initialization and training behavior. + val_stream: Optional[DataStream[ClassificationTrainRecord] + Data to be used for validation throughout the train process or None. + device: str + Device to be used for training the model. Default: cls._DETECT_DEVICE, which + will fall back to "cuda" if available, else None. + tuning_type: str + Type of Peft Tuning config which we would like to build. + num_epochs: int + Number of epochs to tune the prompt vectors. Default: 20. + lr: float + Learning rate to be used while tuning prompt vectors. Default: 1e-3. + verbalizer: str + Verbalizer template to be used for formatting data at train and inference time. + This template may use brackets to indicate where fields from the data model + TrainGenerationRecord must be rendered. Default: "{{input}}", i.e., the raw text. + batch_size: int + Batch sized to be used for training / evaluation data. Default: 8. + max_source_length: int + Max length of input sequences being considered. Default: 256. + max_target_length: int + Max length of target sequences being predicted. Default: 128. + accumulate_steps: int + Number of steps to use for gradient accumulation. Default: 1. + torch_dtype: str + TODO: Optional[Union[torch.dtype, str]] + Data type to use for training/inference of the underlying text generation model. + If no value is provided, we pull from torch_dtype in config. If an in memory + resource is provided which does not match the specified data type, the model + underpinning the resource will be converted in place to the correct torch dtype. + silence_progress_bars: bool + Silences TQDM progress bars at train time. Default: True. + Returns: + ClassificationPeftPromptTuning + Instance of this class with tuned prompt vectors. + """ + + unique_class_labels = get_sorted_unique_class_labels(train_stream) + # Wrap up the trained model in a class instance + return cls( + classifier=PeftPromptTuning.train( + base_model, + train_stream, + tuning_config, + val_stream, + device, + tuning_type, + num_epochs, + lr, + verbalizer, + batch_size, + max_source_length, + max_target_length, + accumulate_steps, + torch_dtype, + silence_progress_bars, + **kwargs, + ), + unique_class_labels=unique_class_labels, + # TODO: Export other training params to model as well + ) + + # TODO: enable passing save_base_model flag as argument when supported by caikit + @wip_decorator.work_in_progress( + category=wip_decorator.WipCategory.WIP, action=wip_decorator.Action.WARNING + ) + def save(self, model_path): + """Save classification model + + Args: + model_path: str + Folder to save classification prompt tuning model + """ + saver = ModuleSaver(self, model_path=model_path) + with saver: + saver.save_module(self.classifier, "artifacts") + saver.update_config( + { + "unique_class_labels": self.unique_class_labels, + } + ) + + @classmethod + @wip_decorator.work_in_progress( + category=wip_decorator.WipCategory.WIP, action=wip_decorator.Action.WARNING + ) + def load(cls, model_path: str) -> "ClassificationPeftPromptTuning": + """Load a classification model. + + Args: + model_path: str + Path to the model to be loaded. + + Returns: + ClassificationPeftPromptTuning + Instance of this class. + """ + config = ModuleConfig.load(os.path.abspath(model_path)) + loader = ModuleLoader(model_path) + classifier = loader.load_module("artifacts") + return cls( + classifier=classifier, + unique_class_labels=config.unique_class_labels, + ) + + # TODO: Currently only singlelabel classification is supported, \ + # hence it will always return list of 1 element. + # Support for multilabel may be added in future. + def run( + self, + text: str, + device: Optional[Union[str, int]] = _DETECT_DEVICE, + max_new_tokens=20, + min_new_tokens=0, + ) -> ClassificationResults: + """Run the classifier model. + + Args: + text: str + Input string to be used to the classification model. + device: Optional[Union[str, int]] + Device on which we should run inference; by default, we use the detected device. + max_new_tokens: int + The maximum numbers of tokens to generate for class label. + Default: 20 + min_new_tokens: int + The minimum numbers of tokens to generate. + Default: 0 - means no minimum + + Returns: + ClassificationResults + """ + gen_result = self.classifier.run(text, device, max_new_tokens, min_new_tokens) + # Either return supported class labels or None + label = ( + gen_result.generated_text + if gen_result.generated_text in self.unique_class_labels + else None + ) + + return ClassificationResults(results=[ClassificationResult(label=label)]) diff --git a/caikit_nlp/modules/text_classification/sequence_classification.py b/caikit_nlp/modules/text_classification/sequence_classification.py index 3be60656..8a88f489 100644 --- a/caikit_nlp/modules/text_classification/sequence_classification.py +++ b/caikit_nlp/modules/text_classification/sequence_classification.py @@ -23,14 +23,14 @@ # First Party from caikit.core.modules import ModuleBase, ModuleLoader, ModuleSaver, module from caikit.core.toolkit import error_handler +from caikit.interfaces.nlp.data_model import ClassificationResult, ClassificationResults +from caikit.interfaces.nlp.tasks import TextClassificationTask import alog # Local -from ...data_model import Classification, ClassificationResult from ...resources.pretrained_model.hf_auto_seq_classifier import ( HFAutoSequenceClassifier, ) -from .text_classification_task import TextClassificationTask log = alog.use_channel("SEQ_CLASS") error = error_handler.get(log) @@ -62,7 +62,7 @@ def __init__( ################################## API functions ############################################# - def run(self, text: str) -> ClassificationResult: + def run(self, text: str) -> ClassificationResults: """Run the sequence classification. NOTE: This will truncate sequences that are too long for model @@ -71,13 +71,13 @@ def run(self, text: str) -> ClassificationResult: Input string to be classified Returns: - ClassificationResult + ClassificationResults """ scores_dict = self._get_scores(text) # Re-organize scores_dict - for one text, this is just the first score return SequenceClassification._process_predictions(scores_dict, text_idx=0) - def run_batch(self, texts: List[str]) -> List[ClassificationResult]: + def run_batch(self, texts: List[str]) -> List[ClassificationResults]: """Run the sequence classification on batch, truncates sequences too long for model Args: @@ -85,7 +85,7 @@ def run_batch(self, texts: List[str]) -> List[ClassificationResult]: Input strings to be classified Returns: - List[ClassificationResult] + List[ClassificationResults] """ scores_dict = self._get_scores(texts) num_texts = len(texts) @@ -207,8 +207,8 @@ def _get_scores(self, text: Union[str, List[str]]): return scores_dict @staticmethod - def _process_predictions(scores_dict: Dict, text_idx: int) -> ClassificationResult: - """Process dictionary of label: scores to ClassificationResult + def _process_predictions(scores_dict: Dict, text_idx: int) -> ClassificationResults: + """Process dictionary of label: scores to ClassificationResults Args: scores_dict: Dict @@ -218,13 +218,13 @@ def _process_predictions(scores_dict: Dict, text_idx: int) -> ClassificationResu Integer index of text in batch Returns: - ClassificationResult + ClassificationResults """ error.type_check("", Dict, scores_dict=scores_dict) classification_list = [] for label, score_array in scores_dict.items(): # NOTE: labels are expected to be str, especially for config classification_list.append( - Classification(label=str(label), score=score_array[text_idx]) + ClassificationResult(label=str(label), score=score_array[text_idx]) ) - return ClassificationResult(results=classification_list) + return ClassificationResults(results=classification_list) diff --git a/caikit_nlp/modules/text_generation/__init__.py b/caikit_nlp/modules/text_generation/__init__.py index ce7363d3..8c696d81 100644 --- a/caikit_nlp/modules/text_generation/__init__.py +++ b/caikit_nlp/modules/text_generation/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Local -from .fine_tuning import FineTuning from .peft_prompt_tuning import PeftPromptTuning from .peft_tgis_remote import PeftPromptTuningTGIS -from .text_generation import TextGeneration +from .text_generation_local import TextGeneration +from .text_generation_tgis import TextGenerationTGIS diff --git a/caikit_nlp/modules/text_generation/fine_tuning.py b/caikit_nlp/modules/text_generation/fine_tuning.py deleted file mode 100644 index cf2ab102..00000000 --- a/caikit_nlp/modules/text_generation/fine_tuning.py +++ /dev/null @@ -1,309 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Third Party -from torch.utils.data import IterableDataset -from transformers import ( - AutoConfig, - AutoTokenizer, - DataCollatorForSeq2Seq, - Seq2SeqTrainer, - Seq2SeqTrainingArguments, - Trainer, -) -import torch - -# First Party -from caikit.core.data_model import DataStream -from caikit.core.modules import ModuleBase, module -from caikit.core.toolkit import error_handler, wip_decorator -from caikit.interfaces.nlp.data_model import GeneratedTextResult -from caikit.interfaces.nlp.tasks import TextGenerationTask -import alog - -# Local -from ...data_model import GenerationTrainRecord -from ...toolkit.data_stream_wrapper import SimpleIterableStreamWrapper -from ...toolkit.data_type_utils import get_torch_dtype - -log = alog.use_channel("FIN_TUN_GEN") -error = error_handler.get(log) - - -# pylint: disable=too-many-lines,too-many-instance-attributes -@module( - id="28a81449-32ce-4be3-b688-545bde68f738", - name="Text Generation", - version="0.1.0", - task=TextGenerationTask, -) -@wip_decorator.work_in_progress( - category=wip_decorator.WipCategory.WIP, action=wip_decorator.Action.ERROR -) -class FineTuning(ModuleBase): - """Module to provide fine-tuning support for text generation task""" - - def __init__(self, tokenizer, model): - super().__init__() - - self.tokenizer = tokenizer - # NOTE: self.model here can also be HF trainer. This is because - # if we have just trained the model then the models weights might be - # available in different devices (and configuration), depending on - # how it was trained. For now (July 10, 2023), we are not trying to - # extract the model out from trainer itself, since that would require - # us to essentially save it or reconstruct it to do normal inferring. - self.model = model - - @classmethod - def train( - cls, - base_model: str, # TODO: Union[str, PretrainedModelBase] - train_stream: DataStream[GenerationTrainRecord], - torch_dtype: str = None, # TODO: Optional[Union[torch.dtype, str]] - max_source_length: int = 256, - max_target_length: int = 128, - batch_size: int = 8, - num_epochs: int = 5, - accumulate_steps: int = 32, - lr: float = 2e-5, - # Directory where model predictions and checkpoints will be written - checkpoint_dir: str = "/tmp", - ): - """ - # FIXME: Below is currently configured for Seq2Seq only - """ - - torch_dtype = get_torch_dtype(torch_dtype) - - ## NOTE: Below code has been used in couple of places at this point, like in - # text_generation module. In future, we would want to consolidate this into - # a base class or a toolkit function - # pylint: disable=duplicate-code - ## Load base model - if isinstance(base_model, str): - model_config = AutoConfig.from_pretrained(base_model) - - resource_type = None - for resource in cls.supported_resources: - if model_config.model_type in resource.SUPPORTED_MODEL_TYPES: - resource_type = resource - break - - if not resource_type: - error( - "", - "{} model type is not supported currently!".format( - model_config.model_type - ), - ) - log.debug("Bootstrapping base resource [%s]", base_model) - base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype) - - ## Generate data loader from stream - training_dataset: IterableDataset = cls._preprocess_function( - train_stream=train_stream, - tokenizer=base_model.tokenizer, - max_source_length=max_source_length, - max_target_length=max_target_length, - shuffle=True, - ) - - ### Dtype based processing - # NOTE: Following is not exhaustive list of all parameters - # for all dtypes - if torch_dtype == torch.float16: - dtype_based_params = { - "fp16": True, - } - elif torch_dtype == torch.bfloat16: - dtype_based_params = { - "bf16": True, - } - else: - # default to float32 - dtype_based_params = {} - - ## TODO: Add automatic sharding selection based on number of parameters - # in base model - ## TODO: Fetch trainer from resource - - # TODO: Make this whole thing configurable by end-users, - # by optionally accepting `training_args` - # as argument to this train function. - # TODO: Remove all the default used below and make them all configurable - training_args = Seq2SeqTrainingArguments( - output_dir=checkpoint_dir, - per_device_train_batch_size=batch_size, - per_device_eval_batch_size=batch_size, - num_train_epochs=num_epochs, - # NOTE: We have disabled evaluation for now - do_eval=False, - # evaluation_strategy = "epoch", - learning_rate=lr, - weight_decay=0.01, - save_total_limit=3, - predict_with_generate=True, - push_to_hub=False, - no_cuda=False, # Default - generation_max_length=max_target_length, - remove_unused_columns=False, - dataloader_pin_memory=False, - gradient_accumulation_steps=accumulate_steps, - eval_accumulation_steps=accumulate_steps, - logging_strategy="epoch", - disable_tqdm=True, - # NOTE: Following not possible without save and eval strategy - # load_best_model_at_end=True, - # eval_steps=1, - **dtype_based_params, - ## TODO: Make below configurable - # fsdp="full_shard auto_wrap", - # local_rank=0, - ) - - data_collator = DataCollatorForSeq2Seq( - tokenizer=base_model.tokenizer, model=base_model.model - ) - - trainer = Seq2SeqTrainer( - base_model.model, - training_args, - train_dataset=training_dataset, - data_collator=data_collator, - tokenizer=base_model.tokenizer, - # compute_metrics=compute_metrics, - ) - - if num_epochs < 1: - log.warning( - "", - f"Number of epochs configured is {num_epochs} which is less than minimum 1. \ - No training will be performed", - ) - - return cls( - tokenizer=base_model.tokenizer, - model=trainer, - ) - - # Start training via Trainer.train function - trainer.train() - # NOTE: By default the model would be available in different ways - # depending on where and how it was trained. So we need to fetch the model - # from the trainer depending on the training method, like fsdp, ddp etc. - # For simplicity, currently we will use trainer as the model since it anyways - # enable the `predict` function on it and has all the layers of the model - # distributed already, so it will be most optimized to use trainer to - # perform prediction at this stage. - - return cls( - tokenizer=base_model.tokenizer, - model=trainer, - ) - - # pylint: disable=unused-argument - def run( - self, text, preserve_input_text=False, max_new_tokens=20, min_new_tokens=0 - ) -> "GeneratedTextResult": - """Run inference against the model running in TGIS. - - Args: - text: str - Source string to be encoded for generation. - preserve_input_text: bool - Whether or not the source string should be contained in the generated output, - e.g., as a prefix. - max_new_tokens: int - The maximum numbers of tokens to generate. - Default: 128 - min_new_tokens: int - The minimum numbers of tokens to generate. - Default: 0 - means no minimum - Returns: - GeneratedTextResult - Generated text result - """ - if isinstance(self.model, Trainer): - # Apply the tokenizer to the sample text & move to correct device - tok_tensors = self.tokenizer(text, return_tensors="pt") - # NOTE: below function is prediction on trainer, for which we need to supply - # the actual underlying model as well - # NOTE: We are using prediction_step instead of calling `self.model.generate` - # because this way HF Trainer automatically handles device placement of the - # data and model. Since the model is with Trainer at this point - # and thus the device placement be according to training strategy, - # its better to let Trainer handle the evaluation / prediction - - # TODO: Add support for passing extra arguments to prediction_step - _, generated_tokens, _ = self.model.prediction_step( - self.model.model, - tok_tensors, - prediction_loss_only=False, - max_new_tokens=max_new_tokens, - min_new_tokens=min_new_tokens, - ) - - generated_text = self.tokenizer.batch_decode( - generated_tokens.detach().cpu().numpy(), skip_special_tokens=True - )[0] - - else: - error( - "", - NotImplementedError( - "model prediction on pre-finetuned model currently not supported" - ), - ) - - return GeneratedTextResult(generated_text=generated_text) - - ################################## Private Functions ########################################### - - @staticmethod - def _preprocess_function( - train_stream: DataStream[GenerationTrainRecord], - tokenizer: AutoTokenizer, - max_source_length: int, - max_target_length: int, - shuffle: bool, - ): - """Pre-process each example to get it prepared for training.""" - - # FIXME: Below is currently configured for Seq2Seq only - - def _tokenization_func( - example: GenerationTrainRecord, - ): - model_inputs = tokenizer( - example.input, - max_length=max_source_length, - truncation=True, - ) - - labels = tokenizer( - example.output, - max_length=max_target_length, - padding="max_length", - truncation=True, - ) - - model_inputs["labels"] = labels["input_ids"] - - return model_inputs - - return SimpleIterableStreamWrapper( - train_stream.map(_tokenization_func), shuffle=shuffle - ) diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index a432cf0a..df27ef6a 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -49,8 +49,9 @@ from caikit import get_config from caikit.core.data_model import DataStream from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module -from caikit.core.toolkit import error_handler, wip_decorator +from caikit.core.toolkit import error_handler from caikit.interfaces.nlp.data_model import ( + ClassificationTrainRecord, GeneratedTextResult, GeneratedTextStreamResult, ) @@ -58,12 +59,7 @@ import alog # Local -from ...data_model import ( - ClassificationTrainRecord, - GenerationTrainRecord, - PromptOutputModelType, - TuningConfig, -) +from ...data_model import GenerationTrainRecord, PromptOutputModelType, TuningConfig from ...resources.pretrained_model import ( HFAutoCausalLM, HFAutoSeq2SeqLM, @@ -158,6 +154,7 @@ def __init__( self.tuning_type = tuning_type self.output_model_types = output_model_types + # pylint: disable=duplicate-code def __del__(self): del self.model del self.tokenizer @@ -199,6 +196,7 @@ def run( verbalized_text = render_verbalizer(self.verbalizer, {"input": text}) # Apply the tokenizer to the sample text & move to correct device tok_tensors = self.tokenizer(verbalized_text, return_tensors="pt") + device = PeftPromptTuning._get_device(device) inputs = {k: v.to(device) for k, v in tok_tensors.items()} with torch.no_grad(): @@ -215,10 +213,13 @@ def run( ) return GeneratedTextResult(generated_text=gen_text[0]) + # NOTE: We need to disable wip decorator here otherwise we get issues in + # proto generation for streaming. We are keeping it commented out for now, + # to essentially document that this streaming function is WIP. + # @wip_decorator.work_in_progress( + # category=wip_decorator.WipCategory.WIP, action=wip_decorator.Action.WARNING + # ) @TextGenerationTask.taskmethod(output_streaming=True) - @wip_decorator.work_in_progress( - category=wip_decorator.WipCategory.WIP, action=wip_decorator.Action.WARNING - ) def run_stream_out( self, text: str, max_new_tokens=20, min_new_tokens=0 ) -> Iterable[GeneratedTextStreamResult]: @@ -616,7 +617,12 @@ def save(self, model_path: str, save_base_model: bool = False): module_saver.update_config(config_options) @classmethod - def load(cls, model_path: str, torch_dtype: str = None) -> "PeftPromptTuning": + def load( + cls, + model_path: str, + torch_dtype: str = None, + device: str = _DETECT_DEVICE, # TODO: Union[int, str] + ) -> "PeftPromptTuning": """Load a PEFT prompt tuning model. This method will currently fail if the original model was not saved with the arg value save_base_model=True. @@ -638,7 +644,7 @@ def load(cls, model_path: str, torch_dtype: str = None) -> "PeftPromptTuning": torch_dtype = str_to_torch_dtype(config.trained_torch_dtype) if config.has_base_model: # TODO: Implement logic for resource loading - device = cls._get_device(cls._DETECT_DEVICE) + device = cls._get_device(device) model_config = os.path.join(model_path, config.full_model_path) peft_config = PeftConfig.from_pretrained(model_config) if peft_config.task_type == "CAUSAL_LM": @@ -1017,7 +1023,7 @@ def _get_data_loaders_from_stream( tokenize_function, requires_unwrapping, ) = base_model.build_task_tokenize_function( - tokenizer, max_source_length, max_target_length, verbalizer + tokenizer, max_source_length, max_target_length, verbalizer, task_ids=0 ) mapped_stream = train_stream.map(tokenize_function) if requires_unwrapping: @@ -1077,8 +1083,11 @@ def _execute_train_loop( num_warmup_steps=0, num_training_steps=(len(train_dataloader) * num_epochs), ) - # Configure accelerator for gradient accumulation - accelerator = Accelerator(gradient_accumulation_steps=accumulate_steps) + + accelerator = Accelerator( + gradient_accumulation_steps=accumulate_steps, device_placement=True + ) + for epoch in range(num_epochs): model.train() total_loss = 0 diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index af9626e8..65112bea 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -160,7 +160,12 @@ def save(self, model_path: str): @TextGenerationTask.taskmethod() def run( - self, text, preserve_input_text=False, max_new_tokens=20, min_new_tokens=0 + self, + text, + preserve_input_text=False, + max_new_tokens=20, + min_new_tokens=0, + truncate_input_tokens=0, ) -> GeneratedTextResult: """Run inference against the model running in TGIS. Currently we leverage greedy decoding and apply the same verbalizer used for training the local model prior to sending the @@ -178,7 +183,11 @@ def run( min_new_tokens: int The minimum numbers of tokens to generate. Default: 0 - means no minimum - + truncate_input_tokens: int + Truncate inputs to provided number of tokens. This can be + use to avoid failing due to input being longer than + configured limits. + Default: 0 - means don't truncate, thus throw error. Returns: GeneratedTextResult Generated text result produced by TGIS. @@ -190,12 +199,21 @@ def run( ) verbalized_text = render_verbalizer(self.verbalizer, {"input": text}) return self.tgis_generation_client.unary_generate( - verbalized_text, preserve_input_text, max_new_tokens, min_new_tokens + verbalized_text, + preserve_input_text, + max_new_tokens, + min_new_tokens, + truncate_input_tokens, ) @TextGenerationTask.taskmethod(output_streaming=True) def run_stream_out( - self, text: str, preserve_input_text=False, max_new_tokens=20, min_new_tokens=0 + self, + text: str, + preserve_input_text=False, + max_new_tokens=20, + min_new_tokens=0, + truncate_input_tokens=0, ) -> Iterable[GeneratedTextStreamResult]: """Run output stream inferencing against the model running in TGIS @@ -211,7 +229,11 @@ def run_stream_out( min_new_tokens: int The minimum numbers of tokens to generate. Default: 0 - means no minimum - + truncate_input_tokens: int + Truncate inputs to provided number of tokens. This can be + use to avoid failing due to input being longer than + configured limits. + Default: 0 - means don't truncate, thus throw error. Returns: Iterable[GeneratedTextStreamResult] """ @@ -223,5 +245,9 @@ def run_stream_out( ) verbalized_text = render_verbalizer(self.verbalizer, {"input": text}) return self.tgis_generation_client.stream_generate( - verbalized_text, preserve_input_text, max_new_tokens, min_new_tokens + verbalized_text, + preserve_input_text, + max_new_tokens, + min_new_tokens, + truncate_input_tokens, ) diff --git a/caikit_nlp/modules/text_generation/text_generation_local.py b/caikit_nlp/modules/text_generation/text_generation_local.py new file mode 100644 index 00000000..3898f16d --- /dev/null +++ b/caikit_nlp/modules/text_generation/text_generation_local.py @@ -0,0 +1,510 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Standard +from typing import Optional +import gc +import os + +# Third Party +from torch.utils.data import IterableDataset +from transformers import AutoConfig, AutoTokenizer +import torch + +# First Party +from caikit.core.data_model import DataStream +from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module +from caikit.core.toolkit import error_handler +from caikit.interfaces.nlp.data_model import GeneratedTextResult +from caikit.interfaces.nlp.tasks import TextGenerationTask +import alog + +# Local +from ...data_model import GenerationTrainRecord +from ...resources.pretrained_model import ( + HFAutoCausalLM, + HFAutoSeq2SeqLM, + PretrainedModelBase, +) +from ...toolkit.data_stream_wrapper import SimpleIterableStreamWrapper +from ...toolkit.data_type_utils import get_torch_dtype, str_to_torch_dtype + +log = alog.use_channel("TXT_GEN") +error = error_handler.get(log) + + +# pylint: disable=too-many-lines,too-many-instance-attributes +@module( + id="f9181353-4ccf-4572-bd1e-f12bcda26792", + name="Text Generation", + version="0.1.0", + task=TextGenerationTask, +) +class TextGeneration(ModuleBase): + """Module to provide text generation capabilities""" + + RANDOM_SEED = 73 + supported_resources = [HFAutoCausalLM, HFAutoSeq2SeqLM] + + def __init__( + self, + model_name: str, + model: PretrainedModelBase = None, + bos_token: Optional[str] = None, + sep_token: Optional[str] = None, + eos_token: Optional[str] = None, + pad_token: Optional[str] = None, + ): + super().__init__() + + error.type_check("", str, allow_none=True, eos_token=eos_token) + self.model = model + self.model_name = model_name + + self._bos_token = bos_token + self._sep_token = sep_token + self._eos_token = eos_token + self._pad_token = pad_token + + # pylint: disable=duplicate-code + def __del__(self): + del self.model + gc.collect() + try: + torch.cuda.empty_cache() + except AttributeError: + pass + + @classmethod + def bootstrap(cls, base_model_path: str, torch_dtype: str = "float32"): + """Function to bootstrap a pre-trained transformers model and + get a caikit text-generation 'model'. + + Args: + base_model_path: str + Path to transformers model + NOTE: Model path needs to contain tokenizer as well + torch_dtype: str + Torch data type to be used when loading the model. + Default: float32 + Returns: + caikit_nlp.blocks.text_generation.TextGeneration + Object of TextGeneration class (model) + """ + # pylint: disable=duplicate-code + model_config = AutoConfig.from_pretrained(base_model_path) + + resource_type = None + for resource in cls.supported_resources: + if model_config.model_type in resource.SUPPORTED_MODEL_TYPES: + resource_type = resource + break + + if not resource_type: + error( + "", + "{} model type is not supported currently!".format( + model_config.model_type + ), + ) + log.debug("Bootstrapping base resource [%s]", base_model_path) + base_model = resource_type.bootstrap( + base_model_path, + tokenizer_name=base_model_path, + torch_dtype=torch_dtype, + ) + eos_token = base_model._tokenizer.eos_token or None + return cls( + base_model_path, + base_model, + eos_token=eos_token, + ) + + @classmethod + def train( + cls, + base_model: str, # TODO: Union[str, PretrainedModelBase] + train_stream: DataStream[GenerationTrainRecord], + torch_dtype: str = None, # TODO: Optional[Union[torch.dtype, str]] + max_source_length: int = 256, + max_target_length: int = 128, + batch_size: int = 8, + num_epochs: int = 5, + accumulate_steps: int = 32, + random_seed: int = RANDOM_SEED, + lr: float = 2e-5, + # Directory where model predictions and checkpoints will be written + checkpoint_dir: str = "/tmp", + **kwargs, + ) -> "TextGeneration": + """ + Fine-tune a CausalLM or Seq2seq text generation model. + + Args: + base_model: Union[str, caikit_nlp.resources.pretrained_model.base.PretrainedModelBase] + Base resource model used for underlying generation. + train_stream: DataStream[GenerationTrainRecord] or DataStream[ClassificationTrainRecord] + Data to be used for fine-tuning the generation model. + torch_dtype: str + TODO: Optional[Union[torch.dtype, str]] + Data type to use for training/inference of the underlying text generation model. + If no value is provided, we pull from torch_dtype in config. If an in memory + resource is provided which does not match the specified data type, the model + underpinning the resource will be converted in place to the correct torch dtype. + max_source_length: int + Max length of input sequences being considered. Default: 256. + max_target_length: int + Max length of target sequences being predicted. Default: 128. + batch_size: int + Batch sized to be used for training / evaluation data. Default: 8. + num_epochs: int + Number of epochs to tune the model. Default: 20. + accumulate_steps: int + Number of steps to use for gradient accumulation. Default: 1. + lr: float + Learning rate to be used while tuning model. Default: 2e-5. + checkpoint_dir: str + Directory where model predictions and checkpoints will be written + **kwargs: + Arguments supported by HF Training Arguments. + TrainingArguments: + https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.TrainingArguments + Seq2SeqTrainingArguments: + https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.Seq2SeqTrainingArguments + Returns: + TextGeneration + Instance of this class with fine-tuned models. + """ + + torch_dtype = get_torch_dtype(torch_dtype) + + ## NOTE: Below code has been used in couple of places at this point, like in + # text_generation module. In future, we would want to consolidate this into + # a base class or a toolkit function + # pylint: disable=duplicate-code + resource_type = None + + ## Load base model + if isinstance(base_model, str): + model_config = AutoConfig.from_pretrained(base_model) + + for resource in cls.supported_resources: + if model_config.model_type in resource.SUPPORTED_MODEL_TYPES: + resource_type = resource + break + + if not resource_type: + error( + "", + "{} model type is not supported currently!".format( + model_config.model_type + ), + ) + log.debug("Bootstrapping base resource [%s]", base_model) + base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype) + + else: + # base_model is actually a resource object + resource_type = type(base_model) + + error.type_check("", PretrainedModelBase, base_model=base_model) + ## Generate data loader from stream + training_dataset: IterableDataset = cls._preprocess_function( + base_model=base_model, + train_stream=train_stream, + tokenizer=base_model.tokenizer, + max_source_length=max_source_length, + max_target_length=max_target_length, + shuffle=True, + ) + + ### Dtype based processing + # NOTE: Following is not exhaustive list of all parameters + # for all dtypes + if torch_dtype == torch.float16: + dtype_based_params = { + "fp16": True, + } + elif torch_dtype == torch.bfloat16: + dtype_based_params = { + "bf16": True, + } + else: + # default to float32 + dtype_based_params = {} + + ## TODO: Add automatic sharding selection based on number of parameters + # in base model + ## TODO: Fetch trainer from resource + + # TODO: Make this whole thing configurable by end-users, + # by optionally accepting `training_args` + # as argument to this train function. + # TODO: Remove all the default used below and make them all configurable + + training_args = { + "output_dir": checkpoint_dir, + "per_device_train_batch_size": batch_size, + "per_device_eval_batch_size": batch_size, + "num_train_epochs": num_epochs, + "seed": random_seed, + # NOTE: We have disabled evaluation for now + "do_eval": False, + # "evaluation_strategy ": "epoch", + "learning_rate": lr, + "weight_decay": 0.01, + "save_total_limit": 3, + "push_to_hub": False, + "no_cuda": False, # Default + "remove_unused_columns": False, + "dataloader_pin_memory": False, + "gradient_accumulation_steps": accumulate_steps, + "eval_accumulation_steps": accumulate_steps, + # eval_steps=1, + # load_best_model_at_end + **kwargs, + **dtype_based_params, + } + + trainer = base_model.get_trainer( + train_dataset=training_dataset, **training_args + ) + + if num_epochs < 1: + log.warning( + "", + f"Number of epochs configured is {num_epochs} which is less than minimum 1. \ + No training will be performed", + ) + + return cls( + model_name=base_model._model_name, + model=base_model, + ) + + # Start training via Trainer.train function + trainer.train() + + # save the model temporarily and reload it + # this is done, since otherwise the model might be distributed in different + # devices, in which case its better to use trainer's `prediction_step` + # functions, but then, they don't always give API similar to `generate` + # and thus cause incompatibilities in `run` function + trainer.save_model(checkpoint_dir) + + model = resource_type.bootstrap( + checkpoint_dir, checkpoint_dir, torch_dtype=torch_dtype + ) + + return cls( + model_name=base_model._model_name, + model=model, + bos_token=model.tokenizer.bos_token or None, + sep_token=model.tokenizer.sep_token or None, + eos_token=model.tokenizer.eos_token or None, + pad_token=model.tokenizer.pad_token or None, + ) + + @classmethod + def load( + cls, + model_path: str, + torch_dtype: str = None, + ) -> "TextGeneration": + """Function to load text-generation model + + Args: + model_path: str + Path to the model to be loaded. + torch_dtype: str + Torch data type to be used when loading the model. + Returns: + TextGeneration + Instance of this class built from the on disk model. + """ + + config = ModuleConfig.load(model_path) + + if torch_dtype is not None: + torch_dtype = str_to_torch_dtype(torch_dtype) + elif config.trained_torch_dtype: + torch_dtype = str_to_torch_dtype(config.trained_torch_dtype) + + base_model_path = config.get("artifact_path") + error.type_check("", str, base_model_path=base_model_path) + + base_model_path = os.path.join(model_path, base_model_path) + error.dir_check("", base_model_path) + return cls.bootstrap(base_model_path, torch_dtype) + + def save(self, model_path): + """Save caikit model + + Args: + model_path: str + Folder to save text-generation caikit model + """ + saver = ModuleSaver( + self, + model_path=model_path, + ) + with saver: + artifacts_dir = "artifacts" + saver.update_config( + { + "artifact_path": artifacts_dir, + "eos_token": self._eos_token, + "torch_dtype": str(self.model._torch_dtype), + } + ) + if self.model: + # This will save both tokenizer and base model + self.model.save( + model_path, + tokenizer_dirname=artifacts_dir, + base_model_dirname=artifacts_dir, + ) + + def run( + self, + text: str, + repetition_penalty: float = 2.5, + length_penalty: float = 1.0, + early_stopping: bool = True, + num_beams: int = 1, + max_new_tokens: int = 20, + min_new_tokens: int = 0, + truncate_input_tokens: int = 0, + **kwargs, + ) -> "GeneratedTextResult": + """Run inference against the model running in TGIS. + + Args: + text: str + Source string to be encoded for generation. + repetition_penalty: float + The parameter for repetition penalty. 1.0 means no penalty. + Default: 2.5 + length_penalty: float + Exponential penalty to the length that is used with beam-based generation. + It is applied as an exponent to the sequence length, \ + which is used to divide the score of the sequence. + Since the score is the log likelihood of the sequence (i.e. negative), \ + length_penalty > 0.0 promotes longer sequences, \ + while length_penalty < 0.0 encourages shorter sequences. + Default: 1.0. + early_stopping: bool + Controls the stopping condition for beam-based methods, like beam-search. + It accepts the following values: + True, where the generation stops as soon as there are num_beams complete candidates; + False, where an heuristic is applied and the generation stops when \ + is it very unlikely to find better candidates; + "never", where the beam search procedure only stops \ + when there cannot be better candidates (canonical beam search algorithm). + num_beams: int + Number of beams for beam search. 1 means no beam search. + Default: 1 + max_new_tokens: int + The maximum numbers of tokens to generate. + Default: 20 + min_new_tokens: int + The minimum numbers of tokens to generate. + Default: 0 - means no minimum + truncate_input_tokens: int + Truncate inputs to provided number of tokens. This can be + use to avoid failing due to input being longer than + configured limits. + Default: 0 - means don't truncate, thus throw error. + kwargs: + Any other parameters to pass to generate as specified in GenerationConfig. + https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/text_generation#transformers.GenerationConfig + Returns: + GeneratedTextResult + Generated text result produced by the model. + """ + + # NOTE: below is to match TGIS API, where 0 identifies as no truncation + if truncate_input_tokens == 0: + # NOTE: below will make model throw error in case inputs are longer + # than allowed length + truncation = False + + else: + truncation = True + + inputs = self.model.tokenizer( + text, + truncation=truncation, + max_length=truncate_input_tokens, + return_tensors="pt", + ) + generate_ids = self.model.model.generate( + input_ids=inputs["input_ids"], + num_beams=num_beams, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + early_stopping=early_stopping, + use_cache=True, + **kwargs, + ) + token_count = generate_ids.size(1) - 1 + preds = [ + self.model.tokenizer.decode( + g, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + for g in generate_ids + ] + if generate_ids[0][-1].item() == self._eos_token: + finish_reason = "EOS_TOKEN" + elif generate_ids.size(1) - 1 == max_new_tokens: + finish_reason = "MAX_TOKENS" + else: + finish_reason = "OTHER" + return GeneratedTextResult( + generated_tokens=token_count, + generated_text=preds[0], + finish_reason=finish_reason, + producer_id=self.PRODUCER_ID, + ) + + ################################## Private Functions ###################################### + + @staticmethod + def _preprocess_function( + base_model: PretrainedModelBase, + train_stream: DataStream[GenerationTrainRecord], + tokenizer: AutoTokenizer, + max_source_length: int, + max_target_length: int, + shuffle: bool, + ): + """Pre-process each example to get it prepared for training.""" + + # TODO: We are using a default verbalizer which is strictly tied to + # source training record currently. We need to figure out a better + # way to make verbalizer optional for build_task_tokenize_function + ( + tokenize_function, + requires_unwrapping, + ) = base_model.build_task_tokenize_function( + tokenizer, max_source_length, max_target_length, verbalizer="{{input}}" + ) + mapped_stream = train_stream.map(tokenize_function) + if requires_unwrapping: + mapped_stream = mapped_stream.flatten() + + return SimpleIterableStreamWrapper(mapped_stream, shuffle=shuffle) diff --git a/caikit_nlp/modules/text_generation/text_generation.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py similarity index 64% rename from caikit_nlp/modules/text_generation/text_generation.py rename to caikit_nlp/modules/text_generation/text_generation_tgis.py index 271843a0..036e86fa 100644 --- a/caikit_nlp/modules/text_generation/text_generation.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -14,14 +14,10 @@ # Standard -from typing import Iterable, Optional +from typing import Iterable, Optional, Union import os -# Third Party -from transformers import AutoConfig - # First Party -from caikit import get_config from caikit.core.module_backends import BackendBase, backend_types from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module from caikit.core.toolkit import error_handler @@ -40,20 +36,14 @@ PretrainedModelBase, ) from ...toolkit.tgis_utils import TGISGenerationClient +from .text_generation_local import TextGeneration log = alog.use_channel("TXT_GEN") error = error_handler.get(log) - -# pylint: disable=too-many-lines,too-many-instance-attributes -@module( - id="f9181353-4ccf-4572-bd1e-f12bcda26792", - name="Text Generation", - version="0.1.0", - backend_type=TGISBackend.backend_type, - task=TextGenerationTask, -) -class TextGeneration(ModuleBase): +# pylint: disable=too-many-instance-attributes +@module(backend_type=TGISBackend.backend_type, base_module=TextGeneration) +class TextGenerationTGIS(ModuleBase): """Module to provide text generation capabilities""" SUPPORTED_LOAD_BACKENDS = [TGISBackend.backend_type, backend_types.LOCAL] @@ -62,8 +52,8 @@ class TextGeneration(ModuleBase): def __init__( self, - base_model_name: str, - base_model: Optional[PretrainedModelBase] = None, + model_name: str, + model: Optional[PretrainedModelBase] = None, bos_token: Optional[str] = None, sep_token: Optional[str] = None, eos_token: Optional[str] = None, @@ -76,8 +66,8 @@ def __init__( error.type_check("", str, allow_none=True, sep_token=sep_token) error.type_check("", str, allow_none=True, eos_token=eos_token) error.type_check("", str, allow_none=True, pad_token=pad_token) - self.base_model = base_model - self.base_model_name = base_model_name + self.model = model + self.model_name = model_name # Set _model_loaded as False by default. This will only get set to True if # we enable the tgis_backend and we are able to fetch the client successfully. @@ -87,25 +77,26 @@ def __init__( # for example, bootstrapping a model to caikit format and saving. self._client = None if tgis_backend: - self._client = tgis_backend.get_client(base_model_name) + self._client = tgis_backend.get_client(model_name) # mark that the model is loaded so that we can unload it later self._model_loaded = True + self.tgis_backend = tgis_backend self._bos_token = bos_token self._sep_token = sep_token self._eos_token = eos_token self._pad_token = pad_token self.tgis_generation_client = TGISGenerationClient( - self.base_model_name, self._eos_token, self._client, self.PRODUCER_ID + self.model_name, self._eos_token, self._client, self.PRODUCER_ID ) def __del__(self): # nothing to unload if we didn't finish loading - if self._model_loaded and self.load_backend: - self.load_backend.unload_model(self._model_path) + if self._model_loaded and self.tgis_backend: + self.tgis_backend.unload_model(self.model_name) @classmethod - def bootstrap(cls, base_model_path: str, load_backend: BackendBase = None): + def bootstrap(cls, model_path: str, load_backend: Union[BackendBase, None] = None): """Function to bootstrap a pre-trained transformers model and get a caikit text-generation 'model'. @@ -121,35 +112,14 @@ def bootstrap(cls, base_model_path: str, load_backend: BackendBase = None): caikit_nlp.blocks.text_generation.TextGeneration Object of TextGeneration class (model) """ - # pylint: disable=duplicate-code - model_config = AutoConfig.from_pretrained( - base_model_path, - local_files_only=not get_config().allow_downloads, - ) - resource_type = None - for resource in cls.supported_resources: - if model_config.model_type in resource.SUPPORTED_MODEL_TYPES: - resource_type = resource - break - - if not resource_type: - error( - "", - "{} model type is not supported currently!".format( - model_config.model_type - ), - ) - log.debug("Bootstrapping base resource [%s]", base_model_path) - base_model = resource_type.bootstrap( - base_model_path, tokenizer_name=base_model_path - ) - bos_token = base_model._tokenizer.bos_token - sep_token = base_model._tokenizer.sep_token - eos_token = base_model._tokenizer.eos_token or None - pad_token = base_model._tokenizer.pad_token + text_generation_inst = TextGeneration.bootstrap(base_model_path) + bos_token = text_generation_inst.base_model._tokenizer.bos_token + sep_token = text_generation_inst.base_model._tokenizer.sep_token + eos_token = text_generation_inst.base_model._tokenizer.eos_token or None + pad_token = text_generation_inst.base_model._tokenizer.pad_token return cls( - base_model_path, - base_model, + text_generation_inst.model_name, + text_generation_inst.model, bos_token=bos_token, sep_token=sep_token, eos_token=eos_token, @@ -157,38 +127,12 @@ def bootstrap(cls, base_model_path: str, load_backend: BackendBase = None): tgis_backend=load_backend, ) - def save(self, model_path): - """Save caikit model - - Args: - model_path: str - Folder to save text-generation caikit model - """ - saver = ModuleSaver(self, model_path=model_path) - with saver: - saver.update_config( - { - "base_model_name": self.base_model_name, - "bos_token": self._bos_token, - "sep_token": self._sep_token, - "eos_token": self._eos_token, - "pad_token": self._pad_token, - } - ) - if self.base_model: - artifacts_dir = "artifacts" - log.debug("Saving model artifacts to %s", artifacts_dir) - saver.update_config({"artifact_path": artifacts_dir}) - # This will save both tokenizer and base model - self.base_model.save( - model_path, - tokenizer_dirname=artifacts_dir, - base_model_dirname=artifacts_dir, - ) - @classmethod def load(cls, model_path: str, load_backend: BackendBase) -> "TextGeneration": - """Function to load text-generation model + """Function to load text-generation model. Note, this only loads + "remote" style model, i.e the cakit-model that doesn't + necessarily required to have actual artifacts in it + and thus only saves them in "remote" format. Args: model_path: str @@ -202,27 +146,58 @@ def load(cls, model_path: str, load_backend: BackendBase) -> "TextGeneration": error.type_check("", TGISBackend, load_backend=load_backend) config = ModuleConfig.load(model_path) + tgis_backend = config.tgis_backend or load_backend artifacts_path = config.artifact_path if artifacts_path: - base_model_name = os.path.join(model_path, artifacts_path) - error.dir_check("", base_model_name) - log.debug("Loading with on-disk artifacts: %s", base_model_name) + model_name = os.path.join(model_path, artifacts_path) + error.dir_check("", model_name) + log.debug("Loading with on-disk artifacts: %s", model_name) else: - base_model_name = config.base_model_name - error.type_check("", str, base_model_name=base_model_name) - log.debug("Loading with model name: %s", base_model_name) + model_name = config.model_name + error.type_check("", str, model_name=model_name) + log.debug("Loading with model name: %s", model_name) return cls( - base_model_name, + model_name, bos_token=config.bos_token, sep_token=config.sep_token, eos_token=config.eos_token, pad_token=config.pad_token, - tgis_backend=load_backend, + tgis_backend=tgis_backend, + ) + + def save(self, model_path: str): + """Export the config for this model. + This saves the model in "remote" style + and does not store the actual model artifacts + along with the caikit-model. + + model_path: str + Path to which we should write our model. + """ + # pylint: disable=duplicate-code + saver = ModuleSaver( + self, + model_path=model_path, ) + with saver: + saver.update_config( + { + "model_name": self.model_name, + "bos_token": self._bos_token, + "sep_token": self._sep_token, + "eos_token": self._eos_token, + "pad_token": self._pad_token, + } + ) @TextGenerationTask.taskmethod() def run( - self, text, preserve_input_text=False, max_new_tokens=20, min_new_tokens=0 + self, + text: str, + preserve_input_text: bool = False, + max_new_tokens: int = 20, + min_new_tokens: int = 0, + truncate_input_tokens: int = 0, ) -> GeneratedTextResult: """Run inference against the model running in TGIS. @@ -238,18 +213,32 @@ def run( min_new_tokens: int The minimum numbers of tokens to generate. Default: 0 - means no minimum + truncate_input_tokens: int + Truncate inputs to provided number of tokens. This can be + use to avoid failing due to input being longer than + configured limits. + Default: 0 - means don't truncate, thus throw error. Returns: GeneratedTextResult Generated text result produced by TGIS. """ if self._model_loaded: return self.tgis_generation_client.unary_generate( - text, preserve_input_text, max_new_tokens, min_new_tokens + text, + preserve_input_text, + max_new_tokens, + min_new_tokens, + truncate_input_tokens, ) @TextGenerationTask.taskmethod(output_streaming=True) def run_stream_out( - self, text: str, preserve_input_text=False, max_new_tokens=20, min_new_tokens=0 + self, + text: str, + preserve_input_text=False, + max_new_tokens=20, + min_new_tokens=0, + truncate_input_tokens=0, ) -> Iterable[GeneratedTextStreamResult]: """Run output stream inferencing for text generation module. @@ -263,11 +252,19 @@ def run_stream_out( Maximum tokens for the model to generate min_new_tokens: int Minimum tokens for the model to generate - + truncate_input_tokens: int + Truncate inputs to provided number of tokens. This can be + use to avoid failing due to input being longer than + configured limits. + Default: 0 - means don't truncate, thus throw error. Returns: Iterable[GeneratedTextStreamResult] """ if self._model_loaded: return self.tgis_generation_client.stream_generate( - text, preserve_input_text, max_new_tokens, min_new_tokens + text, + preserve_input_text, + max_new_tokens, + min_new_tokens, + truncate_input_tokens, ) diff --git a/caikit_nlp/modules/token_classification/__init__.py b/caikit_nlp/modules/token_classification/__init__.py index 4a9cd1f2..604cf0a6 100644 --- a/caikit_nlp/modules/token_classification/__init__.py +++ b/caikit_nlp/modules/token_classification/__init__.py @@ -14,4 +14,3 @@ # Local from .filtered_span_classification import FilteredSpanClassification -from .token_classification_task import TokenClassificationTask diff --git a/caikit_nlp/modules/token_classification/filtered_span_classification.py b/caikit_nlp/modules/token_classification/filtered_span_classification.py index e7c7cfa0..ded05969 100644 --- a/caikit_nlp/modules/token_classification/filtered_span_classification.py +++ b/caikit_nlp/modules/token_classification/filtered_span_classification.py @@ -29,17 +29,17 @@ module, ) from caikit.core.toolkit import error_handler -import alog - -# Local -from ...data_model import ( - StreamingTokenClassificationResult, - TokenClassification, +from caikit.interfaces.nlp.data_model import ( TokenClassificationResult, + TokenClassificationResults, + TokenClassificationStreamResult, ) -from ..text_classification.text_classification_task import TextClassificationTask -from .token_classification_task import TokenClassificationTask -from caikit_nlp.modules.tokenization.tokenization_task import TokenizationTask +from caikit.interfaces.nlp.tasks import ( + TextClassificationTask, + TokenClassificationTask, + TokenizationTask, +) +import alog log = alog.use_channel("FILT_SPAN") error = error_handler.get(log) @@ -73,10 +73,10 @@ def __init__( lang: str 2 letter language code tokenizer: ModuleBase - Tokenizer that returns TokenizationResult + Tokenizer that returns TokenizationResults classifier: ModuleBase Classification model instance returning Classification or - TokenClassification output on .run + TokenClassificationResult output on .run default_threshold: float Default threshold for scores labels_to_output: List[str] @@ -117,7 +117,7 @@ def __init__( @TokenClassificationTask.taskmethod() def run( self, text: str, threshold: Optional[float] = None - ) -> TokenClassificationResult: + ) -> TokenClassificationResults: """Run classification on text split into spans. Returns results based on score threshold for labels that are to be outputted @@ -128,7 +128,7 @@ def run( (Optional) Threshold based on which to return score results Returns: - TokenClassificationResult + TokenClassificationResults """ if threshold is None: threshold = self.default_threshold @@ -163,7 +163,7 @@ def run( if not self.labels_to_output or ( self.labels_to_output and label in self.labels_to_output ): - token_classification = TokenClassification( + token_classification = TokenClassificationResult( start=start, end=end, word=word, @@ -171,12 +171,12 @@ def run( score=classification.score, ) token_classification_results.append(token_classification) - return TokenClassificationResult(results=token_classification_results) + return TokenClassificationResults(results=token_classification_results) @TokenClassificationTask.taskmethod(input_streaming=True, output_streaming=True) def run_bidi_stream( self, text_stream: Iterable[str], threshold: Optional[float] = None - ) -> Iterable[StreamingTokenClassificationResult]: + ) -> Iterable[TokenClassificationStreamResult]: """Run bi-directional streaming inferencing for this module. Run classification on text split into spans. Returns results based on score threshold for labels that are to be outputted @@ -188,7 +188,7 @@ def run_bidi_stream( (Optional) Threshold based on which to return score results Returns: - Iterable[StreamingTokenClassificationResult] + Iterable[TokenClassificationStreamResult] """ # TODO: For optimization implement window based approach. if threshold is None: @@ -215,9 +215,9 @@ def run_bidi_stream( ): # Need to add offset to track actual place of spans within a stream, # as the span splitting will be expected to stream and detect spans - yield StreamingTokenClassificationResult( + yield TokenClassificationStreamResult( results=[ - TokenClassification( + TokenClassificationResult( start=start, end=end, word=word, @@ -231,7 +231,7 @@ def run_bidi_stream( results_to_end_of_span = True if not results_to_end_of_span: - yield StreamingTokenClassificationResult( + yield TokenClassificationStreamResult( results=[], processed_index=span_output.end ) @@ -296,10 +296,10 @@ def bootstrap( lang: str 2 letter language code tokenizer: ModuleBase - Tokenizer that returns TokenizationResult + Tokenizer that returns TokenizationResults classifier: ModuleBase Classification model instance returning Classification or - TokenClassification output on .run + TokenClassificationResult output on .run default_threshold: float Default threshold for scores labels_to_output: List[str] diff --git a/caikit_nlp/modules/token_classification/token_classification_task.py b/caikit_nlp/modules/token_classification/token_classification_task.py deleted file mode 100644 index 9dde2ce5..00000000 --- a/caikit_nlp/modules/token_classification/token_classification_task.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""This task can be promoted to caikit/caikit for wider usage when applicable -to multiple modules -""" -# Standard -from typing import Iterable - -# First Party -from caikit.core import TaskBase, task - -# Local -from ...data_model import StreamingTokenClassificationResult, TokenClassificationResult - - -@task( - unary_parameters={"text": str}, - streaming_parameters={"text_stream": Iterable[str]}, - unary_output_type=TokenClassificationResult, - streaming_output_type=Iterable[StreamingTokenClassificationResult], -) -class TokenClassificationTask(TaskBase): - pass diff --git a/caikit_nlp/modules/tokenization/__init__.py b/caikit_nlp/modules/tokenization/__init__.py index 444fb4a9..1358fd0f 100644 --- a/caikit_nlp/modules/tokenization/__init__.py +++ b/caikit_nlp/modules/tokenization/__init__.py @@ -14,4 +14,3 @@ # Local from .regex_sentence_splitter import RegexSentenceSplitter -from .tokenization_task import TokenizationTask diff --git a/caikit_nlp/modules/tokenization/regex_sentence_splitter.py b/caikit_nlp/modules/tokenization/regex_sentence_splitter.py index 8b8d57dc..d9f578ee 100644 --- a/caikit_nlp/modules/tokenization/regex_sentence_splitter.py +++ b/caikit_nlp/modules/tokenization/regex_sentence_splitter.py @@ -20,12 +20,10 @@ # First Party from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module from caikit.core.toolkit import error_handler +from caikit.interfaces.nlp.data_model import Token, TokenizationResults +from caikit.interfaces.nlp.tasks import TokenizationTask import alog -# Local -from ...data_model import Token, TokenizationResult -from .tokenization_task import TokenizationTask - log = alog.use_channel("RGX_SNT_SPLT") error = error_handler.get(log) @@ -100,7 +98,7 @@ def load(cls, model_path: str) -> "RegexSentenceSplitter": config = ModuleConfig.load(os.path.abspath(model_path)) return cls(regex_str=config.regex_str) - def run(self, text: str) -> TokenizationResult: + def run(self, text: str) -> TokenizationResults: """Run sentence splitting regex on input text. Args: @@ -108,8 +106,8 @@ def run(self, text: str) -> TokenizationResult: Document to run sentence splitting on. Returns: - TokenizationResult - TokenizationResult object containing tokens where each token + TokenizationResults + TokenizationResults object containing tokens where each token corresponds to a detected sentence. """ @@ -121,4 +119,4 @@ def run(self, text: str) -> TokenizationResult: token = Token(start=match.start(), end=match.end(), text=match.group()) tokens.append(token) - return TokenizationResult(results=tokens) + return TokenizationResults(results=tokens) diff --git a/caikit_nlp/resources/pretrained_model/base.py b/caikit_nlp/resources/pretrained_model/base.py index c2232a6c..59bb4d45 100644 --- a/caikit_nlp/resources/pretrained_model/base.py +++ b/caikit_nlp/resources/pretrained_model/base.py @@ -14,12 +14,18 @@ # Standard from abc import ABC, abstractmethod -from typing import Callable, List, Optional, Tuple, Type +from typing import Callable, List, Optional, Tuple, Type, Union import json import os # Third Party -from transformers import AutoTokenizer +from torch.utils.data import IterableDataset +from transformers import ( + AutoTokenizer, + DataCollatorWithPadding, + Trainer, + TrainingArguments, +) from transformers.models.auto.auto_factory import _BaseAutoModelClass import torch @@ -233,6 +239,61 @@ def save( self.tokenizer.save_pretrained(tok_abs_path) self.model.save_pretrained(model_abs_path) + def get_trainer( + self, + train_dataset: IterableDataset, + eval_dataset: Union[IterableDataset, None] = None, + optimizers=(None, None), + **kwargs, + ): + """ + Args: + **kwargs: arguments supported by HF TrainingArguments: + https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.TrainingArguments + + NOTE: following parameters are not supported currently: + 1. model_init + 2. compute_metrics + 3. callbacks + 4. preprocess_logits_for_metrics + """ + + training_args = TrainingArguments(**kwargs) + + data_collator = self._get_data_collator(**kwargs) + + trainer_arguments = { + "train_dataset": train_dataset, + "data_collator": data_collator, + "tokenizer": self._tokenizer, + "optimizers": optimizers, + "eval_dataset": eval_dataset, + } + + return Trainer(self._model, training_args, **trainer_arguments) + + def _get_data_collator(self, **kwargs): + """Function to return appropriate data collator based on resource. + + The default implementation of the base resource uses + DataCollatorWithPadding which will dynamically pad the inputs received. + + Args: + **kwargs: + All the keyword arguments passed to this function + will get filtered out to appropriate ones that are + applicable to implemented data collator. + Returns: + transformers.DataCollator + """ + + applicable_args = ["max_length", "pad_to_multiple_of"] + collator_kwargs = {key: kwargs[key] for key in applicable_args if key in kwargs} + + return DataCollatorWithPadding( + tokenizer=self._tokenizer, padding=True, **collator_kwargs + ) + # pylint: disable=unused-argument @classmethod def get_num_transformers_submodules( @@ -249,6 +310,7 @@ def build_task_tokenize_function( max_source_length: int, max_target_length: int, verbalizer: str, + task_ids: Union[None, int] = None, ) -> Tuple[Callable, bool]: """Builds tokenizer functions which can be mapped over train streams to process data which can then be easily passed to a DataLoader for different model types. @@ -263,6 +325,10 @@ def build_task_tokenize_function( verbalizer: str Verbalizer template to be used for formatting data. This template may use brackets to indicate where fields from the data model TrainGenerationRecord must be rendered. + task_ids: Union[None, int] + Task id corresponding particular task for multi-task prompt tuning. + NOTE: Only required for MPT (Multi-task prompt tuning) + Default: None Returns: Tuple(Callable, bool) diff --git a/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py b/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py index b98a2983..30c0be20 100644 --- a/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py +++ b/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py @@ -16,10 +16,10 @@ """ # Standard from copy import deepcopy -from typing import Callable, Tuple +from typing import Callable, Tuple, Union # Third Party -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, DataCollatorForLanguageModeling from transformers.models.auto import modeling_auto # First Party @@ -52,6 +52,7 @@ def build_task_tokenize_function( max_source_length: int, max_target_length: int, verbalizer: str, + task_ids: Union[None, int] = None, ) -> Tuple[Callable, bool]: """Builds tokenizer functions which can be mapped over train streams to process data which can then be easily passed to a DataLoader for CausalLM models. @@ -66,6 +67,10 @@ def build_task_tokenize_function( verbalizer: str Verbalizer template to be used for formatting data. This template may use brackets to indicate where fields from the data model TrainGenerationRecord must be rendered. + task_ids: Union[None, int] + Task id corresponding particular task for multi-task prompt tuning. + NOTE: Only required for MPT (Multi-task prompt tuning) + Default: None Returns: Tuple(Callable, bool) @@ -104,7 +109,9 @@ def tokenize_function_language_model( # Here, we need to yield and manipulate the attention mask to attend # to the input seq + the tokens we have seen so far... num_target_samples = len(target_ids.input_ids) - source_ids["task_ids"] = 0 + + if task_ids is not None: + source_ids["task_ids"] = task_ids def generator_func(): for idx in range(num_target_samples): @@ -122,3 +129,32 @@ def generator_func(): return DataStream(generator_func) return (tokenize_function_language_model, True) + + def _get_data_collator(self, **kwargs): + """Function to return appropriate data collator based on resource. + + DataCollatorForLanguageModeling is used here which will dynamically + padded to maximum length of a batch if they are not all of the same + length. + + NOTE: If mlm (masked language modeling) is not passed in kwargs, + this function will automatically set it to `False`. + + Args: + **kwargs: + All the keyword arguments passed to this function + will get filtered out to appropriate ones that are + applicable to implemented data collator. + Returns: + transformers.DataCollator + """ + + applicable_args = ["mlm", "pad_to_multiple_of"] + collator_kwargs = {key: kwargs[key] for key in applicable_args if key in kwargs} + + if "mlm" not in collator_kwargs: + collator_kwargs["mlm"] = False + + return DataCollatorForLanguageModeling( + tokenizer=self._tokenizer, return_tensors="pt", **collator_kwargs + ) diff --git a/caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py b/caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py index a8b41d45..bdd69aa1 100644 --- a/caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py +++ b/caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py @@ -15,10 +15,16 @@ Huggingface auto causal LM resource type """ # Standard -from typing import Callable, List, Tuple +from typing import Callable, List, Tuple, Union # Third Party -from transformers import AutoModelForSeq2SeqLM +from torch.utils.data import IterableDataset +from transformers import ( + AutoModelForSeq2SeqLM, + DataCollatorForSeq2Seq, + Seq2SeqTrainer, + Seq2SeqTrainingArguments, +) from transformers.models.auto import modeling_auto # First Party @@ -68,12 +74,72 @@ def get_num_transformers_submodules( ) return num_transformer_submodules + def get_trainer( + self, + train_dataset: IterableDataset, + eval_dataset: Union[IterableDataset, None] = None, + optimizers=(None, None), + **kwargs + ): + """ + Args: + *kwargs: arguments supported by HF Seq2SeqTrainingArguments: + https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.Seq2SeqTrainingArguments + + NOTE: following parameters are not supported currently: + 1. model_init + 2. compute_metrics + 3. callbacks + 4. preprocess_logits_for_metrics + """ + + # NOTE: predict_with_generate is incompatible with fsdp + training_args = Seq2SeqTrainingArguments(**kwargs) + + # pylint: disable=duplicate-code + # TODO: Fetch DataCollator either from property of this + # class or fetch it as an argument. + data_collator = self._get_data_collator(**kwargs) + + trainer_arguments = { + "train_dataset": train_dataset, + "data_collator": data_collator, + "tokenizer": self._tokenizer, + "optimizers": optimizers, + "eval_dataset": eval_dataset, + # "generation_max_length": max_target_length, + } + + return Seq2SeqTrainer(self._model, training_args, **trainer_arguments) + + def _get_data_collator(self, **kwargs): + """Function to return appropriate data collator based on resource. + + This implementation uses DataCollatorForSeq2Seq + + Args: + **kwargs: + All the keyword arguments passed to this function + will get filtered out to appropriate ones that are + applicable to implemented data collator. + Returns: + transformers.DataCollator + """ + + applicable_args = ["max_length", "pad_to_multiple_of"] + collator_kwargs = {key: kwargs[key] for key in applicable_args if key in kwargs} + + return DataCollatorForSeq2Seq( + tokenizer=self._tokenizer, model=self._model, **collator_kwargs + ) + @staticmethod def build_task_tokenize_function( tokenizer: "AutoTokenizer", max_source_length: int, max_target_length: int, verbalizer: str, + task_ids: Union[None, int] = None, ) -> Tuple[Callable, bool]: """Builds tokenizer functions which can be mapped over train streams to process data which can then be easily passed to a DataLoader for seq2seq models. @@ -88,6 +154,10 @@ def build_task_tokenize_function( verbalizer: str Verbalizer template to be used for formatting data. This template may use brackets to indicate where fields from the data model TrainGenerationRecord must be rendered. + task_ids: Union[None, int] + Task id corresponding particular task for multi-task prompt tuning. + NOTE: Only required for MPT (Multi-task prompt tuning) + Default: None Returns: Tuple(Callable, bool) @@ -134,7 +204,9 @@ def tokenize_function_seq2seq( map(lambda x: IGNORE_ID if x == tokenizer.pad_token_id else x, labels) ) model_inputs["labels"] = labels - model_inputs["task_ids"] = 0 + if task_ids is not None: + model_inputs["task_ids"] = task_ids + return model_inputs return (tokenize_function_seq2seq, False) diff --git a/caikit_nlp/toolkit/task_specific_utils.py b/caikit_nlp/toolkit/task_specific_utils.py index 5b728da0..fcc3e9de 100644 --- a/caikit_nlp/toolkit/task_specific_utils.py +++ b/caikit_nlp/toolkit/task_specific_utils.py @@ -14,10 +14,11 @@ # First Party from caikit.core.toolkit import error_handler +from caikit.interfaces.nlp.data_model import ClassificationTrainRecord import alog # Local -from ..data_model import ClassificationTrainRecord, GenerationTrainRecord +from ..data_model import GenerationTrainRecord log = alog.use_channel("TASK_UTILS") error = error_handler.get(log) @@ -38,3 +39,22 @@ def convert_to_generation_record(train_record): and GenerationTrainRecord are supported" ), ) + + +def get_sorted_unique_class_labels(data_stream): + """Get the list of sorted unique class labels from a data stream of ClassificationTrainRecord. + + Args: + data_stream: DataStream[ClassificationTrainRecord] + Data stream of ClassificationTrainRecord from which to extract unique class labels + Returns: + unique_labels + Sorted list containing the unique set of classes discovered in the data stream + """ + labels_data_stream = data_stream.map(lambda item: item.labels) + unique_labels = set() + for label_list in labels_data_stream: + for label in label_list: + unique_labels.add(label) + + return sorted(unique_labels) diff --git a/caikit_nlp/toolkit/tgis_utils.py b/caikit_nlp/toolkit/tgis_utils.py index be32657d..a1d44d4c 100644 --- a/caikit_nlp/toolkit/tgis_utils.py +++ b/caikit_nlp/toolkit/tgis_utils.py @@ -31,19 +31,27 @@ error = error_handler.get(log) -def get_params(preserve_input_text, eos_token, max_new_tokens, min_new_tokens): +def get_params( + preserve_input_text, + eos_token, + max_new_tokens, + min_new_tokens, + truncate_input_tokens, +): """Get generation parameters Args: - preserve_input_text: str + preserve_input_text: str Whether or not the source string should be contained in the generated output, e.g., as a prefix. - eos_token: str + eos_token: str A special token representing the end of a sentence. - max_new_tokens: int + max_new_tokens: int The maximum numbers of tokens to generate. - min_new_tokens: int + min_new_tokens: int The minimum numbers of tokens to generate. + truncate_input_tokens: int + Truncate inputs to provided number of tokens. """ res_options = generation_pb2.ResponseOptions( input_text=preserve_input_text, @@ -60,6 +68,7 @@ def get_params(preserve_input_text, eos_token, max_new_tokens, min_new_tokens): params = generation_pb2.Parameters( response=res_options, stopping=stopping, + truncate_input_tokens=truncate_input_tokens, ) return params @@ -77,7 +86,12 @@ def __init__( self.prefix_id = prefix_id def unary_generate( - self, text, preserve_input_text, max_new_tokens, min_new_tokens + self, + text, + preserve_input_text, + max_new_tokens, + min_new_tokens, + truncate_input_tokens, ) -> GeneratedTextResult: """Generate unary output from model in TGIS @@ -93,7 +107,11 @@ def unary_generate( min_new_tokens: int The minimum numbers of tokens to generate. Default: 0 - means no minimum - + truncate_input_tokens: int + Truncate inputs to provided number of tokens. This can be + use to avoid failing due to input being longer than + configured limits. + 0 - means don't truncate, thus throw error. Returns: GeneratedTextResult Generated text result produced by TGIS. @@ -114,6 +132,7 @@ def unary_generate( eos_token=self.eos_token, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, + truncate_input_tokens=truncate_input_tokens, ) gen_reqs = [generation_pb2.GenerationRequest(text=text)] @@ -150,7 +169,12 @@ def unary_generate( ) def stream_generate( - self, text, preserve_input_text, max_new_tokens, min_new_tokens + self, + text, + preserve_input_text, + max_new_tokens, + min_new_tokens, + truncate_input_tokens, ) -> Iterable[GeneratedTextStreamResult]: """Generate stream output from model in TGIS @@ -164,6 +188,11 @@ def stream_generate( Maximum tokens for the model to generate min_new_tokens: int Minimum tokens for the model to generate + truncate_input_tokens: int + Truncate inputs to provided number of tokens. This can be + use to avoid failing due to input being longer than + configured limits. + 0 - means don't truncate, thus throw error. Returns: Iterable[GeneratedTextStreamResult] @@ -183,6 +212,7 @@ def stream_generate( eos_token=self.eos_token, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, + truncate_input_tokens=truncate_input_tokens, ) gen_req = generation_pb2.GenerationRequest(text=text) diff --git a/examples/run_fine_tuning.py b/examples/run_fine_tuning.py index 4768859e..9d7d5bf4 100644 --- a/examples/run_fine_tuning.py +++ b/examples/run_fine_tuning.py @@ -20,6 +20,7 @@ SUPPORTED_METRICS, DatasetInfo, configure_random_seed_and_logging, + load_model, print_colored, ) @@ -30,7 +31,7 @@ # Local from caikit_nlp.data_model import GenerationTrainRecord, TuningConfig -from caikit_nlp.modules.text_generation import FineTuning +from caikit_nlp.modules.text_generation import TextGeneration from caikit_nlp.resources.pretrained_model import ( HFAutoCausalLM, HFAutoSeq2SeqLM, @@ -170,6 +171,11 @@ def register_common_arguments(subparser: argparse.ArgumentParser) -> None: nargs="*", default=["accuracy"], ) + subparser.add_argument( + "--tgis", + help="Run inference using TGIS. NOTE: This involves saving and reloading model in TGIS container", + action="store_true", + ) def validate_common_args(args: argparse.Namespace): @@ -226,7 +232,9 @@ def show_experiment_configuration(args, dataset_info, model_type) -> None: print_colored("\n".join([print_str for print_str in print_strs if print_str])) -def get_model_preds_and_references(model, validation_stream): +def get_model_preds_and_references( + model, validation_stream, truncate_input_tokens, max_new_tokens +): """Given a model & a validation stream, run the model against every example in the validation stream and compare the outputs to the target/output sequence. @@ -236,7 +244,9 @@ def get_model_preds_and_references(model, validation_stream): validation_stream: DataStream[GenerationTrainRecord] Validation stream with labeled targets that we want to compare to our model's predictions. - + truncate_input_tokens: int + maximum number of tokens to be accepted by the model and rest will be + truncated. Returns: Tuple(List) Tuple of 2 lists; the model predictions and the expected output sequences. @@ -247,7 +257,11 @@ def get_model_preds_and_references(model, validation_stream): for datum in tqdm(validation_stream): # Local .run() currently prepends the input text to the generated string; # Ensure that we're just splitting the first predicted token & beyond. - raw_model_text = model.run(datum.input).text + raw_model_text = model.run( + datum.input, + truncate_input_tokens=truncate_input_tokens, + max_new_tokens=max_new_tokens, + ).generated_text parse_pred_text = raw_model_text.split(datum.input)[-1].strip() model_preds.append(parse_pred_text) targets.append(datum.output) @@ -299,24 +313,24 @@ def export_model_preds(preds_file, predictions, validation_stream): # Init the resource & Build the tuning config from our dataset/arg info print_colored("[Loading the base model resource...]") - base_model = model_type.bootstrap(args.model_name, tokenizer_name=args.model_name) + base_model = model_type.bootstrap( + args.model_name, tokenizer_name=args.model_name, torch_dtype=args.torch_dtype + ) # Then actually train the model & save it print_colored("[Starting the training...]") - model = FineTuning.train( + model = TextGeneration.train( base_model, train_stream, max_source_length=args.max_source_length, max_target_length=args.max_target_length, lr=args.learning_rate, - torch_dtype="float16", + torch_dtype=args.torch_dtype, batch_size=args.batch_size, accumulate_steps=args.accumulate_steps, num_epochs=args.num_epochs, ) - # model.save(args.output_dir, save_base_model=not args.prompt_only) - print_colored("[Training Complete]") # Prediction @@ -325,13 +339,35 @@ def export_model_preds(preds_file, predictions, validation_stream): print("Generated text: ", prediction_results) + # Saving model + model.save(args.output_dir) + + if args.tgis: + + # Load model in TGIS + # HACK: export args.output_dir as MODEL_NAME for TGIS + # container to pick up automatically + os.environ["MODEL_DIR"] = os.path.dirname(args.output_dir) + os.environ["MODEL_NAME"] = os.path.join( + "models", os.path.basename(args.output_dir), "artifacts" + ) + + loaded_model = load_model(is_distributed=True, model_path=args.output_dir) + + else: + # Use trained model directly + loaded_model = model + ## Evaluation print_colored("[Starting Evaluation]") validation_stream = dataset_info.dataset_loader()[1] print_colored("Getting model predictions...") - predictions, references = get_model_preds_and_references(model, validation_stream) + truncate_input_tokens = args.max_source_length + args.max_target_length + predictions, references = get_model_preds_and_references( + loaded_model, validation_stream, truncate_input_tokens, args.max_target_length + ) export_model_preds(args.preds_file, predictions, validation_stream) diff --git a/examples/text-generation-launcher b/examples/text-generation-launcher index 4204f93c..6273dcb2 100755 --- a/examples/text-generation-launcher +++ b/examples/text-generation-launcher @@ -1,16 +1,17 @@ #!/usr/bin/env bash -# This script is primarily meant for illustrative purposes; if we don't have +# This script is primarily meant for illustrative purposes; if we don't have # the text-generation-launcher command locally available, but we do have a Docker # container, we add this script onto our path so when the TGIS backend in caikit # tries to start the server, it runs this script instead. # -# NOTE: +# NOTE: # - Model ID, directories, etc are hardcoded to our example, params from the backend, # e.g., shard configuration, are ignored. # # - We need to export port 3000 (for probes in core distributed), and we forward 8033->50055 # so that our gRPC server is exposed on the expected port for local TGIS. TGIS_MODEL="${MODEL_NAME:-bigscience/bloom-560m}" +MODEL_DIR="${MODEL_DIR:-models}" echo "Running TGIS with model: $TGIS_MODEL" docker run --rm \ @@ -21,7 +22,7 @@ docker run --rm \ -p 8087:8087 \ -p 50055:8033 \ -p 3000:3000 \ - -v $(pwd)/models:/models \ + -v $(pwd)/${MODEL_DIR}:/models \ -v $(pwd)/../runtime_config.yaml:/conf/runtime_config.yaml \ -v $(pwd)/transformers_cache:/shared_model_storage/transformers_cache \ -v $(pwd)/prompt_prefixes:/prompt_prefixes \ diff --git a/examples/utils.py b/examples/utils.py index 80d18cf2..cfb944fd 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -39,6 +39,8 @@ "formatter": "pretty", } +log = alog.use_channel("EXMPL_UTILS") + def configure_random_seed_and_logging(): """Ensure that random experiments will be deterministic & set up default ALOG configuration.""" @@ -86,7 +88,7 @@ def get_distributed_model(model_path): "initializers": { "default": { "config": { - "backend_priority": {[{"type": TGISBackend.backend_type}]} + "backend_priority": [{"type": TGISBackend.backend_type}] } } } @@ -98,14 +100,18 @@ def get_distributed_model(model_path): # make sure that its suffix (base model name) aligns with what we have in our config. # NOTE: bloom-560m is the default here because that's the default model used in our # text generation server hack script. - model_name_override = os.getenv("MODEL_NAME", "bloom-560m") - loaded_base_model = dist_model.base_model_name + model_name_override = os.getenv("MODEL_NAME", "bigscience/bloom-560m") + if hasattr(dist_model, "base_model_name"): + loaded_base_model = dist_model.base_model_name + else: + loaded_base_model = dist_model.model_name if not model_name_override.endswith(loaded_base_model): - raise ValueError( + log.error( "TGIS using model name: {} conflicts with base model name: {}; set env var MODEL_NAME to the correct base model!".format( model_name_override, loaded_base_model ) ) + return dist_model diff --git a/pyproject.toml b/pyproject.toml index ef954769..ae9be782 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,9 +14,8 @@ classifiers=[ "License :: OSI Approved :: Apache Software License" ] dependencies = [ - "caikit[runtime-grpc,runtime-http]>=0.13.0,<0.15.0", - "caikit-tgis-backend>=0.1.14,<0.2.0", - + "caikit[runtime-grpc,runtime-http]>=0.16.0,<0.18.0", + "caikit-tgis-backend>=0.1.16,<0.2.0", # TODO: loosen dependencies "accelerate>=0.18.0", "datasets>=2.4.0", diff --git a/runtime_config.yaml b/runtime_config.yaml index 5799b5ba..3c3812b3 100644 --- a/runtime_config.yaml +++ b/runtime_config.yaml @@ -9,6 +9,13 @@ runtime: size: 0 # Set to batch size for batching model_management: + finders: + default: + type: LOCAL + remote_tgis: + type: TGIS-AUTO + config: + test_connection: true initializers: default: type: LOCAL diff --git a/caikit_nlp/modules/text_classification/text_classification_task.py b/tests/conftest.py similarity index 64% rename from caikit_nlp/modules/text_classification/text_classification_task.py rename to tests/conftest.py index 8de9c088..ff72f10b 100644 --- a/caikit_nlp/modules/text_classification/text_classification_task.py +++ b/tests/conftest.py @@ -11,19 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""This task can be promoted to caikit/caikit for wider usage when applicable -to multiple modules """ -# First Party -from caikit.core import TaskBase, task +This sets up global test configs when pytest starts +""" -# Local -from ...data_model import ClassificationResult +# Standard +import os +# First Party +import alog -@task( - required_parameters={"text": str}, - output_type=ClassificationResult, +# Configure logging from the environment +alog.configure( + default_level=os.environ.get("LOG_LEVEL", "off"), + filters=os.environ.get("LOG_FILTERS", "urllib3:off"), + thread_id=os.environ.get("LOG_THREAD_ID", "") == "true", ) -class TextClassificationTask(TaskBase): - pass diff --git a/tests/data_model/test_classification.py b/tests/data_model/test_classification.py deleted file mode 100644 index 81aebe5c..00000000 --- a/tests/data_model/test_classification.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Local -from caikit_nlp import data_model as dm - -## Setup ######################################################################### - -classification1 = dm.Classification(label="temperature", score=0.71) - -classification2 = dm.Classification(label="conditions", score=0.98) - -classification_result = dm.ClassificationResult( - results=[classification1, classification2] -) - -token_classification1 = dm.TokenClassification( - start=0, end=5, word="moose", entity="animal", score=0.8 -) -token_classification2 = dm.TokenClassification( - start=7, end=12, word="goose", entity="animal", score=0.7 -) -token_classification_result = dm.TokenClassificationResult( - results=[token_classification1, token_classification2] -) - -classification_train_record = dm.ClassificationTrainRecord( - text="It is 20 degrees today", labels=["temperature"] -) - -## Tests ######################################################################## - -### Classification - - -def test_classification_all_fields_accessible(): - classification_result = dm.Classification(label="temperature", score=0.71) - assert classification_result.label == "temperature" - assert classification_result.score == 0.71 - - -def test_classification_from_proto_and_back(): - new = dm.Classification.from_proto(classification1.to_proto()) - assert new.label == "temperature" - assert new.score == 0.71 - - -def test_classification_from_json_and_back(): - new = dm.Classification.from_json(classification1.to_json()) - assert new.label == "temperature" - assert new.score == 0.71 - - -### ClassificationResult - - -def test_classification_result_all_fields_accessible(): - classification_result = dm.ClassificationResult(results=[classification1]) - assert classification_result.results[0].label == "temperature" - assert classification_result.results[0].score == 0.71 - - -def test_classification_result_from_proto_and_back(): - new = dm.ClassificationResult.from_proto(classification_result.to_proto()) - assert new.results[0].label == "temperature" - assert new.results[0].score == 0.71 - assert new.results[1].label == "conditions" - assert new.results[1].score == 0.98 - - -def test_classification_result_from_json_and_back(): - new = dm.ClassificationResult.from_json(classification_result.to_json()) - assert new.results[0].label == "temperature" - assert new.results[0].score == 0.71 - assert new.results[1].label == "conditions" - assert new.results[1].score == 0.98 - - -### TokenClassification - - -def test_token_classification_all_fields_accessible(): - token_classification = dm.TokenClassification( - start=0, - end=28, - word="The cow jumped over the moon", - entity="neutral", - score=0.6, - ) - assert token_classification.start == 0 - assert token_classification.end == 28 - assert token_classification.word == "The cow jumped over the moon" - assert token_classification.entity == "neutral" - assert token_classification.score == 0.6 - - -def test_classification_from_proto_and_back(): - new = dm.TokenClassification.from_proto(token_classification1.to_proto()) - assert new.start == 0 - assert new.word == "moose" - assert new.score == 0.8 - - -def test_classification_from_json_and_back(): - new = dm.TokenClassification.from_json(token_classification1.to_json()) - assert new.start == 0 - assert new.word == "moose" - assert new.score == 0.8 - - -### TokenClassificationResult - - -def test_token_classification_result_all_fields_accessible(): - token_classification_result = dm.TokenClassificationResult( - results=[token_classification1] - ) - assert token_classification_result.results[0].start == 0 - assert token_classification_result.results[0].word == "moose" - assert token_classification_result.results[0].score == 0.8 - - -def test_token_classification_result_from_proto_and_back(): - new = dm.TokenClassificationResult.from_proto( - token_classification_result.to_proto() - ) - assert new.results[0].start == 0 - assert new.results[0].word == "moose" - assert new.results[0].score == 0.8 - assert new.results[1].end == 12 - assert new.results[1].entity == "animal" - - -def test_classification_result_from_json_and_back(): - new = dm.TokenClassificationResult.from_json(token_classification_result.to_json()) - assert new.results[0].start == 0 - assert new.results[0].word == "moose" - assert new.results[0].score == 0.8 - assert new.results[1].end == 12 - assert new.results[1].entity == "animal" - - -### ClassificationTrainRecord - - -def test_all_fields_accessible(): - classification_train_record = dm.ClassificationTrainRecord( - text="It is 20 degrees today", labels=["temperature"] - ) - assert classification_train_record.text == "It is 20 degrees today" - assert classification_train_record.labels == ["temperature"] - - -def test_from_proto_and_back(): - new = dm.ClassificationTrainRecord.from_proto( - classification_train_record.to_proto() - ) - assert new.text == "It is 20 degrees today" - assert new.labels == ["temperature"] - - -def test_from_json_and_back(): - new = dm.ClassificationTrainRecord.from_json(classification_train_record.to_json()) - assert new.text == "It is 20 degrees today" - assert new.labels == ["temperature"] diff --git a/tests/data_model/test_text.py b/tests/data_model/test_text.py deleted file mode 100644 index cb2c2fa3..00000000 --- a/tests/data_model/test_text.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Local -from caikit_nlp import data_model as dm - -## Setup ######################################################################### - -dummy_token = dm.Token(start=0, end=11, text="Hello World") - -## Tests ######################################################################## - - -def test_all_fields_accessible(): - token = dm.Token(start=0, end=11, text="Hello World") - assert token.start == 0 - assert token.end == 11 - assert token.text == "Hello World" - - -def test_from_proto_and_back(): - new = dm.Token.from_proto(dummy_token.to_proto()) - assert new.start == 0 - assert new.end == 11 - assert new.text == "Hello World" - - -def test_from_json_and_back(): - new = dm.Token.from_json(dummy_token.to_json()) - assert new.start == 0 - assert new.end == 11 - assert new.text == "Hello World" diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 2cc4f0b7..99d5febe 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -32,6 +32,20 @@ SEQ2SEQ_LM_MODEL = os.path.join(TINY_MODELS_DIR, "T5ForConditionalGeneration") +@pytest.fixture() +def set_cpu_device(request): + """Fixture to set default cuda device. + This fixture is particularly useful for running the unit tests where + cuda devices are available, in which case, some transformers function + may try to consume cuda and give device mismatch error. + """ + visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "") + os.environ["CUDA_VISIBLE_DEVICES"] = "" + with mock.patch.object(torch.cuda, "is_available", return_value=False): + yield + os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices + + @pytest.fixture def disable_wip(request): """Fixture to temporarily disable wip decorator""" @@ -228,6 +242,7 @@ def __init__( def get_client(self, base_model_name): self._model_connections[base_model_name] = TGISConnection( hostname="foo.bar", + model_id=base_model_name, prompt_dir=self._temp_dir, ) return StubTGISClient(base_model_name) diff --git a/tests/data_model/__init__.py b/tests/model_management/__init__.py similarity index 100% rename from tests/data_model/__init__.py rename to tests/model_management/__init__.py diff --git a/tests/model_management/test_tgis_auto_finder.py b/tests/model_management/test_tgis_auto_finder.py new file mode 100644 index 00000000..6638e865 --- /dev/null +++ b/tests/model_management/test_tgis_auto_finder.py @@ -0,0 +1,263 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the TGISAutoFinder""" + +# Standard +from contextlib import contextmanager +from typing import Optional +from unittest.mock import patch +import copy + +# Third Party +import pytest + +# First Party +from caikit.config.config import merge_configs +from caikit.core.model_manager import ModelManager +from caikit_tgis_backend import TGISBackend +import aconfig +import caikit + +# Local +from caikit_nlp.model_management import tgis_auto_finder +from caikit_nlp.modules.text_generation.text_generation_tgis import TextGenerationTGIS + +## Helpers ##################################################################### + +# Convenient aliases +LOCAL_INITIALIZER_NAME = tgis_auto_finder.TGISAutoFinder._LOCAL_INITIALIZER_NAME_KEY +TGIS_BACKEND_PRIORITY = tgis_auto_finder.TGISAutoFinder._TGIS_BACKEND_PRIORITY_KEY + + +def make_tgis_config(hostname: str): + return { + "type": TGISBackend.backend_type, + "config": { + "test_connections": False, + "connection": { + "hostname": hostname, + }, + }, + } + + +@contextmanager +def temp_model_manager( + auto_finder_config: Optional[dict] = None, + backend_priority: Optional[list] = None, + local_initializer_key: str = "default", +): + global_config = copy.deepcopy(getattr(caikit.config.config, "_CONFIG")) + if backend_priority is None: + backend_priority = [make_tgis_config("foo.bar:123")] + updated_config = merge_configs( + global_config, + { + "model_management": { + "finders": { + "default": { + "type": tgis_auto_finder.TGISAutoFinder.name, + "config": auto_finder_config or {}, + }, + }, + "initializers": { + local_initializer_key: { + "type": "LOCAL", + "config": { + "backend_priority": backend_priority, + }, + }, + }, + }, + }, + ) + with patch("caikit.core.model_manager.get_config", lambda: updated_config): + mmgr = ModelManager() + # NOTE: The TGISAutoFinder relies on searching for the TGISBackend in + # the global caikit.core.MODEL_MANAGER and can inadvertently trigger + # that global instance to set up an initializer that prefers TGIS over + # LOCAL. We need isolation for tests, so we mock that global with this + # temporary instance here. + with patch.object(tgis_auto_finder, "MODEL_MANAGER", new=mmgr): + mmgr.initialize_components() + yield mmgr + + +## Tests ####################################################################### + + +def test_auto_find_tgis_model_ok(): + """Test that a TGIS text-gen model can be auto-found""" + with temp_model_manager() as mmgr: + model = mmgr.load("flan-t5-xl") + assert model + assert isinstance(model, TextGenerationTGIS) + + +def test_auto_find_tgis_model_non_default_local_initializer(): + """Test that a TGIS text-gen model can be auto-found when the local + initializer is not the default + """ + init_name = "notdefault" + with temp_model_manager( + auto_finder_config={LOCAL_INITIALIZER_NAME: init_name}, + local_initializer_key=init_name, + ) as mmgr: + model = mmgr.load("flan-t5-xl", initializer=init_name) + assert model + assert isinstance(model, TextGenerationTGIS) + + +def test_auto_find_tgis_model_multiple_tgis_backends_use_first(): + """Test that a TGIS text-gen model can be auto-found when there are + multiple TGIS backends configured and no explicit priority is given + """ + with temp_model_manager( + backend_priority=[ + make_tgis_config("foo.bar:1234"), + make_tgis_config("baz.bat:4567"), + ] + ) as mmgr: + tgis_be0 = mmgr.get_initializer("default").backends[0] + tgis_be1 = mmgr.get_initializer("default").backends[1] + with patch.object(tgis_be0, "get_connection") as get_con_mock0: + with patch.object(tgis_be1, "get_connection") as get_con_mock1: + model = mmgr.load("flan-t5-xl") + assert model + assert isinstance(model, TextGenerationTGIS) + get_con_mock0.assert_called() + get_con_mock1.assert_not_called() + + +def test_auto_find_tgis_model_multiple_tgis_backends_set_order(): + """Test that a TGIS text-gen model can be auto-found when there are + multiple TGIS backends configured and priority is explicitly given + """ + with temp_model_manager( + backend_priority=[ + make_tgis_config("foo.bar:1234"), + make_tgis_config("baz.bat:4567"), + ], + auto_finder_config={TGIS_BACKEND_PRIORITY: 1}, + ) as mmgr: + tgis_be0 = mmgr.get_initializer("default").backends[0] + tgis_be1 = mmgr.get_initializer("default").backends[1] + with patch.object(tgis_be0, "get_connection") as get_con_mock0: + with patch.object(tgis_be1, "get_connection") as get_con_mock1: + model = mmgr.load("flan-t5-xl") + assert model + assert isinstance(model, TextGenerationTGIS) + get_con_mock0.assert_not_called() + get_con_mock1.assert_called() + + +def test_bad_config_args(): + """Make sure that all flavors of bad configuration args are handled with + appropriate errors + """ + # Bad initializer name type + with pytest.raises(TypeError): + tgis_auto_finder.TGISAutoFinder( + aconfig.Config( + {LOCAL_INITIALIZER_NAME: 123}, + override_env_vars=False, + ) + ) + + # Bad tgis_backend_priority type + with pytest.raises(TypeError): + tgis_auto_finder.TGISAutoFinder( + aconfig.Config( + {TGIS_BACKEND_PRIORITY: "1"}, + override_env_vars=False, + ) + ) + + # Invalid tgis_backend_priority index value + with pytest.raises(ValueError): + tgis_auto_finder.TGISAutoFinder( + aconfig.Config( + {TGIS_BACKEND_PRIORITY: 123}, + override_env_vars=False, + ) + ) + + # Invalid tgis_backend_priority index value + with pytest.raises(ValueError): + tgis_auto_finder.TGISAutoFinder( + aconfig.Config( + {TGIS_BACKEND_PRIORITY: 123}, + override_env_vars=False, + ) + ) + + # Non-TGIS tgis_backend_priority index value + with pytest.raises(ValueError): + tgis_auto_finder.TGISAutoFinder( + aconfig.Config( + {TGIS_BACKEND_PRIORITY: 0}, + override_env_vars=False, + ) + ) + + # No TGIS backend available + with pytest.raises(ValueError): + tgis_auto_finder.TGISAutoFinder(aconfig.Config({})) + + +def test_unsupported_tgis_model(): + """Test that attempting to use a TGIS connection for a model that isn't + supported by the TGIS backend results in not finding the model + """ + with temp_model_manager( + backend_priority=[ + { + "type": TGISBackend.backend_type, + "config": { + "remote_models": { + "not-flan-t5-xl": { + "hostname": "foo.bar:123", + } + } + }, + } + ], + ) as mmgr: + finder = mmgr._finders["default"] + assert isinstance(finder, tgis_auto_finder.TGISAutoFinder) + assert finder.find_model("flan-t5-xl") is None + + +def test_bad_tgis_connection(): + """Test that attempting to use a TGIS connection that can't connect results + in not finding the model + """ + with temp_model_manager( + backend_priority=[ + { + "type": TGISBackend.backend_type, + "config": { + "remote_models": { + "flan-t5-xl": { + "hostname": "foo.bar:123", + } + }, + "test_connections": True, + }, + } + ], + ) as mmgr: + finder = mmgr._finders["default"] + assert isinstance(finder, tgis_auto_finder.TGISAutoFinder) + assert finder.find_model("flan-t5-xl") is None diff --git a/tests/modules/text_classification/test_classification_prompt_tuning.py b/tests/modules/text_classification/test_classification_prompt_tuning.py new file mode 100644 index 00000000..a5e842f6 --- /dev/null +++ b/tests/modules/text_classification/test_classification_prompt_tuning.py @@ -0,0 +1,123 @@ +"""Tests for sequence classification module +""" +# Standard +import os +import tempfile + +# Third Party +import torch + +# First Party +from caikit.interfaces.nlp.data_model import ( + ClassificationResults, + ClassificationTrainRecord, +) +import caikit + +# Local +from caikit_nlp.modules.text_classification.classification_prompt_tuning import ( + ClassificationPeftPromptTuning, +) +from caikit_nlp.modules.text_generation.peft_prompt_tuning import PeftPromptTuning +from tests.fixtures import causal_lm_dummy_model, causal_lm_train_kwargs + +#################### +## train/run ## +#################### + + +def test_train_model(causal_lm_train_kwargs): + """Ensure that we can train a model on some toy data for 1+ steps""" + patch_kwargs = { + "num_epochs": 1, + "verbalizer": "Tweet text : {{input}} Label : ", + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + ClassificationTrainRecord( + text="@foo what a cute dog!", labels=["no complaint"] + ), + ClassificationTrainRecord( + text="@bar this is the worst idea ever.", labels=["complaint"] + ), + ] + ), + "torch_dtype": torch.bfloat16, + "device": "cpu", + } + causal_lm_train_kwargs.update(patch_kwargs) + model = ClassificationPeftPromptTuning.train(**causal_lm_train_kwargs) + # Test fallback to float32 behavior if this machine doesn't support bfloat16 + assert model.classifier.model.dtype is torch.float32 + assert isinstance(model, ClassificationPeftPromptTuning) + + +# TODO: add test for scores in future when implemented +def test_run_classification_model(causal_lm_dummy_model): + classifier_model = ClassificationPeftPromptTuning( + classifier=causal_lm_dummy_model, + unique_class_labels=["LABEL_0", "LABEL_1", "LABEL_2"], + ) + output = classifier_model.run("Text does not matter") + assert isinstance(output, ClassificationResults) + # Returns supported class labels or None + classifier_model.unique_class_labels.append(None) + assert output.results[0].label in classifier_model.unique_class_labels + + +def test_train_run_model_classification_record(causal_lm_train_kwargs): + """Ensure that we can train a model on some toy data for 1+ steps & run inference.""" + patch_kwargs = { + "num_epochs": 1, + "verbalizer": "Tweet text : {{input}} Label : ", + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + ClassificationTrainRecord( + text="@foo what a cute dog!", labels=["no complaint"] + ), + ClassificationTrainRecord( + text="@bar this is the worst idea ever.", labels=["complaint"] + ), + ] + ), + "torch_dtype": torch.bfloat16, + "device": "cpu", + } + causal_lm_train_kwargs.update(patch_kwargs) + model = ClassificationPeftPromptTuning.train(**causal_lm_train_kwargs) + # Test fallback to float32 behavior if this machine doesn't support bfloat16 + assert model.classifier.model.dtype is torch.float32 + assert isinstance(model, ClassificationPeftPromptTuning) + + +#################### +## save/load ## +#################### + + +def test_save(causal_lm_dummy_model): + classifier_model = ClassificationPeftPromptTuning( + classifier=causal_lm_dummy_model, unique_class_labels=["label1", "label2"] + ) + with tempfile.TemporaryDirectory() as model_dir: + classifier_model.save(model_dir) + assert os.path.exists(os.path.join(model_dir, "config.yml")) + assert os.path.exists(os.path.join(model_dir, "artifacts", "config.yml")) + + +# TODO: Enable test when saving of base model is enabled in module_saver +# def test_save_and_load(causal_lm_dummy_model): +# classifier_model = ClassificationPeftPromptTuning( +# classifier=causal_lm_dummy_model, unique_class_labels=["label1", "label2"] +# ) +# with tempfile.TemporaryDirectory() as model_dir: +# classifier_model.save(model_dir) +# model_load = caikit_nlp.load(model_dir) +# assert isinstance(model_load, ClassificationPeftPromptTuning) +# assert isinstance(model_load.classifier, PeftPromptTuning) +# assert model_load.unique_class_labels == ["label1", "label2"] + +#################### +## save/load/run ## +#################### + +# TODO after load is fixed diff --git a/tests/modules/text_classification/test_sequence_classification.py b/tests/modules/text_classification/test_sequence_classification.py index 2ba0538d..58b75409 100644 --- a/tests/modules/text_classification/test_sequence_classification.py +++ b/tests/modules/text_classification/test_sequence_classification.py @@ -7,8 +7,10 @@ from pytest import approx import pytest +# First Party +from caikit.interfaces.nlp.data_model import ClassificationResult, ClassificationResults + # Local -from caikit_nlp.data_model.classification import Classification, ClassificationResult from caikit_nlp.modules.text_classification import SequenceClassification from tests.fixtures import SEQ_CLASS_MODEL @@ -32,10 +34,10 @@ def test_bootstrap_and_run(): """Check if we can bootstrap and run sequence classification models""" model = SequenceClassification.bootstrap(SEQ_CLASS_MODEL) classification_result = model.run(TEXTS[0]) - assert isinstance(classification_result, ClassificationResult) + assert isinstance(classification_result, ClassificationResults) assert len(classification_result.results) == 2 # 2 labels - assert isinstance(classification_result.results[0], Classification) + assert isinstance(classification_result.results[0], ClassificationResult) assert classification_result.results[0].label == "LABEL_0" assert approx(classification_result.results[0].score) == 0.49526197 assert classification_result.results[1].label == "LABEL_1" @@ -48,7 +50,7 @@ def test_bootstrap_and_run_batch(): assert len(classification_result_list) == 2 first_result = classification_result_list[0] - assert isinstance(first_result, ClassificationResult) + assert isinstance(first_result, ClassificationResults) assert first_result.results[0].label == "LABEL_0" assert approx(first_result.results[0].score) == 0.49526197 assert first_result.results[1].label == "LABEL_1" @@ -61,9 +63,9 @@ def test_load_save_and_run_model(): BOOTSTRAPPED_SEQ_CLASS_MODEL.save(model_dir) new_model = SequenceClassification.load(model_dir) classification_result = new_model.run(TEXTS[0]) - assert isinstance(classification_result, ClassificationResult) + assert isinstance(classification_result, ClassificationResults) assert len(classification_result.results) == 2 # 2 labels - assert isinstance(classification_result.results[0], Classification) + assert isinstance(classification_result.results[0], ClassificationResult) assert classification_result.results[0].label == "LABEL_0" assert approx(classification_result.results[0].score) == 0.49526197 diff --git a/tests/modules/text_generation/test_fine_tuning.py b/tests/modules/text_generation/test_fine_tuning.py deleted file mode 100644 index 17a611b1..00000000 --- a/tests/modules/text_generation/test_fine_tuning.py +++ /dev/null @@ -1,63 +0,0 @@ -# Third Party -from transformers import Trainer -import pytest -import torch - -# First Party -from caikit.interfaces.nlp.data_model import GeneratedTextResult -import caikit - -# Local -from caikit_nlp.data_model import GenerationTrainRecord -from caikit_nlp.modules.text_generation import FineTuning -from caikit_nlp.resources.pretrained_model import HFAutoSeq2SeqLM -from tests.fixtures import SEQ2SEQ_LM_MODEL, disable_wip - - -def test_train_model(disable_wip): - """Ensure that we can train a model on some toy data for 1+ steps & run inference.""" - train_kwargs = { - "base_model": HFAutoSeq2SeqLM.bootstrap( - model_name=SEQ2SEQ_LM_MODEL, tokenizer_name=SEQ2SEQ_LM_MODEL - ), - "num_epochs": 1, - "train_stream": caikit.core.data_model.DataStream.from_iterable( - [ - GenerationTrainRecord( - input="@foo what a cute dog!", output="no complaint" - ), - GenerationTrainRecord( - input="@bar this is the worst idea ever.", output="complaint" - ), - ] - ), - "torch_dtype": torch.float32, - } - model = FineTuning.train(**train_kwargs) - assert isinstance(model.model, Trainer) - # Ensure that we can get something out of it - pred = model.run("@bar what a cute cat!") - assert isinstance(pred, GeneratedTextResult) - - -############################## Error Cases ################################ - - -def test_zero_epoch_case(disable_wip): - """Test to ensure 0 epoch training request doesn't explode""" - train_kwargs = { - "base_model": HFAutoSeq2SeqLM.bootstrap( - model_name=SEQ2SEQ_LM_MODEL, tokenizer_name=SEQ2SEQ_LM_MODEL - ), - "num_epochs": 0, - "train_stream": caikit.core.data_model.DataStream.from_iterable( - [ - GenerationTrainRecord( - input="@foo what a cute dog!", output="no complaint" - ), - ] - ), - "torch_dtype": torch.float32, - } - model = FineTuning.train(**train_kwargs) - assert isinstance(model.model, Trainer) diff --git a/tests/modules/text_generation/test_peft_prompt_tuning.py b/tests/modules/text_generation/test_peft_prompt_tuning.py index 28100270..5ed1c9fe 100644 --- a/tests/modules/text_generation/test_peft_prompt_tuning.py +++ b/tests/modules/text_generation/test_peft_prompt_tuning.py @@ -17,6 +17,7 @@ # First Party from caikit.interfaces.nlp.data_model import ( + ClassificationTrainRecord, GeneratedTextResult, GeneratedTextStreamResult, ) @@ -30,14 +31,16 @@ causal_lm_train_kwargs, seq2seq_lm_dummy_model, seq2seq_lm_train_kwargs, + set_cpu_device, ) import caikit_nlp # Indexes into the peft config dictionary to get the actual prompt tuning config DEFAULT_ADAPTER = "default" + ### Tests validating block interfaces and behavior -def test_save_and_reload_with_base_model(causal_lm_dummy_model): +def test_save_and_reload_with_base_model(causal_lm_dummy_model, set_cpu_device): """Ensure that we can save a model + its base to a tempdir and reload it.""" with tempfile.TemporaryDirectory() as model_dir: causal_lm_dummy_model.save(model_dir, save_base_model=True) @@ -109,7 +112,7 @@ def test_verbalizer_cannot_be_static(causal_lm_train_kwargs): ) -def test_train_model(causal_lm_train_kwargs): +def test_train_model(causal_lm_train_kwargs, set_cpu_device): """Ensure that we can train a model on some toy data for 1+ steps & run inference.""" patch_kwargs = { "num_epochs": 1, @@ -138,17 +141,17 @@ def test_train_model(causal_lm_train_kwargs): assert isinstance(pred, GeneratedTextResult) -def test_train_model_classification_record(causal_lm_train_kwargs): +def test_train_model_classification_record(causal_lm_train_kwargs, set_cpu_device): """Ensure that we can train a model on some toy data for 1+ steps & run inference.""" patch_kwargs = { "num_epochs": 1, "verbalizer": "Tweet text : {{input}} Label : ", "train_stream": caikit.core.data_model.DataStream.from_iterable( [ - caikit_nlp.data_model.ClassificationTrainRecord( + ClassificationTrainRecord( text="@foo what a cute dog!", labels=["no complaint"] ), - caikit_nlp.data_model.ClassificationTrainRecord( + ClassificationTrainRecord( text="@bar this is the worst idea ever.", labels=["complaint"] ), ] diff --git a/tests/modules/text_generation/test_text_generation_local.py b/tests/modules/text_generation/test_text_generation_local.py new file mode 100644 index 00000000..d83ed380 --- /dev/null +++ b/tests/modules/text_generation/test_text_generation_local.py @@ -0,0 +1,178 @@ +"""Tests for text-generation module +""" +# Standard +import os +import platform +import tempfile + +# Third Party +import pytest +import torch + +# First Party +from caikit.interfaces.nlp.data_model import GeneratedTextResult +import caikit + +# Local +from caikit_nlp.data_model import GenerationTrainRecord +from caikit_nlp.modules.text_generation import TextGeneration +from caikit_nlp.resources.pretrained_model import HFAutoCausalLM, HFAutoSeq2SeqLM +from tests.fixtures import ( + CAUSAL_LM_MODEL, + SEQ2SEQ_LM_MODEL, + disable_wip, + set_cpu_device, +) + +### Stub Modules + + +def test_bootstrap_and_run_causallm(): + """Check if we can bootstrap and run causallm models""" + + model = TextGeneration.bootstrap(CAUSAL_LM_MODEL) + + sample_text = "Hello stub" + generated_text = model.run(sample_text) + assert isinstance(generated_text, GeneratedTextResult) + + +def test_bootstrap_and_run_seq2seq(): + """Check if we can bootstrap and run seq2seq models""" + + model = TextGeneration.bootstrap(SEQ2SEQ_LM_MODEL) + + sample_text = "Hello stub" + generated_text = model.run(sample_text) + assert isinstance(generated_text, GeneratedTextResult) + + +def test_bootstrap_and_save_model(): + """Check if we can bootstrap and save the model successfully""" + + model = TextGeneration.bootstrap(SEQ2SEQ_LM_MODEL) + + with tempfile.TemporaryDirectory() as model_dir: + model.save(model_dir) + assert os.path.isfile(os.path.join(model_dir, "config.yml")) + + +def test_save_model_can_run(): + """Check if the model we bootstrap and save is able to load and run successfully""" + model = TextGeneration.bootstrap(SEQ2SEQ_LM_MODEL) + + with tempfile.TemporaryDirectory() as model_dir: + model.save(model_dir) + del model + new_model = TextGeneration.load(model_dir) + sample_text = "Hello stub" + generated_text = new_model.run(sample_text) + assert isinstance(generated_text, GeneratedTextResult) + + +############################## Training ################################ + + +@pytest.mark.skipif(platform.processor() == "arm", reason="ARM training not supported") +def test_train_model_seq2seq(disable_wip, set_cpu_device): + """Ensure that we can finetune a seq2seq model on some toy data for 1+ + steps & run inference.""" + train_kwargs = { + "base_model": HFAutoSeq2SeqLM.bootstrap( + model_name=SEQ2SEQ_LM_MODEL, tokenizer_name=SEQ2SEQ_LM_MODEL + ), + "num_epochs": 1, + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + GenerationTrainRecord( + input="@foo what a cute dog!", output="no complaint" + ), + GenerationTrainRecord( + input="@bar this is the worst idea ever.", output="complaint" + ), + ] + ), + "torch_dtype": torch.float32, + } + model = TextGeneration.train(**train_kwargs) + assert isinstance(model.model, HFAutoSeq2SeqLM) + + # Ensure that we can get something out of it + pred = model.run("@bar what a cute cat!") + assert isinstance(pred, GeneratedTextResult) + + +@pytest.mark.skipif(platform.processor() == "arm", reason="ARM training not supported") +def test_train_model_save_and_load(disable_wip, set_cpu_device): + """Ensure that we are able to save and load a finetuned model and execute inference on it""" + train_kwargs = { + "base_model": HFAutoSeq2SeqLM.bootstrap( + model_name=SEQ2SEQ_LM_MODEL, tokenizer_name=SEQ2SEQ_LM_MODEL + ), + "num_epochs": 1, + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + GenerationTrainRecord( + input="@foo what a cute dog!", output="no complaint" + ) + ] + ), + "torch_dtype": torch.float32, + } + model = TextGeneration.train(**train_kwargs) + assert isinstance(model.model, HFAutoSeq2SeqLM) + with tempfile.TemporaryDirectory() as model_dir: + model.save(model_dir) + new_model = TextGeneration.load(model_dir) + sample_text = "Hello stub" + generated_text = new_model.run(sample_text) + assert isinstance(generated_text, GeneratedTextResult) + + +@pytest.mark.skipif(platform.processor() == "arm", reason="ARM training not supported") +def test_train_model_causallm(disable_wip, set_cpu_device): + """Ensure that we can finetune a causal-lm model on some toy data for 1+ + steps & run inference.""" + train_kwargs = { + "base_model": HFAutoCausalLM.bootstrap( + model_name=CAUSAL_LM_MODEL, tokenizer_name=CAUSAL_LM_MODEL + ), + "num_epochs": 1, + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + GenerationTrainRecord( + input="@foo what a cute dog!", output="no complaint" + ), + ] + ), + "torch_dtype": torch.float32, + } + model = TextGeneration.train(**train_kwargs) + assert isinstance(model.model, HFAutoCausalLM) + + # Ensure that we can get something out of it + pred = model.run("@bar what a cute cat!") + assert isinstance(pred, GeneratedTextResult) + + +############################## Error Cases ################################ + + +def test_zero_epoch_case(disable_wip): + """Test to ensure 0 epoch training request doesn't explode""" + train_kwargs = { + "base_model": HFAutoSeq2SeqLM.bootstrap( + model_name=SEQ2SEQ_LM_MODEL, tokenizer_name=SEQ2SEQ_LM_MODEL + ), + "num_epochs": 0, + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + GenerationTrainRecord( + input="@foo what a cute dog!", output="no complaint" + ), + ] + ), + "torch_dtype": torch.float32, + } + model = TextGeneration.train(**train_kwargs) + assert isinstance(model.model, HFAutoSeq2SeqLM) diff --git a/tests/modules/text_generation/test_text_generation.py b/tests/modules/text_generation/test_text_generation_tgis.py similarity index 59% rename from tests/modules/text_generation/test_text_generation.py rename to tests/modules/text_generation/test_text_generation_tgis.py index fb99dfea..d066ce20 100644 --- a/tests/modules/text_generation/test_text_generation.py +++ b/tests/modules/text_generation/test_text_generation_tgis.py @@ -3,13 +3,21 @@ # Standard from unittest import mock import os +import platform import tempfile # Third Party import pytest +import torch + +# First Party +from caikit.interfaces.nlp.data_model import GeneratedTextResult +import caikit # Local -from caikit_nlp.modules.text_generation import TextGeneration +from caikit_nlp.data_model.generation import GenerationTrainRecord +from caikit_nlp.modules.text_generation import TextGeneration, TextGenerationTGIS +from caikit_nlp.resources.pretrained_model.hf_auto_seq2seq_lm import HFAutoSeq2SeqLM from tests.fixtures import ( CAUSAL_LM_MODEL, SEQ2SEQ_LM_MODEL, @@ -23,7 +31,9 @@ def test_bootstrap_and_run_causallm(): """Check if we can bootstrap and run causallm models""" - model = TextGeneration.bootstrap(CAUSAL_LM_MODEL, load_backend=StubTGISBackend()) + model = TextGenerationTGIS.bootstrap( + CAUSAL_LM_MODEL, load_backend=StubTGISBackend() + ) result = model.run(SAMPLE_TEXT, preserve_input_text=True) StubTGISClient.validate_unary_generate_response(result) @@ -32,7 +42,9 @@ def test_bootstrap_and_run_causallm(): def test_bootstrap_and_run_seq2seq(): """Check if we can bootstrap and run seq2seq models""" - model = TextGeneration.bootstrap(SEQ2SEQ_LM_MODEL, load_backend=StubTGISBackend()) + model = TextGenerationTGIS.bootstrap( + SEQ2SEQ_LM_MODEL, load_backend=StubTGISBackend() + ) result = model.run(SAMPLE_TEXT, preserve_input_text=True) StubTGISClient.validate_unary_generate_response(result) @@ -55,7 +67,9 @@ def test_run_multi_response_errors(): def test_bootstrap_and_save_model(): """Check if we can bootstrap and save the model successfully""" - model = TextGeneration.bootstrap(SEQ2SEQ_LM_MODEL) + model = TextGenerationTGIS.bootstrap( + SEQ2SEQ_LM_MODEL, load_backend=StubTGISBackend() + ) with tempfile.TemporaryDirectory() as model_dir: model.save(model_dir) @@ -64,27 +78,58 @@ def test_bootstrap_and_save_model(): def test_save_model_can_run(): """Check if the model we bootstrap and save is able to load and run successfully""" - model = TextGeneration.bootstrap(SEQ2SEQ_LM_MODEL) + model = TextGenerationTGIS.bootstrap(SEQ2SEQ_LM_MODEL) + with tempfile.TemporaryDirectory() as model_dir: model.save(model_dir) del model - new_model = TextGeneration.load( + new_model = TextGenerationTGIS.load( model_dir, load_backend=StubTGISBackend(mock_remote=True) ) result = new_model.run(SAMPLE_TEXT, preserve_input_text=True) StubTGISClient.validate_unary_generate_response(result) +@pytest.mark.skipif(platform.processor() == "arm", reason="ARM training not supported") +def test_local_train_load_tgis(): + """Check if the model trained in local module is able to + be loaded in TGIS module / backend + """ + train_kwargs = { + "base_model": HFAutoSeq2SeqLM.bootstrap( + model_name=SEQ2SEQ_LM_MODEL, tokenizer_name=SEQ2SEQ_LM_MODEL + ), + "num_epochs": 1, + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + GenerationTrainRecord( + input="@foo what a cute dog!", output="no complaint" + ) + ] + ), + "torch_dtype": torch.float32, + } + model = TextGeneration.train(**train_kwargs) + with tempfile.TemporaryDirectory() as model_dir: + model.save(model_dir) + new_model = TextGenerationTGIS.load( + model_dir, load_backend=StubTGISBackend(mock_remote=True) + ) + sample_text = "Hello stub" + generated_text = new_model.run(sample_text) + assert isinstance(generated_text, GeneratedTextResult) + + def test_remote_tgis_only_model(): """Make sure that a model can be created and used that will only work with a remote TGIS connection (i.e. it has no artifacts) """ model_name = "model-name" tgis_backend = StubTGISBackend(mock_remote=True) - model = TextGeneration(model_name, tgis_backend=tgis_backend) + model = TextGenerationTGIS(model_name, tgis_backend=tgis_backend) with tempfile.TemporaryDirectory() as model_dir: model.save(model_dir) - TextGeneration.load(model_dir, load_backend=tgis_backend) + TextGenerationTGIS.load(model_dir, load_backend=tgis_backend) ### Output streaming tests ############################################################## @@ -92,7 +137,9 @@ def test_remote_tgis_only_model(): def test_bootstrap_and_run_stream_out(): """Check if we can bootstrap and run_stream_out""" - model = TextGeneration.bootstrap(SEQ2SEQ_LM_MODEL, load_backend=StubTGISBackend()) + model = TextGenerationTGIS.bootstrap( + SEQ2SEQ_LM_MODEL, load_backend=StubTGISBackend() + ) stream_result = model.run_stream_out(SAMPLE_TEXT) StubTGISClient.validate_stream_generate_response(stream_result) @@ -104,7 +151,7 @@ def test_run_stream_out_with_runtime_error(): with mock.patch.object(StubTGISClient, "GenerateStream") as mock_gen_stream: mock_gen_stream.side_effect = RuntimeError("An error!") - model = TextGeneration.bootstrap( + model = TextGenerationTGIS.bootstrap( SEQ2SEQ_LM_MODEL, load_backend=StubTGISBackend() ) with pytest.raises(RuntimeError): diff --git a/tests/modules/token_classification/test_filtered_span_classification.py b/tests/modules/token_classification/test_filtered_span_classification.py index 9d619b27..ce20c5cb 100644 --- a/tests/modules/token_classification/test_filtered_span_classification.py +++ b/tests/modules/token_classification/test_filtered_span_classification.py @@ -12,17 +12,15 @@ # First Party from caikit.core import data_model from caikit.core.modules import ModuleBase, module - -# Local -from caikit_nlp.data_model.classification import ( - TokenClassification, +from caikit.interfaces.nlp.data_model import ( TokenClassificationResult, + TokenClassificationResults, ) +from caikit.interfaces.nlp.tasks import TokenClassificationTask + +# Local from caikit_nlp.modules.text_classification import SequenceClassification -from caikit_nlp.modules.token_classification import ( - FilteredSpanClassification, - TokenClassificationTask, -) +from caikit_nlp.modules.token_classification import FilteredSpanClassification from caikit_nlp.modules.tokenization.regex_sentence_splitter import ( RegexSentenceSplitter, ) @@ -42,16 +40,16 @@ ) # Token classifications in document -FOX_CLASS = TokenClassification( +FOX_CLASS = TokenClassificationResult( start=16, end=19, word="fox", entity="animal", score=0.8 ) -DOG_CLASS = TokenClassification( +DOG_CLASS = TokenClassificationResult( start=40, end=43, word="dog", entity="animal", score=0.3 ) -LAND_CLASS = TokenClassification( +LAND_CLASS = TokenClassificationResult( start=22, end=26, word="land", entity="thing", score=0.7 ) -TOK_CLASSIFICATION_RESULT = TokenClassificationResult(results=[FOX_CLASS, DOG_CLASS]) +TOK_CLASSIFICATION_RESULT = TokenClassificationResults(results=[FOX_CLASS, DOG_CLASS]) # Modules that already returns token classification for tests @module( @@ -62,33 +60,33 @@ ) class FakeTokenClassificationModule(ModuleBase): # This returns results for the whole document - def run(self, text: str) -> TokenClassificationResult: + def run(self, text: str) -> TokenClassificationResults: return TOK_CLASSIFICATION_RESULT - def run_batch(self, texts: List[str]) -> List[TokenClassificationResult]: + def run_batch(self, texts: List[str]) -> List[TokenClassificationResults]: return [ TOK_CLASSIFICATION_RESULT, - TokenClassificationResult(results=[LAND_CLASS]), + TokenClassificationResults(results=[LAND_CLASS]), ] class StreamFakeTokenClassificationModule(FakeTokenClassificationModule): # Make module return results per sentence - def run(self, text: str) -> TokenClassificationResult: + def run(self, text: str) -> TokenClassificationResults: if "land" in text: - return TokenClassificationResult(results=[LAND_CLASS]) + return TokenClassificationResults(results=[LAND_CLASS]) else: return TOK_CLASSIFICATION_RESULT class EmptyResFakeTokenClassificationModule(FakeTokenClassificationModule): - def run(self, text: str) -> TokenClassificationResult: - return TokenClassificationResult(results=[]) + def run(self, text: str) -> TokenClassificationResults: + return TokenClassificationResults(results=[]) - def run_batch(self, texts: List[str]) -> List[TokenClassificationResult]: + def run_batch(self, texts: List[str]) -> List[TokenClassificationResults]: return [ - TokenClassificationResult(results=[]), - TokenClassificationResult(results=[]), + TokenClassificationResults(results=[]), + TokenClassificationResults(results=[]), ] @@ -108,9 +106,9 @@ def test_bootstrap_run(): default_threshold=0.5, ) token_classification_result = model.run(DOCUMENT) - assert isinstance(token_classification_result, TokenClassificationResult) + assert isinstance(token_classification_result, TokenClassificationResults) assert len(token_classification_result.results) == 2 # 2 results over 0.5 expected - assert isinstance(token_classification_result.results[0], TokenClassification) + assert isinstance(token_classification_result.results[0], TokenClassificationResult) first_result = token_classification_result.results[0] assert first_result.start == 0 assert first_result.end == 44 @@ -129,7 +127,7 @@ def test_bootstrap_run_with_threshold(): default_threshold=0.5, ) token_classification_result = model.run(DOCUMENT, threshold=0.0) - assert isinstance(token_classification_result, TokenClassificationResult) + assert isinstance(token_classification_result, TokenClassificationResults) assert ( len(token_classification_result.results) == 4 ) # 4 (all) results over 0.0 expected @@ -164,9 +162,9 @@ def test_bootstrap_run_with_token_classification(): default_threshold=0.5, ) token_classification_result = model.run(DOCUMENT) - assert isinstance(token_classification_result, TokenClassificationResult) + assert isinstance(token_classification_result, TokenClassificationResults) assert len(token_classification_result.results) == 2 # 2 results over 0.5 expected - assert isinstance(token_classification_result.results[0], TokenClassification) + assert isinstance(token_classification_result.results[0], TokenClassificationResult) first_result = token_classification_result.results[0] assert first_result.start == 16 assert first_result.end == 19 @@ -185,7 +183,7 @@ def test_bootstrap_run_with_token_classification_no_results(): default_threshold=0.5, ) token_classification_result = model.run(DOCUMENT) - assert isinstance(token_classification_result, TokenClassificationResult) + assert isinstance(token_classification_result, TokenClassificationResults) assert len(token_classification_result.results) == 0 @@ -205,7 +203,7 @@ def test_save_load_and_run_model(): new_model = FilteredSpanClassification.load(model_dir) token_classification_result = new_model.run(DOCUMENT) - assert isinstance(token_classification_result, TokenClassificationResult) + assert isinstance(token_classification_result, TokenClassificationResults) assert ( len(token_classification_result.results) == 2 ) # 2 results over 0.5 expected @@ -231,7 +229,7 @@ def test_run_bidi_stream_model(): result_list = list(streaming_token_classification_result) first_result = result_list[0].results[0] - assert isinstance(first_result, TokenClassification) + assert isinstance(first_result, TokenClassificationResult) assert first_result.start == 0 assert first_result.end == 44 assert first_result.word == "The quick brown fox jumps over the lazy dog." @@ -263,7 +261,7 @@ def test_run_bidi_stream_with_token_classification(): result_list = list(streaming_token_classification_result) # Convert to list to more easily check outputs first_result = result_list[0].results[0] - assert isinstance(first_result, TokenClassification) + assert isinstance(first_result, TokenClassificationResult) assert first_result.start == 16 assert first_result.end == 19 assert first_result.word == "fox" @@ -326,7 +324,7 @@ def test_run_bidi_stream_chunk_stream_input(): result_list = list(streaming_token_classification_result) # Convert to list to more easily check outputs first_result = result_list[0].results[0] - assert isinstance(first_result, TokenClassification) + assert isinstance(first_result, TokenClassificationResult) assert first_result.start == 16 assert first_result.end == 19 assert first_result.word == "fox" @@ -366,7 +364,7 @@ def test_run_bidi_stream_with_multiple_spans_in_chunk(): result_list = list(streaming_token_classification_result) first_result = result_list[0].results[0] - assert isinstance(first_result, TokenClassification) + assert isinstance(first_result, TokenClassificationResult) assert first_result.start == 0 assert first_result.end == 44 assert first_result.word == "The quick brown fox jumps over the lazy dog." diff --git a/tests/modules/tokenization/test_regex_sentence_splitter.py b/tests/modules/tokenization/test_regex_sentence_splitter.py index 15585292..f8590c29 100644 --- a/tests/modules/tokenization/test_regex_sentence_splitter.py +++ b/tests/modules/tokenization/test_regex_sentence_splitter.py @@ -4,8 +4,10 @@ import os import tempfile +# First Party +from caikit.interfaces.nlp.data_model import TokenizationResults + # Local -from caikit_nlp.data_model.text import TokenizationResult from caikit_nlp.modules.tokenization.regex_sentence_splitter import ( RegexSentenceSplitter, ) @@ -23,7 +25,7 @@ def test_bootstrap_and_run(): """Check if we can bootstrap and run regex sentence splitter""" tokenization_result = SENTENCE_TOKENIZER.run(DOCUMENT) - assert isinstance(tokenization_result, TokenizationResult) + assert isinstance(tokenization_result, TokenizationResults) assert len(tokenization_result.results) == 2 @@ -35,5 +37,5 @@ def test_save_load_and_run_model(): new_splitter = RegexSentenceSplitter.load(model_dir) tokenization_result = new_splitter.run(DOCUMENT) - assert isinstance(tokenization_result, TokenizationResult) + assert isinstance(tokenization_result, TokenizationResults) assert len(tokenization_result.results) == 2 diff --git a/tests/resources/test_pretrained_model.py b/tests/resources/test_pretrained_model.py index 0b377e28..d7e5a748 100644 --- a/tests/resources/test_pretrained_model.py +++ b/tests/resources/test_pretrained_model.py @@ -128,6 +128,7 @@ def test_causal_lm_tok_output_correctness(models_cache_dir): max_source_length=100, max_target_length=100, verbalizer="{{input}}", + task_ids=0, ) input_tok = causal_lm.tokenizer.encode(sample.input) output_tok = causal_lm.tokenizer.encode(sample.output) @@ -170,6 +171,7 @@ def test_seq2seq_tokenize_func_contains_unwrapped_stream(models_cache_dir): max_source_length=100, max_target_length=100, verbalizer="{{input}}", + task_ids=0, ) tok_res = tok_func(GenerationTrainRecord(input="hello", output="world")) map_stream = SAMPLE_TRAINING_DATA.map(tok_func) @@ -195,6 +197,7 @@ def test_seq2seq_tok_output_correctness(models_cache_dir): max_source_length=20, max_target_length=20, verbalizer="{{input}}", + task_ids=0, ) input_tok = seq2seq.tokenizer.encode(sample.input) output_tok = seq2seq.tokenizer.encode(sample.output) diff --git a/tests/toolkit/test_task_specific_utils.py b/tests/toolkit/test_task_specific_utils.py index ab9d519c..4569d3e1 100644 --- a/tests/toolkit/test_task_specific_utils.py +++ b/tests/toolkit/test_task_specific_utils.py @@ -15,25 +15,31 @@ # Third Party import pytest +# First Party +from caikit.interfaces.nlp.data_model import ClassificationTrainRecord + # Local -from caikit_nlp import data_model as dm -from caikit_nlp.toolkit.task_specific_utils import convert_to_generation_record +from caikit_nlp.data_model import GenerationTrainRecord +from caikit_nlp.toolkit.task_specific_utils import ( + convert_to_generation_record, + get_sorted_unique_class_labels, +) def test_convert_classification_train_record_to_generation_record(): - classification_train_record = dm.ClassificationTrainRecord( + classification_train_record = ClassificationTrainRecord( text="foo bar", labels=["label1"] ) generated_train = convert_to_generation_record(classification_train_record) - assert isinstance(generated_train, dm.GenerationTrainRecord) + assert isinstance(generated_train, GenerationTrainRecord) assert generated_train.input == "foo bar" assert generated_train.output == "label1" def test_convert_generation_record_to_generation_record(): - generation_train_record = dm.GenerationTrainRecord(input="foo bar", output="label1") + generation_train_record = GenerationTrainRecord(input="foo bar", output="label1") generated_train = convert_to_generation_record(generation_train_record) - assert isinstance(generated_train, dm.GenerationTrainRecord) + assert isinstance(generated_train, GenerationTrainRecord) assert generated_train.input == generation_train_record.input assert generated_train.output == generation_train_record.output @@ -53,3 +59,17 @@ def test_convert_to_generation_record_gives_error_with_unsupported_type(): string_record = "test record" with pytest.raises(TypeError): convert_to_generation_record(string_record) + + +def test_get_sorted_unique_class_labels(): + # Sample train data + sample_data = [ + ClassificationTrainRecord(text="foo bar", labels=["label1"]), + ClassificationTrainRecord( + text="foo bar", labels=["label1", "label2", "label3"] + ), + ] + output_labels = ["label1", "label2", "label3"] + sample_stream = DataStream.from_iterable(sample_data) + class_labels = get_sorted_unique_class_labels(sample_stream) + assert output_labels == class_labels