feat: custom completion prompt (#906)
* feat: custom completion prompt * chore: custom prompt
This commit is contained in:
parent
57c44818e2
commit
2f715c3136
|
|
@ -1,6 +1,6 @@
|
|||
use crate::dto::{
|
||||
AIModel, ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, CreateTextChatContext,
|
||||
Document, EmbeddingRequest, EmbeddingResponse, LocalAIConfig, MessageData,
|
||||
CustomPrompt, Document, EmbeddingRequest, EmbeddingResponse, LocalAIConfig, MessageData,
|
||||
RepeatedLocalAIPackage, RepeatedRelatedQuestion, SearchDocumentsRequest, SummarizeRowResponse,
|
||||
TranslateRowData, TranslateRowResponse,
|
||||
};
|
||||
|
|
@ -49,19 +49,29 @@ impl AppFlowyAIClient {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn completion_text(
|
||||
pub async fn completion_text<T: Into<Option<CompletionType>>>(
|
||||
&self,
|
||||
text: &str,
|
||||
completion_type: CompletionType,
|
||||
completion_type: T,
|
||||
custom_prompt: Option<CustomPrompt>,
|
||||
model: AIModel,
|
||||
) -> Result<CompleteTextResponse, AIError> {
|
||||
let completion_type = completion_type.into();
|
||||
|
||||
if completion_type.is_some() && custom_prompt.is_some() {
|
||||
return Err(AIError::InvalidRequest(
|
||||
"Cannot specify both completion_type and custom_prompt".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if text.is_empty() {
|
||||
return Err(AIError::InvalidRequest("Empty text".to_string()));
|
||||
}
|
||||
|
||||
let params = json!({
|
||||
"text": text,
|
||||
"type": completion_type as u8,
|
||||
"type": completion_type.map(|t| t as u8),
|
||||
"custom_prompt": custom_prompt,
|
||||
});
|
||||
|
||||
let url = format!("{}/completion", self.url);
|
||||
|
|
@ -76,19 +86,22 @@ impl AppFlowyAIClient {
|
|||
.into_data()
|
||||
}
|
||||
|
||||
pub async fn stream_completion_text(
|
||||
pub async fn stream_completion_text<T: Into<Option<CompletionType>>>(
|
||||
&self,
|
||||
text: &str,
|
||||
completion_type: CompletionType,
|
||||
completion_type: T,
|
||||
custom_prompt: Option<CustomPrompt>,
|
||||
model: AIModel,
|
||||
) -> Result<impl Stream<Item = Result<Bytes, AIError>>, AIError> {
|
||||
let completion_type = completion_type.into();
|
||||
if text.is_empty() {
|
||||
return Err(AIError::InvalidRequest("Empty text".to_string()));
|
||||
}
|
||||
|
||||
let params = json!({
|
||||
"text": text,
|
||||
"type": completion_type as u8,
|
||||
"type": completion_type.map(|t| t as u8),
|
||||
"custom_prompt": custom_prompt,
|
||||
});
|
||||
|
||||
let url = format!("{}/completion/stream", self.url);
|
||||
|
|
@ -201,7 +214,7 @@ impl AppFlowyAIClient {
|
|||
chat_id: &str,
|
||||
content: &str,
|
||||
model: &AIModel,
|
||||
metadata: Option<serde_json::Value>,
|
||||
metadata: Option<Value>,
|
||||
) -> Result<ChatAnswer, AIError> {
|
||||
let json = ChatQuestion {
|
||||
chat_id: chat_id.to_string(),
|
||||
|
|
@ -226,7 +239,7 @@ impl AppFlowyAIClient {
|
|||
&self,
|
||||
chat_id: &str,
|
||||
content: &str,
|
||||
metadata: Option<serde_json::Value>,
|
||||
metadata: Option<Value>,
|
||||
model: &AIModel,
|
||||
) -> Result<impl Stream<Item = Result<Bytes, AIError>>, AIError> {
|
||||
let json = ChatQuestion {
|
||||
|
|
@ -251,7 +264,7 @@ impl AppFlowyAIClient {
|
|||
&self,
|
||||
chat_id: &str,
|
||||
content: &str,
|
||||
metadata: Option<serde_json::Value>,
|
||||
metadata: Option<Value>,
|
||||
model: &AIModel,
|
||||
) -> Result<impl Stream<Item = Result<Bytes, AIError>>, AIError> {
|
||||
let json = ChatQuestion {
|
||||
|
|
@ -337,7 +350,7 @@ pub struct AIResponse<T> {
|
|||
|
||||
impl<T> AIResponse<T>
|
||||
where
|
||||
T: serde::de::DeserializeOwned + 'static,
|
||||
T: DeserializeOwned + 'static,
|
||||
{
|
||||
pub async fn from_response(resp: reqwest::Response) -> Result<Self, anyhow::Error> {
|
||||
let status_code = resp.status();
|
||||
|
|
|
|||
|
|
@ -205,7 +205,7 @@ impl Display for EmbeddingsModel {
|
|||
pub enum AIModel {
|
||||
#[default]
|
||||
DefaultModel = 0,
|
||||
GPT35 = 1,
|
||||
GPT4oMini = 1,
|
||||
GPT4o = 2,
|
||||
Claude3Sonnet = 3,
|
||||
Claude3Opus = 4,
|
||||
|
|
@ -215,7 +215,7 @@ impl AIModel {
|
|||
pub fn to_str(&self) -> &str {
|
||||
match self {
|
||||
AIModel::DefaultModel => "default-model",
|
||||
AIModel::GPT35 => "gpt-3.5-turbo",
|
||||
AIModel::GPT4oMini => "gpt-4o-mini",
|
||||
AIModel::GPT4o => "gpt-4o",
|
||||
AIModel::Claude3Sonnet => "claude-3-sonnet",
|
||||
AIModel::Claude3Opus => "claude-3-opus",
|
||||
|
|
@ -228,7 +228,8 @@ impl FromStr for AIModel {
|
|||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"gpt-3.5-turbo" => Ok(AIModel::GPT35),
|
||||
"gpt-3.5-turbo" => Ok(AIModel::GPT4oMini),
|
||||
"gpt-4o-mini" => Ok(AIModel::GPT4oMini),
|
||||
"gpt-4o" => Ok(AIModel::GPT4o),
|
||||
"claude-3-sonnet" => Ok(AIModel::Claude3Sonnet),
|
||||
"claude-3-opus" => Ok(AIModel::Claude3Opus),
|
||||
|
|
@ -364,3 +365,9 @@ impl Display for CreateTextChatContext {
|
|||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct CustomPrompt {
|
||||
pub system: String,
|
||||
pub user: Option<String>,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ async fn continue_writing_test() {
|
|||
.completion_text(
|
||||
"I feel hungry",
|
||||
CompletionType::ContinueWriting,
|
||||
None,
|
||||
AIModel::Claude3Sonnet,
|
||||
)
|
||||
.await
|
||||
|
|
@ -23,7 +24,8 @@ async fn improve_writing_test() {
|
|||
.completion_text(
|
||||
"I fell tired because i sleep not very well last night",
|
||||
CompletionType::ImproveWriting,
|
||||
AIModel::GPT35,
|
||||
None,
|
||||
AIModel::GPT4oMini,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
@ -39,7 +41,8 @@ async fn make_text_shorter_text() {
|
|||
.stream_completion_text(
|
||||
"I have an immense passion and deep-seated affection for Rust, a modern, multi-paradigm, high-performance programming language that I find incredibly satisfying to use due to its focus on safety, speed, and concurrency",
|
||||
CompletionType::MakeShorter,
|
||||
AIModel::GPT35
|
||||
None,
|
||||
AIModel::GPT4oMini
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ async fn create_chat_context_test() {
|
|||
};
|
||||
client.create_chat_text_context(context).await.unwrap();
|
||||
let resp = client
|
||||
.send_question(&chat_id, "Where I live?", &AIModel::GPT35, None)
|
||||
.send_question(&chat_id, "Where I live?", &AIModel::GPT4oMini, None)
|
||||
.await
|
||||
.unwrap();
|
||||
// response will be something like:
|
||||
|
|
|
|||
|
|
@ -10,13 +10,13 @@ async fn qa_test() {
|
|||
client.health_check().await.unwrap();
|
||||
let chat_id = uuid::Uuid::new_v4().to_string();
|
||||
let resp = client
|
||||
.send_question(&chat_id, "I feel hungry", &AIModel::GPT35, None)
|
||||
.send_question(&chat_id, "I feel hungry", &AIModel::GPT4o, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!resp.content.is_empty());
|
||||
|
||||
let questions = client
|
||||
.get_related_question(&chat_id, &1, &AIModel::GPT35)
|
||||
.get_related_question(&chat_id, &1, &AIModel::GPT4oMini)
|
||||
.await
|
||||
.unwrap()
|
||||
.items;
|
||||
|
|
@ -29,7 +29,7 @@ async fn stop_stream_test() {
|
|||
client.health_check().await.unwrap();
|
||||
let chat_id = uuid::Uuid::new_v4().to_string();
|
||||
let mut stream = client
|
||||
.stream_question(&chat_id, "I feel hungry", None, &AIModel::GPT35)
|
||||
.stream_question(&chat_id, "I feel hungry", None, &AIModel::GPT4oMini)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -51,7 +51,7 @@ async fn stream_test() {
|
|||
client.health_check().await.unwrap();
|
||||
let chat_id = uuid::Uuid::new_v4().to_string();
|
||||
let stream = client
|
||||
.stream_question_v2(&chat_id, "I feel hungry", None, &AIModel::GPT35)
|
||||
.stream_question_v2(&chat_id, "I feel hungry", None, &AIModel::GPT4oMini)
|
||||
.await
|
||||
.unwrap();
|
||||
let json_stream = JsonStream::<serde_json::Value>::new(stream);
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ async fn summarize_row_test() {
|
|||
let json = json!({"name": "Jack", "age": 25, "city": "New York"});
|
||||
|
||||
let result = client
|
||||
.summarize_row(json.as_object().unwrap(), AIModel::GPT35)
|
||||
.summarize_row(json.as_object().unwrap(), AIModel::GPT4oMini)
|
||||
.await
|
||||
.unwrap();
|
||||
result.text.contains("Jack");
|
||||
|
|
|
|||
|
|
@ -20,6 +20,9 @@ async fn translate_row_test() {
|
|||
include_header: false,
|
||||
};
|
||||
|
||||
let result = client.translate_row(data, AIModel::GPT35).await.unwrap();
|
||||
let result = client
|
||||
.translate_row(data, AIModel::GPT4oMini)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result.items.len(), 2);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -172,7 +172,7 @@ impl Client {
|
|||
);
|
||||
}
|
||||
|
||||
let ai_model = Arc::new(RwLock::new(AIModel::GPT35));
|
||||
let ai_model = Arc::new(RwLock::new(AIModel::GPT4oMini));
|
||||
|
||||
Self {
|
||||
base_url: base_url.to_string(),
|
||||
|
|
|
|||
|
|
@ -34,7 +34,18 @@ pub struct SummarizeRowResponse {
|
|||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct CompleteTextParams {
|
||||
pub text: String,
|
||||
pub completion_type: CompletionType,
|
||||
pub completion_type: Option<CompletionType>,
|
||||
pub custom_prompt: Option<CustomPrompt>,
|
||||
}
|
||||
|
||||
impl CompleteTextParams {
|
||||
pub fn new_with_completion_type(text: String, completion_type: CompletionType) -> Self {
|
||||
Self {
|
||||
text,
|
||||
completion_type: Some(completion_type),
|
||||
custom_prompt: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ async fn complete_text_handler(
|
|||
let params = payload.into_inner();
|
||||
let resp = state
|
||||
.ai_client
|
||||
.completion_text(¶ms.text, params.completion_type, ai_model)
|
||||
.completion_text(¶ms.text, params.completion_type, None, ai_model)
|
||||
.await
|
||||
.map_err(|err| AppError::Internal(err.into()))?;
|
||||
Ok(AppResponse::Ok().with_data(resp).into())
|
||||
|
|
@ -51,7 +51,12 @@ async fn stream_complete_text_handler(
|
|||
let params = payload.into_inner();
|
||||
match state
|
||||
.ai_client
|
||||
.stream_completion_text(¶ms.text, params.completion_type, ai_model)
|
||||
.stream_completion_text(
|
||||
¶ms.text,
|
||||
params.completion_type,
|
||||
params.custom_prompt,
|
||||
ai_model,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(stream) => Ok(
|
||||
|
|
|
|||
|
|
@ -179,5 +179,5 @@ pub(crate) fn ai_model_from_header(req: &HttpRequest) -> AIModel {
|
|||
let header = header.to_str().ok()?;
|
||||
AIModel::from_str(header).ok()
|
||||
})
|
||||
.unwrap_or(AIModel::GPT35)
|
||||
.unwrap_or(AIModel::GPT4oMini)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,13 +8,13 @@ async fn improve_writing_test() {
|
|||
return;
|
||||
}
|
||||
let test_client = TestClient::new_user().await;
|
||||
test_client.api_client.set_ai_model(AIModel::GPT4o);
|
||||
test_client.api_client.set_ai_model(AIModel::GPT4oMini);
|
||||
|
||||
let workspace_id = test_client.workspace_id().await;
|
||||
let params = CompleteTextParams {
|
||||
text: "I feel hungry".to_string(),
|
||||
completion_type: CompletionType::ImproveWriting,
|
||||
};
|
||||
let params = CompleteTextParams::new_with_completion_type(
|
||||
"I feel hungry".to_string(),
|
||||
CompletionType::ImproveWriting,
|
||||
);
|
||||
|
||||
let resp = test_client
|
||||
.api_client
|
||||
|
|
|
|||
Loading…
Reference in New Issue