Skip to content

Commit 023bf2e

Browse files
committed
Implement rate limiting
1 parent 0457ad4 commit 023bf2e

File tree

6 files changed

+185
-4
lines changed

6 files changed

+185
-4
lines changed

example-encrypted-dns.toml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,22 @@ enabled = false
259259

260260
tokens = ["Y2oHkDJNHz", "G5zY3J5cHQtY", "C5zZWN1cmUuZG5z"]
261261

262+
263+
################################
264+
# Rate limiting #
265+
################################
266+
267+
[rate_limit]
268+
269+
# Enable per-client rate limiting
270+
271+
enabled = false
272+
273+
# Maximum queries per second per client IP
274+
275+
max_queries_per_second = 100
276+
277+
# Maximum number of client IPs to track (uses SIEVE cache for automatic eviction)
278+
279+
capacity = 10000
280+

src/config.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ pub struct AccessControlConfig {
1515
pub tokens: Vec<String>,
1616
}
1717

18+
#[derive(Serialize, Deserialize, Debug, Clone)]
19+
pub struct RateLimitConfig {
20+
pub enabled: bool,
21+
pub max_queries_per_second: u32,
22+
pub capacity: usize,
23+
}
24+
1825
#[derive(Serialize, Deserialize, Debug, Clone)]
1926
pub struct AnonymizedDNSConfig {
2027
pub enabled: bool,
@@ -111,6 +118,7 @@ pub struct Config {
111118
pub metrics: Option<MetricsConfig>,
112119
pub anonymized_dns: Option<AnonymizedDNSConfig>,
113120
pub access_control: Option<AccessControlConfig>,
121+
pub rate_limit: Option<RateLimitConfig>,
114122
}
115123

116124
impl Config {

src/globals.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use crate::blacklist::*;
1414
use crate::cache::*;
1515
use crate::crypto::*;
1616
use crate::dnscrypt_certs::*;
17+
use crate::rate_limiter::*;
1718
#[cfg(feature = "metrics")]
1819
use crate::varz::*;
1920

@@ -52,6 +53,8 @@ pub struct Globals {
5253
pub access_control_tokens: Option<Vec<String>>,
5354
pub client_ttl_holdon: u32,
5455
pub my_ip: Option<Vec<u8>>,
56+
#[educe(Debug(ignore))]
57+
pub rate_limiter: SharedRateLimiter,
5558
#[cfg(feature = "metrics")]
5659
#[educe(Debug(ignore))]
5760
pub varz: Varz,

src/main.rs

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ mod errors;
3131
mod globals;
3232
#[cfg(feature = "metrics")]
3333
mod metrics;
34+
mod rate_limiter;
3435
mod resolver;
3536
#[cfg(feature = "metrics")]
3637
mod varz;
@@ -61,6 +62,7 @@ use futures::join;
6162
use futures::prelude::*;
6263
use globals::*;
6364
use parking_lot::Mutex;
65+
use rate_limiter::*;
6466
use parking_lot::RwLock;
6567
#[cfg(target_family = "unix")]
6668
use privdrop::PrivDrop;
@@ -298,7 +300,7 @@ async fn tcp_acceptor(globals: Arc<Globals>, tcp_listener: TcpListener) -> Resul
298300
let concurrent_connections = globals.tcp_concurrent_connections.clone();
299301
let active_connections = globals.tcp_active_connections.clone();
300302
loop {
301-
let (mut client_connection, _client_addr) = match tcp_listener.accept().await {
303+
let (mut client_connection, client_addr) = match tcp_listener.accept().await {
302304
Ok(x) => x,
303305
Err(e) => {
304306
if e.kind() == std::io::ErrorKind::WouldBlock {
@@ -333,6 +335,16 @@ async fn tcp_acceptor(globals: Arc<Globals>, tcp_listener: TcpListener) -> Resul
333335
continue;
334336
}
335337
};
338+
339+
if let Some(ref rate_limiter) = globals.rate_limiter {
340+
if !rate_limiter.is_allowed(client_addr.ip()) {
341+
debug!("Rate limit exceeded for {}", client_addr.ip());
342+
#[cfg(feature = "metrics")]
343+
globals.varz.client_queries_rate_limited.inc();
344+
continue;
345+
}
346+
}
347+
336348
let (tx, rx) = oneshot::channel::<()>();
337349
let tx_channel_index = {
338350
let mut active_connections = active_connections.lock();
@@ -409,11 +421,18 @@ async fn udp_acceptor(
409421
if packet_len < DNS_HEADER_SIZE {
410422
continue;
411423
}
412-
// Create a socket clone only when we've checked the packet is valid
413-
// This helps avoid resource exhaustion
424+
425+
if let Some(ref rate_limiter) = globals.rate_limiter {
426+
if !rate_limiter.is_allowed(client_addr.ip()) {
427+
debug!("Rate limit exceeded for {}", client_addr.ip());
428+
#[cfg(feature = "metrics")]
429+
globals.varz.client_queries_rate_limited.inc();
430+
continue;
431+
}
432+
}
433+
414434
packet.truncate(packet_len);
415435

416-
// Only create a new socket if there's capacity for a new connection
417436
let active_count = concurrent_connections.load(Ordering::Relaxed);
418437
if active_count >= globals.udp_max_active_connections {
419438
debug!("UDP connection limit reached, dropping packet");
@@ -835,6 +854,19 @@ fn main() -> Result<(), Error> {
835854
}
836855
_ => None,
837856
};
857+
let rate_limiter: SharedRateLimiter = match config.rate_limit {
858+
Some(rate_limit) if rate_limit.enabled => {
859+
info!(
860+
"Rate limiting enabled: {} queries/second per client",
861+
rate_limit.max_queries_per_second
862+
);
863+
Some(Arc::new(RateLimiter::new(
864+
rate_limit.capacity,
865+
rate_limit.max_queries_per_second,
866+
)))
867+
}
868+
_ => None,
869+
};
838870
let runtime_handle = runtime.handle();
839871
let globals = Arc::new(Globals {
840872
runtime_handle: runtime_handle.clone(),
@@ -875,6 +907,7 @@ fn main() -> Result<(), Error> {
875907
access_control_tokens,
876908
my_ip: config.my_ip.map(|ip| ip.as_bytes().to_ascii_lowercase()),
877909
client_ttl_holdon: config.client_ttl_holdon.unwrap_or(60),
910+
rate_limiter,
878911
#[cfg(feature = "metrics")]
879912
varz: Varz::default(),
880913
});

src/rate_limiter.rs

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
use std::net::IpAddr;
2+
use std::sync::Arc;
3+
4+
use coarsetime::Instant;
5+
use parking_lot::Mutex;
6+
use sieve_cache::SieveCache;
7+
8+
const DEFAULT_CAPACITY: usize = 10000;
9+
const DEFAULT_MAX_QPS: u32 = 100;
10+
const MICROTOKENS_PER_TOKEN: u64 = 1_000_000;
11+
12+
struct ClientState {
13+
microtokens: u64,
14+
last_update: Instant,
15+
}
16+
17+
pub struct RateLimiter {
18+
clients: Mutex<SieveCache<IpAddr, ClientState>>,
19+
max_microtokens: u64,
20+
refill_rate: u64, // microtokens per microsecond (equals max_qps)
21+
}
22+
23+
impl RateLimiter {
24+
pub fn new(capacity: usize, max_queries_per_second: u32) -> Self {
25+
let capacity = if capacity == 0 {
26+
DEFAULT_CAPACITY
27+
} else {
28+
capacity
29+
};
30+
let max_qps = if max_queries_per_second == 0 {
31+
DEFAULT_MAX_QPS
32+
} else {
33+
max_queries_per_second
34+
};
35+
RateLimiter {
36+
clients: Mutex::new(
37+
SieveCache::new(capacity).expect("Failed to create rate limiter cache"),
38+
),
39+
max_microtokens: (max_qps as u64).saturating_mul(MICROTOKENS_PER_TOKEN),
40+
refill_rate: max_qps as u64,
41+
}
42+
}
43+
44+
pub fn is_allowed(&self, client_ip: IpAddr) -> bool {
45+
let now = Instant::now();
46+
let mut clients = self.clients.lock();
47+
48+
if let Some(state) = clients.get_mut(&client_ip) {
49+
let elapsed_us = now.as_ticks().saturating_sub(state.last_update.as_ticks());
50+
let refill = elapsed_us.saturating_mul(self.refill_rate);
51+
state.microtokens = state.microtokens.saturating_add(refill).min(self.max_microtokens);
52+
state.last_update = now;
53+
54+
if state.microtokens >= MICROTOKENS_PER_TOKEN {
55+
state.microtokens -= MICROTOKENS_PER_TOKEN;
56+
true
57+
} else {
58+
false
59+
}
60+
} else {
61+
let state = ClientState {
62+
microtokens: self.max_microtokens.saturating_sub(MICROTOKENS_PER_TOKEN),
63+
last_update: now,
64+
};
65+
clients.insert(client_ip, state);
66+
true
67+
}
68+
}
69+
}
70+
71+
pub type SharedRateLimiter = Option<Arc<RateLimiter>>;
72+
73+
#[cfg(test)]
74+
mod tests {
75+
use super::*;
76+
use std::net::{IpAddr, Ipv4Addr};
77+
78+
#[test]
79+
fn test_rate_limiter_allows_initial_requests() {
80+
let limiter = RateLimiter::new(100, 10);
81+
let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
82+
83+
assert!(limiter.is_allowed(ip));
84+
}
85+
86+
#[test]
87+
fn test_rate_limiter_exhausts_tokens() {
88+
let limiter = RateLimiter::new(100, 3);
89+
let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2));
90+
91+
let mut allowed = 0;
92+
for _ in 0..10 {
93+
if limiter.is_allowed(ip) {
94+
allowed += 1;
95+
}
96+
}
97+
assert!(allowed >= 3 && allowed <= 5);
98+
}
99+
100+
#[test]
101+
fn test_rate_limiter_separate_clients() {
102+
let limiter = RateLimiter::new(100, 100);
103+
let ip1 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
104+
let ip2 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2));
105+
106+
assert!(limiter.is_allowed(ip1));
107+
assert!(limiter.is_allowed(ip2));
108+
assert!(limiter.is_allowed(ip1));
109+
assert!(limiter.is_allowed(ip2));
110+
}
111+
}

src/varz.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pub struct Inner {
2020
pub client_queries_blocked: IntCounter,
2121
pub client_queries_resolved: IntCounter,
2222
pub client_queries_rcode_nxdomain: IntCounter,
23+
pub client_queries_rate_limited: IntCounter,
2324
pub inflight_udp_queries: IntGauge,
2425
pub inflight_tcp_queries: IntGauge,
2526
pub upstream_errors: IntCounter,
@@ -117,6 +118,12 @@ impl Inner {
117118
labels! {"handler" => "all",}
118119
))
119120
.unwrap(),
121+
client_queries_rate_limited: register_int_counter!(opts!(
122+
"encrypted_dns_client_queries_rate_limited",
123+
"Number of client queries dropped due to rate limiting",
124+
labels! {"handler" => "all",}
125+
))
126+
.unwrap(),
120127
inflight_udp_queries: register_int_gauge!(opts!(
121128
"encrypted_dns_inflight_udp_queries",
122129
"Number of UDP queries currently waiting for a response",

0 commit comments

Comments
 (0)