chore: implement embedding api endpoint
This commit is contained in:
parent
a0376446d0
commit
5bc39729a7
|
|
@ -568,12 +568,14 @@ name = "appflowy-ai-client"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"appflowy-ai-client",
|
||||||
"bytes",
|
"bytes",
|
||||||
"futures",
|
"futures",
|
||||||
"reqwest 0.12.4",
|
"reqwest 0.12.4",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"serde_repr",
|
"serde_repr",
|
||||||
|
"serde_with",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
|
@ -2184,6 +2186,41 @@ dependencies = [
|
||||||
"syn 2.0.48",
|
"syn 2.0.48",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "darling"
|
||||||
|
version = "0.20.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "83b2eb4d90d12bdda5ed17de686c2acb4c57914f8f921b8da7e112b5a36f3fe1"
|
||||||
|
dependencies = [
|
||||||
|
"darling_core",
|
||||||
|
"darling_macro",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "darling_core"
|
||||||
|
version = "0.20.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "622687fe0bac72a04e5599029151f5796111b90f1baaa9b544d807a5e31cd120"
|
||||||
|
dependencies = [
|
||||||
|
"fnv",
|
||||||
|
"ident_case",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"strsim",
|
||||||
|
"syn 2.0.48",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "darling_macro"
|
||||||
|
version = "0.20.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "733cabb43482b1a1b53eee8583c2b9e8684d592215ea83efd305dd31bc2f0178"
|
||||||
|
dependencies = [
|
||||||
|
"darling_core",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.48",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dashmap"
|
name = "dashmap"
|
||||||
version = "5.5.3"
|
version = "5.5.3"
|
||||||
|
|
@ -3289,6 +3326,12 @@ dependencies = [
|
||||||
"cc",
|
"cc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ident_case"
|
||||||
|
version = "1.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "idna"
|
name = "idna"
|
||||||
version = "0.3.0"
|
version = "0.3.0"
|
||||||
|
|
@ -3358,6 +3401,7 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"autocfg",
|
"autocfg",
|
||||||
"hashbrown 0.12.3",
|
"hashbrown 0.12.3",
|
||||||
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -3368,6 +3412,7 @@ checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"equivalent",
|
"equivalent",
|
||||||
"hashbrown 0.14.3",
|
"hashbrown 0.14.3",
|
||||||
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -5603,6 +5648,36 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_with"
|
||||||
|
version = "3.8.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0ad483d2ab0149d5a5ebcd9972a3852711e0153d863bf5a5d0391d28883c4a20"
|
||||||
|
dependencies = [
|
||||||
|
"base64 0.22.0",
|
||||||
|
"chrono",
|
||||||
|
"hex",
|
||||||
|
"indexmap 1.9.3",
|
||||||
|
"indexmap 2.2.5",
|
||||||
|
"serde",
|
||||||
|
"serde_derive",
|
||||||
|
"serde_json",
|
||||||
|
"serde_with_macros",
|
||||||
|
"time",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_with_macros"
|
||||||
|
version = "3.8.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "65569b702f41443e8bc8bbb1c5779bd0450bbe723b56198980e80ec45780bce2"
|
||||||
|
dependencies = [
|
||||||
|
"darling",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.48",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serial_test"
|
name = "serial_test"
|
||||||
version = "3.0.0"
|
version = "3.0.0"
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,10 @@ tracing = { version = "0.1", optional = true }
|
||||||
serde_repr = { version = "0.1", optional = true }
|
serde_repr = { version = "0.1", optional = true }
|
||||||
futures = "0.3.30"
|
futures = "0.3.30"
|
||||||
bytes = "1.6.0"
|
bytes = "1.6.0"
|
||||||
|
serde_with = "3.8.1"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
appflowy-ai-client = { path = ".", features = ["dto", "client-api"] }
|
||||||
tokio = { version = "1.37.0", features = ["macros", "test-util"] }
|
tokio = { version = "1.37.0", features = ["macros", "test-util"] }
|
||||||
tracing-subscriber = { version = "0.3.18", features = ["registry", "env-filter", "ansi", "json"] }
|
tracing-subscriber = { version = "0.3.18", features = ["registry", "env-filter", "ansi", "json"] }
|
||||||
uuid = { version = "1.6", features = ["v4"] }
|
uuid = { version = "1.6", features = ["v4"] }
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
use crate::dto::{
|
use crate::dto::{
|
||||||
ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, Document, MessageData,
|
ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, Document, EmbeddingRequest,
|
||||||
RepeatedRelatedQuestion, SearchDocumentsRequest, SummarizeRowResponse, TranslateRowData,
|
EmbeddingResponse, MessageData, RepeatedRelatedQuestion, SearchDocumentsRequest,
|
||||||
TranslateRowResponse,
|
SummarizeRowResponse, TranslateRowData, TranslateRowResponse,
|
||||||
};
|
};
|
||||||
use crate::error::AIError;
|
use crate::error::AIError;
|
||||||
|
|
||||||
|
|
@ -97,6 +97,18 @@ impl AppFlowyAIClient {
|
||||||
.into_data()
|
.into_data()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn embeddings(&self, params: EmbeddingRequest) -> Result<EmbeddingResponse, AIError> {
|
||||||
|
let url = format!("{}/embeddings", self.url);
|
||||||
|
let resp = self
|
||||||
|
.http_client(Method::POST, &url)?
|
||||||
|
.json(¶ms)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
AIResponse::<EmbeddingResponse>::from_response(resp)
|
||||||
|
.await?
|
||||||
|
.into_data()
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn index_documents(&self, documents: &[Document]) -> Result<(), AIError> {
|
pub async fn index_documents(&self, documents: &[Document]) -> Result<(), AIError> {
|
||||||
let url = format!("{}/index_documents", self.url);
|
let url = format!("{}/index_documents", self.url);
|
||||||
let resp = self
|
let resp = self
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
use serde::{Deserialize, Serialize, Serializer};
|
use serde::{Deserialize, Serialize, Serializer};
|
||||||
use serde_repr::{Deserialize_repr, Serialize_repr};
|
use serde_repr::{Deserialize_repr, Serialize_repr};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::fmt::{Display, Formatter};
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
pub struct SummarizeRowResponse {
|
pub struct SummarizeRowResponse {
|
||||||
|
|
@ -114,3 +115,80 @@ pub struct TranslateItem {
|
||||||
pub struct TranslateRowResponse {
|
pub struct TranslateRowResponse {
|
||||||
pub items: Vec<HashMap<String, String>>,
|
pub items: Vec<HashMap<String, String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum EmbeddingInput {
|
||||||
|
/// The string that will be turned into an embedding.
|
||||||
|
String(String),
|
||||||
|
/// The array of strings that will be turned into an embedding.
|
||||||
|
StringArray(Vec<String>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum EmbeddingOutput {
|
||||||
|
Float(Vec<f64>),
|
||||||
|
Base64(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub struct Embedding {
|
||||||
|
/// An integer representing the index of the embedding in the list of embeddings.
|
||||||
|
pub index: i32,
|
||||||
|
/// The embedding value, which is an instance of `EmbeddingOutput`.
|
||||||
|
pub embedding: EmbeddingOutput,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub struct EmbeddingResponse {
|
||||||
|
/// A string that is always set to "embedding".
|
||||||
|
pub object: String,
|
||||||
|
/// A list of `Embedding` objects.
|
||||||
|
pub data: Vec<Embedding>,
|
||||||
|
/// A string representing the model used to generate the embeddings.
|
||||||
|
pub model: String,
|
||||||
|
/// An integer representing the total number of tokens used.
|
||||||
|
pub total_tokens: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum EmbeddingEncodingFormat {
|
||||||
|
Float,
|
||||||
|
Base64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub struct EmbeddingRequest {
|
||||||
|
/// An instance of `EmbeddingInput` containing the data to be embedded.
|
||||||
|
pub input: EmbeddingInput,
|
||||||
|
/// A string representing the model to use for generating embeddings.
|
||||||
|
pub model: String,
|
||||||
|
/// An integer representing the chunk size for processing.
|
||||||
|
pub chunk_size: i32,
|
||||||
|
/// An instance of `EmbeddingEncodingFormat` representing the format of the embedding.
|
||||||
|
pub encoding_format: EmbeddingEncodingFormat,
|
||||||
|
/// An integer representing the number of dimensions for the embedding.
|
||||||
|
pub dimensions: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||||
|
pub enum EmbeddingsModel {
|
||||||
|
#[serde(rename = "text-embedding-3-small")]
|
||||||
|
TextEmbedding3Small,
|
||||||
|
#[serde(rename = "text-embedding-3-large")]
|
||||||
|
TextEmbedding3Large,
|
||||||
|
#[serde(rename = "text-embedding-ada-002")]
|
||||||
|
TextEmbeddingAda002,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for EmbeddingsModel {
|
||||||
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
EmbeddingsModel::TextEmbedding3Small => write!(f, "text-embedding-3-small"),
|
||||||
|
EmbeddingsModel::TextEmbedding3Large => write!(f, "text-embedding-3-large"),
|
||||||
|
EmbeddingsModel::TextEmbeddingAda002 => write!(f, "text-embedding-ada-002"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,20 @@
|
||||||
|
use crate::appflowy_ai_client;
|
||||||
|
|
||||||
|
use appflowy_ai_client::dto::{
|
||||||
|
EmbeddingEncodingFormat, EmbeddingInput, EmbeddingRequest, EmbeddingsModel,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn embedding_test() {
|
||||||
|
let client = appflowy_ai_client();
|
||||||
|
let request = EmbeddingRequest {
|
||||||
|
input: EmbeddingInput::String("hello world".to_string()),
|
||||||
|
model: EmbeddingsModel::TextEmbedding3Small.to_string(),
|
||||||
|
chunk_size: 1000,
|
||||||
|
encoding_format: EmbeddingEncodingFormat::Float,
|
||||||
|
dimensions: 1536,
|
||||||
|
};
|
||||||
|
let result = client.embeddings(request).await.unwrap();
|
||||||
|
assert!(result.total_tokens > 0);
|
||||||
|
assert!(result.data.len() > 0);
|
||||||
|
}
|
||||||
|
|
@ -1,2 +1,3 @@
|
||||||
mod completion_test;
|
mod completion_test;
|
||||||
|
mod embedding_test;
|
||||||
mod qa_test;
|
mod qa_test;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue