feat: stream chat message (#577)
* chore: save author id * chore: stream response * chore: stream chat message
This commit is contained in:
parent
ae3e075475
commit
559d924cd1
|
|
@ -5487,6 +5487,7 @@ dependencies = [
|
|||
"chrono",
|
||||
"collab-entity",
|
||||
"database-entity",
|
||||
"futures",
|
||||
"gotrue-entity",
|
||||
"reqwest 0.11.27",
|
||||
"rust-s3",
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
use crate::http::log_request_id;
|
||||
use crate::Client;
|
||||
use database_entity::dto::{
|
||||
CreateChatMessageParams, CreateChatParams, MessageCursor, QAChatMessage, RepeatedChatMessage,
|
||||
ChatMessage, CreateChatMessageParams, CreateChatParams, MessageCursor, RepeatedChatMessage,
|
||||
};
|
||||
use futures_core::Stream;
|
||||
use reqwest::Method;
|
||||
use shared_entity::response::{AppResponse, AppResponseError};
|
||||
|
||||
|
|
@ -42,7 +43,7 @@ impl Client {
|
|||
workspace_id: &str,
|
||||
chat_id: &str,
|
||||
params: CreateChatMessageParams,
|
||||
) -> Result<QAChatMessage, AppResponseError> {
|
||||
) -> Result<impl Stream<Item = Result<ChatMessage, AppResponseError>>, AppResponseError> {
|
||||
let url = format!(
|
||||
"{}/api/chat/{workspace_id}/{chat_id}/message",
|
||||
self.base_url
|
||||
|
|
@ -54,9 +55,7 @@ impl Client {
|
|||
.send()
|
||||
.await?;
|
||||
log_request_id(&resp);
|
||||
AppResponse::<QAChatMessage>::from_response(resp)
|
||||
.await?
|
||||
.into_data()
|
||||
AppResponse::<ChatMessage>::stream_response(resp).await
|
||||
}
|
||||
|
||||
pub async fn get_chat_messages(
|
||||
|
|
|
|||
|
|
@ -653,16 +653,38 @@ pub struct RepeatedChatMessage {
|
|||
|
||||
#[derive(Debug, Default, Clone, Serialize_repr, Deserialize_repr)]
|
||||
#[repr(u8)]
|
||||
pub enum ChatAuthor {
|
||||
#[default]
|
||||
pub enum ChatAuthorType {
|
||||
Unknown = 0,
|
||||
Human = 1,
|
||||
#[default]
|
||||
System = 2,
|
||||
AI = 3,
|
||||
}
|
||||
|
||||
impl From<serde_json::Value> for ChatAuthor {
|
||||
fn from(value: serde_json::Value) -> Self {
|
||||
serde_json::from_value::<ChatAuthor>(value).unwrap_or_default()
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatAuthor {
|
||||
pub author_id: i64,
|
||||
#[serde(default)]
|
||||
pub author_type: ChatAuthorType,
|
||||
#[serde(default)]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl ChatAuthor {
|
||||
pub fn new(author_id: i64, author_type: ChatAuthorType) -> Self {
|
||||
Self {
|
||||
author_id,
|
||||
author_type,
|
||||
meta: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ai() -> Self {
|
||||
Self {
|
||||
author_id: 0,
|
||||
author_type: ChatAuthorType::AI,
|
||||
meta: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ use sqlx::postgres::PgArguments;
|
|||
use sqlx::{Arguments, Executor, Postgres, Transaction};
|
||||
use std::ops::DerefMut;
|
||||
use std::str::FromStr;
|
||||
use tracing::warn;
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
|
|
@ -209,11 +210,19 @@ pub async fn select_chat_messages(
|
|||
|
||||
let mut messages = rows
|
||||
.into_iter()
|
||||
.map(|(message_id, content, created_at, author)| ChatMessage {
|
||||
author: serde_json::from_value::<ChatAuthor>(author).unwrap_or_default(),
|
||||
message_id,
|
||||
content,
|
||||
created_at,
|
||||
.flat_map(|(message_id, content, created_at, author)| {
|
||||
match serde_json::from_value::<ChatAuthor>(author) {
|
||||
Ok(author) => Some(ChatMessage {
|
||||
author,
|
||||
message_id,
|
||||
content,
|
||||
created_at,
|
||||
}),
|
||||
Err(err) => {
|
||||
warn!("Failed to deserialize author: {}", err);
|
||||
None
|
||||
},
|
||||
}
|
||||
})
|
||||
.collect::<Vec<ChatMessage>>();
|
||||
|
||||
|
|
@ -280,8 +289,8 @@ pub async fn get_all_chat_messages<'a, E: Executor<'a, Database = Postgres>>(
|
|||
chat_id: &str,
|
||||
) -> Result<Vec<ChatMessage>, AppError> {
|
||||
let chat_id = Uuid::from_str(chat_id)?;
|
||||
let messages: Vec<ChatMessage> = sqlx::query_as!(
|
||||
ChatMessage,
|
||||
let rows = sqlx::query!(
|
||||
// ChatMessage,
|
||||
r#"
|
||||
SELECT message_id, content, created_at, author
|
||||
FROM af_chat_messages
|
||||
|
|
@ -292,5 +301,24 @@ pub async fn get_all_chat_messages<'a, E: Executor<'a, Database = Postgres>>(
|
|||
)
|
||||
.fetch_all(executor)
|
||||
.await?;
|
||||
|
||||
let messages = rows
|
||||
.into_iter()
|
||||
.flat_map(
|
||||
|row| match serde_json::from_value::<ChatAuthor>(row.author) {
|
||||
Ok(author) => Some(ChatMessage {
|
||||
author,
|
||||
message_id: row.message_id,
|
||||
content: row.content,
|
||||
created_at: row.created_at,
|
||||
}),
|
||||
Err(err) => {
|
||||
warn!("Failed to deserialize author: {}", err);
|
||||
None
|
||||
},
|
||||
},
|
||||
)
|
||||
.collect::<Vec<ChatMessage>>();
|
||||
|
||||
Ok(messages)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ serde = "1.0.195"
|
|||
serde_json.workspace = true
|
||||
serde_repr = "0.1.18"
|
||||
thiserror = "1.0.56"
|
||||
reqwest.workspace = true
|
||||
reqwest = { worspace = true, features = ["stream"] }
|
||||
uuid = { version = "1.6.1", features = ["v4"] }
|
||||
gotrue-entity = { path = "../gotrue-entity" }
|
||||
database-entity.workspace = true
|
||||
|
|
@ -25,6 +25,7 @@ appflowy-ai-client = { workspace = true, default-features = false, features = ["
|
|||
actix-web = { version = "4.4.1", default-features = false, features = ["http2"], optional = true }
|
||||
validator = { version = "0.16", features = ["validator_derive", "derive"], optional = true }
|
||||
rust-s3 = { version = "0.34.0-rc4", optional = true }
|
||||
futures = "0.3.30"
|
||||
|
||||
|
||||
[features]
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ use std::borrow::Cow;
|
|||
|
||||
use app_error::AppError;
|
||||
pub use app_error::ErrorCode;
|
||||
use futures::stream::{Stream, StreamExt};
|
||||
use std::fmt::{Debug, Display};
|
||||
|
||||
#[cfg(feature = "cloud")]
|
||||
|
|
@ -141,6 +142,23 @@ where
|
|||
let resp = serde_json::from_slice(&bytes)?;
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
pub async fn stream_response(
|
||||
resp: reqwest::Response,
|
||||
) -> Result<impl Stream<Item = Result<T, AppResponseError>>, AppResponseError> {
|
||||
let status_code = resp.status();
|
||||
if !status_code.is_success() {
|
||||
let body = resp.text().await?;
|
||||
return Err(AppResponseError::new(ErrorCode::Internal, body));
|
||||
}
|
||||
|
||||
let stream = resp.bytes_stream().map(|item| {
|
||||
item.map_err(AppResponseError::from).and_then(|bytes| {
|
||||
serde_json::from_slice::<T>(bytes.as_ref()).map_err(AppResponseError::from)
|
||||
})
|
||||
});
|
||||
Ok(stream)
|
||||
}
|
||||
}
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, thiserror::Error)]
|
||||
pub struct AppResponseError {
|
||||
|
|
|
|||
|
|
@ -2,14 +2,16 @@ use crate::biz::chat::ops::{create_chat, create_chat_message, delete_chat, get_c
|
|||
use crate::biz::user::auth::jwt::UserUuid;
|
||||
use crate::state::AppState;
|
||||
use actix_web::web::{Data, Json};
|
||||
use actix_web::{web, Scope};
|
||||
use actix_web::{web, HttpResponse, Scope};
|
||||
use app_error::AppError;
|
||||
use database_entity::dto::{
|
||||
CreateChatMessageParams, CreateChatParams, GetChatMessageParams, MessageCursor, QAChatMessage,
|
||||
CreateChatMessageParams, CreateChatParams, GetChatMessageParams, MessageCursor,
|
||||
RepeatedChatMessage,
|
||||
};
|
||||
use shared_entity::response::{AppResponse, JsonAppResponse};
|
||||
use std::collections::HashMap;
|
||||
use tracing::trace;
|
||||
use validator::Validate;
|
||||
|
||||
pub fn chat_scope() -> Scope {
|
||||
web::scope("/api/chat/{workspace_id}")
|
||||
|
|
@ -57,14 +59,28 @@ async fn post_chat_message_handler(
|
|||
path: web::Path<(String, String)>,
|
||||
payload: Json<CreateChatMessageParams>,
|
||||
uuid: UserUuid,
|
||||
) -> actix_web::Result<JsonAppResponse<QAChatMessage>> {
|
||||
) -> actix_web::Result<HttpResponse> {
|
||||
let (_workspace_id, chat_id) = path.into_inner();
|
||||
let params = payload.into_inner();
|
||||
|
||||
if let Err(err) = params.validate() {
|
||||
return Ok(HttpResponse::from_error(AppError::from(err)));
|
||||
}
|
||||
|
||||
let uid = state.user_cache.get_user_uid(&uuid).await?;
|
||||
trace!("insert chat message into {}", chat_id);
|
||||
let message =
|
||||
create_chat_message(&state.pg_pool, uid, params, &chat_id, &state.ai_client).await?;
|
||||
Ok(AppResponse::Ok().with_data(message).into())
|
||||
let message_stream = create_chat_message(
|
||||
&state.pg_pool,
|
||||
uid,
|
||||
chat_id,
|
||||
params,
|
||||
state.ai_client.clone(),
|
||||
)
|
||||
.await;
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.streaming(message_stream),
|
||||
)
|
||||
}
|
||||
|
||||
async fn get_chat_message_handler(
|
||||
|
|
|
|||
|
|
@ -1,15 +1,17 @@
|
|||
use anyhow::anyhow;
|
||||
use actix_web::web::Bytes;
|
||||
|
||||
use app_error::AppError;
|
||||
use appflowy_ai_client::client::AppFlowyAIClient;
|
||||
use async_stream::stream;
|
||||
use database::chat;
|
||||
use database::chat::chat_ops::{insert_chat, insert_chat_message, select_chat_messages};
|
||||
use database_entity::dto::{
|
||||
ChatAuthor, ChatMessageType, CreateChatMessageParams, CreateChatParams, GetChatMessageParams,
|
||||
QAChatMessage, RepeatedChatMessage,
|
||||
ChatAuthor, ChatAuthorType, ChatMessageType, CreateChatMessageParams, CreateChatParams,
|
||||
GetChatMessageParams, RepeatedChatMessage,
|
||||
};
|
||||
use futures::stream::Stream;
|
||||
use sqlx::PgPool;
|
||||
use std::ops::DerefMut;
|
||||
use tracing::trace;
|
||||
|
||||
use validator::Validate;
|
||||
|
||||
pub(crate) async fn create_chat(
|
||||
|
|
@ -34,53 +36,73 @@ pub(crate) async fn delete_chat(pg_pool: &PgPool, chat_id: &str) -> Result<(), A
|
|||
|
||||
pub async fn create_chat_message(
|
||||
pg_pool: &PgPool,
|
||||
_uid: i64,
|
||||
uid: i64,
|
||||
chat_id: String,
|
||||
params: CreateChatMessageParams,
|
||||
chat_id: &str,
|
||||
ai_client: &AppFlowyAIClient,
|
||||
) -> Result<QAChatMessage, AppError> {
|
||||
params.validate()?;
|
||||
ai_client: AppFlowyAIClient,
|
||||
) -> impl Stream<Item = Result<Bytes, AppError>> {
|
||||
let params = params.clone();
|
||||
let chat_id = chat_id.clone();
|
||||
let pg_pool = pg_pool.clone();
|
||||
let stream = stream! {
|
||||
// Insert question message
|
||||
let question = match insert_chat_message(
|
||||
&pg_pool,
|
||||
ChatAuthor::new(uid, ChatAuthorType::Human),
|
||||
&chat_id,
|
||||
params.content.clone()
|
||||
).await {
|
||||
Ok(question) => question,
|
||||
Err(err) => {
|
||||
yield Err(err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let answer_content = match params.message_type {
|
||||
ChatMessageType::System => "".to_string(),
|
||||
ChatMessageType::User => {
|
||||
let start = std::time::Instant::now();
|
||||
trace!("[Chat] sending question to AI: {}", params.content);
|
||||
let content = ai_client
|
||||
.send_question(chat_id, ¶ms.content)
|
||||
.await
|
||||
.map(|answer| answer.content)?;
|
||||
trace!(
|
||||
"[Chat] received answer from AI: {}, cost:{} millis",
|
||||
content,
|
||||
start.elapsed().as_millis()
|
||||
);
|
||||
content
|
||||
},
|
||||
let question_bytes = match serde_json::to_vec(&question) {
|
||||
Ok(bytes) => bytes,
|
||||
Err(err) => {
|
||||
yield Err(AppError::from(err));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
yield Ok::<Bytes, AppError>(Bytes::from(question_bytes));
|
||||
|
||||
// Insert answer message
|
||||
match params.message_type {
|
||||
ChatMessageType::System => {}
|
||||
ChatMessageType::User => {
|
||||
let content = match ai_client.send_question(&chat_id, ¶ms.content).await {
|
||||
Ok(response) => response.content,
|
||||
Err(err) => {
|
||||
yield Err(AppError::from(err));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let answer = match insert_chat_message(&pg_pool, ChatAuthor::ai(), &chat_id, content.clone()).await {
|
||||
Ok(answer) => answer,
|
||||
Err(err) => {
|
||||
yield Err(err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let answer_bytes = match serde_json::to_vec(&answer) {
|
||||
Ok(bytes) => bytes,
|
||||
Err(err) => {
|
||||
yield Err(AppError::from(err));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
yield Ok::<Bytes, AppError>(Bytes::from(answer_bytes));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let mut txn = pg_pool.begin().await.map_err(|err| {
|
||||
AppError::Internal(anyhow!(
|
||||
"failed to start transaction for inserting chat message: {}",
|
||||
err
|
||||
))
|
||||
})?;
|
||||
let question =
|
||||
insert_chat_message(txn.deref_mut(), ChatAuthor::Human, chat_id, params.content).await?;
|
||||
|
||||
let answer = match params.message_type {
|
||||
ChatMessageType::System => None,
|
||||
ChatMessageType::User => {
|
||||
Some(insert_chat_message(txn.deref_mut(), ChatAuthor::AI, chat_id, answer_content).await?)
|
||||
},
|
||||
};
|
||||
|
||||
txn
|
||||
.commit()
|
||||
.await
|
||||
.map_err(|err| AppError::Internal(anyhow!("failed to insert chat message: {}", err)))?;
|
||||
|
||||
Ok(QAChatMessage { question, answer })
|
||||
stream
|
||||
}
|
||||
|
||||
pub async fn get_chat_messages(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use client_api_test::TestClient;
|
||||
use database_entity::dto::{CreateChatMessageParams, CreateChatParams, MessageCursor};
|
||||
use database_entity::dto::{ChatMessage, CreateChatMessageParams, CreateChatParams, MessageCursor};
|
||||
use futures_util::StreamExt;
|
||||
|
||||
#[tokio::test]
|
||||
async fn create_chat_and_create_messages_test() {
|
||||
|
|
@ -26,6 +27,10 @@ async fn create_chat_and_create_messages_test() {
|
|||
.api_client
|
||||
.create_chat_message(&workspace_id, &chat_id, params)
|
||||
.await
|
||||
.unwrap()
|
||||
.next()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
messages.push(message);
|
||||
}
|
||||
|
|
@ -37,7 +42,7 @@ async fn create_chat_and_create_messages_test() {
|
|||
.get_chat_messages(
|
||||
&workspace_id,
|
||||
&chat_id,
|
||||
MessageCursor::BeforeMessageId(messages[2].question.message_id),
|
||||
MessageCursor::BeforeMessageId(messages[2].message_id),
|
||||
10,
|
||||
)
|
||||
.await
|
||||
|
|
@ -47,11 +52,11 @@ async fn create_chat_and_create_messages_test() {
|
|||
assert_eq!(message_before_third.messages.len(), 2);
|
||||
assert_eq!(
|
||||
message_before_third.messages[0].message_id,
|
||||
messages[0].question.message_id
|
||||
messages[0].message_id
|
||||
);
|
||||
assert_eq!(
|
||||
message_before_third.messages[1].message_id,
|
||||
messages[1].question.message_id
|
||||
messages[1].message_id
|
||||
);
|
||||
|
||||
// get message after third message
|
||||
|
|
@ -60,7 +65,7 @@ async fn create_chat_and_create_messages_test() {
|
|||
.get_chat_messages(
|
||||
&workspace_id,
|
||||
&chat_id,
|
||||
MessageCursor::AfterMessageId(messages[2].question.message_id),
|
||||
MessageCursor::AfterMessageId(messages[2].message_id),
|
||||
2,
|
||||
)
|
||||
.await
|
||||
|
|
@ -69,11 +74,11 @@ async fn create_chat_and_create_messages_test() {
|
|||
assert_eq!(message_after_third.messages.len(), 2);
|
||||
assert_eq!(
|
||||
message_after_third.messages[0].message_id,
|
||||
messages[3].question.message_id
|
||||
messages[3].message_id
|
||||
);
|
||||
assert_eq!(
|
||||
message_after_third.messages[1].message_id,
|
||||
messages[4].question.message_id
|
||||
messages[4].message_id
|
||||
);
|
||||
|
||||
// get all messages after 8th message
|
||||
|
|
@ -82,7 +87,7 @@ async fn create_chat_and_create_messages_test() {
|
|||
.get_chat_messages(
|
||||
&workspace_id,
|
||||
&chat_id,
|
||||
MessageCursor::AfterMessageId(messages[7].question.message_id),
|
||||
MessageCursor::AfterMessageId(messages[7].message_id),
|
||||
100,
|
||||
)
|
||||
.await
|
||||
|
|
@ -111,10 +116,12 @@ async fn chat_qa_test() {
|
|||
.unwrap();
|
||||
|
||||
let params = CreateChatMessageParams::new_user("where is singapore?");
|
||||
let message = test_client
|
||||
let stream = test_client
|
||||
.api_client
|
||||
.create_chat_message(&workspace_id, &chat_id, params)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!message.answer.unwrap().content.is_empty());
|
||||
|
||||
let messages: Vec<ChatMessage> = stream.map(|message| message.unwrap()).collect().await;
|
||||
assert_eq!(messages.len(), 2);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ use database::chat::chat_ops::{
|
|||
delete_chat, get_all_chat_messages, insert_chat, insert_chat_message, select_chat,
|
||||
select_chat_messages,
|
||||
};
|
||||
use database_entity::dto::{ChatAuthor, CreateChatParams, GetChatMessageParams};
|
||||
use database_entity::dto::{ChatAuthor, ChatAuthorType, CreateChatParams, GetChatMessageParams};
|
||||
use serde_json::json;
|
||||
use sqlx::PgPool;
|
||||
|
||||
|
|
@ -91,9 +91,14 @@ async fn chat_message_crud_test(pool: PgPool) {
|
|||
|
||||
// create chat messages
|
||||
for i in 0..5 {
|
||||
let _ = insert_chat_message(&pool, ChatAuthor::Human, &chat_id, format!("message {}", i))
|
||||
.await
|
||||
.unwrap();
|
||||
let _ = insert_chat_message(
|
||||
&pool,
|
||||
ChatAuthor::new(0, ChatAuthorType::System),
|
||||
&chat_id,
|
||||
format!("message {}", i),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
{
|
||||
let params = GetChatMessageParams::next_back(3);
|
||||
|
|
|
|||
Loading…
Reference in New Issue