From 3b320b0619ade9a8622a3b7267978ce3740e7310 Mon Sep 17 00:00:00 2001 From: Khor Shu Heng <32997938+khorshuheng@users.noreply.github.com> Date: Wed, 2 Oct 2024 09:49:55 +0800 Subject: [PATCH] feat: support protobuf serialization and deserialization for CollabParams (#834) --- Cargo.lock | 1 + libs/database-entity/Cargo.toml | 1 + libs/database-entity/src/dto.rs | 209 +++++++++++++++++++++++++++++- libs/database-entity/src/error.rs | 9 ++ libs/database-entity/src/lib.rs | 1 + script/client_api_deps_check.sh | 2 +- 6 files changed, 219 insertions(+), 4 deletions(-) create mode 100644 libs/database-entity/src/error.rs diff --git a/Cargo.lock b/Cargo.lock index 82159a02..823cb7bf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2798,6 +2798,7 @@ dependencies = [ "bytes", "chrono", "collab-entity", + "prost", "serde", "serde_json", "serde_repr", diff --git a/libs/database-entity/Cargo.toml b/libs/database-entity/Cargo.toml index 5206caf9..95bd7524 100644 --- a/libs/database-entity/Cargo.toml +++ b/libs/database-entity/Cargo.toml @@ -22,3 +22,4 @@ app-error = { workspace = true } bincode = "1.3.3" appflowy-ai-client = { workspace = true, features = ["dto"] } bytes.workspace = true +prost = "0.12" diff --git a/libs/database-entity/src/dto.rs b/libs/database-entity/src/dto.rs index 5c618f8a..646360ca 100644 --- a/libs/database-entity/src/dto.rs +++ b/libs/database-entity/src/dto.rs @@ -1,8 +1,12 @@ +use crate::error::EntityError; +use crate::error::EntityError::{DeserializationError, InvalidData}; use crate::util::{validate_not_empty_payload, validate_not_empty_str}; use appflowy_ai_client::dto::AIModel; use bytes::Bytes; use chrono::{DateTime, Utc}; +use collab_entity::proto; use collab_entity::CollabType; +use prost::Message; use serde::{Deserialize, Serialize}; use serde_repr::{Deserialize_repr, Serialize_repr}; use std::cmp::Ordering; @@ -62,7 +66,7 @@ impl CreateCollabParams { pub struct CollabIndexParams {} -#[derive(Debug, Clone, Validate, Serialize, Deserialize)] +#[derive(Debug, Clone, Validate, Serialize, Deserialize, PartialEq)] pub struct CollabParams { #[validate(custom = "validate_not_empty_str")] pub object_id: String, @@ -107,7 +111,50 @@ impl CollabParams { }, } } + + pub fn to_proto(&self) -> proto::collab::CollabParams { + proto::collab::CollabParams { + object_id: self.object_id.clone(), + encoded_collab: self.encoded_collab_v1.to_vec(), + collab_type: self.collab_type.to_proto() as i32, + embeddings: self + .embeddings + .as_ref() + .map(|embeddings| embeddings.to_proto()), + } + } + + pub fn to_protobuf_bytes(&self) -> Vec { + self.to_proto().encode_to_vec() + } + + pub fn from_protobuf_bytes(bytes: &[u8]) -> Result { + match proto::collab::CollabParams::decode(bytes) { + Ok(proto) => Self::try_from(proto), + Err(err) => Err(DeserializationError(err.to_string())), + } + } } + +impl TryFrom for CollabParams { + type Error = EntityError; + + fn try_from(proto: proto::collab::CollabParams) -> Result { + let collab_type_proto = proto::collab::CollabType::try_from(proto.collab_type).unwrap(); + let collab_type = CollabType::from_proto(&collab_type_proto); + let embeddings = proto + .embeddings + .map(AFCollabEmbeddings::from_proto) + .transpose()?; + Ok(Self { + object_id: proto.object_id, + encoded_collab_v1: Bytes::from(proto.encoded_collab), + collab_type, + embeddings, + }) + } +} + #[derive(Serialize, Deserialize)] struct CollabParamsV0 { object_id: String, @@ -917,12 +964,72 @@ pub struct AFCollabEmbeddingParams { pub embedding: Option>, } +impl AFCollabEmbeddingParams { + pub fn from_proto(proto: &proto::collab::CollabEmbeddingsParams) -> Result { + let collab_type_proto = proto::collab::CollabType::try_from(proto.collab_type).unwrap(); + let collab_type = CollabType::from_proto(&collab_type_proto); + let content_type_proto = + proto::collab::EmbeddingContentType::try_from(proto.content_type).unwrap(); + let content_type = EmbeddingContentType::from_proto(content_type_proto)?; + let embedding = if proto.embedding.is_empty() { + None + } else { + Some(proto.embedding.clone()) + }; + Ok(Self { + fragment_id: proto.fragment_id.clone(), + object_id: proto.object_id.clone(), + collab_type, + content_type, + content: proto.content.clone(), + embedding, + }) + } + + pub fn to_proto(&self) -> proto::collab::CollabEmbeddingsParams { + proto::collab::CollabEmbeddingsParams { + fragment_id: self.fragment_id.clone(), + object_id: self.object_id.clone(), + collab_type: self.collab_type.to_proto() as i32, + content_type: self.content_type.to_proto() as i32, + content: self.content.clone(), + embedding: self.embedding.clone().unwrap_or_default(), + } + } + + pub fn to_protobuf_bytes(&self) -> Vec { + self.to_proto().encode_to_vec() + } +} + #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct AFCollabEmbeddings { pub tokens_consumed: u32, pub params: Vec, } +impl AFCollabEmbeddings { + pub fn from_proto(proto: proto::collab::CollabEmbeddings) -> Result { + let mut params = vec![]; + for param in proto.embeddings { + params.push(AFCollabEmbeddingParams::from_proto(¶m)?); + } + Ok(Self { + tokens_consumed: proto.tokens_consumed, + params, + }) + } + + pub fn to_proto(&self) -> proto::collab::CollabEmbeddings { + let embeddings: Vec = + self.params.iter().map(|param| param.to_proto()).collect(); + proto::collab::CollabEmbeddings { + tokens_consumed: self.tokens_consumed, + embeddings, + } + } +} + /// Type of content stored by the embedding. /// Currently only plain text of the document is supported. /// In the future, we might support other kinds like i.e. PDF, images or image-extracted text. @@ -933,6 +1040,24 @@ pub enum EmbeddingContentType { PlainText = 0, } +impl EmbeddingContentType { + pub fn from_proto(proto: proto::collab::EmbeddingContentType) -> Result { + match proto { + proto::collab::EmbeddingContentType::PlainText => Ok(EmbeddingContentType::PlainText), + proto::collab::EmbeddingContentType::Unknown => Err(InvalidData(format!( + "{} is not a supported embedding type", + proto.as_str_name() + ))), + } + } + + pub fn to_proto(&self) -> proto::collab::EmbeddingContentType { + match self { + EmbeddingContentType::PlainText => proto::collab::EmbeddingContentType::PlainText, + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct UpdateChatMessageResponse { pub answer: Option, @@ -1286,8 +1411,13 @@ pub struct ApproveAccessRequestParams { #[cfg(test)] mod test { - use crate::dto::{CollabParams, CollabParamsV0}; - use collab_entity::CollabType; + use crate::dto::{ + AFCollabEmbeddingParams, AFCollabEmbeddings, CollabParams, CollabParamsV0, EmbeddingContentType, + }; + use crate::error::EntityError; + use bytes::Bytes; + use collab_entity::{proto, CollabType}; + use prost::Message; use uuid::Uuid; #[test] @@ -1357,4 +1487,77 @@ mod test { assert_eq!(collab_params.collab_type, v0.collab_type); assert_eq!(collab_params.encoded_collab_v1, v0.encoded_collab_v1); } + + #[test] + fn deserialization_using_protobuf() { + let collab_params_with_embeddings = CollabParams { + object_id: "object_id".to_string(), + collab_type: CollabType::Document, + encoded_collab_v1: Bytes::default(), + embeddings: Some(AFCollabEmbeddings { + tokens_consumed: 100, + params: vec![AFCollabEmbeddingParams { + fragment_id: "fragment_id".to_string(), + object_id: "object_id".to_string(), + collab_type: CollabType::Document, + content_type: EmbeddingContentType::PlainText, + content: "content".to_string(), + embedding: Some(vec![1.0, 2.0, 3.0]), + }], + }), + }; + + let protobuf_encoded = collab_params_with_embeddings.to_protobuf_bytes(); + let collab_params_decoded = CollabParams::from_protobuf_bytes(&protobuf_encoded).unwrap(); + assert_eq!(collab_params_with_embeddings, collab_params_decoded); + } + + #[test] + fn deserialize_collab_params_without_embeddings() { + let collab_params = CollabParams { + object_id: "object_id".to_string(), + collab_type: CollabType::Document, + encoded_collab_v1: Bytes::from(vec![1, 2, 3]), + embeddings: Some(AFCollabEmbeddings { + tokens_consumed: 100, + params: vec![AFCollabEmbeddingParams { + fragment_id: "fragment_id".to_string(), + object_id: "object_id".to_string(), + collab_type: CollabType::Document, + content_type: EmbeddingContentType::PlainText, + content: "content".to_string(), + embedding: None, + }], + }), + }; + + let protobuf_encoded = collab_params.to_protobuf_bytes(); + let collab_params_decoded = CollabParams::from_protobuf_bytes(&protobuf_encoded).unwrap(); + assert_eq!(collab_params, collab_params_decoded); + } + + #[test] + fn deserialize_collab_params_with_unknown_embedding_type() { + let invalid_serialization = proto::collab::CollabParams { + object_id: "object_id".to_string(), + encoded_collab: vec![1, 2, 3], + collab_type: proto::collab::CollabType::Document as i32, + embeddings: Some(proto::collab::CollabEmbeddings { + tokens_consumed: 100, + embeddings: vec![proto::collab::CollabEmbeddingsParams { + fragment_id: "fragment_id".to_string(), + object_id: "object_id".to_string(), + collab_type: proto::collab::CollabType::Document as i32, + content_type: proto::collab::EmbeddingContentType::Unknown as i32, + content: "content".to_string(), + embedding: vec![1.0, 2.0, 3.0], + }], + }), + } + .encode_to_vec(); + + let result = CollabParams::from_protobuf_bytes(&invalid_serialization); + assert!(result.is_err()); + assert!(matches!(result, Err(EntityError::InvalidData(_)))); + } } diff --git a/libs/database-entity/src/error.rs b/libs/database-entity/src/error.rs new file mode 100644 index 00000000..f3ae3b49 --- /dev/null +++ b/libs/database-entity/src/error.rs @@ -0,0 +1,9 @@ +#[derive(Debug, thiserror::Error)] +pub enum EntityError { + #[error("Invalid data: {0}")] + InvalidData(String), + #[error("Deserialization error: {0}")] + DeserializationError(String), + #[error("Serialization error: {0}")] + SerializationError(String), +} diff --git a/libs/database-entity/src/lib.rs b/libs/database-entity/src/lib.rs index 0d3d14ec..1c4fa4cc 100644 --- a/libs/database-entity/src/lib.rs +++ b/libs/database-entity/src/lib.rs @@ -1,3 +1,4 @@ pub mod dto; +pub mod error; pub mod file_dto; mod util; diff --git a/script/client_api_deps_check.sh b/script/client_api_deps_check.sh index 3e831065..6d296d7f 100755 --- a/script/client_api_deps_check.sh +++ b/script/client_api_deps_check.sh @@ -3,7 +3,7 @@ # Generate the current dependency list cargo tree > current_deps.txt -BASELINE_COUNT=620 +BASELINE_COUNT=621 CURRENT_COUNT=$(cat current_deps.txt | wc -l) echo "Expected dependency count (baseline): $BASELINE_COUNT"