|
|
|
|
@ -1,8 +1,12 @@
|
|
|
|
|
use crate::error::EntityError;
|
|
|
|
|
use crate::error::EntityError::{DeserializationError, InvalidData};
|
|
|
|
|
use crate::util::{validate_not_empty_payload, validate_not_empty_str};
|
|
|
|
|
use appflowy_ai_client::dto::AIModel;
|
|
|
|
|
use bytes::Bytes;
|
|
|
|
|
use chrono::{DateTime, Utc};
|
|
|
|
|
use collab_entity::proto;
|
|
|
|
|
use collab_entity::CollabType;
|
|
|
|
|
use prost::Message;
|
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
|
use serde_repr::{Deserialize_repr, Serialize_repr};
|
|
|
|
|
use std::cmp::Ordering;
|
|
|
|
|
@ -62,7 +66,7 @@ impl CreateCollabParams {
|
|
|
|
|
|
|
|
|
|
pub struct CollabIndexParams {}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
|
|
|
|
|
#[derive(Debug, Clone, Validate, Serialize, Deserialize, PartialEq)]
|
|
|
|
|
pub struct CollabParams {
|
|
|
|
|
#[validate(custom = "validate_not_empty_str")]
|
|
|
|
|
pub object_id: String,
|
|
|
|
|
@ -107,7 +111,50 @@ impl CollabParams {
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn to_proto(&self) -> proto::collab::CollabParams {
|
|
|
|
|
proto::collab::CollabParams {
|
|
|
|
|
object_id: self.object_id.clone(),
|
|
|
|
|
encoded_collab: self.encoded_collab_v1.to_vec(),
|
|
|
|
|
collab_type: self.collab_type.to_proto() as i32,
|
|
|
|
|
embeddings: self
|
|
|
|
|
.embeddings
|
|
|
|
|
.as_ref()
|
|
|
|
|
.map(|embeddings| embeddings.to_proto()),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn to_protobuf_bytes(&self) -> Vec<u8> {
|
|
|
|
|
self.to_proto().encode_to_vec()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn from_protobuf_bytes(bytes: &[u8]) -> Result<Self, EntityError> {
|
|
|
|
|
match proto::collab::CollabParams::decode(bytes) {
|
|
|
|
|
Ok(proto) => Self::try_from(proto),
|
|
|
|
|
Err(err) => Err(DeserializationError(err.to_string())),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl TryFrom<proto::collab::CollabParams> for CollabParams {
|
|
|
|
|
type Error = EntityError;
|
|
|
|
|
|
|
|
|
|
fn try_from(proto: proto::collab::CollabParams) -> Result<Self, Self::Error> {
|
|
|
|
|
let collab_type_proto = proto::collab::CollabType::try_from(proto.collab_type).unwrap();
|
|
|
|
|
let collab_type = CollabType::from_proto(&collab_type_proto);
|
|
|
|
|
let embeddings = proto
|
|
|
|
|
.embeddings
|
|
|
|
|
.map(AFCollabEmbeddings::from_proto)
|
|
|
|
|
.transpose()?;
|
|
|
|
|
Ok(Self {
|
|
|
|
|
object_id: proto.object_id,
|
|
|
|
|
encoded_collab_v1: Bytes::from(proto.encoded_collab),
|
|
|
|
|
collab_type,
|
|
|
|
|
embeddings,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Serialize, Deserialize)]
|
|
|
|
|
struct CollabParamsV0 {
|
|
|
|
|
object_id: String,
|
|
|
|
|
@ -917,12 +964,72 @@ pub struct AFCollabEmbeddingParams {
|
|
|
|
|
pub embedding: Option<Vec<f32>>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl AFCollabEmbeddingParams {
|
|
|
|
|
pub fn from_proto(proto: &proto::collab::CollabEmbeddingsParams) -> Result<Self, EntityError> {
|
|
|
|
|
let collab_type_proto = proto::collab::CollabType::try_from(proto.collab_type).unwrap();
|
|
|
|
|
let collab_type = CollabType::from_proto(&collab_type_proto);
|
|
|
|
|
let content_type_proto =
|
|
|
|
|
proto::collab::EmbeddingContentType::try_from(proto.content_type).unwrap();
|
|
|
|
|
let content_type = EmbeddingContentType::from_proto(content_type_proto)?;
|
|
|
|
|
let embedding = if proto.embedding.is_empty() {
|
|
|
|
|
None
|
|
|
|
|
} else {
|
|
|
|
|
Some(proto.embedding.clone())
|
|
|
|
|
};
|
|
|
|
|
Ok(Self {
|
|
|
|
|
fragment_id: proto.fragment_id.clone(),
|
|
|
|
|
object_id: proto.object_id.clone(),
|
|
|
|
|
collab_type,
|
|
|
|
|
content_type,
|
|
|
|
|
content: proto.content.clone(),
|
|
|
|
|
embedding,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn to_proto(&self) -> proto::collab::CollabEmbeddingsParams {
|
|
|
|
|
proto::collab::CollabEmbeddingsParams {
|
|
|
|
|
fragment_id: self.fragment_id.clone(),
|
|
|
|
|
object_id: self.object_id.clone(),
|
|
|
|
|
collab_type: self.collab_type.to_proto() as i32,
|
|
|
|
|
content_type: self.content_type.to_proto() as i32,
|
|
|
|
|
content: self.content.clone(),
|
|
|
|
|
embedding: self.embedding.clone().unwrap_or_default(),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn to_protobuf_bytes(&self) -> Vec<u8> {
|
|
|
|
|
self.to_proto().encode_to_vec()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
|
|
|
|
pub struct AFCollabEmbeddings {
|
|
|
|
|
pub tokens_consumed: u32,
|
|
|
|
|
pub params: Vec<AFCollabEmbeddingParams>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl AFCollabEmbeddings {
|
|
|
|
|
pub fn from_proto(proto: proto::collab::CollabEmbeddings) -> Result<Self, EntityError> {
|
|
|
|
|
let mut params = vec![];
|
|
|
|
|
for param in proto.embeddings {
|
|
|
|
|
params.push(AFCollabEmbeddingParams::from_proto(¶m)?);
|
|
|
|
|
}
|
|
|
|
|
Ok(Self {
|
|
|
|
|
tokens_consumed: proto.tokens_consumed,
|
|
|
|
|
params,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn to_proto(&self) -> proto::collab::CollabEmbeddings {
|
|
|
|
|
let embeddings: Vec<proto::collab::CollabEmbeddingsParams> =
|
|
|
|
|
self.params.iter().map(|param| param.to_proto()).collect();
|
|
|
|
|
proto::collab::CollabEmbeddings {
|
|
|
|
|
tokens_consumed: self.tokens_consumed,
|
|
|
|
|
embeddings,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Type of content stored by the embedding.
|
|
|
|
|
/// Currently only plain text of the document is supported.
|
|
|
|
|
/// In the future, we might support other kinds like i.e. PDF, images or image-extracted text.
|
|
|
|
|
@ -933,6 +1040,24 @@ pub enum EmbeddingContentType {
|
|
|
|
|
PlainText = 0,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl EmbeddingContentType {
|
|
|
|
|
pub fn from_proto(proto: proto::collab::EmbeddingContentType) -> Result<Self, EntityError> {
|
|
|
|
|
match proto {
|
|
|
|
|
proto::collab::EmbeddingContentType::PlainText => Ok(EmbeddingContentType::PlainText),
|
|
|
|
|
proto::collab::EmbeddingContentType::Unknown => Err(InvalidData(format!(
|
|
|
|
|
"{} is not a supported embedding type",
|
|
|
|
|
proto.as_str_name()
|
|
|
|
|
))),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn to_proto(&self) -> proto::collab::EmbeddingContentType {
|
|
|
|
|
match self {
|
|
|
|
|
EmbeddingContentType::PlainText => proto::collab::EmbeddingContentType::PlainText,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
|
|
|
pub struct UpdateChatMessageResponse {
|
|
|
|
|
pub answer: Option<ChatMessage>,
|
|
|
|
|
@ -1286,8 +1411,13 @@ pub struct ApproveAccessRequestParams {
|
|
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
|
mod test {
|
|
|
|
|
use crate::dto::{CollabParams, CollabParamsV0};
|
|
|
|
|
use collab_entity::CollabType;
|
|
|
|
|
use crate::dto::{
|
|
|
|
|
AFCollabEmbeddingParams, AFCollabEmbeddings, CollabParams, CollabParamsV0, EmbeddingContentType,
|
|
|
|
|
};
|
|
|
|
|
use crate::error::EntityError;
|
|
|
|
|
use bytes::Bytes;
|
|
|
|
|
use collab_entity::{proto, CollabType};
|
|
|
|
|
use prost::Message;
|
|
|
|
|
use uuid::Uuid;
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
@ -1357,4 +1487,77 @@ mod test {
|
|
|
|
|
assert_eq!(collab_params.collab_type, v0.collab_type);
|
|
|
|
|
assert_eq!(collab_params.encoded_collab_v1, v0.encoded_collab_v1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn deserialization_using_protobuf() {
|
|
|
|
|
let collab_params_with_embeddings = CollabParams {
|
|
|
|
|
object_id: "object_id".to_string(),
|
|
|
|
|
collab_type: CollabType::Document,
|
|
|
|
|
encoded_collab_v1: Bytes::default(),
|
|
|
|
|
embeddings: Some(AFCollabEmbeddings {
|
|
|
|
|
tokens_consumed: 100,
|
|
|
|
|
params: vec![AFCollabEmbeddingParams {
|
|
|
|
|
fragment_id: "fragment_id".to_string(),
|
|
|
|
|
object_id: "object_id".to_string(),
|
|
|
|
|
collab_type: CollabType::Document,
|
|
|
|
|
content_type: EmbeddingContentType::PlainText,
|
|
|
|
|
content: "content".to_string(),
|
|
|
|
|
embedding: Some(vec![1.0, 2.0, 3.0]),
|
|
|
|
|
}],
|
|
|
|
|
}),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let protobuf_encoded = collab_params_with_embeddings.to_protobuf_bytes();
|
|
|
|
|
let collab_params_decoded = CollabParams::from_protobuf_bytes(&protobuf_encoded).unwrap();
|
|
|
|
|
assert_eq!(collab_params_with_embeddings, collab_params_decoded);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn deserialize_collab_params_without_embeddings() {
|
|
|
|
|
let collab_params = CollabParams {
|
|
|
|
|
object_id: "object_id".to_string(),
|
|
|
|
|
collab_type: CollabType::Document,
|
|
|
|
|
encoded_collab_v1: Bytes::from(vec![1, 2, 3]),
|
|
|
|
|
embeddings: Some(AFCollabEmbeddings {
|
|
|
|
|
tokens_consumed: 100,
|
|
|
|
|
params: vec![AFCollabEmbeddingParams {
|
|
|
|
|
fragment_id: "fragment_id".to_string(),
|
|
|
|
|
object_id: "object_id".to_string(),
|
|
|
|
|
collab_type: CollabType::Document,
|
|
|
|
|
content_type: EmbeddingContentType::PlainText,
|
|
|
|
|
content: "content".to_string(),
|
|
|
|
|
embedding: None,
|
|
|
|
|
}],
|
|
|
|
|
}),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let protobuf_encoded = collab_params.to_protobuf_bytes();
|
|
|
|
|
let collab_params_decoded = CollabParams::from_protobuf_bytes(&protobuf_encoded).unwrap();
|
|
|
|
|
assert_eq!(collab_params, collab_params_decoded);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn deserialize_collab_params_with_unknown_embedding_type() {
|
|
|
|
|
let invalid_serialization = proto::collab::CollabParams {
|
|
|
|
|
object_id: "object_id".to_string(),
|
|
|
|
|
encoded_collab: vec![1, 2, 3],
|
|
|
|
|
collab_type: proto::collab::CollabType::Document as i32,
|
|
|
|
|
embeddings: Some(proto::collab::CollabEmbeddings {
|
|
|
|
|
tokens_consumed: 100,
|
|
|
|
|
embeddings: vec![proto::collab::CollabEmbeddingsParams {
|
|
|
|
|
fragment_id: "fragment_id".to_string(),
|
|
|
|
|
object_id: "object_id".to_string(),
|
|
|
|
|
collab_type: proto::collab::CollabType::Document as i32,
|
|
|
|
|
content_type: proto::collab::EmbeddingContentType::Unknown as i32,
|
|
|
|
|
content: "content".to_string(),
|
|
|
|
|
embedding: vec![1.0, 2.0, 3.0],
|
|
|
|
|
}],
|
|
|
|
|
}),
|
|
|
|
|
}
|
|
|
|
|
.encode_to_vec();
|
|
|
|
|
|
|
|
|
|
let result = CollabParams::from_protobuf_bytes(&invalid_serialization);
|
|
|
|
|
assert!(result.is_err());
|
|
|
|
|
assert!(matches!(result, Err(EntityError::InvalidData(_))));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|