diff --git a/Cargo.lock b/Cargo.lock index 3b293f79..bae0f4ae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1360,7 +1360,7 @@ dependencies = [ [[package]] name = "collab" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=0c4bcfaf033ef6bfe2ebb40c26b787bfd4cc095f#0c4bcfaf033ef6bfe2ebb40c26b787bfd4cc095f" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=01be7a981515dd02a8bd37e3e79f1d24eade0f47#01be7a981515dd02a8bd37e3e79f1d24eade0f47" dependencies = [ "anyhow", "async-trait", @@ -1382,7 +1382,7 @@ dependencies = [ [[package]] name = "collab-document" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=0c4bcfaf033ef6bfe2ebb40c26b787bfd4cc095f#0c4bcfaf033ef6bfe2ebb40c26b787bfd4cc095f" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=01be7a981515dd02a8bd37e3e79f1d24eade0f47#01be7a981515dd02a8bd37e3e79f1d24eade0f47" dependencies = [ "anyhow", "collab", @@ -1401,7 +1401,7 @@ dependencies = [ [[package]] name = "collab-entity" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=0c4bcfaf033ef6bfe2ebb40c26b787bfd4cc095f#0c4bcfaf033ef6bfe2ebb40c26b787bfd4cc095f" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=01be7a981515dd02a8bd37e3e79f1d24eade0f47#01be7a981515dd02a8bd37e3e79f1d24eade0f47" dependencies = [ "anyhow", "bytes", @@ -1416,7 +1416,7 @@ dependencies = [ [[package]] name = "collab-folder" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=0c4bcfaf033ef6bfe2ebb40c26b787bfd4cc095f#0c4bcfaf033ef6bfe2ebb40c26b787bfd4cc095f" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=01be7a981515dd02a8bd37e3e79f1d24eade0f47#01be7a981515dd02a8bd37e3e79f1d24eade0f47" dependencies = [ "anyhow", "chrono", @@ -4285,6 +4285,7 @@ dependencies = [ "realtime-protocol", "serde", "serde_json", + "serde_repr", "thiserror", "tokio-tungstenite", "websocket", diff --git a/Cargo.toml b/Cargo.toml index 34f9b3f7..0abde9e8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -172,13 +172,15 @@ inherits = "release" debug = true [patch.crates-io] -collab = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "0c4bcfaf033ef6bfe2ebb40c26b787bfd4cc095f" } -collab-entity = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "0c4bcfaf033ef6bfe2ebb40c26b787bfd4cc095f" } -collab-folder = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "0c4bcfaf033ef6bfe2ebb40c26b787bfd4cc095f" } -collab-document = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "0c4bcfaf033ef6bfe2ebb40c26b787bfd4cc095f" } +collab = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "01be7a981515dd02a8bd37e3e79f1d24eade0f47" } +collab-entity = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "01be7a981515dd02a8bd37e3e79f1d24eade0f47" } +collab-folder = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "01be7a981515dd02a8bd37e3e79f1d24eade0f47" } +collab-document = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "01be7a981515dd02a8bd37e3e79f1d24eade0f47" } [features] custom_env= [] +# This feature will be removed once the cpu spike issue is resolved +disable_access_control = [] # Comment the above and uncomment the below to use local version of collab by cloning the repo and placing it in libs folder #collab = { path = "libs/AppFlowy-Collab/collab" } diff --git a/Dockerfile b/Dockerfile index caab48d5..e9bd534e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,7 +16,8 @@ COPY . . ENV SQLX_OFFLINE true # Build the project -RUN cargo build --profile=profiling --bin appflowy_cloud +RUN cargo build --profile=profiling --features="disable_access_control" --bin appflowy_cloud + FROM debian:bookworm-slim AS runtime WORKDIR /app diff --git a/build/client_api_deps_check.sh b/build/client_api_deps_check.sh index 2d22f2f9..189986a0 100644 --- a/build/client_api_deps_check.sh +++ b/build/client_api_deps_check.sh @@ -3,7 +3,7 @@ # Generate the current dependency list cargo tree > current_deps.txt -BASELINE_COUNT=1683 +BASELINE_COUNT=1684 CURRENT_COUNT=$(cat current_deps.txt | wc -l) echo "Expected dependency count (baseline): $BASELINE_COUNT" diff --git a/libs/client-api-test-util/src/log.rs b/libs/client-api-test-util/src/log.rs index b0bd2065..0a10ec09 100644 --- a/libs/client-api-test-util/src/log.rs +++ b/libs/client-api-test-util/src/log.rs @@ -8,9 +8,11 @@ use { pub fn setup_log() { static START: Once = Once::new(); START.call_once(|| { - let level = "info"; + let level = "trace"; let mut filters = vec![]; filters.push(format!("client_api={}", level)); + filters.push(format!("appflowy_cloud={}", level)); + filters.push(format!("collab={}", level)); std::env::set_var("RUST_LOG", filters.join(",")); let subscriber = Subscriber::builder() diff --git a/libs/client-api/src/collab_sync/error.rs b/libs/client-api/src/collab_sync/error.rs index 4dfe5b67..0a10b8b8 100644 --- a/libs/client-api/src/collab_sync/error.rs +++ b/libs/client-api/src/collab_sync/error.rs @@ -9,6 +9,9 @@ pub enum SyncError { #[error("failed to deserialize message: {0}")] DecodingError(#[from] yrs::encoding::read::Error), + #[error("Can not apply update for object:{0}")] + CannotApplyUpdate(String), + #[error(transparent)] SerdeError(#[from] serde_json::Error), @@ -24,3 +27,9 @@ pub enum SyncError { #[error(transparent)] Internal(#[from] anyhow::Error), } + +impl SyncError { + pub fn is_cannot_apply_update(&self) -> bool { + matches!(self, Self::CannotApplyUpdate(_)) + } +} diff --git a/libs/client-api/src/collab_sync/sink.rs b/libs/client-api/src/collab_sync/sink.rs index 9e62ce14..8bc6b84b 100644 --- a/libs/client-api/src/collab_sync/sink.rs +++ b/libs/client-api/src/collab_sync/sink.rs @@ -1,5 +1,3 @@ -use collab::core::origin::CollabOrigin; - use std::marker::PhantomData; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::{Arc, Weak}; @@ -10,7 +8,7 @@ use crate::collab_sync::{SyncError, SyncObject, DEFAULT_SYNC_TIMEOUT}; use futures_util::SinkExt; use crate::platform_spawn; -use realtime_entity::collab_msg::{CollabSinkMessage, MsgId}; +use realtime_entity::collab_msg::{CollabMessage, CollabSinkMessage, MsgId}; use tokio::sync::{mpsc, oneshot, watch, Mutex}; use tokio::time::{interval, Instant, Interval}; use tracing::{debug, error, event, trace, warn}; @@ -162,27 +160,28 @@ where } /// Notify the sink to process the next message and mark the current message as done. - pub async fn ack_msg( - &self, - _origin: Option<&CollabOrigin>, - _object_id: &str, - msg_id: MsgId, - ) -> bool { - match self.pending_msg_queue.lock().peek_mut() { - None => false, - Some(mut pending_msg) => { - // In most cases, the msg_id of the pending_msg is the same as the passed-in msg_id. However, - // due to network issues, the client might send multiple messages with the same msg_id. - // Therefore, the msg_id might not always match the msg_id of the pending_msg. - if pending_msg.msg_id() != msg_id { - return false; - } + pub async fn ack_msg(&self, msg: &CollabMessage) -> bool { + // the msg_id will be None if the message is [ServerBroadcast] or [ServerAwareness] + match msg.msg_id() { + None => true, + Some(msg_id) => { + match self.pending_msg_queue.lock().peek_mut() { + None => false, + Some(mut pending_msg) => { + // In most cases, the msg_id of the pending_msg is the same as the passed-in msg_id. However, + // due to network issues, the client might send multiple messages with the same msg_id. + // Therefore, the msg_id might not always match the msg_id of the pending_msg. + if pending_msg.msg_id() != msg_id { + return false; + } - let is_done = pending_msg.set_state(self.uid, MessageState::Done); - if is_done { - self.notify(); + let is_done = pending_msg.set_state(self.uid, MessageState::Done); + if is_done { + self.notify(); + } + is_done + }, } - is_done }, } } diff --git a/libs/client-api/src/collab_sync/sync.rs b/libs/client-api/src/collab_sync/sync.rs index 8ae3b09f..e116727b 100644 --- a/libs/client-api/src/collab_sync/sync.rs +++ b/libs/client-api/src/collab_sync/sync.rs @@ -8,7 +8,7 @@ use collab::core::collab::MutexCollab; use collab::core::collab_state::SyncState; use collab::core::origin::CollabOrigin; use futures_util::{SinkExt, StreamExt}; -use realtime_entity::collab_msg::{CollabMessage, InitSync, ServerInit, UpdateSync}; +use realtime_entity::collab_msg::{AckCode, CollabMessage, InitSync, ServerInit, UpdateSync}; use realtime_protocol::{handle_collab_message, ClientSyncProtocol, CollabSyncProtocol}; use realtime_protocol::{Message, MessageReader, SyncMessage}; use std::marker::PhantomData; @@ -16,7 +16,7 @@ use std::ops::Deref; use std::sync::{Arc, Weak}; use tokio::sync::watch; use tokio_stream::wrappers::WatchStream; -use tracing::{error, event, trace, warn, Level}; +use tracing::{error, span, trace, warn, Level}; use yrs::encoding::read::Cursor; use yrs::updates::decoder::DecoderV1; use yrs::updates::encoder::{Encoder, EncoderV1}; @@ -236,30 +236,45 @@ where ) where P: CollabSyncProtocol + Send + Sync + 'static, { - while let Some(collab_message) = stream.next().await { - match collab_message { - Ok(msg) => match (weak_collab.upgrade(), weak_sink.upgrade()) { - (Some(collab), Some(sink)) => { - let span = tracing::span!(Level::TRACE, "doc_stream", object_id = %msg.object_id()); - let _enter = span.enter(); - if let Err(error) = SyncStream::::process_message::

( - &origin, &object_id, &protocol, &collab, &sink, msg, - ) - .await - { - error!("Error while processing message: {}", error); - } - }, - _ => { - // The collab or sink is dropped, stop the stream. - warn!("Stop receive doc incoming changes."); - break; - }, - }, - Err(e) => { - warn!("Stream error: {},stop receive incoming changes", e.into()); + while let Some(collab_message_result) = stream.next().await { + let collab = match weak_collab.upgrade() { + Some(collab) => collab, + None => break, // Collab dropped, stop the stream. + }; + + let sink = match weak_sink.upgrade() { + Some(sink) => sink, + None => break, // Sink dropped, stop the stream. + }; + + let msg = match collab_message_result { + Ok(msg) => msg, + Err(err) => { + warn!( + "Stream error: {}, stop receive incoming changes", + err.into() + ); break; }, + }; + + let span = span!(Level::TRACE, "doc_stream", object_id = %msg.object_id()); + let _enter = span.enter(); + if let Err(error) = SyncStream::::process_message::

( + &origin, &object_id, &protocol, &collab, &sink, msg, + ) + .await + { + if error.is_cannot_apply_update() { + // TODO(nathan): ask the client to resolve the conflict. + error!( + "collab:{} can not be synced because of error: {}", + object_id, error + ); + break; + } else { + error!("Error while processing message: {}", error); + } } } } @@ -276,30 +291,29 @@ where where P: CollabSyncProtocol + Send + Sync + 'static, { - let should_process = match msg.msg_id() { - // The msg_id is None if the message is [ServerBroadcast] or [ServerAwareness] - None => true, - Some(msg_id) => sink.ack_msg(msg.origin(), msg.object_id(), msg_id).await, - }; - - if should_process { - if let Some(payload) = msg.payload() { - event!( - Level::TRACE, - "receive collab message: {}, payload: {}", - msg, - payload.len() - ); - if !payload.is_empty() { - trace!("start process message:{:?}", msg.msg_id()); - SyncStream::::process_payload( - origin, payload, object_id, protocol, collab, sink, - ) - .await?; - trace!("end process message: {:?}", msg.msg_id()); - } + // If server return the AckCode::ApplyInternalError, which means the server can not apply the + // update + if let CollabMessage::ClientAck(ack) = &msg { + if ack.code == AckCode::CannotApplyUpdate { + return Err(SyncError::CannotApplyUpdate(object_id.to_string())); } } + + // Check if the message is acknowledged by the sink. If not, return. + if !sink.ack_msg(&msg).await { + return Ok(()); + } + + // If there's no payload or the payload is empty, return. + let payload = match msg.payload() { + Some(payload) if !payload.is_empty() => payload, + _ => return Ok(()), + }; + + trace!("start process message:{:?}", msg.msg_id()); + SyncStream::::process_payload(origin, payload, object_id, protocol, collab, sink) + .await?; + trace!("end process message: {:?}", msg.msg_id()); Ok(()) } @@ -316,12 +330,11 @@ where { let mut decoder = DecoderV1::new(Cursor::new(payload)); let reader = MessageReader::new(&mut decoder); - let cloned_origin = Some(origin.clone()); for msg in reader { let msg = msg?; trace!(" {}", msg); let is_sync_step_1 = matches!(msg, Message::Sync(SyncMessage::SyncStep1(_))); - if let Some(payload) = handle_collab_message(&cloned_origin, protocol, collab, msg)? { + if let Some(payload) = handle_collab_message(origin, protocol, collab, msg)? { if is_sync_step_1 { // flush match collab.try_lock() { diff --git a/libs/realtime-entity/Cargo.toml b/libs/realtime-entity/Cargo.toml index 9f9e1a43..4c0e38f8 100644 --- a/libs/realtime-entity/Cargo.toml +++ b/libs/realtime-entity/Cargo.toml @@ -23,6 +23,7 @@ yrs.workspace = true thiserror = "1.0.56" realtime-protocol.workspace = true websocket.workspace = true +serde_repr = "0.1" [build-dependencies] protoc-bin-vendored = { version = "3.0" } diff --git a/libs/realtime-entity/src/collab_msg.rs b/libs/realtime-entity/src/collab_msg.rs index f359378b..733c6c9a 100644 --- a/libs/realtime-entity/src/collab_msg.rs +++ b/libs/realtime-entity/src/collab_msg.rs @@ -11,6 +11,7 @@ use collab::preclude::updates::encoder::{Encode, Encoder, EncoderV1}; use collab_entity::CollabType; use realtime_protocol::{Message, MessageReader, SyncMessage}; use serde::{Deserialize, Serialize}; +use serde_repr::{Deserialize_repr, Serialize_repr}; pub trait CollabSinkMessage: Clone + Send + Sync + 'static + Ord + Display { fn collab_object_id(&self) -> &str; @@ -143,19 +144,19 @@ impl CollabMessage { pub fn is_empty(&self) -> bool { self.len() == 0 } - pub fn origin(&self) -> Option<&CollabOrigin> { + pub fn origin(&self) -> &CollabOrigin { match self { - CollabMessage::ClientInitSync(value) => Some(&value.origin), - CollabMessage::ClientUpdateSync(value) => Some(&value.origin), - CollabMessage::ClientAck(value) => Some(&value.origin), - CollabMessage::ServerInitSync(value) => Some(&value.origin), - CollabMessage::ServerBroadcast(value) => Some(&value.origin), - CollabMessage::AwarenessSync(_) => None, + CollabMessage::ClientInitSync(value) => &value.origin, + CollabMessage::ClientUpdateSync(value) => &value.origin, + CollabMessage::ClientAck(value) => &value.origin, + CollabMessage::ServerInitSync(value) => &value.origin, + CollabMessage::ServerBroadcast(value) => &value.origin, + CollabMessage::AwarenessSync(value) => &value.origin, } } pub fn uid(&self) -> Option { - self.origin().and_then(|origin| origin.client_user_id()) + self.origin().client_user_id() } pub fn object_id(&self) -> &str { @@ -170,10 +171,10 @@ impl CollabMessage { } pub fn device_id(&self) -> Option { - self.origin().and_then(|origin| match origin { + match self.origin() { CollabOrigin::Client(origin) => Some(origin.device_id.clone()), _ => None, - }) + } } } @@ -194,6 +195,7 @@ impl Display for CollabMessage { pub struct CollabAwareness { object_id: String, payload: Bytes, + origin: CollabOrigin, } impl CollabAwareness { @@ -201,6 +203,7 @@ impl CollabAwareness { Self { object_id, payload: Bytes::from(payload), + origin: CollabOrigin::Server, } } } @@ -389,38 +392,55 @@ impl Display for UpdateSync { } } +#[derive(Clone, Eq, PartialEq, Debug, Serialize_repr, Deserialize_repr)] +#[repr(u8)] +pub enum AckCode { + Success = 0, + CannotApplyUpdate = 1, + Retry = 2, + Internal = 3, +} + #[derive(Clone, Eq, PartialEq, Debug, Serialize, Deserialize)] pub struct CollabAck { pub origin: CollabOrigin, pub object_id: String, pub source: AckSource, pub payload: Bytes, + pub code: AckCode, } #[derive(Clone, Eq, PartialEq, Debug, Serialize, Deserialize)] pub struct AckSource { - pub sync_verbose: String, + #[serde(rename = "sync_verbose")] + pub verbose: String, pub msg_id: MsgId, } impl CollabAck { - pub fn new( - origin: CollabOrigin, - object_id: String, - payload: Vec, - msg_id: MsgId, - sync_verbose: String, - ) -> Self { + pub fn new(origin: CollabOrigin, object_id: String, msg_id: MsgId) -> Self { + let source = AckSource { + verbose: "".to_string(), + msg_id, + }; Self { origin, object_id, - payload: Bytes::from(payload), - source: AckSource { - sync_verbose, - msg_id, - }, + source, + payload: Bytes::from(vec![]), + code: AckCode::Success, } } + + pub fn with_payload>(mut self, payload: T) -> Self { + self.payload = payload.into(); + self + } + + pub fn with_code(mut self, code: AckCode) -> Self { + self.code = code; + self + } } impl From for CollabMessage { @@ -432,11 +452,10 @@ impl From for CollabMessage { impl Display for CollabAck { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.write_fmt(format_args!( - "ack: [origin:{}|oid:{}|msg_id:{:?}|{}|len:{}]", + "ack: [origin:{}|oid:{}|msg_id:{:?}|len:{}]", self.origin, self.object_id, self.source.msg_id, - self.source.sync_verbose, self.payload.len(), )) } diff --git a/libs/realtime-protocol/src/message.rs b/libs/realtime-protocol/src/message.rs index d52b75f3..91c7cc1f 100644 --- a/libs/realtime-protocol/src/message.rs +++ b/libs/realtime-protocol/src/message.rs @@ -248,6 +248,9 @@ pub enum Error { #[error("{0}")] YrsTransaction(String), + #[error("{0}")] + YrsApplyUpdate(String), + #[error(transparent)] BinCodeSerde(#[from] bincode::Error), diff --git a/libs/realtime-protocol/src/protocol.rs b/libs/realtime-protocol/src/protocol.rs index 3c6108c0..0c3e153a 100644 --- a/libs/realtime-protocol/src/protocol.rs +++ b/libs/realtime-protocol/src/protocol.rs @@ -78,20 +78,17 @@ pub trait CollabSyncProtocol { /// an update to current `awareness` document instance. fn handle_sync_step2( &self, - origin: &Option, + origin: &CollabOrigin, awareness: &mut Awareness, update: Update, ) -> Result>, Error> { let mut retry_txn = TransactionRetry::new(awareness.doc()); - let mut txn = if let Some(origin) = origin.as_ref() { - retry_txn.try_get_write_txn_with(origin.clone()) - } else { - retry_txn.try_get_write_txn() - } - .map_err(|err| Error::YrsTransaction(format!("sync step2 transaction acquire: {}", err)))?; + let mut txn = retry_txn + .try_get_write_txn_with(origin.clone()) + .map_err(|err| Error::YrsTransaction(format!("sync step2 transaction acquire: {}", err)))?; txn .try_apply_update(update) - .map_err(|err| Error::YrsTransaction(format!("sync step2 apply update: {}", err)))?; + .map_err(|err| Error::YrsApplyUpdate(format!("sync step2 apply update: {}", err)))?; Ok(None) } @@ -99,7 +96,7 @@ pub trait CollabSyncProtocol { /// `awareness` document instance. fn handle_update( &self, - origin: &Option, + origin: &CollabOrigin, awareness: &mut Awareness, update: Update, ) -> Result>, Error> { @@ -140,7 +137,7 @@ pub trait CollabSyncProtocol { /// Handles incoming messages from the client/server pub fn handle_collab_message( - origin: &Option, + origin: &CollabOrigin, protocol: &P, collab: &MutexCollab, msg: Message, diff --git a/libs/realtime/src/collaborate/broadcast.rs b/libs/realtime/src/collaborate/broadcast.rs index 82e06ffe..2357eed6 100644 --- a/libs/realtime/src/collaborate/broadcast.rs +++ b/libs/realtime/src/collaborate/broadcast.rs @@ -11,7 +11,7 @@ use collab::core::awareness::{Awareness, AwarenessUpdate}; use collab::core::collab::MutexCollab; use collab::core::origin::CollabOrigin; use futures_util::{SinkExt, StreamExt}; -use realtime_protocol::handle_collab_message; +use realtime_protocol::{handle_collab_message, Error}; use realtime_protocol::{Message, MessageReader, MSG_SYNC, MSG_SYNC_UPDATE}; use tokio::select; use tokio::sync::broadcast::error::SendError; @@ -25,7 +25,9 @@ use yrs::updates::encoder::{Encode, Encoder, EncoderV1}; use yrs::UpdateSubscription; use crate::error::RealtimeError; -use realtime_entity::collab_msg::{CollabAck, CollabAwareness, CollabBroadcastData, CollabMessage}; +use realtime_entity::collab_msg::{ + AckCode, CollabAck, CollabAwareness, CollabBroadcastData, CollabMessage, +}; use tracing::{error, trace, warn}; use yrs::encoding::write::Write; @@ -40,6 +42,12 @@ pub struct CollabBroadcast { doc_subscription: Mutex>, } +impl Drop for CollabBroadcast { + fn drop(&mut self) { + trace!("Drop collab broadcast:{}", self.object_id); + } +} + impl CollabBroadcast { /// Creates a new [CollabBroadcast] over a provided `collab` instance. All changes triggered /// by this collab will be propagated to all subscribers which have been registered via @@ -121,12 +129,35 @@ impl CollabBroadcast { Ok(()) } - /// Subscribes a new connection - represented by `sink`/`stream` pair implementing a futures - /// Sink and Stream protocols - to a current broadcast group. + /// Subscribes a new connection to a broadcast group, enabling real-time collaboration. + /// + /// This function takes a `sink`/`stream` pair representing the connection to a subscriber. The `sink` + /// is used to send messages to the subscriber, while the `stream` receives messages from the subscriber. + /// + /// # Arguments + /// - `subscriber_origin`: Identifies the subscriber's origin to avoid echoing messages back. + /// - `sink`: A `Sink` implementation for sending messages to the subscriber(Each connected client). + /// - `stream`: A `Stream` implementation for receiving messages from the subscriber((Each connected client)). + /// - `modified_at`: A shared, mutable reference to track the last modification time of the document. + /// + /// # Behavior + /// - [Sink] Forwards updates received from the document observer to all subscribers through 'sink', excluding the originator + /// of the message, to prevent echoing back the same message. + /// - [Stream] Processes incoming messages from the `stream` associated with the subscriber. If a message alters + /// the document's state, it triggers an update broadcast to all subscribers. + /// + /// - Utilizes two asynchronous tasks: one for broadcasting updates to the `sink`, and another for + /// processing messages from the `stream`. + /// + /// # Termination + /// - The subscription can be manually stopped by dropping the returned `Subscription` structure or + /// by awaiting its `stop` method. This action will terminate both the sink and stream tasks. + /// - Internal errors or disconnection will also terminate the tasks, ending the subscription. + /// + /// # Returns + /// A `Subscription` instance that represents the active subscription. Dropping this structure or + /// calling its `stop` method will unsubscribe the connection and cease all related activities. /// - /// Returns a subscription structure, which can be dropped in order to unsubscribe or awaited - /// via [Subscription::stop] method in order to complete of its own volition (due to - /// an internal connection error or closed connection). pub fn subscribe( &self, subscriber_origin: CollabOrigin, @@ -155,10 +186,8 @@ impl CollabBroadcast { result = receiver.recv() => { match result { Ok(message) => { - if let Some(msg_origin) = message.origin() { - if msg_origin == &subscriber_origin { - continue; - } + if message.origin() == &subscriber_origin { + continue; } trace!("[realtime]: broadcast collab message: {}", message); @@ -240,39 +269,34 @@ async fn handle_client_collab_message( Some(payload) => { let object_id = object_id.to_string(); let mut decoder = DecoderV1::from(payload.as_ref()); - let origin = Arc::new(collab_msg.origin().cloned()); + let origin = collab_msg.origin().clone(); let reader = MessageReader::new(&mut decoder); for msg in reader { match msg { Ok(msg) => { let cloned_collab = collab.clone(); - let cloned_origin = origin.clone(); - let result = - handle_collab_message(&cloned_origin, &ServerSyncProtocol, &cloned_collab, msg); + let result = handle_collab_message(&origin, &ServerSyncProtocol, &cloned_collab, msg); + if let Some(msg_id) = collab_msg.msg_id() { + match result { + Ok(payload) => { + let resp = CollabAck::new(origin.clone(), object_id.clone(), msg_id) + .with_payload(payload.unwrap_or_default()); - match result { - Ok(payload) => match origin.as_ref() { - None => warn!("Client message does not have a origin"), - Some(origin) => { - if let Some(msg_id) = collab_msg.msg_id() { - let resp = CollabAck::new( - origin.clone(), - object_id.clone(), - payload.unwrap_or_default(), - msg_id, - collab_msg.type_str(), - ); - - trace!("Send response to client: {}", resp); - if let Err(err) = sink.send(resp.into()).await { - trace!("fail to send response to client: {}", err); - } + trace!("Send response to client: {}", resp); + if let Err(err) = sink.send(resp.into()).await { + trace!("fail to send response to client: {}", err); } }, - }, - Err(err) => { - error!("object id:{} =>{}", object_id, err); - }, + Err(err) => { + error!("handle collab:{} message error:{}", object_id, err); + let resp = CollabAck::new(origin.clone(), object_id.clone(), msg_id) + .with_code(ack_code_from_error(&err)); + + if let Err(err) = sink.send(resp.into()).await { + trace!("fail to send response to client: {}", err); + } + }, + } } }, Err(e) => { @@ -288,6 +312,15 @@ async fn handle_client_collab_message( } } +#[inline] +fn ack_code_from_error(error: &Error) -> AckCode { + match error { + Error::YrsTransaction(_) => AckCode::Retry, + Error::YrsApplyUpdate(_) => AckCode::CannotApplyUpdate, + _ => AckCode::Internal, + } +} + /// A subscription structure returned from [CollabBroadcast::subscribe], which represents a /// subscribed connection. It can be dropped in order to unsubscribe or awaited via /// [Subscription::stop] method in order to complete of its own volition (due to an internal @@ -302,10 +335,14 @@ pub struct Subscription { impl Subscription { pub async fn stop(&mut self) { if let Some(sink_stop_tx) = self.sink_stop_tx.take() { - let _ = sink_stop_tx.send(()).await; + if let Err(err) = sink_stop_tx.send(()).await { + error!("fail to stop sink:{}", err); + } } if let Some(stream_stop_tx) = self.stream_stop_tx.take() { - let _ = stream_stop_tx.send(()).await; + if let Err(err) = stream_stop_tx.send(()).await { + error!("fail to stop stream:{}", err); + } } } } diff --git a/libs/realtime/src/collaborate/group.rs b/libs/realtime/src/collaborate/group.rs index b5134df3..764b3f8e 100644 --- a/libs/realtime/src/collaborate/group.rs +++ b/libs/realtime/src/collaborate/group.rs @@ -37,13 +37,13 @@ where U: RealtimeUser, AC: CollabAccessControl, { - pub async fn run(mut self) { + pub async fn run(mut self, object_id: String) { let mut receiver = self.recv.take().expect("Only take once"); let stream = stream! { while let Some(msg) = receiver.recv().await { yield msg; } - trace!("The group command runner is stopped"); + trace!("Collab group:{} command runner is stopped", object_id); }; stream diff --git a/libs/realtime/src/collaborate/group_control.rs b/libs/realtime/src/collaborate/group_control.rs index 846887d0..ac29b998 100644 --- a/libs/realtime/src/collaborate/group_control.rs +++ b/libs/realtime/src/collaborate/group_control.rs @@ -14,7 +14,7 @@ use std::sync::Arc; use tokio::sync::Mutex; use tokio::task::spawn_blocking; use tokio::time::Instant; -use tracing::{debug, error, event, instrument}; +use tracing::{debug, error, event, instrument, trace}; pub struct CollabGroupControl { group_by_object_id: Arc>>>, @@ -174,6 +174,12 @@ pub struct CollabGroup { pub modified_at: Arc>, } +impl Drop for CollabGroup { + fn drop(&mut self) { + trace!("Drop collab group:{}", self.collab.lock().object_id); + } +} + impl CollabGroup where U: RealtimeUser, @@ -221,6 +227,61 @@ where } } + /// Subscribes a new connection to the broadcast group for collaborative activities. + /// + /// This method establishes a new subscription for a user, represented by a `sink`/`stream` pair. + /// These pairs implement the futures `Sink` and `Stream` protocols, facilitating real-time + /// communication between the server and the client. + /// + /// # Parameters + /// - `user`: Reference to the user initiating the subscription. Used for managing user-specific + /// subscriptions and ensuring unique subscriptions per user-device combination. + /// - `subscriber_origin`: Identifies the origin of the subscription, used to prevent echoing + /// messages back to the sender. + /// - `sink`: A `Sink` implementation used for sending collaboration changes to the client. + /// - `stream`: A `Stream` implementation for receiving messages from the client. + /// + /// # Behavior + /// - **Sink**: Utilized for forwarding any collaboration changes within the group to the client. + /// Ensures that updates are communicated in real-time. + /// + /// Collaboration Group Changes + /// | + /// | (1) Detect Change + /// V + /// +---------------------------+ + /// | Subscribe Function | + /// +---------------------------+ + /// | + /// | (2) Forward Update + /// V + /// +-------------+ + /// | | + /// | Sink |-----> (To Client) + /// | | + /// +-------------+ + /// + /// - **Stream**: Processes incoming messages from the client. After processing, responses are + /// dispatched back to the client through the `sink`. + /// (From Client) + /// | + /// | (1) Receive Message + /// V + /// +-------------+ + /// | | + /// | Stream | + /// | | + /// +-------------+ + /// | + /// | (2) Process Message + /// V + /// +---------------------------+ + /// | Subscribe Function | + /// +---------------------------+ + /// | + /// | (3) Alter Document (if applicable) + /// V + /// Collaboration Group Updates (triggers Sink flow) pub async fn subscribe( &self, user: &U, diff --git a/libs/realtime/src/collaborate/group_sub.rs b/libs/realtime/src/collaborate/group_sub.rs index 718c09e7..f0a661f4 100644 --- a/libs/realtime/src/collaborate/group_sub.rs +++ b/libs/realtime/src/collaborate/group_sub.rs @@ -33,10 +33,7 @@ where AC: CollabAccessControl, { fn get_origin(collab_message: &CollabMessage) -> &CollabOrigin { - collab_message.origin().unwrap_or_else(|| { - error!("🔴The origin from client message is empty"); - &CollabOrigin::Empty - }) + collab_message.origin() } fn make_channel<'b>( @@ -56,7 +53,7 @@ where object_id, move |object_id, msg| { if msg.object_id() != object_id { - warn!( + error!( "The object id:{} from message is not matched with the object id:{} from sink", msg.object_id(), object_id @@ -65,9 +62,9 @@ where } let object_id = object_id.to_string(); - let cloned_sink_permission_service = sink_permission_service.clone(); + let permission_service = sink_permission_service.clone(); Box::pin(async move { - match cloned_sink_permission_service + match permission_service .can_receive_collab_update(&client_uid, &object_id) .await { @@ -79,7 +76,6 @@ where object_id, ); } - is_allowed }, Err(err) => { diff --git a/libs/realtime/src/collaborate/server.rs b/libs/realtime/src/collaborate/server.rs index 4abca6f6..4d13ec94 100644 --- a/libs/realtime/src/collaborate/server.rs +++ b/libs/realtime/src/collaborate/server.rs @@ -131,22 +131,26 @@ where let sender = match old_sender { Some(sender) => sender, - None => match group_sender_by_object_id.entry(collab_message.object_id().to_string()) { - Entry::Occupied(entry) => entry.get().clone(), - Entry::Vacant(entry) => { - let (new_sender, recv) = tokio::sync::mpsc::channel(1000); - let runner = GroupCommandRunner { - group_control: groups.clone(), - client_stream_by_user: client_stream_by_user.clone(), - edit_collab_by_user: edit_collab_by_user.clone(), - access_control: access_control.clone(), - recv: Some(recv), - }; + None => { + let object_id = collab_message.object_id().to_string(); + match group_sender_by_object_id.entry(object_id) { + Entry::Occupied(entry) => entry.get().clone(), + Entry::Vacant(entry) => { + let (new_sender, recv) = tokio::sync::mpsc::channel(1000); + let runner = GroupCommandRunner { + group_control: groups.clone(), + client_stream_by_user: client_stream_by_user.clone(), + edit_collab_by_user: edit_collab_by_user.clone(), + access_control: access_control.clone(), + recv: Some(recv), + }; - tokio::task::spawn_local(runner.run()); - entry.insert(new_sender.clone()); - new_sender - }, + let object_id = entry.key().clone(); + tokio::task::spawn_local(runner.run(object_id)); + entry.insert(new_sender.clone()); + new_sender + }, + } }, }; @@ -436,6 +440,12 @@ impl CollabClientStream { } /// Returns a [UnboundedSenderSink] and a [ReceiverStream] for the object_id. + /// [Sink] will be used to receive changes from the collab object. Before receiving the changes, the sink_filter + /// will be used to check if the client is allowed to receive the changes. + /// + /// [Stream] will be used to send changes to the collab object. Before sending the changes, the stream_filter + /// will be used to check if the client is allowed to send the changes. + /// #[allow(clippy::type_complexity)] pub fn client_channel( &mut self, @@ -456,20 +466,20 @@ impl CollabClientStream { let cloned_object_id = object_id.to_string(); // Send the message to the connected websocket client - let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + let (client_sink_tx, mut client_sink_rx) = tokio::sync::mpsc::unbounded_channel::(); tokio::task::spawn(async move { - while let Some(msg) = rx.recv().await { + while let Some(msg) = client_sink_rx.recv().await { let can_sink = sink_filter(&cloned_object_id, &msg).await; if can_sink { // Send the message to websocket client actor client_ws_sink.do_send(msg.into()); } else { - // when then client is not allowed to receive the message + // when then client is not allowed to receive messages tokio::time::sleep(Duration::from_secs(2)).await; } } }); - let client_forward_sink = UnboundedSenderSink::::new(tx); + let client_sink = UnboundedSenderSink::::new(client_sink_tx); // forward the message to the stream that was subscribed by the broadcast group, which will // send the messages to all connected clients using the client_forward_sink @@ -480,19 +490,14 @@ impl CollabClientStream { if stream_filter(&cloned_object_id, &msg).await { let _ = tx.send(Ok(msg)).await; } else { - // when then client is not allowed to receive the message + // when then client is not allowed to send messages tokio::time::sleep(Duration::from_secs(2)).await; } } }); - let client_forward_stream = ReceiverStream::new(rx); + let client_stream = ReceiverStream::new(rx); - // When broadcast group write a message to the client_forward_sink, the message will be forwarded - // to the client's websocket sink, which will then send the message to the connected client - // - // When receiving a message from the client_forward_stream, it will send the message to the broadcast - // group. The message will be broadcast to all connected clients. - (client_forward_sink, client_forward_stream) + (client_sink, client_stream) } pub fn disconnect(&self) { diff --git a/src/biz/casbin/access_control.rs b/src/biz/casbin/access_control.rs index ccafabf4..f978720d 100644 --- a/src/biz/casbin/access_control.rs +++ b/src/biz/casbin/access_control.rs @@ -86,37 +86,57 @@ impl AccessControl { obj: &ObjectType<'_>, act: &ActionType, ) -> Result { - self.enforcer.update(uid, obj, act).await + if cfg!(feature = "disable_access_control") { + Ok(true) + } else { + self.enforcer.update(uid, obj, act).await + } } pub async fn remove(&self, uid: &i64, obj: &ObjectType<'_>) -> Result<(), AppError> { - self.enforcer.remove(uid, obj).await?; - Ok(()) + if cfg!(feature = "disable_access_control") { + Ok(()) + } else { + self.enforcer.remove(uid, obj).await?; + Ok(()) + } } pub async fn enforce(&self, uid: &i64, obj: &ObjectType<'_>, act: A) -> Result where A: ToCasbinAction, { - self.enforcer.enforce(uid, obj, act).await + if cfg!(feature = "disable_access_control") { + Ok(true) + } else { + self.enforcer.enforce(uid, obj, act).await + } } pub async fn get_access_level(&self, uid: &i64, oid: &str) -> Option { - let collab_id = ObjectType::Collab(oid); - self - .enforcer - .get_action(uid, &collab_id) - .await - .map(|value| AFAccessLevel::from_action(&value)) + if cfg!(feature = "disable_access_control") { + Some(AFAccessLevel::FullAccess) + } else { + let collab_id = ObjectType::Collab(oid); + self + .enforcer + .get_action(uid, &collab_id) + .await + .map(|value| AFAccessLevel::from_action(&value)) + } } pub async fn get_role(&self, uid: &i64, workspace_id: &str) -> Option { - let workspace_id = ObjectType::Workspace(workspace_id); - self - .enforcer - .get_action(uid, &workspace_id) - .await - .map(|value| AFRole::from_action(&value)) + if cfg!(feature = "disable_access_control") { + Some(AFRole::Owner) + } else { + let workspace_id = ObjectType::Workspace(workspace_id); + self + .enforcer + .get_action(uid, &workspace_id) + .await + .map(|value| AFRole::from_action(&value)) + } } } diff --git a/src/biz/casbin/collab_ac.rs b/src/biz/casbin/collab_ac.rs index 8ef76a9e..b97a2eca 100644 --- a/src/biz/casbin/collab_ac.rs +++ b/src/biz/casbin/collab_ac.rs @@ -81,16 +81,24 @@ impl CollabAccessControl for CollabAccessControlImpl { } async fn can_send_collab_update(&self, uid: &i64, oid: &str) -> Result { - self - .access_control - .enforce(uid, &ObjectType::Collab(oid), Action::Write) - .await + if cfg!(feature = "disable_access_control") { + Ok(true) + } else { + self + .access_control + .enforce(uid, &ObjectType::Collab(oid), Action::Write) + .await + } } async fn can_receive_collab_update(&self, uid: &i64, oid: &str) -> Result { - self - .access_control - .enforce(uid, &ObjectType::Collab(oid), Action::Read) - .await + if cfg!(feature = "disable_access_control") { + Ok(true) + } else { + self + .access_control + .enforce(uid, &ObjectType::Collab(oid), Action::Read) + .await + } } } diff --git a/src/biz/casbin/enforcer.rs b/src/biz/casbin/enforcer.rs index 7c9324ba..a2b65699 100644 --- a/src/biz/casbin/enforcer.rs +++ b/src/biz/casbin/enforcer.rs @@ -84,7 +84,6 @@ impl AFEnforcer { Ok(value) => { trace!("[access control]: add policy:{} => {}", policy_key.0, value); self.action_cache.insert(object_key, act.to_action()); - self.enforcer_result_cache.insert(policy_key, *value); }, Err(err) => { trace!( @@ -160,7 +159,6 @@ impl AFEnforcer { .get_filtered_policy(POLICY_FIELD_INDEX_OBJECT, vec![obj.to_object_id()]); if policies_for_object.is_empty() { - self.enforcer_result_cache.insert(policy_key, true); return Ok(true); } diff --git a/src/biz/casbin/pg_listen.rs b/src/biz/casbin/pg_listen.rs index 7c789d40..29eddd63 100644 --- a/src/biz/casbin/pg_listen.rs +++ b/src/biz/casbin/pg_listen.rs @@ -22,7 +22,8 @@ pub(crate) fn spawn_listen_on_collab_member_change( match change.action_type { CollabMemberAction::INSERT | CollabMemberAction::UPDATE => { if let Some(member_row) = change.new { - if let Ok(Some(row)) = select_permission(&pg_pool, &member_row.permission_id).await { + let permission_row = select_permission(&pg_pool, &member_row.permission_id).await; + if let Ok(Some(row)) = permission_row { if let Err(err) = enforcer .update( &member_row.uid, diff --git a/src/middleware/access_control_mw.rs b/src/middleware/access_control_mw.rs index 4603a9e9..00b1fef1 100644 --- a/src/middleware/access_control_mw.rs +++ b/src/middleware/access_control_mw.rs @@ -10,6 +10,8 @@ use async_trait::async_trait; use futures_util::future::LocalBoxFuture; use actix_web::web::Data; +use dashmap::DashMap; +use once_cell::sync::Lazy; use std::collections::HashMap; use std::future::{ready, Ready}; use std::ops::{Deref, DerefMut}; @@ -20,6 +22,8 @@ use crate::state::AppState; use app_error::AppError; use uuid::Uuid; +static RESOURCE_DEF_CACHE: Lazy> = Lazy::new(DashMap::new); + #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub enum AccessResource { Workspace, @@ -176,11 +180,16 @@ where fn call(&self, mut req: ServiceRequest) -> Self::Future { let path = req.match_pattern().map(|pattern| { - let resource_ref = ResourceDef::new(pattern); + // Create ResourceDef will cause memory leak, so we use the cache to store the ResourceDef let mut path = req.match_info().clone(); - resource_ref.capture_match_info(&mut path); + RESOURCE_DEF_CACHE + .entry(pattern.to_owned()) + .or_insert_with(|| ResourceDef::new(pattern)) + .value() + .capture_match_info(&mut path); path }); + match path { None => { let fut = self.service.call(req); diff --git a/tests/access_control/mod.rs b/tests/access_control/mod.rs index 73f765f9..00207aab 100644 --- a/tests/access_control/mod.rs +++ b/tests/access_control/mod.rs @@ -278,7 +278,7 @@ pub async fn assert_can_access_http_method( expected: bool, ) -> Result<(), Error> { let timeout_duration = Duration::from_secs(10); - let retry_interval = Duration::from_millis(300); + let retry_interval = Duration::from_millis(1000); let mut retries = 0usize; let max_retries = 10;