feat: add post realtime message stream handler to independent collab ws (#647)

This commit is contained in:
Khor Shu Heng 2024-06-25 13:30:05 +08:00 committed by GitHub
parent 0f9fcf2042
commit bdae165849
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 210 additions and 8 deletions

2
Cargo.lock generated
View File

@ -692,6 +692,7 @@ dependencies = [
"async-stream",
"async-trait",
"authentication",
"brotli",
"bytes",
"chrono",
"collab",
@ -714,6 +715,7 @@ dependencies = [
"md5",
"parking_lot 0.12.1",
"prometheus-client",
"prost",
"rand 0.8.5",
"redis 0.25.2",
"secrecy",

View File

@ -62,7 +62,7 @@ uuid = "1.6.1"
tokio-tungstenite = { version = "0.20.1", features = ["native-tls"] }
dotenvy.workspace = true
url = "2.5.0"
brotli = "3.4.0"
brotli.workspace = true
dashmap.workspace = true
async-stream.workspace = true
futures.workspace = true
@ -179,6 +179,7 @@ workspace-access = { path = "libs/workspace-access" }
app-error = { path = "libs/app-error" }
async-trait = "0.1.77"
prometheus-client = "0.22.0"
brotli = "3.4.0"
collab-stream = { path = "libs/collab-stream" }
dotenvy = "0.15.7"
secrecy = { version = "0.8", features = ["serde"] }

View File

@ -19,6 +19,7 @@ actix-http = { workspace = true, default-features = false, features = ["openssl"
actix-web-actors = { version = "4.3" }
app-error = { workspace = true, features = ["sqlx_error", "actix_web_error", "tokio_error"] }
authentication.workspace = true
brotli.workspace = true
dashmap.workspace = true
dotenvy.workspace = true
async-stream.workspace = true
@ -29,6 +30,7 @@ tokio-util = { version = "0.7", features = ["codec"] }
tokio-stream = { version = "0.1.14", features = ["sync"] }
tokio = { workspace = true, features = ["net", "sync", "macros", "rt-multi-thread"] }
async-trait = "0.1.77"
prost.workspace = true
serde.workspace = true
serde_json.workspace = true
serde_repr.workspace = true

View File

@ -1,32 +1,52 @@
use std::collections::HashMap;
use std::str::FromStr;
use std::time::Duration;
use actix::Addr;
use actix_http::header::AUTHORIZATION;
use actix_web::web::{Data, Payload};
use actix_http::header::{HeaderMap, AUTHORIZATION};
use actix_web::web::{Data, Json, Payload, PayloadConfig};
use actix_web::{web, HttpRequest, HttpResponse, Result, Scope};
use actix_web_actors::ws;
use anyhow::anyhow;
use bytes::{Bytes, BytesMut};
use prost::Message;
use secrecy::Secret;
use semver::Version;
use tokio::sync::mpsc;
use tokio::sync::mpsc::Sender;
use tracing::{debug, error, instrument, trace};
use tokio_stream::StreamExt;
use tracing::{debug, error, event, instrument, trace};
use app_error::AppError;
use collab_rt_entity::user::{AFUserChange, RealtimeUser, UserMessage};
use collab_rt_entity::RealtimeMessage;
use shared_entity::response::AppResponseError;
use collab_rt_entity::{HttpRealtimeMessage, RealtimeMessage};
use shared_entity::response::{AppResponse, AppResponseError};
use crate::actix_ws::client::RealtimeClient;
use crate::actix_ws::entities::ClientStreamMessage;
use crate::actix_ws::server::RealtimeServerActor;
use crate::collab::access_control::RealtimeCollabAccessControlImpl;
use crate::collab::storage::CollabAccessControlStorage;
use crate::compression::{
decompress, CompressionType, X_COMPRESSION_BUFFER_SIZE, X_COMPRESSION_TYPE,
};
use crate::state::AppState;
use authentication::jwt::{authorization_from_token, UserUuid};
pub fn ws_scope() -> Scope {
web::scope("/ws").service(web::resource("/v1").route(web::get().to(establish_ws_connection_v1)))
}
pub fn collab_scope() -> Scope {
web::scope("/api/realtime").service(
web::resource("post/stream")
.app_data(
PayloadConfig::new(10 * 1024 * 1024), // 10 MB
)
.route(web::post().to(post_realtime_message_stream_handler)),
)
}
const MAX_FRAME_SIZE: usize = 65_536; // 64 KiB
pub type RealtimeServerAddr =
@ -70,6 +90,147 @@ pub async fn establish_ws_connection_v1(
.await
}
#[instrument(level = "info", skip_all, err)]
async fn post_realtime_message_stream_handler(
user_uuid: UserUuid,
mut payload: Payload,
server: Data<RealtimeServerAddr>,
state: Data<AppState>,
req: HttpRequest,
) -> Result<Json<AppResponse<()>>> {
// TODO(nathan): after upgrade the client application, then the device_id should not be empty
let device_id = device_id_from_headers(req.headers()).unwrap_or_else(|_| "".to_string());
let uid = state
.user_cache
.get_user_uid(&user_uuid)
.await
.map_err(AppResponseError::from)?;
let mut bytes = BytesMut::new();
while let Some(item) = payload.next().await {
bytes.extend_from_slice(&item?);
}
event!(tracing::Level::INFO, "message len: {}", bytes.len());
let device_id = device_id.to_string();
// Only send message to websocket server when the user is connected
if !state
.realtime_shared_state
.is_user_connected(&uid, &device_id)
.await
.unwrap_or(false)
{
return Ok(Json(AppResponse::Ok()));
}
let message = parser_realtime_msg(bytes.freeze(), req.clone()).await?;
let stream_message = ClientStreamMessage {
uid,
device_id,
message,
};
// When the server is under heavy load, try_send may fail. In client side, it will retry to send
// the message later.
match server.try_send(stream_message) {
Ok(_) => return Ok(Json(AppResponse::Ok())),
Err(err) => Err(
AppError::Internal(anyhow!(
"Failed to send message to websocket server, error:{}",
err
))
.into(),
),
}
}
fn device_id_from_headers(headers: &HeaderMap) -> std::result::Result<String, AppError> {
headers
.get("device_id")
.ok_or(AppError::InvalidRequest(
"Missing device_id header".to_string(),
))
.and_then(|header| {
header
.to_str()
.map_err(|err| AppError::InvalidRequest(format!("Failed to parse device_id: {}", err)))
})
.map(|s| s.to_string())
}
fn compress_type_from_header_value(
headers: &HeaderMap,
) -> std::result::Result<CompressionType, AppError> {
let compression_type_str = headers
.get(X_COMPRESSION_TYPE)
.ok_or(AppError::InvalidRequest(
"Missing X-Compression-Type header".to_string(),
))?
.to_str()
.map_err(|err| {
AppError::InvalidRequest(format!("Failed to parse X-Compression-Type: {}", err))
})?;
let buffer_size_str = headers
.get(X_COMPRESSION_BUFFER_SIZE)
.ok_or_else(|| {
AppError::InvalidRequest("Missing X-Compression-Buffer-Size header".to_string())
})?
.to_str()
.map_err(|err| {
AppError::InvalidRequest(format!(
"Failed to parse X-Compression-Buffer-Size: {}",
err
))
})?;
let buffer_size = usize::from_str(buffer_size_str).map_err(|err| {
AppError::InvalidRequest(format!(
"X-Compression-Buffer-Size is not a valid usize: {}",
err
))
})?;
match compression_type_str {
"brotli" => Ok(CompressionType::Brotli { buffer_size }),
s => Err(AppError::InvalidRequest(format!(
"Unknown compression type: {}",
s
))),
}
}
async fn parser_realtime_msg(
payload: Bytes,
req: HttpRequest,
) -> Result<RealtimeMessage, AppError> {
let HttpRealtimeMessage {
device_id: _,
payload,
} =
HttpRealtimeMessage::decode(payload.as_ref()).map_err(|err| AppError::Internal(err.into()))?;
let payload = match req.headers().get(X_COMPRESSION_TYPE) {
None => payload,
Some(_) => match compress_type_from_header_value(req.headers())? {
CompressionType::Brotli { buffer_size } => {
let decompressed_data = decompress(payload, buffer_size).await?;
event!(
tracing::Level::TRACE,
"Decompress realtime http message with len: {}",
decompressed_data.len()
);
decompressed_data
},
},
};
let realtime_msg = tokio::task::spawn_blocking(move || {
RealtimeMessage::decode(&payload)
.map_err(|err| AppError::InvalidRequest(format!("Failed to parse RealtimeMessage: {}", err)))
})
.await
.map_err(AppError::from)??;
Ok(realtime_msg)
}
#[allow(clippy::too_many_arguments)]
#[inline]
async fn start_connect(

View File

@ -10,14 +10,14 @@ use anyhow::{Context, Error};
use secrecy::ExposeSecret;
use sqlx::postgres::PgPoolOptions;
use sqlx::PgPool;
use tracing::info;
use tracing::{info, warn};
use crate::actix_ws::server::RealtimeServerActor;
use access_control::access::AccessControl;
use workspace_access::notification::spawn_listen_on_workspace_member_change;
use workspace_access::WorkspaceAccessControlImpl;
use crate::api::ws_scope;
use crate::api::{collab_scope, ws_scope};
use crate::collab::access_control::{
CollabAccessControlImpl, CollabStorageAccessControlImpl, RealtimeCollabAccessControlImpl,
};
@ -27,6 +27,7 @@ use crate::collab::storage::CollabStorageImpl;
use crate::command::{CLCommandReceiver, CLCommandSender};
use crate::config::{Config, DatabaseSetting};
use crate::pg_listener::PgListeners;
use crate::shared_state::RealtimeSharedState;
use crate::snapshot::SnapshotControl;
use crate::state::{AppMetrics, AppState, UserCache};
use crate::CollaborationServer;
@ -85,6 +86,7 @@ pub async fn run_actix_server(
.app_data(Data::new(state.config.gotrue.jwt_secret.clone()))
.app_data(Data::new(realtime_server_actor.clone()))
.service(ws_scope())
.service(collab_scope())
});
server = server.listen(listener)?;
@ -97,8 +99,13 @@ pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result<A
// User cache
let user_cache = UserCache::new(pg_pool.clone()).await;
info!("Connecting to Redis...");
let redis_conn_manager = get_redis_client(config.redis_uri.expose_secret()).await?;
let realtime_shared_state = RealtimeSharedState::new(redis_conn_manager.clone());
if let Err(err) = realtime_shared_state.remove_all_connected_users().await {
warn!("Failed to remove all connected users: {:?}", err);
}
// Pg listeners
info!("Setting up Pg listeners...");
@ -146,6 +153,7 @@ pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result<A
access_control,
collab_access_control_storage: collab_storage,
metrics,
realtime_shared_state,
};
Ok(app_state)
}

View File

@ -0,0 +1,25 @@
use app_error::AppError;
use brotli::Decompressor;
use std::io::Read;
pub const X_COMPRESSION_TYPE: &str = "X-Compression-Type";
pub const X_COMPRESSION_BUFFER_SIZE: &str = "X-Compression-Buffer-Size";
pub enum CompressionType {
Brotli { buffer_size: usize },
}
pub async fn decompress(data: Vec<u8>, buffer_size: usize) -> Result<Vec<u8>, AppError> {
tokio::task::spawn_blocking(move || {
let mut decompressor = Decompressor::new(&*data, buffer_size);
let mut decompressed_data = Vec::new();
decompressor
.read_to_end(&mut decompressed_data)
.map_err(|err| {
AppError::InvalidRequest(format!("Failed to decompress data:{} {}", data.len(), err))
})?;
Ok(decompressed_data)
})
.await
.map_err(AppError::from)?
}

View File

@ -4,6 +4,7 @@ pub mod application;
mod client;
pub mod collab;
pub mod command;
pub mod compression;
pub mod config;
pub mod connect_state;
pub mod error;

View File

@ -14,6 +14,7 @@ use crate::collab::storage::CollabAccessControlStorage;
use crate::config::Config;
use crate::metrics::CollabMetrics;
use crate::pg_listener::PgListeners;
use crate::shared_state::RealtimeSharedState;
use crate::CollabRealtimeMetrics;
pub type RedisConnectionManager = redis::aio::ConnectionManager;
@ -27,6 +28,7 @@ pub struct AppState {
pub access_control: AccessControl,
pub collab_access_control_storage: Arc<CollabAccessControlStorage>,
pub metrics: AppMetrics,
pub realtime_shared_state: RealtimeSharedState,
}
#[derive(Clone)]