diff --git a/src/fair_queue.rs b/src/fair_queue.rs index 7a91311..ec850b1 100644 --- a/src/fair_queue.rs +++ b/src/fair_queue.rs @@ -90,6 +90,7 @@ where { type Item = (K, T); + #[allow(clippy::needless_continue)] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let fair_queue = self.get_mut(); loop { @@ -163,7 +164,61 @@ impl FairQueue { mod test { use crate::async_rt; use crate::fair_queue::FairQueue; - use futures::{stream, StreamExt}; + use futures::task::noop_waker; + use futures::{stream, Stream, StreamExt}; + use std::collections::VecDeque; + use std::pin::Pin; + use std::task::{Context, Poll}; + + /// Test stream that yields Pending for the first N polls, then emits messages FIFO + struct TestStream { + pending_polls: usize, + messages: VecDeque<&'static str>, + } + + impl TestStream { + fn new(pending_polls: usize, messages: &[&'static str]) -> Self { + Self { + pending_polls, + messages: messages.iter().copied().collect(), + } + } + + fn ready(messages: &[&'static str]) -> Self { + Self::new(0, messages) + } + + fn pending_once(messages: &[&'static str]) -> Self { + Self::new(1, messages) + } + } + + impl Stream for TestStream { + type Item = &'static str; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + if this.pending_polls > 0 { + this.pending_polls -= 1; + return Poll::Pending; + } + Poll::Ready(this.messages.pop_front()) + } + } + + enum UnifiedStream { + Test(TestStream), + } + + impl Stream for UnifiedStream { + type Item = &'static str; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + UnifiedStream::Test(stream) => Pin::new(stream).poll_next(cx), + } + } + } #[async_rt::test] async fn test_fair_queue_ready() { @@ -184,6 +239,7 @@ mod test { while let Some(i) = f_queue.next().await { results.push(i); } + assert_eq!( results, vec![ @@ -219,6 +275,7 @@ mod test { while let Some(i) = f_queue.next().await { results.push(i); } + assert_eq!( results, vec![ @@ -231,4 +288,101 @@ mod test { ] ); } + + #[test] + fn test_fair_queue_continues_on_pending() { + let waker = noop_waker(); + let mut cx = Context::from_waker(&waker); + + let mut fair_queue: FairQueue = FairQueue::new(false); + { + let inner = fair_queue.inner(); + let mut lock = inner.lock(); + lock.insert( + "slow", + UnifiedStream::Test(TestStream::pending_once(&["s1"])), + ); + lock.insert( + "fast", + UnifiedStream::Test(TestStream::ready(&["f1", "f2"])), + ); + } + + // First poll should return fast stream (regression test: no starvation) + let result = Pin::new(&mut fair_queue).poll_next(&mut cx); + match result { + Poll::Ready(Some((key, value))) => { + assert_eq!(key, "fast"); + assert_eq!(value, "f1"); + } + other => panic!("Expected fast stream first, got: {:#?}", other), + } + + // Second poll: fast stream still ready, slow stream pending + let result = Pin::new(&mut fair_queue).poll_next(&mut cx); + match result { + Poll::Ready(Some((key, value))) => { + assert_eq!(key, "fast"); + assert_eq!(value, "f2"); + } + other => panic!("Expected fast stream second, got: {:#?}", other), + } + + // Third poll: With noop_waker, slow stream hasn't been re-polled + let result = Pin::new(&mut fair_queue).poll_next(&mut cx); + match result { + Poll::Pending => {} // Expected with noop_waker + other => panic!("Expected Pending, got: {:#?}", other), + } + } + + #[test] + fn test_fair_queue_multiple_clients_fairness() { + let waker = noop_waker(); + let mut cx = Context::from_waker(&waker); + + let mut fair_queue: FairQueue = FairQueue::new(false); + { + let inner = fair_queue.inner(); + let mut lock = inner.lock(); + lock.insert( + "fast", + UnifiedStream::Test(TestStream::ready(&["f1", "f2", "f3"])), + ); + lock.insert("slow", UnifiedStream::Test(TestStream::new(2, &["s1"]))); + lock.insert( + "mid", + UnifiedStream::Test(TestStream::new(1, &["m1", "m2"])), + ); + } + + let mut messages = Vec::new(); + const MAX_ITERATIONS: usize = 20; // Upper bound - 3 for fast, 2 for mid, 1 for slow. + + for _ in 0..MAX_ITERATIONS { + match Pin::new(&mut fair_queue).poll_next(&mut cx) { + Poll::Ready(Some((key, value))) => { + messages.push(format!("{}:{}", key, value)); + + let has_slow = messages.iter().any(|m| m.starts_with("slow:")); + let fast_count = messages.iter().filter(|m| m.starts_with("fast:")).count(); + let mid_count = messages.iter().filter(|m| m.starts_with("mid:")).count(); + + if has_slow && fast_count == 3 && mid_count == 2 { + break; + } + } + Poll::Ready(None) => break, + Poll::Pending => continue, + } + } + + // Ensure fast stream isn't starved by pending streams + let fast_messages = messages.iter().filter(|m| m.starts_with("fast:")).count(); + assert!( + fast_messages >= 1, + "Fast stream was starved: {:?}", + messages + ); + } }