AppFlowy-Cloud/src/api/ai.rs

161 lines
4.9 KiB
Rust

use crate::api::util::ai_model_from_header;
use crate::state::AppState;
use actix_web::web::{Data, Json};
use actix_web::{web, HttpRequest, HttpResponse, Scope};
use app_error::AppError;
use appflowy_ai_client::dto::{
CompleteTextResponse, LocalAIConfig, TranslateRowParams, TranslateRowResponse,
};
use futures_util::{stream, TryStreamExt};
use serde::Deserialize;
use shared_entity::dto::ai_dto::{
CompleteTextParams, SummarizeRowData, SummarizeRowParams, SummarizeRowResponse,
};
use shared_entity::response::{AppResponse, JsonAppResponse};
use tracing::{error, instrument, trace};
pub fn ai_completion_scope() -> Scope {
web::scope("/api/ai/{workspace_id}")
.service(web::resource("/complete").route(web::post().to(complete_text_handler)))
.service(web::resource("/complete/stream").route(web::post().to(stream_complete_text_handler)))
.service(web::resource("/summarize_row").route(web::post().to(summarize_row_handler)))
.service(web::resource("/translate_row").route(web::post().to(translate_row_handler)))
.service(web::resource("/local/config").route(web::get().to(local_ai_config_handler)))
}
async fn complete_text_handler(
state: Data<AppState>,
payload: Json<CompleteTextParams>,
req: HttpRequest,
) -> actix_web::Result<JsonAppResponse<CompleteTextResponse>> {
let ai_model = ai_model_from_header(&req);
let params = payload.into_inner();
let resp = state
.ai_client
.completion_text(&params.text, params.completion_type, ai_model)
.await
.map_err(|err| AppError::Internal(err.into()))?;
Ok(AppResponse::Ok().with_data(resp).into())
}
async fn stream_complete_text_handler(
state: Data<AppState>,
payload: Json<CompleteTextParams>,
req: HttpRequest,
) -> actix_web::Result<HttpResponse> {
let ai_model = ai_model_from_header(&req);
let params = payload.into_inner();
match state
.ai_client
.stream_completion_text(&params.text, params.completion_type, ai_model)
.await
{
Ok(stream) => Ok(
HttpResponse::Ok()
.content_type("text/event-stream")
.streaming(stream.map_err(AppError::from)),
),
Err(err) => Ok(
HttpResponse::Ok()
.content_type("text/event-stream")
.streaming(stream::once(async move {
Err(AppError::AIServiceUnavailable(err.to_string()))
})),
),
}
}
#[instrument(level = "debug", skip(state, payload), err)]
async fn summarize_row_handler(
state: Data<AppState>,
payload: Json<SummarizeRowParams>,
req: HttpRequest,
) -> actix_web::Result<Json<AppResponse<SummarizeRowResponse>>> {
let params = payload.into_inner();
match params.data {
SummarizeRowData::Identity { .. } => {
return Err(AppError::InvalidRequest("Identity data is not supported".to_string()).into());
},
SummarizeRowData::Content(content) => {
if content.is_empty() {
return Ok(
AppResponse::Ok()
.with_data(SummarizeRowResponse {
text: "No content".to_string(),
})
.into(),
);
}
let ai_model = ai_model_from_header(&req);
let result = state.ai_client.summarize_row(&content, ai_model).await;
let resp = match result {
Ok(resp) => SummarizeRowResponse { text: resp.text },
Err(err) => {
error!("Failed to summarize row: {:?}", err);
SummarizeRowResponse {
text: "No content".to_string(),
}
},
};
Ok(AppResponse::Ok().with_data(resp).into())
},
}
}
#[instrument(level = "debug", skip(state, payload), err)]
async fn translate_row_handler(
state: web::Data<AppState>,
payload: web::Json<TranslateRowParams>,
req: HttpRequest,
) -> actix_web::Result<Json<AppResponse<TranslateRowResponse>>> {
let params = payload.into_inner();
let ai_model = ai_model_from_header(&req);
match state.ai_client.translate_row(params.data, ai_model).await {
Ok(resp) => Ok(AppResponse::Ok().with_data(resp).into()),
Err(err) => {
error!("Failed to translate row: {:?}", err);
Ok(
AppResponse::Ok()
.with_data(TranslateRowResponse::default())
.into(),
)
},
}
}
#[derive(Deserialize, Debug)]
struct ConfigQuery {
platform: String,
app_version: Option<String>,
}
#[instrument(level = "debug", skip_all, err)]
async fn local_ai_config_handler(
state: web::Data<AppState>,
query: web::Query<ConfigQuery>,
) -> actix_web::Result<Json<AppResponse<LocalAIConfig>>> {
let query = query.into_inner();
trace!("query ai configuration: {:?}", query);
let platform = match query.platform.as_str() {
"macos" => "macos",
"linux" => "ubuntu",
"ubuntu" => "ubuntu",
"windows" => "windows",
_ => {
return Err(AppError::InvalidRequest("Invalid platform".to_string()).into());
},
};
let config = state
.ai_client
.get_local_ai_config(platform, query.app_version)
.await
.map_err(|err| AppError::AIServiceUnavailable(err.to_string()))?;
Ok(AppResponse::Ok().with_data(config).into())
}