472 lines
15 KiB
Rust
472 lines
15 KiB
Rust
use crate::biz::chat::ops::{
|
|
create_chat, create_chat_message, create_chat_message_stream, delete_chat,
|
|
extract_chat_message_metadata, generate_chat_message_answer, get_chat_messages,
|
|
update_chat_message,
|
|
};
|
|
use crate::state::AppState;
|
|
use actix_web::web::{Data, Json};
|
|
use actix_web::{web, HttpRequest, HttpResponse, Scope};
|
|
|
|
use app_error::AppError;
|
|
use appflowy_ai_client::dto::{CreateTextChatContext, RepeatedRelatedQuestion};
|
|
use authentication::jwt::UserUuid;
|
|
use bytes::Bytes;
|
|
use database_entity::dto::{
|
|
ChatAuthor, ChatMessage, CreateAnswerMessageParams, CreateChatMessageParams, CreateChatParams,
|
|
GetChatMessageParams, MessageCursor, RepeatedChatMessage, UpdateChatMessageContentParams,
|
|
};
|
|
use futures::Stream;
|
|
use futures_util::stream;
|
|
use futures_util::{FutureExt, TryStreamExt};
|
|
use pin_project::pin_project;
|
|
use shared_entity::response::{AppResponse, JsonAppResponse};
|
|
use std::collections::HashMap;
|
|
use std::pin::Pin;
|
|
use std::task::{Context, Poll};
|
|
use tokio::sync::oneshot;
|
|
use tokio::task;
|
|
|
|
use database::chat;
|
|
|
|
use crate::api::util::ai_model_from_header;
|
|
|
|
use database::chat::chat_ops::insert_answer_message;
|
|
use tracing::{instrument, trace};
|
|
use validator::Validate;
|
|
|
|
pub fn chat_scope() -> Scope {
|
|
web::scope("/api/chat/{workspace_id}")
|
|
.service(web::resource("").route(web::post().to(create_chat_handler)))
|
|
.service(
|
|
web::resource("/{chat_id}")
|
|
.route(web::delete().to(delete_chat_handler))
|
|
.route(web::get().to(get_chat_message_handler)),
|
|
)
|
|
.service(
|
|
web::resource("/{chat_id}/{message_id}/related_question")
|
|
.route(web::get().to(get_related_message_handler)),
|
|
)
|
|
.service(
|
|
web::resource("/{chat_id}/message")
|
|
// create_chat_message_handler is deprecated. No long used after frontend application v0.6.2
|
|
.route(web::post().to(create_chat_message_handler))
|
|
.route(web::put().to(update_chat_message_handler)),
|
|
)
|
|
.service(
|
|
// Creating a [ChatMessage] for given content.
|
|
// When client asks a question, it will use this API to create a chat message
|
|
web::resource("/{chat_id}/message/question").route(web::post().to(create_question_handler)),
|
|
)
|
|
// Writing the final answer for a given chat.
|
|
// After the streaming is finished, the client will use this API to save the message to disk.
|
|
.service(web::resource("/{chat_id}/message/answer").route(web::post().to(save_answer_handler)))
|
|
.service(
|
|
// Use AI to generate a response for a specified message ID.
|
|
// To generate an answer for a given question, use "/answer/stream" to receive the answer in a stream.
|
|
web::resource("/{chat_id}/{message_id}/answer").route(web::get().to(answer_handler)),
|
|
)
|
|
// Use AI to generate a response for a specified message ID. This response will be return as a stream.
|
|
.service(
|
|
web::resource("/{chat_id}/{message_id}/answer/stream")
|
|
.route(web::get().to(answer_stream_handler)),
|
|
)
|
|
.service(
|
|
web::resource("/{chat_id}/{message_id}/v2/answer/stream")
|
|
.route(web::get().to(answer_stream_v2_handler)),
|
|
)
|
|
.service(
|
|
// Create chat context for a given chat.
|
|
web::resource("/{chat_id}/context/text")
|
|
.route(web::post().to(create_chat_context_handler))
|
|
)
|
|
}
|
|
async fn create_chat_handler(
|
|
path: web::Path<String>,
|
|
state: Data<AppState>,
|
|
payload: Json<CreateChatParams>,
|
|
) -> actix_web::Result<JsonAppResponse<()>> {
|
|
let workspace_id = path.into_inner();
|
|
let params = payload.into_inner();
|
|
trace!("create new chat: {:?}", params);
|
|
create_chat(&state.pg_pool, params, &workspace_id).await?;
|
|
Ok(AppResponse::Ok().into())
|
|
}
|
|
|
|
async fn delete_chat_handler(
|
|
path: web::Path<(String, String)>,
|
|
state: Data<AppState>,
|
|
) -> actix_web::Result<JsonAppResponse<()>> {
|
|
let (_workspace_id, chat_id) = path.into_inner();
|
|
delete_chat(&state.pg_pool, &chat_id).await?;
|
|
Ok(AppResponse::Ok().into())
|
|
}
|
|
|
|
#[instrument(level = "info", skip_all, err)]
|
|
async fn create_chat_message_handler(
|
|
state: Data<AppState>,
|
|
path: web::Path<(String, String)>,
|
|
payload: Json<CreateChatMessageParams>,
|
|
uuid: UserUuid,
|
|
req: HttpRequest,
|
|
) -> actix_web::Result<HttpResponse> {
|
|
let (_workspace_id, chat_id) = path.into_inner();
|
|
let params = payload.into_inner();
|
|
|
|
if let Err(err) = params.validate() {
|
|
return Ok(HttpResponse::from_error(AppError::from(err)));
|
|
}
|
|
|
|
let ai_model = ai_model_from_header(&req);
|
|
let uid = state.user_cache.get_user_uid(&uuid).await?;
|
|
let message_stream = create_chat_message_stream(
|
|
&state.pg_pool,
|
|
uid,
|
|
chat_id,
|
|
params,
|
|
state.ai_client.clone(),
|
|
ai_model,
|
|
)
|
|
.await;
|
|
|
|
Ok(
|
|
HttpResponse::Ok()
|
|
.content_type("application/json")
|
|
.streaming(message_stream),
|
|
)
|
|
}
|
|
|
|
#[instrument(level = "debug", skip_all, err)]
|
|
async fn create_chat_context_handler(
|
|
state: Data<AppState>,
|
|
payload: Json<CreateTextChatContext>,
|
|
) -> actix_web::Result<JsonAppResponse<()>> {
|
|
let params = payload.into_inner();
|
|
state
|
|
.ai_client
|
|
.create_chat_text_context(params)
|
|
.await
|
|
.map_err(AppError::from)?;
|
|
Ok(AppResponse::Ok().into())
|
|
}
|
|
|
|
async fn update_chat_message_handler(
|
|
state: Data<AppState>,
|
|
payload: Json<UpdateChatMessageContentParams>,
|
|
req: HttpRequest,
|
|
) -> actix_web::Result<JsonAppResponse<()>> {
|
|
let params = payload.into_inner();
|
|
let ai_model = ai_model_from_header(&req);
|
|
update_chat_message(&state.pg_pool, params, state.ai_client.clone(), ai_model).await?;
|
|
Ok(AppResponse::Ok().into())
|
|
}
|
|
|
|
async fn get_related_message_handler(
|
|
path: web::Path<(String, String, i64)>,
|
|
state: Data<AppState>,
|
|
req: HttpRequest,
|
|
) -> actix_web::Result<JsonAppResponse<RepeatedRelatedQuestion>> {
|
|
let (_workspace_id, chat_id, message_id) = path.into_inner();
|
|
let ai_model = ai_model_from_header(&req);
|
|
let resp = state
|
|
.ai_client
|
|
.get_related_question(&chat_id, &message_id, &ai_model)
|
|
.await
|
|
.map_err(|err| AppError::Internal(err.into()))?;
|
|
Ok(AppResponse::Ok().with_data(resp).into())
|
|
}
|
|
|
|
#[instrument(level = "debug", skip_all, err)]
|
|
async fn create_question_handler(
|
|
state: Data<AppState>,
|
|
path: web::Path<(String, String)>,
|
|
payload: Json<CreateChatMessageParams>,
|
|
uuid: UserUuid,
|
|
) -> actix_web::Result<JsonAppResponse<ChatMessage>> {
|
|
let (_workspace_id, chat_id) = path.into_inner();
|
|
let mut params = payload.into_inner();
|
|
|
|
// When create a question, we will extract the metadata from the question content.
|
|
// metadata might include user mention file,page,or user. For example, @Get started.
|
|
for extract_context in extract_chat_message_metadata(&mut params) {
|
|
let context = CreateTextChatContext::new(
|
|
chat_id.clone(),
|
|
extract_context.content_type,
|
|
extract_context.content,
|
|
)
|
|
.with_metadata(extract_context.metadata);
|
|
trace!("create context for question: {}", context);
|
|
state
|
|
.ai_client
|
|
.create_chat_text_context(context)
|
|
.await
|
|
.map_err(AppError::from)?;
|
|
}
|
|
|
|
let uid = state.user_cache.get_user_uid(&uuid).await?;
|
|
let resp = create_chat_message(&state.pg_pool, uid, chat_id, params).await?;
|
|
Ok(AppResponse::Ok().with_data(resp).into())
|
|
}
|
|
|
|
async fn save_answer_handler(
|
|
path: web::Path<(String, String)>,
|
|
payload: Json<CreateAnswerMessageParams>,
|
|
state: Data<AppState>,
|
|
) -> actix_web::Result<JsonAppResponse<ChatMessage>> {
|
|
let payload = payload.into_inner();
|
|
payload.validate().map_err(AppError::from)?;
|
|
|
|
let (_workspace_id, chat_id) = path.into_inner();
|
|
let message = insert_answer_message(
|
|
&state.pg_pool,
|
|
ChatAuthor::ai(),
|
|
&chat_id,
|
|
payload.content,
|
|
payload.metadata,
|
|
payload.question_message_id,
|
|
)
|
|
.await?;
|
|
|
|
Ok(AppResponse::Ok().with_data(message).into())
|
|
}
|
|
async fn answer_handler(
|
|
path: web::Path<(String, String, i64)>,
|
|
state: Data<AppState>,
|
|
req: HttpRequest,
|
|
) -> actix_web::Result<JsonAppResponse<ChatMessage>> {
|
|
let (_workspace_id, chat_id, message_id) = path.into_inner();
|
|
let ai_model = ai_model_from_header(&req);
|
|
let message = generate_chat_message_answer(
|
|
&state.pg_pool,
|
|
state.ai_client.clone(),
|
|
message_id,
|
|
&chat_id,
|
|
ai_model,
|
|
)
|
|
.await?;
|
|
Ok(AppResponse::Ok().with_data(message).into())
|
|
}
|
|
|
|
#[instrument(level = "debug", skip_all, err)]
|
|
async fn answer_stream_handler(
|
|
path: web::Path<(String, String, i64)>,
|
|
state: Data<AppState>,
|
|
req: HttpRequest,
|
|
) -> actix_web::Result<HttpResponse> {
|
|
let (_workspace_id, chat_id, question_id) = path.into_inner();
|
|
let content = chat::chat_ops::select_chat_message_content(&state.pg_pool, question_id).await?;
|
|
let ai_model = ai_model_from_header(&req);
|
|
match state
|
|
.ai_client
|
|
.stream_question(&chat_id, &content, &ai_model)
|
|
.await
|
|
{
|
|
Ok(answer_stream) => {
|
|
let new_answer_stream = answer_stream.map_err(AppError::from);
|
|
Ok(
|
|
HttpResponse::Ok()
|
|
.content_type("text/event-stream")
|
|
.streaming(new_answer_stream),
|
|
)
|
|
},
|
|
Err(err) => Ok(
|
|
HttpResponse::Ok()
|
|
.content_type("text/event-stream")
|
|
.streaming(stream::once(async move {
|
|
Err(AppError::AIServiceUnavailable(err.to_string()))
|
|
})),
|
|
),
|
|
}
|
|
}
|
|
|
|
#[instrument(level = "debug", skip_all, err)]
|
|
async fn answer_stream_v2_handler(
|
|
path: web::Path<(String, String, i64)>,
|
|
state: Data<AppState>,
|
|
req: HttpRequest,
|
|
) -> actix_web::Result<HttpResponse> {
|
|
let (_workspace_id, chat_id, question_id) = path.into_inner();
|
|
let content = chat::chat_ops::select_chat_message_content(&state.pg_pool, question_id).await?;
|
|
let ai_model = ai_model_from_header(&req);
|
|
match state
|
|
.ai_client
|
|
.stream_question_v2(&chat_id, &content, &ai_model)
|
|
.await
|
|
{
|
|
Ok(answer_stream) => {
|
|
let new_answer_stream = answer_stream.map_err(AppError::from);
|
|
Ok(
|
|
HttpResponse::Ok()
|
|
.content_type("text/event-stream")
|
|
.streaming(new_answer_stream),
|
|
)
|
|
},
|
|
Err(err) => Ok(
|
|
HttpResponse::Ok()
|
|
.content_type("text/event-stream")
|
|
.streaming(stream::once(async move {
|
|
Err(AppError::AIServiceUnavailable(err.to_string()))
|
|
})),
|
|
),
|
|
}
|
|
}
|
|
|
|
#[instrument(level = "debug", skip_all, err)]
|
|
async fn get_chat_message_handler(
|
|
path: web::Path<(String, String)>,
|
|
query: web::Query<HashMap<String, String>>,
|
|
state: Data<AppState>,
|
|
) -> actix_web::Result<JsonAppResponse<RepeatedChatMessage>> {
|
|
let mut params = GetChatMessageParams {
|
|
cursor: MessageCursor::Offset(0),
|
|
limit: query
|
|
.get("limit")
|
|
.and_then(|s| s.parse::<u64>().ok())
|
|
.unwrap_or(10),
|
|
};
|
|
if let Some(value) = query.get("offset").and_then(|s| s.parse::<u64>().ok()) {
|
|
params.cursor = MessageCursor::Offset(value);
|
|
} else if let Some(value) = query.get("after").and_then(|s| s.parse::<i64>().ok()) {
|
|
params.cursor = MessageCursor::AfterMessageId(value);
|
|
} else if let Some(value) = query.get("before").and_then(|s| s.parse::<i64>().ok()) {
|
|
params.cursor = MessageCursor::BeforeMessageId(value);
|
|
} else {
|
|
params.cursor = MessageCursor::NextBack;
|
|
}
|
|
|
|
trace!("get chat messages: {:?}", params);
|
|
let (_workspace_id, chat_id) = path.into_inner();
|
|
let messages = get_chat_messages(&state.pg_pool, params, &chat_id).await?;
|
|
Ok(AppResponse::Ok().with_data(messages).into())
|
|
}
|
|
|
|
#[pin_project]
|
|
pub struct FinalAnswerStream<S, F> {
|
|
#[pin]
|
|
stream: S,
|
|
buffer: Vec<u8>,
|
|
action: Option<F>,
|
|
}
|
|
|
|
impl<S, F> FinalAnswerStream<S, F> {
|
|
pub fn new(stream: S, action: F) -> Self {
|
|
FinalAnswerStream {
|
|
stream,
|
|
buffer: Vec::new(),
|
|
action: Some(action),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<S, F> Stream for FinalAnswerStream<S, F>
|
|
where
|
|
S: Stream<Item = Result<String, AppError>>,
|
|
F: FnOnce(Vec<u8>),
|
|
{
|
|
type Item = Result<Bytes, AppError>;
|
|
|
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
|
let this = self.project();
|
|
|
|
match this.stream.poll_next(cx) {
|
|
Poll::Ready(Some(Ok(item))) => {
|
|
let bytes = item.into_bytes();
|
|
this.buffer.extend_from_slice(&bytes);
|
|
Poll::Ready(Some(Ok(Bytes::from(bytes))))
|
|
},
|
|
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
|
|
Poll::Ready(None) => {
|
|
if let Some(action) = this.action.take() {
|
|
action(std::mem::take(this.buffer));
|
|
}
|
|
Poll::Ready(None)
|
|
},
|
|
Poll::Pending => Poll::Pending,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
#[pin_project]
|
|
pub struct CollectingStream<S, F> {
|
|
#[pin]
|
|
stream: S,
|
|
buffer: Vec<u8>,
|
|
action: Option<F>,
|
|
state: CollectingStreamType,
|
|
result_receiver: Option<oneshot::Receiver<Result<Bytes, AppError>>>,
|
|
}
|
|
|
|
enum CollectingStreamType {
|
|
AnswerString,
|
|
AnswerMessage,
|
|
}
|
|
|
|
impl<S, F> CollectingStream<S, F> {
|
|
pub fn new(stream: S, action: F) -> Self {
|
|
CollectingStream {
|
|
stream,
|
|
buffer: Vec::new(),
|
|
action: Some(action),
|
|
state: CollectingStreamType::AnswerString,
|
|
result_receiver: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<S, F> Stream for CollectingStream<S, F>
|
|
where
|
|
S: Stream<Item = Result<Bytes, AppError>>,
|
|
F: FnOnce(Vec<u8>) -> task::JoinHandle<Result<Bytes, AppError>> + Send + 'static,
|
|
{
|
|
type Item = Result<Bytes, AppError>;
|
|
|
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
|
let mut this = self.project();
|
|
match this.state {
|
|
CollectingStreamType::AnswerString => {
|
|
match this.stream.as_mut().poll_next(cx) {
|
|
Poll::Ready(Some(Ok(bytes))) => {
|
|
this.buffer.extend_from_slice(&bytes);
|
|
Poll::Ready(Some(Ok(bytes)))
|
|
},
|
|
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
|
|
Poll::Ready(None) => {
|
|
if let Some(action) = this.action.take() {
|
|
let buffer = std::mem::take(this.buffer);
|
|
let (sender, receiver) = oneshot::channel();
|
|
*this.result_receiver = Some(receiver);
|
|
*this.state = CollectingStreamType::AnswerMessage;
|
|
|
|
// Spawn the async task to handle the buffer and send the result
|
|
tokio::spawn(async move {
|
|
let result = action(buffer).await;
|
|
sender.send(result.unwrap()).unwrap();
|
|
});
|
|
|
|
Poll::Ready(Some(Ok(Bytes::from(""))))
|
|
} else {
|
|
// If action is None, it means the stream is finished.
|
|
Poll::Ready(None)
|
|
}
|
|
},
|
|
Poll::Pending => Poll::Pending,
|
|
}
|
|
},
|
|
CollectingStreamType::AnswerMessage => {
|
|
if let Some(receiver) = this.result_receiver.as_mut() {
|
|
match receiver.poll_unpin(cx) {
|
|
Poll::Ready(Ok(_)) => {
|
|
this.result_receiver.take();
|
|
Poll::Ready(None)
|
|
},
|
|
Poll::Ready(Err(_)) => Poll::Ready(None),
|
|
Poll::Pending => Poll::Pending,
|
|
}
|
|
} else {
|
|
Poll::Ready(None)
|
|
}
|
|
},
|
|
}
|
|
}
|
|
}
|