diff --git a/.sqlx/query-533ef0f5237ca12ce0c6ca1dc938cc8dd34603b256f36dc683013264018332fc.json b/.sqlx/query-69b0bbc97c37de47ad3f9cfe4f0099496470100d49903e34c8f773c43f094b2f.json similarity index 68% rename from .sqlx/query-533ef0f5237ca12ce0c6ca1dc938cc8dd34603b256f36dc683013264018332fc.json rename to .sqlx/query-69b0bbc97c37de47ad3f9cfe4f0099496470100d49903e34c8f773c43f094b2f.json index 53607eba..74d158de 100644 --- a/.sqlx/query-533ef0f5237ca12ce0c6ca1dc938cc8dd34603b256f36dc683013264018332fc.json +++ b/.sqlx/query-69b0bbc97c37de47ad3f9cfe4f0099496470100d49903e34c8f773c43f094b2f.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT message_id, content, created_at, author\n FROM af_chat_messages\n WHERE chat_id = $1\n ORDER BY created_at ASC\n ", + "query": "\n SELECT message_id, content, created_at, author, meta_data\n FROM af_chat_messages\n WHERE chat_id = $1\n ORDER BY created_at ASC\n ", "describe": { "columns": [ { @@ -22,6 +22,11 @@ "ordinal": 3, "name": "author", "type_info": "Jsonb" + }, + { + "ordinal": 4, + "name": "meta_data", + "type_info": "Jsonb" } ], "parameters": { @@ -33,8 +38,9 @@ false, false, false, + false, false ] }, - "hash": "533ef0f5237ca12ce0c6ca1dc938cc8dd34603b256f36dc683013264018332fc" + "hash": "69b0bbc97c37de47ad3f9cfe4f0099496470100d49903e34c8f773c43f094b2f" } diff --git a/.sqlx/query-fb21df2827de97055cdc1c493b079b29667f75b18169c909c4c8341697fd0105.json b/.sqlx/query-fb21df2827de97055cdc1c493b079b29667f75b18169c909c4c8341697fd0105.json index 60efa542..74407f8c 100644 --- a/.sqlx/query-fb21df2827de97055cdc1c493b079b29667f75b18169c909c4c8341697fd0105.json +++ b/.sqlx/query-fb21df2827de97055cdc1c493b079b29667f75b18169c909c4c8341697fd0105.json @@ -32,6 +32,11 @@ "ordinal": 5, "name": "workspace_id", "type_info": "Uuid" + }, + { + "ordinal": 6, + "name": "meta_data", + "type_info": "Jsonb" } ], "parameters": { @@ -45,6 +50,7 @@ true, false, false, + false, false ] }, diff --git a/Cargo.lock b/Cargo.lock index af1a6fd9..beb14515 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -568,6 +568,8 @@ name = "appflowy-ai-client" version = "0.1.0" dependencies = [ "anyhow", + "bytes", + "futures", "reqwest 0.12.4", "serde", "serde_json", diff --git a/libs/appflowy-ai-client/Cargo.toml b/libs/appflowy-ai-client/Cargo.toml index 992fd39f..62c621b7 100644 --- a/libs/appflowy-ai-client/Cargo.toml +++ b/libs/appflowy-ai-client/Cargo.toml @@ -6,13 +6,15 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -reqwest = { version = "0.12", features = ["json", "rustls-tls", "cookies"], optional = true } +reqwest = { version = "0.12", features = ["json", "rustls-tls", "cookies", "stream"], optional = true } serde = { version = "1.0.199", features = ["derive"], optional = true } serde_json = { version = "1.0", optional = true } thiserror = "1.0.58" anyhow = "1.0.81" tracing = { version = "0.1", optional = true } serde_repr = { version = "0.1", optional = true } +futures = "0.3.30" +bytes = "1.6.0" [dev-dependencies] tokio = { version = "1.37.0", features = ["macros", "test-util"] } diff --git a/libs/appflowy-ai-client/src/client.rs b/libs/appflowy-ai-client/src/client.rs index 084a2ad0..924b36e8 100644 --- a/libs/appflowy-ai-client/src/client.rs +++ b/libs/appflowy-ai-client/src/client.rs @@ -3,10 +3,14 @@ use crate::dto::{ SearchDocumentsRequest, SummarizeRowResponse, TranslateRowResponse, }; use crate::error::AIError; +use anyhow::anyhow; +use futures::{Stream, StreamExt}; +use reqwest; use reqwest::{Method, RequestBuilder, StatusCode}; use serde::{Deserialize, Serialize}; use serde_json::{json, Map, Value}; use std::borrow::Cow; + use tracing::{info, trace}; #[derive(Clone, Debug)] @@ -137,6 +141,26 @@ impl AppFlowyAIClient { .into_data() } + pub async fn stream_question( + &self, + chat_id: &str, + content: &str, + ) -> Result>, AIError> { + let json = ChatQuestion { + chat_id: chat_id.to_string(), + data: MessageData { + content: content.to_string(), + }, + }; + let url = format!("{}/chat/stream_message", self.url); + let resp = self + .http_client(Method::POST, &url)? + .json(&json) + .send() + .await?; + AIResponse::::stream_response(resp).await + } + fn http_client(&self, method: Method, url: &str) -> Result { let request_builder = self.client.request(method, url); Ok(request_builder) @@ -174,8 +198,27 @@ where Some(data) => Ok(data), } } -} + pub async fn stream_response( + resp: reqwest::Response, + ) -> Result>, AIError> { + let status_code = resp.status(); + if !status_code.is_success() { + let body = resp.text().await?; + return Err(AIError::InvalidRequest(body)); + } + let stream = resp.bytes_stream().map(|item| { + item + .map_err(|err| AIError::Internal(err.into())) + .and_then(|bytes| { + String::from_utf8(bytes.to_vec()) + .map(|s| s.replace('\n', "")) + .map_err(|err| AIError::Internal(anyhow!("Parser AI response error: {:?}", err))) + }) + }); + Ok(stream) + } +} impl From for AIError { fn from(error: reqwest::Error) -> Self { if error.is_timeout() { diff --git a/libs/appflowy-ai-client/tests/chat_test/qa_test.rs b/libs/appflowy-ai-client/tests/chat_test/qa_test.rs index d53a0163..056ec847 100644 --- a/libs/appflowy-ai-client/tests/chat_test/qa_test.rs +++ b/libs/appflowy-ai-client/tests/chat_test/qa_test.rs @@ -1,4 +1,5 @@ use crate::appflowy_ai_client; +use futures::stream::StreamExt; #[tokio::test] async fn qa_test() { @@ -11,3 +12,38 @@ async fn qa_test() { .unwrap(); assert!(!resp.content.is_empty()); } +#[tokio::test] + +async fn stop_steam_test() { + let client = appflowy_ai_client(); + client.health_check().await.unwrap(); + let chat_id = uuid::Uuid::new_v4().to_string(); + let mut stream = client + .stream_question(&chat_id, "I feel hungry") + .await + .unwrap(); + + let mut count = 0; + while let Some(message) = stream.next().await { + if count > 1 { + break; + } + count += 1; + println!("message: {:?}", message); + } + + assert_ne!(count, 0); +} + +async fn steam_test() { + let client = appflowy_ai_client(); + client.health_check().await.unwrap(); + let chat_id = uuid::Uuid::new_v4().to_string(); + let mut stream = client + .stream_question(&chat_id, "I feel hungry") + .await + .unwrap(); + + let messages: Vec = stream.map(|message| message.unwrap()).collect().await; + println!("final answer: {}", messages.join("")); +} diff --git a/libs/database-entity/src/dto.rs b/libs/database-entity/src/dto.rs index 912625a5..4c9a73e4 100644 --- a/libs/database-entity/src/dto.rs +++ b/libs/database-entity/src/dto.rs @@ -644,6 +644,7 @@ pub struct ChatMessage { pub message_id: i64, pub content: String, pub created_at: DateTime, + pub meta_data: serde_json::Value, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/libs/database/src/chat/chat_ops.rs b/libs/database/src/chat/chat_ops.rs index 5f96e66a..db55e2ed 100644 --- a/libs/database/src/chat/chat_ops.rs +++ b/libs/database/src/chat/chat_ops.rs @@ -155,6 +155,7 @@ pub async fn insert_chat_message<'a, E: Executor<'a, Database = Postgres>>( message_id: row.message_id, content, created_at: row.created_at, + meta_data: Default::default(), }; Ok(chat_message) } @@ -203,20 +204,26 @@ pub async fn select_chat_messages( }, } - let rows: Vec<(i64, String, DateTime, serde_json::Value)> = - sqlx::query_as_with(&query, args) - .fetch_all(txn.deref_mut()) - .await?; + let rows: Vec<( + i64, + String, + DateTime, + serde_json::Value, + serde_json::Value, + )> = sqlx::query_as_with(&query, args) + .fetch_all(txn.deref_mut()) + .await?; let messages = rows .into_iter() - .flat_map(|(message_id, content, created_at, author)| { + .flat_map(|(message_id, content, created_at, author, meta_data)| { match serde_json::from_value::(author) { Ok(author) => Some(ChatMessage { author, message_id, content, created_at, + meta_data, }), Err(err) => { warn!("Failed to deserialize author: {}", err); @@ -288,7 +295,7 @@ pub async fn get_all_chat_messages<'a, E: Executor<'a, Database = Postgres>>( let rows = sqlx::query!( // ChatMessage, r#" - SELECT message_id, content, created_at, author + SELECT message_id, content, created_at, author, meta_data FROM af_chat_messages WHERE chat_id = $1 ORDER BY created_at ASC @@ -307,6 +314,7 @@ pub async fn get_all_chat_messages<'a, E: Executor<'a, Database = Postgres>>( message_id: row.message_id, content: row.content, created_at: row.created_at, + meta_data: row.meta_data, }), Err(err) => { warn!("Failed to deserialize author: {}", err); diff --git a/libs/database/src/pg_row.rs b/libs/database/src/pg_row.rs index 2cf33017..9c3d0986 100644 --- a/libs/database/src/pg_row.rs +++ b/libs/database/src/pg_row.rs @@ -204,6 +204,7 @@ pub struct AFChatRow { pub deleted_at: Option>, pub rag_ids: serde_json::Value, pub workspace_id: Uuid, + pub meta_data: serde_json::Value, } #[derive(Debug, Clone, FromRow, Serialize, Deserialize)] pub struct AFChatMessageRow { diff --git a/migrations/20240531031836_chat_message_meta.sql b/migrations/20240531031836_chat_message_meta.sql new file mode 100644 index 00000000..b31467ab --- /dev/null +++ b/migrations/20240531031836_chat_message_meta.sql @@ -0,0 +1,6 @@ +-- Add migration script here +ALTER TABLE af_chat + ADD COLUMN meta_data JSONB DEFAULT '{}' NOT NULL; + +ALTER TABLE af_chat_messages + ADD COLUMN meta_data JSONB DEFAULT '{}' NOT NULL; \ No newline at end of file