624 lines
16 KiB
Rust
624 lines
16 KiB
Rust
use crate::pg_row::AFChatRow;
|
|
use crate::workspace::is_workspace_exist;
|
|
use anyhow::anyhow;
|
|
use app_error::AppError;
|
|
use chrono::{DateTime, Utc};
|
|
use database_entity::dto::{
|
|
ChatAuthor, ChatMessage, CreateChatParams, GetChatMessageParams, MessageCursor,
|
|
RepeatedChatMessage, UpdateChatMessageContentParams, UpdateChatMessageMetaParams,
|
|
UpdateChatParams,
|
|
};
|
|
|
|
use serde_json::json;
|
|
use sqlx::postgres::PgArguments;
|
|
use sqlx::types::JsonValue;
|
|
use sqlx::{Arguments, Executor, PgPool, Postgres, Transaction};
|
|
use std::ops::DerefMut;
|
|
use std::str::FromStr;
|
|
use tracing::warn;
|
|
|
|
use uuid::Uuid;
|
|
|
|
pub async fn insert_chat(
|
|
txn: &mut Transaction<'_, Postgres>,
|
|
workspace_id: &str,
|
|
params: CreateChatParams,
|
|
) -> Result<(), AppError> {
|
|
let chat_id = Uuid::from_str(¶ms.chat_id)?;
|
|
let workspace_id = Uuid::from_str(workspace_id)?;
|
|
if !is_workspace_exist(txn.deref_mut(), &workspace_id).await? {
|
|
return Err(AppError::RecordNotFound(format!(
|
|
"workspace with given id:{} is not found",
|
|
workspace_id
|
|
)));
|
|
}
|
|
let rag_ids = json!(params.rag_ids);
|
|
|
|
sqlx::query!(
|
|
r#"
|
|
INSERT INTO af_chat (chat_id, name, workspace_id, rag_ids)
|
|
VALUES ($1, $2, $3, $4)
|
|
"#,
|
|
chat_id,
|
|
params.name,
|
|
workspace_id,
|
|
rag_ids,
|
|
)
|
|
.execute(txn.deref_mut())
|
|
.await
|
|
.map_err(|err| AppError::Internal(anyhow!("Failed to insert chat: {}", err)))?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Updates specific fields of a chat record in the database using transactional queries.
|
|
///
|
|
/// This function dynamically builds an SQL `UPDATE` query based on the provided parameters to
|
|
/// update fields of a specific chat record identified by `chat_id`. It uses a transaction to ensure
|
|
/// that the update operation is atomic.
|
|
///
|
|
pub async fn update_chat(
|
|
txn: &mut Transaction<'_, Postgres>,
|
|
chat_id: &Uuid,
|
|
params: UpdateChatParams,
|
|
) -> Result<(), AppError> {
|
|
let mut query_parts = vec!["UPDATE af_chat SET".to_string()];
|
|
let mut args = PgArguments::default();
|
|
let mut current_param_pos = 1; // Start counting SQL parameters from 1
|
|
|
|
if let Some(ref name) = params.name {
|
|
query_parts.push(format!("name = ${}", current_param_pos));
|
|
args
|
|
.add(name)
|
|
.map_err(|err| AppError::SqlxArgEncodingError {
|
|
desc: format!("unable to encode chat name for chat id {}", chat_id),
|
|
err,
|
|
})?;
|
|
current_param_pos += 1;
|
|
}
|
|
|
|
if let Some(ref metadata) = params.metadata {
|
|
query_parts.push(format!("metadata = metadata || ${}", current_param_pos));
|
|
args
|
|
.add(json!(metadata))
|
|
.map_err(|err| AppError::SqlxArgEncodingError {
|
|
desc: format!("unable to encode metadata json for chat id {}", chat_id),
|
|
err,
|
|
})?;
|
|
current_param_pos += 1;
|
|
}
|
|
|
|
query_parts.push(format!("WHERE chat_id = ${}", current_param_pos));
|
|
args
|
|
.add(chat_id)
|
|
.map_err(|err| AppError::SqlxArgEncodingError {
|
|
desc: format!("unable to encode chat id {}", chat_id),
|
|
err,
|
|
})?;
|
|
|
|
let query = query_parts.join(", ") + ";";
|
|
let query = sqlx::query_with(&query, args);
|
|
query.execute(txn.deref_mut()).await?;
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn delete_chat(
|
|
txn: &mut Transaction<'_, Postgres>,
|
|
chat_id: &str,
|
|
) -> Result<(), AppError> {
|
|
let chat_id = Uuid::from_str(chat_id)?;
|
|
sqlx::query!(
|
|
r#"
|
|
UPDATE af_chat
|
|
SET deleted_at = now()
|
|
WHERE chat_id = $1
|
|
"#,
|
|
chat_id,
|
|
)
|
|
.execute(txn.deref_mut())
|
|
.await?;
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn select_chat<'a, E: Executor<'a, Database = Postgres>>(
|
|
executor: E,
|
|
chat_id: &str,
|
|
) -> Result<AFChatRow, AppError> {
|
|
let chat_id = Uuid::from_str(chat_id)?;
|
|
let row = sqlx::query_as!(
|
|
AFChatRow,
|
|
r#"
|
|
SELECT *
|
|
FROM af_chat
|
|
WHERE chat_id = $1 AND deleted_at IS NULL
|
|
"#,
|
|
&chat_id,
|
|
)
|
|
.fetch_optional(executor)
|
|
.await?;
|
|
match row {
|
|
Some(row) => Ok(row),
|
|
None => Err(AppError::RecordNotFound(format!(
|
|
"chat with given id:{} is not found",
|
|
chat_id
|
|
))),
|
|
}
|
|
}
|
|
|
|
pub async fn insert_answer_message_with_transaction(
|
|
transaction: &mut Transaction<'_, Postgres>,
|
|
author: ChatAuthor,
|
|
chat_id: &str,
|
|
content: String,
|
|
metadata: serde_json::Value,
|
|
question_message_id: i64,
|
|
) -> Result<ChatMessage, AppError> {
|
|
let chat_id = Uuid::from_str(chat_id)?;
|
|
let existing_reply_message_id: Option<i64> = sqlx::query_scalar!(
|
|
r#"
|
|
SELECT reply_message_id
|
|
FROM af_chat_messages
|
|
WHERE message_id = $1
|
|
"#,
|
|
question_message_id
|
|
)
|
|
.fetch_one(transaction.deref_mut())
|
|
.await?;
|
|
|
|
if let Some(reply_id) = existing_reply_message_id {
|
|
// If there is an existing reply_message_id, update the existing message
|
|
sqlx::query!(
|
|
r#"
|
|
UPDATE af_chat_messages
|
|
SET content = $2,
|
|
author = $3,
|
|
created_at = CURRENT_TIMESTAMP,
|
|
meta_data = $4
|
|
WHERE message_id = $1
|
|
"#,
|
|
reply_id,
|
|
&content,
|
|
json!(author),
|
|
metadata,
|
|
)
|
|
.execute(transaction.deref_mut())
|
|
.await
|
|
.map_err(|err| AppError::Internal(anyhow!("Failed to update chat message: {}", err)))?;
|
|
|
|
let row = sqlx::query!(
|
|
r#"
|
|
SELECT message_id, content, created_at, author, meta_data, reply_message_id
|
|
FROM af_chat_messages
|
|
WHERE message_id = $1
|
|
"#,
|
|
reply_id
|
|
)
|
|
.fetch_one(transaction.deref_mut())
|
|
.await
|
|
.map_err(|err| AppError::Internal(anyhow!("Failed to fetch updated message: {}", err)))?;
|
|
|
|
let chat_message = ChatMessage {
|
|
author,
|
|
message_id: row.message_id,
|
|
content: row.content,
|
|
created_at: row.created_at,
|
|
meta_data: row.meta_data,
|
|
reply_message_id: Some(question_message_id),
|
|
};
|
|
|
|
Ok(chat_message)
|
|
} else {
|
|
// Insert a new chat message
|
|
let row = sqlx::query!(
|
|
r#"
|
|
INSERT INTO af_chat_messages (chat_id, author, content, meta_data)
|
|
VALUES ($1, $2, $3, $4)
|
|
RETURNING message_id, created_at
|
|
"#,
|
|
chat_id,
|
|
json!(author),
|
|
&content,
|
|
&metadata,
|
|
)
|
|
.fetch_one(transaction.deref_mut())
|
|
.await
|
|
.map_err(|err| AppError::Internal(anyhow!("Failed to insert chat message: {}", err)))?;
|
|
|
|
// Update the question message with the new reply_message_id
|
|
sqlx::query!(
|
|
r#"
|
|
UPDATE af_chat_messages
|
|
SET reply_message_id = $2
|
|
WHERE message_id = $1
|
|
"#,
|
|
question_message_id,
|
|
row.message_id,
|
|
)
|
|
.execute(transaction.deref_mut())
|
|
.await
|
|
.map_err(|err| AppError::Internal(anyhow!("Failed to update reply_message_id: {}", err)))?;
|
|
|
|
let chat_message = ChatMessage {
|
|
author,
|
|
message_id: row.message_id,
|
|
content,
|
|
created_at: row.created_at,
|
|
meta_data: metadata,
|
|
reply_message_id: None,
|
|
};
|
|
|
|
Ok(chat_message)
|
|
}
|
|
}
|
|
|
|
pub async fn insert_answer_message(
|
|
pg_pool: &PgPool,
|
|
author: ChatAuthor,
|
|
chat_id: &str,
|
|
content: String,
|
|
metadata: Option<serde_json::Value>,
|
|
question_message_id: i64,
|
|
) -> Result<ChatMessage, AppError> {
|
|
let mut txn = pg_pool.begin().await?;
|
|
let chat_message = insert_answer_message_with_transaction(
|
|
&mut txn,
|
|
author,
|
|
chat_id,
|
|
content,
|
|
metadata.unwrap_or_default(),
|
|
question_message_id,
|
|
)
|
|
.await?;
|
|
txn.commit().await.map_err(|err| {
|
|
AppError::Internal(anyhow!(
|
|
"Failed to commit transaction to insert answer message: {}",
|
|
err
|
|
))
|
|
})?;
|
|
Ok(chat_message)
|
|
}
|
|
|
|
pub async fn insert_question_message<'a, E: Executor<'a, Database = Postgres>>(
|
|
executor: E,
|
|
author: ChatAuthor,
|
|
chat_id: &str,
|
|
content: String,
|
|
metadata: Option<JsonValue>,
|
|
) -> Result<ChatMessage, AppError> {
|
|
let metadata = metadata.unwrap_or_else(|| json!({}));
|
|
let chat_id = Uuid::from_str(chat_id)?;
|
|
let row = sqlx::query!(
|
|
r#"
|
|
INSERT INTO af_chat_messages (chat_id, author, content, meta_data)
|
|
VALUES ($1, $2, $3, $4)
|
|
RETURNING message_id, created_at
|
|
"#,
|
|
chat_id,
|
|
json!(author),
|
|
&content,
|
|
&metadata,
|
|
)
|
|
.fetch_one(executor)
|
|
.await
|
|
.map_err(|err| AppError::Internal(anyhow!("Failed to insert chat message: {}", err)))?;
|
|
|
|
let chat_message = ChatMessage {
|
|
author,
|
|
message_id: row.message_id,
|
|
content,
|
|
created_at: row.created_at,
|
|
meta_data: metadata,
|
|
reply_message_id: None,
|
|
};
|
|
Ok(chat_message)
|
|
}
|
|
|
|
pub async fn select_chat_messages(
|
|
txn: &mut Transaction<'_, Postgres>,
|
|
chat_id: &str,
|
|
params: GetChatMessageParams,
|
|
) -> Result<RepeatedChatMessage, AppError> {
|
|
let chat_id = Uuid::from_str(chat_id)?;
|
|
let mut query = r#"
|
|
SELECT message_id, content, created_at, author, meta_data, reply_message_id
|
|
FROM af_chat_messages
|
|
WHERE chat_id = $1
|
|
"#
|
|
.to_string();
|
|
|
|
let mut args = PgArguments::default();
|
|
args
|
|
.add(&chat_id)
|
|
.map_err(|err| AppError::SqlxArgEncodingError {
|
|
desc: format!("unable to encode chat id {}", chat_id),
|
|
err,
|
|
})?;
|
|
|
|
// Message IDs: 1 2 3 4 5
|
|
// AfterMessageId(3, 5): [4] [5] has_more = false
|
|
// BeforeMessageId(3, 5): [1] [2] has_more = false
|
|
// Offset(3, 5): [4] [5] has_more = true
|
|
match params.cursor {
|
|
MessageCursor::AfterMessageId(after_message_id) => {
|
|
query += " AND message_id > $2";
|
|
args
|
|
.add(after_message_id)
|
|
.map_err(|err| AppError::SqlxArgEncodingError {
|
|
desc: format!("unable to encode message id {}", after_message_id),
|
|
err,
|
|
})?;
|
|
query += " ORDER BY message_id DESC LIMIT $3";
|
|
args
|
|
.add(params.limit as i64)
|
|
.map_err(|err| AppError::SqlxArgEncodingError {
|
|
desc: format!("unable to encode row limit {}", params.limit as i64),
|
|
err,
|
|
})?;
|
|
},
|
|
MessageCursor::Offset(offset) => {
|
|
query += " ORDER BY message_id ASC LIMIT $2 OFFSET $3";
|
|
args
|
|
.add(params.limit as i64)
|
|
.map_err(|err| AppError::SqlxArgEncodingError {
|
|
desc: format!("unable to encode row limit {}", params.limit as i64),
|
|
err,
|
|
})?;
|
|
args
|
|
.add(offset as i64)
|
|
.map_err(|err| AppError::SqlxArgEncodingError {
|
|
desc: format!("unable to encode offset {}", offset as i64),
|
|
err,
|
|
})?;
|
|
},
|
|
MessageCursor::BeforeMessageId(before_message_id) => {
|
|
query += " AND message_id < $2";
|
|
args
|
|
.add(before_message_id)
|
|
.map_err(|err| AppError::SqlxArgEncodingError {
|
|
desc: format!("unable to encode message id {}", before_message_id),
|
|
err,
|
|
})?;
|
|
query += " ORDER BY message_id DESC LIMIT $3";
|
|
args
|
|
.add(params.limit as i64)
|
|
.map_err(|err| AppError::SqlxArgEncodingError {
|
|
desc: format!("unable to encode row limit {}", params.limit as i64),
|
|
err,
|
|
})?;
|
|
},
|
|
MessageCursor::NextBack => {
|
|
query += " ORDER BY message_id DESC LIMIT $2";
|
|
args
|
|
.add(params.limit as i64)
|
|
.map_err(|err| AppError::SqlxArgEncodingError {
|
|
desc: format!("unable to encode row limit {}", params.limit as i64),
|
|
err,
|
|
})?;
|
|
},
|
|
}
|
|
|
|
#[allow(clippy::type_complexity)]
|
|
let rows: Vec<(
|
|
i64,
|
|
String,
|
|
DateTime<Utc>,
|
|
serde_json::Value,
|
|
serde_json::Value,
|
|
Option<i64>,
|
|
)> = sqlx::query_as_with(&query, args)
|
|
.fetch_all(txn.deref_mut())
|
|
.await?;
|
|
|
|
let messages = rows
|
|
.into_iter()
|
|
.flat_map(
|
|
|(message_id, content, created_at, author, meta_data, reply_message_id)| {
|
|
match serde_json::from_value::<ChatAuthor>(author) {
|
|
Ok(author) => Some(ChatMessage {
|
|
author,
|
|
message_id,
|
|
content,
|
|
created_at,
|
|
meta_data,
|
|
reply_message_id,
|
|
}),
|
|
Err(err) => {
|
|
warn!("Failed to deserialize author: {}", err);
|
|
None
|
|
},
|
|
}
|
|
},
|
|
)
|
|
.collect::<Vec<ChatMessage>>();
|
|
|
|
let total = sqlx::query_scalar!(
|
|
r#"
|
|
SELECT COUNT(*)
|
|
FROM public.af_chat_messages
|
|
WHERE chat_id = $1
|
|
"#,
|
|
&chat_id
|
|
)
|
|
.fetch_one(txn.deref_mut())
|
|
.await?
|
|
.unwrap_or(0);
|
|
|
|
let has_more = match params.cursor {
|
|
MessageCursor::AfterMessageId(_) => {
|
|
if messages.is_empty() {
|
|
false
|
|
} else {
|
|
sqlx::query!(
|
|
"SELECT EXISTS(SELECT 1 FROM af_chat_messages WHERE chat_id = $1 AND message_id > $2)",
|
|
&chat_id,
|
|
messages[0].message_id
|
|
)
|
|
.fetch_one(txn.deref_mut())
|
|
.await?
|
|
.exists
|
|
.unwrap_or(false)
|
|
}
|
|
},
|
|
MessageCursor::Offset(offset) => (offset + params.limit) < total as u64,
|
|
MessageCursor::BeforeMessageId(_) => {
|
|
if messages.is_empty() {
|
|
false
|
|
} else {
|
|
sqlx::query!(
|
|
"SELECT EXISTS(SELECT 1 FROM af_chat_messages WHERE chat_id = $1 AND message_id < $2)",
|
|
&chat_id,
|
|
messages.last().as_ref().unwrap().message_id
|
|
)
|
|
.fetch_one(txn.deref_mut())
|
|
.await?
|
|
.exists
|
|
.unwrap_or(false)
|
|
}
|
|
},
|
|
MessageCursor::NextBack => params.limit < total as u64,
|
|
};
|
|
|
|
Ok(RepeatedChatMessage {
|
|
messages,
|
|
total,
|
|
has_more,
|
|
})
|
|
}
|
|
|
|
pub async fn get_all_chat_messages<'a, E: Executor<'a, Database = Postgres>>(
|
|
executor: E,
|
|
chat_id: &str,
|
|
) -> Result<Vec<ChatMessage>, AppError> {
|
|
let chat_id = Uuid::from_str(chat_id)?;
|
|
let rows = sqlx::query!(
|
|
// ChatMessage,
|
|
r#"
|
|
SELECT message_id, content, created_at, author, meta_data, reply_message_id
|
|
FROM af_chat_messages
|
|
WHERE chat_id = $1
|
|
ORDER BY created_at ASC
|
|
"#,
|
|
chat_id,
|
|
)
|
|
.fetch_all(executor)
|
|
.await?;
|
|
|
|
let messages = rows
|
|
.into_iter()
|
|
.flat_map(
|
|
|row| match serde_json::from_value::<ChatAuthor>(row.author) {
|
|
Ok(author) => Some(ChatMessage {
|
|
author,
|
|
message_id: row.message_id,
|
|
content: row.content,
|
|
created_at: row.created_at,
|
|
meta_data: row.meta_data,
|
|
reply_message_id: row.reply_message_id,
|
|
}),
|
|
Err(err) => {
|
|
warn!("Failed to deserialize author: {}", err);
|
|
None
|
|
},
|
|
},
|
|
)
|
|
.collect::<Vec<ChatMessage>>();
|
|
|
|
Ok(messages)
|
|
}
|
|
|
|
pub async fn delete_answer_message_by_question_message_id(
|
|
transaction: &mut Transaction<'_, Postgres>,
|
|
message_id: i64,
|
|
) -> Result<(), AppError> {
|
|
// Step 1: Get the reply_message_id of the chat message with the given message_id
|
|
let reply_message_id: Option<i64> = sqlx::query_scalar!(
|
|
r#"
|
|
SELECT reply_message_id
|
|
FROM af_chat_messages
|
|
WHERE message_id = $1
|
|
"#,
|
|
message_id
|
|
)
|
|
.fetch_one(transaction.deref_mut())
|
|
.await?;
|
|
|
|
if let Some(reply_id) = reply_message_id {
|
|
// Step 2: Delete the chat message with the reply_message_id
|
|
sqlx::query!(
|
|
r#"
|
|
DELETE FROM af_chat_messages
|
|
WHERE message_id = $1
|
|
"#,
|
|
reply_id
|
|
)
|
|
.execute(transaction.deref_mut())
|
|
.await?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn update_chat_message_content(
|
|
transaction: &mut Transaction<'_, Postgres>,
|
|
params: &UpdateChatMessageContentParams,
|
|
) -> Result<(), AppError> {
|
|
sqlx::query(
|
|
r#"
|
|
UPDATE af_chat_messages
|
|
SET content = $2,
|
|
edited_at = CURRENT_TIMESTAMP
|
|
WHERE message_id = $1
|
|
"#,
|
|
)
|
|
.bind(params.message_id)
|
|
.bind(¶ms.content)
|
|
.execute(transaction.deref_mut())
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn update_chat_message_meta(
|
|
transaction: &mut Transaction<'_, Postgres>,
|
|
params: &UpdateChatMessageMetaParams,
|
|
) -> Result<(), AppError> {
|
|
for (key, value) in params.meta_data.iter() {
|
|
sqlx::query(
|
|
r#"
|
|
UPDATE af_chat_messages
|
|
SET meta_data = jsonb_set(
|
|
COALESCE(meta_data, '{}'),
|
|
$2,
|
|
$3::jsonb,
|
|
true
|
|
)
|
|
WHERE message_id = $1
|
|
"#,
|
|
)
|
|
.bind(params.message_id)
|
|
.bind(format!("{{{}}}", key))
|
|
.bind(value)
|
|
.execute(transaction.deref_mut())
|
|
.await?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn select_chat_message_content<'a, E: Executor<'a, Database = Postgres>>(
|
|
executor: E,
|
|
message_id: i64,
|
|
) -> Result<(String, serde_json::Value), AppError> {
|
|
let row = sqlx::query!(
|
|
r#"
|
|
SELECT content,meta_data
|
|
FROM af_chat_messages
|
|
WHERE message_id = $1
|
|
"#,
|
|
message_id,
|
|
)
|
|
.fetch_one(executor)
|
|
.await?;
|
|
Ok((row.content, row.meta_data))
|
|
}
|