use futures_util::{Sink, Stream, StreamExt}; use http::HeaderMap; use std::pin::Pin; use std::task::{Context, Poll}; use tokio_tungstenite::tungstenite::client::IntoClientRequest; use tokio_tungstenite::{ tungstenite::{ error::*, protocol::{frame::coding::Data, CloseFrame}, Message, Result, }, MaybeTlsStream, }; pub async fn connect_async(url: &str, header_map: HeaderMap) -> crate::Result { let mut request = url.into_client_request()?; request.headers_mut().extend(header_map); let (inner, _response) = tokio_tungstenite::connect_async(request).await?; let inner = inner.filter_map(to_fut_message as fn(_) -> _); Ok(WebSocketStream { inner }) } type TokioTungsteniteStream = tokio_tungstenite::WebSocketStream>; type FutMessage = futures_util::future::Ready>>; pub struct WebSocketStream { inner: futures_util::stream::FilterMap< TokioTungsteniteStream, FutMessage, fn(Result) -> FutMessage, >, } impl Stream for WebSocketStream { type Item = crate::Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.inner).poll_next(cx) } fn size_hint(&self) -> (usize, Option) { self.inner.size_hint() } } impl Sink for WebSocketStream { type Error = crate::Error; fn poll_ready( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { Pin::new(&mut self.inner).poll_ready(cx).map_err(Into::into) } fn start_send( mut self: Pin<&mut Self>, item: crate::Message, ) -> std::result::Result<(), Self::Error> { Pin::new(&mut self.inner) .start_send(item.into()) .map_err(Into::into) } fn poll_flush( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { Pin::new(&mut self.inner).poll_flush(cx).map_err(Into::into) } fn poll_close( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { Pin::new(&mut self.inner).poll_close(cx).map_err(Into::into) } } fn to_fut_message(msg: Result) -> FutMessage { fn inner(msg: Result) -> Option> { let msg = match msg { Ok(msg) => match msg { Message::Text(inner) => Ok(crate::Message::Text(inner)), Message::Binary(inner) => Ok(crate::Message::Binary(inner)), Message::Close(inner) => Ok(crate::Message::Close(inner.map(Into::into))), Message::Pong(inner) => Ok(crate::Message::Pong(inner)), Message::Ping(inner) => Ok(crate::Message::Ping(inner)), Message::Frame(_) => return None, }, Err(err) => Err(crate::Error::from(err)), }; Some(msg) } futures_util::future::ready(inner(msg)) } impl<'a> From> for crate::message::CloseFrame<'a> { fn from(close_frame: CloseFrame<'a>) -> Self { crate::message::CloseFrame { code: u16::from(close_frame.code).into(), reason: close_frame.reason, } } } impl<'a> From> for CloseFrame<'a> { fn from(close_frame: crate::message::CloseFrame<'a>) -> Self { CloseFrame { code: u16::from(close_frame.code).into(), reason: close_frame.reason, } } } impl From for crate::Message { fn from(msg: Message) -> Self { match msg { Message::Text(inner) => crate::Message::Text(inner), Message::Binary(inner) => crate::Message::Binary(inner), Message::Close(inner) => crate::Message::Close(inner.map(Into::into)), Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => { unreachable!("Unsendable via interface.") }, } } } impl From for Message { fn from(msg: crate::Message) -> Self { match msg { crate::Message::Text(inner) => Message::Text(inner), crate::Message::Binary(inner) => Message::Binary(inner), crate::Message::Close(inner) => Message::Close(inner.map(Into::into)), crate::Message::Ping(data) => Message::Ping(data), crate::Message::Pong(data) => Message::Pong(data), } } } impl From for crate::Error { fn from(err: Error) -> Self { match err { Error::ConnectionClosed => crate::Error::ConnectionClosed, Error::AlreadyClosed => crate::Error::AlreadyClosed, Error::Io(inner) => crate::Error::Io(inner), Error::Tls(inner) => crate::Error::Tls(inner), Error::Capacity(inner) => crate::Error::Capacity(inner.into()), Error::Protocol(inner) => crate::Error::Protocol(inner.into()), Error::WriteBufferFull(inner) => crate::Error::WriteBufferFull(inner.into()), Error::Utf8 => crate::Error::Utf8, Error::AttackAttempt => crate::Error::AttackAttempt, Error::Url(inner) => crate::Error::Url(inner.into()), Error::Http(inner) => crate::Error::Http(inner), Error::HttpFormat(inner) => crate::Error::HttpFormat(inner), } } } impl From for crate::error::CapacityError { fn from(err: CapacityError) -> Self { match err { CapacityError::TooManyHeaders => crate::error::CapacityError::TooManyHeaders, CapacityError::MessageTooLong { size, max_size } => { crate::error::CapacityError::MessageTooLong { size, max_size } }, } } } impl From for crate::error::UrlError { fn from(err: UrlError) -> Self { match err { UrlError::TlsFeatureNotEnabled => crate::error::UrlError::TlsFeatureNotEnabled, UrlError::NoHostName => crate::error::UrlError::NoHostName, UrlError::UnableToConnect(inner) => crate::error::UrlError::UnableToConnect(inner), UrlError::UnsupportedUrlScheme => crate::error::UrlError::UnsupportedUrlScheme, UrlError::EmptyHostName => crate::error::UrlError::EmptyHostName, UrlError::NoPathOrQuery => crate::error::UrlError::NoPathOrQuery, } } } impl From for crate::error::ProtocolError { fn from(err: ProtocolError) -> Self { match err { ProtocolError::WrongHttpMethod => crate::error::ProtocolError::WrongHttpMethod, ProtocolError::WrongHttpVersion => crate::error::ProtocolError::WrongHttpVersion, ProtocolError::MissingConnectionUpgradeHeader => { crate::error::ProtocolError::MissingConnectionUpgradeHeader }, ProtocolError::MissingUpgradeWebSocketHeader => { crate::error::ProtocolError::MissingUpgradeWebSocketHeader }, ProtocolError::MissingSecWebSocketVersionHeader => { crate::error::ProtocolError::MissingSecWebSocketVersionHeader }, ProtocolError::MissingSecWebSocketKey => crate::error::ProtocolError::MissingSecWebSocketKey, ProtocolError::SecWebSocketAcceptKeyMismatch => { crate::error::ProtocolError::SecWebSocketAcceptKeyMismatch }, ProtocolError::JunkAfterRequest => crate::error::ProtocolError::JunkAfterRequest, ProtocolError::CustomResponseSuccessful => { crate::error::ProtocolError::CustomResponseSuccessful }, ProtocolError::InvalidHeader(header_name) => { crate::error::ProtocolError::InvalidHeader(header_name) }, ProtocolError::HandshakeIncomplete => crate::error::ProtocolError::HandshakeIncomplete, ProtocolError::HttparseError(inner) => crate::error::ProtocolError::HttparseError(inner), ProtocolError::SendAfterClosing => crate::error::ProtocolError::SendAfterClosing, ProtocolError::ReceivedAfterClosing => crate::error::ProtocolError::ReceivedAfterClosing, ProtocolError::NonZeroReservedBits => crate::error::ProtocolError::NonZeroReservedBits, ProtocolError::UnmaskedFrameFromClient => { crate::error::ProtocolError::UnmaskedFrameFromClient }, ProtocolError::MaskedFrameFromServer => crate::error::ProtocolError::MaskedFrameFromServer, ProtocolError::FragmentedControlFrame => crate::error::ProtocolError::FragmentedControlFrame, ProtocolError::ControlFrameTooBig => crate::error::ProtocolError::ControlFrameTooBig, ProtocolError::UnknownControlFrameType(inner) => { crate::error::ProtocolError::UnknownControlFrameType(inner) }, ProtocolError::UnknownDataFrameType(inner) => { crate::error::ProtocolError::UnknownDataFrameType(inner) }, ProtocolError::UnexpectedContinueFrame => { crate::error::ProtocolError::UnexpectedContinueFrame }, ProtocolError::ExpectedFragment(inner) => { crate::error::ProtocolError::ExpectedFragment(inner.into()) }, ProtocolError::ResetWithoutClosingHandshake => { crate::error::ProtocolError::ResetWithoutClosingHandshake }, ProtocolError::InvalidOpcode(inner) => crate::error::ProtocolError::InvalidOpcode(inner), ProtocolError::InvalidCloseSequence => crate::error::ProtocolError::InvalidCloseSequence, } } } impl From for crate::error::Data { fn from(data: Data) -> Self { match data { Data::Continue => crate::error::Data::Continue, Data::Text => crate::error::Data::Text, Data::Binary => crate::error::Data::Binary, Data::Reserved(inner) => crate::error::Data::Reserved(inner), } } }