parent
9d3d28ad89
commit
430e3e15c9
|
|
@ -1,12 +1,13 @@
|
|||
use crate::http::log_request_id;
|
||||
use crate::Client;
|
||||
use bytes::Bytes;
|
||||
use database_entity::dto::{
|
||||
ChatMessage, CreateAnswerMessageParams, CreateChatMessageParams, CreateChatParams, MessageCursor,
|
||||
RepeatedChatMessage, UpdateChatMessageContentParams,
|
||||
};
|
||||
use futures_core::Stream;
|
||||
use reqwest::Method;
|
||||
use shared_entity::dto::ai_dto::{RepeatedRelatedQuestion, StringOrMessage};
|
||||
use shared_entity::dto::ai_dto::RepeatedRelatedQuestion;
|
||||
use shared_entity::response::{AppResponse, AppResponseError};
|
||||
|
||||
impl Client {
|
||||
|
|
@ -109,7 +110,7 @@ impl Client {
|
|||
workspace_id: &str,
|
||||
chat_id: &str,
|
||||
message_id: i64,
|
||||
) -> Result<impl Stream<Item = Result<StringOrMessage, AppResponseError>>, AppResponseError> {
|
||||
) -> Result<impl Stream<Item = Result<Bytes, AppResponseError>>, AppResponseError> {
|
||||
let url = format!(
|
||||
"{}/api/chat/{workspace_id}/{chat_id}/{message_id}/answer/stream",
|
||||
self.base_url
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ where
|
|||
}
|
||||
pub async fn answer_response_stream(
|
||||
resp: reqwest::Response,
|
||||
) -> Result<impl Stream<Item = Result<StringOrMessage, AppResponseError>>, AppResponseError> {
|
||||
) -> Result<impl Stream<Item = Result<Bytes, AppResponseError>>, AppResponseError> {
|
||||
let status_code = resp.status();
|
||||
if !status_code.is_success() {
|
||||
let body = resp.text().await?;
|
||||
|
|
@ -52,7 +52,7 @@ where
|
|||
}
|
||||
|
||||
let stream = resp.bytes_stream().map_err(AppResponseError::from);
|
||||
Ok(AnswerStream::new(stream))
|
||||
Ok(stream)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -228,7 +228,6 @@ impl Stream for AnswerStream {
|
|||
const NEW_LINE: &[u8; 1] = b"\n";
|
||||
if bytes.ends_with(NEW_LINE) {
|
||||
let bytes = &bytes[..bytes.len() - NEW_LINE.len()];
|
||||
|
||||
return match String::from_utf8(bytes.to_vec()) {
|
||||
Ok(value) => Poll::Ready(Some(Ok(StringOrMessage::Left(value)))),
|
||||
Err(err) => Poll::Ready(Some(Err(AppResponseError::from(err)))),
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ use tokio::task;
|
|||
use database::chat;
|
||||
|
||||
use database::chat::chat_ops::insert_answer_message;
|
||||
use tracing::{error, instrument, trace};
|
||||
use tracing::{instrument, trace};
|
||||
use validator::Validate;
|
||||
|
||||
pub fn chat_scope() -> Scope {
|
||||
|
|
@ -52,7 +52,7 @@ pub fn chat_scope() -> Scope {
|
|||
// Create a question for given chat
|
||||
web::resource("/{chat_id}/message/question").route(web::post().to(create_question_handler)),
|
||||
)
|
||||
// create a answer for given chat
|
||||
// create an answer for given chat
|
||||
.service(web::resource("/{chat_id}/message/answer").route(web::post().to(create_answer_handler)))
|
||||
.service(
|
||||
// Generate answer for given question.
|
||||
|
|
@ -200,40 +200,11 @@ async fn answer_stream_handler(
|
|||
.await
|
||||
.map_err(|err| AppError::Internal(err.into()))?;
|
||||
|
||||
let finish_action = move |collected_bytes: Vec<u8>| {
|
||||
task::spawn(async move {
|
||||
if let Ok(final_message) = String::from_utf8(collected_bytes) {
|
||||
match chat::chat_ops::insert_answer_message(
|
||||
&state.pg_pool,
|
||||
ChatAuthor::ai(),
|
||||
&chat_id,
|
||||
final_message,
|
||||
question_id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(message) => {
|
||||
let json_bytes = serde_json::to_vec(&message)?;
|
||||
Ok(Bytes::from(json_bytes))
|
||||
},
|
||||
Err(err) => {
|
||||
error!("Failed to insert answer message: {}", err);
|
||||
Err(AppError::Internal(err.into()))
|
||||
},
|
||||
}
|
||||
} else {
|
||||
error!("Stream finished with invalid UTF-8 data.");
|
||||
Err(AppError::InvalidRequest("Invalid UTF-8 data".to_string()))
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
let new_answer_stream = answer_stream.map_err(AppError::from);
|
||||
let finish_answer_stream = CollectingStream::new(new_answer_stream, finish_action);
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("text/event-stream")
|
||||
.streaming(finish_answer_stream),
|
||||
.streaming(new_answer_stream),
|
||||
)
|
||||
}
|
||||
|
||||
|
|
@ -312,6 +283,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[pin_project]
|
||||
pub struct CollectingStream<S, F> {
|
||||
#[pin]
|
||||
|
|
@ -352,11 +324,7 @@ where
|
|||
CollectingStreamType::AnswerString => {
|
||||
match this.stream.as_mut().poll_next(cx) {
|
||||
Poll::Ready(Some(Ok(bytes))) => {
|
||||
if let Some(&b'\n') = bytes.last() {
|
||||
this.buffer.extend_from_slice(&bytes[..bytes.len() - 1]);
|
||||
} else {
|
||||
this.buffer.extend_from_slice(&bytes);
|
||||
}
|
||||
this.buffer.extend_from_slice(&bytes);
|
||||
Poll::Ready(Some(Ok(bytes)))
|
||||
},
|
||||
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
|
||||
|
|
@ -385,9 +353,9 @@ where
|
|||
CollectingStreamType::AnswerMessage => {
|
||||
if let Some(receiver) = this.result_receiver.as_mut() {
|
||||
match receiver.poll_unpin(cx) {
|
||||
Poll::Ready(Ok(result)) => {
|
||||
Poll::Ready(Ok(_)) => {
|
||||
this.result_receiver.take();
|
||||
Poll::Ready(Some(result))
|
||||
Poll::Ready(None)
|
||||
},
|
||||
Poll::Ready(Err(_)) => Poll::Ready(None),
|
||||
Poll::Pending => Poll::Pending,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
use client_api_test::TestClient;
|
||||
use database_entity::dto::{ChatMessage, CreateChatMessageParams, CreateChatParams, MessageCursor};
|
||||
use futures_util::StreamExt;
|
||||
use shared_entity::dto::ai_dto::StringOrMessage;
|
||||
|
||||
#[tokio::test]
|
||||
async fn create_chat_and_create_messages_test() {
|
||||
|
|
@ -199,27 +198,13 @@ async fn generate_stream_answer_test() {
|
|||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut answer = None;
|
||||
let mut answer = String::new();
|
||||
while let Some(message) = answer_stream.next().await {
|
||||
// println!("message: {:?}", message);
|
||||
let message = message.unwrap();
|
||||
match message {
|
||||
StringOrMessage::Left(_) => {},
|
||||
StringOrMessage::Right(message) => {
|
||||
answer = Some(message);
|
||||
},
|
||||
}
|
||||
let s = String::from_utf8(message.to_vec()).unwrap();
|
||||
answer.push_str(&s);
|
||||
}
|
||||
assert!(answer.is_some());
|
||||
|
||||
let remote_messages = test_client
|
||||
.api_client
|
||||
.get_chat_messages(&workspace_id, &chat_id, MessageCursor::NextBack, 2)
|
||||
.await
|
||||
.unwrap()
|
||||
.messages;
|
||||
|
||||
assert_eq!(remote_messages.len(), 2);
|
||||
assert!(!answer.is_empty());
|
||||
}
|
||||
|
||||
// #[tokio::test]
|
||||
|
|
|
|||
Loading…
Reference in New Issue