fix: byte index 8000 is not a char boundary (#995)

* chore: fix split text boundary error and add related tests

* chore: reduce clone
This commit is contained in:
Nathan.fooo 2024-11-15 13:38:08 +08:00 committed by GitHub
parent d8075a9368
commit 97f9ff3dd8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 341 additions and 36 deletions

1
Cargo.lock generated
View File

@ -752,6 +752,7 @@ dependencies = [
"tokio-util",
"tracing",
"tracing-subscriber",
"unicode-segmentation",
"uuid",
"validator",
"workspace-template",

View File

@ -87,6 +87,7 @@ lazy_static = "1.4.0"
itertools = "0.12.0"
validator = "0.16.1"
rayon.workspace = true
unicode-segmentation = "1.9.0"
[dev-dependencies]
rand = "0.8.5"

View File

@ -13,6 +13,7 @@ use collab_document::document::DocumentBody;
use collab_document::error::DocumentError;
use collab_entity::CollabType;
use database_entity::dto::{AFCollabEmbeddingParams, AFCollabEmbeddings, EmbeddingContentType};
use unicode_segmentation::UnicodeSegmentation;
use uuid::Uuid;
use crate::indexer::{DocumentDataExt, Indexer};
@ -45,42 +46,12 @@ impl Indexer for DocumentIndexer {
match result {
Ok(document_data) => {
let content = document_data.to_plain_text();
let mut result = Vec::with_capacity(1 + content.len() / Self::DOC_CONTENT_SPLIT);
let mut slice = content.as_str();
while slice.len() > Self::DOC_CONTENT_SPLIT {
// we should split document into multiple fragments
let (left, right) = slice.split_at(Self::DOC_CONTENT_SPLIT);
let param = AFCollabEmbeddingParams {
fragment_id: Uuid::new_v4().to_string(),
object_id: object_id.clone(),
collab_type: CollabType::Document,
content_type: EmbeddingContentType::PlainText,
content: left.to_string(),
embedding: None,
};
result.push(param);
slice = right;
}
let content = if slice.len() == content.len() {
content // we didn't slice the content
} else {
slice.to_string()
};
if !content.is_empty() {
let param = AFCollabEmbeddingParams {
fragment_id: object_id.clone(),
object_id: object_id.clone(),
collab_type: CollabType::Document,
content_type: EmbeddingContentType::PlainText,
content,
embedding: None,
};
result.push(param);
}
Ok(result)
create_embedding_params(
object_id,
content,
CollabType::Document,
Self::DOC_CONTENT_SPLIT,
)
},
Err(err) => {
if matches!(err, DocumentError::NoRequiredData) {
@ -141,3 +112,335 @@ impl Indexer for DocumentIndexer {
}))
}
}
#[inline]
fn create_embedding_params(
object_id: String,
content: String,
collab_type: CollabType,
max_content_len: usize,
) -> Result<Vec<AFCollabEmbeddingParams>, AppError> {
if content.is_empty() {
return Ok(vec![]);
}
// Helper function to create AFCollabEmbeddingParams
fn create_param(
fragment_id: String,
object_id: &str,
collab_type: &CollabType,
content: String,
) -> AFCollabEmbeddingParams {
AFCollabEmbeddingParams {
fragment_id,
object_id: object_id.to_string(),
collab_type: collab_type.clone(),
content_type: EmbeddingContentType::PlainText,
content,
embedding: None,
}
}
if content.len() <= max_content_len {
// Content is short enough; return as a single fragment
let param = create_param(object_id.clone(), &object_id, &collab_type, content);
return Ok(vec![param]);
}
// Content is longer than max_content_len; need to split
let mut result = Vec::with_capacity(1 + content.len() / max_content_len);
let mut fragment = String::with_capacity(max_content_len);
let mut current_len = 0;
for grapheme in content.graphemes(true) {
let grapheme_len = grapheme.len();
if current_len + grapheme_len > max_content_len {
if !fragment.is_empty() {
// Move the fragment to avoid cloning
result.push(create_param(
Uuid::new_v4().to_string(),
&object_id,
&collab_type,
std::mem::take(&mut fragment),
));
}
current_len = 0;
3 // Check if the grapheme itself is longer than max_content_len
if grapheme_len > max_content_len {
// Push the grapheme as a fragment on its own
result.push(create_param(
Uuid::new_v4().to_string(),
&object_id,
&collab_type,
grapheme.to_string(),
));
continue;
}
}
fragment.push_str(grapheme);
current_len += grapheme_len;
}
// Add the last fragment if it's not empty
if !fragment.is_empty() {
result.push(create_param(
object_id.clone(),
&object_id,
&collab_type,
fragment,
));
}
Ok(result)
}
#[cfg(test)]
mod tests {
use crate::indexer::document_indexer::create_embedding_params;
use collab_entity::CollabType;
#[test]
fn test_split_at_non_utf8() {
let object_id = "test_object".to_string();
let collab_type = CollabType::Document;
let max_content_len = 10; // Small number for testing
// Content with multibyte characters (emojis)
let content = "Hello 😃 World 🌍! This is a test 🚀.".to_string();
let params = create_embedding_params(
object_id.clone(),
content.clone(),
collab_type.clone(),
max_content_len,
)
.unwrap();
// Ensure that we didn't split in the middle of a multibyte character
for param in params {
assert!(param.content.is_char_boundary(0));
assert!(param.content.is_char_boundary(param.content.len()));
}
}
#[test]
fn test_exact_boundary_split() {
let object_id = "test_object".to_string();
let collab_type = CollabType::Document;
let max_content_len = 5; // Set to 5 for testing
// Content length is exactly a multiple of max_content_len
let content = "abcdefghij".to_string(); // 10 characters
let params = create_embedding_params(
object_id.clone(),
content.clone(),
collab_type.clone(),
max_content_len,
)
.unwrap();
assert_eq!(params.len(), 2);
assert_eq!(params[0].content, "abcde");
assert_eq!(params[1].content, "fghij");
}
#[test]
fn test_content_shorter_than_max_len() {
let object_id = "test_object".to_string();
let collab_type = CollabType::Document;
let max_content_len = 100;
let content = "Short content".to_string();
let params = create_embedding_params(
object_id.clone(),
content.clone(),
collab_type.clone(),
max_content_len,
)
.unwrap();
assert_eq!(params.len(), 1);
assert_eq!(params[0].content, content);
}
#[test]
fn test_empty_content() {
let object_id = "test_object".to_string();
let collab_type = CollabType::Document;
let max_content_len = 10;
let content = "".to_string();
let params = create_embedding_params(
object_id.clone(),
content.clone(),
collab_type.clone(),
max_content_len,
)
.unwrap();
assert_eq!(params.len(), 0);
}
#[test]
fn test_content_with_only_multibyte_characters() {
let object_id = "test_object".to_string();
let collab_type = CollabType::Document;
let max_content_len = 4; // Small number for testing
// Each emoji is 4 bytes in UTF-8
let content = "😀😃😄😁😆".to_string();
let params = create_embedding_params(
object_id.clone(),
content.clone(),
collab_type.clone(),
max_content_len,
)
.unwrap();
assert_eq!(params.len(), 5);
let expected_contents = vec!["😀", "😃", "😄", "😁", "😆"];
for (param, expected) in params.iter().zip(expected_contents.iter()) {
assert_eq!(param.content, *expected);
}
}
#[test]
fn test_split_with_combining_characters() {
let object_id = "test_object".to_string();
let collab_type = CollabType::Document;
let max_content_len = 5; // Small number for testing
// String with combining characters (e.g., letters with accents)
let content = "a\u{0301}e\u{0301}i\u{0301}o\u{0301}u\u{0301}".to_string(); // "áéíóú"
let params = create_embedding_params(
object_id.clone(),
content.clone(),
collab_type.clone(),
max_content_len,
)
.unwrap();
assert_eq!(params.len(), 5);
let expected_contents = vec!["", "", "", "", ""];
for (param, expected) in params.iter().zip(expected_contents.iter()) {
assert_eq!(param.content, *expected);
}
}
#[test]
fn test_large_content() {
let object_id = "test_object".to_string();
let collab_type = CollabType::Document;
let max_content_len = 1000;
// Generate a large content string
let content = "a".repeat(5000); // 5000 characters
let params = create_embedding_params(
object_id.clone(),
content.clone(),
collab_type.clone(),
max_content_len,
)
.unwrap();
assert_eq!(params.len(), 5); // 5000 / 1000 = 5
for param in params {
assert_eq!(param.content.len(), 1000);
}
}
#[test]
fn test_non_ascii_characters() {
let object_id = "test_object".to_string();
let collab_type = CollabType::Document;
let max_content_len = 5;
// Non-ASCII characters: "áéíóú"
let content = "áéíóú".to_string();
let params = create_embedding_params(
object_id.clone(),
content.clone(),
collab_type.clone(),
max_content_len,
)
.unwrap();
// Content should be split into two fragments
assert_eq!(params.len(), 3);
assert_eq!(params[0].content, "áé");
assert_eq!(params[1].content, "íó");
assert_eq!(params[2].content, "ú");
}
#[test]
fn test_content_with_leading_and_trailing_whitespace() {
let object_id = "test_object".to_string();
let collab_type = CollabType::Document;
let max_content_len = 5;
let content = " abcde ".to_string();
let params = create_embedding_params(
object_id.clone(),
content.clone(),
collab_type.clone(),
max_content_len,
)
.unwrap();
// Content should include leading and trailing whitespace
assert_eq!(params.len(), 2);
assert_eq!(params[0].content, " abc");
assert_eq!(params[1].content, "de ");
}
#[test]
fn test_content_with_multiple_zero_width_joiners() {
let object_id = "test_object".to_string();
let collab_type = CollabType::Document;
let max_content_len = 10;
// Complex emoji sequence with multiple zero-width joiners
let content = "👩‍👩‍👧‍👧👨‍👨‍👦‍👦".to_string();
let params = create_embedding_params(
object_id.clone(),
content.clone(),
collab_type.clone(),
max_content_len,
)
.unwrap();
// Each complex emoji should be treated as a single grapheme
assert_eq!(params.len(), 2);
assert_eq!(params[0].content, "👩‍👩‍👧‍👧");
assert_eq!(params[1].content, "👨‍👨‍👦‍👦");
}
#[test]
fn test_content_with_long_combining_sequences() {
let object_id = "test_object".to_string();
let collab_type = CollabType::Document;
let max_content_len = 5;
// Character with multiple combining marks
let content = "a\u{0300}\u{0301}\u{0302}\u{0303}\u{0304}".to_string(); // a with multiple accents
let params = create_embedding_params(
object_id.clone(),
content.clone(),
collab_type.clone(),
max_content_len,
)
.unwrap();
// The entire combining sequence should be in one fragment
assert_eq!(params.len(), 1);
assert_eq!(params[0].content, content);
}
}