diff --git a/libs/appflowy-ai-client/src/dto.rs b/libs/appflowy-ai-client/src/dto.rs index 58ec6bb6..d5b7a336 100644 --- a/libs/appflowy-ai-client/src/dto.rs +++ b/libs/appflowy-ai-client/src/dto.rs @@ -274,21 +274,39 @@ pub struct LocalAIConfig { #[derive(Clone, Debug, Serialize, Deserialize)] pub struct CreateTextChatContext { pub chat_id: String, - /// Only support "txt" and "md" for now + /// Only support "txt" and "markdown" for now pub content_type: String, - pub text: String, + pub content: String, pub chunk_size: i32, pub chunk_overlap: i32, pub metadata: HashMap, } +impl CreateTextChatContext { + pub fn new(chat_id: String, content_type: String, text: String) -> Self { + CreateTextChatContext { + chat_id, + content_type, + content: text, + chunk_size: 2000, + chunk_overlap: 20, + metadata: HashMap::new(), + } + } + + pub fn with_metadata(mut self, metadata: HashMap) -> Self { + self.metadata = metadata; + self + } +} + impl Display for CreateTextChatContext { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.write_fmt(format_args!( "Create Chat context: {{ chat_id: {}, content_type: {}, content size: {}, metadata: {:?} }}", self.chat_id, self.content_type, - self.text.len(), + self.content.len(), self.metadata )) } diff --git a/libs/appflowy-ai-client/tests/chat_test/context_test.rs b/libs/appflowy-ai-client/tests/chat_test/context_test.rs index 7f74adb7..529696bd 100644 --- a/libs/appflowy-ai-client/tests/chat_test/context_test.rs +++ b/libs/appflowy-ai-client/tests/chat_test/context_test.rs @@ -7,7 +7,7 @@ async fn create_chat_context_test() { let context = CreateTextChatContext { chat_id: chat_id.clone(), content_type: "txt".to_string(), - text: "I have lived in the US for five years".to_string(), + content: "I have lived in the US for five years".to_string(), chunk_size: 1000, chunk_overlap: 20, metadata: Default::default(), diff --git a/libs/database-entity/src/dto.rs b/libs/database-entity/src/dto.rs index bafd8ee4..bbfc28e3 100644 --- a/libs/database-entity/src/dto.rs +++ b/libs/database-entity/src/dto.rs @@ -640,7 +640,7 @@ pub struct UpdateChatParams { #[validate(custom = "validate_not_empty_str")] pub name: Option, - pub rag_ids: Option>, + pub metadata: Option, } #[derive(Debug, Clone, Validate, Serialize, Deserialize)] @@ -648,24 +648,64 @@ pub struct CreateChatMessageParams { #[validate(custom = "validate_not_empty_str")] pub content: String, pub message_type: ChatMessageType, + + /// metadata is json array object #[serde(skip_serializing_if = "Option::is_none")] pub metadata: Option, } +/// [ChatMessageMetadata] is used when creating a new question message. +/// All the properties of [ChatMessageMetadata] except [ChatMetadataData] will be stored as a +/// metadata for specific [ChatMessage] #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatMessageMetadata { pub data: ChatMetadataData, + /// The id for the metadata. It can be a file_id, view_id pub id: String, + /// The name for the metadata. For example, @xxx, @xx.txt pub name: String, pub source: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub extract: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatMetadataData { /// Don't rename this field, it's used [ops::extract_chat_message_metadata] - content: String, - pub content_type: String, - size: i64, + pub content: String, + pub content_type: ChatMetadataContentType, + pub size: i64, +} + +impl ChatMetadataData { + pub fn validate(&self) -> bool { + match self.content_type { + ChatMetadataContentType::Text => self.content.len() == self.size as usize, + ChatMetadataContentType::Markdown => self.content.len() == self.size as usize, + _ => true, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ChatMetadataContentType { + Unknown, + Text, + Markdown, + Pdf, + Custom(String), +} + +impl Display for ChatMetadataContentType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ChatMetadataContentType::Unknown => write!(f, "unknown"), + ChatMetadataContentType::Text => write!(f, "txt"), + ChatMetadataContentType::Markdown => write!(f, "markdown"), + ChatMetadataContentType::Pdf => write!(f, "pdf"), + ChatMetadataContentType::Custom(custom) => write!(f, "{}", custom), + } + } } impl ChatMetadataData { @@ -673,7 +713,7 @@ impl ChatMetadataData { let size = content.len(); Self { content, - content_type: "text".to_string(), + content_type: ChatMetadataContentType::Text, size: size as i64, } } diff --git a/libs/database/src/chat/chat_ops.rs b/libs/database/src/chat/chat_ops.rs index 7ed94872..fe0a6009 100644 --- a/libs/database/src/chat/chat_ops.rs +++ b/libs/database/src/chat/chat_ops.rs @@ -77,13 +77,12 @@ pub async fn update_chat( current_param_pos += 1; } - if let Some(ref rag_ids) = params.rag_ids { - query_parts.push(format!("rag_ids = ${}", current_param_pos)); - let rag_ids_json = json!(rag_ids); + if let Some(ref metadata) = params.metadata { + query_parts.push(format!("metadata = metadata || ${}", current_param_pos)); args - .add(rag_ids_json) + .add(json!(metadata)) .map_err(|err| AppError::SqlxArgEncodingError { - desc: format!("unable to encode rag ids json for chat id {}", chat_id), + desc: format!("unable to encode metadata json for chat id {}", chat_id), err, })?; current_param_pos += 1; diff --git a/src/api/chat.rs b/src/api/chat.rs index 14eadfcd..d838be25 100644 --- a/src/api/chat.rs +++ b/src/api/chat.rs @@ -1,7 +1,7 @@ use crate::biz::chat::ops::{ create_chat, create_chat_message, create_chat_message_stream, delete_chat, extract_chat_message_metadata, generate_chat_message_answer, get_chat_messages, - update_chat_message, ExtractChatMetadata, + update_chat_message, }; use crate::state::AppState; use actix_web::web::{Data, Json}; @@ -175,6 +175,7 @@ async fn get_related_message_handler( Ok(AppResponse::Ok().with_data(resp).into()) } +#[instrument(level = "debug", skip_all, err)] async fn create_question_handler( state: Data, path: web::Path<(String, String)>, @@ -184,25 +185,21 @@ async fn create_question_handler( let (_workspace_id, chat_id) = path.into_inner(); let mut params = payload.into_inner(); + // When create a question, we will extract the metadata from the question content. + // metadata might include user mention file,page,or user. For example, @Get started. for extract_context in extract_chat_message_metadata(&mut params) { - match extract_context { - ExtractChatMetadata::Text { text, metadata } => { - let context = CreateTextChatContext { - chat_id: chat_id.clone(), - content_type: "txt".to_string(), - text, - chunk_size: 2000, - chunk_overlap: 20, - metadata, - }; - trace!("create chat context: {}", context); - state - .ai_client - .create_chat_text_context(context) - .await - .map_err(AppError::from)?; - }, - } + let context = CreateTextChatContext::new( + chat_id.clone(), + extract_context.content_type, + extract_context.content, + ) + .with_metadata(extract_context.metadata); + trace!("create context for question: {}", context); + state + .ai_client + .create_chat_text_context(context) + .await + .map_err(AppError::from)?; } let uid = state.user_cache.get_user_uid(&uuid).await?; diff --git a/src/biz/chat/ops.rs b/src/biz/chat/ops.rs index b19bed39..0dfaf18a 100644 --- a/src/biz/chat/ops.rs +++ b/src/biz/chat/ops.rs @@ -12,8 +12,9 @@ use database::chat::chat_ops::{ select_chat_messages, }; use database_entity::dto::{ - ChatAuthor, ChatAuthorType, ChatMessage, ChatMessageType, CreateChatMessageParams, - CreateChatParams, GetChatMessageParams, RepeatedChatMessage, UpdateChatMessageContentParams, + ChatAuthor, ChatAuthorType, ChatMessage, ChatMessageType, ChatMetadataData, + CreateChatMessageParams, CreateChatParams, GetChatMessageParams, RepeatedChatMessage, + UpdateChatMessageContentParams, }; use futures::stream::Stream; use serde_json::Value; @@ -130,19 +131,14 @@ pub async fn create_chat_message( Ok(question) } -enum ContextType { - Unknown, - Text, -} - /// Extracts the chat context from the metadata. Currently, we only support text as a context. In /// the future, we will support other types of context. -pub(crate) enum ExtractChatMetadata { - Text { - text: String, - metadata: HashMap, - }, +pub(crate) struct ExtractChatMetadata { + pub(crate) content: String, + pub(crate) content_type: String, + pub(crate) metadata: HashMap, } + /// Removes the "content" field from the metadata if the "ty" field is equal to "text". /// The metadata struct is shown below: /// { @@ -155,68 +151,25 @@ pub(crate) enum ExtractChatMetadata { /// "name": "name" /// } /// -/// # Parameters -/// - `params`: A mutable reference to `CreateChatMessageParams` which contains metadata. -/// -/// # Returns -/// - `Option<(String, HashMap)>`: A tuple containing the removed content and the updated metadata, otherwise `None`. +/// the root json is point to the struct [database_entity::dto::ChatMessageMetadata] fn extract_message_metadata( message_metadata: &mut serde_json::Value, ) -> Option { trace!("Extracting metadata: {:?}", message_metadata); if let Value::Object(message_metadata) = message_metadata { - let mut context_type = ContextType::Unknown; - if let Some(Value::Object(data)) = message_metadata.get("data") { - if let Some(ty) = data.get("content_type").and_then(|v| v.as_str()) { - match ty { - "text" => context_type = ContextType::Text, - _ => context_type = ContextType::Unknown, - } - } - } - - match context_type { - ContextType::Unknown => { - // do nothing - }, - ContextType::Text => { - // remove the "data" field from the context if the "ty" field is equal to "text" - let mut text = None; - if let Some(Value::Object(ref mut data)) = message_metadata.remove("data") { - let content = data - .remove("content") - .and_then(|value| { - if let Value::String(s) = value { - Some(s) - } else { - None - } - }) - .unwrap_or_default(); - - let content_size = data - .remove("size") - .and_then(|value| { - if let Value::Number(n) = value { - n.as_i64() - } else { - None - } - }) - .unwrap_or(0); - - // If the content is not empty and the content size is equal to the length of the content - if !content.is_empty() && content.len() == content_size as usize { - text = Some(content); - } - } - - return text.map(|text| ExtractChatMetadata::Text { - text, + // remove the "data" field + if let Some(data) = message_metadata + .remove("data") + .and_then(|value| serde_json::from_value::(value.clone()).ok()) + { + if data.validate() { + return Some(ExtractChatMetadata { + content: data.content, + content_type: data.content_type.to_string(), metadata: message_metadata.clone().into_iter().collect(), }); - }, + } } } @@ -227,8 +180,8 @@ pub(crate) fn extract_chat_message_metadata( params: &mut CreateChatMessageParams, ) -> Vec { let mut extract_metadatas = vec![]; + trace!("chat metadata: {:?}", params.metadata); if let Some(Value::Array(ref mut list)) = params.metadata { - trace!("Extracting chat metadata: {:?}", list); for metadata in list { if let Some(extract_context) = extract_message_metadata(metadata) { extract_metadatas.push(extract_context); diff --git a/tests/ai_test/chat_test.rs b/tests/ai_test/chat_test.rs index 09bf7488..254097d5 100644 --- a/tests/ai_test/chat_test.rs +++ b/tests/ai_test/chat_test.rs @@ -122,6 +122,7 @@ async fn chat_qa_test() { id: "123".to_string(), name: "test context".to_string(), source: "user added".to_string(), + extract: Some(json!({"created_at": 123})), }; let params = @@ -131,6 +132,19 @@ async fn chat_qa_test() { .create_question(&workspace_id, &chat_id, params) .await .unwrap(); + assert_json_eq!( + question.meta_data, + json!([ + { + "id": "123", + "name": "test context", + "source": "user added", + "extract": { + "created_at": 123 + } + } + ]) + ); let answer = test_client .api_client @@ -275,7 +289,7 @@ async fn create_chat_context_test() { let context = CreateTextChatContext { chat_id: chat_id.clone(), content_type: "txt".to_string(), - text: "Lacus have lived in the US for five years".to_string(), + content: "Lacus have lived in the US for five years".to_string(), chunk_size: 1000, chunk_overlap: 20, metadata: Default::default(),