100 lines
2.9 KiB
Rust
100 lines
2.9 KiB
Rust
use crate::api::metrics::RequestMetrics;
|
|
use app_error::ErrorCode;
|
|
use database::index::{search_documents, SearchDocumentParams};
|
|
use openai_dive::v1::models::EmbeddingsEngine;
|
|
use openai_dive::v1::resources::embedding::{
|
|
EmbeddingEncodingFormat, EmbeddingInput, EmbeddingOutput, EmbeddingParameters,
|
|
};
|
|
use shared_entity::dto::search_dto::{
|
|
SearchContentType, SearchDocumentRequest, SearchDocumentResponseItem,
|
|
};
|
|
use shared_entity::response::AppResponseError;
|
|
use sqlx::PgPool;
|
|
use uuid::Uuid;
|
|
|
|
pub async fn search_document(
|
|
pg_pool: &PgPool,
|
|
openai: &openai_dive::v1::api::Client,
|
|
uid: i64,
|
|
workspace_id: Uuid,
|
|
request: SearchDocumentRequest,
|
|
metrics: &RequestMetrics,
|
|
) -> Result<Vec<SearchDocumentResponseItem>, AppResponseError> {
|
|
let embeddings = openai
|
|
.embeddings()
|
|
.create(EmbeddingParameters {
|
|
input: EmbeddingInput::String(request.query.clone()),
|
|
model: EmbeddingsEngine::TextEmbedding3Small.to_string(),
|
|
encoding_format: Some(EmbeddingEncodingFormat::Float),
|
|
dimensions: Some(1536), // text-embedding-3-small default number of dimensions
|
|
user: None,
|
|
})
|
|
.await
|
|
.map_err(|e| AppResponseError::new(ErrorCode::Internal, e.to_string()))?;
|
|
|
|
let tokens_used = if let Some(usage) = embeddings.usage {
|
|
metrics.record_search_tokens_used(usage.total_tokens);
|
|
tracing::info!(
|
|
"workspace {} OpenAI API search tokens used: {}",
|
|
workspace_id,
|
|
usage.total_tokens
|
|
);
|
|
usage.total_tokens
|
|
} else {
|
|
0
|
|
};
|
|
|
|
let embedding = embeddings
|
|
.data
|
|
.first()
|
|
.ok_or_else(|| AppResponseError::new(ErrorCode::Internal, "OpenAI returned no embeddings"))?;
|
|
let embedding = match &embedding.embedding {
|
|
EmbeddingOutput::Float(vector) => vector.iter().map(|&v| v as f32).collect(),
|
|
EmbeddingOutput::Base64(_) => {
|
|
return Err(AppResponseError::new(
|
|
ErrorCode::Internal,
|
|
"OpenAI returned embeddings in unsupported format",
|
|
))
|
|
},
|
|
};
|
|
|
|
let mut tx = pg_pool
|
|
.begin()
|
|
.await
|
|
.map_err(|e| AppResponseError::new(ErrorCode::Internal, e.to_string()))?;
|
|
let results = search_documents(
|
|
&mut tx,
|
|
SearchDocumentParams {
|
|
user_id: uid,
|
|
workspace_id,
|
|
limit: request.limit.unwrap_or(10) as i32,
|
|
preview: request.preview_size.unwrap_or(180) as i32,
|
|
embedding,
|
|
},
|
|
tokens_used,
|
|
)
|
|
.await?;
|
|
tx.commit().await?;
|
|
tracing::trace!(
|
|
"user {} search request in workspace {} returned {} results for query: `{}`",
|
|
uid,
|
|
workspace_id,
|
|
results.len(),
|
|
request.query
|
|
);
|
|
Ok(
|
|
results
|
|
.into_iter()
|
|
.map(|item| SearchDocumentResponseItem {
|
|
object_id: item.object_id,
|
|
workspace_id: item.workspace_id.to_string(),
|
|
score: item.score,
|
|
content_type: SearchContentType::from_record(item.content_type),
|
|
preview: item.content_preview,
|
|
created_by: item.created_by,
|
|
created_at: item.created_at,
|
|
})
|
|
.collect(),
|
|
)
|
|
}
|