diff --git a/src/application.rs b/src/application.rs index f9a92fed..901ae335 100644 --- a/src/application.rs +++ b/src/application.rs @@ -1,5 +1,5 @@ use crate::api::metrics::{metrics_scope, AppFlowyCloudMetrics}; -use crate::biz::casbin::adapter::PgAdapter; + use crate::component::auth::HEADER_TOKEN; use crate::config::config::{Config, DatabaseSetting, GoTrueSetting, S3Setting}; use crate::middleware::request_id::RequestIdMiddleware; @@ -29,7 +29,7 @@ use crate::api::file_storage::file_storage_scope; use crate::api::user::user_scope; use crate::api::workspace::{collab_scope, workspace_scope}; use crate::api::ws::ws_scope; -use crate::biz::casbin::access_control::{AccessControl, MODEL_CONF}; +use crate::biz::casbin::access_control::AccessControl; use crate::biz::collab::access_control::CollabHttpAccessControl; use crate::biz::collab::storage::init_collab_storage; use crate::biz::pg_listener::PgListeners; @@ -38,7 +38,6 @@ use crate::biz::workspace::access_control::WorkspaceHttpAccessControl; use crate::middleware::access_control_mw::WorkspaceAccessControl; use crate::middleware::metrics_mw::MetricsMiddleware; -use casbin::CoreApi; use database::file::bucket_s3_impl::S3BucketStorage; use prometheus_client::registry::Registry; use realtime::collaborate::{CollabServer, RealtimeMetrics}; @@ -188,15 +187,12 @@ pub async fn init_state(config: &Config) -> Result { let workspace_member_listener = pg_listeners.subscribe_workspace_member_change(); info!("Setting up access controls..."); - let access_control_model = casbin::DefaultModel::from_str(MODEL_CONF).await?; - let access_control_adapter = PgAdapter::new(pg_pool.clone()); - let enforcer = casbin::Enforcer::new(access_control_model, access_control_adapter).await?; let access_control = AccessControl::new( pg_pool.clone(), collab_member_listener, workspace_member_listener, - enforcer, - ); + ) + .await?; let collab_access_control = access_control.new_collab_access_control(); let workspace_access_control = access_control.new_workspace_access_control(); diff --git a/src/biz/casbin/access_control.rs b/src/biz/casbin/access_control.rs index 1b6b1898..74dfa899 100644 --- a/src/biz/casbin/access_control.rs +++ b/src/biz/casbin/access_control.rs @@ -4,10 +4,13 @@ use crate::biz::casbin::pg_listen::*; use crate::biz::casbin::workspace_ac::WorkspaceAccessControlImpl; use app_error::AppError; -use casbin::Enforcer; +use casbin::CoreApi; use database_entity::dto::{AFAccessLevel, AFRole}; +use crate::biz::casbin::adapter::PgAdapter; use actix_http::Method; +use anyhow::anyhow; +use dashmap::DashMap; use sqlx::PgPool; use std::sync::Arc; use tokio::sync::broadcast; @@ -30,16 +33,32 @@ pub struct AccessControl { } impl AccessControl { - pub fn new( + pub async fn new( pg_pool: PgPool, collab_listener: broadcast::Receiver, workspace_listener: broadcast::Receiver, - enforcer: Enforcer, - ) -> Self { - let enforcer = Arc::new(AFEnforcer::new(enforcer)); + ) -> Result { + let enforcer_result_cache = Arc::new(DashMap::new()); + let action_cache = Arc::new(DashMap::new()); + + let access_control_model = casbin::DefaultModel::from_str(MODEL_CONF) + .await + .map_err(|e| AppError::Internal(anyhow!("Failed to create access control model: {}", e)))?; + let access_control_adapter = PgAdapter::new(pg_pool.clone(), action_cache.clone()); + let enforcer = casbin::Enforcer::new(access_control_model, access_control_adapter) + .await + .map_err(|e| { + AppError::Internal(anyhow!("Failed to create access control enforcer: {}", e)) + })?; + + let enforcer = Arc::new(AFEnforcer::new( + enforcer, + enforcer_result_cache, + action_cache, + )); spawn_listen_on_workspace_member_change(workspace_listener, enforcer.clone()); spawn_listen_on_collab_member_change(pg_pool, collab_listener, enforcer.clone()); - Self { enforcer } + Ok(Self { enforcer }) } pub fn new_collab_access_control(&self) -> CollabAccessControlImpl { CollabAccessControlImpl::new(self.clone()) diff --git a/src/biz/casbin/adapter.rs b/src/biz/casbin/adapter.rs index 3b9f4809..41230f9e 100644 --- a/src/biz/casbin/adapter.rs +++ b/src/biz/casbin/adapter.rs @@ -1,10 +1,12 @@ use crate::biz::casbin::access_control::{Action, ObjectType, ToCasbinAction}; +use crate::biz::casbin::enforcer::ActionCacheKey; use async_trait::async_trait; -use casbin::error::AdapterError; + use casbin::Adapter; use casbin::Filter; use casbin::Model; use casbin::Result; +use dashmap::DashMap; use database::collab::select_collab_member_access_level; use database::pg_row::AFCollabMemerAccessLevelRow; use database::pg_row::AFWorkspaceMemberPermRow; @@ -12,52 +14,58 @@ use database::workspace::select_workspace_member_perm_stream; use database_entity::dto::{AFAccessLevel, AFRole}; use futures_util::stream::BoxStream; use sqlx::PgPool; +use std::sync::Arc; use tokio_stream::StreamExt; -/// Implmentation of [`casbin::Adapter`] for access control authorisation. +/// Implementation of [`casbin::Adapter`] for access control authorisation. /// Access control policies that are managed by workspace and collab CRUD. pub struct PgAdapter { pg_pool: PgPool, + action_cache: Arc>, } impl PgAdapter { - pub fn new(pg_pool: PgPool) -> Self { - Self { pg_pool } + pub fn new(pg_pool: PgPool, action_cache: Arc>) -> Self { + Self { + pg_pool, + action_cache, + } } } -async fn create_collab_policies( +async fn load_collab_policies( + action_cache: &Arc>, mut stream: BoxStream<'_, sqlx::Result>, ) -> Result>> { let mut policies: Vec> = Vec::new(); - while let Some(result) = stream.next().await { - let member_access_lv = result.map_err(|err| AdapterError(Box::new(err)))?; - let policy = [ - member_access_lv.uid.to_string(), - ObjectType::Collab(&member_access_lv.oid).to_object_id(), - member_access_lv.access_level.to_action(), - ] - .to_vec(); + while let Some(Ok(member_access_lv)) = stream.next().await { + let uid = member_access_lv.uid; + let object_type = ObjectType::Collab(&member_access_lv.oid); + let action = member_access_lv.access_level.to_action(); + action_cache.insert(ActionCacheKey::new(&uid, &object_type), action.clone()); + + let policy = [uid.to_string(), object_type.to_object_id(), action].to_vec(); policies.push(policy); } Ok(policies) } -async fn create_workspace_policies( +async fn load_workspace_policies( + action_cache: &Arc>, mut stream: BoxStream<'_, sqlx::Result>, ) -> Result>> { let mut policies: Vec> = Vec::new(); - while let Some(result) = stream.next().await { - let member_permission = result.map_err(|err| AdapterError(Box::new(err)))?; - let policy = [ - member_permission.uid.to_string(), - ObjectType::Workspace(&member_permission.workspace_id.to_string()).to_object_id(), - member_permission.role.to_action(), - ] - .to_vec(); + while let Some(Ok(member_permission)) = stream.next().await { + let uid = member_permission.uid; + let workspace_id = member_permission.workspace_id.to_string(); + let object_type = ObjectType::Workspace(&workspace_id); + let action = member_permission.role.to_action(); + action_cache.insert(ActionCacheKey::new(&uid, &object_type), action.clone()); + + let policy = [uid.to_string(), object_type.to_object_id(), action].to_vec(); policies.push(policy); } @@ -68,13 +76,15 @@ async fn create_workspace_policies( impl Adapter for PgAdapter { async fn load_policy(&mut self, model: &mut dyn Model) -> Result<()> { let workspace_member_perm_stream = select_workspace_member_perm_stream(&self.pg_pool); - let workspace_policies = create_workspace_policies(workspace_member_perm_stream).await?; + let workspace_policies = + load_workspace_policies(&self.action_cache, workspace_member_perm_stream).await?; // Policy definition `p` of type `p`. See `model.conf` model.add_policies("p", "p", workspace_policies); let collab_member_access_lv_stream = select_collab_member_access_level(&self.pg_pool); - let collab_policies = create_collab_policies(collab_member_access_lv_stream).await?; + let collab_policies = + load_collab_policies(&self.action_cache, collab_member_access_lv_stream).await?; // Policy definition `p` of type `p`. See `model.conf` model.add_policies("p", "p", collab_policies); diff --git a/src/biz/casbin/enforcer.rs b/src/biz/casbin/enforcer.rs index 40a0a0c9..c5d196e0 100644 --- a/src/biz/casbin/enforcer.rs +++ b/src/biz/casbin/enforcer.rs @@ -8,22 +8,27 @@ use casbin::{CoreApi, Enforcer, MgmtApi}; use dashmap::DashMap; use std::ops::Deref; +use std::sync::Arc; use tokio::sync::RwLock; use tracing::{event, trace}; pub struct AFEnforcer { enforcer: RwLock, /// Cache for the result of the policy check. It's a memory cache for faster access. - result_by_policy_cache: DashMap, - action_by_object_cache: DashMap, + enforcer_result_cache: Arc>, + action_cache: Arc>, } impl AFEnforcer { - pub fn new(enforcer: Enforcer) -> Self { + pub fn new( + enforcer: Enforcer, + enforcer_result_cache: Arc>, + action_cache: Arc>, + ) -> Self { Self { enforcer: RwLock::new(enforcer), - result_by_policy_cache: DashMap::new(), - action_by_object_cache: Default::default(), + enforcer_result_cache, + action_cache, } } pub async fn contains(&self, obj: &ObjectType<'_>) -> bool { @@ -65,17 +70,17 @@ impl AFEnforcer { ) -> Result { validate_obj_action(obj, act)?; let policy = vec![uid.to_string(), obj.to_object_id(), act.to_action()]; - let policy_key = CachePolicyKey::new(&policy); + let policy_key = PolicyCacheKey::new(&policy); // if the policy is already in the cache, return. Only update the policy if it's not in the cache. - if let Some(value) = self.result_by_policy_cache.get(&policy_key) { + if let Some(value) = self.enforcer_result_cache.get(&policy_key) { return Ok(*value); } event!(tracing::Level::INFO, "updating policy: {:?}", policy); // only one policy per user per object. So remove the old policy and add the new one. let _remove_policies = self.remove(uid, obj).await?; - let object_key = CacheObjectKey::new(uid, obj); + let object_key = ActionCacheKey::new(uid, obj); let result = self .enforcer .write() @@ -85,9 +90,7 @@ impl AFEnforcer { .map_err(|e| AppError::Internal(anyhow!("fail to add policy: {e:?}"))); if result.is_ok() { trace!("cache action: {}:{}", object_key.0, act.to_action()); - self - .action_by_object_cache - .insert(object_key, act.to_action()); + self.action_cache.insert(object_key, act.to_action()); } result } @@ -126,12 +129,12 @@ impl AFEnforcer { .await .map_err(|e| AppError::Internal(anyhow!("casbin error enforce: {e:?}")))?; - let object_key = CacheObjectKey::new(uid, object_type); - self.action_by_object_cache.remove(&object_key); + let object_key = ActionCacheKey::new(uid, object_type); + self.action_cache.remove(&object_key); for policy in &policies_for_user_on_object { self - .result_by_policy_cache - .remove(&CachePolicyKey::new(policy)); + .enforcer_result_cache + .remove(&PolicyCacheKey::new(policy)); } Ok(policies_for_user_on_object) @@ -142,8 +145,8 @@ impl AFEnforcer { A: ToCasbinAction, { let policy = vec![uid.to_string(), obj.to_object_id(), act.to_action()]; - let policy_key = CachePolicyKey::new(&policy); - if let Some(value) = self.result_by_policy_cache.get(&policy_key) { + let policy_key = PolicyCacheKey::new(&policy); + if let Some(value) = self.enforcer_result_cache.get(&policy_key) { return Ok(*value); } @@ -163,13 +166,13 @@ impl AFEnforcer { .await .enforce(policy) .map_err(|e| AppError::Internal(anyhow!("error enforce: {e:?}")))?; - self.result_by_policy_cache.insert(policy_key, result); + self.enforcer_result_cache.insert(policy_key, result); Ok(result) } pub async fn get_action(&self, uid: &i64, object_type: &ObjectType<'_>) -> Option { - let object_key = CacheObjectKey::new(uid, object_type); - if let Some(value) = self.action_by_object_cache.get(&object_key) { + let object_key = ActionCacheKey::new(uid, object_type); + if let Some(value) = self.action_cache.get(&object_key) { return Some(value.clone()); } @@ -180,23 +183,21 @@ impl AFEnforcer { let action = policies.first()?[POLICY_FIELD_INDEX_ACTION].clone(); trace!("cache action: {}:{}", object_key.0, action.clone()); - self - .action_by_object_cache - .insert(object_key, action.clone()); + self.action_cache.insert(object_key, action.clone()); Some(action) } } #[derive(Debug, Hash, Eq, PartialEq)] -struct CachePolicyKey(String); +pub struct PolicyCacheKey(String); -impl CachePolicyKey { +impl PolicyCacheKey { fn new(policy: &[String]) -> Self { Self(policy.join(":")) } } -impl Deref for CachePolicyKey { +impl Deref for PolicyCacheKey { type Target = str; fn deref(&self) -> &Self::Target { &self.0 @@ -204,10 +205,10 @@ impl Deref for CachePolicyKey { } #[derive(Debug, Hash, Eq, PartialEq)] -struct CacheObjectKey(String); +pub struct ActionCacheKey(String); -impl CacheObjectKey { - fn new(uid: &i64, object_type: &ObjectType<'_>) -> Self { +impl ActionCacheKey { + pub(crate) fn new(uid: &i64, object_type: &ObjectType<'_>) -> Self { Self(format!("{}:{}", uid, object_type.to_object_id())) } } diff --git a/tests/access_control/collab_ac_test.rs b/tests/access_control/collab_ac_test.rs index 5bc8ffad..8a394947 100644 --- a/tests/access_control/collab_ac_test.rs +++ b/tests/access_control/collab_ac_test.rs @@ -2,12 +2,9 @@ use crate::access_control::*; use actix_http::Method; use anyhow::{anyhow, Context}; use appflowy_cloud::biz; -use appflowy_cloud::biz::casbin::access_control::{ - AccessControl, Action, ActionType, ObjectType, MODEL_CONF, -}; -use appflowy_cloud::biz::casbin::adapter::PgAdapter; +use appflowy_cloud::biz::casbin::access_control::{AccessControl, Action, ActionType, ObjectType}; + use appflowy_cloud::biz::pg_listener::PgListeners; -use casbin::{CoreApi, DefaultModel, Enforcer}; use database_entity::dto::{AFAccessLevel, AFRole}; use realtime::collaborate::CollabAccessControl; use shared_entity::dto::workspace_dto::CreateWorkspaceMember; @@ -19,15 +16,14 @@ use tokio::time::sleep; async fn test_collab_access_control(pool: PgPool) -> anyhow::Result<()> { setup_db(&pool).await?; - let model = DefaultModel::from_str(MODEL_CONF).await?; - let enforcer = Enforcer::new(model, PgAdapter::new(pool.clone())).await?; let listeners = PgListeners::new(&pool).await?; let access_control = AccessControl::new( pool.clone(), listeners.subscribe_collab_member_change(), listeners.subscribe_workspace_member_change(), - enforcer, - ); + ) + .await + .unwrap(); let access_control = access_control.new_collab_access_control(); let user = create_user(&pool).await?; @@ -135,15 +131,14 @@ async fn test_collab_access_control(pool: PgPool) -> anyhow::Result<()> { async fn test_collab_access_control_when_obj_not_exist(pool: PgPool) -> anyhow::Result<()> { setup_db(&pool).await?; - let model = DefaultModel::from_str(MODEL_CONF).await?; - let enforcer = Enforcer::new(model, PgAdapter::new(pool.clone())).await?; let listeners = PgListeners::new(&pool).await?; let access_control = AccessControl::new( pool.clone(), listeners.subscribe_collab_member_change(), listeners.subscribe_workspace_member_change(), - enforcer, - ); + ) + .await + .unwrap(); let access_control = access_control.new_collab_access_control(); let user = create_user(&pool).await?; @@ -158,15 +153,14 @@ async fn test_collab_access_control_when_obj_not_exist(pool: PgPool) -> anyhow:: async fn test_collab_access_control_access_http_method(pool: PgPool) -> anyhow::Result<()> { setup_db(&pool).await?; - let model = DefaultModel::from_str(MODEL_CONF).await?; - let enforcer = Enforcer::new(model, PgAdapter::new(pool.clone())).await?; let listeners = PgListeners::new(&pool).await?; let access_control = AccessControl::new( pool.clone(), listeners.subscribe_collab_member_change(), listeners.subscribe_workspace_member_change(), - enforcer, - ); + ) + .await + .unwrap(); let access_control = access_control.new_collab_access_control(); let user = create_user(&pool).await?; @@ -259,15 +253,14 @@ async fn test_collab_access_control_access_http_method(pool: PgPool) -> anyhow:: async fn test_collab_access_control_send_receive_collab_update(pool: PgPool) -> anyhow::Result<()> { setup_db(&pool).await?; - let model = DefaultModel::from_str(MODEL_CONF).await?; - let enforcer = Enforcer::new(model, PgAdapter::new(pool.clone())).await?; let listeners = PgListeners::new(&pool).await?; let access_control = AccessControl::new( pool.clone(), listeners.subscribe_collab_member_change(), listeners.subscribe_workspace_member_change(), - enforcer, - ); + ) + .await + .unwrap(); let access_control = access_control.new_collab_access_control(); let user = create_user(&pool).await?; @@ -344,15 +337,14 @@ async fn test_collab_access_control_send_receive_collab_update(pool: PgPool) -> async fn test_collab_access_control_cache_collab_access_level(pool: PgPool) -> anyhow::Result<()> { setup_db(&pool).await?; - let model = DefaultModel::from_str(MODEL_CONF).await?; - let enforcer = Enforcer::new(model, PgAdapter::new(pool.clone())).await?; let listeners = PgListeners::new(&pool).await?; let access_control = AccessControl::new( pool.clone(), listeners.subscribe_collab_member_change(), listeners.subscribe_workspace_member_change(), - enforcer, - ); + ) + .await + .unwrap(); let access_control = access_control.new_collab_access_control(); let uid = 123; @@ -381,16 +373,14 @@ async fn test_collab_access_control_cache_collab_access_level(pool: PgPool) -> a #[sqlx::test(migrations = false)] async fn test_casbin_access_control_update_remove(pool: PgPool) -> anyhow::Result<()> { setup_db(&pool).await?; - - let model = DefaultModel::from_str(MODEL_CONF).await?; - let enforcer = Enforcer::new(model, PgAdapter::new(pool.clone())).await?; let listeners = PgListeners::new(&pool).await?; let access_control = AccessControl::new( pool.clone(), listeners.subscribe_collab_member_change(), listeners.subscribe_workspace_member_change(), - enforcer, - ); + ) + .await + .unwrap(); let uid = 123; assert!( diff --git a/tests/access_control/member_ac_test.rs b/tests/access_control/member_ac_test.rs index 99110ace..9642213a 100644 --- a/tests/access_control/member_ac_test.rs +++ b/tests/access_control/member_ac_test.rs @@ -4,10 +4,9 @@ use crate::access_control::{ use anyhow::{anyhow, Context}; use app_error::ErrorCode; use appflowy_cloud::biz; -use appflowy_cloud::biz::casbin::access_control::{AccessControl, MODEL_CONF}; -use appflowy_cloud::biz::casbin::adapter::PgAdapter; +use appflowy_cloud::biz::casbin::access_control::AccessControl; + use appflowy_cloud::biz::pg_listener::PgListeners; -use casbin::{CoreApi, DefaultModel, Enforcer}; use database_entity::dto::AFRole; use shared_entity::dto::workspace_dto::{CreateWorkspaceMember, WorkspaceMemberChangeset}; use sqlx::PgPool; @@ -16,15 +15,14 @@ use sqlx::PgPool; async fn test_workspace_access_control_get_role(pool: PgPool) -> anyhow::Result<()> { setup_db(&pool).await?; - let model = DefaultModel::from_str(MODEL_CONF).await?; - let enforcer = Enforcer::new(model, PgAdapter::new(pool.clone())).await?; let listeners = PgListeners::new(&pool).await?; let access_control = AccessControl::new( pool.clone(), listeners.subscribe_collab_member_change(), listeners.subscribe_workspace_member_change(), - enforcer, - ); + ) + .await + .unwrap(); let access_control = access_control.new_workspace_access_control(); let user = create_user(&pool).await?; diff --git a/tests/access_control/user_ac_test.rs b/tests/access_control/user_ac_test.rs index 88a7f674..5ed801d7 100644 --- a/tests/access_control/user_ac_test.rs +++ b/tests/access_control/user_ac_test.rs @@ -4,9 +4,11 @@ use appflowy_cloud::biz; use appflowy_cloud::biz::casbin::access_control::{Action, ObjectType, ToCasbinAction, MODEL_CONF}; use appflowy_cloud::biz::casbin::adapter::PgAdapter; use casbin::{CoreApi, DefaultModel, Enforcer}; +use dashmap::DashMap; use database_entity::dto::{AFAccessLevel, AFRole}; use shared_entity::dto::workspace_dto::CreateWorkspaceMember; use sqlx::PgPool; +use std::sync::Arc; #[sqlx::test(migrations = false)] async fn test_create_user(pool: PgPool) -> anyhow::Result<()> { @@ -22,7 +24,11 @@ async fn test_create_user(pool: PgPool) -> anyhow::Result<()> { .ok_or(anyhow!("workspace should be created"))?; let model = DefaultModel::from_str(MODEL_CONF).await?; - let enforcer = Enforcer::new(model, PgAdapter::new(pool.clone())).await?; + let enforcer = Enforcer::new( + model, + PgAdapter::new(pool.clone(), Arc::new(DashMap::new())), + ) + .await?; assert!(enforcer .enforce(( @@ -107,8 +113,11 @@ async fn test_add_users_to_workspace(pool: PgPool) -> anyhow::Result<()> { .context("adding users to workspace")?; let model = DefaultModel::from_str(MODEL_CONF).await?; - let enforcer = Enforcer::new(model, PgAdapter::new(pool.clone())).await?; - + let enforcer = Enforcer::new( + model, + PgAdapter::new(pool.clone(), Arc::new(DashMap::new())), + ) + .await?; { // Owner let user = user_owner; @@ -233,8 +242,11 @@ async fn test_reload_policy_after_adding_user_to_workspace(pool: PgPool) -> anyh // Create enforcer before adding user to workspace let model = DefaultModel::from_str(MODEL_CONF).await?; - let mut enforcer = Enforcer::new(model, PgAdapter::new(pool.clone())).await?; - + let mut enforcer = Enforcer::new( + model, + PgAdapter::new(pool.clone(), Arc::new(DashMap::new())), + ) + .await?; let members = vec![CreateWorkspaceMember { email: user_member.email.clone(), role: AFRole::Member,