From d0c0d7832c1ab00adf3f505a5af3e7914a2b208e Mon Sep 17 00:00:00 2001 From: "Nathan.fooo" <86001920+appflowy@users.noreply.github.com> Date: Sun, 24 Mar 2024 10:35:26 +0800 Subject: [PATCH] chore: add test for connect state (#414) * chore: add test * chore: add test * chore: disable redis test --- Cargo.lock | 1 + libs/collab-rt-entity/src/user.rs | 2 + libs/collab-rt/Cargo.toml | 5 +- libs/collab-rt/src/connect_state.rs | 173 ++++++++++++++++++++-------- tests/collab/storage_test.rs | 2 - 5 files changed, 134 insertions(+), 49 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e1fc3047..86db8e76 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1509,6 +1509,7 @@ dependencies = [ "md5", "parking_lot 0.12.1", "prometheus-client", + "rand 0.8.5", "redis 0.25.2", "semver", "serde", diff --git a/libs/collab-rt-entity/src/user.rs b/libs/collab-rt-entity/src/user.rs index 345c5214..833d51c7 100644 --- a/libs/collab-rt-entity/src/user.rs +++ b/libs/collab-rt-entity/src/user.rs @@ -60,6 +60,8 @@ pub struct RealtimeUser { pub uid: i64, pub device_id: String, pub connect_at: i64, + /// Represents the websocket connection session id. + /// When each websocket connection is established, a unique session id is generated. pub session_id: String, } diff --git a/libs/collab-rt/Cargo.toml b/libs/collab-rt/Cargo.toml index 8e4d93dc..2f8cc0dc 100644 --- a/libs/collab-rt/Cargo.toml +++ b/libs/collab-rt/Cargo.toml @@ -37,4 +37,7 @@ prometheus-client = "0.22.1" indexmap = "2.2.5" semver = "1.0.22" redis = "0.25.2" -parking_lot = "0.12.1" \ No newline at end of file +parking_lot = "0.12.1" + +[dev-dependencies] +rand = "0.8.5" \ No newline at end of file diff --git a/libs/collab-rt/src/connect_state.rs b/libs/collab-rt/src/connect_state.rs index eef1fcf7..e0b204de 100644 --- a/libs/collab-rt/src/connect_state.rs +++ b/libs/collab-rt/src/connect_state.rs @@ -3,6 +3,7 @@ use collab_rt_entity::message::{RealtimeMessage, SystemMessage}; use collab_rt_entity::user::{RealtimeUser, UserDevice}; use dashmap::DashMap; +use dashmap::mapref::entry::Entry; use std::sync::Arc; use tracing::{info, trace}; @@ -19,37 +20,58 @@ impl ConnectState { pub fn new() -> Self { Self::default() } + + /// Handles a new user connection, updating the connection state accordingly. + /// + /// This function checks if there is already a connection from the same user device. If an existing + /// connection is found and the new connection is more recent (`connect_at` is greater), the old connection + /// is replaced with the new one. This process involves: + /// + /// - Removing the old user's client stream, if present, and sending a `DuplicateConnection` system message. + /// - Inserting the new user connection into the `user_by_device` and `client_stream_by_user` maps. + /// pub fn handle_user_connect( &self, new_user: RealtimeUser, client_stream: CollabClientStream, ) -> Option { - let old_user = self - .user_by_device - .insert(UserDevice::from(&new_user), new_user.clone()); - - trace!( - "[realtime]: new connection => {}, removing old: {:?}", - new_user, - old_user - ); - - if let Some(old_user) = &old_user { - // Remove and retrieve the old client stream if it exists. - if let Some((_, client_stream)) = self.client_stream_by_user.remove(old_user) { - info!("Removing old stream for same user and device: {}", old_user); - // Notify the old stream of the duplicate connection. - client_stream - .sink - .do_send(RealtimeMessage::System(SystemMessage::DuplicateConnection)); - } - // Remove the old user from all collaboration groups. + let user_device = UserDevice::from(&new_user); + let entry = self.user_by_device.entry(user_device); + match entry { + Entry::Occupied(mut e) => { + if e.get().connect_at <= new_user.connect_at { + let old_user = e.insert(new_user.clone()); + trace!("[realtime]: new connection replaces old => {}", new_user); + if let Some((_, old_stream)) = self.client_stream_by_user.remove(&old_user) { + info!( + "Removing old stream for same user and device: {}", + old_user.uid + ); + old_stream + .sink + .do_send(RealtimeMessage::System(SystemMessage::DuplicateConnection)); + } + self.client_stream_by_user.insert(new_user, client_stream); + Some(old_user) + } else { + None + } + }, + Entry::Vacant(e) => { + trace!("[realtime]: new connection => {}", new_user); + e.insert(new_user.clone()); + self.client_stream_by_user.insert(new_user, client_stream); + None + }, } - self.client_stream_by_user.insert(new_user, client_stream); - - old_user } + /// Handles the disconnection of a user from the system. + /// + /// remove a user based on their device and session ID. If the session ID of the disconnecting user matches + /// the session ID stored in the system for that device, the user is removed. Additionally, it also + /// attempts to remove the associated client stream for the disconnecting user. + /// pub fn handle_user_disconnect( &self, disconnect_user: &RealtimeUser, @@ -68,11 +90,6 @@ impl ConnectState { was_removed } - #[allow(dead_code)] - fn num_connected_users(&self) -> usize { - self.user_by_device.len() - } - #[allow(dead_code)] fn get_user_by_device(&self, user_device: &UserDevice) -> Option { self.user_by_device.get(user_device).map(|v| v.clone()) @@ -85,6 +102,8 @@ mod tests { use crate::{CollabClientStream, RealtimeClientWebsocketSink}; use collab_rt_entity::message::RealtimeMessage; use collab_rt_entity::user::{RealtimeUser, UserDevice}; + use std::time::Duration; + use tokio::time::sleep; struct MockSink; @@ -92,12 +111,12 @@ mod tests { fn do_send(&self, _message: RealtimeMessage) {} } - fn mock_user(uid: i64, device_id: &str) -> RealtimeUser { + fn mock_user(uid: i64, device_id: &str, connect_at: i64) -> RealtimeUser { RealtimeUser::new( uid, device_id.to_string(), uuid::Uuid::new_v4().to_string(), - chrono::Utc::now().timestamp(), + connect_at, ) } @@ -108,43 +127,99 @@ mod tests { #[tokio::test] async fn same_user_different_device_connect_test() { let connect_state = ConnectState::new(); - let user_device_a = mock_user(1, "device_a"); - let user_device_b = mock_user(1, "device_b"); + let user_device_a = mock_user(1, "device_a", 1); + let user_device_b = mock_user(1, "device_b", 1); connect_state.handle_user_connect(user_device_a, mock_stream()); connect_state.handle_user_connect(user_device_b, mock_stream()); - assert_eq!(connect_state.num_connected_users(), 2); + assert_eq!(connect_state.user_by_device.len(), 2); } #[tokio::test] async fn same_user_same_device_connect_test() { let connect_state = ConnectState::new(); - let user_device_a = mock_user(1, "device_a"); - let user_device_b = mock_user(1, "device_a"); + let user_device_a = mock_user(1, "device_a", 1); + let user_device_b = mock_user(1, "device_a", 1); connect_state.handle_user_connect(user_device_a, mock_stream()); connect_state.handle_user_connect(user_device_b.clone(), mock_stream()); - assert_eq!(connect_state.num_connected_users(), 1); + assert_eq!(connect_state.user_by_device.len(), 1); let user = connect_state .get_user_by_device(&UserDevice::from(&user_device_b)) .unwrap(); assert_eq!(user, user_device_b); } + #[tokio::test] + async fn multiple_same_devices_connect_test() { + let connect_state = ConnectState::new(); + let mut handles = vec![]; + for i in 0..1000 { + let cloned_connect_state = connect_state.clone(); + let handle = tokio::spawn(async move { + let random_seconds = rand::random::() % 500; + sleep(Duration::from_millis(random_seconds)).await; + let user = mock_user(1, "device_a", i); + cloned_connect_state.handle_user_connect(user, mock_stream()); + }); + handles.push(handle); + } + + let _ = futures::future::join_all(handles).await; + let user = connect_state + .get_user_by_device(&UserDevice::new("device_a", 1)) + .unwrap(); + assert_eq!(connect_state.user_by_device.len(), 1); + assert_eq!(connect_state.client_stream_by_user.len(), 1); + assert_eq!(user.connect_at, 999); + } + + #[tokio::test] + async fn multiple_same_devices_connect_disconnect_test() { + let connect_state = ConnectState::new(); + let mut handles = vec![]; + for i in 0..2000 { + let should_disconnect = i % 2 == 0; + let cloned_connect_state = connect_state.clone(); + let handle = tokio::spawn(async move { + let random_seconds = rand::random::() % 500; + sleep(Duration::from_millis(random_seconds)).await; + let user = mock_user(1, "device_a", i); + + if should_disconnect { + cloned_connect_state.handle_user_disconnect(&user); + } else { + cloned_connect_state.handle_user_connect(user, mock_stream()); + } + }); + handles.push(handle); + } + + let _ = futures::future::join_all(handles).await; + let user = connect_state + .get_user_by_device(&UserDevice::new("device_a", 1)) + .unwrap(); + + assert_eq!(connect_state.user_by_device.len(), 1); + assert_eq!(connect_state.client_stream_by_user.len(), 1); + assert_eq!(user.connect_at, 1999); + } + #[tokio::test] async fn multiple_devices_connect_test() { let user_a = vec![ - mock_user(1, "device_a"), - mock_user(1, "device_b"), - mock_user(1, "device_c"), - mock_user(1, "device_d"), + mock_user(1, "device_a", 1), + mock_user(1, "device_b", 2), + mock_user(1, "device_c", 1), + mock_user(1, "device_d", 1), ]; let user_b = vec![ - mock_user(2, "device_a"), - mock_user(2, "device_b"), - mock_user(2, "device_b"), - mock_user(2, "device_a"), + mock_user(2, "device_a", 1), + // device_b starts two connections, last connection should be kept. + mock_user(2, "device_b", 1), + mock_user(2, "device_b", 2), + mock_user(2, "device_a", 1), ]; let connect_state = ConnectState::new(); @@ -160,15 +235,21 @@ mod tests { let (tx, rx_2) = tokio::sync::oneshot::channel(); let cloned_connect_state = connect_state.clone(); + let clone_user_b = user_b.clone(); tokio::spawn(async move { - for user in user_b { + for user in clone_user_b { cloned_connect_state.handle_user_connect(user, mock_stream()); } tx.send(()).unwrap(); }); let _ = futures::future::join(rx_1, rx_2).await; + assert_eq!(connect_state.user_by_device.len(), 6); - assert_eq!(connect_state.num_connected_users(), 6); + // device_b with connect_at 2 should be kept. + let user = connect_state + .get_user_by_device(&UserDevice::from(&user_b[2])) + .unwrap(); + assert_eq!(user.connect_at, 2); } } diff --git a/tests/collab/storage_test.rs b/tests/collab/storage_test.rs index 5b8986c1..818e7501 100644 --- a/tests/collab/storage_test.rs +++ b/tests/collab/storage_test.rs @@ -1,12 +1,10 @@ use crate::collab::util::test_encode_collab_v1; use app_error::ErrorCode; use client_api_test_util::*; - use collab_entity::CollabType; use database_entity::dto::{ CreateCollabParams, DeleteCollabParams, QueryCollab, QueryCollabParams, QueryCollabResult, }; - use sqlx::types::Uuid; use std::collections::HashMap;