diff --git a/Cargo.lock b/Cargo.lock index f8b9ad19..4cbddbd0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -684,11 +684,14 @@ version = "0.1.0" dependencies = [ "access-control", "actix", + "actix-http", + "actix-web", "actix-web-actors", "anyhow", "app-error", "async-stream", "async-trait", + "authentication", "bytes", "chrono", "collab", @@ -701,6 +704,7 @@ dependencies = [ "dashmap", "database", "database-entity", + "dotenvy", "futures", "futures-util", "governor", @@ -712,10 +716,12 @@ dependencies = [ "prometheus-client", "rand 0.8.5", "redis 0.25.2", + "secrecy", "semver", "serde", "serde_json", "serde_repr", + "shared-entity", "sqlx", "thiserror", "tokio", diff --git a/Cargo.toml b/Cargo.toml index a1799323..16bf9cb3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,7 +60,7 @@ prometheus-client.workspace = true itertools = "0.11" uuid = "1.6.1" tokio-tungstenite = { version = "0.20.1", features = ["native-tls"] } -dotenvy = "0.15.7" +dotenvy.workspace = true url = "2.5.0" brotli = "3.4.0" dashmap.workspace = true @@ -180,6 +180,7 @@ app-error = { path = "libs/app-error" } async-trait = "0.1.77" prometheus-client = "0.22.0" collab-stream = { path = "libs/collab-stream" } +dotenvy = "0.15.7" secrecy = { version = "0.8", features = ["serde"] } serde_json = "1.0.111" serde_repr = "0.1.18" diff --git a/services/appflowy-collaborate/Cargo.toml b/services/appflowy-collaborate/Cargo.toml index 63d0e90c..f3cbd8c9 100644 --- a/services/appflowy-collaborate/Cargo.toml +++ b/services/appflowy-collaborate/Cargo.toml @@ -14,9 +14,13 @@ path = "src/lib.rs" [dependencies] access-control.workspace = true actix.workspace = true +actix-web.workspace = true +actix-http = { workspace = true, default-features = false, features = ["openssl", "compress-brotli", "compress-gzip"] } actix-web-actors = { version = "4.3" } -app-error = { workspace = true, features = ["sqlx_error", "tokio_error"] } +app-error = { workspace = true, features = ["sqlx_error", "actix_web_error", "tokio_error"] } +authentication.workspace = true dashmap.workspace = true +dotenvy.workspace = true async-stream.workspace = true futures.workspace = true tracing = "0.1.40" @@ -51,6 +55,8 @@ prometheus-client = "0.22.1" indexmap = "2.2.5" semver = "1.0.22" redis = "0.25.2" +secrecy.workspace = true +shared-entity = { workspace = true, features = ["cloud"]} parking_lot = "0.12.1" lazy_static = "1.4.0" itertools = "0.12.0" diff --git a/services/appflowy-collaborate/src/actix_ws/client/mod.rs b/services/appflowy-collaborate/src/actix_ws/client/mod.rs index 1ccd995b..7be4fcd0 100644 --- a/services/appflowy-collaborate/src/actix_ws/client/mod.rs +++ b/services/appflowy-collaborate/src/actix_ws/client/mod.rs @@ -1 +1,2 @@ pub mod rt_client; +pub use crate::actix_ws::client::rt_client::*; diff --git a/services/appflowy-collaborate/src/api.rs b/services/appflowy-collaborate/src/api.rs new file mode 100644 index 00000000..58b602d2 --- /dev/null +++ b/services/appflowy-collaborate/src/api.rs @@ -0,0 +1,231 @@ +use std::collections::HashMap; +use std::time::Duration; + +use actix::Addr; +use actix_http::header::AUTHORIZATION; +use actix_web::web::{Data, Payload}; +use actix_web::{web, HttpRequest, HttpResponse, Result, Scope}; +use actix_web_actors::ws; +use secrecy::Secret; +use semver::Version; +use tokio::sync::mpsc; +use tokio::sync::mpsc::Sender; +use tracing::{debug, error, instrument, trace}; + +use app_error::AppError; +use collab_rt_entity::user::{AFUserChange, RealtimeUser, UserMessage}; +use collab_rt_entity::RealtimeMessage; +use shared_entity::response::AppResponseError; + +use crate::actix_ws::client::RealtimeClient; +use crate::actix_ws::server::RealtimeServerActor; +use crate::collab::access_control::RealtimeCollabAccessControlImpl; +use crate::collab::storage::CollabAccessControlStorage; +use crate::state::AppState; +use authentication::jwt::{authorization_from_token, UserUuid}; + +pub fn ws_scope() -> Scope { + web::scope("/ws").service(web::resource("/v1").route(web::get().to(establish_ws_connection_v1))) +} +const MAX_FRAME_SIZE: usize = 65_536; // 64 KiB + +pub type RealtimeServerAddr = + Addr>; + +#[instrument(skip_all, err)] +pub async fn establish_ws_connection_v1( + request: HttpRequest, + payload: Payload, + state: Data, + jwt_secret: Data>, + server: Data, + web::Query(query_params): web::Query>, +) -> Result { + // Try to parse the connect info from the request body + // If it fails, try to parse it from the query params + let ConnectInfo { + access_token, + client_version, + device_id, + connect_at, + } = match ConnectInfo::parse_from(&request) { + Ok(info) => info, + Err(_) => { + trace!("Failed to parse connect info from request body. Trying to parse from query params."); + ConnectInfo::parse_from(&query_params)? + }, + }; + + start_connect( + &request, + payload, + &state, + &jwt_secret, + server, + access_token, + device_id, + client_version, + connect_at, + ) + .await +} + +#[allow(clippy::too_many_arguments)] +#[inline] +async fn start_connect( + request: &HttpRequest, + payload: Payload, + state: &Data, + jwt_secret: &Data>, + server: Data, + access_token: String, + device_id: String, + client_app_version: Version, + connect_at: i64, +) -> Result { + let auth = authorization_from_token(access_token.as_str(), jwt_secret)?; + let user_uuid = UserUuid::from_auth(auth)?; + let result = state.user_cache.get_user_uid(&user_uuid).await; + + match result { + Ok(uid) => { + debug!( + "🚀new websocket connect: uid={}, device_id={}, client_version:{}", + uid, device_id, client_app_version + ); + + let session_id = uuid::Uuid::new_v4().to_string(); + let realtime_user = RealtimeUser::new( + uid, + device_id, + session_id, + connect_at, + client_app_version.to_string(), + ); + let (tx, external_source) = mpsc::channel(100); + let client = RealtimeClient::new( + realtime_user, + server.get_ref().clone(), + Duration::from_secs(state.config.websocket.heartbeat_interval as u64), + Duration::from_secs(state.config.websocket.client_timeout as u64), + client_app_version, + external_source, + 10, + ); + + // Receive user change notifications and send them to the client. + listen_on_user_change(state, uid, tx); + + match ws::WsResponseBuilder::new(client, request, payload) + .frame_size(MAX_FRAME_SIZE * 2) + .start() + { + Ok(response) => Ok(response), + Err(e) => { + error!("🔴ws connection error: {:?}", e); + Err(e) + }, + } + }, + Err(err) => { + if err.is_record_not_found() { + return Ok(HttpResponse::NotFound().json("user not found")); + } + Err(AppResponseError::from(err).into()) + }, + } +} + +fn listen_on_user_change(state: &Data, uid: i64, tx: Sender) { + let mut user_change_recv = state.pg_listeners.subscribe_user_change(uid); + actix::spawn(async move { + while let Some(notification) = user_change_recv.recv().await { + // Extract the user object from the notification payload. + if let Some(user) = notification.payload { + trace!("Receive user change: {:?}", user); + // Since bincode serialization is used for RealtimeMessage but does not support the + // Serde `deserialize_any` method, the user metadata is serialized into a JSON string. + // This step ensures compatibility and flexibility for the metadata field. + let metadata = serde_json::to_string(&user.metadata).ok(); + // Construct a UserMessage with the user's details, including the serialized metadata. + let msg = UserMessage::ProfileChange(AFUserChange { + uid: user.uid, + name: user.name, + email: user.email, + metadata, + }); + if tx.send(RealtimeMessage::User(msg)).await.is_err() { + break; + } + } + } + }); +} + +struct ConnectInfo { + access_token: String, + client_version: Version, + device_id: String, + connect_at: i64, +} + +const CLIENT_VERSION: &str = "client-version"; +const DEVICE_ID: &str = "device-id"; +const CONNECT_AT: &str = "connect-at"; + +// Trait for parameter extraction +trait ExtractParameter { + fn extract_param(&self, key: &str) -> Result; +} + +impl ExtractParameter for HashMap { + fn extract_param(&self, key: &str) -> Result { + self + .get(key) + .ok_or_else(|| { + AppError::InvalidRequest(format!("Parameter with given key:{} not found", key)) + }) + .map(|s| s.to_string()) + } +} + +// Implement the trait for HttpRequest +impl ExtractParameter for HttpRequest { + fn extract_param(&self, key: &str) -> Result { + self + .headers() + .get(key) + .ok_or_else(|| AppError::InvalidRequest(format!("Header with given key:{} not found", key))) + .and_then(|value| { + value + .to_str() + .map_err(|_| { + AppError::InvalidRequest(format!("Invalid header value for given key:{}", key)) + }) + .map(|s| s.to_string()) + }) + } +} + +impl ConnectInfo { + fn parse_from(source: &T) -> Result { + let access_token = source.extract_param(AUTHORIZATION.as_str())?; + let client_version_str = source.extract_param(CLIENT_VERSION)?; + let client_version = Version::parse(&client_version_str) + .map_err(|_| AppError::InvalidRequest(format!("Invalid version:{}", client_version_str)))?; + let device_id = source.extract_param(DEVICE_ID)?; + let connect_at = match source.extract_param(CONNECT_AT) { + Ok(start_at) => start_at + .parse::() + .unwrap_or_else(|_| chrono::Utc::now().timestamp()), + Err(_) => chrono::Utc::now().timestamp(), + }; + + Ok(Self { + access_token, + client_version, + device_id, + connect_at, + }) + } +} diff --git a/services/appflowy-collaborate/src/application.rs b/services/appflowy-collaborate/src/application.rs new file mode 100644 index 00000000..d11c070a --- /dev/null +++ b/services/appflowy-collaborate/src/application.rs @@ -0,0 +1,173 @@ +use std::net::TcpListener; +use std::sync::Arc; +use std::time::Duration; + +use actix::Supervisor; +use actix_web::dev::Server; +use actix_web::web::Data; +use actix_web::{App, HttpServer}; +use anyhow::{Context, Error}; +use secrecy::ExposeSecret; +use sqlx::postgres::PgPoolOptions; +use sqlx::PgPool; +use tracing::info; + +use crate::actix_ws::server::RealtimeServerActor; +use access_control::access::AccessControl; +use workspace_access::notification::spawn_listen_on_workspace_member_change; +use workspace_access::WorkspaceAccessControlImpl; + +use crate::api::ws_scope; +use crate::collab::access_control::{ + CollabAccessControlImpl, CollabStorageAccessControlImpl, RealtimeCollabAccessControlImpl, +}; +use crate::collab::cache::CollabCache; +use crate::collab::notification::spawn_listen_on_collab_member_change; +use crate::collab::storage::CollabStorageImpl; +use crate::command::{CLCommandReceiver, CLCommandSender}; +use crate::config::{Config, DatabaseSetting}; +use crate::pg_listener::PgListeners; +use crate::snapshot::SnapshotControl; +use crate::state::{AppMetrics, AppState, UserCache}; +use crate::CollaborationServer; + +pub struct Application { + actix_server: Server, +} + +impl Application { + pub async fn build( + config: Config, + state: AppState, + rt_cmd_recv: CLCommandReceiver, + ) -> Result { + let address = format!("{}:{}", config.application.host, config.application.port); + let listener = TcpListener::bind(&address)?; + info!( + "Collab Service started at {}", + listener.local_addr().unwrap() + ); + let actix_server = run_actix_server(listener, state, config, rt_cmd_recv).await?; + + Ok(Self { actix_server }) + } + + pub async fn run_until_stopped(self) -> Result<(), std::io::Error> { + self.actix_server.await + } +} + +pub async fn run_actix_server( + listener: TcpListener, + state: AppState, + config: Config, + rt_cmd_recv: CLCommandReceiver, +) -> Result { + let storage = state.collab_access_control_storage.clone(); + + // Initialize metrics that which are registered in the registry. + let realtime_server = CollaborationServer::<_, _>::new( + storage.clone(), + RealtimeCollabAccessControlImpl::new(state.access_control.clone()), + state.metrics.realtime_metrics.clone(), + rt_cmd_recv, + state.redis_connection_manager.clone(), + Duration::from_secs(config.collab.group_persistence_interval_secs), + config.collab.edit_state_max_count, + config.collab.edit_state_max_secs, + ) + .await + .unwrap(); + let realtime_server_actor = Supervisor::start(|_| RealtimeServerActor(realtime_server)); + let mut server = HttpServer::new(move || { + App::new() + .app_data(Data::new(state.clone())) + .app_data(Data::new(state.config.gotrue.jwt_secret.clone())) + .app_data(Data::new(realtime_server_actor.clone())) + .service(ws_scope()) + }); + server = server.listen(listener)?; + + Ok(server.run()) +} + +pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result { + let metrics = AppMetrics::new(); + let pg_pool = get_connection_pool(&config.db_settings).await?; + + // User cache + let user_cache = UserCache::new(pg_pool.clone()).await; + info!("Connecting to Redis..."); + let redis_conn_manager = get_redis_client(config.redis_uri.expose_secret()).await?; + + // Pg listeners + info!("Setting up Pg listeners..."); + let pg_listeners = Arc::new(PgListeners::new(&pg_pool).await?); + let access_control = + AccessControl::new(pg_pool.clone(), metrics.access_control_metrics.clone()).await?; + let collab_member_listener = pg_listeners.subscribe_collab_member_change(); + let workspace_member_listener = pg_listeners.subscribe_workspace_member_change(); + + spawn_listen_on_workspace_member_change(workspace_member_listener, access_control.clone()); + spawn_listen_on_collab_member_change( + pg_pool.clone(), + collab_member_listener, + access_control.clone(), + ); + + let collab_access_control = CollabAccessControlImpl::new(access_control.clone()); + let workspace_access_control = WorkspaceAccessControlImpl::new(access_control.clone()); + let collab_cache = CollabCache::new(redis_conn_manager.clone(), pg_pool.clone()); + + let collab_storage_access_control = CollabStorageAccessControlImpl { + collab_access_control: collab_access_control.clone().into(), + workspace_access_control: workspace_access_control.clone().into(), + cache: collab_cache.clone(), + }; + let snapshot_control = SnapshotControl::new( + redis_conn_manager.clone(), + pg_pool.clone(), + metrics.collab_metrics.clone(), + ) + .await; + let collab_storage = Arc::new(CollabStorageImpl::new( + collab_cache.clone(), + collab_storage_access_control, + snapshot_control, + rt_cmd_tx, + redis_conn_manager.clone(), + metrics.collab_metrics.clone(), + )); + let app_state = AppState { + config: Arc::new(config.clone()), + pg_listeners, + user_cache, + redis_connection_manager: redis_conn_manager, + access_control, + collab_access_control_storage: collab_storage, + metrics, + }; + Ok(app_state) +} + +async fn get_redis_client(redis_uri: &str) -> Result { + info!("Connecting to redis with uri: {}", redis_uri); + let manager = redis::Client::open(redis_uri) + .context("failed to connect to redis")? + .get_connection_manager() + .await + .context("failed to get the connection manager")?; + Ok(manager) +} + +async fn get_connection_pool(setting: &DatabaseSetting) -> Result { + info!("Connecting to postgres database with setting: {}", setting); + PgPoolOptions::new() + .max_connections(setting.max_connections) + .acquire_timeout(Duration::from_secs(10)) + .max_lifetime(Duration::from_secs(30 * 60)) + .idle_timeout(Duration::from_secs(30)) + .connect_with(setting.pg_connect_options()) + .await + .map_err(|e| anyhow::anyhow!("Failed to connect to postgres database: {}", e)) +} diff --git a/services/appflowy-collaborate/src/config.rs b/services/appflowy-collaborate/src/config.rs new file mode 100644 index 00000000..da527322 --- /dev/null +++ b/services/appflowy-collaborate/src/config.rs @@ -0,0 +1,123 @@ +use std::fmt::Display; +use std::str::FromStr; + +use anyhow::Context; +use secrecy::Secret; +use sqlx::postgres::{PgConnectOptions, PgSslMode}; + +#[derive(Clone, Debug)] +pub struct Config { + pub application: ApplicationSetting, + pub websocket: WebsocketSetting, + pub db_settings: DatabaseSetting, + pub gotrue: GoTrueSetting, + pub collab: CollabSetting, + pub redis_uri: Secret, +} + +#[derive(Clone, Debug)] +pub struct ApplicationSetting { + pub port: u16, + pub host: String, +} + +#[derive(Clone, Debug)] +pub struct WebsocketSetting { + pub heartbeat_interval: u8, + pub client_timeout: u8, +} + +#[derive(Clone, Debug)] +pub struct DatabaseSetting { + pub pg_conn_opts: PgConnectOptions, + pub require_ssl: bool, + /// PostgreSQL has a maximum of 115 connections to the database, 15 connections are reserved to + /// the super user to maintain the integrity of the PostgreSQL database, and 100 PostgreSQL + /// connections are reserved for system applications. + /// When we exceed the limit of the database connection, then it shows an error message. + pub max_connections: u32, +} + +impl Display for DatabaseSetting { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "DatabaseSetting {{ pg_conn_opts: {:?}, require_ssl: {}, max_connections: {} }}", + self.pg_conn_opts, self.require_ssl, self.max_connections + ) + } +} + +impl DatabaseSetting { + pub fn pg_connect_options(&self) -> PgConnectOptions { + let ssl_mode = if self.require_ssl { + PgSslMode::Require + } else { + PgSslMode::Prefer + }; + let options = self.pg_conn_opts.clone(); + options.ssl_mode(ssl_mode) + } +} + +#[derive(serde::Deserialize, Clone, Debug)] +pub struct GoTrueSetting { + pub jwt_secret: Secret, +} + +#[derive(Clone, Debug)] +pub struct CollabSetting { + pub group_persistence_interval_secs: u64, + pub edit_state_max_count: u32, + pub edit_state_max_secs: i64, +} + +pub fn get_env_var(key: &str, default: &str) -> String { + std::env::var(key).unwrap_or_else(|e| { + tracing::warn!( + "failed to read environment variable: {}, using default value: {}", + e, + default + ); + default.to_owned() + }) +} + +pub fn get_configuration() -> Result { + let config = Config { + application: ApplicationSetting { + port: get_env_var("APPFLOWY_COLLAB_SERVICE_PORT", "8001").parse()?, + host: get_env_var("APPFLOWY_COLLAB_SERVICE_HOST", "0.0.0.0"), + }, + websocket: WebsocketSetting { + heartbeat_interval: get_env_var("APPFLOWY_WEBSOCKET_HEARTBEAT_INTERVAL", "6").parse()?, + client_timeout: get_env_var("APPFLOWY_WEBSOCKET_CLIENT_TIMEOUT", "60").parse()?, + }, + db_settings: DatabaseSetting { + pg_conn_opts: PgConnectOptions::from_str(&get_env_var( + "APPFLOWY_DATABASE_URL", + "postgres://postgres:password@localhost:5432/postgres", + ))?, + require_ssl: get_env_var("APPFLOWY_DATABASE_REQUIRE_SSL", "false") + .parse() + .context("fail to get APPFLOWY_DATABASE_REQUIRE_SSL")?, + max_connections: get_env_var("APPFLOWY_DATABASE_MAX_CONNECTIONS", "40") + .parse() + .context("fail to get APPFLOWY_DATABASE_MAX_CONNECTIONS")?, + }, + gotrue: GoTrueSetting { + jwt_secret: get_env_var("APPFLOWY_GOTRUE_JWT_SECRET", "hello456").into(), + }, + collab: CollabSetting { + group_persistence_interval_secs: get_env_var( + "APPFLOWY_COLLAB_GROUP_PERSISTENCE_INTERVAL", + "60", + ) + .parse()?, + edit_state_max_count: get_env_var("APPFLOWY_COLLAB_EDIT_STATE_MAX_COUNT", "100").parse()?, + edit_state_max_secs: get_env_var("APPFLOWY_COLLAB_EDIT_STATE_MAX_SECS", "360").parse()?, + }, + redis_uri: get_env_var("APPFLOWY_REDIS_URI", "redis://localhost:6379").into(), + }; + Ok(config) +} diff --git a/services/appflowy-collaborate/src/lib.rs b/services/appflowy-collaborate/src/lib.rs index 9d6df64b..a97648b7 100644 --- a/services/appflowy-collaborate/src/lib.rs +++ b/services/appflowy-collaborate/src/lib.rs @@ -1,7 +1,10 @@ pub mod actix_ws; +pub mod api; +pub mod application; mod client; pub mod collab; pub mod command; +pub mod config; pub mod connect_state; pub mod error; mod group; diff --git a/services/appflowy-collaborate/src/main.rs b/services/appflowy-collaborate/src/main.rs index 00306738..d5f7131b 100644 --- a/services/appflowy-collaborate/src/main.rs +++ b/services/appflowy-collaborate/src/main.rs @@ -1,5 +1,19 @@ -#[tokio::main] -async fn main() -> Result<(), Box> { - // add more logics that supports deploy appflowy-collaborate as single service +use appflowy_collaborate::application::{init_state, Application}; +use appflowy_collaborate::config::get_configuration; + +#[actix_web::main] +async fn main() -> anyhow::Result<()> { + // Load environment variables from .env file + dotenvy::dotenv().ok(); + + let conf = + get_configuration().map_err(|e| anyhow::anyhow!("Failed to read configuration: {}", e))?; + + let (tx, rx) = tokio::sync::mpsc::channel(1000); + let state = init_state(&conf, tx) + .await + .map_err(|e| anyhow::anyhow!("Failed to initialize application state: {}", e))?; + let application = Application::build(conf, state, rx).await?; + application.run_until_stopped().await?; Ok(()) } diff --git a/services/appflowy-collaborate/src/pg_listener.rs b/services/appflowy-collaborate/src/pg_listener.rs index 4043cd99..696299e8 100644 --- a/services/appflowy-collaborate/src/pg_listener.rs +++ b/services/appflowy-collaborate/src/pg_listener.rs @@ -6,14 +6,12 @@ use sqlx::PgPool; use tokio::sync::broadcast; use workspace_access::notification::WorkspaceMemberNotification; -#[allow(dead_code)] pub struct PgListeners { user_listener: UserListener, workspace_member_listener: WorkspaceMemberListener, collab_member_listener: CollabMemberListener, } -#[allow(dead_code)] impl PgListeners { pub async fn new(pg_pool: &PgPool) -> Result { let user_listener = UserListener::new(pg_pool, "af_user_channel").await?; @@ -57,9 +55,6 @@ impl PgListeners { } } -#[allow(dead_code)] pub type CollabMemberListener = PostgresDBListener; -#[allow(dead_code)] pub type UserListener = PostgresDBListener; -#[allow(dead_code)] pub type WorkspaceMemberListener = PostgresDBListener; diff --git a/services/appflowy-collaborate/src/state.rs b/services/appflowy-collaborate/src/state.rs index 704b004e..41aa85a5 100644 --- a/services/appflowy-collaborate/src/state.rs +++ b/services/appflowy-collaborate/src/state.rs @@ -1 +1,106 @@ +use std::sync::Arc; + +use dashmap::DashMap; +use futures_util::StreamExt; +use sqlx::PgPool; +use uuid::Uuid; + +use access_control::access::AccessControl; +use access_control::metrics::AccessControlMetrics; +use app_error::AppError; +use database::user::{select_all_uid_uuid, select_uid_from_uuid}; + +use crate::collab::storage::CollabAccessControlStorage; +use crate::config::Config; +use crate::metrics::CollabMetrics; +use crate::pg_listener::PgListeners; +use crate::CollabRealtimeMetrics; + pub type RedisConnectionManager = redis::aio::ConnectionManager; + +#[derive(Clone)] +pub struct AppState { + pub config: Arc, + pub pg_listeners: Arc, + pub user_cache: UserCache, + pub redis_connection_manager: RedisConnectionManager, + pub access_control: AccessControl, + pub collab_access_control_storage: Arc, + pub metrics: AppMetrics, +} + +#[derive(Clone)] +pub struct AppMetrics { + #[allow(dead_code)] + pub registry: Arc, + pub access_control_metrics: Arc, + pub realtime_metrics: Arc, + pub collab_metrics: Arc, +} + +impl Default for AppMetrics { + fn default() -> Self { + Self::new() + } +} + +impl AppMetrics { + pub fn new() -> Self { + let mut registry = prometheus_client::registry::Registry::default(); + let access_control_metrics = Arc::new(AccessControlMetrics::register(&mut registry)); + let realtime_metrics = Arc::new(CollabRealtimeMetrics::register(&mut registry)); + let collab_metrics = Arc::new(CollabMetrics::register(&mut registry)); + Self { + registry: Arc::new(registry), + access_control_metrics, + realtime_metrics, + collab_metrics, + } + } +} + +pub struct AuthenticateUser { + pub uid: i64, +} + +#[derive(Clone)] +pub struct UserCache { + pool: PgPool, + users: Arc>, +} + +impl UserCache { + /// Load all users from database when initializing the cache. + pub async fn new(pool: PgPool) -> Self { + let users = { + let users = DashMap::new(); + let mut stream = select_all_uid_uuid(&pool); + while let Some(Ok(af_user_id)) = stream.next().await { + users.insert( + af_user_id.uuid, + AuthenticateUser { + uid: af_user_id.uid, + }, + ); + } + users + }; + + Self { + pool, + users: Arc::new(users), + } + } + + /// Get the user's uid from the cache or the database. + pub async fn get_user_uid(&self, uuid: &Uuid) -> Result { + if let Some(entry) = self.users.get(uuid) { + return Ok(entry.value().uid); + } + + // If the user is not found in the cache, query the database. + let uid = select_uid_from_uuid(&self.pool, uuid).await?; + self.users.insert(*uuid, AuthenticateUser { uid }); + Ok(uid) + } +}