diff --git a/libs/appflowy-ai-client/src/dto.rs b/libs/appflowy-ai-client/src/dto.rs index a041abbc..fb5f0085 100644 --- a/libs/appflowy-ai-client/src/dto.rs +++ b/libs/appflowy-ai-client/src/dto.rs @@ -198,16 +198,18 @@ impl Display for EmbeddingsModel { #[repr(u8)] pub enum AIModel { #[default] - GPT35 = 0, - GPT4o = 1, - Claude3Sonnet = 2, - Claude3Opus = 3, - Local = 4, + DefaultModel = 0, + GPT35 = 1, + GPT4o = 2, + Claude3Sonnet = 3, + Claude3Opus = 4, + Local = 5, } impl AIModel { pub fn to_str(&self) -> &str { match self { + AIModel::DefaultModel => "default-model", AIModel::GPT35 => "gpt-3.5-turbo", AIModel::GPT4o => "gpt-4o", AIModel::Claude3Sonnet => "claude-3-sonnet-20240229", @@ -227,7 +229,7 @@ impl FromStr for AIModel { "claude-3-sonnet" => Ok(AIModel::Claude3Sonnet), "claude-3-opus" => Ok(AIModel::Claude3Opus), "local" => Ok(AIModel::Local), - _ => Ok(AIModel::GPT35), + _ => Ok(AIModel::DefaultModel), } } } diff --git a/libs/client-api/Cargo.toml b/libs/client-api/Cargo.toml index b6bb87a0..54150512 100644 --- a/libs/client-api/Cargo.toml +++ b/libs/client-api/Cargo.toml @@ -72,4 +72,4 @@ test_util = ["scraper"] template = ["workspace-template"] sync_verbose_log = ["collab-rt-protocol/verbose_log"] test_fast_sync = [] -enable_brotli = ["brotli"] +enable_brotli = ["brotli"] \ No newline at end of file diff --git a/libs/client-api/src/http_settings.rs b/libs/client-api/src/http_settings.rs index 4ca4f0a6..8af82645 100644 --- a/libs/client-api/src/http_settings.rs +++ b/libs/client-api/src/http_settings.rs @@ -4,6 +4,7 @@ use tracing::instrument; use client_api_entity::AFWorkspaceSettings; use shared_entity::response::{AppResponse, AppResponseError}; +use crate::entity::AFWorkspaceSettingsChange; use crate::http::log_request_id; use crate::Client; @@ -32,8 +33,8 @@ impl Client { pub async fn update_workspace_settings>( &self, workspace_id: T, - settings: &AFWorkspaceSettings, - ) -> Result<(), AppResponseError> { + changes: &AFWorkspaceSettingsChange, + ) -> Result { let url = format!( "{}/api/workspace/{}/settings", self.base_url, @@ -42,10 +43,11 @@ impl Client { let resp = self .http_client_with_auth(Method::POST, &url) .await? - .json(&settings) + .json(&changes) .send() .await?; log_request_id(&resp); - AppResponse::<()>::from_response(resp).await?.into_error() + let resp = AppResponse::::from_response(resp).await?; + resp.into_data() } } diff --git a/libs/database-entity/src/dto.rs b/libs/database-entity/src/dto.rs index 58d4981c..4c9b62ab 100644 --- a/libs/database-entity/src/dto.rs +++ b/libs/database-entity/src/dto.rs @@ -511,10 +511,47 @@ pub struct AFWorkspace { #[derive(Serialize, Deserialize)] pub struct AFWorkspaces(pub Vec); -#[derive(Default, Serialize, Deserialize)] +#[derive(Serialize, Deserialize)] pub struct AFWorkspaceSettings { #[serde(default)] - pub disable_indexing: bool, + pub disable_search_indexing: bool, + + #[serde(default)] + pub ai_model: String, +} + +impl Default for AFWorkspaceSettings { + fn default() -> Self { + Self { + disable_search_indexing: false, + ai_model: "".to_string(), + } + } +} + +#[derive(Default, Serialize, Deserialize)] +pub struct AFWorkspaceSettingsChange { + #[serde(skip_serializing_if = "Option::is_none")] + pub disable_search_indexing: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub ai_model: Option, +} + +impl AFWorkspaceSettingsChange { + pub fn new() -> Self { + Self { + disable_search_indexing: None, + ai_model: None, + } + } + pub fn disable_search_indexing(mut self, disable_search_indexing: bool) -> Self { + self.disable_search_indexing = Some(disable_search_indexing); + self + } + pub fn ai_model(mut self, ai_model: String) -> Self { + self.ai_model = Some(ai_model); + self + } } #[derive(Serialize, Deserialize)] diff --git a/src/api/workspace.rs b/src/api/workspace.rs index 18d42ca3..9cc85b3d 100644 --- a/src/api/workspace.rs +++ b/src/api/workspace.rs @@ -324,18 +324,18 @@ async fn post_workspace_settings_handler( user_uuid: UserUuid, state: Data, workspace_id: web::Path, - data: Json, -) -> Result> { + data: Json, +) -> Result> { let uid = state.user_cache.get_user_uid(&user_uuid).await?; - workspace::ops::update_workspace_settings( + let settings = workspace::ops::update_workspace_settings( &state.pg_pool, &state.workspace_access_control, &workspace_id, &uid, - &data.into_inner(), + data.into_inner(), ) .await?; - Ok(AppResponse::Ok().into()) + Ok(AppResponse::Ok().with_data(settings).into()) } #[instrument(skip_all, err)] diff --git a/src/biz/workspace/ops.rs b/src/biz/workspace/ops.rs index c10115be..6a684c5a 100644 --- a/src/biz/workspace/ops.rs +++ b/src/biz/workspace/ops.rs @@ -1,4 +1,4 @@ -use database_entity::dto::PublishCollabItem; +use database_entity::dto::{AFWorkspaceSettingsChange, PublishCollabItem}; use std::collections::HashMap; use database_entity::dto::PublishInfo; @@ -517,8 +517,8 @@ pub async fn update_workspace_settings( workspace_access_control: &impl WorkspaceAccessControl, workspace_id: &Uuid, owner_uid: &i64, - workspace_settings: &AFWorkspaceSettings, -) -> Result<(), AppResponseError> { + change: AFWorkspaceSettingsChange, +) -> Result { let has_access = workspace_access_control .enforce_role(owner_uid, &workspace_id.to_string(), AFRole::Owner) .await?; @@ -531,9 +531,21 @@ pub async fn update_workspace_settings( } let mut tx = pg_pool.begin().await?; - upsert_workspace_settings(&mut tx, workspace_id, workspace_settings).await?; + let mut setting = select_workspace_settings(tx.deref_mut(), workspace_id) + .await? + .unwrap_or_default(); + if let Some(disable_indexing) = change.disable_search_indexing { + setting.disable_search_indexing = disable_indexing; + } + + if let Some(ai_model) = change.ai_model { + setting.ai_model = ai_model; + } + + // Update the workspace settings in the database + upsert_workspace_settings(&mut tx, workspace_id, &setting).await?; tx.commit().await?; - Ok(()) + Ok(setting) } async fn check_workspace_owner( diff --git a/tests/workspace/workspace_settings.rs b/tests/workspace/workspace_settings.rs index 5f61250c..4166cdee 100644 --- a/tests/workspace/workspace_settings.rs +++ b/tests/workspace/workspace_settings.rs @@ -1,7 +1,7 @@ use app_error::ErrorCode; use client_api::Client; use client_api_test::generate_unique_registered_user_client; -use database_entity::dto::{AFRole, AFWorkspaceInvitationStatus, AFWorkspaceSettings}; +use database_entity::dto::{AFRole, AFWorkspaceInvitationStatus, AFWorkspaceSettingsChange}; use shared_entity::dto::workspace_dto::WorkspaceMemberInvitation; use uuid::Uuid; @@ -13,17 +13,20 @@ async fn get_and_set_workspace_by_owner() { let mut settings = c.get_workspace_settings(&workspace_id).await.unwrap(); assert!( - !settings.disable_indexing, + !settings.disable_search_indexing, "indexing should be enabled by default" ); - settings.disable_indexing = true; - c.update_workspace_settings(&workspace_id, &settings) - .await - .unwrap(); + settings.disable_search_indexing = true; + c.update_workspace_settings( + &workspace_id, + &AFWorkspaceSettingsChange::new().disable_search_indexing(true), + ) + .await + .unwrap(); let settings = c.get_workspace_settings(&workspace_id).await.unwrap(); - assert!(settings.disable_indexing); + assert!(settings.disable_search_indexing); } #[tokio::test] @@ -48,9 +51,7 @@ async fn get_and_set_workspace_by_non_owner() { let resp = bob_client .update_workspace_settings( &alice_workspace_id.to_string(), - &AFWorkspaceSettings { - disable_indexing: true, - }, + &AFWorkspaceSettingsChange::new().disable_search_indexing(true), ) .await; assert!(