AppFlowy-Cloud/services/appflowy-indexer/src/indexer.rs

343 lines
10 KiB
Rust

use std::pin::Pin;
use async_stream::try_stream;
use async_trait::async_trait;
use collab::entity::EncodedCollab;
use collab::error::CollabError;
use collab_entity::CollabType;
use futures::Stream;
use openai_dive::v1::api::Client;
use openai_dive::v1::models::EmbeddingsEngine;
use openai_dive::v1::resources::embedding::{
EmbeddingEncodingFormat, EmbeddingInput, EmbeddingOutput, EmbeddingParameters,
};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use uuid::Uuid;
use database::collab::select_blob_from_af_collab;
use database::index::{
get_collabs_without_embeddings, get_index_status, remove_collab_embeddings,
upsert_collab_embeddings,
};
use database_entity::dto::{AFCollabEmbeddingParams, EmbeddingContentType};
use crate::error::Result;
#[async_trait]
pub trait Indexer: Send + Sync {
/// Check if document with given id has been already a corresponding index entry.
async fn index_status(&self, object_id: &str) -> Result<IndexStatus>;
async fn update_index(&self, workspace_id: &Uuid, documents: Vec<Fragment>) -> Result<()>;
async fn remove(&self, ids: &[FragmentID]) -> Result<()>;
/// Returns a list of object ids, that have not been indexed yet.
fn get_unindexed_collabs(&self) -> Pin<Box<dyn Stream<Item = Result<UnindexedCollab>>>>;
}
pub struct UnindexedCollab {
pub workspace_id: Uuid,
pub object_id: String,
pub collab_type: CollabType,
pub collab: EncodedCollab,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum IndexStatus {
/// Document is indexed and up-to-date.
Indexed,
/// Document is not indexed.
NotIndexed,
/// Document should never be indexed.
NotPermitted,
}
pub type FragmentID = String;
/// Fragment represents a single piece of indexable data.
/// This can be a piece of document (like block), that belongs to a document.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Fragment {
/// Unique fragment identifier.
pub fragment_id: FragmentID,
/// Object, which this fragment belongs to.
pub object_id: String,
/// Type of the document object.
pub collab_type: CollabType,
/// Type of the content, current fragment represents.
pub content_type: EmbeddingContentType,
/// Content of the fragment.
pub content: String,
}
/// Fragment represents a single piece of indexable data.
/// This can be a piece of document (like block), that belongs to a document.
#[derive(Debug, Clone, PartialEq)]
struct EmbedFragment {
/// Unique fragment identifier.
pub fragment_id: FragmentID,
/// Object, which this fragment belongs to.
pub object_id: String,
/// Type of the document object.
pub collab_type: CollabType,
/// Content of the fragment.
pub content: String,
pub content_type: EmbeddingContentType,
pub embedding: Option<Vec<f32>>,
}
impl From<Fragment> for EmbedFragment {
fn from(fragment: Fragment) -> Self {
EmbedFragment {
fragment_id: fragment.fragment_id,
object_id: fragment.object_id,
collab_type: fragment.collab_type,
content: fragment.content,
content_type: fragment.content_type,
embedding: None,
}
}
}
impl From<EmbedFragment> for AFCollabEmbeddingParams {
fn from(f: EmbedFragment) -> Self {
AFCollabEmbeddingParams {
fragment_id: f.fragment_id,
object_id: f.object_id,
collab_type: f.collab_type,
content_type: f.content_type,
content: f.content,
embedding: f.embedding,
}
}
}
pub struct PostgresIndexer {
openai: Client,
db: PgPool,
}
impl PostgresIndexer {
#[allow(dead_code)]
pub async fn open(openai_api_key: &str, pg_conn: &str) -> Result<Self> {
let openai = Client::new(openai_api_key.to_string());
let db = PgPool::connect(pg_conn).await?;
Ok(Self { openai, db })
}
#[allow(dead_code)]
pub fn new(openai: Client, db: PgPool) -> Self {
Self { openai, db }
}
async fn get_embeddings(&self, fragments: Vec<Fragment>) -> Result<Embeddings> {
let inputs: Vec<_> = fragments
.iter()
.map(|fragment| fragment.content.clone())
.collect();
let resp = self
.openai
.embeddings()
.create(EmbeddingParameters {
input: EmbeddingInput::StringArray(inputs),
model: EmbeddingsEngine::TextEmbedding3Small.to_string(),
encoding_format: Some(EmbeddingEncodingFormat::Float),
dimensions: Some(1536), // text-embedding-3-small default number of dimensions
user: None,
})
.await
.map_err(|e| crate::error::Error::OpenAI(e.to_string()))?;
tracing::trace!("fetched {} embeddings", resp.data.len());
let tokens_used = if let Some(usage) = resp.usage {
tracing::info!("OpenAI API index tokens used: {}", usage.total_tokens);
usage.total_tokens
} else {
0
};
let mut fragments: Vec<_> = fragments.into_iter().map(EmbedFragment::from).collect();
for e in resp.data.into_iter() {
let embedding = match e.embedding {
EmbeddingOutput::Float(embedding) => embedding
.into_iter()
.map(|f| f as f32)
.collect::<Vec<f32>>(),
EmbeddingOutput::Base64(_) => unreachable!("Unexpected base64 encoding"),
};
fragments[e.index as usize].embedding = Some(embedding);
}
Ok(Embeddings {
tokens_used,
fragments,
})
}
async fn store_embeddings(&self, workspace_id: &Uuid, embeddings: Embeddings) -> Result<()> {
tracing::trace!(
"storing {} embeddings inside of vector database",
embeddings.fragments.len()
);
let mut tx = self.db.begin().await?;
upsert_collab_embeddings(
&mut tx,
workspace_id,
embeddings.tokens_used,
embeddings
.fragments
.into_iter()
.map(EmbedFragment::into)
.collect(),
)
.await?;
tx.commit().await?;
Ok(())
}
}
struct Embeddings {
tokens_used: u32,
fragments: Vec<EmbedFragment>,
}
#[async_trait]
impl Indexer for PostgresIndexer {
async fn index_status(&self, object_id: &str) -> Result<IndexStatus> {
let found = get_index_status(&mut self.db.begin().await?, object_id).await?;
match found {
None => Ok(IndexStatus::NotPermitted),
Some(true) => Ok(IndexStatus::Indexed),
Some(false) => Ok(IndexStatus::NotIndexed),
}
}
async fn update_index(&self, workspace_id: &Uuid, documents: Vec<Fragment>) -> Result<()> {
let embeddings = self.get_embeddings(documents).await?;
self.store_embeddings(workspace_id, embeddings).await?;
Ok(())
}
async fn remove(&self, ids: &[FragmentID]) -> Result<()> {
let mut tx = self.db.begin().await?;
remove_collab_embeddings(&mut tx, ids).await?;
tx.commit().await?;
Ok(())
}
fn get_unindexed_collabs(&self) -> Pin<Box<dyn Stream<Item = Result<UnindexedCollab>>>> {
let db = self.db.clone();
Box::pin(try_stream! {
let collabs = get_collabs_without_embeddings(&db).await?;
if !collabs.is_empty() {
tracing::trace!("found {} unindexed collabs", collabs.len());
}
for cid in collabs {
match &cid.collab_type {
CollabType::Document => {
let collab =
select_blob_from_af_collab(&db, &CollabType::Document, &cid.object_id).await?;
let collab = EncodedCollab::decode_from_bytes(&collab)
.map_err(|err| crate::error::Error::Collab(CollabError::Internal(err)))?;
yield UnindexedCollab {
workspace_id: cid.workspace_id,
object_id: cid.object_id,
collab_type: cid.collab_type,
collab,
};
},
CollabType::Database
| CollabType::WorkspaceDatabase
| CollabType::Folder
| CollabType::DatabaseRow
| CollabType::UserAwareness
| CollabType::Unknown => { /* atm. only document types are supported */ },
}
}
})
}
}
#[cfg(test)]
mod test {
use axum::body::Bytes;
use collab::entity::{EncodedCollab, EncoderVersion};
use pgvector::Vector;
use sqlx::Row;
use database_entity::dto::EmbeddingContentType;
use crate::indexer::{Indexer, PostgresIndexer};
use crate::test_utils::{db_pool, openai_client, setup_collab};
#[tokio::test]
async fn test_indexing_embeddings() {
let _ = env_logger::builder().is_test(true).try_init();
let db = db_pool().await;
let object_id = uuid::Uuid::new_v4();
let uid = rand::random();
let workspace_id = setup_collab(
&db,
uid,
object_id,
&EncodedCollab {
state_vector: Bytes::from(vec![1, 2, 3]),
doc_state: Bytes::from(vec![4, 5, 6]),
version: EncoderVersion::V1,
},
)
.await;
let openai = openai_client();
let indexer = PostgresIndexer::new(openai, db);
let fragment_id = uuid::Uuid::new_v4().to_string();
let fragments = vec![super::Fragment {
fragment_id: fragment_id.clone(),
object_id: object_id.to_string(),
collab_type: collab_entity::CollabType::Document,
content_type: EmbeddingContentType::PlainText,
content: "Hello, world!".to_string(),
}];
// resolve embeddings from OpenAI
let embeddings = indexer.get_embeddings(fragments).await.unwrap();
assert!(embeddings.fragments[0].embedding.is_some());
// store embeddings in DB
indexer
.store_embeddings(&workspace_id, embeddings)
.await
.unwrap();
// search for embedding
let mut tx = indexer.db.begin().await.unwrap();
let row =
sqlx::query("SELECT content, embedding FROM af_collab_embeddings WHERE fragment_id = $1")
.bind(&fragment_id)
.fetch_one(&mut *tx)
.await
.unwrap();
tx.commit().await.unwrap();
let content: String = row.get(0);
assert_eq!(&content, "Hello, world!");
let embedding: Option<Vector> = row.get(1);
assert!(embedding.is_some());
// remove embeddings
indexer.remove(&[fragment_id.clone()]).await.unwrap();
let mut tx = indexer.db.begin().await.unwrap();
let row =
sqlx::query("SELECT content, embedding FROM af_collab_embeddings WHERE fragment_id = $1")
.bind(&fragment_id)
.fetch_one(&mut *tx)
.await;
assert!(row.is_err());
}
}