From ecadf8e287204c9f05d9dfe63b88b47f8d446068 Mon Sep 17 00:00:00 2001 From: Richard Shiue <71320345+richardshiue@users.noreply.github.com> Date: Thu, 19 Dec 2024 00:12:53 +0800 Subject: [PATCH] chore: find question message from reply message (#1085) * chore: find question message from answer message id * chore: sqlx * test: fix tests * test: fix test * chore: apply code suggestions to 2 files --- ...8c14137dd09b11be73442a7f46b2f938b8445.json | 53 +++++++++++++ libs/client-api/src/http_chat.rs | 22 ++++++ libs/database/src/chat/chat_ops.rs | 37 +++++++++ src/api/chat.rs | 23 +++++- src/biz/chat/ops.rs | 14 +++- tests/ai_test/chat_test.rs | 76 +++++++++++++++++-- 6 files changed, 216 insertions(+), 9 deletions(-) create mode 100644 .sqlx/query-794c4ced16801b3e98a62eb44c18c14137dd09b11be73442a7f46b2f938b8445.json diff --git a/.sqlx/query-794c4ced16801b3e98a62eb44c18c14137dd09b11be73442a7f46b2f938b8445.json b/.sqlx/query-794c4ced16801b3e98a62eb44c18c14137dd09b11be73442a7f46b2f938b8445.json new file mode 100644 index 00000000..1148f8cd --- /dev/null +++ b/.sqlx/query-794c4ced16801b3e98a62eb44c18c14137dd09b11be73442a7f46b2f938b8445.json @@ -0,0 +1,53 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT message_id, content, created_at, author, meta_data, reply_message_id\n FROM af_chat_messages\n WHERE chat_id = $1\n AND reply_message_id = $2\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "message_id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "content", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 3, + "name": "author", + "type_info": "Jsonb" + }, + { + "ordinal": 4, + "name": "meta_data", + "type_info": "Jsonb" + }, + { + "ordinal": 5, + "name": "reply_message_id", + "type_info": "Int8" + } + ], + "parameters": { + "Left": [ + "Uuid", + "Int8" + ] + }, + "nullable": [ + false, + false, + false, + false, + false, + true + ] + }, + "hash": "794c4ced16801b3e98a62eb44c18c14137dd09b11be73442a7f46b2f938b8445" +} diff --git a/libs/client-api/src/http_chat.rs b/libs/client-api/src/http_chat.rs index 36f015a8..dc540121 100644 --- a/libs/client-api/src/http_chat.rs +++ b/libs/client-api/src/http_chat.rs @@ -262,6 +262,28 @@ impl Client { .into_data() } + pub async fn get_question_message_from_answer_id( + &self, + workspace_id: &str, + chat_id: &str, + answer_message_id: i64, + ) -> Result, AppResponseError> { + let url = format!( + "{}/api/chat/{workspace_id}/{chat_id}/message/find_question", + self.base_url + ); + + let resp = self + .http_client_with_auth(Method::GET, &url) + .await? + .query(&[("answer_message_id", answer_message_id)]) + .send() + .await?; + AppResponse::>::from_response(resp) + .await? + .into_data() + } + pub async fn calculate_similarity( &self, params: CalculateSimilarityParams, diff --git a/libs/database/src/chat/chat_ops.rs b/libs/database/src/chat/chat_ops.rs index 7bb625ce..d415fa67 100644 --- a/libs/database/src/chat/chat_ops.rs +++ b/libs/database/src/chat/chat_ops.rs @@ -669,3 +669,40 @@ pub async fn select_chat_message_content<'a, E: Executor<'a, Database = Postgres .await?; Ok((row.content, row.meta_data)) } + +pub async fn select_chat_message_matching_reply_message_id( + txn: &mut Transaction<'_, Postgres>, + chat_id: &str, + reply_message_id: i64, +) -> Result, AppError> { + let chat_id = Uuid::from_str(chat_id)?; + let row = sqlx::query!( + r#" + SELECT message_id, content, created_at, author, meta_data, reply_message_id + FROM af_chat_messages + WHERE chat_id = $1 + AND reply_message_id = $2 + "#, + &chat_id, + reply_message_id + ) + .fetch_one(txn.deref_mut()) + .await?; + + let message = match serde_json::from_value::(row.author) { + Ok(author) => Some(ChatMessage { + author, + message_id: row.message_id, + content: row.content, + created_at: row.created_at, + meta_data: row.meta_data, + reply_message_id: row.reply_message_id, + }), + Err(err) => { + warn!("Failed to deserialize author: {}", err); + None + }, + }; + + Ok(message) +} diff --git a/src/api/chat.rs b/src/api/chat.rs index c44ebba4..67dc1fab 100644 --- a/src/api/chat.rs +++ b/src/api/chat.rs @@ -1,10 +1,11 @@ use crate::biz::chat::ops::{ create_chat, create_chat_message, delete_chat, generate_chat_message_answer, get_chat_messages, - update_chat_message, + get_question_message, update_chat_message, }; use crate::state::AppState; use actix_web::web::{Data, Json}; use actix_web::{web, HttpRequest, HttpResponse, Scope}; +use serde::Deserialize; use crate::api::util::ai_model_from_header; use app_error::AppError; @@ -69,6 +70,10 @@ pub fn chat_scope() -> Scope { web::resource("/{chat_id}/message/answer") .route(web::post().to(save_answer_handler)) ) + .service( + web::resource("/{chat_id}/message/find_question") + .route(web::get().to(get_chat_question_message_handler)) + ) // AI response generation .service( @@ -349,6 +354,17 @@ async fn get_chat_message_handler( Ok(AppResponse::Ok().with_data(messages).into()) } +#[instrument(level = "debug", skip_all, err)] +async fn get_chat_question_message_handler( + path: web::Path<(String, String)>, + query: web::Query, + state: Data, +) -> actix_web::Result>> { + let (_workspace_id, chat_id) = path.into_inner(); + let message = get_question_message(&state.pg_pool, &chat_id, query.0.answer_message_id).await?; + Ok(AppResponse::Ok().with_data(message).into()) +} + #[instrument(level = "debug", skip_all, err)] async fn get_chat_settings_handler( path: web::Path<(String, String)>, @@ -501,3 +517,8 @@ where } } } + +#[derive(Debug, Deserialize)] +struct FindQuestionParams { + answer_message_id: i64, +} diff --git a/src/biz/chat/ops.rs b/src/biz/chat/ops.rs index 65a5c125..ff6e2541 100644 --- a/src/biz/chat/ops.rs +++ b/src/biz/chat/ops.rs @@ -8,7 +8,7 @@ use database::chat; use database::chat::chat_ops::{ delete_answer_message_by_question_message_id, insert_answer_message, insert_answer_message_with_transaction, insert_chat, insert_question_message, - select_chat_messages, + select_chat_message_matching_reply_message_id, select_chat_messages, }; use futures::stream::Stream; use serde_json::json; @@ -232,3 +232,15 @@ pub async fn get_chat_messages( txn.commit().await?; Ok(messages) } + +pub async fn get_question_message( + pg_pool: &PgPool, + chat_id: &str, + answer_message_id: i64, +) -> Result, AppError> { + let mut txn = pg_pool.begin().await?; + let message = + select_chat_message_matching_reply_message_id(&mut txn, chat_id, answer_message_id).await?; + txn.commit().await?; + Ok(message) +} diff --git a/tests/ai_test/chat_test.rs b/tests/ai_test/chat_test.rs index 46da750d..df21dc33 100644 --- a/tests/ai_test/chat_test.rs +++ b/tests/ai_test/chat_test.rs @@ -6,8 +6,8 @@ use client_api_test::{ai_test_enabled, TestClient}; use futures_util::StreamExt; use serde_json::json; use shared_entity::dto::chat_dto::{ - ChatMessageMetadata, ChatRAGData, CreateChatMessageParams, CreateChatParams, MessageCursor, - UpdateChatParams, + ChatMessageMetadata, ChatRAGData, CreateAnswerMessageParams, CreateChatMessageParams, + CreateChatParams, MessageCursor, UpdateChatParams, }; #[tokio::test] @@ -344,6 +344,10 @@ async fn create_chat_context_test() { // #[tokio::test] // async fn update_chat_message_test() { +// if !ai_test_enabled() { +// return; +// } + // let test_client = TestClient::new_user_without_ws_conn().await; // let workspace_id = test_client.workspace_id().await; // let chat_id = uuid::Uuid::new_v4().to_string(); @@ -352,13 +356,13 @@ async fn create_chat_context_test() { // name: "my second chat".to_string(), // rag_ids: vec![], // }; -// + // test_client // .api_client // .create_chat(&workspace_id, params) // .await // .unwrap(); -// + // let params = CreateChatMessageParams::new_user("where is singapore?"); // let stream = test_client // .api_client @@ -367,7 +371,7 @@ async fn create_chat_context_test() { // .unwrap(); // let messages: Vec = stream.map(|message| message.unwrap()).collect().await; // assert_eq!(messages.len(), 2); -// + // let params = UpdateChatMessageContentParams { // chat_id: chat_id.clone(), // message_id: messages[0].message_id, @@ -378,7 +382,7 @@ async fn create_chat_context_test() { // .update_chat_message(&workspace_id, &chat_id, params) // .await // .unwrap(); -// + // let remote_messages = test_client // .api_client // .get_chat_messages(&workspace_id, &chat_id, MessageCursor::NextBack, 2) @@ -387,11 +391,69 @@ async fn create_chat_context_test() { // .messages; // assert_eq!(remote_messages[0].content, "where is China?"); // assert_eq!(remote_messages.len(), 2); -// + // // when the question was updated, the answer should be different // assert_ne!(remote_messages[1].content, messages[1].content); // } +#[tokio::test] +async fn get_question_message_test() { + if !ai_test_enabled() { + return; + } + + let test_client = TestClient::new_user_without_ws_conn().await; + let workspace_id = test_client.workspace_id().await; + let chat_id = uuid::Uuid::new_v4().to_string(); + let params = CreateChatParams { + chat_id: chat_id.clone(), + name: "my ai chat".to_string(), + rag_ids: vec![], + }; + + test_client + .api_client + .create_chat(&workspace_id, params) + .await + .unwrap(); + + let params = CreateChatMessageParams::new_user("where is singapore?"); + let question = test_client + .api_client + .create_question(&workspace_id, &chat_id, params) + .await + .unwrap(); + + let answer = test_client + .api_client + .get_answer(&workspace_id, &chat_id, question.message_id) + .await + .unwrap(); + + test_client + .api_client + .save_answer( + &workspace_id, + &chat_id, + CreateAnswerMessageParams { + content: answer.content, + metadata: None, + question_message_id: question.message_id, + }, + ) + .await + .unwrap(); + + let find_question = test_client + .api_client + .get_question_message_from_answer_id(&workspace_id, &chat_id, answer.message_id) + .await + .unwrap() + .unwrap(); + + assert_eq!(find_question.reply_message_id.unwrap(), answer.message_id); +} + async fn collect_answer(mut stream: QuestionStream) -> String { let mut answer = String::new(); while let Some(value) = stream.next().await {