AppFlowy-Cloud/libs/infra/src/reqwest.rs

224 lines
6.6 KiB
Rust

use anyhow::anyhow;
use anyhow::Error;
use bytes::{Bytes, BytesMut};
use futures::{ready, Stream};
use std::marker::PhantomData;
use std::pin::Pin;
use pin_project::pin_project;
use serde::de::DeserializeOwned;
use serde_json::de::SliceRead;
use serde_json::StreamDeserializer;
use std::error::Error as StdError;
use std::task::{Context, Poll};
pub async fn check_response(resp: reqwest::Response) -> Result<(), Error> {
let status_code = resp.status();
if !status_code.is_success() {
let body = resp.text().await?;
anyhow::bail!("got error code: {}, body: {}", status_code, body)
}
resp.bytes().await?;
Ok(())
}
pub async fn from_response<T>(resp: reqwest::Response) -> Result<T, Error>
where
T: serde::de::DeserializeOwned,
{
let status_code = resp.status();
if !status_code.is_success() {
let body = resp.text().await?;
anyhow::bail!("got error code: {}, body: {}", status_code, body)
}
from_body(resp).await
}
pub async fn from_body<T>(resp: reqwest::Response) -> Result<T, Error>
where
T: serde::de::DeserializeOwned,
{
let status_code = resp.status();
let bytes = resp.bytes().await?;
serde_json::from_slice(&bytes).map_err(|e| {
anyhow!(
"deserialize error: {}, status: {}, body: {}",
status_code,
e,
String::from_utf8_lossy(&bytes)
)
})
}
#[pin_project]
pub struct JsonStream<T, E, SE> {
#[pin]
stream: Pin<Box<dyn Stream<Item = Result<Bytes, E>> + Send>>,
buffer: Vec<u8>,
_marker: PhantomData<T>,
_marker_error: PhantomData<SE>,
}
impl<T, E, SE> JsonStream<T, E, SE>
where
E: From<SE> + From<serde_json::Error> + std::error::Error + Send + Sync + 'static,
SE: std::error::Error + Send + Sync + 'static,
{
pub fn new<S>(stream: S) -> Self
where
S: Stream<Item = Result<Bytes, E>> + Send + 'static,
{
JsonStream {
stream: Box::pin(stream),
buffer: Vec::new(),
_marker: PhantomData,
_marker_error: PhantomData,
}
}
}
impl<T, E, SE> Stream for JsonStream<T, E, SE>
where
T: DeserializeOwned,
E: From<SE> + From<serde_json::Error> + std::error::Error + Send + Sync + 'static,
SE: std::error::Error + Send + Sync + 'static,
{
type Item = Result<T, E>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
loop {
// Poll for the next chunk of data from the underlying stream
match ready!(this.stream.as_mut().poll_next(cx)) {
Some(Ok(bytes)) => {
// Append the new bytes to the buffer
this.buffer.extend_from_slice(&bytes);
// Create a StreamDeserializer to deserialize the bytes into T
let mut de = StreamDeserializer::new(SliceRead::new(this.buffer));
// Check if there's a valid deserialized object in the stream
match de.next() {
Some(Ok(value)) => {
// Determine the offset of the successfully deserialized data
let offset = de.byte_offset();
// Drain the buffer up to the byte offset to remove the consumed bytes
this.buffer.drain(0..offset);
return Poll::Ready(Some(Ok(value)));
},
Some(Err(err)) if err.is_eof() => {
// Wait for more data if EOF indicates incomplete data
return Poll::Pending;
},
Some(Err(err)) => {
// Return other deserialization errors wrapped in SE
return Poll::Ready(Some(Err(err.into())));
},
None => {
// No complete object is ready, wait for more data
continue;
},
}
},
Some(Err(err)) => {
// Convert the error to SE
return Poll::Ready(Some(Err(err)));
},
None => {
// Stream has ended; handle any remaining data in the buffer
if this.buffer.is_empty() {
return Poll::Ready(None);
}
// Try to deserialize any remaining data in the buffer
let mut de = StreamDeserializer::new(SliceRead::new(this.buffer));
match de.next() {
Some(Ok(value)) => {
let offset = de.byte_offset();
this.buffer.drain(0..offset);
return Poll::Ready(Some(Ok(value)));
},
Some(Err(err)) if err.is_eof() => {
// If EOF and buffer is incomplete, return None to indicate stream end
return Poll::Ready(None);
},
Some(Err(err)) => {
// Return any other errors that occur during deserialization
return Poll::Ready(Some(Err(err.into())));
},
None => {
// No more data to process; end the stream
return Poll::Ready(None);
},
}
},
}
}
}
}
/// Represents a stream of text lines delimited by newlines.
#[pin_project]
pub struct NewlineStream<E> {
#[pin]
stream: Pin<Box<dyn Stream<Item = Result<Bytes, E>> + Send>>,
buffer: BytesMut,
_marker: PhantomData<E>,
}
impl<E> NewlineStream<E> {
pub fn new<S>(stream: S) -> Self
where
S: Stream<Item = Result<Bytes, E>> + Send + 'static,
{
NewlineStream {
stream: Box::pin(stream),
buffer: BytesMut::new(),
_marker: PhantomData,
}
}
}
impl<E> Stream for NewlineStream<E>
where
E: StdError + Send + Sync + 'static + From<std::string::FromUtf8Error>,
{
type Item = Result<String, E>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
loop {
match ready!(this.stream.as_mut().poll_next(cx)) {
Some(Ok(bytes)) => {
this.buffer.extend_from_slice(&bytes);
if let Some(pos) = this.buffer.iter().position(|&b| b == b'\n') {
let line = this.buffer.split_to(pos + 1);
let line = &line[..line.len() - 1]; // Remove the newline character
match String::from_utf8(line.to_vec()) {
Ok(value) => return Poll::Ready(Some(Ok(value))),
Err(err) => return Poll::Ready(Some(Err(E::from(err)))),
}
}
},
Some(Err(err)) => return Poll::Ready(Some(Err(err))),
None => {
if !this.buffer.is_empty() {
match String::from_utf8(this.buffer.to_vec()) {
Ok(value) => {
this.buffer.clear();
return Poll::Ready(Some(Ok(value)));
},
Err(err) => return Poll::Ready(Some(Err(E::from(err)))),
}
} else {
return Poll::Ready(None);
}
},
}
}
}
}