From 559d924cd182d8db16df69c030c96e63153988f7 Mon Sep 17 00:00:00 2001 From: "Nathan.fooo" <86001920+appflowy@users.noreply.github.com> Date: Sun, 26 May 2024 22:44:08 +0800 Subject: [PATCH] feat: stream chat message (#577) * chore: save author id * chore: stream response * chore: stream chat message --- Cargo.lock | 1 + libs/client-api/src/http_chat.rs | 9 +-- libs/database-entity/src/dto.rs | 32 ++++++-- libs/database/src/chat/chat_ops.rs | 42 ++++++++-- libs/shared-entity/Cargo.toml | 3 +- libs/shared-entity/src/response.rs | 18 +++++ src/api/chat.rs | 30 ++++++-- src/biz/chat/ops.rs | 118 +++++++++++++++++------------ tests/ai_test/chat_test.rs | 27 ++++--- tests/sql_test/chat_test.rs | 13 +++- 10 files changed, 206 insertions(+), 87 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5dcf9110..d68f7170 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5487,6 +5487,7 @@ dependencies = [ "chrono", "collab-entity", "database-entity", + "futures", "gotrue-entity", "reqwest 0.11.27", "rust-s3", diff --git a/libs/client-api/src/http_chat.rs b/libs/client-api/src/http_chat.rs index 473d0c61..341bf4a3 100644 --- a/libs/client-api/src/http_chat.rs +++ b/libs/client-api/src/http_chat.rs @@ -1,8 +1,9 @@ use crate::http::log_request_id; use crate::Client; use database_entity::dto::{ - CreateChatMessageParams, CreateChatParams, MessageCursor, QAChatMessage, RepeatedChatMessage, + ChatMessage, CreateChatMessageParams, CreateChatParams, MessageCursor, RepeatedChatMessage, }; +use futures_core::Stream; use reqwest::Method; use shared_entity::response::{AppResponse, AppResponseError}; @@ -42,7 +43,7 @@ impl Client { workspace_id: &str, chat_id: &str, params: CreateChatMessageParams, - ) -> Result { + ) -> Result>, AppResponseError> { let url = format!( "{}/api/chat/{workspace_id}/{chat_id}/message", self.base_url @@ -54,9 +55,7 @@ impl Client { .send() .await?; log_request_id(&resp); - AppResponse::::from_response(resp) - .await? - .into_data() + AppResponse::::stream_response(resp).await } pub async fn get_chat_messages( diff --git a/libs/database-entity/src/dto.rs b/libs/database-entity/src/dto.rs index 276b9ca0..78df41ef 100644 --- a/libs/database-entity/src/dto.rs +++ b/libs/database-entity/src/dto.rs @@ -653,16 +653,38 @@ pub struct RepeatedChatMessage { #[derive(Debug, Default, Clone, Serialize_repr, Deserialize_repr)] #[repr(u8)] -pub enum ChatAuthor { - #[default] +pub enum ChatAuthorType { Unknown = 0, Human = 1, + #[default] System = 2, AI = 3, } -impl From for ChatAuthor { - fn from(value: serde_json::Value) -> Self { - serde_json::from_value::(value).unwrap_or_default() +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatAuthor { + pub author_id: i64, + #[serde(default)] + pub author_type: ChatAuthorType, + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + pub meta: Option, +} + +impl ChatAuthor { + pub fn new(author_id: i64, author_type: ChatAuthorType) -> Self { + Self { + author_id, + author_type, + meta: None, + } + } + + pub fn ai() -> Self { + Self { + author_id: 0, + author_type: ChatAuthorType::AI, + meta: None, + } } } diff --git a/libs/database/src/chat/chat_ops.rs b/libs/database/src/chat/chat_ops.rs index 3eb8b76e..f8ed14ab 100644 --- a/libs/database/src/chat/chat_ops.rs +++ b/libs/database/src/chat/chat_ops.rs @@ -13,6 +13,7 @@ use sqlx::postgres::PgArguments; use sqlx::{Arguments, Executor, Postgres, Transaction}; use std::ops::DerefMut; use std::str::FromStr; +use tracing::warn; use uuid::Uuid; @@ -209,11 +210,19 @@ pub async fn select_chat_messages( let mut messages = rows .into_iter() - .map(|(message_id, content, created_at, author)| ChatMessage { - author: serde_json::from_value::(author).unwrap_or_default(), - message_id, - content, - created_at, + .flat_map(|(message_id, content, created_at, author)| { + match serde_json::from_value::(author) { + Ok(author) => Some(ChatMessage { + author, + message_id, + content, + created_at, + }), + Err(err) => { + warn!("Failed to deserialize author: {}", err); + None + }, + } }) .collect::>(); @@ -280,8 +289,8 @@ pub async fn get_all_chat_messages<'a, E: Executor<'a, Database = Postgres>>( chat_id: &str, ) -> Result, AppError> { let chat_id = Uuid::from_str(chat_id)?; - let messages: Vec = sqlx::query_as!( - ChatMessage, + let rows = sqlx::query!( + // ChatMessage, r#" SELECT message_id, content, created_at, author FROM af_chat_messages @@ -292,5 +301,24 @@ pub async fn get_all_chat_messages<'a, E: Executor<'a, Database = Postgres>>( ) .fetch_all(executor) .await?; + + let messages = rows + .into_iter() + .flat_map( + |row| match serde_json::from_value::(row.author) { + Ok(author) => Some(ChatMessage { + author, + message_id: row.message_id, + content: row.content, + created_at: row.created_at, + }), + Err(err) => { + warn!("Failed to deserialize author: {}", err); + None + }, + }, + ) + .collect::>(); + Ok(messages) } diff --git a/libs/shared-entity/Cargo.toml b/libs/shared-entity/Cargo.toml index 8ccc96fb..e83fe663 100644 --- a/libs/shared-entity/Cargo.toml +++ b/libs/shared-entity/Cargo.toml @@ -13,7 +13,7 @@ serde = "1.0.195" serde_json.workspace = true serde_repr = "0.1.18" thiserror = "1.0.56" -reqwest.workspace = true +reqwest = { worspace = true, features = ["stream"] } uuid = { version = "1.6.1", features = ["v4"] } gotrue-entity = { path = "../gotrue-entity" } database-entity.workspace = true @@ -25,6 +25,7 @@ appflowy-ai-client = { workspace = true, default-features = false, features = [" actix-web = { version = "4.4.1", default-features = false, features = ["http2"], optional = true } validator = { version = "0.16", features = ["validator_derive", "derive"], optional = true } rust-s3 = { version = "0.34.0-rc4", optional = true } +futures = "0.3.30" [features] diff --git a/libs/shared-entity/src/response.rs b/libs/shared-entity/src/response.rs index cbc10c6b..9c10ac3d 100644 --- a/libs/shared-entity/src/response.rs +++ b/libs/shared-entity/src/response.rs @@ -3,6 +3,7 @@ use std::borrow::Cow; use app_error::AppError; pub use app_error::ErrorCode; +use futures::stream::{Stream, StreamExt}; use std::fmt::{Debug, Display}; #[cfg(feature = "cloud")] @@ -141,6 +142,23 @@ where let resp = serde_json::from_slice(&bytes)?; Ok(resp) } + + pub async fn stream_response( + resp: reqwest::Response, + ) -> Result>, AppResponseError> { + let status_code = resp.status(); + if !status_code.is_success() { + let body = resp.text().await?; + return Err(AppResponseError::new(ErrorCode::Internal, body)); + } + + let stream = resp.bytes_stream().map(|item| { + item.map_err(AppResponseError::from).and_then(|bytes| { + serde_json::from_slice::(bytes.as_ref()).map_err(AppResponseError::from) + }) + }); + Ok(stream) + } } #[derive(Clone, Debug, Serialize, Deserialize, thiserror::Error)] pub struct AppResponseError { diff --git a/src/api/chat.rs b/src/api/chat.rs index d1701e89..34103d9b 100644 --- a/src/api/chat.rs +++ b/src/api/chat.rs @@ -2,14 +2,16 @@ use crate::biz::chat::ops::{create_chat, create_chat_message, delete_chat, get_c use crate::biz::user::auth::jwt::UserUuid; use crate::state::AppState; use actix_web::web::{Data, Json}; -use actix_web::{web, Scope}; +use actix_web::{web, HttpResponse, Scope}; +use app_error::AppError; use database_entity::dto::{ - CreateChatMessageParams, CreateChatParams, GetChatMessageParams, MessageCursor, QAChatMessage, + CreateChatMessageParams, CreateChatParams, GetChatMessageParams, MessageCursor, RepeatedChatMessage, }; use shared_entity::response::{AppResponse, JsonAppResponse}; use std::collections::HashMap; use tracing::trace; +use validator::Validate; pub fn chat_scope() -> Scope { web::scope("/api/chat/{workspace_id}") @@ -57,14 +59,28 @@ async fn post_chat_message_handler( path: web::Path<(String, String)>, payload: Json, uuid: UserUuid, -) -> actix_web::Result> { +) -> actix_web::Result { let (_workspace_id, chat_id) = path.into_inner(); let params = payload.into_inner(); + + if let Err(err) = params.validate() { + return Ok(HttpResponse::from_error(AppError::from(err))); + } + let uid = state.user_cache.get_user_uid(&uuid).await?; - trace!("insert chat message into {}", chat_id); - let message = - create_chat_message(&state.pg_pool, uid, params, &chat_id, &state.ai_client).await?; - Ok(AppResponse::Ok().with_data(message).into()) + let message_stream = create_chat_message( + &state.pg_pool, + uid, + chat_id, + params, + state.ai_client.clone(), + ) + .await; + Ok( + HttpResponse::Ok() + .content_type("application/json") + .streaming(message_stream), + ) } async fn get_chat_message_handler( diff --git a/src/biz/chat/ops.rs b/src/biz/chat/ops.rs index 1947d73b..033c017a 100644 --- a/src/biz/chat/ops.rs +++ b/src/biz/chat/ops.rs @@ -1,15 +1,17 @@ -use anyhow::anyhow; +use actix_web::web::Bytes; + use app_error::AppError; use appflowy_ai_client::client::AppFlowyAIClient; +use async_stream::stream; use database::chat; use database::chat::chat_ops::{insert_chat, insert_chat_message, select_chat_messages}; use database_entity::dto::{ - ChatAuthor, ChatMessageType, CreateChatMessageParams, CreateChatParams, GetChatMessageParams, - QAChatMessage, RepeatedChatMessage, + ChatAuthor, ChatAuthorType, ChatMessageType, CreateChatMessageParams, CreateChatParams, + GetChatMessageParams, RepeatedChatMessage, }; +use futures::stream::Stream; use sqlx::PgPool; -use std::ops::DerefMut; -use tracing::trace; + use validator::Validate; pub(crate) async fn create_chat( @@ -34,53 +36,73 @@ pub(crate) async fn delete_chat(pg_pool: &PgPool, chat_id: &str) -> Result<(), A pub async fn create_chat_message( pg_pool: &PgPool, - _uid: i64, + uid: i64, + chat_id: String, params: CreateChatMessageParams, - chat_id: &str, - ai_client: &AppFlowyAIClient, -) -> Result { - params.validate()?; + ai_client: AppFlowyAIClient, +) -> impl Stream> { + let params = params.clone(); + let chat_id = chat_id.clone(); + let pg_pool = pg_pool.clone(); + let stream = stream! { + // Insert question message + let question = match insert_chat_message( + &pg_pool, + ChatAuthor::new(uid, ChatAuthorType::Human), + &chat_id, + params.content.clone() + ).await { + Ok(question) => question, + Err(err) => { + yield Err(err); + return; + } + }; - let answer_content = match params.message_type { - ChatMessageType::System => "".to_string(), - ChatMessageType::User => { - let start = std::time::Instant::now(); - trace!("[Chat] sending question to AI: {}", params.content); - let content = ai_client - .send_question(chat_id, ¶ms.content) - .await - .map(|answer| answer.content)?; - trace!( - "[Chat] received answer from AI: {}, cost:{} millis", - content, - start.elapsed().as_millis() - ); - content - }, + let question_bytes = match serde_json::to_vec(&question) { + Ok(bytes) => bytes, + Err(err) => { + yield Err(AppError::from(err)); + return; + } + }; + + yield Ok::(Bytes::from(question_bytes)); + + // Insert answer message + match params.message_type { + ChatMessageType::System => {} + ChatMessageType::User => { + let content = match ai_client.send_question(&chat_id, ¶ms.content).await { + Ok(response) => response.content, + Err(err) => { + yield Err(AppError::from(err)); + return; + } + }; + + let answer = match insert_chat_message(&pg_pool, ChatAuthor::ai(), &chat_id, content.clone()).await { + Ok(answer) => answer, + Err(err) => { + yield Err(err); + return; + } + }; + + let answer_bytes = match serde_json::to_vec(&answer) { + Ok(bytes) => bytes, + Err(err) => { + yield Err(AppError::from(err)); + return; + } + }; + + yield Ok::(Bytes::from(answer_bytes)); + } + } }; - let mut txn = pg_pool.begin().await.map_err(|err| { - AppError::Internal(anyhow!( - "failed to start transaction for inserting chat message: {}", - err - )) - })?; - let question = - insert_chat_message(txn.deref_mut(), ChatAuthor::Human, chat_id, params.content).await?; - - let answer = match params.message_type { - ChatMessageType::System => None, - ChatMessageType::User => { - Some(insert_chat_message(txn.deref_mut(), ChatAuthor::AI, chat_id, answer_content).await?) - }, - }; - - txn - .commit() - .await - .map_err(|err| AppError::Internal(anyhow!("failed to insert chat message: {}", err)))?; - - Ok(QAChatMessage { question, answer }) + stream } pub async fn get_chat_messages( diff --git a/tests/ai_test/chat_test.rs b/tests/ai_test/chat_test.rs index f806b862..23b67ed1 100644 --- a/tests/ai_test/chat_test.rs +++ b/tests/ai_test/chat_test.rs @@ -1,5 +1,6 @@ use client_api_test::TestClient; -use database_entity::dto::{CreateChatMessageParams, CreateChatParams, MessageCursor}; +use database_entity::dto::{ChatMessage, CreateChatMessageParams, CreateChatParams, MessageCursor}; +use futures_util::StreamExt; #[tokio::test] async fn create_chat_and_create_messages_test() { @@ -26,6 +27,10 @@ async fn create_chat_and_create_messages_test() { .api_client .create_chat_message(&workspace_id, &chat_id, params) .await + .unwrap() + .next() + .await + .unwrap() .unwrap(); messages.push(message); } @@ -37,7 +42,7 @@ async fn create_chat_and_create_messages_test() { .get_chat_messages( &workspace_id, &chat_id, - MessageCursor::BeforeMessageId(messages[2].question.message_id), + MessageCursor::BeforeMessageId(messages[2].message_id), 10, ) .await @@ -47,11 +52,11 @@ async fn create_chat_and_create_messages_test() { assert_eq!(message_before_third.messages.len(), 2); assert_eq!( message_before_third.messages[0].message_id, - messages[0].question.message_id + messages[0].message_id ); assert_eq!( message_before_third.messages[1].message_id, - messages[1].question.message_id + messages[1].message_id ); // get message after third message @@ -60,7 +65,7 @@ async fn create_chat_and_create_messages_test() { .get_chat_messages( &workspace_id, &chat_id, - MessageCursor::AfterMessageId(messages[2].question.message_id), + MessageCursor::AfterMessageId(messages[2].message_id), 2, ) .await @@ -69,11 +74,11 @@ async fn create_chat_and_create_messages_test() { assert_eq!(message_after_third.messages.len(), 2); assert_eq!( message_after_third.messages[0].message_id, - messages[3].question.message_id + messages[3].message_id ); assert_eq!( message_after_third.messages[1].message_id, - messages[4].question.message_id + messages[4].message_id ); // get all messages after 8th message @@ -82,7 +87,7 @@ async fn create_chat_and_create_messages_test() { .get_chat_messages( &workspace_id, &chat_id, - MessageCursor::AfterMessageId(messages[7].question.message_id), + MessageCursor::AfterMessageId(messages[7].message_id), 100, ) .await @@ -111,10 +116,12 @@ async fn chat_qa_test() { .unwrap(); let params = CreateChatMessageParams::new_user("where is singapore?"); - let message = test_client + let stream = test_client .api_client .create_chat_message(&workspace_id, &chat_id, params) .await .unwrap(); - assert!(!message.answer.unwrap().content.is_empty()); + + let messages: Vec = stream.map(|message| message.unwrap()).collect().await; + assert_eq!(messages.len(), 2); } diff --git a/tests/sql_test/chat_test.rs b/tests/sql_test/chat_test.rs index 07288194..e593f8c3 100644 --- a/tests/sql_test/chat_test.rs +++ b/tests/sql_test/chat_test.rs @@ -3,7 +3,7 @@ use database::chat::chat_ops::{ delete_chat, get_all_chat_messages, insert_chat, insert_chat_message, select_chat, select_chat_messages, }; -use database_entity::dto::{ChatAuthor, CreateChatParams, GetChatMessageParams}; +use database_entity::dto::{ChatAuthor, ChatAuthorType, CreateChatParams, GetChatMessageParams}; use serde_json::json; use sqlx::PgPool; @@ -91,9 +91,14 @@ async fn chat_message_crud_test(pool: PgPool) { // create chat messages for i in 0..5 { - let _ = insert_chat_message(&pool, ChatAuthor::Human, &chat_id, format!("message {}", i)) - .await - .unwrap(); + let _ = insert_chat_message( + &pool, + ChatAuthor::new(0, ChatAuthorType::System), + &chat_id, + format!("message {}", i), + ) + .await + .unwrap(); } { let params = GetChatMessageParams::next_back(3);