From 1537c4d1f633ac242f8d98ecbda7ce43efd50427 Mon Sep 17 00:00:00 2001 From: "Nathan.fooo" <86001920+appflowy@users.noreply.github.com> Date: Thu, 30 May 2024 17:15:13 +0800 Subject: [PATCH] chore: update chat test (#590) --- libs/database/src/chat/chat_ops.rs | 14 +++----- src/api/chat.rs | 1 + tests/ai_test/chat_test.rs | 53 +++++++++++++----------------- 3 files changed, 29 insertions(+), 39 deletions(-) diff --git a/libs/database/src/chat/chat_ops.rs b/libs/database/src/chat/chat_ops.rs index f8ed14ab..5f96e66a 100644 --- a/libs/database/src/chat/chat_ops.rs +++ b/libs/database/src/chat/chat_ops.rs @@ -183,7 +183,7 @@ pub async fn select_chat_messages( MessageCursor::AfterMessageId(after_message_id) => { query += " AND message_id > $2"; args.add(after_message_id); - query += " ORDER BY message_id ASC LIMIT $3"; + query += " ORDER BY message_id DESC LIMIT $3"; args.add(params.limit as i64); }, MessageCursor::Offset(offset) => { @@ -194,7 +194,7 @@ pub async fn select_chat_messages( MessageCursor::BeforeMessageId(before_message_id) => { query += " AND message_id < $2"; args.add(before_message_id); - query += " ORDER BY message_id ASC LIMIT $3"; + query += " ORDER BY message_id DESC LIMIT $3"; args.add(params.limit as i64); }, MessageCursor::NextBack => { @@ -208,7 +208,7 @@ pub async fn select_chat_messages( .fetch_all(txn.deref_mut()) .await?; - let mut messages = rows + let messages = rows .into_iter() .flat_map(|(message_id, content, created_at, author)| { match serde_json::from_value::(author) { @@ -226,10 +226,6 @@ pub async fn select_chat_messages( }) .collect::>(); - if matches!(params.cursor, MessageCursor::NextBack) { - messages.reverse(); - } - let total = sqlx::query_scalar!( r#" SELECT COUNT(*) @@ -250,7 +246,7 @@ pub async fn select_chat_messages( sqlx::query!( "SELECT EXISTS(SELECT 1 FROM af_chat_messages WHERE chat_id = $1 AND message_id > $2)", &chat_id, - messages.last().as_ref().unwrap().message_id + messages[0].message_id ) .fetch_one(txn.deref_mut()) .await? @@ -266,7 +262,7 @@ pub async fn select_chat_messages( sqlx::query!( "SELECT EXISTS(SELECT 1 FROM af_chat_messages WHERE chat_id = $1 AND message_id < $2)", &chat_id, - messages[0].message_id + messages.last().as_ref().unwrap().message_id ) .fetch_one(txn.deref_mut()) .await? diff --git a/src/api/chat.rs b/src/api/chat.rs index e9ae8700..93539201 100644 --- a/src/api/chat.rs +++ b/src/api/chat.rs @@ -108,5 +108,6 @@ async fn get_chat_message_handler( trace!("get chat messages: {:?}", params); let (_workspace_id, chat_id) = path.into_inner(); let messages = get_chat_messages(&state.pg_pool, params, &chat_id).await?; + trace!("get chat messages: {:?}", messages.messages); Ok(AppResponse::Ok().with_data(messages).into()) } diff --git a/tests/ai_test/chat_test.rs b/tests/ai_test/chat_test.rs index 23b67ed1..a6f78ae9 100644 --- a/tests/ai_test/chat_test.rs +++ b/tests/ai_test/chat_test.rs @@ -34,32 +34,29 @@ async fn create_chat_and_create_messages_test() { .unwrap(); messages.push(message); } + // DESC is the default order + messages.reverse(); // get messages before third message. it should return first two messages even though we asked // for 10 messages + assert_eq!(messages[7].content, "hello world 2"); let message_before_third = test_client .api_client .get_chat_messages( &workspace_id, &chat_id, - MessageCursor::BeforeMessageId(messages[2].message_id), + MessageCursor::BeforeMessageId(messages[7].message_id), 10, ) .await .unwrap(); - assert!(!message_before_third.has_more); assert_eq!(message_before_third.messages.len(), 2); - assert_eq!( - message_before_third.messages[0].message_id, - messages[0].message_id - ); - assert_eq!( - message_before_third.messages[1].message_id, - messages[1].message_id - ); + assert_eq!(message_before_third.messages[0].content, "hello world 1"); + assert_eq!(message_before_third.messages[1].content, "hello world 0"); // get message after third message + assert_eq!(messages[2].content, "hello world 7"); let message_after_third = test_client .api_client .get_chat_messages( @@ -70,32 +67,28 @@ async fn create_chat_and_create_messages_test() { ) .await .unwrap(); - assert!(message_after_third.has_more); + assert!(!message_after_third.has_more); assert_eq!(message_after_third.messages.len(), 2); - assert_eq!( - message_after_third.messages[0].message_id, - messages[3].message_id - ); - assert_eq!( - message_after_third.messages[1].message_id, - messages[4].message_id - ); + assert_eq!(message_after_third.messages[0].content, "hello world 9"); + assert_eq!(message_after_third.messages[1].content, "hello world 8"); - // get all messages after 8th message - let remaining_messages = test_client + let next_back = test_client .api_client - .get_chat_messages( - &workspace_id, - &chat_id, - MessageCursor::AfterMessageId(messages[7].message_id), - 100, - ) + .get_chat_messages(&workspace_id, &chat_id, MessageCursor::NextBack, 3) .await .unwrap(); + assert!(next_back.has_more); + assert_eq!(next_back.messages.len(), 3); + assert_eq!(next_back.messages[0].content, "hello world 9"); + assert_eq!(next_back.messages[1].content, "hello world 8"); - // has_more should be false because we only have 10 messages - assert!(!remaining_messages.has_more); - assert_eq!(remaining_messages.messages.len(), 2); + let next_back = test_client + .api_client + .get_chat_messages(&workspace_id, &chat_id, MessageCursor::NextBack, 100) + .await + .unwrap(); + assert!(!next_back.has_more); + assert_eq!(next_back.messages.len(), 10); } #[tokio::test]