fix: potential leak (#347)

* fix: potential leak

* fix: potential leak
This commit is contained in:
Nathan.fooo 2024-02-24 07:40:46 +08:00 committed by GitHub
parent 66b7637ad0
commit 39323d99ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 45 additions and 57 deletions

View File

@ -7,7 +7,7 @@ use futures_util::{SinkExt, StreamExt};
use realtime_protocol::{handle_collab_message, Error}; use realtime_protocol::{handle_collab_message, Error};
use realtime_protocol::{Message, MessageReader, MSG_SYNC, MSG_SYNC_UPDATE}; use realtime_protocol::{Message, MessageReader, MSG_SYNC, MSG_SYNC_UPDATE};
use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc; use std::sync::{Arc, Weak};
use tokio::select; use tokio::select;
use tokio::sync::broadcast::error::SendError; use tokio::sync::broadcast::error::SendError;
use tokio::sync::broadcast::{channel, Sender}; use tokio::sync::broadcast::{channel, Sender};
@ -31,7 +31,6 @@ use yrs::encoding::write::Write;
/// ///
pub struct CollabBroadcast { pub struct CollabBroadcast {
object_id: String, object_id: String,
collab: MutexCollab,
sender: Sender<CollabMessage>, sender: Sender<CollabMessage>,
awareness_sub: Mutex<Option<awareness::UpdateSubscription>>, awareness_sub: Mutex<Option<awareness::UpdateSubscription>>,
/// Keep the lifetime of the document observer subscription. The subscription will be stopped /// Keep the lifetime of the document observer subscription. The subscription will be stopped
@ -55,13 +54,12 @@ impl CollabBroadcast {
/// ///
/// The overflow of the incoming events that needs to be propagates will be buffered up to a /// The overflow of the incoming events that needs to be propagates will be buffered up to a
/// provided `buffer_capacity` size. /// provided `buffer_capacity` size.
pub fn new(object_id: &str, collab: MutexCollab, buffer_capacity: usize) -> Self { pub fn new(object_id: &str, buffer_capacity: usize) -> Self {
let object_id = object_id.to_owned(); let object_id = object_id.to_owned();
// broadcast channel // broadcast channel
let (sender, _) = channel(buffer_capacity); let (sender, _) = channel(buffer_capacity);
CollabBroadcast { CollabBroadcast {
object_id, object_id,
collab,
sender, sender,
awareness_sub: Default::default(), awareness_sub: Default::default(),
doc_subscription: Default::default(), doc_subscription: Default::default(),
@ -70,9 +68,9 @@ impl CollabBroadcast {
} }
} }
pub async fn observe_collab_changes(&self) { pub async fn observe_collab_changes(&self, collab: &Arc<MutexCollab>) {
let (doc_sub, awareness_sub) = { let (doc_sub, awareness_sub) = {
let mut mutex_collab = self.collab.lock(); let mut mutex_collab = collab.lock();
// Observer the document's update and broadcast it to all subscribers. // Observer the document's update and broadcast it to all subscribers.
let cloned_oid = self.object_id.clone(); let cloned_oid = self.object_id.clone();
@ -129,11 +127,6 @@ impl CollabBroadcast {
*self.awareness_sub.lock().await = Some(awareness_sub); *self.awareness_sub.lock().await = Some(awareness_sub);
} }
/// Returns a reference to an underlying [MutexCollab] instance.
pub fn collab(&self) -> &MutexCollab {
&self.collab
}
/// Broadcasts user message to all active subscribers. Returns error if message could not have /// Broadcasts user message to all active subscribers. Returns error if message could not have
/// been broadcast. /// been broadcast.
#[allow(clippy::result_large_err)] #[allow(clippy::result_large_err)]
@ -175,6 +168,7 @@ impl CollabBroadcast {
subscriber_origin: CollabOrigin, subscriber_origin: CollabOrigin,
mut sink: Sink, mut sink: Sink,
mut stream: Stream, mut stream: Stream,
collab: Weak<MutexCollab>,
) -> Subscription ) -> Subscription
where where
Sink: SinkExt<CollabMessage> + Clone + Send + Sync + Unpin + 'static, Sink: SinkExt<CollabMessage> + Clone + Send + Sync + Unpin + 'static,
@ -218,7 +212,6 @@ impl CollabBroadcast {
let stream_stop_tx = { let stream_stop_tx = {
let (stream_stop_tx, mut stop_rx) = tokio::sync::mpsc::channel::<()>(1); let (stream_stop_tx, mut stop_rx) = tokio::sync::mpsc::channel::<()>(1);
let collab = self.collab().clone();
let object_id = self.object_id.clone(); let object_id = self.object_id.clone();
// the stream will continue to receive messages from the client and it will stop if the stop_rx // the stream will continue to receive messages from the client and it will stop if the stop_rx
@ -227,21 +220,26 @@ impl CollabBroadcast {
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
select! { select! {
_ = stop_rx.recv() => break, _ = stop_rx.recv() => break,
result = stream.next() => { result = stream.next() => {
match result { match result {
Some(Ok(collab_msg)) => { Some(Ok(collab_msg)) => {
// The message is valid if it has a payload and the object_id matches the broadcast's object_id. match collab.upgrade() {
if object_id == collab_msg.object_id() && collab_msg.payload().is_some() { None => break, // break the loop if the collab is dropped
handle_client_collab_message(&object_id, &mut sink, &collab_msg, &collab).await; Some(collab) => {
} else { // The message is valid if it has a payload and the object_id matches the broadcast's object_id.
warn!("Invalid collab message: {:?}", collab_msg); if object_id == collab_msg.object_id() && collab_msg.payload().is_some() {
} handle_client_collab_message(&object_id, &mut sink, &collab_msg, &collab).await;
}, } else {
Some(Err(e)) => error!("Error receiving collab message: {:?}", e.into()), warn!("Invalid collab message: {:?}", collab_msg);
None => break, }
} }
} }
},
Some(Err(e)) => error!("Error receiving collab message: {:?}", e.into()),
None => break,
}
}
} }
} }
}); });

View File

@ -2,9 +2,9 @@ use crate::collaborate::{CollabAccessControl, CollabBroadcast, CollabStoragePlug
use crate::entities::RealtimeUser; use crate::entities::RealtimeUser;
use anyhow::Error; use anyhow::Error;
use collab::core::collab::MutexCollab; use collab::core::collab::MutexCollab;
use collab::core::collab_plugin::EncodedCollab;
use collab::core::origin::CollabOrigin; use collab::core::origin::CollabOrigin;
use collab::preclude::Collab;
use collab_entity::CollabType; use collab_entity::CollabType;
use dashmap::DashMap; use dashmap::DashMap;
use database::collab::CollabStorage; use database::collab::CollabStorage;
@ -37,8 +37,7 @@ where
} }
/// Performs a periodic check to remove groups based on the following conditions: /// Performs a periodic check to remove groups based on the following conditions:
/// 1. Groups without any subscribers. /// Groups that have been inactive for a specified period of time.
/// 2. Groups that have been inactive for a specified period of time.
pub async fn tick(&self) -> Vec<String> { pub async fn tick(&self) -> Vec<String> {
let mut inactive_group_ids = vec![]; let mut inactive_group_ids = vec![];
for entry in self.group_by_object_id.iter() { for entry in self.group_by_object_id.iter() {
@ -123,9 +122,8 @@ where
collab_type: CollabType, collab_type: CollabType,
) -> Arc<CollabGroup<U>> { ) -> Arc<CollabGroup<U>> {
event!(tracing::Level::TRACE, "New group:{}", object_id); event!(tracing::Level::TRACE, "New group:{}", object_id);
let collab = MutexCollab::new(CollabOrigin::Server, object_id, vec![]); let collab = Arc::new(MutexCollab::new(CollabOrigin::Server, object_id, vec![]));
let broadcast = CollabBroadcast::new(object_id, collab.clone(), 10); let broadcast = CollabBroadcast::new(object_id, 10);
let collab = Arc::new(collab.clone());
// The lifecycle of the collab is managed by the group. // The lifecycle of the collab is managed by the group.
let group = Arc::new(CollabGroup::new( let group = Arc::new(CollabGroup::new(
@ -149,7 +147,7 @@ where
.storage .storage
.cache_collab(object_id, Arc::downgrade(&collab)) .cache_collab(object_id, Arc::downgrade(&collab))
.await; .await;
group.observe_collab().await; group.observe_collab(&collab).await;
group group
} }
@ -160,7 +158,7 @@ where
/// A group used to manage a single [Collab] object /// A group used to manage a single [Collab] object
pub struct CollabGroup<U> { pub struct CollabGroup<U> {
pub collab: Arc<MutexCollab>, collab: Arc<MutexCollab>,
collab_type: CollabType, collab_type: CollabType,
/// A broadcast used to propagate updates produced by yrs [yrs::Doc] and [Awareness] /// A broadcast used to propagate updates produced by yrs [yrs::Doc] and [Awareness]
/// to subscribes. /// to subscribes.
@ -195,8 +193,8 @@ where
} }
} }
pub async fn observe_collab(&self) { pub async fn observe_collab(&self, collab: &Arc<MutexCollab>) {
self.broadcast.observe_collab_changes().await; self.broadcast.observe_collab_changes(collab).await;
} }
pub fn contains_user(&self, user: &U) -> bool { pub fn contains_user(&self, user: &U) -> bool {
@ -289,7 +287,12 @@ where
<Sink as futures_util::Sink<CollabMessage>>::Error: std::error::Error + Send + Sync, <Sink as futures_util::Sink<CollabMessage>>::Error: std::error::Error + Send + Sync,
E: Into<Error> + Send + Sync + 'static, E: Into<Error> + Send + Sync + 'static,
{ {
let sub = self.broadcast.subscribe(subscriber_origin, sink, stream); let sub = self.broadcast.subscribe(
subscriber_origin,
sink,
stream,
Arc::downgrade(&self.collab),
);
// Remove the old user if it exists // Remove the old user if it exists
let user_device = user.user_device(); let user_device = user.user_device();
@ -305,19 +308,6 @@ where
self.subscribers.insert((*user).clone(), sub); self.subscribers.insert((*user).clone(), sub);
} }
/// Mutate the [Collab] by the given closure
pub fn get_mut_collab<F>(&self, f: F)
where
F: FnOnce(&Collab),
{
let collab = self.collab.lock();
f(&collab);
}
pub fn encode_v1(&self) -> EncodedCollab {
self.collab.lock().encode_collab_v1()
}
pub async fn is_empty(&self) -> bool { pub async fn is_empty(&self) -> bool {
self.subscribers.is_empty() self.subscribers.is_empty()
} }

View File

@ -85,15 +85,15 @@ where
loop { loop {
interval.tick().await; interval.tick().await;
if let Some(groups) = weak_groups.upgrade() { if let Some(groups) = weak_groups.upgrade() {
cloned_metrics.record_opening_collab_count(groups.number_of_groups().await);
cloned_metrics.record_connected_users(cloned_client_stream_by_user.len());
cloned_metrics
.record_encode_collab_mem_hit_rate(cloned_storage.encode_collab_mem_hit_rate());
let inactive_group_ids = groups.tick().await; let inactive_group_ids = groups.tick().await;
for id in inactive_group_ids { for id in inactive_group_ids {
cloned_group_sender_by_object_id.remove(&id); cloned_group_sender_by_object_id.remove(&id);
} }
cloned_metrics.record_opening_collab_count(groups.number_of_groups().await);
cloned_metrics.record_connected_users(cloned_client_stream_by_user.len());
cloned_metrics
.record_encode_collab_mem_hit_rate(cloned_storage.encode_collab_mem_hit_rate());
} else { } else {
break; break;
} }