fix: potential leak (#347)

* fix: potential leak

* fix: potential leak
This commit is contained in:
Nathan.fooo 2024-02-24 07:40:46 +08:00 committed by GitHub
parent 66b7637ad0
commit 39323d99ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 45 additions and 57 deletions

View File

@ -7,7 +7,7 @@ use futures_util::{SinkExt, StreamExt};
use realtime_protocol::{handle_collab_message, Error};
use realtime_protocol::{Message, MessageReader, MSG_SYNC, MSG_SYNC_UPDATE};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::sync::{Arc, Weak};
use tokio::select;
use tokio::sync::broadcast::error::SendError;
use tokio::sync::broadcast::{channel, Sender};
@ -31,7 +31,6 @@ use yrs::encoding::write::Write;
///
pub struct CollabBroadcast {
object_id: String,
collab: MutexCollab,
sender: Sender<CollabMessage>,
awareness_sub: Mutex<Option<awareness::UpdateSubscription>>,
/// Keep the lifetime of the document observer subscription. The subscription will be stopped
@ -55,13 +54,12 @@ impl CollabBroadcast {
///
/// The overflow of the incoming events that needs to be propagates will be buffered up to a
/// provided `buffer_capacity` size.
pub fn new(object_id: &str, collab: MutexCollab, buffer_capacity: usize) -> Self {
pub fn new(object_id: &str, buffer_capacity: usize) -> Self {
let object_id = object_id.to_owned();
// broadcast channel
let (sender, _) = channel(buffer_capacity);
CollabBroadcast {
object_id,
collab,
sender,
awareness_sub: Default::default(),
doc_subscription: Default::default(),
@ -70,9 +68,9 @@ impl CollabBroadcast {
}
}
pub async fn observe_collab_changes(&self) {
pub async fn observe_collab_changes(&self, collab: &Arc<MutexCollab>) {
let (doc_sub, awareness_sub) = {
let mut mutex_collab = self.collab.lock();
let mut mutex_collab = collab.lock();
// Observer the document's update and broadcast it to all subscribers.
let cloned_oid = self.object_id.clone();
@ -129,11 +127,6 @@ impl CollabBroadcast {
*self.awareness_sub.lock().await = Some(awareness_sub);
}
/// Returns a reference to an underlying [MutexCollab] instance.
pub fn collab(&self) -> &MutexCollab {
&self.collab
}
/// Broadcasts user message to all active subscribers. Returns error if message could not have
/// been broadcast.
#[allow(clippy::result_large_err)]
@ -175,6 +168,7 @@ impl CollabBroadcast {
subscriber_origin: CollabOrigin,
mut sink: Sink,
mut stream: Stream,
collab: Weak<MutexCollab>,
) -> Subscription
where
Sink: SinkExt<CollabMessage> + Clone + Send + Sync + Unpin + 'static,
@ -218,7 +212,6 @@ impl CollabBroadcast {
let stream_stop_tx = {
let (stream_stop_tx, mut stop_rx) = tokio::sync::mpsc::channel::<()>(1);
let collab = self.collab().clone();
let object_id = self.object_id.clone();
// the stream will continue to receive messages from the client and it will stop if the stop_rx
@ -227,21 +220,26 @@ impl CollabBroadcast {
tokio::spawn(async move {
loop {
select! {
_ = stop_rx.recv() => break,
result = stream.next() => {
match result {
Some(Ok(collab_msg)) => {
// The message is valid if it has a payload and the object_id matches the broadcast's object_id.
if object_id == collab_msg.object_id() && collab_msg.payload().is_some() {
handle_client_collab_message(&object_id, &mut sink, &collab_msg, &collab).await;
} else {
warn!("Invalid collab message: {:?}", collab_msg);
}
},
Some(Err(e)) => error!("Error receiving collab message: {:?}", e.into()),
None => break,
}
}
_ = stop_rx.recv() => break,
result = stream.next() => {
match result {
Some(Ok(collab_msg)) => {
match collab.upgrade() {
None => break, // break the loop if the collab is dropped
Some(collab) => {
// The message is valid if it has a payload and the object_id matches the broadcast's object_id.
if object_id == collab_msg.object_id() && collab_msg.payload().is_some() {
handle_client_collab_message(&object_id, &mut sink, &collab_msg, &collab).await;
} else {
warn!("Invalid collab message: {:?}", collab_msg);
}
}
}
},
Some(Err(e)) => error!("Error receiving collab message: {:?}", e.into()),
None => break,
}
}
}
}
});

View File

@ -2,9 +2,9 @@ use crate::collaborate::{CollabAccessControl, CollabBroadcast, CollabStoragePlug
use crate::entities::RealtimeUser;
use anyhow::Error;
use collab::core::collab::MutexCollab;
use collab::core::collab_plugin::EncodedCollab;
use collab::core::origin::CollabOrigin;
use collab::preclude::Collab;
use collab_entity::CollabType;
use dashmap::DashMap;
use database::collab::CollabStorage;
@ -37,8 +37,7 @@ where
}
/// Performs a periodic check to remove groups based on the following conditions:
/// 1. Groups without any subscribers.
/// 2. Groups that have been inactive for a specified period of time.
/// Groups that have been inactive for a specified period of time.
pub async fn tick(&self) -> Vec<String> {
let mut inactive_group_ids = vec![];
for entry in self.group_by_object_id.iter() {
@ -123,9 +122,8 @@ where
collab_type: CollabType,
) -> Arc<CollabGroup<U>> {
event!(tracing::Level::TRACE, "New group:{}", object_id);
let collab = MutexCollab::new(CollabOrigin::Server, object_id, vec![]);
let broadcast = CollabBroadcast::new(object_id, collab.clone(), 10);
let collab = Arc::new(collab.clone());
let collab = Arc::new(MutexCollab::new(CollabOrigin::Server, object_id, vec![]));
let broadcast = CollabBroadcast::new(object_id, 10);
// The lifecycle of the collab is managed by the group.
let group = Arc::new(CollabGroup::new(
@ -149,7 +147,7 @@ where
.storage
.cache_collab(object_id, Arc::downgrade(&collab))
.await;
group.observe_collab().await;
group.observe_collab(&collab).await;
group
}
@ -160,7 +158,7 @@ where
/// A group used to manage a single [Collab] object
pub struct CollabGroup<U> {
pub collab: Arc<MutexCollab>,
collab: Arc<MutexCollab>,
collab_type: CollabType,
/// A broadcast used to propagate updates produced by yrs [yrs::Doc] and [Awareness]
/// to subscribes.
@ -195,8 +193,8 @@ where
}
}
pub async fn observe_collab(&self) {
self.broadcast.observe_collab_changes().await;
pub async fn observe_collab(&self, collab: &Arc<MutexCollab>) {
self.broadcast.observe_collab_changes(collab).await;
}
pub fn contains_user(&self, user: &U) -> bool {
@ -289,7 +287,12 @@ where
<Sink as futures_util::Sink<CollabMessage>>::Error: std::error::Error + Send + Sync,
E: Into<Error> + Send + Sync + 'static,
{
let sub = self.broadcast.subscribe(subscriber_origin, sink, stream);
let sub = self.broadcast.subscribe(
subscriber_origin,
sink,
stream,
Arc::downgrade(&self.collab),
);
// Remove the old user if it exists
let user_device = user.user_device();
@ -305,19 +308,6 @@ where
self.subscribers.insert((*user).clone(), sub);
}
/// Mutate the [Collab] by the given closure
pub fn get_mut_collab<F>(&self, f: F)
where
F: FnOnce(&Collab),
{
let collab = self.collab.lock();
f(&collab);
}
pub fn encode_v1(&self) -> EncodedCollab {
self.collab.lock().encode_collab_v1()
}
pub async fn is_empty(&self) -> bool {
self.subscribers.is_empty()
}

View File

@ -85,15 +85,15 @@ where
loop {
interval.tick().await;
if let Some(groups) = weak_groups.upgrade() {
cloned_metrics.record_opening_collab_count(groups.number_of_groups().await);
cloned_metrics.record_connected_users(cloned_client_stream_by_user.len());
cloned_metrics
.record_encode_collab_mem_hit_rate(cloned_storage.encode_collab_mem_hit_rate());
let inactive_group_ids = groups.tick().await;
for id in inactive_group_ids {
cloned_group_sender_by_object_id.remove(&id);
}
cloned_metrics.record_opening_collab_count(groups.number_of_groups().await);
cloned_metrics.record_connected_users(cloned_client_stream_by_user.len());
cloned_metrics
.record_encode_collab_mem_hit_rate(cloned_storage.encode_collab_mem_hit_rate());
} else {
break;
}