diff --git a/Cargo.lock b/Cargo.lock index 21821fed..bb81df4f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1518,6 +1518,7 @@ dependencies = [ "tokio", "tokio-retry", "tokio-stream", + "tokio-util", "tracing", "url", "uuid", diff --git a/libs/client-api/Cargo.toml b/libs/client-api/Cargo.toml index b10d8964..1d7edaca 100644 --- a/libs/client-api/Cargo.toml +++ b/libs/client-api/Cargo.toml @@ -46,6 +46,7 @@ scraper = { version = "0.17.1", optional = true } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] tokio-retry = "0.3" +tokio-util = "0.7" [target.'cfg(not(target_arch = "wasm32"))'.dependencies.tokio] workspace = true diff --git a/libs/client-api/src/collab_sync/collab_sink.rs b/libs/client-api/src/collab_sync/collab_sink.rs index 8983bd78..65cd811d 100644 --- a/libs/client-api/src/collab_sync/collab_sink.rs +++ b/libs/client-api/src/collab_sync/collab_sink.rs @@ -1,7 +1,6 @@ use crate::af_spawn; use crate::collab_sync::collab_stream::SeqNumCounter; -use crate::collab_sync::period_state_check::CollabStateCheckRunner; use crate::collab_sync::{SinkConfig, SyncError, SyncObject}; use anyhow::Error; use collab::core::origin::{CollabClient, CollabOrigin}; @@ -71,18 +70,11 @@ where let mut interval = interval(SEND_INTERVAL); let weak_sending_messages = Arc::downgrade(&sending_messages); - let weak_notifier = Arc::downgrade(¬ifier); - let origin = CollabOrigin::Client(CollabClient { + let _weak_notifier = Arc::downgrade(¬ifier); + let _origin = CollabOrigin::Client(CollabClient { uid, device_id: object.device_id.clone(), }); - CollabStateCheckRunner::run( - origin, - object.object_id.clone(), - Arc::downgrade(&message_queue), - weak_notifier, - state.clone(), - ); let cloned_state = state.clone(); let weak_notifier = Arc::downgrade(¬ifier); diff --git a/libs/client-api/src/collab_sync/collab_stream.rs b/libs/client-api/src/collab_sync/collab_stream.rs index 8a3ef89e..b9ae82ca 100644 --- a/libs/client-api/src/collab_sync/collab_stream.rs +++ b/libs/client-api/src/collab_sync/collab_stream.rs @@ -1,5 +1,7 @@ use crate::af_spawn; -use crate::collab_sync::{start_sync, CollabSink, SyncError, SyncObject, SyncReason}; +use crate::collab_sync::{ + start_sync, CollabSink, MissUpdateReason, SyncError, SyncObject, SyncReason, +}; use collab::core::collab::MutexCollab; use collab::core::origin::CollabOrigin; @@ -11,6 +13,9 @@ use futures_util::{SinkExt, StreamExt}; use std::marker::PhantomData; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::{Arc, Weak}; +use tokio::select; +use tokio::sync::Mutex; +use tokio_util::sync::CancellationToken; use tracing::{error, instrument, trace, warn}; use yrs::encoding::read::Cursor; @@ -52,13 +57,16 @@ where let cloned_weak_collab = weak_collab.clone(); let seq_num_counter = Arc::new(SeqNumCounter::default()); let cloned_seq_num_counter = seq_num_counter.clone(); + let init_sync_cancel_token = Arc::new(Mutex::new(CancellationToken::new())); + let arc_object = Arc::new(object); af_spawn(ObserveCollab::::observer_collab_message( origin, - object, + arc_object, stream, cloned_weak_collab, sink, cloned_seq_num_counter, + init_sync_cancel_token, )); Self { object_id, @@ -72,11 +80,12 @@ where // Spawn the stream that continuously reads the doc's updates from remote. async fn observer_collab_message( origin: CollabOrigin, - object: SyncObject, + object: Arc, mut stream: Stream, weak_collab: Weak, weak_sink: Weak>, seq_num_counter: Arc, + cancel_token: Arc>, ) { while let Some(collab_message_result) = stream.next().await { let collab = match weak_collab.upgrade() { @@ -92,12 +101,16 @@ where let msg = match collab_message_result { Ok(msg) => msg, Err(err) => { - warn!("Stream error:{}, stop receive incoming changes", err.into()); + warn!( + "{} stream error:{}, stop receive incoming changes", + object.object_id, + err.into() + ); break; }, }; - if let Err(error) = ObserveCollab::::process_message( + if let Err(error) = ObserveCollab::::process_remote_message( &object, &collab, &sink, @@ -111,8 +124,29 @@ where state_vector_v1, reason, } => { - Self::pull_missing_updates(&origin, &object, &collab, &sink, state_vector_v1, reason) - .await; + let mut cancel_token_lock = cancel_token.lock().await; + cancel_token_lock.cancel(); + let new_cancel_token = CancellationToken::new(); + *cancel_token_lock = new_cancel_token.clone(); + drop(cancel_token_lock); + + let cloned_origin = origin.clone(); + let cloned_object = object.clone(); + let collab = collab.clone(); + let sink = sink.clone(); + tokio::spawn(async move { + select! { + _ = new_cancel_token.cancelled() => { + if cfg!(feature = "sync_verbose_log") { + trace!("{} receive cancel signal, cancel pull missing updates", cloned_object.object_id); + } + }, + _ = tokio::time::sleep(tokio::time::Duration::from_secs(3)) => { + Self::pull_missing_updates(&cloned_origin, &cloned_object, &collab, &sink, state_vector_v1, reason) + .await; + } + } + }); }, SyncError::CannotApplyUpdate => { if let Some(lock_guard) = collab.try_lock() { @@ -136,13 +170,17 @@ where } /// Continuously handle messages from the remote doc - async fn process_message( + async fn process_remote_message( object: &SyncObject, collab: &Arc, sink: &Arc>, msg: ServerCollabMessage, seq_num_counter: &Arc, ) -> Result<(), SyncError> { + if cfg!(feature = "sync_verbose_log") { + trace!("handle server: {}", msg); + } + if let ServerCollabMessage::ClientAck(ack) = &msg { let ack_code = ack.get_code(); // if the server can not apply the update, we start the init sync. @@ -153,7 +191,7 @@ where if ack_code == AckCode::MissUpdate { return Err(SyncError::MissUpdates { state_vector_v1: Some(ack.payload.to_vec()), - reason: "server miss updates".to_string(), + reason: MissUpdateReason::ServerMissUpdates, }); } } @@ -161,12 +199,14 @@ where // msg_id will be None for [ServerBroadcast] or [ServerAwareness]. match msg.msg_id() { None => { + // apply the broadcast data and then check the continuity of the broadcast sequence number. + Self::process_message_follow_protocol(&object.object_id, &msg, collab, sink).await?; + sink.notify_next(); + if let ServerCollabMessage::ServerBroadcast(ref data) = msg { seq_num_counter.check_broadcast_contiguous(&object.object_id, data.seq_num)?; seq_num_counter.store_broadcast_seq_num(data.seq_num); } - Self::process_message_follow_protocol(&object.object_id, &msg, collab, sink).await?; - sink.notify_next(); Ok(()) }, Some(msg_id) => { @@ -190,7 +230,7 @@ where collab: &Arc, sink: &Arc>, state_vector_v1: Option>, - reason: String, + reason: MissUpdateReason, ) { if let Some(lock_guard) = collab.try_lock() { let reason = SyncReason::MissUpdates { @@ -334,17 +374,17 @@ impl SeqNumCounter { /// messages may have been missed, and an error is returned. pub fn check_broadcast_contiguous( &self, - object_id: &str, + _object_id: &str, broadcast_seq_num: u32, ) -> Result<(), SyncError> { let current = self.broadcast_seq_counter.load(Ordering::SeqCst); if current > 0 && broadcast_seq_num > current + 1 { return Err(SyncError::MissUpdates { state_vector_v1: None, - reason: format!( - "{} broadcast is not contiguous, current:{}, broadcast:{}", - object_id, current, broadcast_seq_num, - ), + reason: MissUpdateReason::BroadcastSeqNotContinuous { + current, + expected: broadcast_seq_num, + }, }); } @@ -354,7 +394,14 @@ impl SeqNumCounter { pub fn check_ack_broadcast_contiguous(&self, object_id: &str) -> Result<(), SyncError> { let ack_seq_num = self.ack_seq_counter.load(Ordering::SeqCst); let broadcast_seq_num = self.broadcast_seq_counter.load(Ordering::SeqCst); - log_ack_and_broadcast(object_id, ack_seq_num, broadcast_seq_num); + if cfg!(feature = "sync_verbose_log") { + trace!( + "receive {} seq_num, ack:{}, broadcast:{}", + object_id, + ack_seq_num, + broadcast_seq_num, + ); + } if ack_seq_num > broadcast_seq_num { // calculate the number of times the ack is greater than the broadcast. We don't do return MissingUpdates @@ -372,10 +419,10 @@ impl SeqNumCounter { return Err(SyncError::MissUpdates { state_vector_v1: None, - reason: format!( - "ack is not equal to broadcast, ack:{}, broadcast:{}", - ack_seq_num, broadcast_seq_num, - ), + reason: MissUpdateReason::AckSeqAdvanceBroadcastSeq { + ack_seq: ack_seq_num, + broadcast_seq: broadcast_seq_num, + }, }); } } @@ -383,13 +430,3 @@ impl SeqNumCounter { Ok(()) } } - -#[cfg(feature = "sync_verbose_log")] -fn log_ack_and_broadcast(object_id: &str, ack_seq_num: u32, broadcast_seq_num: u32) { - trace!( - "receive {} seq_num, ack:{}, broadcast:{}", - object_id, - ack_seq_num, - broadcast_seq_num, - ); -} diff --git a/libs/client-api/src/collab_sync/error.rs b/libs/client-api/src/collab_sync/error.rs index 49545dde..5e6a6177 100644 --- a/libs/client-api/src/collab_sync/error.rs +++ b/libs/client-api/src/collab_sync/error.rs @@ -1,4 +1,5 @@ use collab_rt_protocol::RTProtocolError; +use std::fmt::Display; #[derive(Debug, thiserror::Error)] pub enum SyncError { @@ -29,7 +30,7 @@ pub enum SyncError { #[error("Missing updates")] MissUpdates { state_vector_v1: Option>, - reason: String, + reason: MissUpdateReason, }, #[error("Can not apply update")] @@ -39,6 +40,40 @@ pub enum SyncError { Internal(#[from] anyhow::Error), } +#[derive(Debug)] +pub enum MissUpdateReason { + BroadcastSeqNotContinuous { current: u32, expected: u32 }, + AckSeqAdvanceBroadcastSeq { ack_seq: u32, broadcast_seq: u32 }, + ServerMissUpdates, + Other(String), +} + +impl Display for MissUpdateReason { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MissUpdateReason::BroadcastSeqNotContinuous { current, expected } => { + write!( + f, + "Broadcast sequence not continuous: current={}, expected={}", + current, expected + ) + }, + MissUpdateReason::AckSeqAdvanceBroadcastSeq { + ack_seq, + broadcast_seq, + } => { + write!( + f, + "Ack sequence advance broadcast sequence: ack_seq={}, broadcast_seq={}", + ack_seq, broadcast_seq + ) + }, + MissUpdateReason::ServerMissUpdates => write!(f, "Server miss updates"), + MissUpdateReason::Other(reason) => write!(f, "{}", reason), + } + } +} + impl From for SyncError { fn from(value: RTProtocolError) -> Self { match value { @@ -47,7 +82,7 @@ impl From for SyncError { reason, } => Self::MissUpdates { state_vector_v1, - reason, + reason: MissUpdateReason::Other(reason), }, RTProtocolError::DecodingError(e) => Self::DecodingError(e), RTProtocolError::YAwareness(e) => Self::YAwareness(e), diff --git a/libs/client-api/src/collab_sync/sync_control.rs b/libs/client-api/src/collab_sync/sync_control.rs index bdb6d092..c5e7c3b0 100644 --- a/libs/client-api/src/collab_sync/sync_control.rs +++ b/libs/client-api/src/collab_sync/sync_control.rs @@ -1,7 +1,8 @@ use crate::af_spawn; use crate::collab_sync::collab_stream::ObserveCollab; use crate::collab_sync::{ - CollabSink, CollabSinkRunner, CollabSyncState, SinkSignal, SyncError, SyncObject, + CollabSink, CollabSinkRunner, CollabSyncState, MissUpdateReason, SinkSignal, SyncError, + SyncObject, }; use collab::core::awareness::Awareness; @@ -130,7 +131,7 @@ pub enum SyncReason { CollabInitialize, MissUpdates { state_vector_v1: Option>, - reason: String, + reason: MissUpdateReason, }, ServerCannotApplyUpdate, NetworkResume, diff --git a/libs/collab-rt-entity/src/server_message.rs b/libs/collab-rt-entity/src/server_message.rs index e0aa7054..63b69686 100644 --- a/libs/collab-rt-entity/src/server_message.rs +++ b/libs/collab-rt-entity/src/server_message.rs @@ -285,9 +285,10 @@ impl BroadcastSync { impl Display for BroadcastSync { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.write_fmt(format_args!( - "broadcast: [oid:{}|len:{}]", + "broadcast: [oid:{}|len:{}|seq_num:{}]", self.object_id, self.payload.len(), + self.seq_num )) } } diff --git a/tests/collab/pending_write_test.rs b/tests/collab/pending_write_test.rs index 0c713183..34c4a73d 100644 --- a/tests/collab/pending_write_test.rs +++ b/tests/collab/pending_write_test.rs @@ -104,35 +104,25 @@ async fn simulate_large_data_set_write(pool: PgPool) { let storage_queue = StorageQueue::new(collab_cache.clone(), conn, &queue_name); let queries = Arc::new(Mutex::new(Vec::new())); - for i in 0..5 { - let cloned_storage_queue = storage_queue.clone(); - let cloned_queries = queries.clone(); - let cloned_user = user.clone(); - tokio::spawn(async move { - // sleep random seconds less than 2 seconds. because the runtime is single-threaded, - // we need sleep a little time to let the runtime switch to other tasks. - sleep(Duration::from_millis(i % 2)).await; + for i in 0..3 { + // sleep random seconds less than 2 seconds. because the runtime is single-threaded, + // we need sleep a little time to let the runtime switch to other tasks. + sleep(Duration::from_millis(i % 2)).await; - let encode_collab = EncodedCollab::new_v1( - generate_random_bytes(10 * 1024), - generate_random_bytes(1024 * 1024), - ); - let params = CollabParams { - object_id: format!("object_id_{}", i), - collab_type: CollabType::Unknown, - encoded_collab_v1: encode_collab.encode_to_bytes().unwrap(), - }; - cloned_storage_queue - .push( - &cloned_user.workspace_id, - &cloned_user.uid, - ¶ms, - WritePriority::Low, - ) - .await - .unwrap(); - cloned_queries.lock().await.push((params, encode_collab)); - }); + let encode_collab = EncodedCollab::new_v1( + generate_random_bytes(10 * 1024), + generate_random_bytes(2 * 1024 * 1024), + ); + let params = CollabParams { + object_id: format!("object_id_{}", i), + collab_type: CollabType::Unknown, + encoded_collab_v1: encode_collab.encode_to_bytes().unwrap(), + }; + storage_queue + .push(&user.workspace_id, &user.uid, ¶ms, WritePriority::Low) + .await + .unwrap(); + queries.lock().await.push((params, encode_collab)); } // Allow some time for processing