AppFlowy-Cloud/libs/realtime/src/collaborate/server.rs

453 lines
15 KiB
Rust

use crate::entities::{
ClientMessage, ClientStreamMessage, Connect, Disconnect, Editing, RealtimeMessage, RealtimeUser,
};
use crate::error::{RealtimeError, StreamError};
use anyhow::Result;
use actix::{Actor, Context, Handler, ResponseFuture};
use futures_util::future::BoxFuture;
use parking_lot::Mutex;
use realtime_entity::collab_msg::CollabMessage;
use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::pin::Pin;
use actix::dev::Stream;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio::time::interval;
use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
use tokio_stream::StreamExt;
use tracing::{error, event, info, instrument, trace, warn};
use crate::client::ClientWSSink;
use crate::collaborate::group::CollabGroupCache;
use crate::collaborate::permission::CollabAccessControl;
use crate::collaborate::retry::{CollabUserMessage, SubscribeGroupIfNeed};
use crate::util::channel_ext::UnboundedSenderSink;
use database::collab::CollabStorage;
#[derive(Clone)]
pub struct CollabServer<S, U, AC> {
#[allow(dead_code)]
storage: Arc<S>,
/// Keep track of all collab groups
groups: Arc<CollabGroupCache<S, U, AC>>,
/// This map stores the session IDs for users currently connected to the server.
/// The user's identifier [U] is used as the key, and their corresponding session ID is the value.
///
/// When a user disconnects, their session ID is retrieved using their user identifier [U].
/// This session ID is then compared with the session ID provided in the [Disconnect] message.
/// If the two session IDs differ, it indicates that the user has established a new connection
/// to the server since the stored session ID was last updated.
///
user_by_uid: Arc<parking_lot::RwLock<HashMap<i64, U>>>,
session_id_by_user: Arc<RwLock<HashMap<U, String>>>,
/// Keep track of all object ids that a user is subscribed to
editing_collab_by_user: Arc<Mutex<HashMap<U, HashSet<Editing>>>>,
/// Keep track of all client streams
client_stream_by_user: Arc<RwLock<HashMap<U, CollabClientStream>>>,
access_control: Arc<AC>,
}
impl<S, U, AC> CollabServer<S, U, AC>
where
S: CollabStorage,
U: RealtimeUser,
AC: CollabAccessControl,
{
pub fn new(storage: Arc<S>, access_control: AC) -> Result<Self, RealtimeError> {
let access_control = Arc::new(access_control);
let groups = Arc::new(CollabGroupCache::new(
storage.clone(),
access_control.clone(),
));
let edit_collab_by_user = Arc::new(Mutex::new(HashMap::new()));
// Periodically check the collab groups
let weak_group = Arc::downgrade(&groups);
tokio::spawn(async move {
let mut interval = interval(Duration::from_secs(60));
loop {
interval.tick().await;
match weak_group.upgrade() {
Some(groups) => groups.tick().await,
None => break,
}
}
});
Ok(Self {
storage,
groups,
user_by_uid: Default::default(),
session_id_by_user: Default::default(),
editing_collab_by_user: edit_collab_by_user,
client_stream_by_user: Default::default(),
access_control,
})
}
fn process_realtime_message<MS>(
&mut self,
user: U,
mut message_stream: MS,
) -> Pin<Box<impl Future<Output = Result<(), RealtimeError>>>>
where
MS: Stream<Item = RealtimeMessage> + Unpin + Send,
{
let client_stream_by_user = self.client_stream_by_user.clone();
let groups = self.groups.clone();
let edit_collab_by_user = self.editing_collab_by_user.clone();
let access_control = self.access_control.clone();
Box::pin(async move {
match message_stream.next().await {
None => Ok(()),
Some(realtime_msg) => {
trace!("Receive client:{} message:{}", user.uid(), realtime_msg);
match realtime_msg {
RealtimeMessage::Collab(collab_message) => {
let msg = CollabUserMessage {
user: &user,
collab_message: &collab_message,
};
SubscribeGroupIfNeed {
collab_user_message: &msg,
groups: &groups,
edit_collab_by_user: &edit_collab_by_user,
client_stream_by_user: &client_stream_by_user,
access_control: &access_control,
}
.run()
.await?;
broadcast_message(&user, collab_message, &client_stream_by_user).await;
Ok(())
},
_ => {
warn!("Receive unsupported message: {}", realtime_msg);
Ok(())
},
}
},
}
})
}
}
async fn remove_user<S, U, AC>(
groups: &Arc<CollabGroupCache<S, U, AC>>,
editing_collab_by_user: &Arc<Mutex<HashMap<U, HashSet<Editing>>>>,
user: &U,
) where
S: CollabStorage,
U: RealtimeUser,
AC: CollabAccessControl,
{
let editing_set = editing_collab_by_user
.try_lock()
.and_then(|mut guard| guard.remove(user));
if let Some(editing_set) = editing_set {
for editing in editing_set {
remove_user_from_group(user, groups, &editing).await;
}
}
}
impl<S, U, AC> Actor for CollabServer<S, U, AC>
where
S: 'static + Unpin,
U: RealtimeUser + Unpin,
AC: CollabAccessControl + Unpin,
{
type Context = Context<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
ctx.set_mailbox_capacity(100);
}
}
impl<S, U, AC> Handler<Connect<U>> for CollabServer<S, U, AC>
where
U: RealtimeUser + Unpin,
S: CollabStorage + Unpin,
AC: CollabAccessControl + Unpin,
{
type Result = ResponseFuture<Result<(), RealtimeError>>;
fn handle(&mut self, new_conn: Connect<U>, _ctx: &mut Context<Self>) -> Self::Result {
// User with the same id and same device will be replaced with the new connection [CollabClientStream]
let client_stream = CollabClientStream::new(ClientWSSink(new_conn.socket));
let groups = self.groups.clone();
let user_by_uid = self.user_by_uid.clone();
let client_stream_by_user = self.client_stream_by_user.clone();
let editing_collab_by_user = self.editing_collab_by_user.clone();
let user_by_session_id = self.session_id_by_user.clone();
Box::pin(async move {
trace!("[realtime]: new connection => {} ", new_conn.user);
user_by_session_id
.write()
.await
.insert(new_conn.user.clone(), new_conn.session_id);
user_by_uid
.write()
.insert(new_conn.user.uid(), new_conn.user.clone());
// when a new connection is established, remove the old connection from all groups
remove_user(&groups, &editing_collab_by_user, &new_conn.user).await;
info!("new client stream:{}", &new_conn.user);
if let Some(old_stream) = client_stream_by_user
.write()
.await
.insert(new_conn.user, client_stream)
{
old_stream.disconnect();
}
Ok(())
})
}
}
impl<S, U, AC> Handler<Disconnect<U>> for CollabServer<S, U, AC>
where
U: RealtimeUser + Unpin,
S: CollabStorage + Unpin,
AC: CollabAccessControl + Unpin,
{
type Result = ResponseFuture<Result<(), RealtimeError>>;
/// Handles the disconnection of a user from the collaboration server.
///
/// Upon receiving a `Disconnect` message, the method performs the following actions:
/// 1. Attempts to acquire a read lock on `session_id_by_user` to compare the stored session ID
/// with the session ID in the `Disconnect` message.
/// - If the session IDs match, it proceeds to remove the user from groups and client streams.
/// - If the session IDs do not match, indicating the user has reconnected with a new session,
/// no action is taken and the function returns.
/// 2. Removes the user from the collaboration groups and client streams, if applicable.
///
fn handle(&mut self, msg: Disconnect<U>, _: &mut Context<Self>) -> Self::Result {
trace!("[realtime]: disconnect => {}", msg.user);
let groups = self.groups.clone();
let user_by_uid = self.user_by_uid.clone();
let client_stream_by_user = self.client_stream_by_user.clone();
let editing_collab_by_user = self.editing_collab_by_user.clone();
let session_id_by_user = self.session_id_by_user.clone();
Box::pin(async move {
let guard = match session_id_by_user.try_read() {
Ok(guard) => guard,
Err(_) => {
return Ok(());
},
};
if let Some(session_id) = guard.get(&msg.user) {
if session_id != &msg.session_id {
return Ok(());
}
}
remove_user(&groups, &editing_collab_by_user, &msg.user).await;
if let Ok(mut client_stream_by_user) = client_stream_by_user.try_write() {
if client_stream_by_user.remove(&msg.user).is_some() {
user_by_uid.write().remove(&msg.user.uid());
info!("remove client stream: {}", &msg.user);
}
}
Ok(())
})
}
}
impl<S, U, AC> Handler<ClientMessage<U>> for CollabServer<S, U, AC>
where
U: RealtimeUser + Unpin,
S: CollabStorage + Unpin,
AC: CollabAccessControl + Unpin,
{
type Result = ResponseFuture<Result<(), RealtimeError>>;
fn handle(&mut self, client_msg: ClientMessage<U>, _ctx: &mut Context<Self>) -> Self::Result {
let ClientMessage { user, message } = client_msg;
self.process_realtime_message(user, tokio_stream::once(message))
}
}
impl<S, U, AC> Handler<ClientStreamMessage> for CollabServer<S, U, AC>
where
U: RealtimeUser + Unpin,
S: CollabStorage + Unpin,
AC: CollabAccessControl + Unpin,
{
type Result = ResponseFuture<Result<(), RealtimeError>>;
fn handle(&mut self, client_msg: ClientStreamMessage, _ctx: &mut Context<Self>) -> Self::Result {
let ClientStreamMessage { uid, stream } = client_msg;
let user = self.user_by_uid.read().get(&uid).cloned();
match user {
None => Box::pin(async move {
Err(RealtimeError::UserNotFound(format!(
"Can't find the user with given id: {}",
uid
)))
}),
Some(user) => self.process_realtime_message(user, stream),
}
}
}
#[inline]
async fn broadcast_message<U>(
user: &U,
collab_message: CollabMessage,
client_streams: &Arc<RwLock<HashMap<U, CollabClientStream>>>,
) where
U: RealtimeUser,
{
let client_streams = client_streams.read().await;
if let Some(client_stream) = client_streams.get(user) {
trace!("[realtime]: receives collab message: {}", collab_message);
match client_stream
.stream_tx
.send(Ok(RealtimeMessage::Collab(collab_message)))
{
Ok(_) => {},
Err(e) => error!("send error: {}", e),
}
}
}
/// Remove the user from the group and remove the group from the cache if the group is empty.
#[instrument(level = "debug", skip_all)]
async fn remove_user_from_group<S, U, AC>(
user: &U,
groups: &Arc<CollabGroupCache<S, U, AC>>,
editing: &Editing,
) where
S: CollabStorage,
U: RealtimeUser,
AC: CollabAccessControl,
{
let _ = groups.remove_user(&editing.object_id, user).await;
if let Some(group) = groups.get_group(&editing.object_id).await {
event!(
tracing::Level::TRACE,
"{}: Remove group subscriber:{}, Current group member: {}. member ids: {:?}",
&editing.object_id,
editing.origin,
group.subscribers.read().await.len(),
group
.subscribers
.read()
.await
.values()
.map(|value| value.origin.to_string())
.collect::<Vec<_>>(),
);
// Destroy the group if the group is empty
let should_remove = group.is_empty().await;
if should_remove {
group.flush_collab();
event!(tracing::Level::INFO, "Remove group: {}", editing.object_id);
groups.remove_group(&editing.object_id).await;
}
}
}
impl<S, U, AC> actix::Supervised for CollabServer<S, U, AC>
where
S: 'static + Unpin,
U: RealtimeUser + Unpin,
AC: CollabAccessControl + Unpin,
{
fn restarting(&mut self, _ctx: &mut Context<CollabServer<S, U, AC>>) {
tracing::warn!("restarting");
}
}
pub struct CollabClientStream {
sink: ClientWSSink,
/// Used to receive messages from the collab server. The message will forward to the [CollabBroadcast] which
/// will broadcast the message to all connected clients.
///
/// The message flow:
/// ClientSession(websocket) -> [CollabServer] -> [CollabClientStream] -> [CollabBroadcast] 1->* websocket(client)
pub(crate) stream_tx: tokio::sync::broadcast::Sender<Result<RealtimeMessage, StreamError>>,
}
impl CollabClientStream {
pub fn new(sink: ClientWSSink) -> Self {
// When receive a new connection, create a new [ClientStream] that holds the connection's websocket
let (stream_tx, _) = tokio::sync::broadcast::channel(1000);
Self { sink, stream_tx }
}
/// Returns a [UnboundedSenderSink] and a [ReceiverStream] for the object_id.
#[allow(clippy::type_complexity)]
pub fn client_channel<T, SinkFilter, StreamFilter>(
&mut self,
object_id: &str,
sink_filter: SinkFilter,
stream_filter: StreamFilter,
) -> (
UnboundedSenderSink<T>,
ReceiverStream<Result<CollabMessage, StreamError>>,
)
where
T: Into<RealtimeMessage> + Send + Sync + 'static,
SinkFilter: Fn(&str, &T) -> BoxFuture<'static, bool> + Sync + Send + 'static,
StreamFilter: Fn(&str, &CollabMessage) -> BoxFuture<'static, bool> + Sync + Send + 'static,
{
let client_ws_sink = self.sink.clone();
let mut stream_rx = BroadcastStream::new(self.stream_tx.subscribe());
let cloned_object_id = object_id.to_string();
// Send the message to the connected websocket client
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<T>();
tokio::task::spawn(async move {
while let Some(msg) = rx.recv().await {
let can_sink = sink_filter(&cloned_object_id, &msg).await;
if can_sink {
// Send the message to websocket client actor
client_ws_sink.do_send(msg.into());
}
}
});
let client_forward_sink = UnboundedSenderSink::<T>::new(tx);
// forward the message to the stream that was subscribed by the broadcast group, which will
// send the messages to all connected clients using the client_forward_sink
let cloned_object_id = object_id.to_string();
let (tx, rx) = tokio::sync::mpsc::channel(100);
tokio::spawn(async move {
while let Some(Ok(Ok(RealtimeMessage::Collab(msg)))) = stream_rx.next().await {
if stream_filter(&cloned_object_id, &msg).await {
let _ = tx.send(Ok(msg)).await;
}
}
});
let client_forward_stream = ReceiverStream::new(rx);
// When broadcast group write a message to the client_forward_sink, the message will be forwarded
// to the client's websocket sink, which will then send the message to the connected client
//
// When receiving a message from the client_forward_stream, it will send the message to the broadcast
// group. The message will be broadcast to all connected clients.
(client_forward_sink, client_forward_stream)
}
pub fn disconnect(&self) {
self.sink.do_send(RealtimeMessage::ServerKickedOff);
}
}