diff --git a/libs/realtime/src/collaborate/broadcast.rs b/libs/realtime/src/collaborate/broadcast.rs index 78e17d33..82e06ffe 100644 --- a/libs/realtime/src/collaborate/broadcast.rs +++ b/libs/realtime/src/collaborate/broadcast.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, Error}; +use anyhow::anyhow; use collab::core::awareness; use std::future::Future; use std::iter::Take; @@ -130,19 +130,18 @@ impl CollabBroadcast { pub fn subscribe( &self, subscriber_origin: CollabOrigin, - sink: Sink, + mut sink: Sink, mut stream: Stream, modified_at: Arc>, ) -> Subscription where - Sink: SinkExt + Send + Sync + Unpin + 'static, + Sink: SinkExt + Clone + Send + Sync + Unpin + 'static, Stream: StreamExt> + Send + Sync + Unpin + 'static, >::Error: std::error::Error + Send + Sync, E: Into + Send + Sync + 'static, { let cloned_origin = subscriber_origin.clone(); trace!("[realtime]: new subscriber: {}", subscriber_origin); - let sink = Arc::new(Mutex::new(sink)); // Receive a update from the document observer and forward the update to all // connected subscribers using its Sink. let sink_stop_tx = { @@ -199,7 +198,7 @@ impl CollabBroadcast { match result { Some(Ok(collab_msg)) => { if object_id == collab_msg.object_id() && collab_msg.payload().is_some() { - handle_user_collab_message(&object_id, &sink, &collab_msg, &collab).await; + handle_client_collab_message(&object_id, &mut sink, &collab_msg, &collab).await; if let Ok(mut modified_at) = modified_at.try_lock() { *modified_at = Instant::now(); } @@ -226,9 +225,10 @@ impl CollabBroadcast { } } -async fn handle_user_collab_message( +/// Handle the message sent from the client +async fn handle_client_collab_message( object_id: &str, - sink: &Arc>, + sink: &mut Sink, collab_msg: &CollabMessage, collab: &MutexCollab, ) where @@ -247,13 +247,11 @@ async fn handle_user_collab_message( Ok(msg) => { let cloned_collab = collab.clone(); let cloned_origin = origin.clone(); - let result = tokio::task::spawn_blocking(move || { - handle_collab_message(&cloned_origin, &ServerSyncProtocol, &cloned_collab, msg) - }) - .await; + let result = + handle_collab_message(&cloned_origin, &ServerSyncProtocol, &cloned_collab, msg); match result { - Ok(Ok(payload)) => match origin.as_ref() { + Ok(payload) => match origin.as_ref() { None => warn!("Client message does not have a origin"), Some(origin) => { if let Some(msg_id) = collab_msg.msg_id() { @@ -266,22 +264,14 @@ async fn handle_user_collab_message( ); trace!("Send response to client: {}", resp); - match sink.try_lock() { - Ok(mut sink) => { - if let Err(err) = sink.send(resp.into()).await { - trace!("fail to send response to client: {}", err); - } - }, - Err(err) => error!("Requires sink lock failed: {:?}", err), + if let Err(err) = sink.send(resp.into()).await { + trace!("fail to send response to client: {}", err); } } }, }, - Ok(Err(err)) => { - error!("object id:{} =>{}", object_id, err); - }, Err(err) => { - error!("internal error when handle user ws message: {}", err); + error!("object id:{} =>{}", object_id, err); }, } }, @@ -356,14 +346,14 @@ fn gen_awareness_update_message( Ok(update) } -pub struct SinkCollabMessageAction<'a, Sink> { - pub sink: &'a Arc>, +pub struct SinkCollabMessageAction<'a, Sink: Clone> { + pub sink: &'a Sink, pub message: CollabMessage, } impl<'a, Sink> SinkCollabMessageAction<'a, Sink> where - Sink: SinkExt + Send + Sync + Unpin + 'a, + Sink: SinkExt + Clone + Send + Sync + Unpin + 'a, { pub fn run(self) -> Retry, SinkCollabMessageAction<'a, Sink>> { let retry_strategy = FixedInterval::new(Duration::from_secs(2)).take(5); @@ -373,19 +363,16 @@ where impl<'a, Sink> Action for SinkCollabMessageAction<'a, Sink> where - Sink: SinkExt + Send + Sync + Unpin + 'a, + Sink: SinkExt + Clone + Send + Sync + Unpin + 'a, { type Future = Pin> + Send + Sync + 'a>>; type Item = (); type Error = RealtimeError; fn run(&mut self) -> Self::Future { - let sink = self.sink.clone(); + let mut sink = self.sink.clone(); let message = self.message.clone(); Box::pin(async move { - let mut sink = sink - .try_lock() - .map_err(|err| RealtimeError::Internal(Error::from(err)))?; sink .send(message) .await diff --git a/libs/realtime/src/collaborate/group_control.rs b/libs/realtime/src/collaborate/group_control.rs index 9cc50752..846887d0 100644 --- a/libs/realtime/src/collaborate/group_control.rs +++ b/libs/realtime/src/collaborate/group_control.rs @@ -228,7 +228,7 @@ where sink: Sink, stream: Stream, ) where - Sink: SinkExt + Send + Sync + Unpin + 'static, + Sink: SinkExt + Clone + Send + Sync + Unpin + 'static, Stream: StreamExt> + Send + Sync + Unpin + 'static, >::Error: std::error::Error + Send + Sync, E: Into + Send + Sync + 'static, diff --git a/libs/realtime/src/collaborate/metrics.rs b/libs/realtime/src/collaborate/metrics.rs index dc3176ae..d4207044 100644 --- a/libs/realtime/src/collaborate/metrics.rs +++ b/libs/realtime/src/collaborate/metrics.rs @@ -46,8 +46,8 @@ impl RealtimeMetrics { } pub fn record_mem_cache_usage(&self, size_in_bytes: usize) { - let size_in_mb = size_in_bytes / (1024 * 1024); - trace!("[metrics]: mem_cache_usage: {} MB", size_in_mb); + let size_in_mb = size_in_bytes / 1024; + trace!("[metrics]: mem_cache_usage: {} KB", size_in_mb); self.mem_cache_usage.set(size_in_mb as i64); } diff --git a/libs/realtime/src/util/channel_ext.rs b/libs/realtime/src/util/channel_ext.rs index c23d55bd..887312cb 100644 --- a/libs/realtime/src/util/channel_ext.rs +++ b/libs/realtime/src/util/channel_ext.rs @@ -4,6 +4,7 @@ use std::fmt::Debug; use std::pin::Pin; use std::task::{Context, Poll}; +#[derive(Clone)] pub struct UnboundedSenderSink(pub tokio::sync::mpsc::UnboundedSender); impl UnboundedSenderSink { diff --git a/src/biz/casbin/access_control.rs b/src/biz/casbin/access_control.rs index 11f5b335..ccafabf4 100644 --- a/src/biz/casbin/access_control.rs +++ b/src/biz/casbin/access_control.rs @@ -14,7 +14,7 @@ use anyhow::anyhow; use dashmap::DashMap; use sqlx::PgPool; use std::sync::Arc; -use std::time::Instant; + use tokio::sync::broadcast; /// Manages access control. @@ -32,6 +32,7 @@ use tokio::sync::broadcast; #[derive(Clone)] pub struct AccessControl { enforcer: Arc, + #[allow(dead_code)] access_control_metrics: Arc, } @@ -97,12 +98,7 @@ impl AccessControl { where A: ToCasbinAction, { - let start = Instant::now(); - let result = self.enforcer.enforce(uid, obj, act).await; - self - .access_control_metrics - .record_enforce_duration(start.elapsed().as_millis() as u64); - result + self.enforcer.enforce(uid, obj, act).await } pub async fn get_access_level(&self, uid: &i64, oid: &str) -> Option { @@ -124,6 +120,38 @@ impl AccessControl { } } +/// policy in db: +/// p = 1, 123, 1 (1 mean AFRole::Owner) +/// p = 1, 456, 50 (50 mean AFAccessLevel::FullAccess) +/// +/// role_definition in db: +/// g = _, _ +/// af role: +/// ["1", "delete"], ["1", "write"], ["1", "read"], +/// ["2", "write"], ["2", "read"], +/// ["3", "read"], +/// af access level: +/// ["10", "read"], +/// ["20", "read"], +/// ["30", "read"], ["30", "write"], +/// ["50", "read"], ["50", "write"], ["50", "delete"] +/// +/// matchers: +/// r.sub == p.sub && p.obj == r.obj && g(p.act, r.act) +/// +/// Example: +/// request: +/// 1. api/workspace/123, user=1, workspace_id=123 GET +/// r = sub = 1, obj = 123, act =read +/// p = sub = 1, obj = 123, act = 1 +/// +/// Evaluation: +/// 1. Subject Match: r.sub == p.sub +/// 2. Object Match: p.obj == r.obj +/// 3. Action Permission: g(p.act, r.act) => g(1, read) => ["1", "read"] +/// Result: +/// Allow +/// pub const MODEL_CONF: &str = r###" [request_definition] r = sub, obj, act @@ -133,13 +161,12 @@ p = sub, obj, act [role_definition] g = _, _ # rule for action -g2 = _, _ # rule for collab object id [policy_effect] e = some(where (p.eft == allow)) [matchers] -m = r.sub == p.sub && g2(p.obj, r.obj) && g(p.act, r.act) +m = r.sub == p.sub && p.obj == r.obj && g(p.act, r.act) "###; /// Represents the entity stored at the index of the access control policy. diff --git a/src/biz/casbin/adapter.rs b/src/biz/casbin/adapter.rs index 78797559..8d9f2727 100644 --- a/src/biz/casbin/adapter.rs +++ b/src/biz/casbin/adapter.rs @@ -134,7 +134,6 @@ impl Adapter for PgAdapter { }, } } - // Grouping definition `g` of type `g`. See `model.conf` model.add_policies("g", "g", grouping_policies); self diff --git a/src/biz/casbin/collab_ac.rs b/src/biz/casbin/collab_ac.rs index 3941adba..8ef76a9e 100644 --- a/src/biz/casbin/collab_ac.rs +++ b/src/biz/casbin/collab_ac.rs @@ -1,4 +1,4 @@ -use crate::biz::casbin::access_control::AccessControl; +use crate::biz::casbin::access_control::{AccessControl, Action}; use crate::biz::casbin::access_control::{ActionType, ObjectType}; use actix_http::Method; use app_error::AppError; @@ -69,31 +69,28 @@ impl CollabAccessControl for CollabAccessControlImpl { async fn can_access_http_method( &self, - _uid: &i64, - _oid: &str, - _method: &Method, + uid: &i64, + oid: &str, + method: &Method, ) -> Result { - Ok(true) - // let action = Action::from(method); - // self - // .access_control - // .enforce(uid, &ObjectType::Collab(oid), action) - // .await + let action = Action::from(method); + self + .access_control + .enforce(uid, &ObjectType::Collab(oid), action) + .await } - async fn can_send_collab_update(&self, _uid: &i64, _oid: &str) -> Result { - Ok(true) - // self - // .access_control - // .enforce(uid, &ObjectType::Collab(oid), Action::Write) - // .await + async fn can_send_collab_update(&self, uid: &i64, oid: &str) -> Result { + self + .access_control + .enforce(uid, &ObjectType::Collab(oid), Action::Write) + .await } - async fn can_receive_collab_update(&self, _uid: &i64, _oid: &str) -> Result { - Ok(true) - // self - // .access_control - // .enforce(uid, &ObjectType::Collab(oid), Action::Read) - // .await + async fn can_receive_collab_update(&self, uid: &i64, oid: &str) -> Result { + self + .access_control + .enforce(uid, &ObjectType::Collab(oid), Action::Read) + .await } } diff --git a/src/biz/casbin/enforcer.rs b/src/biz/casbin/enforcer.rs index 5da0565f..7c9324ba 100644 --- a/src/biz/casbin/enforcer.rs +++ b/src/biz/casbin/enforcer.rs @@ -160,6 +160,7 @@ impl AFEnforcer { .get_filtered_policy(POLICY_FIELD_INDEX_OBJECT, vec![obj.to_object_id()]); if policies_for_object.is_empty() { + self.enforcer_result_cache.insert(policy_key, true); return Ok(true); } diff --git a/src/biz/collab/storage.rs b/src/biz/collab/storage.rs index 06e50698..0bb8493f 100644 --- a/src/biz/collab/storage.rs +++ b/src/biz/collab/storage.rs @@ -240,7 +240,7 @@ where Some(encoded_collab) => { event!( tracing::Level::DEBUG, - "Get encoded collab:{} from redis", + "Get encoded collab:{} from cache", params.object_id ); Ok(encoded_collab) diff --git a/tests/access_control/collab_ac_test.rs b/tests/access_control/collab_ac_test.rs index 2e4ac2e5..aeef6733 100644 --- a/tests/access_control/collab_ac_test.rs +++ b/tests/access_control/collab_ac_test.rs @@ -120,7 +120,9 @@ async fn test_collab_access_control_when_obj_not_exist(pool: PgPool) -> anyhow:: let user = create_user(&pool).await?; for method in [Method::GET, Method::POST, Method::PUT, Method::DELETE] { - assert_can_access_http_method(&collab_access_control, &user.uid, "fake_id", method, true).await; + assert_can_access_http_method(&collab_access_control, &user.uid, "fake_id", method, true) + .await + .unwrap(); } Ok(()) @@ -160,7 +162,8 @@ async fn test_collab_access_control_access_http_method(pool: PgPool) -> anyhow:: method, true, ) - .await; + .await + .unwrap(); } assert!( @@ -178,7 +181,8 @@ async fn test_collab_access_control_access_http_method(pool: PgPool) -> anyhow:: Method::GET, true, ) - .await; + .await + .unwrap(); // guest should not have write access assert_can_access_http_method( @@ -188,7 +192,8 @@ async fn test_collab_access_control_access_http_method(pool: PgPool) -> anyhow:: Method::POST, false, ) - .await; + .await + .unwrap(); assert!( !collab_access_control @@ -242,7 +247,6 @@ async fn test_collab_access_control_send_receive_collab_update(pool: PgPool) -> .await; // Need to wait for the listener(spawn_listen_on_workspace_member_change) to receive the event - // sleep(Duration::from_secs(2)).await; assert!( diff --git a/tests/access_control/mod.rs b/tests/access_control/mod.rs index f9e7281c..73f765f9 100644 --- a/tests/access_control/mod.rs +++ b/tests/access_control/mod.rs @@ -1,5 +1,5 @@ use actix_http::Method; -use anyhow::Context; +use anyhow::{Context, Error}; use app_error::ErrorCode; use appflowy_cloud::biz; use appflowy_cloud::biz::casbin::{CollabAccessControlImpl, WorkspaceAccessControlImpl}; @@ -276,7 +276,7 @@ pub async fn assert_can_access_http_method( object_id: &str, method: Method, expected: bool, -) { +) -> Result<(), Error> { let timeout_duration = Duration::from_secs(10); let retry_interval = Duration::from_millis(300); let mut retries = 0usize; @@ -307,9 +307,8 @@ pub async fn assert_can_access_http_method( } }; - timeout(timeout_duration, operation) - .await - .expect("Operation timed out"); + timeout(timeout_duration, operation).await?; + Ok(()) } pub async fn add_workspace_members_in_tx( diff --git a/tests/collab/edit_permission.rs b/tests/collab/edit_permission.rs index 277f1138..e853536d 100644 --- a/tests/collab/edit_permission.rs +++ b/tests/collab/edit_permission.rs @@ -434,6 +434,9 @@ async fn multiple_user_with_read_and_write_permission_edit_same_collab_test() { expected_json.insert(index.to_string(), s); } + // wait 5 seconds to make sure all the server broadcast the updates to all the clients + sleep(Duration::from_secs(5)).await; + // all the clients should have the same collab object assert_json_include!( actual: json!(expected_json), diff --git a/tests/collab/single_device_edit.rs b/tests/collab/single_device_edit.rs index c235ec40..cfc35be4 100644 --- a/tests/collab/single_device_edit.rs +++ b/tests/collab/single_device_edit.rs @@ -467,7 +467,7 @@ async fn post_realtime_message_test() { let task = tokio::spawn(async move { let mut new_user = TestClient::new_user().await; // sleep 2 secs to make sure it do not trigger register user too fast in gotrue - sleep(Duration::from_secs(i % 3)).await; + sleep(Duration::from_secs(i % 5)).await; let object_id = Uuid::new_v4().to_string(); let workspace_id = new_user.workspace_id().await; diff --git a/tests/user/sign_up.rs b/tests/user/sign_up.rs index 393c6603..1afb8042 100644 --- a/tests/user/sign_up.rs +++ b/tests/user/sign_up.rs @@ -64,7 +64,7 @@ async fn sign_up_oauth_not_available() { #[tokio::test] async fn concurrent_user_sign_up_test() { let mut tasks = Vec::new(); - for _i in 0..50 { + for _i in 0..30 { let task = tokio::spawn(async move { let _ = TestClient::new_user().await; tokio::time::sleep(Duration::from_millis(300)).await;