use collab_rt_entity::user::{RealtimeUser, UserDevice}; use collab_rt_entity::{RealtimeMessage, SystemMessage}; use dashmap::DashMap; use crate::client_msg_router::ClientMessageRouter; use dashmap::mapref::entry::Entry; use std::sync::Arc; use tracing::{info, trace}; #[derive(Clone, Default)] pub struct ConnectState { pub(crate) user_by_device: Arc>, /// Maintains a record of all client streams. A client stream associated with a user may be terminated for the following reasons: /// 1. User disconnection. /// 2. Server closes the connection due to a ping/pong timeout. pub(crate) client_message_routers: Arc>, } impl ConnectState { pub fn new() -> Self { Self::default() } /// Handles a new user connection, updating the connection state accordingly. /// /// This function checks if there is already a connection from the same user device. If an existing /// connection is found and the new connection is more recent (`connect_at` is greater), the old connection /// is replaced with the new one. This process involves: /// /// - Removing the old user's client stream, if present, and sending a `DuplicateConnection` system message. /// - Inserting the new user connection into the `user_by_device` and `client_stream_by_user` maps. /// pub fn handle_user_connect( &self, new_user: RealtimeUser, client_message_router: ClientMessageRouter, ) -> Option { let user_device = UserDevice::from(&new_user); let entry = self.user_by_device.entry(user_device); match entry { Entry::Occupied(mut e) => { if e.get().connect_at <= new_user.connect_at { let old_user = e.insert(new_user.clone()); trace!("[realtime]: new connection replaces old => {}", new_user); if let Some((_, old_stream)) = self.client_message_routers.remove(&old_user) { info!( "Removing old stream for same user and device: {}", old_user.uid ); old_stream .sink .do_send(RealtimeMessage::System(SystemMessage::DuplicateConnection)); } self .client_message_routers .insert(new_user, client_message_router); Some(old_user) } else { None } }, Entry::Vacant(e) => { trace!("[realtime]: new connection => {}", new_user); e.insert(new_user.clone()); self .client_message_routers .insert(new_user, client_message_router); None }, } } /// Handles the disconnection of a user from the system. /// /// remove a user based on their device and session ID. If the session ID of the disconnecting user matches /// the session ID stored in the system for that device, the user is removed. Additionally, it also /// attempts to remove the associated client stream for the disconnecting user. /// pub fn handle_user_disconnect( &self, disconnect_user: &RealtimeUser, ) -> Option<(UserDevice, RealtimeUser)> { let user_device = UserDevice::from(disconnect_user); let was_removed = self .user_by_device .remove_if(&user_device, |_, existing_user| { existing_user.session_id == disconnect_user.session_id }); if was_removed.is_some() && self .client_message_routers .remove(disconnect_user) .is_some() { info!("remove client stream: {}", &disconnect_user); } was_removed } pub fn number_of_connected_users(&self) -> usize { self.user_by_device.len() } #[allow(dead_code)] fn get_user_by_device(&self, user_device: &UserDevice) -> Option { self.user_by_device.get(user_device).map(|v| v.clone()) } } #[cfg(test)] mod tests { use crate::client_msg_router::{ClientMessageRouter, RealtimeClientWebsocketSink}; use crate::connect_state::ConnectState; use collab_rt_entity::user::{RealtimeUser, UserDevice}; use collab_rt_entity::RealtimeMessage; use std::time::Duration; use tokio::time::sleep; struct MockSink; impl RealtimeClientWebsocketSink for MockSink { fn do_send(&self, _message: RealtimeMessage) {} } fn mock_user(uid: i64, device_id: &str, connect_at: i64) -> RealtimeUser { RealtimeUser::new( uid, device_id.to_string(), uuid::Uuid::new_v4().to_string(), connect_at, ) } fn mock_stream() -> ClientMessageRouter { ClientMessageRouter::new(MockSink) } #[tokio::test] async fn same_user_different_device_connect_test() { let connect_state = ConnectState::new(); let user_device_a = mock_user(1, "device_a", 1); let user_device_b = mock_user(1, "device_b", 1); connect_state.handle_user_connect(user_device_a, mock_stream()); connect_state.handle_user_connect(user_device_b, mock_stream()); assert_eq!(connect_state.user_by_device.len(), 2); } #[tokio::test] async fn same_user_same_device_connect_test() { let connect_state = ConnectState::new(); let user_device_a = mock_user(1, "device_a", 1); let user_device_b = mock_user(1, "device_a", 1); connect_state.handle_user_connect(user_device_a, mock_stream()); connect_state.handle_user_connect(user_device_b.clone(), mock_stream()); assert_eq!(connect_state.user_by_device.len(), 1); let user = connect_state .get_user_by_device(&UserDevice::from(&user_device_b)) .unwrap(); assert_eq!(user, user_device_b); } #[tokio::test] async fn multiple_same_devices_connect_test() { let connect_state = ConnectState::new(); let mut handles = vec![]; for i in 0..1000 { let cloned_connect_state = connect_state.clone(); let handle = tokio::spawn(async move { let random_seconds = rand::random::() % 500; sleep(Duration::from_millis(random_seconds)).await; let user = mock_user(1, "device_a", i); cloned_connect_state.handle_user_connect(user, mock_stream()); }); handles.push(handle); } let _ = futures::future::join_all(handles).await; let user = connect_state .get_user_by_device(&UserDevice::new("device_a", 1)) .unwrap(); assert_eq!(connect_state.user_by_device.len(), 1); assert_eq!(connect_state.client_message_routers.len(), 1); assert_eq!(user.connect_at, 999); } #[tokio::test] async fn multiple_same_devices_connect_disconnect_test() { let connect_state = ConnectState::new(); let mut handles = vec![]; for i in 0..2000 { let should_disconnect = i % 2 == 0; let cloned_connect_state = connect_state.clone(); let handle = tokio::spawn(async move { let random_seconds = rand::random::() % 500; sleep(Duration::from_millis(random_seconds)).await; let user = mock_user(1, "device_a", i); if should_disconnect { cloned_connect_state.handle_user_disconnect(&user); } else { cloned_connect_state.handle_user_connect(user, mock_stream()); } }); handles.push(handle); } let _ = futures::future::join_all(handles).await; let user = connect_state .get_user_by_device(&UserDevice::new("device_a", 1)) .unwrap(); assert_eq!(connect_state.user_by_device.len(), 1); assert_eq!(connect_state.client_message_routers.len(), 1); assert_eq!(user.connect_at, 1999); } #[tokio::test] async fn multiple_devices_connect_test() { let user_a = vec![ mock_user(1, "device_a", 1), mock_user(1, "device_b", 2), mock_user(1, "device_c", 1), mock_user(1, "device_d", 1), ]; let user_b = vec![ mock_user(2, "device_a", 1), // device_b starts two connections, last connection should be kept. mock_user(2, "device_b", 1), mock_user(2, "device_b", 2), mock_user(2, "device_a", 1), ]; let connect_state = ConnectState::new(); let (tx, rx_1) = tokio::sync::oneshot::channel(); let cloned_connect_state = connect_state.clone(); tokio::spawn(async move { for user in user_a { cloned_connect_state.handle_user_connect(user, mock_stream()); } tx.send(()).unwrap(); }); let (tx, rx_2) = tokio::sync::oneshot::channel(); let cloned_connect_state = connect_state.clone(); let clone_user_b = user_b.clone(); tokio::spawn(async move { for user in clone_user_b { cloned_connect_state.handle_user_connect(user, mock_stream()); } tx.send(()).unwrap(); }); let _ = futures::future::join(rx_1, rx_2).await; assert_eq!(connect_state.user_by_device.len(), 6); // device_b with connect_at 2 should be kept. let user = connect_state .get_user_by_device(&UserDevice::from(&user_b[2])) .unwrap(); assert_eq!(user.connect_at, 2); } }