use futures_util::{SinkExt, StreamExt}; use std::borrow::Cow; use parking_lot::RwLock; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::{Arc, Weak}; use std::time::Duration; use crate::ws::ping::ServerFixIntervalPing; use crate::ws::retry::ConnectAction; use crate::ws::state::{ConnectState, ConnectStateNotify}; use crate::ws::{BusinessID, ClientRealtimeMessage, WSError, WebSocketChannel}; use tokio::sync::broadcast::{channel, Receiver, Sender}; use tokio::sync::{oneshot, Mutex}; use tokio_retry::strategy::FixedInterval; use tokio_retry::{Condition, RetryIf}; use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode; use tokio_tungstenite::tungstenite::protocol::CloseFrame; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::MaybeTlsStream; use tracing::{debug, error, info, trace, warn}; pub struct WSClientConfig { /// specifies the number of messages that the channel can hold at any given /// time. It is used to set the initial size of the channel's internal buffer pub buffer_capacity: usize, /// specifies the number of seconds between each ping message pub ping_per_secs: u64, /// specifies the number of pings that the client will start reconnecting pub retry_connect_per_pings: u32, } impl Default for WSClientConfig { fn default() -> Self { Self { buffer_capacity: 1000, ping_per_secs: 8, retry_connect_per_pings: 20, } } } type ChannelByObjectId = HashMap>; pub type WSConnectStateReceiver = Receiver; pub struct WSClient { addr: Arc>>, config: WSClientConfig, state_notify: Arc>, /// Sender used to send messages to the websocket. sender: Sender, channels: Arc>>, ping: Arc>>, stop_tx: Mutex>>, } impl WSClient { pub fn new(config: WSClientConfig) -> Self { let (sender, _) = channel(config.buffer_capacity); let state_notify = Arc::new(parking_lot::Mutex::new(ConnectStateNotify::new())); let channels = Arc::new(RwLock::new(HashMap::new())); let ping = Arc::new(Mutex::new(None)); WSClient { addr: Arc::new(parking_lot::Mutex::new(None)), config, state_notify, sender, channels, ping, stop_tx: Mutex::new(None), } } pub async fn connect(&self, addr: String) -> Result, WSError> { let (stop_tx, mut stop_rx) = oneshot::channel(); *self.stop_tx.lock().await = Some(stop_tx); self.set_state(ConnectState::Connecting).await; *self.addr.lock() = Some(addr.clone()); if let Some(old_ping) = self.ping.lock().await.as_ref() { old_ping.stop().await; } // let retry_strategy = FibonacciBackoff::from_millis(2000).max_delay(Duration::from_secs(10 * 60)); let retry_strategy = FixedInterval::new(Duration::from_secs(6)); let action = ConnectAction::new(addr.clone()); let cond = RetryCondition { connecting_addr: addr, addr: Arc::downgrade(&self.addr), }; let stream = RetryIf::spawn(retry_strategy, action, cond).await?; let addr = match stream.get_ref() { MaybeTlsStream::Plain(s) => s.local_addr().ok(), _ => None, }; self.set_state(ConnectState::Connected).await; let (mut sink, mut stream) = stream.split(); let weak_channels = Arc::downgrade(&self.channels); let sender = self.sender.clone(); let mut ping = ServerFixIntervalPing::new( Duration::from_secs(self.config.ping_per_secs), self.state_notify.clone(), sender.clone(), self.config.retry_connect_per_pings, ); ping.run(); *self.ping.lock().await = Some(ping); // Receive messages from the websocket, and send them to the channels. tokio::spawn(async move { while let Some(Ok(msg)) = stream.next().await { match msg { Message::Text(_) => {}, Message::Binary(_) => { if let Ok(msg) = ClientRealtimeMessage::try_from(&msg) { if let Some(channels) = weak_channels.upgrade() { if let Some(channel) = channels .read() .get(&msg.business_id) .and_then(|map| map.get(&msg.object_id)) { match channel.upgrade() { None => { // when calling [WSClient::subscribe], the caller is responsible for keeping // the channel alive as long as it wants to receive messages from the websocket. trace!("channel is dropped"); }, Some(channel) => { channel.recv_msg(&msg); }, } } } else { warn!("channels are closed"); } } else { error!("🔴Parser ClientRealtimeMessage failed"); } }, Message::Ping(_) => match sender.send(Message::Pong(vec![])) { Ok(_) => {}, Err(e) => { error!("🔴Failed to send pong message to websocket: {:?}", e); }, }, Message::Pong(_) => {}, Message::Close(close) => { info!("{:?}", close); }, Message::Frame(_) => {}, } } }); let mut sink_rx = self.sender.subscribe(); tokio::spawn(async move { loop { tokio::select! { _ = &mut stop_rx => { info!("Client stop sending message using websocket"); break; }, Ok(msg) = sink_rx.recv() => { match sink.send(msg).await { Ok(_) => {}, Err(e) => { error!("Failed to send message via websocket: {:?}", e); break; }, } } } } }); Ok(addr) } /// Return a [WebSocketChannel] that can be used to send messages to the websocket. Caller should /// keep the channel alive as long as it wants to receive messages from the websocket. pub fn subscribe( &self, business_id: BusinessID, object_id: String, ) -> Result, WSError> { let channel = Arc::new(WebSocketChannel::new(business_id, self.sender.clone())); self .channels .write() .entry(business_id) .or_insert_with(HashMap::new) .insert(object_id, Arc::downgrade(&channel)); Ok(channel) } pub fn subscribe_connect_state(&self) -> WSConnectStateReceiver { self.state_notify.lock().subscribe() } pub fn is_connected(&self) -> bool { self.state_notify.lock().state.is_connected() } pub async fn disconnect(&self) { if let Some(stop_tx) = self.stop_tx.lock().await.take() { debug!("client disconnect"); let _ = stop_tx.send(()); let _ = self.sender.send(Message::Close(Some(CloseFrame { code: CloseCode::Normal, reason: Cow::from("client disconnect"), }))); *self.addr.lock() = None; self.set_state(ConnectState::Disconnected).await; } } async fn set_state(&self, state: ConnectState) { trace!("websocket state: {:?}", state); self.state_notify.lock().set_state(state); } } struct RetryCondition { connecting_addr: String, addr: Weak>>, } impl Condition for RetryCondition { fn should_retry(&mut self, _error: &WSError) -> bool { let should_retry = self .addr .upgrade() .map(|addr| match addr.try_lock() { None => false, Some(addr) => match &*addr { None => false, Some(addr) => addr == &self.connecting_addr, }, }) .unwrap_or(false); debug!("WSClient should_retry: {}", should_retry); should_retry } }