475 lines
13 KiB
Rust
475 lines
13 KiB
Rust
use crate::collab_handle::CollabHandle;
|
|
use crate::error::Result;
|
|
use crate::indexer::{Fragment, Indexer, UnindexedCollab};
|
|
use collab::core::collab::DataSource;
|
|
use collab::core::origin::CollabOrigin;
|
|
use collab_document::document::Document;
|
|
use collab_entity::CollabType;
|
|
use collab_stream::client::CollabRedisStream;
|
|
use collab_stream::model::{CollabControlEvent, StreamMessage};
|
|
use collab_stream::stream_group::{ReadOption, StreamGroup};
|
|
use dashmap::mapref::entry::Entry;
|
|
use dashmap::DashMap;
|
|
use database_entity::dto::EmbeddingContentType;
|
|
use futures::StreamExt;
|
|
use std::future::Future;
|
|
use std::pin::Pin;
|
|
use std::sync::Arc;
|
|
use std::task::{Context, Poll};
|
|
use std::time::Duration;
|
|
use tokio::time::interval;
|
|
|
|
const CONSUMER_NAME: &str = "open_collab";
|
|
|
|
type Handles = Arc<DashMap<String, CollabHandle>>;
|
|
|
|
pub struct OpenCollabConsumer {
|
|
#[allow(dead_code)]
|
|
handles: Handles,
|
|
consumer_group: tokio::task::JoinHandle<()>,
|
|
}
|
|
|
|
impl OpenCollabConsumer {
|
|
pub(crate) async fn new(
|
|
redis_stream: CollabRedisStream,
|
|
indexer: Arc<dyn Indexer>,
|
|
control_stream_key: &str,
|
|
ingest_interval: Duration,
|
|
preindex: bool,
|
|
) -> Result<Self> {
|
|
let handles = Arc::new(DashMap::new());
|
|
let mut control_group = redis_stream
|
|
.collab_control_stream(control_stream_key, "indexer")
|
|
.await?;
|
|
|
|
// Handle unindexed documents
|
|
if preindex {
|
|
Self::handle_unindexed_collabs(indexer.clone()).await;
|
|
}
|
|
|
|
// Handle stale messages
|
|
let stale_messages = control_group.get_unacked_messages(CONSUMER_NAME).await?;
|
|
Self::handle_messages(
|
|
&mut control_group,
|
|
indexer.clone(),
|
|
&redis_stream,
|
|
handles.clone(),
|
|
stale_messages,
|
|
ingest_interval,
|
|
)
|
|
.await;
|
|
|
|
let weak_handles = Arc::downgrade(&handles);
|
|
let mut interval = interval(Duration::from_secs(1));
|
|
let consumer_group = tokio::spawn(async move {
|
|
loop {
|
|
interval.tick().await;
|
|
if let Ok(messages) = control_group
|
|
.consumer_messages(CONSUMER_NAME, ReadOption::Count(10))
|
|
.await
|
|
{
|
|
if let Some(handles) = weak_handles.upgrade() {
|
|
for message in &messages {
|
|
if let Ok(event) = CollabControlEvent::decode(&message.data) {
|
|
Self::handle_event(
|
|
event,
|
|
&redis_stream,
|
|
indexer.clone(),
|
|
&handles,
|
|
ingest_interval,
|
|
)
|
|
.await;
|
|
}
|
|
}
|
|
if let Err(err) = control_group.ack_messages(&messages).await {
|
|
tracing::error!("failed to ack messages: {:?}", err);
|
|
}
|
|
} else {
|
|
tracing::trace!("consumer handles dropped, exiting");
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
Ok(Self {
|
|
handles,
|
|
consumer_group,
|
|
})
|
|
}
|
|
|
|
async fn handle_unindexed_collabs(indexer: Arc<dyn Indexer>) {
|
|
let mut stream = indexer.get_unindexed_collabs();
|
|
while let Some(result) = stream.next().await {
|
|
match result {
|
|
Ok(collab) => {
|
|
if let Err(err) = Self::index_collab(&indexer, &collab).await {
|
|
tracing::warn!(
|
|
"failed to index collab {}/{}: {}",
|
|
collab.workspace_id,
|
|
collab.object_id,
|
|
err
|
|
);
|
|
}
|
|
},
|
|
Err(err) => {
|
|
tracing::error!("failed to get unindexed document: {}", err);
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn index_collab(indexer: &Arc<dyn Indexer>, collab: &UnindexedCollab) -> Result<()> {
|
|
let fragment = {
|
|
match &collab.collab_type {
|
|
CollabType::Document => {
|
|
let document = Document::from_doc_state(
|
|
CollabOrigin::Empty,
|
|
DataSource::DocStateV1(collab.collab.doc_state.to_vec()),
|
|
&collab.object_id,
|
|
vec![],
|
|
)?;
|
|
let data = document.get_document_data()?;
|
|
let content = crate::extract::document_to_plain_text(&data);
|
|
if content.is_empty() {
|
|
return Ok(());
|
|
}
|
|
Fragment {
|
|
fragment_id: collab.object_id.clone(),
|
|
object_id: collab.object_id.clone(),
|
|
collab_type: collab.collab_type.clone(),
|
|
content_type: EmbeddingContentType::PlainText,
|
|
content,
|
|
}
|
|
},
|
|
collab_type => {
|
|
tracing::warn!(
|
|
"cannot index collab {}/{} because {:?} type is not supported",
|
|
collab.workspace_id,
|
|
collab.object_id,
|
|
collab_type
|
|
);
|
|
return Ok(());
|
|
},
|
|
}
|
|
};
|
|
tracing::trace!(
|
|
"indexing collab {}/{}",
|
|
collab.workspace_id,
|
|
collab.object_id
|
|
);
|
|
indexer
|
|
.update_index(&collab.workspace_id, vec![fragment])
|
|
.await?;
|
|
Ok(())
|
|
}
|
|
|
|
#[inline]
|
|
async fn handle_messages(
|
|
control_group: &mut StreamGroup,
|
|
indexer: Arc<dyn Indexer>,
|
|
redis_stream: &CollabRedisStream,
|
|
handles: Handles,
|
|
messages: Vec<StreamMessage>,
|
|
ingest_interval: Duration,
|
|
) {
|
|
if messages.is_empty() {
|
|
return;
|
|
}
|
|
|
|
tracing::debug!("received {} messages from Redis", messages.len());
|
|
for message in &messages {
|
|
if let Ok(event) = CollabControlEvent::decode(&message.data) {
|
|
Self::handle_event(
|
|
event,
|
|
redis_stream,
|
|
indexer.clone(),
|
|
&handles,
|
|
ingest_interval,
|
|
)
|
|
.await
|
|
}
|
|
}
|
|
|
|
if let Err(err) = control_group.ack_messages(&messages).await {
|
|
tracing::error!("failed to ack stale messages: {}", err);
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
async fn handle_event(
|
|
event: CollabControlEvent,
|
|
redis_stream: &CollabRedisStream,
|
|
indexer: Arc<dyn Indexer>,
|
|
handles: &Handles,
|
|
ingest_interval: Duration,
|
|
) {
|
|
match event {
|
|
CollabControlEvent::Open {
|
|
workspace_id,
|
|
object_id,
|
|
collab_type,
|
|
doc_state,
|
|
} => match handles.entry(object_id.clone()) {
|
|
Entry::Occupied(_) => { /* do nothing */ },
|
|
Entry::Vacant(entry) => {
|
|
// create a new collab document handle, which will subscribe and apply incoming updates
|
|
let result = CollabHandle::open(
|
|
redis_stream,
|
|
indexer,
|
|
object_id.clone(),
|
|
workspace_id.clone(),
|
|
collab_type.clone(),
|
|
doc_state,
|
|
ingest_interval,
|
|
)
|
|
.await;
|
|
match result {
|
|
Ok(Some(handle)) => {
|
|
entry.insert(handle);
|
|
tracing::info!("created a new handle for {}/{}", workspace_id, object_id);
|
|
},
|
|
Ok(None) => {
|
|
tracing::debug!(
|
|
"document {}/{} of type {} is not indexable",
|
|
workspace_id,
|
|
object_id,
|
|
collab_type
|
|
);
|
|
},
|
|
Err(e) => {
|
|
tracing::error!(
|
|
"failed to open handle for {}/{}: {}",
|
|
workspace_id,
|
|
object_id,
|
|
e
|
|
);
|
|
},
|
|
}
|
|
},
|
|
},
|
|
CollabControlEvent::Close { object_id } => {
|
|
if let Some((_, handle)) = handles.remove(&object_id) {
|
|
// trigger shutdown signal and gracefully wait for handle to complete
|
|
tracing::info!("shutting down handle for {}", object_id);
|
|
handle.shutdown().await;
|
|
}
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Future for OpenCollabConsumer {
|
|
type Output = ();
|
|
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
let consumer_group = Pin::new(&mut self.consumer_group);
|
|
match consumer_group.poll(cx) {
|
|
Poll::Ready(Ok(())) => Poll::Ready(()),
|
|
Poll::Ready(Err(err)) => {
|
|
tracing::error!("consumer group failed: {}", err);
|
|
Poll::Ready(())
|
|
},
|
|
Poll::Pending => Poll::Pending,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Drop for OpenCollabConsumer {
|
|
fn drop(&mut self) {
|
|
self.consumer_group.abort();
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use crate::consumer::OpenCollabConsumer;
|
|
use crate::indexer::PostgresIndexer;
|
|
use crate::test_utils::{
|
|
ai_client, collab_update_forwarder, db_pool, redis_stream, setup_collab,
|
|
};
|
|
use collab::core::collab::MutexCollab;
|
|
use collab::preclude::Collab;
|
|
use collab_document::document::Document;
|
|
use collab_document::document_data::default_document_data;
|
|
use collab_entity::CollabType;
|
|
use collab_stream::model::CollabControlEvent;
|
|
use database::index::get_index_status;
|
|
use serde_json::json;
|
|
use sqlx::Row;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
|
|
#[tokio::test]
|
|
async fn graceful_handle_shutdown() {
|
|
let _ = env_logger::builder().is_test(true).try_init();
|
|
|
|
let redis_stream = redis_stream().await;
|
|
let uid = rand::random();
|
|
let object_id = uuid::Uuid::new_v4();
|
|
|
|
let db = db_pool().await;
|
|
let openai = ai_client();
|
|
|
|
let mut collab = Collab::new(
|
|
uid,
|
|
object_id.to_string(),
|
|
"device-1".to_string(),
|
|
vec![],
|
|
false,
|
|
);
|
|
collab.initialize();
|
|
let collab = Arc::new(MutexCollab::new(collab));
|
|
let doc_data = default_document_data();
|
|
|
|
let text_id = doc_data
|
|
.meta
|
|
.text_map
|
|
.as_ref()
|
|
.unwrap()
|
|
.iter()
|
|
.next()
|
|
.unwrap()
|
|
.0
|
|
.clone();
|
|
let document = Document::create_with_data(collab.clone(), doc_data).unwrap();
|
|
let encoded_collab = document.encode_collab().unwrap();
|
|
|
|
let workspace_id = setup_collab(&db, uid, object_id, &encoded_collab).await;
|
|
|
|
let indexer = Arc::new(PostgresIndexer::new(openai, db.clone()));
|
|
|
|
let mut control_stream = redis_stream
|
|
.collab_control_stream("af_collab_control", "indexer")
|
|
.await
|
|
.unwrap();
|
|
|
|
let update_stream = redis_stream
|
|
.collab_update_stream(&workspace_id.to_string(), &object_id.to_string(), "indexer")
|
|
.await
|
|
.unwrap();
|
|
|
|
let _s = collab_update_forwarder(collab.clone(), update_stream.clone());
|
|
let consumer = OpenCollabConsumer::new(
|
|
redis_stream,
|
|
indexer.clone(),
|
|
"af_collab_control",
|
|
Duration::from_secs(1), // interval longer than test timeout
|
|
false,
|
|
)
|
|
.await
|
|
.unwrap();
|
|
|
|
control_stream
|
|
.insert_message(CollabControlEvent::Open {
|
|
workspace_id: workspace_id.to_string(),
|
|
object_id: object_id.to_string(),
|
|
collab_type: CollabType::Document,
|
|
doc_state: encoded_collab.doc_state.to_vec(),
|
|
})
|
|
.await
|
|
.unwrap();
|
|
|
|
document.apply_text_delta(&text_id, json!([{"insert": "test-value"}]).to_string());
|
|
|
|
tokio::time::sleep(Duration::from_millis(1500)).await;
|
|
assert!(
|
|
consumer.handles.contains_key(&object_id.to_string()),
|
|
"in reaction to open control event, a corresponding handle should be created"
|
|
);
|
|
|
|
control_stream
|
|
.insert_message(CollabControlEvent::Close {
|
|
object_id: object_id.to_string(),
|
|
})
|
|
.await
|
|
.unwrap();
|
|
|
|
tokio::time::sleep(Duration::from_millis(2000)).await;
|
|
assert!(
|
|
!consumer.handles.contains_key(&object_id.to_string()),
|
|
"in reaction to close control event, a corresponding handle should be destroyed"
|
|
);
|
|
|
|
let contents = sqlx::query("SELECT content from af_collab_embeddings WHERE oid = $1")
|
|
.bind(object_id.to_string())
|
|
.fetch_all(&db)
|
|
.await
|
|
.unwrap();
|
|
|
|
assert_ne!(contents.len(), 0);
|
|
let content: Option<String> = contents[0].get(0);
|
|
assert_eq!(content.as_deref(), Some("test-value "));
|
|
}
|
|
|
|
#[ignore]
|
|
#[tokio::test]
|
|
async fn index_documents_at_start() {
|
|
let _ = env_logger::builder().is_test(true).try_init();
|
|
|
|
let redis_stream = redis_stream().await;
|
|
let uid = rand::random();
|
|
let object_id = uuid::Uuid::new_v4();
|
|
|
|
let db = db_pool().await;
|
|
let openai = ai_client();
|
|
|
|
let mut collab = Collab::new(
|
|
uid,
|
|
object_id.to_string(),
|
|
"device-1".to_string(),
|
|
vec![],
|
|
false,
|
|
);
|
|
collab.initialize();
|
|
let collab = Arc::new(MutexCollab::new(collab));
|
|
let doc_data = default_document_data();
|
|
let document = Document::create_with_data(collab.clone(), doc_data).unwrap();
|
|
let encoded_collab = document.encode_collab().unwrap();
|
|
|
|
let workspace_id = setup_collab(&db, uid, object_id, &encoded_collab).await;
|
|
|
|
{
|
|
let mut tx = db.begin().await.unwrap();
|
|
let status = get_index_status(
|
|
&mut tx,
|
|
&workspace_id,
|
|
&object_id.to_string(),
|
|
CollabType::Document as i32,
|
|
)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(
|
|
status,
|
|
Some(false),
|
|
"collab should not have embeddings at start"
|
|
);
|
|
}
|
|
|
|
let indexer = Arc::new(PostgresIndexer::new(openai, db.clone()));
|
|
|
|
let _consumer = OpenCollabConsumer::new(
|
|
redis_stream,
|
|
indexer.clone(),
|
|
"af_collab_control",
|
|
Duration::from_secs(1), // interval longer than test timeout
|
|
true,
|
|
)
|
|
.await
|
|
.unwrap();
|
|
|
|
{
|
|
let mut tx = db.begin().await.unwrap();
|
|
let status = get_index_status(
|
|
&mut tx,
|
|
&workspace_id,
|
|
&object_id.to_string(),
|
|
CollabType::Document as i32,
|
|
)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(status, Some(true), "collab should be indexed after start");
|
|
}
|
|
}
|
|
}
|