use crate::dto::{ AIModel, ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, CreateTextChatContext, CustomPrompt, Document, EmbeddingRequest, EmbeddingResponse, LocalAIConfig, MessageData, RepeatedLocalAIPackage, RepeatedRelatedQuestion, SearchDocumentsRequest, SummarizeRowResponse, TranslateRowData, TranslateRowResponse, }; use crate::error::AIError; use bytes::Bytes; use futures::{Stream, StreamExt}; use reqwest; use reqwest::{Method, RequestBuilder, StatusCode}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use serde_json::{json, Map, Value}; use std::borrow::Cow; use std::time::Duration; use tracing::{info, trace}; const AI_MODEL_HEADER_KEY: &str = "ai-model"; #[derive(Clone, Debug)] pub struct AppFlowyAIClient { client: reqwest::Client, url: String, } impl AppFlowyAIClient { pub fn new(url: &str) -> Self { info!("Creating AppFlowyAIClient with url: {}", url); let url = url.to_string(); let client = reqwest::Client::new(); Self { client, url } } pub async fn health_check(&self) -> Result<(), AIError> { let url = format!("{}/health", self.url); let resp = self.http_client(Method::GET, &url)?.send().await?; let text = resp.text().await?; info!("health response: {:?}", text); Ok(()) } pub async fn completion_text>>( &self, text: &str, completion_type: T, custom_prompt: Option, model: AIModel, ) -> Result { let completion_type = completion_type.into(); if completion_type.is_some() && custom_prompt.is_some() { return Err(AIError::InvalidRequest( "Cannot specify both completion_type and custom_prompt".to_string(), )); } if text.is_empty() { return Err(AIError::InvalidRequest("Empty text".to_string())); } let params = json!({ "text": text, "type": completion_type.map(|t| t as u8), "custom_prompt": custom_prompt, }); let url = format!("{}/completion", self.url); let resp = self .http_client(Method::POST, &url)? .header(AI_MODEL_HEADER_KEY, model.to_str()) .json(¶ms) .send() .await?; AIResponse::::from_response(resp) .await? .into_data() } pub async fn stream_completion_text>>( &self, text: &str, completion_type: T, custom_prompt: Option, model: AIModel, ) -> Result>, AIError> { let completion_type = completion_type.into(); if text.is_empty() { return Err(AIError::InvalidRequest("Empty text".to_string())); } let params = json!({ "text": text, "type": completion_type.map(|t| t as u8), "custom_prompt": custom_prompt, }); let url = format!("{}/completion/stream", self.url); let resp = self .http_client(Method::POST, &url)? .header(AI_MODEL_HEADER_KEY, model.to_str()) .json(¶ms) .send() .await?; AIResponse::<()>::stream_response(resp).await } pub async fn summarize_row( &self, params: &Map, model: AIModel, ) -> Result { if params.is_empty() { return Err(AIError::InvalidRequest("Empty content".to_string())); } let url = format!("{}/summarize_row", self.url); trace!("summarize_row url: {}", url); let resp = self .http_client(Method::POST, &url)? .header(AI_MODEL_HEADER_KEY, model.to_str()) .json(params) .send() .await?; AIResponse::::from_response(resp) .await? .into_data() } pub async fn translate_row( &self, data: TranslateRowData, model: AIModel, ) -> Result { let url = format!("{}/translate_row", self.url); let resp = self .http_client(Method::POST, &url)? .header(AI_MODEL_HEADER_KEY, model.to_str()) .json(&data) .send() .await?; AIResponse::::from_response(resp) .await? .into_data() } pub async fn embeddings(&self, params: EmbeddingRequest) -> Result { let url = format!("{}/embeddings", self.url); let resp = self .http_client(Method::POST, &url)? .json(¶ms) .send() .await?; AIResponse::::from_response(resp) .await? .into_data() } pub async fn index_documents(&self, documents: &[Document]) -> Result<(), AIError> { let url = format!("{}/index_documents", self.url); let resp = self .http_client(Method::POST, &url)? .json(&documents) .send() .await?; let status_code = resp.status(); if !status_code.is_success() { let body = resp.text().await?; return Err(anyhow::anyhow!("error: {}, {}", status_code, body).into()); } Ok(()) } pub async fn search_documents( &self, request: &SearchDocumentsRequest, ) -> Result, AIError> { let url = format!("{}/search", self.url); let resp = self .http_client(Method::GET, &url)? .query(&request) .send() .await?; AIResponse::>::from_response(resp) .await? .into_data() } pub async fn create_chat_text_context( &self, context: CreateTextChatContext, ) -> Result<(), AIError> { let url = format!("{}/chat/context/text", self.url); let resp = self .http_client(Method::POST, &url)? .json(&context) .send() .await?; let _ = AIResponse::<()>::from_response(resp).await?; Ok(()) } pub async fn send_question( &self, chat_id: &str, content: &str, model: &AIModel, metadata: Option, ) -> Result { let json = ChatQuestion { chat_id: chat_id.to_string(), data: MessageData { content: content.to_string(), metadata, }, }; let url = format!("{}/chat/message", self.url); let resp = self .http_client(Method::POST, &url)? .header(AI_MODEL_HEADER_KEY, model.to_str()) .json(&json) .send() .await?; AIResponse::::from_response(resp) .await? .into_data() } pub async fn stream_question( &self, chat_id: &str, content: &str, metadata: Option, model: &AIModel, ) -> Result>, AIError> { let json = ChatQuestion { chat_id: chat_id.to_string(), data: MessageData { content: content.to_string(), metadata, }, }; let url = format!("{}/chat/message/stream", self.url); let resp = self .http_client(Method::POST, &url)? .header(AI_MODEL_HEADER_KEY, model.to_str()) .timeout(Duration::from_secs(30)) .json(&json) .send() .await?; AIResponse::<()>::stream_response(resp).await } pub async fn stream_question_v2( &self, chat_id: &str, content: &str, metadata: Option, model: &AIModel, ) -> Result>, AIError> { let json = ChatQuestion { chat_id: chat_id.to_string(), data: MessageData { content: content.to_string(), metadata, }, }; let url = format!("{}/v2/chat/message/stream", self.url); let resp = self .http_client(Method::POST, &url)? .header(AI_MODEL_HEADER_KEY, model.to_str()) .json(&json) .timeout(Duration::from_secs(30)) .send() .await?; AIResponse::<()>::stream_response(resp).await } pub async fn get_related_question( &self, chat_id: &str, message_id: &i64, model: &AIModel, ) -> Result { let url = format!("{}/chat/{chat_id}/{message_id}/related_question", self.url); let resp = self .http_client(Method::GET, &url)? .header(AI_MODEL_HEADER_KEY, model.to_str()) .timeout(Duration::from_secs(30)) .send() .await?; AIResponse::::from_response(resp) .await? .into_data() } pub async fn get_local_ai_package( &self, platform: &str, ) -> Result { let url = format!("{}/local_ai/plugin?platform={platform}", self.url); let resp = self.http_client(Method::GET, &url)?.send().await?; AIResponse::::from_response(resp) .await? .into_data() } pub async fn get_local_ai_config( &self, platform: &str, app_version: Option, ) -> Result { // Start with the base URL including the platform parameter let mut url = format!("{}/local_ai/config?platform={}", self.url, platform); // If app_version is provided, append it as a query parameter if let Some(version) = app_version { url = format!("{}&app_version={}", url, version); } let resp = self.http_client(Method::GET, &url)?.send().await?; AIResponse::::from_response(resp) .await? .into_data() } fn http_client(&self, method: Method, url: &str) -> Result { let request_builder = self.client.request(method, url); Ok(request_builder) } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AIResponse { #[serde(skip_serializing_if = "Option::is_none")] pub data: Option, #[serde(default)] pub message: Cow<'static, str>, } impl AIResponse where T: DeserializeOwned + 'static, { pub async fn from_response(resp: reqwest::Response) -> Result { let status_code = resp.status(); if !status_code.is_success() { let body = resp.text().await?; anyhow::bail!("error code: {}, {}", status_code, body) } let bytes = resp.bytes().await?; let resp = serde_json::from_slice(&bytes)?; Ok(resp) } pub fn into_data(self) -> Result { match self.data { None => Err(AIError::InvalidRequest("Empty payload".to_string())), Some(data) => Ok(data), } } pub async fn stream_response( resp: reqwest::Response, ) -> Result>, AIError> { let status_code = resp.status(); if !status_code.is_success() { let body = resp.text().await?; return Err(AIError::InvalidRequest(body)); } let stream = resp .bytes_stream() .map(|item| item.map_err(|err| AIError::Internal(err.into()))); Ok(stream) } } impl From for AIError { fn from(error: reqwest::Error) -> Self { if error.is_timeout() { return AIError::RequestTimeout(error.to_string()); } if error.is_request() { return if error.status() == Some(StatusCode::PAYLOAD_TOO_LARGE) { AIError::PayloadTooLarge(error.to_string()) } else { AIError::InvalidRequest(format!("{:?}", error)) }; } AIError::Internal(error.into()) } }