feat: support protobuf serialization and deserialization for CollabParams (#834)

This commit is contained in:
Khor Shu Heng 2024-10-02 09:49:55 +08:00 committed by GitHub
parent 96d7ae8b95
commit 3b320b0619
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 219 additions and 4 deletions

1
Cargo.lock generated
View File

@ -2798,6 +2798,7 @@ dependencies = [
"bytes",
"chrono",
"collab-entity",
"prost",
"serde",
"serde_json",
"serde_repr",

View File

@ -22,3 +22,4 @@ app-error = { workspace = true }
bincode = "1.3.3"
appflowy-ai-client = { workspace = true, features = ["dto"] }
bytes.workspace = true
prost = "0.12"

View File

@ -1,8 +1,12 @@
use crate::error::EntityError;
use crate::error::EntityError::{DeserializationError, InvalidData};
use crate::util::{validate_not_empty_payload, validate_not_empty_str};
use appflowy_ai_client::dto::AIModel;
use bytes::Bytes;
use chrono::{DateTime, Utc};
use collab_entity::proto;
use collab_entity::CollabType;
use prost::Message;
use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
use std::cmp::Ordering;
@ -62,7 +66,7 @@ impl CreateCollabParams {
pub struct CollabIndexParams {}
#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
#[derive(Debug, Clone, Validate, Serialize, Deserialize, PartialEq)]
pub struct CollabParams {
#[validate(custom = "validate_not_empty_str")]
pub object_id: String,
@ -107,7 +111,50 @@ impl CollabParams {
},
}
}
pub fn to_proto(&self) -> proto::collab::CollabParams {
proto::collab::CollabParams {
object_id: self.object_id.clone(),
encoded_collab: self.encoded_collab_v1.to_vec(),
collab_type: self.collab_type.to_proto() as i32,
embeddings: self
.embeddings
.as_ref()
.map(|embeddings| embeddings.to_proto()),
}
}
pub fn to_protobuf_bytes(&self) -> Vec<u8> {
self.to_proto().encode_to_vec()
}
pub fn from_protobuf_bytes(bytes: &[u8]) -> Result<Self, EntityError> {
match proto::collab::CollabParams::decode(bytes) {
Ok(proto) => Self::try_from(proto),
Err(err) => Err(DeserializationError(err.to_string())),
}
}
}
impl TryFrom<proto::collab::CollabParams> for CollabParams {
type Error = EntityError;
fn try_from(proto: proto::collab::CollabParams) -> Result<Self, Self::Error> {
let collab_type_proto = proto::collab::CollabType::try_from(proto.collab_type).unwrap();
let collab_type = CollabType::from_proto(&collab_type_proto);
let embeddings = proto
.embeddings
.map(AFCollabEmbeddings::from_proto)
.transpose()?;
Ok(Self {
object_id: proto.object_id,
encoded_collab_v1: Bytes::from(proto.encoded_collab),
collab_type,
embeddings,
})
}
}
#[derive(Serialize, Deserialize)]
struct CollabParamsV0 {
object_id: String,
@ -917,12 +964,72 @@ pub struct AFCollabEmbeddingParams {
pub embedding: Option<Vec<f32>>,
}
impl AFCollabEmbeddingParams {
pub fn from_proto(proto: &proto::collab::CollabEmbeddingsParams) -> Result<Self, EntityError> {
let collab_type_proto = proto::collab::CollabType::try_from(proto.collab_type).unwrap();
let collab_type = CollabType::from_proto(&collab_type_proto);
let content_type_proto =
proto::collab::EmbeddingContentType::try_from(proto.content_type).unwrap();
let content_type = EmbeddingContentType::from_proto(content_type_proto)?;
let embedding = if proto.embedding.is_empty() {
None
} else {
Some(proto.embedding.clone())
};
Ok(Self {
fragment_id: proto.fragment_id.clone(),
object_id: proto.object_id.clone(),
collab_type,
content_type,
content: proto.content.clone(),
embedding,
})
}
pub fn to_proto(&self) -> proto::collab::CollabEmbeddingsParams {
proto::collab::CollabEmbeddingsParams {
fragment_id: self.fragment_id.clone(),
object_id: self.object_id.clone(),
collab_type: self.collab_type.to_proto() as i32,
content_type: self.content_type.to_proto() as i32,
content: self.content.clone(),
embedding: self.embedding.clone().unwrap_or_default(),
}
}
pub fn to_protobuf_bytes(&self) -> Vec<u8> {
self.to_proto().encode_to_vec()
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AFCollabEmbeddings {
pub tokens_consumed: u32,
pub params: Vec<AFCollabEmbeddingParams>,
}
impl AFCollabEmbeddings {
pub fn from_proto(proto: proto::collab::CollabEmbeddings) -> Result<Self, EntityError> {
let mut params = vec![];
for param in proto.embeddings {
params.push(AFCollabEmbeddingParams::from_proto(&param)?);
}
Ok(Self {
tokens_consumed: proto.tokens_consumed,
params,
})
}
pub fn to_proto(&self) -> proto::collab::CollabEmbeddings {
let embeddings: Vec<proto::collab::CollabEmbeddingsParams> =
self.params.iter().map(|param| param.to_proto()).collect();
proto::collab::CollabEmbeddings {
tokens_consumed: self.tokens_consumed,
embeddings,
}
}
}
/// Type of content stored by the embedding.
/// Currently only plain text of the document is supported.
/// In the future, we might support other kinds like i.e. PDF, images or image-extracted text.
@ -933,6 +1040,24 @@ pub enum EmbeddingContentType {
PlainText = 0,
}
impl EmbeddingContentType {
pub fn from_proto(proto: proto::collab::EmbeddingContentType) -> Result<Self, EntityError> {
match proto {
proto::collab::EmbeddingContentType::PlainText => Ok(EmbeddingContentType::PlainText),
proto::collab::EmbeddingContentType::Unknown => Err(InvalidData(format!(
"{} is not a supported embedding type",
proto.as_str_name()
))),
}
}
pub fn to_proto(&self) -> proto::collab::EmbeddingContentType {
match self {
EmbeddingContentType::PlainText => proto::collab::EmbeddingContentType::PlainText,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdateChatMessageResponse {
pub answer: Option<ChatMessage>,
@ -1286,8 +1411,13 @@ pub struct ApproveAccessRequestParams {
#[cfg(test)]
mod test {
use crate::dto::{CollabParams, CollabParamsV0};
use collab_entity::CollabType;
use crate::dto::{
AFCollabEmbeddingParams, AFCollabEmbeddings, CollabParams, CollabParamsV0, EmbeddingContentType,
};
use crate::error::EntityError;
use bytes::Bytes;
use collab_entity::{proto, CollabType};
use prost::Message;
use uuid::Uuid;
#[test]
@ -1357,4 +1487,77 @@ mod test {
assert_eq!(collab_params.collab_type, v0.collab_type);
assert_eq!(collab_params.encoded_collab_v1, v0.encoded_collab_v1);
}
#[test]
fn deserialization_using_protobuf() {
let collab_params_with_embeddings = CollabParams {
object_id: "object_id".to_string(),
collab_type: CollabType::Document,
encoded_collab_v1: Bytes::default(),
embeddings: Some(AFCollabEmbeddings {
tokens_consumed: 100,
params: vec![AFCollabEmbeddingParams {
fragment_id: "fragment_id".to_string(),
object_id: "object_id".to_string(),
collab_type: CollabType::Document,
content_type: EmbeddingContentType::PlainText,
content: "content".to_string(),
embedding: Some(vec![1.0, 2.0, 3.0]),
}],
}),
};
let protobuf_encoded = collab_params_with_embeddings.to_protobuf_bytes();
let collab_params_decoded = CollabParams::from_protobuf_bytes(&protobuf_encoded).unwrap();
assert_eq!(collab_params_with_embeddings, collab_params_decoded);
}
#[test]
fn deserialize_collab_params_without_embeddings() {
let collab_params = CollabParams {
object_id: "object_id".to_string(),
collab_type: CollabType::Document,
encoded_collab_v1: Bytes::from(vec![1, 2, 3]),
embeddings: Some(AFCollabEmbeddings {
tokens_consumed: 100,
params: vec![AFCollabEmbeddingParams {
fragment_id: "fragment_id".to_string(),
object_id: "object_id".to_string(),
collab_type: CollabType::Document,
content_type: EmbeddingContentType::PlainText,
content: "content".to_string(),
embedding: None,
}],
}),
};
let protobuf_encoded = collab_params.to_protobuf_bytes();
let collab_params_decoded = CollabParams::from_protobuf_bytes(&protobuf_encoded).unwrap();
assert_eq!(collab_params, collab_params_decoded);
}
#[test]
fn deserialize_collab_params_with_unknown_embedding_type() {
let invalid_serialization = proto::collab::CollabParams {
object_id: "object_id".to_string(),
encoded_collab: vec![1, 2, 3],
collab_type: proto::collab::CollabType::Document as i32,
embeddings: Some(proto::collab::CollabEmbeddings {
tokens_consumed: 100,
embeddings: vec![proto::collab::CollabEmbeddingsParams {
fragment_id: "fragment_id".to_string(),
object_id: "object_id".to_string(),
collab_type: proto::collab::CollabType::Document as i32,
content_type: proto::collab::EmbeddingContentType::Unknown as i32,
content: "content".to_string(),
embedding: vec![1.0, 2.0, 3.0],
}],
}),
}
.encode_to_vec();
let result = CollabParams::from_protobuf_bytes(&invalid_serialization);
assert!(result.is_err());
assert!(matches!(result, Err(EntityError::InvalidData(_))));
}
}

View File

@ -0,0 +1,9 @@
#[derive(Debug, thiserror::Error)]
pub enum EntityError {
#[error("Invalid data: {0}")]
InvalidData(String),
#[error("Deserialization error: {0}")]
DeserializationError(String),
#[error("Serialization error: {0}")]
SerializationError(String),
}

View File

@ -1,3 +1,4 @@
pub mod dto;
pub mod error;
pub mod file_dto;
mod util;

View File

@ -3,7 +3,7 @@
# Generate the current dependency list
cargo tree > current_deps.txt
BASELINE_COUNT=620
BASELINE_COUNT=621
CURRENT_COUNT=$(cat current_deps.txt | wc -l)
echo "Expected dependency count (baseline): $BASELINE_COUNT"