chore: implement embedding api endpoint
This commit is contained in:
parent
a0376446d0
commit
5bc39729a7
|
|
@ -568,12 +568,14 @@ name = "appflowy-ai-client"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"appflowy-ai-client",
|
||||
"bytes",
|
||||
"futures",
|
||||
"reqwest 0.12.4",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_repr",
|
||||
"serde_with",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tracing",
|
||||
|
|
@ -2184,6 +2186,41 @@ dependencies = [
|
|||
"syn 2.0.48",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.20.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "83b2eb4d90d12bdda5ed17de686c2acb4c57914f8f921b8da7e112b5a36f3fe1"
|
||||
dependencies = [
|
||||
"darling_core",
|
||||
"darling_macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_core"
|
||||
version = "0.20.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "622687fe0bac72a04e5599029151f5796111b90f1baaa9b544d807a5e31cd120"
|
||||
dependencies = [
|
||||
"fnv",
|
||||
"ident_case",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"strsim",
|
||||
"syn 2.0.48",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_macro"
|
||||
version = "0.20.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "733cabb43482b1a1b53eee8583c2b9e8684d592215ea83efd305dd31bc2f0178"
|
||||
dependencies = [
|
||||
"darling_core",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dashmap"
|
||||
version = "5.5.3"
|
||||
|
|
@ -3289,6 +3326,12 @@ dependencies = [
|
|||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ident_case"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "0.3.0"
|
||||
|
|
@ -3358,6 +3401,7 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
|
|||
dependencies = [
|
||||
"autocfg",
|
||||
"hashbrown 0.12.3",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -3368,6 +3412,7 @@ checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4"
|
|||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown 0.14.3",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -5603,6 +5648,36 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_with"
|
||||
version = "3.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ad483d2ab0149d5a5ebcd9972a3852711e0153d863bf5a5d0391d28883c4a20"
|
||||
dependencies = [
|
||||
"base64 0.22.0",
|
||||
"chrono",
|
||||
"hex",
|
||||
"indexmap 1.9.3",
|
||||
"indexmap 2.2.5",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"serde_with_macros",
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_with_macros"
|
||||
version = "3.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "65569b702f41443e8bc8bbb1c5779bd0450bbe723b56198980e80ec45780bce2"
|
||||
dependencies = [
|
||||
"darling",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serial_test"
|
||||
version = "3.0.0"
|
||||
|
|
|
|||
|
|
@ -15,8 +15,10 @@ tracing = { version = "0.1", optional = true }
|
|||
serde_repr = { version = "0.1", optional = true }
|
||||
futures = "0.3.30"
|
||||
bytes = "1.6.0"
|
||||
serde_with = "3.8.1"
|
||||
|
||||
[dev-dependencies]
|
||||
appflowy-ai-client = { path = ".", features = ["dto", "client-api"] }
|
||||
tokio = { version = "1.37.0", features = ["macros", "test-util"] }
|
||||
tracing-subscriber = { version = "0.3.18", features = ["registry", "env-filter", "ansi", "json"] }
|
||||
uuid = { version = "1.6", features = ["v4"] }
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use crate::dto::{
|
||||
ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, Document, MessageData,
|
||||
RepeatedRelatedQuestion, SearchDocumentsRequest, SummarizeRowResponse, TranslateRowData,
|
||||
TranslateRowResponse,
|
||||
ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, Document, EmbeddingRequest,
|
||||
EmbeddingResponse, MessageData, RepeatedRelatedQuestion, SearchDocumentsRequest,
|
||||
SummarizeRowResponse, TranslateRowData, TranslateRowResponse,
|
||||
};
|
||||
use crate::error::AIError;
|
||||
|
||||
|
|
@ -97,6 +97,18 @@ impl AppFlowyAIClient {
|
|||
.into_data()
|
||||
}
|
||||
|
||||
pub async fn embeddings(&self, params: EmbeddingRequest) -> Result<EmbeddingResponse, AIError> {
|
||||
let url = format!("{}/embeddings", self.url);
|
||||
let resp = self
|
||||
.http_client(Method::POST, &url)?
|
||||
.json(¶ms)
|
||||
.send()
|
||||
.await?;
|
||||
AIResponse::<EmbeddingResponse>::from_response(resp)
|
||||
.await?
|
||||
.into_data()
|
||||
}
|
||||
|
||||
pub async fn index_documents(&self, documents: &[Document]) -> Result<(), AIError> {
|
||||
let url = format!("{}/index_documents", self.url);
|
||||
let resp = self
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
use serde::{Deserialize, Serialize, Serializer};
|
||||
use serde_repr::{Deserialize_repr, Serialize_repr};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::{Display, Formatter};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct SummarizeRowResponse {
|
||||
|
|
@ -114,3 +115,80 @@ pub struct TranslateItem {
|
|||
pub struct TranslateRowResponse {
|
||||
pub items: Vec<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum EmbeddingInput {
|
||||
/// The string that will be turned into an embedding.
|
||||
String(String),
|
||||
/// The array of strings that will be turned into an embedding.
|
||||
StringArray(Vec<String>),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum EmbeddingOutput {
|
||||
Float(Vec<f64>),
|
||||
Base64(String),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Embedding {
|
||||
/// An integer representing the index of the embedding in the list of embeddings.
|
||||
pub index: i32,
|
||||
/// The embedding value, which is an instance of `EmbeddingOutput`.
|
||||
pub embedding: EmbeddingOutput,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct EmbeddingResponse {
|
||||
/// A string that is always set to "embedding".
|
||||
pub object: String,
|
||||
/// A list of `Embedding` objects.
|
||||
pub data: Vec<Embedding>,
|
||||
/// A string representing the model used to generate the embeddings.
|
||||
pub model: String,
|
||||
/// An integer representing the total number of tokens used.
|
||||
pub total_tokens: i32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum EmbeddingEncodingFormat {
|
||||
Float,
|
||||
Base64,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct EmbeddingRequest {
|
||||
/// An instance of `EmbeddingInput` containing the data to be embedded.
|
||||
pub input: EmbeddingInput,
|
||||
/// A string representing the model to use for generating embeddings.
|
||||
pub model: String,
|
||||
/// An integer representing the chunk size for processing.
|
||||
pub chunk_size: i32,
|
||||
/// An instance of `EmbeddingEncodingFormat` representing the format of the embedding.
|
||||
pub encoding_format: EmbeddingEncodingFormat,
|
||||
/// An integer representing the number of dimensions for the embedding.
|
||||
pub dimensions: i32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
pub enum EmbeddingsModel {
|
||||
#[serde(rename = "text-embedding-3-small")]
|
||||
TextEmbedding3Small,
|
||||
#[serde(rename = "text-embedding-3-large")]
|
||||
TextEmbedding3Large,
|
||||
#[serde(rename = "text-embedding-ada-002")]
|
||||
TextEmbeddingAda002,
|
||||
}
|
||||
|
||||
impl Display for EmbeddingsModel {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
EmbeddingsModel::TextEmbedding3Small => write!(f, "text-embedding-3-small"),
|
||||
EmbeddingsModel::TextEmbedding3Large => write!(f, "text-embedding-3-large"),
|
||||
EmbeddingsModel::TextEmbeddingAda002 => write!(f, "text-embedding-ada-002"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
use crate::appflowy_ai_client;
|
||||
|
||||
use appflowy_ai_client::dto::{
|
||||
EmbeddingEncodingFormat, EmbeddingInput, EmbeddingRequest, EmbeddingsModel,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn embedding_test() {
|
||||
let client = appflowy_ai_client();
|
||||
let request = EmbeddingRequest {
|
||||
input: EmbeddingInput::String("hello world".to_string()),
|
||||
model: EmbeddingsModel::TextEmbedding3Small.to_string(),
|
||||
chunk_size: 1000,
|
||||
encoding_format: EmbeddingEncodingFormat::Float,
|
||||
dimensions: 1536,
|
||||
};
|
||||
let result = client.embeddings(request).await.unwrap();
|
||||
assert!(result.total_tokens > 0);
|
||||
assert!(result.data.len() > 0);
|
||||
}
|
||||
|
|
@ -1,2 +1,3 @@
|
|||
mod completion_test;
|
||||
mod embedding_test;
|
||||
mod qa_test;
|
||||
|
|
|
|||
Loading…
Reference in New Issue