412 lines
12 KiB
Rust
412 lines
12 KiB
Rust
use crate::dto::{
|
|
AIModel, CalculateSimilarityParams, ChatAnswer, ChatQuestion, CompleteTextParams,
|
|
CreateChatContext, Document, LocalAIConfig, MessageData, QuestionMetadata,
|
|
RepeatedLocalAIPackage, RepeatedRelatedQuestion, ResponseFormat, SearchDocumentsRequest,
|
|
SimilarityResponse, SummarizeRowResponse, TranslateRowData, TranslateRowResponse,
|
|
};
|
|
use crate::error::AIError;
|
|
|
|
use bytes::Bytes;
|
|
use futures::{Stream, StreamExt};
|
|
use reqwest;
|
|
use reqwest::{Method, RequestBuilder, StatusCode};
|
|
use serde::de::DeserializeOwned;
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::{Map, Value};
|
|
use std::borrow::Cow;
|
|
use std::time::Duration;
|
|
use tracing::{info, trace};
|
|
|
|
const AI_MODEL_HEADER_KEY: &str = "ai-model";
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct AppFlowyAIClient {
|
|
async_client: reqwest::Client,
|
|
url: String,
|
|
}
|
|
|
|
impl AppFlowyAIClient {
|
|
pub fn new(url: &str) -> Self {
|
|
info!("Creating AppFlowyAIClient with url: {}", url);
|
|
let url = url.to_string();
|
|
let async_client = reqwest::Client::new();
|
|
Self { async_client, url }
|
|
}
|
|
|
|
pub async fn health_check(&self) -> Result<(), AIError> {
|
|
let url = format!("{}/health", self.url);
|
|
let resp = self.async_http_client(Method::GET, &url)?.send().await?;
|
|
let text = resp.text().await?;
|
|
info!("health response: {:?}", text);
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn stream_completion_text(
|
|
&self,
|
|
params: CompleteTextParams,
|
|
model: AIModel,
|
|
) -> Result<impl Stream<Item = Result<Bytes, AIError>>, AIError> {
|
|
if params.text.is_empty() {
|
|
return Err(AIError::InvalidRequest("Empty text".to_string()));
|
|
}
|
|
|
|
let url = format!("{}/completion/stream", self.url);
|
|
let resp = self
|
|
.async_http_client(Method::POST, &url)?
|
|
.header(AI_MODEL_HEADER_KEY, model.to_str())
|
|
.json(¶ms)
|
|
.send()
|
|
.await?;
|
|
AIResponse::<()>::stream_response(resp).await
|
|
}
|
|
|
|
pub async fn summarize_row(
|
|
&self,
|
|
params: &Map<String, Value>,
|
|
model: AIModel,
|
|
) -> Result<SummarizeRowResponse, AIError> {
|
|
if params.is_empty() {
|
|
return Err(AIError::InvalidRequest("Empty content".to_string()));
|
|
}
|
|
|
|
let url = format!("{}/summarize_row", self.url);
|
|
trace!("summarize_row url: {}", url);
|
|
let resp = self
|
|
.async_http_client(Method::POST, &url)?
|
|
.header(AI_MODEL_HEADER_KEY, model.to_str())
|
|
.json(params)
|
|
.send()
|
|
.await?;
|
|
AIResponse::<SummarizeRowResponse>::from_reqwest_response(resp)
|
|
.await?
|
|
.into_data()
|
|
}
|
|
|
|
pub async fn translate_row(
|
|
&self,
|
|
data: TranslateRowData,
|
|
model: AIModel,
|
|
) -> Result<TranslateRowResponse, AIError> {
|
|
let url = format!("{}/translate_row", self.url);
|
|
let resp = self
|
|
.async_http_client(Method::POST, &url)?
|
|
.header(AI_MODEL_HEADER_KEY, model.to_str())
|
|
.json(&data)
|
|
.send()
|
|
.await?;
|
|
AIResponse::<TranslateRowResponse>::from_reqwest_response(resp)
|
|
.await?
|
|
.into_data()
|
|
}
|
|
|
|
pub async fn search_documents(
|
|
&self,
|
|
request: &SearchDocumentsRequest,
|
|
) -> Result<Vec<Document>, AIError> {
|
|
let url = format!("{}/search", self.url);
|
|
let resp = self
|
|
.async_http_client(Method::GET, &url)?
|
|
.query(&request)
|
|
.send()
|
|
.await?;
|
|
AIResponse::<Vec<Document>>::from_reqwest_response(resp)
|
|
.await?
|
|
.into_data()
|
|
}
|
|
|
|
pub async fn create_chat_text_context(&self, context: CreateChatContext) -> Result<(), AIError> {
|
|
let url = format!("{}/chat/context/text", self.url);
|
|
let resp = self
|
|
.async_http_client(Method::POST, &url)?
|
|
.json(&context)
|
|
.send()
|
|
.await?;
|
|
let _ = AIResponse::<()>::from_reqwest_response(resp).await?;
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn send_question(
|
|
&self,
|
|
workspace_id: &str,
|
|
chat_id: &str,
|
|
question_id: i64,
|
|
content: &str,
|
|
model: &AIModel,
|
|
metadata: Option<Value>,
|
|
) -> Result<ChatAnswer, AIError> {
|
|
let json = ChatQuestion {
|
|
chat_id: chat_id.to_string(),
|
|
data: MessageData {
|
|
content: content.to_string(),
|
|
metadata,
|
|
message_id: Some(question_id.to_string()),
|
|
},
|
|
format: Default::default(),
|
|
metadata: QuestionMetadata {
|
|
workspace_id: workspace_id.to_string(),
|
|
rag_ids: vec![],
|
|
},
|
|
};
|
|
let url = format!("{}/chat/message", self.url);
|
|
let resp = self
|
|
.async_http_client(Method::POST, &url)?
|
|
.header(AI_MODEL_HEADER_KEY, model.to_str())
|
|
.json(&json)
|
|
.send()
|
|
.await?;
|
|
AIResponse::<ChatAnswer>::from_reqwest_response(resp)
|
|
.await?
|
|
.into_data()
|
|
}
|
|
|
|
pub async fn stream_question(
|
|
&self,
|
|
workspace_id: String,
|
|
chat_id: &str,
|
|
content: &str,
|
|
metadata: Option<Value>,
|
|
rag_ids: Vec<String>,
|
|
model: &AIModel,
|
|
) -> Result<impl Stream<Item = Result<Bytes, AIError>>, AIError> {
|
|
let json = ChatQuestion {
|
|
chat_id: chat_id.to_string(),
|
|
data: MessageData {
|
|
content: content.to_string(),
|
|
metadata,
|
|
message_id: None,
|
|
},
|
|
format: Default::default(),
|
|
metadata: QuestionMetadata {
|
|
workspace_id,
|
|
rag_ids,
|
|
},
|
|
};
|
|
let url = format!("{}/chat/message/stream", self.url);
|
|
let resp = self
|
|
.async_http_client(Method::POST, &url)?
|
|
.header(AI_MODEL_HEADER_KEY, model.to_str())
|
|
.timeout(Duration::from_secs(30))
|
|
.json(&json)
|
|
.send()
|
|
.await?;
|
|
AIResponse::<()>::stream_response(resp).await
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub async fn stream_question_v2(
|
|
&self,
|
|
workspace_id: String,
|
|
chat_id: &str,
|
|
question_id: i64,
|
|
content: &str,
|
|
metadata: Option<Value>,
|
|
rag_ids: Vec<String>,
|
|
model: &AIModel,
|
|
) -> Result<impl Stream<Item = Result<Bytes, AIError>>, AIError> {
|
|
let json = ChatQuestion {
|
|
chat_id: chat_id.to_string(),
|
|
data: MessageData {
|
|
content: content.to_string(),
|
|
metadata,
|
|
message_id: Some(question_id.to_string()),
|
|
},
|
|
format: ResponseFormat::default(),
|
|
metadata: QuestionMetadata {
|
|
workspace_id,
|
|
rag_ids,
|
|
},
|
|
};
|
|
self.stream_question_v3(model, json, Some(30)).await
|
|
}
|
|
|
|
pub async fn stream_question_v3(
|
|
&self,
|
|
model: &AIModel,
|
|
question: ChatQuestion,
|
|
timeout_secs: Option<u64>,
|
|
) -> Result<impl Stream<Item = Result<Bytes, AIError>>, AIError> {
|
|
let url = format!("{}/v2/chat/message/stream", self.url);
|
|
let resp = self
|
|
.async_http_client(Method::POST, &url)?
|
|
.header(AI_MODEL_HEADER_KEY, model.to_str())
|
|
.json(&question)
|
|
.timeout(Duration::from_secs(timeout_secs.unwrap_or(30)))
|
|
.send()
|
|
.await?;
|
|
AIResponse::<()>::stream_response(resp).await
|
|
}
|
|
|
|
pub async fn get_related_question(
|
|
&self,
|
|
chat_id: &str,
|
|
message_id: &i64,
|
|
model: &AIModel,
|
|
) -> Result<RepeatedRelatedQuestion, AIError> {
|
|
let url = format!("{}/chat/{chat_id}/{message_id}/related_question", self.url);
|
|
let resp = self
|
|
.async_http_client(Method::GET, &url)?
|
|
.header(AI_MODEL_HEADER_KEY, model.to_str())
|
|
.timeout(Duration::from_secs(30))
|
|
.send()
|
|
.await?;
|
|
AIResponse::<RepeatedRelatedQuestion>::from_reqwest_response(resp)
|
|
.await?
|
|
.into_data()
|
|
}
|
|
|
|
pub async fn get_local_ai_package(
|
|
&self,
|
|
platform: &str,
|
|
) -> Result<RepeatedLocalAIPackage, AIError> {
|
|
let url = format!("{}/local_ai/plugin?platform={platform}", self.url);
|
|
let resp = self.async_http_client(Method::GET, &url)?.send().await?;
|
|
AIResponse::<RepeatedLocalAIPackage>::from_reqwest_response(resp)
|
|
.await?
|
|
.into_data()
|
|
}
|
|
|
|
pub async fn get_local_ai_config(
|
|
&self,
|
|
platform: &str,
|
|
app_version: Option<String>,
|
|
) -> Result<LocalAIConfig, AIError> {
|
|
// Start with the base URL including the platform parameter
|
|
let mut url = format!("{}/local_ai/config?platform={}", self.url, platform);
|
|
|
|
// If app_version is provided, append it as a query parameter
|
|
if let Some(version) = app_version {
|
|
url = format!("{}&app_version={}", url, version);
|
|
}
|
|
|
|
let resp = self.async_http_client(Method::GET, &url)?.send().await?;
|
|
AIResponse::<LocalAIConfig>::from_reqwest_response(resp)
|
|
.await?
|
|
.into_data()
|
|
}
|
|
|
|
pub async fn calculate_similarity(
|
|
&self,
|
|
params: CalculateSimilarityParams,
|
|
) -> Result<SimilarityResponse, AIError> {
|
|
let url = format!("{}/similarity", self.url);
|
|
let resp = self
|
|
.async_http_client(Method::POST, &url)?
|
|
.json(¶ms)
|
|
.send()
|
|
.await?;
|
|
AIResponse::<SimilarityResponse>::from_reqwest_response(resp)
|
|
.await?
|
|
.into_data()
|
|
}
|
|
|
|
fn async_http_client(&self, method: Method, url: &str) -> Result<RequestBuilder, AIError> {
|
|
let request_builder = self.async_client.request(method, url);
|
|
Ok(request_builder)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct AIResponse<T> {
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub data: Option<T>,
|
|
|
|
#[serde(default)]
|
|
pub message: Cow<'static, str>,
|
|
}
|
|
|
|
impl<T> AIResponse<T>
|
|
where
|
|
T: DeserializeOwned + 'static,
|
|
{
|
|
pub async fn from_reqwest_response(resp: reqwest::Response) -> Result<Self, anyhow::Error> {
|
|
let status_code = resp.status();
|
|
if !status_code.is_success() {
|
|
let body = resp.text().await?;
|
|
anyhow::bail!("error code: {}, {}", status_code, body)
|
|
}
|
|
|
|
let bytes = resp.bytes().await?;
|
|
let resp = serde_json::from_slice(&bytes)?;
|
|
Ok(resp)
|
|
}
|
|
|
|
pub fn from_ur_response(resp: ureq::Response) -> Result<Self, anyhow::Error> {
|
|
let status_code = resp.status();
|
|
if status_code != 200 {
|
|
let body = resp.into_string()?;
|
|
anyhow::bail!("error code: {}, {}", status_code, body)
|
|
}
|
|
|
|
let resp = resp.into_json()?;
|
|
Ok(resp)
|
|
}
|
|
|
|
pub fn into_data(self) -> Result<T, AIError> {
|
|
match self.data {
|
|
None => Err(AIError::InvalidRequest("Empty payload".to_string())),
|
|
Some(data) => Ok(data),
|
|
}
|
|
}
|
|
|
|
pub async fn stream_response(
|
|
resp: reqwest::Response,
|
|
) -> Result<impl Stream<Item = Result<Bytes, AIError>>, AIError> {
|
|
let status_code = resp.status();
|
|
if status_code.is_server_error() {
|
|
let body = resp.text().await?;
|
|
return Err(AIError::ServiceUnavailable(body));
|
|
}
|
|
|
|
if !status_code.is_success() {
|
|
let body = resp.text().await?;
|
|
return Err(AIError::InvalidRequest(body));
|
|
}
|
|
let stream = resp
|
|
.bytes_stream()
|
|
.map(|item| item.map_err(|err| AIError::Internal(err.into())));
|
|
Ok(stream)
|
|
}
|
|
}
|
|
impl From<reqwest::Error> for AIError {
|
|
fn from(error: reqwest::Error) -> Self {
|
|
if error.is_connect() {
|
|
return AIError::ServiceUnavailable(error.to_string());
|
|
}
|
|
|
|
if error.is_timeout() {
|
|
return AIError::RequestTimeout(error.to_string());
|
|
}
|
|
|
|
// Handle request-related errors
|
|
if let Some(status_code) = error.status() {
|
|
if error.is_request() {
|
|
match status_code {
|
|
StatusCode::PAYLOAD_TOO_LARGE => {
|
|
return AIError::PayloadTooLarge(error.to_string());
|
|
},
|
|
status_code if status_code.is_server_error() => {
|
|
return AIError::ServiceUnavailable(error.to_string());
|
|
},
|
|
_ => {
|
|
return AIError::InvalidRequest(format!("{:?}", error));
|
|
},
|
|
}
|
|
}
|
|
}
|
|
AIError::Internal(error.into())
|
|
}
|
|
}
|
|
|
|
pub async fn collect_stream_text(stream: impl Stream<Item = Result<Bytes, AIError>>) -> String {
|
|
let stream = stream.map(|item| {
|
|
item.map(|bytes| {
|
|
String::from_utf8(bytes.to_vec())
|
|
.map(|s| s.replace('\n', ""))
|
|
.unwrap()
|
|
})
|
|
});
|
|
|
|
let lines: Vec<String> = stream.map(|message| message.unwrap()).collect().await;
|
|
lines.join("")
|
|
}
|