chore: update client api (#593)
This commit is contained in:
parent
edfcb5c1ea
commit
1cc5b58254
|
|
@ -1,6 +1,6 @@
|
|||
use crate::dto::{
|
||||
ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, Document, MessageData,
|
||||
SearchDocumentsRequest, SummarizeRowResponse, TranslateRowResponse,
|
||||
RepeatedRelatedQuestion, SearchDocumentsRequest, SummarizeRowResponse, TranslateRowResponse,
|
||||
};
|
||||
use crate::error::AIError;
|
||||
use anyhow::anyhow;
|
||||
|
|
@ -161,6 +161,18 @@ impl AppFlowyAIClient {
|
|||
AIResponse::<String>::stream_response(resp).await
|
||||
}
|
||||
|
||||
pub async fn get_related_question(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
message_id: &i64,
|
||||
) -> Result<RepeatedRelatedQuestion, AIError> {
|
||||
let url = format!("{}/chat/{chat_id}/{message_id}/related_question", self.url);
|
||||
let resp = self.http_client(Method::GET, &url)?.send().await?;
|
||||
AIResponse::<RepeatedRelatedQuestion>::from_response(resp)
|
||||
.await?
|
||||
.into_data()
|
||||
}
|
||||
|
||||
fn http_client(&self, method: Method, url: &str) -> Result<RequestBuilder, AIError> {
|
||||
let request_builder = self.client.request(method, url);
|
||||
Ok(request_builder)
|
||||
|
|
|
|||
|
|
@ -27,6 +27,20 @@ pub struct ChatAnswer {
|
|||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct RepeatedRelatedQuestion {
|
||||
pub message_id: i64,
|
||||
pub items: Vec<RelatedQuestion>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct RelatedQuestion {
|
||||
pub content: String,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct CompleteTextResponse {
|
||||
pub text: String,
|
||||
|
|
|
|||
|
|
@ -11,9 +11,16 @@ async fn qa_test() {
|
|||
.await
|
||||
.unwrap();
|
||||
assert!(!resp.content.is_empty());
|
||||
|
||||
let questions = client
|
||||
.get_related_question(&chat_id, &1)
|
||||
.await
|
||||
.unwrap()
|
||||
.items;
|
||||
println!("questions: {:?}", questions);
|
||||
assert_eq!(questions.len(), 3)
|
||||
}
|
||||
#[tokio::test]
|
||||
|
||||
async fn stop_steam_test() {
|
||||
let client = appflowy_ai_client();
|
||||
client.health_check().await.unwrap();
|
||||
|
|
@ -35,11 +42,12 @@ async fn stop_steam_test() {
|
|||
assert_ne!(count, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn steam_test() {
|
||||
let client = appflowy_ai_client();
|
||||
client.health_check().await.unwrap();
|
||||
let chat_id = uuid::Uuid::new_v4().to_string();
|
||||
let mut stream = client
|
||||
let stream = client
|
||||
.stream_question(&chat_id, "I feel hungry")
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ use database_entity::dto::{
|
|||
};
|
||||
use futures_core::Stream;
|
||||
use reqwest::Method;
|
||||
use shared_entity::dto::ai_dto::RepeatedRelatedQuestion;
|
||||
use shared_entity::response::{AppResponse, AppResponseError};
|
||||
|
||||
impl Client {
|
||||
|
|
@ -57,6 +58,26 @@ impl Client {
|
|||
log_request_id(&resp);
|
||||
AppResponse::<ChatMessage>::stream_response(resp).await
|
||||
}
|
||||
pub async fn get_chat_related_question(
|
||||
&self,
|
||||
workspace_id: &str,
|
||||
chat_id: &str,
|
||||
message_id: i64,
|
||||
) -> Result<RepeatedRelatedQuestion, AppResponseError> {
|
||||
let url = format!(
|
||||
"{}/api/chat/{workspace_id}/{chat_id}/{message_id}/related_question",
|
||||
self.base_url
|
||||
);
|
||||
let resp = self
|
||||
.http_client_with_auth(Method::GET, &url)
|
||||
.await?
|
||||
.send()
|
||||
.await?;
|
||||
log_request_id(&resp);
|
||||
AppResponse::<RepeatedRelatedQuestion>::from_response(resp)
|
||||
.await?
|
||||
.into_data()
|
||||
}
|
||||
|
||||
pub async fn get_chat_messages(
|
||||
&self,
|
||||
|
|
|
|||
|
|
@ -572,6 +572,12 @@ pub struct CreateChatMessageParams {
|
|||
pub message_type: ChatMessageType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UpdateChatMessageParams {
|
||||
pub message_id: i64,
|
||||
pub meta_data: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize_repr, Deserialize_repr)]
|
||||
#[repr(u8)]
|
||||
pub enum ChatMessageType {
|
||||
|
|
|
|||
|
|
@ -5,12 +5,12 @@ use app_error::AppError;
|
|||
use chrono::{DateTime, Utc};
|
||||
use database_entity::dto::{
|
||||
ChatAuthor, ChatMessage, CreateChatParams, GetChatMessageParams, MessageCursor,
|
||||
RepeatedChatMessage, UpdateChatParams,
|
||||
RepeatedChatMessage, UpdateChatMessageParams, UpdateChatParams,
|
||||
};
|
||||
|
||||
use serde_json::json;
|
||||
use sqlx::postgres::PgArguments;
|
||||
use sqlx::{Arguments, Executor, Postgres, Transaction};
|
||||
use sqlx::{Arguments, Executor, PgPool, Postgres, Transaction};
|
||||
use std::ops::DerefMut;
|
||||
use std::str::FromStr;
|
||||
use tracing::warn;
|
||||
|
|
@ -167,7 +167,7 @@ pub async fn select_chat_messages(
|
|||
) -> Result<RepeatedChatMessage, AppError> {
|
||||
let chat_id = Uuid::from_str(chat_id)?;
|
||||
let mut query = r#"
|
||||
SELECT message_id, content, created_at, author
|
||||
SELECT message_id, content, created_at, author, meta_data
|
||||
FROM af_chat_messages
|
||||
WHERE chat_id = $1
|
||||
"#
|
||||
|
|
@ -326,3 +326,30 @@ pub async fn get_all_chat_messages<'a, E: Executor<'a, Database = Postgres>>(
|
|||
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
pub async fn update_chat_message(
|
||||
pg_pool: &PgPool,
|
||||
params: UpdateChatMessageParams,
|
||||
) -> Result<(), AppError> {
|
||||
for (key, value) in params.meta_data.iter() {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE af_chat_messages
|
||||
SET meta_data = jsonb_set(
|
||||
COALESCE(meta_data, '{}'),
|
||||
$2,
|
||||
$3::jsonb,
|
||||
true
|
||||
)
|
||||
WHERE id = $1
|
||||
"#,
|
||||
)
|
||||
.bind(params.message_id)
|
||||
.bind(format!("{{{}}}", key))
|
||||
.bind(value)
|
||||
.execute(pg_pool)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -77,14 +77,14 @@ impl CollabHandle {
|
|||
Self::handle_collab_updates(&mut update_stream, content.get_collab(), messages).await?;
|
||||
}
|
||||
let workspace_id =
|
||||
Uuid::parse_str(&workspace_id).map_err(|e| crate::error::Error::InvalidWorkspace(e))?;
|
||||
Uuid::parse_str(&workspace_id).map_err(crate::error::Error::InvalidWorkspace)?;
|
||||
|
||||
let mut tasks = JoinSet::new();
|
||||
tasks.spawn(Self::receive_collab_updates(
|
||||
update_stream,
|
||||
Arc::downgrade(&content),
|
||||
object_id.clone(),
|
||||
workspace_id.clone(),
|
||||
workspace_id,
|
||||
ingest_interval,
|
||||
closing.clone(),
|
||||
));
|
||||
|
|
@ -253,8 +253,8 @@ impl CollabHandle {
|
|||
}
|
||||
|
||||
pub async fn shutdown(mut self) {
|
||||
let _ = self.closing.cancel();
|
||||
while let Some(_) = self.tasks.join_next().await { /* wait for all tasks to finish */ }
|
||||
self.closing.cancel();
|
||||
while self.tasks.join_next().await.is_some() { /* wait for all tasks to finish */ }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -353,7 +353,7 @@ mod test {
|
|||
|
||||
let tokens: i64 =
|
||||
sqlx::query("SELECT index_token_usage from af_workspace WHERE workspace_id = $1")
|
||||
.bind(&workspace_id)
|
||||
.bind(workspace_id)
|
||||
.fetch_one(&db)
|
||||
.await
|
||||
.unwrap()
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ impl From<Fragment> for EmbedFragment {
|
|||
impl From<EmbedFragment> for AFCollabEmbeddingParams {
|
||||
fn from(f: EmbedFragment) -> Self {
|
||||
AFCollabEmbeddingParams {
|
||||
fragment_id: f.fragment_id.into(),
|
||||
fragment_id: f.fragment_id,
|
||||
object_id: f.object_id,
|
||||
collab_type: f.collab_type,
|
||||
content_type: f.content_type,
|
||||
|
|
@ -91,7 +91,7 @@ impl PostgresIndexer {
|
|||
#[allow(dead_code)]
|
||||
pub async fn open(openai_api_key: &str, pg_conn: &str) -> Result<Self> {
|
||||
let openai = Client::new(openai_api_key.to_string());
|
||||
let db = PgPool::connect(&pg_conn).await?;
|
||||
let db = PgPool::connect(pg_conn).await?;
|
||||
Ok(Self { openai, db })
|
||||
}
|
||||
|
||||
|
|
@ -226,7 +226,7 @@ mod test {
|
|||
|
||||
// resolve embeddings from OpenAI
|
||||
let embeddings = indexer.get_embeddings(fragments).await.unwrap();
|
||||
assert_eq!(embeddings.fragments[0].embedding.is_some(), true);
|
||||
assert!(embeddings.fragments[0].embedding.is_some());
|
||||
|
||||
// store embeddings in DB
|
||||
indexer
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ pub async fn setup_collab(db: &PgPool, uid: i64, object_id: Uuid, encoded_collab
|
|||
let mut tx = db.begin().await.unwrap();
|
||||
let user_uuid = Uuid::new_v4();
|
||||
sqlx::query("INSERT INTO auth.users(id) VALUES($1)")
|
||||
.bind(&user_uuid)
|
||||
.bind(user_uuid)
|
||||
.execute(tx.deref_mut())
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
@ -48,7 +48,7 @@ pub async fn setup_collab(db: &PgPool, uid: i64, object_id: Uuid, encoded_collab
|
|||
&mut tx,
|
||||
&uid,
|
||||
&workspace_id.to_string(),
|
||||
&CollabParams::new(object_id.clone(), CollabType::Document, encoded_collab),
|
||||
&CollabParams::new(object_id, CollabType::Document, encoded_collab),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ impl DocumentWatcher {
|
|||
if block.external_type.as_deref() == Some("text") {
|
||||
if let Some(text_id) = block.external_id.as_deref() {
|
||||
if let Some(json) = text_map.get(text_id) {
|
||||
match serde_json::from_str::<Vec<TextDelta>>(&json) {
|
||||
match serde_json::from_str::<Vec<TextDelta>>(json) {
|
||||
Ok(deltas) => {
|
||||
for delta in deltas {
|
||||
if let TextDelta::Inserted(text, _) = delta {
|
||||
|
|
@ -151,7 +151,7 @@ impl DocumentWatcher {
|
|||
|
||||
impl Indexable for DocumentWatcher {
|
||||
fn get_collab(&self) -> &MutexCollab {
|
||||
&*self.content.get_collab()
|
||||
self.content.get_collab()
|
||||
}
|
||||
|
||||
fn changes(&self) -> Pin<Box<dyn Stream<Item = FragmentUpdate> + Send + Sync>> {
|
||||
|
|
@ -253,9 +253,7 @@ mod test {
|
|||
let blocks = &data.blocks;
|
||||
let (_, block) = blocks
|
||||
.iter()
|
||||
//TODO: Block::external_type should probably be an enum
|
||||
.filter(|(_, b)| b.external_type.as_deref() == Some("text"))
|
||||
.next()
|
||||
.find(|(_, b)| b.external_type.as_deref() == Some("text"))
|
||||
.unwrap();
|
||||
block.clone()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,14 +3,17 @@ use crate::state::AppState;
|
|||
use actix_web::web::{Data, Json};
|
||||
use actix_web::{web, HttpResponse, Scope};
|
||||
use app_error::AppError;
|
||||
use appflowy_ai_client::dto::RepeatedRelatedQuestion;
|
||||
use authentication::jwt::UserUuid;
|
||||
use database::chat::chat_ops::update_chat_message;
|
||||
use database_entity::dto::{
|
||||
CreateChatMessageParams, CreateChatParams, GetChatMessageParams, MessageCursor,
|
||||
RepeatedChatMessage,
|
||||
RepeatedChatMessage, UpdateChatMessageParams,
|
||||
};
|
||||
use shared_entity::response::{AppResponse, JsonAppResponse};
|
||||
use std::collections::HashMap;
|
||||
use tracing::trace;
|
||||
|
||||
use tracing::{instrument, trace};
|
||||
use validator::Validate;
|
||||
|
||||
pub fn chat_scope() -> Scope {
|
||||
|
|
@ -22,7 +25,15 @@ pub fn chat_scope() -> Scope {
|
|||
.route(web::post().to(update_chat_handler))
|
||||
.route(web::get().to(get_chat_message_handler)),
|
||||
)
|
||||
.service(web::resource("/{chat_id}/message").route(web::post().to(post_chat_message_handler)))
|
||||
.service(
|
||||
web::resource("/{chat_id}/{message_id}/related_question")
|
||||
.route(web::get().to(get_related_message_handler)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/{chat_id}/message")
|
||||
.route(web::post().to(post_chat_message_handler))
|
||||
.route(web::put().to(update_chat_message_handler)),
|
||||
)
|
||||
}
|
||||
async fn create_chat_handler(
|
||||
path: web::Path<String>,
|
||||
|
|
@ -83,6 +94,29 @@ async fn post_chat_message_handler(
|
|||
)
|
||||
}
|
||||
|
||||
async fn update_chat_message_handler(
|
||||
state: Data<AppState>,
|
||||
payload: Json<UpdateChatMessageParams>,
|
||||
) -> actix_web::Result<JsonAppResponse<()>> {
|
||||
let params = payload.into_inner();
|
||||
update_chat_message(&state.pg_pool, params).await?;
|
||||
Ok(AppResponse::Ok().into())
|
||||
}
|
||||
|
||||
async fn get_related_message_handler(
|
||||
path: web::Path<(String, String, i64)>,
|
||||
state: Data<AppState>,
|
||||
) -> actix_web::Result<JsonAppResponse<RepeatedRelatedQuestion>> {
|
||||
let (_workspace_id, chat_id, message_id) = path.into_inner();
|
||||
let resp = state
|
||||
.ai_client
|
||||
.get_related_question(&chat_id, &message_id)
|
||||
.await
|
||||
.map_err(|err| AppError::Internal(err.into()))?;
|
||||
Ok(AppResponse::Ok().with_data(resp).into())
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn get_chat_message_handler(
|
||||
path: web::Path<(String, String)>,
|
||||
query: web::Query<HashMap<String, String>>,
|
||||
|
|
|
|||
|
|
@ -117,4 +117,12 @@ async fn chat_qa_test() {
|
|||
|
||||
let messages: Vec<ChatMessage> = stream.map(|message| message.unwrap()).collect().await;
|
||||
assert_eq!(messages.len(), 2);
|
||||
|
||||
let related_questions = test_client
|
||||
.api_client
|
||||
.get_chat_related_question(&workspace_id, &chat_id, messages[1].message_id)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(related_questions.items.len(), 3);
|
||||
println!("related questions: {:?}", related_questions.items);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue