@@ -377,6 +377,11 @@ def from_json(
377377class _AccessoryKeyGenerator (KeyGenerator [KeyPair ]):
378378 """KeyPair generator. Uses the same algorithm internally as FindMy accessories do."""
379379
380+ # cache enough keys for an entire week.
381+ # every interval'th key is cached.
382+ _CACHE_SIZE = 4 * 24 * 7 # 4 keys / hour
383+ _CACHE_INTERVAL = 10
384+
380385 def __init__ (
381386 self ,
382387 master_key : bytes ,
@@ -401,8 +406,7 @@ def __init__(
401406 self ._initial_sk = initial_sk
402407 self ._key_type = key_type
403408
404- self ._cur_sk = initial_sk
405- self ._cur_sk_ind = 0
409+ self ._sk_cache : dict [int , bytes ] = {}
406410
407411 self ._iter_ind = 0
408412
@@ -426,14 +430,33 @@ def _get_sk(self, ind: int) -> bytes:
426430 msg = "The key index must be non-negative"
427431 raise ValueError (msg )
428432
429- if ind < self ._cur_sk_ind : # behind us; need to reset :(
430- self ._cur_sk = self ._initial_sk
431- self ._cur_sk_ind = 0
433+ # retrieve from cache
434+ cached_sk = self ._sk_cache .get (ind )
435+ if cached_sk is not None :
436+ return cached_sk
437+
438+ # not in cache: find largest cached index smaller than ind (if exists)
439+ start_ind : int = 0
440+ cur_sk : bytes = self ._initial_sk
441+ for cached_ind in self ._sk_cache :
442+ if cached_ind < ind and cached_ind > start_ind :
443+ start_ind = cached_ind
444+ cur_sk = self ._sk_cache [cached_ind ]
445+
446+ # compute and update cache
447+ for cur_ind in range (start_ind , ind ):
448+ cur_sk = crypto .x963_kdf (cur_sk , b"update" , 32 )
449+
450+ # insert intermediate result into cache and evict oldest entry if necessary
451+ if cur_ind % self ._CACHE_INTERVAL == 0 :
452+ self ._sk_cache [cur_ind ] = cur_sk
453+
454+ if len (self ._sk_cache ) > self ._CACHE_SIZE :
455+ # evict oldest entry
456+ oldest_ind = min (self ._sk_cache .keys ())
457+ del self ._sk_cache [oldest_ind ]
432458
433- for _ in range (self ._cur_sk_ind , ind ):
434- self ._cur_sk = crypto .x963_kdf (self ._cur_sk , b"update" , 32 )
435- self ._cur_sk_ind += 1
436- return self ._cur_sk
459+ return cur_sk
437460
438461 def _get_keypair (self , ind : int ) -> KeyPair :
439462 sk = self ._get_sk (ind )
@@ -449,14 +472,14 @@ def _generate_keys(self, start: int, stop: int | None) -> Generator[KeyPair, Non
449472
450473 @override
451474 def __iter__ (self ) -> KeyGenerator :
452- self ._iter_ind = - 1
453475 return self
454476
455477 @override
456478 def __next__ (self ) -> KeyPair :
479+ key = self ._get_keypair (self ._iter_ind )
457480 self ._iter_ind += 1
458481
459- return self . _get_keypair ( self . _iter_ind )
482+ return key
460483
461484 @overload
462485 def __getitem__ (self , val : int ) -> KeyPair : ...
0 commit comments