diff --git a/libs/appflowy-ai-client/src/client.rs b/libs/appflowy-ai-client/src/client.rs index 924b36e8..8cfabcc4 100644 --- a/libs/appflowy-ai-client/src/client.rs +++ b/libs/appflowy-ai-client/src/client.rs @@ -1,6 +1,6 @@ use crate::dto::{ ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, Document, MessageData, - SearchDocumentsRequest, SummarizeRowResponse, TranslateRowResponse, + RepeatedRelatedQuestion, SearchDocumentsRequest, SummarizeRowResponse, TranslateRowResponse, }; use crate::error::AIError; use anyhow::anyhow; @@ -161,6 +161,18 @@ impl AppFlowyAIClient { AIResponse::::stream_response(resp).await } + pub async fn get_related_question( + &self, + chat_id: &str, + message_id: &i64, + ) -> Result { + let url = format!("{}/chat/{chat_id}/{message_id}/related_question", self.url); + let resp = self.http_client(Method::GET, &url)?.send().await?; + AIResponse::::from_response(resp) + .await? + .into_data() + } + fn http_client(&self, method: Method, url: &str) -> Result { let request_builder = self.client.request(method, url); Ok(request_builder) diff --git a/libs/appflowy-ai-client/src/dto.rs b/libs/appflowy-ai-client/src/dto.rs index b7eea222..26b13467 100644 --- a/libs/appflowy-ai-client/src/dto.rs +++ b/libs/appflowy-ai-client/src/dto.rs @@ -27,6 +27,20 @@ pub struct ChatAnswer { pub content: String, } +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RepeatedRelatedQuestion { + pub message_id: i64, + pub items: Vec, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RelatedQuestion { + pub content: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, +} + #[derive(Clone, Debug, Serialize, Deserialize)] pub struct CompleteTextResponse { pub text: String, diff --git a/libs/appflowy-ai-client/tests/chat_test/qa_test.rs b/libs/appflowy-ai-client/tests/chat_test/qa_test.rs index 056ec847..0befed54 100644 --- a/libs/appflowy-ai-client/tests/chat_test/qa_test.rs +++ b/libs/appflowy-ai-client/tests/chat_test/qa_test.rs @@ -11,9 +11,16 @@ async fn qa_test() { .await .unwrap(); assert!(!resp.content.is_empty()); + + let questions = client + .get_related_question(&chat_id, &1) + .await + .unwrap() + .items; + println!("questions: {:?}", questions); + assert_eq!(questions.len(), 3) } #[tokio::test] - async fn stop_steam_test() { let client = appflowy_ai_client(); client.health_check().await.unwrap(); @@ -35,11 +42,12 @@ async fn stop_steam_test() { assert_ne!(count, 0); } +#[tokio::test] async fn steam_test() { let client = appflowy_ai_client(); client.health_check().await.unwrap(); let chat_id = uuid::Uuid::new_v4().to_string(); - let mut stream = client + let stream = client .stream_question(&chat_id, "I feel hungry") .await .unwrap(); diff --git a/libs/client-api/src/http_chat.rs b/libs/client-api/src/http_chat.rs index 341bf4a3..a371149e 100644 --- a/libs/client-api/src/http_chat.rs +++ b/libs/client-api/src/http_chat.rs @@ -5,6 +5,7 @@ use database_entity::dto::{ }; use futures_core::Stream; use reqwest::Method; +use shared_entity::dto::ai_dto::RepeatedRelatedQuestion; use shared_entity::response::{AppResponse, AppResponseError}; impl Client { @@ -57,6 +58,26 @@ impl Client { log_request_id(&resp); AppResponse::::stream_response(resp).await } + pub async fn get_chat_related_question( + &self, + workspace_id: &str, + chat_id: &str, + message_id: i64, + ) -> Result { + let url = format!( + "{}/api/chat/{workspace_id}/{chat_id}/{message_id}/related_question", + self.base_url + ); + let resp = self + .http_client_with_auth(Method::GET, &url) + .await? + .send() + .await?; + log_request_id(&resp); + AppResponse::::from_response(resp) + .await? + .into_data() + } pub async fn get_chat_messages( &self, diff --git a/libs/database-entity/src/dto.rs b/libs/database-entity/src/dto.rs index 4c9a73e4..f18cb8b7 100644 --- a/libs/database-entity/src/dto.rs +++ b/libs/database-entity/src/dto.rs @@ -572,6 +572,12 @@ pub struct CreateChatMessageParams { pub message_type: ChatMessageType, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UpdateChatMessageParams { + pub message_id: i64, + pub meta_data: HashMap, +} + #[derive(Debug, Clone, Default, Serialize_repr, Deserialize_repr)] #[repr(u8)] pub enum ChatMessageType { diff --git a/libs/database/src/chat/chat_ops.rs b/libs/database/src/chat/chat_ops.rs index db55e2ed..8a911d69 100644 --- a/libs/database/src/chat/chat_ops.rs +++ b/libs/database/src/chat/chat_ops.rs @@ -5,12 +5,12 @@ use app_error::AppError; use chrono::{DateTime, Utc}; use database_entity::dto::{ ChatAuthor, ChatMessage, CreateChatParams, GetChatMessageParams, MessageCursor, - RepeatedChatMessage, UpdateChatParams, + RepeatedChatMessage, UpdateChatMessageParams, UpdateChatParams, }; use serde_json::json; use sqlx::postgres::PgArguments; -use sqlx::{Arguments, Executor, Postgres, Transaction}; +use sqlx::{Arguments, Executor, PgPool, Postgres, Transaction}; use std::ops::DerefMut; use std::str::FromStr; use tracing::warn; @@ -167,7 +167,7 @@ pub async fn select_chat_messages( ) -> Result { let chat_id = Uuid::from_str(chat_id)?; let mut query = r#" - SELECT message_id, content, created_at, author + SELECT message_id, content, created_at, author, meta_data FROM af_chat_messages WHERE chat_id = $1 "# @@ -326,3 +326,30 @@ pub async fn get_all_chat_messages<'a, E: Executor<'a, Database = Postgres>>( Ok(messages) } + +pub async fn update_chat_message( + pg_pool: &PgPool, + params: UpdateChatMessageParams, +) -> Result<(), AppError> { + for (key, value) in params.meta_data.iter() { + sqlx::query( + r#" + UPDATE af_chat_messages + SET meta_data = jsonb_set( + COALESCE(meta_data, '{}'), + $2, + $3::jsonb, + true + ) + WHERE id = $1 + "#, + ) + .bind(params.message_id) + .bind(format!("{{{}}}", key)) + .bind(value) + .execute(pg_pool) + .await?; + } + + Ok(()) +} diff --git a/services/appflowy-indexer/src/collab_handle.rs b/services/appflowy-indexer/src/collab_handle.rs index babaa15a..fd8a7346 100644 --- a/services/appflowy-indexer/src/collab_handle.rs +++ b/services/appflowy-indexer/src/collab_handle.rs @@ -77,14 +77,14 @@ impl CollabHandle { Self::handle_collab_updates(&mut update_stream, content.get_collab(), messages).await?; } let workspace_id = - Uuid::parse_str(&workspace_id).map_err(|e| crate::error::Error::InvalidWorkspace(e))?; + Uuid::parse_str(&workspace_id).map_err(crate::error::Error::InvalidWorkspace)?; let mut tasks = JoinSet::new(); tasks.spawn(Self::receive_collab_updates( update_stream, Arc::downgrade(&content), object_id.clone(), - workspace_id.clone(), + workspace_id, ingest_interval, closing.clone(), )); @@ -253,8 +253,8 @@ impl CollabHandle { } pub async fn shutdown(mut self) { - let _ = self.closing.cancel(); - while let Some(_) = self.tasks.join_next().await { /* wait for all tasks to finish */ } + self.closing.cancel(); + while self.tasks.join_next().await.is_some() { /* wait for all tasks to finish */ } } } @@ -353,7 +353,7 @@ mod test { let tokens: i64 = sqlx::query("SELECT index_token_usage from af_workspace WHERE workspace_id = $1") - .bind(&workspace_id) + .bind(workspace_id) .fetch_one(&db) .await .unwrap() diff --git a/services/appflowy-indexer/src/indexer.rs b/services/appflowy-indexer/src/indexer.rs index 9c843c8c..291cb0c5 100644 --- a/services/appflowy-indexer/src/indexer.rs +++ b/services/appflowy-indexer/src/indexer.rs @@ -72,7 +72,7 @@ impl From for EmbedFragment { impl From for AFCollabEmbeddingParams { fn from(f: EmbedFragment) -> Self { AFCollabEmbeddingParams { - fragment_id: f.fragment_id.into(), + fragment_id: f.fragment_id, object_id: f.object_id, collab_type: f.collab_type, content_type: f.content_type, @@ -91,7 +91,7 @@ impl PostgresIndexer { #[allow(dead_code)] pub async fn open(openai_api_key: &str, pg_conn: &str) -> Result { let openai = Client::new(openai_api_key.to_string()); - let db = PgPool::connect(&pg_conn).await?; + let db = PgPool::connect(pg_conn).await?; Ok(Self { openai, db }) } @@ -226,7 +226,7 @@ mod test { // resolve embeddings from OpenAI let embeddings = indexer.get_embeddings(fragments).await.unwrap(); - assert_eq!(embeddings.fragments[0].embedding.is_some(), true); + assert!(embeddings.fragments[0].embedding.is_some()); // store embeddings in DB indexer diff --git a/services/appflowy-indexer/src/test_utils.rs b/services/appflowy-indexer/src/test_utils.rs index ddd43ae5..298e0717 100644 --- a/services/appflowy-indexer/src/test_utils.rs +++ b/services/appflowy-indexer/src/test_utils.rs @@ -31,7 +31,7 @@ pub async fn setup_collab(db: &PgPool, uid: i64, object_id: Uuid, encoded_collab let mut tx = db.begin().await.unwrap(); let user_uuid = Uuid::new_v4(); sqlx::query("INSERT INTO auth.users(id) VALUES($1)") - .bind(&user_uuid) + .bind(user_uuid) .execute(tx.deref_mut()) .await .unwrap(); @@ -48,7 +48,7 @@ pub async fn setup_collab(db: &PgPool, uid: i64, object_id: Uuid, encoded_collab &mut tx, &uid, &workspace_id.to_string(), - &CollabParams::new(object_id.clone(), CollabType::Document, encoded_collab), + &CollabParams::new(object_id, CollabType::Document, encoded_collab), ) .await .unwrap(); diff --git a/services/appflowy-indexer/src/watchers/document_watcher.rs b/services/appflowy-indexer/src/watchers/document_watcher.rs index 713cae6d..03de86d2 100644 --- a/services/appflowy-indexer/src/watchers/document_watcher.rs +++ b/services/appflowy-indexer/src/watchers/document_watcher.rs @@ -120,7 +120,7 @@ impl DocumentWatcher { if block.external_type.as_deref() == Some("text") { if let Some(text_id) = block.external_id.as_deref() { if let Some(json) = text_map.get(text_id) { - match serde_json::from_str::>(&json) { + match serde_json::from_str::>(json) { Ok(deltas) => { for delta in deltas { if let TextDelta::Inserted(text, _) = delta { @@ -151,7 +151,7 @@ impl DocumentWatcher { impl Indexable for DocumentWatcher { fn get_collab(&self) -> &MutexCollab { - &*self.content.get_collab() + self.content.get_collab() } fn changes(&self) -> Pin + Send + Sync>> { @@ -253,9 +253,7 @@ mod test { let blocks = &data.blocks; let (_, block) = blocks .iter() - //TODO: Block::external_type should probably be an enum - .filter(|(_, b)| b.external_type.as_deref() == Some("text")) - .next() + .find(|(_, b)| b.external_type.as_deref() == Some("text")) .unwrap(); block.clone() } diff --git a/src/api/chat.rs b/src/api/chat.rs index e9ae8700..f61fe8a3 100644 --- a/src/api/chat.rs +++ b/src/api/chat.rs @@ -3,14 +3,17 @@ use crate::state::AppState; use actix_web::web::{Data, Json}; use actix_web::{web, HttpResponse, Scope}; use app_error::AppError; +use appflowy_ai_client::dto::RepeatedRelatedQuestion; use authentication::jwt::UserUuid; +use database::chat::chat_ops::update_chat_message; use database_entity::dto::{ CreateChatMessageParams, CreateChatParams, GetChatMessageParams, MessageCursor, - RepeatedChatMessage, + RepeatedChatMessage, UpdateChatMessageParams, }; use shared_entity::response::{AppResponse, JsonAppResponse}; use std::collections::HashMap; -use tracing::trace; + +use tracing::{instrument, trace}; use validator::Validate; pub fn chat_scope() -> Scope { @@ -22,7 +25,15 @@ pub fn chat_scope() -> Scope { .route(web::post().to(update_chat_handler)) .route(web::get().to(get_chat_message_handler)), ) - .service(web::resource("/{chat_id}/message").route(web::post().to(post_chat_message_handler))) + .service( + web::resource("/{chat_id}/{message_id}/related_question") + .route(web::get().to(get_related_message_handler)), + ) + .service( + web::resource("/{chat_id}/message") + .route(web::post().to(post_chat_message_handler)) + .route(web::put().to(update_chat_message_handler)), + ) } async fn create_chat_handler( path: web::Path, @@ -83,6 +94,29 @@ async fn post_chat_message_handler( ) } +async fn update_chat_message_handler( + state: Data, + payload: Json, +) -> actix_web::Result> { + let params = payload.into_inner(); + update_chat_message(&state.pg_pool, params).await?; + Ok(AppResponse::Ok().into()) +} + +async fn get_related_message_handler( + path: web::Path<(String, String, i64)>, + state: Data, +) -> actix_web::Result> { + let (_workspace_id, chat_id, message_id) = path.into_inner(); + let resp = state + .ai_client + .get_related_question(&chat_id, &message_id) + .await + .map_err(|err| AppError::Internal(err.into()))?; + Ok(AppResponse::Ok().with_data(resp).into()) +} + +#[instrument(level = "debug", skip_all, err)] async fn get_chat_message_handler( path: web::Path<(String, String)>, query: web::Query>, diff --git a/tests/ai_test/chat_test.rs b/tests/ai_test/chat_test.rs index a6f78ae9..32c9c113 100644 --- a/tests/ai_test/chat_test.rs +++ b/tests/ai_test/chat_test.rs @@ -117,4 +117,12 @@ async fn chat_qa_test() { let messages: Vec = stream.map(|message| message.unwrap()).collect().await; assert_eq!(messages.len(), 2); + + let related_questions = test_client + .api_client + .get_chat_related_question(&workspace_id, &chat_id, messages[1].message_id) + .await + .unwrap(); + assert_eq!(related_questions.items.len(), 3); + println!("related questions: {:?}", related_questions.items); }