diff --git a/Cargo.lock b/Cargo.lock index f0d2be3d..de045a43 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -649,6 +649,7 @@ dependencies = [ "prometheus-client", "prost", "rand 0.8.5", + "rayon", "rcgen", "redis 0.25.4", "reqwest 0.11.27", diff --git a/Cargo.toml b/Cargo.toml index 34418520..9244ac9a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -147,6 +147,7 @@ lettre = { version = "0.11.7", features = ["tokio1", "tokio1-native-tls"] } handlebars = "5.1.2" pin-project = "1.1.5" byteorder = "1.5.0" +rayon = "1.10.0" [dev-dependencies] diff --git a/src/api/workspace.rs b/src/api/workspace.rs index 8eaf98a1..098dd428 100644 --- a/src/api/workspace.rs +++ b/src/api/workspace.rs @@ -8,11 +8,12 @@ use collab::entity::EncodedCollab; use collab_entity::CollabType; use futures_util::future::try_join_all; use prost::Message as ProstMessage; +use rayon::prelude::*; use sqlx::types::uuid; -use tokio::time::Instant; + use tokio_stream::StreamExt; use tokio_tungstenite::tungstenite::Message; -use tracing::{error, event, instrument, trace}; +use tracing::{error, event, info, instrument, trace}; use uuid::Uuid; use validator::Validate; @@ -42,7 +43,9 @@ use crate::biz::workspace::ops::{ create_comment_on_published_view, create_reaction_on_comment, get_comments_on_published_view, get_reactions_on_published_view, remove_comment_on_published_view, remove_reaction_on_comment, }; -use crate::domain::compression::{decompress, CompressionType, X_COMPRESSION_TYPE}; +use crate::domain::compression::{ + blocking_decompress, decompress, CompressionType, X_COMPRESSION_TYPE, +}; use crate::state::AppState; pub const WORKSPACE_ID_PATH: &str = "workspace_id"; @@ -59,40 +62,39 @@ pub const WORKSPACE_PUBLISH_NAMESPACE_PATTERN: &str = pub fn workspace_scope() -> Scope { web::scope("/api/workspace") - .service(web::resource("") - .route(web::get().to(list_workspace_handler)) - .route(web::post().to(create_workspace_handler)) - .route(web::patch().to(patch_workspace_handler)) + .service( + web::resource("") + .route(web::get().to(list_workspace_handler)) + .route(web::post().to(create_workspace_handler)) + .route(web::patch().to(patch_workspace_handler)), ) .service( - web::resource("/{workspace_id}/invite") - .route(web::post().to(post_workspace_invite_handler)) // invite members to workspace + web::resource("/{workspace_id}/invite").route(web::post().to(post_workspace_invite_handler)), // invite members to workspace ) .service( - web::resource("/invite") - .route(web::get().to(get_workspace_invite_handler)) // show invites for user + web::resource("/invite").route(web::get().to(get_workspace_invite_handler)), // show invites for user ) .service( web::resource("/accept-invite/{invite_id}") - .route(web::post().to(post_accept_workspace_invite_handler)) // accept invitation to workspace + .route(web::post().to(post_accept_workspace_invite_handler)), // accept invitation to workspace ) - .service(web::resource("/{workspace_id}") - .route(web::delete().to(delete_workspace_handler)) - ) - .service(web::resource("/{workspace_id}/settings") + .service(web::resource("/{workspace_id}").route(web::delete().to(delete_workspace_handler))) + .service( + web::resource("/{workspace_id}/settings") .route(web::get().to(get_workspace_settings_handler)) - .route(web::post().to(post_workspace_settings_handler)) + .route(web::post().to(post_workspace_settings_handler)), ) .service(web::resource("/{workspace_id}/open").route(web::put().to(open_workspace_handler))) .service(web::resource("/{workspace_id}/leave").route(web::post().to(leave_workspace_handler))) .service( web::resource("/{workspace_id}/member") - .route(web::get().to(get_workspace_members_handler)) - .route(web::put().to(update_workspace_member_handler)) - .route(web::delete().to(remove_workspace_member_handler)) + .route(web::get().to(get_workspace_members_handler)) + .route(web::put().to(update_workspace_member_handler)) + .route(web::delete().to(remove_workspace_member_handler)), ) .service( - web::resource("/{workspace_id}/member/user/{user_id}").route(web::get().to(get_workspace_member_handler)) + web::resource("/{workspace_id}/member/user/{user_id}") + .route(web::get().to(get_workspace_member_handler)), ) .service( web::resource("/{workspace_id}/collab/{object_id}") @@ -106,18 +108,12 @@ pub fn workspace_scope() -> Scope { ) .service( web::resource("/v1/{workspace_id}/collab/{object_id}") - .route(web::get().to(v1_get_collab_handler)) + .route(web::get().to(v1_get_collab_handler)), ) .service( web::resource("/{workspace_id}/batch/collab") .route(web::post().to(batch_create_collab_handler)), ) - // will be deprecated - .service( - web::resource("/{workspace_id}/collabs") - .app_data(PayloadConfig::new(10 * 1024 * 1024)) - .route(web::post().to(create_collab_list_handler)), - ) .service( web::resource("/{workspace_id}/usage").route(web::get().to(get_workspace_usage_handler)), ) @@ -139,49 +135,48 @@ pub fn workspace_scope() -> Scope { ) .service( web::resource("/published/{publish_namespace}/{publish_name}") - .route(web::get().to(get_published_collab_handler)) + .route(web::get().to(get_published_collab_handler)), ) .service( web::resource("/published/{publish_namespace}/{publish_name}/blob") - .route(web::get().to(get_published_collab_blob_handler)) + .route(web::get().to(get_published_collab_blob_handler)), ) .service( web::resource("{workspace_id}/published-duplicate") - .route(web::post().to(post_published_duplicate_handler)) + .route(web::post().to(post_published_duplicate_handler)), ) .service( web::resource("/published-info/{view_id}") - .route(web::get().to(get_published_collab_info_handler)) + .route(web::get().to(get_published_collab_info_handler)), ) .service( web::resource("/published-info/{view_id}/comment") .route(web::get().to(get_published_collab_comment_handler)) .route(web::post().to(post_published_collab_comment_handler)) - .route(web::delete().to(delete_published_collab_comment_handler)) + .route(web::delete().to(delete_published_collab_comment_handler)), ) .service( web::resource("/published-info/{view_id}/reaction") .route(web::get().to(get_published_collab_reaction_handler)) .route(web::post().to(post_published_collab_reaction_handler)) - .route(web::delete().to(delete_published_collab_reaction_handler)) + .route(web::delete().to(delete_published_collab_reaction_handler)), ) .service( web::resource("/{workspace_id}/publish-namespace") .route(web::put().to(put_publish_namespace_handler)) - .route(web::get().to(get_publish_namespace_handler)) + .route(web::get().to(get_publish_namespace_handler)), ) .service( web::resource("/{workspace_id}/publish") .route(web::post().to(post_publish_collabs_handler)) - .route(web::delete().to(delete_published_collabs_handler)) + .route(web::delete().to(delete_published_collabs_handler)), ) .service( - web::resource("/{workspace_id}/folder") - .route(web::get().to(get_workspace_folder_handler)) + web::resource("/{workspace_id}/folder").route(web::get().to(get_workspace_folder_handler)), ) .service( web::resource("/published-outline/{publish_namespace}") - .route(web::get().to(get_workspace_publish_outline_handler)) + .route(web::get().to(get_workspace_publish_outline_handler)), ) .service( web::resource("/{workspace_id}/collab/{object_id}/member/list") @@ -504,7 +499,7 @@ async fn create_collab_handler( })?, Some(_) => match compress_type_from_header_value(req.headers())? { CompressionType::Brotli { buffer_size } => { - let decompress_data = decompress(payload.to_vec(), buffer_size).await?; + let decompress_data = blocking_decompress(payload.to_vec(), buffer_size).await?; CreateCollabParams::from_bytes(&decompress_data).map_err(|err| { AppError::InvalidRequest(format!( "Failed to parse CreateCollabParams with brotli decompression data: {}", @@ -583,68 +578,72 @@ async fn batch_create_collab_handler( req: HttpRequest, ) -> Result>> { let uid = state.user_cache.get_user_uid(&user_uuid).await?; - let mut collab_params_list = vec![]; let workspace_id = workspace_id.into_inner().to_string(); let compress_type = compress_type_from_header_value(req.headers())?; - event!( - tracing::Level::DEBUG, - "start decompressing collab params list" - ); + event!(tracing::Level::DEBUG, "start decompressing collab list"); - let start_time = Instant::now(); let mut payload_buffer = Vec::new(); + let mut offset_len_list = Vec::new(); + let mut current_offset = 0; + while let Some(item) = payload.next().await { if let Ok(bytes) = item { - match compress_type { - CompressionType::Brotli { buffer_size } => { - payload_buffer.extend_from_slice(&bytes); + payload_buffer.extend_from_slice(&bytes); + while current_offset + 4 <= payload_buffer.len() { + // The length of the next frame is determined by the first 4 bytes + let size = u32::from_be_bytes([ + payload_buffer[current_offset], + payload_buffer[current_offset + 1], + payload_buffer[current_offset + 2], + payload_buffer[current_offset + 3], + ]) as usize; - // The client API uses a u32 value as the frame separator, which determines the size of each data frame. - // The length of a u32 is fixed at 4 bytes. It's important not to change the size (length) of the u32, - // unless you also make a corresponding update in the client API. Any mismatch in frame size handling - // between the client and server could lead to incorrect data processing or communication errors. - while payload_buffer.len() >= 4 { - let size = u32::from_be_bytes([ - payload_buffer[0], - payload_buffer[1], - payload_buffer[2], - payload_buffer[3], - ]) as usize; + // Ensure there is enough data for the frame (4 bytes for size + `size` bytes for data) + if current_offset + 4 + size > payload_buffer.len() { + break; + } - if payload_buffer.len() < 4 + size { - break; - } - - let compressed_data = payload_buffer[4..4 + size].to_vec(); - let decompress_data = decompress(compressed_data, buffer_size).await?; - let params = CollabParams::from_bytes(&decompress_data).map_err(|err| { - AppError::InvalidRequest(format!( - "Failed to parse CollabParams with brotli decompression data: {}", - err - )) - })?; - params.validate().map_err(AppError::from)?; - match params.check_encode_collab().await { - Ok(_) => collab_params_list.push(params), - Err(err) => error!("Failed to validate collab params: {:?}", err), - } - - payload_buffer = payload_buffer[4 + size..].to_vec(); - } - }, + // Collect the (offset, len) for the current frame (data starts at current_offset + 4) + offset_len_list.push((current_offset + 4, size)); + current_offset += 4 + size; } } } - let duration = start_time.elapsed(); - event!( - tracing::Level::DEBUG, - "end decompressing collab params list, time taken: {:?}", - duration - ); + // Perform decompression and processing in a Rayon thread pool + let mut collab_params_list = tokio::task::spawn_blocking(move || { + match compress_type { + CompressionType::Brotli { buffer_size } => { + let list = offset_len_list + .par_iter() // Use Rayon parallel iterator + .filter_map(|(offset, len)| { + let compressed_data = &payload_buffer[*offset..*offset + *len]; + match decompress(compressed_data.to_vec(), buffer_size) { + Ok(decompressed_data) => { + if let Ok(params) = CollabParams::from_bytes(&decompressed_data) { + if params.validate().is_ok() { + return Some(params); + } + } + }, + Err(err) => { + error!("Failed to decompress data: {:?}", err); + }, + } + None + }) + .collect::>(); + Ok::<_, AppError>(list) + }, + } + }) + .await + .map_err(|_| AppError::InvalidRequest("Failed to decompress data".to_string()))??; + info!("batch create {} collab objects", collab_params_list.len()); if collab_params_list.is_empty() { return Err(AppError::InvalidRequest("Empty collab params list".to_string()).into()); } + if state .indexer_provider .can_index_workspace(&workspace_id) @@ -658,6 +657,8 @@ async fn batch_create_collab_handler( ); } } + + // Process each collab params for params in collab_params_list { let object_id = params.object_id.clone(); if validate_encode_collab( @@ -679,89 +680,6 @@ async fn batch_create_collab_handler( .await?; } } - Ok(Json(AppResponse::Ok())) -} - -#[instrument(skip(state, payload), err)] -async fn create_collab_list_handler( - user_uuid: UserUuid, - payload: Bytes, - state: Data, - req: HttpRequest, -) -> Result>> { - let uid = state.user_cache.get_user_uid(&user_uuid).await?; - let params = match req.headers().get(X_COMPRESSION_TYPE) { - None => BatchCreateCollabParams::from_bytes(&payload).map_err(|err| { - AppError::InvalidRequest(format!( - "Failed to parse batch BatchCreateCollabParams: {}", - err - )) - })?, - Some(_) => match compress_type_from_header_value(req.headers())? { - CompressionType::Brotli { buffer_size } => { - let decompress_data = decompress(payload.to_vec(), buffer_size).await?; - BatchCreateCollabParams::from_bytes(&decompress_data).map_err(|err| { - AppError::InvalidRequest(format!( - "Failed to parse BatchCreateCollabParams with decompression data: {}", - err - )) - })? - }, - }, - }; - - params.validate().map_err(AppError::from)?; - let BatchCreateCollabParams { - workspace_id, - params_list, - } = params; - - let mut valid_items = Vec::with_capacity(params_list.len()); - for params in params_list { - match params.check_encode_collab().await { - Ok(_) => valid_items.push(params), - Err(err) => error!("Failed to validate collab params: {:?}", err), - } - } - - if valid_items.is_empty() { - return Err(AppError::InvalidRequest("Empty collab params list".to_string()).into()); - } - - if state - .indexer_provider - .can_index_workspace(&workspace_id) - .await? - { - if let Err(err) = fetch_embeddings(&state.indexer_provider, &mut valid_items).await { - tracing::warn!( - "failed to fetch embeddings for {} new documents: {}", - valid_items.len(), - err - ); - } - } - - let mut transaction = state - .pg_pool - .begin() - .await - .map_err(|err| AppError::Internal(anyhow!("Failed to start inserting collab: {}", err)))?; - - for params in valid_items { - let _object_id = params.object_id.clone(); - state - .collab_access_control_storage - .insert_new_collab_with_transaction(&workspace_id, &uid, params, &mut transaction) - .await?; - } - - transaction.commit().await.map_err(|err| { - AppError::Internal(anyhow!( - "Failed to finish inserting list of collab: {}", - err - )) - })?; Ok(Json(AppResponse::Ok())) } @@ -1425,7 +1343,7 @@ async fn parser_realtime_msg( None => payload, Some(_) => match compress_type_from_header_value(req.headers())? { CompressionType::Brotli { buffer_size } => { - let decompressed_data = decompress(payload, buffer_size).await?; + let decompressed_data = blocking_decompress(payload, buffer_size).await?; event!( tracing::Level::TRACE, "Decompress realtime http message with len: {}", diff --git a/src/domain/compression.rs b/src/domain/compression.rs index 23d381ea..21356a2e 100644 --- a/src/domain/compression.rs +++ b/src/domain/compression.rs @@ -9,6 +9,14 @@ pub enum CompressionType { Brotli { buffer_size: usize }, } +impl CompressionType { + pub fn buffer_size(&self) -> usize { + match self { + CompressionType::Brotli { buffer_size } => *buffer_size, + } + } +} + pub async fn compress( data: Vec, quality: u32, @@ -26,17 +34,19 @@ pub async fn compress( .map_err(AppError::from)? } -pub async fn decompress(data: Vec, buffer_size: usize) -> Result, AppError> { - tokio::task::spawn_blocking(move || { - let mut decompressor = Decompressor::new(&*data, buffer_size); - let mut decompressed_data = Vec::new(); - decompressor - .read_to_end(&mut decompressed_data) - .map_err(|err| { - AppError::InvalidRequest(format!("Failed to decompress data:{} {}", data.len(), err)) - })?; - Ok(decompressed_data) - }) - .await - .map_err(AppError::from)? +pub fn decompress(data: Vec, buffer_size: usize) -> Result, AppError> { + let mut decompressor = Decompressor::new(&*data, buffer_size); + let mut decompressed_data = Vec::new(); + decompressor + .read_to_end(&mut decompressed_data) + .map_err(|err| { + AppError::InvalidRequest(format!("Failed to decompress data:{} {}", data.len(), err)) + })?; + Ok(decompressed_data) +} + +pub async fn blocking_decompress(data: Vec, buffer_size: usize) -> Result, AppError> { + tokio::task::spawn_blocking(move || decompress(data, buffer_size)) + .await + .map_err(AppError::from)? } diff --git a/tests/collab/collab_curd_test.rs b/tests/collab/collab_curd_test.rs index 2aee323e..d95ca67e 100644 --- a/tests/collab/collab_curd_test.rs +++ b/tests/collab/collab_curd_test.rs @@ -4,8 +4,7 @@ use collab::entity::EncodedCollab; use collab_document::document_data::default_document_collab_data; use collab_entity::CollabType; use database_entity::dto::{ - BatchCreateCollabParams, CollabParams, CreateCollabParams, QueryCollab, QueryCollabParams, - QueryCollabResult, + CollabParams, CreateCollabParams, QueryCollab, QueryCollabParams, QueryCollabResult, }; use reqwest::Method; @@ -222,68 +221,6 @@ async fn create_collab_compatibility_with_json_params_test() { assert_eq!(encoded_collab, encoded_collab_from_server); } -#[tokio::test] -async fn batch_create_collab_compatibility_with_uncompress_params_test() { - let test_client = TestClient::new_user().await; - let workspace_id = test_client.workspace_id().await; - let object_id = Uuid::new_v4().to_string(); - let api_client = &test_client.api_client; - let url = format!( - "{}/api/workspace/{}/collabs", - api_client.base_url, workspace_id, - ); - - let encoded_collab = test_encode_collab_v1(&object_id, "title", "hello world"); - let params = BatchCreateCollabParams { - workspace_id: workspace_id.to_string(), - params_list: vec![CollabParams { - object_id: object_id.clone(), - encoded_collab_v1: encoded_collab.encode_to_bytes().unwrap().into(), - collab_type: CollabType::Unknown, - embeddings: None, - }], - } - .to_bytes() - .unwrap(); - - test_client - .api_client - .http_client_with_auth(Method::POST, &url) - .await - .unwrap() - .body(params) - .send() - .await - .unwrap(); - - let url = format!( - "{}/api/workspace/{}/collab/{}", - api_client.base_url, workspace_id, &object_id - ); - let resp = test_client - .api_client - .http_client_with_auth(Method::GET, &url) - .await - .unwrap() - .json(&QueryCollabParams { - workspace_id, - inner: QueryCollab { - object_id: object_id.clone(), - collab_type: CollabType::Unknown, - }, - }) - .send() - .await - .unwrap(); - - let encoded_collab_from_server = AppResponse::::from_response(resp) - .await - .unwrap() - .into_data() - .unwrap(); - assert_eq!(encoded_collab, encoded_collab_from_server); -} - #[derive(Debug, Clone, Serialize)] pub struct OldCreateCollabParams { #[serde(flatten)]