diff --git a/src/bytes.rs b/src/bytes.rs index 7927635..6b5c309 100644 --- a/src/bytes.rs +++ b/src/bytes.rs @@ -1,7 +1,8 @@ -//! Provides abstractions to use `AsyncRead` and `AsyncWrite` with a `WebSocketStream`. +//! Provides abstractions to use `AsyncRead` and `AsyncWrite` with +//! a [`WebSocketStream`](crate::WebSocketStream) or a [`WebSocketSender`](crate::WebSocketSender). use std::{ - io, + fmt, io, pin::Pin, task::{Context, Poll}, }; @@ -10,99 +11,184 @@ use futures_core::stream::Stream; use crate::{tungstenite::Bytes, Message, WsError}; -/// Treat a `WebSocketStream` as an `AsyncWrite` implementation. +/// Treat a websocket [sender](Sender) as an `AsyncWrite` implementation. /// /// Every write sends a binary message. If you want to group writes together, consider wrapping /// this with a `BufWriter`. -#[cfg(feature = "futures-03-sink")] -#[derive(Debug)] -pub struct ByteWriter(S); +pub struct ByteWriter { + sender: S, + state: State, +} -#[cfg(feature = "futures-03-sink")] impl ByteWriter { - /// Create a new `ByteWriter` from a `Sink` that accepts a WebSocket `Message` + /// Create a new `ByteWriter` from a [sender](Sender) that accepts a websocket [`Message`]. #[inline(always)] - pub fn new(s: S) -> Self { - Self(s) + pub fn new(sender: S) -> Self + where + S: Sender, + { + Self { + sender, + state: State::Open, + } } - /// Get the underlying `Sink` back. + /// Get the underlying [sender](Sender) back. #[inline(always)] pub fn into_inner(self) -> S { - self.0 + self.sender } } -#[cfg(feature = "futures-03-sink")] -fn poll_write_helper( - mut s: Pin<&mut ByteWriter>, - cx: &mut Context<'_>, - buf: &[u8], -) -> Poll> +impl fmt::Debug for ByteWriter where - S: futures_util::Sink + Unpin, + S: fmt::Debug, { - match Pin::new(&mut s.0).poll_ready(cx).map_err(convert_err) { - Poll::Ready(Ok(())) => {} - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Pending => return Poll::Pending, - } - let len = buf.len(); - let msg = Message::binary(buf.to_owned()); - Poll::Ready( - Pin::new(&mut s.0) - .start_send(msg) - .map_err(convert_err) - .map(|()| len), - ) + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ByteWriter") + .field("sender", &self.sender) + .field("state", &"..") + .finish() + } +} + +enum State { + Open, + Closing(Option), +} + +impl State { + fn close(&mut self) -> &mut Option { + match self { + State::Open => { + *self = State::Closing(Some(Message::Close(None))); + if let State::Closing(msg) = self { + msg + } else { + unreachable!() + } + } + State::Closing(msg) => msg, + } + } +} + +/// Sends bytes as a websocket [`Message`]. +/// +/// It's implemented for [`WebSocketStream`](crate::WebSocketStream) +/// and [`WebSocketSender`](crate::WebSocketSender). +/// It's also implemeted for every `Sink` type that accepts +/// a websocket [`Message`] and returns [`WsError`] type as +/// an error when `futures-03-sink` feature is enabled. +pub trait Sender: private::SealedSender {} + +pub(crate) mod private { + use super::*; + + pub trait SealedSender { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll>; + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; + + fn poll_close( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + msg: &mut Option, + ) -> Poll>; + } + + impl Sender for S where S: SealedSender {} } #[cfg(feature = "futures-03-sink")] -impl futures_io::AsyncWrite for ByteWriter +impl private::SealedSender for S where S: futures_util::Sink + Unpin, { fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + use std::task::ready; + + ready!(self.as_mut().poll_ready(cx))?; + let len = buf.len(); + self.start_send(Message::binary(buf.to_owned()))?; + Poll::Ready(Ok(len)) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + >::poll_flush(self, cx) + } + + fn poll_close( self: Pin<&mut Self>, cx: &mut Context<'_>, + _: &mut Option, + ) -> Poll> { + >::poll_close(self, cx) + } +} + +impl futures_io::AsyncWrite for ByteWriter +where + S: Sender + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - poll_write_helper(self, cx, buf) + ::poll_write(Pin::new(&mut self.sender), cx, buf) + .map_err(convert_err) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_flush(cx).map_err(convert_err) + ::poll_flush(Pin::new(&mut self.sender), cx) + .map_err(convert_err) } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_close(cx).map_err(convert_err) + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.get_mut(); + let msg = me.state.close(); + ::poll_close(Pin::new(&mut me.sender), cx, msg) + .map_err(convert_err) } } -#[cfg(feature = "futures-03-sink")] #[cfg(feature = "tokio-runtime")] impl tokio::io::AsyncWrite for ByteWriter where - S: futures_util::Sink + Unpin, + S: Sender + Unpin, { fn poll_write( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - poll_write_helper(self, cx, buf) + ::poll_write(Pin::new(&mut self.sender), cx, buf) + .map_err(convert_err) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_flush(cx).map_err(convert_err) + ::poll_flush(Pin::new(&mut self.sender), cx) + .map_err(convert_err) } - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_close(cx).map_err(convert_err) + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.get_mut(); + let msg = me.state.close(); + ::poll_close(Pin::new(&mut me.sender), cx, msg) + .map_err(convert_err) } } -/// Treat a `WebSocketStream` as an `AsyncRead` implementation. +/// Treat a websocket [stream](Stream) as an `AsyncRead` implementation. /// /// This also works with any other `Stream` of `Message`, such as a `SplitStream`. /// @@ -115,7 +201,7 @@ pub struct ByteReader { } impl ByteReader { - /// Create a new `ByteReader` from a `Stream` that returns a WebSocket `Message` + /// Create a new `ByteReader` from a [stream](Stream) that returns a WebSocket [`Message`]. #[inline(always)] pub fn new(stream: S) -> Self { Self { diff --git a/src/lib.rs b/src/lib.rs index 89b7e26..14845a4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -95,7 +95,6 @@ pub mod tokio; pub mod bytes; pub use bytes::ByteReader; -#[cfg(feature = "futures-03-sink")] pub use bytes::ByteWriter; use tungstenite::protocol::CloseFrame; @@ -358,9 +357,9 @@ impl WebSocketStream { } } -impl WebSocketStream +impl WebSocketStream where - T: AsyncRead + AsyncWrite + Unpin, + S: AsyncRead + AsyncWrite + Unpin, { fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll>> { #[cfg(feature = "verbose-logging")] @@ -465,9 +464,9 @@ where } } -impl Stream for WebSocketStream +impl Stream for WebSocketStream where - T: AsyncRead + AsyncWrite + Unpin, + S: AsyncRead + AsyncWrite + Unpin, { type Item = Result; @@ -476,9 +475,9 @@ where } } -impl FusedStream for WebSocketStream +impl FusedStream for WebSocketStream where - T: AsyncRead + AsyncWrite + Unpin, + S: AsyncRead + AsyncWrite + Unpin, { fn is_terminated(&self) -> bool { self.ended @@ -486,9 +485,9 @@ where } #[cfg(feature = "futures-03-sink")] -impl futures_util::Sink for WebSocketStream +impl futures_util::Sink for WebSocketStream where - T: AsyncRead + AsyncWrite + Unpin, + S: AsyncRead + AsyncWrite + Unpin, { type Error = WsError; @@ -509,6 +508,37 @@ where } } +#[cfg(not(feature = "futures-03-sink"))] +impl bytes::private::SealedSender for WebSocketStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let me = self.get_mut(); + ready!(me.poll_ready(cx))?; + let len = buf.len(); + me.start_send(Message::binary(buf.to_owned()))?; + Poll::Ready(Ok(len)) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().poll_flush(cx) + } + + fn poll_close( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + msg: &mut Option, + ) -> Poll> { + let me = self.get_mut(); + send_helper(me, msg, cx) + } +} + impl WebSocketStream { /// Simple send method to replace `futures_sink::Sink` (till v0.3). pub async fn send(&mut self, msg: Message) -> Result<(), WsError> @@ -629,6 +659,39 @@ where } } +#[cfg(not(feature = "futures-03-sink"))] +impl bytes::private::SealedSender for WebSocketSender +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let me = self.get_mut(); + let mut ws = me.shared.lock(); + ready!(ws.poll_ready(cx))?; + let len = buf.len(); + ws.start_send(Message::binary(buf.to_owned()))?; + Poll::Ready(Ok(len)) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.shared.lock().poll_flush(cx) + } + + fn poll_close( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + msg: &mut Option, + ) -> Poll> { + let me = self.get_mut(); + let mut ws = me.shared.lock(); + send_helper(&mut ws, msg, cx) + } +} + /// The receiver part of a [websocket](WebSocketStream) stream. #[derive(Debug)] pub struct WebSocketReceiver {