diff --git a/libs/client-api/src/native/http_native.rs b/libs/client-api/src/native/http_native.rs index e33891fc..bcd618c6 100644 --- a/libs/client-api/src/native/http_native.rs +++ b/libs/client-api/src/native/http_native.rs @@ -156,7 +156,7 @@ impl WSClientHttpSender for Client { self .post_realtime_msg(device_id, message) .await - .map_err(|err| WSError::Internal(anyhow::Error::from(err))) + .map_err(|err| WSError::Http(err.to_string())) } } diff --git a/libs/client-api/src/native/ping.rs b/libs/client-api/src/native/ping.rs index 70445aba..d2fec489 100644 --- a/libs/client-api/src/native/ping.rs +++ b/libs/client-api/src/native/ping.rs @@ -59,6 +59,10 @@ impl ServerFixIntervalPing { // send ping to server if let Err(err) = ping_sender.send(Message::Ping(vec![])) { tracing::error!("ping send error: {}", err); + if let Some(state) =weak_state.upgrade() { + state.lock().set_state(ConnectState::PingTimeout); + } + break; } if let Some(ping_count) = weak_ping_count.upgrade() { let mut lock = ping_count.lock().await; diff --git a/libs/client-api/src/native/retry.rs b/libs/client-api/src/native/retry.rs index ef077959..3058ab93 100644 --- a/libs/client-api/src/native/retry.rs +++ b/libs/client-api/src/native/retry.rs @@ -71,7 +71,7 @@ pub async fn retry_connect( ) -> Result { let connecting_addr = addr.to_owned(); let stream = RetryIf::spawn( - FixedInterval::new(Duration::from_secs(6)), + FixedInterval::new(Duration::from_secs(10)), ConnectAction::new(connecting_addr.clone()), RetryCondition { connecting_addr, diff --git a/libs/client-api/src/ws/client.rs b/libs/client-api/src/ws/client.rs index f37278c3..0679a4e3 100644 --- a/libs/client-api/src/ws/client.rs +++ b/libs/client-api/src/ws/client.rs @@ -102,6 +102,11 @@ impl WSClient { } pub async fn connect(&self, addr: String, device_id: &str) -> Result<(), WSError> { + if self.get_state().is_connecting() { + info!("websocket is connecting, skip connect request"); + return Ok(()); + } + self.set_state(ConnectState::Connecting).await; // stop receiving message from client @@ -136,6 +141,7 @@ impl WSClient { WSError::LostConnection(_) => state_notify.lock().set_state(ConnectState::Closed), WSError::AuthError(_) => state_notify.lock().set_state(ConnectState::Unauthorized), WSError::Internal(_) => {}, + WSError::Http(_) => {}, }, } }; @@ -233,10 +239,12 @@ impl WSClient { }, Message::Close(close) => { info!("websocket close: {:?}", close); + break; }, Message::Pong(_) => { if let Err(err) = pong_tx.send(()).await { error!("failed to receive server pong: {}", err); + break; } }, _ => warn!("received unexpected message from websocket: {:?}", ws_msg), @@ -272,9 +280,11 @@ impl WSClient { error!("The HTTP sender has been dropped, unable to send message."); continue; } - } else if let Err(err) = sink.send(msg).await.map_err(WSError::from){ - handle_ws_error(&err); - break; + } else if let Err(err) = sink.send(msg).await.map_err(WSError::from) { + if err.is_lost_connection() { + break; + } + handle_ws_error(&err); } } } diff --git a/libs/client-api/src/ws/error.rs b/libs/client-api/src/ws/error.rs index 747631ae..ef60794c 100644 --- a/libs/client-api/src/ws/error.rs +++ b/libs/client-api/src/ws/error.rs @@ -12,10 +12,19 @@ pub enum WSError { #[error("Auth error: {0}")] AuthError(String), + #[error("Fail to send message via http: {0}")] + Http(String), + #[error(transparent)] Internal(#[from] anyhow::Error), } +impl WSError { + pub fn is_lost_connection(&self) -> bool { + matches!(self, WSError::LostConnection(_)) + } +} + impl From for WSError { fn from(value: Error) -> Self { match &value { diff --git a/libs/client-api/tests/native/conn_test.rs b/libs/client-api/tests/native/conn_test.rs index 78821966..d66e6c81 100644 --- a/libs/client-api/tests/native/conn_test.rs +++ b/libs/client-api/tests/native/conn_test.rs @@ -22,32 +22,6 @@ async fn realtime_connect_test() { } } -#[tokio::test] -async fn realtime_connect_after_token_exp_test() { - let (c, _user) = generate_unique_registered_user_client().await; - - // Set the token to be expired - c.token().write().as_mut().unwrap().expires_at = SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .unwrap() - .as_secs() as i64; - - let ws_client = WSClient::new(WSClientConfig::default(), c.clone()); - let mut state = ws_client.subscribe_connect_state(); - let device_id = "fake_device_id"; - loop { - tokio::select! { - _ = ws_client.connect(c.ws_url(device_id).await.unwrap(), device_id) => {}, - value = state.recv() => { - let new_state = value.unwrap(); - if new_state == ConnectState::Connected { - break; - } - }, - } - } -} - #[tokio::test] async fn realtime_disconnect_test() { let (c, _user) = generate_unique_registered_user_client().await; @@ -71,19 +45,3 @@ async fn realtime_disconnect_test() { } } } - -// use std::time::Duration; -// use tokio_tungstenite::tungstenite::Message; -// #[tokio::test] -// async fn max_frame_size() { -// let (c, _user) = generate_unique_registered_user_client().await; -// let ws_client = WSClient::new(WSClientConfig::default(), c.clone()); -// let device_id = "fake_device_id"; -// ws_client -// .connect(c.ws_url(device_id).unwrap(), device_id) -// .await -// .unwrap(); -// -// ws_client.send(Message::Binary(vec![0; 65536])).unwrap(); -// tokio::time::sleep(Duration::from_secs(5)).await; -// } diff --git a/tests/collab/edit_permission.rs b/tests/collab/edit_permission.rs index d02da7f5..d0832c81 100644 --- a/tests/collab/edit_permission.rs +++ b/tests/collab/edit_permission.rs @@ -440,9 +440,9 @@ async fn multiple_user_with_read_and_write_permission_edit_same_collab_test() { sleep(Duration::from_secs(5)).await; // all the clients should have the same collab object - assert_json_include!( - actual: json!(expected_json), - expected: arc_owner + assert_json_eq!( + json!(expected_json), + arc_owner .collab_by_object_id .get(&object_id) .unwrap() @@ -452,8 +452,8 @@ async fn multiple_user_with_read_and_write_permission_edit_same_collab_test() { for client in clients { assert_json_include!( - expected: json!(expected_json), - actual: client + actual: json!(expected_json), + expected: client .collab_by_object_id .get(&object_id) .unwrap() diff --git a/tests/websocket/conn_test.rs b/tests/websocket/conn_test.rs index be91b365..d46a6707 100644 --- a/tests/websocket/conn_test.rs +++ b/tests/websocket/conn_test.rs @@ -1,4 +1,5 @@ -use std::time::SystemTime; +use std::time::{Duration, SystemTime}; +use tokio::time::timeout; use client_api::ws::{ConnectState, WSClient, WSClientConfig}; use client_api_test_util::generate_unique_registered_user_client; @@ -9,16 +10,25 @@ async fn realtime_connect_test() { let ws_client = WSClient::new(WSClientConfig::default(), c.clone()); let mut state = ws_client.subscribe_connect_state(); let device_id = "fake_device_id"; - loop { - tokio::select! { - _ = ws_client.connect(c.ws_url(device_id).await.unwrap(), device_id) => {}, - value = state.recv() => { - let new_state = value.unwrap(); - if new_state == ConnectState::Connected { - break; - } - }, + + tokio::spawn(async move { + ws_client + .connect(c.ws_url(device_id).await.unwrap(), device_id) + .await + }); + + let connect_future = async { + while let Ok(state) = state.recv().await { + if state == ConnectState::Connected { + break; + } } + }; + + // Apply the timeout + match timeout(Duration::from_secs(10), connect_future).await { + Ok(_) => {}, + Err(_) => panic!("Connection timeout."), } } @@ -35,16 +45,25 @@ async fn realtime_connect_after_token_exp_test() { let ws_client = WSClient::new(WSClientConfig::default(), c.clone()); let mut state = ws_client.subscribe_connect_state(); let device_id = "fake_device_id"; - loop { - tokio::select! { - _ = ws_client.connect(c.ws_url(device_id).await.unwrap(), device_id) => {}, - value = state.recv() => { - let new_state = value.unwrap(); - if new_state == ConnectState::Connected { - break; - } - }, + + tokio::spawn(async move { + ws_client + .connect(c.ws_url(device_id).await.unwrap(), device_id) + .await + }); + + let connect_future = async { + while let Ok(state) = state.recv().await { + if state == ConnectState::Connected { + break; + } } + }; + + // Apply the timeout + match timeout(Duration::from_secs(10), connect_future).await { + Ok(_) => {}, + Err(_) => panic!("Connection timeout."), } }