chore: update client api (#593)

This commit is contained in:
Nathan.fooo 2024-06-02 20:20:14 +08:00 committed by GitHub
parent edfcb5c1ea
commit 1cc5b58254
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 152 additions and 24 deletions

View File

@ -1,6 +1,6 @@
use crate::dto::{
ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, Document, MessageData,
SearchDocumentsRequest, SummarizeRowResponse, TranslateRowResponse,
RepeatedRelatedQuestion, SearchDocumentsRequest, SummarizeRowResponse, TranslateRowResponse,
};
use crate::error::AIError;
use anyhow::anyhow;
@ -161,6 +161,18 @@ impl AppFlowyAIClient {
AIResponse::<String>::stream_response(resp).await
}
pub async fn get_related_question(
&self,
chat_id: &str,
message_id: &i64,
) -> Result<RepeatedRelatedQuestion, AIError> {
let url = format!("{}/chat/{chat_id}/{message_id}/related_question", self.url);
let resp = self.http_client(Method::GET, &url)?.send().await?;
AIResponse::<RepeatedRelatedQuestion>::from_response(resp)
.await?
.into_data()
}
fn http_client(&self, method: Method, url: &str) -> Result<RequestBuilder, AIError> {
let request_builder = self.client.request(method, url);
Ok(request_builder)

View File

@ -27,6 +27,20 @@ pub struct ChatAnswer {
pub content: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RepeatedRelatedQuestion {
pub message_id: i64,
pub items: Vec<RelatedQuestion>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RelatedQuestion {
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CompleteTextResponse {
pub text: String,

View File

@ -11,9 +11,16 @@ async fn qa_test() {
.await
.unwrap();
assert!(!resp.content.is_empty());
let questions = client
.get_related_question(&chat_id, &1)
.await
.unwrap()
.items;
println!("questions: {:?}", questions);
assert_eq!(questions.len(), 3)
}
#[tokio::test]
async fn stop_steam_test() {
let client = appflowy_ai_client();
client.health_check().await.unwrap();
@ -35,11 +42,12 @@ async fn stop_steam_test() {
assert_ne!(count, 0);
}
#[tokio::test]
async fn steam_test() {
let client = appflowy_ai_client();
client.health_check().await.unwrap();
let chat_id = uuid::Uuid::new_v4().to_string();
let mut stream = client
let stream = client
.stream_question(&chat_id, "I feel hungry")
.await
.unwrap();

View File

@ -5,6 +5,7 @@ use database_entity::dto::{
};
use futures_core::Stream;
use reqwest::Method;
use shared_entity::dto::ai_dto::RepeatedRelatedQuestion;
use shared_entity::response::{AppResponse, AppResponseError};
impl Client {
@ -57,6 +58,26 @@ impl Client {
log_request_id(&resp);
AppResponse::<ChatMessage>::stream_response(resp).await
}
pub async fn get_chat_related_question(
&self,
workspace_id: &str,
chat_id: &str,
message_id: i64,
) -> Result<RepeatedRelatedQuestion, AppResponseError> {
let url = format!(
"{}/api/chat/{workspace_id}/{chat_id}/{message_id}/related_question",
self.base_url
);
let resp = self
.http_client_with_auth(Method::GET, &url)
.await?
.send()
.await?;
log_request_id(&resp);
AppResponse::<RepeatedRelatedQuestion>::from_response(resp)
.await?
.into_data()
}
pub async fn get_chat_messages(
&self,

View File

@ -572,6 +572,12 @@ pub struct CreateChatMessageParams {
pub message_type: ChatMessageType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdateChatMessageParams {
pub message_id: i64,
pub meta_data: HashMap<String, String>,
}
#[derive(Debug, Clone, Default, Serialize_repr, Deserialize_repr)]
#[repr(u8)]
pub enum ChatMessageType {

View File

@ -5,12 +5,12 @@ use app_error::AppError;
use chrono::{DateTime, Utc};
use database_entity::dto::{
ChatAuthor, ChatMessage, CreateChatParams, GetChatMessageParams, MessageCursor,
RepeatedChatMessage, UpdateChatParams,
RepeatedChatMessage, UpdateChatMessageParams, UpdateChatParams,
};
use serde_json::json;
use sqlx::postgres::PgArguments;
use sqlx::{Arguments, Executor, Postgres, Transaction};
use sqlx::{Arguments, Executor, PgPool, Postgres, Transaction};
use std::ops::DerefMut;
use std::str::FromStr;
use tracing::warn;
@ -167,7 +167,7 @@ pub async fn select_chat_messages(
) -> Result<RepeatedChatMessage, AppError> {
let chat_id = Uuid::from_str(chat_id)?;
let mut query = r#"
SELECT message_id, content, created_at, author
SELECT message_id, content, created_at, author, meta_data
FROM af_chat_messages
WHERE chat_id = $1
"#
@ -326,3 +326,30 @@ pub async fn get_all_chat_messages<'a, E: Executor<'a, Database = Postgres>>(
Ok(messages)
}
pub async fn update_chat_message(
pg_pool: &PgPool,
params: UpdateChatMessageParams,
) -> Result<(), AppError> {
for (key, value) in params.meta_data.iter() {
sqlx::query(
r#"
UPDATE af_chat_messages
SET meta_data = jsonb_set(
COALESCE(meta_data, '{}'),
$2,
$3::jsonb,
true
)
WHERE id = $1
"#,
)
.bind(params.message_id)
.bind(format!("{{{}}}", key))
.bind(value)
.execute(pg_pool)
.await?;
}
Ok(())
}

View File

@ -77,14 +77,14 @@ impl CollabHandle {
Self::handle_collab_updates(&mut update_stream, content.get_collab(), messages).await?;
}
let workspace_id =
Uuid::parse_str(&workspace_id).map_err(|e| crate::error::Error::InvalidWorkspace(e))?;
Uuid::parse_str(&workspace_id).map_err(crate::error::Error::InvalidWorkspace)?;
let mut tasks = JoinSet::new();
tasks.spawn(Self::receive_collab_updates(
update_stream,
Arc::downgrade(&content),
object_id.clone(),
workspace_id.clone(),
workspace_id,
ingest_interval,
closing.clone(),
));
@ -253,8 +253,8 @@ impl CollabHandle {
}
pub async fn shutdown(mut self) {
let _ = self.closing.cancel();
while let Some(_) = self.tasks.join_next().await { /* wait for all tasks to finish */ }
self.closing.cancel();
while self.tasks.join_next().await.is_some() { /* wait for all tasks to finish */ }
}
}
@ -353,7 +353,7 @@ mod test {
let tokens: i64 =
sqlx::query("SELECT index_token_usage from af_workspace WHERE workspace_id = $1")
.bind(&workspace_id)
.bind(workspace_id)
.fetch_one(&db)
.await
.unwrap()

View File

@ -72,7 +72,7 @@ impl From<Fragment> for EmbedFragment {
impl From<EmbedFragment> for AFCollabEmbeddingParams {
fn from(f: EmbedFragment) -> Self {
AFCollabEmbeddingParams {
fragment_id: f.fragment_id.into(),
fragment_id: f.fragment_id,
object_id: f.object_id,
collab_type: f.collab_type,
content_type: f.content_type,
@ -91,7 +91,7 @@ impl PostgresIndexer {
#[allow(dead_code)]
pub async fn open(openai_api_key: &str, pg_conn: &str) -> Result<Self> {
let openai = Client::new(openai_api_key.to_string());
let db = PgPool::connect(&pg_conn).await?;
let db = PgPool::connect(pg_conn).await?;
Ok(Self { openai, db })
}
@ -226,7 +226,7 @@ mod test {
// resolve embeddings from OpenAI
let embeddings = indexer.get_embeddings(fragments).await.unwrap();
assert_eq!(embeddings.fragments[0].embedding.is_some(), true);
assert!(embeddings.fragments[0].embedding.is_some());
// store embeddings in DB
indexer

View File

@ -31,7 +31,7 @@ pub async fn setup_collab(db: &PgPool, uid: i64, object_id: Uuid, encoded_collab
let mut tx = db.begin().await.unwrap();
let user_uuid = Uuid::new_v4();
sqlx::query("INSERT INTO auth.users(id) VALUES($1)")
.bind(&user_uuid)
.bind(user_uuid)
.execute(tx.deref_mut())
.await
.unwrap();
@ -48,7 +48,7 @@ pub async fn setup_collab(db: &PgPool, uid: i64, object_id: Uuid, encoded_collab
&mut tx,
&uid,
&workspace_id.to_string(),
&CollabParams::new(object_id.clone(), CollabType::Document, encoded_collab),
&CollabParams::new(object_id, CollabType::Document, encoded_collab),
)
.await
.unwrap();

View File

@ -120,7 +120,7 @@ impl DocumentWatcher {
if block.external_type.as_deref() == Some("text") {
if let Some(text_id) = block.external_id.as_deref() {
if let Some(json) = text_map.get(text_id) {
match serde_json::from_str::<Vec<TextDelta>>(&json) {
match serde_json::from_str::<Vec<TextDelta>>(json) {
Ok(deltas) => {
for delta in deltas {
if let TextDelta::Inserted(text, _) = delta {
@ -151,7 +151,7 @@ impl DocumentWatcher {
impl Indexable for DocumentWatcher {
fn get_collab(&self) -> &MutexCollab {
&*self.content.get_collab()
self.content.get_collab()
}
fn changes(&self) -> Pin<Box<dyn Stream<Item = FragmentUpdate> + Send + Sync>> {
@ -253,9 +253,7 @@ mod test {
let blocks = &data.blocks;
let (_, block) = blocks
.iter()
//TODO: Block::external_type should probably be an enum
.filter(|(_, b)| b.external_type.as_deref() == Some("text"))
.next()
.find(|(_, b)| b.external_type.as_deref() == Some("text"))
.unwrap();
block.clone()
}

View File

@ -3,14 +3,17 @@ use crate::state::AppState;
use actix_web::web::{Data, Json};
use actix_web::{web, HttpResponse, Scope};
use app_error::AppError;
use appflowy_ai_client::dto::RepeatedRelatedQuestion;
use authentication::jwt::UserUuid;
use database::chat::chat_ops::update_chat_message;
use database_entity::dto::{
CreateChatMessageParams, CreateChatParams, GetChatMessageParams, MessageCursor,
RepeatedChatMessage,
RepeatedChatMessage, UpdateChatMessageParams,
};
use shared_entity::response::{AppResponse, JsonAppResponse};
use std::collections::HashMap;
use tracing::trace;
use tracing::{instrument, trace};
use validator::Validate;
pub fn chat_scope() -> Scope {
@ -22,7 +25,15 @@ pub fn chat_scope() -> Scope {
.route(web::post().to(update_chat_handler))
.route(web::get().to(get_chat_message_handler)),
)
.service(web::resource("/{chat_id}/message").route(web::post().to(post_chat_message_handler)))
.service(
web::resource("/{chat_id}/{message_id}/related_question")
.route(web::get().to(get_related_message_handler)),
)
.service(
web::resource("/{chat_id}/message")
.route(web::post().to(post_chat_message_handler))
.route(web::put().to(update_chat_message_handler)),
)
}
async fn create_chat_handler(
path: web::Path<String>,
@ -83,6 +94,29 @@ async fn post_chat_message_handler(
)
}
async fn update_chat_message_handler(
state: Data<AppState>,
payload: Json<UpdateChatMessageParams>,
) -> actix_web::Result<JsonAppResponse<()>> {
let params = payload.into_inner();
update_chat_message(&state.pg_pool, params).await?;
Ok(AppResponse::Ok().into())
}
async fn get_related_message_handler(
path: web::Path<(String, String, i64)>,
state: Data<AppState>,
) -> actix_web::Result<JsonAppResponse<RepeatedRelatedQuestion>> {
let (_workspace_id, chat_id, message_id) = path.into_inner();
let resp = state
.ai_client
.get_related_question(&chat_id, &message_id)
.await
.map_err(|err| AppError::Internal(err.into()))?;
Ok(AppResponse::Ok().with_data(resp).into())
}
#[instrument(level = "debug", skip_all, err)]
async fn get_chat_message_handler(
path: web::Path<(String, String)>,
query: web::Query<HashMap<String, String>>,

View File

@ -117,4 +117,12 @@ async fn chat_qa_test() {
let messages: Vec<ChatMessage> = stream.map(|message| message.unwrap()).collect().await;
assert_eq!(messages.len(), 2);
let related_questions = test_client
.api_client
.get_chat_related_question(&workspace_id, &chat_id, messages[1].message_id)
.await
.unwrap();
assert_eq!(related_questions.items.len(), 3);
println!("related questions: {:?}", related_questions.items);
}