chore: implement embedding api endpoint

This commit is contained in:
nathan 2024-06-17 22:06:58 +08:00
parent a0376446d0
commit 5bc39729a7
6 changed files with 191 additions and 3 deletions

75
Cargo.lock generated
View File

@ -568,12 +568,14 @@ name = "appflowy-ai-client"
version = "0.1.0"
dependencies = [
"anyhow",
"appflowy-ai-client",
"bytes",
"futures",
"reqwest 0.12.4",
"serde",
"serde_json",
"serde_repr",
"serde_with",
"thiserror",
"tokio",
"tracing",
@ -2184,6 +2186,41 @@ dependencies = [
"syn 2.0.48",
]
[[package]]
name = "darling"
version = "0.20.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "83b2eb4d90d12bdda5ed17de686c2acb4c57914f8f921b8da7e112b5a36f3fe1"
dependencies = [
"darling_core",
"darling_macro",
]
[[package]]
name = "darling_core"
version = "0.20.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "622687fe0bac72a04e5599029151f5796111b90f1baaa9b544d807a5e31cd120"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim",
"syn 2.0.48",
]
[[package]]
name = "darling_macro"
version = "0.20.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "733cabb43482b1a1b53eee8583c2b9e8684d592215ea83efd305dd31bc2f0178"
dependencies = [
"darling_core",
"quote",
"syn 2.0.48",
]
[[package]]
name = "dashmap"
version = "5.5.3"
@ -3289,6 +3326,12 @@ dependencies = [
"cc",
]
[[package]]
name = "ident_case"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
[[package]]
name = "idna"
version = "0.3.0"
@ -3358,6 +3401,7 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
dependencies = [
"autocfg",
"hashbrown 0.12.3",
"serde",
]
[[package]]
@ -3368,6 +3412,7 @@ checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4"
dependencies = [
"equivalent",
"hashbrown 0.14.3",
"serde",
]
[[package]]
@ -5603,6 +5648,36 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_with"
version = "3.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ad483d2ab0149d5a5ebcd9972a3852711e0153d863bf5a5d0391d28883c4a20"
dependencies = [
"base64 0.22.0",
"chrono",
"hex",
"indexmap 1.9.3",
"indexmap 2.2.5",
"serde",
"serde_derive",
"serde_json",
"serde_with_macros",
"time",
]
[[package]]
name = "serde_with_macros"
version = "3.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65569b702f41443e8bc8bbb1c5779bd0450bbe723b56198980e80ec45780bce2"
dependencies = [
"darling",
"proc-macro2",
"quote",
"syn 2.0.48",
]
[[package]]
name = "serial_test"
version = "3.0.0"

View File

@ -15,8 +15,10 @@ tracing = { version = "0.1", optional = true }
serde_repr = { version = "0.1", optional = true }
futures = "0.3.30"
bytes = "1.6.0"
serde_with = "3.8.1"
[dev-dependencies]
appflowy-ai-client = { path = ".", features = ["dto", "client-api"] }
tokio = { version = "1.37.0", features = ["macros", "test-util"] }
tracing-subscriber = { version = "0.3.18", features = ["registry", "env-filter", "ansi", "json"] }
uuid = { version = "1.6", features = ["v4"] }

View File

@ -1,7 +1,7 @@
use crate::dto::{
ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, Document, MessageData,
RepeatedRelatedQuestion, SearchDocumentsRequest, SummarizeRowResponse, TranslateRowData,
TranslateRowResponse,
ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, Document, EmbeddingRequest,
EmbeddingResponse, MessageData, RepeatedRelatedQuestion, SearchDocumentsRequest,
SummarizeRowResponse, TranslateRowData, TranslateRowResponse,
};
use crate::error::AIError;
@ -97,6 +97,18 @@ impl AppFlowyAIClient {
.into_data()
}
pub async fn embeddings(&self, params: EmbeddingRequest) -> Result<EmbeddingResponse, AIError> {
let url = format!("{}/embeddings", self.url);
let resp = self
.http_client(Method::POST, &url)?
.json(&params)
.send()
.await?;
AIResponse::<EmbeddingResponse>::from_response(resp)
.await?
.into_data()
}
pub async fn index_documents(&self, documents: &[Document]) -> Result<(), AIError> {
let url = format!("{}/index_documents", self.url);
let resp = self

View File

@ -1,6 +1,7 @@
use serde::{Deserialize, Serialize, Serializer};
use serde_repr::{Deserialize_repr, Serialize_repr};
use std::collections::HashMap;
use std::fmt::{Display, Formatter};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SummarizeRowResponse {
@ -114,3 +115,80 @@ pub struct TranslateItem {
pub struct TranslateRowResponse {
pub items: Vec<HashMap<String, String>>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(untagged)]
pub enum EmbeddingInput {
/// The string that will be turned into an embedding.
String(String),
/// The array of strings that will be turned into an embedding.
StringArray(Vec<String>),
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(untagged)]
pub enum EmbeddingOutput {
Float(Vec<f64>),
Base64(String),
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Embedding {
/// An integer representing the index of the embedding in the list of embeddings.
pub index: i32,
/// The embedding value, which is an instance of `EmbeddingOutput`.
pub embedding: EmbeddingOutput,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct EmbeddingResponse {
/// A string that is always set to "embedding".
pub object: String,
/// A list of `Embedding` objects.
pub data: Vec<Embedding>,
/// A string representing the model used to generate the embeddings.
pub model: String,
/// An integer representing the total number of tokens used.
pub total_tokens: i32,
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(rename_all = "lowercase")]
pub enum EmbeddingEncodingFormat {
Float,
Base64,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct EmbeddingRequest {
/// An instance of `EmbeddingInput` containing the data to be embedded.
pub input: EmbeddingInput,
/// A string representing the model to use for generating embeddings.
pub model: String,
/// An integer representing the chunk size for processing.
pub chunk_size: i32,
/// An instance of `EmbeddingEncodingFormat` representing the format of the embedding.
pub encoding_format: EmbeddingEncodingFormat,
/// An integer representing the number of dimensions for the embedding.
pub dimensions: i32,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub enum EmbeddingsModel {
#[serde(rename = "text-embedding-3-small")]
TextEmbedding3Small,
#[serde(rename = "text-embedding-3-large")]
TextEmbedding3Large,
#[serde(rename = "text-embedding-ada-002")]
TextEmbeddingAda002,
}
impl Display for EmbeddingsModel {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
EmbeddingsModel::TextEmbedding3Small => write!(f, "text-embedding-3-small"),
EmbeddingsModel::TextEmbedding3Large => write!(f, "text-embedding-3-large"),
EmbeddingsModel::TextEmbeddingAda002 => write!(f, "text-embedding-ada-002"),
}
}
}

View File

@ -0,0 +1,20 @@
use crate::appflowy_ai_client;
use appflowy_ai_client::dto::{
EmbeddingEncodingFormat, EmbeddingInput, EmbeddingRequest, EmbeddingsModel,
};
#[tokio::test]
async fn embedding_test() {
let client = appflowy_ai_client();
let request = EmbeddingRequest {
input: EmbeddingInput::String("hello world".to_string()),
model: EmbeddingsModel::TextEmbedding3Small.to_string(),
chunk_size: 1000,
encoding_format: EmbeddingEncodingFormat::Float,
dimensions: 1536,
};
let result = client.embeddings(request).await.unwrap();
assert!(result.total_tokens > 0);
assert!(result.data.len() > 0);
}

View File

@ -1,2 +1,3 @@
mod completion_test;
mod embedding_test;
mod qa_test;