@@ -31,6 +31,7 @@ mod errors;
3131mod globals;
3232#[ cfg( feature = "metrics" ) ]
3333mod metrics;
34+ mod rate_limiter;
3435mod resolver;
3536#[ cfg( feature = "metrics" ) ]
3637mod varz;
@@ -61,6 +62,7 @@ use futures::join;
6162use futures:: prelude:: * ;
6263use globals:: * ;
6364use parking_lot:: Mutex ;
65+ use rate_limiter:: * ;
6466use parking_lot:: RwLock ;
6567#[ cfg( target_family = "unix" ) ]
6668use privdrop:: PrivDrop ;
@@ -298,7 +300,7 @@ async fn tcp_acceptor(globals: Arc<Globals>, tcp_listener: TcpListener) -> Resul
298300 let concurrent_connections = globals. tcp_concurrent_connections . clone ( ) ;
299301 let active_connections = globals. tcp_active_connections . clone ( ) ;
300302 loop {
301- let ( mut client_connection, _client_addr ) = match tcp_listener. accept ( ) . await {
303+ let ( mut client_connection, client_addr ) = match tcp_listener. accept ( ) . await {
302304 Ok ( x) => x,
303305 Err ( e) => {
304306 if e. kind ( ) == std:: io:: ErrorKind :: WouldBlock {
@@ -333,6 +335,16 @@ async fn tcp_acceptor(globals: Arc<Globals>, tcp_listener: TcpListener) -> Resul
333335 continue ;
334336 }
335337 } ;
338+
339+ if let Some ( ref rate_limiter) = globals. rate_limiter {
340+ if !rate_limiter. is_allowed ( client_addr. ip ( ) ) {
341+ debug ! ( "Rate limit exceeded for {}" , client_addr. ip( ) ) ;
342+ #[ cfg( feature = "metrics" ) ]
343+ globals. varz . client_queries_rate_limited . inc ( ) ;
344+ continue ;
345+ }
346+ }
347+
336348 let ( tx, rx) = oneshot:: channel :: < ( ) > ( ) ;
337349 let tx_channel_index = {
338350 let mut active_connections = active_connections. lock ( ) ;
@@ -409,11 +421,18 @@ async fn udp_acceptor(
409421 if packet_len < DNS_HEADER_SIZE {
410422 continue ;
411423 }
412- // Create a socket clone only when we've checked the packet is valid
413- // This helps avoid resource exhaustion
424+
425+ if let Some ( ref rate_limiter) = globals. rate_limiter {
426+ if !rate_limiter. is_allowed ( client_addr. ip ( ) ) {
427+ debug ! ( "Rate limit exceeded for {}" , client_addr. ip( ) ) ;
428+ #[ cfg( feature = "metrics" ) ]
429+ globals. varz . client_queries_rate_limited . inc ( ) ;
430+ continue ;
431+ }
432+ }
433+
414434 packet. truncate ( packet_len) ;
415435
416- // Only create a new socket if there's capacity for a new connection
417436 let active_count = concurrent_connections. load ( Ordering :: Relaxed ) ;
418437 if active_count >= globals. udp_max_active_connections {
419438 debug ! ( "UDP connection limit reached, dropping packet" ) ;
@@ -835,6 +854,19 @@ fn main() -> Result<(), Error> {
835854 }
836855 _ => None ,
837856 } ;
857+ let rate_limiter: SharedRateLimiter = match config. rate_limit {
858+ Some ( rate_limit) if rate_limit. enabled => {
859+ info ! (
860+ "Rate limiting enabled: {} queries/second per client" ,
861+ rate_limit. max_queries_per_second
862+ ) ;
863+ Some ( Arc :: new ( RateLimiter :: new (
864+ rate_limit. capacity ,
865+ rate_limit. max_queries_per_second ,
866+ ) ) )
867+ }
868+ _ => None ,
869+ } ;
838870 let runtime_handle = runtime. handle ( ) ;
839871 let globals = Arc :: new ( Globals {
840872 runtime_handle : runtime_handle. clone ( ) ,
@@ -875,6 +907,7 @@ fn main() -> Result<(), Error> {
875907 access_control_tokens,
876908 my_ip : config. my_ip . map ( |ip| ip. as_bytes ( ) . to_ascii_lowercase ( ) ) ,
877909 client_ttl_holdon : config. client_ttl_holdon . unwrap_or ( 60 ) ,
910+ rate_limiter,
878911 #[ cfg( feature = "metrics" ) ]
879912 varz : Varz :: default ( ) ,
880913 } ) ;
0 commit comments