chore: update chat test (#590)

This commit is contained in:
Nathan.fooo 2024-05-30 17:15:13 +08:00 committed by GitHub
parent 06272364b7
commit 1537c4d1f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 29 additions and 39 deletions

View File

@ -183,7 +183,7 @@ pub async fn select_chat_messages(
MessageCursor::AfterMessageId(after_message_id) => {
query += " AND message_id > $2";
args.add(after_message_id);
query += " ORDER BY message_id ASC LIMIT $3";
query += " ORDER BY message_id DESC LIMIT $3";
args.add(params.limit as i64);
},
MessageCursor::Offset(offset) => {
@ -194,7 +194,7 @@ pub async fn select_chat_messages(
MessageCursor::BeforeMessageId(before_message_id) => {
query += " AND message_id < $2";
args.add(before_message_id);
query += " ORDER BY message_id ASC LIMIT $3";
query += " ORDER BY message_id DESC LIMIT $3";
args.add(params.limit as i64);
},
MessageCursor::NextBack => {
@ -208,7 +208,7 @@ pub async fn select_chat_messages(
.fetch_all(txn.deref_mut())
.await?;
let mut messages = rows
let messages = rows
.into_iter()
.flat_map(|(message_id, content, created_at, author)| {
match serde_json::from_value::<ChatAuthor>(author) {
@ -226,10 +226,6 @@ pub async fn select_chat_messages(
})
.collect::<Vec<ChatMessage>>();
if matches!(params.cursor, MessageCursor::NextBack) {
messages.reverse();
}
let total = sqlx::query_scalar!(
r#"
SELECT COUNT(*)
@ -250,7 +246,7 @@ pub async fn select_chat_messages(
sqlx::query!(
"SELECT EXISTS(SELECT 1 FROM af_chat_messages WHERE chat_id = $1 AND message_id > $2)",
&chat_id,
messages.last().as_ref().unwrap().message_id
messages[0].message_id
)
.fetch_one(txn.deref_mut())
.await?
@ -266,7 +262,7 @@ pub async fn select_chat_messages(
sqlx::query!(
"SELECT EXISTS(SELECT 1 FROM af_chat_messages WHERE chat_id = $1 AND message_id < $2)",
&chat_id,
messages[0].message_id
messages.last().as_ref().unwrap().message_id
)
.fetch_one(txn.deref_mut())
.await?

View File

@ -108,5 +108,6 @@ async fn get_chat_message_handler(
trace!("get chat messages: {:?}", params);
let (_workspace_id, chat_id) = path.into_inner();
let messages = get_chat_messages(&state.pg_pool, params, &chat_id).await?;
trace!("get chat messages: {:?}", messages.messages);
Ok(AppResponse::Ok().with_data(messages).into())
}

View File

@ -34,32 +34,29 @@ async fn create_chat_and_create_messages_test() {
.unwrap();
messages.push(message);
}
// DESC is the default order
messages.reverse();
// get messages before third message. it should return first two messages even though we asked
// for 10 messages
assert_eq!(messages[7].content, "hello world 2");
let message_before_third = test_client
.api_client
.get_chat_messages(
&workspace_id,
&chat_id,
MessageCursor::BeforeMessageId(messages[2].message_id),
MessageCursor::BeforeMessageId(messages[7].message_id),
10,
)
.await
.unwrap();
assert!(!message_before_third.has_more);
assert_eq!(message_before_third.messages.len(), 2);
assert_eq!(
message_before_third.messages[0].message_id,
messages[0].message_id
);
assert_eq!(
message_before_third.messages[1].message_id,
messages[1].message_id
);
assert_eq!(message_before_third.messages[0].content, "hello world 1");
assert_eq!(message_before_third.messages[1].content, "hello world 0");
// get message after third message
assert_eq!(messages[2].content, "hello world 7");
let message_after_third = test_client
.api_client
.get_chat_messages(
@ -70,32 +67,28 @@ async fn create_chat_and_create_messages_test() {
)
.await
.unwrap();
assert!(message_after_third.has_more);
assert!(!message_after_third.has_more);
assert_eq!(message_after_third.messages.len(), 2);
assert_eq!(
message_after_third.messages[0].message_id,
messages[3].message_id
);
assert_eq!(
message_after_third.messages[1].message_id,
messages[4].message_id
);
assert_eq!(message_after_third.messages[0].content, "hello world 9");
assert_eq!(message_after_third.messages[1].content, "hello world 8");
// get all messages after 8th message
let remaining_messages = test_client
let next_back = test_client
.api_client
.get_chat_messages(
&workspace_id,
&chat_id,
MessageCursor::AfterMessageId(messages[7].message_id),
100,
)
.get_chat_messages(&workspace_id, &chat_id, MessageCursor::NextBack, 3)
.await
.unwrap();
assert!(next_back.has_more);
assert_eq!(next_back.messages.len(), 3);
assert_eq!(next_back.messages[0].content, "hello world 9");
assert_eq!(next_back.messages[1].content, "hello world 8");
// has_more should be false because we only have 10 messages
assert!(!remaining_messages.has_more);
assert_eq!(remaining_messages.messages.len(), 2);
let next_back = test_client
.api_client
.get_chat_messages(&workspace_id, &chat_id, MessageCursor::NextBack, 100)
.await
.unwrap();
assert!(!next_back.has_more);
assert_eq!(next_back.messages.len(), 10);
}
#[tokio::test]