diff --git a/Cargo.lock b/Cargo.lock index dd026210..811af9d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5474,6 +5474,7 @@ dependencies = [ "actix-web", "anyhow", "app-error", + "appflowy-ai-client", "chrono", "collab-entity", "database-entity", diff --git a/Cargo.toml b/Cargo.toml index c61ff709..2d78ad09 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,7 +73,7 @@ tonic-proto.workspace = true appflowy-collaborate = { path = "services/appflowy-collaborate" } # ai -appflowy-ai-client = { path = "libs/appflowy-ai-client" } +appflowy-ai-client = { workspace = true, features = ["dto", "client-api"] } collab = { workspace = true } collab-document = { workspace = true } @@ -195,6 +195,7 @@ lazy_static = "1.4.0" tonic = "0.11" prost = "0.12" tonic-proto = { path = "libs/tonic-proto" } +appflowy-ai-client = { path = "libs/appflowy-ai-client" } # collaboration yrs = "0.18.7" diff --git a/deploy.env b/deploy.env index 3dc44717..c687b447 100644 --- a/deploy.env +++ b/deploy.env @@ -108,7 +108,6 @@ CLOUDFLARE_TUNNEL_TOKEN= # AppFlowy AI APPFLOWY_AI_OPENAI_API_KEY= -APPFLOWY_AI_SERVER_HOST=ai APPFLOWY_AI_SERVER_PORT=5001 # AppFlowy History diff --git a/dev.env b/dev.env index 6b25d2f2..2aa3cee3 100644 --- a/dev.env +++ b/dev.env @@ -100,7 +100,6 @@ CLOUDFLARE_TUNNEL_TOKEN= # AppFlowy AI APPFLOWY_AI_OPENAI_API_KEY= -APPFLOWY_AI_SERVER_HOST=localhost APPFLOWY_AI_SERVER_PORT=5001 # AppFlowy History diff --git a/libs/appflowy-ai-client/Cargo.toml b/libs/appflowy-ai-client/Cargo.toml index 95bc738a..d1ce896f 100644 --- a/libs/appflowy-ai-client/Cargo.toml +++ b/libs/appflowy-ai-client/Cargo.toml @@ -6,16 +6,17 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -reqwest = { version = "0.12", features = ["json", "rustls-tls", "cookies"] } -serde = { version = "1.0.199", features = ["derive"] } -serde_json = "1.0" +reqwest = { version = "0.12", features = ["json", "rustls-tls", "cookies"], optional = true } +serde = { version = "1.0.199", features = ["derive"], optional = true } +serde_json = { version = "1.0", optional = true } thiserror = "1.0.58" anyhow = "1.0.81" -tracing = "0.1" -serde_repr = "0.1.19" +tracing = { version = "0.1", optional = true } +serde_repr = { version = "0.1.19", optional = true } [dev-dependencies] tokio = { version = "1.37.0", features = ["macros", "test-util"] } [features] -verbose_log = [] \ No newline at end of file +client-api = ["dto", "reqwest", "serde", "serde_json", "tracing", "serde_repr"] +dto = ["serde", "serde_json", "serde_repr"] \ No newline at end of file diff --git a/libs/appflowy-ai-client/src/client.rs b/libs/appflowy-ai-client/src/client.rs index 30ec9429..c0aa7429 100644 --- a/libs/appflowy-ai-client/src/client.rs +++ b/libs/appflowy-ai-client/src/client.rs @@ -1,5 +1,5 @@ use crate::dto::{ - CompletionResponse, CompletionType, Document, SearchDocumentsRequest, SummarizeRowResponse, + CompleteTextResponse, CompletionType, Document, SearchDocumentsRequest, SummarizeRowResponse, TranslateRowResponse, }; use crate::error::AIError; @@ -27,7 +27,7 @@ impl AppFlowyAIClient { &self, text: &str, completion_type: CompletionType, - ) -> Result { + ) -> Result { if text.is_empty() { return Err(AIError::InvalidRequest("Empty text".to_string())); } @@ -43,7 +43,7 @@ impl AppFlowyAIClient { .json(¶ms) .send() .await?; - AIResponse::::from_response(resp) + AIResponse::::from_response(resp) .await? .into_data() } diff --git a/libs/appflowy-ai-client/src/dto.rs b/libs/appflowy-ai-client/src/dto.rs index 5a3ef5fd..37836a43 100644 --- a/libs/appflowy-ai-client/src/dto.rs +++ b/libs/appflowy-ai-client/src/dto.rs @@ -12,7 +12,7 @@ pub struct TranslateRowResponse { } #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct CompletionResponse { +pub struct CompleteTextResponse { pub text: String, } diff --git a/libs/appflowy-ai-client/src/lib.rs b/libs/appflowy-ai-client/src/lib.rs index f3ccaac2..f4f16228 100644 --- a/libs/appflowy-ai-client/src/lib.rs +++ b/libs/appflowy-ai-client/src/lib.rs @@ -1,3 +1,8 @@ +#[cfg(feature = "client-api")] pub mod client; + +#[cfg(feature = "dto")] pub mod dto; + +#[cfg(feature = "client-api")] pub mod error; diff --git a/libs/client-api/src/http.rs b/libs/client-api/src/http.rs index 78587f35..37bed9f4 100644 --- a/libs/client-api/src/http.rs +++ b/libs/client-api/src/http.rs @@ -44,7 +44,9 @@ use url::Url; use crate::ws::ConnectInfo; use gotrue_entity::dto::SignUpResponse::{Authenticated, NotAuthenticated}; use gotrue_entity::dto::{GotrueTokenResponse, UpdateGotrueUserParams, User}; -use shared_entity::dto::ai_dto::{SummarizeRowParams, SummarizeRowResponse}; +use shared_entity::dto::ai_dto::{ + CompleteTextParams, CompleteTextResponse, SummarizeRowParams, SummarizeRowResponse, +}; pub const X_COMPRESSION_TYPE: &str = "X-Compression-Type"; pub const X_COMPRESSION_BUFFER_SIZE: &str = "X-Compression-Buffer-Size"; @@ -1300,7 +1302,7 @@ impl Client { params: SummarizeRowParams, ) -> Result { let url = format!( - "{}/api/workspace/{}/summarize_row", + "{}/api/ai/{}/summarize_row", self.base_url, params.workspace_id ); @@ -1317,6 +1319,25 @@ impl Client { .into_data() } + #[instrument(level = "info", skip_all)] + pub async fn completion_text( + &self, + workspace_id: &str, + params: CompleteTextParams, + ) -> Result { + let url = format!("{}/api/ai/{}/complete_text", self.base_url, workspace_id); + let resp = self + .http_client_with_auth(Method::POST, &url) + .await? + .json(¶ms) + .send() + .await?; + log_request_id(&resp); + AppResponse::::from_response(resp) + .await? + .into_data() + } + #[instrument(level = "debug", skip_all, err)] pub async fn http_client_with_auth( &self, diff --git a/libs/shared-entity/Cargo.toml b/libs/shared-entity/Cargo.toml index 150f01f8..3222f0ae 100644 --- a/libs/shared-entity/Cargo.toml +++ b/libs/shared-entity/Cargo.toml @@ -20,6 +20,7 @@ database-entity.workspace = true collab-entity = { workspace = true } app-error = { workspace = true } chrono = "0.4.31" +appflowy-ai-client = { workspace = true, features = ["dto"] } actix-web = { version = "4.4.1", default-features = false, features = ["http2"], optional = true } validator = { version = "0.16", features = ["validator_derive", "derive"], optional = true } diff --git a/libs/shared-entity/src/dto/ai_dto.rs b/libs/shared-entity/src/dto/ai_dto.rs index 94209d17..64a9950e 100644 --- a/libs/shared-entity/src/dto/ai_dto.rs +++ b/libs/shared-entity/src/dto/ai_dto.rs @@ -1,6 +1,7 @@ use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; +pub use appflowy_ai_client::dto::*; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct SummarizeRowParams { pub workspace_id: String, @@ -27,3 +28,9 @@ pub enum SummarizeRowData { pub struct SummarizeRowResponse { pub text: String, } + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct CompleteTextParams { + pub text: String, + pub completion_type: CompletionType, +} diff --git a/src/api/ai_tool.rs b/src/api/ai_tool.rs new file mode 100644 index 00000000..452da634 --- /dev/null +++ b/src/api/ai_tool.rs @@ -0,0 +1,66 @@ +use crate::state::AppState; +use actix_web::web::{Data, Json}; +use actix_web::{web, Scope}; +use app_error::AppError; +use appflowy_ai_client::dto::CompleteTextResponse; +use shared_entity::dto::ai_dto::{ + CompleteTextParams, SummarizeRowData, SummarizeRowParams, SummarizeRowResponse, +}; +use shared_entity::response::{AppResponse, JsonAppResponse}; +use tracing::{error, instrument}; + +pub fn ai_tool_scope() -> Scope { + web::scope("/api/ai/{workspace_id}") + .service(web::resource("/complete_text").route(web::post().to(complete_text_handler))) + .service(web::resource("/summarize_row").route(web::post().to(summarize_row_handler))) +} + +async fn complete_text_handler( + state: Data, + payload: Json, +) -> actix_web::Result> { + let params = payload.into_inner(); + let resp = state + .ai_client + .completion_text(¶ms.text, params.completion_type) + .await + .map_err(|err| AppError::Internal(err.into()))?; + Ok(AppResponse::Ok().with_data(resp).into()) +} + +#[instrument(level = "debug", skip(state, payload), err)] +async fn summarize_row_handler( + state: Data, + payload: Json, +) -> actix_web::Result>> { + let params = payload.into_inner(); + match params.data { + SummarizeRowData::Identity { .. } => { + return Err(AppError::InvalidRequest("Identity data is not supported".to_string()).into()); + }, + SummarizeRowData::Content(content) => { + if content.is_empty() { + return Ok( + AppResponse::Ok() + .with_data(SummarizeRowResponse { + text: "No content".to_string(), + }) + .into(), + ); + } + + let result = state.ai_client.summarize_row(&content).await; + let resp = match result { + Ok(resp) => SummarizeRowResponse { text: resp.text }, + Err(err) => { + error!("Failed to summarize row: {:?}", err); + SummarizeRowResponse { + text: "No content".to_string(), + } + }, + }; + + Ok(AppResponse::Ok().with_data(resp).into()) + }, + } +} diff --git a/src/api/mod.rs b/src/api/mod.rs index bfc4ecc9..03852332 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,3 +1,4 @@ +pub mod ai_tool; pub mod chat; pub mod file_storage; pub mod metrics; diff --git a/src/api/workspace.rs b/src/api/workspace.rs index df216f4b..dd9c46fe 100644 --- a/src/api/workspace.rs +++ b/src/api/workspace.rs @@ -22,7 +22,6 @@ use collab_rt_protocol::validate_encode_collab; use database::collab::CollabStorage; use database::user::select_uid_from_email; use database_entity::dto::*; -use shared_entity::dto::ai_dto::{SummarizeRowData, SummarizeRowParams, SummarizeRowResponse}; use shared_entity::dto::workspace_dto::*; use shared_entity::response::AppResponseError; use shared_entity::response::{AppResponse, JsonAppResponse}; @@ -132,9 +131,6 @@ pub fn workspace_scope() -> Scope { // for GET request .route(web::post().to(batch_get_collab_handler)), ) - .service( - web::resource("/{workspace_id}/summarize_row").route(web::post().to(summary_row_handler)), - ) } pub fn collab_scope() -> Scope { @@ -1000,40 +996,3 @@ async fn parser_realtime_msg( ))), } } - -#[instrument(level = "debug", skip(state, payload), err)] -async fn summary_row_handler( - state: Data, - payload: Json, -) -> Result>> { - let params = payload.into_inner(); - match params.data { - SummarizeRowData::Identity { .. } => { - return Err(AppError::InvalidRequest("Identity data is not supported".to_string()).into()); - }, - SummarizeRowData::Content(content) => { - if content.is_empty() { - return Ok( - AppResponse::Ok() - .with_data(SummarizeRowResponse { - text: "No content".to_string(), - }) - .into(), - ); - } - - let result = state.ai_client.summarize_row(&content).await; - let resp = match result { - Ok(resp) => SummarizeRowResponse { text: resp.text }, - Err(err) => { - error!("Failed to summarize row: {:?}", err); - SummarizeRowResponse { - text: "No content".to_string(), - } - }, - }; - - Ok(AppResponse::Ok().with_data(resp).into()) - }, - } -} diff --git a/src/application.rs b/src/application.rs index c726b4c7..1c7382ae 100644 --- a/src/application.rs +++ b/src/application.rs @@ -7,6 +7,7 @@ use crate::api::ws::ws_scope; use crate::mailer::Mailer; use access_control::access::{enable_access_control, AccessControl}; +use crate::api::ai_tool::ai_tool_scope; use crate::api::chat::chat_scope; use crate::biz::actix_ws::server::RealtimeServerActor; use crate::biz::collab::access_control::{ @@ -139,6 +140,7 @@ pub async fn run_actix_server( .service(ws_scope()) .service(file_storage_scope()) .service(chat_scope()) + .service(ai_tool_scope()) .service(metrics_scope()) .app_data(Data::new(state.metrics.registry.clone())) .app_data(Data::new(state.metrics.request_metrics.clone())) diff --git a/tests/ai_test/complete_text.rs b/tests/ai_test/complete_text.rs new file mode 100644 index 00000000..fbcfd5d3 --- /dev/null +++ b/tests/ai_test/complete_text.rs @@ -0,0 +1,20 @@ +use appflowy_ai_client::dto::CompletionType; +use client_api_test::TestClient; +use shared_entity::dto::ai_dto::CompleteTextParams; + +#[tokio::test] +async fn improve_writing_test() { + let test_client = TestClient::new_user().await; + let workspace_id = test_client.workspace_id().await; + let params = CompleteTextParams { + text: "I feel hungry".to_string(), + completion_type: CompletionType::ImproveWriting, + }; + + let resp = test_client + .api_client + .completion_text(&workspace_id, params) + .await + .unwrap(); + assert!(resp.text.contains("hungry")); +} diff --git a/tests/ai_test/mod.rs b/tests/ai_test/mod.rs index 8a5982f6..b847286d 100644 --- a/tests/ai_test/mod.rs +++ b/tests/ai_test/mod.rs @@ -1 +1,2 @@ +mod complete_text; mod summarize_row;