chore: save collab embeddings on collab create

This commit is contained in:
Bartosz Sypytkowski 2024-06-25 12:53:10 +02:00
parent 8bf6aff923
commit 5d3574d643
3 changed files with 112 additions and 24 deletions

View File

@ -243,11 +243,18 @@ where
collab_type collab_type
); );
let indexer = self let mut indexer = self.indexer_provider.indexer_for(collab_type.clone());
.indexer_provider if indexer.is_some() {
.indexer_for(workspace_id, collab_type.clone()) if !self
.await .indexer_provider
.map_err(|err| RealtimeError::Internal(err.into()))?; .can_index_workspace(workspace_id)
.await
.map_err(|e| RealtimeError::Internal(e.into()))?
{
tracing::trace!("workspace {} indexing is disabled", workspace_id);
indexer = None;
}
}
let group = Arc::new( let group = Arc::new(
CollabGroup::new( CollabGroup::new(
user.uid, user.uid,

View File

@ -19,11 +19,27 @@ use appflowy_ai_client::client::AppFlowyAIClient;
use database::collab::select_blob_from_af_collab; use database::collab::select_blob_from_af_collab;
use database::index::{get_collabs_without_embeddings, upsert_collab_embeddings}; use database::index::{get_collabs_without_embeddings, upsert_collab_embeddings};
use database::workspace::select_workspace_settings; use database::workspace::select_workspace_settings;
use database_entity::dto::AFCollabEmbeddings; use database_entity::dto::{AFCollabEmbeddings, CollabParams};
#[async_trait] #[async_trait]
pub trait Indexer: Send + Sync { pub trait Indexer: Send + Sync {
async fn index(&self, collab: MutexCollab) -> Result<AFCollabEmbeddings, AppError>; async fn index(&self, collab: MutexCollab) -> Result<AFCollabEmbeddings, AppError>;
async fn index_encoded(
&self,
object_id: &str,
encoded_collab: EncodedCollab,
) -> Result<AFCollabEmbeddings, AppError> {
let collab = Collab::new_with_source(
CollabOrigin::Empty,
object_id,
DataSource::DocStateV1(encoded_collab.doc_state.into()),
vec![],
false,
)
.map_err(|e| AppError::Internal(e.into()))?;
self.index(MutexCollab::new(collab)).await
}
} }
/// A structure responsible for resolving different [Indexer] types for different [CollabType]s, /// A structure responsible for resolving different [Indexer] types for different [CollabType]s,
@ -43,26 +59,22 @@ impl IndexerProvider {
}) })
} }
/// Returns indexer for a specific type of [Collab] object. pub async fn can_index_workspace(&self, workspace_id: &str) -> Result<bool, AppError> {
/// If collab of given type is not supported or workspace it belongs to has indexing disabled,
/// returns `None`.
pub async fn indexer_for(
&self,
workspace_id: &str,
collab_type: CollabType,
) -> Result<Option<Arc<dyn Indexer>>, AppError> {
let indexer = self.indexer_cache.get(&collab_type).cloned();
if indexer.is_none() {
return Ok(None);
}
let uuid = Uuid::parse_str(workspace_id)?; let uuid = Uuid::parse_str(workspace_id)?;
let settings = select_workspace_settings(&self.db, &uuid).await?; let settings = select_workspace_settings(&self.db, &uuid).await?;
match settings { match settings {
Some(settings) if settings.disable_search_indexing => Ok(None), None => Ok(true),
_ => Ok(indexer), Some(settings) => Ok(!settings.disable_search_indexing),
} }
} }
/// Returns indexer for a specific type of [Collab] object.
/// If collab of given type is not supported or workspace it belongs to has indexing disabled,
/// returns `None`.
pub fn indexer_for(&self, collab_type: CollabType) -> Option<Arc<dyn Indexer>> {
self.indexer_cache.get(&collab_type).cloned()
}
fn get_unindexed_collabs( fn get_unindexed_collabs(
&self, &self,
) -> Pin<Box<dyn Stream<Item = Result<UnindexedCollab, anyhow::Error>>>> { ) -> Pin<Box<dyn Stream<Item = Result<UnindexedCollab, anyhow::Error>>>> {
@ -140,6 +152,23 @@ impl IndexerProvider {
} }
Ok(()) Ok(())
} }
pub async fn create_collab_embeddings(
&self,
params: &CollabParams,
) -> Result<Option<AFCollabEmbeddings>, AppError> {
if let Some(indexer) = self.indexer_for(params.collab_type.clone()) {
let embeddings = indexer
.index_encoded(
&params.object_id,
EncodedCollab::decode_from_bytes(&params.encoded_collab_v1)?,
)
.await?;
Ok(Some(embeddings))
} else {
Ok(None)
}
}
} }
struct UnindexedCollab { struct UnindexedCollab {

View File

@ -5,6 +5,7 @@ use actix_web::{web, Scope};
use actix_web::{HttpRequest, Result}; use actix_web::{HttpRequest, Result};
use anyhow::{anyhow, Context}; use anyhow::{anyhow, Context};
use bytes::BytesMut; use bytes::BytesMut;
use collab::entity::EncodedCollab;
use collab_entity::CollabType; use collab_entity::CollabType;
use prost::Message as ProstMessage; use prost::Message as ProstMessage;
use sqlx::types::uuid; use sqlx::types::uuid;
@ -486,7 +487,7 @@ async fn create_collab_handler(
}, },
}; };
let (params, workspace_id) = params.split(); let (mut params, workspace_id) = params.split();
if params.object_id == workspace_id { if params.object_id == workspace_id {
// Only the object with [CollabType::Folder] can have the same object_id as workspace_id. But // Only the object with [CollabType::Folder] can have the same object_id as workspace_id. But
@ -506,6 +507,17 @@ async fn create_collab_handler(
); );
} }
if state
.indexer_provider
.can_index_workspace(&workspace_id)
.await?
{
params.embeddings = state
.indexer_provider
.create_collab_embeddings(&params)
.await?;
}
let mut transaction = state let mut transaction = state
.pg_pool .pg_pool
.begin() .begin()
@ -597,7 +609,11 @@ async fn batch_create_collab_handler(
if collab_params_list.is_empty() { if collab_params_list.is_empty() {
return Err(AppError::InvalidRequest("Empty collab params list".to_string()).into()); return Err(AppError::InvalidRequest("Empty collab params list".to_string()).into());
} }
for params in collab_params_list { let can_index = state
.indexer_provider
.can_index_workspace(&workspace_id)
.await?;
for mut params in collab_params_list {
let object_id = params.object_id.clone(); let object_id = params.object_id.clone();
if validate_encode_collab( if validate_encode_collab(
&params.object_id, &params.object_id,
@ -607,6 +623,15 @@ async fn batch_create_collab_handler(
.await .await
.is_ok() .is_ok()
{ {
params.embeddings = if can_index {
state
.indexer_provider
.create_collab_embeddings(&params)
.await?
} else {
None
};
state state
.collab_access_control_storage .collab_access_control_storage
.insert_new_collab(&workspace_id, &uid, params) .insert_new_collab(&workspace_id, &uid, params)
@ -672,8 +697,18 @@ async fn create_collab_list_handler(
.await .await
.map_err(|err| AppError::Internal(anyhow!("Failed to start inserting collab: {}", err)))?; .map_err(|err| AppError::Internal(anyhow!("Failed to start inserting collab: {}", err)))?;
for params in valid_items { let can_index = state
.indexer_provider
.can_index_workspace(&workspace_id)
.await?;
for mut params in valid_items {
let _object_id = params.object_id.clone(); let _object_id = params.object_id.clone();
if can_index {
params.embeddings = state
.indexer_provider
.create_collab_embeddings(&params)
.await?;
}
state state
.collab_access_control_storage .collab_access_control_storage
.insert_new_collab_with_transaction(&workspace_id, &uid, params, &mut transaction) .insert_new_collab_with_transaction(&workspace_id, &uid, params, &mut transaction)
@ -851,7 +886,24 @@ async fn update_collab_handler(
let uid = state.user_cache.get_user_uid(&user_uuid).await?; let uid = state.user_cache.get_user_uid(&user_uuid).await?;
let create_params = CreateCollabParams::from((workspace_id.to_string(), params)); let create_params = CreateCollabParams::from((workspace_id.to_string(), params));
let (params, workspace_id) = create_params.split(); let (mut params, workspace_id) = create_params.split();
if let Some(indexer) = state
.indexer_provider
.indexer_for(params.collab_type.clone())
{
if state
.indexer_provider
.can_index_workspace(&workspace_id)
.await?
{
let encoded_collab = EncodedCollab::decode_from_bytes(&params.encoded_collab_v1)?;
params.embeddings = Some(
indexer
.index_encoded(&params.object_id, encoded_collab)
.await?,
);
}
}
state state
.collab_access_control_storage .collab_access_control_storage
.insert_or_update_collab(&workspace_id, &uid, params, false) .insert_or_update_collab(&workspace_id, &uid, params, false)