Skip to content

Commit 1db0119

Browse files
author
Jelmer Kuperus
committed
Add support for incremental models
1 parent ffea776 commit 1db0119

File tree

39 files changed

+289
-189
lines changed

39 files changed

+289
-189
lines changed

hnswlib-core-jdk17/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
<parent>
1010
<groupId>com.github.jelmerk</groupId>
1111
<artifactId>hnswlib-parent-pom</artifactId>
12-
<version>1.0.1</version>
12+
<version>1.1.0</version>
1313
<relativePath>..</relativePath>
1414
</parent>
1515

hnswlib-core/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
<parent>
1010
<groupId>com.github.jelmerk</groupId>
1111
<artifactId>hnswlib-parent-pom</artifactId>
12-
<version>1.0.1</version>
12+
<version>1.1.0</version>
1313
<relativePath>..</relativePath>
1414
</parent>
1515

hnswlib-core/src/main/java/com/github/jelmerk/knn/hnsw/HnswIndex.java

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import com.github.jelmerk.knn.*;
55
import com.github.jelmerk.knn.util.*;
6-
import com.github.jelmerk.knn.util.BitSet;
76
import org.eclipse.collections.api.list.primitive.MutableIntList;
87
import org.eclipse.collections.api.map.primitive.MutableObjectIntMap;
98
import org.eclipse.collections.api.map.primitive.MutableObjectLongMap;
@@ -67,9 +66,9 @@ public class HnswIndex<TId, TVector, TItem extends Item<TId, TVector>, TDistance
6766

6867
private ReentrantLock globalLock;
6968

70-
private GenericObjectPool<BitSet> visitedBitSetPool;
69+
private GenericObjectPool<ArrayBitSet> visitedBitSetPool;
7170

72-
private BitSet excludedCandidates;
71+
private ArrayBitSet excludedCandidates;
7372

7473
private ExactView exactView;
7574

@@ -103,7 +102,7 @@ private HnswIndex(RefinedBuilder<TId, TVector, TItem, TDistance> builder) {
103102
this.visitedBitSetPool = new GenericObjectPool<>(() -> new ArrayBitSet(this.maxItemCount),
104103
Runtime.getRuntime().availableProcessors());
105104

106-
this.excludedCandidates = new SynchronizedBitSet(new ArrayBitSet(this.maxItemCount));
105+
this.excludedCandidates = new ArrayBitSet(this.maxItemCount);
107106

108107
this.exactView = new ExactView();
109108
}
@@ -250,7 +249,9 @@ public boolean add(TItem item) {
250249

251250
int newNodeId = nodeCount++;
252251

253-
excludedCandidates.add(newNodeId);
252+
synchronized (excludedCandidates) {
253+
excludedCandidates.add(newNodeId);
254+
}
254255

255256
Node<TItem> newNode = new Node<>(newNodeId, connections, item, false);
256257

@@ -339,7 +340,9 @@ public boolean add(TItem item) {
339340
}
340341
}
341342
} finally {
342-
excludedCandidates.remove(newNodeId);
343+
synchronized (excludedCandidates) {
344+
excludedCandidates.remove(newNodeId);
345+
}
343346
}
344347
} finally {
345348
if (globalLock.isHeldByCurrentThread()) {
@@ -363,8 +366,10 @@ private void mutuallyConnectNewElement(Node<TItem> newNode,
363366
while (!topCandidates.isEmpty()) {
364367
int selectedNeighbourId = topCandidates.poll().nodeId;
365368

366-
if (excludedCandidates.contains(selectedNeighbourId)) {
367-
continue;
369+
synchronized (excludedCandidates) {
370+
if (excludedCandidates.contains(selectedNeighbourId)) {
371+
continue;
372+
}
368373
}
369374

370375
newItemConnections.add(selectedNeighbourId);
@@ -519,10 +524,34 @@ public List<SearchResult<TItem, TDistance>> findNearest(TVector destination, int
519524
return results;
520525
}
521526

527+
/**
528+
* Changes the maximum capacity of the index.
529+
* @param newSize new size of the index
530+
*/
531+
public void resize(int newSize) {
532+
globalLock.lock();
533+
try {
534+
this.maxItemCount = newSize;
535+
536+
this.visitedBitSetPool = new GenericObjectPool<>(() -> new ArrayBitSet(this.maxItemCount),
537+
Runtime.getRuntime().availableProcessors());
538+
539+
AtomicReferenceArray<Node<TItem>> newNodes = new AtomicReferenceArray<>(newSize);
540+
for(int i = 0; i < this.nodes.length(); i++) {
541+
newNodes.set(i, this.nodes.get(i));
542+
}
543+
this.nodes = newNodes;
544+
545+
this.excludedCandidates = new ArrayBitSet(this.excludedCandidates, newSize);
546+
} finally {
547+
globalLock.unlock();
548+
}
549+
}
550+
522551
private PriorityQueue<NodeIdAndDistance<TDistance>> searchBaseLayer(
523552
Node<TItem> entryPointNode, TVector destination, int k, int layer) {
524553

525-
BitSet visitedBitSet = visitedBitSetPool.borrowObject();
554+
ArrayBitSet visitedBitSet = visitedBitSetPool.borrowObject();
526555

527556
try {
528557
PriorityQueue<NodeIdAndDistance<TDistance>> topCandidates =
@@ -778,7 +807,7 @@ private void readObject(ObjectInputStream ois) throws IOException, ClassNotFound
778807
this.globalLock = new ReentrantLock();
779808
this.visitedBitSetPool = new GenericObjectPool<>(() -> new ArrayBitSet(this.maxItemCount),
780809
Runtime.getRuntime().availableProcessors());
781-
this.excludedCandidates = new SynchronizedBitSet(new ArrayBitSet(this.maxItemCount));
810+
this.excludedCandidates = new ArrayBitSet(this.maxItemCount);
782811
this.locks = new HashMap<>();
783812
this.exactView = new ExactView();
784813
}

hnswlib-core/src/main/java/com/github/jelmerk/knn/util/ArrayBitSet.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
/**
77
* Bitset.
88
*/
9-
public class ArrayBitSet implements BitSet, Serializable {
9+
public class ArrayBitSet implements Serializable {
1010

1111
private static final long serialVersionUID = 1L;
1212

@@ -21,10 +21,18 @@ public ArrayBitSet(int count) {
2121
this.buffer = new int[(count >> 5) + 1];
2222
}
2323

24+
/**
25+
* Initializes a new instance of the {@link ArrayBitSet} class. and copies the values
26+
* of another bitset
27+
* @param count The number of items in the set.
28+
*/
29+
public ArrayBitSet(ArrayBitSet other, int count) {
30+
this.buffer = Arrays.copyOf(other.buffer, (count >> 5) + 1);
31+
}
32+
2433
/**
2534
* {@inheritDoc}
2635
*/
27-
@Override
2836
public boolean contains(int id) {
2937
int carrier = this.buffer[id >> 5];
3038
return ((1 << (id & 31)) & carrier) != 0;
@@ -33,7 +41,6 @@ public boolean contains(int id) {
3341
/**
3442
* {@inheritDoc}
3543
*/
36-
@Override
3744
public void add(int id) {
3845
int mask = 1 << (id & 31);
3946
this.buffer[id >> 5] |= mask;
@@ -42,7 +49,6 @@ public void add(int id) {
4249
/**
4350
* {@inheritDoc}
4451
*/
45-
@Override
4652
public void remove(int id) {
4753
int mask = 1 << (id & 31);
4854
this.buffer[id >> 5] &= ~mask;
@@ -51,7 +57,6 @@ public void remove(int id) {
5157
/**
5258
* {@inheritDoc}
5359
*/
54-
@Override
5560
public void clear() {
5661
Arrays.fill(this.buffer, 0);
5762
}

hnswlib-core/src/main/java/com/github/jelmerk/knn/util/BitSet.java

Lines changed: 0 additions & 31 deletions
This file was deleted.

hnswlib-core/src/main/java/com/github/jelmerk/knn/util/SynchronizedBitSet.java

Lines changed: 0 additions & 51 deletions
This file was deleted.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package com.github.jelmerk.knn.util;
2+
3+
import org.junit.jupiter.api.Test;
4+
5+
import static org.hamcrest.CoreMatchers.is;
6+
import static org.hamcrest.MatcherAssert.assertThat;
7+
8+
public class ArrayBitSetTest {
9+
10+
@Test
11+
void copyConstructor() {
12+
ArrayBitSet bitset = new ArrayBitSet(100);
13+
bitset.add(50);
14+
ArrayBitSet other = new ArrayBitSet(bitset, 200);
15+
other.add(101);
16+
assertThat(other.contains(50), is(true));
17+
assertThat(other.contains(101), is(true));
18+
}
19+
}

hnswlib-examples/hnswlib-examples-java/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
<parent>
1010
<groupId>com.github.jelmerk</groupId>
1111
<artifactId>hnswlib-examples-parent-pom</artifactId>
12-
<version>1.0.1</version>
12+
<version>1.1.0</version>
1313
<relativePath>..</relativePath>
1414
</parent>
1515

hnswlib-examples/hnswlib-examples-pyspark-luigi/flow.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class Convert(SparkSubmitTask):
6969

7070
app = 'convert.py'
7171

72-
packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.0.1']
72+
packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0']
7373

7474
def requires(self):
7575
return Unzip()
@@ -109,7 +109,7 @@ class HnswIndex(SparkSubmitTask):
109109

110110
app = 'hnsw_index.py'
111111

112-
packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.0.1']
112+
packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0']
113113

114114
m = IntParameter(default=16)
115115

@@ -164,7 +164,7 @@ class Query(SparkSubmitTask):
164164

165165
executor_cores = IntParameter(default=2)
166166

167-
packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.0.1']
167+
packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0']
168168

169169
name = 'Query index'
170170

@@ -230,7 +230,7 @@ class BruteForceIndex(SparkSubmitTask):
230230

231231
app = 'bruteforce_index.py'
232232

233-
packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.0.1']
233+
packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0']
234234

235235
@property
236236
def conf(self):
@@ -291,7 +291,7 @@ class Evaluate(SparkSubmitTask):
291291

292292
app = 'evaluate_performance.py'
293293

294-
packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.0.1']
294+
packages = ['com.github.jelmerk:hnswlib-spark_2.4_2.11:1.1.0']
295295

296296
@property
297297
def conf(self):

hnswlib-examples/hnswlib-examples-scala/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
<parent>
1010
<groupId>com.github.jelmerk</groupId>
1111
<artifactId>hnswlib-examples-parent-pom</artifactId>
12-
<version>1.0.1</version>
12+
<version>1.1.0</version>
1313
<relativePath>..</relativePath>
1414
</parent>
1515

0 commit comments

Comments
 (0)