diff --git a/Cargo.lock b/Cargo.lock index beb14515..577141eb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4302,18 +4302,18 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.1.4" +version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0302c4a0442c456bd56f841aee5c3bfd17967563f6fadc9ceb9f9c23cf3807e0" +checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.4" +version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "266c042b60c9c76b8d53061e52b2e0d1116abc57cefc8c5cd671619a56ac3690" +checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", @@ -5680,11 +5680,13 @@ dependencies = [ "anyhow", "app-error", "appflowy-ai-client", + "bytes", "chrono", "collab-entity", "database-entity", "futures", "gotrue-entity", + "pin-project", "reqwest 0.11.27", "rust-s3", "serde", diff --git a/libs/client-api/src/lib.rs b/libs/client-api/src/lib.rs index 6c2a64d0..69b700f4 100644 --- a/libs/client-api/src/lib.rs +++ b/libs/client-api/src/lib.rs @@ -20,6 +20,8 @@ pub use native::*; mod wasm; #[cfg(target_arch = "wasm32")] pub use wasm::*; + +#[cfg(not(target_arch = "wasm32"))] mod http_chat; mod http_search; pub mod ws; diff --git a/libs/shared-entity/Cargo.toml b/libs/shared-entity/Cargo.toml index 27cfd819..d6e7136f 100644 --- a/libs/shared-entity/Cargo.toml +++ b/libs/shared-entity/Cargo.toml @@ -21,11 +21,13 @@ collab-entity = { workspace = true } app-error = { workspace = true } chrono = "0.4.31" appflowy-ai-client = { workspace = true, default-features = false, features = ["dto"] } +pin-project = "1.1.5" actix-web = { version = "4.4.1", default-features = false, features = ["http2"], optional = true } validator = { version = "0.16", features = ["validator_derive", "derive"], optional = true } rust-s3 = { version = "0.34.0-rc4", optional = true } futures = "0.3.30" +bytes = "1.6.0" [features] diff --git a/libs/shared-entity/src/json_stream.rs b/libs/shared-entity/src/json_stream.rs new file mode 100644 index 00000000..67c4cefc --- /dev/null +++ b/libs/shared-entity/src/json_stream.rs @@ -0,0 +1,84 @@ +use crate::response::{AppResponse, AppResponseError}; +use app_error::ErrorCode; +use futures::{Stream, TryStreamExt}; +use serde::de::DeserializeOwned; +use serde_json::StreamDeserializer; +use std::pin::Pin; +use std::task::{Context, Poll}; + +impl AppResponse +where + T: DeserializeOwned + 'static, +{ + pub async fn stream_response( + resp: reqwest::Response, + ) -> Result>, AppResponseError> { + let status_code = resp.status(); + if !status_code.is_success() { + let body = resp.text().await?; + return Err(AppResponseError::new(ErrorCode::Internal, body)); + } + + let stream = resp.bytes_stream().map_err(AppResponseError::from); + Ok(JsonStream::new(stream)) + } +} + +#[pin_project::pin_project] +pub struct JsonStream { + stream: Pin> + Send>>, + buffer: Vec, + _marker: std::marker::PhantomData, +} + +impl JsonStream { + pub fn new(stream: S) -> Self + where + S: Stream> + Send + 'static, + { + JsonStream { + stream: Box::pin(stream), + buffer: Vec::new(), + _marker: std::marker::PhantomData, + } + } +} + +impl Stream for JsonStream +where + T: DeserializeOwned, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + loop { + match futures::ready!(this.stream.as_mut().poll_next(cx)) { + Some(Ok(bytes)) => { + this.buffer.extend_from_slice(&bytes); + let de = StreamDeserializer::new(serde_json::de::SliceRead::new(this.buffer)); + let mut iter = de.into_iter(); + if let Some(result) = iter.next() { + match result { + Ok(value) => { + let remaining = iter.byte_offset(); + this.buffer.drain(0..remaining); + return Poll::Ready(Some(Ok(value))); + }, + Err(err) => { + if err.is_eof() { + continue; + } else { + return Poll::Ready(Some(Err(AppResponseError::from(err)))); + } + }, + } + } + }, + Some(Err(err)) => return Poll::Ready(Some(Err(err))), + None => return Poll::Ready(None), + } + } + } +} diff --git a/libs/shared-entity/src/lib.rs b/libs/shared-entity/src/lib.rs index 7302c0ea..2c74ef3b 100644 --- a/libs/shared-entity/src/lib.rs +++ b/libs/shared-entity/src/lib.rs @@ -1,6 +1,9 @@ pub mod response; pub mod dto; + +#[cfg(not(target_arch = "wasm32"))] +mod json_stream; mod request; #[cfg(feature = "cloud")] mod response_actix; diff --git a/libs/shared-entity/src/response.rs b/libs/shared-entity/src/response.rs index 9c10ac3d..99434e64 100644 --- a/libs/shared-entity/src/response.rs +++ b/libs/shared-entity/src/response.rs @@ -3,7 +3,7 @@ use std::borrow::Cow; use app_error::AppError; pub use app_error::ErrorCode; -use futures::stream::{Stream, StreamExt}; +use serde::de::DeserializeOwned; use std::fmt::{Debug, Display}; #[cfg(feature = "cloud")] @@ -129,7 +129,7 @@ where impl AppResponse where - T: serde::de::DeserializeOwned + 'static, + T: DeserializeOwned + 'static, { pub async fn from_response(resp: reqwest::Response) -> Result { let status_code = resp.status(); @@ -142,24 +142,8 @@ where let resp = serde_json::from_slice(&bytes)?; Ok(resp) } - - pub async fn stream_response( - resp: reqwest::Response, - ) -> Result>, AppResponseError> { - let status_code = resp.status(); - if !status_code.is_success() { - let body = resp.text().await?; - return Err(AppResponseError::new(ErrorCode::Internal, body)); - } - - let stream = resp.bytes_stream().map(|item| { - item.map_err(AppResponseError::from).and_then(|bytes| { - serde_json::from_slice::(bytes.as_ref()).map_err(AppResponseError::from) - }) - }); - Ok(stream) - } } + #[derive(Clone, Debug, Serialize, Deserialize, thiserror::Error)] pub struct AppResponseError { #[serde(deserialize_with = "default_error_code")]