diff --git a/libs/appflowy-ai-client/src/client.rs b/libs/appflowy-ai-client/src/client.rs index ceb98c98..78df90ef 100644 --- a/libs/appflowy-ai-client/src/client.rs +++ b/libs/appflowy-ai-client/src/client.rs @@ -1,9 +1,8 @@ use crate::dto::{ - AIModel, CalculateSimilarityParams, ChatAnswer, ChatQuestion, CompleteTextResponse, - CompletionType, CreateChatContext, CustomPrompt, Document, LocalAIConfig, MessageData, - QuestionMetadata, RepeatedLocalAIPackage, RepeatedRelatedQuestion, ResponseFormat, - SearchDocumentsRequest, SimilarityResponse, SummarizeRowResponse, TranslateRowData, - TranslateRowResponse, + AIModel, CalculateSimilarityParams, ChatAnswer, ChatQuestion, CompletionType, CreateChatContext, + CustomPrompt, Document, LocalAIConfig, MessageData, QuestionMetadata, RepeatedLocalAIPackage, + RepeatedRelatedQuestion, ResponseFormat, SearchDocumentsRequest, SimilarityResponse, + SummarizeRowResponse, TranslateRowData, TranslateRowResponse, }; use crate::error::AIError; @@ -42,43 +41,6 @@ impl AppFlowyAIClient { Ok(()) } - pub async fn completion_text>>( - &self, - text: &str, - completion_type: T, - custom_prompt: Option, - model: AIModel, - ) -> Result { - let completion_type = completion_type.into(); - - if completion_type.is_some() && custom_prompt.is_some() { - return Err(AIError::InvalidRequest( - "Cannot specify both completion_type and custom_prompt".to_string(), - )); - } - - if text.is_empty() { - return Err(AIError::InvalidRequest("Empty text".to_string())); - } - - let params = json!({ - "text": text, - "type": completion_type.map(|t| t as u8), - "custom_prompt": custom_prompt, - }); - - let url = format!("{}/completion", self.url); - let resp = self - .async_http_client(Method::POST, &url)? - .header(AI_MODEL_HEADER_KEY, model.to_str()) - .json(¶ms) - .send() - .await?; - AIResponse::::from_reqwest_response(resp) - .await? - .into_data() - } - pub async fn stream_completion_text>>( &self, text: &str, @@ -443,3 +405,16 @@ impl From for AIError { AIError::Internal(error.into()) } } + +pub async fn collect_stream_text(stream: impl Stream>) -> String { + let stream = stream.map(|item| { + item.map(|bytes| { + String::from_utf8(bytes.to_vec()) + .map(|s| s.replace('\n', "")) + .unwrap() + }) + }); + + let lines: Vec = stream.map(|message| message.unwrap()).collect().await; + lines.join("") +} diff --git a/libs/appflowy-ai-client/tests/chat_test/completion_test.rs b/libs/appflowy-ai-client/tests/chat_test/completion_test.rs index 9a6d36a1..38532a12 100644 --- a/libs/appflowy-ai-client/tests/chat_test/completion_test.rs +++ b/libs/appflowy-ai-client/tests/chat_test/completion_test.rs @@ -1,11 +1,11 @@ use crate::appflowy_ai_client; +use appflowy_ai_client::client::collect_stream_text; use appflowy_ai_client::dto::{AIModel, CompletionType}; -use futures::stream::StreamExt; #[tokio::test] async fn continue_writing_test() { let client = appflowy_ai_client(); - let resp = client - .completion_text( + let stream = client + .stream_completion_text( "I feel hungry", CompletionType::ContinueWriting, None, @@ -13,15 +13,16 @@ async fn continue_writing_test() { ) .await .unwrap(); - assert!(!resp.text.is_empty()); - println!("{}", resp.text); + let text = collect_stream_text(stream).await; + assert!(!text.is_empty()); + println!("{}", text); } #[tokio::test] async fn improve_writing_test() { let client = appflowy_ai_client(); - let resp = client - .completion_text( + let stream = client + .stream_completion_text( "I fell tired because i sleep not very well last night", CompletionType::ImproveWriting, None, @@ -30,9 +31,11 @@ async fn improve_writing_test() { .await .unwrap(); + let text = collect_stream_text(stream).await; + // the response would be something like: I feel exhausted due to a restless night of sleep. - assert!(!resp.text.is_empty()); - println!("{}", resp.text); + assert!(!text.is_empty()); + println!("{}", text); } #[tokio::test] async fn make_text_shorter_text() { @@ -47,16 +50,7 @@ async fn make_text_shorter_text() { .await .unwrap(); - let stream = stream.map(|item| { - item.map(|bytes| { - String::from_utf8(bytes.to_vec()) - .map(|s| s.replace('\n', "")) - .unwrap() - }) - }); - - let lines: Vec = stream.map(|message| message.unwrap()).collect().await; - let text = lines.join(""); + let text = collect_stream_text(stream).await; // the response would be something like: // I'm deeply passionate about Rust, a modern, high-performance programming language, due to its emphasis on safety, speed, and concurrency diff --git a/libs/client-api/src/http_ai.rs b/libs/client-api/src/http_ai.rs index a5c314d0..6910a74a 100644 --- a/libs/client-api/src/http_ai.rs +++ b/libs/client-api/src/http_ai.rs @@ -4,8 +4,8 @@ use bytes::Bytes; use futures_core::Stream; use reqwest::Method; use shared_entity::dto::ai_dto::{ - CompleteTextParams, CompleteTextResponse, LocalAIConfig, SummarizeRowParams, - SummarizeRowResponse, TranslateRowParams, TranslateRowResponse, + CompleteTextParams, LocalAIConfig, SummarizeRowParams, SummarizeRowResponse, TranslateRowParams, + TranslateRowResponse, }; use shared_entity::response::{AppResponse, AppResponseError}; use std::time::Duration; @@ -75,26 +75,6 @@ impl Client { .into_data() } - #[instrument(level = "info", skip_all)] - pub async fn completion_text( - &self, - workspace_id: &str, - params: CompleteTextParams, - ) -> Result { - let url = format!("{}/api/ai/{}/complete", self.base_url, workspace_id); - let resp = self - .http_client_with_auth(Method::POST, &url) - .await? - .json(¶ms) - .timeout(Duration::from_secs(30)) - .send() - .await?; - log_request_id(&resp); - AppResponse::::from_response(resp) - .await? - .into_data() - } - #[instrument(level = "info", skip_all, err)] pub async fn get_local_ai_config( &self, diff --git a/src/api/ai.rs b/src/api/ai.rs index 60ba4cdb..949eb3da 100644 --- a/src/api/ai.rs +++ b/src/api/ai.rs @@ -5,8 +5,8 @@ use actix_web::web::{Data, Json}; use actix_web::{web, HttpRequest, HttpResponse, Scope}; use app_error::AppError; use appflowy_ai_client::dto::{ - CalculateSimilarityParams, CompleteTextResponse, LocalAIConfig, SimilarityResponse, - TranslateRowParams, TranslateRowResponse, + CalculateSimilarityParams, LocalAIConfig, SimilarityResponse, TranslateRowParams, + TranslateRowResponse, }; use futures_util::{stream, TryStreamExt}; @@ -15,13 +15,12 @@ use serde::Deserialize; use shared_entity::dto::ai_dto::{ CompleteTextParams, SummarizeRowData, SummarizeRowParams, SummarizeRowResponse, }; -use shared_entity::response::{AppResponse, JsonAppResponse}; +use shared_entity::response::AppResponse; use tracing::{error, instrument, trace}; pub fn ai_completion_scope() -> Scope { web::scope("/api/ai/{workspace_id}") - .service(web::resource("/complete").route(web::post().to(complete_text_handler))) .service(web::resource("/complete/stream").route(web::post().to(stream_complete_text_handler))) .service(web::resource("/summarize_row").route(web::post().to(summarize_row_handler))) .service(web::resource("/translate_row").route(web::post().to(translate_row_handler))) @@ -31,21 +30,6 @@ pub fn ai_completion_scope() -> Scope { ) } -async fn complete_text_handler( - state: Data, - payload: Json, - req: HttpRequest, -) -> actix_web::Result> { - let ai_model = ai_model_from_header(&req); - let params = payload.into_inner(); - let resp = state - .ai_client - .completion_text(¶ms.text, params.completion_type, None, ai_model) - .await - .map_err(|err| AppError::Internal(err.into()))?; - Ok(AppResponse::Ok().with_data(resp).into()) -} - async fn stream_complete_text_handler( state: Data, payload: Json, diff --git a/tests/ai_test/complete_text.rs b/tests/ai_test/complete_text.rs deleted file mode 100644 index f3684c5e..00000000 --- a/tests/ai_test/complete_text.rs +++ /dev/null @@ -1,25 +0,0 @@ -use appflowy_ai_client::dto::{AIModel, CompletionType}; -use client_api_test::{ai_test_enabled, TestClient}; -use shared_entity::dto::ai_dto::CompleteTextParams; - -#[tokio::test] -async fn improve_writing_test() { - if !ai_test_enabled() { - return; - } - let test_client = TestClient::new_user().await; - test_client.api_client.set_ai_model(AIModel::GPT4oMini); - - let workspace_id = test_client.workspace_id().await; - let params = CompleteTextParams::new_with_completion_type( - "I feel hungry".to_string(), - CompletionType::ImproveWriting, - ); - - let resp = test_client - .api_client - .completion_text(&workspace_id, params) - .await - .unwrap(); - assert!(!resp.text.is_empty()); -} diff --git a/tests/ai_test/mod.rs b/tests/ai_test/mod.rs index bbcc4d7d..08e391a3 100644 --- a/tests/ai_test/mod.rs +++ b/tests/ai_test/mod.rs @@ -1,5 +1,4 @@ mod chat_test; -mod complete_text; // mod local_ai_test; mod summarize_row; mod util;