diff --git a/Cargo.lock b/Cargo.lock index 38aebd60..3d2476d7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -781,6 +781,7 @@ dependencies = [ "futures-util", "gotrue-entity", "opener", + "parking_lot", "reqwest", "serde", "serde_json", @@ -2477,6 +2478,7 @@ dependencies = [ "assert-json-diff", "async-trait", "bytes", + "chrono", "collab", "collab-define", "collab-plugins", diff --git a/libs/client-api/Cargo.toml b/libs/client-api/Cargo.toml index 32c4c154..8a86c3de 100644 --- a/libs/client-api/Cargo.toml +++ b/libs/client-api/Cargo.toml @@ -16,7 +16,7 @@ storage-entity = { path = "../storage-entity" } opener = "0.6.1" url = "2.4.1" tokio-stream = { version = "0.1.14" } - +parking_lot = "0.12.1" # ws tracing = { version = "0.1" } thiserror = "1.0.39" diff --git a/libs/client-api/src/ws/client.rs b/libs/client-api/src/ws/client.rs index 76e85305..f2fae51b 100644 --- a/libs/client-api/src/ws/client.rs +++ b/libs/client-api/src/ws/client.rs @@ -10,7 +10,7 @@ use crate::ws::{BusinessID, ClientRealtimeMessage, WSError, WebSocketChannel}; use tokio::sync::broadcast::{channel, Receiver, Sender}; use tokio::sync::{Mutex, RwLock}; use tokio_retry::strategy::FixedInterval; -use tokio_retry::Retry; +use tokio_retry::{Condition, RetryIf}; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::MaybeTlsStream; @@ -37,11 +37,12 @@ impl Default for WSClientConfig { type ChannelByObjectId = HashMap>; pub struct WSClient { - addr: Mutex>, + addr: Arc>>, + config: WSClientConfig, state: Arc>, sender: Sender, channels: Arc>>, - ping: Arc>, + ping: Arc>>, } impl WSClient { @@ -49,14 +50,10 @@ impl WSClient { let (sender, _) = channel(config.buffer_capacity); let state = Arc::new(Mutex::new(ConnectStateNotify::new())); let channels = Arc::new(RwLock::new(HashMap::new())); - let ping = Arc::new(Mutex::new(ServerFixIntervalPing::new( - Duration::from_secs(config.ping_per_secs), - state.clone(), - sender.clone(), - config.retry_connect_per_pings, - ))); + let ping = Arc::new(Mutex::new(None)); WSClient { - addr: Mutex::new(None), + addr: Arc::new(parking_lot::Mutex::new(None)), + config, state, sender, channels, @@ -65,12 +62,16 @@ impl WSClient { } pub async fn connect(&self, addr: String) -> Result, WSError> { - *self.addr.lock().await = Some(addr.clone()); + *self.addr.lock() = Some(addr.clone()); self.set_state(ConnectState::Connecting).await; let retry_strategy = FixedInterval::new(Duration::from_secs(2)).take(3); - let action = ConnectAction::new(addr); - let stream = Retry::spawn(retry_strategy, action).await?; + let action = ConnectAction::new(addr.clone()); + let cond = RetryCondition { + connecting_addr: addr, + addr: Arc::downgrade(&self.addr), + }; + let stream = RetryIf::spawn(retry_strategy, action, cond).await?; let addr = match stream.get_ref() { MaybeTlsStream::Plain(s) => s.local_addr().ok(), _ => None, @@ -80,7 +81,16 @@ impl WSClient { self.set_state(ConnectState::Connected).await; let weak_channels = Arc::downgrade(&self.channels); let sender = self.sender.clone(); - self.ping.lock().await.run(); + + let mut ping = ServerFixIntervalPing::new( + Duration::from_secs(self.config.ping_per_secs), + self.state.clone(), + sender.clone(), + self.config.retry_connect_per_pings, + ); + ping.run(); + *self.ping.lock().await = Some(ping); + // Receive messages from the websocket, and send them to the channels. tokio::spawn(async move { while let Some(Ok(msg)) = stream.next().await { @@ -158,7 +168,9 @@ impl WSClient { } pub async fn disconnect(&self) { + *self.addr.lock() = None; let _ = self.sender.send(Message::Close(None)); + self.set_state(ConnectState::Disconnected).await; } async fn set_state(&self, state: ConnectState) { @@ -295,3 +307,20 @@ impl ConnectState { matches!(self, ConnectState::Disconnected) } } + +struct RetryCondition { + connecting_addr: String, + addr: Weak>>, +} +impl Condition for RetryCondition { + fn should_retry(&mut self, _error: &WSError) -> bool { + self + .addr + .upgrade() + .map(|addr| match addr.lock().as_ref() { + None => false, + Some(addr) => addr == &self.connecting_addr, + }) + .unwrap_or(false) + } +} diff --git a/libs/realtime/Cargo.toml b/libs/realtime/Cargo.toml index 6699f8e4..0c473f61 100644 --- a/libs/realtime/Cargo.toml +++ b/libs/realtime/Cargo.toml @@ -30,6 +30,7 @@ storage-entity = { path = "../storage-entity" } y-sync = { version = "0.3.1" } yrs = "0.16.5" lib0 = "0.16.3" +chrono = "0.4.30" [dev-dependencies] actix = "0.13" diff --git a/libs/realtime/src/client.rs b/libs/realtime/src/client.rs index e18eb344..95b770bc 100644 --- a/libs/realtime/src/client.rs +++ b/libs/realtime/src/client.rs @@ -181,23 +181,26 @@ impl Deref for ClientWSSink { pub struct RealtimeUserImpl { pub uuid: String, pub device_id: String, + timestamp: i64, +} + +impl RealtimeUserImpl { + pub fn new(uuid: String, device_id: String) -> Self { + Self { + uuid, + device_id, + timestamp: chrono::Utc::now().timestamp(), + } + } } impl Display for RealtimeUserImpl { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.write_fmt(format_args!( - "uuid:{}|device_id:{}", - self.uuid, self.device_id, + "uuid:{}|device_id:{}:{}", + self.uuid, self.device_id, self.timestamp )) } } -impl RealtimeUser for RealtimeUserImpl { - fn id(&self) -> &str { - &self.uuid - } - - fn device_id(&self) -> &str { - &self.device_id - } -} +impl RealtimeUser for RealtimeUserImpl {} diff --git a/libs/realtime/src/collaborate/group.rs b/libs/realtime/src/collaborate/group.rs index d03d6f9b..b80641b2 100644 --- a/libs/realtime/src/collaborate/group.rs +++ b/libs/realtime/src/collaborate/group.rs @@ -12,17 +12,19 @@ use std::ops::{Deref, DerefMut}; use std::sync::Arc; use storage::collab::CollabStorage; +use crate::entities::RealtimeUser; use tokio::task::spawn_blocking; use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec}; -pub struct CollabGroupCache { - group_by_object_id: RwLock>>, +pub struct CollabGroupCache { + group_by_object_id: RwLock>>>, storage: S, } -impl CollabGroupCache +impl CollabGroupCache where S: CollabStorage + Clone, + U: RealtimeUser, { pub fn new(storage: S) -> Self { Self { @@ -57,7 +59,7 @@ where workspace_id: &str, object_id: &str, collab_type: CollabType, - ) -> Arc { + ) -> Arc> { tracing::trace!("Create new group for object_id:{}", object_id); let collab = MutexCollab::new(CollabOrigin::Server, object_id, vec![]); let broadcast = CollabBroadcast::new(object_id, collab.clone(), 10); @@ -89,20 +91,22 @@ where } } -impl Deref for CollabGroupCache +impl Deref for CollabGroupCache where S: CollabStorage, + U: RealtimeUser, { - type Target = RwLock>>; + type Target = RwLock>>>; fn deref(&self) -> &Self::Target { &self.group_by_object_id } } -impl DerefMut for CollabGroupCache +impl DerefMut for CollabGroupCache where S: CollabStorage, + U: RealtimeUser, { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.group_by_object_id @@ -110,7 +114,7 @@ where } /// A group used to manage a single [Collab] object -pub struct CollabGroup { +pub struct CollabGroup { pub collab: Arc, /// A broadcast used to propagate updates produced by yrs [yrs::Doc] and [Awareness] @@ -119,10 +123,13 @@ pub struct CollabGroup { /// A list of subscribers to this group. Each subscriber will receive updates from the /// broadcast. - pub subscribers: RwLock>, + pub subscribers: RwLock>, } -impl CollabGroup { +impl CollabGroup +where + U: RealtimeUser, +{ /// Mutate the [Collab] by the given closure pub fn get_mut_collab(&self, f: F) where diff --git a/libs/realtime/src/collaborate/plugin.rs b/libs/realtime/src/collaborate/plugin.rs index bce001e6..166e1c97 100644 --- a/libs/realtime/src/collaborate/plugin.rs +++ b/libs/realtime/src/collaborate/plugin.rs @@ -14,27 +14,28 @@ use storage::error::StorageError; use storage_entity::{InsertCollabParams, QueryCollabParams, RawData}; use crate::collaborate::group::CollabGroup; +use crate::entities::RealtimeUser; use y_sync::awareness::Awareness; use yrs::updates::decoder::Decode; use yrs::{ReadTxn, StateVector, Transact, Update}; -pub struct CollabStoragePlugin { +pub struct CollabStoragePlugin { uid: i64, workspace_id: String, storage: S, did_load: AtomicBool, update_count: AtomicU32, - group: Weak, + group: Weak>, collab_type: CollabType, } -impl CollabStoragePlugin { +impl CollabStoragePlugin { pub fn new( uid: i64, workspace_id: &str, collab_type: CollabType, storage: S, - group: Weak, + group: Weak>, ) -> Self { let workspace_id = workspace_id.to_string(); let did_load = AtomicBool::new(false); @@ -62,9 +63,10 @@ fn init_collab_with_raw_data(raw_data: RawData, doc: &Doc) -> Result<(), Realtim } #[async_trait] -impl CollabPlugin for CollabStoragePlugin +impl CollabPlugin for CollabStoragePlugin where S: CollabStorage, + U: RealtimeUser, { async fn init(&self, object_id: &str, _origin: &CollabOrigin, doc: &Doc) { let params = QueryCollabParams { diff --git a/libs/realtime/src/collaborate/server.rs b/libs/realtime/src/collaborate/server.rs index 01d7c75e..e16283d1 100644 --- a/libs/realtime/src/collaborate/server.rs +++ b/libs/realtime/src/collaborate/server.rs @@ -23,7 +23,7 @@ pub struct CollabServer { #[allow(dead_code)] storage: S, /// Keep track of all collab groups - groups: Arc>, + groups: Arc>, /// Keep track of all object ids that a user is subscribed to editing_collab_by_user: Arc>>>, /// Keep track of all client streams @@ -45,6 +45,18 @@ where client_stream_by_user: Default::default(), }) } + + fn remove_user(&self, user: &U) { + self.client_stream_by_user.write().remove(user); + + let editing_set = self.editing_collab_by_user.write().remove(user); + if let Some(editing_set) = editing_set { + tracing::info!("Remove user from group: {}", user); + for editing in editing_set { + remove_user_from_group(user, &self.groups, &editing); + } + } + } } impl Actor for CollabServer @@ -58,19 +70,20 @@ where impl Handler> for CollabServer where U: RealtimeUser + Unpin, - S: 'static + Unpin, + S: CollabStorage + Unpin, { type Result = Result<(), RealtimeError>; fn handle(&mut self, new_conn: Connect, _ctx: &mut Context) -> Self::Result { tracing::trace!("[💭Server]: new connection => {} ", new_conn.user); + // Remove the user from the group if the user is already connected + self.remove_user(&new_conn.user); let stream = CollabClientStream::new(ClientWSSink(new_conn.socket)); self .client_stream_by_user .write() .insert(new_conn.user, stream); - Ok(()) } } @@ -83,21 +96,7 @@ where type Result = Result<(), RealtimeError>; fn handle(&mut self, msg: Disconnect, _: &mut Context) -> Self::Result { tracing::trace!("[💭Server]: disconnect => {}", msg.user); - self.client_stream_by_user.write().remove(&msg.user); - - // Remove the user from all collab groups that the user is subscribed to - let editing_set = self.editing_collab_by_user.write().remove(&msg.user); - if let Some(editing_set) = editing_set { - if !editing_set.is_empty() { - let groups = self.groups.clone(); - tokio::task::spawn_blocking(move || { - for editing in editing_set { - remove_user_from_group(&groups, &editing); - } - }); - } - } - + self.remove_user(&msg.user); Ok(()) } } @@ -136,7 +135,8 @@ async fn forward_message_to_collab_group( { if let Some(client_stream) = client_streams.read().get(&client_msg.user) { tracing::trace!( - "[💭Server]: receives: [oid:{}|msg_id:{:?}]", + "[💭Server]: receives: user:{} message: [oid:{}|msg_id:{:?}]", + client_msg.user, client_msg.content.object_id(), client_msg.content.msg_id() ); @@ -154,7 +154,7 @@ async fn forward_message_to_collab_group( async fn subscribe_collab_group_change_if_need( client_msg: &ClientMessage, - groups: &Arc>, + groups: &Arc>, edit_collab_by_user: &Arc>>>, client_streams: &Arc>>, ) -> Result<(), RealtimeError> @@ -200,7 +200,7 @@ where if groups .read() .get(object_id) - .map(|group| group.subscribers.read().get(origin).is_some()) + .map(|group| group.subscribers.read().get(&client_msg.user).is_some()) .unwrap_or(false) { return Ok(()); @@ -213,7 +213,7 @@ where collab_group .subscribers .write() - .entry(origin.clone()) + .entry(client_msg.user.clone()) .or_insert_with(|| { tracing::trace!( "[💭Server]: {} subscribe group:{}", @@ -250,15 +250,16 @@ where } /// Remove the user from the group and remove the group from the cache if the group is empty. -fn remove_user_from_group(groups: &Arc>, editing: &Editing) +fn remove_user_from_group(user: &U, groups: &Arc>, editing: &Editing) where S: CollabStorage, + U: RealtimeUser, { let mut groups_write_guard = groups.write(); let should_remove_group = groups_write_guard.get_mut(&editing.object_id).map(|group| { - tracing::debug!("Remove subscriber: {}", editing.origin); - group.subscribers.write().remove(&editing.origin); + tracing::info!("Remove subscriber: {}", editing.origin); + group.subscribers.write().remove(user); let should_remove = group.is_empty(); if should_remove { group.flush_collab(); diff --git a/libs/realtime/src/entities.rs b/libs/realtime/src/entities.rs index a0641326..2235673e 100644 --- a/libs/realtime/src/entities.rs +++ b/libs/realtime/src/entities.rs @@ -8,15 +8,15 @@ use serde::{Deserialize, Serialize}; use serde_repr::{Deserialize_repr, Serialize_repr}; use std::fmt::{Debug, Display}; use std::hash::Hash; +use std::sync::Arc; pub trait RealtimeUser: Clone + Debug + Send + Sync + 'static + Display + Hash + Eq + PartialEq { - fn id(&self) -> &str; - - fn device_id(&self) -> &str; } +impl RealtimeUser for Arc where T: RealtimeUser {} + #[derive(Debug, Message, Clone)] #[rtype(result = "Result<(), RealtimeError>")] pub struct Connect { diff --git a/src/api/ws.rs b/src/api/ws.rs index e6fb1e4f..8663c5a9 100644 --- a/src/api/ws.rs +++ b/src/api/ws.rs @@ -3,6 +3,7 @@ use actix::Addr; use actix_web::web::{Data, Path, Payload}; use actix_web::{get, web, HttpRequest, HttpResponse, Result, Scope}; use actix_web_actors::ws; +use std::sync::Arc; use realtime::client::{ClientWSSession, RealtimeUserImpl}; use realtime::collaborate::CollabServer; @@ -23,16 +24,13 @@ pub async fn establish_ws_connection( payload: Payload, path: Path<(String, String)>, state: Data, - server: Data>>, + server: Data>>>, ) -> Result { tracing::info!("ws connect: {:?}", request); let (token, device_id) = path.into_inner(); let auth = authorization_from_token(token.as_str(), &state)?; let user_uuid = UserUuid::from_auth(auth)?; - let realtime_user = RealtimeUserImpl { - uuid: user_uuid.to_string(), - device_id, - }; + let realtime_user = Arc::new(RealtimeUserImpl::new(user_uuid.to_string(), device_id)); let client = ClientWSSession::new( realtime_user, server.get_ref().clone(), diff --git a/src/application.rs b/src/application.rs index 38d90015..7e626696 100644 --- a/src/application.rs +++ b/src/application.rs @@ -83,7 +83,7 @@ where .map(|(_, server_key)| Key::from(server_key.expose_secret().as_bytes())) .unwrap_or_else(Key::generate); - let collab_server = CollabServer::<_, RealtimeUserImpl>::new(storage.collab_storage.clone()) + let collab_server = CollabServer::<_, Arc>::new(storage.collab_storage.clone()) .unwrap() .start(); let mut server = HttpServer::new(move || { diff --git a/tests/realtime/connect_test.rs b/tests/realtime/connect_test.rs index d773890a..e387f72b 100644 --- a/tests/realtime/connect_test.rs +++ b/tests/realtime/connect_test.rs @@ -30,3 +30,36 @@ async fn realtime_connect_test() { } } } + +#[tokio::test] +async fn realtime_disconnect_test() { + let _guard = REGISTERED_USER_MUTEX.lock().await; + + let mut c = client_api_client(); + c.sign_in_password(®ISTERED_EMAIL, ®ISTERED_PASSWORD) + .await + .unwrap(); + + let ws_client = WSClient::new(WSClientConfig { + buffer_capacity: 100, + ping_per_secs: 2, + retry_connect_per_pings: 5, + }); + ws_client + .connect(c.ws_url("fake_device_id").unwrap()) + .await + .unwrap(); + + let mut state = ws_client.subscribe_connect_state().await; + loop { + tokio::select! { + _ = ws_client.disconnect() => {}, + value = state.recv() => { + let new_state = value.unwrap(); + if new_state == ConnectState::Disconnected { + break; + } + }, + } + } +} diff --git a/tests/realtime/edit_collab_test.rs b/tests/realtime/edit_collab_test.rs index 32329abf..caec16ff 100644 --- a/tests/realtime/edit_collab_test.rs +++ b/tests/realtime/edit_collab_test.rs @@ -6,8 +6,11 @@ use collab_define::CollabType; use crate::realtime::test_client::{assert_collab_json, TestClient}; use assert_json_diff::assert_json_eq; +use shared_entity::error_code::ErrorCode; use std::time::Duration; use storage::collab::FLUSH_PER_UPDATE; +use storage_entity::QueryCollabParams; +use uuid::Uuid; #[tokio::test] async fn realtime_write_collab_test() { @@ -25,7 +28,6 @@ async fn realtime_write_collab_test() { // Wait for the messages to be sent tokio::time::sleep(Duration::from_secs(2)).await; - test_client.disconnect().await; assert_collab_json( &mut test_client.api_client, @@ -73,6 +75,52 @@ async fn one_direction_peer_sync_test() { assert_json_eq!(json_1, json_2); } +#[tokio::test] +async fn same_user_with_same_device_id_test() { + let object_id = uuid::Uuid::new_v4().to_string(); + let collab_type = CollabType::Document; + + // Client_1_2 will force the server to disconnect client_1_1. So any changes made by client_1_1 + // will not be saved to the server. + let device_id = Uuid::new_v4().to_string(); + let client_1_1 = + TestClient::new_with_device_id(&object_id, &device_id, collab_type.clone()).await; + let mut client_1_2 = + TestClient::new_with_device_id(&object_id, &device_id, collab_type.clone()).await; + + client_1_1.collab.lock().insert("1", "a"); + client_1_2.collab.lock().insert("2", "b"); + client_1_1.collab.lock().insert("3", "c"); + + tokio::time::sleep(Duration::from_millis(200)).await; + + let json_1 = client_1_1.collab.lock().to_json_value(); + let json_2 = client_1_2.collab.lock().to_json_value(); + assert_json_eq!( + json_1, + json!({ + "1": "a", + "3": "c" + }) + ); + assert_json_eq!( + json_2, + json!({ + "2": "b" + }) + ); + assert_collab_json( + &mut client_1_2.api_client, + &object_id, + &collab_type, + 5, + json!({ + "2": "b" + }), + ) + .await; +} + #[tokio::test] async fn two_direction_peer_sync_test() { let _client_api = client_api_client(); @@ -174,3 +222,51 @@ async fn multiple_collab_edit_test() { ) .await; } + +#[tokio::test] +async fn ws_reconnect_sync_test() { + let object_id = uuid::Uuid::new_v4().to_string(); + let collab_type = CollabType::Document; + let mut test_client = TestClient::new(&object_id, collab_type.clone()).await; + + // Disconnect the client and edit the collab. The updates will not be sent to the server. + test_client.disconnect().await; + for i in 0..=5 { + test_client + .collab + .lock() + .insert(&i.to_string(), i.to_string()); + } + + // it will return RecordNotFound error when trying to get the collab from the server + let err = test_client + .api_client + .get_collab(QueryCollabParams { + object_id: object_id.clone(), + collab_type: collab_type.clone(), + }) + .await + .unwrap_err(); + assert_eq!(err.code, ErrorCode::RecordNotFound); + + // After reconnect the collab should be synced to the server. + test_client.reconnect().await; + // Wait for the messages to be sent + tokio::time::sleep(Duration::from_secs(2)).await; + + assert_collab_json( + &mut test_client.api_client, + &object_id, + &collab_type, + 3, + json!( { + "0": "0", + "1": "1", + "2": "2", + "3": "3", + "4": "4", + "5": "5", + }), + ) + .await; +} diff --git a/tests/realtime/test_client.rs b/tests/realtime/test_client.rs index a4847de1..85d6ccbb 100644 --- a/tests/realtime/test_client.rs +++ b/tests/realtime/test_client.rs @@ -25,10 +25,16 @@ pub(crate) struct TestClient { #[allow(dead_code)] pub handler: Arc, pub api_client: client_api::Client, + device_id: String, } impl TestClient { - pub(crate) async fn new(object_id: &str, collab_type: CollabType) -> Self { + pub(crate) async fn new_with_device_id( + object_id: &str, + device_id: &str, + collab_type: CollabType, + ) -> Self { + let device_id = device_id.to_string(); let mut api_client = client_api_client(); let _guard = REGISTERED_USER_MUTEX.lock().await; @@ -38,7 +44,6 @@ impl TestClient { .await .unwrap(); - let device_id = Uuid::new_v4().to_string(); // Connect to server via websocket let ws_client = WSClient::new(WSClientConfig { buffer_capacity: 100, @@ -67,7 +72,7 @@ impl TestClient { .await .unwrap(); let (sink, stream) = (handler.sink(), handler.stream()); - let origin = CollabOrigin::Client(CollabClient::new(uid, device_id)); + let origin = CollabOrigin::Client(CollabClient::new(uid, device_id.clone())); let collab = Arc::new(MutexCollab::new(origin.clone(), object_id, vec![])); let object = SyncObject::new(object_id, &workspace_id, collab_type); @@ -87,12 +92,26 @@ impl TestClient { origin, collab, handler, + device_id, } } + pub(crate) async fn new(object_id: &str, collab_type: CollabType) -> Self { + let device_id = Uuid::new_v4().to_string(); + Self::new_with_device_id(object_id, &device_id, collab_type).await + } + pub(crate) async fn disconnect(&self) { self.ws_client.disconnect().await; } + + pub(crate) async fn reconnect(&self) { + self + .ws_client + .connect(self.api_client.ws_url(&self.device_id).unwrap()) + .await + .unwrap(); + } } #[allow(dead_code)] @@ -130,9 +149,9 @@ pub async fn assert_collab_json( } tokio::time::sleep(Duration::from_millis(200)).await; }, - Err(_) => { + Err(e) => { if retry_count > 5 { - panic!("Query collab failed"); + panic!("Query collab failed: {}", e); } tokio::time::sleep(Duration::from_millis(200)).await; }