From 39323d99ac2d96896b0844112f368bde104fdb0a Mon Sep 17 00:00:00 2001 From: "Nathan.fooo" <86001920+appflowy@users.noreply.github.com> Date: Sat, 24 Feb 2024 07:40:46 +0800 Subject: [PATCH] fix: potential leak (#347) * fix: potential leak * fix: potential leak --- libs/realtime/src/collaborate/broadcast.rs | 52 +++++++++---------- .../realtime/src/collaborate/group_control.rs | 40 ++++++-------- libs/realtime/src/collaborate/server.rs | 10 ++-- 3 files changed, 45 insertions(+), 57 deletions(-) diff --git a/libs/realtime/src/collaborate/broadcast.rs b/libs/realtime/src/collaborate/broadcast.rs index 0b7ebe9b..4fd1dac0 100644 --- a/libs/realtime/src/collaborate/broadcast.rs +++ b/libs/realtime/src/collaborate/broadcast.rs @@ -7,7 +7,7 @@ use futures_util::{SinkExt, StreamExt}; use realtime_protocol::{handle_collab_message, Error}; use realtime_protocol::{Message, MessageReader, MSG_SYNC, MSG_SYNC_UPDATE}; use std::sync::atomic::{AtomicU32, Ordering}; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use tokio::select; use tokio::sync::broadcast::error::SendError; use tokio::sync::broadcast::{channel, Sender}; @@ -31,7 +31,6 @@ use yrs::encoding::write::Write; /// pub struct CollabBroadcast { object_id: String, - collab: MutexCollab, sender: Sender, awareness_sub: Mutex>, /// Keep the lifetime of the document observer subscription. The subscription will be stopped @@ -55,13 +54,12 @@ impl CollabBroadcast { /// /// The overflow of the incoming events that needs to be propagates will be buffered up to a /// provided `buffer_capacity` size. - pub fn new(object_id: &str, collab: MutexCollab, buffer_capacity: usize) -> Self { + pub fn new(object_id: &str, buffer_capacity: usize) -> Self { let object_id = object_id.to_owned(); // broadcast channel let (sender, _) = channel(buffer_capacity); CollabBroadcast { object_id, - collab, sender, awareness_sub: Default::default(), doc_subscription: Default::default(), @@ -70,9 +68,9 @@ impl CollabBroadcast { } } - pub async fn observe_collab_changes(&self) { + pub async fn observe_collab_changes(&self, collab: &Arc) { let (doc_sub, awareness_sub) = { - let mut mutex_collab = self.collab.lock(); + let mut mutex_collab = collab.lock(); // Observer the document's update and broadcast it to all subscribers. let cloned_oid = self.object_id.clone(); @@ -129,11 +127,6 @@ impl CollabBroadcast { *self.awareness_sub.lock().await = Some(awareness_sub); } - /// Returns a reference to an underlying [MutexCollab] instance. - pub fn collab(&self) -> &MutexCollab { - &self.collab - } - /// Broadcasts user message to all active subscribers. Returns error if message could not have /// been broadcast. #[allow(clippy::result_large_err)] @@ -175,6 +168,7 @@ impl CollabBroadcast { subscriber_origin: CollabOrigin, mut sink: Sink, mut stream: Stream, + collab: Weak, ) -> Subscription where Sink: SinkExt + Clone + Send + Sync + Unpin + 'static, @@ -218,7 +212,6 @@ impl CollabBroadcast { let stream_stop_tx = { let (stream_stop_tx, mut stop_rx) = tokio::sync::mpsc::channel::<()>(1); - let collab = self.collab().clone(); let object_id = self.object_id.clone(); // the stream will continue to receive messages from the client and it will stop if the stop_rx @@ -227,21 +220,26 @@ impl CollabBroadcast { tokio::spawn(async move { loop { select! { - _ = stop_rx.recv() => break, - result = stream.next() => { - match result { - Some(Ok(collab_msg)) => { - // The message is valid if it has a payload and the object_id matches the broadcast's object_id. - if object_id == collab_msg.object_id() && collab_msg.payload().is_some() { - handle_client_collab_message(&object_id, &mut sink, &collab_msg, &collab).await; - } else { - warn!("Invalid collab message: {:?}", collab_msg); - } - }, - Some(Err(e)) => error!("Error receiving collab message: {:?}", e.into()), - None => break, - } - } + _ = stop_rx.recv() => break, + result = stream.next() => { + match result { + Some(Ok(collab_msg)) => { + match collab.upgrade() { + None => break, // break the loop if the collab is dropped + Some(collab) => { + // The message is valid if it has a payload and the object_id matches the broadcast's object_id. + if object_id == collab_msg.object_id() && collab_msg.payload().is_some() { + handle_client_collab_message(&object_id, &mut sink, &collab_msg, &collab).await; + } else { + warn!("Invalid collab message: {:?}", collab_msg); + } + } + } + }, + Some(Err(e)) => error!("Error receiving collab message: {:?}", e.into()), + None => break, + } + } } } }); diff --git a/libs/realtime/src/collaborate/group_control.rs b/libs/realtime/src/collaborate/group_control.rs index 99e31949..d5cf90ea 100644 --- a/libs/realtime/src/collaborate/group_control.rs +++ b/libs/realtime/src/collaborate/group_control.rs @@ -2,9 +2,9 @@ use crate::collaborate::{CollabAccessControl, CollabBroadcast, CollabStoragePlug use crate::entities::RealtimeUser; use anyhow::Error; use collab::core::collab::MutexCollab; -use collab::core::collab_plugin::EncodedCollab; + use collab::core::origin::CollabOrigin; -use collab::preclude::Collab; + use collab_entity::CollabType; use dashmap::DashMap; use database::collab::CollabStorage; @@ -37,8 +37,7 @@ where } /// Performs a periodic check to remove groups based on the following conditions: - /// 1. Groups without any subscribers. - /// 2. Groups that have been inactive for a specified period of time. + /// Groups that have been inactive for a specified period of time. pub async fn tick(&self) -> Vec { let mut inactive_group_ids = vec![]; for entry in self.group_by_object_id.iter() { @@ -123,9 +122,8 @@ where collab_type: CollabType, ) -> Arc> { event!(tracing::Level::TRACE, "New group:{}", object_id); - let collab = MutexCollab::new(CollabOrigin::Server, object_id, vec![]); - let broadcast = CollabBroadcast::new(object_id, collab.clone(), 10); - let collab = Arc::new(collab.clone()); + let collab = Arc::new(MutexCollab::new(CollabOrigin::Server, object_id, vec![])); + let broadcast = CollabBroadcast::new(object_id, 10); // The lifecycle of the collab is managed by the group. let group = Arc::new(CollabGroup::new( @@ -149,7 +147,7 @@ where .storage .cache_collab(object_id, Arc::downgrade(&collab)) .await; - group.observe_collab().await; + group.observe_collab(&collab).await; group } @@ -160,7 +158,7 @@ where /// A group used to manage a single [Collab] object pub struct CollabGroup { - pub collab: Arc, + collab: Arc, collab_type: CollabType, /// A broadcast used to propagate updates produced by yrs [yrs::Doc] and [Awareness] /// to subscribes. @@ -195,8 +193,8 @@ where } } - pub async fn observe_collab(&self) { - self.broadcast.observe_collab_changes().await; + pub async fn observe_collab(&self, collab: &Arc) { + self.broadcast.observe_collab_changes(collab).await; } pub fn contains_user(&self, user: &U) -> bool { @@ -289,7 +287,12 @@ where >::Error: std::error::Error + Send + Sync, E: Into + Send + Sync + 'static, { - let sub = self.broadcast.subscribe(subscriber_origin, sink, stream); + let sub = self.broadcast.subscribe( + subscriber_origin, + sink, + stream, + Arc::downgrade(&self.collab), + ); // Remove the old user if it exists let user_device = user.user_device(); @@ -305,19 +308,6 @@ where self.subscribers.insert((*user).clone(), sub); } - /// Mutate the [Collab] by the given closure - pub fn get_mut_collab(&self, f: F) - where - F: FnOnce(&Collab), - { - let collab = self.collab.lock(); - f(&collab); - } - - pub fn encode_v1(&self) -> EncodedCollab { - self.collab.lock().encode_collab_v1() - } - pub async fn is_empty(&self) -> bool { self.subscribers.is_empty() } diff --git a/libs/realtime/src/collaborate/server.rs b/libs/realtime/src/collaborate/server.rs index 90f8a5d3..6a90eb68 100644 --- a/libs/realtime/src/collaborate/server.rs +++ b/libs/realtime/src/collaborate/server.rs @@ -85,15 +85,15 @@ where loop { interval.tick().await; if let Some(groups) = weak_groups.upgrade() { - cloned_metrics.record_opening_collab_count(groups.number_of_groups().await); - cloned_metrics.record_connected_users(cloned_client_stream_by_user.len()); - cloned_metrics - .record_encode_collab_mem_hit_rate(cloned_storage.encode_collab_mem_hit_rate()); - let inactive_group_ids = groups.tick().await; for id in inactive_group_ids { cloned_group_sender_by_object_id.remove(&id); } + + cloned_metrics.record_opening_collab_count(groups.number_of_groups().await); + cloned_metrics.record_connected_users(cloned_client_stream_by_user.len()); + cloned_metrics + .record_encode_collab_mem_hit_rate(cloned_storage.encode_collab_mem_hit_rate()); } else { break; }