chore: separate connect control (#413)
* chore: separate connect control * chore: add tests * chore: add tests
This commit is contained in:
parent
acc13414cf
commit
4878d51c1b
|
|
@ -21,7 +21,7 @@ serde_json.workspace = true
|
|||
thiserror = "1.0.56"
|
||||
anyhow = "1"
|
||||
|
||||
collab = { version = "0.1.0"}
|
||||
collab = { version = "0.1.0", features = ["async-plugin"]}
|
||||
collab-entity = { version = "0.1.0" }
|
||||
collab-folder = { version = "0.1.0" }
|
||||
collab-document = { version = "0.1.0" }
|
||||
|
|
|
|||
|
|
@ -0,0 +1,174 @@
|
|||
use crate::CollabClientStream;
|
||||
use collab_rt_entity::message::{RealtimeMessage, SystemMessage};
|
||||
use collab_rt_entity::user::{RealtimeUser, UserDevice};
|
||||
use dashmap::DashMap;
|
||||
|
||||
use std::sync::Arc;
|
||||
use tracing::{info, trace};
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct ConnectState {
|
||||
pub(crate) user_by_device: Arc<DashMap<UserDevice, RealtimeUser>>,
|
||||
/// Maintains a record of all client streams. A client stream associated with a user may be terminated for the following reasons:
|
||||
/// 1. User disconnection.
|
||||
/// 2. Server closes the connection due to a ping/pong timeout.
|
||||
pub(crate) client_stream_by_user: Arc<DashMap<RealtimeUser, CollabClientStream>>,
|
||||
}
|
||||
|
||||
impl ConnectState {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
pub fn handle_user_connect(
|
||||
&self,
|
||||
new_user: RealtimeUser,
|
||||
client_stream: CollabClientStream,
|
||||
) -> Option<RealtimeUser> {
|
||||
let old_user = self
|
||||
.user_by_device
|
||||
.insert(UserDevice::from(&new_user), new_user.clone());
|
||||
|
||||
trace!(
|
||||
"[realtime]: new connection => {}, removing old: {:?}",
|
||||
new_user,
|
||||
old_user
|
||||
);
|
||||
|
||||
if let Some(old_user) = &old_user {
|
||||
// Remove and retrieve the old client stream if it exists.
|
||||
if let Some((_, client_stream)) = self.client_stream_by_user.remove(old_user) {
|
||||
info!("Removing old stream for same user and device: {}", old_user);
|
||||
// Notify the old stream of the duplicate connection.
|
||||
client_stream
|
||||
.sink
|
||||
.do_send(RealtimeMessage::System(SystemMessage::DuplicateConnection));
|
||||
}
|
||||
// Remove the old user from all collaboration groups.
|
||||
}
|
||||
self.client_stream_by_user.insert(new_user, client_stream);
|
||||
|
||||
old_user
|
||||
}
|
||||
|
||||
pub fn handle_user_disconnect(
|
||||
&self,
|
||||
disconnect_user: &RealtimeUser,
|
||||
) -> Option<(UserDevice, RealtimeUser)> {
|
||||
let user_device = UserDevice::from(disconnect_user);
|
||||
let was_removed = self
|
||||
.user_by_device
|
||||
.remove_if(&user_device, |_, existing_user| {
|
||||
existing_user.session_id == disconnect_user.session_id
|
||||
});
|
||||
|
||||
if was_removed.is_some() && self.client_stream_by_user.remove(disconnect_user).is_some() {
|
||||
info!("remove client stream: {}", &disconnect_user);
|
||||
}
|
||||
|
||||
was_removed
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn num_connected_users(&self) -> usize {
|
||||
self.user_by_device.len()
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn get_user_by_device(&self, user_device: &UserDevice) -> Option<RealtimeUser> {
|
||||
self.user_by_device.get(user_device).map(|v| v.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::connect_state::ConnectState;
|
||||
use crate::{CollabClientStream, RealtimeClientWebsocketSink};
|
||||
use collab_rt_entity::message::RealtimeMessage;
|
||||
use collab_rt_entity::user::{RealtimeUser, UserDevice};
|
||||
|
||||
struct MockSink;
|
||||
|
||||
impl RealtimeClientWebsocketSink for MockSink {
|
||||
fn do_send(&self, _message: RealtimeMessage) {}
|
||||
}
|
||||
|
||||
fn mock_user(uid: i64, device_id: &str) -> RealtimeUser {
|
||||
RealtimeUser::new(
|
||||
uid,
|
||||
device_id.to_string(),
|
||||
uuid::Uuid::new_v4().to_string(),
|
||||
chrono::Utc::now().timestamp(),
|
||||
)
|
||||
}
|
||||
|
||||
fn mock_stream() -> CollabClientStream {
|
||||
CollabClientStream::new(MockSink)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn same_user_different_device_connect_test() {
|
||||
let connect_state = ConnectState::new();
|
||||
let user_device_a = mock_user(1, "device_a");
|
||||
let user_device_b = mock_user(1, "device_b");
|
||||
connect_state.handle_user_connect(user_device_a, mock_stream());
|
||||
connect_state.handle_user_connect(user_device_b, mock_stream());
|
||||
|
||||
assert_eq!(connect_state.num_connected_users(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn same_user_same_device_connect_test() {
|
||||
let connect_state = ConnectState::new();
|
||||
let user_device_a = mock_user(1, "device_a");
|
||||
let user_device_b = mock_user(1, "device_a");
|
||||
connect_state.handle_user_connect(user_device_a, mock_stream());
|
||||
connect_state.handle_user_connect(user_device_b.clone(), mock_stream());
|
||||
|
||||
assert_eq!(connect_state.num_connected_users(), 1);
|
||||
let user = connect_state
|
||||
.get_user_by_device(&UserDevice::from(&user_device_b))
|
||||
.unwrap();
|
||||
assert_eq!(user, user_device_b);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn multiple_devices_connect_test() {
|
||||
let user_a = vec![
|
||||
mock_user(1, "device_a"),
|
||||
mock_user(1, "device_b"),
|
||||
mock_user(1, "device_c"),
|
||||
mock_user(1, "device_d"),
|
||||
];
|
||||
|
||||
let user_b = vec![
|
||||
mock_user(2, "device_a"),
|
||||
mock_user(2, "device_b"),
|
||||
mock_user(2, "device_b"),
|
||||
mock_user(2, "device_a"),
|
||||
];
|
||||
|
||||
let connect_state = ConnectState::new();
|
||||
|
||||
let (tx, rx_1) = tokio::sync::oneshot::channel();
|
||||
let cloned_connect_state = connect_state.clone();
|
||||
tokio::spawn(async move {
|
||||
for user in user_a {
|
||||
cloned_connect_state.handle_user_connect(user, mock_stream());
|
||||
}
|
||||
tx.send(()).unwrap();
|
||||
});
|
||||
|
||||
let (tx, rx_2) = tokio::sync::oneshot::channel();
|
||||
let cloned_connect_state = connect_state.clone();
|
||||
tokio::spawn(async move {
|
||||
for user in user_b {
|
||||
cloned_connect_state.handle_user_connect(user, mock_stream());
|
||||
}
|
||||
tx.send(()).unwrap();
|
||||
});
|
||||
|
||||
let _ = futures::future::join(rx_1, rx_2).await;
|
||||
|
||||
assert_eq!(connect_state.num_connected_users(), 6);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,10 +1,12 @@
|
|||
mod collaborate;
|
||||
pub mod command;
|
||||
pub mod connect_state;
|
||||
pub mod error;
|
||||
mod metrics;
|
||||
mod permission;
|
||||
mod rt_server;
|
||||
mod util;
|
||||
|
||||
pub use metrics::*;
|
||||
pub use permission::*;
|
||||
pub use rt_server::*;
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ use crate::{spawn_metrics, CollabRealtimeMetrics, RealtimeAccessControl};
|
|||
|
||||
use anyhow::Result;
|
||||
use collab_rt_entity::collab_msg::{ClientCollabMessage, CollabSinkMessage};
|
||||
use collab_rt_entity::message::{MessageByObjectId, RealtimeMessage, SystemMessage};
|
||||
use collab_rt_entity::message::{MessageByObjectId, RealtimeMessage};
|
||||
use collab_rt_entity::user::{Editing, RealtimeUser, UserDevice};
|
||||
use dashmap::mapref::entry::Entry;
|
||||
use dashmap::DashMap;
|
||||
|
|
@ -15,30 +15,24 @@ use database::collab::CollabStorage;
|
|||
use std::collections::HashSet;
|
||||
use std::future::Future;
|
||||
|
||||
use crate::connect_state::ConnectState;
|
||||
use async_trait::async_trait;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::{error, event, info, trace, warn};
|
||||
use tracing::{error, event, trace, warn};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CollabRealtimeServer<S, AC> {
|
||||
#[allow(dead_code)]
|
||||
storage: Arc<S>,
|
||||
/// Keep track of all collab groups
|
||||
groups: Arc<AllGroup<S, AC>>,
|
||||
//
|
||||
pub user_by_device: Arc<DashMap<UserDevice, RealtimeUser>>,
|
||||
/// Keep track of all object ids that a user is subscribed to
|
||||
editing_collab_by_user: Arc<DashMap<RealtimeUser, HashSet<Editing>>>,
|
||||
/// Maintains a record of all client streams. A client stream associated with a user may be terminated for the following reasons:
|
||||
/// 1. User disconnection.
|
||||
/// 2. Server closes the connection due to a ping/pong timeout.
|
||||
client_stream_by_user: Arc<DashMap<RealtimeUser, CollabClientStream>>,
|
||||
connect_state: ConnectState,
|
||||
group_sender_by_object_id: Arc<DashMap<String, GroupCommandSender>>,
|
||||
access_control: Arc<AC>,
|
||||
/// Keep track of all object ids that a user is subscribed to
|
||||
editing_collab_by_user: Arc<DashMap<RealtimeUser, HashSet<Editing>>>,
|
||||
#[allow(dead_code)]
|
||||
metrics: Arc<CollabRealtimeMetrics>,
|
||||
}
|
||||
|
|
@ -54,10 +48,9 @@ where
|
|||
metrics: Arc<CollabRealtimeMetrics>,
|
||||
command_recv: RTCommandReceiver,
|
||||
) -> Result<Self, RealtimeError> {
|
||||
let connect_state = ConnectState::new();
|
||||
let access_control = Arc::new(access_control);
|
||||
let groups = Arc::new(AllGroup::new(storage.clone(), access_control.clone()));
|
||||
let client_stream_by_user: Arc<DashMap<RealtimeUser, CollabClientStream>> = Default::default();
|
||||
let editing_collab_by_user = Default::default();
|
||||
let group_sender_by_object_id: Arc<DashMap<String, GroupCommandSender>> =
|
||||
Arc::new(Default::default());
|
||||
|
||||
|
|
@ -67,18 +60,16 @@ where
|
|||
&group_sender_by_object_id,
|
||||
Arc::downgrade(&groups),
|
||||
&metrics,
|
||||
&client_stream_by_user,
|
||||
&connect_state.client_stream_by_user,
|
||||
&storage,
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
storage,
|
||||
groups,
|
||||
user_by_device: Default::default(),
|
||||
editing_collab_by_user,
|
||||
client_stream_by_user,
|
||||
connect_state,
|
||||
group_sender_by_object_id,
|
||||
access_control,
|
||||
editing_collab_by_user: Arc::new(Default::default()),
|
||||
metrics,
|
||||
})
|
||||
}
|
||||
|
|
@ -97,38 +88,15 @@ where
|
|||
) -> Pin<Box<dyn Future<Output = Result<(), RealtimeError>>>> {
|
||||
let new_client_stream = CollabClientStream::new(conn_sink);
|
||||
let groups = self.groups.clone();
|
||||
let device_by_user = self.user_by_device.clone();
|
||||
let client_stream_by_user = self.client_stream_by_user.clone();
|
||||
let connect_control = self.connect_state.clone();
|
||||
let editing_collab_by_user = self.editing_collab_by_user.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
let old_user =
|
||||
device_by_user.insert(UserDevice::from(&connected_user), connected_user.clone());
|
||||
|
||||
trace!(
|
||||
"[realtime]: new connection => {}, removing old: {:?}",
|
||||
connected_user,
|
||||
old_user
|
||||
);
|
||||
|
||||
// If there was a previous connection for the same user, handle cleanup.
|
||||
if let Some(old_user) = old_user {
|
||||
// Remove and retrieve the old client stream if it exists.
|
||||
if let Some((_, client_stream)) = client_stream_by_user.remove(&old_user) {
|
||||
info!(
|
||||
"Removing old stream for same user and device: {}",
|
||||
&old_user
|
||||
);
|
||||
// Notify the old stream of the duplicate connection.
|
||||
client_stream
|
||||
.sink
|
||||
.do_send(RealtimeMessage::System(SystemMessage::DuplicateConnection));
|
||||
}
|
||||
if let Some(old_user) = connect_control.handle_user_connect(connected_user, new_client_stream)
|
||||
{
|
||||
// Remove the old user from all collaboration groups.
|
||||
remove_user_in_groups(&groups, &editing_collab_by_user, &old_user).await;
|
||||
}
|
||||
|
||||
client_stream_by_user.insert(connected_user, new_client_stream);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
|
@ -145,23 +113,14 @@ where
|
|||
disconnect_user: RealtimeUser,
|
||||
) -> Pin<Box<dyn Future<Output = Result<(), RealtimeError>>>> {
|
||||
let groups = self.groups.clone();
|
||||
let client_stream_by_user = self.client_stream_by_user.clone();
|
||||
let connect_control = self.connect_state.clone();
|
||||
let editing_collab_by_user = self.editing_collab_by_user.clone();
|
||||
let device_by_user = self.user_by_device.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
let user_device = UserDevice::from(&disconnect_user);
|
||||
let was_removed = device_by_user.remove_if(&user_device, |_, existing_user| {
|
||||
existing_user.session_id == disconnect_user.session_id
|
||||
});
|
||||
|
||||
trace!("[realtime]: disconnect => {}", disconnect_user);
|
||||
let was_removed = connect_control.handle_user_disconnect(&disconnect_user);
|
||||
if was_removed.is_some() {
|
||||
trace!("[realtime]: disconnect => {}", disconnect_user);
|
||||
|
||||
remove_user_in_groups(&groups, &editing_collab_by_user, &disconnect_user).await;
|
||||
if client_stream_by_user.remove(&disconnect_user).is_some() {
|
||||
info!("remove client stream: {}", &disconnect_user);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
@ -175,7 +134,7 @@ where
|
|||
message_by_oid: MessageByObjectId,
|
||||
) -> Pin<Box<dyn Future<Output = Result<(), RealtimeError>>>> {
|
||||
let group_sender_by_object_id = self.group_sender_by_object_id.clone();
|
||||
let client_stream_by_user = self.client_stream_by_user.clone();
|
||||
let client_stream_by_user = self.connect_state.client_stream_by_user.clone();
|
||||
let groups = self.groups.clone();
|
||||
let edit_collab_by_user = self.editing_collab_by_user.clone();
|
||||
let access_control = self.access_control.clone();
|
||||
|
|
@ -223,6 +182,14 @@ where
|
|||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_user_by_device(&self, user_device: &UserDevice) -> Option<RealtimeUser> {
|
||||
self
|
||||
.connect_state
|
||||
.user_by_device
|
||||
.get(user_device)
|
||||
.map(|entry| entry.value().clone())
|
||||
}
|
||||
}
|
||||
|
||||
async fn remove_user_in_groups<S, AC>(
|
||||
|
|
@ -257,7 +224,7 @@ pub trait RealtimeClientWebsocketSink: Send + Sync + 'static {
|
|||
}
|
||||
|
||||
pub struct CollabClientStream {
|
||||
sink: Arc<dyn RealtimeClientWebsocketSink>,
|
||||
pub(crate) sink: Arc<dyn RealtimeClientWebsocketSink>,
|
||||
/// Used to receive messages from the collab server. The message will forward to the [CollabBroadcast] which
|
||||
/// will broadcast the message to all connected clients.
|
||||
///
|
||||
|
|
|
|||
|
|
@ -99,10 +99,7 @@ where
|
|||
message,
|
||||
} = client_msg;
|
||||
|
||||
let user = self
|
||||
.user_by_device
|
||||
.get(&UserDevice::new(&device_id, uid))
|
||||
.map(|entry| entry.value().clone());
|
||||
let user = self.get_user_by_device(&UserDevice::new(&device_id, uid));
|
||||
|
||||
match (user, message.transform()) {
|
||||
(Some(user), Ok(messages)) => self.handle_client_message(user, messages),
|
||||
|
|
|
|||
Loading…
Reference in New Issue