Skip to content

Commit 5e750cc

Browse files
committed
Caikit embeddings examples + local run documentation
Signed-off-by: Flavia Beo <[email protected]>
1 parent c12cb82 commit 5e750cc

File tree

6 files changed

+1316
-0
lines changed

6 files changed

+1316
-0
lines changed

examples/embeddings/README.md

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
# Set up and run locally caikit embeddings server
2+
3+
#### Setting Up Virtual Environment using Python venv
4+
5+
For [(venv)](https://docs.python.org/3/library/venv.html), make sure you are in an activated `venv` when running `python` in the example commands that follow. Use `deactivate` if you want to exit the `venv`.
6+
7+
```shell
8+
python3 -m venv venv
9+
source venv/bin/activate
10+
```
11+
12+
### Models
13+
14+
To create a model configuration and artifacts, the best practice is to run the module's bootstrap() and save() methods. This will:
15+
16+
* Load the model by name (from Hugging Face hub or repository) or from a local directory. The model is loaded using the sentence-transformers library.
17+
* Save a config.yml which:
18+
* Ties the model to the module (with a module_id GUID)
19+
* Sets the artifacts_path to the default "artifacts" subdirectory
20+
* Saves the model in the artifacts subdirectory
21+
22+
> For the reranker service, models supported are bi-encoder and are the same used by the other embeddings tasks.
23+
24+
This can be achieved by the following lines of code, using BGE as example model:
25+
26+
```python
27+
import os
28+
os.environ['ALLOW_DOWNLOADS'] = "1"
29+
30+
import caikit_nlp
31+
model_name = "BAAI/bge-large-en-v1.5"
32+
model = caikit_nlp.text_embedding.EmbeddingModule.bootstrap(model_name)
33+
model.save(f"{model_name}-caikit")
34+
```
35+
36+
To avoid overwriting your files, the save() will return an error if the output directory already exists. You may want to use a temporary name. After success, move the output directory to a `<model-id>` directory under your local models dir.
37+
38+
### Environment variables
39+
40+
These are the set of variables/params related to the environment which embeddings will be run:
41+
42+
```bash
43+
# use IPEX optimization
44+
IPEX_OPTIMIZE: 'true'
45+
46+
# use "xpu" for IPEX on GPU instead of IPEX on CPU
47+
USE_XPU: 'false'
48+
49+
# IPEX performs best with autocast using bfloat16
50+
BFLOAT16: '1'
51+
52+
# use Mac chip
53+
USE_MPS: 'false'
54+
55+
# use Pytorch compile
56+
PT2_COMPILE: 'false'
57+
```
58+
59+
### Starting the Caikit Runtime
60+
61+
Run caikit-runtime configured to use the caikit-nlp library. Set up the following environment variables:
62+
63+
```bash
64+
export RUNTIME_HTTP_ENABLED=true
65+
export RUNTIME_LOCAL_MODELS_DIR=/models
66+
export RUNTIME_LAZY_LOAD_LOCAL_MODELS=true
67+
export RUNTIME_LIBRARY='caikit_nlp'
68+
```
69+
70+
In one terminal, start the runtime server:
71+
72+
```bash
73+
source venv/bin/activate
74+
pip install -r requirements.txt
75+
caikit-runtime
76+
```
77+
78+
To run the library locally:
79+
80+
```bash
81+
pip install caikit-nlp@file:///<path-to-your-local-caikit_nlp-clone-repo>/caikit-nlp
82+
python -m caikit.runtime
83+
```
84+
85+
### Embedding retrieval example Python client
86+
87+
In another terminal, run the example client code to retrieve embeddings.
88+
89+
```shell
90+
source venv/bin/activate
91+
cd demo/client
92+
MODEL=<model-id> python embeddings.py
93+
```
94+
95+
The client code calls the model and queries for embeddings using 2 example sentences.
96+
97+
You should see output similar to the following:
98+
99+
```ShellSession
100+
$ python embeddings.py
101+
INPUT TEXTS: ['test first sentence', 'another test sentence']
102+
OUTPUT: {
103+
{
104+
"results": [
105+
[
106+
-0.17895537614822388,
107+
0.03200146183371544,
108+
-0.030327674001455307,
109+
...
110+
],
111+
[
112+
-0.17895537614822388,
113+
0.03200146183371544,
114+
-0.030327674001455307,
115+
...
116+
]
117+
],
118+
"producerId": {
119+
"name": "EmbeddingModule",
120+
"version": "0.0.1"
121+
},
122+
"inputTokenCount": "9"
123+
}
124+
}
125+
LENGTH: 2 x 384
126+
```
127+
128+
### Sentence similarity example Python client
129+
130+
In another terminal, run the client code to infer sentence similarity.
131+
132+
```shell
133+
source venv/bin/activate
134+
cd demo/client
135+
MODEL=<model-id> python sentence_similarity.py
136+
```
137+
138+
The client code calls the model and queries sentence similarity using 1 source sentence and 2 other sentences (hardcoded in sentence_similarity.py). The result produces the cosine similarity score by comparing the source sentence with each of the other sentences.
139+
140+
You should see output similar to the following:
141+
142+
```ShellSession
143+
$ python sentence_similarity.py
144+
SOURCE SENTENCE: first sentence
145+
SENTENCES: ['test first sentence', 'another test sentence']
146+
OUTPUT: {
147+
"result": {
148+
"scores": [
149+
1.0000001192092896
150+
]
151+
},
152+
"producerId": {
153+
"name": "EmbeddingModule",
154+
"version": "0.0.1"
155+
},
156+
"inputTokenCount": "9"
157+
}
158+
```
159+
160+
### Reranker example Python client
161+
162+
In another terminal, run the client code to execute the reranker task using both gRPC and REST.
163+
164+
```shell
165+
source venv/bin/activate
166+
cd demo/client
167+
MODEL=<model-id> python reranker.py
168+
```
169+
170+
You should see output similar to the following:
171+
172+
```ShellSession
173+
$ python reranker.py
174+
======================
175+
TOP N: 3
176+
QUERIES: ['first sentence', 'any sentence']
177+
DOCUMENTS: [{'text': 'first sentence', 'title': 'first title'}, {'_text': 'another sentence', 'more': 'more attributes here'}, {'text': 'a doc with a nested metadata', 'meta': {'foo': 'bar', 'i': 999, 'f': 12.34}}]
178+
======================
179+
RESPONSE from gRPC:
180+
===
181+
QUERY: first sentence
182+
score: 0.9999997019767761 index: 0 text: first sentence
183+
score: 0.7350112199783325 index: 1 text: another sentence
184+
score: 0.10398174077272415 index: 2 text: a doc with a nested metadata
185+
===
186+
QUERY: any sentence
187+
score: 0.6631797552108765 index: 0 text: first sentence
188+
score: 0.6505964398384094 index: 1 text: another sentence
189+
score: 0.11903437972068787 index: 2 text: a doc with a nested metadata
190+
===================
191+
RESPONSE from HTTP:
192+
{
193+
"results": [
194+
{
195+
"query": "first sentence",
196+
"scores": [
197+
{
198+
"document": {
199+
"text": "first sentence",
200+
"title": "first title"
201+
},
202+
"index": 0,
203+
"score": 0.9999997019767761,
204+
"text": "first sentence"
205+
},
206+
{
207+
"document": {
208+
"_text": "another sentence",
209+
"more": "more attributes here"
210+
},
211+
"index": 1,
212+
"score": 0.7350112199783325,
213+
"text": "another sentence"
214+
},
215+
{
216+
"document": {
217+
"text": "a doc with a nested metadata",
218+
"meta": {
219+
"foo": "bar",
220+
"i": 999,
221+
"f": 12.34
222+
}
223+
},
224+
"index": 2,
225+
"score": 0.10398174077272415,
226+
"text": "a doc with a nested metadata"
227+
}
228+
]
229+
},
230+
{
231+
"query": "any sentence",
232+
"scores": [
233+
{
234+
"document": {
235+
"text": "first sentence",
236+
"title": "first title"
237+
},
238+
"index": 0,
239+
"score": 0.6631797552108765,
240+
"text": "first sentence"
241+
},
242+
{
243+
"document": {
244+
"_text": "another sentence",
245+
"more": "more attributes here"
246+
},
247+
"index": 1,
248+
"score": 0.6505964398384094,
249+
"text": "another sentence"
250+
},
251+
{
252+
"document": {
253+
"text": "a doc with a nested metadata",
254+
"meta": {
255+
"foo": "bar",
256+
"i": 999,
257+
"f": 12.34
258+
}
259+
},
260+
"index": 2,
261+
"score": 0.11903437972068787,
262+
"text": "a doc with a nested metadata"
263+
}
264+
]
265+
}
266+
],
267+
"producerId": {
268+
"name": "EmbeddingModule",
269+
"version": "0.0.1"
270+
},
271+
"inputTokenCount": "9"
272+
}
273+
```

examples/embeddings/embeddings.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright The Caikit Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Third Party
16+
import grpc
17+
from os import path
18+
import sys
19+
import os
20+
21+
# Local
22+
import caikit
23+
from caikit.runtime.service_factory import ServicePackageFactory
24+
25+
# Add the runtime/library to the path
26+
sys.path.append(
27+
path.abspath(path.join(path.dirname(__file__), "../../"))
28+
)
29+
30+
# Load configuration for Caikit runtime
31+
CONFIG_PATH = path.realpath(
32+
path.join(path.dirname(__file__), "config.yml")
33+
)
34+
caikit.configure(CONFIG_PATH)
35+
36+
# NOTE: The model id needs to be a path to folder.
37+
# NOTE: This is relative path to the models directory
38+
MODEL_ID = os.getenv("MODEL", "mini")
39+
40+
inference_service = ServicePackageFactory().get_service_package(
41+
ServicePackageFactory.ServiceType.INFERENCE,
42+
)
43+
44+
port = os.getenv('CAIKIT_EMBEDDINGS_PORT') if os.getenv('CAIKIT_EMBEDDINGS_PORT') else 8085
45+
host = os.getenv('CAIKIT_EMBEDDINGS_HOST') if os.getenv('CAIKIT_EMBEDDINGS_HOST') else 'localhost'
46+
channel = grpc.insecure_channel(f"{host}:{port}")
47+
client_stub = inference_service.stub_class(channel)
48+
49+
# Create request object
50+
51+
texts = ["test first sentence", "another test sentence"]
52+
request = inference_service.messages.EmbeddingTasksRequest(texts=texts)
53+
54+
# Fetch predictions from server (infer)
55+
response = client_stub.EmbeddingTasksPredict(
56+
request, metadata=[("mm-model-id", MODEL_ID)]
57+
)
58+
59+
# Print response
60+
print("INPUTS TEXTS: ", texts)
61+
print("RESULTS: [")
62+
for d in response.results.vectors:
63+
woo = d.WhichOneof("data") # which one of data_<float_type>s did we get?
64+
print(getattr(d, woo).values)
65+
print("]")
66+
print("LENGTH: ", len(response.results.vectors), " x ",
67+
len(getattr(response.results.vectors[0], woo).values))

0 commit comments

Comments
 (0)