diff --git a/libs/appflowy-ai-client/src/client.rs b/libs/appflowy-ai-client/src/client.rs index 8622cc45..a3d5c2eb 100644 --- a/libs/appflowy-ai-client/src/client.rs +++ b/libs/appflowy-ai-client/src/client.rs @@ -1,6 +1,6 @@ use crate::dto::{ AIModel, ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, CreateTextChatContext, - Document, EmbeddingRequest, EmbeddingResponse, LocalAIConfig, MessageData, + CustomPrompt, Document, EmbeddingRequest, EmbeddingResponse, LocalAIConfig, MessageData, RepeatedLocalAIPackage, RepeatedRelatedQuestion, SearchDocumentsRequest, SummarizeRowResponse, TranslateRowData, TranslateRowResponse, }; @@ -49,19 +49,29 @@ impl AppFlowyAIClient { Ok(()) } - pub async fn completion_text( + pub async fn completion_text>>( &self, text: &str, - completion_type: CompletionType, + completion_type: T, + custom_prompt: Option, model: AIModel, ) -> Result { + let completion_type = completion_type.into(); + + if completion_type.is_some() && custom_prompt.is_some() { + return Err(AIError::InvalidRequest( + "Cannot specify both completion_type and custom_prompt".to_string(), + )); + } + if text.is_empty() { return Err(AIError::InvalidRequest("Empty text".to_string())); } let params = json!({ "text": text, - "type": completion_type as u8, + "type": completion_type.map(|t| t as u8), + "custom_prompt": custom_prompt, }); let url = format!("{}/completion", self.url); @@ -76,19 +86,22 @@ impl AppFlowyAIClient { .into_data() } - pub async fn stream_completion_text( + pub async fn stream_completion_text>>( &self, text: &str, - completion_type: CompletionType, + completion_type: T, + custom_prompt: Option, model: AIModel, ) -> Result>, AIError> { + let completion_type = completion_type.into(); if text.is_empty() { return Err(AIError::InvalidRequest("Empty text".to_string())); } let params = json!({ "text": text, - "type": completion_type as u8, + "type": completion_type.map(|t| t as u8), + "custom_prompt": custom_prompt, }); let url = format!("{}/completion/stream", self.url); @@ -201,7 +214,7 @@ impl AppFlowyAIClient { chat_id: &str, content: &str, model: &AIModel, - metadata: Option, + metadata: Option, ) -> Result { let json = ChatQuestion { chat_id: chat_id.to_string(), @@ -226,7 +239,7 @@ impl AppFlowyAIClient { &self, chat_id: &str, content: &str, - metadata: Option, + metadata: Option, model: &AIModel, ) -> Result>, AIError> { let json = ChatQuestion { @@ -251,7 +264,7 @@ impl AppFlowyAIClient { &self, chat_id: &str, content: &str, - metadata: Option, + metadata: Option, model: &AIModel, ) -> Result>, AIError> { let json = ChatQuestion { @@ -337,7 +350,7 @@ pub struct AIResponse { impl AIResponse where - T: serde::de::DeserializeOwned + 'static, + T: DeserializeOwned + 'static, { pub async fn from_response(resp: reqwest::Response) -> Result { let status_code = resp.status(); diff --git a/libs/appflowy-ai-client/src/dto.rs b/libs/appflowy-ai-client/src/dto.rs index 408bbcbe..b7b78dfa 100644 --- a/libs/appflowy-ai-client/src/dto.rs +++ b/libs/appflowy-ai-client/src/dto.rs @@ -205,7 +205,7 @@ impl Display for EmbeddingsModel { pub enum AIModel { #[default] DefaultModel = 0, - GPT35 = 1, + GPT4oMini = 1, GPT4o = 2, Claude3Sonnet = 3, Claude3Opus = 4, @@ -215,7 +215,7 @@ impl AIModel { pub fn to_str(&self) -> &str { match self { AIModel::DefaultModel => "default-model", - AIModel::GPT35 => "gpt-3.5-turbo", + AIModel::GPT4oMini => "gpt-4o-mini", AIModel::GPT4o => "gpt-4o", AIModel::Claude3Sonnet => "claude-3-sonnet", AIModel::Claude3Opus => "claude-3-opus", @@ -228,7 +228,8 @@ impl FromStr for AIModel { fn from_str(s: &str) -> Result { match s { - "gpt-3.5-turbo" => Ok(AIModel::GPT35), + "gpt-3.5-turbo" => Ok(AIModel::GPT4oMini), + "gpt-4o-mini" => Ok(AIModel::GPT4oMini), "gpt-4o" => Ok(AIModel::GPT4o), "claude-3-sonnet" => Ok(AIModel::Claude3Sonnet), "claude-3-opus" => Ok(AIModel::Claude3Opus), @@ -364,3 +365,9 @@ impl Display for CreateTextChatContext { )) } } + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct CustomPrompt { + pub system: String, + pub user: Option, +} diff --git a/libs/appflowy-ai-client/tests/chat_test/completion_test.rs b/libs/appflowy-ai-client/tests/chat_test/completion_test.rs index 41f8d167..9a6d36a1 100644 --- a/libs/appflowy-ai-client/tests/chat_test/completion_test.rs +++ b/libs/appflowy-ai-client/tests/chat_test/completion_test.rs @@ -8,6 +8,7 @@ async fn continue_writing_test() { .completion_text( "I feel hungry", CompletionType::ContinueWriting, + None, AIModel::Claude3Sonnet, ) .await @@ -23,7 +24,8 @@ async fn improve_writing_test() { .completion_text( "I fell tired because i sleep not very well last night", CompletionType::ImproveWriting, - AIModel::GPT35, + None, + AIModel::GPT4oMini, ) .await .unwrap(); @@ -39,7 +41,8 @@ async fn make_text_shorter_text() { .stream_completion_text( "I have an immense passion and deep-seated affection for Rust, a modern, multi-paradigm, high-performance programming language that I find incredibly satisfying to use due to its focus on safety, speed, and concurrency", CompletionType::MakeShorter, - AIModel::GPT35 + None, + AIModel::GPT4oMini ) .await .unwrap(); diff --git a/libs/appflowy-ai-client/tests/chat_test/context_test.rs b/libs/appflowy-ai-client/tests/chat_test/context_test.rs index e82cda1c..61ba0717 100644 --- a/libs/appflowy-ai-client/tests/chat_test/context_test.rs +++ b/libs/appflowy-ai-client/tests/chat_test/context_test.rs @@ -14,7 +14,7 @@ async fn create_chat_context_test() { }; client.create_chat_text_context(context).await.unwrap(); let resp = client - .send_question(&chat_id, "Where I live?", &AIModel::GPT35, None) + .send_question(&chat_id, "Where I live?", &AIModel::GPT4oMini, None) .await .unwrap(); // response will be something like: diff --git a/libs/appflowy-ai-client/tests/chat_test/qa_test.rs b/libs/appflowy-ai-client/tests/chat_test/qa_test.rs index 54a400e0..a01c2120 100644 --- a/libs/appflowy-ai-client/tests/chat_test/qa_test.rs +++ b/libs/appflowy-ai-client/tests/chat_test/qa_test.rs @@ -10,13 +10,13 @@ async fn qa_test() { client.health_check().await.unwrap(); let chat_id = uuid::Uuid::new_v4().to_string(); let resp = client - .send_question(&chat_id, "I feel hungry", &AIModel::GPT35, None) + .send_question(&chat_id, "I feel hungry", &AIModel::GPT4o, None) .await .unwrap(); assert!(!resp.content.is_empty()); let questions = client - .get_related_question(&chat_id, &1, &AIModel::GPT35) + .get_related_question(&chat_id, &1, &AIModel::GPT4oMini) .await .unwrap() .items; @@ -29,7 +29,7 @@ async fn stop_stream_test() { client.health_check().await.unwrap(); let chat_id = uuid::Uuid::new_v4().to_string(); let mut stream = client - .stream_question(&chat_id, "I feel hungry", None, &AIModel::GPT35) + .stream_question(&chat_id, "I feel hungry", None, &AIModel::GPT4oMini) .await .unwrap(); @@ -51,7 +51,7 @@ async fn stream_test() { client.health_check().await.unwrap(); let chat_id = uuid::Uuid::new_v4().to_string(); let stream = client - .stream_question_v2(&chat_id, "I feel hungry", None, &AIModel::GPT35) + .stream_question_v2(&chat_id, "I feel hungry", None, &AIModel::GPT4oMini) .await .unwrap(); let json_stream = JsonStream::::new(stream); diff --git a/libs/appflowy-ai-client/tests/row_test/summarize_test.rs b/libs/appflowy-ai-client/tests/row_test/summarize_test.rs index 7945d93e..223debe0 100644 --- a/libs/appflowy-ai-client/tests/row_test/summarize_test.rs +++ b/libs/appflowy-ai-client/tests/row_test/summarize_test.rs @@ -9,7 +9,7 @@ async fn summarize_row_test() { let json = json!({"name": "Jack", "age": 25, "city": "New York"}); let result = client - .summarize_row(json.as_object().unwrap(), AIModel::GPT35) + .summarize_row(json.as_object().unwrap(), AIModel::GPT4oMini) .await .unwrap(); result.text.contains("Jack"); diff --git a/libs/appflowy-ai-client/tests/row_test/translate_test.rs b/libs/appflowy-ai-client/tests/row_test/translate_test.rs index 2045fc53..e4951812 100644 --- a/libs/appflowy-ai-client/tests/row_test/translate_test.rs +++ b/libs/appflowy-ai-client/tests/row_test/translate_test.rs @@ -20,6 +20,9 @@ async fn translate_row_test() { include_header: false, }; - let result = client.translate_row(data, AIModel::GPT35).await.unwrap(); + let result = client + .translate_row(data, AIModel::GPT4oMini) + .await + .unwrap(); assert_eq!(result.items.len(), 2); } diff --git a/libs/client-api/src/http.rs b/libs/client-api/src/http.rs index c8c30bbb..01c720a5 100644 --- a/libs/client-api/src/http.rs +++ b/libs/client-api/src/http.rs @@ -172,7 +172,7 @@ impl Client { ); } - let ai_model = Arc::new(RwLock::new(AIModel::GPT35)); + let ai_model = Arc::new(RwLock::new(AIModel::GPT4oMini)); Self { base_url: base_url.to_string(), diff --git a/libs/shared-entity/src/dto/ai_dto.rs b/libs/shared-entity/src/dto/ai_dto.rs index a889a1e4..dcb6253c 100644 --- a/libs/shared-entity/src/dto/ai_dto.rs +++ b/libs/shared-entity/src/dto/ai_dto.rs @@ -34,7 +34,18 @@ pub struct SummarizeRowResponse { #[derive(Clone, Debug, Serialize, Deserialize)] pub struct CompleteTextParams { pub text: String, - pub completion_type: CompletionType, + pub completion_type: Option, + pub custom_prompt: Option, +} + +impl CompleteTextParams { + pub fn new_with_completion_type(text: String, completion_type: CompletionType) -> Self { + Self { + text, + completion_type: Some(completion_type), + custom_prompt: None, + } + } } #[derive(Debug)] diff --git a/src/api/ai.rs b/src/api/ai.rs index 99fe8d4a..109fc59b 100644 --- a/src/api/ai.rs +++ b/src/api/ai.rs @@ -36,7 +36,7 @@ async fn complete_text_handler( let params = payload.into_inner(); let resp = state .ai_client - .completion_text(¶ms.text, params.completion_type, ai_model) + .completion_text(¶ms.text, params.completion_type, None, ai_model) .await .map_err(|err| AppError::Internal(err.into()))?; Ok(AppResponse::Ok().with_data(resp).into()) @@ -51,7 +51,12 @@ async fn stream_complete_text_handler( let params = payload.into_inner(); match state .ai_client - .stream_completion_text(¶ms.text, params.completion_type, ai_model) + .stream_completion_text( + ¶ms.text, + params.completion_type, + params.custom_prompt, + ai_model, + ) .await { Ok(stream) => Ok( diff --git a/src/api/util.rs b/src/api/util.rs index 73df5225..7b57ea62 100644 --- a/src/api/util.rs +++ b/src/api/util.rs @@ -179,5 +179,5 @@ pub(crate) fn ai_model_from_header(req: &HttpRequest) -> AIModel { let header = header.to_str().ok()?; AIModel::from_str(header).ok() }) - .unwrap_or(AIModel::GPT35) + .unwrap_or(AIModel::GPT4oMini) } diff --git a/tests/ai_test/complete_text.rs b/tests/ai_test/complete_text.rs index 8930f58c..a055c421 100644 --- a/tests/ai_test/complete_text.rs +++ b/tests/ai_test/complete_text.rs @@ -8,13 +8,13 @@ async fn improve_writing_test() { return; } let test_client = TestClient::new_user().await; - test_client.api_client.set_ai_model(AIModel::GPT4o); + test_client.api_client.set_ai_model(AIModel::GPT4oMini); let workspace_id = test_client.workspace_id().await; - let params = CompleteTextParams { - text: "I feel hungry".to_string(), - completion_type: CompletionType::ImproveWriting, - }; + let params = CompleteTextParams::new_with_completion_type( + "I feel hungry".to_string(), + CompletionType::ImproveWriting, + ); let resp = test_client .api_client