chore: support split by text len (#1002)
* chore: support split by text len * chore: update docs * chore: update tests
This commit is contained in:
parent
dcbc84dacc
commit
d798c81ba4
|
|
@ -754,6 +754,7 @@ dependencies = [
|
|||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"unicode-normalization",
|
||||
"unicode-segmentation",
|
||||
"uuid",
|
||||
"validator",
|
||||
"workspace-template",
|
||||
|
|
|
|||
|
|
@ -88,6 +88,8 @@ itertools = "0.12.0"
|
|||
validator = "0.16.1"
|
||||
rayon.workspace = true
|
||||
tiktoken-rs = "0.6.0"
|
||||
unicode-segmentation = "1.9.0"
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
rand = "0.8.5"
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ use collab_document::error::DocumentError;
|
|||
use collab_entity::CollabType;
|
||||
use database_entity::dto::{AFCollabEmbeddingParams, AFCollabEmbeddings, EmbeddingContentType};
|
||||
|
||||
use crate::config::get_env_var;
|
||||
use crate::indexer::open_ai::{split_text_by_max_content_len, split_text_by_max_tokens};
|
||||
use tiktoken_rs::CoreBPE;
|
||||
use tracing::trace;
|
||||
use uuid::Uuid;
|
||||
|
|
@ -54,12 +56,11 @@ impl Indexer for DocumentIndexer {
|
|||
match result {
|
||||
Ok(document_data) => {
|
||||
let content = document_data.to_plain_text();
|
||||
let max_tokens = self.embedding_model.default_dimensions() as usize;
|
||||
create_embedding(
|
||||
object_id,
|
||||
content,
|
||||
CollabType::Document,
|
||||
max_tokens,
|
||||
&self.embedding_model,
|
||||
self.tokenizer.clone(),
|
||||
)
|
||||
.await
|
||||
|
|
@ -129,47 +130,35 @@ impl Indexer for DocumentIndexer {
|
|||
}
|
||||
}
|
||||
|
||||
/// ## Execution Time Comparison Results
|
||||
///
|
||||
/// The following results were observed when running `execution_time_comparison_tests`:
|
||||
///
|
||||
/// | Content Size (chars) | Direct Time (ms) | spawn_blocking Time (ms) |
|
||||
/// |-----------------------|------------------|--------------------------|
|
||||
/// | 500 | 1 | 1 |
|
||||
/// | 1000 | 2 | 2 |
|
||||
/// | 2000 | 5 | 5 |
|
||||
/// | 5000 | 11 | 11 |
|
||||
/// | 20000 | 49 | 48 |
|
||||
///
|
||||
/// ## Guidelines for Using `spawn_blocking`
|
||||
///
|
||||
/// - **Short Tasks (< 1 ms)**:
|
||||
/// Use direct execution on the async runtime. The minimal execution time has negligible impact.
|
||||
///
|
||||
/// - **Moderate Tasks (1–10 ms)**:
|
||||
/// - For infrequent or low-concurrency tasks, direct execution is acceptable.
|
||||
/// - For frequent or high-concurrency tasks, consider using `spawn_blocking` to avoid delays.
|
||||
///
|
||||
/// - **Long Tasks (> 10 ms)**:
|
||||
/// Always offload to a blocking thread with `spawn_blocking` to maintain runtime efficiency and responsiveness.
|
||||
///
|
||||
/// Related blog:
|
||||
/// https://tokio.rs/blog/2020-04-preemption
|
||||
/// https://ryhl.io/blog/async-what-is-blocking/
|
||||
async fn create_embedding(
|
||||
object_id: String,
|
||||
content: String,
|
||||
collab_type: CollabType,
|
||||
max_tokens: usize,
|
||||
embedding_model: &EmbeddingModel,
|
||||
tokenizer: Arc<CoreBPE>,
|
||||
) -> Result<Vec<AFCollabEmbeddingParams>, AppError> {
|
||||
let split_contents = if content.len() < 500 {
|
||||
split_text_by_max_tokens(content, max_tokens, tokenizer.as_ref())?
|
||||
let use_tiktoken = get_env_var("APPFLOWY_AI_CONTENT_SPLITTER_TIKTOKEN", "false")
|
||||
.parse::<bool>()
|
||||
.unwrap_or(false);
|
||||
|
||||
let split_contents = if use_tiktoken {
|
||||
let max_tokens = embedding_model.default_dimensions() as usize;
|
||||
if content.len() < 500 {
|
||||
split_text_by_max_tokens(content, max_tokens, tokenizer.as_ref())?
|
||||
} else {
|
||||
tokio::task::spawn_blocking(move || {
|
||||
split_text_by_max_tokens(content, max_tokens, tokenizer.as_ref())
|
||||
})
|
||||
.await??
|
||||
}
|
||||
} else {
|
||||
tokio::task::spawn_blocking(move || {
|
||||
split_text_by_max_tokens(content, max_tokens, tokenizer.as_ref())
|
||||
})
|
||||
.await??
|
||||
debug_assert!(matches!(
|
||||
embedding_model,
|
||||
EmbeddingModel::TextEmbedding3Small
|
||||
));
|
||||
// We assume that every token is ~4 bytes. We're going to split document content into fragments
|
||||
// of ~2000 tokens each.
|
||||
split_text_by_max_content_len(content, 8000)?
|
||||
};
|
||||
|
||||
Ok(
|
||||
|
|
@ -186,264 +175,3 @@ async fn create_embedding(
|
|||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
fn split_text_by_max_tokens(
|
||||
content: String,
|
||||
max_tokens: usize,
|
||||
tokenizer: &CoreBPE,
|
||||
) -> Result<Vec<String>, AppError> {
|
||||
if content.is_empty() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
let token_ids = tokenizer.encode_ordinary(&content);
|
||||
let total_tokens = token_ids.len();
|
||||
if total_tokens <= max_tokens {
|
||||
return Ok(vec![content]);
|
||||
}
|
||||
|
||||
let mut chunks = Vec::new();
|
||||
let mut start_idx = 0;
|
||||
while start_idx < total_tokens {
|
||||
let mut end_idx = (start_idx + max_tokens).min(total_tokens);
|
||||
let mut decoded = false;
|
||||
// Try to decode the chunk, adjust end_idx if decoding fails
|
||||
while !decoded {
|
||||
let token_chunk = &token_ids[start_idx..end_idx];
|
||||
// Attempt to decode the current chunk
|
||||
match tokenizer.decode(token_chunk.to_vec()) {
|
||||
Ok(chunk_text) => {
|
||||
chunks.push(chunk_text);
|
||||
start_idx = end_idx;
|
||||
decoded = true;
|
||||
},
|
||||
Err(_) => {
|
||||
// If we can extend the chunk, do so
|
||||
if end_idx < total_tokens {
|
||||
end_idx += 1;
|
||||
} else if start_idx + 1 < total_tokens {
|
||||
// Skip the problematic token at start_idx
|
||||
start_idx += 1;
|
||||
end_idx = (start_idx + max_tokens).min(total_tokens);
|
||||
} else {
|
||||
// Cannot decode any further, break to avoid infinite loop
|
||||
start_idx = total_tokens;
|
||||
break;
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(chunks)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::indexer::document_indexer::split_text_by_max_tokens;
|
||||
|
||||
use tiktoken_rs::cl100k_base;
|
||||
|
||||
#[test]
|
||||
fn test_split_at_non_utf8() {
|
||||
let max_tokens = 10; // Small number for testing
|
||||
|
||||
// Content with multibyte characters (emojis)
|
||||
let content = "Hello 😃 World 🌍! This is a test 🚀.".to_string();
|
||||
let tokenizer = cl100k_base().unwrap();
|
||||
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();
|
||||
|
||||
// Ensure that we didn't split in the middle of a multibyte character
|
||||
for content in params {
|
||||
assert!(content.is_char_boundary(0));
|
||||
assert!(content.is_char_boundary(content.len()));
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
fn test_exact_boundary_split() {
|
||||
let max_tokens = 5; // Set to 5 tokens for testing
|
||||
let content = "The quick brown fox jumps over the lazy dog".to_string();
|
||||
let tokenizer = cl100k_base().unwrap();
|
||||
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();
|
||||
|
||||
let total_tokens = tokenizer.encode_ordinary(&content).len();
|
||||
let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens;
|
||||
assert_eq!(params.len(), expected_fragments);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_shorter_than_max_len() {
|
||||
let max_tokens = 100;
|
||||
let content = "Short content".to_string();
|
||||
let tokenizer = cl100k_base().unwrap();
|
||||
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();
|
||||
|
||||
assert_eq!(params.len(), 1);
|
||||
assert_eq!(params[0], content);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_content() {
|
||||
let max_tokens = 10;
|
||||
let content = "".to_string();
|
||||
let tokenizer = cl100k_base().unwrap();
|
||||
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();
|
||||
|
||||
assert_eq!(params.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_with_only_multibyte_characters() {
|
||||
let max_tokens = 1; // Set to 1 token for testing
|
||||
let content = "😀😃😄😁😆".to_string();
|
||||
let tokenizer = cl100k_base().unwrap();
|
||||
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();
|
||||
|
||||
let emojis: Vec<String> = content.chars().map(|c| c.to_string()).collect();
|
||||
for (param, emoji) in params.iter().zip(emojis.iter()) {
|
||||
assert_eq!(param, emoji);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_with_combining_characters() {
|
||||
let max_tokens = 1; // Set to 1 token for testing
|
||||
let content = "a\u{0301}e\u{0301}i\u{0301}o\u{0301}u\u{0301}".to_string(); // "áéíóú"
|
||||
let tokenizer = cl100k_base().unwrap();
|
||||
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();
|
||||
|
||||
let total_tokens = tokenizer.encode_ordinary(&content).len();
|
||||
assert_eq!(params.len(), total_tokens);
|
||||
|
||||
let reconstructed_content = params.join("");
|
||||
assert_eq!(reconstructed_content, content);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_large_content() {
|
||||
let max_tokens = 1000;
|
||||
let content = "a".repeat(5000); // 5000 characters
|
||||
let tokenizer = cl100k_base().unwrap();
|
||||
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();
|
||||
|
||||
let total_tokens = tokenizer.encode_ordinary(&content).len();
|
||||
let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens;
|
||||
assert_eq!(params.len(), expected_fragments);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_non_ascii_characters() {
|
||||
let max_tokens = 2;
|
||||
let content = "áéíóú".to_string();
|
||||
let tokenizer = cl100k_base().unwrap();
|
||||
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();
|
||||
|
||||
let total_tokens = tokenizer.encode_ordinary(&content).len();
|
||||
let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens;
|
||||
assert_eq!(params.len(), expected_fragments);
|
||||
|
||||
let reconstructed_content: String = params.concat();
|
||||
assert_eq!(reconstructed_content, content);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_with_leading_and_trailing_whitespace() {
|
||||
let max_tokens = 3;
|
||||
let content = " abcde ".to_string();
|
||||
let tokenizer = cl100k_base().unwrap();
|
||||
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();
|
||||
|
||||
let total_tokens = tokenizer.encode_ordinary(&content).len();
|
||||
let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens;
|
||||
assert_eq!(params.len(), expected_fragments);
|
||||
|
||||
let reconstructed_content: String = params.concat();
|
||||
assert_eq!(reconstructed_content, content);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_with_multiple_zero_width_joiners() {
|
||||
let max_tokens = 1;
|
||||
let content = "👩👩👧👧👨👨👦👦".to_string();
|
||||
let tokenizer = cl100k_base().unwrap();
|
||||
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();
|
||||
|
||||
let reconstructed_content: String = params.concat();
|
||||
assert_eq!(reconstructed_content, content);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_with_long_combining_sequences() {
|
||||
let max_tokens = 1;
|
||||
let content = "a\u{0300}\u{0301}\u{0302}\u{0303}\u{0304}".to_string();
|
||||
let tokenizer = cl100k_base().unwrap();
|
||||
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();
|
||||
|
||||
let reconstructed_content: String = params.concat();
|
||||
assert_eq!(reconstructed_content, content);
|
||||
}
|
||||
}
|
||||
|
||||
// #[cfg(test)]
|
||||
// mod execution_time_comparison_tests {
|
||||
// use crate::indexer::document_indexer::split_text_by_max_tokens;
|
||||
// use rand::distributions::Alphanumeric;
|
||||
// use rand::{thread_rng, Rng};
|
||||
// use std::sync::Arc;
|
||||
// use std::time::Instant;
|
||||
// use tiktoken_rs::{cl100k_base, CoreBPE};
|
||||
//
|
||||
// #[tokio::test]
|
||||
// async fn test_execution_time_comparison() {
|
||||
// let tokenizer = Arc::new(cl100k_base().unwrap());
|
||||
// let max_tokens = 100;
|
||||
//
|
||||
// let sizes = vec![500, 1000, 2000, 5000, 20000]; // Content sizes to test
|
||||
// for size in sizes {
|
||||
// let content = generate_random_string(size);
|
||||
//
|
||||
// // Measure direct execution time
|
||||
// let direct_time = measure_direct_execution(content.clone(), max_tokens, &tokenizer);
|
||||
//
|
||||
// // Measure spawn_blocking execution time
|
||||
// let spawn_blocking_time =
|
||||
// measure_spawn_blocking_execution(content, max_tokens, Arc::clone(&tokenizer)).await;
|
||||
//
|
||||
// println!(
|
||||
// "Content Size: {} | Direct Time: {}ms | spawn_blocking Time: {}ms",
|
||||
// size, direct_time, spawn_blocking_time
|
||||
// );
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // Measure direct execution time
|
||||
// fn measure_direct_execution(content: String, max_tokens: usize, tokenizer: &CoreBPE) -> u128 {
|
||||
// let start = Instant::now();
|
||||
// split_text_by_max_tokens(content, max_tokens, tokenizer).unwrap();
|
||||
// start.elapsed().as_millis()
|
||||
// }
|
||||
//
|
||||
// // Measure `spawn_blocking` execution time
|
||||
// async fn measure_spawn_blocking_execution(
|
||||
// content: String,
|
||||
// max_tokens: usize,
|
||||
// tokenizer: Arc<CoreBPE>,
|
||||
// ) -> u128 {
|
||||
// let start = Instant::now();
|
||||
// tokio::task::spawn_blocking(move || {
|
||||
// split_text_by_max_tokens(content, max_tokens, tokenizer.as_ref()).unwrap()
|
||||
// })
|
||||
// .await
|
||||
// .unwrap();
|
||||
// start.elapsed().as_millis()
|
||||
// }
|
||||
//
|
||||
// pub fn generate_random_string(len: usize) -> String {
|
||||
// let rng = thread_rng();
|
||||
// rng
|
||||
// .sample_iter(&Alphanumeric)
|
||||
// .take(len)
|
||||
// .map(char::from)
|
||||
// .collect()
|
||||
// }
|
||||
// }
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
mod document_indexer;
|
||||
mod ext;
|
||||
mod open_ai;
|
||||
mod provider;
|
||||
|
||||
pub use document_indexer::DocumentIndexer;
|
||||
pub use ext::DocumentDataExt;
|
||||
pub use provider::*;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,361 @@
|
|||
use app_error::AppError;
|
||||
use tiktoken_rs::CoreBPE;
|
||||
use unicode_segmentation::UnicodeSegmentation;
|
||||
|
||||
/// ## Execution Time Comparison Results
|
||||
///
|
||||
/// The following results were observed when running `execution_time_comparison_tests`:
|
||||
///
|
||||
/// | Content Size (chars) | Direct Time (ms) | spawn_blocking Time (ms) |
|
||||
/// |-----------------------|------------------|--------------------------|
|
||||
/// | 500 | 1 | 1 |
|
||||
/// | 1000 | 2 | 2 |
|
||||
/// | 2000 | 5 | 5 |
|
||||
/// | 5000 | 11 | 11 |
|
||||
/// | 20000 | 49 | 48 |
|
||||
///
|
||||
/// ## Guidelines for Using `spawn_blocking`
|
||||
///
|
||||
/// - **Short Tasks (< 1 ms)**:
|
||||
/// Use direct execution on the async runtime. The minimal execution time has negligible impact.
|
||||
///
|
||||
/// - **Moderate Tasks (1–10 ms)**:
|
||||
/// - For infrequent or low-concurrency tasks, direct execution is acceptable.
|
||||
/// - For frequent or high-concurrency tasks, consider using `spawn_blocking` to avoid delays.
|
||||
///
|
||||
/// - **Long Tasks (> 10 ms)**:
|
||||
/// Always offload to a blocking thread with `spawn_blocking` to maintain runtime efficiency and responsiveness.
|
||||
///
|
||||
/// Related blog:
|
||||
/// https://tokio.rs/blog/2020-04-preemption
|
||||
/// https://ryhl.io/blog/async-what-is-blocking/
|
||||
#[inline]
|
||||
pub fn split_text_by_max_tokens(
|
||||
content: String,
|
||||
max_tokens: usize,
|
||||
tokenizer: &CoreBPE,
|
||||
) -> Result<Vec<String>, AppError> {
|
||||
if content.is_empty() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
let token_ids = tokenizer.encode_ordinary(&content);
|
||||
let total_tokens = token_ids.len();
|
||||
if total_tokens <= max_tokens {
|
||||
return Ok(vec![content]);
|
||||
}
|
||||
|
||||
let mut chunks = Vec::new();
|
||||
let mut start_idx = 0;
|
||||
while start_idx < total_tokens {
|
||||
let mut end_idx = (start_idx + max_tokens).min(total_tokens);
|
||||
let mut decoded = false;
|
||||
// Try to decode the chunk, adjust end_idx if decoding fails
|
||||
while !decoded {
|
||||
let token_chunk = &token_ids[start_idx..end_idx];
|
||||
// Attempt to decode the current chunk
|
||||
match tokenizer.decode(token_chunk.to_vec()) {
|
||||
Ok(chunk_text) => {
|
||||
chunks.push(chunk_text);
|
||||
start_idx = end_idx;
|
||||
decoded = true;
|
||||
},
|
||||
Err(_) => {
|
||||
// If we can extend the chunk, do so
|
||||
if end_idx < total_tokens {
|
||||
end_idx += 1;
|
||||
} else if start_idx + 1 < total_tokens {
|
||||
// Skip the problematic token at start_idx
|
||||
start_idx += 1;
|
||||
end_idx = (start_idx + max_tokens).min(total_tokens);
|
||||
} else {
|
||||
// Cannot decode any further, break to avoid infinite loop
|
||||
start_idx = total_tokens;
|
||||
break;
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(chunks)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn split_text_by_max_content_len(
|
||||
content: String,
|
||||
max_content_len: usize,
|
||||
) -> Result<Vec<String>, AppError> {
|
||||
if content.is_empty() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
if content.len() <= max_content_len {
|
||||
return Ok(vec![content]);
|
||||
}
|
||||
|
||||
// Content is longer than max_content_len; need to split
|
||||
let mut result = Vec::with_capacity(1 + content.len() / max_content_len);
|
||||
let mut fragment = String::with_capacity(max_content_len);
|
||||
let mut current_len = 0;
|
||||
|
||||
for grapheme in content.graphemes(true) {
|
||||
let grapheme_len = grapheme.len();
|
||||
if current_len + grapheme_len > max_content_len {
|
||||
if !fragment.is_empty() {
|
||||
result.push(std::mem::take(&mut fragment));
|
||||
}
|
||||
current_len = 0;
|
||||
|
||||
if grapheme_len > max_content_len {
|
||||
// Push the grapheme as a fragment on its own
|
||||
result.push(grapheme.to_string());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
fragment.push_str(grapheme);
|
||||
current_len += grapheme_len;
|
||||
}
|
||||
|
||||
// Add the last fragment if it's not empty
|
||||
if !fragment.is_empty() {
|
||||
result.push(fragment);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use crate::indexer::open_ai::{split_text_by_max_content_len, split_text_by_max_tokens};
|
||||
use tiktoken_rs::cl100k_base;
|
||||
|
||||
#[test]
|
||||
fn test_split_at_non_utf8() {
|
||||
let max_tokens = 10; // Small number for testing
|
||||
|
||||
// Content with multibyte characters (emojis)
|
||||
let content = "Hello 😃 World 🌍! This is a test 🚀.".to_string();
|
||||
let tokenizer = cl100k_base().unwrap();
|
||||
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();
|
||||
for content in params {
|
||||
assert!(content.is_char_boundary(0));
|
||||
assert!(content.is_char_boundary(content.len()));
|
||||
}
|
||||
|
||||
let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap();
|
||||
for content in params {
|
||||
assert!(content.is_char_boundary(0));
|
||||
assert!(content.is_char_boundary(content.len()));
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
fn test_exact_boundary_split() {
|
||||
let max_tokens = 5; // Set to 5 tokens for testing
|
||||
let content = "The quick brown fox jumps over the lazy dog".to_string();
|
||||
let tokenizer = cl100k_base().unwrap();
|
||||
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();
|
||||
|
||||
let total_tokens = tokenizer.encode_ordinary(&content).len();
|
||||
let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens;
|
||||
assert_eq!(params.len(), expected_fragments);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_shorter_than_max_len() {
|
||||
let max_tokens = 100;
|
||||
let content = "Short content".to_string();
|
||||
let tokenizer = cl100k_base().unwrap();
|
||||
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();
|
||||
|
||||
assert_eq!(params.len(), 1);
|
||||
assert_eq!(params[0], content);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_content() {
|
||||
let max_tokens = 10;
|
||||
let content = "".to_string();
|
||||
let tokenizer = cl100k_base().unwrap();
|
||||
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();
|
||||
assert_eq!(params.len(), 0);
|
||||
|
||||
let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap();
|
||||
assert_eq!(params.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_with_only_multibyte_characters() {
|
||||
let max_tokens = 1; // Set to 1 token for testing
|
||||
let content = "😀😃😄😁😆".to_string();
|
||||
let tokenizer = cl100k_base().unwrap();
|
||||
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();
|
||||
|
||||
let emojis: Vec<String> = content.chars().map(|c| c.to_string()).collect();
|
||||
for (param, emoji) in params.iter().zip(emojis.iter()) {
|
||||
assert_eq!(param, emoji);
|
||||
}
|
||||
|
||||
let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap();
|
||||
for (param, emoji) in params.iter().zip(emojis.iter()) {
|
||||
assert_eq!(param, emoji);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_with_combining_characters() {
|
||||
let max_tokens = 1; // Set to 1 token for testing
|
||||
let content = "a\u{0301}e\u{0301}i\u{0301}o\u{0301}u\u{0301}".to_string(); // "áéíóú"
|
||||
|
||||
let tokenizer = cl100k_base().unwrap();
|
||||
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();
|
||||
let total_tokens = tokenizer.encode_ordinary(&content).len();
|
||||
assert_eq!(params.len(), total_tokens);
|
||||
let reconstructed_content = params.join("");
|
||||
assert_eq!(reconstructed_content, content);
|
||||
|
||||
let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap();
|
||||
let reconstructed_content: String = params.concat();
|
||||
assert_eq!(reconstructed_content, content);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_large_content() {
|
||||
let max_tokens = 1000;
|
||||
let content = "a".repeat(5000); // 5000 characters
|
||||
let tokenizer = cl100k_base().unwrap();
|
||||
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();
|
||||
|
||||
let total_tokens = tokenizer.encode_ordinary(&content).len();
|
||||
let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens;
|
||||
assert_eq!(params.len(), expected_fragments);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_non_ascii_characters() {
|
||||
let max_tokens = 2;
|
||||
let content = "áéíóú".to_string();
|
||||
let tokenizer = cl100k_base().unwrap();
|
||||
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();
|
||||
|
||||
let total_tokens = tokenizer.encode_ordinary(&content).len();
|
||||
let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens;
|
||||
assert_eq!(params.len(), expected_fragments);
|
||||
let reconstructed_content: String = params.concat();
|
||||
assert_eq!(reconstructed_content, content);
|
||||
|
||||
let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap();
|
||||
let reconstructed_content: String = params.concat();
|
||||
assert_eq!(reconstructed_content, content);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_with_leading_and_trailing_whitespace() {
|
||||
let max_tokens = 3;
|
||||
let content = " abcde ".to_string();
|
||||
let tokenizer = cl100k_base().unwrap();
|
||||
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();
|
||||
|
||||
let total_tokens = tokenizer.encode_ordinary(&content).len();
|
||||
let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens;
|
||||
assert_eq!(params.len(), expected_fragments);
|
||||
let reconstructed_content: String = params.concat();
|
||||
assert_eq!(reconstructed_content, content);
|
||||
|
||||
let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap();
|
||||
let reconstructed_content: String = params.concat();
|
||||
assert_eq!(reconstructed_content, content);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_with_multiple_zero_width_joiners() {
|
||||
let max_tokens = 1;
|
||||
let content = "👩👩👧👧👨👨👦👦".to_string();
|
||||
let tokenizer = cl100k_base().unwrap();
|
||||
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();
|
||||
let reconstructed_content: String = params.concat();
|
||||
assert_eq!(reconstructed_content, content);
|
||||
|
||||
let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap();
|
||||
let reconstructed_content: String = params.concat();
|
||||
assert_eq!(reconstructed_content, content);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_with_long_combining_sequences() {
|
||||
let max_tokens = 1;
|
||||
let content = "a\u{0300}\u{0301}\u{0302}\u{0303}\u{0304}".to_string();
|
||||
let tokenizer = cl100k_base().unwrap();
|
||||
let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap();
|
||||
let reconstructed_content: String = params.concat();
|
||||
assert_eq!(reconstructed_content, content);
|
||||
|
||||
let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap();
|
||||
let reconstructed_content: String = params.concat();
|
||||
assert_eq!(reconstructed_content, content);
|
||||
}
|
||||
}
|
||||
|
||||
// #[cfg(test)]
|
||||
// mod execution_time_comparison_tests {
|
||||
// use crate::indexer::document_indexer::split_text_by_max_tokens;
|
||||
// use rand::distributions::Alphanumeric;
|
||||
// use rand::{thread_rng, Rng};
|
||||
// use std::sync::Arc;
|
||||
// use std::time::Instant;
|
||||
// use tiktoken_rs::{cl100k_base, CoreBPE};
|
||||
//
|
||||
// #[tokio::test]
|
||||
// async fn test_execution_time_comparison() {
|
||||
// let tokenizer = Arc::new(cl100k_base().unwrap());
|
||||
// let max_tokens = 100;
|
||||
//
|
||||
// let sizes = vec![500, 1000, 2000, 5000, 20000]; // Content sizes to test
|
||||
// for size in sizes {
|
||||
// let content = generate_random_string(size);
|
||||
//
|
||||
// // Measure direct execution time
|
||||
// let direct_time = measure_direct_execution(content.clone(), max_tokens, &tokenizer);
|
||||
//
|
||||
// // Measure spawn_blocking execution time
|
||||
// let spawn_blocking_time =
|
||||
// measure_spawn_blocking_execution(content, max_tokens, Arc::clone(&tokenizer)).await;
|
||||
//
|
||||
// println!(
|
||||
// "Content Size: {} | Direct Time: {}ms | spawn_blocking Time: {}ms",
|
||||
// size, direct_time, spawn_blocking_time
|
||||
// );
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // Measure direct execution time
|
||||
// fn measure_direct_execution(content: String, max_tokens: usize, tokenizer: &CoreBPE) -> u128 {
|
||||
// let start = Instant::now();
|
||||
// split_text_by_max_tokens(content, max_tokens, tokenizer).unwrap();
|
||||
// start.elapsed().as_millis()
|
||||
// }
|
||||
//
|
||||
// // Measure `spawn_blocking` execution time
|
||||
// async fn measure_spawn_blocking_execution(
|
||||
// content: String,
|
||||
// max_tokens: usize,
|
||||
// tokenizer: Arc<CoreBPE>,
|
||||
// ) -> u128 {
|
||||
// let start = Instant::now();
|
||||
// tokio::task::spawn_blocking(move || {
|
||||
// split_text_by_max_tokens(content, max_tokens, tokenizer.as_ref()).unwrap()
|
||||
// })
|
||||
// .await
|
||||
// .unwrap();
|
||||
// start.elapsed().as_millis()
|
||||
// }
|
||||
//
|
||||
// pub fn generate_random_string(len: usize) -> String {
|
||||
// let rng = thread_rng();
|
||||
// rng
|
||||
// .sample_iter(&Alphanumeric)
|
||||
// .take(len)
|
||||
// .map(char::from)
|
||||
// .collect()
|
||||
// }
|
||||
// }
|
||||
|
|
@ -245,7 +245,6 @@ async fn process_upcoming_tasks(
|
|||
}
|
||||
}
|
||||
}
|
||||
info!("[Import] stop reading tasks from stream");
|
||||
}
|
||||
#[derive(Clone)]
|
||||
struct TaskContext {
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ The Five Dysfunctions of a Team by Patrick Lencioni The Five Dysfunctions of a T
|
|||
|
||||
let params = CalculateSimilarityParams {
|
||||
workspace_id,
|
||||
input: answer,
|
||||
input: answer.clone(),
|
||||
expected: r#"
|
||||
Kathryn Petersen is the newly appointed CEO of DecisionTech, a struggling Silicon Valley startup.
|
||||
She steps into a role facing a dysfunctional executive team characterized by poor communication,
|
||||
|
|
@ -131,7 +131,7 @@ The Five Dysfunctions of a Team by Patrick Lencioni The Five Dysfunctions of a T
|
|||
.unwrap()
|
||||
.score;
|
||||
|
||||
assert!(score > 0.9, "score: {}", score);
|
||||
assert!(score > 0.9, "score: {}, input:{}", score, answer);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue