From ded3b55342a8fb0379dd3c05665a022fd3d31928 Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Thu, 18 Apr 2024 11:46:22 -0300 Subject: [PATCH] Caikit embeddings examples + local run documentation Signed-off-by: Flavia Beo --- examples/embeddings/README.md | 284 +++++++ examples/embeddings/config.py | 11 + examples/embeddings/embedding_tasks.ipynb | 789 ++++++++++++++++++ examples/embeddings/embeddings.py | 68 ++ .../artifacts/example_model_dir.txt | 1 + .../models_/all-minilm-l6-v2/config.yml | 9 + examples/embeddings/requirements.txt | 2 + examples/embeddings/reranker.py | 114 +++ examples/embeddings/sentence_similarity.py | 62 ++ 9 files changed, 1340 insertions(+) create mode 100644 examples/embeddings/README.md create mode 100644 examples/embeddings/config.py create mode 100644 examples/embeddings/embedding_tasks.ipynb create mode 100644 examples/embeddings/embeddings.py create mode 100644 examples/embeddings/models_/all-minilm-l6-v2/artifacts/example_model_dir.txt create mode 100644 examples/embeddings/models_/all-minilm-l6-v2/config.yml create mode 100644 examples/embeddings/requirements.txt create mode 100644 examples/embeddings/reranker.py create mode 100644 examples/embeddings/sentence_similarity.py diff --git a/examples/embeddings/README.md b/examples/embeddings/README.md new file mode 100644 index 00000000..84bb8d76 --- /dev/null +++ b/examples/embeddings/README.md @@ -0,0 +1,284 @@ +# Set up and run locally caikit embeddings server + +#### Setting Up Virtual Environment using Python venv + +For [(venv)](https://docs.python.org/3/library/venv.html), make sure you are in an activated `venv` when running `python` in the example commands that follow. Use `deactivate` if you want to exit the `venv`. + +```shell +python3 -m venv venv +source venv/bin/activate +``` + +### Models + +For this tutorial, you can download [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2), to do that you need to follow the steps to clone and use `git lfs` to get all the models files: + +```shell +# Make sure you have git-lfs installed (https://git-lfs.com) +git lfs install + +git clone https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 +``` + +To create a model configuration and artifacts, the best practice is to run the module's bootstrap() and save() methods. This will: + +* Load the model by name (from Hugging Face hub or repository) or from a local directory. The model is loaded using the sentence-transformers library. +* Save a config.yml which: + * Ties the model to the module (with a module_id GUID) + * Sets the artifacts_path to the default "artifacts" subdirectory + * Saves the model in the artifacts subdirectory +* Check an example of the folder structure at [models_](./models_/) + +> For the reranker service, models supported are bi-encoder and are the same used by the other embeddings tasks. + +This can be achieved by the following lines of code, using BGE as example model: + +```python +import os +os.environ['ALLOW_DOWNLOADS'] = "1" + +import caikit_nlp +model_name = "BAAI/bge-large-en-v1.5" +model = caikit_nlp.text_embedding.EmbeddingModule.bootstrap(model_name) +model.save(f"{model_name}-caikit") +``` + +To avoid overwriting your files, the save() will return an error if the output directory already exists. You may want to use a temporary name. After success, move the output directory to a `` directory under your local models dir. + +### Environment variables + +These are the set of variables/params related to the environment which embeddings will be run: + +```bash +# use IPEX optimization +IPEX_OPTIMIZE: 'true' + +# use "xpu" for IPEX on GPU instead of IPEX on CPU +USE_XPU: 'false' + +# IPEX performs best with autocast using bfloat16 +BFLOAT16: '1' + +# use Mac chip +USE_MPS: 'false' + +# use Pytorch compile +PT2_COMPILE: 'false' +``` + +### Starting the Caikit Runtime + +Run caikit-runtime configured to use the caikit-nlp library. Set up the following environment variables: + +```bash +# set where the runtime should look for the models +export RUNTIME_LOCAL_MODELS_DIR=/models_ + +# load the models from the path set up at previous var +export RUNTIME_LAZY_LOAD_LOCAL_MODELS=true + +# set the runtime +export RUNTIME_LIBRARY='caikit_nlp' +``` + +In one terminal, start the runtime server: + +```bash +source venv/bin/activate +pip install -r requirements.txt +caikit-runtime +``` + +To run the library locally: + +```bash +pip install caikit-nlp@file:////caikit-nlp +python -m caikit.runtime +``` + +### Embedding retrieval example Python client + +In another terminal, run the example client code to retrieve embeddings. + +```shell +source venv/bin/activate +MODEL= python embeddings.py +``` + +The client code calls the model and queries for embeddings using 2 example sentences. + +You should see output similar to the following: + +```ShellSession +$ python embeddings.py +INPUT TEXTS: ['test first sentence', 'another test sentence'] +OUTPUT: { + { + "results": [ + [ + -0.17895537614822388, + 0.03200146183371544, + -0.030327674001455307, + ... + ], + [ + -0.17895537614822388, + 0.03200146183371544, + -0.030327674001455307, + ... + ] + ], + "producerId": { + "name": "EmbeddingModule", + "version": "0.0.1" + }, + "inputTokenCount": "9" + } +} +LENGTH: 2 x 384 +``` + +### Sentence similarity example Python client + +In another terminal, run the client code to infer sentence similarity. + +```shell +source venv/bin/activate +MODEL= python sentence_similarity.py +``` + +The client code calls the model and queries sentence similarity using 1 source sentence and 2 other sentences (hardcoded in sentence_similarity.py). The result produces the cosine similarity score by comparing the source sentence with each of the other sentences. + +You should see output similar to the following: + +```ShellSession +$ python sentence_similarity.py +SOURCE SENTENCE: first sentence +SENTENCES: ['test first sentence', 'another test sentence'] +OUTPUT: { + "result": { + "scores": [ + 1.0000001192092896 + ] + }, + "producerId": { + "name": "EmbeddingModule", + "version": "0.0.1" + }, + "inputTokenCount": "9" +} +``` + +### Reranker example Python client + +In another terminal, run the client code to execute the reranker task using both gRPC and REST. + +```shell +source venv/bin/activate +MODEL= python reranker.py +``` + +You should see output similar to the following: + +```ShellSession +$ python reranker.py +====================== +TOP N: 3 +QUERIES: ['first sentence', 'any sentence'] +DOCUMENTS: [{'text': 'first sentence', 'title': 'first title'}, {'_text': 'another sentence', 'more': 'more attributes here'}, {'text': 'a doc with a nested metadata', 'meta': {'foo': 'bar', 'i': 999, 'f': 12.34}}] +====================== +RESPONSE from gRPC: +=== +QUERY: first sentence + score: 0.9999997019767761 index: 0 text: first sentence + score: 0.7350112199783325 index: 1 text: another sentence + score: 0.10398174077272415 index: 2 text: a doc with a nested metadata +=== +QUERY: any sentence + score: 0.6631797552108765 index: 0 text: first sentence + score: 0.6505964398384094 index: 1 text: another sentence + score: 0.11903437972068787 index: 2 text: a doc with a nested metadata +=================== +RESPONSE from HTTP: +{ + "results": [ + { + "query": "first sentence", + "scores": [ + { + "document": { + "text": "first sentence", + "title": "first title" + }, + "index": 0, + "score": 0.9999997019767761, + "text": "first sentence" + }, + { + "document": { + "_text": "another sentence", + "more": "more attributes here" + }, + "index": 1, + "score": 0.7350112199783325, + "text": "another sentence" + }, + { + "document": { + "text": "a doc with a nested metadata", + "meta": { + "foo": "bar", + "i": 999, + "f": 12.34 + } + }, + "index": 2, + "score": 0.10398174077272415, + "text": "a doc with a nested metadata" + } + ] + }, + { + "query": "any sentence", + "scores": [ + { + "document": { + "text": "first sentence", + "title": "first title" + }, + "index": 0, + "score": 0.6631797552108765, + "text": "first sentence" + }, + { + "document": { + "_text": "another sentence", + "more": "more attributes here" + }, + "index": 1, + "score": 0.6505964398384094, + "text": "another sentence" + }, + { + "document": { + "text": "a doc with a nested metadata", + "meta": { + "foo": "bar", + "i": 999, + "f": 12.34 + } + }, + "index": 2, + "score": 0.11903437972068787, + "text": "a doc with a nested metadata" + } + ] + } + ], + "producerId": { + "name": "EmbeddingModule", + "version": "0.0.1" + }, + "inputTokenCount": "9" +} +``` \ No newline at end of file diff --git a/examples/embeddings/config.py b/examples/embeddings/config.py new file mode 100644 index 00000000..c3375c76 --- /dev/null +++ b/examples/embeddings/config.py @@ -0,0 +1,11 @@ +# Standard +import os + +port = ( + os.getenv("CAIKIT_EMBEDDINGS_PORT") if os.getenv("CAIKIT_EMBEDDINGS_PORT") else 8085 +) +host = ( + os.getenv("CAIKIT_EMBEDDINGS_HOST") + if os.getenv("CAIKIT_EMBEDDINGS_HOST") + else "localhost" +) diff --git a/examples/embeddings/embedding_tasks.ipynb b/examples/embeddings/embedding_tasks.ipynb new file mode 100644 index 00000000..9a2ab0ae --- /dev/null +++ b/examples/embeddings/embedding_tasks.ipynb @@ -0,0 +1,789 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nqAU3Yh-rha5" + }, + "source": [ + "# Embeddings Taks Examples\n", + "\n", + "### Installation and Setup\n", + "\n", + "In this example Jupyter notebook, we'll be caikit-nlp to run the embeddings tasks available.\n", + "\n", + "### Installing `caikit` and `caikit-nlp`\n", + "\n", + "Next, we'll install specific versions of the caikit and caikit-nlp libraries, as the project is still in beta and breaking changes can happen." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ZhZcVULDrTRz", + "outputId": "aa1b6a72-1f39-4d37-b8a4-04d2176a6478" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Running command git clone --filter=blob:none --quiet https://github.com/caikit/caikit /private/var/folders/5x/cztshy892cbf92p2fdgqlxhc0000gn/T/pip-req-build-1dc7bgl5\n", + "Successfully installed caikit-0.26.24.dev2+g2d02e00\n", + " Running command git clone --filter=blob:none --quiet https://github.com/caikit/caikit-nlp /private/var/folders/5x/cztshy892cbf92p2fdgqlxhc0000gn/T/pip-req-build-4zjb5wzg\n", + "Successfully installed caikit-nlp-0.4.11 grpcio-1.63.0 grpcio-health-checking-1.62.2 grpcio-reflection-1.62.2\n" + ] + } + ], + "source": [ + "!pip install git+https://github.com/caikit/caikit | tail -n 1\n", + "!pip install git+https://github.com/caikit/caikit-nlp | tail -n 1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import the EmbeddingsModule\n", + "Then we instantiate the caikit module that contains the embeddings taks we want to run." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " is still in the BETA phase and subject to change!\n", + "[2024-05-19 15:58:17,179] torch.distributed.elastic.multiprocessing.redirects: [WARNING] NOTE: Redirects are currently not supported in Windows or MacOs.\n" + ] + } + ], + "source": [ + "from caikit_nlp.modules.text_embedding import EmbeddingModule" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Loading the model\n", + "\n", + "When running the code from the module without the runtime, we need to load the model we want to use by passing the path to it's directory that contains the bootstraped `config.yaml` and the `artifacts` folder.\n", + "\n", + "> Make sure you get the correct path from the model downloaded at the [Models](./README.md#models) section of the documentation." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "embeddings_module = EmbeddingModule.load('caikit-nlp/examples/embeddings/models/all-minilm-l6-v2/')" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "seq = \"Generate a summary of the context that answers the question. Explain the answer in multiple steps if possible. Answer style should match the context. Ideal Answer Length 2-3 sentences. To start a huddle: In Slack, open a channel or DM. Huddles work in Slack Connect, including Slack Connect DMs. On the bottom-left of your Slack sidebar, Open mini window icon. For more details, refer to Available Features. You can also start a huddle in a channel or DM. In the upper-right corner of your message window, click the headphones toggle.\" " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Code to retrieve embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "EmbeddingResults(results={\n", + " \"vectors\": [\n", + " {\n", + " \"data\": {\n", + " \"values\": [\n", + " -0.021235033869743347,\n", + " 0.005407060030847788,\n", + " -0.05031250789761543,\n", + " 0.02103376016020775,\n", + " -0.01952761597931385,\n", + " 0.04022081568837166,\n", + " 0.05267741531133652,\n", + " 0.07997128367424011,\n", + " -0.05027708038687706,\n", + " 0.0029696396086364985,\n", + " 0.02008694037795067,\n", + " -0.022737698629498482,\n", + " -0.03342238441109657,\n", + " -0.02924107201397419,\n", + " 0.07941676676273346,\n", + " 0.08066023886203766,\n", + " 0.012848478741943836,\n", + " -0.08107789605855942,\n", + " 0.013536344282329082,\n", + " 0.004463522229343653,\n", + " 0.07578443735837936,\n", + " -0.024243105202913284,\n", + " 0.009474135003983974,\n", + " 0.001784364227205515,\n", + " 0.011258895508944988,\n", + " -0.010532456450164318,\n", + " 0.004065200220793486,\n", + " 0.03179578483104706,\n", + " -0.005736625753343105,\n", + " -0.013960667885839939,\n", + " 0.08669246733188629,\n", + " -0.017347481101751328,\n", + " 0.13358548283576965,\n", + " 0.06395360827445984,\n", + " -0.015684617683291435,\n", + " 0.10341953486204147,\n", + " 0.006815536879003048,\n", + " 0.007421928457915783,\n", + " -0.029734350740909576,\n", + " -0.02422873117029667,\n", + " 0.026100093498826027,\n", + " -0.02857046201825142,\n", + " -0.012553386390209198,\n", + " 0.04234964773058891,\n", + " 0.04080246016383171,\n", + " -0.10286334902048111,\n", + " -0.06252577155828476,\n", + " -0.012835538946092129,\n", + " 0.019094401970505714,\n", + " 0.00535888085141778,\n", + " -0.0328969843685627,\n", + " -0.06895513832569122,\n", + " 0.0221876073628664,\n", + " -0.046914491802453995,\n", + " -0.005367487668991089,\n", + " -0.007742544636130333,\n", + " 0.0013248560717329383,\n", + " 0.10055586695671082,\n", + " 0.023445509374141693,\n", + " -0.00013345986371859908,\n", + " 0.0203069057315588,\n", + " -0.029342805966734886,\n", + " -0.07989127933979034,\n", + " -0.01451518852263689,\n", + " 0.03955736756324768,\n", + " 0.004272318445146084,\n", + " 0.0273045152425766,\n", + " 0.0019088068511337042,\n", + " -0.08474176377058029,\n", + " 0.05107026547193527,\n", + " -0.0908379778265953,\n", + " -0.02937319315969944,\n", + " -0.0548190139234066,\n", + " -0.02358042635023594,\n", + " -0.02644682303071022,\n", + " -0.06952518969774246,\n", + " 0.07299194484949112,\n", + " 0.013201785273849964,\n", + " -0.014183051884174347,\n", + " 0.005473137833178043,\n", + " -0.078541100025177,\n", + " 0.009881802834570408,\n", + " -0.00950150191783905,\n", + " 0.08684621006250381,\n", + " 0.042046044021844864,\n", + " 0.08892218768596649,\n", + " 0.05986150726675987,\n", + " 0.012897167354822159,\n", + " -0.05137587711215019,\n", + " -0.021497627720236778,\n", + " -0.12070707231760025,\n", + " -0.05436254292726517,\n", + " -0.034722913056612015,\n", + " 0.007247344125062227,\n", + " 0.0010218544630333781,\n", + " 0.017282985150814056,\n", + " -0.011663887649774551,\n", + " 0.013926330022513866,\n", + " -0.00971834734082222,\n", + " 0.04917515441775322,\n", + " 0.07010623812675476,\n", + " -0.001550190499983728,\n", + " -0.005577955394983292,\n", + " -0.020590296015143394,\n", + " 0.004789913538843393,\n", + " -0.044980697333812714,\n", + " -0.007839592173695564,\n", + " -0.0012734548654407263,\n", + " -0.009631364606320858,\n", + " 0.04565273970365524,\n", + " -0.001930751372128725,\n", + " -0.005540340207517147,\n", + " 0.03726514056324959,\n", + " -0.0031732716597616673,\n", + " 0.06050105392932892,\n", + " -0.008025945164263248,\n", + " -0.02727639675140381,\n", + " 0.07360674440860748,\n", + " 0.022478198632597923,\n", + " 0.0031944538932293653,\n", + " 0.08479750156402588,\n", + " 0.029232097789645195,\n", + " -0.026779362931847572,\n", + " -0.01567859761416912,\n", + " -0.0500980019569397,\n", + " -0.055112872272729874,\n", + " 0.012919710017740726,\n", + " 3.906626170579282e-33,\n", + " 0.06209629774093628,\n", + " -0.03467091917991638,\n", + " 0.06048177555203438,\n", + " 0.11794383823871613,\n", + " 0.02726908028125763,\n", + " -0.020709939301013947,\n", + " -0.04636675491929054,\n", + " 0.07772329449653625,\n", + " 0.024567145854234695,\n", + " 0.08308780938386917,\n", + " -0.017148278653621674,\n", + " 0.007116342894732952,\n", + " 0.031904421746730804,\n", + " -0.04350250959396362,\n", + " -0.029232650995254517,\n", + " -0.10858789831399918,\n", + " -0.08988913148641586,\n", + " -0.014144341461360455,\n", + " -0.05512676015496254,\n", + " 0.002025468507781625,\n", + " -0.0851856917142868,\n", + " -0.15837782621383667,\n", + " 0.013074349611997604,\n", + " 0.032473087310791016,\n", + " 0.08602656424045563,\n", + " -0.04621557891368866,\n", + " 0.09109809994697571,\n", + " -0.05884924158453941,\n", + " 0.013843545690178871,\n", + " 0.0109795443713665,\n", + " -0.05635201558470726,\n", + " -0.11180073767900467,\n", + " 0.011798023246228695,\n", + " -0.0021985198836773634,\n", + " -0.011004103347659111,\n", + " -0.02257312461733818,\n", + " -0.01864079013466835,\n", + " 0.03766394406557083,\n", + " -0.029037808999419212,\n", + " -0.007440229412168264,\n", + " -0.005429113283753395,\n", + " -0.03880683332681656,\n", + " 0.024294480681419373,\n", + " -0.04527240991592407,\n", + " -0.04138528183102608,\n", + " 0.05977773666381836,\n", + " 0.04934180527925491,\n", + " -0.0037967776879668236,\n", + " 0.06915342062711716,\n", + " 0.00983446929603815,\n", + " 0.032780930399894714,\n", + " -0.032358407974243164,\n", + " 0.013523639179766178,\n", + " -0.003942179027944803,\n", + " 0.04101790115237236,\n", + " -0.0776507705450058,\n", + " 0.00023890483134891838,\n", + " 0.018628118559718132,\n", + " -0.0043238429352641106,\n", + " 0.05584253370761871,\n", + " 0.03746451810002327,\n", + " 0.010690205730497837,\n", + " 0.008108246140182018,\n", + " 0.018529025837779045,\n", + " -0.02777915820479393,\n", + " 0.03906390815973282,\n", + " -0.007802413310855627,\n", + " -0.022678233683109283,\n", + " 0.013777872547507286,\n", + " -0.027063053101301193,\n", + " -0.03839331492781639,\n", + " 0.004471300169825554,\n", + " -0.03480680286884308,\n", + " 0.07302685081958771,\n", + " -0.08336115628480911,\n", + " 0.05537783354520798,\n", + " -0.012942178174853325,\n", + " 0.03744035214185715,\n", + " 0.00681062089279294,\n", + " -0.04396096616983414,\n", + " 0.07073654979467392,\n", + " -0.0181974358856678,\n", + " -0.05265658721327782,\n", + " -0.013782795518636703,\n", + " 0.0006169207626953721,\n", + " -0.020498482510447502,\n", + " 0.0329594649374485,\n", + " -0.0809134840965271,\n", + " -0.0970836654305458,\n", + " -0.031381379812955856,\n", + " -0.14832927286624908,\n", + " 0.032261546701192856,\n", + " 0.047380272299051285,\n", + " -0.01687476970255375,\n", + " 0.08339183032512665,\n", + " -3.8514128335791746e-33,\n", + " 0.06588003784418106,\n", + " -0.01516030915081501,\n", + " 0.019150158390402794,\n", + " -0.05137329548597336,\n", + " 0.02077265828847885,\n", + " 0.14335110783576965,\n", + " 0.0834193229675293,\n", + " 0.01308922003954649,\n", + " -0.020682260394096375,\n", + " 0.1347867101430893,\n", + " 0.01445856224745512,\n", + " -0.009214757941663265,\n", + " 0.028443923220038414,\n", + " 0.03464259207248688,\n", + " -0.005576323717832565,\n", + " -0.021238137036561966,\n", + " -0.05885425582528114,\n", + " -0.025843992829322815,\n", + " -0.01501612551510334,\n", + " 0.12529130280017853,\n", + " -0.06814100593328476,\n", + " -0.030087590217590332,\n", + " 0.03912295028567314,\n", + " -0.049984242767095566,\n", + " -0.03947773203253746,\n", + " 0.06323368847370148,\n", + " 0.02809607796370983,\n", + " 0.005717149004340172,\n", + " -0.009920122101902962,\n", + " -0.006467635277658701,\n", + " -0.02106902375817299,\n", + " -0.057317014783620834,\n", + " -0.0351557657122612,\n", + " -0.017118817195296288,\n", + " -0.10647141188383102,\n", + " 0.08869069069623947,\n", + " 0.005715256091207266,\n", + " 0.02274831011891365,\n", + " 0.0020732416305691004,\n", + " -0.04369700327515602,\n", + " 0.07731905579566956,\n", + " -0.08648794144392014,\n", + " 0.035607676953077316,\n", + " 0.00825593899935484,\n", + " -0.027381066232919693,\n", + " -0.032219048589468,\n", + " -0.05114440619945526,\n", + " -0.03060261532664299,\n", + " -0.1824631690979004,\n", + " 0.05379294604063034,\n", + " 0.024520311504602432,\n", + " 0.010771253146231174,\n", + " 0.01153299119323492,\n", + " -0.039151158183813095,\n", + " -0.09760303050279617,\n", + " 0.05836302787065506,\n", + " -0.03163729980587959,\n", + " -0.03566080704331398,\n", + " 0.0029704368207603693,\n", + " -0.05960788577795029,\n", + " 0.04115477576851845,\n", + " -0.008758237585425377,\n", + " -0.052527885884046555,\n", + " 0.032494425773620605,\n", + " 0.039141032844781876,\n", + " -0.030843526124954224,\n", + " -0.016273807734251022,\n", + " 0.008914263918995857,\n", + " -0.06448430567979813,\n", + " -0.07994430512189865,\n", + " -0.04104636237025261,\n", + " 0.048222556710243225,\n", + " 0.04914189875125885,\n", + " -0.01670641265809536,\n", + " 0.05294075980782509,\n", + " 0.07141680270433426,\n", + " -0.04140758514404297,\n", + " -0.15950848162174225,\n", + " 0.024030549451708794,\n", + " -0.03772575408220291,\n", + " 0.05671562999486923,\n", + " 0.040088627487421036,\n", + " -0.000479479058412835,\n", + " 0.09809868782758713,\n", + " -0.10765805095434189,\n", + " 0.03403845429420471,\n", + " 0.033328719437122345,\n", + " 0.06883085519075394,\n", + " 0.02439040318131447,\n", + " 0.004099009558558464,\n", + " -0.022474415600299835,\n", + " 0.04787461832165718,\n", + " 0.054905399680137634,\n", + " 0.09703219681978226,\n", + " -0.0021149744279682636,\n", + " -5.0256957706551475e-08,\n", + " -0.08289197832345963,\n", + " -0.028697030618786812,\n", + " 0.07010963559150696,\n", + " -0.0010948119452223182,\n", + " -0.072735495865345,\n", + " 0.06303180009126663,\n", + " -0.0073213884606957436,\n", + " -0.058972302824258804,\n", + " 0.034158773720264435,\n", + " -0.013043764047324657,\n", + " -0.008664760738611221,\n", + " -0.004007968585938215,\n", + " -0.01517412532120943,\n", + " 0.07948628067970276,\n", + " -0.01072731427848339,\n", + " 0.0070273312740027905,\n", + " -0.0021736789494752884,\n", + " 0.0945856049656868,\n", + " -0.009432066231966019,\n", + " -0.07700041681528091,\n", + " 0.09006796777248383,\n", + " 0.03282124921679497,\n", + " 0.03300708904862404,\n", + " 0.13614419102668762,\n", + " 0.019504176452755928,\n", + " -0.03615563362836838,\n", + " -0.007411245256662369,\n", + " 0.11008388549089432,\n", + " -0.043609559535980225,\n", + " -0.013116138055920601,\n", + " 0.0017177090048789978,\n", + " 0.020423712208867073,\n", + " -0.09361141175031662,\n", + " -0.009915024973452091,\n", + " -0.06769847124814987,\n", + " 0.022375263273715973,\n", + " -0.05416032299399376,\n", + " -0.04927198216319084,\n", + " 0.06300338357686996,\n", + " 0.03314073383808136,\n", + " 0.01125667616724968,\n", + " -0.02000957913696766,\n", + " -0.03769528120756149,\n", + " -0.002417295938357711,\n", + " -0.05737944692373276,\n", + " 0.07864715158939362,\n", + " -0.006273307837545872,\n", + " -0.0807550698518753,\n", + " -0.046455398201942444,\n", + " -0.01982717402279377,\n", + " -0.05301357060670853,\n", + " -0.04026421904563904,\n", + " 0.07347984611988068,\n", + " -0.024325627833604813,\n", + " -0.046872396022081375,\n", + " 0.009954046458005905,\n", + " 0.009784957394003868,\n", + " 0.03125231713056564,\n", + " -0.03335946053266525,\n", + " 0.04947611689567566,\n", + " -0.06522347033023834,\n", + " 0.09552241861820221,\n", + " -0.02097415179014206,\n", + " 0.03368087857961655\n", + " ]\n", + " }\n", + " }\n", + " ]\n", + "}, producer_id={\n", + " \"name\": \"EmbeddingModule\",\n", + " \"version\": \"0.0.1\"\n", + "}, input_token_count=124)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "embeddings_response = embeddings_module.run_embeddings(texts=[seq], truncate_input_tokens=0)\n", + "embeddings_response" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Sentence Similarity task" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{\n", + " \"result\": {\n", + " \"scores\": [\n", + " 0.8578777313232422,\n", + " 0.5489557981491089\n", + " ]\n", + " },\n", + " \"producer_id\": {\n", + " \"name\": \"EmbeddingModule\",\n", + " \"version\": \"0.0.1\"\n", + " },\n", + " \"input_token_count\": 18\n", + "}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ss_response = embeddings_module.run_sentence_similarity(\n", + " source_sentence=\"This is an apple\", \n", + " sentences=[\"This is another apple\", \"This is a banana\"],\n", + " truncate_input_tokens=0)\n", + "ss_response" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Reranker " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{\n", + " \"result\": {\n", + " \"query\": \"second\",\n", + " \"scores\": [\n", + " {\n", + " \"document\": {\n", + " \"text\": \"second sentence\",\n", + " \"additionalProp1\": 0,\n", + " \"additionalProp2\": 0,\n", + " \"additionalProp3\": 0\n", + " },\n", + " \"index\": 1,\n", + " \"score\": 0.5184812545776367,\n", + " \"text\": \"second sentence\"\n", + " },\n", + " {\n", + " \"document\": {\n", + " \"text\": \"first sentence\",\n", + " \"additionalProp1\": 0,\n", + " \"additionalProp2\": 0,\n", + " \"additionalProp3\": 0\n", + " },\n", + " \"index\": 0,\n", + " \"score\": 0.4005824625492096,\n", + " \"text\": \"first sentence\"\n", + " }\n", + " ]\n", + " },\n", + " \"producer_id\": {\n", + " \"name\": \"EmbeddingModule\",\n", + " \"version\": \"0.0.1\"\n", + " },\n", + " \"input_token_count\": 11\n", + "}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rr_results = embeddings_module.run_rerank_query(\n", + " documents= [\n", + " {\n", + " \"text\": \"first sentence\",\n", + " \"additionalProp1\": 0,\n", + " \"additionalProp2\": 0,\n", + " \"additionalProp3\": 0\n", + " },\n", + "\n", + " {\n", + " \"text\": \"second sentence\",\n", + " \"additionalProp1\": 0,\n", + " \"additionalProp2\": 0,\n", + " \"additionalProp3\": 0\n", + " }\n", + "\n", + " ],\n", + " query=\"second\")\n", + "rr_results" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{\n", + " \"results\": [\n", + " {\n", + " \"query\": \"banana\",\n", + " \"scores\": [\n", + " {\n", + " \"document\": {\n", + " \"text\": \"first sentence is this is a banana\",\n", + " \"additionalProp1\": 0,\n", + " \"additionalProp2\": 0,\n", + " \"additionalProp3\": 0\n", + " },\n", + " \"index\": 0,\n", + " \"score\": 0.723056972026825,\n", + " \"text\": \"first sentence is this is a banana\"\n", + " },\n", + " {\n", + " \"document\": {\n", + " \"text\": \"second sentence is this is an apple\",\n", + " \"additionalProp1\": 0,\n", + " \"additionalProp2\": 0,\n", + " \"additionalProp3\": 0\n", + " },\n", + " \"index\": 1,\n", + " \"score\": 0.28278833627700806,\n", + " \"text\": \"second sentence is this is an apple\"\n", + " }\n", + " ]\n", + " },\n", + " {\n", + " \"query\": \"is an apple\",\n", + " \"scores\": [\n", + " {\n", + " \"document\": {\n", + " \"text\": \"second sentence is this is an apple\",\n", + " \"additionalProp1\": 0,\n", + " \"additionalProp2\": 0,\n", + " \"additionalProp3\": 0\n", + " },\n", + " \"index\": 1,\n", + " \"score\": 0.8389687538146973,\n", + " \"text\": \"second sentence is this is an apple\"\n", + " },\n", + " {\n", + " \"document\": {\n", + " \"text\": \"first sentence is this is a banana\",\n", + " \"additionalProp1\": 0,\n", + " \"additionalProp2\": 0,\n", + " \"additionalProp3\": 0\n", + " },\n", + " \"index\": 0,\n", + " \"score\": 0.3868182599544525,\n", + " \"text\": \"first sentence is this is a banana\"\n", + " }\n", + " ]\n", + " }\n", + " ],\n", + " \"producer_id\": {\n", + " \"name\": \"EmbeddingModule\",\n", + " \"version\": \"0.0.1\"\n", + " },\n", + " \"input_token_count\": 26\n", + "}" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# more than 1 query\n", + "rr_results_multi = embeddings_module.run_rerank_queries(\n", + " documents=[\n", + " {\n", + " \"text\": \"first sentence is this is a banana\",\n", + " \"additionalProp1\": 0,\n", + " \"additionalProp2\": 0,\n", + " \"additionalProp3\": 0\n", + " },\n", + "\n", + " {\n", + " \"text\": \"second sentence is this is an apple\",\n", + " \"additionalProp1\": 0,\n", + " \"additionalProp2\": 0,\n", + " \"additionalProp3\": 0\n", + " }\n", + "\n", + " ],\n", + " queries=[\n", + " \"banana\",\n", + " \"is an apple\"\n", + " ])\n", + "rr_results_multi" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "authorship_tag": "ABX9TyM3vxACY/vOArAktjoeYZmS", + "gpuType": "T4", + "include_colab_link": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/examples/embeddings/embeddings.py b/examples/embeddings/embeddings.py new file mode 100644 index 00000000..280cc24d --- /dev/null +++ b/examples/embeddings/embeddings.py @@ -0,0 +1,68 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from os import path +import os +import sys + +# Third Party +from config import host, port +import grpc + +# First Party +from caikit.runtime.service_factory import ServicePackageFactory +import caikit + +# Add the runtime/library to the path +sys.path.append(path.abspath(path.join(path.dirname(__file__), "../../"))) + +# Load configuration for Caikit runtime +CONFIG_PATH = path.realpath(path.join(path.dirname(__file__), "config.yml")) +caikit.configure(CONFIG_PATH) + +# NOTE: The model id needs to be a path to folder. +# NOTE: This is relative path to the models directory +MODEL_ID = os.getenv("MODEL", "mini") + +inference_service = ServicePackageFactory().get_service_package( + ServicePackageFactory.ServiceType.INFERENCE, +) + +channel = grpc.insecure_channel(f"{host}:{port}") +client_stub = inference_service.stub_class(channel) + +# Create request object + +texts = ["test first sentence", "another test sentence"] +request = inference_service.messages.EmbeddingTasksRequest(texts=texts) + +# Fetch predictions from server (infer) +response = client_stub.EmbeddingTasksPredict( + request, metadata=[("mm-model-id", MODEL_ID)] +) + +# Print response +print("INPUTS TEXTS: ", texts) +print("RESULTS: [") +for d in response.results.vectors: + woo = d.WhichOneof("data") # which one of data_s did we get? + print(getattr(d, woo).values) +print("]") +print( + "LENGTH: ", + len(response.results.vectors), + " x ", + len(getattr(response.results.vectors[0], woo).values), +) diff --git a/examples/embeddings/models_/all-minilm-l6-v2/artifacts/example_model_dir.txt b/examples/embeddings/models_/all-minilm-l6-v2/artifacts/example_model_dir.txt new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/examples/embeddings/models_/all-minilm-l6-v2/artifacts/example_model_dir.txt @@ -0,0 +1 @@ + diff --git a/examples/embeddings/models_/all-minilm-l6-v2/config.yml b/examples/embeddings/models_/all-minilm-l6-v2/config.yml new file mode 100644 index 00000000..559b31c8 --- /dev/null +++ b/examples/embeddings/models_/all-minilm-l6-v2/config.yml @@ -0,0 +1,9 @@ +artifacts_path: artifacts +caikit_nlp_version: 0.4.0 +created: "2024-01-23 13:01:58.022674" +module_class: caikit_nlp.modules.text_embedding.embedding.EmbeddingModule +module_id: eeb12558-b4fa-4f34-a9fd-3f5890e9cd3f +name: EmbeddingModule +saved: "2024-01-23 13:01:58.022692" +tracking_id: 88a3a3de-da3d-4201-9220-274ad1540a82 +version: 0.0.1 diff --git a/examples/embeddings/requirements.txt b/examples/embeddings/requirements.txt new file mode 100644 index 00000000..8238a8b1 --- /dev/null +++ b/examples/embeddings/requirements.txt @@ -0,0 +1,2 @@ +caikit[runtime-grpc,runtime-http] +caikit-nlp \ No newline at end of file diff --git a/examples/embeddings/reranker.py b/examples/embeddings/reranker.py new file mode 100644 index 00000000..8a4c2824 --- /dev/null +++ b/examples/embeddings/reranker.py @@ -0,0 +1,114 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from os import getenv, path +import json +import sys + +# Third Party +from config import host, port +from google.protobuf.struct_pb2 import Struct +import grpc +import requests + +# First Party +from caikit.config.config import get_config +from caikit.runtime.service_factory import ServicePackageFactory +import caikit + +if __name__ == "__main__": + model_id = getenv("MODEL", "mini") + + # Add the runtime/library to the path + sys.path.append(path.abspath(path.join(path.dirname(__file__), "../../"))) + + # Load configuration for Caikit runtime + CONFIG_PATH = path.realpath(path.join(path.dirname(__file__), "config.yml")) + caikit.configure(CONFIG_PATH) + + inference_service = ServicePackageFactory().get_service_package( + ServicePackageFactory.ServiceType.INFERENCE, + ) + + top_n = 3 + queries = ["first sentence", "any sentence"] + documents = [ + {"text": "first sentence", "title": "first title"}, + {"_text": "another sentence", "more": "more attributes here"}, + { + "text": "a doc with a nested metadata", + "meta": {"foo": "bar", "i": 999, "f": 12.34}, + }, + ] + + print("======================") + print("TOP N: ", top_n) + print("QUERIES: ", queries) + print("DOCUMENTS: ", documents) + print("======================") + + if get_config().runtime.grpc.enabled: + + # Setup the client + channel = grpc.insecure_channel(f"{host}:{port}") + client_stub = inference_service.stub_class(channel) + + # gRPC JSON documents go in Structs + docs = [] + for d in documents: + s = Struct() + s.update(d) + docs.append(s) + + request = inference_service.messages.RerankTasksRequest( + queries=queries, documents=docs, top_n=top_n + ) + response = client_stub.RerankTasksPredict( + request, metadata=[("mm-model-id", model_id)], timeout=1 + ) + + # print("RESPONSE:", response) + + # gRPC response + print("RESPONSE from gRPC:") + for i, r in enumerate(response.results): + print("===") + print("QUERY: ", r.query) + for s in r.scores: + print(f" score: {s.score} index: {s.index} text: {s.text}") + + if get_config().runtime.http.enabled: + # REST payload + payload = { + "inputs": { + "documents": documents, + "queries": queries, + }, + "parameters": { + "top_n": -1, + "return_documents": True, + "return_queries": True, + "return_text": True, + }, + "model_id": model_id, + } + response = requests.post( + f"http://{host}:8080/api/v1/task/rerank-tasks", + json=payload, + timeout=1, + ) + print("===================") + print("RESPONSE from HTTP:") + print(json.dumps(response.json(), indent=4)) diff --git a/examples/embeddings/sentence_similarity.py b/examples/embeddings/sentence_similarity.py new file mode 100644 index 00000000..4f8e8291 --- /dev/null +++ b/examples/embeddings/sentence_similarity.py @@ -0,0 +1,62 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from os import path +import os +import sys + +# Third Party +from config import host, port +import grpc + +# First Party +from caikit.runtime.service_factory import ServicePackageFactory +import caikit + +# Add the runtime/library to the path +sys.path.append(path.abspath(path.join(path.dirname(__file__), "../../"))) + +# Load configuration for Caikit runtime +CONFIG_PATH = path.realpath(path.join(path.dirname(__file__), "config.yml")) +caikit.configure(CONFIG_PATH) + +# NOTE: The model id needs to be a path to folder. +# NOTE: This is relative path to the models directory +MODEL_ID = os.getenv("MODEL", "mini") + +inference_service = ServicePackageFactory().get_service_package( + ServicePackageFactory.ServiceType.INFERENCE, +) + +channel = grpc.insecure_channel(f"{host}:{port}") +client_stub = inference_service.stub_class(channel) + +# Create request object + +source_sentence = "first sentence" +sentences = ["test first sentence", "another test sentence"] +request = inference_service.messages.SentenceSimilarityTaskRequest( + source_sentence=source_sentence, sentences=sentences +) + +# Fetch predictions from server (infer) +response = client_stub.SentenceSimilarityTaskPredict( + request, metadata=[("mm-model-id", MODEL_ID)] +) + +# Print response +print("SOURCE SENTENCE: ", source_sentence) +print("SENTENCES: ", sentences) +print("RESULTS: ", [v for v in response.result.scores])