AppFlowy-Cloud/libs/collab-rt/src/connect_state.rs

269 lines
8.7 KiB
Rust

use collab_rt_entity::user::{RealtimeUser, UserDevice};
use collab_rt_entity::{RealtimeMessage, SystemMessage};
use dashmap::DashMap;
use crate::client_msg_router::ClientMessageRouter;
use dashmap::mapref::entry::Entry;
use std::sync::Arc;
use tracing::{info, trace};
#[derive(Clone, Default)]
pub struct ConnectState {
pub(crate) user_by_device: Arc<DashMap<UserDevice, RealtimeUser>>,
/// Maintains a record of all client streams. A client stream associated with a user may be terminated for the following reasons:
/// 1. User disconnection.
/// 2. Server closes the connection due to a ping/pong timeout.
pub(crate) client_message_routers: Arc<DashMap<RealtimeUser, ClientMessageRouter>>,
}
impl ConnectState {
pub fn new() -> Self {
Self::default()
}
/// Handles a new user connection, updating the connection state accordingly.
///
/// This function checks if there is already a connection from the same user device. If an existing
/// connection is found and the new connection is more recent (`connect_at` is greater), the old connection
/// is replaced with the new one. This process involves:
///
/// - Removing the old user's client stream, if present, and sending a `DuplicateConnection` system message.
/// - Inserting the new user connection into the `user_by_device` and `client_stream_by_user` maps.
///
pub fn handle_user_connect(
&self,
new_user: RealtimeUser,
client_message_router: ClientMessageRouter,
) -> Option<RealtimeUser> {
let user_device = UserDevice::from(&new_user);
let entry = self.user_by_device.entry(user_device);
match entry {
Entry::Occupied(mut e) => {
if e.get().connect_at <= new_user.connect_at {
let old_user = e.insert(new_user.clone());
trace!("[realtime]: new connection replaces old => {}", new_user);
if let Some((_, old_stream)) = self.client_message_routers.remove(&old_user) {
info!(
"Removing old stream for same user and device: {}",
old_user.uid
);
old_stream
.sink
.do_send(RealtimeMessage::System(SystemMessage::DuplicateConnection));
}
self
.client_message_routers
.insert(new_user, client_message_router);
Some(old_user)
} else {
None
}
},
Entry::Vacant(e) => {
trace!("[realtime]: new connection => {}", new_user);
e.insert(new_user.clone());
self
.client_message_routers
.insert(new_user, client_message_router);
None
},
}
}
/// Handles the disconnection of a user from the system.
///
/// remove a user based on their device and session ID. If the session ID of the disconnecting user matches
/// the session ID stored in the system for that device, the user is removed. Additionally, it also
/// attempts to remove the associated client stream for the disconnecting user.
///
pub fn handle_user_disconnect(
&self,
disconnect_user: &RealtimeUser,
) -> Option<(UserDevice, RealtimeUser)> {
let user_device = UserDevice::from(disconnect_user);
let was_removed = self
.user_by_device
.remove_if(&user_device, |_, existing_user| {
existing_user.session_id == disconnect_user.session_id
});
if was_removed.is_some()
&& self
.client_message_routers
.remove(disconnect_user)
.is_some()
{
info!("remove client stream: {}", &disconnect_user);
}
was_removed
}
pub fn number_of_connected_users(&self) -> usize {
self.user_by_device.len()
}
#[allow(dead_code)]
fn get_user_by_device(&self, user_device: &UserDevice) -> Option<RealtimeUser> {
self.user_by_device.get(user_device).map(|v| v.clone())
}
}
#[cfg(test)]
mod tests {
use crate::client_msg_router::{ClientMessageRouter, RealtimeClientWebsocketSink};
use crate::connect_state::ConnectState;
use collab_rt_entity::user::{RealtimeUser, UserDevice};
use collab_rt_entity::RealtimeMessage;
use std::time::Duration;
use tokio::time::sleep;
struct MockSink;
impl RealtimeClientWebsocketSink for MockSink {
fn do_send(&self, _message: RealtimeMessage) {}
}
fn mock_user(uid: i64, device_id: &str, connect_at: i64) -> RealtimeUser {
RealtimeUser::new(
uid,
device_id.to_string(),
uuid::Uuid::new_v4().to_string(),
connect_at,
)
}
fn mock_stream() -> ClientMessageRouter {
ClientMessageRouter::new(MockSink)
}
#[tokio::test]
async fn same_user_different_device_connect_test() {
let connect_state = ConnectState::new();
let user_device_a = mock_user(1, "device_a", 1);
let user_device_b = mock_user(1, "device_b", 1);
connect_state.handle_user_connect(user_device_a, mock_stream());
connect_state.handle_user_connect(user_device_b, mock_stream());
assert_eq!(connect_state.user_by_device.len(), 2);
}
#[tokio::test]
async fn same_user_same_device_connect_test() {
let connect_state = ConnectState::new();
let user_device_a = mock_user(1, "device_a", 1);
let user_device_b = mock_user(1, "device_a", 1);
connect_state.handle_user_connect(user_device_a, mock_stream());
connect_state.handle_user_connect(user_device_b.clone(), mock_stream());
assert_eq!(connect_state.user_by_device.len(), 1);
let user = connect_state
.get_user_by_device(&UserDevice::from(&user_device_b))
.unwrap();
assert_eq!(user, user_device_b);
}
#[tokio::test]
async fn multiple_same_devices_connect_test() {
let connect_state = ConnectState::new();
let mut handles = vec![];
for i in 0..1000 {
let cloned_connect_state = connect_state.clone();
let handle = tokio::spawn(async move {
let random_seconds = rand::random::<u64>() % 500;
sleep(Duration::from_millis(random_seconds)).await;
let user = mock_user(1, "device_a", i);
cloned_connect_state.handle_user_connect(user, mock_stream());
});
handles.push(handle);
}
let _ = futures::future::join_all(handles).await;
let user = connect_state
.get_user_by_device(&UserDevice::new("device_a", 1))
.unwrap();
assert_eq!(connect_state.user_by_device.len(), 1);
assert_eq!(connect_state.client_message_routers.len(), 1);
assert_eq!(user.connect_at, 999);
}
#[tokio::test]
async fn multiple_same_devices_connect_disconnect_test() {
let connect_state = ConnectState::new();
let mut handles = vec![];
for i in 0..2000 {
let should_disconnect = i % 2 == 0;
let cloned_connect_state = connect_state.clone();
let handle = tokio::spawn(async move {
let random_seconds = rand::random::<u64>() % 500;
sleep(Duration::from_millis(random_seconds)).await;
let user = mock_user(1, "device_a", i);
if should_disconnect {
cloned_connect_state.handle_user_disconnect(&user);
} else {
cloned_connect_state.handle_user_connect(user, mock_stream());
}
});
handles.push(handle);
}
let _ = futures::future::join_all(handles).await;
let user = connect_state
.get_user_by_device(&UserDevice::new("device_a", 1))
.unwrap();
assert_eq!(connect_state.user_by_device.len(), 1);
assert_eq!(connect_state.client_message_routers.len(), 1);
assert_eq!(user.connect_at, 1999);
}
#[tokio::test]
async fn multiple_devices_connect_test() {
let user_a = vec![
mock_user(1, "device_a", 1),
mock_user(1, "device_b", 2),
mock_user(1, "device_c", 1),
mock_user(1, "device_d", 1),
];
let user_b = vec![
mock_user(2, "device_a", 1),
// device_b starts two connections, last connection should be kept.
mock_user(2, "device_b", 1),
mock_user(2, "device_b", 2),
mock_user(2, "device_a", 1),
];
let connect_state = ConnectState::new();
let (tx, rx_1) = tokio::sync::oneshot::channel();
let cloned_connect_state = connect_state.clone();
tokio::spawn(async move {
for user in user_a {
cloned_connect_state.handle_user_connect(user, mock_stream());
}
tx.send(()).unwrap();
});
let (tx, rx_2) = tokio::sync::oneshot::channel();
let cloned_connect_state = connect_state.clone();
let clone_user_b = user_b.clone();
tokio::spawn(async move {
for user in clone_user_b {
cloned_connect_state.handle_user_connect(user, mock_stream());
}
tx.send(()).unwrap();
});
let _ = futures::future::join(rx_1, rx_2).await;
assert_eq!(connect_state.user_by_device.len(), 6);
// device_b with connect_at 2 should be kept.
let user = connect_state
.get_user_by_device(&UserDevice::from(&user_b[2]))
.unwrap();
assert_eq!(user.connect_at, 2);
}
}