diff --git a/libs/database/src/index/collab_embeddings_ops.rs b/libs/database/src/index/collab_embeddings_ops.rs index fdfe3005..58c8d0bc 100644 --- a/libs/database/src/index/collab_embeddings_ops.rs +++ b/libs/database/src/index/collab_embeddings_ops.rs @@ -2,6 +2,7 @@ use std::ops::DerefMut; use pgvector::Vector; use sqlx::Transaction; +use uuid::Uuid; use database_entity::dto::AFCollabEmbeddingParams; @@ -20,8 +21,20 @@ pub async fn has_collab_embeddings( pub async fn upsert_collab_embeddings( tx: &mut Transaction<'_, sqlx::Postgres>, + workspace_id: &Uuid, + tokens_used: u32, records: Vec, ) -> Result<(), sqlx::Error> { + if tokens_used > 0 { + sqlx::query( + "UPDATE af_workspace SET index_token_usage = index_token_usage + $2 WHERE workspace_id = $1", + ) + .bind(workspace_id) + .bind(tokens_used as i64) + .execute(tx.deref_mut()) + .await?; + } + for r in records { sqlx::query( r#"INSERT INTO af_collab_embeddings (fragment_id, oid, partition_key, content_type, content, embedding, indexed_at) diff --git a/libs/database/src/index/search_ops.rs b/libs/database/src/index/search_ops.rs index 312b5206..5afeb743 100644 --- a/libs/database/src/index/search_ops.rs +++ b/libs/database/src/index/search_ops.rs @@ -8,13 +8,21 @@ use uuid::Uuid; pub async fn search_documents( tx: &mut Transaction<'_, sqlx::Postgres>, params: SearchDocumentParams, + tokens_used: u32, ) -> Result, sqlx::Error> { let query = sqlx::query_as::<_, SearchDocumentItem>( r#" + WITH workspace AS ( + UPDATE af_workspace + SET search_token_usage = search_token_usage + $6 + WHERE workspace_id = $2 + RETURNING workspace_id + ) SELECT em.oid AS object_id, collab.workspace_id, em.partition_key AS collab_type, + em.content_type, LEFT(em.content, $4) AS content_preview, u.name AS created_by, collab.created_at AS created_at, @@ -32,7 +40,8 @@ pub async fn search_documents( .bind(params.workspace_id) .bind(Vector::from(params.embedding)) .bind(params.preview) - .bind(params.limit); + .bind(params.limit) + .bind(tokens_used as i64); let rows = query.fetch_all(tx.deref_mut()).await?; Ok(rows) } diff --git a/migrations/20240529054858_workspace_add_token_usage.sql b/migrations/20240529054858_workspace_add_token_usage.sql new file mode 100644 index 00000000..82fd5ce1 --- /dev/null +++ b/migrations/20240529054858_workspace_add_token_usage.sql @@ -0,0 +1,3 @@ +-- Add migration script here +ALTER TABLE af_workspace ADD COLUMN search_token_usage BIGINT NOT NULL DEFAULT 0; +ALTER TABLE af_workspace ADD COLUMN index_token_usage BIGINT NOT NULL DEFAULT 0; \ No newline at end of file diff --git a/services/appflowy-indexer/src/collab_handle.rs b/services/appflowy-indexer/src/collab_handle.rs index 981f34d6..babaa15a 100644 --- a/services/appflowy-indexer/src/collab_handle.rs +++ b/services/appflowy-indexer/src/collab_handle.rs @@ -16,6 +16,7 @@ use tokio::task::JoinSet; use tokio::time::interval; use tokio_util::sync::CancellationToken; use tracing::instrument; +use uuid::Uuid; use collab_stream::client::CollabRedisStream; use collab_stream::model::{CollabUpdateEvent, StreamMessage}; @@ -75,6 +76,8 @@ impl CollabHandle { if !messages.is_empty() { Self::handle_collab_updates(&mut update_stream, content.get_collab(), messages).await?; } + let workspace_id = + Uuid::parse_str(&workspace_id).map_err(|e| crate::error::Error::InvalidWorkspace(e))?; let mut tasks = JoinSet::new(); tasks.spawn(Self::receive_collab_updates( @@ -108,7 +111,7 @@ impl CollabHandle { mut update_stream: StreamGroup, content: Weak, object_id: String, - workspace_id: String, + workspace_id: Uuid, ingest_interval: Duration, closing: CancellationToken, ) { @@ -175,7 +178,7 @@ impl CollabHandle { mut updates: Pin + Send + Sync>>, indexer: Arc, object_id: String, - workspace_id: String, + workspace_id: Uuid, ingest_interval: Duration, token: CancellationToken, ) { @@ -186,14 +189,14 @@ impl CollabHandle { loop { select! { _ = interval.tick() => { - match Self::publish_updates(&indexer, &mut inserts, &mut removals).await { + match Self::publish_updates(&indexer, &workspace_id, &mut inserts, &mut removals).await { Ok(_) => last_update = Instant::now(), Err(err) => tracing::error!("document {}/{} watcher failed to publish fragment updates: {}", workspace_id, object_id, err), } } _ = token.cancelled() => { tracing::trace!("document {}/{} watcher closing signal received, flushing remaining updates", workspace_id, object_id); - if let Err(err) = Self::publish_updates(&indexer, &mut inserts, &mut removals).await { + if let Err(err) = Self::publish_updates(&indexer, &workspace_id, &mut inserts, &mut removals).await { tracing::error!("document {}/{} watcher failed to publish fragment updates: {}", workspace_id, object_id, err); } return; @@ -215,7 +218,7 @@ impl CollabHandle { let now = Instant::now(); if now.duration_since(last_update) > ingest_interval { - match Self::publish_updates(&indexer, &mut inserts, &mut removals).await { + match Self::publish_updates(&indexer, &workspace_id, &mut inserts, &mut removals).await { Ok(_) => last_update = now, Err(err) => tracing::error!("document {}/{} watcher failed to publish fragment updates: {}", workspace_id, object_id, err), } @@ -227,6 +230,7 @@ impl CollabHandle { async fn publish_updates( indexer: &Arc, + workspace_id: &Uuid, inserts: &mut HashMap, removals: &mut HashSet, ) -> Result<()> { @@ -236,7 +240,7 @@ impl CollabHandle { let inserts: Vec<_> = inserts.drain().map(|(_, doc)| doc).collect(); if !inserts.is_empty() { tracing::info!("updating indexes for {} fragments", inserts.len()); - indexer.update_index(inserts).await?; + indexer.update_index(workspace_id, inserts).await?; } if !removals.is_empty() { @@ -346,5 +350,14 @@ mod test { .collect::>(); assert_eq!(contents.len(), 1); + + let tokens: i64 = + sqlx::query("SELECT index_token_usage from af_workspace WHERE workspace_id = $1") + .bind(&workspace_id) + .fetch_one(&db) + .await + .unwrap() + .get(0); + assert_ne!(tokens, 0); } } diff --git a/services/appflowy-indexer/src/error.rs b/services/appflowy-indexer/src/error.rs index 3032fc3b..91bc3827 100644 --- a/services/appflowy-indexer/src/error.rs +++ b/services/appflowy-indexer/src/error.rs @@ -14,6 +14,8 @@ pub enum Error { Sql(#[from] sqlx::Error), #[error("OpenAI failed to process request: {0}")] OpenAI(String), + #[error("invalid workspace ID: {0}")] + InvalidWorkspace(uuid::Error), } pub type Result = std::result::Result; diff --git a/services/appflowy-indexer/src/indexer.rs b/services/appflowy-indexer/src/indexer.rs index 5c90e609..9c843c8c 100644 --- a/services/appflowy-indexer/src/indexer.rs +++ b/services/appflowy-indexer/src/indexer.rs @@ -7,6 +7,7 @@ use openai_dive::v1::resources::embedding::{ }; use serde::{Deserialize, Serialize}; use sqlx::PgPool; +use uuid::Uuid; use database::index::{has_collab_embeddings, remove_collab_embeddings, upsert_collab_embeddings}; use database_entity::dto::{AFCollabEmbeddingParams, EmbeddingContentType}; @@ -17,7 +18,7 @@ use crate::error::Result; pub trait Indexer: Send + Sync { /// Check if document with given id has been already a corresponding index entry. async fn was_indexed(&self, object_id: &str) -> Result; - async fn update_index(&self, documents: Vec) -> Result<()>; + async fn update_index(&self, workspace_id: &Uuid, documents: Vec) -> Result<()>; async fn remove(&self, ids: &[FragmentID]) -> Result<()>; } @@ -99,7 +100,7 @@ impl PostgresIndexer { Self { openai, db } } - async fn get_embeddings(&self, fragments: Vec) -> Result> { + async fn get_embeddings(&self, fragments: Vec) -> Result { let inputs: Vec<_> = fragments .iter() .map(|fragment| fragment.content.clone()) @@ -118,10 +119,12 @@ impl PostgresIndexer { .map_err(|e| crate::error::Error::OpenAI(e.to_string()))?; tracing::trace!("fetched {} embeddings", resp.data.len()); - if let Some(usage) = resp.usage { - tracing::info!("OpenAI API usage: {}", usage.total_tokens); - //TODO: report usage statistics - } + let tokens_used = if let Some(usage) = resp.usage { + tracing::info!("OpenAI API index tokens used: {}", usage.total_tokens); + usage.total_tokens + } else { + 0 + }; let mut fragments: Vec<_> = fragments.into_iter().map(EmbedFragment::from).collect(); for e in resp.data.into_iter() { @@ -135,18 +138,27 @@ impl PostgresIndexer { }; fragments[e.index as usize].embedding = Some(embedding); } - Ok(fragments) + Ok(Embeddings { + tokens_used, + fragments, + }) } - async fn store_embeddings(&self, fragments: Vec) -> Result<()> { + async fn store_embeddings(&self, workspace_id: &Uuid, embeddings: Embeddings) -> Result<()> { tracing::trace!( "storing {} embeddings inside of vector database", - fragments.len() + embeddings.fragments.len() ); let mut tx = self.db.begin().await?; upsert_collab_embeddings( &mut tx, - fragments.into_iter().map(EmbedFragment::into).collect(), + workspace_id, + embeddings.tokens_used, + embeddings + .fragments + .into_iter() + .map(EmbedFragment::into) + .collect(), ) .await?; tx.commit().await?; @@ -154,6 +166,11 @@ impl PostgresIndexer { } } +struct Embeddings { + tokens_used: u32, + fragments: Vec, +} + #[async_trait] impl Indexer for PostgresIndexer { async fn was_indexed(&self, object_id: &str) -> Result { @@ -161,9 +178,9 @@ impl Indexer for PostgresIndexer { Ok(found) } - async fn update_index(&self, documents: Vec) -> Result<()> { + async fn update_index(&self, workspace_id: &Uuid, documents: Vec) -> Result<()> { let embeddings = self.get_embeddings(documents).await?; - self.store_embeddings(embeddings).await?; + self.store_embeddings(workspace_id, embeddings).await?; Ok(()) } @@ -191,7 +208,7 @@ mod test { let db = db_pool().await; let object_id = uuid::Uuid::new_v4(); let uid = rand::random(); - setup_collab(&db, uid, object_id, vec![]).await; + let workspace_id = setup_collab(&db, uid, object_id, vec![]).await; let openai = openai_client(); @@ -209,10 +226,13 @@ mod test { // resolve embeddings from OpenAI let embeddings = indexer.get_embeddings(fragments).await.unwrap(); - assert_eq!(embeddings[0].embedding.is_some(), true); + assert_eq!(embeddings.fragments[0].embedding.is_some(), true); // store embeddings in DB - indexer.store_embeddings(embeddings).await.unwrap(); + indexer + .store_embeddings(&workspace_id, embeddings) + .await + .unwrap(); // search for embedding let mut tx = indexer.db.begin().await.unwrap(); diff --git a/src/biz/search/ops.rs b/src/biz/search/ops.rs index 952dbf4d..07384a52 100644 --- a/src/biz/search/ops.rs +++ b/src/biz/search/ops.rs @@ -30,9 +30,16 @@ pub async fn search_document( .await .map_err(|e| AppResponseError::new(ErrorCode::Internal, e.to_string()))?; - if let Some(usage) = embeddings.usage { - tracing::info!("OpenAI API usage: {}", usage.total_tokens) - } + let tokens_used = if let Some(usage) = embeddings.usage { + tracing::info!( + "workspace {} OpenAI API search tokens used: {}", + workspace_id, + usage.total_tokens + ); + usage.total_tokens + } else { + 0 + }; let embedding = embeddings .data @@ -61,8 +68,10 @@ pub async fn search_document( preview: request.preview_size.unwrap_or(180) as i32, embedding, }, + tokens_used, ) .await?; + tx.commit().await?; Ok( results .into_iter()