use crate::entities::{ ClientMessage, ClientStreamMessage, Connect, Disconnect, Editing, RealtimeMessage, RealtimeUser, }; use crate::error::{RealtimeError, StreamError}; use anyhow::Result; use actix::{Actor, Context, Handler, ResponseFuture}; use futures_util::future::BoxFuture; use parking_lot::Mutex; use realtime_entity::collab_msg::CollabMessage; use std::collections::{HashMap, HashSet}; use std::future::Future; use std::pin::Pin; use actix::dev::Stream; use std::sync::Arc; use std::time::Duration; use tokio::sync::RwLock; use tokio::time::interval; use tokio_stream::wrappers::{BroadcastStream, ReceiverStream}; use tokio_stream::StreamExt; use tracing::{error, event, info, instrument, trace, warn}; use crate::client::ClientWSSink; use crate::collaborate::group::CollabGroupCache; use crate::collaborate::permission::CollabAccessControl; use crate::collaborate::retry::{CollabUserMessage, SubscribeGroupIfNeed}; use crate::util::channel_ext::UnboundedSenderSink; use database::collab::CollabStorage; #[derive(Clone)] pub struct CollabServer { #[allow(dead_code)] storage: Arc, /// Keep track of all collab groups groups: Arc>, /// This map stores the session IDs for users currently connected to the server. /// The user's identifier [U] is used as the key, and their corresponding session ID is the value. /// /// When a user disconnects, their session ID is retrieved using their user identifier [U]. /// This session ID is then compared with the session ID provided in the [Disconnect] message. /// If the two session IDs differ, it indicates that the user has established a new connection /// to the server since the stored session ID was last updated. /// user_by_uid: Arc>>, session_id_by_user: Arc>>, /// Keep track of all object ids that a user is subscribed to editing_collab_by_user: Arc>>>, /// Keep track of all client streams client_stream_by_user: Arc>>, access_control: Arc, } impl CollabServer where S: CollabStorage, U: RealtimeUser, AC: CollabAccessControl, { pub fn new(storage: Arc, access_control: AC) -> Result { let access_control = Arc::new(access_control); let groups = Arc::new(CollabGroupCache::new( storage.clone(), access_control.clone(), )); let edit_collab_by_user = Arc::new(Mutex::new(HashMap::new())); // Periodically check the collab groups let weak_group = Arc::downgrade(&groups); tokio::spawn(async move { let mut interval = interval(Duration::from_secs(60)); loop { interval.tick().await; match weak_group.upgrade() { Some(groups) => groups.tick().await, None => break, } } }); Ok(Self { storage, groups, user_by_uid: Default::default(), session_id_by_user: Default::default(), editing_collab_by_user: edit_collab_by_user, client_stream_by_user: Default::default(), access_control, }) } fn process_realtime_message( &mut self, user: U, mut message_stream: MS, ) -> Pin>>> where MS: Stream + Unpin + Send, { let client_stream_by_user = self.client_stream_by_user.clone(); let groups = self.groups.clone(); let edit_collab_by_user = self.editing_collab_by_user.clone(); let access_control = self.access_control.clone(); Box::pin(async move { match message_stream.next().await { None => Ok(()), Some(realtime_msg) => { trace!("Receive client:{} message:{}", user.uid(), realtime_msg); match realtime_msg { RealtimeMessage::Collab(collab_message) => { let msg = CollabUserMessage { user: &user, collab_message: &collab_message, }; SubscribeGroupIfNeed { collab_user_message: &msg, groups: &groups, edit_collab_by_user: &edit_collab_by_user, client_stream_by_user: &client_stream_by_user, access_control: &access_control, } .run() .await?; broadcast_message(&user, collab_message, &client_stream_by_user).await; Ok(()) }, _ => { warn!("Receive unsupported message: {}", realtime_msg); Ok(()) }, } }, } }) } } async fn remove_user( groups: &Arc>, editing_collab_by_user: &Arc>>>, user: &U, ) where S: CollabStorage, U: RealtimeUser, AC: CollabAccessControl, { let editing_set = editing_collab_by_user .try_lock() .and_then(|mut guard| guard.remove(user)); if let Some(editing_set) = editing_set { for editing in editing_set { remove_user_from_group(user, groups, &editing).await; } } } impl Actor for CollabServer where S: 'static + Unpin, U: RealtimeUser + Unpin, AC: CollabAccessControl + Unpin, { type Context = Context; fn started(&mut self, ctx: &mut Self::Context) { ctx.set_mailbox_capacity(100); } } impl Handler> for CollabServer where U: RealtimeUser + Unpin, S: CollabStorage + Unpin, AC: CollabAccessControl + Unpin, { type Result = ResponseFuture>; fn handle(&mut self, new_conn: Connect, _ctx: &mut Context) -> Self::Result { // User with the same id and same device will be replaced with the new connection [CollabClientStream] let client_stream = CollabClientStream::new(ClientWSSink(new_conn.socket)); let groups = self.groups.clone(); let user_by_uid = self.user_by_uid.clone(); let client_stream_by_user = self.client_stream_by_user.clone(); let editing_collab_by_user = self.editing_collab_by_user.clone(); let user_by_session_id = self.session_id_by_user.clone(); Box::pin(async move { trace!("[realtime]: new connection => {} ", new_conn.user); user_by_session_id .write() .await .insert(new_conn.user.clone(), new_conn.session_id); user_by_uid .write() .insert(new_conn.user.uid(), new_conn.user.clone()); // when a new connection is established, remove the old connection from all groups remove_user(&groups, &editing_collab_by_user, &new_conn.user).await; info!("new client stream:{}", &new_conn.user); if let Some(old_stream) = client_stream_by_user .write() .await .insert(new_conn.user, client_stream) { old_stream.disconnect(); } Ok(()) }) } } impl Handler> for CollabServer where U: RealtimeUser + Unpin, S: CollabStorage + Unpin, AC: CollabAccessControl + Unpin, { type Result = ResponseFuture>; /// Handles the disconnection of a user from the collaboration server. /// /// Upon receiving a `Disconnect` message, the method performs the following actions: /// 1. Attempts to acquire a read lock on `session_id_by_user` to compare the stored session ID /// with the session ID in the `Disconnect` message. /// - If the session IDs match, it proceeds to remove the user from groups and client streams. /// - If the session IDs do not match, indicating the user has reconnected with a new session, /// no action is taken and the function returns. /// 2. Removes the user from the collaboration groups and client streams, if applicable. /// fn handle(&mut self, msg: Disconnect, _: &mut Context) -> Self::Result { trace!("[realtime]: disconnect => {}", msg.user); let groups = self.groups.clone(); let user_by_uid = self.user_by_uid.clone(); let client_stream_by_user = self.client_stream_by_user.clone(); let editing_collab_by_user = self.editing_collab_by_user.clone(); let session_id_by_user = self.session_id_by_user.clone(); Box::pin(async move { let guard = match session_id_by_user.try_read() { Ok(guard) => guard, Err(_) => { return Ok(()); }, }; if let Some(session_id) = guard.get(&msg.user) { if session_id != &msg.session_id { return Ok(()); } } remove_user(&groups, &editing_collab_by_user, &msg.user).await; if let Ok(mut client_stream_by_user) = client_stream_by_user.try_write() { if client_stream_by_user.remove(&msg.user).is_some() { user_by_uid.write().remove(&msg.user.uid()); info!("remove client stream: {}", &msg.user); } } Ok(()) }) } } impl Handler> for CollabServer where U: RealtimeUser + Unpin, S: CollabStorage + Unpin, AC: CollabAccessControl + Unpin, { type Result = ResponseFuture>; fn handle(&mut self, client_msg: ClientMessage, _ctx: &mut Context) -> Self::Result { let ClientMessage { user, message } = client_msg; self.process_realtime_message(user, tokio_stream::once(message)) } } impl Handler for CollabServer where U: RealtimeUser + Unpin, S: CollabStorage + Unpin, AC: CollabAccessControl + Unpin, { type Result = ResponseFuture>; fn handle(&mut self, client_msg: ClientStreamMessage, _ctx: &mut Context) -> Self::Result { let ClientStreamMessage { uid, stream } = client_msg; let user = self.user_by_uid.read().get(&uid).cloned(); match user { None => Box::pin(async move { Err(RealtimeError::UserNotFound(format!( "Can't find the user with given id: {}", uid ))) }), Some(user) => self.process_realtime_message(user, stream), } } } #[inline] async fn broadcast_message( user: &U, collab_message: CollabMessage, client_streams: &Arc>>, ) where U: RealtimeUser, { let client_streams = client_streams.read().await; if let Some(client_stream) = client_streams.get(user) { trace!("[realtime]: receives collab message: {}", collab_message); match client_stream .stream_tx .send(Ok(RealtimeMessage::Collab(collab_message))) { Ok(_) => {}, Err(e) => error!("send error: {}", e), } } } /// Remove the user from the group and remove the group from the cache if the group is empty. #[instrument(level = "debug", skip_all)] async fn remove_user_from_group( user: &U, groups: &Arc>, editing: &Editing, ) where S: CollabStorage, U: RealtimeUser, AC: CollabAccessControl, { let _ = groups.remove_user(&editing.object_id, user).await; if let Some(group) = groups.get_group(&editing.object_id).await { event!( tracing::Level::TRACE, "{}: Remove group subscriber:{}, Current group member: {}. member ids: {:?}", &editing.object_id, editing.origin, group.subscribers.read().await.len(), group .subscribers .read() .await .values() .map(|value| value.origin.to_string()) .collect::>(), ); // Destroy the group if the group is empty let should_remove = group.is_empty().await; if should_remove { group.flush_collab(); event!(tracing::Level::INFO, "Remove group: {}", editing.object_id); groups.remove_group(&editing.object_id).await; } } } impl actix::Supervised for CollabServer where S: 'static + Unpin, U: RealtimeUser + Unpin, AC: CollabAccessControl + Unpin, { fn restarting(&mut self, _ctx: &mut Context>) { tracing::warn!("restarting"); } } pub struct CollabClientStream { sink: ClientWSSink, /// Used to receive messages from the collab server. The message will forward to the [CollabBroadcast] which /// will broadcast the message to all connected clients. /// /// The message flow: /// ClientSession(websocket) -> [CollabServer] -> [CollabClientStream] -> [CollabBroadcast] 1->* websocket(client) pub(crate) stream_tx: tokio::sync::broadcast::Sender>, } impl CollabClientStream { pub fn new(sink: ClientWSSink) -> Self { // When receive a new connection, create a new [ClientStream] that holds the connection's websocket let (stream_tx, _) = tokio::sync::broadcast::channel(1000); Self { sink, stream_tx } } /// Returns a [UnboundedSenderSink] and a [ReceiverStream] for the object_id. #[allow(clippy::type_complexity)] pub fn client_channel( &mut self, object_id: &str, sink_filter: SinkFilter, stream_filter: StreamFilter, ) -> ( UnboundedSenderSink, ReceiverStream>, ) where T: Into + Send + Sync + 'static, SinkFilter: Fn(&str, &T) -> BoxFuture<'static, bool> + Sync + Send + 'static, StreamFilter: Fn(&str, &CollabMessage) -> BoxFuture<'static, bool> + Sync + Send + 'static, { let client_ws_sink = self.sink.clone(); let mut stream_rx = BroadcastStream::new(self.stream_tx.subscribe()); let cloned_object_id = object_id.to_string(); // Send the message to the connected websocket client let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); tokio::task::spawn(async move { while let Some(msg) = rx.recv().await { let can_sink = sink_filter(&cloned_object_id, &msg).await; if can_sink { // Send the message to websocket client actor client_ws_sink.do_send(msg.into()); } } }); let client_forward_sink = UnboundedSenderSink::::new(tx); // forward the message to the stream that was subscribed by the broadcast group, which will // send the messages to all connected clients using the client_forward_sink let cloned_object_id = object_id.to_string(); let (tx, rx) = tokio::sync::mpsc::channel(100); tokio::spawn(async move { while let Some(Ok(Ok(RealtimeMessage::Collab(msg)))) = stream_rx.next().await { if stream_filter(&cloned_object_id, &msg).await { let _ = tx.send(Ok(msg)).await; } } }); let client_forward_stream = ReceiverStream::new(rx); // When broadcast group write a message to the client_forward_sink, the message will be forwarded // to the client's websocket sink, which will then send the message to the connected client // // When receiving a message from the client_forward_stream, it will send the message to the broadcast // group. The message will be broadcast to all connected clients. (client_forward_sink, client_forward_stream) } pub fn disconnect(&self) { self.sink.do_send(RealtimeMessage::ServerKickedOff); } }