From 0ca22f771705b051212f2eb20826cb861ca332ab Mon Sep 17 00:00:00 2001 From: Khor Shu Heng <32997938+khorshuheng@users.noreply.github.com> Date: Wed, 8 Jan 2025 21:49:06 +0800 Subject: [PATCH] fix: obtain unsearchable views before searching for embeddings (#1140) --- libs/database/src/index/search_ops.rs | 18 +++--- src/biz/collab/folder_view.rs | 14 ++++- src/biz/search/ops.rs | 86 +++++++++++++-------------- 3 files changed, 65 insertions(+), 53 deletions(-) diff --git a/libs/database/src/index/search_ops.rs b/libs/database/src/index/search_ops.rs index 589ec50d..97eedc1a 100644 --- a/libs/database/src/index/search_ops.rs +++ b/libs/database/src/index/search_ops.rs @@ -1,8 +1,6 @@ -use std::ops::DerefMut; - use chrono::{DateTime, Utc}; use pgvector::Vector; -use sqlx::Transaction; +use sqlx::{Executor, Postgres}; use uuid::Uuid; /// Logs each search request to track usage by workspace. It either inserts a new record or updates @@ -12,8 +10,8 @@ use uuid::Uuid; /// Searches and retrieves documents based on their similarity to a given search embedding. /// It filters by workspace, user access, and document status, and returns a limited number /// of the most relevant documents, sorted by similarity score. -pub async fn search_documents( - tx: &mut Transaction<'_, sqlx::Postgres>, +pub async fn search_documents<'a, E: Executor<'a, Database = Postgres>>( + executor: E, params: SearchDocumentParams, tokens_used: u32, ) -> Result, sqlx::Error> { @@ -38,9 +36,8 @@ pub async fn search_documents( em.embedding <=> $3 AS score FROM af_collab_embeddings em JOIN af_collab collab ON em.oid = collab.oid AND em.partition_key = collab.partition_key - JOIN af_workspace_member member ON collab.workspace_id = member.workspace_id JOIN af_user u ON collab.owner_uid = u.uid - WHERE member.uid = $1 AND collab.workspace_id = $2 AND collab.deleted_at IS NULL + WHERE collab.workspace_id = $2 AND NOT(collab.oid = ANY($7::text[])) ORDER BY em.embedding <=> $3 LIMIT $5 "#, @@ -50,8 +47,9 @@ pub async fn search_documents( .bind(Vector::from(params.embedding)) .bind(params.preview) .bind(params.limit) - .bind(tokens_used as i64); - let rows = query.fetch_all(tx.deref_mut()).await?; + .bind(tokens_used as i64) + .bind(params.non_viewable_view_ids); + let rows = query.fetch_all(executor).await?; Ok(rows) } @@ -67,6 +65,8 @@ pub struct SearchDocumentParams { pub preview: i32, /// Embedding of the query - generated by OpenAI embedder. pub embedding: Vec, + /// List of view ids which is not supposed to be returned in the search results. + pub non_viewable_view_ids: Vec, } #[derive(Debug, Clone, sqlx::FromRow)] diff --git a/src/biz/collab/folder_view.rs b/src/biz/collab/folder_view.rs index 73252ff3..ce9158cd 100644 --- a/src/biz/collab/folder_view.rs +++ b/src/biz/collab/folder_view.rs @@ -25,11 +25,23 @@ pub fn private_and_nonviewable_view_ids(folder: &Folder) -> PrivateAndNonviewabl if check_if_view_is_space(&private_view) && !my_private_view_ids.contains(&private_section.id) { nonviewable_view_ids.insert(private_section.id); + let private_view_ids_in_space: HashSet = folder + .get_views_belong_to(&private_view.id) + .iter() + .map(|v| v.id.clone()) + .collect(); + nonviewable_view_ids.extend(private_view_ids_in_space); } } } for trash_view in folder.get_all_trash_sections() { - nonviewable_view_ids.insert(trash_view.id); + nonviewable_view_ids.insert(trash_view.id.clone()); + let child_views_for_trash: HashSet = folder + .get_views_belong_to(&trash_view.id) + .iter() + .map(|v| v.id.clone()) + .collect(); + nonviewable_view_ids.extend(child_views_for_trash); } PrivateAndNonviewableViews { my_private_view_ids, diff --git a/src/biz/search/ops.rs b/src/biz/search/ops.rs index 0923f394..2fbe456a 100644 --- a/src/biz/search/ops.rs +++ b/src/biz/search/ops.rs @@ -1,21 +1,20 @@ use crate::api::metrics::RequestMetrics; -use crate::biz::collab::folder_view::{ - check_if_view_ancestors_fulfil_condition, private_and_nonviewable_view_ids, -}; +use crate::biz::collab::folder_view::private_and_nonviewable_view_ids; use crate::biz::collab::utils::get_latest_collab_folder; -use app_error::ErrorCode; +use app_error::AppError; use appflowy_ai_client::dto::{ EmbeddingEncodingFormat, EmbeddingInput, EmbeddingModel, EmbeddingOutput, EmbeddingRequest, }; use appflowy_collaborate::collab::storage::CollabAccessControlStorage; use database::collab::GetCollabOrigin; +use itertools::Itertools; +use std::collections::HashSet; use std::sync::Arc; use database::index::{search_documents, SearchDocumentParams}; use shared_entity::dto::search_dto::{ SearchContentType, SearchDocumentRequest, SearchDocumentResponseItem, }; -use shared_entity::response::AppResponseError; use sqlx::PgPool; use indexer::scheduler::IndexerScheduler; @@ -29,7 +28,7 @@ pub async fn search_document( workspace_id: Uuid, request: SearchDocumentRequest, metrics: &RequestMetrics, -) -> Result, AppResponseError> { +) -> Result, AppError> { let embeddings = indexer_scheduler .create_search_embeddings(EmbeddingRequest { input: EmbeddingInput::String(request.query.clone()), @@ -49,42 +48,16 @@ pub async fn search_document( let embedding = embeddings .data .first() - .ok_or_else(|| AppResponseError::new(ErrorCode::Internal, "OpenAI returned no embeddings"))?; + .ok_or_else(|| AppError::Internal(anyhow::anyhow!("OpenAI returned no embeddings")))?; let embedding = match &embedding.embedding { EmbeddingOutput::Float(vector) => vector.iter().map(|&v| v as f32).collect(), EmbeddingOutput::Base64(_) => { - return Err(AppResponseError::new( - ErrorCode::Internal, - "OpenAI returned embeddings in unsupported format", - )) + return Err(AppError::Internal(anyhow::anyhow!( + "OpenAI returned embeddings in unsupported format" + ))) }, }; - let mut tx = pg_pool - .begin() - .await - .map_err(|e| AppResponseError::new(ErrorCode::Internal, e.to_string()))?; - let results = search_documents( - &mut tx, - SearchDocumentParams { - user_id: uid, - workspace_id, - limit: request.limit.unwrap_or(10) as i32, - preview: request.preview_size.unwrap_or(500) as i32, - embedding, - }, - total_tokens, - ) - .await?; - tx.commit().await?; - tracing::trace!( - "user {} search request in workspace {} returned {} results for query: `{}`", - uid, - workspace_id, - results.len(), - request.query - ); - let folder = get_latest_collab_folder( collab_storage, GetCollabOrigin::User { uid }, @@ -92,15 +65,42 @@ pub async fn search_document( ) .await?; let private_and_nonviewable_views = private_and_nonviewable_view_ids(&folder); - let non_searchable_view_ids = private_and_nonviewable_views.nonviewable_view_ids; - let filtered_results = results.into_iter().filter(|item| { - !check_if_view_ancestors_fulfil_condition(&item.object_id, &folder, |view| { - non_searchable_view_ids.contains(&view.id) - }) - }); + let space_ids: HashSet = folder + .get_view(&workspace_id.to_string()) + .ok_or_else(|| AppError::Internal(anyhow::anyhow!("Workspace view not found in folder")))? + .children + .iter() + .map(|c| c.id.clone()) + .collect(); + + let mut non_searchable_view_ids = private_and_nonviewable_views.nonviewable_view_ids; + non_searchable_view_ids.extend(space_ids); + let results = search_documents( + pg_pool, + SearchDocumentParams { + user_id: uid, + workspace_id, + limit: request.limit.unwrap_or(10) as i32, + preview: request.preview_size.unwrap_or(500) as i32, + embedding, + non_viewable_view_ids: non_searchable_view_ids + .iter() + .map(|uuid| uuid.to_string()) + .collect_vec(), + }, + total_tokens, + ) + .await?; + tracing::trace!( + "user {} search request in workspace {} returned {} results for query: `{}`", + uid, + workspace_id, + results.len(), + request.query + ); Ok( - filtered_results + results .into_iter() .map(|item| SearchDocumentResponseItem { object_id: item.object_id,