chore: remove ai model enum (#1207)
This commit is contained in:
parent
18b1386bc2
commit
82409199f8
|
|
@ -1,8 +1,8 @@
|
|||
use crate::dto::{
|
||||
AIModel, CalculateSimilarityParams, ChatAnswer, ChatQuestion, CompleteTextParams,
|
||||
CreateChatContext, Document, LocalAIConfig, MessageData, ModelList, QuestionMetadata,
|
||||
RepeatedLocalAIPackage, RepeatedRelatedQuestion, ResponseFormat, SearchDocumentsRequest,
|
||||
SimilarityResponse, SummarizeRowResponse, TranslateRowData, TranslateRowResponse,
|
||||
CalculateSimilarityParams, ChatAnswer, ChatQuestion, CompleteTextParams, CreateChatContext,
|
||||
Document, LocalAIConfig, MessageData, ModelList, QuestionMetadata, RepeatedLocalAIPackage,
|
||||
RepeatedRelatedQuestion, ResponseFormat, SearchDocumentsRequest, SimilarityResponse,
|
||||
SummarizeRowResponse, TranslateRowData, TranslateRowResponse,
|
||||
};
|
||||
use crate::error::AIError;
|
||||
|
||||
|
|
@ -44,7 +44,7 @@ impl AppFlowyAIClient {
|
|||
pub async fn stream_completion_text(
|
||||
&self,
|
||||
params: CompleteTextParams,
|
||||
model: AIModel,
|
||||
model: &str,
|
||||
) -> Result<impl Stream<Item = Result<Bytes, AIError>>, AIError> {
|
||||
if params.text.is_empty() {
|
||||
return Err(AIError::InvalidRequest("Empty text".to_string()));
|
||||
|
|
@ -53,7 +53,7 @@ impl AppFlowyAIClient {
|
|||
let url = format!("{}/completion/stream", self.url);
|
||||
let resp = self
|
||||
.async_http_client(Method::POST, &url)?
|
||||
.header(AI_MODEL_HEADER_KEY, model.to_str())
|
||||
.header(AI_MODEL_HEADER_KEY, model)
|
||||
.json(¶ms)
|
||||
.send()
|
||||
.await?;
|
||||
|
|
@ -63,7 +63,7 @@ impl AppFlowyAIClient {
|
|||
pub async fn summarize_row(
|
||||
&self,
|
||||
params: &Map<String, Value>,
|
||||
model: AIModel,
|
||||
model: &str,
|
||||
) -> Result<SummarizeRowResponse, AIError> {
|
||||
if params.is_empty() {
|
||||
return Err(AIError::InvalidRequest("Empty content".to_string()));
|
||||
|
|
@ -73,7 +73,7 @@ impl AppFlowyAIClient {
|
|||
trace!("summarize_row url: {}", url);
|
||||
let resp = self
|
||||
.async_http_client(Method::POST, &url)?
|
||||
.header(AI_MODEL_HEADER_KEY, model.to_str())
|
||||
.header(AI_MODEL_HEADER_KEY, model)
|
||||
.json(params)
|
||||
.send()
|
||||
.await?;
|
||||
|
|
@ -85,12 +85,12 @@ impl AppFlowyAIClient {
|
|||
pub async fn translate_row(
|
||||
&self,
|
||||
data: TranslateRowData,
|
||||
model: AIModel,
|
||||
model: &str,
|
||||
) -> Result<TranslateRowResponse, AIError> {
|
||||
let url = format!("{}/translate_row", self.url);
|
||||
let resp = self
|
||||
.async_http_client(Method::POST, &url)?
|
||||
.header(AI_MODEL_HEADER_KEY, model.to_str())
|
||||
.header(AI_MODEL_HEADER_KEY, model)
|
||||
.json(&data)
|
||||
.send()
|
||||
.await?;
|
||||
|
|
@ -131,7 +131,7 @@ impl AppFlowyAIClient {
|
|||
chat_id: &str,
|
||||
question_id: i64,
|
||||
content: &str,
|
||||
model: &AIModel,
|
||||
model: &str,
|
||||
metadata: Option<Value>,
|
||||
) -> Result<ChatAnswer, AIError> {
|
||||
let json = ChatQuestion {
|
||||
|
|
@ -150,7 +150,7 @@ impl AppFlowyAIClient {
|
|||
let url = format!("{}/chat/message", self.url);
|
||||
let resp = self
|
||||
.async_http_client(Method::POST, &url)?
|
||||
.header(AI_MODEL_HEADER_KEY, model.to_str())
|
||||
.header(AI_MODEL_HEADER_KEY, model)
|
||||
.json(&json)
|
||||
.send()
|
||||
.await?;
|
||||
|
|
@ -166,7 +166,7 @@ impl AppFlowyAIClient {
|
|||
content: &str,
|
||||
metadata: Option<Value>,
|
||||
rag_ids: Vec<String>,
|
||||
model: &AIModel,
|
||||
model: &str,
|
||||
) -> Result<impl Stream<Item = Result<Bytes, AIError>>, AIError> {
|
||||
let json = ChatQuestion {
|
||||
chat_id: chat_id.to_string(),
|
||||
|
|
@ -184,7 +184,7 @@ impl AppFlowyAIClient {
|
|||
let url = format!("{}/chat/message/stream", self.url);
|
||||
let resp = self
|
||||
.async_http_client(Method::POST, &url)?
|
||||
.header(AI_MODEL_HEADER_KEY, model.to_str())
|
||||
.header(AI_MODEL_HEADER_KEY, model)
|
||||
.timeout(Duration::from_secs(30))
|
||||
.json(&json)
|
||||
.send()
|
||||
|
|
@ -201,7 +201,7 @@ impl AppFlowyAIClient {
|
|||
content: &str,
|
||||
metadata: Option<Value>,
|
||||
rag_ids: Vec<String>,
|
||||
model: &AIModel,
|
||||
model: &str,
|
||||
) -> Result<impl Stream<Item = Result<Bytes, AIError>>, AIError> {
|
||||
let json = ChatQuestion {
|
||||
chat_id: chat_id.to_string(),
|
||||
|
|
@ -221,14 +221,14 @@ impl AppFlowyAIClient {
|
|||
|
||||
pub async fn stream_question_v3(
|
||||
&self,
|
||||
model: &AIModel,
|
||||
model: &str,
|
||||
question: ChatQuestion,
|
||||
timeout_secs: Option<u64>,
|
||||
) -> Result<impl Stream<Item = Result<Bytes, AIError>>, AIError> {
|
||||
let url = format!("{}/v2/chat/message/stream", self.url);
|
||||
let resp = self
|
||||
.async_http_client(Method::POST, &url)?
|
||||
.header(AI_MODEL_HEADER_KEY, model.to_str())
|
||||
.header(AI_MODEL_HEADER_KEY, model)
|
||||
.json(&question)
|
||||
.timeout(Duration::from_secs(timeout_secs.unwrap_or(30)))
|
||||
.send()
|
||||
|
|
@ -240,12 +240,12 @@ impl AppFlowyAIClient {
|
|||
&self,
|
||||
chat_id: &str,
|
||||
message_id: &i64,
|
||||
model: &AIModel,
|
||||
model: &str,
|
||||
) -> Result<RepeatedRelatedQuestion, AIError> {
|
||||
let url = format!("{}/chat/{chat_id}/{message_id}/related_question", self.url);
|
||||
let resp = self
|
||||
.async_http_client(Method::GET, &url)?
|
||||
.header(AI_MODEL_HEADER_KEY, model.to_str())
|
||||
.header(AI_MODEL_HEADER_KEY, model)
|
||||
.timeout(Duration::from_secs(30))
|
||||
.send()
|
||||
.await?;
|
||||
|
|
|
|||
|
|
@ -3,8 +3,6 @@ use serde_json::json;
|
|||
use serde_repr::{Deserialize_repr, Serialize_repr};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::str::FromStr;
|
||||
|
||||
pub const STREAM_METADATA_KEY: &str = "0";
|
||||
pub const STREAM_ANSWER_KEY: &str = "1";
|
||||
pub const STREAM_IMAGE_KEY: &str = "2";
|
||||
|
|
@ -340,44 +338,6 @@ impl Display for EmbeddingModel {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize_repr, Deserialize_repr)]
|
||||
#[repr(u8)]
|
||||
pub enum AIModel {
|
||||
#[default]
|
||||
DefaultModel = 0,
|
||||
GPT4oMini = 1,
|
||||
GPT4o = 2,
|
||||
Claude3Sonnet = 3,
|
||||
Claude3Opus = 4,
|
||||
}
|
||||
|
||||
impl AIModel {
|
||||
pub fn to_str(&self) -> &str {
|
||||
match self {
|
||||
AIModel::DefaultModel => "default-model",
|
||||
AIModel::GPT4oMini => "gpt-4o-mini",
|
||||
AIModel::GPT4o => "gpt-4o",
|
||||
AIModel::Claude3Sonnet => "claude-3-sonnet",
|
||||
AIModel::Claude3Opus => "claude-3-opus",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for AIModel {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"gpt-3.5-turbo" => Ok(AIModel::GPT4oMini),
|
||||
"gpt-4o-mini" => Ok(AIModel::GPT4oMini),
|
||||
"gpt-4o" => Ok(AIModel::GPT4o),
|
||||
"claude-3-sonnet" => Ok(AIModel::Claude3Sonnet),
|
||||
"claude-3-opus" => Ok(AIModel::Claude3Opus),
|
||||
_ => Ok(AIModel::DefaultModel),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct RepeatedLocalAIPackage(pub Vec<AppFlowyOfflineAI>);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use crate::appflowy_ai_client;
|
||||
use appflowy_ai_client::client::collect_stream_text;
|
||||
use appflowy_ai_client::dto::{AIModel, CompleteTextParams, CompletionType};
|
||||
use appflowy_ai_client::dto::{CompleteTextParams, CompletionType};
|
||||
#[tokio::test]
|
||||
async fn continue_writing_test() {
|
||||
let client = appflowy_ai_client();
|
||||
|
|
@ -11,7 +11,7 @@ async fn continue_writing_test() {
|
|||
metadata: None,
|
||||
};
|
||||
let stream = client
|
||||
.stream_completion_text(params, AIModel::GPT4oMini)
|
||||
.stream_completion_text(params, "gpt-4o-mini")
|
||||
.await
|
||||
.unwrap();
|
||||
let text = collect_stream_text(stream).await;
|
||||
|
|
@ -29,7 +29,7 @@ async fn improve_writing_test() {
|
|||
metadata: None,
|
||||
};
|
||||
let stream = client
|
||||
.stream_completion_text(params, AIModel::GPT4oMini)
|
||||
.stream_completion_text(params, "gpt-4o-mini")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -49,7 +49,7 @@ async fn make_text_shorter_text() {
|
|||
metadata: None,
|
||||
};
|
||||
let stream = client
|
||||
.stream_completion_text(params, AIModel::GPT4oMini)
|
||||
.stream_completion_text(params, "gpt-4o-mini")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use crate::appflowy_ai_client;
|
||||
use appflowy_ai_client::dto::{AIModel, CreateChatContext};
|
||||
use appflowy_ai_client::dto::CreateChatContext;
|
||||
#[tokio::test]
|
||||
async fn create_chat_context_test() {
|
||||
let client = appflowy_ai_client();
|
||||
|
|
@ -19,7 +19,7 @@ async fn create_chat_context_test() {
|
|||
&chat_id,
|
||||
1,
|
||||
"Where I live?",
|
||||
&AIModel::GPT4oMini,
|
||||
"gpt-4o-mini",
|
||||
None,
|
||||
)
|
||||
.await
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
use crate::appflowy_ai_client;
|
||||
|
||||
use appflowy_ai_client::dto::AIModel;
|
||||
|
||||
#[tokio::test]
|
||||
async fn qa_test() {
|
||||
let client = appflowy_ai_client();
|
||||
|
|
@ -13,7 +11,7 @@ async fn qa_test() {
|
|||
&chat_id,
|
||||
1,
|
||||
"I feel hungry",
|
||||
&AIModel::GPT4o,
|
||||
"gpt-4o",
|
||||
None,
|
||||
)
|
||||
.await
|
||||
|
|
@ -21,7 +19,7 @@ async fn qa_test() {
|
|||
assert!(!resp.content.is_empty());
|
||||
|
||||
let questions = client
|
||||
.get_related_question(&chat_id, &1, &AIModel::GPT4oMini)
|
||||
.get_related_question(&chat_id, &1, "gpt-4o-mini")
|
||||
.await
|
||||
.unwrap()
|
||||
.items;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
use crate::appflowy_ai_client;
|
||||
|
||||
use appflowy_ai_client::dto::AIModel;
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -9,7 +7,7 @@ async fn summarize_row_test() {
|
|||
let json = json!({"name": "Jack", "age": 25, "city": "New York"});
|
||||
|
||||
let result = client
|
||||
.summarize_row(json.as_object().unwrap(), AIModel::GPT4oMini)
|
||||
.summarize_row(json.as_object().unwrap(), "gpt-4o-mini")
|
||||
.await
|
||||
.unwrap();
|
||||
result.text.contains("Jack");
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use crate::appflowy_ai_client;
|
||||
|
||||
use appflowy_ai_client::dto::{AIModel, TranslateItem, TranslateRowData};
|
||||
use appflowy_ai_client::dto::{TranslateItem, TranslateRowData};
|
||||
|
||||
#[tokio::test]
|
||||
async fn translate_row_test() {
|
||||
|
|
@ -20,9 +20,6 @@ async fn translate_row_test() {
|
|||
include_header: false,
|
||||
};
|
||||
|
||||
let result = client
|
||||
.translate_row(data, AIModel::GPT4oMini)
|
||||
.await
|
||||
.unwrap();
|
||||
let result = client.translate_row(data, "gpt-4o-mini").await.unwrap();
|
||||
assert_eq!(result.items.len(), 2);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -45,7 +45,6 @@ use crate::retry::{RefreshTokenAction, RefreshTokenRetryCondition};
|
|||
use crate::ws::ConnectInfo;
|
||||
use client_api_entity::SignUpResponse::{Authenticated, NotAuthenticated};
|
||||
use client_api_entity::{GotrueTokenResponse, UpdateGotrueUserParams, User};
|
||||
use shared_entity::dto::ai_dto::AIModel;
|
||||
|
||||
pub const X_COMPRESSION_TYPE: &str = "X-Compression-Type";
|
||||
pub const X_COMPRESSION_BUFFER_SIZE: &str = "X-Compression-Buffer-Size";
|
||||
|
|
@ -112,7 +111,7 @@ pub struct Client {
|
|||
pub(crate) is_refreshing_token: Arc<AtomicBool>,
|
||||
pub(crate) refresh_ret_txs: Arc<RwLock<Vec<RefreshTokenSender>>>,
|
||||
pub(crate) config: ClientConfiguration,
|
||||
pub(crate) ai_model: Arc<RwLock<AIModel>>,
|
||||
pub(crate) ai_model: Arc<RwLock<String>>,
|
||||
}
|
||||
|
||||
pub(crate) type RefreshTokenSender = tokio::sync::oneshot::Sender<Result<(), AppResponseError>>;
|
||||
|
|
@ -176,7 +175,7 @@ impl Client {
|
|||
);
|
||||
}
|
||||
|
||||
let ai_model = Arc::new(RwLock::new(AIModel::GPT4oMini));
|
||||
let ai_model = Arc::new(RwLock::new("gpt-4o-mini".to_string()));
|
||||
|
||||
Self {
|
||||
base_url: base_url.to_string(),
|
||||
|
|
@ -205,7 +204,7 @@ impl Client {
|
|||
&self.gotrue_client.base_url
|
||||
}
|
||||
|
||||
pub fn set_ai_model(&self, model: AIModel) {
|
||||
pub fn set_ai_model(&self, model: String) {
|
||||
info!("using ai model: {:?}", model);
|
||||
*self.ai_model.write() = model;
|
||||
}
|
||||
|
|
@ -1087,7 +1086,7 @@ impl Client {
|
|||
("client-version", self.client_version.to_string()),
|
||||
("client-timestamp", ts_now.to_string()),
|
||||
("device-id", self.device_id.clone()),
|
||||
("ai-model", self.ai_model.read().to_str().to_string()),
|
||||
("ai-model", self.ai_model.read().clone()),
|
||||
];
|
||||
trace!(
|
||||
"start request: {}, method: {}, headers: {:?}",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
use appflowy_ai_client::dto::AIModel;
|
||||
use chrono::{DateTime, Utc};
|
||||
use infra::validate::validate_not_empty_str;
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
|
@ -209,7 +208,7 @@ pub struct UpdateChatMessageContentParams {
|
|||
pub message_id: i64,
|
||||
pub content: String,
|
||||
#[serde(default)]
|
||||
pub model: AIModel,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize_repr, Deserialize_repr)]
|
||||
|
|
|
|||
|
|
@ -169,7 +169,7 @@ async fn get_related_message_handler(
|
|||
let ai_model = ai_model_from_header(&req);
|
||||
let resp = state
|
||||
.ai_client
|
||||
.get_related_question(&chat_id, &message_id, &ai_model)
|
||||
.get_related_question(&chat_id, &message_id, ai_model)
|
||||
.await
|
||||
.map_err(|err| AppError::Internal(err.into()))?;
|
||||
Ok(AppResponse::Ok().with_data(resp).into())
|
||||
|
|
@ -280,7 +280,7 @@ async fn answer_stream_handler(
|
|||
&content,
|
||||
Some(metadata),
|
||||
rag_ids,
|
||||
&ai_model,
|
||||
ai_model,
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
|
@ -333,7 +333,7 @@ async fn answer_stream_v2_handler(
|
|||
&content,
|
||||
Some(metadata),
|
||||
rag_ids,
|
||||
&ai_model,
|
||||
ai_model,
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
|
@ -393,7 +393,7 @@ async fn answer_stream_v3_handler(
|
|||
trace!("[Chat] stream v3 {:?}", question);
|
||||
match state
|
||||
.ai_client
|
||||
.stream_question_v3(&ai_model, question, Some(60))
|
||||
.stream_question_v3(ai_model, question, Some(60))
|
||||
.await
|
||||
{
|
||||
Ok(answer_stream) => {
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ use actix_web::web::Payload;
|
|||
use app_error::AppError;
|
||||
|
||||
use actix_web::HttpRequest;
|
||||
use appflowy_ai_client::dto::AIModel;
|
||||
use async_trait::async_trait;
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
use chrono::Utc;
|
||||
|
|
@ -212,15 +211,12 @@ fn copy_buffer(src: &[u8], dest: &mut [u8]) -> usize {
|
|||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn ai_model_from_header(req: &HttpRequest) -> AIModel {
|
||||
pub(crate) fn ai_model_from_header(req: &HttpRequest) -> &str {
|
||||
req
|
||||
.headers()
|
||||
.get("ai-model")
|
||||
.and_then(|header| {
|
||||
let header = header.to_str().ok()?;
|
||||
AIModel::from_str(header).ok()
|
||||
})
|
||||
.unwrap_or(AIModel::GPT4oMini)
|
||||
.and_then(|header| header.to_str().ok())
|
||||
.unwrap_or("Default")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ use shared_entity::dto::chat_dto::{
|
|||
use sqlx::PgPool;
|
||||
use tracing::{error, info, trace};
|
||||
|
||||
use appflowy_ai_client::dto::AIModel;
|
||||
use validator::Validate;
|
||||
|
||||
pub(crate) async fn create_chat(
|
||||
|
|
@ -46,7 +45,7 @@ pub async fn update_chat_message(
|
|||
pg_pool: &PgPool,
|
||||
params: UpdateChatMessageContentParams,
|
||||
ai_client: AppFlowyAIClient,
|
||||
ai_model: AIModel,
|
||||
ai_model: &str,
|
||||
) -> Result<(), AppError> {
|
||||
let mut txn = pg_pool.begin().await?;
|
||||
delete_answer_message_by_question_message_id(&mut txn, params.message_id).await?;
|
||||
|
|
@ -65,7 +64,7 @@ pub async fn update_chat_message(
|
|||
¶ms.chat_id,
|
||||
params.message_id,
|
||||
¶ms.content,
|
||||
&ai_model,
|
||||
ai_model,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
|
@ -88,7 +87,7 @@ pub async fn generate_chat_message_answer(
|
|||
ai_client: AppFlowyAIClient,
|
||||
question_message_id: i64,
|
||||
chat_id: &str,
|
||||
ai_model: AIModel,
|
||||
ai_model: &str,
|
||||
) -> Result<ChatMessage, AppError> {
|
||||
let (content, metadata) =
|
||||
chat::chat_ops::select_chat_message_content(pg_pool, question_message_id).await?;
|
||||
|
|
@ -98,7 +97,7 @@ pub async fn generate_chat_message_answer(
|
|||
chat_id,
|
||||
question_message_id,
|
||||
&content,
|
||||
&ai_model,
|
||||
ai_model,
|
||||
Some(metadata),
|
||||
)
|
||||
.await
|
||||
|
|
@ -153,8 +152,9 @@ pub async fn create_chat_message_stream(
|
|||
chat_id: String,
|
||||
params: CreateChatMessageParams,
|
||||
ai_client: AppFlowyAIClient,
|
||||
ai_model: AIModel,
|
||||
ai_model: &str,
|
||||
) -> impl Stream<Item = Result<Bytes, AppError>> {
|
||||
let ai_model = ai_model.to_string();
|
||||
let params = params.clone();
|
||||
let chat_id = chat_id.clone();
|
||||
let pg_pool = pg_pool.clone();
|
||||
|
|
|
|||
Loading…
Reference in New Issue