diff --git a/libs/app-error/src/lib.rs b/libs/app-error/src/lib.rs index 54b88a00..1367e3d7 100644 --- a/libs/app-error/src/lib.rs +++ b/libs/app-error/src/lib.rs @@ -119,6 +119,9 @@ pub enum AppError { #[error("{0}")] PublishNamespaceAlreadyTaken(String), + + #[error("{0}")] + AIServiceUnavailable(String), } impl AppError { @@ -178,6 +181,7 @@ impl AppError { AppError::OverrideWithIncorrectData(_) => ErrorCode::OverrideWithIncorrectData, AppError::Utf8Error(_) => ErrorCode::Internal, AppError::PublishNamespaceAlreadyTaken(_) => ErrorCode::PublishNamespaceAlreadyTaken, + AppError::AIServiceUnavailable(_) => ErrorCode::AIServiceUnavailable, } } } @@ -288,6 +292,7 @@ pub enum ErrorCode { OverrideWithIncorrectData = 1029, PublishNamespaceNotSet = 1030, PublishNamespaceAlreadyTaken = 1031, + AIServiceUnavailable = 1032, } impl ErrorCode { diff --git a/src/api/ai.rs b/src/api/ai.rs index 8248dd44..e6ad6057 100644 --- a/src/api/ai.rs +++ b/src/api/ai.rs @@ -1,10 +1,12 @@ use crate::api::util::ai_model_from_header; use crate::state::AppState; + use actix_web::web::{Data, Json}; use actix_web::{web, HttpRequest, HttpResponse, Scope}; use app_error::AppError; use appflowy_ai_client::dto::{CompleteTextResponse, TranslateRowParams, TranslateRowResponse}; -use futures_util::TryStreamExt; + +use futures_util::{stream, TryStreamExt}; use shared_entity::dto::ai_dto::{ CompleteTextParams, SummarizeRowData, SummarizeRowParams, SummarizeRowResponse, }; @@ -42,16 +44,24 @@ async fn stream_complete_text_handler( ) -> actix_web::Result { let ai_model = ai_model_from_header(&req); let params = payload.into_inner(); - let stream = state + match state .ai_client .stream_completion_text(¶ms.text, params.completion_type, ai_model) .await - .map_err(|err| AppError::Internal(err.into()))?; - Ok( - HttpResponse::Ok() - .content_type("text/event-stream") - .streaming(stream.map_err(AppError::from)), - ) + { + Ok(stream) => Ok( + HttpResponse::Ok() + .content_type("text/event-stream") + .streaming(stream.map_err(AppError::from)), + ), + Err(err) => Ok( + HttpResponse::Ok() + .content_type("text/event-stream") + .streaming(stream::once(async move { + Err(AppError::AIServiceUnavailable(err.to_string())) + })), + ), + } } #[instrument(level = "debug", skip(state, payload), err)] diff --git a/src/api/chat.rs b/src/api/chat.rs index 880c0e06..a9f42bed 100644 --- a/src/api/chat.rs +++ b/src/api/chat.rs @@ -15,12 +15,12 @@ use database_entity::dto::{ GetChatMessageParams, MessageCursor, RepeatedChatMessage, UpdateChatMessageContentParams, }; use futures::Stream; +use futures_util::stream; use futures_util::{FutureExt, TryStreamExt}; use pin_project::pin_project; use shared_entity::response::{AppResponse, JsonAppResponse}; use std::collections::HashMap; use std::pin::Pin; - use std::task::{Context, Poll}; use tokio::sync::oneshot; use tokio::task; @@ -28,6 +28,7 @@ use tokio::task; use database::chat; use crate::api::util::ai_model_from_header; + use database::chat::chat_ops::insert_answer_message; use tracing::{instrument, trace}; use validator::Validate; @@ -208,18 +209,27 @@ async fn answer_stream_handler( let (_workspace_id, chat_id, question_id) = path.into_inner(); let content = chat::chat_ops::select_chat_message_content(&state.pg_pool, question_id).await?; let ai_model = ai_model_from_header(&req); - let answer_stream = state + match state .ai_client .stream_question(&chat_id, &content, &ai_model) .await - .map_err(|err| AppError::Internal(err.into()))?; - - let new_answer_stream = answer_stream.map_err(AppError::from); - Ok( - HttpResponse::Ok() - .content_type("text/event-stream") - .streaming(new_answer_stream), - ) + { + Ok(answer_stream) => { + let new_answer_stream = answer_stream.map_err(AppError::from); + Ok( + HttpResponse::Ok() + .content_type("text/event-stream") + .streaming(new_answer_stream), + ) + }, + Err(err) => Ok( + HttpResponse::Ok() + .content_type("text/event-stream") + .streaming(stream::once(async move { + Err(AppError::AIServiceUnavailable(err.to_string())) + })), + ), + } } #[instrument(level = "debug", skip_all, err)]