From 82409199f8ffa0166f2f5d9403ccd55831890549 Mon Sep 17 00:00:00 2001 From: "Nathan.fooo" <86001920+appflowy@users.noreply.github.com> Date: Sat, 1 Feb 2025 22:47:46 +0800 Subject: [PATCH] chore: remove ai model enum (#1207) --- libs/appflowy-ai-client/src/client.rs | 38 +++++++++--------- libs/appflowy-ai-client/src/dto.rs | 40 ------------------- .../tests/chat_test/completion_test.rs | 8 ++-- .../tests/chat_test/context_test.rs | 4 +- .../tests/chat_test/qa_test.rs | 6 +-- .../tests/row_test/summarize_test.rs | 4 +- .../tests/row_test/translate_test.rs | 7 +--- libs/client-api/src/http.rs | 9 ++--- libs/shared-entity/src/dto/chat_dto.rs | 3 +- src/api/chat.rs | 8 ++-- src/api/util.rs | 10 ++--- src/biz/chat/ops.rs | 12 +++--- 12 files changed, 48 insertions(+), 101 deletions(-) diff --git a/libs/appflowy-ai-client/src/client.rs b/libs/appflowy-ai-client/src/client.rs index 9a08a726..f0bf12f8 100644 --- a/libs/appflowy-ai-client/src/client.rs +++ b/libs/appflowy-ai-client/src/client.rs @@ -1,8 +1,8 @@ use crate::dto::{ - AIModel, CalculateSimilarityParams, ChatAnswer, ChatQuestion, CompleteTextParams, - CreateChatContext, Document, LocalAIConfig, MessageData, ModelList, QuestionMetadata, - RepeatedLocalAIPackage, RepeatedRelatedQuestion, ResponseFormat, SearchDocumentsRequest, - SimilarityResponse, SummarizeRowResponse, TranslateRowData, TranslateRowResponse, + CalculateSimilarityParams, ChatAnswer, ChatQuestion, CompleteTextParams, CreateChatContext, + Document, LocalAIConfig, MessageData, ModelList, QuestionMetadata, RepeatedLocalAIPackage, + RepeatedRelatedQuestion, ResponseFormat, SearchDocumentsRequest, SimilarityResponse, + SummarizeRowResponse, TranslateRowData, TranslateRowResponse, }; use crate::error::AIError; @@ -44,7 +44,7 @@ impl AppFlowyAIClient { pub async fn stream_completion_text( &self, params: CompleteTextParams, - model: AIModel, + model: &str, ) -> Result>, AIError> { if params.text.is_empty() { return Err(AIError::InvalidRequest("Empty text".to_string())); @@ -53,7 +53,7 @@ impl AppFlowyAIClient { let url = format!("{}/completion/stream", self.url); let resp = self .async_http_client(Method::POST, &url)? - .header(AI_MODEL_HEADER_KEY, model.to_str()) + .header(AI_MODEL_HEADER_KEY, model) .json(¶ms) .send() .await?; @@ -63,7 +63,7 @@ impl AppFlowyAIClient { pub async fn summarize_row( &self, params: &Map, - model: AIModel, + model: &str, ) -> Result { if params.is_empty() { return Err(AIError::InvalidRequest("Empty content".to_string())); @@ -73,7 +73,7 @@ impl AppFlowyAIClient { trace!("summarize_row url: {}", url); let resp = self .async_http_client(Method::POST, &url)? - .header(AI_MODEL_HEADER_KEY, model.to_str()) + .header(AI_MODEL_HEADER_KEY, model) .json(params) .send() .await?; @@ -85,12 +85,12 @@ impl AppFlowyAIClient { pub async fn translate_row( &self, data: TranslateRowData, - model: AIModel, + model: &str, ) -> Result { let url = format!("{}/translate_row", self.url); let resp = self .async_http_client(Method::POST, &url)? - .header(AI_MODEL_HEADER_KEY, model.to_str()) + .header(AI_MODEL_HEADER_KEY, model) .json(&data) .send() .await?; @@ -131,7 +131,7 @@ impl AppFlowyAIClient { chat_id: &str, question_id: i64, content: &str, - model: &AIModel, + model: &str, metadata: Option, ) -> Result { let json = ChatQuestion { @@ -150,7 +150,7 @@ impl AppFlowyAIClient { let url = format!("{}/chat/message", self.url); let resp = self .async_http_client(Method::POST, &url)? - .header(AI_MODEL_HEADER_KEY, model.to_str()) + .header(AI_MODEL_HEADER_KEY, model) .json(&json) .send() .await?; @@ -166,7 +166,7 @@ impl AppFlowyAIClient { content: &str, metadata: Option, rag_ids: Vec, - model: &AIModel, + model: &str, ) -> Result>, AIError> { let json = ChatQuestion { chat_id: chat_id.to_string(), @@ -184,7 +184,7 @@ impl AppFlowyAIClient { let url = format!("{}/chat/message/stream", self.url); let resp = self .async_http_client(Method::POST, &url)? - .header(AI_MODEL_HEADER_KEY, model.to_str()) + .header(AI_MODEL_HEADER_KEY, model) .timeout(Duration::from_secs(30)) .json(&json) .send() @@ -201,7 +201,7 @@ impl AppFlowyAIClient { content: &str, metadata: Option, rag_ids: Vec, - model: &AIModel, + model: &str, ) -> Result>, AIError> { let json = ChatQuestion { chat_id: chat_id.to_string(), @@ -221,14 +221,14 @@ impl AppFlowyAIClient { pub async fn stream_question_v3( &self, - model: &AIModel, + model: &str, question: ChatQuestion, timeout_secs: Option, ) -> Result>, AIError> { let url = format!("{}/v2/chat/message/stream", self.url); let resp = self .async_http_client(Method::POST, &url)? - .header(AI_MODEL_HEADER_KEY, model.to_str()) + .header(AI_MODEL_HEADER_KEY, model) .json(&question) .timeout(Duration::from_secs(timeout_secs.unwrap_or(30))) .send() @@ -240,12 +240,12 @@ impl AppFlowyAIClient { &self, chat_id: &str, message_id: &i64, - model: &AIModel, + model: &str, ) -> Result { let url = format!("{}/chat/{chat_id}/{message_id}/related_question", self.url); let resp = self .async_http_client(Method::GET, &url)? - .header(AI_MODEL_HEADER_KEY, model.to_str()) + .header(AI_MODEL_HEADER_KEY, model) .timeout(Duration::from_secs(30)) .send() .await?; diff --git a/libs/appflowy-ai-client/src/dto.rs b/libs/appflowy-ai-client/src/dto.rs index 2ba667de..d4cb6f3f 100644 --- a/libs/appflowy-ai-client/src/dto.rs +++ b/libs/appflowy-ai-client/src/dto.rs @@ -3,8 +3,6 @@ use serde_json::json; use serde_repr::{Deserialize_repr, Serialize_repr}; use std::collections::HashMap; use std::fmt::{Display, Formatter}; -use std::str::FromStr; - pub const STREAM_METADATA_KEY: &str = "0"; pub const STREAM_ANSWER_KEY: &str = "1"; pub const STREAM_IMAGE_KEY: &str = "2"; @@ -340,44 +338,6 @@ impl Display for EmbeddingModel { } } -#[derive(Debug, Clone, Default, Serialize_repr, Deserialize_repr)] -#[repr(u8)] -pub enum AIModel { - #[default] - DefaultModel = 0, - GPT4oMini = 1, - GPT4o = 2, - Claude3Sonnet = 3, - Claude3Opus = 4, -} - -impl AIModel { - pub fn to_str(&self) -> &str { - match self { - AIModel::DefaultModel => "default-model", - AIModel::GPT4oMini => "gpt-4o-mini", - AIModel::GPT4o => "gpt-4o", - AIModel::Claude3Sonnet => "claude-3-sonnet", - AIModel::Claude3Opus => "claude-3-opus", - } - } -} - -impl FromStr for AIModel { - type Err = anyhow::Error; - - fn from_str(s: &str) -> Result { - match s { - "gpt-3.5-turbo" => Ok(AIModel::GPT4oMini), - "gpt-4o-mini" => Ok(AIModel::GPT4oMini), - "gpt-4o" => Ok(AIModel::GPT4o), - "claude-3-sonnet" => Ok(AIModel::Claude3Sonnet), - "claude-3-opus" => Ok(AIModel::Claude3Opus), - _ => Ok(AIModel::DefaultModel), - } - } -} - #[derive(Clone, Debug, Serialize, Deserialize)] pub struct RepeatedLocalAIPackage(pub Vec); diff --git a/libs/appflowy-ai-client/tests/chat_test/completion_test.rs b/libs/appflowy-ai-client/tests/chat_test/completion_test.rs index e3f72686..2efa1965 100644 --- a/libs/appflowy-ai-client/tests/chat_test/completion_test.rs +++ b/libs/appflowy-ai-client/tests/chat_test/completion_test.rs @@ -1,6 +1,6 @@ use crate::appflowy_ai_client; use appflowy_ai_client::client::collect_stream_text; -use appflowy_ai_client::dto::{AIModel, CompleteTextParams, CompletionType}; +use appflowy_ai_client::dto::{CompleteTextParams, CompletionType}; #[tokio::test] async fn continue_writing_test() { let client = appflowy_ai_client(); @@ -11,7 +11,7 @@ async fn continue_writing_test() { metadata: None, }; let stream = client - .stream_completion_text(params, AIModel::GPT4oMini) + .stream_completion_text(params, "gpt-4o-mini") .await .unwrap(); let text = collect_stream_text(stream).await; @@ -29,7 +29,7 @@ async fn improve_writing_test() { metadata: None, }; let stream = client - .stream_completion_text(params, AIModel::GPT4oMini) + .stream_completion_text(params, "gpt-4o-mini") .await .unwrap(); @@ -49,7 +49,7 @@ async fn make_text_shorter_text() { metadata: None, }; let stream = client - .stream_completion_text(params, AIModel::GPT4oMini) + .stream_completion_text(params, "gpt-4o-mini") .await .unwrap(); diff --git a/libs/appflowy-ai-client/tests/chat_test/context_test.rs b/libs/appflowy-ai-client/tests/chat_test/context_test.rs index 9938a71a..5b92a082 100644 --- a/libs/appflowy-ai-client/tests/chat_test/context_test.rs +++ b/libs/appflowy-ai-client/tests/chat_test/context_test.rs @@ -1,5 +1,5 @@ use crate::appflowy_ai_client; -use appflowy_ai_client::dto::{AIModel, CreateChatContext}; +use appflowy_ai_client::dto::CreateChatContext; #[tokio::test] async fn create_chat_context_test() { let client = appflowy_ai_client(); @@ -19,7 +19,7 @@ async fn create_chat_context_test() { &chat_id, 1, "Where I live?", - &AIModel::GPT4oMini, + "gpt-4o-mini", None, ) .await diff --git a/libs/appflowy-ai-client/tests/chat_test/qa_test.rs b/libs/appflowy-ai-client/tests/chat_test/qa_test.rs index 38804ae3..693d4440 100644 --- a/libs/appflowy-ai-client/tests/chat_test/qa_test.rs +++ b/libs/appflowy-ai-client/tests/chat_test/qa_test.rs @@ -1,7 +1,5 @@ use crate::appflowy_ai_client; -use appflowy_ai_client::dto::AIModel; - #[tokio::test] async fn qa_test() { let client = appflowy_ai_client(); @@ -13,7 +11,7 @@ async fn qa_test() { &chat_id, 1, "I feel hungry", - &AIModel::GPT4o, + "gpt-4o", None, ) .await @@ -21,7 +19,7 @@ async fn qa_test() { assert!(!resp.content.is_empty()); let questions = client - .get_related_question(&chat_id, &1, &AIModel::GPT4oMini) + .get_related_question(&chat_id, &1, "gpt-4o-mini") .await .unwrap() .items; diff --git a/libs/appflowy-ai-client/tests/row_test/summarize_test.rs b/libs/appflowy-ai-client/tests/row_test/summarize_test.rs index 223debe0..72f05294 100644 --- a/libs/appflowy-ai-client/tests/row_test/summarize_test.rs +++ b/libs/appflowy-ai-client/tests/row_test/summarize_test.rs @@ -1,6 +1,4 @@ use crate::appflowy_ai_client; - -use appflowy_ai_client::dto::AIModel; use serde_json::json; #[tokio::test] @@ -9,7 +7,7 @@ async fn summarize_row_test() { let json = json!({"name": "Jack", "age": 25, "city": "New York"}); let result = client - .summarize_row(json.as_object().unwrap(), AIModel::GPT4oMini) + .summarize_row(json.as_object().unwrap(), "gpt-4o-mini") .await .unwrap(); result.text.contains("Jack"); diff --git a/libs/appflowy-ai-client/tests/row_test/translate_test.rs b/libs/appflowy-ai-client/tests/row_test/translate_test.rs index e4951812..75e9c543 100644 --- a/libs/appflowy-ai-client/tests/row_test/translate_test.rs +++ b/libs/appflowy-ai-client/tests/row_test/translate_test.rs @@ -1,6 +1,6 @@ use crate::appflowy_ai_client; -use appflowy_ai_client::dto::{AIModel, TranslateItem, TranslateRowData}; +use appflowy_ai_client::dto::{TranslateItem, TranslateRowData}; #[tokio::test] async fn translate_row_test() { @@ -20,9 +20,6 @@ async fn translate_row_test() { include_header: false, }; - let result = client - .translate_row(data, AIModel::GPT4oMini) - .await - .unwrap(); + let result = client.translate_row(data, "gpt-4o-mini").await.unwrap(); assert_eq!(result.items.len(), 2); } diff --git a/libs/client-api/src/http.rs b/libs/client-api/src/http.rs index 0a50df35..bfb06adc 100644 --- a/libs/client-api/src/http.rs +++ b/libs/client-api/src/http.rs @@ -45,7 +45,6 @@ use crate::retry::{RefreshTokenAction, RefreshTokenRetryCondition}; use crate::ws::ConnectInfo; use client_api_entity::SignUpResponse::{Authenticated, NotAuthenticated}; use client_api_entity::{GotrueTokenResponse, UpdateGotrueUserParams, User}; -use shared_entity::dto::ai_dto::AIModel; pub const X_COMPRESSION_TYPE: &str = "X-Compression-Type"; pub const X_COMPRESSION_BUFFER_SIZE: &str = "X-Compression-Buffer-Size"; @@ -112,7 +111,7 @@ pub struct Client { pub(crate) is_refreshing_token: Arc, pub(crate) refresh_ret_txs: Arc>>, pub(crate) config: ClientConfiguration, - pub(crate) ai_model: Arc>, + pub(crate) ai_model: Arc>, } pub(crate) type RefreshTokenSender = tokio::sync::oneshot::Sender>; @@ -176,7 +175,7 @@ impl Client { ); } - let ai_model = Arc::new(RwLock::new(AIModel::GPT4oMini)); + let ai_model = Arc::new(RwLock::new("gpt-4o-mini".to_string())); Self { base_url: base_url.to_string(), @@ -205,7 +204,7 @@ impl Client { &self.gotrue_client.base_url } - pub fn set_ai_model(&self, model: AIModel) { + pub fn set_ai_model(&self, model: String) { info!("using ai model: {:?}", model); *self.ai_model.write() = model; } @@ -1087,7 +1086,7 @@ impl Client { ("client-version", self.client_version.to_string()), ("client-timestamp", ts_now.to_string()), ("device-id", self.device_id.clone()), - ("ai-model", self.ai_model.read().to_str().to_string()), + ("ai-model", self.ai_model.read().clone()), ]; trace!( "start request: {}, method: {}, headers: {:?}", diff --git a/libs/shared-entity/src/dto/chat_dto.rs b/libs/shared-entity/src/dto/chat_dto.rs index e0ab8395..a89b836b 100644 --- a/libs/shared-entity/src/dto/chat_dto.rs +++ b/libs/shared-entity/src/dto/chat_dto.rs @@ -1,4 +1,3 @@ -use appflowy_ai_client::dto::AIModel; use chrono::{DateTime, Utc}; use infra::validate::validate_not_empty_str; use serde::{Deserialize, Deserializer, Serialize}; @@ -209,7 +208,7 @@ pub struct UpdateChatMessageContentParams { pub message_id: i64, pub content: String, #[serde(default)] - pub model: AIModel, + pub model: String, } #[derive(Debug, Clone, Default, Serialize_repr, Deserialize_repr)] diff --git a/src/api/chat.rs b/src/api/chat.rs index 965e7214..d328d083 100644 --- a/src/api/chat.rs +++ b/src/api/chat.rs @@ -169,7 +169,7 @@ async fn get_related_message_handler( let ai_model = ai_model_from_header(&req); let resp = state .ai_client - .get_related_question(&chat_id, &message_id, &ai_model) + .get_related_question(&chat_id, &message_id, ai_model) .await .map_err(|err| AppError::Internal(err.into()))?; Ok(AppResponse::Ok().with_data(resp).into()) @@ -280,7 +280,7 @@ async fn answer_stream_handler( &content, Some(metadata), rag_ids, - &ai_model, + ai_model, ) .await { @@ -333,7 +333,7 @@ async fn answer_stream_v2_handler( &content, Some(metadata), rag_ids, - &ai_model, + ai_model, ) .await { @@ -393,7 +393,7 @@ async fn answer_stream_v3_handler( trace!("[Chat] stream v3 {:?}", question); match state .ai_client - .stream_question_v3(&ai_model, question, Some(60)) + .stream_question_v3(ai_model, question, Some(60)) .await { Ok(answer_stream) => { diff --git a/src/api/util.rs b/src/api/util.rs index 64bb6ea7..0220bc95 100644 --- a/src/api/util.rs +++ b/src/api/util.rs @@ -4,7 +4,6 @@ use actix_web::web::Payload; use app_error::AppError; use actix_web::HttpRequest; -use appflowy_ai_client::dto::AIModel; use async_trait::async_trait; use byteorder::{ByteOrder, LittleEndian}; use chrono::Utc; @@ -212,15 +211,12 @@ fn copy_buffer(src: &[u8], dest: &mut [u8]) -> usize { } #[inline] -pub(crate) fn ai_model_from_header(req: &HttpRequest) -> AIModel { +pub(crate) fn ai_model_from_header(req: &HttpRequest) -> &str { req .headers() .get("ai-model") - .and_then(|header| { - let header = header.to_str().ok()?; - AIModel::from_str(header).ok() - }) - .unwrap_or(AIModel::GPT4oMini) + .and_then(|header| header.to_str().ok()) + .unwrap_or("Default") } #[cfg(test)] diff --git a/src/biz/chat/ops.rs b/src/biz/chat/ops.rs index a56ca6a9..293fdac4 100644 --- a/src/biz/chat/ops.rs +++ b/src/biz/chat/ops.rs @@ -19,7 +19,6 @@ use shared_entity::dto::chat_dto::{ use sqlx::PgPool; use tracing::{error, info, trace}; -use appflowy_ai_client::dto::AIModel; use validator::Validate; pub(crate) async fn create_chat( @@ -46,7 +45,7 @@ pub async fn update_chat_message( pg_pool: &PgPool, params: UpdateChatMessageContentParams, ai_client: AppFlowyAIClient, - ai_model: AIModel, + ai_model: &str, ) -> Result<(), AppError> { let mut txn = pg_pool.begin().await?; delete_answer_message_by_question_message_id(&mut txn, params.message_id).await?; @@ -65,7 +64,7 @@ pub async fn update_chat_message( ¶ms.chat_id, params.message_id, ¶ms.content, - &ai_model, + ai_model, None, ) .await?; @@ -88,7 +87,7 @@ pub async fn generate_chat_message_answer( ai_client: AppFlowyAIClient, question_message_id: i64, chat_id: &str, - ai_model: AIModel, + ai_model: &str, ) -> Result { let (content, metadata) = chat::chat_ops::select_chat_message_content(pg_pool, question_message_id).await?; @@ -98,7 +97,7 @@ pub async fn generate_chat_message_answer( chat_id, question_message_id, &content, - &ai_model, + ai_model, Some(metadata), ) .await @@ -153,8 +152,9 @@ pub async fn create_chat_message_stream( chat_id: String, params: CreateChatMessageParams, ai_client: AppFlowyAIClient, - ai_model: AIModel, + ai_model: &str, ) -> impl Stream> { + let ai_model = ai_model.to_string(); let params = params.clone(); let chat_id = chat_id.clone(); let pg_pool = pg_pool.clone();