fix: peer sync (#408)

This commit is contained in:
Nathan.fooo 2024-03-22 21:32:41 +08:00 committed by GitHub
parent c85383b21d
commit c015ee7c7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 135 additions and 89 deletions

View File

@ -14,7 +14,7 @@ use collab_rt_entity::collab_msg::{
AckCode, BroadcastSync, ClientCollabMessage, InitSync, ServerCollabMessage, ServerInit,
UpdateSync,
};
use collab_rt_protocol::{handle_collab_message, ClientSyncProtocol, CollabSyncProtocol};
use collab_rt_protocol::{handle_message, ClientSyncProtocol, CollabSyncProtocol};
use collab_rt_protocol::{Message, MessageReader, SyncMessage};
use futures_util::{SinkExt, StreamExt};
use std::marker::PhantomData;
@ -412,8 +412,7 @@ where
for msg in reader {
let msg = msg?;
let is_server_sync_step_1 = matches!(msg, Message::Sync(SyncMessage::SyncStep1(_)));
if let Some(payload) = handle_collab_message(origin, &ClientSyncProtocol, &mut collab, msg)?
{
if let Some(payload) = handle_message(origin, &ClientSyncProtocol, &mut collab, msg)? {
let object_id = object_id.to_string();
sink.queue_msg(|msg_id| {
if is_server_sync_step_1 {

View File

@ -64,8 +64,11 @@ pub trait RealtimeUser:
{
fn uid(&self) -> i64;
fn device_id(&self) -> &str;
fn connect_at(&self) -> i64;
fn user_device(&self) -> String {
format!("{}-{}", self.uid(), self.device_id())
format!("{}:{}", self.uid(), self.device_id())
}
}
@ -80,4 +83,8 @@ where
fn device_id(&self) -> &str {
self.as_ref().device_id()
}
fn connect_at(&self) -> i64 {
self.as_ref().connect_at()
}
}

View File

@ -138,7 +138,7 @@ pub trait CollabSyncProtocol {
}
/// Handles incoming messages from the client/server
pub fn handle_collab_message<P: CollabSyncProtocol>(
pub fn handle_message<P: CollabSyncProtocol>(
origin: &CollabOrigin,
protocol: &P,
collab: &mut Collab,

View File

@ -66,7 +66,7 @@ where
}
pub async fn remove_user(&self, user: &U) {
trace!("remove subscribe: {}", user);
trace!("{} remove subscriber from group: {}", self.object_id, user);
if let Some((_, mut old_sub)) = self.subscribers.remove(user) {
old_sub.stop().await;
}
@ -142,32 +142,41 @@ where
Stream: StreamExt<Item = MessageByObjectId> + Send + Sync + Unpin + 'static,
<Sink as futures_util::Sink<CollabMessage>>::Error: std::error::Error + Send + Sync,
{
trace!(
"[realtime]: {} new subscriber: {}, connected members: {}",
self.object_id,
user.uid(),
self.subscribers.len(),
);
let sub =
self
.broadcast
.subscribe(subscriber_origin, sink, stream, Rc::downgrade(&self.collab));
// Remove the old user if it exists
let user_device = user.user_device();
if let Some((_, old)) = self.user_by_user_device.remove(&user_device) {
trace!("remove subscriber: {}", old);
trace!(
"{} remove subscriber when resubscribing: {}",
self.object_id,
old
);
if let Some((_, mut old_sub)) = self.subscribers.remove(&old) {
old_sub.stop().await;
}
}
// create new subscription for new subscriber
let sub = self.broadcast.subscribe(
user,
subscriber_origin,
sink,
stream,
Rc::downgrade(&self.collab),
);
// insert the device for given user
self
.user_by_user_device
.insert(user_device, (*user).clone());
trace!("insert subscriber: {}", user);
self.subscribers.insert((*user).clone(), sub);
trace!(
"[realtime]:{} new subscriber:{}, connect at:{}, connected members: {}",
self.object_id,
user.user_device(),
user.connect_at(),
self.subscribers.len(),
);
}
/// Check if the group is active. A group is considered active if it has at least one

View File

@ -3,7 +3,7 @@ use std::rc::{Rc, Weak};
use collab::core::origin::CollabOrigin;
use collab::preclude::Collab;
use collab_rt_protocol::{handle_collab_message, Error};
use collab_rt_protocol::{handle_message, Error};
use collab_rt_protocol::{Message, MessageReader, MSG_SYNC, MSG_SYNC_UPDATE};
use futures_util::{SinkExt, StreamExt};
use std::sync::atomic::{AtomicU32, Ordering};
@ -22,6 +22,7 @@ use collab_rt_entity::collab_msg::{
AckCode, AwarenessSync, BroadcastSync, ClientCollabMessage, CollabAck, CollabMessage,
};
use collab_rt_entity::message::MessageByObjectId;
use collab_rt_entity::user::RealtimeUser;
use tracing::{error, trace, warn};
use yrs::encoding::write::Write;
@ -158,14 +159,16 @@ impl CollabBroadcast {
/// A `Subscription` instance that represents the active subscription. Dropping this structure or
/// calling its `stop` method will unsubscribe the connection and cease all related activities.
///
pub fn subscribe<Sink, Stream>(
pub fn subscribe<Sink, Stream, U>(
&self,
user: &U,
subscriber_origin: CollabOrigin,
mut sink: Sink,
mut stream: Stream,
collab: Weak<Mutex<Collab>>,
) -> Subscription
where
U: RealtimeUser,
Sink: SinkExt<CollabMessage> + Clone + Send + Sync + Unpin + 'static,
Stream: StreamExt<Item = MessageByObjectId> + Send + Sync + Unpin + 'static,
<Sink as futures_util::Sink<CollabMessage>>::Error: std::error::Error + Send + Sync,
@ -203,6 +206,7 @@ impl CollabBroadcast {
stop_tx
};
let user = user.clone();
let stream_stop_tx = {
let (stream_stop_tx, mut stop_rx) = tokio::sync::mpsc::channel::<()>(1);
let object_id = self.object_id.clone();
@ -213,16 +217,23 @@ impl CollabBroadcast {
tokio::task::spawn_local(async move {
loop {
select! {
_ = stop_rx.recv() => break,
_ = stop_rx.recv() => {
trace!("stop receiving {} stream from user:{} connect at:{}", object_id, user.uid(), user.connect_at());
break
},
result = stream.next() => {
match result {
Some(map) => {
match collab.upgrade() {
None => break, // break the loop if the collab is dropped
None => {
trace!("{} stop receiving user:{} messages because of collab is drop", user.user_device(), object_id);
// break the loop if the collab is dropped
break
},
Some(collab) => {
for (msg_oid, collab_messages) in map {
if collab_messages.is_empty() {
warn!("collab messages is empty");
warn!("{} collab messages is empty", object_id);
}
// The message is valid if it has a payload and the object_id matches the broadcast's object_id.
@ -237,7 +248,10 @@ impl CollabBroadcast {
}
}
},
None => break,
None => {
trace!("{} stop receiving user:{} messages", object_id, user.user_device());
break
},
}
}
}
@ -274,7 +288,7 @@ async fn handle_client_collab_message<Sink>(
match msg {
Ok(msg) => {
if let Ok(mut collab) = collab.try_lock() {
let result = handle_collab_message(&origin, &ServerSyncProtocol, &mut collab, msg);
let result = handle_message(&origin, &ServerSyncProtocol, &mut collab, msg);
match result {
Ok(payload) => {
let resp = CollabAck::new(origin.clone(), object_id.to_string(), collab_msg.msg_id())

View File

@ -63,7 +63,7 @@ where
collab_messages,
} => {
if let Err(err) = self
.handle_collab_message(&user, object_id, collab_messages)
.handle_client_collab_message(&user, object_id, collab_messages)
.await
{
error!("handle client message error: {}", err);
@ -96,7 +96,7 @@ where
/// - If the group does not exist: The client is prompted to send an 'init sync' message first.
#[instrument(level = "trace", skip_all)]
async fn handle_collab_message(
async fn handle_client_collab_message(
&self,
user: &U,
object_id: String,
@ -120,19 +120,14 @@ where
let is_group_exist = self.all_groups.contains_group(&object_id).await;
if is_group_exist {
let first_message = messages.first().unwrap();
// If a group exists for the specified object_id and the message is an 'init sync',
// then remove any existing subscriber from that group and add the new user as a subscriber to the group.
if first_message.is_client_init_sync() {
self.all_groups.remove_user(&object_id, user).await?;
}
// subscribe the user to the group. then the user will receive the changes from the group
let is_user_subscribed = self.all_groups.contains_user(&object_id, user).await;
if !is_user_subscribed {
// safety: messages is not empty because we have checked it before
let first_message = messages.first().unwrap();
self.subscribe_group(user, first_message).await?;
}
broadcast_client_collab_message(user, object_id, messages, &self.client_stream_by_user).await;
forward_message_to_group(user, object_id, messages, &self.client_stream_by_user).await;
} else {
let first_message = messages.first().unwrap();
// If there is no existing group for the given object_id and the message is an 'init message',
@ -140,8 +135,7 @@ where
if first_message.is_client_init_sync() {
self.create_group(first_message).await?;
self.subscribe_group(user, first_message).await?;
broadcast_client_collab_message(user, object_id, messages, &self.client_stream_by_user)
.await;
forward_message_to_group(user, object_id, messages, &self.client_stream_by_user).await;
} else {
warn!(
"The group:{} is not found, the client:{} should send the init message first",
@ -196,8 +190,10 @@ where
}
}
/// Forward the message to the group.
/// When the group receives the message, it will broadcast the message to all the users in the group.
#[inline]
pub async fn broadcast_client_collab_message<U>(
pub async fn forward_message_to_group<U>(
user: &U,
object_id: String,
collab_messages: Vec<ClientCollabMessage>,

View File

@ -61,7 +61,6 @@ where
collab_message.object_id()
);
let client_uid = user.uid();
self
.edit_collab_by_user
.entry((*user).clone())
@ -73,9 +72,9 @@ where
let (sink, stream) = client_stream
.value_mut()
.client_channel::<CollabMessage, _>(
.client_channel::<CollabMessage, _, U>(
&collab_group.workspace_id,
client_uid,
user,
object_id,
self.access_control.clone(),
);

View File

@ -38,7 +38,7 @@ pub struct CollabRealtimeServer<S, U, AC> {
/// If the two session IDs differ, it indicates that the user has established a new connection
/// to the server since the stored session ID was last updated.
///
session_id_by_user: Arc<DashMap<U, String>>,
user_by_ws_connect_id: Arc<DashMap<U, String>>,
/// Keep track of all object ids that a user is subscribed to
editing_collab_by_user: Arc<DashMap<U, HashSet<Editing>>>,
/// Maintains a record of all client streams. A client stream associated with a user may be terminated for the following reasons:
@ -84,7 +84,7 @@ where
storage,
groups,
user_by_device: Default::default(),
session_id_by_user: Default::default(),
user_by_ws_connect_id: Default::default(),
editing_collab_by_user,
client_stream_by_user,
group_sender_by_object_id,
@ -96,7 +96,7 @@ where
pub fn handle_new_connection(
&self,
user: U,
session_id: String,
ws_connect_id: String,
conn_sink: impl RealtimeClientWebsocketSink,
) -> Pin<Box<dyn Future<Output = Result<(), RealtimeError>>>> {
// User with the same id and same device will be replaced with the new connection [CollabClientStream]
@ -105,13 +105,16 @@ where
let user_by_device = self.user_by_device.clone();
let client_stream_by_user = self.client_stream_by_user.clone();
let editing_collab_by_user = self.editing_collab_by_user.clone();
let user_by_session_id = self.session_id_by_user.clone();
let user_by_ws_connect_id = self.user_by_ws_connect_id.clone();
Box::pin(async move {
trace!("[realtime]: new connection => {} ", user);
user_by_session_id.insert(user.clone(), session_id);
user_by_ws_connect_id.insert(user.clone(), ws_connect_id);
let old_user = user_by_device.insert(UserDevice::from(&user), user.clone());
trace!(
"[realtime]: new connection => {}, remove old: {:?}",
user,
old_user
);
if let Some(old_user) = old_user {
if let Some(value) = client_stream_by_user.remove(&old_user) {
let old_stream = value.1;
@ -146,19 +149,19 @@ where
pub fn handle_disconnect(
&self,
user: U,
session_id: String,
ws_connect_id: String,
) -> Pin<Box<dyn Future<Output = Result<(), RealtimeError>>>> {
let groups = self.groups.clone();
let client_stream_by_user = self.client_stream_by_user.clone();
let editing_collab_by_user = self.editing_collab_by_user.clone();
let session_id_by_user = self.session_id_by_user.clone();
let session_id_by_user = self.user_by_ws_connect_id.clone();
Box::pin(async move {
trace!("[realtime]: disconnect => {}", user);
// If the user has reconnected with a new session, the session id will be different.
// So do not remove the user from groups and client streams
if let Some(entry) = session_id_by_user.get(&user) {
if entry.value() != &session_id {
if entry.value() != &ws_connect_id {
return Ok(());
}
}
@ -285,14 +288,15 @@ impl CollabClientStream {
///
/// [Stream] will be used to send changes to the collab object.
///
pub fn client_channel<T, AC>(
pub fn client_channel<T, AC, U>(
&mut self,
workspace_id: &str,
uid: i64,
user: &U,
object_id: &str,
access_control: Arc<AC>,
) -> (UnboundedSenderSink<T>, ReceiverStream<MessageByObjectId>)
where
U: RealtimeUser,
T: Into<RealtimeMessage> + Send + Sync + 'static,
AC: RealtimeAccessControl,
{
@ -305,6 +309,7 @@ impl CollabClientStream {
let (client_sink_tx, mut client_sink_rx) = tokio::sync::mpsc::unbounded_channel::<T>();
let sink_access_control = access_control.clone();
let sink_workspace_id = workspace_id.to_string();
let uid = user.uid();
tokio::spawn(async move {
while let Some(msg) = client_sink_rx.recv().await {
let result = sink_access_control
@ -336,7 +341,9 @@ impl CollabClientStream {
let cloned_object_id = object_id.to_string();
let stream_workspace_id = workspace_id.to_string();
// forward the message to the stream that was subscribed by the broadcast group
let user = user.clone();
// stream_rx continuously receive messages from the websocket client and then
// forward the message to the subscriber which is the broadcast channel [CollabBroadcast].
let (tx, rx) = tokio::sync::mpsc::channel(100);
tokio::spawn(async move {
while let Some(Ok(realtime_msg)) = stream_rx.next().await {
@ -349,15 +356,16 @@ impl CollabClientStream {
let (valid_messages, invalid_message) = Self::access_control(
&stream_workspace_id,
&uid,
&user.uid(),
&msg_oid,
&access_control,
original_messages,
)
.await;
trace!(
"{} receive message: valid:{} invalid:{}",
"{} receive {} client message: valid:{} invalid:{}",
msg_oid,
user.uid(),
valid_messages.len(),
invalid_message.len()
);
@ -365,8 +373,16 @@ impl CollabClientStream {
if valid_messages.is_empty() {
continue;
}
if tx.send([(msg_oid, valid_messages)].into()).await.is_err() {
break;
// if tx.send return error, it means the client is disconnected from the group
if let Err(err) = tx.send([(msg_oid, valid_messages)].into()).await {
trace!(
"{} send message to user:{} stream fail with error: {}, break the loop",
cloned_object_id,
user.user_device(),
err,
);
return;
}
}
},

View File

@ -32,7 +32,7 @@ pub trait CollabStorageAccessControl: Send + Sync + 'static {
async fn enforce_write_collab(
&self,
worksapce_id: &str,
workspace_id: &str,
uid: &i64,
oid: &str,
) -> Result<bool, AppError>;

View File

@ -82,7 +82,7 @@ where
act.server.do_send(Disconnect {
user,
session_id: session_id.clone(),
ws_connect_id: session_id.clone(),
});
ctx.stop();
return;
@ -186,7 +186,7 @@ where
.send(Connect {
socket: ctx.address().recipient(),
user: self.user.clone(),
session_id: self.session_id.clone(),
ws_connect_id: self.session_id.clone(),
})
.into_actor(self)
.then(|res, _session, ctx| {
@ -215,7 +215,7 @@ where
trace!("{} stopping websocket connect", user);
self.server.do_send(Disconnect {
user,
session_id: self.session_id.clone(),
ws_connect_id: self.session_id.clone(),
});
Running::Stop
}

View File

@ -11,14 +11,16 @@ pub use collab_rt_entity::message::RealtimeMessage;
pub struct Connect<U> {
pub socket: Recipient<RealtimeMessage>,
pub user: U,
pub session_id: String,
/// Each websocket connection has a unique id
pub ws_connect_id: String,
}
#[derive(Debug, Message, Clone)]
#[rtype(result = "Result<(), RealtimeError>")]
pub struct Disconnect<U> {
pub user: U,
pub session_id: String,
/// Each websocket connection has a unique id
pub ws_connect_id: String,
}
#[derive(Debug, Message, Clone)]

View File

@ -52,7 +52,7 @@ where
fn handle(&mut self, new_conn: Connect<U>, _ctx: &mut Context<Self>) -> Self::Result {
let conn_sink = RealtimeClientWebsocketSinkImpl(new_conn.socket);
self.handle_new_connection(new_conn.user, new_conn.session_id, conn_sink)
self.handle_new_connection(new_conn.user, new_conn.ws_connect_id, conn_sink)
}
}
@ -64,7 +64,7 @@ where
{
type Result = ResponseFuture<anyhow::Result<(), RealtimeError>>;
fn handle(&mut self, msg: Disconnect<U>, _: &mut Context<Self>) -> Self::Result {
self.handle_disconnect(msg.user, msg.session_id)
self.handle_disconnect(msg.user, msg.ws_connect_id)
}
}

View File

@ -118,12 +118,10 @@ where
trace!("Skip access control for the request");
return Ok(());
}
let collab_exists = self.collab_cache.is_exist(oid).await?;
let collab_exists = self.collab_cache.is_exist_in_disk(oid).await?;
if !collab_exists {
return Err(AppError::RecordNotFound(format!(
"Collab not exist in db. {}",
oid
)));
// If the collab does not exist, we should not enforce the access control
return Ok(());
}
let access_level = self.require_access_level(&method, path);
@ -190,12 +188,11 @@ where
uid: &i64,
oid: &str,
) -> Result<bool, AppError> {
let collab_exists = self.cache.is_exist(oid).await?;
let collab_exists = self.cache.is_exist_in_disk(oid).await?;
if !collab_exists {
return Err(AppError::RecordNotFound(format!(
"Collab not exist in db. {}",
oid
)));
// If the collab does not exist, we should not enforce the access control. We consider the user
// has the permission to read the collab
return Ok(true);
}
self
.collab_access_control
@ -209,12 +206,11 @@ where
uid: &i64,
oid: &str,
) -> Result<bool, AppError> {
let collab_exists = self.cache.is_exist(oid).await?;
let collab_exists = self.cache.is_exist_in_disk(oid).await?;
if !collab_exists {
return Err(AppError::RecordNotFound(format!(
"Collab not exist in db. {}",
oid
)));
// If the collab does not exist, we should not enforce the access control. we consider the user
// has the permission to write the collab
return Ok(true);
}
self
.collab_access_control

View File

@ -149,7 +149,7 @@ impl CollabCache {
Ok(())
}
pub async fn is_exist(&self, oid: &str) -> Result<bool, AppError> {
pub async fn is_exist_in_disk(&self, oid: &str) -> Result<bool, AppError> {
let is_exist = self.disk_cache.is_exist(oid).await?;
Ok(is_exist)
}

View File

@ -23,7 +23,7 @@ use std::ops::DerefMut;
use std::time::Duration;
use tokio::sync::oneshot;
use tokio::time::timeout;
use tracing::{error, instrument};
use tracing::{error, instrument, trace};
use validator::Validate;
pub type CollabAccessControlStorage = CollabStorageImpl<
@ -201,6 +201,7 @@ where
Ok(())
}
#[instrument(level = "trace", skip_all, fields(oid = %params.object_id, is_collab_init = %is_collab_init))]
async fn get_collab_encoded(
&self,
uid: &i64,
@ -224,6 +225,7 @@ where
// Early return if editing collab is initialized, as it indicates no need to query further.
if !is_collab_init {
trace!("Get encode collab {} from editing collab", params.object_id);
// Attempt to retrieve encoded collab from the editing collab
if let Some(value) = self.get_encode_collab_from_editing(&params.object_id).await {
return Ok(value);

View File

@ -207,7 +207,7 @@ pub type UserListener = crate::biz::pg_listener::PostgresDBListener<AFUserNotifi
pub struct RealtimeUserImpl {
pub uid: i64,
pub device_id: String,
pub timestamp: i64,
pub connect_at: i64,
}
impl RealtimeUserImpl {
@ -215,7 +215,7 @@ impl RealtimeUserImpl {
Self {
uid,
device_id,
timestamp: chrono::Utc::now().timestamp(),
connect_at: chrono::Utc::now().timestamp(),
}
}
}
@ -224,7 +224,7 @@ impl Display for RealtimeUserImpl {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!(
"uid:{}|device_id:{}|connected_at:{}",
self.uid, self.device_id, self.timestamp,
self.uid, self.device_id, self.connect_at,
))
}
}
@ -237,4 +237,8 @@ impl RealtimeUser for RealtimeUserImpl {
fn device_id(&self) -> &str {
&self.device_id
}
fn connect_at(&self) -> i64 {
self.connect_at
}
}

View File

@ -1,9 +1,9 @@
mod awareness_test;
mod collab_curd_test;
mod edit_permission;
mod edit_workspace;
mod member_crud;
mod multi_devices_edit;
mod single_device_edit;
mod storage_test;
mod team_edit_test;
mod util;

View File

@ -0,0 +1 @@

View File

@ -1,4 +1,5 @@
mod blob;
mod edit_workspace;
mod invitation_crud;
mod member_crud;
mod template_test;