AppFlowy-Cloud/services/appflowy-indexer/src/collab_handle.rs

377 lines
12 KiB
Rust

use std::collections::{HashMap, HashSet};
use std::pin::Pin;
use std::sync::{Arc, Weak};
use std::time::{Duration, Instant};
use collab::core::collab::TransactionMutExt;
use collab::core::collab::{DataSource, MutexCollab};
use collab::core::origin::CollabOrigin;
use collab::preclude::updates::decoder::Decode;
use collab::preclude::Update;
use collab_document::document::Document;
use collab_entity::CollabType;
use futures::{Stream, StreamExt};
use tokio::select;
use tokio::task::JoinSet;
use tokio::time::interval;
use tokio_util::sync::CancellationToken;
use tracing::instrument;
use uuid::Uuid;
use collab_stream::client::CollabRedisStream;
use collab_stream::model::{CollabUpdateEvent, StreamMessage};
use collab_stream::stream_group::{ReadOption, StreamGroup};
use crate::error::Result;
use crate::indexer::{Fragment, FragmentID, IndexStatus, Indexer};
use crate::watchers::DocumentWatcher;
const CONSUMER_NAME: &str = "open_collab_handle";
pub trait Indexable: Send + Sync {
fn get_collab(&self) -> &MutexCollab;
fn changes(&self) -> Pin<Box<dyn Stream<Item = FragmentUpdate> + Send + Sync>>;
}
#[allow(dead_code)]
pub struct CollabHandle {
content: Arc<dyn Indexable>,
tasks: JoinSet<()>,
closing: CancellationToken,
}
impl CollabHandle {
pub(crate) async fn open(
redis_stream: &CollabRedisStream,
indexer: Arc<dyn Indexer>,
object_id: String,
workspace_id: String,
collab_type: CollabType,
doc_state: Vec<u8>,
ingest_interval: Duration,
) -> Result<Option<Self>> {
let closing = CancellationToken::new();
let was_indexed = match indexer.index_status(&object_id).await? {
IndexStatus::Indexed => true,
IndexStatus::NotIndexed => false,
IndexStatus::NotPermitted => {
tracing::trace!(
"document {}/{} is not permitted to be indexed",
workspace_id,
object_id
);
return Ok(None);
},
};
let content: Arc<dyn Indexable> = match collab_type {
CollabType::Document => {
let content = Document::from_doc_state(
CollabOrigin::Empty,
DataSource::DocStateV1(doc_state),
&object_id,
vec![],
)?;
let watcher = DocumentWatcher::new(object_id.clone(), content, !was_indexed)?;
Arc::new(watcher)
},
_ => return Ok(None),
};
let group_name = format!("indexer_{}:{}", workspace_id, object_id);
let mut update_stream = redis_stream
.collab_update_stream(&workspace_id, &object_id, &group_name)
.await
.unwrap();
let messages = update_stream.get_unacked_messages(CONSUMER_NAME).await?;
if !messages.is_empty() {
Self::handle_collab_updates(&mut update_stream, content.get_collab(), messages).await?;
}
let workspace_id =
Uuid::parse_str(&workspace_id).map_err(crate::error::Error::InvalidWorkspace)?;
let mut tasks = JoinSet::new();
tasks.spawn(Self::receive_collab_updates(
update_stream,
Arc::downgrade(&content),
object_id.clone(),
workspace_id,
ingest_interval,
closing.clone(),
));
tasks.spawn(Self::process_content_changes(
content.changes(),
indexer,
object_id,
workspace_id,
ingest_interval,
closing.clone(),
));
Ok(Some(Self {
content,
tasks,
closing,
}))
}
/// In regular time intervals, receive yrs updates and apply them to the locall in-memory collab
/// representation. This should emit index content events, which we listen to on
/// [Self::receive_index_events].
async fn receive_collab_updates(
mut update_stream: StreamGroup,
content: Weak<dyn Indexable>,
object_id: String,
workspace_id: Uuid,
ingest_interval: Duration,
closing: CancellationToken,
) {
let mut interval = interval(ingest_interval);
loop {
select! {
_ = closing.cancelled() => {
tracing::trace!("document {}/{} watcher cancelled, stopping.", workspace_id, object_id);
return;
},
_ = interval.tick() => {
let result = update_stream
.consumer_messages(CONSUMER_NAME, ReadOption::Count(100))
.await;
match result {
Ok(messages) => {
if let Some(content) = content.upgrade() {
// check if we received empty message batch, if not: update the collab
if !messages.is_empty() {
if let Err(err) = Self::handle_collab_updates(&mut update_stream, content.get_collab(), messages).await {
tracing::error!("document {}/{} watcher failed to handle updates: {}", workspace_id, object_id, err);
}
}
} else {
tracing::trace!("collab dropped, stopping consumer");
return;
}
},
Err(err) => {
tracing::error!("document {}/{} watcher failed to receive messages: {}", workspace_id, object_id, err);
},
}
}
}
}
}
#[instrument(skip(update_stream, collab, messages), fields(messages = messages.len()))]
async fn handle_collab_updates(
update_stream: &mut StreamGroup,
collab: &MutexCollab,
messages: Vec<StreamMessage>,
) -> Result<()> {
if let Some(collab) = collab.try_lock() {
let mut txn = collab.try_transaction_mut()?;
for message in &messages {
match CollabUpdateEvent::decode(&message.data) {
Ok(CollabUpdateEvent::UpdateV1 { encode_update }) => {
let update = Update::decode_v1(&encode_update)?;
txn.try_apply_update(update)?;
},
Err(err) => tracing::error!("failed to decode update event: {}", err),
}
}
} else {
tracing::warn!("failed to obtain a collab lock");
};
update_stream.ack_messages(&messages).await?;
Ok(())
}
async fn process_content_changes(
mut updates: Pin<Box<dyn Stream<Item = FragmentUpdate> + Send + Sync>>,
indexer: Arc<dyn Indexer>,
object_id: String,
workspace_id: Uuid,
ingest_interval: Duration,
token: CancellationToken,
) {
let mut last_update = Instant::now();
let mut inserts = HashMap::new();
let mut removals = HashSet::new();
let mut interval = interval(ingest_interval);
loop {
select! {
_ = interval.tick() => {
match Self::publish_updates(&indexer, &workspace_id, &mut inserts, &mut removals).await {
Ok(_) => last_update = Instant::now(),
Err(err) => tracing::error!("document {}/{} watcher failed to publish fragment updates: {}", workspace_id, object_id, err),
}
}
_ = token.cancelled() => {
tracing::trace!("document {}/{} watcher closing signal received, flushing remaining updates", workspace_id, object_id);
if let Err(err) = Self::publish_updates(&indexer, &workspace_id, &mut inserts, &mut removals).await {
tracing::error!("document {}/{} watcher failed to publish fragment updates: {}", workspace_id, object_id, err);
}
return;
},
Some(update) = updates.next() => {
match update {
FragmentUpdate::Update(doc) => {
if doc.content.is_empty() {
// we count empty blocks as removals
removals.insert(doc.fragment_id.clone());
} else {
inserts.insert(doc.fragment_id.clone(), doc);
}
}
FragmentUpdate::Removed(id) => {
removals.insert(id);
}
}
let now = Instant::now();
if now.duration_since(last_update) > ingest_interval {
match Self::publish_updates(&indexer, &workspace_id, &mut inserts, &mut removals).await {
Ok(_) => last_update = now,
Err(err) => tracing::error!("document {}/{} watcher failed to publish fragment updates: {}", workspace_id, object_id, err),
}
}
}
}
}
}
async fn publish_updates(
indexer: &Arc<dyn Indexer>,
workspace_id: &Uuid,
inserts: &mut HashMap<FragmentID, Fragment>,
removals: &mut HashSet<FragmentID>,
) -> Result<()> {
if inserts.is_empty() && removals.is_empty() {
return Ok(());
}
let inserts: Vec<_> = inserts.drain().map(|(_, doc)| doc).collect();
if !inserts.is_empty() {
tracing::info!("updating indexes for {} fragments", inserts.len());
indexer.update_index(workspace_id, inserts).await?;
}
if !removals.is_empty() {
tracing::info!("removing indexes for {} fragments", removals.len());
indexer
.remove(&removals.drain().collect::<Vec<_>>())
.await?;
}
Ok(())
}
pub async fn shutdown(mut self) {
self.closing.cancel();
while self.tasks.join_next().await.is_some() { /* wait for all tasks to finish */ }
}
}
#[allow(dead_code)]
#[derive(Debug, Clone, PartialEq)]
pub enum FragmentUpdate {
Update(Fragment),
Removed(FragmentID),
}
#[cfg(test)]
mod test {
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
use collab::core::collab::MutexCollab;
use collab::preclude::Collab;
use collab_document::document::Document;
use collab_entity::CollabType;
use sqlx::{Postgres, Row};
use workspace_template::document::get_started::get_started_document_data;
use crate::collab_handle::CollabHandle;
use crate::indexer::{Indexer, PostgresIndexer};
use crate::test_utils::{
ai_client, collab_update_forwarder, db_pool, redis_stream, setup_collab,
};
#[tokio::test]
async fn test_indexing_for_new_initialized_document() {
let _ = env_logger::builder().is_test(true).try_init();
let redis_stream = redis_stream().await;
let uid = rand::random();
let object_id = uuid::Uuid::new_v4();
let mut collab = Collab::new(
uid,
object_id.to_string(),
"device-1".to_string(),
vec![],
false,
);
collab.initialize();
let collab = Arc::new(MutexCollab::new(collab));
let encoded_collab = {
let doc_data = get_started_document_data().unwrap();
let document = Document::create_with_data(collab.clone(), doc_data).unwrap();
document.encode_collab().unwrap()
};
let db = db_pool().await;
let workspace_id = setup_collab(&db, uid, object_id, &encoded_collab).await;
let object_id = object_id.to_string();
let openai = ai_client();
let indexer: Arc<dyn Indexer> = Arc::new(PostgresIndexer::new(openai, db));
let stream_group = redis_stream
.collab_update_stream(&workspace_id.to_string(), &object_id, "indexer")
.await
.unwrap();
let _s = collab_update_forwarder(collab, stream_group.clone());
let _handle = CollabHandle::open(
&redis_stream,
indexer.clone(),
object_id.clone(),
workspace_id.to_string(),
CollabType::Document,
encoded_collab.doc_state.to_vec(),
Duration::from_millis(50),
)
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(2000)).await;
let db = db_pool().await;
let contents = sqlx::query("SELECT content from af_collab_embeddings WHERE oid = $1")
.bind(&object_id)
.fetch_all(&db)
.await
.unwrap();
let contents = contents
.into_iter()
.map(|r| r.get::<String, _>("content"))
.collect::<HashSet<_>>();
assert_eq!(contents.len(), 1);
let tokens: i64 = sqlx::query_scalar::<Postgres, Option<i64>>(
"SELECT index_tokens_consumed from af_workspace_ai_usage WHERE workspace_id = $1",
)
.bind(workspace_id)
.fetch_one(&db)
.await
.unwrap()
.unwrap_or(0);
assert_ne!(tokens, 0);
}
}