diff --git a/libs/database/src/collab/collab_storage.rs b/libs/database/src/collab/collab_storage.rs index e3f6c2fa..01a6c29c 100644 --- a/libs/database/src/collab/collab_storage.rs +++ b/libs/database/src/collab/collab_storage.rs @@ -154,6 +154,9 @@ pub trait CollabStorage: Send + Sync + 'static { /// Returns list of snapshots for given object_id in descending order of creation time. async fn get_collab_snapshot_list(&self, oid: &str) -> AppResult; + + async fn add_connected_user(&self, uid: i64, device_id: &str); + async fn remove_connected_user(&self, uid: i64, device_id: &str); } #[async_trait] @@ -268,6 +271,14 @@ where async fn get_collab_snapshot_list(&self, oid: &str) -> AppResult { self.as_ref().get_collab_snapshot_list(oid).await } + + async fn add_connected_user(&self, uid: i64, device_id: &str) { + self.as_ref().add_connected_user(uid, device_id).await + } + + async fn remove_connected_user(&self, uid: i64, device_id: &str) { + self.as_ref().remove_connected_user(uid, device_id).await + } } #[derive(Debug, Clone, Deserialize, Serialize)] diff --git a/services/appflowy-collaborate/src/lib.rs b/services/appflowy-collaborate/src/lib.rs index c7599617..34a0ae9e 100644 --- a/services/appflowy-collaborate/src/lib.rs +++ b/services/appflowy-collaborate/src/lib.rs @@ -6,6 +6,7 @@ mod group; mod metrics; mod permission; mod rt_server; +pub mod shared_state; mod util; pub use metrics::*; diff --git a/services/appflowy-collaborate/src/rt_server.rs b/services/appflowy-collaborate/src/rt_server.rs index d11b7c00..a4f24866 100644 --- a/services/appflowy-collaborate/src/rt_server.rs +++ b/services/appflowy-collaborate/src/rt_server.rs @@ -32,6 +32,7 @@ pub struct CollaborationServer { connect_state: ConnectState, group_sender_by_object_id: Arc>, access_control: Arc, + storage: Arc, #[allow(dead_code)] metrics: Arc, metrics_calculate: CollabMetricsCalculate, @@ -68,7 +69,9 @@ where spawn_collaboration_command(command_recv, &group_sender_by_object_id); spawn_metrics(&metrics, &metrics_calculate, &storage); + Ok(Self { + storage, group_manager, connect_state, group_sender_by_object_id, @@ -94,13 +97,17 @@ where let group_manager = self.group_manager.clone(); let connect_state = self.connect_state.clone(); let metrics_calculate = self.metrics_calculate.clone(); + let storage = self.storage.clone(); Box::pin(async move { + storage + .add_connected_user(connected_user.uid, &connected_user.device_id) + .await; + if let Some(old_user) = connect_state.handle_user_connect(connected_user, new_client_router) { // Remove the old user from all collaboration groups. group_manager.remove_user(&old_user).await; } - metrics_calculate.connected_users.store( connect_state.number_of_connected_users() as i64, std::sync::atomic::Ordering::Relaxed, @@ -123,11 +130,16 @@ where let group_manager = self.group_manager.clone(); let connect_state = self.connect_state.clone(); let metrics_calculate = self.metrics_calculate.clone(); + let storage = self.storage.clone(); Box::pin(async move { trace!("[realtime]: disconnect => {}", disconnect_user); let was_removed = connect_state.handle_user_disconnect(&disconnect_user); if was_removed.is_some() { + storage + .remove_connected_user(disconnect_user.uid, &disconnect_user.device_id) + .await; + metrics_calculate.connected_users.store( connect_state.number_of_connected_users() as i64, std::sync::atomic::Ordering::Relaxed, diff --git a/services/appflowy-collaborate/src/shared_state.rs b/services/appflowy-collaborate/src/shared_state.rs new file mode 100644 index 00000000..f57b2d72 --- /dev/null +++ b/services/appflowy-collaborate/src/shared_state.rs @@ -0,0 +1,74 @@ +use crate::error::RealtimeError; +use futures_util::StreamExt; +use redis::{pipe, AsyncCommands, AsyncIter}; + +#[derive(Clone)] +pub struct RealtimeSharedState { + redis_conn_manager: redis::aio::ConnectionManager, +} + +impl RealtimeSharedState { + pub fn new(redis_conn_manager: redis::aio::ConnectionManager) -> Self { + Self { redis_conn_manager } + } + pub async fn add_connected_user(&self, uid: i64, device_id: &str) -> Result<(), RealtimeError> { + let mut conn = self.redis_conn_manager.clone(); + let key = realtime_shared_state_cache_key(&uid, device_id); + conn + .set_ex(key, "1", 60 * 60 * 3) + .await + .map_err(|err| RealtimeError::Internal(err.into()))?; + Ok(()) + } + + pub async fn remove_connected_user( + &self, + uid: i64, + device_id: &str, + ) -> Result<(), RealtimeError> { + let mut conn = self.redis_conn_manager.clone(); + let key = realtime_shared_state_cache_key(&uid, device_id); + conn + .del(key) + .await + .map_err(|err| RealtimeError::Internal(err.into()))?; + Ok(()) + } + + pub async fn is_user_connected(&self, uid: &i64, device_id: &str) -> Result { + let mut conn = self.redis_conn_manager.clone(); + let key = realtime_shared_state_cache_key(uid, device_id); + let result: Option = conn + .get(key) + .await + .map_err(|err| RealtimeError::Internal(err.into()))?; + Ok(result.is_some()) + } + + pub async fn remove_all_connected_users(&self) -> Result<(), RealtimeError> { + let mut conn = self.redis_conn_manager.clone(); + let iter: AsyncIter = conn + .scan_match(format!("{}:*", REALTIME_SHARE_STATE_PREFIX)) + .await + .map_err(|err| RealtimeError::Internal(err.into()))?; + let keys_to_delete: Vec<_> = iter.collect().await; + if !keys_to_delete.is_empty() { + let mut pipeline = pipe(); + for key in keys_to_delete.iter() { + pipeline.del(key); + } + pipeline + .query_async(&mut conn) + .await + .map_err(|err| RealtimeError::Internal(err.into()))?; + } + Ok(()) + } +} + +pub(crate) const REALTIME_SHARE_STATE_PREFIX: &str = "realtime_shared_state_v0"; + +#[inline] +pub(crate) fn realtime_shared_state_cache_key(uid: &i64, device_id: &str) -> String { + format!("{}:{}:{}", REALTIME_SHARE_STATE_PREFIX, uid, device_id) +} diff --git a/services/appflowy-collaborate/tests/main.rs b/services/appflowy-collaborate/tests/main.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/services/appflowy-collaborate/tests/main.rs @@ -0,0 +1 @@ + diff --git a/services/appflowy-collaborate/tests/shared_state_test.rs b/services/appflowy-collaborate/tests/shared_state_test.rs new file mode 100644 index 00000000..0485b613 --- /dev/null +++ b/services/appflowy-collaborate/tests/shared_state_test.rs @@ -0,0 +1,63 @@ +use anyhow::Context; +use appflowy_collaborate::shared_state::RealtimeSharedState; + +async fn redis_client() -> redis::Client { + let redis_uri = "redis://localhost:6379"; + redis::Client::open(redis_uri) + .context("failed to connect to redis") + .unwrap() +} + +#[tokio::test] +async fn connected_user_test() { + let redis_client = redis_client().await; + let shared_state = RealtimeSharedState::new(redis_client.get_connection_manager().await.unwrap()); + + let device_id = uuid::Uuid::new_v4().to_string(); + let is_connected = shared_state + .is_user_connected(&1, &device_id) + .await + .unwrap(); + assert!(!is_connected); + + shared_state + .add_connected_user(1, &device_id) + .await + .unwrap(); + + let is_connected = shared_state + .is_user_connected(&1, &device_id) + .await + .unwrap(); + assert!(is_connected); + + shared_state + .remove_connected_user(1, &device_id) + .await + .unwrap(); + + let is_connected = shared_state + .is_user_connected(&1, &device_id) + .await + .unwrap(); + assert!(!is_connected); +} + +#[tokio::test] +async fn remove_all_connected_user_test() { + let redis_client = redis_client().await; + let shared_state = RealtimeSharedState::new(redis_client.get_connection_manager().await.unwrap()); + + let device_id = uuid::Uuid::new_v4().to_string(); + shared_state + .add_connected_user(1, &device_id) + .await + .unwrap(); + shared_state.remove_all_connected_users().await.unwrap(); + + let is_connected = shared_state + .is_user_connected(&1, &device_id) + .await + .unwrap(); + assert!(!is_connected); +} diff --git a/src/api/workspace.rs b/src/api/workspace.rs index 5a03f87e..fe326c64 100644 --- a/src/api/workspace.rs +++ b/src/api/workspace.rs @@ -934,6 +934,16 @@ async fn post_realtime_message_stream_handler( event!(tracing::Level::INFO, "message len: {}", bytes.len()); let device_id = device_id.to_string(); + // Only send message to websocket server when the user is connected + if !state + .realtime_shared_state + .is_user_connected(&uid, &device_id) + .await + .unwrap_or(false) + { + return Ok(Json(AppResponse::Ok())); + } + let message = parser_realtime_msg(bytes.freeze(), req.clone()).await?; let stream_message = ClientStreamMessage { uid, diff --git a/src/application.rs b/src/application.rs index 19c26104..9d254907 100644 --- a/src/application.rs +++ b/src/application.rs @@ -34,6 +34,7 @@ use actix_web::{dev::Server, web, web::Data, App, HttpServer}; use anyhow::{Context, Error}; use appflowy_ai_client::client::AppFlowyAIClient; use appflowy_collaborate::command::{CLCommandReceiver, CLCommandSender}; +use appflowy_collaborate::shared_state::RealtimeSharedState; use appflowy_collaborate::CollaborationServer; use database::file::bucket_s3_impl::S3BucketStorage; use openssl::ssl::{SslAcceptor, SslAcceptorBuilder, SslFiletype, SslMethod}; @@ -246,6 +247,10 @@ pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result Result { snapshot_control: SnapshotControl, rt_cmd_sender: CLCommandSender, queue: Arc, + shared_state: RealtimeSharedState, } impl CollabStorageImpl @@ -55,6 +57,7 @@ where redis_conn_manager: RedisConnectionManager, metrics: Arc, ) -> Self { + let shared_state = RealtimeSharedState::new(redis_conn_manager.clone()); let queue = Arc::new(StorageQueue::new_with_metrics( cache.clone(), redis_conn_manager, @@ -67,6 +70,7 @@ where snapshot_control, rt_cmd_sender, queue, + shared_state, } } @@ -376,36 +380,20 @@ where async fn get_collab_snapshot_list(&self, oid: &str) -> AppResult { self.snapshot_control.get_collab_snapshot_list(oid).await } -} -#[allow(dead_code)] -#[derive(Clone)] -pub struct CollabUserState { - redis_conn_manager: RedisConnectionManager, -} - -#[allow(dead_code)] -impl CollabUserState { - async fn add_connected_user(&self, uid: i64, device_id: &str) -> AppResult<()> { - let mut conn = self.redis_conn_manager.clone(); - redis::cmd("HSET") - .arg(uid) - .arg(device_id) - .arg("true") - .query_async(&mut conn) - .await - .map_err(|err| AppError::Internal(err.into()))?; - Ok(()) + async fn add_connected_user(&self, uid: i64, device_id: &str) { + if let Err(err) = self.shared_state.add_connected_user(uid, device_id).await { + error!("Failed to add connected user: {}", err); + } } - async fn remove_connected_user(&self, uid: i64, device_id: &str) -> AppResult<()> { - let mut conn = self.redis_conn_manager.clone(); - redis::cmd("HDEL") - .arg(uid) - .arg(device_id) - .query_async(&mut conn) + async fn remove_connected_user(&self, uid: i64, device_id: &str) { + if let Err(err) = self + .shared_state + .remove_connected_user(uid, device_id) .await - .map_err(|err| AppError::Internal(err.into()))?; - Ok(()) + { + error!("Failed to remove connected user: {}", err); + } } } diff --git a/src/state.rs b/src/state.rs index 5167849f..c16383e5 100644 --- a/src/state.rs +++ b/src/state.rs @@ -11,6 +11,7 @@ use access_control::access::AccessControl; use access_control::metrics::AccessControlMetrics; use app_error::AppError; use appflowy_ai_client::client::AppFlowyAIClient; +use appflowy_collaborate::shared_state::RealtimeSharedState; use appflowy_collaborate::CollabRealtimeMetrics; use dashmap::DashMap; use database::file::bucket_s3_impl::S3BucketStorage; @@ -45,6 +46,7 @@ pub struct AppState { pub gotrue_admin: GoTrueAdmin, pub mailer: Mailer, pub ai_client: AppFlowyAIClient, + pub realtime_shared_state: RealtimeSharedState, #[cfg(feature = "history")] pub grpc_history_client: tonic_proto::history::history_client::HistoryClient, diff --git a/tests/collab/single_device_edit.rs b/tests/collab/single_device_edit.rs index 74740203..98392fef 100644 --- a/tests/collab/single_device_edit.rs +++ b/tests/collab/single_device_edit.rs @@ -1,16 +1,20 @@ -use crate::collab::util::{generate_random_string, make_big_collab_doc_state}; +use crate::collab::util::{ + generate_random_bytes, generate_random_string, make_big_collab_doc_state, +}; use assert_json_diff::assert_json_eq; use client_api_test_util::*; use collab_entity::CollabType; use database_entity::dto::AFAccessLevel; use serde_json::json; use std::collections::HashMap; +use std::sync::Arc; +use collab::core::origin::CollabOrigin; use std::time::Duration; use tokio::time::sleep; -use collab_rt_entity::MAXIMUM_REALTIME_MESSAGE_SIZE; +use collab_rt_entity::{CollabMessage, RealtimeMessage, UpdateSync, MAXIMUM_REALTIME_MESSAGE_SIZE}; use uuid::Uuid; #[tokio::test] @@ -576,27 +580,45 @@ async fn post_realtime_message_test() { } } -// #[tokio::test] -// async fn post_large_num_of_realtime_message_request_test() { -// let client = Arc::new(TestClient::new_user().await); -// let mut handles = vec![]; -// for _ in 0..100 { -// let cloned_client = client.clone(); -// let handle = tokio::spawn(async move { -// let message = RealtimeMessage::Collab(CollabMessage::ClientUpdateSync(UpdateSync::new( -// CollabOrigin::Empty, -// "fake_object_id".to_string(), -// generate_random_bytes(1024), -// 1, -// ))) -// .encode() -// .unwrap(); -// cloned_client.post_realtime_binary(message).await.unwrap(); -// }); -// handles.push(handle); -// } -// futures::future::join_all(handles).await; -// } +#[tokio::test] +async fn post_realtime_message_without_ws_connect_test() { + let client = Arc::new(TestClient::new_user_without_ws_conn().await); + let mut handles = vec![]; + + // try to post 10 realtime message without connect to the websocket server. + for _ in 0..10 { + let cloned_client = client.clone(); + let handle = tokio::spawn(async move { + let message = RealtimeMessage::Collab(CollabMessage::ClientUpdateSync(UpdateSync::new( + CollabOrigin::Empty, + uuid::Uuid::new_v4().to_string(), + generate_random_bytes(1024), + 1, + ))) + .encode() + .unwrap(); + cloned_client.post_realtime_binary(message).await.unwrap(); + }); + handles.push(handle); + } + for result in futures::future::join_all(handles).await { + result.unwrap(); + } +} + +#[tokio::test] +async fn post_realtime_message_with_ws_connect_test() { + let client = Arc::new(TestClient::new_user().await); + let message = RealtimeMessage::Collab(CollabMessage::ClientUpdateSync(UpdateSync::new( + CollabOrigin::Empty, + uuid::Uuid::new_v4().to_string(), + generate_random_bytes(1024), + 1, + ))) + .encode() + .unwrap(); + client.post_realtime_binary(message).await.unwrap(); +} #[tokio::test] async fn simulate_10_offline_user_connect_and_then_sync_document_test() {