feat: support chat with markdown (#718)

* chore: store metatdata

* chore: support markdown

* chore: update test
This commit is contained in:
Nathan.fooo 2024-08-08 13:19:19 +08:00 committed by GitHub
parent bb1c93b98a
commit 0b3949152b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 122 additions and 101 deletions

View File

@ -274,21 +274,39 @@ pub struct LocalAIConfig {
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CreateTextChatContext {
pub chat_id: String,
/// Only support "txt" and "md" for now
/// Only support "txt" and "markdown" for now
pub content_type: String,
pub text: String,
pub content: String,
pub chunk_size: i32,
pub chunk_overlap: i32,
pub metadata: HashMap<String, serde_json::Value>,
}
impl CreateTextChatContext {
pub fn new(chat_id: String, content_type: String, text: String) -> Self {
CreateTextChatContext {
chat_id,
content_type,
content: text,
chunk_size: 2000,
chunk_overlap: 20,
metadata: HashMap::new(),
}
}
pub fn with_metadata(mut self, metadata: HashMap<String, serde_json::Value>) -> Self {
self.metadata = metadata;
self
}
}
impl Display for CreateTextChatContext {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!(
"Create Chat context: {{ chat_id: {}, content_type: {}, content size: {}, metadata: {:?} }}",
self.chat_id,
self.content_type,
self.text.len(),
self.content.len(),
self.metadata
))
}

View File

@ -7,7 +7,7 @@ async fn create_chat_context_test() {
let context = CreateTextChatContext {
chat_id: chat_id.clone(),
content_type: "txt".to_string(),
text: "I have lived in the US for five years".to_string(),
content: "I have lived in the US for five years".to_string(),
chunk_size: 1000,
chunk_overlap: 20,
metadata: Default::default(),

View File

@ -640,7 +640,7 @@ pub struct UpdateChatParams {
#[validate(custom = "validate_not_empty_str")]
pub name: Option<String>,
pub rag_ids: Option<Vec<String>>,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
@ -648,24 +648,64 @@ pub struct CreateChatMessageParams {
#[validate(custom = "validate_not_empty_str")]
pub content: String,
pub message_type: ChatMessageType,
/// metadata is json array object
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
}
/// [ChatMessageMetadata] is used when creating a new question message.
/// All the properties of [ChatMessageMetadata] except [ChatMetadataData] will be stored as a
/// metadata for specific [ChatMessage]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessageMetadata {
pub data: ChatMetadataData,
/// The id for the metadata. It can be a file_id, view_id
pub id: String,
/// The name for the metadata. For example, @xxx, @xx.txt
pub name: String,
pub source: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub extract: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMetadataData {
/// Don't rename this field, it's used [ops::extract_chat_message_metadata]
content: String,
pub content_type: String,
size: i64,
pub content: String,
pub content_type: ChatMetadataContentType,
pub size: i64,
}
impl ChatMetadataData {
pub fn validate(&self) -> bool {
match self.content_type {
ChatMetadataContentType::Text => self.content.len() == self.size as usize,
ChatMetadataContentType::Markdown => self.content.len() == self.size as usize,
_ => true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ChatMetadataContentType {
Unknown,
Text,
Markdown,
Pdf,
Custom(String),
}
impl Display for ChatMetadataContentType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ChatMetadataContentType::Unknown => write!(f, "unknown"),
ChatMetadataContentType::Text => write!(f, "txt"),
ChatMetadataContentType::Markdown => write!(f, "markdown"),
ChatMetadataContentType::Pdf => write!(f, "pdf"),
ChatMetadataContentType::Custom(custom) => write!(f, "{}", custom),
}
}
}
impl ChatMetadataData {
@ -673,7 +713,7 @@ impl ChatMetadataData {
let size = content.len();
Self {
content,
content_type: "text".to_string(),
content_type: ChatMetadataContentType::Text,
size: size as i64,
}
}

View File

@ -77,13 +77,12 @@ pub async fn update_chat(
current_param_pos += 1;
}
if let Some(ref rag_ids) = params.rag_ids {
query_parts.push(format!("rag_ids = ${}", current_param_pos));
let rag_ids_json = json!(rag_ids);
if let Some(ref metadata) = params.metadata {
query_parts.push(format!("metadata = metadata || ${}", current_param_pos));
args
.add(rag_ids_json)
.add(json!(metadata))
.map_err(|err| AppError::SqlxArgEncodingError {
desc: format!("unable to encode rag ids json for chat id {}", chat_id),
desc: format!("unable to encode metadata json for chat id {}", chat_id),
err,
})?;
current_param_pos += 1;

View File

@ -1,7 +1,7 @@
use crate::biz::chat::ops::{
create_chat, create_chat_message, create_chat_message_stream, delete_chat,
extract_chat_message_metadata, generate_chat_message_answer, get_chat_messages,
update_chat_message, ExtractChatMetadata,
update_chat_message,
};
use crate::state::AppState;
use actix_web::web::{Data, Json};
@ -175,6 +175,7 @@ async fn get_related_message_handler(
Ok(AppResponse::Ok().with_data(resp).into())
}
#[instrument(level = "debug", skip_all, err)]
async fn create_question_handler(
state: Data<AppState>,
path: web::Path<(String, String)>,
@ -184,25 +185,21 @@ async fn create_question_handler(
let (_workspace_id, chat_id) = path.into_inner();
let mut params = payload.into_inner();
// When create a question, we will extract the metadata from the question content.
// metadata might include user mention file,page,or user. For example, @Get started.
for extract_context in extract_chat_message_metadata(&mut params) {
match extract_context {
ExtractChatMetadata::Text { text, metadata } => {
let context = CreateTextChatContext {
chat_id: chat_id.clone(),
content_type: "txt".to_string(),
text,
chunk_size: 2000,
chunk_overlap: 20,
metadata,
};
trace!("create chat context: {}", context);
state
.ai_client
.create_chat_text_context(context)
.await
.map_err(AppError::from)?;
},
}
let context = CreateTextChatContext::new(
chat_id.clone(),
extract_context.content_type,
extract_context.content,
)
.with_metadata(extract_context.metadata);
trace!("create context for question: {}", context);
state
.ai_client
.create_chat_text_context(context)
.await
.map_err(AppError::from)?;
}
let uid = state.user_cache.get_user_uid(&uuid).await?;

View File

@ -12,8 +12,9 @@ use database::chat::chat_ops::{
select_chat_messages,
};
use database_entity::dto::{
ChatAuthor, ChatAuthorType, ChatMessage, ChatMessageType, CreateChatMessageParams,
CreateChatParams, GetChatMessageParams, RepeatedChatMessage, UpdateChatMessageContentParams,
ChatAuthor, ChatAuthorType, ChatMessage, ChatMessageType, ChatMetadataData,
CreateChatMessageParams, CreateChatParams, GetChatMessageParams, RepeatedChatMessage,
UpdateChatMessageContentParams,
};
use futures::stream::Stream;
use serde_json::Value;
@ -130,19 +131,14 @@ pub async fn create_chat_message(
Ok(question)
}
enum ContextType {
Unknown,
Text,
}
/// Extracts the chat context from the metadata. Currently, we only support text as a context. In
/// the future, we will support other types of context.
pub(crate) enum ExtractChatMetadata {
Text {
text: String,
metadata: HashMap<String, Value>,
},
pub(crate) struct ExtractChatMetadata {
pub(crate) content: String,
pub(crate) content_type: String,
pub(crate) metadata: HashMap<String, Value>,
}
/// Removes the "content" field from the metadata if the "ty" field is equal to "text".
/// The metadata struct is shown below:
/// {
@ -155,68 +151,25 @@ pub(crate) enum ExtractChatMetadata {
/// "name": "name"
/// }
///
/// # Parameters
/// - `params`: A mutable reference to `CreateChatMessageParams` which contains metadata.
///
/// # Returns
/// - `Option<(String, HashMap<String, serde_json::Value>)>`: A tuple containing the removed content and the updated metadata, otherwise `None`.
/// the root json is point to the struct [database_entity::dto::ChatMessageMetadata]
fn extract_message_metadata(
message_metadata: &mut serde_json::Value,
) -> Option<ExtractChatMetadata> {
trace!("Extracting metadata: {:?}", message_metadata);
if let Value::Object(message_metadata) = message_metadata {
let mut context_type = ContextType::Unknown;
if let Some(Value::Object(data)) = message_metadata.get("data") {
if let Some(ty) = data.get("content_type").and_then(|v| v.as_str()) {
match ty {
"text" => context_type = ContextType::Text,
_ => context_type = ContextType::Unknown,
}
}
}
match context_type {
ContextType::Unknown => {
// do nothing
},
ContextType::Text => {
// remove the "data" field from the context if the "ty" field is equal to "text"
let mut text = None;
if let Some(Value::Object(ref mut data)) = message_metadata.remove("data") {
let content = data
.remove("content")
.and_then(|value| {
if let Value::String(s) = value {
Some(s)
} else {
None
}
})
.unwrap_or_default();
let content_size = data
.remove("size")
.and_then(|value| {
if let Value::Number(n) = value {
n.as_i64()
} else {
None
}
})
.unwrap_or(0);
// If the content is not empty and the content size is equal to the length of the content
if !content.is_empty() && content.len() == content_size as usize {
text = Some(content);
}
}
return text.map(|text| ExtractChatMetadata::Text {
text,
// remove the "data" field
if let Some(data) = message_metadata
.remove("data")
.and_then(|value| serde_json::from_value::<ChatMetadataData>(value.clone()).ok())
{
if data.validate() {
return Some(ExtractChatMetadata {
content: data.content,
content_type: data.content_type.to_string(),
metadata: message_metadata.clone().into_iter().collect(),
});
},
}
}
}
@ -227,8 +180,8 @@ pub(crate) fn extract_chat_message_metadata(
params: &mut CreateChatMessageParams,
) -> Vec<ExtractChatMetadata> {
let mut extract_metadatas = vec![];
trace!("chat metadata: {:?}", params.metadata);
if let Some(Value::Array(ref mut list)) = params.metadata {
trace!("Extracting chat metadata: {:?}", list);
for metadata in list {
if let Some(extract_context) = extract_message_metadata(metadata) {
extract_metadatas.push(extract_context);

View File

@ -122,6 +122,7 @@ async fn chat_qa_test() {
id: "123".to_string(),
name: "test context".to_string(),
source: "user added".to_string(),
extract: Some(json!({"created_at": 123})),
};
let params =
@ -131,6 +132,19 @@ async fn chat_qa_test() {
.create_question(&workspace_id, &chat_id, params)
.await
.unwrap();
assert_json_eq!(
question.meta_data,
json!([
{
"id": "123",
"name": "test context",
"source": "user added",
"extract": {
"created_at": 123
}
}
])
);
let answer = test_client
.api_client
@ -275,7 +289,7 @@ async fn create_chat_context_test() {
let context = CreateTextChatContext {
chat_id: chat_id.clone(),
content_type: "txt".to_string(),
text: "Lacus have lived in the US for five years".to_string(),
content: "Lacus have lived in the US for five years".to_string(),
chunk_size: 1000,
chunk_overlap: 20,
metadata: Default::default(),