chore: separate connect control (#413)

* chore: separate connect control

* chore: add tests

* chore: add tests
This commit is contained in:
Nathan.fooo 2024-03-24 07:30:21 +08:00 committed by GitHub
parent acc13414cf
commit 4878d51c1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 204 additions and 64 deletions

View File

@ -21,7 +21,7 @@ serde_json.workspace = true
thiserror = "1.0.56"
anyhow = "1"
collab = { version = "0.1.0"}
collab = { version = "0.1.0", features = ["async-plugin"]}
collab-entity = { version = "0.1.0" }
collab-folder = { version = "0.1.0" }
collab-document = { version = "0.1.0" }

View File

@ -0,0 +1,174 @@
use crate::CollabClientStream;
use collab_rt_entity::message::{RealtimeMessage, SystemMessage};
use collab_rt_entity::user::{RealtimeUser, UserDevice};
use dashmap::DashMap;
use std::sync::Arc;
use tracing::{info, trace};
#[derive(Clone, Default)]
pub struct ConnectState {
pub(crate) user_by_device: Arc<DashMap<UserDevice, RealtimeUser>>,
/// Maintains a record of all client streams. A client stream associated with a user may be terminated for the following reasons:
/// 1. User disconnection.
/// 2. Server closes the connection due to a ping/pong timeout.
pub(crate) client_stream_by_user: Arc<DashMap<RealtimeUser, CollabClientStream>>,
}
impl ConnectState {
pub fn new() -> Self {
Self::default()
}
pub fn handle_user_connect(
&self,
new_user: RealtimeUser,
client_stream: CollabClientStream,
) -> Option<RealtimeUser> {
let old_user = self
.user_by_device
.insert(UserDevice::from(&new_user), new_user.clone());
trace!(
"[realtime]: new connection => {}, removing old: {:?}",
new_user,
old_user
);
if let Some(old_user) = &old_user {
// Remove and retrieve the old client stream if it exists.
if let Some((_, client_stream)) = self.client_stream_by_user.remove(old_user) {
info!("Removing old stream for same user and device: {}", old_user);
// Notify the old stream of the duplicate connection.
client_stream
.sink
.do_send(RealtimeMessage::System(SystemMessage::DuplicateConnection));
}
// Remove the old user from all collaboration groups.
}
self.client_stream_by_user.insert(new_user, client_stream);
old_user
}
pub fn handle_user_disconnect(
&self,
disconnect_user: &RealtimeUser,
) -> Option<(UserDevice, RealtimeUser)> {
let user_device = UserDevice::from(disconnect_user);
let was_removed = self
.user_by_device
.remove_if(&user_device, |_, existing_user| {
existing_user.session_id == disconnect_user.session_id
});
if was_removed.is_some() && self.client_stream_by_user.remove(disconnect_user).is_some() {
info!("remove client stream: {}", &disconnect_user);
}
was_removed
}
#[allow(dead_code)]
fn num_connected_users(&self) -> usize {
self.user_by_device.len()
}
#[allow(dead_code)]
fn get_user_by_device(&self, user_device: &UserDevice) -> Option<RealtimeUser> {
self.user_by_device.get(user_device).map(|v| v.clone())
}
}
#[cfg(test)]
mod tests {
use crate::connect_state::ConnectState;
use crate::{CollabClientStream, RealtimeClientWebsocketSink};
use collab_rt_entity::message::RealtimeMessage;
use collab_rt_entity::user::{RealtimeUser, UserDevice};
struct MockSink;
impl RealtimeClientWebsocketSink for MockSink {
fn do_send(&self, _message: RealtimeMessage) {}
}
fn mock_user(uid: i64, device_id: &str) -> RealtimeUser {
RealtimeUser::new(
uid,
device_id.to_string(),
uuid::Uuid::new_v4().to_string(),
chrono::Utc::now().timestamp(),
)
}
fn mock_stream() -> CollabClientStream {
CollabClientStream::new(MockSink)
}
#[tokio::test]
async fn same_user_different_device_connect_test() {
let connect_state = ConnectState::new();
let user_device_a = mock_user(1, "device_a");
let user_device_b = mock_user(1, "device_b");
connect_state.handle_user_connect(user_device_a, mock_stream());
connect_state.handle_user_connect(user_device_b, mock_stream());
assert_eq!(connect_state.num_connected_users(), 2);
}
#[tokio::test]
async fn same_user_same_device_connect_test() {
let connect_state = ConnectState::new();
let user_device_a = mock_user(1, "device_a");
let user_device_b = mock_user(1, "device_a");
connect_state.handle_user_connect(user_device_a, mock_stream());
connect_state.handle_user_connect(user_device_b.clone(), mock_stream());
assert_eq!(connect_state.num_connected_users(), 1);
let user = connect_state
.get_user_by_device(&UserDevice::from(&user_device_b))
.unwrap();
assert_eq!(user, user_device_b);
}
#[tokio::test]
async fn multiple_devices_connect_test() {
let user_a = vec![
mock_user(1, "device_a"),
mock_user(1, "device_b"),
mock_user(1, "device_c"),
mock_user(1, "device_d"),
];
let user_b = vec![
mock_user(2, "device_a"),
mock_user(2, "device_b"),
mock_user(2, "device_b"),
mock_user(2, "device_a"),
];
let connect_state = ConnectState::new();
let (tx, rx_1) = tokio::sync::oneshot::channel();
let cloned_connect_state = connect_state.clone();
tokio::spawn(async move {
for user in user_a {
cloned_connect_state.handle_user_connect(user, mock_stream());
}
tx.send(()).unwrap();
});
let (tx, rx_2) = tokio::sync::oneshot::channel();
let cloned_connect_state = connect_state.clone();
tokio::spawn(async move {
for user in user_b {
cloned_connect_state.handle_user_connect(user, mock_stream());
}
tx.send(()).unwrap();
});
let _ = futures::future::join(rx_1, rx_2).await;
assert_eq!(connect_state.num_connected_users(), 6);
}
}

View File

@ -1,10 +1,12 @@
mod collaborate;
pub mod command;
pub mod connect_state;
pub mod error;
mod metrics;
mod permission;
mod rt_server;
mod util;
pub use metrics::*;
pub use permission::*;
pub use rt_server::*;

View File

@ -7,7 +7,7 @@ use crate::{spawn_metrics, CollabRealtimeMetrics, RealtimeAccessControl};
use anyhow::Result;
use collab_rt_entity::collab_msg::{ClientCollabMessage, CollabSinkMessage};
use collab_rt_entity::message::{MessageByObjectId, RealtimeMessage, SystemMessage};
use collab_rt_entity::message::{MessageByObjectId, RealtimeMessage};
use collab_rt_entity::user::{Editing, RealtimeUser, UserDevice};
use dashmap::mapref::entry::Entry;
use dashmap::DashMap;
@ -15,30 +15,24 @@ use database::collab::CollabStorage;
use std::collections::HashSet;
use std::future::Future;
use crate::connect_state::ConnectState;
use async_trait::async_trait;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
use tokio_stream::StreamExt;
use tracing::{error, event, info, trace, warn};
use tracing::{error, event, trace, warn};
#[derive(Clone)]
pub struct CollabRealtimeServer<S, AC> {
#[allow(dead_code)]
storage: Arc<S>,
/// Keep track of all collab groups
groups: Arc<AllGroup<S, AC>>,
//
pub user_by_device: Arc<DashMap<UserDevice, RealtimeUser>>,
/// Keep track of all object ids that a user is subscribed to
editing_collab_by_user: Arc<DashMap<RealtimeUser, HashSet<Editing>>>,
/// Maintains a record of all client streams. A client stream associated with a user may be terminated for the following reasons:
/// 1. User disconnection.
/// 2. Server closes the connection due to a ping/pong timeout.
client_stream_by_user: Arc<DashMap<RealtimeUser, CollabClientStream>>,
connect_state: ConnectState,
group_sender_by_object_id: Arc<DashMap<String, GroupCommandSender>>,
access_control: Arc<AC>,
/// Keep track of all object ids that a user is subscribed to
editing_collab_by_user: Arc<DashMap<RealtimeUser, HashSet<Editing>>>,
#[allow(dead_code)]
metrics: Arc<CollabRealtimeMetrics>,
}
@ -54,10 +48,9 @@ where
metrics: Arc<CollabRealtimeMetrics>,
command_recv: RTCommandReceiver,
) -> Result<Self, RealtimeError> {
let connect_state = ConnectState::new();
let access_control = Arc::new(access_control);
let groups = Arc::new(AllGroup::new(storage.clone(), access_control.clone()));
let client_stream_by_user: Arc<DashMap<RealtimeUser, CollabClientStream>> = Default::default();
let editing_collab_by_user = Default::default();
let group_sender_by_object_id: Arc<DashMap<String, GroupCommandSender>> =
Arc::new(Default::default());
@ -67,18 +60,16 @@ where
&group_sender_by_object_id,
Arc::downgrade(&groups),
&metrics,
&client_stream_by_user,
&connect_state.client_stream_by_user,
&storage,
);
Ok(Self {
storage,
groups,
user_by_device: Default::default(),
editing_collab_by_user,
client_stream_by_user,
connect_state,
group_sender_by_object_id,
access_control,
editing_collab_by_user: Arc::new(Default::default()),
metrics,
})
}
@ -97,38 +88,15 @@ where
) -> Pin<Box<dyn Future<Output = Result<(), RealtimeError>>>> {
let new_client_stream = CollabClientStream::new(conn_sink);
let groups = self.groups.clone();
let device_by_user = self.user_by_device.clone();
let client_stream_by_user = self.client_stream_by_user.clone();
let connect_control = self.connect_state.clone();
let editing_collab_by_user = self.editing_collab_by_user.clone();
Box::pin(async move {
let old_user =
device_by_user.insert(UserDevice::from(&connected_user), connected_user.clone());
trace!(
"[realtime]: new connection => {}, removing old: {:?}",
connected_user,
old_user
);
// If there was a previous connection for the same user, handle cleanup.
if let Some(old_user) = old_user {
// Remove and retrieve the old client stream if it exists.
if let Some((_, client_stream)) = client_stream_by_user.remove(&old_user) {
info!(
"Removing old stream for same user and device: {}",
&old_user
);
// Notify the old stream of the duplicate connection.
client_stream
.sink
.do_send(RealtimeMessage::System(SystemMessage::DuplicateConnection));
}
if let Some(old_user) = connect_control.handle_user_connect(connected_user, new_client_stream)
{
// Remove the old user from all collaboration groups.
remove_user_in_groups(&groups, &editing_collab_by_user, &old_user).await;
}
client_stream_by_user.insert(connected_user, new_client_stream);
Ok(())
})
}
@ -145,23 +113,14 @@ where
disconnect_user: RealtimeUser,
) -> Pin<Box<dyn Future<Output = Result<(), RealtimeError>>>> {
let groups = self.groups.clone();
let client_stream_by_user = self.client_stream_by_user.clone();
let connect_control = self.connect_state.clone();
let editing_collab_by_user = self.editing_collab_by_user.clone();
let device_by_user = self.user_by_device.clone();
Box::pin(async move {
let user_device = UserDevice::from(&disconnect_user);
let was_removed = device_by_user.remove_if(&user_device, |_, existing_user| {
existing_user.session_id == disconnect_user.session_id
});
trace!("[realtime]: disconnect => {}", disconnect_user);
let was_removed = connect_control.handle_user_disconnect(&disconnect_user);
if was_removed.is_some() {
trace!("[realtime]: disconnect => {}", disconnect_user);
remove_user_in_groups(&groups, &editing_collab_by_user, &disconnect_user).await;
if client_stream_by_user.remove(&disconnect_user).is_some() {
info!("remove client stream: {}", &disconnect_user);
}
}
Ok(())
@ -175,7 +134,7 @@ where
message_by_oid: MessageByObjectId,
) -> Pin<Box<dyn Future<Output = Result<(), RealtimeError>>>> {
let group_sender_by_object_id = self.group_sender_by_object_id.clone();
let client_stream_by_user = self.client_stream_by_user.clone();
let client_stream_by_user = self.connect_state.client_stream_by_user.clone();
let groups = self.groups.clone();
let edit_collab_by_user = self.editing_collab_by_user.clone();
let access_control = self.access_control.clone();
@ -223,6 +182,14 @@ where
Ok(())
})
}
pub fn get_user_by_device(&self, user_device: &UserDevice) -> Option<RealtimeUser> {
self
.connect_state
.user_by_device
.get(user_device)
.map(|entry| entry.value().clone())
}
}
async fn remove_user_in_groups<S, AC>(
@ -257,7 +224,7 @@ pub trait RealtimeClientWebsocketSink: Send + Sync + 'static {
}
pub struct CollabClientStream {
sink: Arc<dyn RealtimeClientWebsocketSink>,
pub(crate) sink: Arc<dyn RealtimeClientWebsocketSink>,
/// Used to receive messages from the collab server. The message will forward to the [CollabBroadcast] which
/// will broadcast the message to all connected clients.
///

View File

@ -99,10 +99,7 @@ where
message,
} = client_msg;
let user = self
.user_by_device
.get(&UserDevice::new(&device_id, uid))
.map(|entry| entry.value().clone());
let user = self.get_user_by_device(&UserDevice::new(&device_id, uid));
match (user, message.transform()) {
(Some(user), Ok(messages)) => self.handle_client_message(user, messages),