Skip to content

Commit c12cb82

Browse files
authored
Merge pull request #344 from markstur/embed_same_test
Embedding add a test that would have helped
2 parents d34987a + 55b07f8 commit c12cb82

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

tests/modules/text_embedding/test_embedding.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,3 +1054,43 @@ def test_encode_extensions(loaded_model):
10541054
BOOTSTRAPPED_MODEL._encode_with_retry(
10551055
"text here"
10561056
) # and no KeyError trying to remove non-existing keys
1057+
1058+
1059+
@pytest.mark.parametrize(
1060+
"truncate_input_tokens",
1061+
[0, 1, 2, 3, 4, 5, 99, 100, 101, 300, 510, 511, 512, 513, 1000, -1],
1062+
)
1063+
def test_same_same(loaded_model: EmbeddingModule, truncate_input_tokens):
1064+
"""Confirm that same text gives same results"""
1065+
1066+
inputs = ["What is generative ai?", "What is generative ai?", "different"]
1067+
1068+
# First ensuring that batch input vs loop over inputs is the same
1069+
separate_embeddings = [
1070+
loaded_model.run_embedding(text=i, truncate_input_tokens=truncate_input_tokens)
1071+
for i in inputs
1072+
]
1073+
combined_embeddings = loaded_model.run_embeddings(
1074+
texts=inputs, truncate_input_tokens=truncate_input_tokens
1075+
)
1076+
1077+
separate_vectors = [
1078+
e.to_dict()["result"]["data"]["values"] for e in separate_embeddings
1079+
]
1080+
combined_vectors = [
1081+
e["data"]["values"] for e in combined_embeddings.to_dict()["results"]["vectors"]
1082+
]
1083+
1084+
assert len(separate_vectors) == len(
1085+
combined_vectors
1086+
), "expected the same number separate and combined embeddings"
1087+
1088+
# test order by comparing value of individual embeddings in sequence
1089+
for i, e in enumerate(separate_vectors):
1090+
assert np.allclose(e, combined_vectors[i])
1091+
1092+
# Next ensuring that the two identical sentences yield identical results (and 3rd does not)
1093+
assert np.array_equal(combined_vectors[0], combined_vectors[1])
1094+
assert not np.array_equal(combined_vectors[1], combined_vectors[2])
1095+
assert np.array_equal(separate_vectors[0], separate_vectors[1])
1096+
assert not np.array_equal(separate_vectors[1], separate_vectors[2])

0 commit comments

Comments
 (0)