diff --git a/libs/appflowy-ai-client/src/client.rs b/libs/appflowy-ai-client/src/client.rs index f2f79b89..1a39c6fa 100644 --- a/libs/appflowy-ai-client/src/client.rs +++ b/libs/appflowy-ai-client/src/client.rs @@ -67,6 +67,31 @@ impl AppFlowyAIClient { .into_data() } + pub async fn stream_completion_text( + &self, + text: &str, + completion_type: CompletionType, + model: AIModel, + ) -> Result>, AIError> { + if text.is_empty() { + return Err(AIError::InvalidRequest("Empty text".to_string())); + } + + let params = json!({ + "text": text, + "type": completion_type as u8, + }); + + let url = format!("{}/completion/stream", self.url); + let resp = self + .http_client(Method::POST, &url)? + .header(AI_MODEL_HEADER_KEY, model.to_str()) + .json(¶ms) + .send() + .await?; + AIResponse::<()>::stream_response(resp).await + } + pub async fn summarize_row( &self, params: &Map, diff --git a/libs/appflowy-ai-client/src/dto.rs b/libs/appflowy-ai-client/src/dto.rs index fb5f0085..28a0521a 100644 --- a/libs/appflowy-ai-client/src/dto.rs +++ b/libs/appflowy-ai-client/src/dto.rs @@ -203,7 +203,7 @@ pub enum AIModel { GPT4o = 2, Claude3Sonnet = 3, Claude3Opus = 4, - Local = 5, + Local = 5, // work in progress } impl AIModel { @@ -212,7 +212,7 @@ impl AIModel { AIModel::DefaultModel => "default-model", AIModel::GPT35 => "gpt-3.5-turbo", AIModel::GPT4o => "gpt-4o", - AIModel::Claude3Sonnet => "claude-3-sonnet-20240229", + AIModel::Claude3Sonnet => "claude-3-sonnet", AIModel::Claude3Opus => "claude-3-opus", AIModel::Local => "local", } diff --git a/libs/appflowy-ai-client/tests/chat_test/completion_test.rs b/libs/appflowy-ai-client/tests/chat_test/completion_test.rs index 826bdc9d..41f8d167 100644 --- a/libs/appflowy-ai-client/tests/chat_test/completion_test.rs +++ b/libs/appflowy-ai-client/tests/chat_test/completion_test.rs @@ -1,6 +1,6 @@ use crate::appflowy_ai_client; use appflowy_ai_client::dto::{AIModel, CompletionType}; - +use futures::stream::StreamExt; #[tokio::test] async fn continue_writing_test() { let client = appflowy_ai_client(); @@ -35,8 +35,8 @@ async fn improve_writing_test() { #[tokio::test] async fn make_text_shorter_text() { let client = appflowy_ai_client(); - let resp = client - .completion_text( + let stream = client + .stream_completion_text( "I have an immense passion and deep-seated affection for Rust, a modern, multi-paradigm, high-performance programming language that I find incredibly satisfying to use due to its focus on safety, speed, and concurrency", CompletionType::MakeShorter, AIModel::GPT35 @@ -44,8 +44,19 @@ async fn make_text_shorter_text() { .await .unwrap(); + let stream = stream.map(|item| { + item.map(|bytes| { + String::from_utf8(bytes.to_vec()) + .map(|s| s.replace('\n', "")) + .unwrap() + }) + }); + + let lines: Vec = stream.map(|message| message.unwrap()).collect().await; + let text = lines.join(""); + // the response would be something like: // I'm deeply passionate about Rust, a modern, high-performance programming language, due to its emphasis on safety, speed, and concurrency - assert!(!resp.text.is_empty()); - println!("{}", resp.text); + assert!(!text.is_empty()); + println!("{}", text); } diff --git a/libs/client-api/src/http.rs b/libs/client-api/src/http.rs index 8940d812..81ab9ea8 100644 --- a/libs/client-api/src/http.rs +++ b/libs/client-api/src/http.rs @@ -815,15 +815,27 @@ impl Client { .await?; let access_token = self.access_token()?; - trace!("start request: {}, method: {}", url, method); - let request_builder = self + let headers = [ + ("client-version", self.client_version.to_string()), + ("client-timestamp", ts_now.to_string()), + ("device_id", self.device_id.clone()), + ("ai-model", self.ai_model.read().to_str().to_string()), + ]; + trace!( + "start request: {}, method: {}, headers: {:?}", + url, + method, + headers + ); + + let mut request_builder = self .cloud_client .request(method, url) - .header("client-version", self.client_version.to_string()) - .header("client-timestamp", ts_now.to_string()) - .header("device_id", self.device_id.clone()) - .header("ai-model", self.ai_model.read().to_str()) .bearer_auth(access_token); + + for header in headers { + request_builder = request_builder.header(header.0, header.1); + } Ok(request_builder) } diff --git a/libs/client-api/src/http_ai.rs b/libs/client-api/src/http_ai.rs index 44426c42..908c465c 100644 --- a/libs/client-api/src/http_ai.rs +++ b/libs/client-api/src/http_ai.rs @@ -61,7 +61,7 @@ impl Client { workspace_id: &str, params: CompleteTextParams, ) -> Result { - let url = format!("{}/api/ai/{}/complete_text", self.base_url, workspace_id); + let url = format!("{}/api/ai/{}/complete", self.base_url, workspace_id); let resp = self .http_client_with_auth(Method::POST, &url) .await? diff --git a/libs/client-api/src/native/http_native.rs b/libs/client-api/src/native/http_native.rs index 71d87250..18b4ea44 100644 --- a/libs/client-api/src/native/http_native.rs +++ b/libs/client-api/src/native/http_native.rs @@ -32,8 +32,25 @@ use tokio_retry::{Retry, RetryIf}; use tracing::{event, info, instrument, trace}; pub use infra::file_util::ChunkedBytes; +use shared_entity::dto::ai_dto::CompleteTextParams; impl Client { + pub async fn stream_completion_text( + &self, + workspace_id: &str, + params: CompleteTextParams, + ) -> Result>, AppResponseError> { + let url = format!("{}/api/ai/{}/complete/stream", self.base_url, workspace_id); + let resp = self + .http_client_with_auth(Method::POST, &url) + .await? + .json(¶ms) + .send() + .await?; + log_request_id(&resp); + AppResponse::<()>::answer_response_stream(resp).await + } + pub async fn create_upload( &self, workspace_id: &str, diff --git a/src/api/ai.rs b/src/api/ai.rs index 2182f155..8248dd44 100644 --- a/src/api/ai.rs +++ b/src/api/ai.rs @@ -1,9 +1,10 @@ use crate::api::util::ai_model_from_header; use crate::state::AppState; use actix_web::web::{Data, Json}; -use actix_web::{web, HttpRequest, Scope}; +use actix_web::{web, HttpRequest, HttpResponse, Scope}; use app_error::AppError; use appflowy_ai_client::dto::{CompleteTextResponse, TranslateRowParams, TranslateRowResponse}; +use futures_util::TryStreamExt; use shared_entity::dto::ai_dto::{ CompleteTextParams, SummarizeRowData, SummarizeRowParams, SummarizeRowResponse, }; @@ -13,7 +14,8 @@ use tracing::{error, instrument}; pub fn ai_completion_scope() -> Scope { web::scope("/api/ai/{workspace_id}") - .service(web::resource("/complete_text").route(web::post().to(complete_text_handler))) + .service(web::resource("/complete").route(web::post().to(complete_text_handler))) + .service(web::resource("/complete/stream").route(web::post().to(stream_complete_text_handler))) .service(web::resource("/summarize_row").route(web::post().to(summarize_row_handler))) .service(web::resource("/translate_row").route(web::post().to(translate_row_handler))) } @@ -33,6 +35,25 @@ async fn complete_text_handler( Ok(AppResponse::Ok().with_data(resp).into()) } +async fn stream_complete_text_handler( + state: Data, + payload: Json, + req: HttpRequest, +) -> actix_web::Result { + let ai_model = ai_model_from_header(&req); + let params = payload.into_inner(); + let stream = state + .ai_client + .stream_completion_text(¶ms.text, params.completion_type, ai_model) + .await + .map_err(|err| AppError::Internal(err.into()))?; + Ok( + HttpResponse::Ok() + .content_type("text/event-stream") + .streaming(stream.map_err(AppError::from)), + ) +} + #[instrument(level = "debug", skip(state, payload), err)] async fn summarize_row_handler( state: Data,