From 18b1386bc2d16851d4b5f42d28f23b8c333d02db Mon Sep 17 00:00:00 2001 From: "Nathan.fooo" <86001920+appflowy@users.noreply.github.com> Date: Sat, 1 Feb 2025 12:45:20 +0800 Subject: [PATCH] chore: return model list (#1206) --- libs/appflowy-ai-client/src/client.rs | 11 ++++++++++- libs/appflowy-ai-client/src/dto.rs | 5 +++++ libs/appflowy-ai-client/tests/chat_test/mod.rs | 1 + .../tests/chat_test/model_config_test.rs | 8 ++++++++ libs/client-api/src/http_ai.rs | 18 ++++++++++++++++-- src/api/ai.rs | 16 +++++++++++++++- tests/ai_test/chat_test.rs | 18 ++++++++++++++++++ 7 files changed, 73 insertions(+), 4 deletions(-) create mode 100644 libs/appflowy-ai-client/tests/chat_test/model_config_test.rs diff --git a/libs/appflowy-ai-client/src/client.rs b/libs/appflowy-ai-client/src/client.rs index e94608bc..9a08a726 100644 --- a/libs/appflowy-ai-client/src/client.rs +++ b/libs/appflowy-ai-client/src/client.rs @@ -1,6 +1,6 @@ use crate::dto::{ AIModel, CalculateSimilarityParams, ChatAnswer, ChatQuestion, CompleteTextParams, - CreateChatContext, Document, LocalAIConfig, MessageData, QuestionMetadata, + CreateChatContext, Document, LocalAIConfig, MessageData, ModelList, QuestionMetadata, RepeatedLocalAIPackage, RepeatedRelatedQuestion, ResponseFormat, SearchDocumentsRequest, SimilarityResponse, SummarizeRowResponse, TranslateRowData, TranslateRowResponse, }; @@ -284,6 +284,15 @@ impl AppFlowyAIClient { .into_data() } + pub async fn get_model_list(&self) -> Result { + let url = format!("{}/model/list", self.url); + + let resp = self.async_http_client(Method::GET, &url)?.send().await?; + AIResponse::::from_reqwest_response(resp) + .await? + .into_data() + } + pub async fn calculate_similarity( &self, params: CalculateSimilarityParams, diff --git a/libs/appflowy-ai-client/src/dto.rs b/libs/appflowy-ai-client/src/dto.rs index b602d9dc..2ba667de 100644 --- a/libs/appflowy-ai-client/src/dto.rs +++ b/libs/appflowy-ai-client/src/dto.rs @@ -414,6 +414,11 @@ pub struct LocalAIConfig { pub plugin: AppFlowyOfflineAI, } +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ModelList { + pub models: Vec, +} + #[derive(Clone, Debug, Serialize, Deserialize)] pub struct CreateChatContext { pub chat_id: String, diff --git a/libs/appflowy-ai-client/tests/chat_test/mod.rs b/libs/appflowy-ai-client/tests/chat_test/mod.rs index c7835a6a..247cd26c 100644 --- a/libs/appflowy-ai-client/tests/chat_test/mod.rs +++ b/libs/appflowy-ai-client/tests/chat_test/mod.rs @@ -1,3 +1,4 @@ mod completion_test; mod context_test; +mod model_config_test; mod qa_test; diff --git a/libs/appflowy-ai-client/tests/chat_test/model_config_test.rs b/libs/appflowy-ai-client/tests/chat_test/model_config_test.rs new file mode 100644 index 00000000..71eb550d --- /dev/null +++ b/libs/appflowy-ai-client/tests/chat_test/model_config_test.rs @@ -0,0 +1,8 @@ +use crate::appflowy_ai_client; + +#[tokio::test] +async fn get_model_list_test() { + let client = appflowy_ai_client(); + let models = client.get_model_list().await.unwrap().models; + assert!(models.len() >= 5, "models.len() = {}", models.len()); +} diff --git a/libs/client-api/src/http_ai.rs b/libs/client-api/src/http_ai.rs index 6910a74a..93ccfd17 100644 --- a/libs/client-api/src/http_ai.rs +++ b/libs/client-api/src/http_ai.rs @@ -4,8 +4,8 @@ use bytes::Bytes; use futures_core::Stream; use reqwest::Method; use shared_entity::dto::ai_dto::{ - CompleteTextParams, LocalAIConfig, SummarizeRowParams, SummarizeRowResponse, TranslateRowParams, - TranslateRowResponse, + CompleteTextParams, LocalAIConfig, ModelList, SummarizeRowParams, SummarizeRowResponse, + TranslateRowParams, TranslateRowResponse, }; use shared_entity::response::{AppResponse, AppResponseError}; use std::time::Duration; @@ -96,4 +96,18 @@ impl Client { .await? .into_data() } + + #[instrument(level = "debug", skip_all, err)] + pub async fn get_model_list(&self, workspace_id: &str) -> Result { + let url = format!("{}/api/ai/{workspace_id}/model/list", self.base_url); + let resp = self + .http_client_with_auth(Method::GET, &url) + .await? + .send() + .await?; + log_request_id(&resp); + AppResponse::::from_response(resp) + .await? + .into_data() + } } diff --git a/src/api/ai.rs b/src/api/ai.rs index 484f75e9..0bf1e374 100644 --- a/src/api/ai.rs +++ b/src/api/ai.rs @@ -5,7 +5,7 @@ use actix_web::web::{Data, Json}; use actix_web::{web, HttpRequest, HttpResponse, Scope}; use app_error::AppError; use appflowy_ai_client::dto::{ - CalculateSimilarityParams, LocalAIConfig, SimilarityResponse, TranslateRowParams, + CalculateSimilarityParams, LocalAIConfig, ModelList, SimilarityResponse, TranslateRowParams, TranslateRowResponse, }; @@ -28,6 +28,7 @@ pub fn ai_completion_scope() -> Scope { .service( web::resource("/calculate_similarity").route(web::post().to(calculate_similarity_handler)), ) + .service(web::resource("/model/list").route(web::get().to(model_list_handler))) } async fn stream_complete_text_handler( @@ -126,6 +127,7 @@ struct ConfigQuery { platform: String, app_version: Option, } + #[instrument(level = "debug", skip_all, err)] async fn local_ai_config_handler( state: web::Data, @@ -165,3 +167,15 @@ async fn calculate_similarity_handler( .map_err(|err| AppError::AIServiceUnavailable(err.to_string()))?; Ok(AppResponse::Ok().with_data(response).into()) } + +#[instrument(level = "debug", skip_all, err)] +async fn model_list_handler( + state: web::Data, +) -> actix_web::Result>> { + let model_list = state + .ai_client + .get_model_list() + .await + .map_err(|err| AppError::AIServiceUnavailable(err.to_string()))?; + Ok(AppResponse::Ok().with_data(model_list).into()) +} diff --git a/tests/ai_test/chat_test.rs b/tests/ai_test/chat_test.rs index 45ec97dc..5ef1f9ea 100644 --- a/tests/ai_test/chat_test.rs +++ b/tests/ai_test/chat_test.rs @@ -484,6 +484,24 @@ async fn get_question_message_test() { assert_eq!(find_question.reply_message_id.unwrap(), answer.message_id); } +#[tokio::test] +async fn get_model_list_test() { + if !ai_test_enabled() { + return; + } + let test_client = TestClient::new_user().await; + let workspace_id = test_client.workspace_id().await; + let models = test_client + .api_client + .get_model_list(&workspace_id) + .await + .unwrap() + .models; + assert!(!models.is_empty()); + assert!(models.len() >= 5, "models.len() = {}", models.len()); + println!("models: {:?}", models); +} + async fn collect_answer(mut stream: QuestionStream) -> String { let mut answer = String::new(); while let Some(value) = stream.next().await {