chore: add test for connect state (#414)

* chore: add test

* chore: add test

* chore: disable redis test
This commit is contained in:
Nathan.fooo 2024-03-24 10:35:26 +08:00 committed by GitHub
parent 4878d51c1b
commit d0c0d7832c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 134 additions and 49 deletions

1
Cargo.lock generated
View File

@ -1509,6 +1509,7 @@ dependencies = [
"md5",
"parking_lot 0.12.1",
"prometheus-client",
"rand 0.8.5",
"redis 0.25.2",
"semver",
"serde",

View File

@ -60,6 +60,8 @@ pub struct RealtimeUser {
pub uid: i64,
pub device_id: String,
pub connect_at: i64,
/// Represents the websocket connection session id.
/// When each websocket connection is established, a unique session id is generated.
pub session_id: String,
}

View File

@ -37,4 +37,7 @@ prometheus-client = "0.22.1"
indexmap = "2.2.5"
semver = "1.0.22"
redis = "0.25.2"
parking_lot = "0.12.1"
parking_lot = "0.12.1"
[dev-dependencies]
rand = "0.8.5"

View File

@ -3,6 +3,7 @@ use collab_rt_entity::message::{RealtimeMessage, SystemMessage};
use collab_rt_entity::user::{RealtimeUser, UserDevice};
use dashmap::DashMap;
use dashmap::mapref::entry::Entry;
use std::sync::Arc;
use tracing::{info, trace};
@ -19,37 +20,58 @@ impl ConnectState {
pub fn new() -> Self {
Self::default()
}
/// Handles a new user connection, updating the connection state accordingly.
///
/// This function checks if there is already a connection from the same user device. If an existing
/// connection is found and the new connection is more recent (`connect_at` is greater), the old connection
/// is replaced with the new one. This process involves:
///
/// - Removing the old user's client stream, if present, and sending a `DuplicateConnection` system message.
/// - Inserting the new user connection into the `user_by_device` and `client_stream_by_user` maps.
///
pub fn handle_user_connect(
&self,
new_user: RealtimeUser,
client_stream: CollabClientStream,
) -> Option<RealtimeUser> {
let old_user = self
.user_by_device
.insert(UserDevice::from(&new_user), new_user.clone());
trace!(
"[realtime]: new connection => {}, removing old: {:?}",
new_user,
old_user
);
if let Some(old_user) = &old_user {
// Remove and retrieve the old client stream if it exists.
if let Some((_, client_stream)) = self.client_stream_by_user.remove(old_user) {
info!("Removing old stream for same user and device: {}", old_user);
// Notify the old stream of the duplicate connection.
client_stream
.sink
.do_send(RealtimeMessage::System(SystemMessage::DuplicateConnection));
}
// Remove the old user from all collaboration groups.
let user_device = UserDevice::from(&new_user);
let entry = self.user_by_device.entry(user_device);
match entry {
Entry::Occupied(mut e) => {
if e.get().connect_at <= new_user.connect_at {
let old_user = e.insert(new_user.clone());
trace!("[realtime]: new connection replaces old => {}", new_user);
if let Some((_, old_stream)) = self.client_stream_by_user.remove(&old_user) {
info!(
"Removing old stream for same user and device: {}",
old_user.uid
);
old_stream
.sink
.do_send(RealtimeMessage::System(SystemMessage::DuplicateConnection));
}
self.client_stream_by_user.insert(new_user, client_stream);
Some(old_user)
} else {
None
}
},
Entry::Vacant(e) => {
trace!("[realtime]: new connection => {}", new_user);
e.insert(new_user.clone());
self.client_stream_by_user.insert(new_user, client_stream);
None
},
}
self.client_stream_by_user.insert(new_user, client_stream);
old_user
}
/// Handles the disconnection of a user from the system.
///
/// remove a user based on their device and session ID. If the session ID of the disconnecting user matches
/// the session ID stored in the system for that device, the user is removed. Additionally, it also
/// attempts to remove the associated client stream for the disconnecting user.
///
pub fn handle_user_disconnect(
&self,
disconnect_user: &RealtimeUser,
@ -68,11 +90,6 @@ impl ConnectState {
was_removed
}
#[allow(dead_code)]
fn num_connected_users(&self) -> usize {
self.user_by_device.len()
}
#[allow(dead_code)]
fn get_user_by_device(&self, user_device: &UserDevice) -> Option<RealtimeUser> {
self.user_by_device.get(user_device).map(|v| v.clone())
@ -85,6 +102,8 @@ mod tests {
use crate::{CollabClientStream, RealtimeClientWebsocketSink};
use collab_rt_entity::message::RealtimeMessage;
use collab_rt_entity::user::{RealtimeUser, UserDevice};
use std::time::Duration;
use tokio::time::sleep;
struct MockSink;
@ -92,12 +111,12 @@ mod tests {
fn do_send(&self, _message: RealtimeMessage) {}
}
fn mock_user(uid: i64, device_id: &str) -> RealtimeUser {
fn mock_user(uid: i64, device_id: &str, connect_at: i64) -> RealtimeUser {
RealtimeUser::new(
uid,
device_id.to_string(),
uuid::Uuid::new_v4().to_string(),
chrono::Utc::now().timestamp(),
connect_at,
)
}
@ -108,43 +127,99 @@ mod tests {
#[tokio::test]
async fn same_user_different_device_connect_test() {
let connect_state = ConnectState::new();
let user_device_a = mock_user(1, "device_a");
let user_device_b = mock_user(1, "device_b");
let user_device_a = mock_user(1, "device_a", 1);
let user_device_b = mock_user(1, "device_b", 1);
connect_state.handle_user_connect(user_device_a, mock_stream());
connect_state.handle_user_connect(user_device_b, mock_stream());
assert_eq!(connect_state.num_connected_users(), 2);
assert_eq!(connect_state.user_by_device.len(), 2);
}
#[tokio::test]
async fn same_user_same_device_connect_test() {
let connect_state = ConnectState::new();
let user_device_a = mock_user(1, "device_a");
let user_device_b = mock_user(1, "device_a");
let user_device_a = mock_user(1, "device_a", 1);
let user_device_b = mock_user(1, "device_a", 1);
connect_state.handle_user_connect(user_device_a, mock_stream());
connect_state.handle_user_connect(user_device_b.clone(), mock_stream());
assert_eq!(connect_state.num_connected_users(), 1);
assert_eq!(connect_state.user_by_device.len(), 1);
let user = connect_state
.get_user_by_device(&UserDevice::from(&user_device_b))
.unwrap();
assert_eq!(user, user_device_b);
}
#[tokio::test]
async fn multiple_same_devices_connect_test() {
let connect_state = ConnectState::new();
let mut handles = vec![];
for i in 0..1000 {
let cloned_connect_state = connect_state.clone();
let handle = tokio::spawn(async move {
let random_seconds = rand::random::<u64>() % 500;
sleep(Duration::from_millis(random_seconds)).await;
let user = mock_user(1, "device_a", i);
cloned_connect_state.handle_user_connect(user, mock_stream());
});
handles.push(handle);
}
let _ = futures::future::join_all(handles).await;
let user = connect_state
.get_user_by_device(&UserDevice::new("device_a", 1))
.unwrap();
assert_eq!(connect_state.user_by_device.len(), 1);
assert_eq!(connect_state.client_stream_by_user.len(), 1);
assert_eq!(user.connect_at, 999);
}
#[tokio::test]
async fn multiple_same_devices_connect_disconnect_test() {
let connect_state = ConnectState::new();
let mut handles = vec![];
for i in 0..2000 {
let should_disconnect = i % 2 == 0;
let cloned_connect_state = connect_state.clone();
let handle = tokio::spawn(async move {
let random_seconds = rand::random::<u64>() % 500;
sleep(Duration::from_millis(random_seconds)).await;
let user = mock_user(1, "device_a", i);
if should_disconnect {
cloned_connect_state.handle_user_disconnect(&user);
} else {
cloned_connect_state.handle_user_connect(user, mock_stream());
}
});
handles.push(handle);
}
let _ = futures::future::join_all(handles).await;
let user = connect_state
.get_user_by_device(&UserDevice::new("device_a", 1))
.unwrap();
assert_eq!(connect_state.user_by_device.len(), 1);
assert_eq!(connect_state.client_stream_by_user.len(), 1);
assert_eq!(user.connect_at, 1999);
}
#[tokio::test]
async fn multiple_devices_connect_test() {
let user_a = vec![
mock_user(1, "device_a"),
mock_user(1, "device_b"),
mock_user(1, "device_c"),
mock_user(1, "device_d"),
mock_user(1, "device_a", 1),
mock_user(1, "device_b", 2),
mock_user(1, "device_c", 1),
mock_user(1, "device_d", 1),
];
let user_b = vec![
mock_user(2, "device_a"),
mock_user(2, "device_b"),
mock_user(2, "device_b"),
mock_user(2, "device_a"),
mock_user(2, "device_a", 1),
// device_b starts two connections, last connection should be kept.
mock_user(2, "device_b", 1),
mock_user(2, "device_b", 2),
mock_user(2, "device_a", 1),
];
let connect_state = ConnectState::new();
@ -160,15 +235,21 @@ mod tests {
let (tx, rx_2) = tokio::sync::oneshot::channel();
let cloned_connect_state = connect_state.clone();
let clone_user_b = user_b.clone();
tokio::spawn(async move {
for user in user_b {
for user in clone_user_b {
cloned_connect_state.handle_user_connect(user, mock_stream());
}
tx.send(()).unwrap();
});
let _ = futures::future::join(rx_1, rx_2).await;
assert_eq!(connect_state.user_by_device.len(), 6);
assert_eq!(connect_state.num_connected_users(), 6);
// device_b with connect_at 2 should be kept.
let user = connect_state
.get_user_by_device(&UserDevice::from(&user_b[2]))
.unwrap();
assert_eq!(user.connect_at, 2);
}
}

View File

@ -1,12 +1,10 @@
use crate::collab::util::test_encode_collab_v1;
use app_error::ErrorCode;
use client_api_test_util::*;
use collab_entity::CollabType;
use database_entity::dto::{
CreateCollabParams, DeleteCollabParams, QueryCollab, QueryCollabParams, QueryCollabResult,
};
use sqlx::types::Uuid;
use std::collections::HashMap;