chore: use dashmap (#319)

This commit is contained in:
Nathan.fooo 2024-02-18 11:55:47 +08:00 committed by GitHub
parent 2f0f093331
commit 7f12628547
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 93 additions and 154 deletions

1
Cargo.lock generated
View File

@ -4215,6 +4215,7 @@ dependencies = [
"chrono",
"collab",
"collab-entity",
"dashmap",
"database",
"database-entity",
"futures-util",

View File

@ -24,6 +24,7 @@ serde_repr = "0.1.18"
tokio-retry = "0.3.0"
reqwest = "0.11.23"
app-error = { workspace = true }
dashmap.workspace = true
collab = { version = "0.1.0"}
collab-entity = { version = "0.1.0" }

View File

@ -9,6 +9,7 @@ use database::collab::CollabStorage;
use std::collections::HashMap;
use collab::core::collab_plugin::EncodedCollab;
use dashmap::DashMap;
use futures_util::{SinkExt, StreamExt};
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
@ -19,7 +20,7 @@ use realtime_entity::collab_msg::CollabMessage;
use tracing::{debug, error, event, instrument, trace, warn};
pub struct CollabGroupCache<S, U, AC> {
group_by_object_id: Arc<RwLock<HashMap<String, Arc<CollabGroup<U>>>>>,
group_by_object_id: Arc<DashMap<String, Arc<CollabGroup<U>>>>,
storage: Arc<S>,
access_control: Arc<AC>,
}
@ -32,7 +33,7 @@ where
{
pub fn new(storage: Arc<S>, access_control: Arc<AC>) -> Self {
Self {
group_by_object_id: Arc::new(RwLock::new(HashMap::new())),
group_by_object_id: Arc::new(DashMap::new()),
storage,
access_control,
}
@ -43,13 +44,12 @@ where
/// 2. Groups that have been inactive for a specified period of time.
pub async fn tick(&self) {
let mut inactive_group_ids = vec![];
if let Ok(groups) = self.group_by_object_id.try_read() {
for (object_id, group) in groups.iter() {
if group.is_inactive().await {
inactive_group_ids.push(object_id.clone());
if inactive_group_ids.len() > 5 {
break;
}
for entry in self.group_by_object_id.iter() {
let (object_id, group) = (entry.key(), entry.value());
if group.is_inactive().await {
inactive_group_ids.push(object_id.clone());
if inactive_group_ids.len() > 5 {
break;
}
}
}
@ -62,17 +62,16 @@ where
}
pub async fn contains_user(&self, object_id: &str, user: &U) -> Result<bool, Error> {
let group_by_object_id = self.group_by_object_id.try_read()?;
if let Some(group) = group_by_object_id.get(object_id) {
Ok(group.subscribers.try_read()?.get(user).is_some())
if let Some(entry) = self.group_by_object_id.get(object_id) {
Ok(entry.value().subscribers.try_read()?.get(user).is_some())
} else {
Ok(false)
}
}
pub async fn remove_user(&self, object_id: &str, user: &U) -> Result<(), Error> {
let group_by_object_id = self.group_by_object_id.try_read()?;
if let Some(group) = group_by_object_id.get(object_id) {
if let Some(entry) = self.group_by_object_id.get(object_id) {
let group = entry.value();
if let Some(mut subscriber) = group.subscribers.try_write()?.remove(user) {
trace!("Remove subscriber: {}", subscriber.origin);
tokio::spawn(async move {
@ -84,39 +83,29 @@ where
}
pub async fn contains_group(&self, object_id: &str) -> Result<bool, Error> {
let group_by_object_id = self.group_by_object_id.try_read()?;
Ok(group_by_object_id.get(object_id).is_some())
Ok(self.group_by_object_id.get(object_id).is_some())
}
pub async fn get_group(&self, object_id: &str) -> Option<Arc<CollabGroup<U>>> {
self
.group_by_object_id
.try_read()
.ok()?
.get(object_id)
.cloned()
.map(|v| v.value().clone())
}
#[instrument(skip(self))]
pub async fn remove_group(&self, object_id: &str) {
let mut group_by_object_id = match self.group_by_object_id.try_write() {
Ok(lock) => lock,
Err(err) => {
error!("Failed to acquire write lock to remove group: {:?}", err);
return;
},
};
let group = group_by_object_id.remove(object_id);
drop(group_by_object_id);
let entry = self.group_by_object_id.remove(object_id);
if let Some(group) = group {
group.flush_collab().await;
if let Some(entry) = entry {
let group = entry.1;
// As we've already removed the group, we directly operate on the removed group's subscribers.
if let Ok(mut subscribers) = group.subscribers.try_write() {
for (_, subscriber) in subscribers.iter_mut() {
subscriber.stop().await;
}
}
group.flush_collab().await;
} else {
// Log error if the group doesn't exist
error!("Group for object_id:{} not found", object_id);
@ -132,21 +121,16 @@ where
object_id: &str,
collab_type: CollabType,
) {
match self.group_by_object_id.try_write() {
Ok(mut group_by_object_id) => {
if group_by_object_id.contains_key(object_id) {
warn!("Group for object_id:{} already exists", object_id);
return;
}
let group = self
.init_group(uid, workspace_id, object_id, collab_type)
.await;
debug!("[realtime]: {} create group:{}", uid, object_id);
group_by_object_id.insert(object_id.to_string(), group);
},
Err(err) => error!("Failed to acquire write lock to create group: {:?}", err),
if self.group_by_object_id.contains_key(object_id) {
warn!("Group for object_id:{} already exists", object_id);
return;
}
let group = self
.init_group(uid, workspace_id, object_id, collab_type)
.await;
debug!("[realtime]: {} create group:{}", uid, object_id);
self.group_by_object_id.insert(object_id.to_string(), group);
}
#[tracing::instrument(level = "trace", skip(self))]
@ -189,8 +173,7 @@ where
}
pub async fn number_of_groups(&self) -> Option<usize> {
let read_guard = self.group_by_object_id.try_read().ok()?;
Some(read_guard.keys().len())
Some(self.group_by_object_id.len())
}
}

View File

@ -4,18 +4,18 @@ use anyhow::{anyhow, Error};
use collab::core::origin::CollabOrigin;
use database::collab::CollabStorage;
use futures_util::SinkExt;
use parking_lot::Mutex;
use realtime_entity::collab_msg::{CollabMessage, CollabSinkMessage};
use dashmap::DashMap;
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use std::collections::HashSet;
use std::future;
use std::future::Future;
use std::iter::Take;
use std::pin::Pin;
use std::sync::{Arc, Weak};
use std::time::Duration;
use tokio::sync::RwLock;
use crate::entities::{Editing, RealtimeUser};
use tokio_retry::strategy::FixedInterval;
@ -36,8 +36,8 @@ pub(crate) struct CollabUserMessage<'a, U> {
pub(crate) struct SubscribeGroupIfNeed<'a, U, S, AC> {
pub(crate) collab_user_message: &'a CollabUserMessage<'a, U>,
pub(crate) groups: &'a Arc<CollabGroupCache<S, U, AC>>,
pub(crate) edit_collab_by_user: &'a Arc<Mutex<HashMap<U, HashSet<Editing>>>>,
pub(crate) client_stream_by_user: &'a Arc<RwLock<HashMap<U, CollabClientStream>>>,
pub(crate) edit_collab_by_user: &'a Arc<DashMap<U, HashSet<Editing>>>,
pub(crate) client_stream_by_user: &'a Arc<DashMap<U, CollabClientStream>>,
pub(crate) access_control: &'a Arc<AC>,
}
@ -107,12 +107,7 @@ where
}
let origin = Self::get_origin(collab_message);
if let Some(client_stream) = self
.client_stream_by_user
.try_write()
.map_err(|err| RealtimeError::Internal(err.into()))?
.get_mut(user)
{
if let Some(mut client_stream) = self.client_stream_by_user.get_mut(user) {
if let Some(collab_group) = self.groups.get_group(object_id).await {
if let Entry::Vacant(entry) = collab_group
.subscribers
@ -129,10 +124,6 @@ where
let client_uid = user.uid();
self
.edit_collab_by_user
.try_lock()
.ok_or(RealtimeError::Internal(anyhow!(
"Failed to acquire lock to insert editing"
)))?
.entry((*user).clone())
.or_default()
.insert(Editing {
@ -142,7 +133,7 @@ where
let (sink, stream) = Self::make_channel(
object_id,
client_stream,
client_stream.value_mut(),
client_uid,
self.access_control.clone(),
self.access_control.clone(),
@ -295,7 +286,7 @@ where
}
}
pub struct SubscribeGroupCondition<U>(pub Weak<RwLock<HashMap<U, CollabClientStream>>>);
pub struct SubscribeGroupCondition<U>(pub Weak<DashMap<U, CollabClientStream>>);
impl<U> Condition<RealtimeError> for SubscribeGroupCondition<U> {
fn should_retry(&mut self, _error: &RealtimeError) -> bool {
self.0.upgrade().is_some()

View File

@ -1,35 +1,29 @@
use crate::entities::{
ClientMessage, ClientStreamMessage, Connect, Disconnect, Editing, RealtimeMessage, RealtimeUser,
};
use crate::error::{RealtimeError, StreamError};
use anyhow::{anyhow, Result};
use actix::{Actor, Context, Handler, ResponseFuture};
use futures_util::future::BoxFuture;
use parking_lot::Mutex;
use realtime_entity::collab_msg::CollabMessage;
use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio::time::interval;
use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
use tokio_stream::StreamExt;
use tracing::{error, event, info, instrument, trace, warn};
use crate::client::ClientWSSink;
use crate::collaborate::group::CollabGroupCache;
use crate::collaborate::permission::CollabAccessControl;
use crate::collaborate::retry::{CollabUserMessage, SubscribeGroupIfNeed};
use crate::collaborate::RealtimeMetrics;
use crate::entities::{
ClientMessage, ClientStreamMessage, Connect, Disconnect, Editing, RealtimeMessage, RealtimeUser,
};
use crate::error::{RealtimeError, StreamError};
use crate::util::channel_ext::UnboundedSenderSink;
use actix::{Actor, Context, Handler, ResponseFuture};
use anyhow::{anyhow, Result};
use dashmap::DashMap;
use database::collab::CollabStorage;
use futures_util::future::BoxFuture;
use realtime_entity::collab_msg::CollabMessage;
use realtime_entity::message::SystemMessage;
use std::collections::HashSet;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::interval;
use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
use tokio_stream::StreamExt;
use tracing::{error, event, info, instrument, trace, warn};
#[derive(Clone)]
pub struct CollabServer<S, U, AC> {
@ -37,7 +31,7 @@ pub struct CollabServer<S, U, AC> {
storage: Arc<S>,
/// Keep track of all collab groups
groups: Arc<CollabGroupCache<S, U, AC>>,
user_by_device: Arc<parking_lot::RwLock<HashMap<UserDevice, U>>>,
user_by_device: Arc<DashMap<UserDevice, U>>,
/// This map stores the session IDs for users currently connected to the server.
/// The user's identifier [U] is used as the key, and their corresponding session ID is the value.
///
@ -46,11 +40,11 @@ pub struct CollabServer<S, U, AC> {
/// If the two session IDs differ, it indicates that the user has established a new connection
/// to the server since the stored session ID was last updated.
///
session_id_by_user: Arc<RwLock<HashMap<U, String>>>,
session_id_by_user: Arc<DashMap<U, String>>,
/// Keep track of all object ids that a user is subscribed to
editing_collab_by_user: Arc<Mutex<HashMap<U, HashSet<Editing>>>>,
editing_collab_by_user: Arc<DashMap<U, HashSet<Editing>>>,
/// Keep track of all client streams
client_stream_by_user: Arc<RwLock<HashMap<U, CollabClientStream>>>,
client_stream_by_user: Arc<DashMap<U, CollabClientStream>>,
access_control: Arc<AC>,
#[allow(dead_code)]
metrics: Arc<RealtimeMetrics>,
@ -72,7 +66,7 @@ where
storage.clone(),
access_control.clone(),
));
let client_stream_by_user: Arc<RwLock<HashMap<U, CollabClientStream>>> = Default::default();
let client_stream_by_user: Arc<DashMap<U, CollabClientStream>> = Default::default();
let editing_collab_by_user = Default::default();
let weak_groups = Arc::downgrade(&groups);
@ -84,14 +78,11 @@ where
loop {
interval.tick().await;
if let Some(groups) = weak_groups.upgrade() {
// Perform operations that require awaiting outside of the synchronous code block
if let Some(groups_operation) = groups.number_of_groups().await {
cloned_metrics.record_opening_collab_count(groups_operation);
}
if let Ok(read_guard) = cloned_client_stream_by_user.try_read() {
cloned_metrics.record_connected_users(read_guard.keys().len());
}
cloned_metrics.record_connected_users(cloned_client_stream_by_user.len());
// Assuming mem_usage() is synchronous and quick to execute
let mem_usage = cloned_storage.mem_usage();
@ -119,9 +110,9 @@ where
fn process_realtime_message(
user: U,
client_stream_by_user: Arc<RwLock<HashMap<U, CollabClientStream>>>,
client_stream_by_user: Arc<DashMap<U, CollabClientStream>>,
groups: Arc<CollabGroupCache<S, U, AC>>,
edit_collab_by_user: Arc<Mutex<HashMap<U, HashSet<Editing>>>>,
edit_collab_by_user: Arc<DashMap<U, HashSet<Editing>>>,
access_control: Arc<AC>,
realtime_msg: RealtimeMessage,
) -> Pin<Box<impl Future<Output = Result<(), RealtimeError>>>> {
@ -130,17 +121,7 @@ where
match realtime_msg {
RealtimeMessage::Collab(collab_message) => {
// 1.Check the client is connected with the websocket server
if client_stream_by_user
.try_read()
.map_err(|err| {
RealtimeError::Internal(anyhow!(
"failed to acquire the lock for client stream:{}",
err
))
})?
.get(&user)
.is_none()
{
if client_stream_by_user.get(&user).is_none() {
let msg = anyhow!(
"The client stream: {} is not found, it should be created when the client is connected with this websocket server",
user
@ -183,19 +164,16 @@ where
async fn remove_user<S, U, AC>(
groups: &Arc<CollabGroupCache<S, U, AC>>,
editing_collab_by_user: &Arc<Mutex<HashMap<U, HashSet<Editing>>>>,
editing_collab_by_user: &Arc<DashMap<U, HashSet<Editing>>>,
user: &U,
) where
S: CollabStorage,
U: RealtimeUser,
AC: CollabAccessControl,
{
let editing_set = editing_collab_by_user
.try_lock()
.and_then(|mut guard| guard.remove(user));
if let Some(editing_set) = editing_set {
for editing in editing_set {
let entry = editing_collab_by_user.remove(user);
if let Some(entry) = entry {
for editing in entry.1 {
remove_user_from_group(user, groups, &editing).await;
}
}
@ -233,17 +211,13 @@ where
Box::pin(async move {
trace!("[realtime]: new connection => {} ", new_conn.user);
user_by_session_id
.write()
.await
.insert(new_conn.user.clone(), new_conn.session_id);
user_by_session_id.insert(new_conn.user.clone(), new_conn.session_id);
let old_user = user_by_device
.write()
.insert(UserDevice::from(&new_conn.user), new_conn.user.clone());
let old_user = user_by_device.insert(UserDevice::from(&new_conn.user), new_conn.user.clone());
if let Some(old_user) = old_user {
if let Some(old_stream) = client_stream_by_user.write().await.remove(&old_user) {
if let Some(value) = client_stream_by_user.remove(&old_user) {
let old_stream = value.1;
info!("same user connect again, remove the stream: {}", &old_user);
old_stream.disconnect();
}
@ -252,15 +226,7 @@ where
remove_user(&groups, &editing_collab_by_user, &old_user).await;
}
let mut write_guard = client_stream_by_user.write().await;
info!(
"new user: {}, connected user:{}",
&new_conn.user,
write_guard.keys().len()
);
write_guard.insert(new_conn.user, client_stream);
drop(write_guard);
client_stream_by_user.insert(new_conn.user, client_stream);
Ok(())
})
}
@ -291,26 +257,18 @@ where
let session_id_by_user = self.session_id_by_user.clone();
Box::pin(async move {
let guard = match session_id_by_user.try_read() {
Ok(guard) => guard,
Err(_) => {
return Ok(());
},
};
if let Some(session_id) = guard.get(&msg.user) {
if session_id != &msg.session_id {
// If the user has reconnected with a new session, the session id will be different.
// So do not remove the user from groups and client streams
if let Some(entry) = session_id_by_user.get(&msg.user) {
if entry.value() != &msg.session_id {
return Ok(());
}
}
remove_user(&groups, &editing_collab_by_user, &msg.user).await;
if let Ok(mut client_stream_by_user) = client_stream_by_user.try_write() {
if client_stream_by_user.remove(&msg.user).is_some() {
info!("remove client stream: {}", &msg.user);
}
if client_stream_by_user.remove(&msg.user).is_some() {
info!("remove client stream: {}", &msg.user);
}
Ok(())
})
}
@ -371,15 +329,15 @@ where
}
}
let device_user = UserDevice { device_id, uid };
let user = user_by_device.read().get(&device_user).cloned();
match user {
let entry = user_by_device.get(&device_user);
match entry {
None => Err(RealtimeError::UserNotFound(format!(
"Can't find the user:{} device_id:{} from client stream message",
uid, device_user.device_id
))),
Some(user) => {
Some(entry) => {
Self::process_realtime_message(
user,
entry.value().clone(),
client_stream_by_user,
groups,
edit_collab_by_user,
@ -400,11 +358,10 @@ where
async fn broadcast_message<U>(
user: &U,
collab_message: CollabMessage,
client_streams: &Arc<RwLock<HashMap<U, CollabClientStream>>>,
client_streams: &Arc<DashMap<U, CollabClientStream>>,
) where
U: RealtimeUser,
{
let client_streams = client_streams.read().await;
if let Some(client_stream) = client_streams.get(user) {
trace!("[realtime]: receives collab message: {}", collab_message);
match client_stream

View File

@ -16,7 +16,7 @@ use std::time::Duration;
use tokio::time::sleep;
#[sqlx::test(migrations = false)]
async fn test_collab_access_control_get_access_level(pool: PgPool) -> anyhow::Result<()> {
async fn test_collab_access_control(pool: PgPool) -> anyhow::Result<()> {
setup_db(&pool).await?;
let model = DefaultModel::from_str(MODEL_CONF).await?;
@ -61,6 +61,7 @@ async fn test_collab_access_control_get_access_level(pool: PgPool) -> anyhow::Re
.await
.context("adding users to workspace")?;
// user that created the workspace should have full access
assert_access_level(
&access_control,
&user.uid,
@ -69,6 +70,7 @@ async fn test_collab_access_control_get_access_level(pool: PgPool) -> anyhow::Re
)
.await;
// member should have read and write access
assert_access_level(
&access_control,
&member.uid,
@ -77,6 +79,7 @@ async fn test_collab_access_control_get_access_level(pool: PgPool) -> anyhow::Re
)
.await;
// guest should have read access
assert_access_level(
&access_control,
&guest.uid,
@ -90,6 +93,7 @@ async fn test_collab_access_control_get_access_level(pool: PgPool) -> anyhow::Re
.await
.context("acquire transaction to update collab member")?;
// update guest access level to read and comment
database::collab::upsert_collab_member_with_txn(
guest.uid,
&workspace.workspace_id.to_string(),
@ -103,6 +107,7 @@ async fn test_collab_access_control_get_access_level(pool: PgPool) -> anyhow::Re
.await
.expect("commit transaction to update collab member");
// guest should have read and comment access
assert_access_level(
&access_control,
&guest.uid,
@ -115,6 +120,7 @@ async fn test_collab_access_control_get_access_level(pool: PgPool) -> anyhow::Re
.await
.context("delete collab member")?;
// guest should not have access after removed from collab
assert_access_level(
&access_control,
&guest.uid,