From 8beac5c85f94217b516b1671e5b74b69d28545b6 Mon Sep 17 00:00:00 2001 From: Bartosz Sypytkowski Date: Tue, 27 Aug 2024 18:13:41 +0200 Subject: [PATCH] chore: reuse group collab for indexing (#737) * chore: reuse group collab for indexing * chore: fix linting errors * chore: post rebase fixes --- .../src/group/persistence.rs | 38 ++++++++--- .../src/indexer/document_indexer.rs | 66 ++++++++----------- .../src/indexer/provider.rs | 43 ++++++++---- src/api/workspace.rs | 13 ++-- 4 files changed, 94 insertions(+), 66 deletions(-) diff --git a/services/appflowy-collaborate/src/group/persistence.rs b/services/appflowy-collaborate/src/group/persistence.rs index bc94224e..8fd185b8 100644 --- a/services/appflowy-collaborate/src/group/persistence.rs +++ b/services/appflowy-collaborate/src/group/persistence.rs @@ -122,18 +122,36 @@ where None => return Err(AppError::Internal(anyhow!("collab has been dropped"))), }; - let mut params = { + let params = { let lock = collab.read().await; - get_encode_collab(&workspace_id, &object_id, &lock, &collab_type)? - }; - if let Some(indexer) = &self.indexer { - if let Ok(embeddings) = indexer - .index(&object_id, params.encoded_collab_v1.clone()) - .await - { - params.embeddings = embeddings; + let mut params = get_encode_collab(&workspace_id, &object_id, &lock, &collab_type)?; + + if let Some(indexer) = &self.indexer { + match indexer.embedding_params(&lock) { + Ok(embedding_params) => { + drop(lock); // we no longer need the lock + match indexer.embeddings(embedding_params).await { + Ok(embeddings) => { + params.embeddings = embeddings; + }, + Err(err) => { + warn!( + "failed to index embeddings from remote service for document {}/{}: {}", + workspace_id, object_id, err + ); + }, + } + }, + Err(err) => { + warn!( + "failed to get embedding params for document {}/{}: {}", + workspace_id, object_id, err + ); + }, + } } - } + params + }; self .storage diff --git a/services/appflowy-collaborate/src/indexer/document_indexer.rs b/services/appflowy-collaborate/src/indexer/document_indexer.rs index 0004ab85..268884c9 100644 --- a/services/appflowy-collaborate/src/indexer/document_indexer.rs +++ b/services/appflowy-collaborate/src/indexer/document_indexer.rs @@ -1,11 +1,9 @@ use std::sync::Arc; +use anyhow::anyhow; use async_trait::async_trait; -use collab::core::collab::DataSource; -use collab::core::origin::CollabOrigin; use collab::preclude::Collab; -use collab_document::document::Document; -use collab_document::error::DocumentError; +use collab_document::document::DocumentBody; use collab_entity::CollabType; use app_error::AppError; @@ -25,12 +23,27 @@ impl DocumentIndexer { pub fn new(ai_client: AppFlowyAIClient) -> Arc { Arc::new(Self { ai_client }) } +} - fn get_document_contents( - document: &Document, - ) -> Result, DocumentError> { - let object_id = document.object_id().to_string(); - let document_data = document.get_document_data()?; +#[async_trait] +impl Indexer for DocumentIndexer { + fn embedding_params(&self, collab: &Collab) -> Result, AppError> { + let object_id = collab.object_id().to_string(); + let document = DocumentBody::from_collab(collab).ok_or_else(|| { + anyhow!( + "Failed to get document body from collab `{}`: schema is missing required fields", + object_id + ) + })?; + let document_data = document + .get_document_data(&collab.transact()) + .map_err(|err| { + anyhow!( + "Failed to get document data from collab `{}`: {}", + object_id, + err + ) + })?; let content = document_data.to_plain_text(); let plain_text_param = AFCollabEmbeddingParams { @@ -44,40 +57,15 @@ impl DocumentIndexer { Ok(vec![plain_text_param]) } -} -#[async_trait] -impl Indexer for DocumentIndexer { - async fn index( + async fn embeddings( &self, - object_id: &str, - doc_state: Vec, + mut params: Vec, ) -> Result, AppError> { - let cloned_object_id = object_id.to_string(); - let collab = tokio::spawn(async move { - Collab::new_with_source( - CollabOrigin::Server, - &cloned_object_id, - DataSource::DocStateV1(doc_state), - vec![], - false, - ) - .map_err(|e| AppError::Internal(e.into())) - }) - .await - .map_err(|e| AppError::Internal(e.into()))??; - - let document = Document::open(collab).map_err(|e| AppError::Internal(e.into()))?; - let mut params = match Self::get_document_contents(&document) { - Ok(result) => result, - Err(err) => { - if cfg!(debug_assertions) { - tracing::warn!("failed to get document:{} error:{}", object_id, err); - } - return Ok(None); - }, + let object_id = match params.first() { + None => return Ok(None), + Some(first) => first.object_id.clone(), }; - let contents: Vec<_> = params .iter() .map(|fragment| fragment.content.clone()) diff --git a/services/appflowy-collaborate/src/indexer/provider.rs b/services/appflowy-collaborate/src/indexer/provider.rs index 4288720a..4f02ca1e 100644 --- a/services/appflowy-collaborate/src/indexer/provider.rs +++ b/services/appflowy-collaborate/src/indexer/provider.rs @@ -5,7 +5,10 @@ use std::sync::Arc; use actix::dev::Stream; use async_stream::try_stream; use async_trait::async_trait; +use collab::core::collab::DataSource; +use collab::core::origin::CollabOrigin; use collab::entity::EncodedCollab; +use collab::preclude::Collab; use collab_entity::CollabType; use sqlx::PgPool; use tokio_stream::StreamExt; @@ -16,17 +19,35 @@ use appflowy_ai_client::client::AppFlowyAIClient; use database::collab::select_blob_from_af_collab; use database::index::{get_collabs_without_embeddings, upsert_collab_embeddings}; use database::workspace::select_workspace_settings; -use database_entity::dto::{AFCollabEmbeddings, CollabParams}; +use database_entity::dto::{AFCollabEmbeddingParams, AFCollabEmbeddings, CollabParams}; use crate::indexer::DocumentIndexer; #[async_trait] pub trait Indexer: Send + Sync { + fn embedding_params(&self, collab: &Collab) -> Result, AppError>; + + async fn embeddings( + &self, + params: Vec, + ) -> Result, AppError>; + async fn index( &self, object_id: &str, - doc_state: Vec, - ) -> Result, AppError>; + encoded_collab: EncodedCollab, + ) -> Result, AppError> { + let collab = Collab::new_with_source( + CollabOrigin::Empty, + object_id, + DataSource::DocStateV1(encoded_collab.doc_state.into()), + vec![], + false, + ) + .map_err(|err| AppError::Internal(err.into()))?; + let embedding_params = self.embedding_params(&collab)?; + self.embeddings(embedding_params).await + } } /// A structure responsible for resolving different [Indexer] types for different [CollabType]s, @@ -119,14 +140,15 @@ impl IndexerProvider { async fn index_collab(&self, unindexed: UnindexedCollab) -> Result<(), AppError> { if let Some(indexer) = self.indexer_cache.get(&unindexed.collab_type) { - if let Some(embeddings) = indexer - .index(&unindexed.object_id, unindexed.collab.doc_state.into()) - .await? - { + let workspace_id = unindexed.workspace_id; + let embeddings = indexer + .index(&unindexed.object_id, unindexed.collab) + .await?; + if let Some(embeddings) = embeddings { let mut tx = self.db.begin().await?; upsert_collab_embeddings( &mut tx, - &unindexed.workspace_id, + &workspace_id, embeddings.tokens_consumed, &embeddings.params, ) @@ -142,9 +164,8 @@ impl IndexerProvider { params: &CollabParams, ) -> Result, AppError> { if let Some(indexer) = self.indexer_for(params.collab_type.clone()) { - let embeddings = indexer - .index(¶ms.object_id, params.encoded_collab_v1.clone()) - .await?; + let encoded_collab = EncodedCollab::decode_from_bytes(¶ms.encoded_collab_v1)?; + let embeddings = indexer.index(¶ms.object_id, encoded_collab).await?; Ok(embeddings) } else { Ok(None) diff --git a/src/api/workspace.rs b/src/api/workspace.rs index 6e2dec72..2c1c628d 100644 --- a/src/api/workspace.rs +++ b/src/api/workspace.rs @@ -920,12 +920,13 @@ async fn update_collab_handler( .can_index_workspace(&workspace_id) .await? { - let encoded = EncodedCollab::decode_from_bytes(¶ms.encoded_collab_v1) - .map_err(|e| AppError::Internal(e.into()))?; - match indexer - .index(¶ms.object_id, encoded.doc_state.into()) - .await - { + let encoded = EncodedCollab::decode_from_bytes(¶ms.encoded_collab_v1).map_err(|err| { + AppError::InvalidRequest(format!( + "Failed to decode collab `{}`: {}", + params.object_id, err + )) + })?; + match indexer.index(¶ms.object_id, encoded).await { Ok(embeddings) => params.embeddings = embeddings, Err(err) => tracing::warn!( "failed to fetch embeddings for document {}: {}",