fix: chat stream (#620)

* chore: fix chat stream
This commit is contained in:
Nathan.fooo 2024-06-13 22:46:49 +08:00 committed by GitHub
parent 9d3d28ad89
commit 430e3e15c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 16 additions and 63 deletions

View File

@ -1,12 +1,13 @@
use crate::http::log_request_id;
use crate::Client;
use bytes::Bytes;
use database_entity::dto::{
ChatMessage, CreateAnswerMessageParams, CreateChatMessageParams, CreateChatParams, MessageCursor,
RepeatedChatMessage, UpdateChatMessageContentParams,
};
use futures_core::Stream;
use reqwest::Method;
use shared_entity::dto::ai_dto::{RepeatedRelatedQuestion, StringOrMessage};
use shared_entity::dto::ai_dto::RepeatedRelatedQuestion;
use shared_entity::response::{AppResponse, AppResponseError};
impl Client {
@ -109,7 +110,7 @@ impl Client {
workspace_id: &str,
chat_id: &str,
message_id: i64,
) -> Result<impl Stream<Item = Result<StringOrMessage, AppResponseError>>, AppResponseError> {
) -> Result<impl Stream<Item = Result<Bytes, AppResponseError>>, AppResponseError> {
let url = format!(
"{}/api/chat/{workspace_id}/{chat_id}/{message_id}/answer/stream",
self.base_url

View File

@ -44,7 +44,7 @@ where
}
pub async fn answer_response_stream(
resp: reqwest::Response,
) -> Result<impl Stream<Item = Result<StringOrMessage, AppResponseError>>, AppResponseError> {
) -> Result<impl Stream<Item = Result<Bytes, AppResponseError>>, AppResponseError> {
let status_code = resp.status();
if !status_code.is_success() {
let body = resp.text().await?;
@ -52,7 +52,7 @@ where
}
let stream = resp.bytes_stream().map_err(AppResponseError::from);
Ok(AnswerStream::new(stream))
Ok(stream)
}
}
@ -228,7 +228,6 @@ impl Stream for AnswerStream {
const NEW_LINE: &[u8; 1] = b"\n";
if bytes.ends_with(NEW_LINE) {
let bytes = &bytes[..bytes.len() - NEW_LINE.len()];
return match String::from_utf8(bytes.to_vec()) {
Ok(value) => Poll::Ready(Some(Ok(StringOrMessage::Left(value)))),
Err(err) => Poll::Ready(Some(Err(AppResponseError::from(err)))),

View File

@ -28,7 +28,7 @@ use tokio::task;
use database::chat;
use database::chat::chat_ops::insert_answer_message;
use tracing::{error, instrument, trace};
use tracing::{instrument, trace};
use validator::Validate;
pub fn chat_scope() -> Scope {
@ -52,7 +52,7 @@ pub fn chat_scope() -> Scope {
// Create a question for given chat
web::resource("/{chat_id}/message/question").route(web::post().to(create_question_handler)),
)
// create a answer for given chat
// create an answer for given chat
.service(web::resource("/{chat_id}/message/answer").route(web::post().to(create_answer_handler)))
.service(
// Generate answer for given question.
@ -200,40 +200,11 @@ async fn answer_stream_handler(
.await
.map_err(|err| AppError::Internal(err.into()))?;
let finish_action = move |collected_bytes: Vec<u8>| {
task::spawn(async move {
if let Ok(final_message) = String::from_utf8(collected_bytes) {
match chat::chat_ops::insert_answer_message(
&state.pg_pool,
ChatAuthor::ai(),
&chat_id,
final_message,
question_id,
)
.await
{
Ok(message) => {
let json_bytes = serde_json::to_vec(&message)?;
Ok(Bytes::from(json_bytes))
},
Err(err) => {
error!("Failed to insert answer message: {}", err);
Err(AppError::Internal(err.into()))
},
}
} else {
error!("Stream finished with invalid UTF-8 data.");
Err(AppError::InvalidRequest("Invalid UTF-8 data".to_string()))
}
})
};
let new_answer_stream = answer_stream.map_err(AppError::from);
let finish_answer_stream = CollectingStream::new(new_answer_stream, finish_action);
Ok(
HttpResponse::Ok()
.content_type("text/event-stream")
.streaming(finish_answer_stream),
.streaming(new_answer_stream),
)
}
@ -312,6 +283,7 @@ where
}
}
#[allow(dead_code)]
#[pin_project]
pub struct CollectingStream<S, F> {
#[pin]
@ -352,11 +324,7 @@ where
CollectingStreamType::AnswerString => {
match this.stream.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
if let Some(&b'\n') = bytes.last() {
this.buffer.extend_from_slice(&bytes[..bytes.len() - 1]);
} else {
this.buffer.extend_from_slice(&bytes);
}
this.buffer.extend_from_slice(&bytes);
Poll::Ready(Some(Ok(bytes)))
},
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
@ -385,9 +353,9 @@ where
CollectingStreamType::AnswerMessage => {
if let Some(receiver) = this.result_receiver.as_mut() {
match receiver.poll_unpin(cx) {
Poll::Ready(Ok(result)) => {
Poll::Ready(Ok(_)) => {
this.result_receiver.take();
Poll::Ready(Some(result))
Poll::Ready(None)
},
Poll::Ready(Err(_)) => Poll::Ready(None),
Poll::Pending => Poll::Pending,

View File

@ -1,7 +1,6 @@
use client_api_test::TestClient;
use database_entity::dto::{ChatMessage, CreateChatMessageParams, CreateChatParams, MessageCursor};
use futures_util::StreamExt;
use shared_entity::dto::ai_dto::StringOrMessage;
#[tokio::test]
async fn create_chat_and_create_messages_test() {
@ -199,27 +198,13 @@ async fn generate_stream_answer_test() {
.await
.unwrap();
let mut answer = None;
let mut answer = String::new();
while let Some(message) = answer_stream.next().await {
// println!("message: {:?}", message);
let message = message.unwrap();
match message {
StringOrMessage::Left(_) => {},
StringOrMessage::Right(message) => {
answer = Some(message);
},
}
let s = String::from_utf8(message.to_vec()).unwrap();
answer.push_str(&s);
}
assert!(answer.is_some());
let remote_messages = test_client
.api_client
.get_chat_messages(&workspace_id, &chat_id, MessageCursor::NextBack, 2)
.await
.unwrap()
.messages;
assert_eq!(remote_messages.len(), 2);
assert!(!answer.is_empty());
}
// #[tokio::test]