diff --git a/libs/client-api/src/http_chat.rs b/libs/client-api/src/http_chat.rs index 326e5e78..dcd3013f 100644 --- a/libs/client-api/src/http_chat.rs +++ b/libs/client-api/src/http_chat.rs @@ -1,12 +1,13 @@ use crate::http::log_request_id; use crate::Client; +use bytes::Bytes; use database_entity::dto::{ ChatMessage, CreateAnswerMessageParams, CreateChatMessageParams, CreateChatParams, MessageCursor, RepeatedChatMessage, UpdateChatMessageContentParams, }; use futures_core::Stream; use reqwest::Method; -use shared_entity::dto::ai_dto::{RepeatedRelatedQuestion, StringOrMessage}; +use shared_entity::dto::ai_dto::RepeatedRelatedQuestion; use shared_entity::response::{AppResponse, AppResponseError}; impl Client { @@ -109,7 +110,7 @@ impl Client { workspace_id: &str, chat_id: &str, message_id: i64, - ) -> Result>, AppResponseError> { + ) -> Result>, AppResponseError> { let url = format!( "{}/api/chat/{workspace_id}/{chat_id}/{message_id}/answer/stream", self.base_url diff --git a/libs/shared-entity/src/response_stream.rs b/libs/shared-entity/src/response_stream.rs index 2fabeefe..1ab91138 100644 --- a/libs/shared-entity/src/response_stream.rs +++ b/libs/shared-entity/src/response_stream.rs @@ -44,7 +44,7 @@ where } pub async fn answer_response_stream( resp: reqwest::Response, - ) -> Result>, AppResponseError> { + ) -> Result>, AppResponseError> { let status_code = resp.status(); if !status_code.is_success() { let body = resp.text().await?; @@ -52,7 +52,7 @@ where } let stream = resp.bytes_stream().map_err(AppResponseError::from); - Ok(AnswerStream::new(stream)) + Ok(stream) } } @@ -228,7 +228,6 @@ impl Stream for AnswerStream { const NEW_LINE: &[u8; 1] = b"\n"; if bytes.ends_with(NEW_LINE) { let bytes = &bytes[..bytes.len() - NEW_LINE.len()]; - return match String::from_utf8(bytes.to_vec()) { Ok(value) => Poll::Ready(Some(Ok(StringOrMessage::Left(value)))), Err(err) => Poll::Ready(Some(Err(AppResponseError::from(err)))), diff --git a/src/api/chat.rs b/src/api/chat.rs index 81175541..02b17bf7 100644 --- a/src/api/chat.rs +++ b/src/api/chat.rs @@ -28,7 +28,7 @@ use tokio::task; use database::chat; use database::chat::chat_ops::insert_answer_message; -use tracing::{error, instrument, trace}; +use tracing::{instrument, trace}; use validator::Validate; pub fn chat_scope() -> Scope { @@ -52,7 +52,7 @@ pub fn chat_scope() -> Scope { // Create a question for given chat web::resource("/{chat_id}/message/question").route(web::post().to(create_question_handler)), ) - // create a answer for given chat + // create an answer for given chat .service(web::resource("/{chat_id}/message/answer").route(web::post().to(create_answer_handler))) .service( // Generate answer for given question. @@ -200,40 +200,11 @@ async fn answer_stream_handler( .await .map_err(|err| AppError::Internal(err.into()))?; - let finish_action = move |collected_bytes: Vec| { - task::spawn(async move { - if let Ok(final_message) = String::from_utf8(collected_bytes) { - match chat::chat_ops::insert_answer_message( - &state.pg_pool, - ChatAuthor::ai(), - &chat_id, - final_message, - question_id, - ) - .await - { - Ok(message) => { - let json_bytes = serde_json::to_vec(&message)?; - Ok(Bytes::from(json_bytes)) - }, - Err(err) => { - error!("Failed to insert answer message: {}", err); - Err(AppError::Internal(err.into())) - }, - } - } else { - error!("Stream finished with invalid UTF-8 data."); - Err(AppError::InvalidRequest("Invalid UTF-8 data".to_string())) - } - }) - }; - let new_answer_stream = answer_stream.map_err(AppError::from); - let finish_answer_stream = CollectingStream::new(new_answer_stream, finish_action); Ok( HttpResponse::Ok() .content_type("text/event-stream") - .streaming(finish_answer_stream), + .streaming(new_answer_stream), ) } @@ -312,6 +283,7 @@ where } } +#[allow(dead_code)] #[pin_project] pub struct CollectingStream { #[pin] @@ -352,11 +324,7 @@ where CollectingStreamType::AnswerString => { match this.stream.as_mut().poll_next(cx) { Poll::Ready(Some(Ok(bytes))) => { - if let Some(&b'\n') = bytes.last() { - this.buffer.extend_from_slice(&bytes[..bytes.len() - 1]); - } else { - this.buffer.extend_from_slice(&bytes); - } + this.buffer.extend_from_slice(&bytes); Poll::Ready(Some(Ok(bytes))) }, Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), @@ -385,9 +353,9 @@ where CollectingStreamType::AnswerMessage => { if let Some(receiver) = this.result_receiver.as_mut() { match receiver.poll_unpin(cx) { - Poll::Ready(Ok(result)) => { + Poll::Ready(Ok(_)) => { this.result_receiver.take(); - Poll::Ready(Some(result)) + Poll::Ready(None) }, Poll::Ready(Err(_)) => Poll::Ready(None), Poll::Pending => Poll::Pending, diff --git a/tests/ai_test/chat_test.rs b/tests/ai_test/chat_test.rs index 8446eb30..d659fd4a 100644 --- a/tests/ai_test/chat_test.rs +++ b/tests/ai_test/chat_test.rs @@ -1,7 +1,6 @@ use client_api_test::TestClient; use database_entity::dto::{ChatMessage, CreateChatMessageParams, CreateChatParams, MessageCursor}; use futures_util::StreamExt; -use shared_entity::dto::ai_dto::StringOrMessage; #[tokio::test] async fn create_chat_and_create_messages_test() { @@ -199,27 +198,13 @@ async fn generate_stream_answer_test() { .await .unwrap(); - let mut answer = None; + let mut answer = String::new(); while let Some(message) = answer_stream.next().await { - // println!("message: {:?}", message); let message = message.unwrap(); - match message { - StringOrMessage::Left(_) => {}, - StringOrMessage::Right(message) => { - answer = Some(message); - }, - } + let s = String::from_utf8(message.to_vec()).unwrap(); + answer.push_str(&s); } - assert!(answer.is_some()); - - let remote_messages = test_client - .api_client - .get_chat_messages(&workspace_id, &chat_id, MessageCursor::NextBack, 2) - .await - .unwrap() - .messages; - - assert_eq!(remote_messages.len(), 2); + assert!(!answer.is_empty()); } // #[tokio::test]