From 7f12628547fbd37a98f91c879e0bad1570831d82 Mon Sep 17 00:00:00 2001 From: "Nathan.fooo" <86001920+appflowy@users.noreply.github.com> Date: Sun, 18 Feb 2024 11:55:47 +0800 Subject: [PATCH] chore: use dashmap (#319) --- Cargo.lock | 1 + libs/realtime/Cargo.toml | 1 + libs/realtime/src/collaborate/group.rs | 75 +++++-------- libs/realtime/src/collaborate/retry.rs | 25 ++--- libs/realtime/src/collaborate/server.rs | 137 ++++++++---------------- tests/casbin/collab_ac_test.rs | 8 +- 6 files changed, 93 insertions(+), 154 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3f438363..ee466c3d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4215,6 +4215,7 @@ dependencies = [ "chrono", "collab", "collab-entity", + "dashmap", "database", "database-entity", "futures-util", diff --git a/libs/realtime/Cargo.toml b/libs/realtime/Cargo.toml index 54c27642..85d893f5 100644 --- a/libs/realtime/Cargo.toml +++ b/libs/realtime/Cargo.toml @@ -24,6 +24,7 @@ serde_repr = "0.1.18" tokio-retry = "0.3.0" reqwest = "0.11.23" app-error = { workspace = true } +dashmap.workspace = true collab = { version = "0.1.0"} collab-entity = { version = "0.1.0" } diff --git a/libs/realtime/src/collaborate/group.rs b/libs/realtime/src/collaborate/group.rs index 0cf2d01d..b905c0b2 100644 --- a/libs/realtime/src/collaborate/group.rs +++ b/libs/realtime/src/collaborate/group.rs @@ -9,6 +9,7 @@ use database::collab::CollabStorage; use std::collections::HashMap; use collab::core::collab_plugin::EncodedCollab; +use dashmap::DashMap; use futures_util::{SinkExt, StreamExt}; use std::sync::Arc; use tokio::sync::{Mutex, RwLock}; @@ -19,7 +20,7 @@ use realtime_entity::collab_msg::CollabMessage; use tracing::{debug, error, event, instrument, trace, warn}; pub struct CollabGroupCache { - group_by_object_id: Arc>>>>, + group_by_object_id: Arc>>>, storage: Arc, access_control: Arc, } @@ -32,7 +33,7 @@ where { pub fn new(storage: Arc, access_control: Arc) -> Self { Self { - group_by_object_id: Arc::new(RwLock::new(HashMap::new())), + group_by_object_id: Arc::new(DashMap::new()), storage, access_control, } @@ -43,13 +44,12 @@ where /// 2. Groups that have been inactive for a specified period of time. pub async fn tick(&self) { let mut inactive_group_ids = vec![]; - if let Ok(groups) = self.group_by_object_id.try_read() { - for (object_id, group) in groups.iter() { - if group.is_inactive().await { - inactive_group_ids.push(object_id.clone()); - if inactive_group_ids.len() > 5 { - break; - } + for entry in self.group_by_object_id.iter() { + let (object_id, group) = (entry.key(), entry.value()); + if group.is_inactive().await { + inactive_group_ids.push(object_id.clone()); + if inactive_group_ids.len() > 5 { + break; } } } @@ -62,17 +62,16 @@ where } pub async fn contains_user(&self, object_id: &str, user: &U) -> Result { - let group_by_object_id = self.group_by_object_id.try_read()?; - if let Some(group) = group_by_object_id.get(object_id) { - Ok(group.subscribers.try_read()?.get(user).is_some()) + if let Some(entry) = self.group_by_object_id.get(object_id) { + Ok(entry.value().subscribers.try_read()?.get(user).is_some()) } else { Ok(false) } } pub async fn remove_user(&self, object_id: &str, user: &U) -> Result<(), Error> { - let group_by_object_id = self.group_by_object_id.try_read()?; - if let Some(group) = group_by_object_id.get(object_id) { + if let Some(entry) = self.group_by_object_id.get(object_id) { + let group = entry.value(); if let Some(mut subscriber) = group.subscribers.try_write()?.remove(user) { trace!("Remove subscriber: {}", subscriber.origin); tokio::spawn(async move { @@ -84,39 +83,29 @@ where } pub async fn contains_group(&self, object_id: &str) -> Result { - let group_by_object_id = self.group_by_object_id.try_read()?; - Ok(group_by_object_id.get(object_id).is_some()) + Ok(self.group_by_object_id.get(object_id).is_some()) } pub async fn get_group(&self, object_id: &str) -> Option>> { self .group_by_object_id - .try_read() - .ok()? .get(object_id) - .cloned() + .map(|v| v.value().clone()) } #[instrument(skip(self))] pub async fn remove_group(&self, object_id: &str) { - let mut group_by_object_id = match self.group_by_object_id.try_write() { - Ok(lock) => lock, - Err(err) => { - error!("Failed to acquire write lock to remove group: {:?}", err); - return; - }, - }; - let group = group_by_object_id.remove(object_id); - drop(group_by_object_id); + let entry = self.group_by_object_id.remove(object_id); - if let Some(group) = group { - group.flush_collab().await; + if let Some(entry) = entry { + let group = entry.1; // As we've already removed the group, we directly operate on the removed group's subscribers. if let Ok(mut subscribers) = group.subscribers.try_write() { for (_, subscriber) in subscribers.iter_mut() { subscriber.stop().await; } } + group.flush_collab().await; } else { // Log error if the group doesn't exist error!("Group for object_id:{} not found", object_id); @@ -132,21 +121,16 @@ where object_id: &str, collab_type: CollabType, ) { - match self.group_by_object_id.try_write() { - Ok(mut group_by_object_id) => { - if group_by_object_id.contains_key(object_id) { - warn!("Group for object_id:{} already exists", object_id); - return; - } - - let group = self - .init_group(uid, workspace_id, object_id, collab_type) - .await; - debug!("[realtime]: {} create group:{}", uid, object_id); - group_by_object_id.insert(object_id.to_string(), group); - }, - Err(err) => error!("Failed to acquire write lock to create group: {:?}", err), + if self.group_by_object_id.contains_key(object_id) { + warn!("Group for object_id:{} already exists", object_id); + return; } + + let group = self + .init_group(uid, workspace_id, object_id, collab_type) + .await; + debug!("[realtime]: {} create group:{}", uid, object_id); + self.group_by_object_id.insert(object_id.to_string(), group); } #[tracing::instrument(level = "trace", skip(self))] @@ -189,8 +173,7 @@ where } pub async fn number_of_groups(&self) -> Option { - let read_guard = self.group_by_object_id.try_read().ok()?; - Some(read_guard.keys().len()) + Some(self.group_by_object_id.len()) } } diff --git a/libs/realtime/src/collaborate/retry.rs b/libs/realtime/src/collaborate/retry.rs index 15e40f30..ed83170e 100644 --- a/libs/realtime/src/collaborate/retry.rs +++ b/libs/realtime/src/collaborate/retry.rs @@ -4,18 +4,18 @@ use anyhow::{anyhow, Error}; use collab::core::origin::CollabOrigin; use database::collab::CollabStorage; use futures_util::SinkExt; -use parking_lot::Mutex; + use realtime_entity::collab_msg::{CollabMessage, CollabSinkMessage}; +use dashmap::DashMap; use std::collections::hash_map::Entry; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::future; use std::future::Future; use std::iter::Take; use std::pin::Pin; use std::sync::{Arc, Weak}; use std::time::Duration; -use tokio::sync::RwLock; use crate::entities::{Editing, RealtimeUser}; use tokio_retry::strategy::FixedInterval; @@ -36,8 +36,8 @@ pub(crate) struct CollabUserMessage<'a, U> { pub(crate) struct SubscribeGroupIfNeed<'a, U, S, AC> { pub(crate) collab_user_message: &'a CollabUserMessage<'a, U>, pub(crate) groups: &'a Arc>, - pub(crate) edit_collab_by_user: &'a Arc>>>, - pub(crate) client_stream_by_user: &'a Arc>>, + pub(crate) edit_collab_by_user: &'a Arc>>, + pub(crate) client_stream_by_user: &'a Arc>, pub(crate) access_control: &'a Arc, } @@ -107,12 +107,7 @@ where } let origin = Self::get_origin(collab_message); - if let Some(client_stream) = self - .client_stream_by_user - .try_write() - .map_err(|err| RealtimeError::Internal(err.into()))? - .get_mut(user) - { + if let Some(mut client_stream) = self.client_stream_by_user.get_mut(user) { if let Some(collab_group) = self.groups.get_group(object_id).await { if let Entry::Vacant(entry) = collab_group .subscribers @@ -129,10 +124,6 @@ where let client_uid = user.uid(); self .edit_collab_by_user - .try_lock() - .ok_or(RealtimeError::Internal(anyhow!( - "Failed to acquire lock to insert editing" - )))? .entry((*user).clone()) .or_default() .insert(Editing { @@ -142,7 +133,7 @@ where let (sink, stream) = Self::make_channel( object_id, - client_stream, + client_stream.value_mut(), client_uid, self.access_control.clone(), self.access_control.clone(), @@ -295,7 +286,7 @@ where } } -pub struct SubscribeGroupCondition(pub Weak>>); +pub struct SubscribeGroupCondition(pub Weak>); impl Condition for SubscribeGroupCondition { fn should_retry(&mut self, _error: &RealtimeError) -> bool { self.0.upgrade().is_some() diff --git a/libs/realtime/src/collaborate/server.rs b/libs/realtime/src/collaborate/server.rs index 77dd3fda..266720b1 100644 --- a/libs/realtime/src/collaborate/server.rs +++ b/libs/realtime/src/collaborate/server.rs @@ -1,35 +1,29 @@ -use crate::entities::{ - ClientMessage, ClientStreamMessage, Connect, Disconnect, Editing, RealtimeMessage, RealtimeUser, -}; -use crate::error::{RealtimeError, StreamError}; -use anyhow::{anyhow, Result}; - -use actix::{Actor, Context, Handler, ResponseFuture}; -use futures_util::future::BoxFuture; -use parking_lot::Mutex; -use realtime_entity::collab_msg::CollabMessage; -use std::collections::{HashMap, HashSet}; -use std::future::Future; -use std::pin::Pin; - -use std::sync::Arc; -use std::time::Duration; - -use tokio::sync::RwLock; -use tokio::time::interval; - -use tokio_stream::wrappers::{BroadcastStream, ReceiverStream}; -use tokio_stream::StreamExt; -use tracing::{error, event, info, instrument, trace, warn}; - use crate::client::ClientWSSink; use crate::collaborate::group::CollabGroupCache; use crate::collaborate::permission::CollabAccessControl; use crate::collaborate::retry::{CollabUserMessage, SubscribeGroupIfNeed}; use crate::collaborate::RealtimeMetrics; +use crate::entities::{ + ClientMessage, ClientStreamMessage, Connect, Disconnect, Editing, RealtimeMessage, RealtimeUser, +}; +use crate::error::{RealtimeError, StreamError}; use crate::util::channel_ext::UnboundedSenderSink; +use actix::{Actor, Context, Handler, ResponseFuture}; +use anyhow::{anyhow, Result}; +use dashmap::DashMap; use database::collab::CollabStorage; +use futures_util::future::BoxFuture; +use realtime_entity::collab_msg::CollabMessage; use realtime_entity::message::SystemMessage; +use std::collections::HashSet; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; +use tokio::time::interval; +use tokio_stream::wrappers::{BroadcastStream, ReceiverStream}; +use tokio_stream::StreamExt; +use tracing::{error, event, info, instrument, trace, warn}; #[derive(Clone)] pub struct CollabServer { @@ -37,7 +31,7 @@ pub struct CollabServer { storage: Arc, /// Keep track of all collab groups groups: Arc>, - user_by_device: Arc>>, + user_by_device: Arc>, /// This map stores the session IDs for users currently connected to the server. /// The user's identifier [U] is used as the key, and their corresponding session ID is the value. /// @@ -46,11 +40,11 @@ pub struct CollabServer { /// If the two session IDs differ, it indicates that the user has established a new connection /// to the server since the stored session ID was last updated. /// - session_id_by_user: Arc>>, + session_id_by_user: Arc>, /// Keep track of all object ids that a user is subscribed to - editing_collab_by_user: Arc>>>, + editing_collab_by_user: Arc>>, /// Keep track of all client streams - client_stream_by_user: Arc>>, + client_stream_by_user: Arc>, access_control: Arc, #[allow(dead_code)] metrics: Arc, @@ -72,7 +66,7 @@ where storage.clone(), access_control.clone(), )); - let client_stream_by_user: Arc>> = Default::default(); + let client_stream_by_user: Arc> = Default::default(); let editing_collab_by_user = Default::default(); let weak_groups = Arc::downgrade(&groups); @@ -84,14 +78,11 @@ where loop { interval.tick().await; if let Some(groups) = weak_groups.upgrade() { - // Perform operations that require awaiting outside of the synchronous code block if let Some(groups_operation) = groups.number_of_groups().await { cloned_metrics.record_opening_collab_count(groups_operation); } - if let Ok(read_guard) = cloned_client_stream_by_user.try_read() { - cloned_metrics.record_connected_users(read_guard.keys().len()); - } + cloned_metrics.record_connected_users(cloned_client_stream_by_user.len()); // Assuming mem_usage() is synchronous and quick to execute let mem_usage = cloned_storage.mem_usage(); @@ -119,9 +110,9 @@ where fn process_realtime_message( user: U, - client_stream_by_user: Arc>>, + client_stream_by_user: Arc>, groups: Arc>, - edit_collab_by_user: Arc>>>, + edit_collab_by_user: Arc>>, access_control: Arc, realtime_msg: RealtimeMessage, ) -> Pin>>> { @@ -130,17 +121,7 @@ where match realtime_msg { RealtimeMessage::Collab(collab_message) => { // 1.Check the client is connected with the websocket server - if client_stream_by_user - .try_read() - .map_err(|err| { - RealtimeError::Internal(anyhow!( - "failed to acquire the lock for client stream:{}", - err - )) - })? - .get(&user) - .is_none() - { + if client_stream_by_user.get(&user).is_none() { let msg = anyhow!( "The client stream: {} is not found, it should be created when the client is connected with this websocket server", user @@ -183,19 +164,16 @@ where async fn remove_user( groups: &Arc>, - editing_collab_by_user: &Arc>>>, + editing_collab_by_user: &Arc>>, user: &U, ) where S: CollabStorage, U: RealtimeUser, AC: CollabAccessControl, { - let editing_set = editing_collab_by_user - .try_lock() - .and_then(|mut guard| guard.remove(user)); - - if let Some(editing_set) = editing_set { - for editing in editing_set { + let entry = editing_collab_by_user.remove(user); + if let Some(entry) = entry { + for editing in entry.1 { remove_user_from_group(user, groups, &editing).await; } } @@ -233,17 +211,13 @@ where Box::pin(async move { trace!("[realtime]: new connection => {} ", new_conn.user); - user_by_session_id - .write() - .await - .insert(new_conn.user.clone(), new_conn.session_id); + user_by_session_id.insert(new_conn.user.clone(), new_conn.session_id); - let old_user = user_by_device - .write() - .insert(UserDevice::from(&new_conn.user), new_conn.user.clone()); + let old_user = user_by_device.insert(UserDevice::from(&new_conn.user), new_conn.user.clone()); if let Some(old_user) = old_user { - if let Some(old_stream) = client_stream_by_user.write().await.remove(&old_user) { + if let Some(value) = client_stream_by_user.remove(&old_user) { + let old_stream = value.1; info!("same user connect again, remove the stream: {}", &old_user); old_stream.disconnect(); } @@ -252,15 +226,7 @@ where remove_user(&groups, &editing_collab_by_user, &old_user).await; } - let mut write_guard = client_stream_by_user.write().await; - info!( - "new user: {}, connected user:{}", - &new_conn.user, - write_guard.keys().len() - ); - write_guard.insert(new_conn.user, client_stream); - drop(write_guard); - + client_stream_by_user.insert(new_conn.user, client_stream); Ok(()) }) } @@ -291,26 +257,18 @@ where let session_id_by_user = self.session_id_by_user.clone(); Box::pin(async move { - let guard = match session_id_by_user.try_read() { - Ok(guard) => guard, - Err(_) => { - return Ok(()); - }, - }; - - if let Some(session_id) = guard.get(&msg.user) { - if session_id != &msg.session_id { + // If the user has reconnected with a new session, the session id will be different. + // So do not remove the user from groups and client streams + if let Some(entry) = session_id_by_user.get(&msg.user) { + if entry.value() != &msg.session_id { return Ok(()); } } remove_user(&groups, &editing_collab_by_user, &msg.user).await; - if let Ok(mut client_stream_by_user) = client_stream_by_user.try_write() { - if client_stream_by_user.remove(&msg.user).is_some() { - info!("remove client stream: {}", &msg.user); - } + if client_stream_by_user.remove(&msg.user).is_some() { + info!("remove client stream: {}", &msg.user); } - Ok(()) }) } @@ -371,15 +329,15 @@ where } } let device_user = UserDevice { device_id, uid }; - let user = user_by_device.read().get(&device_user).cloned(); - match user { + let entry = user_by_device.get(&device_user); + match entry { None => Err(RealtimeError::UserNotFound(format!( "Can't find the user:{} device_id:{} from client stream message", uid, device_user.device_id ))), - Some(user) => { + Some(entry) => { Self::process_realtime_message( - user, + entry.value().clone(), client_stream_by_user, groups, edit_collab_by_user, @@ -400,11 +358,10 @@ where async fn broadcast_message( user: &U, collab_message: CollabMessage, - client_streams: &Arc>>, + client_streams: &Arc>, ) where U: RealtimeUser, { - let client_streams = client_streams.read().await; if let Some(client_stream) = client_streams.get(user) { trace!("[realtime]: receives collab message: {}", collab_message); match client_stream diff --git a/tests/casbin/collab_ac_test.rs b/tests/casbin/collab_ac_test.rs index 5541fc2a..0ecca3f6 100644 --- a/tests/casbin/collab_ac_test.rs +++ b/tests/casbin/collab_ac_test.rs @@ -16,7 +16,7 @@ use std::time::Duration; use tokio::time::sleep; #[sqlx::test(migrations = false)] -async fn test_collab_access_control_get_access_level(pool: PgPool) -> anyhow::Result<()> { +async fn test_collab_access_control(pool: PgPool) -> anyhow::Result<()> { setup_db(&pool).await?; let model = DefaultModel::from_str(MODEL_CONF).await?; @@ -61,6 +61,7 @@ async fn test_collab_access_control_get_access_level(pool: PgPool) -> anyhow::Re .await .context("adding users to workspace")?; + // user that created the workspace should have full access assert_access_level( &access_control, &user.uid, @@ -69,6 +70,7 @@ async fn test_collab_access_control_get_access_level(pool: PgPool) -> anyhow::Re ) .await; + // member should have read and write access assert_access_level( &access_control, &member.uid, @@ -77,6 +79,7 @@ async fn test_collab_access_control_get_access_level(pool: PgPool) -> anyhow::Re ) .await; + // guest should have read access assert_access_level( &access_control, &guest.uid, @@ -90,6 +93,7 @@ async fn test_collab_access_control_get_access_level(pool: PgPool) -> anyhow::Re .await .context("acquire transaction to update collab member")?; + // update guest access level to read and comment database::collab::upsert_collab_member_with_txn( guest.uid, &workspace.workspace_id.to_string(), @@ -103,6 +107,7 @@ async fn test_collab_access_control_get_access_level(pool: PgPool) -> anyhow::Re .await .expect("commit transaction to update collab member"); + // guest should have read and comment access assert_access_level( &access_control, &guest.uid, @@ -115,6 +120,7 @@ async fn test_collab_access_control_get_access_level(pool: PgPool) -> anyhow::Re .await .context("delete collab member")?; + // guest should not have access after removed from collab assert_access_level( &access_control, &guest.uid,