Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 155 additions & 1 deletion src/fair_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ where
{
type Item = (K, T);

#[allow(clippy::needless_continue)]
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let fair_queue = self.get_mut();
loop {
Expand Down Expand Up @@ -163,7 +164,61 @@ impl<S, K: Clone> FairQueue<S, K> {
mod test {
use crate::async_rt;
use crate::fair_queue::FairQueue;
use futures::{stream, StreamExt};
use futures::task::noop_waker;
use futures::{stream, Stream, StreamExt};
use std::collections::VecDeque;
use std::pin::Pin;
use std::task::{Context, Poll};

/// Test stream that yields Pending for the first N polls, then emits messages FIFO
struct TestStream {
pending_polls: usize,
messages: VecDeque<&'static str>,
}

impl TestStream {
fn new(pending_polls: usize, messages: &[&'static str]) -> Self {
Self {
pending_polls,
messages: messages.iter().copied().collect(),
}
}

fn ready(messages: &[&'static str]) -> Self {
Self::new(0, messages)
}

fn pending_once(messages: &[&'static str]) -> Self {
Self::new(1, messages)
}
}

impl Stream for TestStream {
type Item = &'static str;

fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.pending_polls > 0 {
this.pending_polls -= 1;
return Poll::Pending;
}
Poll::Ready(this.messages.pop_front())
}
}

enum UnifiedStream {
Test(TestStream),
}

impl Stream for UnifiedStream {
type Item = &'static str;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.get_mut() {
UnifiedStream::Test(stream) => Pin::new(stream).poll_next(cx),
}
}
}

#[async_rt::test]
async fn test_fair_queue_ready() {
Expand All @@ -184,6 +239,7 @@ mod test {
while let Some(i) = f_queue.next().await {
results.push(i);
}

assert_eq!(
results,
vec![
Expand Down Expand Up @@ -219,6 +275,7 @@ mod test {
while let Some(i) = f_queue.next().await {
results.push(i);
}

assert_eq!(
results,
vec![
Expand All @@ -231,4 +288,101 @@ mod test {
]
);
}

#[test]
fn test_fair_queue_continues_on_pending() {
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);

let mut fair_queue: FairQueue<UnifiedStream, &str> = FairQueue::new(false);
{
let inner = fair_queue.inner();
let mut lock = inner.lock();
lock.insert(
"slow",
UnifiedStream::Test(TestStream::pending_once(&["s1"])),
);
lock.insert(
"fast",
UnifiedStream::Test(TestStream::ready(&["f1", "f2"])),
);
}

// First poll should return fast stream (regression test: no starvation)
let result = Pin::new(&mut fair_queue).poll_next(&mut cx);
match result {
Poll::Ready(Some((key, value))) => {
assert_eq!(key, "fast");
assert_eq!(value, "f1");
}
other => panic!("Expected fast stream first, got: {:#?}", other),
}

// Second poll: fast stream still ready, slow stream pending
let result = Pin::new(&mut fair_queue).poll_next(&mut cx);
match result {
Poll::Ready(Some((key, value))) => {
assert_eq!(key, "fast");
assert_eq!(value, "f2");
}
other => panic!("Expected fast stream second, got: {:#?}", other),
}

// Third poll: With noop_waker, slow stream hasn't been re-polled
let result = Pin::new(&mut fair_queue).poll_next(&mut cx);
match result {
Poll::Pending => {} // Expected with noop_waker
other => panic!("Expected Pending, got: {:#?}", other),
}
}

#[test]
fn test_fair_queue_multiple_clients_fairness() {
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);

let mut fair_queue: FairQueue<UnifiedStream, &str> = FairQueue::new(false);
{
let inner = fair_queue.inner();
let mut lock = inner.lock();
lock.insert(
"fast",
UnifiedStream::Test(TestStream::ready(&["f1", "f2", "f3"])),
);
lock.insert("slow", UnifiedStream::Test(TestStream::new(2, &["s1"])));
lock.insert(
"mid",
UnifiedStream::Test(TestStream::new(1, &["m1", "m2"])),
);
}

let mut messages = Vec::new();
const MAX_ITERATIONS: usize = 20; // Upper bound - 3 for fast, 2 for mid, 1 for slow.

for _ in 0..MAX_ITERATIONS {
match Pin::new(&mut fair_queue).poll_next(&mut cx) {
Poll::Ready(Some((key, value))) => {
messages.push(format!("{}:{}", key, value));

let has_slow = messages.iter().any(|m| m.starts_with("slow:"));
let fast_count = messages.iter().filter(|m| m.starts_with("fast:")).count();
let mid_count = messages.iter().filter(|m| m.starts_with("mid:")).count();

if has_slow && fast_count == 3 && mid_count == 2 {
break;
}
}
Poll::Ready(None) => break,
Poll::Pending => continue,
}
}

// Ensure fast stream isn't starved by pending streams
let fast_messages = messages.iter().filter(|m| m.starts_with("fast:")).count();
assert!(
fast_messages >= 1,
"Fast stream was starved: {:?}",
messages
);
}
}
Loading