Skip to content

Commit 6c68820

Browse files
committed
Return scores from storage providers
1 parent c986daf commit 6c68820

File tree

7 files changed

+67
-14
lines changed

7 files changed

+67
-14
lines changed

src/django_ai_core/contrib/index/storage/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,10 @@ class BaseStorageDocument(VirtualModel):
3737
document_key: str
3838
content: str
3939
metadata: dict[str, Any]
40+
score: float = 0
4041

4142
class Meta:
42-
fields = ["document_key", "content", "metadata"]
43+
fields = ["document_key", "content", "metadata", "score"]
4344
storage_provider: "StorageProvider"
4445

4546
def __str__(self):

src/django_ai_core/contrib/index/storage/pgvector/models.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,20 @@
1+
from typing import Self, Sequence
2+
13
from django.db import models
2-
from pgvector.django import VectorField
4+
from pgvector.django import CosineDistance, VectorField
5+
6+
7+
class PgvectorEmbeddingQuerySet(models.QuerySet["BasePgVectorEmbedding"]):
8+
def annotate_with_distance(
9+
self,
10+
query_vector: Sequence[float],
11+
) -> Self:
12+
kwargs = {"distance": CosineDistance("vector", query_vector)}
13+
return self.annotate(**kwargs)
14+
15+
16+
class PgvectorEmbeddingManager(models.Manager.from_queryset(PgvectorEmbeddingQuerySet)):
17+
pass
318

419

520
class BasePgVectorEmbedding(models.Model):
@@ -12,6 +27,8 @@ class BasePgVectorEmbedding(models.Model):
1227
content = models.TextField()
1328
metadata = models.JSONField(default=dict)
1429

30+
objects = PgvectorEmbeddingManager()
31+
1532
class Meta:
1633
abstract = True
1734

src/django_ai_core/contrib/index/storage/pgvector/provider.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
from typing import TYPE_CHECKING, Generator, Type
2-
3-
from pgvector.django import CosineDistance
1+
from typing import TYPE_CHECKING, Generator, Type, cast
42

53
from ...schema import EmbeddedDocument
64
from ..base import BaseStorageDocument, BaseStorageQuerySet, StorageProvider
75

86
if TYPE_CHECKING:
9-
from .models import BasePgVectorEmbedding
7+
from .models import BasePgVectorEmbedding, PgvectorEmbeddingQuerySet
108

119

1210
class PgVectorQuerySet(BaseStorageQuerySet["PgVectorProvider"]):
@@ -18,6 +16,7 @@ def get_instance(self, val: "BasePgVectorEmbedding") -> BaseStorageDocument:
1816
document_key=val.document_key,
1917
content=val.content,
2018
metadata=val.metadata,
19+
score=1 - val.distance,
2120
)
2221

2322
def run_query(self) -> Generator[BaseStorageDocument, None, None]:
@@ -41,8 +40,9 @@ def run_query(self) -> Generator[BaseStorageDocument, None, None]:
4140
raise ValueError("Model class is required")
4241

4342
queryset = model.objects.filter(index_name=storage_provider.index_name)
43+
queryset = cast("PgvectorEmbeddingQuerySet", queryset)
4444

45-
queryset = queryset.order_by(CosineDistance("vector", embedding))
45+
queryset = queryset.annotate_with_distance(embedding).order_by("distance")
4646

4747
# Apply metadata filters if any
4848
for key, value in filter_map.items():

src/django_ai_core/contrib/index/storage/qdrant.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,29 @@
1+
import uuid
2+
13
from qdrant_client import QdrantClient
24
from qdrant_client.http import models as qdrant_models
35
from qdrant_client.models import Distance
46

57
from ..schema import EmbeddedDocument
68
from .base import BaseStorageDocument, BaseStorageQuerySet, StorageProvider
79

10+
# Key used for storing original content in metadata
11+
CONTENT_METADATA_KEY = "dj_ai_core_content"
12+
813

914
class QdrantQuerySet(BaseStorageQuerySet["QdrantProvider"]):
1015
def get_instance(self, val) -> BaseStorageDocument:
1116
if self.model:
12-
metadata = val["metadata"]
17+
metadata = val.payload
18+
content = metadata.pop(CONTENT_METADATA_KEY)
1319
return self.model(
14-
document_key=val["key"],
15-
content="",
20+
document_key=metadata["document_key"],
21+
content=content,
1622
metadata=metadata,
23+
score=val.score,
1724
)
1825
else:
19-
return val
26+
return val.payload
2027

2128
def run_query(self):
2229
if not self.storage_provider:
@@ -51,7 +58,7 @@ def run_query(self):
5158
query_filter=qdrant_models.Filter(must=filters),
5259
)
5360

54-
for vector in response["vectors"]:
61+
for vector in response:
5562
yield self.get_instance(vector)
5663

5764

@@ -89,7 +96,13 @@ def add(self, documents: list["EmbeddedDocument"]):
8996
for doc in documents:
9097
points.append(
9198
qdrant_models.PointStruct(
92-
id=doc.document_key, vector=doc.vector, payload=doc.metadata
99+
id=str(uuid.uuid4()),
100+
vector=doc.vector,
101+
payload={
102+
**doc.metadata,
103+
CONTENT_METADATA_KEY: doc.content,
104+
"document_key": doc.document_key,
105+
},
93106
)
94107
)
95108

@@ -99,7 +112,16 @@ def delete(self, document_keys: list[str]):
99112
"""Delete documents by their keys."""
100113
self.client.delete(
101114
collection_name=self.index_name,
102-
points_selector=qdrant_models.PointIdsList(points=document_keys),
115+
points_selector=qdrant_models.FilterSelector(
116+
filter=qdrant_models.Filter(
117+
must=[
118+
qdrant_models.FieldCondition(
119+
key="document_key",
120+
match=qdrant_models.MatchAny(any=document_keys),
121+
)
122+
]
123+
)
124+
),
103125
)
104126

105127
def clear(self):

src/django_ai_core/contrib/index/storage/s3vectors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def get_instance(self, val) -> BaseStorageDocument:
1616
document_key=val["key"],
1717
content=content,
1818
metadata=metadata,
19+
score=1 - val["distance"],
1920
)
2021
else:
2122
return val

tests/testapp/indexes.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@
2929
storage_provider = S3VectorProvider(
3030
bucket_name="prototyping-vector-bucket", dimensions=1536
3131
)
32+
elif storage_provider_setting == "qdrant":
33+
from django_ai_core.contrib.index.storage.qdrant import QdrantProvider
34+
35+
storage_provider = QdrantProvider(
36+
host=settings.AI_CORE_TESTAPP_QDRANT_HOST,
37+
port=6333,
38+
api_key=settings.AI_CORE_TESTAPP_QDRANT_API_KEY,
39+
dimensions=1536,
40+
)
3241
else:
3342
from django_ai_core.contrib.index.storage.inmemory import InMemoryProvider
3443

tests/testapp/settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,6 @@
174174
"postgres://postgres:[email protected]:5432/postgres"
175175
),
176176
}
177+
178+
AI_CORE_TESTAPP_QDRANT_HOST = os.environ.get("AI_CORE_TESTAPP_QDRANT_HOST", None)
179+
AI_CORE_TESTAPP_QDRANT_API_KEY = os.environ.get("AI_CORE_TESTAPP_QDRANT_API_KEY", None)

0 commit comments

Comments
 (0)