@@ -2,7 +2,7 @@ use std::fmt::{self, Debug};
22
33use relay_common:: time:: UnixTimestamp ;
44use relay_log:: protocol:: value;
5- use relay_redis:: redis:: Script ;
5+ use relay_redis:: redis:: { self , FromRedisValue , Script } ;
66use relay_redis:: { AsyncRedisClient , RedisError , RedisScripts } ;
77use thiserror:: Error ;
88
@@ -346,13 +346,13 @@ impl<T: GlobalLimiter> RedisRateLimiter<T> {
346346 // client across await points, otherwise it might be held for too long, and we will run out
347347 // of connections.
348348 let mut connection = self . client . get_connection ( ) . await ?;
349- let rejections : Vec < bool > = invocation
349+ let results : ScriptResult = invocation
350350 . invoke_async ( & mut connection)
351351 . await
352352 . map_err ( RedisError :: Redis ) ?;
353353
354- for ( quota, is_rejected ) in tracked_quotas. iter ( ) . zip ( rejections ) {
355- if is_rejected {
354+ for ( quota, state ) in tracked_quotas. iter ( ) . zip ( results . 0 ) {
355+ if state . is_rejected {
356356 let retry_after = self . retry_after ( ( quota. expiry ( ) - timestamp) . as_secs ( ) ) ;
357357 rate_limits. add ( RateLimit :: from_quota ( quota, * item_scoping, retry_after) ) ;
358358 }
@@ -373,6 +373,51 @@ impl<T: GlobalLimiter> RedisRateLimiter<T> {
373373 }
374374}
375375
376+ /// The result returned from the rate limiting Redis script.
377+ #[ derive( Debug ) ]
378+ struct ScriptResult ( Vec < QuotaState > ) ;
379+
380+ impl FromRedisValue for ScriptResult {
381+ fn from_redis_value ( v : & redis:: Value ) -> redis:: RedisResult < Self > {
382+ let Some ( seq) = v. as_sequence ( ) else {
383+ return Err ( redis:: RedisError :: from ( (
384+ redis:: ErrorKind :: TypeError ,
385+ "Expected a sequence from the rate limiting script" ,
386+ format ! ( "{v:?}" ) ,
387+ ) ) ) ;
388+ } ;
389+
390+ let ( chunks, rem) = seq. as_chunks ( ) ;
391+ if !rem. is_empty ( ) {
392+ return Err ( redis:: RedisError :: from ( (
393+ redis:: ErrorKind :: TypeError ,
394+ "Expected an even number of values from the rate limiting script" ,
395+ format ! ( "{v:?}" ) ,
396+ ) ) ) ;
397+ }
398+
399+ let mut result = Vec :: with_capacity ( chunks. len ( ) ) ;
400+ for [ is_rejected, consumed] in chunks {
401+ result. push ( QuotaState {
402+ is_rejected : bool:: from_redis_value ( is_rejected) ?,
403+ consumed : i64:: from_redis_value ( consumed) ?,
404+ } ) ;
405+ }
406+
407+ Ok ( Self ( result) )
408+ }
409+ }
410+
411+ /// The state returned from the rate limiting script for a single quota.
412+ #[ derive( Debug ) ]
413+ struct QuotaState {
414+ /// Whether the quota rejects the request.
415+ is_rejected : bool ,
416+ /// How much of the quota has already been consumed, before adding the requested quantity.
417+ #[ expect( unused, reason = "not yet used" ) ]
418+ consumed : i64 ,
419+ }
420+
376421#[ cfg( test) ]
377422mod tests {
378423 use std:: time:: { SystemTime , UNIX_EPOCH } ;
@@ -1003,7 +1048,6 @@ mod tests {
10031048 }
10041049
10051050 #[ tokio:: test]
1006- #[ allow( clippy:: disallowed_names, clippy:: let_unit_value) ]
10071051 async fn test_is_rate_limited_script ( ) {
10081052 let now = SystemTime :: now ( )
10091053 . duration_since ( UNIX_EPOCH )
@@ -1024,6 +1068,17 @@ mod tests {
10241068
10251069 let script = RedisScripts :: load_is_rate_limited ( ) ;
10261070
1071+ macro_rules! assert_invocation {
1072+ ( $invocation: expr, $( $tt: tt) * ) => { {
1073+ let result = $invocation
1074+ . invoke_async:: <ScriptResult >( & mut conn)
1075+ . await
1076+ . unwrap( ) ;
1077+
1078+ insta:: assert_debug_snapshot!( result, $( $tt) * ) ;
1079+ } } ;
1080+ }
1081+
10271082 let mut invocation = script. prepare_invoke ( ) ;
10281083 invocation
10291084 . key ( & foo) // key
@@ -1039,42 +1094,124 @@ mod tests {
10391094 . arg ( 1 ) // quantity
10401095 . arg ( false ) ; // over accept once
10411096
1042- // The item should not be rate limited by either key.
1043- assert_eq ! (
1044- invocation
1045- . invoke_async:: <Vec <bool >>( & mut conn)
1046- . await
1047- . unwrap( ) ,
1048- vec![ false , false ]
1097+ // Craft a new invocation similar to the previous one, but it only applies to the quota
1098+ // with a higher limit (2).
1099+ let mut invocation2 = script. prepare_invoke ( ) ;
1100+ invocation2
1101+ . key ( & bar) // key
1102+ . key ( & r_bar) // refund key
1103+ . arg ( 2 ) // limit
1104+ . arg ( now + 120 ) // expiry
1105+ . arg ( 1 ) // quantity
1106+ . arg ( false ) ; // over accept once
1107+
1108+ // Current usage is 0. But current values are now incremented by 1 (quantity).
1109+ assert_invocation ! ( invocation, @r"
1110+ ScriptResult(
1111+ [
1112+ QuotaState {
1113+ is_rejected: false,
1114+ consumed: 0,
1115+ },
1116+ QuotaState {
1117+ is_rejected: false,
1118+ consumed: 0,
1119+ },
1120+ ],
1121+ )
1122+ "
10491123 ) ;
10501124
1051- // The item should be rate limited by the first key (1).
1052- assert_eq ! (
1053- invocation
1054- . invoke_async:: <Vec <bool >>( & mut conn)
1055- . await
1056- . unwrap( ) ,
1057- vec![ true , false ]
1125+ // The usage was incremented in the last invocation, this invocation fails the rate limit
1126+ // on the first quota. -> No changes are made to the counters, the next invocation still
1127+ // needs to be `[1, 1]`.
1128+ assert_invocation ! ( invocation, @r"
1129+ ScriptResult(
1130+ [
1131+ QuotaState {
1132+ is_rejected: true,
1133+ consumed: 1,
1134+ },
1135+ QuotaState {
1136+ is_rejected: false,
1137+ consumed: 1,
1138+ },
1139+ ],
1140+ )
1141+ "
10581142 ) ;
10591143
10601144 // The item should still be rate limited by the first key (1), but *not*
10611145 // rate limited by the second key (2) even though this is the third time
10621146 // we've checked the quotas. This ensures items that are rejected by a lower
10631147 // quota don't affect unrelated items that share a parent quota.
1064- assert_eq ! (
1065- invocation
1066- . invoke_async:: <Vec <bool >>( & mut conn)
1067- . await
1068- . unwrap( ) ,
1069- vec![ true , false ]
1148+ assert_invocation ! ( invocation, @r"
1149+ ScriptResult(
1150+ [
1151+ QuotaState {
1152+ is_rejected: true,
1153+ consumed: 1,
1154+ },
1155+ QuotaState {
1156+ is_rejected: false,
1157+ consumed: 1,
1158+ },
1159+ ],
1160+ )
1161+ "
1162+ ) ;
1163+
1164+ // Using the second invocation which only considers a quota with a higher limit, this
1165+ // should still yield the current value of `1` and the next invocation should yield `2`.
1166+ assert_invocation ! ( invocation2, @r"
1167+ ScriptResult(
1168+ [
1169+ QuotaState {
1170+ is_rejected: false,
1171+ consumed: 1,
1172+ },
1173+ ],
1174+ )
1175+ "
1176+ ) ;
1177+
1178+ // This now yields `2`. This is also the invocation at the limit, which means it should no
1179+ // longer increment the counter.
1180+ assert_invocation ! ( invocation2, @r"
1181+ ScriptResult(
1182+ [
1183+ QuotaState {
1184+ is_rejected: true,
1185+ consumed: 2,
1186+ },
1187+ ],
1188+ )
1189+ "
1190+ ) ;
1191+
1192+ // Check again with the original invocation, this now yields `[1, 2]`.
1193+ assert_invocation ! ( invocation, @r"
1194+ ScriptResult(
1195+ [
1196+ QuotaState {
1197+ is_rejected: true,
1198+ consumed: 1,
1199+ },
1200+ QuotaState {
1201+ is_rejected: true,
1202+ consumed: 2,
1203+ },
1204+ ],
1205+ )
1206+ "
10701207 ) ;
10711208
10721209 assert_eq ! ( conn. get:: <_, String >( & foo) . await . unwrap( ) , "1" ) ;
10731210 let ttl: u64 = conn. ttl ( & foo) . await . unwrap ( ) ;
10741211 assert ! ( ttl >= 59 ) ;
10751212 assert ! ( ttl <= 60 ) ;
10761213
1077- assert_eq ! ( conn. get:: <_, String >( & bar) . await . unwrap( ) , "1 " ) ;
1214+ assert_eq ! ( conn. get:: <_, String >( & bar) . await . unwrap( ) , "2 " ) ;
10781215 let ttl: u64 = conn. ttl ( & bar) . await . unwrap ( ) ;
10791216 assert ! ( ttl >= 119 ) ;
10801217 assert ! ( ttl <= 120 ) ;
@@ -1095,22 +1232,43 @@ mod tests {
10951232 . arg ( 1 ) // quantity
10961233 . arg ( false ) ;
10971234
1098- // increment
1099- assert_eq ! (
1100- invocation
1101- . invoke_async:: <Vec <bool >>( & mut conn)
1102- . await
1103- . unwrap( ) ,
1104- vec![ false ]
1235+ // increment, current quota is 0.
1236+ assert_invocation ! ( invocation, @r"
1237+ ScriptResult(
1238+ [
1239+ QuotaState {
1240+ is_rejected: false,
1241+ consumed: 0,
1242+ },
1243+ ],
1244+ )
1245+ "
11051246 ) ;
11061247
1107- // test that it's rate limited without refund
1108- assert_eq ! (
1109- invocation
1110- . invoke_async:: <Vec <bool >>( & mut conn)
1111- . await
1112- . unwrap( ) ,
1113- vec![ true ]
1248+ // test that it's rate limited without refund.
1249+ assert_invocation ! ( invocation, @r"
1250+ ScriptResult(
1251+ [
1252+ QuotaState {
1253+ is_rejected: true,
1254+ consumed: 1,
1255+ },
1256+ ],
1257+ )
1258+ "
1259+ ) ;
1260+
1261+ // Make sure, the counter wasn't incremented.
1262+ assert_invocation ! ( invocation, @r"
1263+ ScriptResult(
1264+ [
1265+ QuotaState {
1266+ is_rejected: true,
1267+ consumed: 1,
1268+ },
1269+ ],
1270+ )
1271+ "
11141272 ) ;
11151273
11161274 let mut invocation = script. prepare_invoke ( ) ;
@@ -1123,12 +1281,16 @@ mod tests {
11231281 . arg ( false ) ;
11241282
11251283 // test that refund key is used
1126- assert_eq ! (
1127- invocation
1128- . invoke_async:: <Vec <bool >>( & mut conn)
1129- . await
1130- . unwrap( ) ,
1131- vec![ false ]
1284+ assert_invocation ! ( invocation, @r"
1285+ ScriptResult(
1286+ [
1287+ QuotaState {
1288+ is_rejected: false,
1289+ consumed: -4,
1290+ },
1291+ ],
1292+ )
1293+ "
11321294 ) ;
11331295 }
11341296}
0 commit comments