From 5bc39729a7dae5240b780fd9c8ac5b5c0beef2ae Mon Sep 17 00:00:00 2001 From: nathan Date: Mon, 17 Jun 2024 22:06:58 +0800 Subject: [PATCH] chore: implement embedding api endpoint --- Cargo.lock | 75 ++++++++++++++++++ libs/appflowy-ai-client/Cargo.toml | 2 + libs/appflowy-ai-client/src/client.rs | 18 ++++- libs/appflowy-ai-client/src/dto.rs | 78 +++++++++++++++++++ .../tests/chat_test/embedding_test.rs | 20 +++++ .../appflowy-ai-client/tests/chat_test/mod.rs | 1 + 6 files changed, 191 insertions(+), 3 deletions(-) create mode 100644 libs/appflowy-ai-client/tests/chat_test/embedding_test.rs diff --git a/Cargo.lock b/Cargo.lock index b7b76e91..82c528cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -568,12 +568,14 @@ name = "appflowy-ai-client" version = "0.1.0" dependencies = [ "anyhow", + "appflowy-ai-client", "bytes", "futures", "reqwest 0.12.4", "serde", "serde_json", "serde_repr", + "serde_with", "thiserror", "tokio", "tracing", @@ -2184,6 +2186,41 @@ dependencies = [ "syn 2.0.48", ] +[[package]] +name = "darling" +version = "0.20.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83b2eb4d90d12bdda5ed17de686c2acb4c57914f8f921b8da7e112b5a36f3fe1" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "622687fe0bac72a04e5599029151f5796111b90f1baaa9b544d807a5e31cd120" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.48", +] + +[[package]] +name = "darling_macro" +version = "0.20.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "733cabb43482b1a1b53eee8583c2b9e8684d592215ea83efd305dd31bc2f0178" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.48", +] + [[package]] name = "dashmap" version = "5.5.3" @@ -3289,6 +3326,12 @@ dependencies = [ "cc", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "0.3.0" @@ -3358,6 +3401,7 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", "hashbrown 0.12.3", + "serde", ] [[package]] @@ -3368,6 +3412,7 @@ checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4" dependencies = [ "equivalent", "hashbrown 0.14.3", + "serde", ] [[package]] @@ -5603,6 +5648,36 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_with" +version = "3.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ad483d2ab0149d5a5ebcd9972a3852711e0153d863bf5a5d0391d28883c4a20" +dependencies = [ + "base64 0.22.0", + "chrono", + "hex", + "indexmap 1.9.3", + "indexmap 2.2.5", + "serde", + "serde_derive", + "serde_json", + "serde_with_macros", + "time", +] + +[[package]] +name = "serde_with_macros" +version = "3.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65569b702f41443e8bc8bbb1c5779bd0450bbe723b56198980e80ec45780bce2" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.48", +] + [[package]] name = "serial_test" version = "3.0.0" diff --git a/libs/appflowy-ai-client/Cargo.toml b/libs/appflowy-ai-client/Cargo.toml index 62c621b7..9de0e794 100644 --- a/libs/appflowy-ai-client/Cargo.toml +++ b/libs/appflowy-ai-client/Cargo.toml @@ -15,8 +15,10 @@ tracing = { version = "0.1", optional = true } serde_repr = { version = "0.1", optional = true } futures = "0.3.30" bytes = "1.6.0" +serde_with = "3.8.1" [dev-dependencies] +appflowy-ai-client = { path = ".", features = ["dto", "client-api"] } tokio = { version = "1.37.0", features = ["macros", "test-util"] } tracing-subscriber = { version = "0.3.18", features = ["registry", "env-filter", "ansi", "json"] } uuid = { version = "1.6", features = ["v4"] } diff --git a/libs/appflowy-ai-client/src/client.rs b/libs/appflowy-ai-client/src/client.rs index 73aff13e..49363fda 100644 --- a/libs/appflowy-ai-client/src/client.rs +++ b/libs/appflowy-ai-client/src/client.rs @@ -1,7 +1,7 @@ use crate::dto::{ - ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, Document, MessageData, - RepeatedRelatedQuestion, SearchDocumentsRequest, SummarizeRowResponse, TranslateRowData, - TranslateRowResponse, + ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, Document, EmbeddingRequest, + EmbeddingResponse, MessageData, RepeatedRelatedQuestion, SearchDocumentsRequest, + SummarizeRowResponse, TranslateRowData, TranslateRowResponse, }; use crate::error::AIError; @@ -97,6 +97,18 @@ impl AppFlowyAIClient { .into_data() } + pub async fn embeddings(&self, params: EmbeddingRequest) -> Result { + let url = format!("{}/embeddings", self.url); + let resp = self + .http_client(Method::POST, &url)? + .json(¶ms) + .send() + .await?; + AIResponse::::from_response(resp) + .await? + .into_data() + } + pub async fn index_documents(&self, documents: &[Document]) -> Result<(), AIError> { let url = format!("{}/index_documents", self.url); let resp = self diff --git a/libs/appflowy-ai-client/src/dto.rs b/libs/appflowy-ai-client/src/dto.rs index 6d2fa482..6cbb61f3 100644 --- a/libs/appflowy-ai-client/src/dto.rs +++ b/libs/appflowy-ai-client/src/dto.rs @@ -1,6 +1,7 @@ use serde::{Deserialize, Serialize, Serializer}; use serde_repr::{Deserialize_repr, Serialize_repr}; use std::collections::HashMap; +use std::fmt::{Display, Formatter}; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct SummarizeRowResponse { @@ -114,3 +115,80 @@ pub struct TranslateItem { pub struct TranslateRowResponse { pub items: Vec>, } + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(untagged)] +pub enum EmbeddingInput { + /// The string that will be turned into an embedding. + String(String), + /// The array of strings that will be turned into an embedding. + StringArray(Vec), +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(untagged)] +pub enum EmbeddingOutput { + Float(Vec), + Base64(String), +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Embedding { + /// An integer representing the index of the embedding in the list of embeddings. + pub index: i32, + /// The embedding value, which is an instance of `EmbeddingOutput`. + pub embedding: EmbeddingOutput, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct EmbeddingResponse { + /// A string that is always set to "embedding". + pub object: String, + /// A list of `Embedding` objects. + pub data: Vec, + /// A string representing the model used to generate the embeddings. + pub model: String, + /// An integer representing the total number of tokens used. + pub total_tokens: i32, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "lowercase")] +pub enum EmbeddingEncodingFormat { + Float, + Base64, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct EmbeddingRequest { + /// An instance of `EmbeddingInput` containing the data to be embedded. + pub input: EmbeddingInput, + /// A string representing the model to use for generating embeddings. + pub model: String, + /// An integer representing the chunk size for processing. + pub chunk_size: i32, + /// An instance of `EmbeddingEncodingFormat` representing the format of the embedding. + pub encoding_format: EmbeddingEncodingFormat, + /// An integer representing the number of dimensions for the embedding. + pub dimensions: i32, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub enum EmbeddingsModel { + #[serde(rename = "text-embedding-3-small")] + TextEmbedding3Small, + #[serde(rename = "text-embedding-3-large")] + TextEmbedding3Large, + #[serde(rename = "text-embedding-ada-002")] + TextEmbeddingAda002, +} + +impl Display for EmbeddingsModel { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + EmbeddingsModel::TextEmbedding3Small => write!(f, "text-embedding-3-small"), + EmbeddingsModel::TextEmbedding3Large => write!(f, "text-embedding-3-large"), + EmbeddingsModel::TextEmbeddingAda002 => write!(f, "text-embedding-ada-002"), + } + } +} diff --git a/libs/appflowy-ai-client/tests/chat_test/embedding_test.rs b/libs/appflowy-ai-client/tests/chat_test/embedding_test.rs new file mode 100644 index 00000000..e5565a11 --- /dev/null +++ b/libs/appflowy-ai-client/tests/chat_test/embedding_test.rs @@ -0,0 +1,20 @@ +use crate::appflowy_ai_client; + +use appflowy_ai_client::dto::{ + EmbeddingEncodingFormat, EmbeddingInput, EmbeddingRequest, EmbeddingsModel, +}; + +#[tokio::test] +async fn embedding_test() { + let client = appflowy_ai_client(); + let request = EmbeddingRequest { + input: EmbeddingInput::String("hello world".to_string()), + model: EmbeddingsModel::TextEmbedding3Small.to_string(), + chunk_size: 1000, + encoding_format: EmbeddingEncodingFormat::Float, + dimensions: 1536, + }; + let result = client.embeddings(request).await.unwrap(); + assert!(result.total_tokens > 0); + assert!(result.data.len() > 0); +} diff --git a/libs/appflowy-ai-client/tests/chat_test/mod.rs b/libs/appflowy-ai-client/tests/chat_test/mod.rs index fb191592..430fb1a7 100644 --- a/libs/appflowy-ai-client/tests/chat_test/mod.rs +++ b/libs/appflowy-ai-client/tests/chat_test/mod.rs @@ -1,2 +1,3 @@ mod completion_test; +mod embedding_test; mod qa_test;