feat: stream chat message (#577)

* chore: save author id

* chore: stream response

* chore: stream chat message
This commit is contained in:
Nathan.fooo 2024-05-26 22:44:08 +08:00 committed by GitHub
parent ae3e075475
commit 559d924cd1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 206 additions and 87 deletions

1
Cargo.lock generated
View File

@ -5487,6 +5487,7 @@ dependencies = [
"chrono",
"collab-entity",
"database-entity",
"futures",
"gotrue-entity",
"reqwest 0.11.27",
"rust-s3",

View File

@ -1,8 +1,9 @@
use crate::http::log_request_id;
use crate::Client;
use database_entity::dto::{
CreateChatMessageParams, CreateChatParams, MessageCursor, QAChatMessage, RepeatedChatMessage,
ChatMessage, CreateChatMessageParams, CreateChatParams, MessageCursor, RepeatedChatMessage,
};
use futures_core::Stream;
use reqwest::Method;
use shared_entity::response::{AppResponse, AppResponseError};
@ -42,7 +43,7 @@ impl Client {
workspace_id: &str,
chat_id: &str,
params: CreateChatMessageParams,
) -> Result<QAChatMessage, AppResponseError> {
) -> Result<impl Stream<Item = Result<ChatMessage, AppResponseError>>, AppResponseError> {
let url = format!(
"{}/api/chat/{workspace_id}/{chat_id}/message",
self.base_url
@ -54,9 +55,7 @@ impl Client {
.send()
.await?;
log_request_id(&resp);
AppResponse::<QAChatMessage>::from_response(resp)
.await?
.into_data()
AppResponse::<ChatMessage>::stream_response(resp).await
}
pub async fn get_chat_messages(

View File

@ -653,16 +653,38 @@ pub struct RepeatedChatMessage {
#[derive(Debug, Default, Clone, Serialize_repr, Deserialize_repr)]
#[repr(u8)]
pub enum ChatAuthor {
#[default]
pub enum ChatAuthorType {
Unknown = 0,
Human = 1,
#[default]
System = 2,
AI = 3,
}
impl From<serde_json::Value> for ChatAuthor {
fn from(value: serde_json::Value) -> Self {
serde_json::from_value::<ChatAuthor>(value).unwrap_or_default()
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatAuthor {
pub author_id: i64,
#[serde(default)]
pub author_type: ChatAuthorType,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub meta: Option<serde_json::Value>,
}
impl ChatAuthor {
pub fn new(author_id: i64, author_type: ChatAuthorType) -> Self {
Self {
author_id,
author_type,
meta: None,
}
}
pub fn ai() -> Self {
Self {
author_id: 0,
author_type: ChatAuthorType::AI,
meta: None,
}
}
}

View File

@ -13,6 +13,7 @@ use sqlx::postgres::PgArguments;
use sqlx::{Arguments, Executor, Postgres, Transaction};
use std::ops::DerefMut;
use std::str::FromStr;
use tracing::warn;
use uuid::Uuid;
@ -209,11 +210,19 @@ pub async fn select_chat_messages(
let mut messages = rows
.into_iter()
.map(|(message_id, content, created_at, author)| ChatMessage {
author: serde_json::from_value::<ChatAuthor>(author).unwrap_or_default(),
message_id,
content,
created_at,
.flat_map(|(message_id, content, created_at, author)| {
match serde_json::from_value::<ChatAuthor>(author) {
Ok(author) => Some(ChatMessage {
author,
message_id,
content,
created_at,
}),
Err(err) => {
warn!("Failed to deserialize author: {}", err);
None
},
}
})
.collect::<Vec<ChatMessage>>();
@ -280,8 +289,8 @@ pub async fn get_all_chat_messages<'a, E: Executor<'a, Database = Postgres>>(
chat_id: &str,
) -> Result<Vec<ChatMessage>, AppError> {
let chat_id = Uuid::from_str(chat_id)?;
let messages: Vec<ChatMessage> = sqlx::query_as!(
ChatMessage,
let rows = sqlx::query!(
// ChatMessage,
r#"
SELECT message_id, content, created_at, author
FROM af_chat_messages
@ -292,5 +301,24 @@ pub async fn get_all_chat_messages<'a, E: Executor<'a, Database = Postgres>>(
)
.fetch_all(executor)
.await?;
let messages = rows
.into_iter()
.flat_map(
|row| match serde_json::from_value::<ChatAuthor>(row.author) {
Ok(author) => Some(ChatMessage {
author,
message_id: row.message_id,
content: row.content,
created_at: row.created_at,
}),
Err(err) => {
warn!("Failed to deserialize author: {}", err);
None
},
},
)
.collect::<Vec<ChatMessage>>();
Ok(messages)
}

View File

@ -13,7 +13,7 @@ serde = "1.0.195"
serde_json.workspace = true
serde_repr = "0.1.18"
thiserror = "1.0.56"
reqwest.workspace = true
reqwest = { worspace = true, features = ["stream"] }
uuid = { version = "1.6.1", features = ["v4"] }
gotrue-entity = { path = "../gotrue-entity" }
database-entity.workspace = true
@ -25,6 +25,7 @@ appflowy-ai-client = { workspace = true, default-features = false, features = ["
actix-web = { version = "4.4.1", default-features = false, features = ["http2"], optional = true }
validator = { version = "0.16", features = ["validator_derive", "derive"], optional = true }
rust-s3 = { version = "0.34.0-rc4", optional = true }
futures = "0.3.30"
[features]

View File

@ -3,6 +3,7 @@ use std::borrow::Cow;
use app_error::AppError;
pub use app_error::ErrorCode;
use futures::stream::{Stream, StreamExt};
use std::fmt::{Debug, Display};
#[cfg(feature = "cloud")]
@ -141,6 +142,23 @@ where
let resp = serde_json::from_slice(&bytes)?;
Ok(resp)
}
pub async fn stream_response(
resp: reqwest::Response,
) -> Result<impl Stream<Item = Result<T, AppResponseError>>, AppResponseError> {
let status_code = resp.status();
if !status_code.is_success() {
let body = resp.text().await?;
return Err(AppResponseError::new(ErrorCode::Internal, body));
}
let stream = resp.bytes_stream().map(|item| {
item.map_err(AppResponseError::from).and_then(|bytes| {
serde_json::from_slice::<T>(bytes.as_ref()).map_err(AppResponseError::from)
})
});
Ok(stream)
}
}
#[derive(Clone, Debug, Serialize, Deserialize, thiserror::Error)]
pub struct AppResponseError {

View File

@ -2,14 +2,16 @@ use crate::biz::chat::ops::{create_chat, create_chat_message, delete_chat, get_c
use crate::biz::user::auth::jwt::UserUuid;
use crate::state::AppState;
use actix_web::web::{Data, Json};
use actix_web::{web, Scope};
use actix_web::{web, HttpResponse, Scope};
use app_error::AppError;
use database_entity::dto::{
CreateChatMessageParams, CreateChatParams, GetChatMessageParams, MessageCursor, QAChatMessage,
CreateChatMessageParams, CreateChatParams, GetChatMessageParams, MessageCursor,
RepeatedChatMessage,
};
use shared_entity::response::{AppResponse, JsonAppResponse};
use std::collections::HashMap;
use tracing::trace;
use validator::Validate;
pub fn chat_scope() -> Scope {
web::scope("/api/chat/{workspace_id}")
@ -57,14 +59,28 @@ async fn post_chat_message_handler(
path: web::Path<(String, String)>,
payload: Json<CreateChatMessageParams>,
uuid: UserUuid,
) -> actix_web::Result<JsonAppResponse<QAChatMessage>> {
) -> actix_web::Result<HttpResponse> {
let (_workspace_id, chat_id) = path.into_inner();
let params = payload.into_inner();
if let Err(err) = params.validate() {
return Ok(HttpResponse::from_error(AppError::from(err)));
}
let uid = state.user_cache.get_user_uid(&uuid).await?;
trace!("insert chat message into {}", chat_id);
let message =
create_chat_message(&state.pg_pool, uid, params, &chat_id, &state.ai_client).await?;
Ok(AppResponse::Ok().with_data(message).into())
let message_stream = create_chat_message(
&state.pg_pool,
uid,
chat_id,
params,
state.ai_client.clone(),
)
.await;
Ok(
HttpResponse::Ok()
.content_type("application/json")
.streaming(message_stream),
)
}
async fn get_chat_message_handler(

View File

@ -1,15 +1,17 @@
use anyhow::anyhow;
use actix_web::web::Bytes;
use app_error::AppError;
use appflowy_ai_client::client::AppFlowyAIClient;
use async_stream::stream;
use database::chat;
use database::chat::chat_ops::{insert_chat, insert_chat_message, select_chat_messages};
use database_entity::dto::{
ChatAuthor, ChatMessageType, CreateChatMessageParams, CreateChatParams, GetChatMessageParams,
QAChatMessage, RepeatedChatMessage,
ChatAuthor, ChatAuthorType, ChatMessageType, CreateChatMessageParams, CreateChatParams,
GetChatMessageParams, RepeatedChatMessage,
};
use futures::stream::Stream;
use sqlx::PgPool;
use std::ops::DerefMut;
use tracing::trace;
use validator::Validate;
pub(crate) async fn create_chat(
@ -34,53 +36,73 @@ pub(crate) async fn delete_chat(pg_pool: &PgPool, chat_id: &str) -> Result<(), A
pub async fn create_chat_message(
pg_pool: &PgPool,
_uid: i64,
uid: i64,
chat_id: String,
params: CreateChatMessageParams,
chat_id: &str,
ai_client: &AppFlowyAIClient,
) -> Result<QAChatMessage, AppError> {
params.validate()?;
ai_client: AppFlowyAIClient,
) -> impl Stream<Item = Result<Bytes, AppError>> {
let params = params.clone();
let chat_id = chat_id.clone();
let pg_pool = pg_pool.clone();
let stream = stream! {
// Insert question message
let question = match insert_chat_message(
&pg_pool,
ChatAuthor::new(uid, ChatAuthorType::Human),
&chat_id,
params.content.clone()
).await {
Ok(question) => question,
Err(err) => {
yield Err(err);
return;
}
};
let answer_content = match params.message_type {
ChatMessageType::System => "".to_string(),
ChatMessageType::User => {
let start = std::time::Instant::now();
trace!("[Chat] sending question to AI: {}", params.content);
let content = ai_client
.send_question(chat_id, &params.content)
.await
.map(|answer| answer.content)?;
trace!(
"[Chat] received answer from AI: {}, cost:{} millis",
content,
start.elapsed().as_millis()
);
content
},
let question_bytes = match serde_json::to_vec(&question) {
Ok(bytes) => bytes,
Err(err) => {
yield Err(AppError::from(err));
return;
}
};
yield Ok::<Bytes, AppError>(Bytes::from(question_bytes));
// Insert answer message
match params.message_type {
ChatMessageType::System => {}
ChatMessageType::User => {
let content = match ai_client.send_question(&chat_id, &params.content).await {
Ok(response) => response.content,
Err(err) => {
yield Err(AppError::from(err));
return;
}
};
let answer = match insert_chat_message(&pg_pool, ChatAuthor::ai(), &chat_id, content.clone()).await {
Ok(answer) => answer,
Err(err) => {
yield Err(err);
return;
}
};
let answer_bytes = match serde_json::to_vec(&answer) {
Ok(bytes) => bytes,
Err(err) => {
yield Err(AppError::from(err));
return;
}
};
yield Ok::<Bytes, AppError>(Bytes::from(answer_bytes));
}
}
};
let mut txn = pg_pool.begin().await.map_err(|err| {
AppError::Internal(anyhow!(
"failed to start transaction for inserting chat message: {}",
err
))
})?;
let question =
insert_chat_message(txn.deref_mut(), ChatAuthor::Human, chat_id, params.content).await?;
let answer = match params.message_type {
ChatMessageType::System => None,
ChatMessageType::User => {
Some(insert_chat_message(txn.deref_mut(), ChatAuthor::AI, chat_id, answer_content).await?)
},
};
txn
.commit()
.await
.map_err(|err| AppError::Internal(anyhow!("failed to insert chat message: {}", err)))?;
Ok(QAChatMessage { question, answer })
stream
}
pub async fn get_chat_messages(

View File

@ -1,5 +1,6 @@
use client_api_test::TestClient;
use database_entity::dto::{CreateChatMessageParams, CreateChatParams, MessageCursor};
use database_entity::dto::{ChatMessage, CreateChatMessageParams, CreateChatParams, MessageCursor};
use futures_util::StreamExt;
#[tokio::test]
async fn create_chat_and_create_messages_test() {
@ -26,6 +27,10 @@ async fn create_chat_and_create_messages_test() {
.api_client
.create_chat_message(&workspace_id, &chat_id, params)
.await
.unwrap()
.next()
.await
.unwrap()
.unwrap();
messages.push(message);
}
@ -37,7 +42,7 @@ async fn create_chat_and_create_messages_test() {
.get_chat_messages(
&workspace_id,
&chat_id,
MessageCursor::BeforeMessageId(messages[2].question.message_id),
MessageCursor::BeforeMessageId(messages[2].message_id),
10,
)
.await
@ -47,11 +52,11 @@ async fn create_chat_and_create_messages_test() {
assert_eq!(message_before_third.messages.len(), 2);
assert_eq!(
message_before_third.messages[0].message_id,
messages[0].question.message_id
messages[0].message_id
);
assert_eq!(
message_before_third.messages[1].message_id,
messages[1].question.message_id
messages[1].message_id
);
// get message after third message
@ -60,7 +65,7 @@ async fn create_chat_and_create_messages_test() {
.get_chat_messages(
&workspace_id,
&chat_id,
MessageCursor::AfterMessageId(messages[2].question.message_id),
MessageCursor::AfterMessageId(messages[2].message_id),
2,
)
.await
@ -69,11 +74,11 @@ async fn create_chat_and_create_messages_test() {
assert_eq!(message_after_third.messages.len(), 2);
assert_eq!(
message_after_third.messages[0].message_id,
messages[3].question.message_id
messages[3].message_id
);
assert_eq!(
message_after_third.messages[1].message_id,
messages[4].question.message_id
messages[4].message_id
);
// get all messages after 8th message
@ -82,7 +87,7 @@ async fn create_chat_and_create_messages_test() {
.get_chat_messages(
&workspace_id,
&chat_id,
MessageCursor::AfterMessageId(messages[7].question.message_id),
MessageCursor::AfterMessageId(messages[7].message_id),
100,
)
.await
@ -111,10 +116,12 @@ async fn chat_qa_test() {
.unwrap();
let params = CreateChatMessageParams::new_user("where is singapore?");
let message = test_client
let stream = test_client
.api_client
.create_chat_message(&workspace_id, &chat_id, params)
.await
.unwrap();
assert!(!message.answer.unwrap().content.is_empty());
let messages: Vec<ChatMessage> = stream.map(|message| message.unwrap()).collect().await;
assert_eq!(messages.len(), 2);
}

View File

@ -3,7 +3,7 @@ use database::chat::chat_ops::{
delete_chat, get_all_chat_messages, insert_chat, insert_chat_message, select_chat,
select_chat_messages,
};
use database_entity::dto::{ChatAuthor, CreateChatParams, GetChatMessageParams};
use database_entity::dto::{ChatAuthor, ChatAuthorType, CreateChatParams, GetChatMessageParams};
use serde_json::json;
use sqlx::PgPool;
@ -91,9 +91,14 @@ async fn chat_message_crud_test(pool: PgPool) {
// create chat messages
for i in 0..5 {
let _ = insert_chat_message(&pool, ChatAuthor::Human, &chat_id, format!("message {}", i))
.await
.unwrap();
let _ = insert_chat_message(
&pool,
ChatAuthor::new(0, ChatAuthorType::System),
&chat_id,
format!("message {}", i),
)
.await
.unwrap();
}
{
let params = GetChatMessageParams::next_back(3);