chore: remove ai model enum (#1207)

This commit is contained in:
Nathan.fooo 2025-02-01 22:47:46 +08:00 committed by GitHub
parent 18b1386bc2
commit 82409199f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 48 additions and 101 deletions

View File

@ -1,8 +1,8 @@
use crate::dto::{
AIModel, CalculateSimilarityParams, ChatAnswer, ChatQuestion, CompleteTextParams,
CreateChatContext, Document, LocalAIConfig, MessageData, ModelList, QuestionMetadata,
RepeatedLocalAIPackage, RepeatedRelatedQuestion, ResponseFormat, SearchDocumentsRequest,
SimilarityResponse, SummarizeRowResponse, TranslateRowData, TranslateRowResponse,
CalculateSimilarityParams, ChatAnswer, ChatQuestion, CompleteTextParams, CreateChatContext,
Document, LocalAIConfig, MessageData, ModelList, QuestionMetadata, RepeatedLocalAIPackage,
RepeatedRelatedQuestion, ResponseFormat, SearchDocumentsRequest, SimilarityResponse,
SummarizeRowResponse, TranslateRowData, TranslateRowResponse,
};
use crate::error::AIError;
@ -44,7 +44,7 @@ impl AppFlowyAIClient {
pub async fn stream_completion_text(
&self,
params: CompleteTextParams,
model: AIModel,
model: &str,
) -> Result<impl Stream<Item = Result<Bytes, AIError>>, AIError> {
if params.text.is_empty() {
return Err(AIError::InvalidRequest("Empty text".to_string()));
@ -53,7 +53,7 @@ impl AppFlowyAIClient {
let url = format!("{}/completion/stream", self.url);
let resp = self
.async_http_client(Method::POST, &url)?
.header(AI_MODEL_HEADER_KEY, model.to_str())
.header(AI_MODEL_HEADER_KEY, model)
.json(&params)
.send()
.await?;
@ -63,7 +63,7 @@ impl AppFlowyAIClient {
pub async fn summarize_row(
&self,
params: &Map<String, Value>,
model: AIModel,
model: &str,
) -> Result<SummarizeRowResponse, AIError> {
if params.is_empty() {
return Err(AIError::InvalidRequest("Empty content".to_string()));
@ -73,7 +73,7 @@ impl AppFlowyAIClient {
trace!("summarize_row url: {}", url);
let resp = self
.async_http_client(Method::POST, &url)?
.header(AI_MODEL_HEADER_KEY, model.to_str())
.header(AI_MODEL_HEADER_KEY, model)
.json(params)
.send()
.await?;
@ -85,12 +85,12 @@ impl AppFlowyAIClient {
pub async fn translate_row(
&self,
data: TranslateRowData,
model: AIModel,
model: &str,
) -> Result<TranslateRowResponse, AIError> {
let url = format!("{}/translate_row", self.url);
let resp = self
.async_http_client(Method::POST, &url)?
.header(AI_MODEL_HEADER_KEY, model.to_str())
.header(AI_MODEL_HEADER_KEY, model)
.json(&data)
.send()
.await?;
@ -131,7 +131,7 @@ impl AppFlowyAIClient {
chat_id: &str,
question_id: i64,
content: &str,
model: &AIModel,
model: &str,
metadata: Option<Value>,
) -> Result<ChatAnswer, AIError> {
let json = ChatQuestion {
@ -150,7 +150,7 @@ impl AppFlowyAIClient {
let url = format!("{}/chat/message", self.url);
let resp = self
.async_http_client(Method::POST, &url)?
.header(AI_MODEL_HEADER_KEY, model.to_str())
.header(AI_MODEL_HEADER_KEY, model)
.json(&json)
.send()
.await?;
@ -166,7 +166,7 @@ impl AppFlowyAIClient {
content: &str,
metadata: Option<Value>,
rag_ids: Vec<String>,
model: &AIModel,
model: &str,
) -> Result<impl Stream<Item = Result<Bytes, AIError>>, AIError> {
let json = ChatQuestion {
chat_id: chat_id.to_string(),
@ -184,7 +184,7 @@ impl AppFlowyAIClient {
let url = format!("{}/chat/message/stream", self.url);
let resp = self
.async_http_client(Method::POST, &url)?
.header(AI_MODEL_HEADER_KEY, model.to_str())
.header(AI_MODEL_HEADER_KEY, model)
.timeout(Duration::from_secs(30))
.json(&json)
.send()
@ -201,7 +201,7 @@ impl AppFlowyAIClient {
content: &str,
metadata: Option<Value>,
rag_ids: Vec<String>,
model: &AIModel,
model: &str,
) -> Result<impl Stream<Item = Result<Bytes, AIError>>, AIError> {
let json = ChatQuestion {
chat_id: chat_id.to_string(),
@ -221,14 +221,14 @@ impl AppFlowyAIClient {
pub async fn stream_question_v3(
&self,
model: &AIModel,
model: &str,
question: ChatQuestion,
timeout_secs: Option<u64>,
) -> Result<impl Stream<Item = Result<Bytes, AIError>>, AIError> {
let url = format!("{}/v2/chat/message/stream", self.url);
let resp = self
.async_http_client(Method::POST, &url)?
.header(AI_MODEL_HEADER_KEY, model.to_str())
.header(AI_MODEL_HEADER_KEY, model)
.json(&question)
.timeout(Duration::from_secs(timeout_secs.unwrap_or(30)))
.send()
@ -240,12 +240,12 @@ impl AppFlowyAIClient {
&self,
chat_id: &str,
message_id: &i64,
model: &AIModel,
model: &str,
) -> Result<RepeatedRelatedQuestion, AIError> {
let url = format!("{}/chat/{chat_id}/{message_id}/related_question", self.url);
let resp = self
.async_http_client(Method::GET, &url)?
.header(AI_MODEL_HEADER_KEY, model.to_str())
.header(AI_MODEL_HEADER_KEY, model)
.timeout(Duration::from_secs(30))
.send()
.await?;

View File

@ -3,8 +3,6 @@ use serde_json::json;
use serde_repr::{Deserialize_repr, Serialize_repr};
use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use std::str::FromStr;
pub const STREAM_METADATA_KEY: &str = "0";
pub const STREAM_ANSWER_KEY: &str = "1";
pub const STREAM_IMAGE_KEY: &str = "2";
@ -340,44 +338,6 @@ impl Display for EmbeddingModel {
}
}
#[derive(Debug, Clone, Default, Serialize_repr, Deserialize_repr)]
#[repr(u8)]
pub enum AIModel {
#[default]
DefaultModel = 0,
GPT4oMini = 1,
GPT4o = 2,
Claude3Sonnet = 3,
Claude3Opus = 4,
}
impl AIModel {
pub fn to_str(&self) -> &str {
match self {
AIModel::DefaultModel => "default-model",
AIModel::GPT4oMini => "gpt-4o-mini",
AIModel::GPT4o => "gpt-4o",
AIModel::Claude3Sonnet => "claude-3-sonnet",
AIModel::Claude3Opus => "claude-3-opus",
}
}
}
impl FromStr for AIModel {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"gpt-3.5-turbo" => Ok(AIModel::GPT4oMini),
"gpt-4o-mini" => Ok(AIModel::GPT4oMini),
"gpt-4o" => Ok(AIModel::GPT4o),
"claude-3-sonnet" => Ok(AIModel::Claude3Sonnet),
"claude-3-opus" => Ok(AIModel::Claude3Opus),
_ => Ok(AIModel::DefaultModel),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RepeatedLocalAIPackage(pub Vec<AppFlowyOfflineAI>);

View File

@ -1,6 +1,6 @@
use crate::appflowy_ai_client;
use appflowy_ai_client::client::collect_stream_text;
use appflowy_ai_client::dto::{AIModel, CompleteTextParams, CompletionType};
use appflowy_ai_client::dto::{CompleteTextParams, CompletionType};
#[tokio::test]
async fn continue_writing_test() {
let client = appflowy_ai_client();
@ -11,7 +11,7 @@ async fn continue_writing_test() {
metadata: None,
};
let stream = client
.stream_completion_text(params, AIModel::GPT4oMini)
.stream_completion_text(params, "gpt-4o-mini")
.await
.unwrap();
let text = collect_stream_text(stream).await;
@ -29,7 +29,7 @@ async fn improve_writing_test() {
metadata: None,
};
let stream = client
.stream_completion_text(params, AIModel::GPT4oMini)
.stream_completion_text(params, "gpt-4o-mini")
.await
.unwrap();
@ -49,7 +49,7 @@ async fn make_text_shorter_text() {
metadata: None,
};
let stream = client
.stream_completion_text(params, AIModel::GPT4oMini)
.stream_completion_text(params, "gpt-4o-mini")
.await
.unwrap();

View File

@ -1,5 +1,5 @@
use crate::appflowy_ai_client;
use appflowy_ai_client::dto::{AIModel, CreateChatContext};
use appflowy_ai_client::dto::CreateChatContext;
#[tokio::test]
async fn create_chat_context_test() {
let client = appflowy_ai_client();
@ -19,7 +19,7 @@ async fn create_chat_context_test() {
&chat_id,
1,
"Where I live?",
&AIModel::GPT4oMini,
"gpt-4o-mini",
None,
)
.await

View File

@ -1,7 +1,5 @@
use crate::appflowy_ai_client;
use appflowy_ai_client::dto::AIModel;
#[tokio::test]
async fn qa_test() {
let client = appflowy_ai_client();
@ -13,7 +11,7 @@ async fn qa_test() {
&chat_id,
1,
"I feel hungry",
&AIModel::GPT4o,
"gpt-4o",
None,
)
.await
@ -21,7 +19,7 @@ async fn qa_test() {
assert!(!resp.content.is_empty());
let questions = client
.get_related_question(&chat_id, &1, &AIModel::GPT4oMini)
.get_related_question(&chat_id, &1, "gpt-4o-mini")
.await
.unwrap()
.items;

View File

@ -1,6 +1,4 @@
use crate::appflowy_ai_client;
use appflowy_ai_client::dto::AIModel;
use serde_json::json;
#[tokio::test]
@ -9,7 +7,7 @@ async fn summarize_row_test() {
let json = json!({"name": "Jack", "age": 25, "city": "New York"});
let result = client
.summarize_row(json.as_object().unwrap(), AIModel::GPT4oMini)
.summarize_row(json.as_object().unwrap(), "gpt-4o-mini")
.await
.unwrap();
result.text.contains("Jack");

View File

@ -1,6 +1,6 @@
use crate::appflowy_ai_client;
use appflowy_ai_client::dto::{AIModel, TranslateItem, TranslateRowData};
use appflowy_ai_client::dto::{TranslateItem, TranslateRowData};
#[tokio::test]
async fn translate_row_test() {
@ -20,9 +20,6 @@ async fn translate_row_test() {
include_header: false,
};
let result = client
.translate_row(data, AIModel::GPT4oMini)
.await
.unwrap();
let result = client.translate_row(data, "gpt-4o-mini").await.unwrap();
assert_eq!(result.items.len(), 2);
}

View File

@ -45,7 +45,6 @@ use crate::retry::{RefreshTokenAction, RefreshTokenRetryCondition};
use crate::ws::ConnectInfo;
use client_api_entity::SignUpResponse::{Authenticated, NotAuthenticated};
use client_api_entity::{GotrueTokenResponse, UpdateGotrueUserParams, User};
use shared_entity::dto::ai_dto::AIModel;
pub const X_COMPRESSION_TYPE: &str = "X-Compression-Type";
pub const X_COMPRESSION_BUFFER_SIZE: &str = "X-Compression-Buffer-Size";
@ -112,7 +111,7 @@ pub struct Client {
pub(crate) is_refreshing_token: Arc<AtomicBool>,
pub(crate) refresh_ret_txs: Arc<RwLock<Vec<RefreshTokenSender>>>,
pub(crate) config: ClientConfiguration,
pub(crate) ai_model: Arc<RwLock<AIModel>>,
pub(crate) ai_model: Arc<RwLock<String>>,
}
pub(crate) type RefreshTokenSender = tokio::sync::oneshot::Sender<Result<(), AppResponseError>>;
@ -176,7 +175,7 @@ impl Client {
);
}
let ai_model = Arc::new(RwLock::new(AIModel::GPT4oMini));
let ai_model = Arc::new(RwLock::new("gpt-4o-mini".to_string()));
Self {
base_url: base_url.to_string(),
@ -205,7 +204,7 @@ impl Client {
&self.gotrue_client.base_url
}
pub fn set_ai_model(&self, model: AIModel) {
pub fn set_ai_model(&self, model: String) {
info!("using ai model: {:?}", model);
*self.ai_model.write() = model;
}
@ -1087,7 +1086,7 @@ impl Client {
("client-version", self.client_version.to_string()),
("client-timestamp", ts_now.to_string()),
("device-id", self.device_id.clone()),
("ai-model", self.ai_model.read().to_str().to_string()),
("ai-model", self.ai_model.read().clone()),
];
trace!(
"start request: {}, method: {}, headers: {:?}",

View File

@ -1,4 +1,3 @@
use appflowy_ai_client::dto::AIModel;
use chrono::{DateTime, Utc};
use infra::validate::validate_not_empty_str;
use serde::{Deserialize, Deserializer, Serialize};
@ -209,7 +208,7 @@ pub struct UpdateChatMessageContentParams {
pub message_id: i64,
pub content: String,
#[serde(default)]
pub model: AIModel,
pub model: String,
}
#[derive(Debug, Clone, Default, Serialize_repr, Deserialize_repr)]

View File

@ -169,7 +169,7 @@ async fn get_related_message_handler(
let ai_model = ai_model_from_header(&req);
let resp = state
.ai_client
.get_related_question(&chat_id, &message_id, &ai_model)
.get_related_question(&chat_id, &message_id, ai_model)
.await
.map_err(|err| AppError::Internal(err.into()))?;
Ok(AppResponse::Ok().with_data(resp).into())
@ -280,7 +280,7 @@ async fn answer_stream_handler(
&content,
Some(metadata),
rag_ids,
&ai_model,
ai_model,
)
.await
{
@ -333,7 +333,7 @@ async fn answer_stream_v2_handler(
&content,
Some(metadata),
rag_ids,
&ai_model,
ai_model,
)
.await
{
@ -393,7 +393,7 @@ async fn answer_stream_v3_handler(
trace!("[Chat] stream v3 {:?}", question);
match state
.ai_client
.stream_question_v3(&ai_model, question, Some(60))
.stream_question_v3(ai_model, question, Some(60))
.await
{
Ok(answer_stream) => {

View File

@ -4,7 +4,6 @@ use actix_web::web::Payload;
use app_error::AppError;
use actix_web::HttpRequest;
use appflowy_ai_client::dto::AIModel;
use async_trait::async_trait;
use byteorder::{ByteOrder, LittleEndian};
use chrono::Utc;
@ -212,15 +211,12 @@ fn copy_buffer(src: &[u8], dest: &mut [u8]) -> usize {
}
#[inline]
pub(crate) fn ai_model_from_header(req: &HttpRequest) -> AIModel {
pub(crate) fn ai_model_from_header(req: &HttpRequest) -> &str {
req
.headers()
.get("ai-model")
.and_then(|header| {
let header = header.to_str().ok()?;
AIModel::from_str(header).ok()
})
.unwrap_or(AIModel::GPT4oMini)
.and_then(|header| header.to_str().ok())
.unwrap_or("Default")
}
#[cfg(test)]

View File

@ -19,7 +19,6 @@ use shared_entity::dto::chat_dto::{
use sqlx::PgPool;
use tracing::{error, info, trace};
use appflowy_ai_client::dto::AIModel;
use validator::Validate;
pub(crate) async fn create_chat(
@ -46,7 +45,7 @@ pub async fn update_chat_message(
pg_pool: &PgPool,
params: UpdateChatMessageContentParams,
ai_client: AppFlowyAIClient,
ai_model: AIModel,
ai_model: &str,
) -> Result<(), AppError> {
let mut txn = pg_pool.begin().await?;
delete_answer_message_by_question_message_id(&mut txn, params.message_id).await?;
@ -65,7 +64,7 @@ pub async fn update_chat_message(
&params.chat_id,
params.message_id,
&params.content,
&ai_model,
ai_model,
None,
)
.await?;
@ -88,7 +87,7 @@ pub async fn generate_chat_message_answer(
ai_client: AppFlowyAIClient,
question_message_id: i64,
chat_id: &str,
ai_model: AIModel,
ai_model: &str,
) -> Result<ChatMessage, AppError> {
let (content, metadata) =
chat::chat_ops::select_chat_message_content(pg_pool, question_message_id).await?;
@ -98,7 +97,7 @@ pub async fn generate_chat_message_answer(
chat_id,
question_message_id,
&content,
&ai_model,
ai_model,
Some(metadata),
)
.await
@ -153,8 +152,9 @@ pub async fn create_chat_message_stream(
chat_id: String,
params: CreateChatMessageParams,
ai_client: AppFlowyAIClient,
ai_model: AIModel,
ai_model: &str,
) -> impl Stream<Item = Result<Bytes, AppError>> {
let ai_model = ai_model.to_string();
let params = params.clone();
let chat_id = chat_id.clone();
let pg_pool = pg_pool.clone();