6464 from .types .audit_log import AuditLog as AuditLogPayload
6565 from .types .guild import Guild as GuildPayload
6666 from .types .message import Message as MessagePayload
67+ from .types .monetization import Entitlement as EntitlementPayload
6768 from .types .threads import Thread as ThreadPayload
6869 from .types .user import PartialUser as PartialUserPayload
6970 from .user import User
@@ -988,11 +989,21 @@ def __init__(
988989 self .guild_id = guild_id
989990 self .exclude_ended = exclude_ended
990991
992+ self ._filter = None
993+
994+ if self .before and self .after :
995+ self ._retrieve_entitlements = self ._retrieve_entitlements_before_strategy
996+ self ._filter = lambda e : int (e ["id" ]) > self .after .id
997+ elif self .after :
998+ self ._retrieve_entitlements = self ._retrieve_entitlements_after_strategy
999+ else :
1000+ self ._retrieve_entitlements = self ._retrieve_entitlements_before_strategy
1001+
9911002 self .state = state
9921003 self .get_entitlements = state .http .list_entitlements
9931004 self .entitlements = asyncio .Queue ()
9941005
995- async def next (self ) -> BanEntry :
1006+ async def next (self ) -> Entitlement :
9961007 if self .entitlements .empty ():
9971008 await self .fill_entitlements ()
9981009
@@ -1014,30 +1025,57 @@ async def fill_entitlements(self):
10141025 if not self ._get_retrieve ():
10151026 return
10161027
1028+ data = await self ._retrieve_entitlements (self .retrieve )
1029+
1030+ if self ._filter :
1031+ data = list (filter (self ._filter , data ))
1032+
1033+ if len (data ) < 100 :
1034+ self .limit = 0 # terminate loop
1035+
1036+ for element in data :
1037+ await self .entitlements .put (Entitlement (data = element , state = self .state ))
1038+
1039+ async def _retrieve_entitlements (self , retrieve ) -> list [Entitlement ]:
1040+ """Retrieve entitlements and update next parameters."""
1041+ raise NotImplementedError
1042+
1043+ async def _retrieve_entitlements_before_strategy (
1044+ self , retrieve : int
1045+ ) -> list [EntitlementPayload ]:
1046+ """Retrieve entitlements using before parameter."""
10171047 before = self .before .id if self .before else None
1018- after = self .after .id if self .after else None
10191048 data = await self .get_entitlements (
10201049 self .state .application_id ,
10211050 before = before ,
1022- after = after ,
1023- limit = self .retrieve ,
1051+ limit = retrieve ,
10241052 user_id = self .user_id ,
10251053 guild_id = self .guild_id ,
10261054 sku_ids = self .sku_ids ,
10271055 exclude_ended = self .exclude_ended ,
10281056 )
1057+ if data :
1058+ if self .limit is not None :
1059+ self .limit -= retrieve
1060+ self .before = Object (id = int (data [- 1 ]["id" ]))
1061+ return data
10291062
1030- if not data :
1031- # no data, terminate
1032- return
1033-
1034- if self .limit :
1035- self .limit -= self .retrieve
1036-
1037- if len (data ) < 100 :
1038- self .limit = 0 # terminate loop
1039-
1040- self .after = Object (id = int (data [- 1 ]["id" ]))
1041-
1042- for element in reversed (data ):
1043- await self .entitlements .put (Entitlement (data = element , state = self .state ))
1063+ async def _retrieve_entitlements_after_strategy (
1064+ self , retrieve : int
1065+ ) -> list [EntitlementPayload ]:
1066+ """Retrieve entitlements using after parameter."""
1067+ after = self .after .id if self .after else None
1068+ data = await self .get_entitlements (
1069+ self .state .application_id ,
1070+ after = after ,
1071+ limit = retrieve ,
1072+ user_id = self .user_id ,
1073+ guild_id = self .guild_id ,
1074+ sku_ids = self .sku_ids ,
1075+ exclude_ended = self .exclude_ended ,
1076+ )
1077+ if data :
1078+ if self .limit is not None :
1079+ self .limit -= retrieve
1080+ self .after = Object (id = int (data [- 1 ]["id" ]))
1081+ return data
0 commit comments