diff --git a/Cargo.toml b/Cargo.toml index adf10f991..80efaf12c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ all-features = true async-compression = { version = "0.3.7", features = ["brotli", "deflate", "gzip", "tokio"], optional = true } bytes = "1.0" futures-util = { version = "0.3", default-features = false, features = ["sink"] } +futures-channel = { version = "0.3.17", features = ["sink"]} headers = "0.3" http = "0.2" hyper = { version = "0.14", features = ["stream", "server", "http1", "tcp", "client"] } diff --git a/src/test.rs b/src/test.rs index 06262d275..a656aca43 100644 --- a/src/test.rs +++ b/src/test.rs @@ -89,10 +89,14 @@ use std::net::SocketAddr; #[cfg(feature = "websocket")] use std::pin::Pin; #[cfg(feature = "websocket")] +use std::task::Context; +#[cfg(feature = "websocket")] use std::task::{self, Poll}; use bytes::Bytes; #[cfg(feature = "websocket")] +use futures_channel::mpsc; +#[cfg(feature = "websocket")] use futures_util::StreamExt; use futures_util::{future, FutureExt, TryFutureExt}; use http::{ @@ -102,15 +106,17 @@ use http::{ use serde::Serialize; use serde_json; #[cfg(feature = "websocket")] -use tokio::sync::{mpsc, oneshot}; -#[cfg(feature = "websocket")] -use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio::sync::oneshot; use crate::filter::Filter; +#[cfg(feature = "websocket")] +use crate::filters::ws::Message; use crate::reject::IsReject; use crate::reply::Reply; use crate::route::{self, Route}; use crate::Request; +#[cfg(feature = "websocket")] +use crate::{Sink, Stream}; use self::inner::OneOrTuple; @@ -484,9 +490,8 @@ impl WsBuilder { F::Error: IsReject + Send, { let (upgraded_tx, upgraded_rx) = oneshot::channel(); - let (wr_tx, wr_rx) = mpsc::unbounded_channel(); - let wr_rx = UnboundedReceiverStream::new(wr_rx); - let (rd_tx, rd_rx) = mpsc::unbounded_channel(); + let (wr_tx, wr_rx) = mpsc::unbounded(); + let (rd_tx, rd_rx) = mpsc::unbounded(); tokio::spawn(async move { use tokio_tungstenite::tungstenite::protocol; @@ -546,7 +551,7 @@ impl WsBuilder { Ok(m) => future::ready(!m.is_close()), }) .for_each(move |item| { - rd_tx.send(item).expect("ws receive error"); + rd_tx.unbounded_send(item).expect("ws receive error"); future::ready(()) }); @@ -573,13 +578,13 @@ impl WsClient { /// Send a websocket message to the server. pub async fn send(&mut self, msg: crate::ws::Message) { - self.tx.send(msg).unwrap(); + self.tx.unbounded_send(msg).unwrap(); } /// Receive a websocket message from the server. pub async fn recv(&mut self) -> Result { self.rx - .recv() + .next() .await .map(|result| result.map_err(WsError::new)) .unwrap_or_else(|| { @@ -591,7 +596,7 @@ impl WsClient { /// Assert the server has closed the connection. pub async fn recv_closed(&mut self) -> Result<(), WsError> { self.rx - .recv() + .next() .await .map(|result| match result { Ok(msg) => Err(WsError::new(format!("received message: {:?}", msg))), @@ -602,6 +607,11 @@ impl WsClient { Ok(()) }) } + + fn pinned_tx(self: Pin<&mut Self>) -> Pin<&mut mpsc::UnboundedSender> { + let this = Pin::into_inner(self); + Pin::new(&mut this.tx) + } } #[cfg(feature = "websocket")] @@ -611,6 +621,51 @@ impl fmt::Debug for WsClient { } } +#[cfg(feature = "websocket")] +impl Sink for WsClient { + type Error = WsError; + + fn poll_ready( + self: Pin<&mut Self>, + context: &mut Context<'_>, + ) -> Poll> { + self.pinned_tx().poll_ready(context).map_err(WsError::new) + } + + fn start_send(self: Pin<&mut Self>, message: Message) -> Result<(), Self::Error> { + self.pinned_tx().start_send(message).map_err(WsError::new) + } + + fn poll_flush( + self: Pin<&mut Self>, + context: &mut Context<'_>, + ) -> Poll> { + self.pinned_tx().poll_flush(context).map_err(WsError::new) + } + + fn poll_close( + self: Pin<&mut Self>, + context: &mut Context<'_>, + ) -> Poll> { + self.pinned_tx().poll_close(context).map_err(WsError::new) + } +} + +#[cfg(feature = "websocket")] +impl Stream for WsClient { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll> { + let this = Pin::into_inner(self); + let rx = Pin::new(&mut this.rx); + match rx.poll_next(context) { + Poll::Ready(Some(result)) => Poll::Ready(Some(result.map_err(WsError::new))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + // ===== impl WsError ===== #[cfg(feature = "websocket")] diff --git a/tests/ws.rs b/tests/ws.rs index 2a14d1e18..d5b60356e 100644 --- a/tests/ws.rs +++ b/tests/ws.rs @@ -81,6 +81,21 @@ async fn binary() { assert_eq!(msg.as_bytes(), &b"bonk"[..]); } +#[tokio::test] +async fn wsclient_sink_and_stream() { + let _ = pretty_env_logger::try_init(); + + let mut client = warp::test::ws() + .handshake(ws_echo()) + .await + .expect("handshake"); + + let message = warp::ws::Message::text("hello"); + SinkExt::send(&mut client, message.clone()).await.unwrap(); + let received_message = client.next().await.unwrap().unwrap(); + assert_eq!(message, received_message); +} + #[tokio::test] async fn close_frame() { let _ = pretty_env_logger::try_init();