Skip to content

Commit 8a4126b

Browse files
authored
Merge pull request #196 from malmeloo/fix/improve-keygen-performance
feat: cache more intermediate accessory keys
2 parents 0eb3a2c + 997cf82 commit 8a4126b

File tree

1 file changed

+34
-11
lines changed

1 file changed

+34
-11
lines changed

findmy/accessory.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,11 @@ def from_json(
377377
class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
378378
"""KeyPair generator. Uses the same algorithm internally as FindMy accessories do."""
379379

380+
# cache enough keys for an entire week.
381+
# every interval'th key is cached.
382+
_CACHE_SIZE = 4 * 24 * 7 # 4 keys / hour
383+
_CACHE_INTERVAL = 10
384+
380385
def __init__(
381386
self,
382387
master_key: bytes,
@@ -401,8 +406,7 @@ def __init__(
401406
self._initial_sk = initial_sk
402407
self._key_type = key_type
403408

404-
self._cur_sk = initial_sk
405-
self._cur_sk_ind = 0
409+
self._sk_cache: dict[int, bytes] = {}
406410

407411
self._iter_ind = 0
408412

@@ -426,14 +430,33 @@ def _get_sk(self, ind: int) -> bytes:
426430
msg = "The key index must be non-negative"
427431
raise ValueError(msg)
428432

429-
if ind < self._cur_sk_ind: # behind us; need to reset :(
430-
self._cur_sk = self._initial_sk
431-
self._cur_sk_ind = 0
433+
# retrieve from cache
434+
cached_sk = self._sk_cache.get(ind)
435+
if cached_sk is not None:
436+
return cached_sk
437+
438+
# not in cache: find largest cached index smaller than ind (if exists)
439+
start_ind: int = 0
440+
cur_sk: bytes = self._initial_sk
441+
for cached_ind in self._sk_cache:
442+
if cached_ind < ind and cached_ind > start_ind:
443+
start_ind = cached_ind
444+
cur_sk = self._sk_cache[cached_ind]
445+
446+
# compute and update cache
447+
for cur_ind in range(start_ind, ind):
448+
cur_sk = crypto.x963_kdf(cur_sk, b"update", 32)
449+
450+
# insert intermediate result into cache and evict oldest entry if necessary
451+
if cur_ind % self._CACHE_INTERVAL == 0:
452+
self._sk_cache[cur_ind] = cur_sk
453+
454+
if len(self._sk_cache) > self._CACHE_SIZE:
455+
# evict oldest entry
456+
oldest_ind = min(self._sk_cache.keys())
457+
del self._sk_cache[oldest_ind]
432458

433-
for _ in range(self._cur_sk_ind, ind):
434-
self._cur_sk = crypto.x963_kdf(self._cur_sk, b"update", 32)
435-
self._cur_sk_ind += 1
436-
return self._cur_sk
459+
return cur_sk
437460

438461
def _get_keypair(self, ind: int) -> KeyPair:
439462
sk = self._get_sk(ind)
@@ -449,14 +472,14 @@ def _generate_keys(self, start: int, stop: int | None) -> Generator[KeyPair, Non
449472

450473
@override
451474
def __iter__(self) -> KeyGenerator:
452-
self._iter_ind = -1
453475
return self
454476

455477
@override
456478
def __next__(self) -> KeyPair:
479+
key = self._get_keypair(self._iter_ind)
457480
self._iter_ind += 1
458481

459-
return self._get_keypair(self._iter_ind)
482+
return key
460483

461484
@overload
462485
def __getitem__(self, val: int) -> KeyPair: ...

0 commit comments

Comments
 (0)